From 81cdd8a2350b96168a06662c2601a41141a19f2d Mon Sep 17 00:00:00 2001 From: Akshay Mankar Date: Mon, 25 Mar 2024 15:21:16 +0100 Subject: [PATCH] Authenticate with each node in cluster mode --- src/Database/Redis/Cluster.hs | 57 +++++++++++++++++++++----------- src/Database/Redis/Connection.hs | 4 +-- 2 files changed, 40 insertions(+), 21 deletions(-) diff --git a/src/Database/Redis/Cluster.hs b/src/Database/Redis/Cluster.hs index 2f577bc2..1a0debcb 100644 --- a/src/Database/Redis/Cluster.hs +++ b/src/Database/Redis/Cluster.hs @@ -34,7 +34,7 @@ import Data.Typeable import qualified Scanner import System.IO.Unsafe(unsafeInterleaveIO) -import Database.Redis.Protocol(Reply(Error), renderRequest, reply) +import Database.Redis.Protocol(Reply(..), renderRequest, reply) import qualified Database.Redis.Cluster.Command as CMD -- This module implements a clustered connection whilst maintaining @@ -100,8 +100,11 @@ instance Exception UnsupportedClusterCommandException newtype CrossSlotException = CrossSlotException [[B.ByteString]] deriving (Show, Typeable) instance Exception CrossSlotException -connect :: [CMD.CommandInfo] -> MVar ShardMap -> Maybe Int -> IO Connection -connect commandInfos shardMapVar timeoutOpt = do +data ClusterAuthError = ClusterAuthError Host Port Reply deriving (Show) +instance Exception ClusterAuthError + +connect :: Maybe B.ByteString -> Maybe B.ByteString -> [CMD.CommandInfo] -> MVar ShardMap -> Maybe Int -> IO Connection +connect mUsername mPassword commandInfos shardMapVar timeoutOpt = do shardMap <- readMVar shardMapVar stateVar <- newMVar $ Pending [] pipelineVar <- newMVar $ Pipeline stateVar @@ -113,7 +116,16 @@ connect commandInfos shardMapVar timeoutOpt = do connectNode (Node n _ host port) = do ctx <- CC.connect host (CC.PortNumber $ toEnum port) timeoutOpt ref <- IOR.newIORef Nothing - return (n, NodeConnection ctx ref n) + let nodeConn = NodeConnection ctx ref n + case mPassword of + Nothing -> pure () + Just password -> do + let reqOpts = maybe [password] (:[password]) mUsername + authReply <- requestNode1 nodeConn ( ["AUTH"] <> reqOpts ) + case authReply of + SingleLine "OK" -> pure () + _ -> throwIO $ ClusterAuthError host port authReply + return (n, nodeConn) disconnect :: Connection -> IO () disconnect (Connection nodeConnMap _ _ _) = mapM_ disconnectNode (HM.elems nodeConnMap) where @@ -370,28 +382,35 @@ allMasterNodes (Connection nodeConns _ _ _) (ShardMap shardMap) = masterNodes = (\(Shard master _) -> master) <$> nub (IntMap.elems shardMap) requestNode :: NodeConnection -> [[B.ByteString]] -> IO [Reply] -requestNode (NodeConnection ctx lastRecvRef _) requests = do +requestNode nodeConn@(NodeConnection ctx _ _) requests = do mapM_ (sendNode . renderRequest) requests _ <- CC.flush ctx - replicateM (length requests) recvNode + replicateM (length requests) $ recvNode nodeConn where sendNode :: B.ByteString -> IO () sendNode = CC.send ctx - recvNode :: IO Reply - recvNode = do - maybeLastRecv <- IOR.readIORef lastRecvRef - scanResult <- case maybeLastRecv of - Just lastRecv -> Scanner.scanWith (CC.recv ctx) reply lastRecv - Nothing -> Scanner.scanWith (CC.recv ctx) reply B.empty - - case scanResult of - Scanner.Fail{} -> CC.errConnClosed - Scanner.More{} -> error "Hedis: parseWith returned Partial" - Scanner.Done rest' r -> do - IOR.writeIORef lastRecvRef (Just rest') - return r + +requestNode1 :: NodeConnection -> [B.ByteString] -> IO Reply +requestNode1 nodeConn@(NodeConnection ctx _ _) request = do + CC.send ctx $ renderRequest request + _ <- CC.flush ctx + recvNode nodeConn + +recvNode :: NodeConnection -> IO Reply +recvNode (NodeConnection ctx lastRecvRef _) = do + maybeLastRecv <- IOR.readIORef lastRecvRef + scanResult <- case maybeLastRecv of + Just lastRecv -> Scanner.scanWith (CC.recv ctx) reply lastRecv + Nothing -> Scanner.scanWith (CC.recv ctx) reply B.empty + + case scanResult of + Scanner.Fail{} -> CC.errConnClosed + Scanner.More{} -> error "Hedis: parseWith returned Partial" + Scanner.Done rest' r -> do + IOR.writeIORef lastRecvRef (Just rest') + return r nodes :: ShardMap -> [Node] nodes (ShardMap shardMap) = concatMap snd $ IntMap.toList $ fmap shardNodes shardMap where diff --git a/src/Database/Redis/Connection.hs b/src/Database/Redis/Connection.hs index 156662ec..c0a38bfc 100644 --- a/src/Database/Redis/Connection.hs +++ b/src/Database/Redis/Connection.hs @@ -231,9 +231,9 @@ connectCluster bootstrapConnInfo = do Left e -> throwIO $ ClusterConnectError e Right infos -> do #if MIN_VERSION_resource_pool(0,3,0) - pool <- newPool (defaultPoolConfig (Cluster.connect infos shardMapVar Nothing) Cluster.disconnect (realToFrac $ connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo)) + pool <- newPool (defaultPoolConfig (Cluster.connect (connectUsername bootstrapConnInfo) (connectAuth bootstrapConnInfo) infos shardMapVar Nothing) Cluster.disconnect (realToFrac $ connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo)) #else - pool <- createPool (Cluster.connect infos shardMapVar Nothing) Cluster.disconnect 1 (connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo) + pool <- createPool (Cluster.connect (connectUsername bootstrapConnInfo) (connectAuth bootstrapConnInfo) infos shardMapVar Nothing) Cluster.disconnect 1 (connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo) #endif return $ ClusteredConnection shardMapVar pool