From d334d889ee64630ebf32fc75f22d6f4aa52603fb Mon Sep 17 00:00:00 2001 From: Neil Cowburn Date: Fri, 1 Nov 2019 16:55:59 +0000 Subject: [PATCH 1/3] Add support for opening replica sets over TLS --- Database/MongoDB/Connection.hs | 49 +++++++++++++++++++--------- Database/MongoDB/Internal/Network.hs | 6 +++- Database/MongoDB/Transport/Tls.hs | 3 +- 3 files changed, 39 insertions(+), 19 deletions(-) diff --git a/Database/MongoDB/Connection.hs b/Database/MongoDB/Connection.hs index 5a5e110..de77cad 100644 --- a/Database/MongoDB/Connection.hs +++ b/Database/MongoDB/Connection.hs @@ -17,7 +17,7 @@ module Database.MongoDB.Connection ( Host(..), PortID(..), defaultPort, host, showHostPort, readHostPort, readHostPortM, globalConnectTimeout, connect, connect', -- * Replica Set - ReplicaSetName, openReplicaSet, openReplicaSet', + ReplicaSetName, openReplicaSet, openReplicaSet', openReplicaSetTLS, openReplicaSetTLS', ReplicaSet, primary, secondaryOk, routedHost, closeReplicaSet, replSetName ) where @@ -47,12 +47,16 @@ import Data.Text (Text) import qualified Data.Bson as B import qualified Data.Text as T -import Database.MongoDB.Internal.Network (HostName, PortID(..), connectTo) +import Database.MongoDB.Internal.Network (Host(..), HostName, PortID(..), connectTo, lookupSeedList, lookupReplicaSetName) import Database.MongoDB.Internal.Protocol (Pipe, newPipe, close, isClosed) import Database.MongoDB.Internal.Util (untilSuccess, liftIOE, updateAssocs, shuffle, mergesortM) import Database.MongoDB.Query (Command, Failure(ConnectionFailure), access, slaveOk, runCommand, retrieveServerData) +import qualified Database.MongoDB.Transport.Tls as TLS (connect) + +flip' :: (a -> b -> c -> d) -> b -> c -> a -> d +flip' f x y z = f z x y adminCommand :: Command -> Pipe -> IO Document -- ^ Run command against admin database on server connected to pipe. Fail if connection fails. @@ -62,10 +66,6 @@ adminCommand cmd pipe = failureToIOError (ConnectionFailure e) = e failureToIOError e = userError $ show e --- * Host - -data Host = Host HostName PortID deriving (Show, Eq, Ord) - defaultPort :: PortID -- ^ Default MongoDB port = 27017 defaultPort = PortNumber 27017 @@ -124,12 +124,14 @@ connect' timeoutSecs (Host hostname port) = do type ReplicaSetName = Text +data TransportSecurity = Secure | Insecure + -- | Maintains a connection (created on demand) to each server in the named replica set -data ReplicaSet = ReplicaSet ReplicaSetName (MVar [(Host, Maybe Pipe)]) Secs +data ReplicaSet = ReplicaSet ReplicaSetName (MVar [(Host, Maybe Pipe)]) Secs TransportSecurity replSetName :: ReplicaSet -> Text -- ^ name of connected replica set -replSetName (ReplicaSet rsName _ _) = rsName +replSetName (ReplicaSet rsName _ _ _) = rsName openReplicaSet :: (ReplicaSetName, [Host]) -> IO ReplicaSet -- ^ Open connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSet\'' instead. @@ -137,19 +139,30 @@ openReplicaSet rsSeed = readIORef globalConnectTimeout >>= flip openReplicaSet' openReplicaSet' :: Secs -> (ReplicaSetName, [Host]) -> IO ReplicaSet -- ^ Open connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. Supplied seconds timeout is used for connect attempts to members. -openReplicaSet' timeoutSecs (rsName, seedList) = do +openReplicaSet' timeoutSecs (rs, hosts) = _openReplicaSet timeoutSecs (rs, hosts, Insecure) + +openReplicaSetTLS :: (ReplicaSetName, [Host]) -> IO ReplicaSet +-- ^ Open secure connections (on demand) to servers in the replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSetTLS\'' instead. +openReplicaSetTLS rsSeed = readIORef globalConnectTimeout >>= flip openReplicaSetTLS' rsSeed + +openReplicaSetTLS' :: Secs -> (ReplicaSetName, [Host]) -> IO ReplicaSet +-- ^ Open secure connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. Supplied seconds timeout is used for connect attempts to members. +openReplicaSetTLS' timeoutSecs (rs, hosts) = _openReplicaSet timeoutSecs (rs, hosts, Secure) + +_openReplicaSet :: Secs -> (ReplicaSetName, [Host], TransportSecurity) -> IO ReplicaSet +_openReplicaSet timeoutSecs (rsName, seedList, transportSecurity) = do vMembers <- newMVar (map (, Nothing) seedList) - let rs = ReplicaSet rsName vMembers timeoutSecs + let rs = ReplicaSet rsName vMembers timeoutSecs transportSecurity _ <- updateMembers rs return rs closeReplicaSet :: ReplicaSet -> IO () -- ^ Close all connections to replica set -closeReplicaSet (ReplicaSet _ vMembers _) = withMVar vMembers $ mapM_ (maybe (return ()) close . snd) +closeReplicaSet (ReplicaSet _ vMembers _ _) = withMVar vMembers $ mapM_ (maybe (return ()) close . snd) primary :: ReplicaSet -> IO Pipe -- ^ Return connection to current primary of replica set. Fail if no primary available. -primary rs@(ReplicaSet rsName _ _) = do +primary rs@(ReplicaSet rsName _ _ _) = do mHost <- statedPrimary <$> updateMembers rs case mHost of Just host' -> connection rs Nothing host' @@ -185,7 +198,7 @@ possibleHosts (_, info) = map readHostPort $ at "hosts" info updateMembers :: ReplicaSet -> IO ReplicaInfo -- ^ Fetch replica info from any server and update members accordingly -updateMembers rs@(ReplicaSet _ vMembers _) = do +updateMembers rs@(ReplicaSet _ vMembers _ _) = do (host', info) <- untilSuccess (fetchReplicaInfo rs) =<< readMVar vMembers modifyMVar vMembers $ \members -> do let ((members', old), new) = intersection (map readHostPort $ at "hosts" info) members @@ -199,7 +212,7 @@ updateMembers rs@(ReplicaSet _ vMembers _) = do fetchReplicaInfo :: ReplicaSet -> (Host, Maybe Pipe) -> IO ReplicaInfo -- Connect to host and fetch replica info from host creating new connection if missing or closed (previously failed). Fail if not member of named replica set. -fetchReplicaInfo rs@(ReplicaSet rsName _ _) (host', mPipe) = do +fetchReplicaInfo rs@(ReplicaSet rsName _ _ _) (host', mPipe) = do pipe <- connection rs mPipe host' info <- adminCommand ["isMaster" =: (1 :: Int)] pipe case B.lookup "setName" info of @@ -209,11 +222,15 @@ fetchReplicaInfo rs@(ReplicaSet rsName _ _) (host', mPipe) = do connection :: ReplicaSet -> Maybe Pipe -> Host -> IO Pipe -- ^ Return new or existing connection to member of replica set. If pipe is already known for host it is given, but we still test if it is open. -connection (ReplicaSet _ vMembers timeoutSecs) mPipe host' = +connection (ReplicaSet _ vMembers timeoutSecs transportSecurity) mPipe host' = maybe conn (\p -> isClosed p >>= \bad -> if bad then conn else return p) mPipe where conn = modifyMVar vMembers $ \members -> do - let new = connect' timeoutSecs host' >>= \pipe -> return (updateAssocs host' (Just pipe) members, pipe) + let (Host h p) = host' + let conn' = case transportSecurity of + Secure -> TLS.connect h p + Insecure -> connect' timeoutSecs host' + let new = conn' >>= \pipe -> return (updateAssocs host' (Just pipe) members, pipe) case List.lookup host' members of Just (Just pipe) -> isClosed pipe >>= \bad -> if bad then new else return (members, pipe) _ -> new diff --git a/Database/MongoDB/Internal/Network.hs b/Database/MongoDB/Internal/Network.hs index ae94830..3fb9a86 100644 --- a/Database/MongoDB/Internal/Network.hs +++ b/Database/MongoDB/Internal/Network.hs @@ -1,7 +1,7 @@ -- | Compatibility layer for network package, including newtype 'PortID' {-# LANGUAGE CPP, GeneralizedNewtypeDeriving #-} -module Database.MongoDB.Internal.Network (PortID(..), N.HostName, connectTo) where +module Database.MongoDB.Internal.Network (Host(..), PortID(..), N.HostName, connectTo) where #if !MIN_VERSION_network(2, 9, 0) @@ -50,3 +50,7 @@ connectTo hostname (PortNumber port) = do N.socketToHandle sock ReadWriteMode ) #endif + +-- * Host + +data Host = Host N.HostName PortID deriving (Show, Eq, Ord) diff --git a/Database/MongoDB/Transport/Tls.hs b/Database/MongoDB/Transport/Tls.hs index 696be93..6915d1f 100644 --- a/Database/MongoDB/Transport/Tls.hs +++ b/Database/MongoDB/Transport/Tls.hs @@ -34,8 +34,7 @@ import Control.Applicative ((<$>)) import Control.Exception (bracketOnError) import Control.Monad (when, unless) import System.IO -import Database.MongoDB (Pipe) -import Database.MongoDB.Internal.Protocol (newPipeWith) +import Database.MongoDB.Internal.Protocol (Pipe, newPipeWith) import Database.MongoDB.Transport (Transport(Transport)) import qualified Database.MongoDB.Transport as T import System.IO.Error (mkIOError, eofErrorType) From bcfbcc29185c1a95cf6e5bed5cb8e0552d44887b Mon Sep 17 00:00:00 2001 From: Neil Cowburn Date: Fri, 1 Nov 2019 17:00:06 +0000 Subject: [PATCH 2/3] Add support for opening replica sets using v3.6-style connection strings --- Database/MongoDB/Connection.hs | 42 ++++++++++++++++++++++++---- Database/MongoDB/Internal/Network.hs | 35 +++++++++++++++++++++-- mongoDB.cabal | 2 ++ 3 files changed, 71 insertions(+), 8 deletions(-) diff --git a/Database/MongoDB/Connection.hs b/Database/MongoDB/Connection.hs index de77cad..dc42a37 100644 --- a/Database/MongoDB/Connection.hs +++ b/Database/MongoDB/Connection.hs @@ -18,6 +18,7 @@ module Database.MongoDB.Connection ( readHostPortM, globalConnectTimeout, connect, connect', -- * Replica Set ReplicaSetName, openReplicaSet, openReplicaSet', openReplicaSetTLS, openReplicaSetTLS', + openReplicaSetSRV, openReplicaSetSRV', openReplicaSetSRV'', openReplicaSetSRV''', ReplicaSet, primary, secondaryOk, routedHost, closeReplicaSet, replSetName ) where @@ -55,9 +56,6 @@ import Database.MongoDB.Query (Command, Failure(ConnectionFailure), access, slaveOk, runCommand, retrieveServerData) import qualified Database.MongoDB.Transport.Tls as TLS (connect) -flip' :: (a -> b -> c -> d) -> b -> c -> a -> d -flip' f x y z = f z x y - adminCommand :: Command -> Pipe -> IO Document -- ^ Run command against admin database on server connected to pipe. Fail if connection fails. adminCommand cmd pipe = @@ -124,7 +122,7 @@ connect' timeoutSecs (Host hostname port) = do type ReplicaSetName = Text -data TransportSecurity = Secure | Insecure +data TransportSecurity = Secure | Unsecure -- | Maintains a connection (created on demand) to each server in the named replica set data ReplicaSet = ReplicaSet ReplicaSetName (MVar [(Host, Maybe Pipe)]) Secs TransportSecurity @@ -139,7 +137,7 @@ openReplicaSet rsSeed = readIORef globalConnectTimeout >>= flip openReplicaSet' openReplicaSet' :: Secs -> (ReplicaSetName, [Host]) -> IO ReplicaSet -- ^ Open connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. Supplied seconds timeout is used for connect attempts to members. -openReplicaSet' timeoutSecs (rs, hosts) = _openReplicaSet timeoutSecs (rs, hosts, Insecure) +openReplicaSet' timeoutSecs (rs, hosts) = _openReplicaSet timeoutSecs (rs, hosts, Unsecure) openReplicaSetTLS :: (ReplicaSetName, [Host]) -> IO ReplicaSet -- ^ Open secure connections (on demand) to servers in the replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSetTLS\'' instead. @@ -156,6 +154,38 @@ _openReplicaSet timeoutSecs (rsName, seedList, transportSecurity) = do _ <- updateMembers rs return rs +openReplicaSetSRV :: HostName -> IO ReplicaSet +-- ^ Open non-secure connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSetSRV\'\'\'' instead. +openReplicaSetSRV hostname = do + timeoutSecs <- readIORef globalConnectTimeout + _openReplicaSetSRV timeoutSecs Unsecure hostname + +openReplicaSetSRV' :: HostName -> IO ReplicaSet +-- ^ Open secure connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSetSRV\'\'\'\'' instead. +openReplicaSetSRV' hostname = do + timeoutSecs <- readIORef globalConnectTimeout + _openReplicaSetSRV timeoutSecs Secure hostname + +openReplicaSetSRV'' :: Secs -> HostName -> IO ReplicaSet +-- ^ Open non-secure connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. Supplied seconds timeout is used for connect attempts to members. +openReplicaSetSRV'' timeoutSecs = _openReplicaSetSRV timeoutSecs Unsecure + +openReplicaSetSRV''' :: Secs -> HostName -> IO ReplicaSet +-- ^ Open secure connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. Supplied seconds timeout is used for connect attempts to members. +openReplicaSetSRV''' timeoutSecs = _openReplicaSetSRV timeoutSecs Secure + +_openReplicaSetSRV :: Secs -> TransportSecurity -> HostName -> IO ReplicaSet +_openReplicaSetSRV timeoutSecs transportSecurity hostname = do + replicaSetName <- lookupReplicaSetName hostname + hosts <- lookupSeedList hostname + case (replicaSetName, hosts) of + (Nothing, _) -> throwError $ userError "Failed to lookup replica set name" + (_, []) -> throwError $ userError "Failed to lookup replica set seedlist" + (Just rsName, _) -> + case transportSecurity of + Secure -> openReplicaSetTLS' timeoutSecs (rsName, hosts) + Unsecure -> openReplicaSet' timeoutSecs (rsName, hosts) + closeReplicaSet :: ReplicaSet -> IO () -- ^ Close all connections to replica set closeReplicaSet (ReplicaSet _ vMembers _ _) = withMVar vMembers $ mapM_ (maybe (return ()) close . snd) @@ -229,7 +259,7 @@ connection (ReplicaSet _ vMembers timeoutSecs transportSecurity) mPipe host' = let (Host h p) = host' let conn' = case transportSecurity of Secure -> TLS.connect h p - Insecure -> connect' timeoutSecs host' + Unsecure -> connect' timeoutSecs host' let new = conn' >>= \pipe -> return (updateAssocs host' (Just pipe) members, pipe) case List.lookup host' members of Just (Just pipe) -> isClosed pipe >>= \bad -> if bad then new else return (members, pipe) diff --git a/Database/MongoDB/Internal/Network.hs b/Database/MongoDB/Internal/Network.hs index 3fb9a86..802b1a7 100644 --- a/Database/MongoDB/Internal/Network.hs +++ b/Database/MongoDB/Internal/Network.hs @@ -1,7 +1,8 @@ -- | Compatibility layer for network package, including newtype 'PortID' -{-# LANGUAGE CPP, GeneralizedNewtypeDeriving #-} +{-# LANGUAGE CPP, GeneralizedNewtypeDeriving, OverloadedStrings #-} -module Database.MongoDB.Internal.Network (Host(..), PortID(..), N.HostName, connectTo) where +module Database.MongoDB.Internal.Network (Host(..), PortID(..), N.HostName, connectTo, + lookupReplicaSetName, lookupSeedList) where #if !MIN_VERSION_network(2, 9, 0) @@ -18,6 +19,14 @@ import System.IO (Handle, IOMode(ReadWriteMode)) #endif +import Data.ByteString.Char8 (pack, unpack) +import Data.List (dropWhileEnd, lookup) +import Data.Maybe (fromMaybe) +import Data.Text (Text) +import Network.DNS.Lookup (lookupSRV, lookupTXT) +import Network.DNS.Resolver (defaultResolvConf, makeResolvSeed, withResolver) +import Network.HTTP.Types.URI (parseQueryText) + -- | Wraps network's 'PortNumber' -- Used to ease compatibility between older and newer network versions. @@ -54,3 +63,25 @@ connectTo hostname (PortNumber port) = do -- * Host data Host = Host N.HostName PortID deriving (Show, Eq, Ord) + +lookupReplicaSetName :: N.HostName -> IO (Maybe Text) +-- ^ Retrieves the replica set name from the TXT DNS record for the given hostname +lookupReplicaSetName hostname = do + rs <- makeResolvSeed defaultResolvConf + res <- withResolver rs $ \resolver -> lookupTXT resolver (pack hostname) + case res of + Left _ -> pure Nothing + Right [] -> pure Nothing + Right (x:_) -> + pure $ fromMaybe (Nothing :: Maybe Text) (lookup "replicaSet" $ parseQueryText x) + +lookupSeedList :: N.HostName -> IO [Host] +-- ^ Retrieves the replica set seed list from the SRV DNS record for the given hostname +lookupSeedList hostname = do + rs <- makeResolvSeed defaultResolvConf + res <- withResolver rs $ \resolver -> lookupSRV resolver $ "_mongodb._tcp." ++ pack hostname + case res of + Left _ -> pure [] + Right srv -> pure $ map (\(_, _, por, tar) -> + let tar' = dropWhileEnd (=='.') (unpack tar) + in Host tar' (PortNumber . fromIntegral $ por)) srv \ No newline at end of file diff --git a/mongoDB.cabal b/mongoDB.cabal index 6160aeb..29e6b39 100644 --- a/mongoDB.cabal +++ b/mongoDB.cabal @@ -57,6 +57,8 @@ Library , base16-bytestring >= 0.1.1.6 , base64-bytestring >= 1.0.0.1 , nonce >= 1.0.5 + , dns + , http-types if flag(_old-network) -- "Network.BSD" is only available in network < 2.9 From 30ef4e15705eb9f539aee81417f67ce5a2825af5 Mon Sep 17 00:00:00 2001 From: Victor Denisov Date: Wed, 1 Jan 2020 17:38:13 -0800 Subject: [PATCH 3/3] Fix compilation error --- Database/MongoDB/Internal/Network.hs | 4 ++-- mongoDB.cabal | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/Database/MongoDB/Internal/Network.hs b/Database/MongoDB/Internal/Network.hs index 802b1a7..7f9084e 100644 --- a/Database/MongoDB/Internal/Network.hs +++ b/Database/MongoDB/Internal/Network.hs @@ -79,9 +79,9 @@ lookupSeedList :: N.HostName -> IO [Host] -- ^ Retrieves the replica set seed list from the SRV DNS record for the given hostname lookupSeedList hostname = do rs <- makeResolvSeed defaultResolvConf - res <- withResolver rs $ \resolver -> lookupSRV resolver $ "_mongodb._tcp." ++ pack hostname + res <- withResolver rs $ \resolver -> lookupSRV resolver $ pack $ "_mongodb._tcp." ++ hostname case res of Left _ -> pure [] Right srv -> pure $ map (\(_, _, por, tar) -> let tar' = dropWhileEnd (=='.') (unpack tar) - in Host tar' (PortNumber . fromIntegral $ por)) srv \ No newline at end of file + in Host tar' (PortNumber . fromIntegral $ por)) srv diff --git a/mongoDB.cabal b/mongoDB.cabal index 29e6b39..8a5717e 100644 --- a/mongoDB.cabal +++ b/mongoDB.cabal @@ -129,7 +129,11 @@ Benchmark bench , lifted-base >= 0.1.0.3 , transformers-base >= 0.4.1 , hashtables >= 1.1.2.0 + , fail + , dns + , http-types , criterion + , tls >= 1.3.0 if flag(_old-network) -- "Network.BSD" is only available in network < 2.9