Incorporate Tls implementation
This commit is contained in:
parent
19e631c9f4
commit
f956cb2623
3 changed files with 33 additions and 78 deletions
|
@ -3,7 +3,7 @@
|
|||
|
||||
-- | TLS connection to mongodb
|
||||
|
||||
module Bend.Database.Mongo.Tls
|
||||
module Database.MongoDB.Internal.Tls
|
||||
(
|
||||
connect,
|
||||
)
|
||||
|
@ -16,12 +16,15 @@ import qualified Data.Text as Text
|
|||
import qualified Data.ByteString as ByteString
|
||||
import qualified Data.ByteString.Lazy as Lazy.ByteString
|
||||
import Data.Default.Class (def)
|
||||
import Control.Applicative ((<$>))
|
||||
import Control.Exception (bracketOnError)
|
||||
import Control.Monad (when, unless)
|
||||
import System.IO
|
||||
import Database.MongoDB (Pipe)
|
||||
import Database.MongoDB.Internal.Protocol (newPipeWith)
|
||||
import Database.MongoDB.Internal.Connection (Connection(Connection))
|
||||
import qualified Database.MongoDB.Internal.Connection as Connection
|
||||
import System.IO.Error (mkIOError, eofErrorType)
|
||||
import qualified Network
|
||||
import qualified Network.TLS as TLS
|
||||
import qualified Network.TLS.Extra.Cipher as TLS
|
||||
|
@ -53,14 +56,31 @@ tlsConnection :: TLS.Context -> IO () -> IO Connection
|
|||
tlsConnection ctx close = do
|
||||
restRef <- newIORef mempty
|
||||
return Connection
|
||||
{ Connection.read = do
|
||||
{ Connection.readExactly = \count -> let
|
||||
readSome = do
|
||||
rest <- readIORef restRef
|
||||
writeIORef restRef mempty
|
||||
if ByteString.null rest
|
||||
then TLS.recvData ctx
|
||||
else return rest
|
||||
, Connection.unread = \rest ->
|
||||
unread = \rest ->
|
||||
modifyIORef restRef (rest <>)
|
||||
go acc n = do
|
||||
-- read until get enough bytes
|
||||
chunk <- readSome
|
||||
when (ByteString.null chunk) $
|
||||
ioError eof
|
||||
let len = ByteString.length chunk
|
||||
if len >= n
|
||||
then do
|
||||
let (res, rest) = ByteString.splitAt n chunk
|
||||
unless (ByteString.null rest) $
|
||||
unread rest
|
||||
return (acc <> Lazy.ByteString.fromStrict res)
|
||||
else go (acc <> Lazy.ByteString.fromStrict chunk) (n - len)
|
||||
eof = mkIOError eofErrorType "Database.MongoDB.Internal.Connection"
|
||||
Nothing Nothing
|
||||
in Lazy.ByteString.toStrict <$> go mempty count
|
||||
, Connection.write = TLS.sendData ctx . Lazy.ByteString.fromStrict
|
||||
, Connection.flush = TLS.contextFlush ctx
|
||||
, Connection.close = close
|
||||
|
|
|
@ -34,11 +34,14 @@ Library
|
|||
, mtl >= 2
|
||||
, cryptohash -any
|
||||
, network -any
|
||||
, io-region -any
|
||||
, parsec -any
|
||||
, random -any
|
||||
, random-shuffle -any
|
||||
, monad-control >= 0.3.1
|
||||
, lifted-base >= 0.1.0.3
|
||||
, tls >= 1.2.0
|
||||
, data-default-class -any
|
||||
, transformers-base >= 0.4.1
|
||||
, hashtables >= 1.1.2.0
|
||||
, base16-bytestring >= 0.1.1.6
|
||||
|
@ -49,6 +52,7 @@ Library
|
|||
Database.MongoDB.Admin
|
||||
Database.MongoDB.Connection
|
||||
Database.MongoDB.Internal.Connection
|
||||
Database.MongoDB.Internal.Tls
|
||||
Database.MongoDB.Internal.Protocol
|
||||
Database.MongoDB.Internal.Util
|
||||
Database.MongoDB.Query
|
||||
|
|
|
@ -1,69 +0,0 @@
|
|||
|
||||
module Internal.ConnectionSpec (
|
||||
spec,
|
||||
) where
|
||||
|
||||
import Prelude hiding (read)
|
||||
import Data.Monoid
|
||||
import Data.IORef
|
||||
import Control.Monad
|
||||
import System.IO.Error (isEOFError)
|
||||
import Test.Hspec
|
||||
|
||||
import Database.MongoDB.Internal.Connection
|
||||
|
||||
spec :: Spec
|
||||
spec = describe "Internal.Connection" $ do
|
||||
readExactlySpec
|
||||
|
||||
readExactlySpec :: Spec
|
||||
readExactlySpec = describe "readExactly" $ do
|
||||
it "should return specified number of bytes" $ do
|
||||
let conn = Connection
|
||||
{ read = return "12345"
|
||||
, unread = \_ -> return ()
|
||||
, write = \_ -> return ()
|
||||
, flush = return ()
|
||||
, close = return ()
|
||||
}
|
||||
|
||||
res <- readExactly conn 3
|
||||
res `shouldBe` "123"
|
||||
|
||||
it "should unread the rest" $ do
|
||||
restRef <- newIORef mempty
|
||||
let conn = Connection
|
||||
{ read = return "12345"
|
||||
, unread = writeIORef restRef
|
||||
, write = \_ -> return ()
|
||||
, flush = return ()
|
||||
, close = return ()
|
||||
}
|
||||
|
||||
void $ readExactly conn 3
|
||||
rest <- readIORef restRef
|
||||
rest `shouldBe` "45"
|
||||
|
||||
it "should ask for more bytes if the first chunk is too small" $ do
|
||||
let conn = Connection
|
||||
{ read = return "12345"
|
||||
, unread = \_ -> return ()
|
||||
, write = \_ -> return ()
|
||||
, flush = return ()
|
||||
, close = return ()
|
||||
}
|
||||
|
||||
res <- readExactly conn 8
|
||||
res `shouldBe` "12345123"
|
||||
|
||||
it "should throw on EOF" $ do
|
||||
let conn = Connection
|
||||
{ read = return mempty
|
||||
, unread = \_ -> return ()
|
||||
, write = \_ -> return ()
|
||||
, flush = return ()
|
||||
, close = return ()
|
||||
}
|
||||
|
||||
void $ readExactly conn 3
|
||||
`shouldThrow` isEOFError
|
Loading…
Reference in a new issue