From 93240325dff18f610ae45196f3839dff7387f16e Mon Sep 17 00:00:00 2001 From: Rene Vergara Date: Tue, 24 Sep 2024 14:34:19 -0500 Subject: [PATCH] feat!: add re-org detection and rewind --- CHANGELOG.md | 1 + src/Zenith/CLI.hs | 37 ++++++++++++++++++++----------------- src/Zenith/DB.hs | 26 ++++++++++++++++++++++++-- src/Zenith/GUI.hs | 35 ++++++++++++++++++----------------- src/Zenith/Scanner.hs | 31 +++++++++++++++++++++++++++++++ 5 files changed, 94 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 896f51a..22dcc13 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Detection of changes in database schema for automatic re-scan +- Block tracking for chain re-org detection ## [0.6.0.0-beta] diff --git a/src/Zenith/CLI.hs b/src/Zenith/CLI.hs index 3de9e91..10e34bc 100644 --- a/src/Zenith/CLI.hs +++ b/src/Zenith/CLI.hs @@ -61,7 +61,7 @@ import qualified Brick.Widgets.List as L import qualified Brick.Widgets.ProgressBar as P import Control.Concurrent (forkIO, threadDelay) import Control.Exception (throw, throwIO, try) -import Control.Monad (forever, void, when) +import Control.Monad (forever, unless, void, when) import Control.Monad.IO.Class (liftIO) import Control.Monad.Logger (runFileLoggingT, runNoLoggingT) import Data.Aeson @@ -88,7 +88,7 @@ import ZcashHaskell.Types import ZcashHaskell.Utils (getBlockTime, makeZebraCall) import Zenith.Core import Zenith.DB -import Zenith.Scanner (processTx, rescanZebra, updateConfs) +import Zenith.Scanner (checkIntegrity, processTx, rescanZebra, updateConfs) import Zenith.Types ( Config(..) , HexStringDB(..) @@ -722,26 +722,29 @@ scanZebra dbP zHost zPort b eChan znet = do bStatus <- liftIO $ checkBlockChain zHost zPort pool <- runNoLoggingT $ initPool dbP dbBlock <- getMaxBlock pool $ ZcashNetDB znet + chkBlock <- checkIntegrity dbP zHost zPort dbBlock 1 + unless (chkBlock == dbBlock) $ rewindWalletData pool chkBlock + let sb = + if chkBlock == dbBlock + then max dbBlock b + else max chkBlock b + if sb > zgb_blocks bStatus || sb < 1 + then do + liftIO $ BC.writeBChan eChan $ TickMsg "Invalid starting block for scan" + else do + let bList = [(sb + 1) .. (zgb_blocks bStatus)] + if not (null bList) + then do + let step = + (1.0 :: Float) / fromIntegral (zgb_blocks bStatus - (sb + 1)) + mapM_ (processBlock pool step) bList + else liftIO $ BC.writeBChan eChan $ TickVal 1.0 confUp <- try $ updateConfs zHost zPort pool :: IO (Either IOError ()) case confUp of Left _e0 -> liftIO $ BC.writeBChan eChan $ TickMsg "Failed to update unconfirmed transactions" - Right _ -> do - let sb = max dbBlock b - if sb > zgb_blocks bStatus || sb < 1 - then do - liftIO $ - BC.writeBChan eChan $ TickMsg "Invalid starting block for scan" - else do - let bList = [(sb + 1) .. (zgb_blocks bStatus)] - if not (null bList) - then do - let step = - (1.0 :: Float) / - fromIntegral (zgb_blocks bStatus - (sb + 1)) - mapM_ (processBlock pool step) bList - else liftIO $ BC.writeBChan eChan $ TickVal 1.0 + Right _ -> return () where processBlock :: ConnectionPool -> Float -> Int -> IO () processBlock pool step bl = do diff --git a/src/Zenith/DB.hs b/src/Zenith/DB.hs index 82645ae..dd6225c 100644 --- a/src/Zenith/DB.hs +++ b/src/Zenith/DB.hs @@ -444,10 +444,10 @@ initDb dbName = do clearWalletTransactions pool clearWalletData pool m <- - try $ PS.runSqlite dbName $ runMigrationQuiet migrateAll :: IO + try $ PS.runSqlite dbName $ runMigrationUnsafeQuiet migrateAll :: IO (Either SomeException [T.Text]) case m of - Left _e2 -> return $ Left "Failed to migrate data tables" + Left e2 -> return $ Left $ "Failed to migrate data tables" ++ show e2 Right _ -> return $ Right True Right _ -> return $ Right False @@ -688,6 +688,17 @@ saveBlock :: ConnectionPool -> ZcashBlock -> IO (Key ZcashBlock) saveBlock pool b = runNoLoggingT $ PS.retryOnBusy $ flip PS.runSqlPool pool $ do insert b +-- | Read a block by height +getBlock :: ConnectionPool -> Int -> IO (Maybe (Entity ZcashBlock)) +getBlock pool b = + runNoLoggingT $ + PS.retryOnBusy $ + flip PS.runSqlPool pool $ do + selectOne $ do + bl <- from $ table @ZcashBlock + where_ $ bl ^. ZcashBlockHeight ==. val b + pure bl + -- | Save a transaction to the data model saveTransaction :: ConnectionPool -- ^ the database path @@ -2270,3 +2281,14 @@ finalizeOperation pool op status result = do , OperationResult =. val (Just result) ] where_ (ops ^. OperationId ==. val op) + +-- | Rewind the data store to a given block height +rewindWalletData :: ConnectionPool -> Int -> IO () +rewindWalletData pool b = do + runNoLoggingT $ + PS.retryOnBusy $ + flip PS.runSqlPool pool $ + delete $ do + blk <- from $ table @ZcashBlock + where_ $ blk ^. ZcashBlockHeight >. val b + clearWalletTransactions pool diff --git a/src/Zenith/GUI.hs b/src/Zenith/GUI.hs index dadb35a..16eabef 100644 --- a/src/Zenith/GUI.hs +++ b/src/Zenith/GUI.hs @@ -10,7 +10,7 @@ import Codec.QRCode import Codec.QRCode.JuicyPixels import Control.Concurrent (threadDelay) import Control.Exception (throwIO, try) -import Control.Monad (when) +import Control.Monad (unless, when) import Control.Monad.IO.Class (liftIO) import Control.Monad.Logger (runFileLoggingT, runNoLoggingT) import Data.Aeson @@ -47,12 +47,10 @@ import ZcashHaskell.Utils (getBlockTime, makeZebraCall) import Zenith.Core import Zenith.DB import Zenith.GUI.Theme -import Zenith.Scanner (processTx, rescanZebra, updateConfs) +import Zenith.Scanner (checkIntegrity, processTx, rescanZebra, updateConfs) import Zenith.Types hiding (ZcashAddress(..)) import Zenith.Utils ( displayAmount - , getZenithPath - , isEmpty , isRecipientValid , isValidString , jsonNumber @@ -60,7 +58,6 @@ import Zenith.Utils , parseAddress , showAddress , validBarValue - , validateAddressBool ) data AppEvent @@ -116,7 +113,6 @@ data AppEvent | CheckValidAddress !T.Text | CheckValidDescrip !T.Text | SaveNewABEntry - | SaveABDescription !T.Text | UpdateABEntry !T.Text !T.Text | CloseUpdABEntry | ShowMessage !T.Text @@ -1443,20 +1439,25 @@ scanZebra dbPath zHost zPort net sendMsg = do pool <- runNoLoggingT $ initPool dbPath b <- liftIO $ getMinBirthdayHeight pool dbBlock <- getMaxBlock pool $ ZcashNetDB net - let sb = max dbBlock b + chkBlock <- checkIntegrity dbPath zHost zPort dbBlock 1 + unless (chkBlock == dbBlock) $ rewindWalletData pool chkBlock + let sb = + if chkBlock == dbBlock + then max dbBlock b + else max chkBlock b + if sb > zgb_blocks bStatus || sb < 1 + then sendMsg (ShowError "Invalid starting block for scan") + else do + let bList = [(sb + 1) .. (zgb_blocks bStatus)] + if not (null bList) + then do + let step = (1.0 :: Float) / fromIntegral (length bList) + mapM_ (processBlock pool step) bList + else sendMsg (SyncVal 1.0) confUp <- try $ updateConfs zHost zPort pool :: IO (Either IOError ()) case confUp of Left _e0 -> sendMsg (ShowError "Failed to update unconfirmed transactions") - Right _ -> do - if sb > zgb_blocks bStatus || sb < 1 - then sendMsg (ShowError "Invalid starting block for scan") - else do - let bList = [(sb + 1) .. (zgb_blocks bStatus)] - if not (null bList) - then do - let step = (1.0 :: Float) / fromIntegral (length bList) - mapM_ (processBlock pool step) bList - else sendMsg (SyncVal 1.0) + Right _ -> return () where processBlock :: ConnectionPool -> Float -> Int -> IO () processBlock pool step bl = do diff --git a/src/Zenith/Scanner.hs b/src/Zenith/Scanner.hs index 8fee929..10ca49d 100644 --- a/src/Zenith/Scanner.hs +++ b/src/Zenith/Scanner.hs @@ -31,6 +31,7 @@ import Zenith.DB , ZcashBlockId , clearWalletData , clearWalletTransactions + , getBlock , getMaxBlock , getMinBirthdayHeight , getUnconfirmedBlocks @@ -214,3 +215,33 @@ clearSync config = do w' <- liftIO $ getWallets pool $ zgb_net chainInfo r <- mapM (syncWallet config) w' liftIO $ print r + +-- | Detect chain re-orgs +checkIntegrity :: + T.Text -- ^ Database path + -> T.Text -- ^ Zebra host + -> Int -- ^ Zebra port + -> Int -- ^ The block to start the check + -> Int -- ^ depth + -> IO Int +checkIntegrity dbP zHost zPort b d = + if b < 1 + then return 1 + else do + r <- + makeZebraCall + zHost + zPort + "getblock" + [Data.Aeson.String $ T.pack $ show b, jsonNumber 1] + case r of + Left e -> throwIO $ userError e + Right blk -> do + pool <- runNoLoggingT $ initPool dbP + dbBlk <- getBlock pool b + case dbBlk of + Nothing -> throwIO $ userError "Block mismatch, rescan needed" + Just dbBlk' -> + if bl_hash blk == getHex (zcashBlockHash $ entityVal dbBlk') + then return b + else checkIntegrity dbP zHost zPort (b - 5 * d) (d + 1)