Compare commits

..

No commits in common. "scram-sha256" and "master" have entirely different histories.

View file

@ -11,7 +11,7 @@ module Database.MongoDB.Query (
-- * Database -- * Database
Database, allDatabases, useDb, thisDatabase, Database, allDatabases, useDb, thisDatabase,
-- ** Authentication -- ** Authentication
Username, Password, auth, authMongoCR, authSCRAMSHA1, authSCRAMSHA256, Username, Password, auth, authMongoCR, authSCRAMSHA1,
-- * Collection -- * Collection
Collection, allCollections, Collection, allCollections,
-- ** Selection -- ** Selection
@ -63,7 +63,6 @@ import Control.Monad.Reader (MonadReader, ReaderT, ask, asks, local, runReaderT)
import Control.Monad.Trans (MonadIO, liftIO, lift) import Control.Monad.Trans (MonadIO, liftIO, lift)
import qualified Crypto.Hash.MD5 as MD5 import qualified Crypto.Hash.MD5 as MD5
import qualified Crypto.Hash.SHA1 as SHA1 import qualified Crypto.Hash.SHA1 as SHA1
import qualified Crypto.Hash.SHA256 as SHA256
import qualified Crypto.MAC.HMAC as HMAC import qualified Crypto.MAC.HMAC as HMAC
import qualified Crypto.Nonce as Nonce import qualified Crypto.Nonce as Nonce
import Data.Binary.Put (runPut) import Data.Binary.Put (runPut)
@ -275,10 +274,10 @@ auth un pw = do
mmv <- readMaybe . T.unpack . head . T.splitOn "." <$> serverVersion mmv <- readMaybe . T.unpack . head . T.splitOn "." <$> serverVersion
maybe (return False) performAuth mmv maybe (return False) performAuth mmv
where where
performAuth majorVersion performAuth majorVersion =
| majorVersion >= (6 :: Int) = authSCRAMSHA256 un pw if majorVersion >= (3 :: Int)
| majorVersion >= (3 :: Int) = authSCRAMSHA1 un pw then authSCRAMSHA1 un pw
| otherwise = authMongoCR un pw else authMongoCR un pw
authMongoCR :: (MonadIO m) => Username -> Password -> Action m Bool authMongoCR :: (MonadIO m) => Username -> Password -> Action m Bool
-- ^ Authenticate with the current database, using the MongoDB-CR authentication mechanism (default in MongoDB server < 3.0) -- ^ Authenticate with the current database, using the MongoDB-CR authentication mechanism (default in MongoDB server < 3.0)
@ -286,48 +285,29 @@ authMongoCR usr pss = do
n <- at "nonce" <$> runCommand ["getnonce" =: (1 :: Int)] n <- at "nonce" <$> runCommand ["getnonce" =: (1 :: Int)]
true1 "ok" <$> runCommand ["authenticate" =: (1 :: Int), "user" =: usr, "nonce" =: n, "key" =: pwKey n usr pss] true1 "ok" <$> runCommand ["authenticate" =: (1 :: Int), "user" =: usr, "nonce" =: n, "key" =: pwKey n usr pss]
data HashAlgorithm = SHA1 | SHA256
hash :: HashAlgorithm -> B.ByteString -> B.ByteString
hash SHA1 = SHA1.hash
hash SHA256 = SHA256.hash
authSCRAMSHA1 :: MonadIO m => Username -> Password -> Action m Bool authSCRAMSHA1 :: MonadIO m => Username -> Password -> Action m Bool
authSCRAMSHA1 = authSCRAMWith SHA1
authSCRAMSHA256 :: MonadIO m => Username -> Password -> Action m Bool
authSCRAMSHA256 = authSCRAMWith SHA256
authSCRAMWith :: MonadIO m => HashAlgorithm -> Username -> Password -> Action m Bool
-- ^ Authenticate with the current database, using the SCRAM-SHA-1 authentication mechanism (default in MongoDB server >= 3.0) -- ^ Authenticate with the current database, using the SCRAM-SHA-1 authentication mechanism (default in MongoDB server >= 3.0)
authSCRAMWith algo un pw = do authSCRAMSHA1 un pw = do
let hmac = HMAC.hmac (hash algo) 64 let hmac = HMAC.hmac SHA1.hash 64
nonce <- liftIO (Nonce.withGenerator Nonce.nonce128 <&> B64.encode) nonce <- liftIO (Nonce.withGenerator Nonce.nonce128 <&> B64.encode)
let firstBare = B.concat [B.pack $ "n=" ++ T.unpack un ++ ",r=", nonce] let firstBare = B.concat [B.pack $ "n=" ++ T.unpack un ++ ",r=", nonce]
let client1 = let client1 = ["saslStart" =: (1 :: Int), "mechanism" =: ("SCRAM-SHA-1" :: String), "payload" =: (B.unpack . B64.encode $ B.concat [B.pack "n,,", firstBare]), "autoAuthorize" =: (1 :: Int)]
[ "saslStart" =: (1 :: Int)
, "mechanism" =: case algo of
SHA1 -> "SCRAM-SHA-1" :: String
SHA256 -> "SCRAM-SHA-256"
, "payload" =: (B.unpack . B64.encode $ B.concat [B.pack "n,,", firstBare])
, "autoAuthorize" =: (1 :: Int)
]
server1 <- runCommand client1 server1 <- runCommand client1
shortcircuit (true1 "ok" server1) server1 $ do shortcircuit (true1 "ok" server1) $ do
let serverPayload1 = B64.decodeLenient . B.pack . at "payload" $ server1 let serverPayload1 = B64.decodeLenient . B.pack . at "payload" $ server1
let serverData1 = parseSCRAM serverPayload1 let serverData1 = parseSCRAM serverPayload1
let iterations = read . B.unpack $ Map.findWithDefault "1" "i" serverData1 let iterations = read . B.unpack $ Map.findWithDefault "1" "i" serverData1
let salt = B64.decodeLenient $ Map.findWithDefault "" "s" serverData1 let salt = B64.decodeLenient $ Map.findWithDefault "" "s" serverData1
let snonce = Map.findWithDefault "" "r" serverData1 let snonce = Map.findWithDefault "" "r" serverData1
shortcircuit (B.isInfixOf nonce snonce) "nonce" $ do shortcircuit (B.isInfixOf nonce snonce) $ do
let withoutProof = B.concat [B.pack "c=biws,r=", snonce] let withoutProof = B.concat [B.pack "c=biws,r=", snonce]
let digestS = B.pack $ T.unpack un ++ ":mongo:" ++ T.unpack pw let digestS = B.pack $ T.unpack un ++ ":mongo:" ++ T.unpack pw
let digest = B16.encode $ MD5.hash digestS let digest = B16.encode $ MD5.hash digestS
let saltedPass = scramHI algo digest salt iterations let saltedPass = scramHI digest salt iterations
let clientKey = hmac saltedPass (B.pack "Client Key") let clientKey = hmac saltedPass (B.pack "Client Key")
let storedKey = hash algo clientKey let storedKey = SHA1.hash clientKey
let authMsg = B.concat [firstBare, B.pack ",", serverPayload1, B.pack ",", withoutProof] let authMsg = B.concat [firstBare, B.pack ",", serverPayload1, B.pack ",", withoutProof]
let clientSig = hmac storedKey authMsg let clientSig = hmac storedKey authMsg
let pval = B64.encode . BS.pack $ BS.zipWith xor clientKey clientSig let pval = B64.encode . BS.pack $ BS.zipWith xor clientKey clientSig
@ -337,12 +317,12 @@ authSCRAMWith algo un pw = do
let client2 = ["saslContinue" =: (1 :: Int), "conversationId" =: (at "conversationId" server1 :: Int), "payload" =: B.unpack (B64.encode clientFinal)] let client2 = ["saslContinue" =: (1 :: Int), "conversationId" =: (at "conversationId" server1 :: Int), "payload" =: B.unpack (B64.encode clientFinal)]
server2 <- runCommand client2 server2 <- runCommand client2
shortcircuit (true1 "ok" server2) "server2" $ do shortcircuit (true1 "ok" server2) $ do
let serverPayload2 = B64.decodeLenient . B.pack $ at "payload" server2 let serverPayload2 = B64.decodeLenient . B.pack $ at "payload" server2
let serverData2 = parseSCRAM serverPayload2 let serverData2 = parseSCRAM serverPayload2
let serverSigComp = Map.findWithDefault "" "v" serverData2 let serverSigComp = Map.findWithDefault "" "v" serverData2
shortcircuit (serverSig == serverSigComp) "server2'" $ do shortcircuit (serverSig == serverSigComp) $ do
let done = true1 "done" server2 let done = true1 "done" server2
if done if done
then return True then return True
@ -351,16 +331,16 @@ authSCRAMWith algo un pw = do
, "conversationId" =: (at "conversationId" server1 :: Int) , "conversationId" =: (at "conversationId" server1 :: Int)
, "payload" =: String ""] , "payload" =: String ""]
server3 <- runCommand client2Step2 server3 <- runCommand client2Step2
shortcircuit (true1 "ok" server3) "server3" $ do shortcircuit (true1 "ok" server3) $ do
return True return True
where where
shortcircuit True _ f = f shortcircuit True f = f
shortcircuit False reason _ = liftIO (print reason) >> return False shortcircuit False _ = return False
scramHI :: HashAlgorithm -> B.ByteString -> B.ByteString -> Int -> B.ByteString scramHI :: B.ByteString -> B.ByteString -> Int -> B.ByteString
scramHI algo digest salt iters = snd $ foldl com (u1, u1) [1..(iters-1)] scramHI digest salt iters = snd $ foldl com (u1, u1) [1..(iters-1)]
where where
hmacd = HMAC.hmac (hash algo) 64 digest hmacd = HMAC.hmac SHA1.hash 64 digest
u1 = hmacd (B.concat [salt, BS.pack [0, 0, 0, 1]]) u1 = hmacd (B.concat [salt, BS.pack [0, 0, 0, 1]])
com (u,uc) _ = let u' = hmacd u in (u', BS.pack $ BS.zipWith xor uc u') com (u,uc) _ = let u' = hmacd u in (u', BS.pack $ BS.zipWith xor uc u')