Open ReplicaSets over TLS

This commit is contained in:
Victor Denisov 2020-01-01 20:34:31 -08:00
commit 73cae15466
4 changed files with 107 additions and 20 deletions

View file

@ -17,7 +17,8 @@ module Database.MongoDB.Connection (
Host(..), PortID(..), defaultPort, host, showHostPort, readHostPort, Host(..), PortID(..), defaultPort, host, showHostPort, readHostPort,
readHostPortM, globalConnectTimeout, connect, connect', readHostPortM, globalConnectTimeout, connect, connect',
-- * Replica Set -- * Replica Set
ReplicaSetName, openReplicaSet, openReplicaSet', ReplicaSetName, openReplicaSet, openReplicaSet', openReplicaSetTLS, openReplicaSetTLS',
openReplicaSetSRV, openReplicaSetSRV', openReplicaSetSRV'', openReplicaSetSRV''',
ReplicaSet, primary, secondaryOk, routedHost, closeReplicaSet, replSetName ReplicaSet, primary, secondaryOk, routedHost, closeReplicaSet, replSetName
) where ) where
@ -49,12 +50,13 @@ import Data.Text (Text)
import qualified Data.Bson as B import qualified Data.Bson as B
import qualified Data.Text as T import qualified Data.Text as T
import Database.MongoDB.Internal.Network (HostName, PortID(..), connectTo) import Database.MongoDB.Internal.Network (Host(..), HostName, PortID(..), connectTo, lookupSeedList, lookupReplicaSetName)
import Database.MongoDB.Internal.Protocol (Pipe, newPipe, close, isClosed) import Database.MongoDB.Internal.Protocol (Pipe, newPipe, close, isClosed)
import Database.MongoDB.Internal.Util (untilSuccess, liftIOE, import Database.MongoDB.Internal.Util (untilSuccess, liftIOE,
updateAssocs, shuffle, mergesortM) updateAssocs, shuffle, mergesortM)
import Database.MongoDB.Query (Command, Failure(ConnectionFailure), access, import Database.MongoDB.Query (Command, Failure(ConnectionFailure), access,
slaveOk, runCommand, retrieveServerData) slaveOk, runCommand, retrieveServerData)
import qualified Database.MongoDB.Transport.Tls as TLS (connect)
adminCommand :: Command -> Pipe -> IO Document adminCommand :: Command -> Pipe -> IO Document
-- ^ Run command against admin database on server connected to pipe. Fail if connection fails. -- ^ Run command against admin database on server connected to pipe. Fail if connection fails.
@ -64,10 +66,6 @@ adminCommand cmd pipe =
failureToIOError (ConnectionFailure e) = e failureToIOError (ConnectionFailure e) = e
failureToIOError e = userError $ show e failureToIOError e = userError $ show e
-- * Host
data Host = Host HostName PortID deriving (Show, Eq, Ord)
defaultPort :: PortID defaultPort :: PortID
-- ^ Default MongoDB port = 27017 -- ^ Default MongoDB port = 27017
defaultPort = PortNumber 27017 defaultPort = PortNumber 27017
@ -133,12 +131,14 @@ connect' timeoutSecs (Host hostname port) = do
type ReplicaSetName = Text type ReplicaSetName = Text
data TransportSecurity = Secure | Unsecure
-- | Maintains a connection (created on demand) to each server in the named replica set -- | Maintains a connection (created on demand) to each server in the named replica set
data ReplicaSet = ReplicaSet ReplicaSetName (MVar [(Host, Maybe Pipe)]) Secs data ReplicaSet = ReplicaSet ReplicaSetName (MVar [(Host, Maybe Pipe)]) Secs TransportSecurity
replSetName :: ReplicaSet -> Text replSetName :: ReplicaSet -> Text
-- ^ name of connected replica set -- ^ name of connected replica set
replSetName (ReplicaSet rsName _ _) = rsName replSetName (ReplicaSet rsName _ _ _) = rsName
openReplicaSet :: (ReplicaSetName, [Host]) -> IO ReplicaSet openReplicaSet :: (ReplicaSetName, [Host]) -> IO ReplicaSet
-- ^ Open connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSet\'' instead. -- ^ Open connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSet\'' instead.
@ -146,19 +146,62 @@ openReplicaSet rsSeed = readIORef globalConnectTimeout >>= flip openReplicaSet'
openReplicaSet' :: Secs -> (ReplicaSetName, [Host]) -> IO ReplicaSet openReplicaSet' :: Secs -> (ReplicaSetName, [Host]) -> IO ReplicaSet
-- ^ Open connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. Supplied seconds timeout is used for connect attempts to members. -- ^ Open connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. Supplied seconds timeout is used for connect attempts to members.
openReplicaSet' timeoutSecs (rsName, seedList) = do openReplicaSet' timeoutSecs (rs, hosts) = _openReplicaSet timeoutSecs (rs, hosts, Unsecure)
openReplicaSetTLS :: (ReplicaSetName, [Host]) -> IO ReplicaSet
-- ^ Open secure connections (on demand) to servers in the replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSetTLS\'' instead.
openReplicaSetTLS rsSeed = readIORef globalConnectTimeout >>= flip openReplicaSetTLS' rsSeed
openReplicaSetTLS' :: Secs -> (ReplicaSetName, [Host]) -> IO ReplicaSet
-- ^ Open secure connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. Supplied seconds timeout is used for connect attempts to members.
openReplicaSetTLS' timeoutSecs (rs, hosts) = _openReplicaSet timeoutSecs (rs, hosts, Secure)
_openReplicaSet :: Secs -> (ReplicaSetName, [Host], TransportSecurity) -> IO ReplicaSet
_openReplicaSet timeoutSecs (rsName, seedList, transportSecurity) = do
vMembers <- newMVar (map (, Nothing) seedList) vMembers <- newMVar (map (, Nothing) seedList)
let rs = ReplicaSet rsName vMembers timeoutSecs let rs = ReplicaSet rsName vMembers timeoutSecs transportSecurity
_ <- updateMembers rs _ <- updateMembers rs
return rs return rs
openReplicaSetSRV :: HostName -> IO ReplicaSet
-- ^ Open non-secure connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSetSRV\'\'\'' instead.
openReplicaSetSRV hostname = do
timeoutSecs <- readIORef globalConnectTimeout
_openReplicaSetSRV timeoutSecs Unsecure hostname
openReplicaSetSRV' :: HostName -> IO ReplicaSet
-- ^ Open secure connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSetSRV\'\'\'\'' instead.
openReplicaSetSRV' hostname = do
timeoutSecs <- readIORef globalConnectTimeout
_openReplicaSetSRV timeoutSecs Secure hostname
openReplicaSetSRV'' :: Secs -> HostName -> IO ReplicaSet
-- ^ Open non-secure connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. Supplied seconds timeout is used for connect attempts to members.
openReplicaSetSRV'' timeoutSecs = _openReplicaSetSRV timeoutSecs Unsecure
openReplicaSetSRV''' :: Secs -> HostName -> IO ReplicaSet
-- ^ Open secure connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. Supplied seconds timeout is used for connect attempts to members.
openReplicaSetSRV''' timeoutSecs = _openReplicaSetSRV timeoutSecs Secure
_openReplicaSetSRV :: Secs -> TransportSecurity -> HostName -> IO ReplicaSet
_openReplicaSetSRV timeoutSecs transportSecurity hostname = do
replicaSetName <- lookupReplicaSetName hostname
hosts <- lookupSeedList hostname
case (replicaSetName, hosts) of
(Nothing, _) -> throwError $ userError "Failed to lookup replica set name"
(_, []) -> throwError $ userError "Failed to lookup replica set seedlist"
(Just rsName, _) ->
case transportSecurity of
Secure -> openReplicaSetTLS' timeoutSecs (rsName, hosts)
Unsecure -> openReplicaSet' timeoutSecs (rsName, hosts)
closeReplicaSet :: ReplicaSet -> IO () closeReplicaSet :: ReplicaSet -> IO ()
-- ^ Close all connections to replica set -- ^ Close all connections to replica set
closeReplicaSet (ReplicaSet _ vMembers _) = withMVar vMembers $ mapM_ (maybe (return ()) close . snd) closeReplicaSet (ReplicaSet _ vMembers _ _) = withMVar vMembers $ mapM_ (maybe (return ()) close . snd)
primary :: ReplicaSet -> IO Pipe primary :: ReplicaSet -> IO Pipe
-- ^ Return connection to current primary of replica set. Fail if no primary available. -- ^ Return connection to current primary of replica set. Fail if no primary available.
primary rs@(ReplicaSet rsName _ _) = do primary rs@(ReplicaSet rsName _ _ _) = do
mHost <- statedPrimary <$> updateMembers rs mHost <- statedPrimary <$> updateMembers rs
case mHost of case mHost of
Just host' -> connection rs Nothing host' Just host' -> connection rs Nothing host'
@ -194,7 +237,7 @@ possibleHosts (_, info) = map readHostPort $ at "hosts" info
updateMembers :: ReplicaSet -> IO ReplicaInfo updateMembers :: ReplicaSet -> IO ReplicaInfo
-- ^ Fetch replica info from any server and update members accordingly -- ^ Fetch replica info from any server and update members accordingly
updateMembers rs@(ReplicaSet _ vMembers _) = do updateMembers rs@(ReplicaSet _ vMembers _ _) = do
(host', info) <- untilSuccess (fetchReplicaInfo rs) =<< readMVar vMembers (host', info) <- untilSuccess (fetchReplicaInfo rs) =<< readMVar vMembers
modifyMVar vMembers $ \members -> do modifyMVar vMembers $ \members -> do
let ((members', old), new) = intersection (map readHostPort $ at "hosts" info) members let ((members', old), new) = intersection (map readHostPort $ at "hosts" info) members
@ -208,7 +251,7 @@ updateMembers rs@(ReplicaSet _ vMembers _) = do
fetchReplicaInfo :: ReplicaSet -> (Host, Maybe Pipe) -> IO ReplicaInfo fetchReplicaInfo :: ReplicaSet -> (Host, Maybe Pipe) -> IO ReplicaInfo
-- Connect to host and fetch replica info from host creating new connection if missing or closed (previously failed). Fail if not member of named replica set. -- Connect to host and fetch replica info from host creating new connection if missing or closed (previously failed). Fail if not member of named replica set.
fetchReplicaInfo rs@(ReplicaSet rsName _ _) (host', mPipe) = do fetchReplicaInfo rs@(ReplicaSet rsName _ _ _) (host', mPipe) = do
pipe <- connection rs mPipe host' pipe <- connection rs mPipe host'
info <- adminCommand ["isMaster" =: (1 :: Int)] pipe info <- adminCommand ["isMaster" =: (1 :: Int)] pipe
case B.lookup "setName" info of case B.lookup "setName" info of
@ -218,11 +261,15 @@ fetchReplicaInfo rs@(ReplicaSet rsName _ _) (host', mPipe) = do
connection :: ReplicaSet -> Maybe Pipe -> Host -> IO Pipe connection :: ReplicaSet -> Maybe Pipe -> Host -> IO Pipe
-- ^ Return new or existing connection to member of replica set. If pipe is already known for host it is given, but we still test if it is open. -- ^ Return new or existing connection to member of replica set. If pipe is already known for host it is given, but we still test if it is open.
connection (ReplicaSet _ vMembers timeoutSecs) mPipe host' = connection (ReplicaSet _ vMembers timeoutSecs transportSecurity) mPipe host' =
maybe conn (\p -> isClosed p >>= \bad -> if bad then conn else return p) mPipe maybe conn (\p -> isClosed p >>= \bad -> if bad then conn else return p) mPipe
where where
conn = modifyMVar vMembers $ \members -> do conn = modifyMVar vMembers $ \members -> do
let new = connect' timeoutSecs host' >>= \pipe -> return (updateAssocs host' (Just pipe) members, pipe) let (Host h p) = host'
let conn' = case transportSecurity of
Secure -> TLS.connect h p
Unsecure -> connect' timeoutSecs host'
let new = conn' >>= \pipe -> return (updateAssocs host' (Just pipe) members, pipe)
case List.lookup host' members of case List.lookup host' members of
Just (Just pipe) -> isClosed pipe >>= \bad -> if bad then new else return (members, pipe) Just (Just pipe) -> isClosed pipe >>= \bad -> if bad then new else return (members, pipe)
_ -> new _ -> new

View file

@ -1,7 +1,8 @@
-- | Compatibility layer for network package, including newtype 'PortID' -- | Compatibility layer for network package, including newtype 'PortID'
{-# LANGUAGE CPP, GeneralizedNewtypeDeriving #-} {-# LANGUAGE CPP, GeneralizedNewtypeDeriving, OverloadedStrings #-}
module Database.MongoDB.Internal.Network (PortID(..), N.HostName, connectTo) where module Database.MongoDB.Internal.Network (Host(..), PortID(..), N.HostName, connectTo,
lookupReplicaSetName, lookupSeedList) where
#if !MIN_VERSION_network(2, 9, 0) #if !MIN_VERSION_network(2, 9, 0)
@ -18,6 +19,14 @@ import System.IO (Handle, IOMode(ReadWriteMode))
#endif #endif
import Data.ByteString.Char8 (pack, unpack)
import Data.List (dropWhileEnd, lookup)
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import Network.DNS.Lookup (lookupSRV, lookupTXT)
import Network.DNS.Resolver (defaultResolvConf, makeResolvSeed, withResolver)
import Network.HTTP.Types.URI (parseQueryText)
-- | Wraps network's 'PortNumber' -- | Wraps network's 'PortNumber'
-- Used to ease compatibility between older and newer network versions. -- Used to ease compatibility between older and newer network versions.
@ -70,3 +79,29 @@ connectTo _ (UnixSocket path) = do
#endif #endif
#endif #endif
-- * Host
data Host = Host N.HostName PortID deriving (Show, Eq, Ord)
lookupReplicaSetName :: N.HostName -> IO (Maybe Text)
-- ^ Retrieves the replica set name from the TXT DNS record for the given hostname
lookupReplicaSetName hostname = do
rs <- makeResolvSeed defaultResolvConf
res <- withResolver rs $ \resolver -> lookupTXT resolver (pack hostname)
case res of
Left _ -> pure Nothing
Right [] -> pure Nothing
Right (x:_) ->
pure $ fromMaybe (Nothing :: Maybe Text) (lookup "replicaSet" $ parseQueryText x)
lookupSeedList :: N.HostName -> IO [Host]
-- ^ Retrieves the replica set seed list from the SRV DNS record for the given hostname
lookupSeedList hostname = do
rs <- makeResolvSeed defaultResolvConf
res <- withResolver rs $ \resolver -> lookupSRV resolver $ pack $ "_mongodb._tcp." ++ hostname
case res of
Left _ -> pure []
Right srv -> pure $ map (\(_, _, por, tar) ->
let tar' = dropWhileEnd (=='.') (unpack tar)
in Host tar' (PortNumber . fromIntegral $ por)) srv

View file

@ -34,8 +34,7 @@ import Control.Applicative ((<$>))
import Control.Exception (bracketOnError) import Control.Exception (bracketOnError)
import Control.Monad (when, unless) import Control.Monad (when, unless)
import System.IO import System.IO
import Database.MongoDB (Pipe) import Database.MongoDB.Internal.Protocol (Pipe, newPipeWith)
import Database.MongoDB.Internal.Protocol (newPipeWith)
import Database.MongoDB.Transport (Transport(Transport)) import Database.MongoDB.Transport (Transport(Transport))
import qualified Database.MongoDB.Transport as T import qualified Database.MongoDB.Transport as T
import System.IO.Error (mkIOError, eofErrorType) import System.IO.Error (mkIOError, eofErrorType)

View file

@ -58,6 +58,8 @@ Library
, base64-bytestring >= 1.0.0.1 , base64-bytestring >= 1.0.0.1
, nonce >= 1.0.5 , nonce >= 1.0.5
, fail , fail
, dns
, http-types
if flag(_old-network) if flag(_old-network)
-- "Network.BSD" is only available in network < 2.9 -- "Network.BSD" is only available in network < 2.9
@ -128,7 +130,11 @@ Benchmark bench
, lifted-base >= 0.1.0.3 , lifted-base >= 0.1.0.3
, transformers-base >= 0.4.1 , transformers-base >= 0.4.1
, hashtables >= 1.1.2.0 , hashtables >= 1.1.2.0
, fail
, dns
, http-types
, criterion , criterion
, tls >= 1.3.0
if flag(_old-network) if flag(_old-network)
-- "Network.BSD" is only available in network < 2.9 -- "Network.BSD" is only available in network < 2.9