zenith/src/Zenith/Tree.hs

162 lines
4.7 KiB
Haskell
Raw Normal View History

2024-10-23 20:49:24 +00:00
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
2024-10-23 20:49:24 +00:00
module Zenith.Tree where
import Data.HexString
import Data.Int (Int64)
import ZcashHaskell.Orchard (combineOrchardNodes, getOrchardNodeValue)
import ZcashHaskell.Types (OrchardFrontier(..))
type Level = Integer
maxLevel :: Level
maxLevel = 32
type Position = Int64
2024-10-23 20:49:24 +00:00
class Monoid v =>
Measured a v
where
measure :: a -> Position -> v
2024-10-23 20:49:24 +00:00
class Node v where
getLevel :: v -> Level
getTag :: v -> HexString
getPosition :: v -> Position
isFull :: v -> Bool
mkNode :: Level -> Position -> HexString -> v
2024-10-23 20:49:24 +00:00
type OrchardCommitment = HexString
instance Measured OrchardCommitment OrchardNode where
measure oc p =
case getOrchardNodeValue (hexBytes oc) of
Nothing -> OrchardNode 0 (hexString "00") 0 True
Just val -> OrchardNode p val 0 True
data Tree v
= EmptyLeaf
| Leaf !v
| PrunedBranch !v
| Branch !v !(Tree v) !(Tree v)
deriving (Eq)
instance (Node v, Show v) => Show (Tree v) where
show EmptyLeaf = "()"
show (Leaf v) = "(" ++ show v ++ ")"
show (PrunedBranch v) = "{" ++ show v ++ "}"
show (Branch s x y) =
"<" ++ show (getTag s) ++ ">\n" ++ show x ++ "\n" ++ show y
instance (Monoid v, Node v) => Semigroup (Tree v) where
(<>) EmptyLeaf EmptyLeaf = PrunedBranch $ value $ branch EmptyLeaf EmptyLeaf
(<>) EmptyLeaf x = x
(<>) (Leaf x) EmptyLeaf = branch (Leaf x) EmptyLeaf
(<>) (Leaf x) (Leaf y) = branch (Leaf x) (Leaf y)
(<>) (Leaf x) Branch {} = Leaf x
(<>) (Leaf x) (PrunedBranch _) = Leaf x
(<>) (PrunedBranch x) EmptyLeaf = PrunedBranch $ x <> x
(<>) (PrunedBranch x) (Leaf _) = PrunedBranch x
(<>) (PrunedBranch x) (Branch s t u) =
if getLevel x == getLevel s
then branch (PrunedBranch x) (Branch s t u)
else EmptyLeaf
(<>) (PrunedBranch x) (PrunedBranch y) = PrunedBranch $ x <> y
(<>) (Branch s x y) EmptyLeaf =
branch (Branch s x y) $ getEmptyRoot (getLevel s)
(<>) (Branch s x y) (PrunedBranch _) = Branch s x y
(<>) (Branch s x y) (Leaf w)
| isFull s = Branch s x y <> mkSubTree (getLevel s) (Leaf w)
| isFull (value x) = branch x (y <> Leaf w)
| otherwise = branch (x <> Leaf w) y
(<>) (Branch s x y) (Branch s1 x1 y1)
| getLevel s == getLevel s1 = branch (Branch s x y) (Branch s1 x1 y1)
| otherwise = Branch s x y
value :: Monoid v => Tree v -> v
value EmptyLeaf = mempty
value (Leaf v) = v
value (PrunedBranch v) = v
2024-10-23 20:49:24 +00:00
value (Branch v _ _) = v
branch :: Monoid v => Tree v -> Tree v -> Tree v
2024-10-23 20:49:24 +00:00
branch x y = Branch (value x <> value y) x y
leaf :: Measured a v => a -> Int64 -> Tree v
leaf a p = Leaf (measure a p)
2024-10-23 20:49:24 +00:00
prunedBranch :: Monoid v => Node v => Level -> Position -> HexString -> Tree v
prunedBranch level pos val = PrunedBranch $ mkNode level pos val
root :: Monoid v => Node v => Tree v -> Tree v
root tree =
if getLevel (value tree) == maxLevel - 1
then tree
else mkSubTree (maxLevel - 1) tree
getEmptyRoot :: Monoid v => Node v => Level -> Tree v
getEmptyRoot level = iterate (\x -> x <> x) EmptyLeaf !! fromIntegral level
2024-10-23 20:49:24 +00:00
data OrchardNode = OrchardNode
{ on_position :: !Position
, on_value :: !HexString
, on_level :: !Level
, on_full :: !Bool
} deriving (Eq)
2024-10-23 20:49:24 +00:00
instance Semigroup OrchardNode where
(<>) x y =
case combineOrchardNodes (on_level x) (on_value x) (on_value y) of
Nothing -> x
Just newHash ->
OrchardNode
(max (on_position x) (on_position y))
newHash
(1 + on_level x)
(on_full x && on_full y)
2024-10-23 20:49:24 +00:00
instance Monoid OrchardNode where
mempty = OrchardNode 0 (hexString "00") 0 False
2024-10-23 20:49:24 +00:00
mappend = (<>)
instance Node OrchardNode where
getLevel = on_level
getTag = on_value
getPosition = on_position
isFull = on_full
mkNode l p v = OrchardNode p v l True
instance Show OrchardNode where
show = show . on_value
instance Measured OrchardNode OrchardNode where
measure o p = OrchardNode p (on_value o) (on_level o) (on_full o)
mkSubTree :: Node v => Monoid v => Level -> Tree v -> Tree v
mkSubTree level t =
if getLevel (value subtree) == level
then subtree
else mkSubTree level subtree
where
subtree = t <> EmptyLeaf
mkOrchardTree :: OrchardFrontier -> Tree OrchardNode
mkOrchardTree (OrchardFrontier p l o) =
if odd p
then addOrchardOmmers (tail o) $
Leaf (OrchardNode (p - 1) (head o) 0 True) <>
Leaf (OrchardNode p l 0 True)
else addOrchardOmmers o $ Leaf (OrchardNode p l 0 True) <> EmptyLeaf
addOrchardOmmers :: [HexString] -> Tree OrchardNode -> Tree OrchardNode
addOrchardOmmers xs t =
foldl
(\s x -> PrunedBranch (mkNode (getLevel $ value s) (p (value s)) x) <> s)
t
xs
where
p :: OrchardNode -> Position
p (OrchardNode pos _ l _) = pos - (2 ^ l)