{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE UndecidableInstances #-} module Zenith.Tree where import Codec.Borsh import Data.HexString import Data.Int (Int32, Int64, Int8) import Data.Maybe (fromJust, isNothing) import qualified GHC.Generics as GHC import qualified Generics.SOP as SOP import ZcashHaskell.Orchard (combineOrchardNodes, getOrchardNodeValue) import ZcashHaskell.Sapling (combineSaplingNodes, getSaplingNodeValue) import ZcashHaskell.Types ( MerklePath(..) , OrchardFrontier(..) , OrchardTree(..) , SaplingTree(..) ) type Level = Int8 maxLevel :: Level maxLevel = 32 type Position = Int32 class Monoid v => Measured a v where measure :: a -> Position -> Int64 -> v class Node v where getLevel :: v -> Level getHash :: v -> HexString getPosition :: v -> Position getIndex :: v -> Int64 isFull :: v -> Bool isMarked :: v -> Bool mkNode :: Level -> Position -> HexString -> v type OrchardCommitment = HexString instance Measured OrchardCommitment OrchardNode where measure oc p i = case getOrchardNodeValue (hexBytes oc) of Nothing -> OrchardNode 0 (hexString "00") 0 True 0 False Just val -> OrchardNode p val 0 True i False type SaplingCommitment = HexString instance Measured SaplingCommitment SaplingNode where measure sc p i = case getSaplingNodeValue (hexBytes sc) of Nothing -> SaplingNode 0 (hexString "00") 0 True 0 False Just val -> SaplingNode p val 0 True i False data Tree v = EmptyLeaf | Leaf !v | PrunedBranch !v | Branch !v !(Tree v) !(Tree v) | InvalidTree deriving stock (Eq, GHC.Generic) deriving anyclass (SOP.Generic, SOP.HasDatatypeInfo) deriving (BorshSize, ToBorsh, FromBorsh) via AsEnum (Tree v) instance (Node v, Show v) => Show (Tree v) where show EmptyLeaf = "()" show (Leaf v) = "(" ++ show v ++ ")" show (PrunedBranch v) = "{" ++ show v ++ "}" show (Branch s x y) = "<" ++ show (getHash s) ++ ">\n" ++ show x ++ "\n" ++ show y show InvalidTree = "InvalidTree" instance (Monoid v, Node v) => Semigroup (Tree v) where (<>) InvalidTree _ = InvalidTree (<>) _ InvalidTree = InvalidTree (<>) EmptyLeaf EmptyLeaf = PrunedBranch $ value $ branch EmptyLeaf EmptyLeaf (<>) EmptyLeaf x = x (<>) (Leaf x) EmptyLeaf = branch (Leaf x) EmptyLeaf (<>) (Leaf x) (Leaf y) = branch (Leaf x) (Leaf y) (<>) (Leaf _) Branch {} = InvalidTree (<>) (Leaf _) (PrunedBranch _) = InvalidTree (<>) (PrunedBranch x) EmptyLeaf = PrunedBranch $ x <> x (<>) (PrunedBranch x) (Leaf y) = if isFull x then InvalidTree else mkSubTree (getLevel x) (Leaf y) (<>) (PrunedBranch x) (Branch s t u) = if getLevel x == getLevel s then branch (PrunedBranch x) (Branch s t u) else InvalidTree (<>) (PrunedBranch x) (PrunedBranch y) = PrunedBranch $ x <> y (<>) (Branch s x y) EmptyLeaf = branch (Branch s x y) $ getEmptyRoot (getLevel s) (<>) (Branch s x y) (PrunedBranch w) | getLevel s == getLevel w = branch (Branch s x y) (PrunedBranch w) | otherwise = InvalidTree (<>) (Branch s x y) (Leaf w) | isFull s = Branch s x y <> mkSubTree (getLevel s) (Leaf w) | isFull (value x) = branch x (y <> Leaf w) | otherwise = branch (x <> Leaf w) y (<>) (Branch s x y) (Branch s1 x1 y1) | getLevel s == getLevel s1 = branch (Branch s x y) (Branch s1 x1 y1) | otherwise = InvalidTree value :: Monoid v => Tree v -> v value EmptyLeaf = mempty value (Leaf v) = v value (PrunedBranch v) = v value (Branch v _ _) = v value InvalidTree = mempty branch :: Monoid v => Tree v -> Tree v -> Tree v branch x y = Branch (value x <> value y) x y leaf :: Measured a v => a -> Int32 -> Int64 -> Tree v leaf a p i = Leaf (measure a p i) prunedBranch :: Monoid v => Node v => Level -> Position -> HexString -> Tree v prunedBranch level pos val = PrunedBranch $ mkNode level pos val root :: Monoid v => Node v => Tree v -> Tree v root tree = if getLevel (value tree) == maxLevel then tree else mkSubTree maxLevel tree getEmptyRoot :: Monoid v => Node v => Level -> Tree v getEmptyRoot level = iterate (\x -> x <> x) EmptyLeaf !! fromIntegral level append :: Monoid v => Measured a v => Node v => Tree v -> (a, Int64) -> Tree v append tree (n, i) = tree <> leaf n p i where p = 1 + getPosition (value tree) mkSubTree :: Node v => Monoid v => Level -> Tree v -> Tree v mkSubTree level t = if getLevel (value subtree) == level then subtree else mkSubTree level subtree where subtree = t <> EmptyLeaf path :: Monoid v => Node v => Position -> Tree v -> Maybe MerklePath path pos (Branch s x y) = if length (collectPath (Branch s x y)) /= 32 then Nothing else Just $ MerklePath pos $ collectPath (Branch s x y) where collectPath :: Monoid v => Node v => Tree v -> [HexString] collectPath EmptyLeaf = [] collectPath Leaf {} = [] collectPath PrunedBranch {} = [] collectPath InvalidTree = [] collectPath (Branch _ j k) | getPosition (value k) /= 0 && getPosition (value k) < pos = [] | getPosition (value j) < pos = collectPath k <> [getHash (value j)] | getPosition (value j) >= pos = collectPath j <> [getHash (value k)] | otherwise = [] path _ _ = Nothing getNotePosition :: Monoid v => Node v => Tree v -> Int64 -> Maybe Position getNotePosition (Leaf x) i | getIndex x == i = Just $ getPosition x | otherwise = Nothing getNotePosition (Branch _ x y) i | getIndex (value x) >= i = getNotePosition x i | getIndex (value y) >= i = getNotePosition y i | otherwise = Nothing getNotePosition _ _ = Nothing truncateTree :: Monoid v => Node v => Tree v -> Int64 -> Tree v truncateTree (Branch s x y) i | getLevel s == 1 && getIndex (value x) == i = branch x EmptyLeaf | getLevel s == 1 && getIndex (value y) == i = branch x y | getIndex (value x) >= i = branch (truncateTree x i) (getEmptyRoot (getLevel s)) | getIndex (value y) >= i = branch x (truncateTree y i) truncateTree x _ = x data SaplingNode = SaplingNode { sn_position :: !Position , sn_value :: !HexString , sn_level :: !Level , sn_full :: !Bool , sn_index :: !Int64 , sn_mark :: !Bool } deriving stock (Eq, GHC.Generic) deriving anyclass (SOP.Generic, SOP.HasDatatypeInfo) deriving (BorshSize, ToBorsh, FromBorsh) via AsStruct SaplingNode instance Semigroup SaplingNode where (<>) x y = case combineSaplingNodes (sn_level x) (sn_value x) (sn_value y) of Nothing -> x Just newHash -> SaplingNode (max (sn_position x) (sn_position y)) newHash (1 + sn_level x) (sn_full x && sn_full y) (max (sn_index x) (sn_index y)) (sn_mark x || sn_mark y) instance Monoid SaplingNode where mempty = SaplingNode 0 (hexString "00") 0 False 0 False mappend = (<>) instance Node SaplingNode where getLevel = sn_level getHash = sn_value getPosition = sn_position getIndex = sn_index isFull = sn_full isMarked = sn_mark mkNode l p v = SaplingNode p v l True 0 False instance Show SaplingNode where show = show . sn_value saplingSize :: SaplingTree -> Int64 saplingSize tree = (if isNothing (st_left tree) then 0 else 1) + (if isNothing (st_right tree) then 0 else 1) + foldl (\x (i, p) -> case p of Nothing -> x + 0 Just _ -> x + 2 ^ i) 0 (zip [1 ..] $ st_parents tree) mkSaplingTree :: SaplingTree -> Tree SaplingNode mkSaplingTree tree = foldl (\t (i, n) -> case n of Just n' -> prunedBranch i 0 n' <> t Nothing -> t <> getEmptyRoot i) leafRoot (zip [1 ..] $ st_parents tree) where leafRoot = case st_right tree of Just r' -> leaf (fromJust $ st_left tree) (pos - 1) 0 <> leaf r' pos 0 Nothing -> leaf (fromJust $ st_left tree) pos 0 <> EmptyLeaf pos = fromIntegral $ saplingSize tree - 1 -- | Orchard data OrchardNode = OrchardNode { on_position :: !Position , on_value :: !HexString , on_level :: !Level , on_full :: !Bool , on_index :: !Int64 , on_mark :: !Bool } deriving stock (Eq, GHC.Generic) deriving anyclass (SOP.Generic, SOP.HasDatatypeInfo) deriving (BorshSize, ToBorsh, FromBorsh) via AsStruct OrchardNode instance Semigroup OrchardNode where (<>) x y = case combineOrchardNodes (fromIntegral $ on_level x) (on_value x) (on_value y) of Nothing -> x Just newHash -> OrchardNode (max (on_position x) (on_position y)) newHash (1 + on_level x) (on_full x && on_full y) (max (on_index x) (on_index y)) (on_mark x || on_mark y) instance Monoid OrchardNode where mempty = OrchardNode 0 (hexString "00") 0 False 0 False mappend = (<>) instance Node OrchardNode where getLevel = on_level getHash = on_value getPosition = on_position getIndex = on_index isFull = on_full isMarked = on_mark mkNode l p v = OrchardNode p v l True 0 False instance Show OrchardNode where show = show . on_value instance Measured OrchardNode OrchardNode where measure o p i = OrchardNode p (on_value o) (on_level o) (on_full o) i (on_mark o) orchardSize :: OrchardTree -> Int64 orchardSize tree = (if isNothing (ot_left tree) then 0 else 1) + (if isNothing (ot_right tree) then 0 else 1) + foldl (\x (i, p) -> case p of Nothing -> x + 0 Just _ -> x + 2 ^ i) 0 (zip [1 ..] $ ot_parents tree) mkOrchardTree :: OrchardTree -> Tree OrchardNode mkOrchardTree tree = foldl (\t (i, n) -> case n of Just n' -> prunedBranch i 0 n' <> t Nothing -> t <> getEmptyRoot i) leafRoot (zip [1 ..] $ ot_parents tree) where leafRoot = case ot_right tree of Just r' -> leaf (fromJust $ ot_left tree) (pos - 1) 0 <> leaf r' pos 0 Nothing -> leaf (fromJust $ ot_left tree) pos 0 <> EmptyLeaf pos = fromIntegral $ orchardSize tree - 1