-- |
-- Module      : Crypto.Cipher.Types.Block
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : Stable
-- Portability : Excellent
--
-- block cipher basic types
--
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE ViewPatterns #-}
module Crypto.Cipher.Types.BlockIO
    ( BlockCipherIO(..)
    , PtrDest
    , PtrSource
    , PtrIV
    , BufferLength
    , onBlock
    ) where

import Control.Applicative
import Data.Word
import Data.ByteString (ByteString)
import qualified Data.ByteString.Internal as B (fromForeignPtr, memcpy)
import Data.Byteable
import Data.Bits (xor, Bits)
import Foreign.Storable (poke, peek, Storable)
--import Foreign.Ptr (plusPtr, Ptr, castPtr, nullPtr)
import Crypto.Cipher.Types.Block
import Foreign.Ptr
import Foreign.ForeignPtr (newForeignPtr_)

-- | pointer to the destination data
type PtrDest   = Ptr Word8

-- | pointer to the source data
type PtrSource = Ptr Word8

-- | pointer to the IV data
type PtrIV     = Ptr Word8

-- | Length of the pointed data
type BufferLength = Word32

-- | Symmetric block cipher class, mutable API
class BlockCipher cipher => BlockCipherIO cipher where
    -- | Encrypt using the ECB mode.
    --
    -- input need to be a multiple of the blocksize
    ecbEncryptMutable :: cipher -> PtrDest -> PtrSource -> BufferLength -> IO ()

    -- | Decrypt using the ECB mode.
    --
    -- input need to be a multiple of the blocksize
    ecbDecryptMutable :: cipher -> PtrDest -> PtrSource -> BufferLength -> IO ()

    -- | encrypt using the CBC mode.
    --
    -- input need to be a multiple of the blocksize
    cbcEncryptMutable :: cipher -> PtrIV -> PtrDest -> PtrSource -> BufferLength -> IO ()
    cbcEncryptMutable = cipher -> PtrIV -> PtrIV -> PtrIV -> BufferLength -> IO ()
forall cipher.
BlockCipherIO cipher =>
cipher -> PtrIV -> PtrIV -> PtrIV -> BufferLength -> IO ()
cbcEncryptGeneric

    -- | decrypt using the CBC mode.
    --
    -- input need to be a multiple of the blocksize
    cbcDecryptMutable :: cipher -> PtrIV -> PtrDest -> PtrSource -> BufferLength -> IO ()
    cbcDecryptMutable = cipher -> PtrIV -> PtrIV -> PtrIV -> BufferLength -> IO ()
forall cipher.
BlockCipherIO cipher =>
cipher -> PtrIV -> PtrIV -> PtrIV -> BufferLength -> IO ()
cbcDecryptGeneric

{-
    -- | encrypt using the CFB mode.
    --
    -- input need to be a multiple of the blocksize
    cfbEncryptMutable :: cipher -> PtrIV -> PtrDest -> PtrSource -> BufferLength -> IO ()
    cfbEncryptMutable = cfbEncryptGeneric

    -- | decrypt using the CFB mode.
    --
    -- input need to be a multiple of the blocksize
    cfbDecryptMutable :: cipher -> PtrIV -> PtrDest -> PtrSource -> BufferLength -> IO ()
    cfbDecryptMutable = cfbDecryptGeneric

    -- | combine using the CTR mode.
    --
    -- CTR mode produce a stream of randomized data that is combined
    -- (by XOR operation) with the input stream.
    --
    -- encryption and decryption are the same operation.
    --
    -- input can be of any size
    ctrCombineMutable :: cipher -> PtrIV -> PtrDest -> PtrSource -> BufferLength -> IO ()
    ctrCombineMutable = ctrCombineGeneric

    -- | encrypt using the XTS mode.
    --
    -- input need to be a multiple of the blocksize
    xtsEncryptMutable :: (cipher, cipher) -> PtrIV -> DataUnitOffset -> PtrDest -> PtrSource -> BufferLength -> IO ()
    xtsEncryptMutable = xtsEncryptGeneric
    -- | decrypt using the XTS mode.
    --
    -- input need to be a multiple of the blocksize
    xtsDecryptMutable :: (cipher, cipher) -> PtrIV -> DataUnitOffset -> PtrDest -> PtrSource -> BufferLength -> IO ()
    xtsDecryptMutable = xtsDecryptGeneric
-}

cbcEncryptGeneric :: BlockCipherIO cipher => cipher -> PtrIV -> PtrDest -> PtrSource -> BufferLength -> IO ()
cbcEncryptGeneric :: forall cipher.
BlockCipherIO cipher =>
cipher -> PtrIV -> PtrIV -> PtrIV -> BufferLength -> IO ()
cbcEncryptGeneric cipher
cipher = cipher
-> (Int -> PtrIV -> PtrIV -> PtrIV -> IO PtrIV)
-> PtrIV
-> PtrIV
-> PtrIV
-> BufferLength
-> IO ()
forall cipher.
BlockCipherIO cipher =>
cipher
-> (Int -> PtrIV -> PtrIV -> PtrIV -> IO PtrIV)
-> PtrIV
-> PtrIV
-> PtrIV
-> BufferLength
-> IO ()
loopBS cipher
cipher Int -> PtrIV -> PtrIV -> PtrIV -> IO PtrIV
encrypt
  where encrypt :: Int -> PtrIV -> PtrIV -> PtrIV -> IO PtrIV
encrypt Int
bs PtrIV
iv PtrIV
d PtrIV
s = do
            PtrIV -> PtrIV -> PtrIV -> Int -> IO ()
mutableXor PtrIV
d PtrIV
iv PtrIV
s Int
bs
            cipher -> PtrIV -> PtrIV -> BufferLength -> IO ()
forall cipher.
BlockCipherIO cipher =>
cipher -> PtrIV -> PtrIV -> BufferLength -> IO ()
ecbEncryptMutable cipher
cipher PtrIV
d PtrIV
d (Int -> BufferLength
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
bs)
            PtrIV -> IO PtrIV
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return PtrIV
s

cbcDecryptGeneric :: BlockCipherIO cipher => cipher -> PtrIV -> PtrDest -> PtrSource -> BufferLength -> IO ()
cbcDecryptGeneric :: forall cipher.
BlockCipherIO cipher =>
cipher -> PtrIV -> PtrIV -> PtrIV -> BufferLength -> IO ()
cbcDecryptGeneric cipher
cipher = cipher
-> (Int -> PtrIV -> PtrIV -> PtrIV -> IO PtrIV)
-> PtrIV
-> PtrIV
-> PtrIV
-> BufferLength
-> IO ()
forall cipher.
BlockCipherIO cipher =>
cipher
-> (Int -> PtrIV -> PtrIV -> PtrIV -> IO PtrIV)
-> PtrIV
-> PtrIV
-> PtrIV
-> BufferLength
-> IO ()
loopBS cipher
cipher Int -> PtrIV -> PtrIV -> PtrIV -> IO PtrIV
decrypt
  where decrypt :: Int -> PtrIV -> PtrIV -> PtrIV -> IO PtrIV
decrypt Int
bs PtrIV
iv PtrIV
d PtrIV
s = do
            cipher -> PtrIV -> PtrIV -> BufferLength -> IO ()
forall cipher.
BlockCipherIO cipher =>
cipher -> PtrIV -> PtrIV -> BufferLength -> IO ()
ecbEncryptMutable cipher
cipher PtrIV
d PtrIV
s (Int -> BufferLength
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
bs)
            -- FIXME only work if s != d
            PtrIV -> PtrIV -> PtrIV -> Int -> IO ()
mutableXor PtrIV
d PtrIV
iv PtrIV
d Int
bs
            PtrIV -> IO PtrIV
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return PtrIV
d

{-
cfbEncryptGeneric :: BlockCipherIO cipher => cipher -> PtrIV -> PtrDest -> PtrSource -> BufferLength -> IO ()
cfbEncryptGeneric cipher = loopBS cipher encrypt
  where encrypt bs iv d s = do
            ecbEncryptMutable cipher d iv (fromIntegral bs)
            mutableXor d d s bs
            return d


cfbDecryptGeneric :: BlockCipherIO cipher => cipher -> PtrIV -> PtrDest -> PtrSource -> BufferLength -> IO ()
cfbDecryptGeneric cipher = loopBS cipher decrypt
  where decrypt bs iv d s = do
            ecbEncryptMutable cipher d iv (fromIntegral bs)
            mutableXor d d s bs
            return s

ctrCombineGeneric :: BlockCipherIO cipher => cipher -> PtrIV -> PtrDest -> PtrSource -> BufferLength -> IO ()
ctrCombineGeneric cipher ivini dst src len = return () {-B.concat $ doCnt ivini $ chunk (blockSize cipher) input
  where doCnt _  [] = []
        doCnt iv (i:is) =
            let ivEnc = ecbEncrypt cipher (toBytes iv)
             in bxor i ivEnc : doCnt (ivAdd iv 1) is-}
-}

-- | Helper to use a purer interface
onBlock :: BlockCipherIO cipher
        => cipher
        -> (ByteString -> ByteString)
        -> PtrDest
        -> PtrSource
        -> BufferLength
        -> IO ()
onBlock :: forall cipher.
BlockCipherIO cipher =>
cipher
-> (ByteString -> ByteString)
-> PtrIV
-> PtrIV
-> BufferLength
-> IO ()
onBlock cipher
cipher ByteString -> ByteString
f PtrIV
dst PtrIV
src BufferLength
len = cipher
-> (Int -> PtrIV -> PtrIV -> PtrIV -> IO PtrIV)
-> PtrIV
-> PtrIV
-> PtrIV
-> BufferLength
-> IO ()
forall cipher.
BlockCipherIO cipher =>
cipher
-> (Int -> PtrIV -> PtrIV -> PtrIV -> IO PtrIV)
-> PtrIV
-> PtrIV
-> PtrIV
-> BufferLength
-> IO ()
loopBS cipher
cipher Int -> PtrIV -> PtrIV -> PtrIV -> IO PtrIV
forall {b}. Int -> b -> PtrIV -> PtrIV -> IO b
wrap PtrIV
forall a. Ptr a
nullPtr PtrIV
dst PtrIV
src BufferLength
len
  where wrap :: Int -> b -> PtrIV -> PtrIV -> IO b
wrap Int
bs b
fakeIv PtrIV
d PtrIV
s = do
            fSrc <- PtrIV -> IO (ForeignPtr Word8)
forall a. Ptr a -> IO (ForeignPtr a)
newForeignPtr_ PtrIV
s
            let res = ByteString -> ByteString
f (ForeignPtr Word8 -> Int -> Int -> ByteString
B.fromForeignPtr ForeignPtr Word8
fSrc Int
0 Int
bs)
            withBytePtr res $ \PtrIV
r -> PtrIV -> PtrIV -> Int -> IO ()
B.memcpy PtrIV
d PtrIV
r (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
bs)
            return fakeIv

loopBS :: BlockCipherIO cipher
       => cipher
       -> (Int -> PtrIV -> PtrDest -> PtrSource -> IO PtrIV)
       -> PtrIV -> PtrDest -> PtrSource -> BufferLength
       -> IO ()
loopBS :: forall cipher.
BlockCipherIO cipher =>
cipher
-> (Int -> PtrIV -> PtrIV -> PtrIV -> IO PtrIV)
-> PtrIV
-> PtrIV
-> PtrIV
-> BufferLength
-> IO ()
loopBS cipher
cipher Int -> PtrIV -> PtrIV -> PtrIV -> IO PtrIV
f PtrIV
iv PtrIV
dst PtrIV
src BufferLength
len = PtrIV -> PtrIV -> PtrIV -> BufferLength -> IO ()
forall {t}. (Eq t, Num t) => PtrIV -> PtrIV -> PtrIV -> t -> IO ()
loop PtrIV
iv PtrIV
dst PtrIV
src BufferLength
len
  where bs :: Int
bs = cipher -> Int
forall cipher. BlockCipher cipher => cipher -> Int
blockSize cipher
cipher
        loop :: PtrIV -> PtrIV -> PtrIV -> t -> IO ()
loop PtrIV
_ PtrIV
_ PtrIV
_ t
0 = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        loop PtrIV
i PtrIV
d PtrIV
s t
n = do
            newIV <- Int -> PtrIV -> PtrIV -> PtrIV -> IO PtrIV
f Int
bs PtrIV
i PtrIV
d PtrIV
s
            loop newIV (d `plusPtr` bs) (s `plusPtr` bs) (n - fromIntegral bs)

mutableXor :: PtrDest -> PtrSource -> PtrIV -> Int -> IO ()
mutableXor :: PtrIV -> PtrIV -> PtrIV -> Int -> IO ()
mutableXor (PtrIV -> Ptr Word64
to64 -> Ptr Word64
dst) (PtrIV -> Ptr Word64
to64 -> Ptr Word64
src) (PtrIV -> Ptr Word64
to64 -> Ptr Word64
iv) Int
16 = do
    Ptr Word64 -> Ptr Word64 -> Ptr Word64 -> IO ()
forall a. (Bits a, Storable a) => Ptr a -> Ptr a -> Ptr a -> IO ()
peeksAndPoke Ptr Word64
dst Ptr Word64
src Ptr Word64
iv
    Ptr Word64 -> Ptr Word64 -> Ptr Word64 -> IO ()
forall a. (Bits a, Storable a) => Ptr a -> Ptr a -> Ptr a -> IO ()
peeksAndPoke (Ptr Word64
dst Ptr Word64 -> Int -> Ptr Word64
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
8) (Ptr Word64
src Ptr Word64 -> Int -> Ptr Word64
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
8) ((Ptr Word64
iv Ptr Word64 -> Int -> Ptr Word64
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
8) :: Ptr Word64)
mutableXor (PtrIV -> Ptr Word64
to64 -> Ptr Word64
dst) (PtrIV -> Ptr Word64
to64 -> Ptr Word64
src) (PtrIV -> Ptr Word64
to64 -> Ptr Word64
iv) Int
8 = do
    Ptr Word64 -> Ptr Word64 -> Ptr Word64 -> IO ()
forall a. (Bits a, Storable a) => Ptr a -> Ptr a -> Ptr a -> IO ()
peeksAndPoke Ptr Word64
dst Ptr Word64
src Ptr Word64
iv
mutableXor PtrIV
dst PtrIV
src PtrIV
iv Int
len = PtrIV -> PtrIV -> PtrIV -> Int -> IO ()
forall {t} {b}.
(Num t, Bits b, Storable b, Eq t) =>
Ptr b -> Ptr b -> Ptr b -> t -> IO ()
loop PtrIV
dst PtrIV
src PtrIV
iv Int
len
  where loop :: Ptr b -> Ptr b -> Ptr b -> t -> IO ()
loop Ptr b
_ Ptr b
_ Ptr b
_ t
0 = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        loop Ptr b
d Ptr b
s Ptr b
i t
n = Ptr b -> Ptr b -> Ptr b -> IO ()
forall a. (Bits a, Storable a) => Ptr a -> Ptr a -> Ptr a -> IO ()
peeksAndPoke Ptr b
d Ptr b
s Ptr b
i IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr b -> Ptr b -> Ptr b -> t -> IO ()
loop (Ptr b
d Ptr b -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (Ptr b
s Ptr b -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (Ptr b
i Ptr b -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (t
nt -> t -> t
forall a. Num a => a -> a -> a
-t
1)

to64 :: Ptr Word8 -> Ptr Word64
to64 :: PtrIV -> Ptr Word64
to64 = PtrIV -> Ptr Word64
forall a b. Ptr a -> Ptr b
castPtr

peeksAndPoke :: (Bits a, Storable a) => Ptr a -> Ptr a -> Ptr a -> IO ()
peeksAndPoke :: forall a. (Bits a, Storable a) => Ptr a -> Ptr a -> Ptr a -> IO ()
peeksAndPoke Ptr a
dst Ptr a
a Ptr a
b = (a -> a -> a
forall a. Bits a => a -> a -> a
xor (a -> a -> a) -> IO a -> IO (a -> a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr a -> IO a
forall a. Storable a => Ptr a -> IO a
peek Ptr a
a IO (a -> a) -> IO a -> IO a
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Ptr a -> IO a
forall a. Storable a => Ptr a -> IO a
peek Ptr a
b) IO a -> (a -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr a -> a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr a
dst