diff --git a/src/Zenith/CLI.hs b/src/Zenith/CLI.hs index ab6549c..e877b43 100644 --- a/src/Zenith/CLI.hs +++ b/src/Zenith/CLI.hs @@ -832,7 +832,7 @@ scanZebra dbP zHost zPort b eChan znet = do bStatus <- liftIO $ checkBlockChain zHost zPort pool <- liftIO $ runNoLoggingT $ initPool dbP dbBlock <- liftIO $ getMaxBlock pool $ ZcashNetDB znet - chkBlock <- liftIO $ checkIntegrity dbP zHost zPort dbBlock 1 + chkBlock <- liftIO $ checkIntegrity dbP zHost zPort znet dbBlock 1 syncChk <- liftIO $ isSyncing pool if syncChk then liftIO $ BC.writeBChan eChan $ TickMsg "Sync alread in progress" @@ -844,7 +844,8 @@ scanZebra dbP zHost zPort b eChan znet = do if chkBlock == dbBlock then max dbBlock b else max chkBlock b - when (chkBlock /= dbBlock && chkBlock /= 1) $ rewindWalletData pool sb + when (chkBlock /= dbBlock && chkBlock /= 1) $ + rewindWalletData pool sb $ ZcashNetDB znet if sb > zgb_blocks bStatus || sb < 1 then do liftIO $ diff --git a/src/Zenith/Core.hs b/src/Zenith/Core.hs index 9cdb015..2dd74f2 100644 --- a/src/Zenith/Core.hs +++ b/src/Zenith/Core.hs @@ -119,10 +119,11 @@ getCommitmentTrees :: ConnectionPool -> T.Text -- ^ Host where `zebrad` is avaiable -> Int -- ^ Port where `zebrad` is available + -> ZcashNetDB -> Int -- ^ Block height -> IO ZebraTreeInfo -getCommitmentTrees pool nodeHost nodePort block = do - bh' <- getBlockHash pool block +getCommitmentTrees pool nodeHost nodePort znet block = do + bh' <- getBlockHash pool block znet case bh' of Nothing -> throwIO $ userError "couldn't get block hash" Just bh -> do @@ -293,7 +294,7 @@ findSaplingOutputs config b znet za = do let zn = getNet znet pool <- liftIO $ runNoLoggingT $ initPool dbPath tList <- liftIO $ getShieldedOutputs pool b znet - trees <- liftIO $ getCommitmentTrees pool zebraHost zebraPort (b - 1) + trees <- liftIO $ getCommitmentTrees pool zebraHost zebraPort znet (b - 1) logDebugN "getting Sapling frontier" let sT = getSaplingFrontier $ SaplingCommitmentTree $ ztiSapling trees case sT of @@ -400,7 +401,7 @@ findOrchardActions config b znet za = do let zn = getNet znet pool <- runNoLoggingT $ initPool dbPath tList <- getOrchardActions pool b znet - trees <- getCommitmentTrees pool zebraHost zebraPort (b - 1) + trees <- getCommitmentTrees pool zebraHost zebraPort znet (b - 1) let sT = getOrchardFrontier $ OrchardCommitmentTree $ ztiOrchard trees case sT of Nothing -> throwIO $ userError "Failed to read Orchard commitment tree" @@ -560,7 +561,8 @@ prepareTx pool zebraHost zebraPort zn za bh amt ua memo = do Just r1 -> (4, getBytes r1) logDebugN $ T.pack $ show recipient logDebugN $ T.pack $ "Target block: " ++ show bh - trees <- liftIO $ getCommitmentTrees pool zebraHost zebraPort bh + trees <- + liftIO $ getCommitmentTrees pool zebraHost zebraPort (ZcashNetDB zn) bh let sT = SaplingCommitmentTree $ ztiSapling trees let oT = OrchardCommitmentTree $ ztiOrchard trees case accRead of diff --git a/src/Zenith/DB.hs b/src/Zenith/DB.hs index 18882cc..5b1f125 100644 --- a/src/Zenith/DB.hs +++ b/src/Zenith/DB.hs @@ -705,25 +705,30 @@ saveBlock pool b = runNoLoggingT $ PS.retryOnBusy $ flip PS.runSqlPool pool $ do insert b -- | Read a block by height -getBlock :: ConnectionPool -> Int -> IO (Maybe (Entity ZcashBlock)) -getBlock pool b = +getBlock :: + ConnectionPool -> Int -> ZcashNetDB -> IO (Maybe (Entity ZcashBlock)) +getBlock pool b znet = runNoLoggingT $ PS.retryOnBusy $ flip PS.runSqlPool pool $ do selectOne $ do bl <- from $ table @ZcashBlock - where_ $ bl ^. ZcashBlockHeight ==. val b + where_ $ + bl ^. ZcashBlockHeight ==. val b &&. bl ^. ZcashBlockNetwork ==. + val znet pure bl -getBlockHash :: ConnectionPool -> Int -> IO (Maybe HexString) -getBlockHash pool b = do +getBlockHash :: ConnectionPool -> Int -> ZcashNetDB -> IO (Maybe HexString) +getBlockHash pool b znet = do r <- runNoLoggingT $ PS.retryOnBusy $ flip PS.runSqlPool pool $ do selectOne $ do bl <- from $ table @ZcashBlock - where_ $ bl ^. ZcashBlockHeight ==. val b + where_ $ + bl ^. ZcashBlockHeight ==. val b &&. bl ^. ZcashBlockNetwork ==. + val znet pure $ bl ^. ZcashBlockHash case r of Nothing -> return Nothing @@ -2663,8 +2668,8 @@ completeSync pool st = do return () -- | Rewind the data store to a given block height -rewindWalletData :: ConnectionPool -> Int -> LoggingT IO () -rewindWalletData pool b = do +rewindWalletData :: ConnectionPool -> Int -> ZcashNetDB -> LoggingT IO () +rewindWalletData pool b net = do logDebugN "Starting transaction rewind" liftIO $ clearWalletTransactions pool logDebugN "Completed transaction rewind" @@ -2676,7 +2681,9 @@ rewindWalletData pool b = do oldBlocks <- select $ do blk <- from $ table @ZcashBlock - where_ $ blk ^. ZcashBlockHeight >. val b + where_ + (blk ^. ZcashBlockHeight >. val b &&. blk ^. ZcashBlockNetwork ==. + val net) pure blk let oldBlkKeys = map entityKey oldBlocks oldTxs <- @@ -2696,7 +2703,9 @@ rewindWalletData pool b = do oldBlocks <- select $ do blk <- from $ table @ZcashBlock - where_ $ blk ^. ZcashBlockHeight >. val b + where_ + (blk ^. ZcashBlockHeight >. val b &&. blk ^. ZcashBlockNetwork ==. + val net) pure blk let oldBlkKeys = map entityKey oldBlocks oldTxs <- @@ -2716,7 +2725,9 @@ rewindWalletData pool b = do oldBlocks <- select $ do blk <- from $ table @ZcashBlock - where_ $ blk ^. ZcashBlockHeight >. val b + where_ + (blk ^. ZcashBlockHeight >. val b &&. blk ^. ZcashBlockNetwork ==. + val net) pure blk let oldBlkKeys = map entityKey oldBlocks oldTxs <- @@ -2736,7 +2747,9 @@ rewindWalletData pool b = do oldBlocks <- select $ do blk <- from $ table @ZcashBlock - where_ $ blk ^. ZcashBlockHeight >. val b + where_ + (blk ^. ZcashBlockHeight >. val b &&. blk ^. ZcashBlockNetwork ==. + val net) pure blk let oldBlkKeys = map entityKey oldBlocks oldTxs <- @@ -2756,7 +2769,9 @@ rewindWalletData pool b = do oldBlocks <- select $ do blk <- from $ table @ZcashBlock - where_ $ blk ^. ZcashBlockHeight >. val b + where_ + (blk ^. ZcashBlockHeight >. val b &&. blk ^. ZcashBlockNetwork ==. + val net) pure blk let oldBlkKeys = map entityKey oldBlocks oldTxs <- @@ -2776,7 +2791,9 @@ rewindWalletData pool b = do oldBlocks <- select $ do blk <- from $ table @ZcashBlock - where_ $ blk ^. ZcashBlockHeight >. val b + where_ + (blk ^. ZcashBlockHeight >. val b &&. blk ^. ZcashBlockNetwork ==. + val net) pure blk let oldBlkKeys = map entityKey oldBlocks oldTxs <- @@ -2795,5 +2812,7 @@ rewindWalletData pool b = do flip PS.runSqlPool pool $ do delete $ do blk <- from $ table @ZcashBlock - where_ $ blk ^. ZcashBlockHeight >. val b + where_ + (blk ^. ZcashBlockHeight >. val b &&. blk ^. ZcashBlockNetwork ==. + val net) logDebugN "Completed data store rewind" diff --git a/src/Zenith/GUI.hs b/src/Zenith/GUI.hs index 19003a0..24a962a 100644 --- a/src/Zenith/GUI.hs +++ b/src/Zenith/GUI.hs @@ -1627,7 +1627,7 @@ scanZebra dbPath zHost zPort net sendMsg = do pool <- runNoLoggingT $ initPool dbPath b <- liftIO $ getMinBirthdayHeight pool dbBlock <- getMaxBlock pool $ ZcashNetDB net - chkBlock <- checkIntegrity dbPath zHost zPort dbBlock 1 + chkBlock <- checkIntegrity dbPath zHost zPort net dbBlock 1 syncChk <- isSyncing pool if syncChk then sendMsg (ShowError "Sync already in progress") @@ -1637,7 +1637,7 @@ scanZebra dbPath zHost zPort net sendMsg = do then max dbBlock b else max chkBlock b unless (chkBlock == dbBlock || chkBlock == 1) $ - runStderrLoggingT $ rewindWalletData pool sb + runStderrLoggingT $ rewindWalletData pool sb $ ZcashNetDB net if sb > zgb_blocks bStatus || sb < 1 then sendMsg (ShowError "Invalid starting block for scan") else do diff --git a/src/Zenith/RPC.hs b/src/Zenith/RPC.hs index a88e014..e4d2f7a 100644 --- a/src/Zenith/RPC.hs +++ b/src/Zenith/RPC.hs @@ -889,7 +889,7 @@ scanZebra dbPath zHost zPort net = do pool <- runNoLoggingT $ initPool dbPath b <- getMinBirthdayHeight pool dbBlock <- getMaxBlock pool $ ZcashNetDB net - chkBlock <- checkIntegrity dbPath zHost zPort dbBlock 1 + chkBlock <- checkIntegrity dbPath zHost zPort net dbBlock 1 syncChk <- isSyncing pool unless syncChk $ do let sb = @@ -897,7 +897,7 @@ scanZebra dbPath zHost zPort net = do then max dbBlock b else max chkBlock b unless (chkBlock == dbBlock || chkBlock == 1) $ - runStderrLoggingT $ rewindWalletData pool sb + runStderrLoggingT $ rewindWalletData pool sb $ ZcashNetDB net unless (sb > zgb_blocks bStatus || sb < 1) $ do let bList = [(sb + 1) .. (zgb_blocks bStatus)] unless (null bList) $ do diff --git a/src/Zenith/Scanner.hs b/src/Zenith/Scanner.hs index b48045e..b36ca79 100644 --- a/src/Zenith/Scanner.hs +++ b/src/Zenith/Scanner.hs @@ -246,10 +246,11 @@ checkIntegrity :: T.Text -- ^ Database path -> T.Text -- ^ Zebra host -> Int -- ^ Zebra port + -> ZcashNet -- ^ the network to scan -> Int -- ^ The block to start the check -> Int -- ^ depth -> IO Int -checkIntegrity dbP zHost zPort b d = +checkIntegrity dbP zHost zPort znet b d = if b < 1 then return 1 else do @@ -263,10 +264,10 @@ checkIntegrity dbP zHost zPort b d = Left e -> throwIO $ userError e Right blk -> do pool <- runNoLoggingT $ initPool dbP - dbBlk <- getBlock pool b + dbBlk <- getBlock pool b $ ZcashNetDB znet case dbBlk of - Nothing -> throwIO $ userError "Block mismatch, rescan needed" + Nothing -> return 1 Just dbBlk' -> if bl_hash blk == getHex (zcashBlockHash $ entityVal dbBlk') then return b - else checkIntegrity dbP zHost zPort (b - 5 * d) (d + 1) + else checkIntegrity dbP zHost zPort znet (b - 5 * d) (d + 1)