Retrieve server data on connection

This commit is contained in:
Victor Denisov 2016-05-19 21:44:42 -07:00
parent 76ac212708
commit 2ba71ca277
4 changed files with 54 additions and 16 deletions

View file

@ -2,6 +2,12 @@
{-# LANGUAGE CPP, OverloadedStrings, ScopedTypeVariables, TupleSections #-} {-# LANGUAGE CPP, OverloadedStrings, ScopedTypeVariables, TupleSections #-}
#if (__GLASGOW_HASKELL__ >= 706)
{-# LANGUAGE RecursiveDo #-}
#else
{-# LANGUAGE DoRec #-}
#endif
module Database.MongoDB.Connection ( module Database.MongoDB.Connection (
-- * Util -- * Util
Secs, Secs,
@ -46,7 +52,7 @@ import Database.MongoDB.Internal.Protocol (Pipe, newPipe, close, isClosed)
import Database.MongoDB.Internal.Util (untilSuccess, liftIOE, import Database.MongoDB.Internal.Util (untilSuccess, liftIOE,
updateAssocs, shuffle, mergesortM) updateAssocs, shuffle, mergesortM)
import Database.MongoDB.Query (Command, Failure(ConnectionFailure), access, import Database.MongoDB.Query (Command, Failure(ConnectionFailure), access,
slaveOk, runCommand) slaveOk, runCommand, retrieveServerData)
adminCommand :: Command -> Pipe -> IO Document adminCommand :: Command -> Pipe -> IO Document
-- ^ Run command against admin database on server connected to pipe. Fail if connection fails. -- ^ Run command against admin database on server connected to pipe. Fail if connection fails.
@ -113,7 +119,10 @@ connect' :: Secs -> Host -> IO Pipe
connect' timeoutSecs (Host hostname port) = do connect' timeoutSecs (Host hostname port) = do
mh <- timeout (round $ timeoutSecs * 1000000) (connectTo hostname port) mh <- timeout (round $ timeoutSecs * 1000000) (connectTo hostname port)
handle <- maybe (ioError $ userError "connect timed out") return mh handle <- maybe (ioError $ userError "connect timed out") return mh
newPipe handle rec
p <- newPipe sd handle
sd <- access p slaveOk "admin" retrieveServerData
return p
-- * Replica Set -- * Replica Set

View file

@ -28,7 +28,7 @@ module Database.MongoDB.Internal.Protocol (
Reply(..), ResponseFlag(..), Reply(..), ResponseFlag(..),
-- * Authentication -- * Authentication
Username, Password, Nonce, pwHash, pwKey, Username, Password, Nonce, pwHash, pwKey,
isClosed, close isClosed, close, ServerData(..), Pipeline(..)
) where ) where
#if !MIN_VERSION_base(4,8,0) #if !MIN_VERSION_base(4,8,0)
@ -83,15 +83,22 @@ mkWeakMVar = addMVarFinalizer
-- * Pipeline -- * Pipeline
-- | Thread-safe and pipelined connection -- | Thread-safe and pipelined connection
data Pipeline = Pipeline { data Pipeline = Pipeline
vStream :: MVar Transport, -- ^ Mutex on handle, so only one thread at a time can write to it { vStream :: MVar Transport -- ^ Mutex on handle, so only one thread at a time can write to it
responseQueue :: Chan (MVar (Either IOError Response)), -- ^ Queue of threads waiting for responses. Every time a response arrive we pop the next thread and give it the response. , responseQueue :: Chan (MVar (Either IOError Response)) -- ^ Queue of threads waiting for responses. Every time a response arrive we pop the next thread and give it the response.
listenThread :: ThreadId , listenThread :: ThreadId
, serverData :: ServerData
} }
data ServerData = ServerData
{ isMaster :: Bool
, minWireVersion :: Int
, maxWireVersion :: Int
}
-- | Create new Pipeline over given handle. You should 'close' pipeline when finished, which will also close handle. If pipeline is not closed but eventually garbage collected, it will be closed along with handle. -- | Create new Pipeline over given handle. You should 'close' pipeline when finished, which will also close handle. If pipeline is not closed but eventually garbage collected, it will be closed along with handle.
newPipeline :: Transport -> IO Pipeline newPipeline :: ServerData -> Transport -> IO Pipeline
newPipeline stream = do newPipeline serverData stream = do
vStream <- newMVar stream vStream <- newMVar stream
responseQueue <- newChan responseQueue <- newChan
rec rec
@ -150,13 +157,13 @@ pcall p@Pipeline{..} message = withMVar vStream doCall `onException` close p wh
type Pipe = Pipeline type Pipe = Pipeline
-- ^ Thread-safe TCP connection with pipelined requests -- ^ Thread-safe TCP connection with pipelined requests
newPipe :: Handle -> IO Pipe newPipe :: ServerData -> Handle -> IO Pipe
-- ^ Create pipe over handle -- ^ Create pipe over handle
newPipe handle = T.fromHandle handle >>= newPipeWith newPipe sd handle = T.fromHandle handle >>= (newPipeWith sd)
newPipeWith :: Transport -> IO Pipe newPipeWith :: ServerData -> Transport -> IO Pipe
-- ^ Create pipe over connection -- ^ Create pipe over connection
newPipeWith conn = newPipeline conn newPipeWith sd conn = newPipeline sd conn
send :: Pipe -> [Notice] -> IO () send :: Pipe -> [Notice] -> IO ()
-- ^ Send notices as a contiguous batch to server with no reply. Throw IOError if connection fails. -- ^ Send notices as a contiguous batch to server with no reply. Throw IOError if connection fails.

View file

@ -42,7 +42,7 @@ module Database.MongoDB.Query (
MRResult, mapReduce, runMR, runMR', MRResult, mapReduce, runMR, runMR',
-- * Command -- * Command
Command, runCommand, runCommand1, Command, runCommand, runCommand1,
eval, eval, retrieveServerData
) where ) where
import Prelude hiding (lookup) import Prelude hiding (lookup)
@ -84,7 +84,7 @@ import Database.MongoDB.Internal.Protocol (Reply(..), QueryOption(..),
Request(GetMore, qOptions, qSkip, Request(GetMore, qOptions, qSkip,
qFullCollection, qBatchSize, qFullCollection, qBatchSize,
qSelector, qProjector), qSelector, qProjector),
pwKey) pwKey, ServerData(..))
import Database.MongoDB.Internal.Util (loop, liftIOE, true1, (<.>)) import Database.MongoDB.Internal.Util (loop, liftIOE, true1, (<.>))
import qualified Database.MongoDB.Internal.Protocol as P import qualified Database.MongoDB.Internal.Protocol as P
@ -99,6 +99,7 @@ import qualified Crypto.MAC.HMAC as HMAC
import Data.Bits (xor) import Data.Bits (xor)
import qualified Data.Map as Map import qualified Data.Map as Map
import Text.Read (readMaybe) import Text.Read (readMaybe)
import Data.Maybe (fromMaybe)
#if !MIN_VERSION_base(4,6,0) #if !MIN_VERSION_base(4,6,0)
--mkWeakMVar = addMVarFinalizer --mkWeakMVar = addMVarFinalizer
@ -296,6 +297,16 @@ parseSCRAM :: B.ByteString -> Map.Map B.ByteString B.ByteString
parseSCRAM = Map.fromList . fmap cleanup . (fmap $ T.breakOn "=") . T.splitOn "," . T.pack . B.unpack parseSCRAM = Map.fromList . fmap cleanup . (fmap $ T.breakOn "=") . T.splitOn "," . T.pack . B.unpack
where cleanup (t1, t2) = (B.pack $ T.unpack t1, B.pack . T.unpack $ T.drop 1 t2) where cleanup (t1, t2) = (B.pack $ T.unpack t1, B.pack . T.unpack $ T.drop 1 t2)
retrieveServerData :: (MonadIO m) => Action m ServerData
retrieveServerData = do
d <- runCommand1 "isMaster"
let newSd = ServerData
{ isMaster = (fromMaybe False $ lookup "ismaster" d)
, minWireVersion = (fromMaybe 0 $ lookup "minWireVersion" d)
, maxWireVersion = (fromMaybe 0 $ lookup "maxWireVersion" d)
}
return newSd
-- * Collection -- * Collection
type Collection = Text type Collection = Text

View file

@ -1,6 +1,13 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE RecordWildCards #-}
#if (__GLASGOW_HASKELL__ >= 706)
{-# LANGUAGE RecursiveDo #-}
#else
{-# LANGUAGE DoRec #-}
#endif
{-| {-|
Module : MongoDB TLS Module : MongoDB TLS
Description : TLS transport for mongodb Description : TLS transport for mongodb
@ -36,6 +43,7 @@ import System.IO.Error (mkIOError, eofErrorType)
import Network (connectTo, HostName, PortID) import Network (connectTo, HostName, PortID)
import qualified Network.TLS as TLS import qualified Network.TLS as TLS
import qualified Network.TLS.Extra.Cipher as TLS import qualified Network.TLS.Extra.Cipher as TLS
import Database.MongoDB.Query (access, slaveOk, retrieveServerData)
-- | Connect to mongodb using TLS -- | Connect to mongodb using TLS
connect :: HostName -> PortID -> IO Pipe connect :: HostName -> PortID -> IO Pipe
@ -51,7 +59,10 @@ connect host port = bracketOnError (connectTo host port) hClose $ \handle -> do
TLS.handshake context TLS.handshake context
conn <- tlsConnection context conn <- tlsConnection context
newPipeWith conn rec
p <- newPipeWith sd conn
sd <- access p slaveOk "admin" retrieveServerData
return p
tlsConnection :: TLS.Context -> IO Transport tlsConnection :: TLS.Context -> IO Transport
tlsConnection ctx = do tlsConnection ctx = do