zenith/src/Zenith/Tree.hs

379 lines
11 KiB
Haskell
Raw Normal View History

2024-10-23 20:49:24 +00:00
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
2024-11-04 16:17:54 +00:00
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE UndecidableInstances #-}
2024-10-23 20:49:24 +00:00
module Zenith.Tree where
2024-11-04 16:17:54 +00:00
import Codec.Borsh
2024-11-19 13:26:27 +00:00
import Control.Monad.Logger (LoggingT, logDebugN)
2024-10-23 20:49:24 +00:00
import Data.HexString
2024-11-05 00:56:16 +00:00
import Data.Int (Int32, Int64, Int8)
2024-11-04 16:17:54 +00:00
import Data.Maybe (fromJust, isNothing)
2024-11-19 13:26:27 +00:00
import qualified Data.Text as T
2024-11-04 16:17:54 +00:00
import qualified GHC.Generics as GHC
import qualified Generics.SOP as SOP
import ZcashHaskell.Orchard (combineOrchardNodes, getOrchardNodeValue)
import ZcashHaskell.Sapling (combineSaplingNodes, getSaplingNodeValue)
import ZcashHaskell.Types
( MerklePath(..)
, OrchardFrontier(..)
, OrchardTree(..)
, SaplingTree(..)
)
2024-11-04 16:17:54 +00:00
type Level = Int8
maxLevel :: Level
maxLevel = 32
2024-11-05 00:56:16 +00:00
type Position = Int32
2024-10-23 20:49:24 +00:00
class Monoid v =>
Measured a v
where
2024-11-04 16:17:54 +00:00
measure :: a -> Position -> Int64 -> v
2024-10-23 20:49:24 +00:00
class Node v where
getLevel :: v -> Level
2024-11-05 00:56:16 +00:00
getHash :: v -> HexString
getPosition :: v -> Position
getIndex :: v -> Int64
isFull :: v -> Bool
2024-11-04 16:17:54 +00:00
isMarked :: v -> Bool
mkNode :: Level -> Position -> HexString -> v
2024-10-23 20:49:24 +00:00
type OrchardCommitment = HexString
instance Measured OrchardCommitment OrchardNode where
2024-11-04 16:17:54 +00:00
measure oc p i =
case getOrchardNodeValue (hexBytes oc) of
2024-11-04 16:17:54 +00:00
Nothing -> OrchardNode 0 (hexString "00") 0 True 0 False
Just val -> OrchardNode p val 0 True i False
type SaplingCommitment = HexString
instance Measured SaplingCommitment SaplingNode where
measure sc p i =
case getSaplingNodeValue (hexBytes sc) of
Nothing -> SaplingNode 0 (hexString "00") 0 True 0 False
Just val -> SaplingNode p val 0 True i False
data Tree v
= EmptyLeaf
| Leaf !v
| PrunedBranch !v
| Branch !v !(Tree v) !(Tree v)
2024-11-04 16:17:54 +00:00
| InvalidTree
deriving stock (Eq, GHC.Generic)
deriving anyclass (SOP.Generic, SOP.HasDatatypeInfo)
deriving (BorshSize, ToBorsh, FromBorsh) via AsEnum (Tree v)
instance (Node v, Show v) => Show (Tree v) where
show EmptyLeaf = "()"
show (Leaf v) = "(" ++ show v ++ ")"
show (PrunedBranch v) = "{" ++ show v ++ "}"
show (Branch s x y) =
2024-11-05 00:56:16 +00:00
"<" ++ show (getHash s) ++ ">\n" ++ show x ++ "\n" ++ show y
2024-11-04 16:17:54 +00:00
show InvalidTree = "InvalidTree"
instance (Monoid v, Node v) => Semigroup (Tree v) where
2024-11-04 16:17:54 +00:00
(<>) InvalidTree _ = InvalidTree
(<>) _ InvalidTree = InvalidTree
(<>) EmptyLeaf EmptyLeaf = PrunedBranch $ value $ branch EmptyLeaf EmptyLeaf
(<>) EmptyLeaf x = x
(<>) (Leaf x) EmptyLeaf = branch (Leaf x) EmptyLeaf
(<>) (Leaf x) (Leaf y) = branch (Leaf x) (Leaf y)
2024-11-04 16:17:54 +00:00
(<>) (Leaf _) Branch {} = InvalidTree
(<>) (Leaf _) (PrunedBranch _) = InvalidTree
(<>) (PrunedBranch x) EmptyLeaf = PrunedBranch $ x <> x
2024-11-04 16:17:54 +00:00
(<>) (PrunedBranch x) (Leaf y) =
if isFull x
then InvalidTree
else mkSubTree (getLevel x) (Leaf y)
(<>) (PrunedBranch x) (Branch s t u) =
if getLevel x == getLevel s
then branch (PrunedBranch x) (Branch s t u)
2024-11-04 16:17:54 +00:00
else InvalidTree
(<>) (PrunedBranch x) (PrunedBranch y) = PrunedBranch $ x <> y
(<>) (Branch s x y) EmptyLeaf =
branch (Branch s x y) $ getEmptyRoot (getLevel s)
2024-11-04 16:17:54 +00:00
(<>) (Branch s x y) (PrunedBranch w)
| getLevel s == getLevel w = branch (Branch s x y) (PrunedBranch w)
| otherwise = InvalidTree
(<>) (Branch s x y) (Leaf w)
| isFull s = InvalidTree
| isFull (value x) = branch x (y <> Leaf w)
| otherwise = branch (x <> Leaf w) y
(<>) (Branch s x y) (Branch s1 x1 y1)
| getLevel s == getLevel s1 = branch (Branch s x y) (Branch s1 x1 y1)
2024-11-04 16:17:54 +00:00
| otherwise = InvalidTree
value :: Monoid v => Tree v -> v
value EmptyLeaf = mempty
value (Leaf v) = v
value (PrunedBranch v) = v
2024-10-23 20:49:24 +00:00
value (Branch v _ _) = v
2024-11-04 16:17:54 +00:00
value InvalidTree = mempty
2024-10-23 20:49:24 +00:00
branch :: Monoid v => Tree v -> Tree v -> Tree v
2024-10-23 20:49:24 +00:00
branch x y = Branch (value x <> value y) x y
2024-11-05 00:56:16 +00:00
leaf :: Measured a v => a -> Int32 -> Int64 -> Tree v
2024-11-04 16:17:54 +00:00
leaf a p i = Leaf (measure a p i)
2024-10-23 20:49:24 +00:00
prunedBranch :: Monoid v => Node v => Level -> Position -> HexString -> Tree v
prunedBranch level pos val = PrunedBranch $ mkNode level pos val
root :: Monoid v => Node v => Tree v -> Tree v
root tree =
2024-11-04 16:17:54 +00:00
if getLevel (value tree) == maxLevel
then tree
2024-11-04 16:17:54 +00:00
else mkSubTree maxLevel tree
getEmptyRoot :: Monoid v => Node v => Level -> Tree v
getEmptyRoot level = iterate (\x -> x <> x) EmptyLeaf !! fromIntegral level
2024-11-05 00:56:16 +00:00
append :: Monoid v => Measured a v => Node v => Tree v -> (a, Int64) -> Tree v
append tree (n, i) = tree <> leaf n p i
2024-11-04 16:17:54 +00:00
where
p = 1 + getPosition (value tree)
mkSubTree :: Node v => Monoid v => Level -> Tree v -> Tree v
mkSubTree level t =
if getLevel (value subtree) == level
then subtree
else mkSubTree level subtree
where
subtree = t <> EmptyLeaf
2024-11-05 00:56:16 +00:00
path :: Monoid v => Node v => Position -> Tree v -> Maybe MerklePath
path pos (Branch s x y) =
if length (collectPath (Branch s x y)) /= 32
then Nothing
else Just $ MerklePath pos $ collectPath (Branch s x y)
2024-11-04 16:17:54 +00:00
where
2024-11-05 00:56:16 +00:00
collectPath :: Monoid v => Node v => Tree v -> [HexString]
collectPath EmptyLeaf = []
collectPath Leaf {} = []
collectPath PrunedBranch {} = []
collectPath InvalidTree = []
collectPath (Branch _ j k)
| getPosition (value k) /= 0 && getPosition (value k) < pos = []
| getPosition (value j) < pos = collectPath k <> [getHash (value j)]
| getPosition (value j) >= pos = collectPath j <> [getHash (value k)]
| otherwise = []
path _ _ = Nothing
2024-11-04 16:17:54 +00:00
nullPath :: MerklePath
nullPath = MerklePath 0 []
getNotePosition :: Monoid v => Node v => Tree v -> Int64 -> Maybe Position
getNotePosition (Leaf x) i
| getIndex x == i = Just $ getPosition x
| otherwise = Nothing
getNotePosition (Branch _ x y) i
| getIndex (value x) >= i = getNotePosition x i
| getIndex (value y) >= i = getNotePosition y i
| otherwise = Nothing
getNotePosition _ _ = Nothing
2024-11-19 13:26:27 +00:00
truncateTree :: Monoid v => Node v => Tree v -> Int64 -> LoggingT IO (Tree v)
truncateTree (Branch s x y) i
2024-11-19 13:26:27 +00:00
| getLevel s == 1 && getIndex (value x) == i = do
logDebugN $ T.pack $ show (getLevel s) ++ " Trunc to left leaf"
return $ branch x EmptyLeaf
| getLevel s == 1 && getIndex (value y) == i = do
logDebugN $ T.pack $ show (getLevel s) ++ " Trunc to right leaf"
return $ branch x y
| getIndex (value x) >= i = do
logDebugN $
T.pack $
show (getLevel s) ++
": " ++ show i ++ " left i: " ++ show (getIndex (value x))
l <- truncateTree x i
return $ branch (l) (getEmptyRoot (getLevel (value x)))
| getIndex (value y) /= 0 && getIndex (value y) >= i = do
logDebugN $
T.pack $
show (getLevel s) ++
": " ++ show i ++ " right i: " ++ show (getIndex (value y))
r <- truncateTree y i
return $ branch x (r)
| otherwise = do
logDebugN $
T.pack $
show (getLevel s) ++
": " ++
show (getIndex (value x)) ++ " catchall " ++ show (getIndex (value y))
return InvalidTree
truncateTree x _ = return x
countLeaves :: Node v => Tree v -> Int64
countLeaves (Branch s x y) =
if isFull s
then 2 ^ getLevel s
else countLeaves x + countLeaves y
countLeaves (PrunedBranch x) =
if isFull x
then 2 ^ getLevel x
else 0
countLeaves (Leaf _) = 1
countLeaves EmptyLeaf = 0
countLeaves InvalidTree = 0
data SaplingNode = SaplingNode
{ sn_position :: !Position
, sn_value :: !HexString
, sn_level :: !Level
, sn_full :: !Bool
, sn_index :: !Int64
, sn_mark :: !Bool
} deriving stock (Eq, GHC.Generic)
deriving anyclass (SOP.Generic, SOP.HasDatatypeInfo)
deriving (BorshSize, ToBorsh, FromBorsh) via AsStruct SaplingNode
instance Semigroup SaplingNode where
(<>) x y =
case combineSaplingNodes (sn_level x) (sn_value x) (sn_value y) of
Nothing -> x
Just newHash ->
SaplingNode
(max (sn_position x) (sn_position y))
newHash
(1 + sn_level x)
(sn_full x && sn_full y)
(max (sn_index x) (sn_index y))
(sn_mark x || sn_mark y)
instance Monoid SaplingNode where
mempty = SaplingNode 0 (hexString "00") 0 False 0 False
mappend = (<>)
instance Node SaplingNode where
getLevel = sn_level
getHash = sn_value
getPosition = sn_position
getIndex = sn_index
isFull = sn_full
isMarked = sn_mark
mkNode l p v = SaplingNode p v l True 0 False
instance Show SaplingNode where
show = show . sn_value
saplingSize :: SaplingTree -> Int64
saplingSize tree =
(if isNothing (st_left tree)
then 0
else 1) +
(if isNothing (st_right tree)
then 0
else 1) +
foldl
(\x (i, p) ->
case p of
Nothing -> x + 0
Just _ -> x + 2 ^ i)
0
(zip [1 ..] $ st_parents tree)
mkSaplingTree :: SaplingTree -> Tree SaplingNode
mkSaplingTree tree =
foldl
(\t (i, n) ->
case n of
Just n' -> prunedBranch i 0 n' <> t
Nothing -> t <> getEmptyRoot i)
leafRoot
(zip [1 ..] $ st_parents tree)
where
leafRoot =
case st_right tree of
Just r' -> leaf (fromJust $ st_left tree) (pos - 1) 0 <> leaf r' pos 0
Nothing -> leaf (fromJust $ st_left tree) pos 0 <> EmptyLeaf
pos = fromIntegral $ saplingSize tree - 1
-- | Orchard
2024-10-23 20:49:24 +00:00
data OrchardNode = OrchardNode
{ on_position :: !Position
, on_value :: !HexString
, on_level :: !Level
, on_full :: !Bool
2024-11-04 16:17:54 +00:00
, on_index :: !Int64
, on_mark :: !Bool
} deriving stock (Eq, GHC.Generic)
deriving anyclass (SOP.Generic, SOP.HasDatatypeInfo)
deriving (BorshSize, ToBorsh, FromBorsh) via AsStruct OrchardNode
2024-10-23 20:49:24 +00:00
instance Semigroup OrchardNode where
(<>) x y =
2024-11-04 16:17:54 +00:00
case combineOrchardNodes
(fromIntegral $ on_level x)
(on_value x)
(on_value y) of
Nothing -> x
Just newHash ->
OrchardNode
(max (on_position x) (on_position y))
newHash
(1 + on_level x)
(on_full x && on_full y)
2024-11-04 16:17:54 +00:00
(max (on_index x) (on_index y))
(on_mark x || on_mark y)
2024-10-23 20:49:24 +00:00
instance Monoid OrchardNode where
2024-11-04 16:17:54 +00:00
mempty = OrchardNode 0 (hexString "00") 0 False 0 False
2024-10-23 20:49:24 +00:00
mappend = (<>)
instance Node OrchardNode where
getLevel = on_level
2024-11-05 00:56:16 +00:00
getHash = on_value
getPosition = on_position
getIndex = on_index
isFull = on_full
2024-11-04 16:17:54 +00:00
isMarked = on_mark
mkNode l p v = OrchardNode p v l True 0 False
instance Show OrchardNode where
show = show . on_value
instance Measured OrchardNode OrchardNode where
2024-11-04 16:17:54 +00:00
measure o p i =
OrchardNode p (on_value o) (on_level o) (on_full o) i (on_mark o)
orchardSize :: OrchardTree -> Int64
orchardSize tree =
(if isNothing (ot_left tree)
then 0
else 1) +
(if isNothing (ot_right tree)
then 0
else 1) +
foldl
(\x (i, p) ->
case p of
Nothing -> x + 0
Just _ -> x + 2 ^ i)
0
(zip [1 ..] $ ot_parents tree)
mkOrchardTree :: OrchardTree -> Tree OrchardNode
mkOrchardTree tree =
foldl
2024-11-04 16:17:54 +00:00
(\t (i, n) ->
case n of
Just n' -> prunedBranch i 0 n' <> t
Nothing -> t <> getEmptyRoot i)
leafRoot
(zip [1 ..] $ ot_parents tree)
where
2024-11-04 16:17:54 +00:00
leafRoot =
case ot_right tree of
Just r' -> leaf (fromJust $ ot_left tree) (pos - 1) 0 <> leaf r' pos 0
Nothing -> leaf (fromJust $ ot_left tree) pos 0 <> EmptyLeaf
2024-11-05 00:56:16 +00:00
pos = fromIntegral $ orchardSize tree - 1