Add support for opening replica sets over TLS

This commit is contained in:
Neil Cowburn 2019-11-01 16:55:59 +00:00
parent 76d5f84f8a
commit d334d889ee
No known key found for this signature in database
GPG key ID: 4A07D4C2E3F34E66
3 changed files with 39 additions and 19 deletions

View file

@ -17,7 +17,7 @@ module Database.MongoDB.Connection (
Host(..), PortID(..), defaultPort, host, showHostPort, readHostPort, Host(..), PortID(..), defaultPort, host, showHostPort, readHostPort,
readHostPortM, globalConnectTimeout, connect, connect', readHostPortM, globalConnectTimeout, connect, connect',
-- * Replica Set -- * Replica Set
ReplicaSetName, openReplicaSet, openReplicaSet', ReplicaSetName, openReplicaSet, openReplicaSet', openReplicaSetTLS, openReplicaSetTLS',
ReplicaSet, primary, secondaryOk, routedHost, closeReplicaSet, replSetName ReplicaSet, primary, secondaryOk, routedHost, closeReplicaSet, replSetName
) where ) where
@ -47,12 +47,16 @@ import Data.Text (Text)
import qualified Data.Bson as B import qualified Data.Bson as B
import qualified Data.Text as T import qualified Data.Text as T
import Database.MongoDB.Internal.Network (HostName, PortID(..), connectTo) import Database.MongoDB.Internal.Network (Host(..), HostName, PortID(..), connectTo, lookupSeedList, lookupReplicaSetName)
import Database.MongoDB.Internal.Protocol (Pipe, newPipe, close, isClosed) import Database.MongoDB.Internal.Protocol (Pipe, newPipe, close, isClosed)
import Database.MongoDB.Internal.Util (untilSuccess, liftIOE, import Database.MongoDB.Internal.Util (untilSuccess, liftIOE,
updateAssocs, shuffle, mergesortM) updateAssocs, shuffle, mergesortM)
import Database.MongoDB.Query (Command, Failure(ConnectionFailure), access, import Database.MongoDB.Query (Command, Failure(ConnectionFailure), access,
slaveOk, runCommand, retrieveServerData) slaveOk, runCommand, retrieveServerData)
import qualified Database.MongoDB.Transport.Tls as TLS (connect)
flip' :: (a -> b -> c -> d) -> b -> c -> a -> d
flip' f x y z = f z x y
adminCommand :: Command -> Pipe -> IO Document adminCommand :: Command -> Pipe -> IO Document
-- ^ Run command against admin database on server connected to pipe. Fail if connection fails. -- ^ Run command against admin database on server connected to pipe. Fail if connection fails.
@ -62,10 +66,6 @@ adminCommand cmd pipe =
failureToIOError (ConnectionFailure e) = e failureToIOError (ConnectionFailure e) = e
failureToIOError e = userError $ show e failureToIOError e = userError $ show e
-- * Host
data Host = Host HostName PortID deriving (Show, Eq, Ord)
defaultPort :: PortID defaultPort :: PortID
-- ^ Default MongoDB port = 27017 -- ^ Default MongoDB port = 27017
defaultPort = PortNumber 27017 defaultPort = PortNumber 27017
@ -124,12 +124,14 @@ connect' timeoutSecs (Host hostname port) = do
type ReplicaSetName = Text type ReplicaSetName = Text
data TransportSecurity = Secure | Insecure
-- | Maintains a connection (created on demand) to each server in the named replica set -- | Maintains a connection (created on demand) to each server in the named replica set
data ReplicaSet = ReplicaSet ReplicaSetName (MVar [(Host, Maybe Pipe)]) Secs data ReplicaSet = ReplicaSet ReplicaSetName (MVar [(Host, Maybe Pipe)]) Secs TransportSecurity
replSetName :: ReplicaSet -> Text replSetName :: ReplicaSet -> Text
-- ^ name of connected replica set -- ^ name of connected replica set
replSetName (ReplicaSet rsName _ _) = rsName replSetName (ReplicaSet rsName _ _ _) = rsName
openReplicaSet :: (ReplicaSetName, [Host]) -> IO ReplicaSet openReplicaSet :: (ReplicaSetName, [Host]) -> IO ReplicaSet
-- ^ Open connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSet\'' instead. -- ^ Open connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSet\'' instead.
@ -137,19 +139,30 @@ openReplicaSet rsSeed = readIORef globalConnectTimeout >>= flip openReplicaSet'
openReplicaSet' :: Secs -> (ReplicaSetName, [Host]) -> IO ReplicaSet openReplicaSet' :: Secs -> (ReplicaSetName, [Host]) -> IO ReplicaSet
-- ^ Open connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. Supplied seconds timeout is used for connect attempts to members. -- ^ Open connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. Supplied seconds timeout is used for connect attempts to members.
openReplicaSet' timeoutSecs (rsName, seedList) = do openReplicaSet' timeoutSecs (rs, hosts) = _openReplicaSet timeoutSecs (rs, hosts, Insecure)
openReplicaSetTLS :: (ReplicaSetName, [Host]) -> IO ReplicaSet
-- ^ Open secure connections (on demand) to servers in the replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSetTLS\'' instead.
openReplicaSetTLS rsSeed = readIORef globalConnectTimeout >>= flip openReplicaSetTLS' rsSeed
openReplicaSetTLS' :: Secs -> (ReplicaSetName, [Host]) -> IO ReplicaSet
-- ^ Open secure connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. Supplied seconds timeout is used for connect attempts to members.
openReplicaSetTLS' timeoutSecs (rs, hosts) = _openReplicaSet timeoutSecs (rs, hosts, Secure)
_openReplicaSet :: Secs -> (ReplicaSetName, [Host], TransportSecurity) -> IO ReplicaSet
_openReplicaSet timeoutSecs (rsName, seedList, transportSecurity) = do
vMembers <- newMVar (map (, Nothing) seedList) vMembers <- newMVar (map (, Nothing) seedList)
let rs = ReplicaSet rsName vMembers timeoutSecs let rs = ReplicaSet rsName vMembers timeoutSecs transportSecurity
_ <- updateMembers rs _ <- updateMembers rs
return rs return rs
closeReplicaSet :: ReplicaSet -> IO () closeReplicaSet :: ReplicaSet -> IO ()
-- ^ Close all connections to replica set -- ^ Close all connections to replica set
closeReplicaSet (ReplicaSet _ vMembers _) = withMVar vMembers $ mapM_ (maybe (return ()) close . snd) closeReplicaSet (ReplicaSet _ vMembers _ _) = withMVar vMembers $ mapM_ (maybe (return ()) close . snd)
primary :: ReplicaSet -> IO Pipe primary :: ReplicaSet -> IO Pipe
-- ^ Return connection to current primary of replica set. Fail if no primary available. -- ^ Return connection to current primary of replica set. Fail if no primary available.
primary rs@(ReplicaSet rsName _ _) = do primary rs@(ReplicaSet rsName _ _ _) = do
mHost <- statedPrimary <$> updateMembers rs mHost <- statedPrimary <$> updateMembers rs
case mHost of case mHost of
Just host' -> connection rs Nothing host' Just host' -> connection rs Nothing host'
@ -185,7 +198,7 @@ possibleHosts (_, info) = map readHostPort $ at "hosts" info
updateMembers :: ReplicaSet -> IO ReplicaInfo updateMembers :: ReplicaSet -> IO ReplicaInfo
-- ^ Fetch replica info from any server and update members accordingly -- ^ Fetch replica info from any server and update members accordingly
updateMembers rs@(ReplicaSet _ vMembers _) = do updateMembers rs@(ReplicaSet _ vMembers _ _) = do
(host', info) <- untilSuccess (fetchReplicaInfo rs) =<< readMVar vMembers (host', info) <- untilSuccess (fetchReplicaInfo rs) =<< readMVar vMembers
modifyMVar vMembers $ \members -> do modifyMVar vMembers $ \members -> do
let ((members', old), new) = intersection (map readHostPort $ at "hosts" info) members let ((members', old), new) = intersection (map readHostPort $ at "hosts" info) members
@ -199,7 +212,7 @@ updateMembers rs@(ReplicaSet _ vMembers _) = do
fetchReplicaInfo :: ReplicaSet -> (Host, Maybe Pipe) -> IO ReplicaInfo fetchReplicaInfo :: ReplicaSet -> (Host, Maybe Pipe) -> IO ReplicaInfo
-- Connect to host and fetch replica info from host creating new connection if missing or closed (previously failed). Fail if not member of named replica set. -- Connect to host and fetch replica info from host creating new connection if missing or closed (previously failed). Fail if not member of named replica set.
fetchReplicaInfo rs@(ReplicaSet rsName _ _) (host', mPipe) = do fetchReplicaInfo rs@(ReplicaSet rsName _ _ _) (host', mPipe) = do
pipe <- connection rs mPipe host' pipe <- connection rs mPipe host'
info <- adminCommand ["isMaster" =: (1 :: Int)] pipe info <- adminCommand ["isMaster" =: (1 :: Int)] pipe
case B.lookup "setName" info of case B.lookup "setName" info of
@ -209,11 +222,15 @@ fetchReplicaInfo rs@(ReplicaSet rsName _ _) (host', mPipe) = do
connection :: ReplicaSet -> Maybe Pipe -> Host -> IO Pipe connection :: ReplicaSet -> Maybe Pipe -> Host -> IO Pipe
-- ^ Return new or existing connection to member of replica set. If pipe is already known for host it is given, but we still test if it is open. -- ^ Return new or existing connection to member of replica set. If pipe is already known for host it is given, but we still test if it is open.
connection (ReplicaSet _ vMembers timeoutSecs) mPipe host' = connection (ReplicaSet _ vMembers timeoutSecs transportSecurity) mPipe host' =
maybe conn (\p -> isClosed p >>= \bad -> if bad then conn else return p) mPipe maybe conn (\p -> isClosed p >>= \bad -> if bad then conn else return p) mPipe
where where
conn = modifyMVar vMembers $ \members -> do conn = modifyMVar vMembers $ \members -> do
let new = connect' timeoutSecs host' >>= \pipe -> return (updateAssocs host' (Just pipe) members, pipe) let (Host h p) = host'
let conn' = case transportSecurity of
Secure -> TLS.connect h p
Insecure -> connect' timeoutSecs host'
let new = conn' >>= \pipe -> return (updateAssocs host' (Just pipe) members, pipe)
case List.lookup host' members of case List.lookup host' members of
Just (Just pipe) -> isClosed pipe >>= \bad -> if bad then new else return (members, pipe) Just (Just pipe) -> isClosed pipe >>= \bad -> if bad then new else return (members, pipe)
_ -> new _ -> new

View file

@ -1,7 +1,7 @@
-- | Compatibility layer for network package, including newtype 'PortID' -- | Compatibility layer for network package, including newtype 'PortID'
{-# LANGUAGE CPP, GeneralizedNewtypeDeriving #-} {-# LANGUAGE CPP, GeneralizedNewtypeDeriving #-}
module Database.MongoDB.Internal.Network (PortID(..), N.HostName, connectTo) where module Database.MongoDB.Internal.Network (Host(..), PortID(..), N.HostName, connectTo) where
#if !MIN_VERSION_network(2, 9, 0) #if !MIN_VERSION_network(2, 9, 0)
@ -50,3 +50,7 @@ connectTo hostname (PortNumber port) = do
N.socketToHandle sock ReadWriteMode N.socketToHandle sock ReadWriteMode
) )
#endif #endif
-- * Host
data Host = Host N.HostName PortID deriving (Show, Eq, Ord)

View file

@ -34,8 +34,7 @@ import Control.Applicative ((<$>))
import Control.Exception (bracketOnError) import Control.Exception (bracketOnError)
import Control.Monad (when, unless) import Control.Monad (when, unless)
import System.IO import System.IO
import Database.MongoDB (Pipe) import Database.MongoDB.Internal.Protocol (Pipe, newPipeWith)
import Database.MongoDB.Internal.Protocol (newPipeWith)
import Database.MongoDB.Transport (Transport(Transport)) import Database.MongoDB.Transport (Transport(Transport))
import qualified Database.MongoDB.Transport as T import qualified Database.MongoDB.Transport as T
import System.IO.Error (mkIOError, eofErrorType) import System.IO.Error (mkIOError, eofErrorType)