Compare commits

...

1 commit

Author SHA1 Message Date
Fumiaki Kinoshita
6e79777656 add authSCRAMSHA256 2023-04-17 12:19:06 +09:00

View file

@ -11,7 +11,7 @@ module Database.MongoDB.Query (
-- * Database -- * Database
Database, allDatabases, useDb, thisDatabase, Database, allDatabases, useDb, thisDatabase,
-- ** Authentication -- ** Authentication
Username, Password, auth, authMongoCR, authSCRAMSHA1, Username, Password, auth, authMongoCR, authSCRAMSHA1, authSCRAMSHA256,
-- * Collection -- * Collection
Collection, allCollections, Collection, allCollections,
-- ** Selection -- ** Selection
@ -63,6 +63,7 @@ import Control.Monad.Reader (MonadReader, ReaderT, ask, asks, local, runReaderT)
import Control.Monad.Trans (MonadIO, liftIO, lift) import Control.Monad.Trans (MonadIO, liftIO, lift)
import qualified Crypto.Hash.MD5 as MD5 import qualified Crypto.Hash.MD5 as MD5
import qualified Crypto.Hash.SHA1 as SHA1 import qualified Crypto.Hash.SHA1 as SHA1
import qualified Crypto.Hash.SHA256 as SHA256
import qualified Crypto.MAC.HMAC as HMAC import qualified Crypto.MAC.HMAC as HMAC
import qualified Crypto.Nonce as Nonce import qualified Crypto.Nonce as Nonce
import Data.Binary.Put (runPut) import Data.Binary.Put (runPut)
@ -274,10 +275,10 @@ auth un pw = do
mmv <- readMaybe . T.unpack . head . T.splitOn "." <$> serverVersion mmv <- readMaybe . T.unpack . head . T.splitOn "." <$> serverVersion
maybe (return False) performAuth mmv maybe (return False) performAuth mmv
where where
performAuth majorVersion = performAuth majorVersion
if majorVersion >= (3 :: Int) | majorVersion >= (6 :: Int) = authSCRAMSHA256 un pw
then authSCRAMSHA1 un pw | majorVersion >= (3 :: Int) = authSCRAMSHA1 un pw
else authMongoCR un pw | otherwise = authMongoCR un pw
authMongoCR :: (MonadIO m) => Username -> Password -> Action m Bool authMongoCR :: (MonadIO m) => Username -> Password -> Action m Bool
-- ^ Authenticate with the current database, using the MongoDB-CR authentication mechanism (default in MongoDB server < 3.0) -- ^ Authenticate with the current database, using the MongoDB-CR authentication mechanism (default in MongoDB server < 3.0)
@ -285,29 +286,48 @@ authMongoCR usr pss = do
n <- at "nonce" <$> runCommand ["getnonce" =: (1 :: Int)] n <- at "nonce" <$> runCommand ["getnonce" =: (1 :: Int)]
true1 "ok" <$> runCommand ["authenticate" =: (1 :: Int), "user" =: usr, "nonce" =: n, "key" =: pwKey n usr pss] true1 "ok" <$> runCommand ["authenticate" =: (1 :: Int), "user" =: usr, "nonce" =: n, "key" =: pwKey n usr pss]
data HashAlgorithm = SHA1 | SHA256
hash :: HashAlgorithm -> B.ByteString -> B.ByteString
hash SHA1 = SHA1.hash
hash SHA256 = SHA256.hash
authSCRAMSHA1 :: MonadIO m => Username -> Password -> Action m Bool authSCRAMSHA1 :: MonadIO m => Username -> Password -> Action m Bool
authSCRAMSHA1 = authSCRAMWith SHA1
authSCRAMSHA256 :: MonadIO m => Username -> Password -> Action m Bool
authSCRAMSHA256 = authSCRAMWith SHA256
authSCRAMWith :: MonadIO m => HashAlgorithm -> Username -> Password -> Action m Bool
-- ^ Authenticate with the current database, using the SCRAM-SHA-1 authentication mechanism (default in MongoDB server >= 3.0) -- ^ Authenticate with the current database, using the SCRAM-SHA-1 authentication mechanism (default in MongoDB server >= 3.0)
authSCRAMSHA1 un pw = do authSCRAMWith algo un pw = do
let hmac = HMAC.hmac SHA1.hash 64 let hmac = HMAC.hmac (hash algo) 64
nonce <- liftIO (Nonce.withGenerator Nonce.nonce128 <&> B64.encode) nonce <- liftIO (Nonce.withGenerator Nonce.nonce128 <&> B64.encode)
let firstBare = B.concat [B.pack $ "n=" ++ T.unpack un ++ ",r=", nonce] let firstBare = B.concat [B.pack $ "n=" ++ T.unpack un ++ ",r=", nonce]
let client1 = ["saslStart" =: (1 :: Int), "mechanism" =: ("SCRAM-SHA-1" :: String), "payload" =: (B.unpack . B64.encode $ B.concat [B.pack "n,,", firstBare]), "autoAuthorize" =: (1 :: Int)] let client1 =
[ "saslStart" =: (1 :: Int)
, "mechanism" =: case algo of
SHA1 -> "SCRAM-SHA-1" :: String
SHA256 -> "SCRAM-SHA-256"
, "payload" =: (B.unpack . B64.encode $ B.concat [B.pack "n,,", firstBare])
, "autoAuthorize" =: (1 :: Int)
]
server1 <- runCommand client1 server1 <- runCommand client1
shortcircuit (true1 "ok" server1) $ do shortcircuit (true1 "ok" server1) server1 $ do
let serverPayload1 = B64.decodeLenient . B.pack . at "payload" $ server1 let serverPayload1 = B64.decodeLenient . B.pack . at "payload" $ server1
let serverData1 = parseSCRAM serverPayload1 let serverData1 = parseSCRAM serverPayload1
let iterations = read . B.unpack $ Map.findWithDefault "1" "i" serverData1 let iterations = read . B.unpack $ Map.findWithDefault "1" "i" serverData1
let salt = B64.decodeLenient $ Map.findWithDefault "" "s" serverData1 let salt = B64.decodeLenient $ Map.findWithDefault "" "s" serverData1
let snonce = Map.findWithDefault "" "r" serverData1 let snonce = Map.findWithDefault "" "r" serverData1
shortcircuit (B.isInfixOf nonce snonce) $ do shortcircuit (B.isInfixOf nonce snonce) "nonce" $ do
let withoutProof = B.concat [B.pack "c=biws,r=", snonce] let withoutProof = B.concat [B.pack "c=biws,r=", snonce]
let digestS = B.pack $ T.unpack un ++ ":mongo:" ++ T.unpack pw let digestS = B.pack $ T.unpack un ++ ":mongo:" ++ T.unpack pw
let digest = B16.encode $ MD5.hash digestS let digest = B16.encode $ MD5.hash digestS
let saltedPass = scramHI digest salt iterations let saltedPass = scramHI algo digest salt iterations
let clientKey = hmac saltedPass (B.pack "Client Key") let clientKey = hmac saltedPass (B.pack "Client Key")
let storedKey = SHA1.hash clientKey let storedKey = hash algo clientKey
let authMsg = B.concat [firstBare, B.pack ",", serverPayload1, B.pack ",", withoutProof] let authMsg = B.concat [firstBare, B.pack ",", serverPayload1, B.pack ",", withoutProof]
let clientSig = hmac storedKey authMsg let clientSig = hmac storedKey authMsg
let pval = B64.encode . BS.pack $ BS.zipWith xor clientKey clientSig let pval = B64.encode . BS.pack $ BS.zipWith xor clientKey clientSig
@ -317,12 +337,12 @@ authSCRAMSHA1 un pw = do
let client2 = ["saslContinue" =: (1 :: Int), "conversationId" =: (at "conversationId" server1 :: Int), "payload" =: B.unpack (B64.encode clientFinal)] let client2 = ["saslContinue" =: (1 :: Int), "conversationId" =: (at "conversationId" server1 :: Int), "payload" =: B.unpack (B64.encode clientFinal)]
server2 <- runCommand client2 server2 <- runCommand client2
shortcircuit (true1 "ok" server2) $ do shortcircuit (true1 "ok" server2) "server2" $ do
let serverPayload2 = B64.decodeLenient . B.pack $ at "payload" server2 let serverPayload2 = B64.decodeLenient . B.pack $ at "payload" server2
let serverData2 = parseSCRAM serverPayload2 let serverData2 = parseSCRAM serverPayload2
let serverSigComp = Map.findWithDefault "" "v" serverData2 let serverSigComp = Map.findWithDefault "" "v" serverData2
shortcircuit (serverSig == serverSigComp) $ do shortcircuit (serverSig == serverSigComp) "server2'" $ do
let done = true1 "done" server2 let done = true1 "done" server2
if done if done
then return True then return True
@ -331,16 +351,16 @@ authSCRAMSHA1 un pw = do
, "conversationId" =: (at "conversationId" server1 :: Int) , "conversationId" =: (at "conversationId" server1 :: Int)
, "payload" =: String ""] , "payload" =: String ""]
server3 <- runCommand client2Step2 server3 <- runCommand client2Step2
shortcircuit (true1 "ok" server3) $ do shortcircuit (true1 "ok" server3) "server3" $ do
return True return True
where where
shortcircuit True f = f shortcircuit True _ f = f
shortcircuit False _ = return False shortcircuit False reason _ = liftIO (print reason) >> return False
scramHI :: B.ByteString -> B.ByteString -> Int -> B.ByteString scramHI :: HashAlgorithm -> B.ByteString -> B.ByteString -> Int -> B.ByteString
scramHI digest salt iters = snd $ foldl com (u1, u1) [1..(iters-1)] scramHI algo digest salt iters = snd $ foldl com (u1, u1) [1..(iters-1)]
where where
hmacd = HMAC.hmac SHA1.hash 64 digest hmacd = HMAC.hmac (hash algo) 64 digest
u1 = hmacd (B.concat [salt, BS.pack [0, 0, 0, 1]]) u1 = hmacd (B.concat [salt, BS.pack [0, 0, 0, 1]])
com (u,uc) _ = let u' = hmacd u in (u', BS.pack $ BS.zipWith xor uc u') com (u,uc) _ = let u' = hmacd u in (u', BS.pack $ BS.zipWith xor uc u')