Return error if listening thread is closed

This commit is contained in:
Victor Denisov 2016-10-23 23:19:49 -07:00
parent 2d348449bc
commit 04e5dd3248

View file

@ -42,12 +42,13 @@ import Data.Bits (bit, testBit)
import Data.Int (Int32, Int64) import Data.Int (Int32, Int64)
import Data.IORef (IORef, newIORef, atomicModifyIORef) import Data.IORef (IORef, newIORef, atomicModifyIORef)
import System.IO (Handle) import System.IO (Handle)
import System.IO.Error (doesNotExistErrorType, mkIOError, ioError)
import System.IO.Unsafe (unsafePerformIO) import System.IO.Unsafe (unsafePerformIO)
import Data.Maybe (maybeToList) import Data.Maybe (maybeToList)
import GHC.Conc (ThreadStatus(..), threadStatus) import GHC.Conc (ThreadStatus(..), threadStatus)
import Control.Monad (forever) import Control.Monad (forever)
import Control.Concurrent.Chan (Chan, newChan, readChan, writeChan) import Control.Concurrent.Chan (Chan, newChan, readChan, writeChan)
import Control.Concurrent (ThreadId, forkIO, killThread) import Control.Concurrent (ThreadId, forkIO, killThread, forkFinally)
import Control.Exception.Lifted (onException, throwIO, try) import Control.Exception.Lifted (onException, throwIO, try)
@ -66,11 +67,11 @@ import qualified Data.Text.Encoding as TE
import Database.MongoDB.Internal.Util (bitOr, byteStringHex) import Database.MongoDB.Internal.Util (bitOr, byteStringHex)
import Database.MongoDB.Transport (Transport) import Database.MongoDB.Transport (Transport)
import qualified Database.MongoDB.Transport as T import qualified Database.MongoDB.Transport as Tr
#if MIN_VERSION_base(4,6,0) #if MIN_VERSION_base(4,6,0)
import Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar, withMVar, import Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar, withMVar,
putMVar, readMVar, mkWeakMVar) putMVar, readMVar, mkWeakMVar, isEmptyMVar)
#else #else
import Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar, withMVar, import Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar, withMVar,
putMVar, readMVar, addMVarFinalizer) putMVar, readMVar, addMVarFinalizer)
@ -88,6 +89,7 @@ data Pipeline = Pipeline
{ vStream :: MVar Transport -- ^ Mutex on handle, so only one thread at a time can write to it { vStream :: MVar Transport -- ^ Mutex on handle, so only one thread at a time can write to it
, responseQueue :: Chan (MVar (Either IOError Response)) -- ^ Queue of threads waiting for responses. Every time a response arrive we pop the next thread and give it the response. , responseQueue :: Chan (MVar (Either IOError Response)) -- ^ Queue of threads waiting for responses. Every time a response arrive we pop the next thread and give it the response.
, listenThread :: ThreadId , listenThread :: ThreadId
, finished :: MVar ()
, serverData :: ServerData , serverData :: ServerData
} }
@ -105,19 +107,25 @@ newPipeline :: ServerData -> Transport -> IO Pipeline
newPipeline serverData stream = do newPipeline serverData stream = do
vStream <- newMVar stream vStream <- newMVar stream
responseQueue <- newChan responseQueue <- newChan
finished <- newEmptyMVar
rec rec
let pipe = Pipeline{..} let pipe = Pipeline{..}
listenThread <- forkIO (listen pipe) listenThread <- forkFinally (listen pipe) (\_ -> putMVar finished ())
_ <- mkWeakMVar vStream $ do _ <- mkWeakMVar vStream $ do
killThread listenThread killThread listenThread
T.close stream Tr.close stream
return pipe return pipe
isFinished :: Pipeline -> IO Bool
isFinished Pipeline {finished} = do
empty <- isEmptyMVar finished
return $ not empty
close :: Pipeline -> IO () close :: Pipeline -> IO ()
-- ^ Close pipe and underlying connection -- ^ Close pipe and underlying connection
close Pipeline{..} = do close Pipeline{..} = do
killThread listenThread killThread listenThread
T.close =<< readMVar vStream Tr.close =<< readMVar vStream
isClosed :: Pipeline -> IO Bool isClosed :: Pipeline -> IO Bool
isClosed Pipeline{listenThread} = do isClosed Pipeline{listenThread} = do
@ -138,7 +146,7 @@ listen Pipeline{..} = do
var <- readChan responseQueue var <- readChan responseQueue
putMVar var e putMVar var e
case e of case e of
Left err -> T.close stream >> ioError err -- close and stop looping Left err -> Tr.close stream >> ioError err -- close and stop looping
Right _ -> return () Right _ -> return ()
psend :: Pipeline -> Message -> IO () psend :: Pipeline -> Message -> IO ()
@ -149,7 +157,12 @@ psend p@Pipeline{..} !message = withMVar vStream (flip writeMessage message) `on
pcall :: Pipeline -> Message -> IO (IO Response) pcall :: Pipeline -> Message -> IO (IO Response)
-- ^ Send message to destination and return /promise/ of response from one message only. The destination must reply to the message (otherwise promises will have the wrong responses in them). -- ^ Send message to destination and return /promise/ of response from one message only. The destination must reply to the message (otherwise promises will have the wrong responses in them).
-- Throw IOError and closes pipeline if send fails, likewise for promised response. -- Throw IOError and closes pipeline if send fails, likewise for promised response.
pcall p@Pipeline{..} message = withMVar vStream doCall `onException` close p where pcall p@Pipeline{..} message = do
finished <- isFinished p
if finished
then ioError $ mkIOError doesNotExistErrorType "Handle has been closed" Nothing Nothing
else withMVar vStream doCall `onException` close p
where
doCall stream = do doCall stream = do
writeMessage stream message writeMessage stream message
var <- newEmptyMVar var <- newEmptyMVar
@ -163,7 +176,7 @@ type Pipe = Pipeline
newPipe :: ServerData -> Handle -> IO Pipe newPipe :: ServerData -> Handle -> IO Pipe
-- ^ Create pipe over handle -- ^ Create pipe over handle
newPipe sd handle = T.fromHandle handle >>= (newPipeWith sd) newPipe sd handle = Tr.fromHandle handle >>= (newPipeWith sd)
newPipeWith :: ServerData -> Transport -> IO Pipe newPipeWith :: ServerData -> Transport -> IO Pipe
-- ^ Create pipe over connection -- ^ Create pipe over connection
@ -202,8 +215,8 @@ writeMessage conn (notices, mRequest) = do
let s = runPut $ putRequest request requestId let s = runPut $ putRequest request requestId
return $ (lenBytes s) `L.append` s return $ (lenBytes s) `L.append` s
T.write conn $ L.toStrict $ L.concat $ noticeStrings ++ (maybeToList requestString) Tr.write conn $ L.toStrict $ L.concat $ noticeStrings ++ (maybeToList requestString)
T.flush conn Tr.flush conn
where where
lenBytes bytes = encodeSize . toEnum . fromEnum $ L.length bytes lenBytes bytes = encodeSize . toEnum . fromEnum $ L.length bytes
encodeSize = runPut . putInt32 . (+ 4) encodeSize = runPut . putInt32 . (+ 4)
@ -215,8 +228,8 @@ readMessage :: Transport -> IO Response
-- ^ read response from a connection -- ^ read response from a connection
readMessage conn = readResp where readMessage conn = readResp where
readResp = do readResp = do
len <- fromEnum . decodeSize . L.fromStrict <$> T.read conn 4 len <- fromEnum . decodeSize . L.fromStrict <$> Tr.read conn 4
runGet getReply . L.fromStrict <$> T.read conn len runGet getReply . L.fromStrict <$> Tr.read conn len
decodeSize = subtract 4 . runGet getInt32 decodeSize = subtract 4 . runGet getInt32
type FullCollection = Text type FullCollection = Text