diff --git a/io-sim/src/Control/Monad/IOSim/Internal.hs b/io-sim/src/Control/Monad/IOSim/Internal.hs index eb14197b..47a464d6 100644 --- a/io-sim/src/Control/Monad/IOSim/Internal.hs +++ b/io-sim/src/Control/Monad/IOSim/Internal.hs @@ -1148,19 +1148,45 @@ execAtomically !time !tid !tlbl !nextVid0 action0 k0 = ThrowStm e -> {-# SCC "execAtomically.go.ThrowStm" #-} do - -- Revert all the TVar writes + -- Rollback `TVar`s written since catch handler was installed !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written - k0 $ StmTxAborted [] (toException e) + case ctl of + AtomicallyFrame -> do + k0 $ StmTxAborted (Map.elems read) (toException e) + + BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> + {-# SCC "execAtomically.go.BranchFrame" #-} do + -- Execute the left side in a new frame with an empty written set. + -- but preserve ones that were set prior to it, as specified in the + -- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package. + let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl' + go ctl'' read Map.empty [] [] nextVid (h e) + + BranchFrame (OrElseStmA _r) _k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> + {-# SCC "execAtomically.go.BranchFrame" #-} do + go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e) + + BranchFrame NoOpStmA _k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> + {-# SCC "execAtomically.go.BranchFrame" #-} do + go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e) + + CatchStm a h k -> + {-# SCC "execAtomically.go.ThrowStm" #-} do + -- Execute the catch handler with an empty written set. + -- but preserve ones that were set prior to it, as specified in the + -- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package. + let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl + go ctl' read Map.empty [] [] nextVid a + Retry -> - {-# SCC "execAtomically.go.Retry" #-} - do + {-# SCC "execAtomically.go.Retry" #-} do -- Always revert all the TVar writes for the retry !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written case ctl of AtomicallyFrame -> do -- Return vars read, so the thread can block on them - k0 $! StmTxBlocked $! (Map.elems read) + k0 $! StmTxBlocked $! Map.elems read BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> {-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do diff --git a/io-sim/src/Control/Monad/IOSim/Types.hs b/io-sim/src/Control/Monad/IOSim/Types.hs index af54e0ad..9a214fd2 100644 --- a/io-sim/src/Control/Monad/IOSim/Types.hs +++ b/io-sim/src/Control/Monad/IOSim/Types.hs @@ -195,6 +195,7 @@ runSTM (STM k) = k ReturnStm data StmA s a where ReturnStm :: a -> StmA s a ThrowStm :: SomeException -> StmA s a + CatchStm :: StmA s a -> (SomeException -> StmA s a) -> (a -> StmA s b) -> StmA s b NewTVar :: Maybe String -> x -> (TVar s x -> StmA s b) -> StmA s b LabelTVar :: String -> TVar s a -> StmA s b -> StmA s b @@ -339,6 +340,31 @@ instance MonadThrow (STM s) where instance Exceptions.MonadThrow (STM s) where throwM = MonadThrow.throwIO + +instance MonadCatch (STM s) where + + catch action handler = STM $ oneShot $ \k -> CatchStm (runSTM action) (runSTM . fromHandler handler) k + where + -- Get a total handler from the given handler + fromHandler :: Exception e => (e -> STM s a) -> SomeException -> STM s a + fromHandler h e = case fromException e of + Nothing -> throwIO e -- Rethrow the exception if handler does not handle it. + Just e' -> h e' + + -- Masking is not required as STM actions are always run inside + -- `execAtomically` and behave as if masked. Also note that the default + -- implementation of `generalBracket` needs mask, and is part of `MonadThrow`. + generalBracket acquire release use = do + resource <- acquire + b <- use resource `catch` \e -> do + _ <- release resource (ExitCaseException e) + throwIO e + c <- release resource (ExitCaseSuccess b) + return (b, c) + +instance Exceptions.MonadCatch (STM s) where + catch = MonadThrow.catch + instance MonadCatch (IOSim s) where catch action handler = IOSim $ oneShot $ \k -> Catch (runIOSim action) (runIOSim . handler) k @@ -867,9 +893,22 @@ data StmTxResult s a = | StmTxAborted [SomeTVar s] SomeException --- | OrElse/Catch give rise to an alternate right hand side branch. A right branch --- can be a NoOp -data BranchStmA s a = OrElseStmA (StmA s a) | NoOpStmA +-- | A branch indicates that an alternative statement is available in the current +-- context. For example, `OrElse` has two alternative statements, say "left" +-- and "right". While executing the left statement, `OrElseStmA` branch indicates +-- that the right branch is still available, in case the left statement fails. +data BranchStmA s a = + -- | `OrElse` statement with its 'right' alternative. + OrElseStmA (StmA s a) + -- | `CatchStm` statement with the 'catch' handler. + | CatchStmA (SomeException -> StmA s a) + -- | Unlike the other two branches, the no-op branch is not an explicit + -- part of the STM syntax. It simply indicates that there are no + -- alternative statements left to be executed. For example, when running + -- right alternative of the `OrElse` statement or when running the catch + -- handler of a `CatchStm` statement, there are no alternative statements + -- available. This case is represented by the no-op branch. + | NoOpStmA data StmStack s b a where -- | Executing in the context of a top level 'atomically'. diff --git a/io-sim/src/Control/Monad/IOSimPOR/Internal.hs b/io-sim/src/Control/Monad/IOSimPOR/Internal.hs index 64020326..15ca22f8 100644 --- a/io-sim/src/Control/Monad/IOSimPOR/Internal.hs +++ b/io-sim/src/Control/Monad/IOSimPOR/Internal.hs @@ -1391,32 +1391,54 @@ execAtomically time tid tlbl nextVid0 action0 k0 = {-# SCC "execAtomically.go.ThrowStm" #-} do -- Revert all the TVar writes !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written - k0 $ StmTxAborted (Map.elems read) (toException e) + case ctl of + AtomicallyFrame -> do + k0 $ StmTxAborted (Map.elems read) (toException e) + + BranchFrame (CatchStmA h) k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> + {-# SCC "execAtomically.go.BranchFrame" #-} do + -- Execute the left side in a new frame with an empty written set. + -- but preserve ones that were set prior to it, as specified in the + -- [stm](https://hackage.haskell.org/package/stm/docs/Control-Monad-STM.html#v:catchSTM) package. + let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl' + go ctl'' read Map.empty [] [] nextVid (h e) + + BranchFrame (OrElseStmA _r) _k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> + {-# SCC "execAtomically.go.BranchFrame" #-} do + go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e) + + BranchFrame NoOpStmA _k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> + {-# SCC "execAtomically.go.BranchFrame" #-} do + go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid (ThrowStm e) + + CatchStm a h k -> + {-# SCC "execAtomically.go.ThrowStm" #-} do + -- Execute the left side in a new frame with an empty written set + let ctl' = BranchFrame (CatchStmA h) k written writtenSeq createdSeq ctl + go ctl' read Map.empty [] [] nextVid a Retry -> - {-# SCC "execAtomically.go.Retry" #-} - do - -- Always revert all the TVar writes for the retry - !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written - case ctl of - AtomicallyFrame -> do - -- Return vars read, so the thread can block on them - k0 $! StmTxBlocked $! Map.elems read - - BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> - {-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do - !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written - -- Execute the orElse right hand with an empty written set - let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl' - go ctl'' read Map.empty [] [] nextVid b - - BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> - {-# SCC "execAtomically.go.BranchFrame" #-} do - -- Retry makes sense only within a OrElse context. If it is a branch other than - -- OrElse left side, then bubble up the `retry` to the frame above. - -- Skip the continuation and propagate the retry into the outer frame - -- using the written set for the outer frame - go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry + {-# SCC "execAtomically.go.Retry" #-} do + -- Always revert all the TVar writes for the retry + !_ <- traverse_ (\(SomeTVar tvar) -> revertTVar tvar) written + case ctl of + AtomicallyFrame -> do + -- Return vars read, so the thread can block on them + k0 $! StmTxBlocked $! Map.elems read + + BranchFrame (OrElseStmA b) k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> + {-# SCC "execAtomically.go.BranchFrame.OrElseStmA" #-} do + -- Execute the orElse right hand with an empty written set + let ctl'' = BranchFrame NoOpStmA k writtenOuter writtenOuterSeq createdOuterSeq ctl' + go ctl'' read Map.empty [] [] nextVid b + + BranchFrame _ _k writtenOuter writtenOuterSeq createdOuterSeq ctl' -> + {-# SCC "execAtomically.go.BranchFrame" #-} do + -- Retry makes sense only within a OrElse context. If it is a branch other than + -- OrElse left side, then bubble up the `retry` to the frame above. + -- Skip the continuation and propagate the retry into the outer frame + -- using the written set for the outer frame + go ctl' read writtenOuter writtenOuterSeq createdOuterSeq nextVid Retry OrElse a b k -> {-# SCC "execAtomically.go.OrElse" #-} do diff --git a/io-sim/test/Test/IOSim.hs b/io-sim/test/Test/IOSim.hs index fc0316fc..ff204914 100644 --- a/io-sim/test/Test/IOSim.hs +++ b/io-sim/test/Test/IOSim.hs @@ -1249,7 +1249,7 @@ prop_stm_referenceSim t = -- | Compare the behaviour of the STM reference operational semantics with -- the behaviour of any 'MonadSTM' STM implementation. -- -prop_stm_referenceM :: (MonadSTM m, MonadThrow (STM m), MonadCatch m) +prop_stm_referenceM :: (MonadSTM m, MonadCatch (STM m), MonadCatch m) => SomeTerm -> m Property prop_stm_referenceM (SomeTerm _tyrep t) = do let (r1, _heap) = evalAtomically t diff --git a/io-sim/test/Test/STM.hs b/io-sim/test/Test/STM.hs index 10c8a2d5..f818b705 100644 --- a/io-sim/test/Test/STM.hs +++ b/io-sim/test/Test/STM.hs @@ -67,6 +67,7 @@ data Term (t :: Type) where Return :: Expr t -> Term t Throw :: Expr a -> Term t + Catch :: Term t -> Term t -> Term t Retry :: Term t ReadTVar :: Name (TyVar t) -> Term t @@ -296,7 +297,7 @@ deriving instance Show (NfTerm t) -- | The STM transition rules. They reduce a 'Term' to a normal-form 'NfTerm'. -- -- Compare the implementation of this against the operational semantics in --- Figure 4 in the paper. Note that @catch@ is not included. +-- Figure 4 in the paper including the `Catch` semantics from the Appendix A. -- evalTerm :: Env -> Heap -> Allocs -> Term t -> (NfTerm t, Heap, Allocs) evalTerm !env !heap !allocs term = case term of @@ -309,6 +310,30 @@ evalTerm !env !heap !allocs term = case term of where e' = evalExpr env e + -- Exception semantics are detailed in "Appendix A Exception semantics" p 12-13 of + -- + Catch t1 t2 -> + let (nf1, heap', allocs') = evalTerm env heap mempty t1 in case nf1 of + + -- Rule XSTM1 + -- M; heap, {} => return P; heap', allocs' + -- -------------------------------------------------------- + -- S[catch M N]; heap, allocs => S[return P]; heap', allocs U allocs' + NfReturn v -> (NfReturn v, heap', allocs <> allocs') + + -- Rule XSTM2 + -- M; heap, {} => throw P; heap', allocs' + -- -------------------------------------------------------- + -- S[catch M N]; heap, allocs => S[N P]; heap U allocs', allocs U allocs' + NfThrow _ -> evalTerm env (heap <> allocs') (allocs <> allocs') t2 + + -- Rule XSTM3 + -- M; heap, {} => retry; heap', allocs' + -- -------------------------------------------------------- + -- S[catch M N]; heap, allocs => S[retry]; heap, allocs + NfRetry -> (NfRetry, heap, allocs) + + Retry -> (NfRetry, heap, allocs) -- Rule READ @@ -437,7 +462,7 @@ extendExecEnv (Name name _tyrep) v (ExecEnv env) = -- | Execute an STM 'Term' in the 'STM' monad. -- -execTerm :: (MonadSTM m, MonadThrow (STM m)) +execTerm :: (MonadSTM m, MonadCatch (STM m)) => ExecEnv m -> Term t -> STM m (ExecValue m t) @@ -451,6 +476,8 @@ execTerm env t = let e' = execExpr env e throwSTM =<< snapshotExecValue e' + Catch t1 t2 -> execTerm env t1 `catch` \(_ :: ImmValue) -> execTerm env t2 + Retry -> retry ReadTVar n -> do @@ -491,7 +518,7 @@ snapshotExecValue (ExecValInt x) = return (ImmValInt x) snapshotExecValue (ExecValVar v _) = fmap ImmValVar (snapshotExecValue =<< readTVar v) -execAtomically :: forall m t. (MonadSTM m, MonadThrow (STM m), MonadCatch m) +execAtomically :: forall m t. (MonadSTM m, MonadCatch (STM m), MonadCatch m) => Term t -> m TxResult execAtomically t = toTxResult <$> try (atomically action') @@ -657,7 +684,7 @@ genTerm env tyrep = Nothing) ] - binTerm = frequency [ (2, bindTerm), (1, orElseTerm)] + binTerm = frequency [ (2, bindTerm), (1, orElseTerm), (1, catchTerm)] bindTerm = sized $ \sz -> do @@ -671,10 +698,15 @@ genTerm env tyrep = return (Bind t1 name t2) orElseTerm = - sized $ \sz -> resize (sz `div` 2) $ + scale (`div` 2) $ OrElse <$> genTerm env tyrep <*> genTerm env tyrep + catchTerm = + scale (`div` 2) $ + Catch <$> genTerm env tyrep + <*> genTerm env tyrep + genSomeExpr :: GenEnv -> Gen SomeExpr genSomeExpr env = oneof' @@ -713,6 +745,8 @@ shrinkTerm t = case t of Return e -> [Return e' | e' <- shrinkExpr e] Throw e -> [Throw e' | e' <- shrinkExpr e] + Catch t1 t2 -> [t1, t2] + ++ [Catch t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2)] Retry -> [] ReadTVar _ -> [] @@ -721,12 +755,10 @@ shrinkTerm t = NewTVar e -> [NewTVar e' | e' <- shrinkExpr e] Bind t1 n t2 -> [ t2 | nameId n `Set.notMember` freeNamesTerm t2 ] - ++ [ Bind t1' n t2 | t1' <- shrinkTerm t1 ] - ++ [ Bind t1 n t2' | t2' <- shrinkTerm t2 ] + ++ [ Bind t1' n t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ] OrElse t1 t2 -> [t1, t2] - ++ [ OrElse t1' t2 | t1' <- shrinkTerm t1 ] - ++ [ OrElse t1 t2' | t2' <- shrinkTerm t2 ] + ++ [ OrElse t1' t2' | (t1', t2') <- liftShrink2 shrinkTerm shrinkTerm (t1, t2) ] shrinkExpr :: Expr t -> [Expr t] shrinkExpr ExprUnit = [] @@ -738,6 +770,10 @@ shrinkExpr (ExprName (Name _ (TyRepVar _))) = [] freeNamesTerm :: Term t -> Set NameId freeNamesTerm (Return e) = freeNamesExpr e freeNamesTerm (Throw e) = freeNamesExpr e +-- The current generator of catch term ignores the argument passed to the +-- handler. +-- TODO: Correctly handle free names when the handler also binds a variable. +freeNamesTerm (Catch t1 t2) = freeNamesTerm t1 <> freeNamesTerm t2 freeNamesTerm Retry = Set.empty freeNamesTerm (ReadTVar n) = Set.singleton (nameId n) freeNamesTerm (WriteTVar n e) = Set.singleton (nameId n) <> freeNamesExpr e @@ -768,6 +804,7 @@ prop_genSomeTerm (SomeTerm tyrep term) = termSize :: Term a -> Int termSize Return{} = 1 termSize Throw{} = 1 +termSize (Catch a b) = 1 + termSize a + termSize b termSize Retry{} = 1 termSize ReadTVar{} = 1 termSize WriteTVar{} = 1 @@ -778,6 +815,7 @@ termSize (OrElse a b) = 1 + termSize a + termSize b termDepth :: Term a -> Int termDepth Return{} = 1 termDepth Throw{} = 1 +termDepth (Catch a b) = 1 + max (termDepth a) (termDepth b) termDepth Retry{} = 1 termDepth ReadTVar{} = 1 termDepth WriteTVar{} = 1 @@ -790,6 +828,9 @@ showTerm p (Return e) = showParen (p > 10) $ showString "return " . showExpr 11 e showTerm p (Throw e) = showParen (p > 10) $ showString "throwSTM " . showExpr 11 e +showTerm p (Catch t1 t2) = showParen (p > 9) $ + showTerm 10 t1 . showString " `catch` " + . showTerm 10 t2 showTerm _ Retry = showString "retry" showTerm p (ReadTVar n) = showParen (p > 10) $ showString "readTVar " . showName n