Merge pull request #46 from VictorDenisov/tls
Merge a simple implementation of TLS and drop System.IO.Pipeline
This commit is contained in:
commit
e683f59faa
10 changed files with 246 additions and 286 deletions
|
@ -11,11 +11,11 @@ import Database.MongoDB.Query
|
||||||
import qualified Data.Text as T
|
import qualified Data.Text as T
|
||||||
|
|
||||||
main = defaultMain [
|
main = defaultMain [
|
||||||
bgroup "insert" [ bench "100" $ nfIO doInserts ]
|
bgroup "insert" [ bench "1000" $ nfIO doInserts ]
|
||||||
]
|
]
|
||||||
|
|
||||||
doInserts = do
|
doInserts = do
|
||||||
let docs = (flip map) [0..100] $ \i ->
|
let docs = (flip map) [0..1000] $ \i ->
|
||||||
["name" M.=: (T.pack $ "name " ++ (show i))]
|
["name" M.=: (T.pack $ "name " ++ (show i))]
|
||||||
|
|
||||||
pipe <- M.connect (M.host "127.0.0.1")
|
pipe <- M.connect (M.host "127.0.0.1")
|
||||||
|
|
|
@ -2,6 +2,14 @@
|
||||||
All notable changes to this project will be documented in this file.
|
All notable changes to this project will be documented in this file.
|
||||||
This project adheres to [Package Versioning Policy](https://wiki.haskell.org/Package_versioning_policy).
|
This project adheres to [Package Versioning Policy](https://wiki.haskell.org/Package_versioning_policy).
|
||||||
|
|
||||||
|
## [2.1.0] - unreleased
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- TLS implementation. So far it is an experimental feature.
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
- System.IO.Pipeline module
|
||||||
|
|
||||||
## [2.0.10] - 2015-12-22
|
## [2.0.10] - 2015-12-22
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
|
@ -42,12 +42,11 @@ import Data.Text (Text)
|
||||||
import qualified Data.Bson as B
|
import qualified Data.Bson as B
|
||||||
import qualified Data.Text as T
|
import qualified Data.Text as T
|
||||||
|
|
||||||
import Database.MongoDB.Internal.Protocol (Pipe, newPipe)
|
import Database.MongoDB.Internal.Protocol (Pipe, newPipe, close, isClosed)
|
||||||
import Database.MongoDB.Internal.Util (untilSuccess, liftIOE,
|
import Database.MongoDB.Internal.Util (untilSuccess, liftIOE,
|
||||||
updateAssocs, shuffle, mergesortM)
|
updateAssocs, shuffle, mergesortM)
|
||||||
import Database.MongoDB.Query (Command, Failure(ConnectionFailure), access,
|
import Database.MongoDB.Query (Command, Failure(ConnectionFailure), access,
|
||||||
slaveOk, runCommand)
|
slaveOk, runCommand)
|
||||||
import System.IO.Pipeline (close, isClosed)
|
|
||||||
|
|
||||||
adminCommand :: Command -> Pipe -> IO Document
|
adminCommand :: Command -> Pipe -> IO Document
|
||||||
-- ^ Run command against admin database on server connected to pipe. Fail if connection fails.
|
-- ^ Run command against admin database on server connected to pipe. Fail if connection fails.
|
||||||
|
|
|
@ -1,72 +0,0 @@
|
||||||
|
|
||||||
-- | This module defines a connection interface. It could be a regular
|
|
||||||
-- network connection, TLS connection, a mock or anything else.
|
|
||||||
|
|
||||||
module Database.MongoDB.Internal.Connection (
|
|
||||||
Connection(..),
|
|
||||||
readExactly,
|
|
||||||
fromHandle,
|
|
||||||
) where
|
|
||||||
|
|
||||||
import Prelude hiding (read)
|
|
||||||
import Data.Monoid
|
|
||||||
import Data.IORef
|
|
||||||
import Data.ByteString (ByteString)
|
|
||||||
import qualified Data.ByteString as ByteString
|
|
||||||
import qualified Data.ByteString.Lazy as Lazy (ByteString)
|
|
||||||
import qualified Data.ByteString.Lazy as Lazy.ByteString
|
|
||||||
import Control.Monad
|
|
||||||
import System.IO
|
|
||||||
import System.IO.Error (mkIOError, eofErrorType)
|
|
||||||
|
|
||||||
-- | Abstract connection interface
|
|
||||||
--
|
|
||||||
-- `read` should return `ByteString.null` on EOF
|
|
||||||
data Connection = Connection {
|
|
||||||
read :: IO ByteString,
|
|
||||||
unread :: ByteString -> IO (),
|
|
||||||
write :: ByteString -> IO (),
|
|
||||||
flush :: IO (),
|
|
||||||
close :: IO ()}
|
|
||||||
|
|
||||||
readExactly :: Connection -> Int -> IO Lazy.ByteString
|
|
||||||
-- ^ Read specified number of bytes
|
|
||||||
--
|
|
||||||
-- If EOF is reached before N bytes then raise EOF exception.
|
|
||||||
readExactly conn count = go mempty count
|
|
||||||
where
|
|
||||||
go acc n = do
|
|
||||||
-- read until get enough bytes
|
|
||||||
chunk <- read conn
|
|
||||||
when (ByteString.null chunk) $
|
|
||||||
ioError eof
|
|
||||||
let len = ByteString.length chunk
|
|
||||||
if len >= n
|
|
||||||
then do
|
|
||||||
let (res, rest) = ByteString.splitAt n chunk
|
|
||||||
unless (ByteString.null rest) $
|
|
||||||
unread conn rest
|
|
||||||
return (acc <> Lazy.ByteString.fromStrict res)
|
|
||||||
else go (acc <> Lazy.ByteString.fromStrict chunk) (n - len)
|
|
||||||
eof = mkIOError eofErrorType "Database.MongoDB.Internal.Connection"
|
|
||||||
Nothing Nothing
|
|
||||||
|
|
||||||
fromHandle :: Handle -> IO Connection
|
|
||||||
-- ^ Make connection form handle
|
|
||||||
fromHandle handle = do
|
|
||||||
restRef <- newIORef mempty
|
|
||||||
return Connection
|
|
||||||
{ read = do
|
|
||||||
rest <- readIORef restRef
|
|
||||||
writeIORef restRef mempty
|
|
||||||
if ByteString.null rest
|
|
||||||
-- 32k corresponds to the default chunk size
|
|
||||||
-- used in bytestring package
|
|
||||||
then ByteString.hGetSome handle (32 * 1024)
|
|
||||||
else return rest
|
|
||||||
, unread = \rest ->
|
|
||||||
modifyIORef restRef (rest <>)
|
|
||||||
, write = ByteString.hPut handle
|
|
||||||
, flush = hFlush handle
|
|
||||||
, close = hClose handle
|
|
||||||
}
|
|
|
@ -8,6 +8,14 @@
|
||||||
{-# LANGUAGE CPP, FlexibleContexts, TupleSections, TypeSynonymInstances #-}
|
{-# LANGUAGE CPP, FlexibleContexts, TupleSections, TypeSynonymInstances #-}
|
||||||
{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, UndecidableInstances #-}
|
{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, UndecidableInstances #-}
|
||||||
|
|
||||||
|
{-# LANGUAGE NamedFieldPuns, ScopedTypeVariables #-}
|
||||||
|
|
||||||
|
#if (__GLASGOW_HASKELL__ >= 706)
|
||||||
|
{-# LANGUAGE RecursiveDo #-}
|
||||||
|
#else
|
||||||
|
{-# LANGUAGE DoRec #-}
|
||||||
|
#endif
|
||||||
|
|
||||||
module Database.MongoDB.Internal.Protocol (
|
module Database.MongoDB.Internal.Protocol (
|
||||||
FullCollection,
|
FullCollection,
|
||||||
-- * Pipe
|
-- * Pipe
|
||||||
|
@ -19,13 +27,13 @@ module Database.MongoDB.Internal.Protocol (
|
||||||
-- ** Reply
|
-- ** Reply
|
||||||
Reply(..), ResponseFlag(..),
|
Reply(..), ResponseFlag(..),
|
||||||
-- * Authentication
|
-- * Authentication
|
||||||
Username, Password, Nonce, pwHash, pwKey
|
Username, Password, Nonce, pwHash, pwKey,
|
||||||
|
isClosed, close
|
||||||
) where
|
) where
|
||||||
|
|
||||||
#if !MIN_VERSION_base(4,8,0)
|
#if !MIN_VERSION_base(4,8,0)
|
||||||
import Control.Applicative ((<$>))
|
import Control.Applicative ((<$>))
|
||||||
#endif
|
#endif
|
||||||
import Control.Arrow ((***))
|
|
||||||
import Control.Monad (forM, replicateM, unless)
|
import Control.Monad (forM, replicateM, unless)
|
||||||
import Data.Binary.Get (Get, runGet)
|
import Data.Binary.Get (Get, runGet)
|
||||||
import Data.Binary.Put (Put, runPut)
|
import Data.Binary.Put (Put, runPut)
|
||||||
|
@ -35,6 +43,12 @@ import Data.IORef (IORef, newIORef, atomicModifyIORef)
|
||||||
import System.IO (Handle)
|
import System.IO (Handle)
|
||||||
import System.IO.Unsafe (unsafePerformIO)
|
import System.IO.Unsafe (unsafePerformIO)
|
||||||
import Data.Maybe (maybeToList)
|
import Data.Maybe (maybeToList)
|
||||||
|
import GHC.Conc (ThreadStatus(..), threadStatus)
|
||||||
|
import Control.Monad (forever)
|
||||||
|
import Control.Concurrent.Chan (Chan, newChan, readChan, writeChan)
|
||||||
|
import Control.Concurrent (ThreadId, forkIO, killThread)
|
||||||
|
|
||||||
|
import Control.Exception.Lifted (onException, throwIO, try)
|
||||||
|
|
||||||
import qualified Data.ByteString.Lazy as L
|
import qualified Data.ByteString.Lazy as L
|
||||||
|
|
||||||
|
@ -48,38 +62,111 @@ import qualified Crypto.Hash.MD5 as MD5
|
||||||
import qualified Data.Text as T
|
import qualified Data.Text as T
|
||||||
import qualified Data.Text.Encoding as TE
|
import qualified Data.Text.Encoding as TE
|
||||||
|
|
||||||
import Database.MongoDB.Internal.Util (whenJust, bitOr, byteStringHex)
|
import Database.MongoDB.Internal.Util (bitOr, byteStringHex)
|
||||||
import System.IO.Pipeline (Pipeline, newPipeline, IOStream(..))
|
|
||||||
|
|
||||||
import qualified System.IO.Pipeline as P
|
import Database.MongoDB.Transport (Transport)
|
||||||
|
import qualified Database.MongoDB.Transport as T
|
||||||
|
|
||||||
import Database.MongoDB.Internal.Connection (Connection)
|
#if MIN_VERSION_base(4,6,0)
|
||||||
import qualified Database.MongoDB.Internal.Connection as Connection
|
import Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar, withMVar,
|
||||||
|
putMVar, readMVar, mkWeakMVar)
|
||||||
|
#else
|
||||||
|
import Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar, withMVar,
|
||||||
|
putMVar, readMVar, addMVarFinalizer)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if !MIN_VERSION_base(4,6,0)
|
||||||
|
mkWeakMVar :: MVar a -> IO () -> IO ()
|
||||||
|
mkWeakMVar = addMVarFinalizer
|
||||||
|
#endif
|
||||||
|
|
||||||
|
-- * Pipeline
|
||||||
|
|
||||||
|
-- | Thread-safe and pipelined connection
|
||||||
|
data Pipeline = Pipeline {
|
||||||
|
vStream :: MVar Transport, -- ^ Mutex on handle, so only one thread at a time can write to it
|
||||||
|
responseQueue :: Chan (MVar (Either IOError Response)), -- ^ Queue of threads waiting for responses. Every time a response arrive we pop the next thread and give it the response.
|
||||||
|
listenThread :: ThreadId
|
||||||
|
}
|
||||||
|
|
||||||
|
-- | Create new Pipeline over given handle. You should 'close' pipeline when finished, which will also close handle. If pipeline is not closed but eventually garbage collected, it will be closed along with handle.
|
||||||
|
newPipeline :: Transport -> IO Pipeline
|
||||||
|
newPipeline stream = do
|
||||||
|
vStream <- newMVar stream
|
||||||
|
responseQueue <- newChan
|
||||||
|
rec
|
||||||
|
let pipe = Pipeline{..}
|
||||||
|
listenThread <- forkIO (listen pipe)
|
||||||
|
_ <- mkWeakMVar vStream $ do
|
||||||
|
killThread listenThread
|
||||||
|
T.close stream
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
close :: Pipeline -> IO ()
|
||||||
|
-- ^ Close pipe and underlying connection
|
||||||
|
close Pipeline{..} = do
|
||||||
|
killThread listenThread
|
||||||
|
T.close =<< readMVar vStream
|
||||||
|
|
||||||
|
isClosed :: Pipeline -> IO Bool
|
||||||
|
isClosed Pipeline{listenThread} = do
|
||||||
|
status <- threadStatus listenThread
|
||||||
|
return $ case status of
|
||||||
|
ThreadRunning -> False
|
||||||
|
ThreadFinished -> True
|
||||||
|
ThreadBlocked _ -> False
|
||||||
|
ThreadDied -> True
|
||||||
|
--isPipeClosed Pipeline{..} = isClosed =<< readMVar vHandle -- isClosed hangs while listen loop is waiting on read
|
||||||
|
|
||||||
|
listen :: Pipeline -> IO ()
|
||||||
|
-- ^ Listen for responses and supply them to waiting threads in order
|
||||||
|
listen Pipeline{..} = do
|
||||||
|
stream <- readMVar vStream
|
||||||
|
forever $ do
|
||||||
|
e <- try $ readMessage stream
|
||||||
|
var <- readChan responseQueue
|
||||||
|
putMVar var e
|
||||||
|
case e of
|
||||||
|
Left err -> T.close stream >> ioError err -- close and stop looping
|
||||||
|
Right _ -> return ()
|
||||||
|
|
||||||
|
psend :: Pipeline -> Message -> IO ()
|
||||||
|
-- ^ Send message to destination; the destination must not response (otherwise future 'call's will get these responses instead of their own).
|
||||||
|
-- Throw IOError and close pipeline if send fails
|
||||||
|
psend p@Pipeline{..} message = withMVar vStream (flip writeMessage message) `onException` close p
|
||||||
|
|
||||||
|
pcall :: Pipeline -> Message -> IO (IO Response)
|
||||||
|
-- ^ Send message to destination and return /promise/ of response from one message only. The destination must reply to the message (otherwise promises will have the wrong responses in them).
|
||||||
|
-- Throw IOError and closes pipeline if send fails, likewise for promised response.
|
||||||
|
pcall p@Pipeline{..} message = withMVar vStream doCall `onException` close p where
|
||||||
|
doCall stream = do
|
||||||
|
writeMessage stream message
|
||||||
|
var <- newEmptyMVar
|
||||||
|
liftIO $ writeChan responseQueue var
|
||||||
|
return $ readMVar var >>= either throwIO return -- return promise
|
||||||
|
|
||||||
-- * Pipe
|
-- * Pipe
|
||||||
|
|
||||||
type Pipe = Pipeline Response Message
|
type Pipe = Pipeline
|
||||||
-- ^ Thread-safe TCP connection with pipelined requests
|
-- ^ Thread-safe TCP connection with pipelined requests
|
||||||
|
|
||||||
newPipe :: Handle -> IO Pipe
|
newPipe :: Handle -> IO Pipe
|
||||||
-- ^ Create pipe over handle
|
-- ^ Create pipe over handle
|
||||||
newPipe handle = Connection.fromHandle handle >>= newPipeWith
|
newPipe handle = T.fromHandle handle >>= newPipeWith
|
||||||
|
|
||||||
newPipeWith :: Connection -> IO Pipe
|
newPipeWith :: Transport -> IO Pipe
|
||||||
-- ^ Create pipe over connection
|
-- ^ Create pipe over connection
|
||||||
newPipeWith conn = newPipeline $ IOStream (writeMessage conn)
|
newPipeWith conn = newPipeline conn
|
||||||
(readMessage conn)
|
|
||||||
(Connection.close conn)
|
|
||||||
|
|
||||||
send :: Pipe -> [Notice] -> IO ()
|
send :: Pipe -> [Notice] -> IO ()
|
||||||
-- ^ Send notices as a contiguous batch to server with no reply. Throw IOError if connection fails.
|
-- ^ Send notices as a contiguous batch to server with no reply. Throw IOError if connection fails.
|
||||||
send pipe notices = P.send pipe (notices, Nothing)
|
send pipe notices = psend pipe (notices, Nothing)
|
||||||
|
|
||||||
call :: Pipe -> [Notice] -> Request -> IO (IO Reply)
|
call :: Pipe -> [Notice] -> Request -> IO (IO Reply)
|
||||||
-- ^ Send notices and request as a contiguous batch to server and return reply promise, which will block when invoked until reply arrives. This call and resulting promise will throw IOError if connection fails.
|
-- ^ Send notices and request as a contiguous batch to server and return reply promise, which will block when invoked until reply arrives. This call and resulting promise will throw IOError if connection fails.
|
||||||
call pipe notices request = do
|
call pipe notices request = do
|
||||||
requestId <- genRequestId
|
requestId <- genRequestId
|
||||||
promise <- P.call pipe (notices, Just (request, requestId))
|
promise <- pcall pipe (notices, Just (request, requestId))
|
||||||
return $ check requestId <$> promise
|
return $ check requestId <$> promise
|
||||||
where
|
where
|
||||||
check requestId (responseTo, reply) = if requestId == responseTo then reply else
|
check requestId (responseTo, reply) = if requestId == responseTo then reply else
|
||||||
|
@ -91,7 +178,7 @@ type Message = ([Notice], Maybe (Request, RequestId))
|
||||||
-- ^ A write notice(s) with getLastError request, or just query request.
|
-- ^ A write notice(s) with getLastError request, or just query request.
|
||||||
-- Note, that requestId will be out of order because request ids will be generated for notices after the request id supplied was generated. This is ok because the mongo server does not care about order just uniqueness.
|
-- Note, that requestId will be out of order because request ids will be generated for notices after the request id supplied was generated. This is ok because the mongo server does not care about order just uniqueness.
|
||||||
|
|
||||||
writeMessage :: Connection -> Message -> IO ()
|
writeMessage :: Transport -> Message -> IO ()
|
||||||
-- ^ Write message to connection
|
-- ^ Write message to connection
|
||||||
writeMessage conn (notices, mRequest) = do
|
writeMessage conn (notices, mRequest) = do
|
||||||
noticeStrings <- forM notices $ \n -> do
|
noticeStrings <- forM notices $ \n -> do
|
||||||
|
@ -104,8 +191,8 @@ writeMessage conn (notices, mRequest) = do
|
||||||
let s = runPut $ putRequest request requestId
|
let s = runPut $ putRequest request requestId
|
||||||
return $ (lenBytes s) `L.append` s
|
return $ (lenBytes s) `L.append` s
|
||||||
|
|
||||||
Connection.write conn $ L.toStrict $ L.concat $ noticeStrings ++ (maybeToList requestString)
|
T.write conn $ L.toStrict $ L.concat $ noticeStrings ++ (maybeToList requestString)
|
||||||
Connection.flush conn
|
T.flush conn
|
||||||
where
|
where
|
||||||
lenBytes bytes = encodeSize . toEnum . fromEnum $ L.length bytes
|
lenBytes bytes = encodeSize . toEnum . fromEnum $ L.length bytes
|
||||||
encodeSize = runPut . putInt32 . (+ 4)
|
encodeSize = runPut . putInt32 . (+ 4)
|
||||||
|
@ -113,12 +200,12 @@ writeMessage conn (notices, mRequest) = do
|
||||||
type Response = (ResponseTo, Reply)
|
type Response = (ResponseTo, Reply)
|
||||||
-- ^ Message received from a Mongo server in response to a Request
|
-- ^ Message received from a Mongo server in response to a Request
|
||||||
|
|
||||||
readMessage :: Connection -> IO Response
|
readMessage :: Transport -> IO Response
|
||||||
-- ^ read response from a connection
|
-- ^ read response from a connection
|
||||||
readMessage conn = readResp where
|
readMessage conn = readResp where
|
||||||
readResp = do
|
readResp = do
|
||||||
len <- fromEnum . decodeSize <$> Connection.readExactly conn 4
|
len <- fromEnum . decodeSize . L.fromStrict <$> T.read conn 4
|
||||||
runGet getReply <$> Connection.readExactly conn len
|
runGet getReply . L.fromStrict <$> T.read conn len
|
||||||
decodeSize = subtract 4 . runGet getInt32
|
decodeSize = subtract 4 . runGet getInt32
|
||||||
|
|
||||||
type FullCollection = Text
|
type FullCollection = Text
|
||||||
|
|
32
Database/MongoDB/Transport.hs
Normal file
32
Database/MongoDB/Transport.hs
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
|
||||||
|
-- | This module defines a connection interface. It could be a regular
|
||||||
|
-- network connection, TLS connection, a mock or anything else.
|
||||||
|
|
||||||
|
module Database.MongoDB.Transport (
|
||||||
|
Transport(..),
|
||||||
|
fromHandle,
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Prelude hiding (read)
|
||||||
|
import Data.ByteString (ByteString)
|
||||||
|
import qualified Data.ByteString as ByteString
|
||||||
|
import System.IO
|
||||||
|
|
||||||
|
-- | Abstract transport interface
|
||||||
|
--
|
||||||
|
-- `read` should return `ByteString.null` on EOF
|
||||||
|
data Transport = Transport {
|
||||||
|
read :: Int -> IO ByteString,
|
||||||
|
write :: ByteString -> IO (),
|
||||||
|
flush :: IO (),
|
||||||
|
close :: IO ()}
|
||||||
|
|
||||||
|
fromHandle :: Handle -> IO Transport
|
||||||
|
-- ^ Make connection form handle
|
||||||
|
fromHandle handle = do
|
||||||
|
return Transport
|
||||||
|
{ read = ByteString.hGet handle
|
||||||
|
, write = ByteString.hPut handle
|
||||||
|
, flush = hFlush handle
|
||||||
|
, close = hClose handle
|
||||||
|
}
|
88
Database/MongoDB/Transport/Tls.hs
Normal file
88
Database/MongoDB/Transport/Tls.hs
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE RecordWildCards #-}
|
||||||
|
|
||||||
|
{-|
|
||||||
|
Module : MongoDB TLS
|
||||||
|
Description : TLS transport for mongodb
|
||||||
|
Copyright : (c) Yuras Shumovich, 2016
|
||||||
|
License : Apache 2.0
|
||||||
|
Maintainer : Victor Denisov denisovenator@gmail.com
|
||||||
|
Stability : experimental
|
||||||
|
Portability : POSIX
|
||||||
|
|
||||||
|
This module is for connecting to TLS enabled mongodb servers.
|
||||||
|
ATTENTION!!! Be aware that this module is highly experimental and is
|
||||||
|
barely tested. The current implementation doesn't verify server's identity.
|
||||||
|
It only allows you to connect to a mongodb server using TLS protocol.
|
||||||
|
-}
|
||||||
|
module Database.MongoDB.Transport.Tls
|
||||||
|
(connect)
|
||||||
|
where
|
||||||
|
|
||||||
|
import Data.IORef
|
||||||
|
import Data.Monoid
|
||||||
|
import qualified Data.ByteString as ByteString
|
||||||
|
import qualified Data.ByteString.Lazy as Lazy.ByteString
|
||||||
|
import Data.Default.Class (def)
|
||||||
|
import Control.Applicative ((<$>))
|
||||||
|
import Control.Exception (bracketOnError)
|
||||||
|
import Control.Monad (when, unless)
|
||||||
|
import System.IO
|
||||||
|
import Database.MongoDB (Pipe)
|
||||||
|
import Database.MongoDB.Internal.Protocol (newPipeWith)
|
||||||
|
import Database.MongoDB.Transport (Transport(Transport))
|
||||||
|
import qualified Database.MongoDB.Transport as T
|
||||||
|
import System.IO.Error (mkIOError, eofErrorType)
|
||||||
|
import Network (connectTo, HostName, PortID)
|
||||||
|
import qualified Network.TLS as TLS
|
||||||
|
import qualified Network.TLS.Extra.Cipher as TLS
|
||||||
|
|
||||||
|
-- | Connect to mongodb using TLS
|
||||||
|
connect :: HostName -> PortID -> IO Pipe
|
||||||
|
connect host port = bracketOnError (connectTo host port) hClose $ \handle -> do
|
||||||
|
|
||||||
|
let params = (TLS.defaultParamsClient host "")
|
||||||
|
{ TLS.clientSupported = def
|
||||||
|
{ TLS.supportedCiphers = TLS.ciphersuite_all}
|
||||||
|
, TLS.clientHooks = def
|
||||||
|
{ TLS.onServerCertificate = \_ _ _ _ -> return []}
|
||||||
|
}
|
||||||
|
context <- TLS.contextNew handle params
|
||||||
|
TLS.handshake context
|
||||||
|
|
||||||
|
conn <- tlsConnection context
|
||||||
|
newPipeWith conn
|
||||||
|
|
||||||
|
tlsConnection :: TLS.Context -> IO Transport
|
||||||
|
tlsConnection ctx = do
|
||||||
|
restRef <- newIORef mempty
|
||||||
|
return Transport
|
||||||
|
{ T.read = \count -> let
|
||||||
|
readSome = do
|
||||||
|
rest <- readIORef restRef
|
||||||
|
writeIORef restRef mempty
|
||||||
|
if ByteString.null rest
|
||||||
|
then TLS.recvData ctx
|
||||||
|
else return rest
|
||||||
|
unread = \rest ->
|
||||||
|
modifyIORef restRef (rest <>)
|
||||||
|
go acc n = do
|
||||||
|
-- read until get enough bytes
|
||||||
|
chunk <- readSome
|
||||||
|
when (ByteString.null chunk) $
|
||||||
|
ioError eof
|
||||||
|
let len = ByteString.length chunk
|
||||||
|
if len >= n
|
||||||
|
then do
|
||||||
|
let (res, rest) = ByteString.splitAt n chunk
|
||||||
|
unless (ByteString.null rest) $
|
||||||
|
unread rest
|
||||||
|
return (acc <> Lazy.ByteString.fromStrict res)
|
||||||
|
else go (acc <> Lazy.ByteString.fromStrict chunk) (n - len)
|
||||||
|
eof = mkIOError eofErrorType "Database.MongoDB.Transport"
|
||||||
|
Nothing Nothing
|
||||||
|
in Lazy.ByteString.toStrict <$> go mempty count
|
||||||
|
, T.write = TLS.sendData ctx . Lazy.ByteString.fromStrict
|
||||||
|
, T.flush = TLS.contextFlush ctx
|
||||||
|
, T.close = TLS.contextClose ctx
|
||||||
|
}
|
|
@ -1,118 +0,0 @@
|
||||||
{- | Pipelining is sending multiple requests over a socket and receiving the responses later in the same order (a' la HTTP pipelining). This is faster than sending one request, waiting for the response, then sending the next request, and so on. This implementation returns a /promise (future)/ response for each request that when invoked waits for the response if not already arrived. Multiple threads can send on the same pipeline (and get promises back); it will send each thread's request right away without waiting.
|
|
||||||
|
|
||||||
A pipeline closes itself when a read or write causes an error, so you can detect a broken pipeline by checking isClosed. It also closes itself when garbage collected, or you can close it explicitly. -}
|
|
||||||
|
|
||||||
{-# LANGUAGE RecordWildCards, NamedFieldPuns, ScopedTypeVariables #-}
|
|
||||||
{-# LANGUAGE CPP, FlexibleContexts #-}
|
|
||||||
|
|
||||||
#if (__GLASGOW_HASKELL__ >= 706)
|
|
||||||
{-# LANGUAGE RecursiveDo #-}
|
|
||||||
#else
|
|
||||||
{-# LANGUAGE DoRec #-}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
module System.IO.Pipeline (
|
|
||||||
-- * IOStream
|
|
||||||
IOStream(..),
|
|
||||||
-- * Pipeline
|
|
||||||
Pipeline, newPipeline, send, call, close, isClosed
|
|
||||||
) where
|
|
||||||
|
|
||||||
import Prelude hiding (length)
|
|
||||||
import Control.Concurrent (ThreadId, forkIO, killThread)
|
|
||||||
import Control.Concurrent.Chan (Chan, newChan, readChan, writeChan)
|
|
||||||
import Control.Monad (forever)
|
|
||||||
import GHC.Conc (ThreadStatus(..), threadStatus)
|
|
||||||
|
|
||||||
import Control.Monad.Trans (liftIO)
|
|
||||||
#if MIN_VERSION_base(4,6,0)
|
|
||||||
import Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar, withMVar,
|
|
||||||
putMVar, readMVar, mkWeakMVar)
|
|
||||||
#else
|
|
||||||
import Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar, withMVar,
|
|
||||||
putMVar, readMVar, addMVarFinalizer)
|
|
||||||
#endif
|
|
||||||
import Control.Exception.Lifted (onException, throwIO, try)
|
|
||||||
|
|
||||||
#if !MIN_VERSION_base(4,6,0)
|
|
||||||
mkWeakMVar :: MVar a -> IO () -> IO ()
|
|
||||||
mkWeakMVar = addMVarFinalizer
|
|
||||||
#endif
|
|
||||||
|
|
||||||
-- * IOStream
|
|
||||||
|
|
||||||
-- | An IO sink and source where value of type @o@ are sent and values of type @i@ are received.
|
|
||||||
data IOStream i o = IOStream {
|
|
||||||
writeStream :: o -> IO (),
|
|
||||||
readStream :: IO i,
|
|
||||||
closeStream :: IO () }
|
|
||||||
|
|
||||||
-- * Pipeline
|
|
||||||
|
|
||||||
-- | Thread-safe and pipelined connection
|
|
||||||
data Pipeline i o = Pipeline {
|
|
||||||
vStream :: MVar (IOStream i o), -- ^ Mutex on handle, so only one thread at a time can write to it
|
|
||||||
responseQueue :: Chan (MVar (Either IOError i)), -- ^ Queue of threads waiting for responses. Every time a response arrive we pop the next thread and give it the response.
|
|
||||||
listenThread :: ThreadId
|
|
||||||
}
|
|
||||||
|
|
||||||
-- | Create new Pipeline over given handle. You should 'close' pipeline when finished, which will also close handle. If pipeline is not closed but eventually garbage collected, it will be closed along with handle.
|
|
||||||
newPipeline :: IOStream i o -> IO (Pipeline i o)
|
|
||||||
newPipeline stream = do
|
|
||||||
vStream <- newMVar stream
|
|
||||||
responseQueue <- newChan
|
|
||||||
rec
|
|
||||||
let pipe = Pipeline{..}
|
|
||||||
listenThread <- forkIO (listen pipe)
|
|
||||||
_ <- mkWeakMVar vStream $ do
|
|
||||||
killThread listenThread
|
|
||||||
closeStream stream
|
|
||||||
return pipe
|
|
||||||
|
|
||||||
close :: Pipeline i o -> IO ()
|
|
||||||
-- ^ Close pipe and underlying connection
|
|
||||||
close Pipeline{..} = do
|
|
||||||
killThread listenThread
|
|
||||||
closeStream =<< readMVar vStream
|
|
||||||
|
|
||||||
isClosed :: Pipeline i o -> IO Bool
|
|
||||||
isClosed Pipeline{listenThread} = do
|
|
||||||
status <- threadStatus listenThread
|
|
||||||
return $ case status of
|
|
||||||
ThreadRunning -> False
|
|
||||||
ThreadFinished -> True
|
|
||||||
ThreadBlocked _ -> False
|
|
||||||
ThreadDied -> True
|
|
||||||
--isPipeClosed Pipeline{..} = isClosed =<< readMVar vHandle -- isClosed hangs while listen loop is waiting on read
|
|
||||||
|
|
||||||
listen :: Pipeline i o -> IO ()
|
|
||||||
-- ^ Listen for responses and supply them to waiting threads in order
|
|
||||||
listen Pipeline{..} = do
|
|
||||||
stream <- readMVar vStream
|
|
||||||
forever $ do
|
|
||||||
e <- try $ readStream stream
|
|
||||||
var <- readChan responseQueue
|
|
||||||
putMVar var e
|
|
||||||
case e of
|
|
||||||
Left err -> closeStream stream >> ioError err -- close and stop looping
|
|
||||||
Right _ -> return ()
|
|
||||||
|
|
||||||
send :: Pipeline i o -> o -> IO ()
|
|
||||||
-- ^ Send message to destination; the destination must not response (otherwise future 'call's will get these responses instead of their own).
|
|
||||||
-- Throw IOError and close pipeline if send fails
|
|
||||||
send p@Pipeline{..} message = withMVar vStream (flip writeStream message) `onException` close p
|
|
||||||
|
|
||||||
call :: Pipeline i o -> o -> IO (IO i)
|
|
||||||
-- ^ Send message to destination and return /promise/ of response from one message only. The destination must reply to the message (otherwise promises will have the wrong responses in them).
|
|
||||||
-- Throw IOError and closes pipeline if send fails, likewise for promised response.
|
|
||||||
call p@Pipeline{..} message = withMVar vStream doCall `onException` close p where
|
|
||||||
doCall stream = do
|
|
||||||
writeStream stream message
|
|
||||||
var <- newEmptyMVar
|
|
||||||
liftIO $ writeChan responseQueue var
|
|
||||||
return $ readMVar var >>= either throwIO return -- return promise
|
|
||||||
|
|
||||||
|
|
||||||
{- Authors: Tony Hannan <tony@10gen.com>
|
|
||||||
Copyright 2011 10gen Inc.
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0. Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -}
|
|
|
@ -39,6 +39,8 @@ Library
|
||||||
, random-shuffle -any
|
, random-shuffle -any
|
||||||
, monad-control >= 0.3.1
|
, monad-control >= 0.3.1
|
||||||
, lifted-base >= 0.1.0.3
|
, lifted-base >= 0.1.0.3
|
||||||
|
, tls >= 1.2.0
|
||||||
|
, data-default-class -any
|
||||||
, transformers-base >= 0.4.1
|
, transformers-base >= 0.4.1
|
||||||
, hashtables >= 1.1.2.0
|
, hashtables >= 1.1.2.0
|
||||||
, base16-bytestring >= 0.1.1.6
|
, base16-bytestring >= 0.1.1.6
|
||||||
|
@ -48,11 +50,11 @@ Library
|
||||||
Exposed-modules: Database.MongoDB
|
Exposed-modules: Database.MongoDB
|
||||||
Database.MongoDB.Admin
|
Database.MongoDB.Admin
|
||||||
Database.MongoDB.Connection
|
Database.MongoDB.Connection
|
||||||
Database.MongoDB.Internal.Connection
|
|
||||||
Database.MongoDB.Internal.Protocol
|
Database.MongoDB.Internal.Protocol
|
||||||
Database.MongoDB.Internal.Util
|
Database.MongoDB.Internal.Util
|
||||||
Database.MongoDB.Query
|
Database.MongoDB.Query
|
||||||
System.IO.Pipeline
|
Database.MongoDB.Transport
|
||||||
|
Database.MongoDB.Transport.Tls
|
||||||
|
|
||||||
Source-repository head
|
Source-repository head
|
||||||
Type: git
|
Type: git
|
||||||
|
@ -82,6 +84,8 @@ Benchmark bench
|
||||||
type: exitcode-stdio-1.0
|
type: exitcode-stdio-1.0
|
||||||
Build-depends: array -any
|
Build-depends: array -any
|
||||||
, base < 5
|
, base < 5
|
||||||
|
, base64-bytestring
|
||||||
|
, base16-bytestring
|
||||||
, binary -any
|
, binary -any
|
||||||
, bson >= 0.3 && < 0.4
|
, bson >= 0.3 && < 0.4
|
||||||
, text
|
, text
|
||||||
|
@ -90,6 +94,7 @@ Benchmark bench
|
||||||
, mtl >= 2
|
, mtl >= 2
|
||||||
, cryptohash -any
|
, cryptohash -any
|
||||||
, network -any
|
, network -any
|
||||||
|
, nonce
|
||||||
, parsec -any
|
, parsec -any
|
||||||
, random -any
|
, random -any
|
||||||
, random-shuffle -any
|
, random-shuffle -any
|
||||||
|
|
|
@ -1,69 +0,0 @@
|
||||||
|
|
||||||
module Internal.ConnectionSpec (
|
|
||||||
spec,
|
|
||||||
) where
|
|
||||||
|
|
||||||
import Prelude hiding (read)
|
|
||||||
import Data.Monoid
|
|
||||||
import Data.IORef
|
|
||||||
import Control.Monad
|
|
||||||
import System.IO.Error (isEOFError)
|
|
||||||
import Test.Hspec
|
|
||||||
|
|
||||||
import Database.MongoDB.Internal.Connection
|
|
||||||
|
|
||||||
spec :: Spec
|
|
||||||
spec = describe "Internal.Connection" $ do
|
|
||||||
readExactlySpec
|
|
||||||
|
|
||||||
readExactlySpec :: Spec
|
|
||||||
readExactlySpec = describe "readExactly" $ do
|
|
||||||
it "should return specified number of bytes" $ do
|
|
||||||
let conn = Connection
|
|
||||||
{ read = return "12345"
|
|
||||||
, unread = \_ -> return ()
|
|
||||||
, write = \_ -> return ()
|
|
||||||
, flush = return ()
|
|
||||||
, close = return ()
|
|
||||||
}
|
|
||||||
|
|
||||||
res <- readExactly conn 3
|
|
||||||
res `shouldBe` "123"
|
|
||||||
|
|
||||||
it "should unread the rest" $ do
|
|
||||||
restRef <- newIORef mempty
|
|
||||||
let conn = Connection
|
|
||||||
{ read = return "12345"
|
|
||||||
, unread = writeIORef restRef
|
|
||||||
, write = \_ -> return ()
|
|
||||||
, flush = return ()
|
|
||||||
, close = return ()
|
|
||||||
}
|
|
||||||
|
|
||||||
void $ readExactly conn 3
|
|
||||||
rest <- readIORef restRef
|
|
||||||
rest `shouldBe` "45"
|
|
||||||
|
|
||||||
it "should ask for more bytes if the first chunk is too small" $ do
|
|
||||||
let conn = Connection
|
|
||||||
{ read = return "12345"
|
|
||||||
, unread = \_ -> return ()
|
|
||||||
, write = \_ -> return ()
|
|
||||||
, flush = return ()
|
|
||||||
, close = return ()
|
|
||||||
}
|
|
||||||
|
|
||||||
res <- readExactly conn 8
|
|
||||||
res `shouldBe` "12345123"
|
|
||||||
|
|
||||||
it "should throw on EOF" $ do
|
|
||||||
let conn = Connection
|
|
||||||
{ read = return mempty
|
|
||||||
, unread = \_ -> return ()
|
|
||||||
, write = \_ -> return ()
|
|
||||||
, flush = return ()
|
|
||||||
, close = return ()
|
|
||||||
}
|
|
||||||
|
|
||||||
void $ readExactly conn 3
|
|
||||||
`shouldThrow` isEOFError
|
|
Loading…
Reference in a new issue