Add support for opening replica sets using v3.6-style connection strings

This commit is contained in:
Neil Cowburn 2019-11-01 17:00:06 +00:00
parent d334d889ee
commit bcfbcc2918
No known key found for this signature in database
GPG key ID: 4A07D4C2E3F34E66
3 changed files with 71 additions and 8 deletions

View file

@ -18,6 +18,7 @@ module Database.MongoDB.Connection (
readHostPortM, globalConnectTimeout, connect, connect',
-- * Replica Set
ReplicaSetName, openReplicaSet, openReplicaSet', openReplicaSetTLS, openReplicaSetTLS',
openReplicaSetSRV, openReplicaSetSRV', openReplicaSetSRV'', openReplicaSetSRV''',
ReplicaSet, primary, secondaryOk, routedHost, closeReplicaSet, replSetName
) where
@ -55,9 +56,6 @@ import Database.MongoDB.Query (Command, Failure(ConnectionFailure), access,
slaveOk, runCommand, retrieveServerData)
import qualified Database.MongoDB.Transport.Tls as TLS (connect)
flip' :: (a -> b -> c -> d) -> b -> c -> a -> d
flip' f x y z = f z x y
adminCommand :: Command -> Pipe -> IO Document
-- ^ Run command against admin database on server connected to pipe. Fail if connection fails.
adminCommand cmd pipe =
@ -124,7 +122,7 @@ connect' timeoutSecs (Host hostname port) = do
type ReplicaSetName = Text
data TransportSecurity = Secure | Insecure
data TransportSecurity = Secure | Unsecure
-- | Maintains a connection (created on demand) to each server in the named replica set
data ReplicaSet = ReplicaSet ReplicaSetName (MVar [(Host, Maybe Pipe)]) Secs TransportSecurity
@ -139,7 +137,7 @@ openReplicaSet rsSeed = readIORef globalConnectTimeout >>= flip openReplicaSet'
openReplicaSet' :: Secs -> (ReplicaSetName, [Host]) -> IO ReplicaSet
-- ^ Open connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. Supplied seconds timeout is used for connect attempts to members.
openReplicaSet' timeoutSecs (rs, hosts) = _openReplicaSet timeoutSecs (rs, hosts, Insecure)
openReplicaSet' timeoutSecs (rs, hosts) = _openReplicaSet timeoutSecs (rs, hosts, Unsecure)
openReplicaSetTLS :: (ReplicaSetName, [Host]) -> IO ReplicaSet
-- ^ Open secure connections (on demand) to servers in the replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSetTLS\'' instead.
@ -156,6 +154,38 @@ _openReplicaSet timeoutSecs (rsName, seedList, transportSecurity) = do
_ <- updateMembers rs
return rs
openReplicaSetSRV :: HostName -> IO ReplicaSet
-- ^ Open non-secure connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSetSRV\'\'\'' instead.
openReplicaSetSRV hostname = do
timeoutSecs <- readIORef globalConnectTimeout
_openReplicaSetSRV timeoutSecs Unsecure hostname
openReplicaSetSRV' :: HostName -> IO ReplicaSet
-- ^ Open secure connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSetSRV\'\'\'\'' instead.
openReplicaSetSRV' hostname = do
timeoutSecs <- readIORef globalConnectTimeout
_openReplicaSetSRV timeoutSecs Secure hostname
openReplicaSetSRV'' :: Secs -> HostName -> IO ReplicaSet
-- ^ Open non-secure connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. Supplied seconds timeout is used for connect attempts to members.
openReplicaSetSRV'' timeoutSecs = _openReplicaSetSRV timeoutSecs Unsecure
openReplicaSetSRV''' :: Secs -> HostName -> IO ReplicaSet
-- ^ Open secure connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. Supplied seconds timeout is used for connect attempts to members.
openReplicaSetSRV''' timeoutSecs = _openReplicaSetSRV timeoutSecs Secure
_openReplicaSetSRV :: Secs -> TransportSecurity -> HostName -> IO ReplicaSet
_openReplicaSetSRV timeoutSecs transportSecurity hostname = do
replicaSetName <- lookupReplicaSetName hostname
hosts <- lookupSeedList hostname
case (replicaSetName, hosts) of
(Nothing, _) -> throwError $ userError "Failed to lookup replica set name"
(_, []) -> throwError $ userError "Failed to lookup replica set seedlist"
(Just rsName, _) ->
case transportSecurity of
Secure -> openReplicaSetTLS' timeoutSecs (rsName, hosts)
Unsecure -> openReplicaSet' timeoutSecs (rsName, hosts)
closeReplicaSet :: ReplicaSet -> IO ()
-- ^ Close all connections to replica set
closeReplicaSet (ReplicaSet _ vMembers _ _) = withMVar vMembers $ mapM_ (maybe (return ()) close . snd)
@ -229,7 +259,7 @@ connection (ReplicaSet _ vMembers timeoutSecs transportSecurity) mPipe host' =
let (Host h p) = host'
let conn' = case transportSecurity of
Secure -> TLS.connect h p
Insecure -> connect' timeoutSecs host'
Unsecure -> connect' timeoutSecs host'
let new = conn' >>= \pipe -> return (updateAssocs host' (Just pipe) members, pipe)
case List.lookup host' members of
Just (Just pipe) -> isClosed pipe >>= \bad -> if bad then new else return (members, pipe)

View file

@ -1,7 +1,8 @@
-- | Compatibility layer for network package, including newtype 'PortID'
{-# LANGUAGE CPP, GeneralizedNewtypeDeriving #-}
{-# LANGUAGE CPP, GeneralizedNewtypeDeriving, OverloadedStrings #-}
module Database.MongoDB.Internal.Network (Host(..), PortID(..), N.HostName, connectTo) where
module Database.MongoDB.Internal.Network (Host(..), PortID(..), N.HostName, connectTo,
lookupReplicaSetName, lookupSeedList) where
#if !MIN_VERSION_network(2, 9, 0)
@ -18,6 +19,14 @@ import System.IO (Handle, IOMode(ReadWriteMode))
#endif
import Data.ByteString.Char8 (pack, unpack)
import Data.List (dropWhileEnd, lookup)
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import Network.DNS.Lookup (lookupSRV, lookupTXT)
import Network.DNS.Resolver (defaultResolvConf, makeResolvSeed, withResolver)
import Network.HTTP.Types.URI (parseQueryText)
-- | Wraps network's 'PortNumber'
-- Used to ease compatibility between older and newer network versions.
@ -54,3 +63,25 @@ connectTo hostname (PortNumber port) = do
-- * Host
data Host = Host N.HostName PortID deriving (Show, Eq, Ord)
lookupReplicaSetName :: N.HostName -> IO (Maybe Text)
-- ^ Retrieves the replica set name from the TXT DNS record for the given hostname
lookupReplicaSetName hostname = do
rs <- makeResolvSeed defaultResolvConf
res <- withResolver rs $ \resolver -> lookupTXT resolver (pack hostname)
case res of
Left _ -> pure Nothing
Right [] -> pure Nothing
Right (x:_) ->
pure $ fromMaybe (Nothing :: Maybe Text) (lookup "replicaSet" $ parseQueryText x)
lookupSeedList :: N.HostName -> IO [Host]
-- ^ Retrieves the replica set seed list from the SRV DNS record for the given hostname
lookupSeedList hostname = do
rs <- makeResolvSeed defaultResolvConf
res <- withResolver rs $ \resolver -> lookupSRV resolver $ "_mongodb._tcp." ++ pack hostname
case res of
Left _ -> pure []
Right srv -> pure $ map (\(_, _, por, tar) ->
let tar' = dropWhileEnd (=='.') (unpack tar)
in Host tar' (PortNumber . fromIntegral $ por)) srv

View file

@ -57,6 +57,8 @@ Library
, base16-bytestring >= 0.1.1.6
, base64-bytestring >= 1.0.0.1
, nonce >= 1.0.5
, dns
, http-types
if flag(_old-network)
-- "Network.BSD" is only available in network < 2.9