Incorporate Tls implementation

This commit is contained in:
Victor Denisov 2016-04-30 20:11:44 -07:00
parent 19e631c9f4
commit f956cb2623
3 changed files with 33 additions and 78 deletions

View file

@ -3,7 +3,7 @@
-- | TLS connection to mongodb -- | TLS connection to mongodb
module Bend.Database.Mongo.Tls module Database.MongoDB.Internal.Tls
( (
connect, connect,
) )
@ -16,12 +16,15 @@ import qualified Data.Text as Text
import qualified Data.ByteString as ByteString import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Lazy as Lazy.ByteString import qualified Data.ByteString.Lazy as Lazy.ByteString
import Data.Default.Class (def) import Data.Default.Class (def)
import Control.Applicative ((<$>))
import Control.Exception (bracketOnError) import Control.Exception (bracketOnError)
import Control.Monad (when, unless)
import System.IO import System.IO
import Database.MongoDB (Pipe) import Database.MongoDB (Pipe)
import Database.MongoDB.Internal.Protocol (newPipeWith) import Database.MongoDB.Internal.Protocol (newPipeWith)
import Database.MongoDB.Internal.Connection (Connection(Connection)) import Database.MongoDB.Internal.Connection (Connection(Connection))
import qualified Database.MongoDB.Internal.Connection as Connection import qualified Database.MongoDB.Internal.Connection as Connection
import System.IO.Error (mkIOError, eofErrorType)
import qualified Network import qualified Network
import qualified Network.TLS as TLS import qualified Network.TLS as TLS
import qualified Network.TLS.Extra.Cipher as TLS import qualified Network.TLS.Extra.Cipher as TLS
@ -53,14 +56,31 @@ tlsConnection :: TLS.Context -> IO () -> IO Connection
tlsConnection ctx close = do tlsConnection ctx close = do
restRef <- newIORef mempty restRef <- newIORef mempty
return Connection return Connection
{ Connection.read = do { Connection.readExactly = \count -> let
readSome = do
rest <- readIORef restRef rest <- readIORef restRef
writeIORef restRef mempty writeIORef restRef mempty
if ByteString.null rest if ByteString.null rest
then TLS.recvData ctx then TLS.recvData ctx
else return rest else return rest
, Connection.unread = \rest -> unread = \rest ->
modifyIORef restRef (rest <>) modifyIORef restRef (rest <>)
go acc n = do
-- read until get enough bytes
chunk <- readSome
when (ByteString.null chunk) $
ioError eof
let len = ByteString.length chunk
if len >= n
then do
let (res, rest) = ByteString.splitAt n chunk
unless (ByteString.null rest) $
unread rest
return (acc <> Lazy.ByteString.fromStrict res)
else go (acc <> Lazy.ByteString.fromStrict chunk) (n - len)
eof = mkIOError eofErrorType "Database.MongoDB.Internal.Connection"
Nothing Nothing
in Lazy.ByteString.toStrict <$> go mempty count
, Connection.write = TLS.sendData ctx . Lazy.ByteString.fromStrict , Connection.write = TLS.sendData ctx . Lazy.ByteString.fromStrict
, Connection.flush = TLS.contextFlush ctx , Connection.flush = TLS.contextFlush ctx
, Connection.close = close , Connection.close = close

View file

@ -34,11 +34,14 @@ Library
, mtl >= 2 , mtl >= 2
, cryptohash -any , cryptohash -any
, network -any , network -any
, io-region -any
, parsec -any , parsec -any
, random -any , random -any
, random-shuffle -any , random-shuffle -any
, monad-control >= 0.3.1 , monad-control >= 0.3.1
, lifted-base >= 0.1.0.3 , lifted-base >= 0.1.0.3
, tls >= 1.2.0
, data-default-class -any
, transformers-base >= 0.4.1 , transformers-base >= 0.4.1
, hashtables >= 1.1.2.0 , hashtables >= 1.1.2.0
, base16-bytestring >= 0.1.1.6 , base16-bytestring >= 0.1.1.6
@ -49,6 +52,7 @@ Library
Database.MongoDB.Admin Database.MongoDB.Admin
Database.MongoDB.Connection Database.MongoDB.Connection
Database.MongoDB.Internal.Connection Database.MongoDB.Internal.Connection
Database.MongoDB.Internal.Tls
Database.MongoDB.Internal.Protocol Database.MongoDB.Internal.Protocol
Database.MongoDB.Internal.Util Database.MongoDB.Internal.Util
Database.MongoDB.Query Database.MongoDB.Query

View file

@ -1,69 +0,0 @@
module Internal.ConnectionSpec (
spec,
) where
import Prelude hiding (read)
import Data.Monoid
import Data.IORef
import Control.Monad
import System.IO.Error (isEOFError)
import Test.Hspec
import Database.MongoDB.Internal.Connection
spec :: Spec
spec = describe "Internal.Connection" $ do
readExactlySpec
readExactlySpec :: Spec
readExactlySpec = describe "readExactly" $ do
it "should return specified number of bytes" $ do
let conn = Connection
{ read = return "12345"
, unread = \_ -> return ()
, write = \_ -> return ()
, flush = return ()
, close = return ()
}
res <- readExactly conn 3
res `shouldBe` "123"
it "should unread the rest" $ do
restRef <- newIORef mempty
let conn = Connection
{ read = return "12345"
, unread = writeIORef restRef
, write = \_ -> return ()
, flush = return ()
, close = return ()
}
void $ readExactly conn 3
rest <- readIORef restRef
rest `shouldBe` "45"
it "should ask for more bytes if the first chunk is too small" $ do
let conn = Connection
{ read = return "12345"
, unread = \_ -> return ()
, write = \_ -> return ()
, flush = return ()
, close = return ()
}
res <- readExactly conn 8
res `shouldBe` "12345123"
it "should throw on EOF" $ do
let conn = Connection
{ read = return mempty
, unread = \_ -> return ()
, write = \_ -> return ()
, flush = return ()
, close = return ()
}
void $ readExactly conn 3
`shouldThrow` isEOFError