diff --git a/Database/MongoDB/Internal/Connection.hs b/Database/MongoDB/Internal/Connection.hs new file mode 100644 index 0000000..ddb23f8 --- /dev/null +++ b/Database/MongoDB/Internal/Connection.hs @@ -0,0 +1,76 @@ + +-- | This module defines a connection interface. It could be a regular +-- network connection, TLS connection, a mock or anything else. + +module Database.MongoDB.Internal.Connection ( + Connection(..), + readExactly, + writeLazy, + fromHandle, +) where + +import Prelude hiding (read) +import Data.Monoid +import Data.IORef +import Data.ByteString (ByteString) +import qualified Data.ByteString as ByteString +import qualified Data.ByteString.Lazy as Lazy (ByteString) +import qualified Data.ByteString.Lazy as Lazy.ByteString +import Control.Monad +import System.IO +import System.IO.Error (mkIOError, eofErrorType) + +-- | Abstract connection interface +-- +-- `read` should return `ByteString.null` on EOF +data Connection = Connection { + read :: IO ByteString, + unread :: ByteString -> IO (), + write :: ByteString -> IO (), + flush :: IO (), + close :: IO ()} + +readExactly :: Connection -> Int -> IO Lazy.ByteString +-- ^ Read specified number of bytes +-- +-- If EOF is reached before N bytes then raise EOF exception. +readExactly conn count = go mempty count + where + go acc n = do + -- read until get enough bytes + chunk <- read conn + when (ByteString.null chunk) $ + ioError eof + let len = ByteString.length chunk + if len >= n + then do + let (res, rest) = ByteString.splitAt n chunk + unless (ByteString.null rest) $ + unread conn rest + return (acc <> Lazy.ByteString.fromStrict res) + else go (acc <> Lazy.ByteString.fromStrict chunk) (n - len) + eof = mkIOError eofErrorType "Database.MongoDB.Internal.Connection" + Nothing Nothing + +writeLazy :: Connection -> Lazy.ByteString -> IO () +writeLazy conn = mapM_ (write conn) . Lazy.ByteString.toChunks + +fromHandle :: Handle -> IO Connection +-- ^ Make connection form handle +fromHandle handle = do + restRef <- newIORef mempty + return Connection + { read = do + rest <- readIORef restRef + writeIORef restRef mempty + if ByteString.null rest + -- 32k corresponds to the default chunk size + -- used in bytestring package + then ByteString.hGetSome handle (32 * 1024) + else return rest + , unread = \rest -> + modifyIORef restRef (rest <>) + , write = ByteString.hPut handle + , flush = hFlush handle + , close = hClose handle + } diff --git a/Database/MongoDB/Internal/Protocol.hs b/Database/MongoDB/Internal/Protocol.hs index 0305ae8..bd0a49b 100644 --- a/Database/MongoDB/Internal/Protocol.hs +++ b/Database/MongoDB/Internal/Protocol.hs @@ -11,7 +11,7 @@ module Database.MongoDB.Internal.Protocol ( FullCollection, -- * Pipe - Pipe, newPipe, send, call, + Pipe, newPipe, newPipeWith, send, call, -- ** Notice Notice(..), InsertOption(..), UpdateOption(..), DeleteOption(..), CursorId, -- ** Request @@ -30,7 +30,7 @@ import Data.Binary.Put (Put, runPut) import Data.Bits (bit, testBit) import Data.Int (Int32, Int64) import Data.IORef (IORef, newIORef, atomicModifyIORef) -import System.IO (Handle, hClose, hFlush) +import System.IO (Handle) import System.IO.Unsafe (unsafePerformIO) import qualified Data.ByteString.Lazy as L @@ -45,11 +45,14 @@ import qualified Crypto.Hash.MD5 as MD5 import qualified Data.Text as T import qualified Data.Text.Encoding as TE -import Database.MongoDB.Internal.Util (whenJust, hGetN, bitOr, byteStringHex) +import Database.MongoDB.Internal.Util (whenJust, bitOr, byteStringHex) import System.IO.Pipeline (Pipeline, newPipeline, IOStream(..)) import qualified System.IO.Pipeline as P +import Database.MongoDB.Internal.Connection (Connection) +import qualified Database.MongoDB.Internal.Connection as Connection + -- * Pipe type Pipe = Pipeline Response Message @@ -57,7 +60,13 @@ type Pipe = Pipeline Response Message newPipe :: Handle -> IO Pipe -- ^ Create pipe over handle -newPipe handle = newPipeline $ IOStream (writeMessage handle) (readMessage handle) (hClose handle) +newPipe handle = Connection.fromHandle handle >>= newPipeWith + +newPipeWith :: Connection -> IO Pipe +-- ^ Create pipe over connection +newPipeWith conn = newPipeline $ IOStream (writeMessage conn) + (readMessage conn) + (Connection.close conn) send :: Pipe -> [Notice] -> IO () -- ^ Send notices as a contiguous batch to server with no reply. Throw IOError if connection fails. @@ -79,16 +88,16 @@ type Message = ([Notice], Maybe (Request, RequestId)) -- ^ A write notice(s) with getLastError request, or just query request. -- Note, that requestId will be out of order because request ids will be generated for notices after the request id supplied was generated. This is ok because the mongo server does not care about order just uniqueness. -writeMessage :: Handle -> Message -> IO () --- ^ Write message to socket -writeMessage handle (notices, mRequest) = do +writeMessage :: Connection -> Message -> IO () +-- ^ Write message to connection +writeMessage conn (notices, mRequest) = do forM_ notices $ \n -> writeReq . (Left n,) =<< genRequestId whenJust mRequest $ writeReq . (Right *** id) - hFlush handle + Connection.flush conn where writeReq (e, requestId) = do - L.hPut handle lenBytes - L.hPut handle bytes + Connection.writeLazy conn lenBytes + Connection.writeLazy conn bytes where bytes = runPut $ (either putNotice putRequest e) requestId lenBytes = encodeSize . toEnum . fromEnum $ L.length bytes @@ -97,12 +106,12 @@ writeMessage handle (notices, mRequest) = do type Response = (ResponseTo, Reply) -- ^ Message received from a Mongo server in response to a Request -readMessage :: Handle -> IO Response --- ^ read response from socket -readMessage handle = readResp where +readMessage :: Connection -> IO Response +-- ^ read response from a connection +readMessage conn = readResp where readResp = do - len <- fromEnum . decodeSize <$> hGetN handle 4 - runGet getReply <$> hGetN handle len + len <- fromEnum . decodeSize <$> Connection.readExactly conn 4 + runGet getReply <$> Connection.readExactly conn len decodeSize = subtract 4 . runGet getInt32 type FullCollection = Text diff --git a/Database/MongoDB/Internal/Util.hs b/Database/MongoDB/Internal/Util.hs index 166107c..2ebc754 100644 --- a/Database/MongoDB/Internal/Util.hs +++ b/Database/MongoDB/Internal/Util.hs @@ -8,18 +8,15 @@ module Database.MongoDB.Internal.Util where import Control.Applicative ((<$>)) -import Control.Exception (assert, handle, throwIO, Exception) +import Control.Exception (handle, throwIO, Exception) import Control.Monad (liftM, liftM2) import Data.Bits (Bits, (.|.)) import Data.Word (Word8) import Network (PortID(..)) import Numeric (showHex) -import System.IO (Handle) -import System.IO.Error (mkIOError, eofErrorType) import System.Random (newStdGen) import System.Random.Shuffle (shuffle') -import qualified Data.ByteString.Lazy as L import qualified Data.ByteString as S import Control.Monad.Error (MonadError(..), Error(..)) @@ -108,15 +105,6 @@ true1 k doc = case valueAt k doc of Int64 n -> n == 1 _ -> error $ "expected " ++ show k ++ " to be Num or Bool in " ++ show doc -hGetN :: Handle -> Int -> IO L.ByteString --- ^ Read N bytes from hande, blocking until all N bytes are read. If EOF is reached before N bytes then raise EOF exception. -hGetN h n = assert (n >= 0) $ do - bytes <- L.hGet h n - let x = fromEnum $ L.length bytes - if x >= n then return bytes - else if x == 0 then ioError (mkIOError eofErrorType "hGetN" (Just h) Nothing) - else L.append bytes <$> hGetN h (n - x) - byteStringHex :: S.ByteString -> String -- ^ Hexadecimal string representation of a byte string. Each byte yields two hexadecimal characters. byteStringHex = concatMap byteHex . S.unpack diff --git a/mongoDB.cabal b/mongoDB.cabal index d2fd10c..ed1206b 100644 --- a/mongoDB.cabal +++ b/mongoDB.cabal @@ -44,6 +44,7 @@ Library Exposed-modules: Database.MongoDB Database.MongoDB.Admin Database.MongoDB.Connection + Database.MongoDB.Internal.Connection Database.MongoDB.Internal.Protocol Database.MongoDB.Internal.Util Database.MongoDB.Query diff --git a/test/Internal/ConnectionTest.hs b/test/Internal/ConnectionTest.hs new file mode 100644 index 0000000..dfe79f8 --- /dev/null +++ b/test/Internal/ConnectionTest.hs @@ -0,0 +1,69 @@ + +module Internal.ConnectionTest ( + spec, +) where + +import Prelude hiding (read) +import Data.Monoid +import Data.IORef +import Control.Monad +import System.IO.Error (isEOFError) +import Test.Hspec + +import Database.MongoDB.Internal.Connection + +spec :: Spec +spec = describe "Internal.Connection" $ do + readExactlySpec + +readExactlySpec :: Spec +readExactlySpec = describe "readExactly" $ do + it "should return specified number of bytes" $ do + let conn = Connection + { read = return "12345" + , unread = \_ -> return () + , write = \_ -> return () + , flush = return () + , close = return () + } + + res <- readExactly conn 3 + res `shouldBe` "123" + + it "should unread the rest" $ do + restRef <- newIORef mempty + let conn = Connection + { read = return "12345" + , unread = writeIORef restRef + , write = \_ -> return () + , flush = return () + , close = return () + } + + void $ readExactly conn 3 + rest <- readIORef restRef + rest `shouldBe` "45" + + it "should ask for more bytes if the first chunk is too small" $ do + let conn = Connection + { read = return "12345" + , unread = \_ -> return () + , write = \_ -> return () + , flush = return () + , close = return () + } + + res <- readExactly conn 8 + res `shouldBe` "12345123" + + it "should throw on EOF" $ do + let conn = Connection + { read = return mempty + , unread = \_ -> return () + , write = \_ -> return () + , flush = return () + , close = return () + } + + void $ readExactly conn 3 + `shouldThrow` isEOFError