diff --git a/.gitignore b/.gitignore index 0d511c2..74deb86 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ dist/ cabal.sandbox.config .cabal-sandbox/ .stack-work/ +dist-newstyle/* +!dist-newstyle/config \ No newline at end of file diff --git a/Database/MongoDB/Admin.hs b/Database/MongoDB/Admin.hs index 1afbba7..da9f197 100644 --- a/Database/MongoDB/Admin.hs +++ b/Database/MongoDB/Admin.hs @@ -33,7 +33,6 @@ import Control.Applicative ((<$>)) #endif import Control.Concurrent (forkIO, threadDelay) import Control.Monad (forever, unless, liftM) -import Control.Monad.Fail(MonadFail) import Data.IORef (IORef, newIORef, readIORef, writeIORef) import Data.Maybe (maybeToList) import Data.Set (Set) diff --git a/Database/MongoDB/Connection.hs b/Database/MongoDB/Connection.hs index 4b88cd6..0d386de 100644 --- a/Database/MongoDB/Connection.hs +++ b/Database/MongoDB/Connection.hs @@ -17,8 +17,8 @@ module Database.MongoDB.Connection ( Host(..), PortID(..), defaultPort, host, showHostPort, readHostPort, readHostPortM, globalConnectTimeout, connect, connect', -- * Replica Set - ReplicaSetName, openReplicaSet, openReplicaSet', openReplicaSetTLS, openReplicaSetTLS', - openReplicaSetSRV, openReplicaSetSRV', openReplicaSetSRV'', openReplicaSetSRV''', + ReplicaSetName, openReplicaSet, openReplicaSet', openReplicaSetTLS, openReplicaSetTLS', + openReplicaSetSRV, openReplicaSetSRV', openReplicaSetSRV'', openReplicaSetSRV''', ReplicaSet, primary, secondaryOk, routedHost, closeReplicaSet, replSetName ) where @@ -32,7 +32,6 @@ import Control.Applicative ((<$>)) #endif import Control.Monad (forM_, guard) -import Control.Monad.Fail(MonadFail) import System.IO.Unsafe (unsafePerformIO) import System.Timeout (timeout) import Text.ParserCombinators.Parsec (parse, many1, letter, digit, char, anyChar, eof, @@ -40,7 +39,6 @@ import Text.ParserCombinators.Parsec (parse, many1, letter, digit, char, anyChar import qualified Data.List as List -import Control.Monad.Identity (runIdentity) import Control.Monad.Except (throwError) import Control.Concurrent.MVar.Lifted (MVar, newMVar, withMVar, modifyMVar, readMVar) @@ -149,28 +147,28 @@ openReplicaSet' :: Secs -> (ReplicaSetName, [Host]) -> IO ReplicaSet -- ^ Open connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. Supplied seconds timeout is used for connect attempts to members. openReplicaSet' timeoutSecs (rs, hosts) = _openReplicaSet timeoutSecs (rs, hosts, Unsecure) -openReplicaSetTLS :: (ReplicaSetName, [Host]) -> IO ReplicaSet +openReplicaSetTLS :: (ReplicaSetName, [Host]) -> IO ReplicaSet -- ^ Open secure connections (on demand) to servers in the replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSetTLS'' instead. openReplicaSetTLS rsSeed = readIORef globalConnectTimeout >>= flip openReplicaSetTLS' rsSeed -openReplicaSetTLS' :: Secs -> (ReplicaSetName, [Host]) -> IO ReplicaSet +openReplicaSetTLS' :: Secs -> (ReplicaSetName, [Host]) -> IO ReplicaSet -- ^ Open secure connections (on demand) to servers in replica set. Supplied hosts is seed list. At least one of them must be a live member of the named replica set, otherwise fail. Supplied seconds timeout is used for connect attempts to members. openReplicaSetTLS' timeoutSecs (rs, hosts) = _openReplicaSet timeoutSecs (rs, hosts, Secure) _openReplicaSet :: Secs -> (ReplicaSetName, [Host], TransportSecurity) -> IO ReplicaSet -_openReplicaSet timeoutSecs (rsName, seedList, transportSecurity) = do +_openReplicaSet timeoutSecs (rsName, seedList, transportSecurity) = do vMembers <- newMVar (map (, Nothing) seedList) let rs = ReplicaSet rsName vMembers timeoutSecs transportSecurity _ <- updateMembers rs return rs -openReplicaSetSRV :: HostName -> IO ReplicaSet +openReplicaSetSRV :: HostName -> IO ReplicaSet -- ^ Open /non-secure/ connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSetSRV''' instead. -openReplicaSetSRV hostname = do +openReplicaSetSRV hostname = do timeoutSecs <- readIORef globalConnectTimeout _openReplicaSetSRV timeoutSecs Unsecure hostname -openReplicaSetSRV' :: HostName -> IO ReplicaSet +openReplicaSetSRV' :: HostName -> IO ReplicaSet -- ^ Open /secure/ connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. The value of 'globalConnectTimeout' at the time of this call is the timeout used for future member connect attempts. To use your own value call 'openReplicaSetSRV'''' instead. -- -- The preferred connection method for cloud MongoDB providers. A typical connecting sequence is shown in the example below. @@ -180,27 +178,27 @@ openReplicaSetSRV' :: HostName -> IO ReplicaSet -- > pipe <- openReplicatSetSRV' "cluster#.xxxxx.yyyyy.zzz" -- > is_auth <- access pipe master "admin" $ auth user_name password -- > unless (not is_auth) (throwIO $ userError "Authentication failed!") -openReplicaSetSRV' hostname = do +openReplicaSetSRV' hostname = do timeoutSecs <- readIORef globalConnectTimeout _openReplicaSetSRV timeoutSecs Secure hostname -openReplicaSetSRV'' :: Secs -> HostName -> IO ReplicaSet +openReplicaSetSRV'' :: Secs -> HostName -> IO ReplicaSet -- ^ Open /non-secure/ connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. Supplied seconds timeout is used for connect attempts to members. openReplicaSetSRV'' timeoutSecs = _openReplicaSetSRV timeoutSecs Unsecure -openReplicaSetSRV''' :: Secs -> HostName -> IO ReplicaSet +openReplicaSetSRV''' :: Secs -> HostName -> IO ReplicaSet -- ^ Open /secure/ connections (on demand) to servers in a replica set. The seedlist and replica set name is fetched from the SRV and TXT DNS records for the given hostname. Supplied seconds timeout is used for connect attempts to members. openReplicaSetSRV''' timeoutSecs = _openReplicaSetSRV timeoutSecs Secure -_openReplicaSetSRV :: Secs -> TransportSecurity -> HostName -> IO ReplicaSet -_openReplicaSetSRV timeoutSecs transportSecurity hostname = do - replicaSetName <- lookupReplicaSetName hostname - hosts <- lookupSeedList hostname - case (replicaSetName, hosts) of +_openReplicaSetSRV :: Secs -> TransportSecurity -> HostName -> IO ReplicaSet +_openReplicaSetSRV timeoutSecs transportSecurity hostname = do + replicaSetName <- lookupReplicaSetName hostname + hosts <- lookupSeedList hostname + case (replicaSetName, hosts) of (Nothing, _) -> throwError $ userError "Failed to lookup replica set name" (_, []) -> throwError $ userError "Failed to lookup replica set seedlist" - (Just rsName, _) -> - case transportSecurity of + (Just rsName, _) -> + case transportSecurity of Secure -> openReplicaSetTLS' timeoutSecs (rsName, hosts) Unsecure -> openReplicaSet' timeoutSecs (rsName, hosts) @@ -229,7 +227,7 @@ routedHost :: ((Host, Bool) -> (Host, Bool) -> IO Ordering) -> ReplicaSet -> IO routedHost f rs = do info <- updateMembers rs hosts <- shuffle (possibleHosts info) - let addIsPrimary h = (h, if Just h == statedPrimary info then True else False) + let addIsPrimary h = (h, Just h == statedPrimary info) hosts' <- mergesortM (\a b -> f (addIsPrimary a) (addIsPrimary b)) hosts untilSuccess (connection rs Nothing) hosts' @@ -275,8 +273,8 @@ connection (ReplicaSet _ vMembers timeoutSecs transportSecurity) mPipe host' = where conn = modifyMVar vMembers $ \members -> do let (Host h p) = host' - let conn' = case transportSecurity of - Secure -> TLS.connect h p + let conn' = case transportSecurity of + Secure -> TLS.connect h p Unsecure -> connect' timeoutSecs host' let new = conn' >>= \pipe -> return (updateAssocs host' (Just pipe) members, pipe) case List.lookup host' members of diff --git a/Database/MongoDB/GridFS.hs b/Database/MongoDB/GridFS.hs index 5515ace..a78bb6c 100644 --- a/Database/MongoDB/GridFS.hs +++ b/Database/MongoDB/GridFS.hs @@ -1,7 +1,7 @@ -- Author: -- Brent Tubbs -- | MongoDB GridFS implementation -{-# LANGUAGE OverloadedStrings, RecordWildCards, NamedFieldPuns, TupleSections, FlexibleContexts, FlexibleInstances, UndecidableInstances, MultiParamTypeClasses, GeneralizedNewtypeDeriving, StandaloneDeriving, TypeSynonymInstances, TypeFamilies, CPP, RankNTypes #-} +{-# LANGUAGE OverloadedStrings, FlexibleContexts, FlexibleInstances, UndecidableInstances, MultiParamTypeClasses, TypeFamilies, CPP, RankNTypes #-} module Database.MongoDB.GridFS ( Bucket @@ -23,10 +23,8 @@ module Database.MongoDB.GridFS ) where -import Control.Applicative((<$>)) import Control.Monad(when) -import Control.Monad.Fail(MonadFail) import Control.Monad.IO.Class import Control.Monad.Trans(lift) @@ -64,7 +62,7 @@ openBucket :: (Monad m, MonadIO m) => Text -> Action m Bucket openBucket name = do let filesCollection = name `append` ".files" let chunksCollection = name `append` ".chunks" - ensureIndex $ (index filesCollection ["filename" =: (1::Int), "uploadDate" =: (1::Int)]) + ensureIndex $ index filesCollection ["filename" =: (1::Int), "uploadDate" =: (1::Int)] ensureIndex $ (index chunksCollection ["files_id" =: (1::Int), "n" =: (1::Int)]) { iUnique = True, iDropDups = True } return $ Bucket filesCollection chunksCollection @@ -72,9 +70,9 @@ data File = File {bucket :: Bucket, document :: Document} getChunk :: (MonadFail m, MonadIO m) => File -> Int -> Action m (Maybe S.ByteString) -- ^ Get a chunk of a file -getChunk (File bucket doc) i = do +getChunk (File _bucket doc) i = do files_id <- B.look "_id" doc - result <- findOne $ select ["files_id" := files_id, "n" =: i] $ chunks bucket + result <- findOne $ select ["files_id" := files_id, "n" =: i] $ chunks _bucket let content = at "data" <$> result case content of Just (Binary b) -> return (Just b) @@ -82,36 +80,36 @@ getChunk (File bucket doc) i = do findFile :: MonadIO m => Bucket -> Selector -> Action m [File] -- ^ Find files in the bucket -findFile bucket sel = do - cursor <- find $ select sel $ files bucket +findFile _bucket sel = do + cursor <- find $ select sel $ files _bucket results <- rest cursor - return $ File bucket <$> results + return $ File _bucket <$> results findOneFile :: MonadIO m => Bucket -> Selector -> Action m (Maybe File) -- ^ Find one file in the bucket -findOneFile bucket sel = do - mdoc <- findOne $ select sel $ files bucket - return $ File bucket <$> mdoc +findOneFile _bucket sel = do + mdoc <- findOne $ select sel $ files _bucket + return $ File _bucket <$> mdoc fetchFile :: MonadIO m => Bucket -> Selector -> Action m File -- ^ Fetch one file in the bucket -fetchFile bucket sel = do - doc <- fetch $ select sel $ files bucket - return $ File bucket doc +fetchFile _bucket sel = do + doc <- fetch $ select sel $ files _bucket + return $ File _bucket doc deleteFile :: (MonadIO m, MonadFail m) => File -> Action m () -- ^ Delete files in the bucket -deleteFile (File bucket doc) = do +deleteFile (File _bucket doc) = do files_id <- B.look "_id" doc - delete $ select ["_id" := files_id] $ files bucket - delete $ select ["files_id" := files_id] $ chunks bucket + delete $ select ["_id" := files_id] $ files _bucket + delete $ select ["files_id" := files_id] $ chunks _bucket putChunk :: (Monad m, MonadIO m) => Bucket -> ObjectId -> Int -> L.ByteString -> Action m () -- ^ Put a chunk in the bucket -putChunk bucket files_id i chunk = do - insert_ (chunks bucket) ["files_id" =: files_id, "n" =: i, "data" =: Binary (L.toStrict chunk)] +putChunk _bucket files_id i chunk = do + insert_ (chunks _bucket) ["files_id" =: files_id, "n" =: i, "data" =: Binary (L.toStrict chunk)] -sourceFile :: (MonadFail m, MonadIO m) => File -> Producer (Action m) S.ByteString +sourceFile :: (MonadFail m, MonadIO m) => File -> ConduitT File S.ByteString (Action m) () -- ^ A producer for the contents of a file sourceFile file = yieldChunk 0 where yieldChunk i = do @@ -134,19 +132,19 @@ data FileWriter = FileWriter -- Finalize file, calculating md5 digest, saving the last chunk, and creating the file in the bucket finalizeFile :: (Monad m, MonadIO m) => Text -> FileWriter -> Action m File -finalizeFile filename (FileWriter chunkSize bucket files_id i size acc md5context md5acc) = do +finalizeFile filename (FileWriter chunkSize _bucket files_id i size acc md5context md5acc) = do let md5digest = finalizeMD5 md5context (L.toStrict md5acc) - when (L.length acc > 0) $ putChunk bucket files_id i acc - currentTimestamp <- liftIO $ getCurrentTime + when (L.length acc > 0) $ putChunk _bucket files_id i acc + currentTimestamp <- liftIO getCurrentTime let doc = [ "_id" =: files_id , "length" =: size , "uploadDate" =: currentTimestamp - , "md5" =: show (md5digest) + , "md5" =: show md5digest , "chunkSize" =: chunkSize , "filename" =: filename ] - insert_ (files bucket) doc - return $ File bucket doc + insert_ (files _bucket) doc + return $ File _bucket doc -- finalize the remainder and return the MD5Digest. finalizeMD5 :: MD5Context -> S.ByteString -> MD5Digest @@ -160,11 +158,11 @@ finalizeMD5 ctx remainder = -- Write as many chunks as can be written from the file writer writeChunks :: (Monad m, MonadIO m) => FileWriter -> L.ByteString -> Action m FileWriter -writeChunks (FileWriter chunkSize bucket files_id i size acc md5context md5acc) chunk = do +writeChunks (FileWriter chunkSize _bucket files_id i size acc md5context md5acc) chunk = do -- Update md5 context let md5BlockLength = fromIntegral $ untag (blockLength :: Tagged MD5Digest Int) let md5acc_temp = (md5acc `L.append` chunk) - let (md5context', md5acc') = + let (md5context', md5acc') = if (L.length md5acc_temp < md5BlockLength) then (md5context, md5acc_temp) else let numBlocks = L.length md5acc_temp `div` md5BlockLength @@ -174,17 +172,17 @@ writeChunks (FileWriter chunkSize bucket files_id i size acc md5context md5acc) let size' = (size + L.length chunk) let acc_temp = (acc `L.append` chunk) if (L.length acc_temp < chunkSize) - then return (FileWriter chunkSize bucket files_id i size' acc_temp md5context' md5acc') + then return (FileWriter chunkSize _bucket files_id i size' acc_temp md5context' md5acc') else do let (newChunk, acc') = L.splitAt chunkSize acc_temp - putChunk bucket files_id i newChunk - writeChunks (FileWriter chunkSize bucket files_id (i+1) size' acc' md5context' md5acc') L.empty + putChunk _bucket files_id i newChunk + writeChunks (FileWriter chunkSize _bucket files_id (i+1) size' acc' md5context' md5acc') L.empty -sinkFile :: (Monad m, MonadIO m) => Bucket -> Text -> Consumer S.ByteString (Action m) File +sinkFile :: (Monad m, MonadIO m) => Bucket -> Text -> ConduitT S.ByteString () (Action m) File -- ^ A consumer that creates a file in the bucket and puts all consumed data in it -sinkFile bucket filename = do +sinkFile _bucket filename = do files_id <- liftIO $ genObjectId - awaitChunk $ FileWriter defaultChunkSize bucket files_id 0 0 L.empty md5InitialContext L.empty + awaitChunk $ FileWriter defaultChunkSize _bucket files_id 0 0 L.empty md5InitialContext L.empty where awaitChunk fw = do mchunk <- await diff --git a/Database/MongoDB/Internal/Network.hs b/Database/MongoDB/Internal/Network.hs index 11190b1..8eb6934 100644 --- a/Database/MongoDB/Internal/Network.hs +++ b/Database/MongoDB/Internal/Network.hs @@ -1,10 +1,9 @@ -- | Compatibility layer for network package, including newtype 'PortID' -{-# LANGUAGE CPP, GeneralizedNewtypeDeriving, OverloadedStrings #-} +{-# LANGUAGE CPP, OverloadedStrings #-} module Database.MongoDB.Internal.Network (Host(..), PortID(..), N.HostName, connectTo, lookupReplicaSetName, lookupSeedList) where - #if !MIN_VERSION_network(2, 9, 0) import qualified Network as N @@ -20,7 +19,7 @@ import System.IO (Handle, IOMode(ReadWriteMode)) #endif import Data.ByteString.Char8 (pack, unpack) -import Data.List (dropWhileEnd, lookup) +import Data.List (dropWhileEnd) import Data.Maybe (fromMaybe) import Data.Text (Text) import Network.DNS.Lookup (lookupSRV, lookupTXT) @@ -60,7 +59,7 @@ connectTo hostname (PortNumber port) = do proto <- BSD.getProtocolNumber "tcp" bracketOnError (N.socket N.AF_INET N.Stream proto) - (N.close) -- only done if there's an error + N.close -- only done if there's an error (\sock -> do he <- BSD.getHostByName hostname N.connect sock (N.SockAddrInet port (hostAddress he)) @@ -71,7 +70,7 @@ connectTo hostname (PortNumber port) = do connectTo _ (UnixSocket path) = do bracketOnError (N.socket N.AF_UNIX N.Stream 0) - (N.close) + N.close (\sock -> do N.connect sock (N.SockAddrUnix path) N.socketToHandle sock ReadWriteMode @@ -104,4 +103,4 @@ lookupSeedList hostname = do Left _ -> pure [] Right srv -> pure $ map (\(_, _, por, tar) -> let tar' = dropWhileEnd (=='.') (unpack tar) - in Host tar' (PortNumber . fromIntegral $ por)) srv + in Host tar' (PortNumber . fromIntegral $ por)) srv \ No newline at end of file diff --git a/Database/MongoDB/Internal/Protocol.hs b/Database/MongoDB/Internal/Protocol.hs index a86f4a1..ed99a0c 100644 --- a/Database/MongoDB/Internal/Protocol.hs +++ b/Database/MongoDB/Internal/Protocol.hs @@ -4,8 +4,8 @@ -- This module is not intended for direct use. Use the high-level interface at -- "Database.MongoDB.Query" and "Database.MongoDB.Connection" instead. -{-# LANGUAGE RecordWildCards, StandaloneDeriving, OverloadedStrings #-} -{-# LANGUAGE CPP, FlexibleContexts, TupleSections, TypeSynonymInstances #-} +{-# LANGUAGE RecordWildCards, OverloadedStrings #-} +{-# LANGUAGE CPP, FlexibleContexts #-} {-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, UndecidableInstances #-} {-# LANGUAGE BangPatterns #-} @@ -35,7 +35,7 @@ module Database.MongoDB.Internal.Protocol ( #if !MIN_VERSION_base(4,8,0) import Control.Applicative ((<$>)) #endif -import Control.Monad (forM, replicateM, unless) +import Control.Monad ( forM, replicateM, unless, forever ) import Data.Binary.Get (Get, runGet) import Data.Binary.Put (Put, runPut) import Data.Bits (bit, testBit) @@ -46,7 +46,6 @@ import System.IO.Error (doesNotExistErrorType, mkIOError) import System.IO.Unsafe (unsafePerformIO) import Data.Maybe (maybeToList) import GHC.Conc (ThreadStatus(..), threadStatus) -import Control.Monad (forever) import Control.Monad.STM (atomically) import Control.Concurrent (ThreadId, killThread, forkIOWithUnmask) import Control.Concurrent.STM.TChan (TChan, newTChan, readTChan, writeTChan, isEmptyTChan) @@ -70,6 +69,7 @@ import Database.MongoDB.Internal.Util (bitOr, byteStringHex) import Database.MongoDB.Transport (Transport) import qualified Database.MongoDB.Transport as Tr + #if MIN_VERSION_base(4,6,0) import Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar, withMVar, putMVar, readMVar, mkWeakMVar, isEmptyMVar) @@ -83,6 +83,7 @@ mkWeakMVar :: MVar a -> IO () -> IO () mkWeakMVar = addMVarFinalizer #endif + -- * Pipeline -- | Thread-safe and pipelined connection @@ -270,6 +271,7 @@ type ResponseTo = RequestId genRequestId :: (MonadIO m) => m RequestId -- ^ Generate fresh request id +{-# NOINLINE genRequestId #-} genRequestId = liftIO $ atomicModifyIORef counter $ \n -> (n + 1, n) where counter :: IORef RequestId counter = unsafePerformIO (newIORef 0) diff --git a/Database/MongoDB/Query.hs b/Database/MongoDB/Query.hs index c814317..a68c8c5 100644 --- a/Database/MongoDB/Query.hs +++ b/Database/MongoDB/Query.hs @@ -1,6 +1,6 @@ -- | Query and update documents -{-# LANGUAGE OverloadedStrings, RecordWildCards, NamedFieldPuns, TupleSections, FlexibleContexts, FlexibleInstances, UndecidableInstances, MultiParamTypeClasses, GeneralizedNewtypeDeriving, StandaloneDeriving, TypeSynonymInstances, TypeFamilies, CPP, DeriveDataTypeable, ScopedTypeVariables, BangPatterns #-} +{-# LANGUAGE OverloadedStrings, RecordWildCards, NamedFieldPuns, TupleSections, FlexibleContexts, FlexibleInstances, UndecidableInstances, MultiParamTypeClasses, TypeFamilies, CPP, DeriveDataTypeable, ScopedTypeVariables, BangPatterns #-} module Database.MongoDB.Query ( -- * Monad @@ -46,69 +46,92 @@ module Database.MongoDB.Query ( eval, retrieveServerData, ServerData(..) ) where -import Prelude hiding (lookup) -import Control.Exception (Exception, throwIO) -import Control.Monad (unless, replicateM, liftM, liftM2) -import Control.Monad.Fail(MonadFail) -import Data.Default.Class (Default(..)) -import Data.Int (Int32, Int64) -import Data.Either (lefts, rights) -import Data.List (foldl1') -import Data.Maybe (listToMaybe, catMaybes, isNothing) -import Data.Word (Word32) -#if !MIN_VERSION_base(4,8,0) -import Data.Monoid (mappend) -#endif -import Data.Typeable (Typeable) -import System.Mem.Weak (Weak) - import qualified Control.Concurrent.MVar as MV -#if MIN_VERSION_base(4,6,0) -import Control.Concurrent.MVar.Lifted (MVar, - readMVar) -#else -import Control.Concurrent.MVar.Lifted (MVar, addMVarFinalizer, - readMVar) -#endif -import Control.Applicative ((<$>)) -import Control.Exception (catch) -import Control.Monad (when, void) -import Control.Monad.Reader (MonadReader, ReaderT, runReaderT, ask, asks, local) +import Control.Concurrent.MVar.Lifted + ( MVar, + readMVar, + ) +import Control.Exception (Exception, catch, throwIO) +import Control.Monad + ( liftM2, + replicateM, + unless, + void, + when, + ) +import Control.Monad.Reader (MonadReader, ReaderT, ask, asks, local, runReaderT) import Control.Monad.Trans (MonadIO, liftIO) -import Data.Binary.Put (runPut) -import Data.Bson (Document, Field(..), Label, Val, Value(String, Doc, Bool), - Javascript, at, valueAt, lookup, look, genObjectId, (=:), - (=?), (!?), Val(..), ObjectId, Value(..)) -import Data.Bson.Binary (putDocument) -import Data.Text (Text) -import qualified Data.Text as T - -import Database.MongoDB.Internal.Protocol (Reply(..), QueryOption(..), - ResponseFlag(..), InsertOption(..), - UpdateOption(..), DeleteOption(..), - CursorId, FullCollection, Username, - Password, Pipe, Notice(..), - Request(GetMore, qOptions, qSkip, - qFullCollection, qBatchSize, - qSelector, qProjector), - pwKey, ServerData(..)) -import Database.MongoDB.Internal.Util (loop, liftIOE, true1, (<.>)) -import qualified Database.MongoDB.Internal.Protocol as P - -import qualified Crypto.Nonce as Nonce -import qualified Data.ByteString as BS -import qualified Data.ByteString.Lazy as LBS -import qualified Data.ByteString.Base16 as B16 -import qualified Data.ByteString.Base64 as B64 -import qualified Data.ByteString.Char8 as B -import qualified Data.Either as E import qualified Crypto.Hash.MD5 as MD5 import qualified Crypto.Hash.SHA1 as SHA1 import qualified Crypto.MAC.HMAC as HMAC +import qualified Crypto.Nonce as Nonce +import Data.Binary.Put (runPut) import Data.Bits (xor) +import Data.Bson + ( Document, + Field (..), + Javascript, + Label, + ObjectId, + Val (..), + Value (..), + at, + genObjectId, + look, + lookup, + valueAt, + (!?), + (=:), + (=?), + ) +import Data.Bson.Binary (putDocument) +import qualified Data.ByteString as BS +import qualified Data.ByteString.Base16 as B16 +import qualified Data.ByteString.Base64 as B64 +import qualified Data.ByteString.Char8 as B +import qualified Data.ByteString.Lazy as LBS +import Data.Default.Class (Default (..)) +import Data.Either (lefts, rights) +import qualified Data.Either as E +import Data.Functor ((<&>)) +import Data.Int (Int32, Int64) +import Data.List (foldl1') import qualified Data.Map as Map +import Data.Maybe (catMaybes, fromMaybe, isNothing, listToMaybe, mapMaybe) +import Data.Text (Text) +import qualified Data.Text as T +import Data.Typeable (Typeable) +import Data.Word (Word32) +import Database.MongoDB.Internal.Protocol + ( CursorId, + DeleteOption (..), + FullCollection, + InsertOption (..), + Notice (..), + Password, + Pipe, + QueryOption (..), + Reply (..), + Request + ( GetMore, + qBatchSize, + qFullCollection, + qOptions, + qProjector, + qSelector, + qSkip + ), + ResponseFlag (..), + ServerData (..), + UpdateOption (..), + Username, + pwKey, + ) +import qualified Database.MongoDB.Internal.Protocol as P +import Database.MongoDB.Internal.Util (liftIOE, loop, true1, (<.>)) +import System.Mem.Weak (Weak) import Text.Read (readMaybe) -import Data.Maybe (fromMaybe) +import Prelude hiding (lookup) -- * Monad @@ -185,7 +208,7 @@ slaveOk = ReadStaleOk accessMode :: (Monad m) => AccessMode -> Action m a -> Action m a -- ^ Run action with given 'AccessMode' -accessMode mode act = local (\ctx -> ctx {mongoAccessMode = mode}) act +accessMode mode = local (\ctx -> ctx {mongoAccessMode = mode}) readMode :: AccessMode -> ReadMode readMode ReadStaleOk = StaleOk @@ -227,7 +250,7 @@ type Database = Text allDatabases :: (MonadIO m) => Action m [Database] -- ^ List all databases residing on server -allDatabases = (map (at "name") . at "databases") `liftM` useDb "admin" (runCommand1 "listDatabases") +allDatabases = map (at "name") . at "databases" <$> useDb "admin" (runCommand1 "listDatabases") thisDatabase :: (Monad m) => Action m Database -- ^ Current database in use @@ -235,34 +258,34 @@ thisDatabase = asks mongoDatabase useDb :: (Monad m) => Database -> Action m a -> Action m a -- ^ Run action against given database -useDb db act = local (\ctx -> ctx {mongoDatabase = db}) act +useDb db = local (\ctx -> ctx {mongoDatabase = db}) -- * Authentication auth :: MonadIO m => Username -> Password -> Action m Bool -- ^ Authenticate with the current database (if server is running in secure mode). Return whether authentication was successful or not. Reauthentication is required for every new pipe. SCRAM-SHA-1 will be used for server versions 3.0+, MONGO-CR for lower versions. auth un pw = do - let serverVersion = liftM (at "version") $ useDb "admin" $ runCommand ["buildinfo" =: (1 :: Int)] - mmv <- liftM (readMaybe . T.unpack . head . T.splitOn ".") $ serverVersion + let serverVersion = fmap (at "version") $ useDb "admin" $ runCommand ["buildinfo" =: (1 :: Int)] + mmv <- readMaybe . T.unpack . head . T.splitOn "." <$> serverVersion maybe (return False) performAuth mmv where performAuth majorVersion = - case (majorVersion >= (3 :: Int)) of - True -> authSCRAMSHA1 un pw - False -> authMongoCR un pw + if majorVersion >= (3 :: Int) + then authSCRAMSHA1 un pw + else authMongoCR un pw authMongoCR :: (MonadIO m) => Username -> Password -> Action m Bool -- ^ Authenticate with the current database, using the MongoDB-CR authentication mechanism (default in MongoDB server < 3.0) authMongoCR usr pss = do - n <- at "nonce" `liftM` runCommand ["getnonce" =: (1 :: Int)] - true1 "ok" `liftM` runCommand ["authenticate" =: (1 :: Int), "user" =: usr, "nonce" =: n, "key" =: pwKey n usr pss] + n <- at "nonce" <$> runCommand ["getnonce" =: (1 :: Int)] + true1 "ok" <$> runCommand ["authenticate" =: (1 :: Int), "user" =: usr, "nonce" =: n, "key" =: pwKey n usr pss] authSCRAMSHA1 :: MonadIO m => Username -> Password -> Action m Bool -- ^ Authenticate with the current database, using the SCRAM-SHA-1 authentication mechanism (default in MongoDB server >= 3.0) authSCRAMSHA1 un pw = do let hmac = HMAC.hmac SHA1.hash 64 - nonce <- liftIO (Nonce.withGenerator Nonce.nonce128 >>= return . B64.encode) - let firstBare = B.concat [B.pack $ "n=" ++ (T.unpack un) ++ ",r=", nonce] + nonce <- liftIO (Nonce.withGenerator Nonce.nonce128 <&> B64.encode) + let firstBare = B.concat [B.pack $ "n=" ++ T.unpack un ++ ",r=", nonce] let client1 = ["saslStart" =: (1 :: Int), "mechanism" =: ("SCRAM-SHA-1" :: String), "payload" =: (B.unpack . B64.encode $ B.concat [B.pack "n,,", firstBare]), "autoAuthorize" =: (1 :: Int)] server1 <- runCommand client1 @@ -286,7 +309,7 @@ authSCRAMSHA1 un pw = do let clientFinal = B.concat [withoutProof, B.pack ",p=", pval] let serverKey = hmac saltedPass (B.pack "Server Key") let serverSig = B64.encode $ hmac serverKey authMsg - let client2 = ["saslContinue" =: (1 :: Int), "conversationId" =: (at "conversationId" server1 :: Int), "payload" =: (B.unpack $ B64.encode clientFinal)] + let client2 = ["saslContinue" =: (1 :: Int), "conversationId" =: (at "conversationId" server1 :: Int), "payload" =: B.unpack (B64.encode clientFinal)] server2 <- runCommand client2 shortcircuit (true1 "ok" server2) $ do @@ -317,19 +340,19 @@ scramHI digest salt iters = snd $ foldl com (u1, u1) [1..(iters-1)] com (u,uc) _ = let u' = hmacd u in (u', BS.pack $ BS.zipWith xor uc u') parseSCRAM :: B.ByteString -> Map.Map B.ByteString B.ByteString -parseSCRAM = Map.fromList . fmap cleanup . (fmap $ T.breakOn "=") . T.splitOn "," . T.pack . B.unpack +parseSCRAM = Map.fromList . fmap (cleanup . T.breakOn "=") . T.splitOn "," . T.pack . B.unpack where cleanup (t1, t2) = (B.pack $ T.unpack t1, B.pack . T.unpack $ T.drop 1 t2) retrieveServerData :: (MonadIO m) => Action m ServerData retrieveServerData = do d <- runCommand1 "isMaster" let newSd = ServerData - { isMaster = (fromMaybe False $ lookup "ismaster" d) - , minWireVersion = (fromMaybe 0 $ lookup "minWireVersion" d) - , maxWireVersion = (fromMaybe 0 $ lookup "maxWireVersion" d) - , maxMessageSizeBytes = (fromMaybe 48000000 $ lookup "maxMessageSizeBytes" d) - , maxBsonObjectSize = (fromMaybe (16 * 1024 * 1024) $ lookup "maxBsonObjectSize" d) - , maxWriteBatchSize = (fromMaybe 1000 $ lookup "maxWriteBatchSize" d) + { isMaster = fromMaybe False $ lookup "ismaster" d + , minWireVersion = fromMaybe 0 $ lookup "minWireVersion" d + , maxWireVersion = fromMaybe 0 $ lookup "maxWireVersion" d + , maxMessageSizeBytes = fromMaybe 48000000 $ lookup "maxMessageSizeBytes" d + , maxBsonObjectSize = fromMaybe (16 * 1024 * 1024) $ lookup "maxBsonObjectSize" d + , maxWriteBatchSize = fromMaybe 1000 $ lookup "maxWriteBatchSize" d } return newSd @@ -343,11 +366,11 @@ allCollections :: MonadIO m => Action m [Collection] allCollections = do p <- asks mongoPipe let sd = P.serverData p - if (maxWireVersion sd <= 2) + if maxWireVersion sd <= 2 then do db <- thisDatabase docs <- rest =<< find (query [] "system.namespaces") {sort = ["name" =: (1 :: Int)]} - return . filter (not . isSpecial db) . map dropDbPrefix $ map (at "name") docs + (return . filter (not . isSpecial db)) (map (dropDbPrefix . at "name") docs) else do r <- runCommand1 "listCollections" let curData = do @@ -355,14 +378,14 @@ allCollections = do (curId :: Int64) <- curDoc !? "id" (curNs :: Text) <- curDoc !? "ns" (firstBatch :: [Value]) <- curDoc !? "firstBatch" - return $ (curId, curNs, ((catMaybes (map cast' firstBatch)) :: [Document])) + return (curId, curNs, mapMaybe cast' firstBatch :: [Document]) case curData of Nothing -> return [] Just (curId, curNs, firstBatch) -> do db <- thisDatabase nc <- newCursor db curNs 0 $ return $ Batch Nothing curId firstBatch docs <- rest nc - return $ catMaybes $ map (\d -> (d !? "name")) docs + return $ mapMaybe (\d -> d !? "name") docs where dropDbPrefix = T.tail . T.dropWhile (/= '.') isSpecial db col = T.any (== '$') col && db <.> col /= "local.oplog.$main" @@ -473,7 +496,7 @@ insert' opts col docs = do NoConfirm -> ["w" =: (0 :: Int)] Confirm params -> params let docSize = sizeOfDocument $ insertCommandDocument opts col [] writeConcern - let ordered = (not (KeepGoing `elem` opts)) + let ordered = KeepGoing `notElem` opts let preChunks = splitAtLimit (maxBsonObjectSize sd - docSize) -- size of auxiliary part of insert @@ -487,7 +510,7 @@ insert' opts col docs = do else rights preChunks let lens = map length chunks - let lSums = 0 : (zipWith (+) lSums lens) + let lSums = 0 : zipWith (+) lSums lens chunkResults <- interruptibleFor ordered (zip lSums chunks) $ insertBlock opts col @@ -508,13 +531,13 @@ insertBlock opts col (prevCount, docs) = do p <- asks mongoPipe let sd = P.serverData p - if (maxWireVersion sd < 2) + if maxWireVersion sd < 2 then do res <- liftDB $ write (Insert (db <.> col) opts docs) let errorMessage = do jRes <- res em <- lookup "err" jRes - return $ WriteFailure prevCount (maybe 0 id $ lookup "code" jRes) em + return $ WriteFailure prevCount (fromMaybe 0 $ lookup "code" jRes) em -- In older versions of ^^ the protocol we can't really say which document failed. -- So we just report the accumulated number of documents in the previous blocks. @@ -530,45 +553,45 @@ insertBlock opts col (prevCount, docs) = do case (look "writeErrors" doc, look "writeConcernError" doc) of (Nothing, Nothing) -> return $ Right $ map (valueAt "_id") docs (Just (Array errs), Nothing) -> do - let writeErrors = map (anyToWriteError prevCount) $ errs + let writeErrors = map (anyToWriteError prevCount) errs let errorsWithFailureIndex = map (addFailureIndex prevCount) writeErrors return $ Left $ CompoundFailure errorsWithFailureIndex (Nothing, Just err) -> do return $ Left $ WriteFailure prevCount - (maybe 0 id $ lookup "ok" doc) + (fromMaybe 0 $ lookup "ok" doc) (show err) (Just (Array errs), Just writeConcernErr) -> do - let writeErrors = map (anyToWriteError prevCount) $ errs + let writeErrors = map (anyToWriteError prevCount) errs let errorsWithFailureIndex = map (addFailureIndex prevCount) writeErrors - return $ Left $ CompoundFailure $ (WriteFailure + return $ Left $ CompoundFailure $ WriteFailure prevCount - (maybe 0 id $ lookup "ok" doc) - (show writeConcernErr)) : errorsWithFailureIndex + (fromMaybe 0 $ lookup "ok" doc) + (show writeConcernErr) : errorsWithFailureIndex (Just unknownValue, Nothing) -> do return $ Left $ ProtocolFailure prevCount $ "Expected array of errors. Received: " ++ show unknownValue (Just unknownValue, Just writeConcernErr) -> do - return $ Left $ CompoundFailure $ [ ProtocolFailure prevCount $ "Expected array of errors. Received: " ++ show unknownValue - , WriteFailure prevCount (maybe 0 id $ lookup "ok" doc) $ show writeConcernErr] + return $ Left $ CompoundFailure [ ProtocolFailure prevCount $ "Expected array of errors. Received: " ++ show unknownValue + , WriteFailure prevCount (fromMaybe 0 $ lookup "ok" doc) $ show writeConcernErr] splitAtLimit :: Int -> Int -> [Document] -> [Either Failure [Document]] splitAtLimit maxSize maxCount list = chop (go 0 0 []) list where - go :: Int -> Int -> [Document] -> [Document] -> ((Either Failure [Document]), [Document]) + go :: Int -> Int -> [Document] -> [Document] -> (Either Failure [Document], [Document]) go _ _ res [] = (Right $ reverse res, []) go curSize curCount [] (x:xs) | - ((curSize + (sizeOfDocument x) + 2 + curCount) > maxSize) = + (curSize + sizeOfDocument x + 2 + curCount) > maxSize = (Left $ WriteFailure 0 0 "One document is too big for the message", xs) go curSize curCount res (x:xs) = - if ( ((curSize + (sizeOfDocument x) + 2 + curCount) > maxSize) + if ((curSize + sizeOfDocument x + 2 + curCount) > maxSize) -- we have ^ 2 brackets and curCount commas in -- the document that we need to take into -- account - || ((curCount + 1) > maxCount)) + || ((curCount + 1) > maxCount) then (Right $ reverse res, x:xs) else - go (curSize + (sizeOfDocument x)) (curCount + 1) (x:res) xs + go (curSize + sizeOfDocument x) (curCount + 1) (x:res) xs chop :: ([a] -> (b, [a])) -> [a] -> [b] chop _ [] = [] @@ -581,7 +604,7 @@ assignId :: Document -> IO Document -- ^ Assign a unique value to _id field if missing assignId doc = if any (("_id" ==) . label) doc then return doc - else (\oid -> ("_id" =: oid) : doc) `liftM` genObjectId + else (\oid -> ("_id" =: oid) : doc) <$> genObjectId -- ** Update @@ -696,22 +719,21 @@ update' ordered col updateDocs = do then takeRightsUpToLeft preChunks else rights preChunks let lens = map length chunks - let lSums = 0 : (zipWith (+) lSums lens) + let lSums = 0 : zipWith (+) lSums lens blocks <- interruptibleFor ordered (zip lSums chunks) $ \b -> do - ur <- runReaderT (updateBlock ordered col b) ctx - return ur + runReaderT (updateBlock ordered col b) ctx `catch` \(e :: Failure) -> do return $ WriteResult True 0 Nothing 0 [] [e] [] - let failedTotal = or $ map failed blocks + let failedTotal = any failed blocks let updatedTotal = sum $ map nMatched blocks let modifiedTotal = - if all isNothing $ map nModified blocks + if all (isNothing . nModified) blocks then Nothing - else Just $ sum $ catMaybes $ map nModified blocks - let totalWriteErrors = concat $ map writeErrors blocks - let totalWriteConcernErrors = concat $ map writeConcernErrors blocks + else Just $ sum $ mapMaybe nModified blocks + let totalWriteErrors = concatMap writeErrors blocks + let totalWriteConcernErrors = concatMap writeConcernErrors blocks - let upsertedTotal = concat $ map upserted blocks + let upsertedTotal = concatMap upserted blocks return $ WriteResult failedTotal updatedTotal @@ -728,7 +750,7 @@ updateBlock :: (MonadIO m) updateBlock ordered col (prevCount, docs) = do p <- asks mongoPipe let sd = P.serverData p - if (maxWireVersion sd < 2) + if maxWireVersion sd < 2 then liftIO $ ioError $ userError "updateMany doesn't support mongodb older than 2.6" else do mode <- asks mongoWriteMode @@ -751,7 +773,7 @@ updateBlock ordered col (prevCount, docs) = do [ ProtocolFailure prevCount $ "Expected array of error docs, but received: " - ++ (show unknownErr)] + ++ show unknownErr] [] let writeConcernResults = @@ -778,9 +800,9 @@ updateBlock ordered col (prevCount, docs) = do [ ProtocolFailure prevCount $ "Expected doc in writeConcernError, but received: " - ++ (show unknownErr)] + ++ show unknownErr] - let upsertedList = map docToUpserted $ fromMaybe [] (doc !? "upserted") + let upsertedList = maybe [] (map docToUpserted) (doc !? "upserted") let successResults = WriteResult False n (doc !? "nModified") 0 upsertedList [] [] return $ foldl1' mergeWriteResults [writeErrorsResults, writeConcernResults, successResults] @@ -799,10 +821,10 @@ mergeWriteResults :: WriteResult -> WriteResult -> WriteResult mergeWriteResults (WriteResult failed1 nMatched1 nModified1 nDeleted1 upserted1 writeErrors1 writeConcernErrors1) (WriteResult failed2 nMatched2 nModified2 nDeleted2 upserted2 writeErrors2 writeConcernErrors2) = - (WriteResult + WriteResult (failed1 || failed2) (nMatched1 + nMatched2) - ((liftM2 (+)) nModified1 nModified2) + (liftM2 (+) nModified1 nModified2) (nDeleted1 + nDeleted2) -- This function is used in foldl1' function. The first argument is the accumulator. -- The list in the accumulator is usually longer than the subsequent value which goes in the second argument. @@ -811,7 +833,6 @@ mergeWriteResults (upserted2 ++ upserted1) (writeErrors2 ++ writeErrors1) (writeConcernErrors2 ++ writeConcernErrors1) - ) docToUpserted :: Document -> Upserted @@ -905,7 +926,7 @@ delete' ordered col deleteDocs = do deletes ctx <- ask let lens = map (either (const 1) length) chunks - let lSums = 0 : (zipWith (+) lSums lens) + let lSums = 0 : zipWith (+) lSums lens let failureResult e = return $ WriteResult True 0 Nothing 0 [] [e] [] let doChunk b = runReaderT (deleteBlock ordered col b) ctx `catch` failureResult blockResult <- liftIO $ interruptibleFor ordered (zip lSums chunks) $ \(n, c) -> @@ -924,7 +945,7 @@ deleteBlock :: (MonadIO m) deleteBlock ordered col (prevCount, docs) = do p <- asks mongoPipe let sd = P.serverData p - if (maxWireVersion sd < 2) + if maxWireVersion sd < 2 then liftIO $ ioError $ userError "deleteMany doesn't support mongodb older than 2.6" else do mode <- asks mongoWriteMode @@ -948,7 +969,7 @@ deleteBlock ordered col (prevCount, docs) = do [ ProtocolFailure prevCount $ "Expected array of error docs, but received: " - ++ (show unknownErr)] + ++ show unknownErr] [] let writeConcernResults = case look "writeConcernError" doc of @@ -974,7 +995,7 @@ deleteBlock ordered col (prevCount, docs) = do [ ProtocolFailure prevCount $ "Expected doc in writeConcernError, but received: " - ++ (show unknownErr)] + ++ show unknownErr] return $ foldl1' mergeWriteResults [successResults, writeErrorsResults, writeConcernResults] anyToWriteError :: Int -> Value -> Failure @@ -1115,11 +1136,11 @@ findAndModifyOpts :: (MonadIO m, MonadFail m) => Query -> FindAndModifyOpts -> Action m (Either String (Maybe Document)) -findAndModifyOpts (Query { +findAndModifyOpts Query { selection = Select sel collection , project = project , sort = sort - }) famOpts = do + } famOpts = do result <- runCommand ([ "findAndModify" := String collection , "query" := Doc sel @@ -1165,13 +1186,13 @@ explain q = do -- same as findOne but with explain set to true count :: (MonadIO m) => Query -> Action m Int -- ^ Fetch number of documents satisfying query (including effect of skip and/or limit if present) -count Query{selection = Select sel col, skip, limit} = at "n" `liftM` runCommand +count Query{selection = Select sel col, skip, limit} = at "n" <$> runCommand (["count" =: col, "query" =: sel, "skip" =: (fromIntegral skip :: Int32)] ++ ("limit" =? if limit == 0 then Nothing else Just (fromIntegral limit :: Int32))) distinct :: (MonadIO m) => Label -> Selection -> Action m [Value] -- ^ Fetch distinct values of field in selected documents -distinct k (Select sel col) = at "values" `liftM` runCommand ["distinct" =: col, "key" =: k, "query" =: sel] +distinct k (Select sel col) = at "values" <$> runCommand ["distinct" =: col, "key" =: k, "query" =: sel] queryRequest :: (Monad m) => Bool -> Query -> Action m (Request, Maybe Limit) -- ^ Translate Query to Protocol.Query. If first arg is true then add special $explain attribute. @@ -1192,7 +1213,7 @@ queryRequest isExplain Query{..} = do special = catMaybes [mOrder, mSnapshot, mHint, mExplain] qSelector = if null special then s else ("$query" =: s) : special where s = selector selection -batchSizeRemainingLimit :: BatchSize -> (Maybe Limit) -> (Int32, Maybe Limit) +batchSizeRemainingLimit :: BatchSize -> Maybe Limit -> (Int32, Maybe Limit) -- ^ Given batchSize and limit return P.qBatchSize and remaining limit batchSizeRemainingLimit batchSize mLimit = let remaining = @@ -1253,10 +1274,10 @@ nextBatch (Cursor fcol batchSize var) = liftDB $ modifyMVar var $ \dBatch -> do Batch mLimit cid docs <- liftDB $ fulfill' fcol batchSize dBatch let newLimit = do limit <- mLimit - return $ limit - (min limit $ fromIntegral $ length docs) + return $ limit - min limit (fromIntegral $ length docs) let emptyBatch = return $ Batch (Just 0) 0 [] let getNextBatch = nextBatch' fcol batchSize newLimit cid - let resultDocs = (maybe id (take . fromIntegral) mLimit) docs + let resultDocs = maybe id (take . fromIntegral) mLimit docs case (cid, newLimit) of (0, _) -> return (emptyBatch, resultDocs) (_, Just 0) -> do @@ -1269,11 +1290,11 @@ fulfill' :: FullCollection -> BatchSize -> DelayedBatch -> Action IO Batch -- Discard pre-fetched batch if empty with nonzero cid. fulfill' fcol batchSize dBatch = do b@(Batch limit cid docs) <- fulfill dBatch - if cid /= 0 && null docs && (limit > (Just 0)) + if cid /= 0 && null docs && (limit > Just 0) then nextBatch' fcol batchSize limit cid >>= fulfill else return b -nextBatch' :: (MonadIO m) => FullCollection -> BatchSize -> (Maybe Limit) -> CursorId -> Action m DelayedBatch +nextBatch' :: (MonadIO m) => FullCollection -> BatchSize -> Maybe Limit -> CursorId -> Action m DelayedBatch nextBatch' fcol batchSize limit cid = do pipe <- asks mongoPipe liftIO $ request pipe [] (GetMore fcol batchSize' cid, remLimit) @@ -1286,7 +1307,7 @@ next (Cursor fcol batchSize var) = liftDB $ modifyMVar var nextState where -- nextState:: DelayedBatch -> Action m (DelayedBatch, Maybe Document) nextState dBatch = do Batch mLimit cid docs <- liftDB $ fulfill' fcol batchSize dBatch - if mLimit == (Just 0) + if mLimit == Just 0 then return (return $ Batch (Just 0) 0 [], Nothing) else case docs of @@ -1294,10 +1315,10 @@ next (Cursor fcol batchSize var) = liftDB $ modifyMVar var nextState where let newLimit = do limit <- mLimit return $ limit - 1 - dBatch' <- if null docs' && cid /= 0 && ((newLimit > (Just 0)) || (isNothing newLimit)) + dBatch' <- if null docs' && cid /= 0 && ((newLimit > Just 0) || isNothing newLimit) then nextBatch' fcol batchSize newLimit cid else return $ return (Batch newLimit cid docs') - when (newLimit == (Just 0)) $ unless (cid == 0) $ do + when (newLimit == Just 0) $ unless (cid == 0) $ do pipe <- asks mongoPipe liftIOE ConnectionFailure $ P.send pipe [KillCursors [cid]] return (dBatch', Just doc) @@ -1309,7 +1330,7 @@ next (Cursor fcol batchSize var) = liftDB $ modifyMVar var nextState where nextN :: MonadIO m => Int -> Cursor -> Action m [Document] -- ^ Return next N documents or less if end is reached -nextN n c = catMaybes `liftM` replicateM n (next c) +nextN n c = catMaybes <$> replicateM n (next c) rest :: MonadIO m => Cursor -> Action m [Document] -- ^ Return remaining documents in query result @@ -1321,7 +1342,7 @@ closeCursor (Cursor _ _ var) = liftDB $ modifyMVar var $ \dBatch -> do unless (cid == 0) $ do pipe <- asks mongoPipe liftIOE ConnectionFailure $ P.send pipe [KillCursors [cid]] - return $ (return $ Batch (Just 0) 0 [], ()) + return (return $ Batch (Just 0) 0 [], ()) isCursorClosed :: MonadIO m => Cursor -> Action m Bool isCursorClosed (Cursor _ _ var) = do @@ -1404,7 +1425,7 @@ groupDocument Group{..} = group :: (MonadIO m) => Group -> Action m [Document] -- ^ Execute group query and return resulting aggregate value for each distinct key -group g = at "retval" `liftM` runCommand ["group" =: groupDocument g] +group g = at "retval" <$> runCommand ["group" =: groupDocument g] -- ** MapReduce @@ -1497,7 +1518,7 @@ type Command = Document runCommand :: (MonadIO m) => Command -> Action m Document -- ^ Run command against the database and return its result -runCommand c = maybe err id `liftM` findOne (query c "$cmd") where +runCommand c = fromMaybe err <$> findOne (query c "$cmd") where err = error $ "Nothing returned for command: " ++ show c runCommand1 :: (MonadIO m) => Text -> Action m Document @@ -1506,7 +1527,7 @@ runCommand1 c = runCommand [c =: (1 :: Int)] eval :: (MonadIO m, Val v) => Javascript -> Action m v -- ^ Run code on server -eval code = at "retval" `liftM` runCommand ["$eval" =: code] +eval code = at "retval" <$> runCommand ["$eval" =: code] modifyMVar :: MVar a -> (a -> Action IO (a, b)) -> Action IO b modifyMVar v f = do @@ -1516,6 +1537,7 @@ modifyMVar v f = do mkWeakMVar :: MVar a -> Action IO () -> Action IO (Weak (MVar a)) mkWeakMVar m closing = do ctx <- ask + #if MIN_VERSION_base(4,6,0) liftIO $ MV.mkWeakMVar m $ runReaderT closing ctx #else diff --git a/Database/MongoDB/Transport/Tls.hs b/Database/MongoDB/Transport/Tls.hs index 6915d1f..39b6c0e 100644 --- a/Database/MongoDB/Transport/Tls.hs +++ b/Database/MongoDB/Transport/Tls.hs @@ -1,6 +1,5 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE RecordWildCards #-} #if (__GLASGOW_HASKELL__ >= 706) {-# LANGUAGE RecursiveDo #-} @@ -21,16 +20,17 @@ ATTENTION!!! Be aware that this module is highly experimental and is barely tested. The current implementation doesn't verify server's identity. It only allows you to connect to a mongodb server using TLS protocol. -} + module Database.MongoDB.Transport.Tls -(connect) +( connect +, connectWithTlsParams +) where import Data.IORef -import Data.Monoid import qualified Data.ByteString as ByteString import qualified Data.ByteString.Lazy as Lazy.ByteString import Data.Default.Class (def) -import Control.Applicative ((<$>)) import Control.Exception (bracketOnError) import Control.Monad (when, unless) import System.IO @@ -45,15 +45,19 @@ import Database.MongoDB.Query (access, slaveOk, retrieveServerData) -- | Connect to mongodb using TLS connect :: HostName -> PortID -> IO Pipe -connect host port = bracketOnError (connectTo host port) hClose $ \handle -> do - - let params = (TLS.defaultParamsClient host "") +connect host port = connectWithTlsParams params host port + where + params = (TLS.defaultParamsClient host "") { TLS.clientSupported = def - { TLS.supportedCiphers = TLS.ciphersuite_default} + { TLS.supportedCiphers = TLS.ciphersuite_default } , TLS.clientHooks = def - { TLS.onServerCertificate = \_ _ _ _ -> return []} + { TLS.onServerCertificate = \_ _ _ _ -> return [] } } - context <- TLS.contextNew handle params + +-- | Connect to mongodb using TLS using provided TLS client parameters +connectWithTlsParams :: TLS.ClientParams -> HostName -> PortID -> IO Pipe +connectWithTlsParams clientParams host port = bracketOnError (connectTo host port) hClose $ \handle -> do + context <- TLS.contextNew handle clientParams TLS.handshake context conn <- tlsConnection context diff --git a/dist-newstyle/cache/config b/dist-newstyle/cache/config index 7e39491..feaef74 100644 Binary files a/dist-newstyle/cache/config and b/dist-newstyle/cache/config differ