Merge pull request #46 from VictorDenisov/tls

Merge a simple implementation of TLS and drop System.IO.Pipeline
This commit is contained in:
Victor Denisov 2016-05-06 23:50:42 -07:00
commit e683f59faa
10 changed files with 246 additions and 286 deletions

View file

@ -11,11 +11,11 @@ import Database.MongoDB.Query
import qualified Data.Text as T
main = defaultMain [
bgroup "insert" [ bench "100" $ nfIO doInserts ]
bgroup "insert" [ bench "1000" $ nfIO doInserts ]
]
doInserts = do
let docs = (flip map) [0..100] $ \i ->
let docs = (flip map) [0..1000] $ \i ->
["name" M.=: (T.pack $ "name " ++ (show i))]
pipe <- M.connect (M.host "127.0.0.1")

View file

@ -2,6 +2,14 @@
All notable changes to this project will be documented in this file.
This project adheres to [Package Versioning Policy](https://wiki.haskell.org/Package_versioning_policy).
## [2.1.0] - unreleased
### Added
- TLS implementation. So far it is an experimental feature.
### Removed
- System.IO.Pipeline module
## [2.0.10] - 2015-12-22
### Fixed

View file

@ -42,12 +42,11 @@ import Data.Text (Text)
import qualified Data.Bson as B
import qualified Data.Text as T
import Database.MongoDB.Internal.Protocol (Pipe, newPipe)
import Database.MongoDB.Internal.Protocol (Pipe, newPipe, close, isClosed)
import Database.MongoDB.Internal.Util (untilSuccess, liftIOE,
updateAssocs, shuffle, mergesortM)
import Database.MongoDB.Query (Command, Failure(ConnectionFailure), access,
slaveOk, runCommand)
import System.IO.Pipeline (close, isClosed)
adminCommand :: Command -> Pipe -> IO Document
-- ^ Run command against admin database on server connected to pipe. Fail if connection fails.

View file

@ -1,72 +0,0 @@
-- | This module defines a connection interface. It could be a regular
-- network connection, TLS connection, a mock or anything else.
module Database.MongoDB.Internal.Connection (
Connection(..),
readExactly,
fromHandle,
) where
import Prelude hiding (read)
import Data.Monoid
import Data.IORef
import Data.ByteString (ByteString)
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Lazy as Lazy (ByteString)
import qualified Data.ByteString.Lazy as Lazy.ByteString
import Control.Monad
import System.IO
import System.IO.Error (mkIOError, eofErrorType)
-- | Abstract connection interface
--
-- `read` should return `ByteString.null` on EOF
data Connection = Connection {
read :: IO ByteString,
unread :: ByteString -> IO (),
write :: ByteString -> IO (),
flush :: IO (),
close :: IO ()}
readExactly :: Connection -> Int -> IO Lazy.ByteString
-- ^ Read specified number of bytes
--
-- If EOF is reached before N bytes then raise EOF exception.
readExactly conn count = go mempty count
where
go acc n = do
-- read until get enough bytes
chunk <- read conn
when (ByteString.null chunk) $
ioError eof
let len = ByteString.length chunk
if len >= n
then do
let (res, rest) = ByteString.splitAt n chunk
unless (ByteString.null rest) $
unread conn rest
return (acc <> Lazy.ByteString.fromStrict res)
else go (acc <> Lazy.ByteString.fromStrict chunk) (n - len)
eof = mkIOError eofErrorType "Database.MongoDB.Internal.Connection"
Nothing Nothing
fromHandle :: Handle -> IO Connection
-- ^ Make connection form handle
fromHandle handle = do
restRef <- newIORef mempty
return Connection
{ read = do
rest <- readIORef restRef
writeIORef restRef mempty
if ByteString.null rest
-- 32k corresponds to the default chunk size
-- used in bytestring package
then ByteString.hGetSome handle (32 * 1024)
else return rest
, unread = \rest ->
modifyIORef restRef (rest <>)
, write = ByteString.hPut handle
, flush = hFlush handle
, close = hClose handle
}

View file

@ -8,6 +8,14 @@
{-# LANGUAGE CPP, FlexibleContexts, TupleSections, TypeSynonymInstances #-}
{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, UndecidableInstances #-}
{-# LANGUAGE NamedFieldPuns, ScopedTypeVariables #-}
#if (__GLASGOW_HASKELL__ >= 706)
{-# LANGUAGE RecursiveDo #-}
#else
{-# LANGUAGE DoRec #-}
#endif
module Database.MongoDB.Internal.Protocol (
FullCollection,
-- * Pipe
@ -19,13 +27,13 @@ module Database.MongoDB.Internal.Protocol (
-- ** Reply
Reply(..), ResponseFlag(..),
-- * Authentication
Username, Password, Nonce, pwHash, pwKey
Username, Password, Nonce, pwHash, pwKey,
isClosed, close
) where
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative ((<$>))
#endif
import Control.Arrow ((***))
import Control.Monad (forM, replicateM, unless)
import Data.Binary.Get (Get, runGet)
import Data.Binary.Put (Put, runPut)
@ -35,6 +43,12 @@ import Data.IORef (IORef, newIORef, atomicModifyIORef)
import System.IO (Handle)
import System.IO.Unsafe (unsafePerformIO)
import Data.Maybe (maybeToList)
import GHC.Conc (ThreadStatus(..), threadStatus)
import Control.Monad (forever)
import Control.Concurrent.Chan (Chan, newChan, readChan, writeChan)
import Control.Concurrent (ThreadId, forkIO, killThread)
import Control.Exception.Lifted (onException, throwIO, try)
import qualified Data.ByteString.Lazy as L
@ -48,38 +62,111 @@ import qualified Crypto.Hash.MD5 as MD5
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import Database.MongoDB.Internal.Util (whenJust, bitOr, byteStringHex)
import System.IO.Pipeline (Pipeline, newPipeline, IOStream(..))
import Database.MongoDB.Internal.Util (bitOr, byteStringHex)
import qualified System.IO.Pipeline as P
import Database.MongoDB.Transport (Transport)
import qualified Database.MongoDB.Transport as T
import Database.MongoDB.Internal.Connection (Connection)
import qualified Database.MongoDB.Internal.Connection as Connection
#if MIN_VERSION_base(4,6,0)
import Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar, withMVar,
putMVar, readMVar, mkWeakMVar)
#else
import Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar, withMVar,
putMVar, readMVar, addMVarFinalizer)
#endif
#if !MIN_VERSION_base(4,6,0)
mkWeakMVar :: MVar a -> IO () -> IO ()
mkWeakMVar = addMVarFinalizer
#endif
-- * Pipeline
-- | Thread-safe and pipelined connection
data Pipeline = Pipeline {
vStream :: MVar Transport, -- ^ Mutex on handle, so only one thread at a time can write to it
responseQueue :: Chan (MVar (Either IOError Response)), -- ^ Queue of threads waiting for responses. Every time a response arrive we pop the next thread and give it the response.
listenThread :: ThreadId
}
-- | Create new Pipeline over given handle. You should 'close' pipeline when finished, which will also close handle. If pipeline is not closed but eventually garbage collected, it will be closed along with handle.
newPipeline :: Transport -> IO Pipeline
newPipeline stream = do
vStream <- newMVar stream
responseQueue <- newChan
rec
let pipe = Pipeline{..}
listenThread <- forkIO (listen pipe)
_ <- mkWeakMVar vStream $ do
killThread listenThread
T.close stream
return pipe
close :: Pipeline -> IO ()
-- ^ Close pipe and underlying connection
close Pipeline{..} = do
killThread listenThread
T.close =<< readMVar vStream
isClosed :: Pipeline -> IO Bool
isClosed Pipeline{listenThread} = do
status <- threadStatus listenThread
return $ case status of
ThreadRunning -> False
ThreadFinished -> True
ThreadBlocked _ -> False
ThreadDied -> True
--isPipeClosed Pipeline{..} = isClosed =<< readMVar vHandle -- isClosed hangs while listen loop is waiting on read
listen :: Pipeline -> IO ()
-- ^ Listen for responses and supply them to waiting threads in order
listen Pipeline{..} = do
stream <- readMVar vStream
forever $ do
e <- try $ readMessage stream
var <- readChan responseQueue
putMVar var e
case e of
Left err -> T.close stream >> ioError err -- close and stop looping
Right _ -> return ()
psend :: Pipeline -> Message -> IO ()
-- ^ Send message to destination; the destination must not response (otherwise future 'call's will get these responses instead of their own).
-- Throw IOError and close pipeline if send fails
psend p@Pipeline{..} message = withMVar vStream (flip writeMessage message) `onException` close p
pcall :: Pipeline -> Message -> IO (IO Response)
-- ^ Send message to destination and return /promise/ of response from one message only. The destination must reply to the message (otherwise promises will have the wrong responses in them).
-- Throw IOError and closes pipeline if send fails, likewise for promised response.
pcall p@Pipeline{..} message = withMVar vStream doCall `onException` close p where
doCall stream = do
writeMessage stream message
var <- newEmptyMVar
liftIO $ writeChan responseQueue var
return $ readMVar var >>= either throwIO return -- return promise
-- * Pipe
type Pipe = Pipeline Response Message
type Pipe = Pipeline
-- ^ Thread-safe TCP connection with pipelined requests
newPipe :: Handle -> IO Pipe
-- ^ Create pipe over handle
newPipe handle = Connection.fromHandle handle >>= newPipeWith
newPipe handle = T.fromHandle handle >>= newPipeWith
newPipeWith :: Connection -> IO Pipe
newPipeWith :: Transport -> IO Pipe
-- ^ Create pipe over connection
newPipeWith conn = newPipeline $ IOStream (writeMessage conn)
(readMessage conn)
(Connection.close conn)
newPipeWith conn = newPipeline conn
send :: Pipe -> [Notice] -> IO ()
-- ^ Send notices as a contiguous batch to server with no reply. Throw IOError if connection fails.
send pipe notices = P.send pipe (notices, Nothing)
send pipe notices = psend pipe (notices, Nothing)
call :: Pipe -> [Notice] -> Request -> IO (IO Reply)
-- ^ Send notices and request as a contiguous batch to server and return reply promise, which will block when invoked until reply arrives. This call and resulting promise will throw IOError if connection fails.
call pipe notices request = do
requestId <- genRequestId
promise <- P.call pipe (notices, Just (request, requestId))
promise <- pcall pipe (notices, Just (request, requestId))
return $ check requestId <$> promise
where
check requestId (responseTo, reply) = if requestId == responseTo then reply else
@ -91,7 +178,7 @@ type Message = ([Notice], Maybe (Request, RequestId))
-- ^ A write notice(s) with getLastError request, or just query request.
-- Note, that requestId will be out of order because request ids will be generated for notices after the request id supplied was generated. This is ok because the mongo server does not care about order just uniqueness.
writeMessage :: Connection -> Message -> IO ()
writeMessage :: Transport -> Message -> IO ()
-- ^ Write message to connection
writeMessage conn (notices, mRequest) = do
noticeStrings <- forM notices $ \n -> do
@ -104,8 +191,8 @@ writeMessage conn (notices, mRequest) = do
let s = runPut $ putRequest request requestId
return $ (lenBytes s) `L.append` s
Connection.write conn $ L.toStrict $ L.concat $ noticeStrings ++ (maybeToList requestString)
Connection.flush conn
T.write conn $ L.toStrict $ L.concat $ noticeStrings ++ (maybeToList requestString)
T.flush conn
where
lenBytes bytes = encodeSize . toEnum . fromEnum $ L.length bytes
encodeSize = runPut . putInt32 . (+ 4)
@ -113,12 +200,12 @@ writeMessage conn (notices, mRequest) = do
type Response = (ResponseTo, Reply)
-- ^ Message received from a Mongo server in response to a Request
readMessage :: Connection -> IO Response
readMessage :: Transport -> IO Response
-- ^ read response from a connection
readMessage conn = readResp where
readResp = do
len <- fromEnum . decodeSize <$> Connection.readExactly conn 4
runGet getReply <$> Connection.readExactly conn len
len <- fromEnum . decodeSize . L.fromStrict <$> T.read conn 4
runGet getReply . L.fromStrict <$> T.read conn len
decodeSize = subtract 4 . runGet getInt32
type FullCollection = Text

View file

@ -0,0 +1,32 @@
-- | This module defines a connection interface. It could be a regular
-- network connection, TLS connection, a mock or anything else.
module Database.MongoDB.Transport (
Transport(..),
fromHandle,
) where
import Prelude hiding (read)
import Data.ByteString (ByteString)
import qualified Data.ByteString as ByteString
import System.IO
-- | Abstract transport interface
--
-- `read` should return `ByteString.null` on EOF
data Transport = Transport {
read :: Int -> IO ByteString,
write :: ByteString -> IO (),
flush :: IO (),
close :: IO ()}
fromHandle :: Handle -> IO Transport
-- ^ Make connection form handle
fromHandle handle = do
return Transport
{ read = ByteString.hGet handle
, write = ByteString.hPut handle
, flush = hFlush handle
, close = hClose handle
}

View file

@ -0,0 +1,88 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-|
Module : MongoDB TLS
Description : TLS transport for mongodb
Copyright : (c) Yuras Shumovich, 2016
License : Apache 2.0
Maintainer : Victor Denisov denisovenator@gmail.com
Stability : experimental
Portability : POSIX
This module is for connecting to TLS enabled mongodb servers.
ATTENTION!!! Be aware that this module is highly experimental and is
barely tested. The current implementation doesn't verify server's identity.
It only allows you to connect to a mongodb server using TLS protocol.
-}
module Database.MongoDB.Transport.Tls
(connect)
where
import Data.IORef
import Data.Monoid
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Lazy as Lazy.ByteString
import Data.Default.Class (def)
import Control.Applicative ((<$>))
import Control.Exception (bracketOnError)
import Control.Monad (when, unless)
import System.IO
import Database.MongoDB (Pipe)
import Database.MongoDB.Internal.Protocol (newPipeWith)
import Database.MongoDB.Transport (Transport(Transport))
import qualified Database.MongoDB.Transport as T
import System.IO.Error (mkIOError, eofErrorType)
import Network (connectTo, HostName, PortID)
import qualified Network.TLS as TLS
import qualified Network.TLS.Extra.Cipher as TLS
-- | Connect to mongodb using TLS
connect :: HostName -> PortID -> IO Pipe
connect host port = bracketOnError (connectTo host port) hClose $ \handle -> do
let params = (TLS.defaultParamsClient host "")
{ TLS.clientSupported = def
{ TLS.supportedCiphers = TLS.ciphersuite_all}
, TLS.clientHooks = def
{ TLS.onServerCertificate = \_ _ _ _ -> return []}
}
context <- TLS.contextNew handle params
TLS.handshake context
conn <- tlsConnection context
newPipeWith conn
tlsConnection :: TLS.Context -> IO Transport
tlsConnection ctx = do
restRef <- newIORef mempty
return Transport
{ T.read = \count -> let
readSome = do
rest <- readIORef restRef
writeIORef restRef mempty
if ByteString.null rest
then TLS.recvData ctx
else return rest
unread = \rest ->
modifyIORef restRef (rest <>)
go acc n = do
-- read until get enough bytes
chunk <- readSome
when (ByteString.null chunk) $
ioError eof
let len = ByteString.length chunk
if len >= n
then do
let (res, rest) = ByteString.splitAt n chunk
unless (ByteString.null rest) $
unread rest
return (acc <> Lazy.ByteString.fromStrict res)
else go (acc <> Lazy.ByteString.fromStrict chunk) (n - len)
eof = mkIOError eofErrorType "Database.MongoDB.Transport"
Nothing Nothing
in Lazy.ByteString.toStrict <$> go mempty count
, T.write = TLS.sendData ctx . Lazy.ByteString.fromStrict
, T.flush = TLS.contextFlush ctx
, T.close = TLS.contextClose ctx
}

View file

@ -1,118 +0,0 @@
{- | Pipelining is sending multiple requests over a socket and receiving the responses later in the same order (a' la HTTP pipelining). This is faster than sending one request, waiting for the response, then sending the next request, and so on. This implementation returns a /promise (future)/ response for each request that when invoked waits for the response if not already arrived. Multiple threads can send on the same pipeline (and get promises back); it will send each thread's request right away without waiting.
A pipeline closes itself when a read or write causes an error, so you can detect a broken pipeline by checking isClosed. It also closes itself when garbage collected, or you can close it explicitly. -}
{-# LANGUAGE RecordWildCards, NamedFieldPuns, ScopedTypeVariables #-}
{-# LANGUAGE CPP, FlexibleContexts #-}
#if (__GLASGOW_HASKELL__ >= 706)
{-# LANGUAGE RecursiveDo #-}
#else
{-# LANGUAGE DoRec #-}
#endif
module System.IO.Pipeline (
-- * IOStream
IOStream(..),
-- * Pipeline
Pipeline, newPipeline, send, call, close, isClosed
) where
import Prelude hiding (length)
import Control.Concurrent (ThreadId, forkIO, killThread)
import Control.Concurrent.Chan (Chan, newChan, readChan, writeChan)
import Control.Monad (forever)
import GHC.Conc (ThreadStatus(..), threadStatus)
import Control.Monad.Trans (liftIO)
#if MIN_VERSION_base(4,6,0)
import Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar, withMVar,
putMVar, readMVar, mkWeakMVar)
#else
import Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar, withMVar,
putMVar, readMVar, addMVarFinalizer)
#endif
import Control.Exception.Lifted (onException, throwIO, try)
#if !MIN_VERSION_base(4,6,0)
mkWeakMVar :: MVar a -> IO () -> IO ()
mkWeakMVar = addMVarFinalizer
#endif
-- * IOStream
-- | An IO sink and source where value of type @o@ are sent and values of type @i@ are received.
data IOStream i o = IOStream {
writeStream :: o -> IO (),
readStream :: IO i,
closeStream :: IO () }
-- * Pipeline
-- | Thread-safe and pipelined connection
data Pipeline i o = Pipeline {
vStream :: MVar (IOStream i o), -- ^ Mutex on handle, so only one thread at a time can write to it
responseQueue :: Chan (MVar (Either IOError i)), -- ^ Queue of threads waiting for responses. Every time a response arrive we pop the next thread and give it the response.
listenThread :: ThreadId
}
-- | Create new Pipeline over given handle. You should 'close' pipeline when finished, which will also close handle. If pipeline is not closed but eventually garbage collected, it will be closed along with handle.
newPipeline :: IOStream i o -> IO (Pipeline i o)
newPipeline stream = do
vStream <- newMVar stream
responseQueue <- newChan
rec
let pipe = Pipeline{..}
listenThread <- forkIO (listen pipe)
_ <- mkWeakMVar vStream $ do
killThread listenThread
closeStream stream
return pipe
close :: Pipeline i o -> IO ()
-- ^ Close pipe and underlying connection
close Pipeline{..} = do
killThread listenThread
closeStream =<< readMVar vStream
isClosed :: Pipeline i o -> IO Bool
isClosed Pipeline{listenThread} = do
status <- threadStatus listenThread
return $ case status of
ThreadRunning -> False
ThreadFinished -> True
ThreadBlocked _ -> False
ThreadDied -> True
--isPipeClosed Pipeline{..} = isClosed =<< readMVar vHandle -- isClosed hangs while listen loop is waiting on read
listen :: Pipeline i o -> IO ()
-- ^ Listen for responses and supply them to waiting threads in order
listen Pipeline{..} = do
stream <- readMVar vStream
forever $ do
e <- try $ readStream stream
var <- readChan responseQueue
putMVar var e
case e of
Left err -> closeStream stream >> ioError err -- close and stop looping
Right _ -> return ()
send :: Pipeline i o -> o -> IO ()
-- ^ Send message to destination; the destination must not response (otherwise future 'call's will get these responses instead of their own).
-- Throw IOError and close pipeline if send fails
send p@Pipeline{..} message = withMVar vStream (flip writeStream message) `onException` close p
call :: Pipeline i o -> o -> IO (IO i)
-- ^ Send message to destination and return /promise/ of response from one message only. The destination must reply to the message (otherwise promises will have the wrong responses in them).
-- Throw IOError and closes pipeline if send fails, likewise for promised response.
call p@Pipeline{..} message = withMVar vStream doCall `onException` close p where
doCall stream = do
writeStream stream message
var <- newEmptyMVar
liftIO $ writeChan responseQueue var
return $ readMVar var >>= either throwIO return -- return promise
{- Authors: Tony Hannan <tony@10gen.com>
Copyright 2011 10gen Inc.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0. Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -}

View file

@ -39,6 +39,8 @@ Library
, random-shuffle -any
, monad-control >= 0.3.1
, lifted-base >= 0.1.0.3
, tls >= 1.2.0
, data-default-class -any
, transformers-base >= 0.4.1
, hashtables >= 1.1.2.0
, base16-bytestring >= 0.1.1.6
@ -48,11 +50,11 @@ Library
Exposed-modules: Database.MongoDB
Database.MongoDB.Admin
Database.MongoDB.Connection
Database.MongoDB.Internal.Connection
Database.MongoDB.Internal.Protocol
Database.MongoDB.Internal.Util
Database.MongoDB.Query
System.IO.Pipeline
Database.MongoDB.Transport
Database.MongoDB.Transport.Tls
Source-repository head
Type: git
@ -82,6 +84,8 @@ Benchmark bench
type: exitcode-stdio-1.0
Build-depends: array -any
, base < 5
, base64-bytestring
, base16-bytestring
, binary -any
, bson >= 0.3 && < 0.4
, text
@ -90,6 +94,7 @@ Benchmark bench
, mtl >= 2
, cryptohash -any
, network -any
, nonce
, parsec -any
, random -any
, random-shuffle -any

View file

@ -1,69 +0,0 @@
module Internal.ConnectionSpec (
spec,
) where
import Prelude hiding (read)
import Data.Monoid
import Data.IORef
import Control.Monad
import System.IO.Error (isEOFError)
import Test.Hspec
import Database.MongoDB.Internal.Connection
spec :: Spec
spec = describe "Internal.Connection" $ do
readExactlySpec
readExactlySpec :: Spec
readExactlySpec = describe "readExactly" $ do
it "should return specified number of bytes" $ do
let conn = Connection
{ read = return "12345"
, unread = \_ -> return ()
, write = \_ -> return ()
, flush = return ()
, close = return ()
}
res <- readExactly conn 3
res `shouldBe` "123"
it "should unread the rest" $ do
restRef <- newIORef mempty
let conn = Connection
{ read = return "12345"
, unread = writeIORef restRef
, write = \_ -> return ()
, flush = return ()
, close = return ()
}
void $ readExactly conn 3
rest <- readIORef restRef
rest `shouldBe` "45"
it "should ask for more bytes if the first chunk is too small" $ do
let conn = Connection
{ read = return "12345"
, unread = \_ -> return ()
, write = \_ -> return ()
, flush = return ()
, close = return ()
}
res <- readExactly conn 8
res `shouldBe` "12345123"
it "should throw on EOF" $ do
let conn = Connection
{ read = return mempty
, unread = \_ -> return ()
, write = \_ -> return ()
, flush = return ()
, close = return ()
}
void $ readExactly conn 3
`shouldThrow` isEOFError