{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE UndecidableInstances #-} module Zenith.Tree where import Codec.Borsh import Control.Monad.Logger (NoLoggingT, logDebugN) import Data.HexString import Data.Int (Int32, Int64, Int8) import Data.Maybe (fromJust, isNothing) import qualified Data.Text as T import qualified GHC.Generics as GHC import qualified Generics.SOP as SOP import ZcashHaskell.Orchard (combineOrchardNodes, getOrchardNodeValue) import ZcashHaskell.Sapling (combineSaplingNodes, getSaplingNodeValue) import ZcashHaskell.Types (MerklePath(..), OrchardTree(..), SaplingTree(..)) type Level = Int8 maxLevel :: Level maxLevel = 32 type Position = Int32 class Monoid v => Measured a v where measure :: a -> Position -> Int64 -> v class Node v where getLevel :: v -> Level getHash :: v -> HexString getPosition :: v -> Position getIndex :: v -> Int64 isFull :: v -> Bool isMarked :: v -> Bool mkNode :: Level -> Position -> HexString -> v type OrchardCommitment = HexString instance Measured OrchardCommitment OrchardNode where measure oc p i = case getOrchardNodeValue (hexBytes oc) of Nothing -> OrchardNode 0 (hexString "00") 0 True 0 False Just val -> OrchardNode p val 0 True i False type SaplingCommitment = HexString instance Measured SaplingCommitment SaplingNode where measure sc p i = case getSaplingNodeValue (hexBytes sc) of Nothing -> SaplingNode 0 (hexString "00") 0 True 0 False Just val -> SaplingNode p val 0 True i False data Tree v = EmptyLeaf | Leaf !v | PrunedBranch !v | Branch !v !(Tree v) !(Tree v) | InvalidTree deriving stock (Eq, GHC.Generic) deriving anyclass (SOP.Generic, SOP.HasDatatypeInfo) deriving (BorshSize, ToBorsh, FromBorsh) via AsEnum (Tree v) instance (Node v, Show v) => Show (Tree v) where show EmptyLeaf = "()" show (Leaf v) = "(" ++ show v ++ ")" show (PrunedBranch v) = "{" ++ show v ++ "}" show (Branch s x y) = "<" ++ show (getHash s) ++ ">\n" ++ show x ++ "\n" ++ show y show InvalidTree = "InvalidTree" instance (Monoid v, Node v) => Semigroup (Tree v) where (<>) InvalidTree _ = InvalidTree (<>) _ InvalidTree = InvalidTree (<>) EmptyLeaf EmptyLeaf = PrunedBranch $ value $ branch EmptyLeaf EmptyLeaf (<>) EmptyLeaf x = x (<>) (Leaf x) EmptyLeaf = branch (Leaf x) EmptyLeaf (<>) (Leaf x) (Leaf y) = branch (Leaf x) (Leaf y) (<>) (Leaf _) Branch {} = InvalidTree (<>) (Leaf _) (PrunedBranch _) = InvalidTree (<>) (PrunedBranch x) EmptyLeaf = PrunedBranch $ x <> x (<>) (PrunedBranch x) (Leaf y) = if isFull x then InvalidTree else mkSubTree (getLevel x) (Leaf y) (<>) (PrunedBranch x) (Branch s t u) = if getLevel x == getLevel s then branch (PrunedBranch x) (Branch s t u) else InvalidTree (<>) (PrunedBranch x) (PrunedBranch y) = PrunedBranch $ x <> y (<>) (Branch s x y) EmptyLeaf = branch (Branch s x y) $ getEmptyRoot (getLevel s) (<>) (Branch s x y) (PrunedBranch w) | getLevel s == getLevel w = branch (Branch s x y) (PrunedBranch w) | otherwise = InvalidTree (<>) (Branch s x y) (Leaf w) | isFull s = InvalidTree | isFull (value x) = branch x (y <> Leaf w) | otherwise = branch (x <> Leaf w) y (<>) (Branch s x y) (Branch s1 x1 y1) | getLevel s == getLevel s1 = branch (Branch s x y) (Branch s1 x1 y1) | otherwise = InvalidTree value :: Monoid v => Tree v -> v value EmptyLeaf = mempty value (Leaf v) = v value (PrunedBranch v) = v value (Branch v _ _) = v value InvalidTree = mempty branch :: Monoid v => Tree v -> Tree v -> Tree v branch x y = Branch (value x <> value y) x y leaf :: Measured a v => a -> Int32 -> Int64 -> Tree v leaf a p i = Leaf (measure a p i) prunedBranch :: Monoid v => Node v => Level -> Position -> HexString -> Tree v prunedBranch level pos val = PrunedBranch $ mkNode level pos val root :: Monoid v => Node v => Tree v -> Tree v root tree = if getLevel (value tree) == maxLevel then tree else mkSubTree maxLevel tree getEmptyRoot :: Monoid v => Node v => Level -> Tree v getEmptyRoot level = iterate (\x -> x <> x) EmptyLeaf !! fromIntegral level append :: Monoid v => Measured a v => Node v => Tree v -> (a, Int64) -> Tree v append tree (n, i) = tree <> leaf n p i where p = 1 + getPosition (value tree) mkSubTree :: Node v => Monoid v => Level -> Tree v -> Tree v mkSubTree level t = if getLevel (value subtree) == level then subtree else mkSubTree level subtree where subtree = t <> EmptyLeaf path :: Monoid v => Node v => Position -> Tree v -> Maybe MerklePath path pos (Branch s x y) = if length (collectPath (Branch s x y)) /= 32 then Nothing else Just $ MerklePath pos $ collectPath (Branch s x y) where collectPath :: Monoid v => Node v => Tree v -> [HexString] collectPath EmptyLeaf = [] collectPath Leaf {} = [] collectPath PrunedBranch {} = [] collectPath InvalidTree = [] collectPath (Branch _ j k) | getPosition (value k) /= 0 && getPosition (value k) < pos = [] | getPosition (value j) < pos = collectPath k <> [getHash (value j)] | getPosition (value j) >= pos = collectPath j <> [getHash (value k)] | otherwise = [] path _ _ = Nothing nullPath :: MerklePath nullPath = MerklePath 0 [] getNotePosition :: Monoid v => Node v => Tree v -> Int64 -> Maybe Position getNotePosition (Leaf x) i | getIndex x == i = Just $ getPosition x | otherwise = Nothing getNotePosition (Branch _ x y) i | getIndex (value x) >= i = getNotePosition x i | getIndex (value y) >= i = getNotePosition y i | otherwise = Nothing getNotePosition _ _ = Nothing truncateTree :: Monoid v => Node v => Tree v -> Int64 -> NoLoggingT IO (Tree v) truncateTree (Branch s x y) i | getLevel s == 1 && getIndex (value x) == i = do logDebugN $ T.pack $ show (getLevel s) ++ " Trunc to left leaf" return $ branch x EmptyLeaf | getLevel s == 1 && getIndex (value y) == i = do logDebugN $ T.pack $ show (getLevel s) ++ " Trunc to right leaf" return $ branch x y | getIndex (value x) >= i = do logDebugN $ T.pack $ show (getLevel s) ++ ": " ++ show i ++ " left i: " ++ show (getIndex (value x)) l <- truncateTree x i return $ branch (l) (getEmptyRoot (getLevel (value x))) | getIndex (value y) /= 0 && getIndex (value y) >= i = do logDebugN $ T.pack $ show (getLevel s) ++ ": " ++ show i ++ " right i: " ++ show (getIndex (value y)) r <- truncateTree y i return $ branch x (r) | otherwise = do logDebugN $ T.pack $ show (getLevel s) ++ ": " ++ show (getIndex (value x)) ++ " catchall " ++ show (getIndex (value y)) return InvalidTree truncateTree x _ = return x countLeaves :: Node v => Tree v -> Int64 countLeaves (Branch s x y) = if isFull s then 2 ^ getLevel s else countLeaves x + countLeaves y countLeaves (PrunedBranch x) = if isFull x then 2 ^ getLevel x else 0 countLeaves (Leaf _) = 1 countLeaves EmptyLeaf = 0 countLeaves InvalidTree = 0 batchAppend :: Measured a v => Node v => Monoid v => Tree v -> [(Int32, (a, Int64))] -> Tree v batchAppend x [] = x batchAppend (Branch s x y) notes | isFull s = InvalidTree | isFull (value x) = branch x (batchAppend y notes) | otherwise = branch (batchAppend x (take leftSide notes)) (batchAppend y (drop leftSide notes)) where leftSide = fromIntegral $ 2 ^ getLevel (value x) - countLeaves x batchAppend (PrunedBranch k) notes | isFull k = InvalidTree | otherwise = branch (batchAppend (getEmptyRoot (getLevel k - 1)) (take leftSide notes)) (batchAppend (getEmptyRoot (getLevel k - 1)) (drop leftSide notes)) where leftSide = fromIntegral $ 2 ^ (getLevel k - 1) batchAppend EmptyLeaf notes | length notes == 1 = leaf (fst $ snd $ head notes) (fst $ head notes) (snd $ snd $ head notes) | otherwise = InvalidTree batchAppend _ notes = InvalidTree data SaplingNode = SaplingNode { sn_position :: !Position , sn_value :: !HexString , sn_level :: !Level , sn_full :: !Bool , sn_index :: !Int64 , sn_mark :: !Bool } deriving stock (Eq, GHC.Generic) deriving anyclass (SOP.Generic, SOP.HasDatatypeInfo) deriving (BorshSize, ToBorsh, FromBorsh) via AsStruct SaplingNode instance Semigroup SaplingNode where (<>) x y = case combineSaplingNodes (sn_level x) (sn_value x) (sn_value y) of Nothing -> x Just newHash -> SaplingNode (max (sn_position x) (sn_position y)) newHash (1 + sn_level x) (sn_full x && sn_full y) (max (sn_index x) (sn_index y)) (sn_mark x || sn_mark y) instance Monoid SaplingNode where mempty = SaplingNode 0 (hexString "00") 0 False 0 False mappend = (<>) instance Node SaplingNode where getLevel = sn_level getHash = sn_value getPosition = sn_position getIndex = sn_index isFull = sn_full isMarked = sn_mark mkNode l p v = SaplingNode p v l True 0 False instance Show SaplingNode where show = show . sn_value saplingSize :: SaplingTree -> Int64 saplingSize tree = (if isNothing (st_left tree) then 0 else 1) + (if isNothing (st_right tree) then 0 else 1) + foldl (\x (i, p) -> case p of Nothing -> x + 0 Just _ -> x + 2 ^ i) 0 (zip [1 ..] $ st_parents tree) mkSaplingTree :: SaplingTree -> Tree SaplingNode mkSaplingTree tree = foldl (\t (i, n) -> case n of Just n' -> prunedBranch i 0 n' <> t Nothing -> t <> getEmptyRoot i) leafRoot (zip [1 ..] $ st_parents tree) where leafRoot = case st_right tree of Just r' -> leaf (fromJust $ st_left tree) (pos - 1) 0 <> leaf r' pos 0 Nothing -> leaf (fromJust $ st_left tree) pos 0 <> EmptyLeaf pos = fromIntegral $ saplingSize tree - 1 -- | Orchard data OrchardNode = OrchardNode { on_position :: !Position , on_value :: !HexString , on_level :: !Level , on_full :: !Bool , on_index :: !Int64 , on_mark :: !Bool } deriving stock (Eq, GHC.Generic) deriving anyclass (SOP.Generic, SOP.HasDatatypeInfo) deriving (BorshSize, ToBorsh, FromBorsh) via AsStruct OrchardNode instance Semigroup OrchardNode where (<>) x y = case combineOrchardNodes (fromIntegral $ on_level x) (on_value x) (on_value y) of Nothing -> x Just newHash -> OrchardNode (max (on_position x) (on_position y)) newHash (1 + on_level x) (on_full x && on_full y) (max (on_index x) (on_index y)) (on_mark x || on_mark y) instance Monoid OrchardNode where mempty = OrchardNode 0 (hexString "00") 0 False 0 False mappend = (<>) instance Node OrchardNode where getLevel = on_level getHash = on_value getPosition = on_position getIndex = on_index isFull = on_full isMarked = on_mark mkNode l p v = OrchardNode p v l True 0 False instance Show OrchardNode where show = show . on_value instance Measured OrchardNode OrchardNode where measure o p i = OrchardNode p (on_value o) (on_level o) (on_full o) i (on_mark o) orchardSize :: OrchardTree -> Int64 orchardSize tree = (if isNothing (ot_left tree) then 0 else 1) + (if isNothing (ot_right tree) then 0 else 1) + foldl (\x (i, p) -> case p of Nothing -> x + 0 Just _ -> x + 2 ^ i) 0 (zip [1 ..] $ ot_parents tree) mkOrchardTree :: OrchardTree -> Tree OrchardNode mkOrchardTree tree = foldl (\t (i, n) -> case n of Just n' -> prunedBranch i 0 n' <> t Nothing -> t <> getEmptyRoot i) leafRoot (zip [1 ..] $ ot_parents tree) where leafRoot = case ot_right tree of Just r' -> leaf (fromJust $ ot_left tree) (pos - 1) 0 <> leaf r' pos 0 Nothing -> leaf (fromJust $ ot_left tree) pos 0 <> EmptyLeaf pos = fromIntegral $ orchardSize tree - 1