-- |
-- Module      : Data.SecureMem
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : Stable
-- Portability : GHC
--
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE CPP #-}
module Data.SecureMem
    ( SecureMem
    , secureMemGetSize
    , secureMemCopy
    , ToSecureMem(..)
    -- * Allocation and early termination
    , allocateSecureMem
    , createSecureMem
    , unsafeCreateSecureMem
    , finalizeSecureMem
    -- * Pointers manipulation
    , withSecureMemPtr
    , withSecureMemPtrSz
    , withSecureMemCopy
    -- * convertion
    , secureMemFromByteString
    , secureMemFromByteable
    ) where

import           Foreign.ForeignPtr (withForeignPtr)
import           Foreign.Ptr
import           Data.Word (Word8)
#if MIN_VERSION_base(4,9,0)
import           Data.Semigroup
import           Data.Foldable (toList)
#else
import           Data.Monoid
#endif
import           Control.Applicative
import           Data.Byteable

import           Data.ByteString (ByteString)
import           Data.ByteArray  (ScrubbedBytes)
import qualified Data.ByteArray as B
import qualified Data.Memory.PtrMethods as B (memSet)

import qualified Data.ByteString.Internal as BS

#if MIN_VERSION_base(4,4,0)
import System.IO.Unsafe (unsafeDupablePerformIO)
#else
import System.IO.Unsafe (unsafePerformIO)
#endif

pureIO :: IO a -> a
#if MIN_VERSION_base(4,4,0)
pureIO :: forall a. IO a -> a
pureIO = IO a -> a
forall a. IO a -> a
unsafeDupablePerformIO
#else
pureIO = unsafePerformIO
#endif

-- | SecureMem is a memory chunk which have the properties of:
--
-- * Being scrubbed after its goes out of scope.
--
-- * A Show instance that doesn't actually show any content
--
-- * A Eq instance that is constant time
--
newtype SecureMem = SecureMem ScrubbedBytes

secureMemGetSize :: SecureMem -> Int
secureMemGetSize :: SecureMem -> Int
secureMemGetSize (SecureMem ScrubbedBytes
scrubbedBytes) = ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ScrubbedBytes
scrubbedBytes

secureMemEq :: SecureMem -> SecureMem -> Bool
secureMemEq :: SecureMem -> SecureMem -> Bool
secureMemEq (SecureMem ScrubbedBytes
sm1) (SecureMem ScrubbedBytes
sm2) = ScrubbedBytes
sm1 ScrubbedBytes -> ScrubbedBytes -> Bool
forall a. Eq a => a -> a -> Bool
== ScrubbedBytes
sm2

secureMemAppend :: SecureMem -> SecureMem -> SecureMem
secureMemAppend :: SecureMem -> SecureMem -> SecureMem
secureMemAppend (SecureMem ScrubbedBytes
s1) (SecureMem ScrubbedBytes
s2) = ScrubbedBytes -> SecureMem
SecureMem (ScrubbedBytes
s1 ScrubbedBytes -> ScrubbedBytes -> ScrubbedBytes
forall a. Monoid a => a -> a -> a
`mappend` ScrubbedBytes
s2)

secureMemConcat :: [SecureMem] -> SecureMem
secureMemConcat :: [SecureMem] -> SecureMem
secureMemConcat = ScrubbedBytes -> SecureMem
SecureMem (ScrubbedBytes -> SecureMem)
-> ([SecureMem] -> ScrubbedBytes) -> [SecureMem] -> SecureMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ScrubbedBytes] -> ScrubbedBytes
forall a. Monoid a => [a] -> a
mconcat ([ScrubbedBytes] -> ScrubbedBytes)
-> ([SecureMem] -> [ScrubbedBytes]) -> [SecureMem] -> ScrubbedBytes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SecureMem -> ScrubbedBytes) -> [SecureMem] -> [ScrubbedBytes]
forall a b. (a -> b) -> [a] -> [b]
map SecureMem -> ScrubbedBytes
unSecureMem
  where unSecureMem :: SecureMem -> ScrubbedBytes
unSecureMem (SecureMem ScrubbedBytes
sb) = ScrubbedBytes
sb

secureMemCopy :: SecureMem -> IO SecureMem
secureMemCopy :: SecureMem -> IO SecureMem
secureMemCopy (SecureMem ScrubbedBytes
src) =
    ScrubbedBytes -> SecureMem
SecureMem (ScrubbedBytes -> SecureMem) -> IO ScrubbedBytes -> IO SecureMem
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` ScrubbedBytes -> (Ptr (ZonkAny 1) -> IO ()) -> IO ScrubbedBytes
forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> IO bs2
B.copy ScrubbedBytes
src (\Ptr (ZonkAny 1)
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())

withSecureMemCopy :: SecureMem -> (Ptr Word8 -> IO ()) -> IO SecureMem
withSecureMemCopy :: SecureMem -> (Ptr Word8 -> IO ()) -> IO SecureMem
withSecureMemCopy (SecureMem ScrubbedBytes
src) Ptr Word8 -> IO ()
f = ScrubbedBytes -> SecureMem
SecureMem (ScrubbedBytes -> SecureMem) -> IO ScrubbedBytes -> IO SecureMem
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` ScrubbedBytes -> (Ptr Word8 -> IO ()) -> IO ScrubbedBytes
forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> IO bs2
B.copy ScrubbedBytes
src Ptr Word8 -> IO ()
f

instance Show SecureMem where
    show :: SecureMem -> String
show SecureMem
_ = String
"<secure-mem>"

instance Byteable SecureMem where
    toBytes :: SecureMem -> ByteString
toBytes        = SecureMem -> ByteString
secureMemToByteString
    byteableLength :: SecureMem -> Int
byteableLength = SecureMem -> Int
secureMemGetSize
    withBytePtr :: forall b. SecureMem -> (Ptr Word8 -> IO b) -> IO b
withBytePtr    = SecureMem -> (Ptr Word8 -> IO b) -> IO b
forall b. SecureMem -> (Ptr Word8 -> IO b) -> IO b
withSecureMemPtr

instance Eq SecureMem where
    == :: SecureMem -> SecureMem -> Bool
(==) = SecureMem -> SecureMem -> Bool
secureMemEq

#if MIN_VERSION_base(4,9,0)
instance Semigroup SecureMem where
    <> :: SecureMem -> SecureMem -> SecureMem
(<>)    = SecureMem -> SecureMem -> SecureMem
secureMemAppend
    sconcat :: NonEmpty SecureMem -> SecureMem
sconcat = [SecureMem] -> SecureMem
secureMemConcat ([SecureMem] -> SecureMem)
-> (NonEmpty SecureMem -> [SecureMem])
-> NonEmpty SecureMem
-> SecureMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty SecureMem -> [SecureMem]
forall a. NonEmpty a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList
#endif

instance Monoid SecureMem where
    mempty :: SecureMem
mempty  = Int -> (Ptr Word8 -> IO ()) -> SecureMem
unsafeCreateSecureMem Int
0 (\Ptr Word8
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
#if !(MIN_VERSION_base(4,11,0))
    mappend = secureMemAppend
    mconcat = secureMemConcat
#endif

-- | Types that can be converted to a secure mem object.
class ToSecureMem a where
    toSecureMem :: a -> SecureMem

instance ToSecureMem SecureMem where
    toSecureMem :: SecureMem -> SecureMem
toSecureMem SecureMem
a = SecureMem
a
instance ToSecureMem ByteString where
    toSecureMem :: ByteString -> SecureMem
toSecureMem ByteString
bs = ByteString -> SecureMem
secureMemFromByteString ByteString
bs

-- | Allocate a new SecureMem
--
-- The memory is allocated on the haskell heap, and will be scrubed
-- before being released.
allocateSecureMem :: Int -> IO SecureMem
allocateSecureMem :: Int -> IO SecureMem
allocateSecureMem Int
sz = ScrubbedBytes -> SecureMem
SecureMem (ScrubbedBytes -> SecureMem) -> IO ScrubbedBytes -> IO SecureMem
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> (Ptr (ZonkAny 0) -> IO ()) -> IO ScrubbedBytes
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.create Int
sz (\Ptr (ZonkAny 0)
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())

-- | Create a new secure mem and running an initializer function
createSecureMem :: Int -> (Ptr Word8 -> IO ()) -> IO SecureMem
createSecureMem :: Int -> (Ptr Word8 -> IO ()) -> IO SecureMem
createSecureMem Int
sz Ptr Word8 -> IO ()
f = ScrubbedBytes -> SecureMem
SecureMem (ScrubbedBytes -> SecureMem) -> IO ScrubbedBytes -> IO SecureMem
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Int -> (Ptr Word8 -> IO ()) -> IO ScrubbedBytes
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.create Int
sz Ptr Word8 -> IO ()
f

-- | Create a new secure mem using inline perform IO to create a pure
-- result.
unsafeCreateSecureMem :: Int -> (Ptr Word8 -> IO ()) -> SecureMem
unsafeCreateSecureMem :: Int -> (Ptr Word8 -> IO ()) -> SecureMem
unsafeCreateSecureMem Int
sz Ptr Word8 -> IO ()
f = IO SecureMem -> SecureMem
forall a. IO a -> a
pureIO (Int -> (Ptr Word8 -> IO ()) -> IO SecureMem
createSecureMem Int
sz Ptr Word8 -> IO ()
f)
{-# NOINLINE unsafeCreateSecureMem #-}

-- | This is a way to look at the pointer living inside a foreign object. This
-- function takes a function which is applied to that pointer. The resulting IO
-- action is then executed
--
-- this is similary to withForeignPtr for a ForeignPtr
withSecureMemPtr :: SecureMem -> (Ptr Word8 -> IO b) -> IO b
withSecureMemPtr :: forall b. SecureMem -> (Ptr Word8 -> IO b) -> IO b
withSecureMemPtr (SecureMem ScrubbedBytes
sm) Ptr Word8 -> IO b
f = ScrubbedBytes -> (Ptr Word8 -> IO b) -> IO b
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
B.withByteArray ScrubbedBytes
sm Ptr Word8 -> IO b
f

-- | similar to withSecureMem but also include the size of the pointed memory.
withSecureMemPtrSz :: SecureMem -> (Int -> Ptr Word8 -> IO b) -> IO b
withSecureMemPtrSz :: forall b. SecureMem -> (Int -> Ptr Word8 -> IO b) -> IO b
withSecureMemPtrSz (SecureMem ScrubbedBytes
sm) Int -> Ptr Word8 -> IO b
f = ScrubbedBytes -> (Ptr Word8 -> IO b) -> IO b
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
B.withByteArray ScrubbedBytes
sm (Int -> Ptr Word8 -> IO b
f (ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ScrubbedBytes
sm))

-- | Finalize a SecureMem early
finalizeSecureMem :: SecureMem -> IO ()
finalizeSecureMem :: SecureMem -> IO ()
finalizeSecureMem (SecureMem ScrubbedBytes
sb) = ScrubbedBytes -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
B.withByteArray ScrubbedBytes
sb ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p ->
    Ptr Word8 -> Word8 -> Int -> IO ()
B.memSet Ptr Word8
p Word8
0 (ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ScrubbedBytes
sb)

-- | Create a bytestring from a Secure Mem
secureMemToByteString :: SecureMem -> ByteString
secureMemToByteString :: SecureMem -> ByteString
secureMemToByteString SecureMem
sm =
    Int -> (Ptr Word8 -> IO ()) -> ByteString
BS.unsafeCreate Int
sz ((Ptr Word8 -> IO ()) -> ByteString)
-> (Ptr Word8 -> IO ()) -> ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dst ->
    SecureMem -> (Ptr Word8 -> IO ()) -> IO ()
forall b. SecureMem -> (Ptr Word8 -> IO b) -> IO b
withSecureMemPtr SecureMem
sm ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
src ->
    Ptr Word8 -> Ptr Word8 -> Int -> IO ()
BS.memcpy Ptr Word8
dst Ptr Word8
src (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
sz)
  where !sz :: Int
sz = SecureMem -> Int
secureMemGetSize SecureMem
sm

-- | Create a SecureMem from a bytestring
secureMemFromByteString :: ByteString -> SecureMem
secureMemFromByteString :: ByteString -> SecureMem
secureMemFromByteString ByteString
b = IO SecureMem -> SecureMem
forall a. IO a -> a
pureIO (IO SecureMem -> SecureMem) -> IO SecureMem -> SecureMem
forall a b. (a -> b) -> a -> b
$ do
    sm <- Int -> IO SecureMem
allocateSecureMem Int
len
    withSecureMemPtr sm $ \Ptr Word8
dst -> (Ptr Word8 -> IO ()) -> IO ()
forall {b} {b}. (Ptr b -> IO b) -> IO b
withBytestringPtr ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
src -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
BS.memcpy Ptr Word8
dst Ptr Word8
src (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
    return sm
  where (ForeignPtr Word8
fp, Int
off, !Int
len) = ByteString -> (ForeignPtr Word8, Int, Int)
BS.toForeignPtr ByteString
b
        withBytestringPtr :: (Ptr b -> IO b) -> IO b
withBytestringPtr Ptr b -> IO b
f = ForeignPtr Word8 -> (Ptr Word8 -> IO b) -> IO b
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO b) -> IO b) -> (Ptr Word8 -> IO b) -> IO b
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p -> Ptr b -> IO b
f (Ptr Word8
p Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off)
{-# NOINLINE secureMemFromByteString #-}

-- | Create a SecureMem from any byteable object
secureMemFromByteable :: Byteable b => b -> SecureMem
secureMemFromByteable :: forall b. Byteable b => b -> SecureMem
secureMemFromByteable b
bs = IO SecureMem -> SecureMem
forall a. IO a -> a
pureIO (IO SecureMem -> SecureMem) -> IO SecureMem -> SecureMem
forall a b. (a -> b) -> a -> b
$ do
    sm <- Int -> IO SecureMem
allocateSecureMem Int
len
    withSecureMemPtr sm $ \Ptr Word8
dst -> b -> (Ptr Word8 -> IO ()) -> IO ()
forall b. b -> (Ptr Word8 -> IO b) -> IO b
forall a b. Byteable a => a -> (Ptr Word8 -> IO b) -> IO b
withBytePtr b
bs ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
src -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
BS.memcpy Ptr Word8
dst Ptr Word8
src (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
    return sm
  where len :: Int
len = b -> Int
forall a. Byteable a => a -> Int
byteableLength b
bs
{-# NOINLINE secureMemFromByteable #-}