diff --git a/Database/MongoDB/Internal/Tls.hs b/Database/MongoDB/Internal/Tls.hs index 191a1f4..3579a9b 100644 --- a/Database/MongoDB/Internal/Tls.hs +++ b/Database/MongoDB/Internal/Tls.hs @@ -3,7 +3,7 @@ -- | TLS connection to mongodb -module Bend.Database.Mongo.Tls +module Database.MongoDB.Internal.Tls ( connect, ) @@ -16,12 +16,15 @@ import qualified Data.Text as Text import qualified Data.ByteString as ByteString import qualified Data.ByteString.Lazy as Lazy.ByteString import Data.Default.Class (def) +import Control.Applicative ((<$>)) import Control.Exception (bracketOnError) +import Control.Monad (when, unless) import System.IO import Database.MongoDB (Pipe) import Database.MongoDB.Internal.Protocol (newPipeWith) import Database.MongoDB.Internal.Connection (Connection(Connection)) import qualified Database.MongoDB.Internal.Connection as Connection +import System.IO.Error (mkIOError, eofErrorType) import qualified Network import qualified Network.TLS as TLS import qualified Network.TLS.Extra.Cipher as TLS @@ -53,14 +56,31 @@ tlsConnection :: TLS.Context -> IO () -> IO Connection tlsConnection ctx close = do restRef <- newIORef mempty return Connection - { Connection.read = do - rest <- readIORef restRef - writeIORef restRef mempty - if ByteString.null rest - then TLS.recvData ctx - else return rest - , Connection.unread = \rest -> - modifyIORef restRef (rest <>) + { Connection.readExactly = \count -> let + readSome = do + rest <- readIORef restRef + writeIORef restRef mempty + if ByteString.null rest + then TLS.recvData ctx + else return rest + unread = \rest -> + modifyIORef restRef (rest <>) + go acc n = do + -- read until get enough bytes + chunk <- readSome + when (ByteString.null chunk) $ + ioError eof + let len = ByteString.length chunk + if len >= n + then do + let (res, rest) = ByteString.splitAt n chunk + unless (ByteString.null rest) $ + unread rest + return (acc <> Lazy.ByteString.fromStrict res) + else go (acc <> Lazy.ByteString.fromStrict chunk) (n - len) + eof = mkIOError eofErrorType "Database.MongoDB.Internal.Connection" + Nothing Nothing + in Lazy.ByteString.toStrict <$> go mempty count , Connection.write = TLS.sendData ctx . Lazy.ByteString.fromStrict , Connection.flush = TLS.contextFlush ctx , Connection.close = close diff --git a/mongoDB.cabal b/mongoDB.cabal index 884be35..e7cc087 100644 --- a/mongoDB.cabal +++ b/mongoDB.cabal @@ -34,11 +34,14 @@ Library , mtl >= 2 , cryptohash -any , network -any + , io-region -any , parsec -any , random -any , random-shuffle -any , monad-control >= 0.3.1 , lifted-base >= 0.1.0.3 + , tls >= 1.2.0 + , data-default-class -any , transformers-base >= 0.4.1 , hashtables >= 1.1.2.0 , base16-bytestring >= 0.1.1.6 @@ -49,6 +52,7 @@ Library Database.MongoDB.Admin Database.MongoDB.Connection Database.MongoDB.Internal.Connection + Database.MongoDB.Internal.Tls Database.MongoDB.Internal.Protocol Database.MongoDB.Internal.Util Database.MongoDB.Query diff --git a/test/Internal/ConnectionSpec.hs b/test/Internal/ConnectionSpec.hs deleted file mode 100644 index 89ef85b..0000000 --- a/test/Internal/ConnectionSpec.hs +++ /dev/null @@ -1,69 +0,0 @@ - -module Internal.ConnectionSpec ( - spec, -) where - -import Prelude hiding (read) -import Data.Monoid -import Data.IORef -import Control.Monad -import System.IO.Error (isEOFError) -import Test.Hspec - -import Database.MongoDB.Internal.Connection - -spec :: Spec -spec = describe "Internal.Connection" $ do - readExactlySpec - -readExactlySpec :: Spec -readExactlySpec = describe "readExactly" $ do - it "should return specified number of bytes" $ do - let conn = Connection - { read = return "12345" - , unread = \_ -> return () - , write = \_ -> return () - , flush = return () - , close = return () - } - - res <- readExactly conn 3 - res `shouldBe` "123" - - it "should unread the rest" $ do - restRef <- newIORef mempty - let conn = Connection - { read = return "12345" - , unread = writeIORef restRef - , write = \_ -> return () - , flush = return () - , close = return () - } - - void $ readExactly conn 3 - rest <- readIORef restRef - rest `shouldBe` "45" - - it "should ask for more bytes if the first chunk is too small" $ do - let conn = Connection - { read = return "12345" - , unread = \_ -> return () - , write = \_ -> return () - , flush = return () - , close = return () - } - - res <- readExactly conn 8 - res `shouldBe` "12345123" - - it "should throw on EOF" $ do - let conn = Connection - { read = return mempty - , unread = \_ -> return () - , write = \_ -> return () - , flush = return () - , close = return () - } - - void $ readExactly conn 3 - `shouldThrow` isEOFError