From 42f6b6becb1ade1e5713e9cc8387042429754a3f Mon Sep 17 00:00:00 2001 From: Rene Vergara Date: Tue, 19 Nov 2024 07:26:27 -0600 Subject: [PATCH] fix: tree truncation function --- src/Zenith/Tree.hs | 51 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 7 deletions(-) diff --git a/src/Zenith/Tree.hs b/src/Zenith/Tree.hs index c78331a..9e70530 100644 --- a/src/Zenith/Tree.hs +++ b/src/Zenith/Tree.hs @@ -9,9 +9,11 @@ module Zenith.Tree where import Codec.Borsh +import Control.Monad.Logger (LoggingT, logDebugN) import Data.HexString import Data.Int (Int32, Int64, Int8) import Data.Maybe (fromJust, isNothing) +import qualified Data.Text as T import qualified GHC.Generics as GHC import qualified Generics.SOP as SOP import ZcashHaskell.Orchard (combineOrchardNodes, getOrchardNodeValue) @@ -179,14 +181,49 @@ getNotePosition (Branch _ x y) i | otherwise = Nothing getNotePosition _ _ = Nothing -truncateTree :: Monoid v => Node v => Tree v -> Int64 -> Tree v +truncateTree :: Monoid v => Node v => Tree v -> Int64 -> LoggingT IO (Tree v) truncateTree (Branch s x y) i - | getLevel s == 1 && getIndex (value x) == i = branch x EmptyLeaf - | getLevel s == 1 && getIndex (value y) == i = branch x y - | getIndex (value x) >= i = - branch (truncateTree x i) (getEmptyRoot (getLevel s)) - | getIndex (value y) >= i = branch x (truncateTree y i) -truncateTree x _ = x + | getLevel s == 1 && getIndex (value x) == i = do + logDebugN $ T.pack $ show (getLevel s) ++ " Trunc to left leaf" + return $ branch x EmptyLeaf + | getLevel s == 1 && getIndex (value y) == i = do + logDebugN $ T.pack $ show (getLevel s) ++ " Trunc to right leaf" + return $ branch x y + | getIndex (value x) >= i = do + logDebugN $ + T.pack $ + show (getLevel s) ++ + ": " ++ show i ++ " left i: " ++ show (getIndex (value x)) + l <- truncateTree x i + return $ branch (l) (getEmptyRoot (getLevel (value x))) + | getIndex (value y) /= 0 && getIndex (value y) >= i = do + logDebugN $ + T.pack $ + show (getLevel s) ++ + ": " ++ show i ++ " right i: " ++ show (getIndex (value y)) + r <- truncateTree y i + return $ branch x (r) + | otherwise = do + logDebugN $ + T.pack $ + show (getLevel s) ++ + ": " ++ + show (getIndex (value x)) ++ " catchall " ++ show (getIndex (value y)) + return InvalidTree +truncateTree x _ = return x + +countLeaves :: Node v => Tree v -> Int64 +countLeaves (Branch s x y) = + if isFull s + then 2 ^ getLevel s + else countLeaves x + countLeaves y +countLeaves (PrunedBranch x) = + if isFull x + then 2 ^ getLevel x + else 0 +countLeaves (Leaf _) = 1 +countLeaves EmptyLeaf = 0 +countLeaves InvalidTree = 0 data SaplingNode = SaplingNode { sn_position :: !Position