88{-# LANGUAGE RecordWildCards #-}
99{-# LANGUAGE ScopedTypeVariables #-}
1010{-# LANGUAGE TypeFamilies #-}
11+ {-# LANGUAGE TupleSections #-}
1112
1213module Development.IDE.Graph.Internal.Database (newDatabase , incDatabase , build , getDirtySet , getKeysAndVisitAge ) where
1314
@@ -32,14 +33,15 @@ import Data.IORef.Extra
3233import Data.Maybe
3334import Data.Traversable (for )
3435import Data.Tuple.Extra
36+ import Debug.Trace (traceM )
3537import Development.IDE.Graph.Classes
3638import Development.IDE.Graph.Internal.Rules
3739import Development.IDE.Graph.Internal.Types
3840import qualified Focus
3941import qualified ListT
4042import qualified StmContainers.Map as SMap
43+ import System.Time.Extra (duration , sleep )
4144import System.IO.Unsafe
42- import System.Time.Extra (duration )
4345
4446newDatabase :: Dynamic -> TheRules -> IO Database
4547newDatabase databaseExtra databaseRules = do
@@ -120,7 +122,7 @@ builder db@Database{..} stack keys = withRunInIO $ \(RunInIO run) -> do
120122 pure (id , val)
121123
122124 toForceList <- liftIO $ readTVarIO toForce
123- let waitAll = run $ mapConcurrentlyAIO_ id toForceList
125+ let waitAll = run $ waitConcurrently_ toForceList
124126 case toForceList of
125127 [] -> return $ Left results
126128 _ -> return $ Right $ do
@@ -170,6 +172,10 @@ compute db@Database{..} stack key mode result = do
170172 deps | not (null deps)
171173 && runChanged /= ChangedNothing
172174 -> do
175+ -- IMPORTANT: record the reverse deps **before** marking the key Clean.
176+ -- If an async exception strikes before the deps have been recorded,
177+ -- we won't be able to accurately propagate dirtiness for this key
178+ -- on the next build.
173179 void $
174180 updateReverseDeps key db
175181 (getResultDepsDefault [] previousDeps)
@@ -224,7 +230,8 @@ updateReverseDeps
224230 -> [Key ] -- ^ Previous direct dependencies of Id
225231 -> HashSet Key -- ^ Current direct dependencies of Id
226232 -> IO ()
227- updateReverseDeps myId db prev new = uninterruptibleMask_ $ do
233+ -- mask to ensure that all the reverse dependencies are updated
234+ updateReverseDeps myId db prev new = do
228235 forM_ prev $ \ d ->
229236 unless (d `HSet.member` new) $
230237 doOne (HSet. delete myId) d
@@ -252,20 +259,27 @@ transitiveDirtySet database = flip State.execStateT HSet.empty . traverse_ loop
252259 next <- lift $ atomically $ getReverseDependencies database x
253260 traverse_ loop (maybe mempty HSet. toList next)
254261
255- -- | IO extended to track created asyncs to clean them up when the thread is killed,
256- -- generalizing 'withAsync'
262+ --------------------------------------------------------------------------------
263+ -- Asynchronous computations with cancellation
264+
265+ -- | A simple monad to implement cancellation on top of 'Async',
266+ -- generalizing 'withAsync' to monadic scopes.
257267newtype AIO a = AIO { unAIO :: ReaderT (IORef [Async () ]) IO a }
258268 deriving newtype (Applicative , Functor , Monad , MonadIO )
259269
270+ -- | Run the monadic computation, cancelling all the spawned asyncs if an exception arises
260271runAIO :: AIO a -> IO a
261272runAIO (AIO act) = do
262273 asyncs <- newIORef []
263274 runReaderT act asyncs `onException` cleanupAsync asyncs
264275
276+ -- | Like 'async' but with built-in cancellation.
277+ -- Returns an IO action to wait on the result.
265278asyncWithCleanUp :: AIO a -> AIO (IO a )
266279asyncWithCleanUp act = do
267280 st <- AIO ask
268281 io <- unliftAIO act
282+ -- mask to make sure we keep track of the spawned async
269283 liftIO $ uninterruptibleMask $ \ restore -> do
270284 a <- async $ restore io
271285 atomicModifyIORef'_ st (void a : )
@@ -284,27 +298,40 @@ withRunInIO k = do
284298 k $ RunInIO (\ aio -> runReaderT (unAIO aio) st)
285299
286300cleanupAsync :: IORef [Async a ] -> IO ()
287- cleanupAsync ref = uninterruptibleMask_ $ do
288- asyncs <- readIORef ref
301+ -- mask to make sure we interrupt all the asyncs
302+ cleanupAsync ref = uninterruptibleMask $ \ unmask -> do
303+ asyncs <- atomicModifyIORef' ref ([] ,)
304+ -- interrupt all the asyncs without waiting
289305 mapM_ (\ a -> throwTo (asyncThreadId a) AsyncCancelled ) asyncs
290- mapM_ waitCatch asyncs
306+ -- Wait until all the asyncs are done
307+ -- But if it takes more than 10 seconds, log to stderr
308+ unless (null asyncs) $ do
309+ let warnIfTakingTooLong = unmask $ forever $ do
310+ sleep 10
311+ traceM " cleanupAsync: waiting for asyncs to finish"
312+ withAsync warnIfTakingTooLong $ \ _ ->
313+ mapM_ waitCatch asyncs
314+
315+ data Wait
316+ = Wait { justWait :: ! (IO () )}
317+ | Spawn { justWait :: ! (IO () )}
291318
292- data Wait a
293- = Wait { justWait :: ! a }
294- | Spawn { justWait :: ! a }
295- deriving Functor
319+ fmapWait :: (IO () -> IO () ) -> Wait -> Wait
320+ fmapWait f (Wait io) = Wait (f io)
321+ fmapWait f (Spawn io) = Spawn (f io)
296322
297- waitOrSpawn :: Wait ( IO a ) -> IO (Either (IO a ) (Async a ))
323+ waitOrSpawn :: Wait -> IO (Either (IO () ) (Async () ))
298324waitOrSpawn (Wait io) = pure $ Left io
299325waitOrSpawn (Spawn io) = Right <$> async io
300326
301- mapConcurrentlyAIO_ :: ( a -> IO () ) -> [Wait a ] -> AIO ()
302- mapConcurrentlyAIO_ _ [] = pure ()
303- mapConcurrentlyAIO_ f [one] = liftIO $ justWait $ fmap f one
304- mapConcurrentlyAIO_ f many = do
327+ waitConcurrently_ :: [Wait ] -> AIO ()
328+ waitConcurrently_ [] = pure ()
329+ waitConcurrently_ [one] = liftIO $ justWait one
330+ waitConcurrently_ many = do
305331 ref <- AIO ask
306- waits <- liftIO $ uninterruptibleMask $ \ restore -> do
307- waits <- liftIO $ traverse (waitOrSpawn . fmap (restore . f)) many
332+ -- mask to make sure we keep track of all the asyncs
333+ waits <- liftIO $ uninterruptibleMask $ \ unmask -> do
334+ waits <- liftIO $ traverse (waitOrSpawn . fmapWait unmask) many
308335 let asyncs = rights waits
309336 liftIO $ atomicModifyIORef'_ ref (asyncs ++ )
310337 return waits
0 commit comments