zenith/src/Zenith/Tree.hs

231 lines
6.9 KiB
Haskell
Raw Normal View History

2024-10-23 20:49:24 +00:00
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
2024-11-04 16:17:54 +00:00
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE UndecidableInstances #-}
2024-10-23 20:49:24 +00:00
module Zenith.Tree where
2024-11-04 16:17:54 +00:00
import Codec.Borsh
2024-10-23 20:49:24 +00:00
import Data.HexString
2024-11-05 00:56:16 +00:00
import Data.Int (Int32, Int64, Int8)
2024-11-04 16:17:54 +00:00
import Data.Maybe (fromJust, isNothing)
import qualified GHC.Generics as GHC
import qualified Generics.SOP as SOP
import ZcashHaskell.Orchard (combineOrchardNodes, getOrchardNodeValue)
2024-11-04 16:17:54 +00:00
import ZcashHaskell.Types (MerklePath(..), OrchardFrontier(..), OrchardTree(..))
2024-11-04 16:17:54 +00:00
type Level = Int8
maxLevel :: Level
maxLevel = 32
2024-11-05 00:56:16 +00:00
type Position = Int32
2024-10-23 20:49:24 +00:00
class Monoid v =>
Measured a v
where
2024-11-04 16:17:54 +00:00
measure :: a -> Position -> Int64 -> v
2024-10-23 20:49:24 +00:00
class Node v where
getLevel :: v -> Level
2024-11-05 00:56:16 +00:00
getHash :: v -> HexString
getPosition :: v -> Position
isFull :: v -> Bool
2024-11-04 16:17:54 +00:00
isMarked :: v -> Bool
mkNode :: Level -> Position -> HexString -> v
2024-10-23 20:49:24 +00:00
type OrchardCommitment = HexString
instance Measured OrchardCommitment OrchardNode where
2024-11-04 16:17:54 +00:00
measure oc p i =
case getOrchardNodeValue (hexBytes oc) of
2024-11-04 16:17:54 +00:00
Nothing -> OrchardNode 0 (hexString "00") 0 True 0 False
Just val -> OrchardNode p val 0 True i False
data Tree v
= EmptyLeaf
| Leaf !v
| PrunedBranch !v
| Branch !v !(Tree v) !(Tree v)
2024-11-04 16:17:54 +00:00
| InvalidTree
deriving stock (Eq, GHC.Generic)
deriving anyclass (SOP.Generic, SOP.HasDatatypeInfo)
deriving (BorshSize, ToBorsh, FromBorsh) via AsEnum (Tree v)
instance (Node v, Show v) => Show (Tree v) where
show EmptyLeaf = "()"
show (Leaf v) = "(" ++ show v ++ ")"
show (PrunedBranch v) = "{" ++ show v ++ "}"
show (Branch s x y) =
2024-11-05 00:56:16 +00:00
"<" ++ show (getHash s) ++ ">\n" ++ show x ++ "\n" ++ show y
2024-11-04 16:17:54 +00:00
show InvalidTree = "InvalidTree"
instance (Monoid v, Node v) => Semigroup (Tree v) where
2024-11-04 16:17:54 +00:00
(<>) InvalidTree _ = InvalidTree
(<>) _ InvalidTree = InvalidTree
(<>) EmptyLeaf EmptyLeaf = PrunedBranch $ value $ branch EmptyLeaf EmptyLeaf
(<>) EmptyLeaf x = x
(<>) (Leaf x) EmptyLeaf = branch (Leaf x) EmptyLeaf
(<>) (Leaf x) (Leaf y) = branch (Leaf x) (Leaf y)
2024-11-04 16:17:54 +00:00
(<>) (Leaf _) Branch {} = InvalidTree
(<>) (Leaf _) (PrunedBranch _) = InvalidTree
(<>) (PrunedBranch x) EmptyLeaf = PrunedBranch $ x <> x
2024-11-04 16:17:54 +00:00
(<>) (PrunedBranch x) (Leaf y) =
if isFull x
then InvalidTree
else mkSubTree (getLevel x) (Leaf y)
(<>) (PrunedBranch x) (Branch s t u) =
if getLevel x == getLevel s
then branch (PrunedBranch x) (Branch s t u)
2024-11-04 16:17:54 +00:00
else InvalidTree
(<>) (PrunedBranch x) (PrunedBranch y) = PrunedBranch $ x <> y
(<>) (Branch s x y) EmptyLeaf =
branch (Branch s x y) $ getEmptyRoot (getLevel s)
2024-11-04 16:17:54 +00:00
(<>) (Branch s x y) (PrunedBranch w)
| getLevel s == getLevel w = branch (Branch s x y) (PrunedBranch w)
| otherwise = InvalidTree
(<>) (Branch s x y) (Leaf w)
| isFull s = Branch s x y <> mkSubTree (getLevel s) (Leaf w)
| isFull (value x) = branch x (y <> Leaf w)
| otherwise = branch (x <> Leaf w) y
(<>) (Branch s x y) (Branch s1 x1 y1)
| getLevel s == getLevel s1 = branch (Branch s x y) (Branch s1 x1 y1)
2024-11-04 16:17:54 +00:00
| otherwise = InvalidTree
value :: Monoid v => Tree v -> v
value EmptyLeaf = mempty
value (Leaf v) = v
value (PrunedBranch v) = v
2024-10-23 20:49:24 +00:00
value (Branch v _ _) = v
2024-11-04 16:17:54 +00:00
value InvalidTree = mempty
2024-10-23 20:49:24 +00:00
branch :: Monoid v => Tree v -> Tree v -> Tree v
2024-10-23 20:49:24 +00:00
branch x y = Branch (value x <> value y) x y
2024-11-05 00:56:16 +00:00
leaf :: Measured a v => a -> Int32 -> Int64 -> Tree v
2024-11-04 16:17:54 +00:00
leaf a p i = Leaf (measure a p i)
2024-10-23 20:49:24 +00:00
prunedBranch :: Monoid v => Node v => Level -> Position -> HexString -> Tree v
prunedBranch level pos val = PrunedBranch $ mkNode level pos val
root :: Monoid v => Node v => Tree v -> Tree v
root tree =
2024-11-04 16:17:54 +00:00
if getLevel (value tree) == maxLevel
then tree
2024-11-04 16:17:54 +00:00
else mkSubTree maxLevel tree
getEmptyRoot :: Monoid v => Node v => Level -> Tree v
getEmptyRoot level = iterate (\x -> x <> x) EmptyLeaf !! fromIntegral level
2024-11-05 00:56:16 +00:00
append :: Monoid v => Measured a v => Node v => Tree v -> (a, Int64) -> Tree v
append tree (n, i) = tree <> leaf n p i
2024-11-04 16:17:54 +00:00
where
p = 1 + getPosition (value tree)
mkSubTree :: Node v => Monoid v => Level -> Tree v -> Tree v
mkSubTree level t =
if getLevel (value subtree) == level
then subtree
else mkSubTree level subtree
where
subtree = t <> EmptyLeaf
2024-11-05 00:56:16 +00:00
path :: Monoid v => Node v => Position -> Tree v -> Maybe MerklePath
path pos (Branch s x y) =
if length (collectPath (Branch s x y)) /= 32
then Nothing
else Just $ MerklePath pos $ collectPath (Branch s x y)
2024-11-04 16:17:54 +00:00
where
2024-11-05 00:56:16 +00:00
collectPath :: Monoid v => Node v => Tree v -> [HexString]
collectPath EmptyLeaf = []
collectPath Leaf {} = []
collectPath PrunedBranch {} = []
collectPath InvalidTree = []
collectPath (Branch _ j k)
| getPosition (value k) /= 0 && getPosition (value k) < pos = []
| getPosition (value j) < pos = collectPath k <> [getHash (value j)]
| getPosition (value j) >= pos = collectPath j <> [getHash (value k)]
| otherwise = []
path _ _ = Nothing
2024-11-04 16:17:54 +00:00
2024-10-23 20:49:24 +00:00
data OrchardNode = OrchardNode
{ on_position :: !Position
, on_value :: !HexString
, on_level :: !Level
, on_full :: !Bool
2024-11-04 16:17:54 +00:00
, on_index :: !Int64
, on_mark :: !Bool
} deriving stock (Eq, GHC.Generic)
deriving anyclass (SOP.Generic, SOP.HasDatatypeInfo)
deriving (BorshSize, ToBorsh, FromBorsh) via AsStruct OrchardNode
2024-10-23 20:49:24 +00:00
instance Semigroup OrchardNode where
(<>) x y =
2024-11-04 16:17:54 +00:00
case combineOrchardNodes
(fromIntegral $ on_level x)
(on_value x)
(on_value y) of
Nothing -> x
Just newHash ->
OrchardNode
(max (on_position x) (on_position y))
newHash
(1 + on_level x)
(on_full x && on_full y)
2024-11-04 16:17:54 +00:00
(max (on_index x) (on_index y))
(on_mark x || on_mark y)
2024-10-23 20:49:24 +00:00
instance Monoid OrchardNode where
2024-11-04 16:17:54 +00:00
mempty = OrchardNode 0 (hexString "00") 0 False 0 False
2024-10-23 20:49:24 +00:00
mappend = (<>)
instance Node OrchardNode where
getLevel = on_level
2024-11-05 00:56:16 +00:00
getHash = on_value
getPosition = on_position
isFull = on_full
2024-11-04 16:17:54 +00:00
isMarked = on_mark
mkNode l p v = OrchardNode p v l True 0 False
instance Show OrchardNode where
show = show . on_value
instance Measured OrchardNode OrchardNode where
2024-11-04 16:17:54 +00:00
measure o p i =
OrchardNode p (on_value o) (on_level o) (on_full o) i (on_mark o)
orchardSize :: OrchardTree -> Int64
orchardSize tree =
(if isNothing (ot_left tree)
then 0
else 1) +
(if isNothing (ot_right tree)
then 0
else 1) +
foldl
(\x (i, p) ->
case p of
Nothing -> x + 0
Just _ -> x + 2 ^ i)
0
(zip [1 ..] $ ot_parents tree)
mkOrchardTree :: OrchardTree -> Tree OrchardNode
mkOrchardTree tree =
foldl
2024-11-04 16:17:54 +00:00
(\t (i, n) ->
case n of
Just n' -> prunedBranch i 0 n' <> t
Nothing -> t <> getEmptyRoot i)
leafRoot
(zip [1 ..] $ ot_parents tree)
where
2024-11-04 16:17:54 +00:00
leafRoot =
case ot_right tree of
Just r' -> leaf (fromJust $ ot_left tree) (pos - 1) 0 <> leaf r' pos 0
Nothing -> leaf (fromJust $ ot_left tree) pos 0 <> EmptyLeaf
2024-11-05 00:56:16 +00:00
pos = fromIntegral $ orchardSize tree - 1