{-# LANGUAGE RankNTypes #-}

-- |
-- A utility module which lets you put a concurrency limit to an IO action.
module Control.Concurrent.Throttled
  ( Throttle,
    newThrottle,
    throttled,
  )
where

import Control.Concurrent.QSem
import Control.Monad.Catch
import Control.Monad.IO.Unlift

data Throttle
  = Throttle
      { Throttle
-> forall (m :: * -> *) a. (MonadIO m, MonadMask m) => m a -> m a
throttled :: forall m a. (MonadIO m, MonadMask m) => m a -> m a
      }

newThrottle :: MonadIO m => Int -> m Throttle
newThrottle :: Int -> m Throttle
newThrottle Int
n | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = Throttle -> m Throttle
forall (m :: * -> *) a. Monad m => a -> m a
return (Throttle -> m Throttle) -> Throttle -> m Throttle
forall a b. (a -> b) -> a -> b
$ (forall (m :: * -> *) a. (MonadIO m, MonadMask m) => m a -> m a)
-> Throttle
Throttle forall a. a -> a
forall (m :: * -> *) a. (MonadIO m, MonadMask m) => m a -> m a
id
newThrottle Int
n = do
  QSem
sem <- IO QSem -> m QSem
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO QSem -> m QSem) -> IO QSem -> m QSem
forall a b. (a -> b) -> a -> b
$ Int -> IO QSem
newQSem Int
n
  Throttle -> m Throttle
forall (m :: * -> *) a. Monad m => a -> m a
return (Throttle -> m Throttle) -> Throttle -> m Throttle
forall a b. (a -> b) -> a -> b
$ (forall (m :: * -> *) a. (MonadIO m, MonadMask m) => m a -> m a)
-> Throttle
Throttle ((forall (m :: * -> *) a. (MonadIO m, MonadMask m) => m a -> m a)
 -> Throttle)
-> (forall (m :: * -> *) a. (MonadIO m, MonadMask m) => m a -> m a)
-> Throttle
forall a b. (a -> b) -> a -> b
$
    m () -> m () -> m a -> m a
forall (m :: * -> *) a c b. MonadMask m => m a -> m c -> m b -> m b
bracket_
      (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ QSem -> IO ()
waitQSem QSem
sem)
      (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ QSem -> IO ()
signalQSem QSem
sem)