Incorporate Tls implementation
This commit is contained in:
parent
19e631c9f4
commit
f956cb2623
3 changed files with 33 additions and 78 deletions
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
-- | TLS connection to mongodb
|
-- | TLS connection to mongodb
|
||||||
|
|
||||||
module Bend.Database.Mongo.Tls
|
module Database.MongoDB.Internal.Tls
|
||||||
(
|
(
|
||||||
connect,
|
connect,
|
||||||
)
|
)
|
||||||
|
@ -16,12 +16,15 @@ import qualified Data.Text as Text
|
||||||
import qualified Data.ByteString as ByteString
|
import qualified Data.ByteString as ByteString
|
||||||
import qualified Data.ByteString.Lazy as Lazy.ByteString
|
import qualified Data.ByteString.Lazy as Lazy.ByteString
|
||||||
import Data.Default.Class (def)
|
import Data.Default.Class (def)
|
||||||
|
import Control.Applicative ((<$>))
|
||||||
import Control.Exception (bracketOnError)
|
import Control.Exception (bracketOnError)
|
||||||
|
import Control.Monad (when, unless)
|
||||||
import System.IO
|
import System.IO
|
||||||
import Database.MongoDB (Pipe)
|
import Database.MongoDB (Pipe)
|
||||||
import Database.MongoDB.Internal.Protocol (newPipeWith)
|
import Database.MongoDB.Internal.Protocol (newPipeWith)
|
||||||
import Database.MongoDB.Internal.Connection (Connection(Connection))
|
import Database.MongoDB.Internal.Connection (Connection(Connection))
|
||||||
import qualified Database.MongoDB.Internal.Connection as Connection
|
import qualified Database.MongoDB.Internal.Connection as Connection
|
||||||
|
import System.IO.Error (mkIOError, eofErrorType)
|
||||||
import qualified Network
|
import qualified Network
|
||||||
import qualified Network.TLS as TLS
|
import qualified Network.TLS as TLS
|
||||||
import qualified Network.TLS.Extra.Cipher as TLS
|
import qualified Network.TLS.Extra.Cipher as TLS
|
||||||
|
@ -53,14 +56,31 @@ tlsConnection :: TLS.Context -> IO () -> IO Connection
|
||||||
tlsConnection ctx close = do
|
tlsConnection ctx close = do
|
||||||
restRef <- newIORef mempty
|
restRef <- newIORef mempty
|
||||||
return Connection
|
return Connection
|
||||||
{ Connection.read = do
|
{ Connection.readExactly = \count -> let
|
||||||
|
readSome = do
|
||||||
rest <- readIORef restRef
|
rest <- readIORef restRef
|
||||||
writeIORef restRef mempty
|
writeIORef restRef mempty
|
||||||
if ByteString.null rest
|
if ByteString.null rest
|
||||||
then TLS.recvData ctx
|
then TLS.recvData ctx
|
||||||
else return rest
|
else return rest
|
||||||
, Connection.unread = \rest ->
|
unread = \rest ->
|
||||||
modifyIORef restRef (rest <>)
|
modifyIORef restRef (rest <>)
|
||||||
|
go acc n = do
|
||||||
|
-- read until get enough bytes
|
||||||
|
chunk <- readSome
|
||||||
|
when (ByteString.null chunk) $
|
||||||
|
ioError eof
|
||||||
|
let len = ByteString.length chunk
|
||||||
|
if len >= n
|
||||||
|
then do
|
||||||
|
let (res, rest) = ByteString.splitAt n chunk
|
||||||
|
unless (ByteString.null rest) $
|
||||||
|
unread rest
|
||||||
|
return (acc <> Lazy.ByteString.fromStrict res)
|
||||||
|
else go (acc <> Lazy.ByteString.fromStrict chunk) (n - len)
|
||||||
|
eof = mkIOError eofErrorType "Database.MongoDB.Internal.Connection"
|
||||||
|
Nothing Nothing
|
||||||
|
in Lazy.ByteString.toStrict <$> go mempty count
|
||||||
, Connection.write = TLS.sendData ctx . Lazy.ByteString.fromStrict
|
, Connection.write = TLS.sendData ctx . Lazy.ByteString.fromStrict
|
||||||
, Connection.flush = TLS.contextFlush ctx
|
, Connection.flush = TLS.contextFlush ctx
|
||||||
, Connection.close = close
|
, Connection.close = close
|
||||||
|
|
|
@ -34,11 +34,14 @@ Library
|
||||||
, mtl >= 2
|
, mtl >= 2
|
||||||
, cryptohash -any
|
, cryptohash -any
|
||||||
, network -any
|
, network -any
|
||||||
|
, io-region -any
|
||||||
, parsec -any
|
, parsec -any
|
||||||
, random -any
|
, random -any
|
||||||
, random-shuffle -any
|
, random-shuffle -any
|
||||||
, monad-control >= 0.3.1
|
, monad-control >= 0.3.1
|
||||||
, lifted-base >= 0.1.0.3
|
, lifted-base >= 0.1.0.3
|
||||||
|
, tls >= 1.2.0
|
||||||
|
, data-default-class -any
|
||||||
, transformers-base >= 0.4.1
|
, transformers-base >= 0.4.1
|
||||||
, hashtables >= 1.1.2.0
|
, hashtables >= 1.1.2.0
|
||||||
, base16-bytestring >= 0.1.1.6
|
, base16-bytestring >= 0.1.1.6
|
||||||
|
@ -49,6 +52,7 @@ Library
|
||||||
Database.MongoDB.Admin
|
Database.MongoDB.Admin
|
||||||
Database.MongoDB.Connection
|
Database.MongoDB.Connection
|
||||||
Database.MongoDB.Internal.Connection
|
Database.MongoDB.Internal.Connection
|
||||||
|
Database.MongoDB.Internal.Tls
|
||||||
Database.MongoDB.Internal.Protocol
|
Database.MongoDB.Internal.Protocol
|
||||||
Database.MongoDB.Internal.Util
|
Database.MongoDB.Internal.Util
|
||||||
Database.MongoDB.Query
|
Database.MongoDB.Query
|
||||||
|
|
|
@ -1,69 +0,0 @@
|
||||||
|
|
||||||
module Internal.ConnectionSpec (
|
|
||||||
spec,
|
|
||||||
) where
|
|
||||||
|
|
||||||
import Prelude hiding (read)
|
|
||||||
import Data.Monoid
|
|
||||||
import Data.IORef
|
|
||||||
import Control.Monad
|
|
||||||
import System.IO.Error (isEOFError)
|
|
||||||
import Test.Hspec
|
|
||||||
|
|
||||||
import Database.MongoDB.Internal.Connection
|
|
||||||
|
|
||||||
spec :: Spec
|
|
||||||
spec = describe "Internal.Connection" $ do
|
|
||||||
readExactlySpec
|
|
||||||
|
|
||||||
readExactlySpec :: Spec
|
|
||||||
readExactlySpec = describe "readExactly" $ do
|
|
||||||
it "should return specified number of bytes" $ do
|
|
||||||
let conn = Connection
|
|
||||||
{ read = return "12345"
|
|
||||||
, unread = \_ -> return ()
|
|
||||||
, write = \_ -> return ()
|
|
||||||
, flush = return ()
|
|
||||||
, close = return ()
|
|
||||||
}
|
|
||||||
|
|
||||||
res <- readExactly conn 3
|
|
||||||
res `shouldBe` "123"
|
|
||||||
|
|
||||||
it "should unread the rest" $ do
|
|
||||||
restRef <- newIORef mempty
|
|
||||||
let conn = Connection
|
|
||||||
{ read = return "12345"
|
|
||||||
, unread = writeIORef restRef
|
|
||||||
, write = \_ -> return ()
|
|
||||||
, flush = return ()
|
|
||||||
, close = return ()
|
|
||||||
}
|
|
||||||
|
|
||||||
void $ readExactly conn 3
|
|
||||||
rest <- readIORef restRef
|
|
||||||
rest `shouldBe` "45"
|
|
||||||
|
|
||||||
it "should ask for more bytes if the first chunk is too small" $ do
|
|
||||||
let conn = Connection
|
|
||||||
{ read = return "12345"
|
|
||||||
, unread = \_ -> return ()
|
|
||||||
, write = \_ -> return ()
|
|
||||||
, flush = return ()
|
|
||||||
, close = return ()
|
|
||||||
}
|
|
||||||
|
|
||||||
res <- readExactly conn 8
|
|
||||||
res `shouldBe` "12345123"
|
|
||||||
|
|
||||||
it "should throw on EOF" $ do
|
|
||||||
let conn = Connection
|
|
||||||
{ read = return mempty
|
|
||||||
, unread = \_ -> return ()
|
|
||||||
, write = \_ -> return ()
|
|
||||||
, flush = return ()
|
|
||||||
, close = return ()
|
|
||||||
}
|
|
||||||
|
|
||||||
void $ readExactly conn 3
|
|
||||||
`shouldThrow` isEOFError
|
|
Loading…
Reference in a new issue