parent
a77370f2d9
commit
98bcc2dfe8
5 changed files with 171 additions and 28 deletions
76
Database/MongoDB/Internal/Connection.hs
Normal file
76
Database/MongoDB/Internal/Connection.hs
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
|
||||||
|
-- | This module defines a connection interface. It could be a regular
|
||||||
|
-- network connection, TLS connection, a mock or anything else.
|
||||||
|
|
||||||
|
module Database.MongoDB.Internal.Connection (
|
||||||
|
Connection(..),
|
||||||
|
readExactly,
|
||||||
|
writeLazy,
|
||||||
|
fromHandle,
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Prelude hiding (read)
|
||||||
|
import Data.Monoid
|
||||||
|
import Data.IORef
|
||||||
|
import Data.ByteString (ByteString)
|
||||||
|
import qualified Data.ByteString as ByteString
|
||||||
|
import qualified Data.ByteString.Lazy as Lazy (ByteString)
|
||||||
|
import qualified Data.ByteString.Lazy as Lazy.ByteString
|
||||||
|
import Control.Monad
|
||||||
|
import System.IO
|
||||||
|
import System.IO.Error (mkIOError, eofErrorType)
|
||||||
|
|
||||||
|
-- | Abstract connection interface
|
||||||
|
--
|
||||||
|
-- `read` should return `ByteString.null` on EOF
|
||||||
|
data Connection = Connection {
|
||||||
|
read :: IO ByteString,
|
||||||
|
unread :: ByteString -> IO (),
|
||||||
|
write :: ByteString -> IO (),
|
||||||
|
flush :: IO (),
|
||||||
|
close :: IO ()}
|
||||||
|
|
||||||
|
readExactly :: Connection -> Int -> IO Lazy.ByteString
|
||||||
|
-- ^ Read specified number of bytes
|
||||||
|
--
|
||||||
|
-- If EOF is reached before N bytes then raise EOF exception.
|
||||||
|
readExactly conn count = go mempty count
|
||||||
|
where
|
||||||
|
go acc n = do
|
||||||
|
-- read until get enough bytes
|
||||||
|
chunk <- read conn
|
||||||
|
when (ByteString.null chunk) $
|
||||||
|
ioError eof
|
||||||
|
let len = ByteString.length chunk
|
||||||
|
if len >= n
|
||||||
|
then do
|
||||||
|
let (res, rest) = ByteString.splitAt n chunk
|
||||||
|
unless (ByteString.null rest) $
|
||||||
|
unread conn rest
|
||||||
|
return (acc <> Lazy.ByteString.fromStrict res)
|
||||||
|
else go (acc <> Lazy.ByteString.fromStrict chunk) (n - len)
|
||||||
|
eof = mkIOError eofErrorType "Database.MongoDB.Internal.Connection"
|
||||||
|
Nothing Nothing
|
||||||
|
|
||||||
|
writeLazy :: Connection -> Lazy.ByteString -> IO ()
|
||||||
|
writeLazy conn = mapM_ (write conn) . Lazy.ByteString.toChunks
|
||||||
|
|
||||||
|
fromHandle :: Handle -> IO Connection
|
||||||
|
-- ^ Make connection form handle
|
||||||
|
fromHandle handle = do
|
||||||
|
restRef <- newIORef mempty
|
||||||
|
return Connection
|
||||||
|
{ read = do
|
||||||
|
rest <- readIORef restRef
|
||||||
|
writeIORef restRef mempty
|
||||||
|
if ByteString.null rest
|
||||||
|
-- 32k corresponds to the default chunk size
|
||||||
|
-- used in bytestring package
|
||||||
|
then ByteString.hGetSome handle (32 * 1024)
|
||||||
|
else return rest
|
||||||
|
, unread = \rest ->
|
||||||
|
modifyIORef restRef (rest <>)
|
||||||
|
, write = ByteString.hPut handle
|
||||||
|
, flush = hFlush handle
|
||||||
|
, close = hClose handle
|
||||||
|
}
|
|
@ -11,7 +11,7 @@
|
||||||
module Database.MongoDB.Internal.Protocol (
|
module Database.MongoDB.Internal.Protocol (
|
||||||
FullCollection,
|
FullCollection,
|
||||||
-- * Pipe
|
-- * Pipe
|
||||||
Pipe, newPipe, send, call,
|
Pipe, newPipe, newPipeWith, send, call,
|
||||||
-- ** Notice
|
-- ** Notice
|
||||||
Notice(..), InsertOption(..), UpdateOption(..), DeleteOption(..), CursorId,
|
Notice(..), InsertOption(..), UpdateOption(..), DeleteOption(..), CursorId,
|
||||||
-- ** Request
|
-- ** Request
|
||||||
|
@ -30,7 +30,7 @@ import Data.Binary.Put (Put, runPut)
|
||||||
import Data.Bits (bit, testBit)
|
import Data.Bits (bit, testBit)
|
||||||
import Data.Int (Int32, Int64)
|
import Data.Int (Int32, Int64)
|
||||||
import Data.IORef (IORef, newIORef, atomicModifyIORef)
|
import Data.IORef (IORef, newIORef, atomicModifyIORef)
|
||||||
import System.IO (Handle, hClose, hFlush)
|
import System.IO (Handle)
|
||||||
import System.IO.Unsafe (unsafePerformIO)
|
import System.IO.Unsafe (unsafePerformIO)
|
||||||
|
|
||||||
import qualified Data.ByteString.Lazy as L
|
import qualified Data.ByteString.Lazy as L
|
||||||
|
@ -45,11 +45,14 @@ import qualified Crypto.Hash.MD5 as MD5
|
||||||
import qualified Data.Text as T
|
import qualified Data.Text as T
|
||||||
import qualified Data.Text.Encoding as TE
|
import qualified Data.Text.Encoding as TE
|
||||||
|
|
||||||
import Database.MongoDB.Internal.Util (whenJust, hGetN, bitOr, byteStringHex)
|
import Database.MongoDB.Internal.Util (whenJust, bitOr, byteStringHex)
|
||||||
import System.IO.Pipeline (Pipeline, newPipeline, IOStream(..))
|
import System.IO.Pipeline (Pipeline, newPipeline, IOStream(..))
|
||||||
|
|
||||||
import qualified System.IO.Pipeline as P
|
import qualified System.IO.Pipeline as P
|
||||||
|
|
||||||
|
import Database.MongoDB.Internal.Connection (Connection)
|
||||||
|
import qualified Database.MongoDB.Internal.Connection as Connection
|
||||||
|
|
||||||
-- * Pipe
|
-- * Pipe
|
||||||
|
|
||||||
type Pipe = Pipeline Response Message
|
type Pipe = Pipeline Response Message
|
||||||
|
@ -57,7 +60,13 @@ type Pipe = Pipeline Response Message
|
||||||
|
|
||||||
newPipe :: Handle -> IO Pipe
|
newPipe :: Handle -> IO Pipe
|
||||||
-- ^ Create pipe over handle
|
-- ^ Create pipe over handle
|
||||||
newPipe handle = newPipeline $ IOStream (writeMessage handle) (readMessage handle) (hClose handle)
|
newPipe handle = Connection.fromHandle handle >>= newPipeWith
|
||||||
|
|
||||||
|
newPipeWith :: Connection -> IO Pipe
|
||||||
|
-- ^ Create pipe over connection
|
||||||
|
newPipeWith conn = newPipeline $ IOStream (writeMessage conn)
|
||||||
|
(readMessage conn)
|
||||||
|
(Connection.close conn)
|
||||||
|
|
||||||
send :: Pipe -> [Notice] -> IO ()
|
send :: Pipe -> [Notice] -> IO ()
|
||||||
-- ^ Send notices as a contiguous batch to server with no reply. Throw IOError if connection fails.
|
-- ^ Send notices as a contiguous batch to server with no reply. Throw IOError if connection fails.
|
||||||
|
@ -79,16 +88,16 @@ type Message = ([Notice], Maybe (Request, RequestId))
|
||||||
-- ^ A write notice(s) with getLastError request, or just query request.
|
-- ^ A write notice(s) with getLastError request, or just query request.
|
||||||
-- Note, that requestId will be out of order because request ids will be generated for notices after the request id supplied was generated. This is ok because the mongo server does not care about order just uniqueness.
|
-- Note, that requestId will be out of order because request ids will be generated for notices after the request id supplied was generated. This is ok because the mongo server does not care about order just uniqueness.
|
||||||
|
|
||||||
writeMessage :: Handle -> Message -> IO ()
|
writeMessage :: Connection -> Message -> IO ()
|
||||||
-- ^ Write message to socket
|
-- ^ Write message to connection
|
||||||
writeMessage handle (notices, mRequest) = do
|
writeMessage conn (notices, mRequest) = do
|
||||||
forM_ notices $ \n -> writeReq . (Left n,) =<< genRequestId
|
forM_ notices $ \n -> writeReq . (Left n,) =<< genRequestId
|
||||||
whenJust mRequest $ writeReq . (Right *** id)
|
whenJust mRequest $ writeReq . (Right *** id)
|
||||||
hFlush handle
|
Connection.flush conn
|
||||||
where
|
where
|
||||||
writeReq (e, requestId) = do
|
writeReq (e, requestId) = do
|
||||||
L.hPut handle lenBytes
|
Connection.writeLazy conn lenBytes
|
||||||
L.hPut handle bytes
|
Connection.writeLazy conn bytes
|
||||||
where
|
where
|
||||||
bytes = runPut $ (either putNotice putRequest e) requestId
|
bytes = runPut $ (either putNotice putRequest e) requestId
|
||||||
lenBytes = encodeSize . toEnum . fromEnum $ L.length bytes
|
lenBytes = encodeSize . toEnum . fromEnum $ L.length bytes
|
||||||
|
@ -97,12 +106,12 @@ writeMessage handle (notices, mRequest) = do
|
||||||
type Response = (ResponseTo, Reply)
|
type Response = (ResponseTo, Reply)
|
||||||
-- ^ Message received from a Mongo server in response to a Request
|
-- ^ Message received from a Mongo server in response to a Request
|
||||||
|
|
||||||
readMessage :: Handle -> IO Response
|
readMessage :: Connection -> IO Response
|
||||||
-- ^ read response from socket
|
-- ^ read response from a connection
|
||||||
readMessage handle = readResp where
|
readMessage conn = readResp where
|
||||||
readResp = do
|
readResp = do
|
||||||
len <- fromEnum . decodeSize <$> hGetN handle 4
|
len <- fromEnum . decodeSize <$> Connection.readExactly conn 4
|
||||||
runGet getReply <$> hGetN handle len
|
runGet getReply <$> Connection.readExactly conn len
|
||||||
decodeSize = subtract 4 . runGet getInt32
|
decodeSize = subtract 4 . runGet getInt32
|
||||||
|
|
||||||
type FullCollection = Text
|
type FullCollection = Text
|
||||||
|
|
|
@ -8,18 +8,15 @@
|
||||||
module Database.MongoDB.Internal.Util where
|
module Database.MongoDB.Internal.Util where
|
||||||
|
|
||||||
import Control.Applicative ((<$>))
|
import Control.Applicative ((<$>))
|
||||||
import Control.Exception (assert, handle, throwIO, Exception)
|
import Control.Exception (handle, throwIO, Exception)
|
||||||
import Control.Monad (liftM, liftM2)
|
import Control.Monad (liftM, liftM2)
|
||||||
import Data.Bits (Bits, (.|.))
|
import Data.Bits (Bits, (.|.))
|
||||||
import Data.Word (Word8)
|
import Data.Word (Word8)
|
||||||
import Network (PortID(..))
|
import Network (PortID(..))
|
||||||
import Numeric (showHex)
|
import Numeric (showHex)
|
||||||
import System.IO (Handle)
|
|
||||||
import System.IO.Error (mkIOError, eofErrorType)
|
|
||||||
import System.Random (newStdGen)
|
import System.Random (newStdGen)
|
||||||
import System.Random.Shuffle (shuffle')
|
import System.Random.Shuffle (shuffle')
|
||||||
|
|
||||||
import qualified Data.ByteString.Lazy as L
|
|
||||||
import qualified Data.ByteString as S
|
import qualified Data.ByteString as S
|
||||||
|
|
||||||
import Control.Monad.Error (MonadError(..), Error(..))
|
import Control.Monad.Error (MonadError(..), Error(..))
|
||||||
|
@ -108,15 +105,6 @@ true1 k doc = case valueAt k doc of
|
||||||
Int64 n -> n == 1
|
Int64 n -> n == 1
|
||||||
_ -> error $ "expected " ++ show k ++ " to be Num or Bool in " ++ show doc
|
_ -> error $ "expected " ++ show k ++ " to be Num or Bool in " ++ show doc
|
||||||
|
|
||||||
hGetN :: Handle -> Int -> IO L.ByteString
|
|
||||||
-- ^ Read N bytes from hande, blocking until all N bytes are read. If EOF is reached before N bytes then raise EOF exception.
|
|
||||||
hGetN h n = assert (n >= 0) $ do
|
|
||||||
bytes <- L.hGet h n
|
|
||||||
let x = fromEnum $ L.length bytes
|
|
||||||
if x >= n then return bytes
|
|
||||||
else if x == 0 then ioError (mkIOError eofErrorType "hGetN" (Just h) Nothing)
|
|
||||||
else L.append bytes <$> hGetN h (n - x)
|
|
||||||
|
|
||||||
byteStringHex :: S.ByteString -> String
|
byteStringHex :: S.ByteString -> String
|
||||||
-- ^ Hexadecimal string representation of a byte string. Each byte yields two hexadecimal characters.
|
-- ^ Hexadecimal string representation of a byte string. Each byte yields two hexadecimal characters.
|
||||||
byteStringHex = concatMap byteHex . S.unpack
|
byteStringHex = concatMap byteHex . S.unpack
|
||||||
|
|
|
@ -44,6 +44,7 @@ Library
|
||||||
Exposed-modules: Database.MongoDB
|
Exposed-modules: Database.MongoDB
|
||||||
Database.MongoDB.Admin
|
Database.MongoDB.Admin
|
||||||
Database.MongoDB.Connection
|
Database.MongoDB.Connection
|
||||||
|
Database.MongoDB.Internal.Connection
|
||||||
Database.MongoDB.Internal.Protocol
|
Database.MongoDB.Internal.Protocol
|
||||||
Database.MongoDB.Internal.Util
|
Database.MongoDB.Internal.Util
|
||||||
Database.MongoDB.Query
|
Database.MongoDB.Query
|
||||||
|
|
69
test/Internal/ConnectionTest.hs
Normal file
69
test/Internal/ConnectionTest.hs
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
|
||||||
|
module Internal.ConnectionTest (
|
||||||
|
spec,
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Prelude hiding (read)
|
||||||
|
import Data.Monoid
|
||||||
|
import Data.IORef
|
||||||
|
import Control.Monad
|
||||||
|
import System.IO.Error (isEOFError)
|
||||||
|
import Test.Hspec
|
||||||
|
|
||||||
|
import Database.MongoDB.Internal.Connection
|
||||||
|
|
||||||
|
spec :: Spec
|
||||||
|
spec = describe "Internal.Connection" $ do
|
||||||
|
readExactlySpec
|
||||||
|
|
||||||
|
readExactlySpec :: Spec
|
||||||
|
readExactlySpec = describe "readExactly" $ do
|
||||||
|
it "should return specified number of bytes" $ do
|
||||||
|
let conn = Connection
|
||||||
|
{ read = return "12345"
|
||||||
|
, unread = \_ -> return ()
|
||||||
|
, write = \_ -> return ()
|
||||||
|
, flush = return ()
|
||||||
|
, close = return ()
|
||||||
|
}
|
||||||
|
|
||||||
|
res <- readExactly conn 3
|
||||||
|
res `shouldBe` "123"
|
||||||
|
|
||||||
|
it "should unread the rest" $ do
|
||||||
|
restRef <- newIORef mempty
|
||||||
|
let conn = Connection
|
||||||
|
{ read = return "12345"
|
||||||
|
, unread = writeIORef restRef
|
||||||
|
, write = \_ -> return ()
|
||||||
|
, flush = return ()
|
||||||
|
, close = return ()
|
||||||
|
}
|
||||||
|
|
||||||
|
void $ readExactly conn 3
|
||||||
|
rest <- readIORef restRef
|
||||||
|
rest `shouldBe` "45"
|
||||||
|
|
||||||
|
it "should ask for more bytes if the first chunk is too small" $ do
|
||||||
|
let conn = Connection
|
||||||
|
{ read = return "12345"
|
||||||
|
, unread = \_ -> return ()
|
||||||
|
, write = \_ -> return ()
|
||||||
|
, flush = return ()
|
||||||
|
, close = return ()
|
||||||
|
}
|
||||||
|
|
||||||
|
res <- readExactly conn 8
|
||||||
|
res `shouldBe` "12345123"
|
||||||
|
|
||||||
|
it "should throw on EOF" $ do
|
||||||
|
let conn = Connection
|
||||||
|
{ read = return mempty
|
||||||
|
, unread = \_ -> return ()
|
||||||
|
, write = \_ -> return ()
|
||||||
|
, flush = return ()
|
||||||
|
, close = return ()
|
||||||
|
}
|
||||||
|
|
||||||
|
void $ readExactly conn 3
|
||||||
|
`shouldThrow` isEOFError
|
Loading…
Reference in a new issue