abstract connection interface

rebase #13 to master
This commit is contained in:
Greg Weber 2015-03-05 11:20:02 -08:00
parent a77370f2d9
commit 98bcc2dfe8
5 changed files with 171 additions and 28 deletions

View file

@ -0,0 +1,76 @@
-- | This module defines a connection interface. It could be a regular
-- network connection, TLS connection, a mock or anything else.
module Database.MongoDB.Internal.Connection (
Connection(..),
readExactly,
writeLazy,
fromHandle,
) where
import Prelude hiding (read)
import Data.Monoid
import Data.IORef
import Data.ByteString (ByteString)
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Lazy as Lazy (ByteString)
import qualified Data.ByteString.Lazy as Lazy.ByteString
import Control.Monad
import System.IO
import System.IO.Error (mkIOError, eofErrorType)
-- | Abstract connection interface
--
-- `read` should return `ByteString.null` on EOF
data Connection = Connection {
read :: IO ByteString,
unread :: ByteString -> IO (),
write :: ByteString -> IO (),
flush :: IO (),
close :: IO ()}
readExactly :: Connection -> Int -> IO Lazy.ByteString
-- ^ Read specified number of bytes
--
-- If EOF is reached before N bytes then raise EOF exception.
readExactly conn count = go mempty count
where
go acc n = do
-- read until get enough bytes
chunk <- read conn
when (ByteString.null chunk) $
ioError eof
let len = ByteString.length chunk
if len >= n
then do
let (res, rest) = ByteString.splitAt n chunk
unless (ByteString.null rest) $
unread conn rest
return (acc <> Lazy.ByteString.fromStrict res)
else go (acc <> Lazy.ByteString.fromStrict chunk) (n - len)
eof = mkIOError eofErrorType "Database.MongoDB.Internal.Connection"
Nothing Nothing
writeLazy :: Connection -> Lazy.ByteString -> IO ()
writeLazy conn = mapM_ (write conn) . Lazy.ByteString.toChunks
fromHandle :: Handle -> IO Connection
-- ^ Make connection form handle
fromHandle handle = do
restRef <- newIORef mempty
return Connection
{ read = do
rest <- readIORef restRef
writeIORef restRef mempty
if ByteString.null rest
-- 32k corresponds to the default chunk size
-- used in bytestring package
then ByteString.hGetSome handle (32 * 1024)
else return rest
, unread = \rest ->
modifyIORef restRef (rest <>)
, write = ByteString.hPut handle
, flush = hFlush handle
, close = hClose handle
}

View file

@ -11,7 +11,7 @@
module Database.MongoDB.Internal.Protocol ( module Database.MongoDB.Internal.Protocol (
FullCollection, FullCollection,
-- * Pipe -- * Pipe
Pipe, newPipe, send, call, Pipe, newPipe, newPipeWith, send, call,
-- ** Notice -- ** Notice
Notice(..), InsertOption(..), UpdateOption(..), DeleteOption(..), CursorId, Notice(..), InsertOption(..), UpdateOption(..), DeleteOption(..), CursorId,
-- ** Request -- ** Request
@ -30,7 +30,7 @@ import Data.Binary.Put (Put, runPut)
import Data.Bits (bit, testBit) import Data.Bits (bit, testBit)
import Data.Int (Int32, Int64) import Data.Int (Int32, Int64)
import Data.IORef (IORef, newIORef, atomicModifyIORef) import Data.IORef (IORef, newIORef, atomicModifyIORef)
import System.IO (Handle, hClose, hFlush) import System.IO (Handle)
import System.IO.Unsafe (unsafePerformIO) import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString.Lazy as L import qualified Data.ByteString.Lazy as L
@ -45,11 +45,14 @@ import qualified Crypto.Hash.MD5 as MD5
import qualified Data.Text as T import qualified Data.Text as T
import qualified Data.Text.Encoding as TE import qualified Data.Text.Encoding as TE
import Database.MongoDB.Internal.Util (whenJust, hGetN, bitOr, byteStringHex) import Database.MongoDB.Internal.Util (whenJust, bitOr, byteStringHex)
import System.IO.Pipeline (Pipeline, newPipeline, IOStream(..)) import System.IO.Pipeline (Pipeline, newPipeline, IOStream(..))
import qualified System.IO.Pipeline as P import qualified System.IO.Pipeline as P
import Database.MongoDB.Internal.Connection (Connection)
import qualified Database.MongoDB.Internal.Connection as Connection
-- * Pipe -- * Pipe
type Pipe = Pipeline Response Message type Pipe = Pipeline Response Message
@ -57,7 +60,13 @@ type Pipe = Pipeline Response Message
newPipe :: Handle -> IO Pipe newPipe :: Handle -> IO Pipe
-- ^ Create pipe over handle -- ^ Create pipe over handle
newPipe handle = newPipeline $ IOStream (writeMessage handle) (readMessage handle) (hClose handle) newPipe handle = Connection.fromHandle handle >>= newPipeWith
newPipeWith :: Connection -> IO Pipe
-- ^ Create pipe over connection
newPipeWith conn = newPipeline $ IOStream (writeMessage conn)
(readMessage conn)
(Connection.close conn)
send :: Pipe -> [Notice] -> IO () send :: Pipe -> [Notice] -> IO ()
-- ^ Send notices as a contiguous batch to server with no reply. Throw IOError if connection fails. -- ^ Send notices as a contiguous batch to server with no reply. Throw IOError if connection fails.
@ -79,16 +88,16 @@ type Message = ([Notice], Maybe (Request, RequestId))
-- ^ A write notice(s) with getLastError request, or just query request. -- ^ A write notice(s) with getLastError request, or just query request.
-- Note, that requestId will be out of order because request ids will be generated for notices after the request id supplied was generated. This is ok because the mongo server does not care about order just uniqueness. -- Note, that requestId will be out of order because request ids will be generated for notices after the request id supplied was generated. This is ok because the mongo server does not care about order just uniqueness.
writeMessage :: Handle -> Message -> IO () writeMessage :: Connection -> Message -> IO ()
-- ^ Write message to socket -- ^ Write message to connection
writeMessage handle (notices, mRequest) = do writeMessage conn (notices, mRequest) = do
forM_ notices $ \n -> writeReq . (Left n,) =<< genRequestId forM_ notices $ \n -> writeReq . (Left n,) =<< genRequestId
whenJust mRequest $ writeReq . (Right *** id) whenJust mRequest $ writeReq . (Right *** id)
hFlush handle Connection.flush conn
where where
writeReq (e, requestId) = do writeReq (e, requestId) = do
L.hPut handle lenBytes Connection.writeLazy conn lenBytes
L.hPut handle bytes Connection.writeLazy conn bytes
where where
bytes = runPut $ (either putNotice putRequest e) requestId bytes = runPut $ (either putNotice putRequest e) requestId
lenBytes = encodeSize . toEnum . fromEnum $ L.length bytes lenBytes = encodeSize . toEnum . fromEnum $ L.length bytes
@ -97,12 +106,12 @@ writeMessage handle (notices, mRequest) = do
type Response = (ResponseTo, Reply) type Response = (ResponseTo, Reply)
-- ^ Message received from a Mongo server in response to a Request -- ^ Message received from a Mongo server in response to a Request
readMessage :: Handle -> IO Response readMessage :: Connection -> IO Response
-- ^ read response from socket -- ^ read response from a connection
readMessage handle = readResp where readMessage conn = readResp where
readResp = do readResp = do
len <- fromEnum . decodeSize <$> hGetN handle 4 len <- fromEnum . decodeSize <$> Connection.readExactly conn 4
runGet getReply <$> hGetN handle len runGet getReply <$> Connection.readExactly conn len
decodeSize = subtract 4 . runGet getInt32 decodeSize = subtract 4 . runGet getInt32
type FullCollection = Text type FullCollection = Text

View file

@ -8,18 +8,15 @@
module Database.MongoDB.Internal.Util where module Database.MongoDB.Internal.Util where
import Control.Applicative ((<$>)) import Control.Applicative ((<$>))
import Control.Exception (assert, handle, throwIO, Exception) import Control.Exception (handle, throwIO, Exception)
import Control.Monad (liftM, liftM2) import Control.Monad (liftM, liftM2)
import Data.Bits (Bits, (.|.)) import Data.Bits (Bits, (.|.))
import Data.Word (Word8) import Data.Word (Word8)
import Network (PortID(..)) import Network (PortID(..))
import Numeric (showHex) import Numeric (showHex)
import System.IO (Handle)
import System.IO.Error (mkIOError, eofErrorType)
import System.Random (newStdGen) import System.Random (newStdGen)
import System.Random.Shuffle (shuffle') import System.Random.Shuffle (shuffle')
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString as S import qualified Data.ByteString as S
import Control.Monad.Error (MonadError(..), Error(..)) import Control.Monad.Error (MonadError(..), Error(..))
@ -108,15 +105,6 @@ true1 k doc = case valueAt k doc of
Int64 n -> n == 1 Int64 n -> n == 1
_ -> error $ "expected " ++ show k ++ " to be Num or Bool in " ++ show doc _ -> error $ "expected " ++ show k ++ " to be Num or Bool in " ++ show doc
hGetN :: Handle -> Int -> IO L.ByteString
-- ^ Read N bytes from hande, blocking until all N bytes are read. If EOF is reached before N bytes then raise EOF exception.
hGetN h n = assert (n >= 0) $ do
bytes <- L.hGet h n
let x = fromEnum $ L.length bytes
if x >= n then return bytes
else if x == 0 then ioError (mkIOError eofErrorType "hGetN" (Just h) Nothing)
else L.append bytes <$> hGetN h (n - x)
byteStringHex :: S.ByteString -> String byteStringHex :: S.ByteString -> String
-- ^ Hexadecimal string representation of a byte string. Each byte yields two hexadecimal characters. -- ^ Hexadecimal string representation of a byte string. Each byte yields two hexadecimal characters.
byteStringHex = concatMap byteHex . S.unpack byteStringHex = concatMap byteHex . S.unpack

View file

@ -44,6 +44,7 @@ Library
Exposed-modules: Database.MongoDB Exposed-modules: Database.MongoDB
Database.MongoDB.Admin Database.MongoDB.Admin
Database.MongoDB.Connection Database.MongoDB.Connection
Database.MongoDB.Internal.Connection
Database.MongoDB.Internal.Protocol Database.MongoDB.Internal.Protocol
Database.MongoDB.Internal.Util Database.MongoDB.Internal.Util
Database.MongoDB.Query Database.MongoDB.Query

View file

@ -0,0 +1,69 @@
module Internal.ConnectionTest (
spec,
) where
import Prelude hiding (read)
import Data.Monoid
import Data.IORef
import Control.Monad
import System.IO.Error (isEOFError)
import Test.Hspec
import Database.MongoDB.Internal.Connection
spec :: Spec
spec = describe "Internal.Connection" $ do
readExactlySpec
readExactlySpec :: Spec
readExactlySpec = describe "readExactly" $ do
it "should return specified number of bytes" $ do
let conn = Connection
{ read = return "12345"
, unread = \_ -> return ()
, write = \_ -> return ()
, flush = return ()
, close = return ()
}
res <- readExactly conn 3
res `shouldBe` "123"
it "should unread the rest" $ do
restRef <- newIORef mempty
let conn = Connection
{ read = return "12345"
, unread = writeIORef restRef
, write = \_ -> return ()
, flush = return ()
, close = return ()
}
void $ readExactly conn 3
rest <- readIORef restRef
rest `shouldBe` "45"
it "should ask for more bytes if the first chunk is too small" $ do
let conn = Connection
{ read = return "12345"
, unread = \_ -> return ()
, write = \_ -> return ()
, flush = return ()
, close = return ()
}
res <- readExactly conn 8
res `shouldBe` "12345123"
it "should throw on EOF" $ do
let conn = Connection
{ read = return mempty
, unread = \_ -> return ()
, write = \_ -> return ()
, flush = return ()
, close = return ()
}
void $ readExactly conn 3
`shouldThrow` isEOFError