Open ReplicaSets over TLS

This commit is contained in:
Victor Denisov 2020-01-01 20:34:31 -08:00
commit 73cae15466
4 changed files with 107 additions and 20 deletions

View file

@ -17,7 +17,8 @@ module Database.MongoDB.Connection (
Host(..), PortID(..), defaultPort, host, showHostPort, readHostPort,
readHostPortM, globalConnectTimeout, connect, connect',
-- * Replica Set
ReplicaSetName, openReplicaSet, openReplicaSet',
ReplicaSetName, openReplicaSet, openReplicaSet', openReplicaSetTLS, openReplicaSetTLS',
openReplicaSetSRV, openReplicaSetSRV', openReplicaSetSRV'', openReplicaSetSRV''',
ReplicaSet, primary, secondaryOk, routedHost, closeReplicaSet, replSetName
) where
@ -49,12 +50,13 @@ import Data.Text (Text)
import qualified Data.Bson as B
import qualified Data.Text as T
import Database.MongoDB.Internal.Network (HostName, PortID(..), connectTo)
import Database.MongoDB.Internal.Network (Host(..), HostName, PortID(..), connectTo, lookupSeedList, lookupReplicaSetName)
import Database.MongoDB.Internal.Protocol (Pipe, newPipe, close, isClosed)
import Database.MongoDB.Internal.Util (untilSuccess, liftIOE,
updateAssocs, shuffle, mergesortM)
import Database.MongoDB.Query (Command, Failure(ConnectionFailure), access,
slaveOk, runCommand, retrieveServerData)
import qualified Database.MongoDB.Transport.Tls as TLS (connect)
adminCommand :: Command -> Pipe -> IO Document
-- ^ Run command against admin database on server connected to pipe. Fail if connection fails.
@ -64,10 +66,6 @@ adminCommand cmd pipe =
failureToIOError (ConnectionFailure e) = e
failureToIOError e = userError $ show e
-- * Host
data Host = Host HostName PortID deriving (Show, Eq, Ord)
defaultPort :: PortID
-- ^ Default MongoDB port = 27017
defaultPort = PortNumber 27017
@ -133,12 +131,14 @@ connect' timeoutSecs (Host hostname port) = do
type ReplicaSetName = Text
data TransportSecurity = Secure | Unsecure
-- | Maintains a connection (created on demand) to each server in the named replica set
data ReplicaSet = ReplicaSet ReplicaSetName (MVar [(Host, Maybe Pipe)]) Secs
data ReplicaSet = ReplicaSet ReplicaSetName (MVar [(Host, Maybe Pipe)]) Secs TransportSecurity
replSetName :: ReplicaSet -> Text
-- ^ name of connected replica set
replSetName (ReplicaSet rsName _ _) = rsName
replSetName (ReplicaSet rsName _ _ _) = rsName
openReplicaSet :: (ReplicaSetName, [Host]) -> IO ReplicaSet
-- ^ Open connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSet\'' instead.
@ -146,19 +146,62 @@ openReplicaSet rsSeed = readIORef globalConnectTimeout >>= flip openReplicaSet'
openReplicaSet' :: Secs -> (ReplicaSetName, [Host]) -> IO ReplicaSet
-- ^ Open connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. Supplied seconds timeout is used for connect attempts to members.
openReplicaSet' timeoutSecs (rsName, seedList) = do
openReplicaSet' timeoutSecs (rs, hosts) = _openReplicaSet timeoutSecs (rs, hosts, Unsecure)
openReplicaSetTLS :: (ReplicaSetName, [Host]) -> IO ReplicaSet
-- ^ Open secure connections (on demand) to servers in the replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSetTLS\'' instead.
openReplicaSetTLS rsSeed = readIORef globalConnectTimeout >>= flip openReplicaSetTLS' rsSeed
openReplicaSetTLS' :: Secs -> (ReplicaSetName, [Host]) -> IO ReplicaSet
-- ^ Open secure connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. Supplied seconds timeout is used for connect attempts to members.
openReplicaSetTLS' timeoutSecs (rs, hosts) = _openReplicaSet timeoutSecs (rs, hosts, Secure)
_openReplicaSet :: Secs -> (ReplicaSetName, [Host], TransportSecurity) -> IO ReplicaSet
_openReplicaSet timeoutSecs (rsName, seedList, transportSecurity) = do
vMembers <- newMVar (map (, Nothing) seedList)
let rs = ReplicaSet rsName vMembers timeoutSecs
let rs = ReplicaSet rsName vMembers timeoutSecs transportSecurity
_ <- updateMembers rs
return rs
openReplicaSetSRV :: HostName -> IO ReplicaSet
-- ^ Open non-secure connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSetSRV\'\'\'' instead.
openReplicaSetSRV hostname = do
timeoutSecs <- readIORef globalConnectTimeout
_openReplicaSetSRV timeoutSecs Unsecure hostname
openReplicaSetSRV' :: HostName -> IO ReplicaSet
-- ^ Open secure connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSetSRV\'\'\'\'' instead.
openReplicaSetSRV' hostname = do
timeoutSecs <- readIORef globalConnectTimeout
_openReplicaSetSRV timeoutSecs Secure hostname
openReplicaSetSRV'' :: Secs -> HostName -> IO ReplicaSet
-- ^ Open non-secure connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. Supplied seconds timeout is used for connect attempts to members.
openReplicaSetSRV'' timeoutSecs = _openReplicaSetSRV timeoutSecs Unsecure
openReplicaSetSRV''' :: Secs -> HostName -> IO ReplicaSet
-- ^ Open secure connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. Supplied seconds timeout is used for connect attempts to members.
openReplicaSetSRV''' timeoutSecs = _openReplicaSetSRV timeoutSecs Secure
_openReplicaSetSRV :: Secs -> TransportSecurity -> HostName -> IO ReplicaSet
_openReplicaSetSRV timeoutSecs transportSecurity hostname = do
replicaSetName <- lookupReplicaSetName hostname
hosts <- lookupSeedList hostname
case (replicaSetName, hosts) of
(Nothing, _) -> throwError $ userError "Failed to lookup replica set name"
(_, []) -> throwError $ userError "Failed to lookup replica set seedlist"
(Just rsName, _) ->
case transportSecurity of
Secure -> openReplicaSetTLS' timeoutSecs (rsName, hosts)
Unsecure -> openReplicaSet' timeoutSecs (rsName, hosts)
closeReplicaSet :: ReplicaSet -> IO ()
-- ^ Close all connections to replica set
closeReplicaSet (ReplicaSet _ vMembers _) = withMVar vMembers $ mapM_ (maybe (return ()) close . snd)
closeReplicaSet (ReplicaSet _ vMembers _ _) = withMVar vMembers $ mapM_ (maybe (return ()) close . snd)
primary :: ReplicaSet -> IO Pipe
-- ^ Return connection to current primary of replica set. Fail if no primary available.
primary rs@(ReplicaSet rsName _ _) = do
primary rs@(ReplicaSet rsName _ _ _) = do
mHost <- statedPrimary <$> updateMembers rs
case mHost of
Just host' -> connection rs Nothing host'
@ -194,7 +237,7 @@ possibleHosts (_, info) = map readHostPort $ at "hosts" info
updateMembers :: ReplicaSet -> IO ReplicaInfo
-- ^ Fetch replica info from any server and update members accordingly
updateMembers rs@(ReplicaSet _ vMembers _) = do
updateMembers rs@(ReplicaSet _ vMembers _ _) = do
(host', info) <- untilSuccess (fetchReplicaInfo rs) =<< readMVar vMembers
modifyMVar vMembers $ \members -> do
let ((members', old), new) = intersection (map readHostPort $ at "hosts" info) members
@ -208,7 +251,7 @@ updateMembers rs@(ReplicaSet _ vMembers _) = do
fetchReplicaInfo :: ReplicaSet -> (Host, Maybe Pipe) -> IO ReplicaInfo
-- Connect to host and fetch replica info from host creating new connection if missing or closed (previously failed). Fail if not member of named replica set.
fetchReplicaInfo rs@(ReplicaSet rsName _ _) (host', mPipe) = do
fetchReplicaInfo rs@(ReplicaSet rsName _ _ _) (host', mPipe) = do
pipe <- connection rs mPipe host'
info <- adminCommand ["isMaster" =: (1 :: Int)] pipe
case B.lookup "setName" info of
@ -218,11 +261,15 @@ fetchReplicaInfo rs@(ReplicaSet rsName _ _) (host', mPipe) = do
connection :: ReplicaSet -> Maybe Pipe -> Host -> IO Pipe
-- ^ Return new or existing connection to member of replica set. If pipe is already known for host it is given, but we still test if it is open.
connection (ReplicaSet _ vMembers timeoutSecs) mPipe host' =
connection (ReplicaSet _ vMembers timeoutSecs transportSecurity) mPipe host' =
maybe conn (\p -> isClosed p >>= \bad -> if bad then conn else return p) mPipe
where
conn = modifyMVar vMembers $ \members -> do
let new = connect' timeoutSecs host' >>= \pipe -> return (updateAssocs host' (Just pipe) members, pipe)
let (Host h p) = host'
let conn' = case transportSecurity of
Secure -> TLS.connect h p
Unsecure -> connect' timeoutSecs host'
let new = conn' >>= \pipe -> return (updateAssocs host' (Just pipe) members, pipe)
case List.lookup host' members of
Just (Just pipe) -> isClosed pipe >>= \bad -> if bad then new else return (members, pipe)
_ -> new

View file

@ -1,7 +1,8 @@
-- | Compatibility layer for network package, including newtype 'PortID'
{-# LANGUAGE CPP, GeneralizedNewtypeDeriving #-}
{-# LANGUAGE CPP, GeneralizedNewtypeDeriving, OverloadedStrings #-}
module Database.MongoDB.Internal.Network (PortID(..), N.HostName, connectTo) where
module Database.MongoDB.Internal.Network (Host(..), PortID(..), N.HostName, connectTo,
lookupReplicaSetName, lookupSeedList) where
#if !MIN_VERSION_network(2, 9, 0)
@ -18,6 +19,14 @@ import System.IO (Handle, IOMode(ReadWriteMode))
#endif
import Data.ByteString.Char8 (pack, unpack)
import Data.List (dropWhileEnd, lookup)
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import Network.DNS.Lookup (lookupSRV, lookupTXT)
import Network.DNS.Resolver (defaultResolvConf, makeResolvSeed, withResolver)
import Network.HTTP.Types.URI (parseQueryText)
-- | Wraps network's 'PortNumber'
-- Used to ease compatibility between older and newer network versions.
@ -70,3 +79,29 @@ connectTo _ (UnixSocket path) = do
#endif
#endif
-- * Host
data Host = Host N.HostName PortID deriving (Show, Eq, Ord)
lookupReplicaSetName :: N.HostName -> IO (Maybe Text)
-- ^ Retrieves the replica set name from the TXT DNS record for the given hostname
lookupReplicaSetName hostname = do
rs <- makeResolvSeed defaultResolvConf
res <- withResolver rs $ \resolver -> lookupTXT resolver (pack hostname)
case res of
Left _ -> pure Nothing
Right [] -> pure Nothing
Right (x:_) ->
pure $ fromMaybe (Nothing :: Maybe Text) (lookup "replicaSet" $ parseQueryText x)
lookupSeedList :: N.HostName -> IO [Host]
-- ^ Retrieves the replica set seed list from the SRV DNS record for the given hostname
lookupSeedList hostname = do
rs <- makeResolvSeed defaultResolvConf
res <- withResolver rs $ \resolver -> lookupSRV resolver $ pack $ "_mongodb._tcp." ++ hostname
case res of
Left _ -> pure []
Right srv -> pure $ map (\(_, _, por, tar) ->
let tar' = dropWhileEnd (=='.') (unpack tar)
in Host tar' (PortNumber . fromIntegral $ por)) srv

View file

@ -34,8 +34,7 @@ import Control.Applicative ((<$>))
import Control.Exception (bracketOnError)
import Control.Monad (when, unless)
import System.IO
import Database.MongoDB (Pipe)
import Database.MongoDB.Internal.Protocol (newPipeWith)
import Database.MongoDB.Internal.Protocol (Pipe, newPipeWith)
import Database.MongoDB.Transport (Transport(Transport))
import qualified Database.MongoDB.Transport as T
import System.IO.Error (mkIOError, eofErrorType)

View file

@ -58,6 +58,8 @@ Library
, base64-bytestring >= 1.0.0.1
, nonce >= 1.0.5
, fail
, dns
, http-types
if flag(_old-network)
-- "Network.BSD" is only available in network < 2.9
@ -128,7 +130,11 @@ Benchmark bench
, lifted-base >= 0.1.0.3
, transformers-base >= 0.4.1
, hashtables >= 1.1.2.0
, fail
, dns
, http-types
, criterion
, tls >= 1.3.0
if flag(_old-network)
-- "Network.BSD" is only available in network < 2.9