{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleInstances #-} module Zenith.Tree where import Data.HexString import Data.Int (Int64) import ZcashHaskell.Orchard (combineOrchardNodes, getOrchardNodeValue) import ZcashHaskell.Types (OrchardFrontier(..)) type Level = Integer maxLevel :: Level maxLevel = 32 type Position = Int64 class Monoid v => Measured a v where measure :: a -> Position -> v class Node v where getLevel :: v -> Level getTag :: v -> HexString getPosition :: v -> Position isFull :: v -> Bool mkNode :: Level -> Position -> HexString -> v type OrchardCommitment = HexString instance Measured OrchardCommitment OrchardNode where measure oc p = case getOrchardNodeValue (hexBytes oc) of Nothing -> OrchardNode 0 (hexString "00") 0 True Just val -> OrchardNode p val 0 True data Tree v = EmptyLeaf | Leaf !v | PrunedBranch !v | Branch !v !(Tree v) !(Tree v) deriving (Eq) instance (Node v, Show v) => Show (Tree v) where show EmptyLeaf = "()" show (Leaf v) = "(" ++ show v ++ ")" show (PrunedBranch v) = "{" ++ show v ++ "}" show (Branch s x y) = "<" ++ show (getTag s) ++ ">\n" ++ show x ++ "\n" ++ show y instance (Monoid v, Node v) => Semigroup (Tree v) where (<>) EmptyLeaf EmptyLeaf = PrunedBranch $ value $ branch EmptyLeaf EmptyLeaf (<>) EmptyLeaf x = x (<>) (Leaf x) EmptyLeaf = branch (Leaf x) EmptyLeaf (<>) (Leaf x) (Leaf y) = branch (Leaf x) (Leaf y) (<>) (Leaf x) Branch {} = Leaf x (<>) (Leaf x) (PrunedBranch _) = Leaf x (<>) (PrunedBranch x) EmptyLeaf = PrunedBranch $ x <> x (<>) (PrunedBranch x) (Leaf _) = PrunedBranch x (<>) (PrunedBranch x) (Branch s t u) = if getLevel x == getLevel s then branch (PrunedBranch x) (Branch s t u) else EmptyLeaf (<>) (PrunedBranch x) (PrunedBranch y) = PrunedBranch $ x <> y (<>) (Branch s x y) EmptyLeaf = branch (Branch s x y) $ getEmptyRoot (getLevel s) (<>) (Branch s x y) (PrunedBranch _) = Branch s x y (<>) (Branch s x y) (Leaf w) | isFull s = Branch s x y <> mkSubTree (getLevel s) (Leaf w) | isFull (value x) = branch x (y <> Leaf w) | otherwise = branch (x <> Leaf w) y (<>) (Branch s x y) (Branch s1 x1 y1) | getLevel s == getLevel s1 = branch (Branch s x y) (Branch s1 x1 y1) | otherwise = Branch s x y value :: Monoid v => Tree v -> v value EmptyLeaf = mempty value (Leaf v) = v value (PrunedBranch v) = v value (Branch v _ _) = v branch :: Monoid v => Tree v -> Tree v -> Tree v branch x y = Branch (value x <> value y) x y leaf :: Measured a v => a -> Int64 -> Tree v leaf a p = Leaf (measure a p) prunedBranch :: Monoid v => Node v => Level -> Position -> HexString -> Tree v prunedBranch level pos val = PrunedBranch $ mkNode level pos val root :: Monoid v => Node v => Tree v -> Tree v root tree = if getLevel (value tree) == maxLevel - 1 then tree else mkSubTree (maxLevel - 1) tree getEmptyRoot :: Monoid v => Node v => Level -> Tree v getEmptyRoot level = iterate (\x -> x <> x) EmptyLeaf !! fromIntegral level data OrchardNode = OrchardNode { on_position :: !Position , on_value :: !HexString , on_level :: !Level , on_full :: !Bool } deriving (Eq) instance Semigroup OrchardNode where (<>) x y = case combineOrchardNodes (on_level x) (on_value x) (on_value y) of Nothing -> x Just newHash -> OrchardNode (max (on_position x) (on_position y)) newHash (1 + on_level x) (on_full x && on_full y) instance Monoid OrchardNode where mempty = OrchardNode 0 (hexString "00") 0 False mappend = (<>) instance Node OrchardNode where getLevel = on_level getTag = on_value getPosition = on_position isFull = on_full mkNode l p v = OrchardNode p v l True instance Show OrchardNode where show = show . on_value instance Measured OrchardNode OrchardNode where measure o p = OrchardNode p (on_value o) (on_level o) (on_full o) mkSubTree :: Node v => Monoid v => Level -> Tree v -> Tree v mkSubTree level t = if getLevel (value subtree) == level then subtree else mkSubTree level subtree where subtree = t <> EmptyLeaf mkOrchardTree :: OrchardFrontier -> Tree OrchardNode mkOrchardTree (OrchardFrontier p l o) = if odd p then addOrchardOmmers (tail o) $ Leaf (OrchardNode (p - 1) (head o) 0 True) <> Leaf (OrchardNode p l 0 True) else addOrchardOmmers o $ Leaf (OrchardNode p l 0 True) <> EmptyLeaf addOrchardOmmers :: [HexString] -> Tree OrchardNode -> Tree OrchardNode addOrchardOmmers xs t = foldl (\s x -> PrunedBranch (mkNode (getLevel $ value s) (p (value s)) x) <> s) t xs where p :: OrchardNode -> Position p (OrchardNode pos _ l _) = pos - (2 ^ l)