abstract connection interface

rebase #13 to master
This commit is contained in:
Greg Weber 2015-03-05 11:20:02 -08:00
parent a77370f2d9
commit 98bcc2dfe8
5 changed files with 171 additions and 28 deletions

View file

@ -0,0 +1,76 @@
-- | This module defines a connection interface. It could be a regular
-- network connection, TLS connection, a mock or anything else.
module Database.MongoDB.Internal.Connection (
Connection(..),
readExactly,
writeLazy,
fromHandle,
) where
import Prelude hiding (read)
import Data.Monoid
import Data.IORef
import Data.ByteString (ByteString)
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Lazy as Lazy (ByteString)
import qualified Data.ByteString.Lazy as Lazy.ByteString
import Control.Monad
import System.IO
import System.IO.Error (mkIOError, eofErrorType)
-- | Abstract connection interface
--
-- `read` should return `ByteString.null` on EOF
data Connection = Connection {
read :: IO ByteString,
unread :: ByteString -> IO (),
write :: ByteString -> IO (),
flush :: IO (),
close :: IO ()}
readExactly :: Connection -> Int -> IO Lazy.ByteString
-- ^ Read specified number of bytes
--
-- If EOF is reached before N bytes then raise EOF exception.
readExactly conn count = go mempty count
where
go acc n = do
-- read until get enough bytes
chunk <- read conn
when (ByteString.null chunk) $
ioError eof
let len = ByteString.length chunk
if len >= n
then do
let (res, rest) = ByteString.splitAt n chunk
unless (ByteString.null rest) $
unread conn rest
return (acc <> Lazy.ByteString.fromStrict res)
else go (acc <> Lazy.ByteString.fromStrict chunk) (n - len)
eof = mkIOError eofErrorType "Database.MongoDB.Internal.Connection"
Nothing Nothing
writeLazy :: Connection -> Lazy.ByteString -> IO ()
writeLazy conn = mapM_ (write conn) . Lazy.ByteString.toChunks
fromHandle :: Handle -> IO Connection
-- ^ Make connection form handle
fromHandle handle = do
restRef <- newIORef mempty
return Connection
{ read = do
rest <- readIORef restRef
writeIORef restRef mempty
if ByteString.null rest
-- 32k corresponds to the default chunk size
-- used in bytestring package
then ByteString.hGetSome handle (32 * 1024)
else return rest
, unread = \rest ->
modifyIORef restRef (rest <>)
, write = ByteString.hPut handle
, flush = hFlush handle
, close = hClose handle
}

View file

@ -11,7 +11,7 @@
module Database.MongoDB.Internal.Protocol (
FullCollection,
-- * Pipe
Pipe, newPipe, send, call,
Pipe, newPipe, newPipeWith, send, call,
-- ** Notice
Notice(..), InsertOption(..), UpdateOption(..), DeleteOption(..), CursorId,
-- ** Request
@ -30,7 +30,7 @@ import Data.Binary.Put (Put, runPut)
import Data.Bits (bit, testBit)
import Data.Int (Int32, Int64)
import Data.IORef (IORef, newIORef, atomicModifyIORef)
import System.IO (Handle, hClose, hFlush)
import System.IO (Handle)
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString.Lazy as L
@ -45,11 +45,14 @@ import qualified Crypto.Hash.MD5 as MD5
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import Database.MongoDB.Internal.Util (whenJust, hGetN, bitOr, byteStringHex)
import Database.MongoDB.Internal.Util (whenJust, bitOr, byteStringHex)
import System.IO.Pipeline (Pipeline, newPipeline, IOStream(..))
import qualified System.IO.Pipeline as P
import Database.MongoDB.Internal.Connection (Connection)
import qualified Database.MongoDB.Internal.Connection as Connection
-- * Pipe
type Pipe = Pipeline Response Message
@ -57,7 +60,13 @@ type Pipe = Pipeline Response Message
newPipe :: Handle -> IO Pipe
-- ^ Create pipe over handle
newPipe handle = newPipeline $ IOStream (writeMessage handle) (readMessage handle) (hClose handle)
newPipe handle = Connection.fromHandle handle >>= newPipeWith
newPipeWith :: Connection -> IO Pipe
-- ^ Create pipe over connection
newPipeWith conn = newPipeline $ IOStream (writeMessage conn)
(readMessage conn)
(Connection.close conn)
send :: Pipe -> [Notice] -> IO ()
-- ^ Send notices as a contiguous batch to server with no reply. Throw IOError if connection fails.
@ -79,16 +88,16 @@ type Message = ([Notice], Maybe (Request, RequestId))
-- ^ A write notice(s) with getLastError request, or just query request.
-- Note, that requestId will be out of order because request ids will be generated for notices after the request id supplied was generated. This is ok because the mongo server does not care about order just uniqueness.
writeMessage :: Handle -> Message -> IO ()
-- ^ Write message to socket
writeMessage handle (notices, mRequest) = do
writeMessage :: Connection -> Message -> IO ()
-- ^ Write message to connection
writeMessage conn (notices, mRequest) = do
forM_ notices $ \n -> writeReq . (Left n,) =<< genRequestId
whenJust mRequest $ writeReq . (Right *** id)
hFlush handle
Connection.flush conn
where
writeReq (e, requestId) = do
L.hPut handle lenBytes
L.hPut handle bytes
Connection.writeLazy conn lenBytes
Connection.writeLazy conn bytes
where
bytes = runPut $ (either putNotice putRequest e) requestId
lenBytes = encodeSize . toEnum . fromEnum $ L.length bytes
@ -97,12 +106,12 @@ writeMessage handle (notices, mRequest) = do
type Response = (ResponseTo, Reply)
-- ^ Message received from a Mongo server in response to a Request
readMessage :: Handle -> IO Response
-- ^ read response from socket
readMessage handle = readResp where
readMessage :: Connection -> IO Response
-- ^ read response from a connection
readMessage conn = readResp where
readResp = do
len <- fromEnum . decodeSize <$> hGetN handle 4
runGet getReply <$> hGetN handle len
len <- fromEnum . decodeSize <$> Connection.readExactly conn 4
runGet getReply <$> Connection.readExactly conn len
decodeSize = subtract 4 . runGet getInt32
type FullCollection = Text

View file

@ -8,18 +8,15 @@
module Database.MongoDB.Internal.Util where
import Control.Applicative ((<$>))
import Control.Exception (assert, handle, throwIO, Exception)
import Control.Exception (handle, throwIO, Exception)
import Control.Monad (liftM, liftM2)
import Data.Bits (Bits, (.|.))
import Data.Word (Word8)
import Network (PortID(..))
import Numeric (showHex)
import System.IO (Handle)
import System.IO.Error (mkIOError, eofErrorType)
import System.Random (newStdGen)
import System.Random.Shuffle (shuffle')
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString as S
import Control.Monad.Error (MonadError(..), Error(..))
@ -108,15 +105,6 @@ true1 k doc = case valueAt k doc of
Int64 n -> n == 1
_ -> error $ "expected " ++ show k ++ " to be Num or Bool in " ++ show doc
hGetN :: Handle -> Int -> IO L.ByteString
-- ^ Read N bytes from hande, blocking until all N bytes are read. If EOF is reached before N bytes then raise EOF exception.
hGetN h n = assert (n >= 0) $ do
bytes <- L.hGet h n
let x = fromEnum $ L.length bytes
if x >= n then return bytes
else if x == 0 then ioError (mkIOError eofErrorType "hGetN" (Just h) Nothing)
else L.append bytes <$> hGetN h (n - x)
byteStringHex :: S.ByteString -> String
-- ^ Hexadecimal string representation of a byte string. Each byte yields two hexadecimal characters.
byteStringHex = concatMap byteHex . S.unpack

View file

@ -44,6 +44,7 @@ Library
Exposed-modules: Database.MongoDB
Database.MongoDB.Admin
Database.MongoDB.Connection
Database.MongoDB.Internal.Connection
Database.MongoDB.Internal.Protocol
Database.MongoDB.Internal.Util
Database.MongoDB.Query

View file

@ -0,0 +1,69 @@
module Internal.ConnectionTest (
spec,
) where
import Prelude hiding (read)
import Data.Monoid
import Data.IORef
import Control.Monad
import System.IO.Error (isEOFError)
import Test.Hspec
import Database.MongoDB.Internal.Connection
spec :: Spec
spec = describe "Internal.Connection" $ do
readExactlySpec
readExactlySpec :: Spec
readExactlySpec = describe "readExactly" $ do
it "should return specified number of bytes" $ do
let conn = Connection
{ read = return "12345"
, unread = \_ -> return ()
, write = \_ -> return ()
, flush = return ()
, close = return ()
}
res <- readExactly conn 3
res `shouldBe` "123"
it "should unread the rest" $ do
restRef <- newIORef mempty
let conn = Connection
{ read = return "12345"
, unread = writeIORef restRef
, write = \_ -> return ()
, flush = return ()
, close = return ()
}
void $ readExactly conn 3
rest <- readIORef restRef
rest `shouldBe` "45"
it "should ask for more bytes if the first chunk is too small" $ do
let conn = Connection
{ read = return "12345"
, unread = \_ -> return ()
, write = \_ -> return ()
, flush = return ()
, close = return ()
}
res <- readExactly conn 8
res `shouldBe` "12345123"
it "should throw on EOF" $ do
let conn = Connection
{ read = return mempty
, unread = \_ -> return ()
, write = \_ -> return ()
, flush = return ()
, close = return ()
}
void $ readExactly conn 3
`shouldThrow` isEOFError