Add support for opening replica sets over TLS

This commit is contained in:
Neil Cowburn 2019-11-01 16:55:59 +00:00
parent 76d5f84f8a
commit d334d889ee
No known key found for this signature in database
GPG key ID: 4A07D4C2E3F34E66
3 changed files with 39 additions and 19 deletions

View file

@ -17,7 +17,7 @@ module Database.MongoDB.Connection (
Host(..), PortID(..), defaultPort, host, showHostPort, readHostPort,
readHostPortM, globalConnectTimeout, connect, connect',
-- * Replica Set
ReplicaSetName, openReplicaSet, openReplicaSet',
ReplicaSetName, openReplicaSet, openReplicaSet', openReplicaSetTLS, openReplicaSetTLS',
ReplicaSet, primary, secondaryOk, routedHost, closeReplicaSet, replSetName
) where
@ -47,12 +47,16 @@ import Data.Text (Text)
import qualified Data.Bson as B
import qualified Data.Text as T
import Database.MongoDB.Internal.Network (HostName, PortID(..), connectTo)
import Database.MongoDB.Internal.Network (Host(..), HostName, PortID(..), connectTo, lookupSeedList, lookupReplicaSetName)
import Database.MongoDB.Internal.Protocol (Pipe, newPipe, close, isClosed)
import Database.MongoDB.Internal.Util (untilSuccess, liftIOE,
updateAssocs, shuffle, mergesortM)
import Database.MongoDB.Query (Command, Failure(ConnectionFailure), access,
slaveOk, runCommand, retrieveServerData)
import qualified Database.MongoDB.Transport.Tls as TLS (connect)
flip' :: (a -> b -> c -> d) -> b -> c -> a -> d
flip' f x y z = f z x y
adminCommand :: Command -> Pipe -> IO Document
-- ^ Run command against admin database on server connected to pipe. Fail if connection fails.
@ -62,10 +66,6 @@ adminCommand cmd pipe =
failureToIOError (ConnectionFailure e) = e
failureToIOError e = userError $ show e
-- * Host
data Host = Host HostName PortID deriving (Show, Eq, Ord)
defaultPort :: PortID
-- ^ Default MongoDB port = 27017
defaultPort = PortNumber 27017
@ -124,12 +124,14 @@ connect' timeoutSecs (Host hostname port) = do
type ReplicaSetName = Text
data TransportSecurity = Secure | Insecure
-- | Maintains a connection (created on demand) to each server in the named replica set
data ReplicaSet = ReplicaSet ReplicaSetName (MVar [(Host, Maybe Pipe)]) Secs
data ReplicaSet = ReplicaSet ReplicaSetName (MVar [(Host, Maybe Pipe)]) Secs TransportSecurity
replSetName :: ReplicaSet -> Text
-- ^ name of connected replica set
replSetName (ReplicaSet rsName _ _) = rsName
replSetName (ReplicaSet rsName _ _ _) = rsName
openReplicaSet :: (ReplicaSetName, [Host]) -> IO ReplicaSet
-- ^ Open connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSet\'' instead.
@ -137,19 +139,30 @@ openReplicaSet rsSeed = readIORef globalConnectTimeout >>= flip openReplicaSet'
openReplicaSet' :: Secs -> (ReplicaSetName, [Host]) -> IO ReplicaSet
-- ^ Open connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. Supplied seconds timeout is used for connect attempts to members.
openReplicaSet' timeoutSecs (rsName, seedList) = do
openReplicaSet' timeoutSecs (rs, hosts) = _openReplicaSet timeoutSecs (rs, hosts, Insecure)
openReplicaSetTLS :: (ReplicaSetName, [Host]) -> IO ReplicaSet
-- ^ Open secure connections (on demand) to servers in the replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSetTLS\'' instead.
openReplicaSetTLS rsSeed = readIORef globalConnectTimeout >>= flip openReplicaSetTLS' rsSeed
openReplicaSetTLS' :: Secs -> (ReplicaSetName, [Host]) -> IO ReplicaSet
-- ^ Open secure connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. Supplied seconds timeout is used for connect attempts to members.
openReplicaSetTLS' timeoutSecs (rs, hosts) = _openReplicaSet timeoutSecs (rs, hosts, Secure)
_openReplicaSet :: Secs -> (ReplicaSetName, [Host], TransportSecurity) -> IO ReplicaSet
_openReplicaSet timeoutSecs (rsName, seedList, transportSecurity) = do
vMembers <- newMVar (map (, Nothing) seedList)
let rs = ReplicaSet rsName vMembers timeoutSecs
let rs = ReplicaSet rsName vMembers timeoutSecs transportSecurity
_ <- updateMembers rs
return rs
closeReplicaSet :: ReplicaSet -> IO ()
-- ^ Close all connections to replica set
closeReplicaSet (ReplicaSet _ vMembers _) = withMVar vMembers $ mapM_ (maybe (return ()) close . snd)
closeReplicaSet (ReplicaSet _ vMembers _ _) = withMVar vMembers $ mapM_ (maybe (return ()) close . snd)
primary :: ReplicaSet -> IO Pipe
-- ^ Return connection to current primary of replica set. Fail if no primary available.
primary rs@(ReplicaSet rsName _ _) = do
primary rs@(ReplicaSet rsName _ _ _) = do
mHost <- statedPrimary <$> updateMembers rs
case mHost of
Just host' -> connection rs Nothing host'
@ -185,7 +198,7 @@ possibleHosts (_, info) = map readHostPort $ at "hosts" info
updateMembers :: ReplicaSet -> IO ReplicaInfo
-- ^ Fetch replica info from any server and update members accordingly
updateMembers rs@(ReplicaSet _ vMembers _) = do
updateMembers rs@(ReplicaSet _ vMembers _ _) = do
(host', info) <- untilSuccess (fetchReplicaInfo rs) =<< readMVar vMembers
modifyMVar vMembers $ \members -> do
let ((members', old), new) = intersection (map readHostPort $ at "hosts" info) members
@ -199,7 +212,7 @@ updateMembers rs@(ReplicaSet _ vMembers _) = do
fetchReplicaInfo :: ReplicaSet -> (Host, Maybe Pipe) -> IO ReplicaInfo
-- Connect to host and fetch replica info from host creating new connection if missing or closed (previously failed). Fail if not member of named replica set.
fetchReplicaInfo rs@(ReplicaSet rsName _ _) (host', mPipe) = do
fetchReplicaInfo rs@(ReplicaSet rsName _ _ _) (host', mPipe) = do
pipe <- connection rs mPipe host'
info <- adminCommand ["isMaster" =: (1 :: Int)] pipe
case B.lookup "setName" info of
@ -209,11 +222,15 @@ fetchReplicaInfo rs@(ReplicaSet rsName _ _) (host', mPipe) = do
connection :: ReplicaSet -> Maybe Pipe -> Host -> IO Pipe
-- ^ Return new or existing connection to member of replica set. If pipe is already known for host it is given, but we still test if it is open.
connection (ReplicaSet _ vMembers timeoutSecs) mPipe host' =
connection (ReplicaSet _ vMembers timeoutSecs transportSecurity) mPipe host' =
maybe conn (\p -> isClosed p >>= \bad -> if bad then conn else return p) mPipe
where
conn = modifyMVar vMembers $ \members -> do
let new = connect' timeoutSecs host' >>= \pipe -> return (updateAssocs host' (Just pipe) members, pipe)
let (Host h p) = host'
let conn' = case transportSecurity of
Secure -> TLS.connect h p
Insecure -> connect' timeoutSecs host'
let new = conn' >>= \pipe -> return (updateAssocs host' (Just pipe) members, pipe)
case List.lookup host' members of
Just (Just pipe) -> isClosed pipe >>= \bad -> if bad then new else return (members, pipe)
_ -> new

View file

@ -1,7 +1,7 @@
-- | Compatibility layer for network package, including newtype 'PortID'
{-# LANGUAGE CPP, GeneralizedNewtypeDeriving #-}
module Database.MongoDB.Internal.Network (PortID(..), N.HostName, connectTo) where
module Database.MongoDB.Internal.Network (Host(..), PortID(..), N.HostName, connectTo) where
#if !MIN_VERSION_network(2, 9, 0)
@ -50,3 +50,7 @@ connectTo hostname (PortNumber port) = do
N.socketToHandle sock ReadWriteMode
)
#endif
-- * Host
data Host = Host N.HostName PortID deriving (Show, Eq, Ord)

View file

@ -34,8 +34,7 @@ import Control.Applicative ((<$>))
import Control.Exception (bracketOnError)
import Control.Monad (when, unless)
import System.IO
import Database.MongoDB (Pipe)
import Database.MongoDB.Internal.Protocol (newPipeWith)
import Database.MongoDB.Internal.Protocol (Pipe, newPipeWith)
import Database.MongoDB.Transport (Transport(Transport))
import qualified Database.MongoDB.Transport as T
import System.IO.Error (mkIOError, eofErrorType)