{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE UndecidableInstances #-} module Zenith.Tree where import Codec.Borsh import Data.HexString import Data.Int (Int32, Int64, Int8) import Data.Maybe (fromJust, isNothing) import qualified GHC.Generics as GHC import qualified Generics.SOP as SOP import ZcashHaskell.Orchard (combineOrchardNodes, getOrchardNodeValue) import ZcashHaskell.Types (MerklePath(..), OrchardFrontier(..), OrchardTree(..)) type Level = Int8 maxLevel :: Level maxLevel = 32 type Position = Int32 class Monoid v => Measured a v where measure :: a -> Position -> Int64 -> v class Node v where getLevel :: v -> Level getHash :: v -> HexString getPosition :: v -> Position isFull :: v -> Bool isMarked :: v -> Bool mkNode :: Level -> Position -> HexString -> v type OrchardCommitment = HexString instance Measured OrchardCommitment OrchardNode where measure oc p i = case getOrchardNodeValue (hexBytes oc) of Nothing -> OrchardNode 0 (hexString "00") 0 True 0 False Just val -> OrchardNode p val 0 True i False data Tree v = EmptyLeaf | Leaf !v | PrunedBranch !v | Branch !v !(Tree v) !(Tree v) | InvalidTree deriving stock (Eq, GHC.Generic) deriving anyclass (SOP.Generic, SOP.HasDatatypeInfo) deriving (BorshSize, ToBorsh, FromBorsh) via AsEnum (Tree v) instance (Node v, Show v) => Show (Tree v) where show EmptyLeaf = "()" show (Leaf v) = "(" ++ show v ++ ")" show (PrunedBranch v) = "{" ++ show v ++ "}" show (Branch s x y) = "<" ++ show (getHash s) ++ ">\n" ++ show x ++ "\n" ++ show y show InvalidTree = "InvalidTree" instance (Monoid v, Node v) => Semigroup (Tree v) where (<>) InvalidTree _ = InvalidTree (<>) _ InvalidTree = InvalidTree (<>) EmptyLeaf EmptyLeaf = PrunedBranch $ value $ branch EmptyLeaf EmptyLeaf (<>) EmptyLeaf x = x (<>) (Leaf x) EmptyLeaf = branch (Leaf x) EmptyLeaf (<>) (Leaf x) (Leaf y) = branch (Leaf x) (Leaf y) (<>) (Leaf _) Branch {} = InvalidTree (<>) (Leaf _) (PrunedBranch _) = InvalidTree (<>) (PrunedBranch x) EmptyLeaf = PrunedBranch $ x <> x (<>) (PrunedBranch x) (Leaf y) = if isFull x then InvalidTree else mkSubTree (getLevel x) (Leaf y) (<>) (PrunedBranch x) (Branch s t u) = if getLevel x == getLevel s then branch (PrunedBranch x) (Branch s t u) else InvalidTree (<>) (PrunedBranch x) (PrunedBranch y) = PrunedBranch $ x <> y (<>) (Branch s x y) EmptyLeaf = branch (Branch s x y) $ getEmptyRoot (getLevel s) (<>) (Branch s x y) (PrunedBranch w) | getLevel s == getLevel w = branch (Branch s x y) (PrunedBranch w) | otherwise = InvalidTree (<>) (Branch s x y) (Leaf w) | isFull s = Branch s x y <> mkSubTree (getLevel s) (Leaf w) | isFull (value x) = branch x (y <> Leaf w) | otherwise = branch (x <> Leaf w) y (<>) (Branch s x y) (Branch s1 x1 y1) | getLevel s == getLevel s1 = branch (Branch s x y) (Branch s1 x1 y1) | otherwise = InvalidTree value :: Monoid v => Tree v -> v value EmptyLeaf = mempty value (Leaf v) = v value (PrunedBranch v) = v value (Branch v _ _) = v value InvalidTree = mempty branch :: Monoid v => Tree v -> Tree v -> Tree v branch x y = Branch (value x <> value y) x y leaf :: Measured a v => a -> Int32 -> Int64 -> Tree v leaf a p i = Leaf (measure a p i) prunedBranch :: Monoid v => Node v => Level -> Position -> HexString -> Tree v prunedBranch level pos val = PrunedBranch $ mkNode level pos val root :: Monoid v => Node v => Tree v -> Tree v root tree = if getLevel (value tree) == maxLevel then tree else mkSubTree maxLevel tree getEmptyRoot :: Monoid v => Node v => Level -> Tree v getEmptyRoot level = iterate (\x -> x <> x) EmptyLeaf !! fromIntegral level append :: Monoid v => Measured a v => Node v => Tree v -> (a, Int64) -> Tree v append tree (n, i) = tree <> leaf n p i where p = 1 + getPosition (value tree) mkSubTree :: Node v => Monoid v => Level -> Tree v -> Tree v mkSubTree level t = if getLevel (value subtree) == level then subtree else mkSubTree level subtree where subtree = t <> EmptyLeaf path :: Monoid v => Node v => Position -> Tree v -> Maybe MerklePath path pos (Branch s x y) = if length (collectPath (Branch s x y)) /= 32 then Nothing else Just $ MerklePath pos $ collectPath (Branch s x y) where collectPath :: Monoid v => Node v => Tree v -> [HexString] collectPath EmptyLeaf = [] collectPath Leaf {} = [] collectPath PrunedBranch {} = [] collectPath InvalidTree = [] collectPath (Branch _ j k) | getPosition (value k) /= 0 && getPosition (value k) < pos = [] | getPosition (value j) < pos = collectPath k <> [getHash (value j)] | getPosition (value j) >= pos = collectPath j <> [getHash (value k)] | otherwise = [] path _ _ = Nothing data OrchardNode = OrchardNode { on_position :: !Position , on_value :: !HexString , on_level :: !Level , on_full :: !Bool , on_index :: !Int64 , on_mark :: !Bool } deriving stock (Eq, GHC.Generic) deriving anyclass (SOP.Generic, SOP.HasDatatypeInfo) deriving (BorshSize, ToBorsh, FromBorsh) via AsStruct OrchardNode instance Semigroup OrchardNode where (<>) x y = case combineOrchardNodes (fromIntegral $ on_level x) (on_value x) (on_value y) of Nothing -> x Just newHash -> OrchardNode (max (on_position x) (on_position y)) newHash (1 + on_level x) (on_full x && on_full y) (max (on_index x) (on_index y)) (on_mark x || on_mark y) instance Monoid OrchardNode where mempty = OrchardNode 0 (hexString "00") 0 False 0 False mappend = (<>) instance Node OrchardNode where getLevel = on_level getHash = on_value getPosition = on_position isFull = on_full isMarked = on_mark mkNode l p v = OrchardNode p v l True 0 False instance Show OrchardNode where show = show . on_value instance Measured OrchardNode OrchardNode where measure o p i = OrchardNode p (on_value o) (on_level o) (on_full o) i (on_mark o) orchardSize :: OrchardTree -> Int64 orchardSize tree = (if isNothing (ot_left tree) then 0 else 1) + (if isNothing (ot_right tree) then 0 else 1) + foldl (\x (i, p) -> case p of Nothing -> x + 0 Just _ -> x + 2 ^ i) 0 (zip [1 ..] $ ot_parents tree) mkOrchardTree :: OrchardTree -> Tree OrchardNode mkOrchardTree tree = foldl (\t (i, n) -> case n of Just n' -> prunedBranch i 0 n' <> t Nothing -> t <> getEmptyRoot i) leafRoot (zip [1 ..] $ ot_parents tree) where leafRoot = case ot_right tree of Just r' -> leaf (fromJust $ ot_left tree) (pos - 1) 0 <> leaf r' pos 0 Nothing -> leaf (fromJust $ ot_left tree) pos 0 <> EmptyLeaf pos = fromIntegral $ orchardSize tree - 1