diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 2262a96e5..ba4ba273f 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -479,7 +479,7 @@ setUserNetworkInfo c@AgentClient {userNetworkInfo, userNetworkUpdated} ni = with reconnectAllServers :: AgentClient -> IO () reconnectAllServers c = do - reconnectServerClients c smpClients + withAgentEnv' c $ reconnectSMPServerClients c reconnectServerClients c xftpClients reconnectServerClients c ntfClients diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 0467c31f8..f61c1b483 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -33,6 +33,7 @@ module Simplex.Messaging.Agent.Client closeAgentClient, closeProtocolServerClients, reconnectServerClients, + reconnectSMPServerClients, reconnectSMPServer, closeXFTPServerClient, runSMPServerTest, @@ -922,6 +923,30 @@ reconnectServerClients :: ProtocolServerClient v err msg => AgentClient -> (Agen reconnectServerClients c clientsSel = readTVarIO (clientsSel c) >>= mapM_ (forkIO . closeClient_ c) +reconnectSMPServerClients :: AgentClient -> AM' () +reconnectSMPServerClients c = do + (clients, qs) <- atomically $ do + clients <- swapTVar (smpClients c) M.empty + qs <- RQ.getDelAllQueues (activeSubs c) + qs' <- RQ.getDelAllQueues (pendingSubs c) + pure (clients, qs <> qs') + atomically $ writeTBQueue (subQ c) ("", "", AEvt SAENone DOWN_ALL) + mapM_ (liftIO . forkIO . closeClient_ c) clients + (qSubRs, _) <- subscribeQueues c qs + let upConns = subscribedConnsByServer qSubRs + forM_ (M.toList upConns) $ \(server, connIds) -> + liftIO $ notifyUP server (S.toList . S.fromList $ connIds) + where + subscribedConnsByServer :: [(RcvQueue, Either AgentErrorType ())] -> Map SMPServer [ConnId] + subscribedConnsByServer = foldl' insertConnId M.empty + where + insertConnId :: Map SMPServer [ConnId] -> (RcvQueue, Either AgentErrorType ()) -> Map SMPServer [ConnId] + insertConnId acc (RcvQueue {server, connId}, qSubResult) = case qSubResult of + Right _ -> M.insertWith (<>) server [connId] acc + Left _ -> acc + notifyUP :: SMPServer -> [ConnId] -> IO () + notifyUP server connIds = atomically $ writeTBQueue (subQ c) ("", "", AEvt SAENone (UP server connIds)) + reconnectSMPServer :: AgentClient -> UserId -> SMPServer -> IO () reconnectSMPServer c userId srv = do cs <- readTVarIO $ smpClients c diff --git a/src/Simplex/Messaging/Agent/Protocol.hs b/src/Simplex/Messaging/Agent/Protocol.hs index b123fc1ec..5bcb7ded3 100644 --- a/src/Simplex/Messaging/Agent/Protocol.hs +++ b/src/Simplex/Messaging/Agent/Protocol.hs @@ -338,6 +338,7 @@ data AEvent (e :: AEntity) where CONNECT :: AProtocolType -> TransportHost -> AEvent AENone DISCONNECT :: AProtocolType -> TransportHost -> AEvent AENone DOWN :: SMPServer -> [ConnId] -> AEvent AENone + DOWN_ALL :: AEvent AENone UP :: SMPServer -> [ConnId] -> AEvent AENone SWITCH :: QueueDirection -> SwitchPhase -> ConnectionStats -> AEvent AEConn RSYNC :: RatchetSyncState -> Maybe AgentCryptoError -> ConnectionStats -> AEvent AEConn @@ -406,6 +407,7 @@ data AEventTag (e :: AEntity) where CONNECT_ :: AEventTag AENone DISCONNECT_ :: AEventTag AENone DOWN_ :: AEventTag AENone + DOWN_ALL_ :: AEventTag AENone UP_ :: AEventTag AENone SWITCH_ :: AEventTag AEConn RSYNC_ :: AEventTag AEConn @@ -458,6 +460,7 @@ aEventTag = \case CONNECT {} -> CONNECT_ DISCONNECT {} -> DISCONNECT_ DOWN {} -> DOWN_ + DOWN_ALL {} -> DOWN_ALL_ UP {} -> UP_ SWITCH {} -> SWITCH_ RSYNC {} -> RSYNC_ diff --git a/src/Simplex/Messaging/Agent/TRcvQueues.hs b/src/Simplex/Messaging/Agent/TRcvQueues.hs index 9ffe325b2..c326e56f5 100644 --- a/src/Simplex/Messaging/Agent/TRcvQueues.hs +++ b/src/Simplex/Messaging/Agent/TRcvQueues.hs @@ -11,6 +11,7 @@ module Simplex.Messaging.Agent.TRcvQueues deleteQueue, getSessQueues, getDelSessQueues, + getDelAllQueues, qKey, ) where @@ -96,6 +97,11 @@ getDelSessQueues tSess (TRcvQueues qs cs) = do Nothing -> (cId : removed, Nothing) Nothing -> (removed, Nothing) -- "impossible" in invariant holds, because we get keys from the known queues +getDelAllQueues :: TRcvQueues -> STM [RcvQueue] +getDelAllQueues (TRcvQueues qs cs) = do + writeTVar cs M.empty + M.elems <$> swapTVar qs M.empty + isSession :: RcvQueue -> (UserId, SMPServer, Maybe ConnId) -> Bool isSession rq (uId, srv, connId_) = userId rq == uId && server rq == srv && maybe True (connId rq ==) connId_ diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index ab78d2ee9..131cda3c5 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -14,6 +14,7 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -Wno-orphans #-} +{-# OPTIONS_GHC -fno-warn-ambiguous-fields #-} module AgentTests.FunctionalAPITests ( functionalAPITests, @@ -2904,7 +2905,7 @@ testDeliveryReceiptsConcurrent t = _ -> error "timeout" testTwoUsers :: HasCallStack => IO () -testTwoUsers = withAgentClients2 $ \a b -> do +testTwoUsers = withAgentClientsCfg2 aCfg aCfg $ \a b -> do let nc = netCfg initAgentServers sessionMode nc `shouldBe` TSMUser runRight_ $ do @@ -2916,7 +2917,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do b `hasClients` 1 liftIO $ setNetworkConfig a nc {sessionMode = TSMEntity} liftIO $ threadDelay 250000 - ("", "", DOWN _ _) <- nGet a + ("", "", DOWN_ALL) <- nGet a ("", "", UP _ _) <- nGet a a `hasClients` 2 @@ -2925,9 +2926,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do liftIO $ threadDelay 250000 liftIO $ setNetworkConfig a nc {sessionMode = TSMUser} liftIO $ threadDelay 250000 - ("", "", DOWN _ _) <- nGet a - ("", "", DOWN _ _) <- nGet a - ("", "", UP _ _) <- nGet a + ("", "", DOWN_ALL) <- nGet a ("", "", UP _ _) <- nGet a a `hasClients` 1 @@ -2940,9 +2939,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do b `hasClients` 1 liftIO $ setNetworkConfig a nc {sessionMode = TSMEntity} liftIO $ threadDelay 250000 - ("", "", DOWN _ _) <- nGet a - ("", "", DOWN _ _) <- nGet a - ("", "", UP _ _) <- nGet a + ("", "", DOWN_ALL) <- nGet a ("", "", UP _ _) <- nGet a a `hasClients` 4 exchangeGreetingsMsgId 6 a bId1 b aId1 @@ -2952,13 +2949,7 @@ testTwoUsers = withAgentClients2 $ \a b -> do liftIO $ threadDelay 250000 liftIO $ setNetworkConfig a nc {sessionMode = TSMUser} liftIO $ threadDelay 250000 - ("", "", DOWN _ _) <- nGet a - ("", "", DOWN _ _) <- nGet a - ("", "", DOWN _ _) <- nGet a - ("", "", DOWN _ _) <- nGet a - ("", "", UP _ _) <- nGet a - ("", "", UP _ _) <- nGet a - ("", "", UP _ _) <- nGet a + ("", "", DOWN_ALL) <- nGet a ("", "", UP _ _) <- nGet a a `hasClients` 2 exchangeGreetingsMsgId 8 a bId1 b aId1 @@ -2966,6 +2957,8 @@ testTwoUsers = withAgentClients2 $ \a b -> do exchangeGreetingsMsgId 6 a bId2 b aId2 exchangeGreetingsMsgId 6 a bId2' b aId2' where + aCfg :: AgentConfig + aCfg = agentCfg {tbqSize = 16} hasClients :: HasCallStack => AgentClient -> Int -> ExceptT AgentErrorType IO () hasClients c n = liftIO $ M.size <$> readTVarIO (smpClients c) `shouldReturn` n diff --git a/tests/CoreTests/TRcvQueuesTests.hs b/tests/CoreTests/TRcvQueuesTests.hs index 9f7c4932e..14b894774 100644 --- a/tests/CoreTests/TRcvQueuesTests.hs +++ b/tests/CoreTests/TRcvQueuesTests.hs @@ -76,7 +76,7 @@ batchIdempotentTest = do atomically $ RQ.batchAddQueues trq qs checkDataInvariant trq `shouldReturn` True readTVarIO (RQ.getRcvQueues trq) `shouldReturn` qs' - fmap L.nub <$> readTVarIO (RQ.getConnections trq) `shouldReturn`cs' -- connections get duplicated, but that doesn't appear to affect anybody + fmap L.nub <$> readTVarIO (RQ.getConnections trq) `shouldReturn` cs' -- connections get duplicated, but that doesn't appear to affect anybody deleteConnTest :: IO () deleteConnTest = do