-- | This module implements a method to ingest a sequence of "Data.Binary"
-- encoded records using bounded memory. Minimal example:
--
-- > {-# LANGUAGE TypeApplications #-}
-- >
-- > import Data.Function ((&))
-- > import qualified Data.ByteString.Streaming as Q
-- > import Streaming
-- > import Streaming.Binary
-- > import qualified Streaming.Prelude as S
-- >
-- > -- Interpret all bytes on stdin as a sequence of integers.
-- > -- Print them on-the-fly on stdout.
-- > main = Q.getContents & decoded @Int & S.print

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}

module Streaming.Binary
  ( decode
  , decodeWith
  , decoded
  , decodedWith
  , encode
  , encodeWith
  , encoded
  , encodedWith
  ) where

import qualified Data.Binary.Get as Binary
import qualified Data.Binary.Put as Binary
import Data.Binary (Binary(..))
import qualified Data.ByteString.Builder.Extra as BS
import qualified Data.ByteString.Streaming as Q
import Data.ByteString.Streaming (ByteString)
import Data.Int (Int64)
import Streaming
import qualified Streaming.Prelude as S

-- | Decode a single element from a streaming bytestring. Returns any leftover
-- input, the number of bytes consumed, and either an error string or the
-- element if decoding succeeded.
decode
  :: (Binary a, Monad m)
  => ByteString m r
  -> m (ByteString m r, Int64, Either String a)
decode :: ByteString m r -> m (ByteString m r, Int64, Either String a)
decode = Get a
-> ByteString m r -> m (ByteString m r, Int64, Either String a)
forall (m :: * -> *) a r.
Monad m =>
Get a
-> ByteString m r -> m (ByteString m r, Int64, Either String a)
decodeWith Get a
forall t. Binary t => Get t
get

-- | Like 'decode', but with an explicitly provided decoder.
decodeWith
  :: Monad m
  => Binary.Get a
  -> ByteString m r
  -> m (ByteString m r, Int64, Either String a)
decodeWith :: Get a
-> ByteString m r -> m (ByteString m r, Int64, Either String a)
decodeWith Get a
getter = Int64
-> Decoder a
-> ByteString m r
-> m (ByteString m r, Int64, Either String a)
forall (m :: * -> *) b a.
Monad m =>
Int64
-> Decoder b
-> ByteStream m a
-> m (ByteStream m a, Int64, Either String b)
go Int64
0 (Get a -> Decoder a
forall a. Get a -> Decoder a
Binary.runGetIncremental Get a
getter)
  where
    go :: Int64
-> Decoder b
-> ByteStream m a
-> m (ByteStream m a, Int64, Either String b)
go !Int64
total (Binary.Fail ByteString
leftover Int64
nconsumed String
err) ByteStream m a
p = do
        (ByteStream m a, Int64, Either String b)
-> m (ByteStream m a, Int64, Either String b)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> ByteStream m ()
forall (m :: * -> *). ByteString -> ByteStream m ()
Q.chunk ByteString
leftover ByteStream m () -> ByteStream m a -> ByteStream m a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteStream m a
p, Int64
total Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
nconsumed, String -> Either String b
forall a b. a -> Either a b
Left String
err)
    go !Int64
total (Binary.Done ByteString
leftover Int64
nconsumed b
x) ByteStream m a
p = do
        (ByteStream m a, Int64, Either String b)
-> m (ByteStream m a, Int64, Either String b)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> ByteStream m ()
forall (m :: * -> *). ByteString -> ByteStream m ()
Q.chunk ByteString
leftover ByteStream m () -> ByteStream m a -> ByteStream m a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteStream m a
p, Int64
total Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
nconsumed, b -> Either String b
forall a b. b -> Either a b
Right b
x)
    go !Int64
total (Binary.Partial Maybe ByteString -> Decoder b
k) ByteStream m a
p = do
      ByteStream m a -> m (Either a (ByteString, ByteStream m a))
forall (m :: * -> *) r.
Monad m =>
ByteStream m r -> m (Either r (ByteString, ByteStream m r))
Q.nextChunk ByteStream m a
p m (Either a (ByteString, ByteStream m a))
-> (Either a (ByteString, ByteStream m a)
    -> m (ByteStream m a, Int64, Either String b))
-> m (ByteStream m a, Int64, Either String b)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Left a
res -> Int64
-> Decoder b
-> ByteStream m a
-> m (ByteStream m a, Int64, Either String b)
go Int64
total (Maybe ByteString -> Decoder b
k Maybe ByteString
forall a. Maybe a
Nothing) (a -> ByteStream m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
res)
        Right (ByteString
bs, ByteStream m a
p') -> Int64
-> Decoder b
-> ByteStream m a
-> m (ByteStream m a, Int64, Either String b)
go Int64
total (Maybe ByteString -> Decoder b
k (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
bs)) ByteStream m a
p'


-- | Decode a sequence of elements from a streaming bytestring. Returns any
-- leftover input, the number of bytes consumed, and either an error string or
-- the return value if there were no errors. Decoding stops at the first error.
decoded
  :: (Binary a, Monad m)
  => ByteString m r
  -> Stream (Of a) m (ByteString m r, Int64, Either String r)
decoded :: ByteString m r
-> Stream (Of a) m (ByteString m r, Int64, Either String r)
decoded = Get a
-> ByteString m r
-> Stream (Of a) m (ByteString m r, Int64, Either String r)
forall (m :: * -> *) a r.
Monad m =>
Get a
-> ByteString m r
-> Stream (Of a) m (ByteString m r, Int64, Either String r)
decodedWith Get a
forall t. Binary t => Get t
get

-- | Like 'decoded', but with an explicitly provided decoder.
decodedWith
  :: Monad m
  => Binary.Get a
  -> ByteString m r
  -> Stream (Of a) m (ByteString m r, Int64, Either String r)
decodedWith :: Get a
-> ByteString m r
-> Stream (Of a) m (ByteString m r, Int64, Either String r)
decodedWith Get a
getter = Int64
-> Decoder a
-> ByteString m r
-> Stream (Of a) m (ByteString m r, Int64, Either String r)
forall (m :: * -> *) a.
Monad m =>
Int64
-> Decoder a
-> ByteStream m a
-> Stream (Of a) m (ByteStream m a, Int64, Either String a)
go Int64
0 Decoder a
decoder0
  where
    decoder0 :: Decoder a
decoder0 = Get a -> Decoder a
forall a. Get a -> Decoder a
Binary.runGetIncremental Get a
getter
    go :: Int64
-> Decoder a
-> ByteStream m a
-> Stream (Of a) m (ByteStream m a, Int64, Either String a)
go !Int64
total (Binary.Fail ByteString
leftover Int64
nconsumed String
err) ByteStream m a
p = do
        (ByteStream m a, Int64, Either String a)
-> Stream (Of a) m (ByteStream m a, Int64, Either String a)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> ByteStream m ()
forall (m :: * -> *). ByteString -> ByteStream m ()
Q.chunk ByteString
leftover ByteStream m () -> ByteStream m a -> ByteStream m a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteStream m a
p, Int64
total Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
nconsumed, String -> Either String a
forall a b. a -> Either a b
Left String
err)
    go !Int64
total (Binary.Done ByteString
"" Int64
nconsumed a
x) ByteStream m a
p = do
        a -> Stream (Of a) m ()
forall (m :: * -> *) a. Monad m => a -> Stream (Of a) m ()
S.yield a
x
        m (Either a (ByteString, ByteStream m a))
-> Stream (Of a) m (Either a (ByteString, ByteStream m a))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ByteStream m a -> m (Either a (ByteString, ByteStream m a))
forall (m :: * -> *) r.
Monad m =>
ByteStream m r -> m (Either r (ByteString, ByteStream m r))
Q.nextChunk ByteStream m a
p) Stream (Of a) m (Either a (ByteString, ByteStream m a))
-> (Either a (ByteString, ByteStream m a)
    -> Stream (Of a) m (ByteStream m a, Int64, Either String a))
-> Stream (Of a) m (ByteStream m a, Int64, Either String a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Left a
res -> (ByteStream m a, Int64, Either String a)
-> Stream (Of a) m (ByteStream m a, Int64, Either String a)
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> ByteStream m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
res, Int64
total Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
nconsumed, a -> Either String a
forall a b. b -> Either a b
Right a
res)
          Right (ByteString
bs, ByteStream m a
p') -> do
            Int64
-> Decoder a
-> ByteStream m a
-> Stream (Of a) m (ByteStream m a, Int64, Either String a)
go (Int64
total Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
nconsumed) Decoder a
decoder0 (ByteString -> ByteStream m ()
forall (m :: * -> *). ByteString -> ByteStream m ()
Q.chunk ByteString
bs ByteStream m () -> ByteStream m a -> ByteStream m a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteStream m a
p')
    go !Int64
total (Binary.Done ByteString
leftover Int64
nconsumed a
x) ByteStream m a
p = do
        a -> Stream (Of a) m ()
forall (m :: * -> *) a. Monad m => a -> Stream (Of a) m ()
S.yield a
x
        Int64
-> Decoder a
-> ByteStream m a
-> Stream (Of a) m (ByteStream m a, Int64, Either String a)
go (Int64
total Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
nconsumed) (Decoder a
decoder0 Decoder a -> ByteString -> Decoder a
forall a. Decoder a -> ByteString -> Decoder a
`Binary.pushChunk` ByteString
leftover) ByteStream m a
p
    go !Int64
total (Binary.Partial Maybe ByteString -> Decoder a
k) ByteStream m a
p = do
      m (Either a (ByteString, ByteStream m a))
-> Stream (Of a) m (Either a (ByteString, ByteStream m a))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ByteStream m a -> m (Either a (ByteString, ByteStream m a))
forall (m :: * -> *) r.
Monad m =>
ByteStream m r -> m (Either r (ByteString, ByteStream m r))
Q.nextChunk ByteStream m a
p) Stream (Of a) m (Either a (ByteString, ByteStream m a))
-> (Either a (ByteString, ByteStream m a)
    -> Stream (Of a) m (ByteStream m a, Int64, Either String a))
-> Stream (Of a) m (ByteStream m a, Int64, Either String a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Left a
res -> Int64
-> Decoder a
-> ByteStream m a
-> Stream (Of a) m (ByteStream m a, Int64, Either String a)
go Int64
total (Maybe ByteString -> Decoder a
k Maybe ByteString
forall a. Maybe a
Nothing) (a -> ByteStream m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
res)
        Right (ByteString
bs, ByteStream m a
p') -> Int64
-> Decoder a
-> ByteStream m a
-> Stream (Of a) m (ByteStream m a, Int64, Either String a)
go Int64
total (Maybe ByteString -> Decoder a
k (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
bs)) ByteStream m a
p'

-- | Encode a single element.
encode
  :: (Binary a, MonadIO m)
  => a
  -> ByteString m ()
encode :: a -> ByteString m ()
encode = (a -> Put) -> a -> ByteString m ()
forall (m :: * -> *) a.
MonadIO m =>
(a -> Put) -> a -> ByteString m ()
encodeWith a -> Put
forall t. Binary t => t -> Put
put

-- | Like 'encode', but with an explicitly provided encoder.
encodeWith
  :: MonadIO m
  => (a -> Binary.Put)
  -> a
  -> ByteString m ()
encodeWith :: (a -> Put) -> a -> ByteString m ()
encodeWith a -> Put
putter a
x =
    AllocationStrategy -> Builder -> ByteString m ()
forall (m :: * -> *).
MonadIO m =>
AllocationStrategy -> Builder -> ByteStream m ()
Q.toStreamingByteStringWith
      (Int -> Int -> AllocationStrategy
BS.untrimmedStrategy Int
BS.smallChunkSize Int
BS.defaultChunkSize)
      (Put -> Builder
forall a. PutM a -> Builder
Binary.execPut (a -> Put
putter a
x))

-- | Encode a stream of elements to a streaming bytestring.
encoded
  :: (Binary a, MonadIO m)
  => Stream (Of a) IO ()
  -> ByteString m ()
encoded :: Stream (Of a) IO () -> ByteString m ()
encoded = (a -> Put) -> Stream (Of a) IO () -> ByteString m ()
forall (m :: * -> *) a.
MonadIO m =>
(a -> Put) -> Stream (Of a) IO () -> ByteString m ()
encodedWith a -> Put
forall t. Binary t => t -> Put
put

-- | Like 'encoded', but with an explicitly provided encoder.
encodedWith
  :: MonadIO m
  => (a -> Binary.Put)
  -> Stream (Of a) IO ()
  -> ByteString m ()
encodedWith :: (a -> Put) -> Stream (Of a) IO () -> ByteString m ()
encodedWith a -> Put
putter Stream (Of a) IO ()
xs =
    (forall a. IO a -> m a) -> ByteStream IO () -> ByteString m ()
forall k (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
hoist forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (ByteStream IO () -> ByteString m ())
-> ByteStream IO () -> ByteString m ()
forall a b. (a -> b) -> a -> b
$
    AllocationStrategy -> Builder -> ByteStream IO ()
forall (m :: * -> *).
MonadIO m =>
AllocationStrategy -> Builder -> ByteStream m ()
Q.toStreamingByteStringWith AllocationStrategy
strategy (Builder -> ByteStream IO ()) -> Builder -> ByteStream IO ()
forall a b. (a -> b) -> a -> b
$
    Stream (Of Builder) IO () -> Builder
Q.concatBuilders (Stream (Of Builder) IO () -> Builder)
-> Stream (Of Builder) IO () -> Builder
forall a b. (a -> b) -> a -> b
$
    (a -> Builder) -> Stream (Of a) IO () -> Stream (Of Builder) IO ()
forall (m :: * -> *) a b r.
Monad m =>
(a -> b) -> Stream (Of a) m r -> Stream (Of b) m r
S.map (Put -> Builder
forall a. PutM a -> Builder
Binary.execPut (Put -> Builder) -> (a -> Put) -> a -> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Put
putter) Stream (Of a) IO ()
xs
  where
    strategy :: AllocationStrategy
strategy = Int -> Int -> AllocationStrategy
BS.untrimmedStrategy Int
BS.smallChunkSize Int
BS.defaultChunkSize