From 04e5dd32488b471668c1eda2811cdfa576259418 Mon Sep 17 00:00:00 2001 From: Victor Denisov Date: Sun, 23 Oct 2016 23:19:49 -0700 Subject: [PATCH] Return error if listening thread is closed --- Database/MongoDB/Internal/Protocol.hs | 39 ++++++++++++++++++--------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/Database/MongoDB/Internal/Protocol.hs b/Database/MongoDB/Internal/Protocol.hs index 324513d..f087730 100644 --- a/Database/MongoDB/Internal/Protocol.hs +++ b/Database/MongoDB/Internal/Protocol.hs @@ -42,12 +42,13 @@ import Data.Bits (bit, testBit) import Data.Int (Int32, Int64) import Data.IORef (IORef, newIORef, atomicModifyIORef) import System.IO (Handle) +import System.IO.Error (doesNotExistErrorType, mkIOError, ioError) import System.IO.Unsafe (unsafePerformIO) import Data.Maybe (maybeToList) import GHC.Conc (ThreadStatus(..), threadStatus) import Control.Monad (forever) import Control.Concurrent.Chan (Chan, newChan, readChan, writeChan) -import Control.Concurrent (ThreadId, forkIO, killThread) +import Control.Concurrent (ThreadId, forkIO, killThread, forkFinally) import Control.Exception.Lifted (onException, throwIO, try) @@ -66,11 +67,11 @@ import qualified Data.Text.Encoding as TE import Database.MongoDB.Internal.Util (bitOr, byteStringHex) import Database.MongoDB.Transport (Transport) -import qualified Database.MongoDB.Transport as T +import qualified Database.MongoDB.Transport as Tr #if MIN_VERSION_base(4,6,0) import Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar, withMVar, - putMVar, readMVar, mkWeakMVar) + putMVar, readMVar, mkWeakMVar, isEmptyMVar) #else import Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar, withMVar, putMVar, readMVar, addMVarFinalizer) @@ -88,6 +89,7 @@ data Pipeline = Pipeline { vStream :: MVar Transport -- ^ Mutex on handle, so only one thread at a time can write to it , responseQueue :: Chan (MVar (Either IOError Response)) -- ^ Queue of threads waiting for responses. Every time a response arrive we pop the next thread and give it the response. , listenThread :: ThreadId + , finished :: MVar () , serverData :: ServerData } @@ -105,19 +107,25 @@ newPipeline :: ServerData -> Transport -> IO Pipeline newPipeline serverData stream = do vStream <- newMVar stream responseQueue <- newChan + finished <- newEmptyMVar rec let pipe = Pipeline{..} - listenThread <- forkIO (listen pipe) + listenThread <- forkFinally (listen pipe) (\_ -> putMVar finished ()) _ <- mkWeakMVar vStream $ do killThread listenThread - T.close stream + Tr.close stream return pipe +isFinished :: Pipeline -> IO Bool +isFinished Pipeline {finished} = do + empty <- isEmptyMVar finished + return $ not empty + close :: Pipeline -> IO () -- ^ Close pipe and underlying connection close Pipeline{..} = do killThread listenThread - T.close =<< readMVar vStream + Tr.close =<< readMVar vStream isClosed :: Pipeline -> IO Bool isClosed Pipeline{listenThread} = do @@ -138,7 +146,7 @@ listen Pipeline{..} = do var <- readChan responseQueue putMVar var e case e of - Left err -> T.close stream >> ioError err -- close and stop looping + Left err -> Tr.close stream >> ioError err -- close and stop looping Right _ -> return () psend :: Pipeline -> Message -> IO () @@ -149,7 +157,12 @@ psend p@Pipeline{..} !message = withMVar vStream (flip writeMessage message) `on pcall :: Pipeline -> Message -> IO (IO Response) -- ^ Send message to destination and return /promise/ of response from one message only. The destination must reply to the message (otherwise promises will have the wrong responses in them). -- Throw IOError and closes pipeline if send fails, likewise for promised response. -pcall p@Pipeline{..} message = withMVar vStream doCall `onException` close p where +pcall p@Pipeline{..} message = do + finished <- isFinished p + if finished + then ioError $ mkIOError doesNotExistErrorType "Handle has been closed" Nothing Nothing + else withMVar vStream doCall `onException` close p + where doCall stream = do writeMessage stream message var <- newEmptyMVar @@ -163,7 +176,7 @@ type Pipe = Pipeline newPipe :: ServerData -> Handle -> IO Pipe -- ^ Create pipe over handle -newPipe sd handle = T.fromHandle handle >>= (newPipeWith sd) +newPipe sd handle = Tr.fromHandle handle >>= (newPipeWith sd) newPipeWith :: ServerData -> Transport -> IO Pipe -- ^ Create pipe over connection @@ -202,8 +215,8 @@ writeMessage conn (notices, mRequest) = do let s = runPut $ putRequest request requestId return $ (lenBytes s) `L.append` s - T.write conn $ L.toStrict $ L.concat $ noticeStrings ++ (maybeToList requestString) - T.flush conn + Tr.write conn $ L.toStrict $ L.concat $ noticeStrings ++ (maybeToList requestString) + Tr.flush conn where lenBytes bytes = encodeSize . toEnum . fromEnum $ L.length bytes encodeSize = runPut . putInt32 . (+ 4) @@ -215,8 +228,8 @@ readMessage :: Transport -> IO Response -- ^ read response from a connection readMessage conn = readResp where readResp = do - len <- fromEnum . decodeSize . L.fromStrict <$> T.read conn 4 - runGet getReply . L.fromStrict <$> T.read conn len + len <- fromEnum . decodeSize . L.fromStrict <$> Tr.read conn 4 + runGet getReply . L.fromStrict <$> Tr.read conn len decodeSize = subtract 4 . runGet getInt32 type FullCollection = Text