diff --git a/Database/MongoDB/Connection.hs b/Database/MongoDB/Connection.hs index 875603a..9408a5f 100644 --- a/Database/MongoDB/Connection.hs +++ b/Database/MongoDB/Connection.hs @@ -2,6 +2,12 @@ {-# LANGUAGE CPP, OverloadedStrings, ScopedTypeVariables, TupleSections #-} +#if (__GLASGOW_HASKELL__ >= 706) +{-# LANGUAGE RecursiveDo #-} +#else +{-# LANGUAGE DoRec #-} +#endif + module Database.MongoDB.Connection ( -- * Util Secs, @@ -46,7 +52,7 @@ import Database.MongoDB.Internal.Protocol (Pipe, newPipe, close, isClosed) import Database.MongoDB.Internal.Util (untilSuccess, liftIOE, updateAssocs, shuffle, mergesortM) import Database.MongoDB.Query (Command, Failure(ConnectionFailure), access, - slaveOk, runCommand) + slaveOk, runCommand, retrieveServerData) adminCommand :: Command -> Pipe -> IO Document -- ^ Run command against admin database on server connected to pipe. Fail if connection fails. @@ -113,7 +119,10 @@ connect' :: Secs -> Host -> IO Pipe connect' timeoutSecs (Host hostname port) = do mh <- timeout (round $ timeoutSecs * 1000000) (connectTo hostname port) handle <- maybe (ioError $ userError "connect timed out") return mh - newPipe handle + rec + p <- newPipe sd handle + sd <- access p slaveOk "admin" retrieveServerData + return p -- * Replica Set diff --git a/Database/MongoDB/Internal/Protocol.hs b/Database/MongoDB/Internal/Protocol.hs index 06b54f6..2bf0916 100644 --- a/Database/MongoDB/Internal/Protocol.hs +++ b/Database/MongoDB/Internal/Protocol.hs @@ -28,7 +28,7 @@ module Database.MongoDB.Internal.Protocol ( Reply(..), ResponseFlag(..), -- * Authentication Username, Password, Nonce, pwHash, pwKey, - isClosed, close + isClosed, close, ServerData(..), Pipeline(..) ) where #if !MIN_VERSION_base(4,8,0) @@ -83,15 +83,22 @@ mkWeakMVar = addMVarFinalizer -- * Pipeline -- | Thread-safe and pipelined connection -data Pipeline = Pipeline { - vStream :: MVar Transport, -- ^ Mutex on handle, so only one thread at a time can write to it - responseQueue :: Chan (MVar (Either IOError Response)), -- ^ Queue of threads waiting for responses. Every time a response arrive we pop the next thread and give it the response. - listenThread :: ThreadId +data Pipeline = Pipeline + { vStream :: MVar Transport -- ^ Mutex on handle, so only one thread at a time can write to it + , responseQueue :: Chan (MVar (Either IOError Response)) -- ^ Queue of threads waiting for responses. Every time a response arrive we pop the next thread and give it the response. + , listenThread :: ThreadId + , serverData :: ServerData } +data ServerData = ServerData + { isMaster :: Bool + , minWireVersion :: Int + , maxWireVersion :: Int + } + -- | Create new Pipeline over given handle. You should 'close' pipeline when finished, which will also close handle. If pipeline is not closed but eventually garbage collected, it will be closed along with handle. -newPipeline :: Transport -> IO Pipeline -newPipeline stream = do +newPipeline :: ServerData -> Transport -> IO Pipeline +newPipeline serverData stream = do vStream <- newMVar stream responseQueue <- newChan rec @@ -150,13 +157,13 @@ pcall p@Pipeline{..} message = withMVar vStream doCall `onException` close p wh type Pipe = Pipeline -- ^ Thread-safe TCP connection with pipelined requests -newPipe :: Handle -> IO Pipe +newPipe :: ServerData -> Handle -> IO Pipe -- ^ Create pipe over handle -newPipe handle = T.fromHandle handle >>= newPipeWith +newPipe sd handle = T.fromHandle handle >>= (newPipeWith sd) -newPipeWith :: Transport -> IO Pipe +newPipeWith :: ServerData -> Transport -> IO Pipe -- ^ Create pipe over connection -newPipeWith conn = newPipeline conn +newPipeWith sd conn = newPipeline sd conn send :: Pipe -> [Notice] -> IO () -- ^ Send notices as a contiguous batch to server with no reply. Throw IOError if connection fails. diff --git a/Database/MongoDB/Query.hs b/Database/MongoDB/Query.hs index ac6cff2..cc578d9 100644 --- a/Database/MongoDB/Query.hs +++ b/Database/MongoDB/Query.hs @@ -42,7 +42,7 @@ module Database.MongoDB.Query ( MRResult, mapReduce, runMR, runMR', -- * Command Command, runCommand, runCommand1, - eval, + eval, retrieveServerData ) where import Prelude hiding (lookup) @@ -84,7 +84,7 @@ import Database.MongoDB.Internal.Protocol (Reply(..), QueryOption(..), Request(GetMore, qOptions, qSkip, qFullCollection, qBatchSize, qSelector, qProjector), - pwKey) + pwKey, ServerData(..)) import Database.MongoDB.Internal.Util (loop, liftIOE, true1, (<.>)) import qualified Database.MongoDB.Internal.Protocol as P @@ -99,6 +99,7 @@ import qualified Crypto.MAC.HMAC as HMAC import Data.Bits (xor) import qualified Data.Map as Map import Text.Read (readMaybe) +import Data.Maybe (fromMaybe) #if !MIN_VERSION_base(4,6,0) --mkWeakMVar = addMVarFinalizer @@ -296,6 +297,16 @@ parseSCRAM :: B.ByteString -> Map.Map B.ByteString B.ByteString parseSCRAM = Map.fromList . fmap cleanup . (fmap $ T.breakOn "=") . T.splitOn "," . T.pack . B.unpack where cleanup (t1, t2) = (B.pack $ T.unpack t1, B.pack . T.unpack $ T.drop 1 t2) +retrieveServerData :: (MonadIO m) => Action m ServerData +retrieveServerData = do + d <- runCommand1 "isMaster" + let newSd = ServerData + { isMaster = (fromMaybe False $ lookup "ismaster" d) + , minWireVersion = (fromMaybe 0 $ lookup "minWireVersion" d) + , maxWireVersion = (fromMaybe 0 $ lookup "maxWireVersion" d) + } + return newSd + -- * Collection type Collection = Text diff --git a/Database/MongoDB/Transport/Tls.hs b/Database/MongoDB/Transport/Tls.hs index 2fd3c73..fa1fc0f 100644 --- a/Database/MongoDB/Transport/Tls.hs +++ b/Database/MongoDB/Transport/Tls.hs @@ -1,6 +1,13 @@ +{-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} +#if (__GLASGOW_HASKELL__ >= 706) +{-# LANGUAGE RecursiveDo #-} +#else +{-# LANGUAGE DoRec #-} +#endif + {-| Module : MongoDB TLS Description : TLS transport for mongodb @@ -36,6 +43,7 @@ import System.IO.Error (mkIOError, eofErrorType) import Network (connectTo, HostName, PortID) import qualified Network.TLS as TLS import qualified Network.TLS.Extra.Cipher as TLS +import Database.MongoDB.Query (access, slaveOk, retrieveServerData) -- | Connect to mongodb using TLS connect :: HostName -> PortID -> IO Pipe @@ -51,7 +59,10 @@ connect host port = bracketOnError (connectTo host port) hClose $ \handle -> do TLS.handshake context conn <- tlsConnection context - newPipeWith conn + rec + p <- newPipeWith sd conn + sd <- access p slaveOk "admin" retrieveServerData + return p tlsConnection :: TLS.Context -> IO Transport tlsConnection ctx = do