{-# LANGUAGE ScopedTypeVariables #-}
module Crypto.Data.AFIS
( split
, merge
) where
import Crypto.Hash
import Crypto.Random.Types
import Crypto.Internal.Compat
import Control.Monad (forM_, foldM)
import Data.Word
import Data.Bits
import Foreign.Storable
import Foreign.Ptr
import Crypto.Internal.ByteArray (ByteArray, Bytes, MemView(..))
import qualified Crypto.Internal.ByteArray as B
import Data.Memory.PtrMethods (memSet, memCopy)
split :: (ByteArray ba, HashAlgorithm hash, DRG rng)
=> hash
-> rng
-> Int
-> ba
-> (ba, rng)
{-# NOINLINE split #-}
split :: hash -> rng -> Int -> ba -> (ba, rng)
split hashAlg :: hash
hashAlg rng :: rng
rng expandTimes :: Int
expandTimes src :: ba
src
| Int
expandTimes Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 1 = [Char] -> (ba, rng)
forall a. HasCallStack => [Char] -> a
error "invalid expandTimes value"
| Bool
otherwise = IO (ba, rng) -> (ba, rng)
forall a. IO a -> a
unsafeDoIO (IO (ba, rng) -> (ba, rng)) -> IO (ba, rng) -> (ba, rng)
forall a b. (a -> b) -> a -> b
$ do
(rng' :: rng
rng', bs :: ba
bs) <- Int -> (Ptr Any -> IO rng) -> IO (rng, ba)
forall ba p a. ByteArray ba => Int -> (Ptr p -> IO a) -> IO (a, ba)
B.allocRet Int
diffusedLen Ptr Any -> IO rng
forall a. Ptr a -> IO rng
runOp
(ba, rng) -> IO (ba, rng)
forall (m :: * -> *) a. Monad m => a -> m a
return (ba
bs, rng
rng')
where diffusedLen :: Int
diffusedLen = Int
blockSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
expandTimes
blockSize :: Int
blockSize = ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
src
runOp :: Ptr a -> IO rng
runOp dstPtr :: Ptr a
dstPtr = do
let lastBlock :: Ptr b
lastBlock = Ptr a
dstPtr Ptr a -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
blockSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
expandTimesInt -> Int -> Int
forall a. Num a => a -> a -> a
-1))
Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
forall b. Ptr b
lastBlock 0 Int
blockSize
let randomBlockPtrs :: [Ptr b]
randomBlockPtrs = (Int -> Ptr b) -> [Int] -> [Ptr b]
forall a b. (a -> b) -> [a] -> [b]
map (Ptr a -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr a
dstPtr (Int -> Ptr b) -> (Int -> Int) -> Int -> Ptr b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) Int
blockSize) [0..(Int
expandTimesInt -> Int -> Int
forall a. Num a => a -> a -> a
-2)]
rng
rng' <- (rng -> Ptr Word8 -> IO rng) -> rng -> [Ptr Word8] -> IO rng
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM rng -> Ptr Word8 -> IO rng
forall b. DRG b => b -> Ptr Word8 -> IO b
fillRandomBlock rng
rng [Ptr Word8]
forall b. [Ptr b]
randomBlockPtrs
(Ptr Word8 -> IO ()) -> [Ptr Word8] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Ptr Word8 -> Ptr Word8 -> IO ()
addRandomBlock Ptr Word8
forall b. Ptr b
lastBlock) [Ptr Word8]
forall b. [Ptr b]
randomBlockPtrs
ba -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
src ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \srcPtr :: Ptr Word8
srcPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem Ptr Word8
srcPtr Ptr Word8
forall b. Ptr b
lastBlock Int
blockSize
rng -> IO rng
forall (m :: * -> *) a. Monad m => a -> m a
return rng
rng'
addRandomBlock :: Ptr Word8 -> Ptr Word8 -> IO ()
addRandomBlock lastBlock :: Ptr Word8
lastBlock blockPtr :: Ptr Word8
blockPtr = do
Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem Ptr Word8
blockPtr Ptr Word8
lastBlock Int
blockSize
hash -> Ptr Word8 -> Int -> IO ()
forall hash.
HashAlgorithm hash =>
hash -> Ptr Word8 -> Int -> IO ()
diffuse hash
hashAlg Ptr Word8
lastBlock Int
blockSize
fillRandomBlock :: b -> Ptr Word8 -> IO b
fillRandomBlock g :: b
g blockPtr :: Ptr Word8
blockPtr = do
let (Bytes
rand :: Bytes, g' :: b
g') = Int -> b -> (Bytes, b)
forall gen byteArray.
(DRG gen, ByteArray byteArray) =>
Int -> gen -> (byteArray, gen)
randomBytesGenerate Int
blockSize b
g
Bytes -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
rand ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \randPtr :: Ptr Word8
randPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
blockPtr Ptr Word8
randPtr Int
blockSize
b -> IO b
forall (m :: * -> *) a. Monad m => a -> m a
return b
g'
merge :: (ByteArray ba, HashAlgorithm hash)
=> hash
-> Int
-> ba
-> ba
{-# NOINLINE merge #-}
merge :: hash -> Int -> ba -> ba
merge hashAlg :: hash
hashAlg expandTimes :: Int
expandTimes bs :: ba
bs
| Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= 0 = [Char] -> ba
forall a. HasCallStack => [Char] -> a
error "diffused data not a multiple of expandTimes"
| Int
originalSize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 0 = [Char] -> ba
forall a. HasCallStack => [Char] -> a
error "diffused data null"
| Bool
otherwise = Int -> (Ptr Word8 -> IO ()) -> ba
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
originalSize ((Ptr Word8 -> IO ()) -> ba) -> (Ptr Word8 -> IO ()) -> ba
forall a b. (a -> b) -> a -> b
$ \dstPtr :: Ptr Word8
dstPtr ->
ba -> (Ptr Any -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
bs ((Ptr Any -> IO ()) -> IO ()) -> (Ptr Any -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \srcPtr :: Ptr Any
srcPtr -> do
Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
dstPtr 0 Int
originalSize
[Int] -> (Int -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [0..(Int
expandTimesInt -> Int -> Int
forall a. Num a => a -> a -> a
-2)] ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \i :: Int
i -> do
Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem (Ptr Any
srcPtr Ptr Any -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
originalSize)) Ptr Word8
dstPtr Int
originalSize
hash -> Ptr Word8 -> Int -> IO ()
forall hash.
HashAlgorithm hash =>
hash -> Ptr Word8 -> Int -> IO ()
diffuse hash
hashAlg Ptr Word8
dstPtr Int
originalSize
Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem (Ptr Any
srcPtr Ptr Any -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` ((Int
expandTimesInt -> Int -> Int
forall a. Num a => a -> a -> a
-1) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
originalSize)) Ptr Word8
dstPtr Int
originalSize
() -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
where (originalSize :: Int
originalSize,r :: Int
r) = Int
len Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Int
expandTimes
len :: Int
len = ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
bs
xorMem :: Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem :: Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem src :: Ptr Word8
src dst :: Ptr Word8
dst sz :: Int
sz
| Int
sz Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` 64 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 0 = Int -> Ptr Word64 -> Ptr Word64 -> Int -> IO ()
forall b.
(Storable b, Bits b) =>
Int -> Ptr b -> Ptr b -> Int -> IO ()
loop 8 (Ptr Word8 -> Ptr Word64
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
src :: Ptr Word64) (Ptr Word8 -> Ptr Word64
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
dst) Int
sz
| Int
sz Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` 32 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 0 = Int -> Ptr Word32 -> Ptr Word32 -> Int -> IO ()
forall b.
(Storable b, Bits b) =>
Int -> Ptr b -> Ptr b -> Int -> IO ()
loop 4 (Ptr Word8 -> Ptr Word32
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
src :: Ptr Word32) (Ptr Word8 -> Ptr Word32
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
dst) Int
sz
| Bool
otherwise = Int -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall b.
(Storable b, Bits b) =>
Int -> Ptr b -> Ptr b -> Int -> IO ()
loop 1 (Ptr Word8
src :: Ptr Word8) Ptr Word8
dst Int
sz
where loop :: Int -> Ptr b -> Ptr b -> Int -> IO ()
loop _ _ _ 0 = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
loop incr :: Int
incr s :: Ptr b
s d :: Ptr b
d n :: Int
n = do b
a <- Ptr b -> IO b
forall a. Storable a => Ptr a -> IO a
peek Ptr b
s
b
b <- Ptr b -> IO b
forall a. Storable a => Ptr a -> IO a
peek Ptr b
d
Ptr b -> b -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr b
d (b
a b -> b -> b
forall a. Bits a => a -> a -> a
`xor` b
b)
Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
incr (Ptr b
s Ptr b -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
incr) (Ptr b
d Ptr b -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
incr) (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
incr)
diffuse :: HashAlgorithm hash
=> hash
-> Ptr Word8
-> Int
-> IO ()
diffuse :: hash -> Ptr Word8 -> Int -> IO ()
diffuse hashAlg :: hash
hashAlg src :: Ptr Word8
src sz :: Int
sz = Ptr Word8 -> Int -> IO ()
loop Ptr Word8
src 0
where (full :: Int
full,pad :: Int
pad) = Int
sz Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Int
digestSize
loop :: Ptr Word8 -> Int -> IO ()
loop s :: Ptr Word8
s i :: Int
i
| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
full = do Digest hash
h <- Int -> Ptr Word8 -> Int -> IO (Digest hash)
forall (m :: * -> *).
Monad m =>
Int -> Ptr Word8 -> Int -> m (Digest hash)
hashBlock Int
i Ptr Word8
s Int
digestSize
Digest hash -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Digest hash
h ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \hPtr :: Ptr Word8
hPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
s Ptr Word8
hPtr Int
digestSize
Ptr Word8 -> Int -> IO ()
loop (Ptr Word8
s Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
digestSize) (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+1)
| Int
pad Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= 0 = do Digest hash
h <- Int -> Ptr Word8 -> Int -> IO (Digest hash)
forall (m :: * -> *).
Monad m =>
Int -> Ptr Word8 -> Int -> m (Digest hash)
hashBlock Int
i Ptr Word8
s Int
pad
Digest hash -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Digest hash
h ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \hPtr :: Ptr Word8
hPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
s Ptr Word8
hPtr Int
pad
() -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
| Bool
otherwise = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
digestSize :: Int
digestSize = hash -> Int
forall a. HashAlgorithm a => a -> Int
hashDigestSize hash
hashAlg
hashBlock :: Int -> Ptr Word8 -> Int -> m (Digest hash)
hashBlock n :: Int
n p :: Ptr Word8
p hashSz :: Int
hashSz = do
let ctx :: Context hash
ctx = hash -> Context hash
forall alg. HashAlgorithm alg => alg -> Context alg
hashInitWith hash
hashAlg
Digest hash -> m (Digest hash)
forall (m :: * -> *) a. Monad m => a -> m a
return (Digest hash -> m (Digest hash)) -> Digest hash -> m (Digest hash)
forall a b. (a -> b) -> a -> b
$! Context hash -> Digest hash
forall a. HashAlgorithm a => Context a -> Digest a
hashFinalize (Context hash -> Digest hash) -> Context hash -> Digest hash
forall a b. (a -> b) -> a -> b
$ Context hash -> MemView -> Context hash
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
Context a -> ba -> Context a
hashUpdate (Context hash -> Bytes -> Context hash
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
Context a -> ba -> Context a
hashUpdate Context hash
ctx (Int -> Bytes
be32 Int
n)) (Ptr Word8 -> Int -> MemView
MemView Ptr Word8
p Int
hashSz)
be32 :: Int -> Bytes
be32 :: Int -> Bytes
be32 n :: Int
n = Int -> (Ptr Word8 -> IO ()) -> Bytes
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze 4 ((Ptr Word8 -> IO ()) -> Bytes) -> (Ptr Word8 -> IO ()) -> Bytes
forall a b. (a -> b) -> a -> b
$ \ptr :: Ptr Word8
ptr -> do
Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
ptr (Int -> Word8
f8 (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` 24))
Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
ptr Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` 1) (Int -> Word8
f8 (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` 16))
Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
ptr Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` 2) (Int -> Word8
f8 (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` 8))
Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
ptr Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` 3) (Int -> Word8
f8 Int
n)
where
f8 :: Int -> Word8
f8 :: Int -> Word8
f8 = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral