--------------------------------------------------------------------------------
-- | Demultiplexing of frames into messages
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings  #-}
module Network.WebSockets.Hybi13.Demultiplex
    ( FrameType (..)
    , Frame (..)
    , DemultiplexState
    , emptyDemultiplexState
    , DemultiplexResult (..)
    , demultiplex
    ) where


--------------------------------------------------------------------------------
import           Data.ByteString.Builder               (Builder)
import qualified Data.ByteString.Builder               as B
import           Control.Exception                     (Exception)
import           Data.Binary.Get                       (getWord16be, runGet)
import qualified Data.ByteString.Lazy                  as BL
import           Data.Int                              (Int64)
import           Data.Monoid                           (mappend)
import           Data.Typeable                         (Typeable)
import           Network.WebSockets.Connection.Options
import           Network.WebSockets.Types


--------------------------------------------------------------------------------
-- | A low-level representation of a WebSocket packet
data Frame = Frame
    { Frame -> Bool
frameFin     :: !Bool
    , Frame -> Bool
frameRsv1    :: !Bool
    , Frame -> Bool
frameRsv2    :: !Bool
    , Frame -> Bool
frameRsv3    :: !Bool
    , Frame -> FrameType
frameType    :: !FrameType
    , Frame -> ByteString
framePayload :: !BL.ByteString
    } deriving (Frame -> Frame -> Bool
(Frame -> Frame -> Bool) -> (Frame -> Frame -> Bool) -> Eq Frame
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Frame -> Frame -> Bool
$c/= :: Frame -> Frame -> Bool
== :: Frame -> Frame -> Bool
$c== :: Frame -> Frame -> Bool
Eq, Int -> Frame -> ShowS
[Frame] -> ShowS
Frame -> String
(Int -> Frame -> ShowS)
-> (Frame -> String) -> ([Frame] -> ShowS) -> Show Frame
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Frame] -> ShowS
$cshowList :: [Frame] -> ShowS
show :: Frame -> String
$cshow :: Frame -> String
showsPrec :: Int -> Frame -> ShowS
$cshowsPrec :: Int -> Frame -> ShowS
Show)


--------------------------------------------------------------------------------
-- | The type of a frame. Not all types are allowed for all protocols.
data FrameType
    = ContinuationFrame
    | TextFrame
    | BinaryFrame
    | CloseFrame
    | PingFrame
    | PongFrame
    deriving (FrameType -> FrameType -> Bool
(FrameType -> FrameType -> Bool)
-> (FrameType -> FrameType -> Bool) -> Eq FrameType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: FrameType -> FrameType -> Bool
$c/= :: FrameType -> FrameType -> Bool
== :: FrameType -> FrameType -> Bool
$c== :: FrameType -> FrameType -> Bool
Eq, Int -> FrameType -> ShowS
[FrameType] -> ShowS
FrameType -> String
(Int -> FrameType -> ShowS)
-> (FrameType -> String)
-> ([FrameType] -> ShowS)
-> Show FrameType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [FrameType] -> ShowS
$cshowList :: [FrameType] -> ShowS
show :: FrameType -> String
$cshow :: FrameType -> String
showsPrec :: Int -> FrameType -> ShowS
$cshowsPrec :: Int -> FrameType -> ShowS
Show)


--------------------------------------------------------------------------------
-- | Thrown if the client sends invalid multiplexed data
data DemultiplexException = DemultiplexException
    deriving (Int -> DemultiplexException -> ShowS
[DemultiplexException] -> ShowS
DemultiplexException -> String
(Int -> DemultiplexException -> ShowS)
-> (DemultiplexException -> String)
-> ([DemultiplexException] -> ShowS)
-> Show DemultiplexException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DemultiplexException] -> ShowS
$cshowList :: [DemultiplexException] -> ShowS
show :: DemultiplexException -> String
$cshow :: DemultiplexException -> String
showsPrec :: Int -> DemultiplexException -> ShowS
$cshowsPrec :: Int -> DemultiplexException -> ShowS
Show, Typeable)


--------------------------------------------------------------------------------
instance Exception DemultiplexException


--------------------------------------------------------------------------------
-- | Internal state used by the demultiplexer
data DemultiplexState
    = EmptyDemultiplexState
    | DemultiplexState !Int64 !Builder !(Builder -> Message)


--------------------------------------------------------------------------------
emptyDemultiplexState :: DemultiplexState
emptyDemultiplexState :: DemultiplexState
emptyDemultiplexState = DemultiplexState
EmptyDemultiplexState


--------------------------------------------------------------------------------
-- | Result of demultiplexing
data DemultiplexResult
    = DemultiplexSuccess  Message
    | DemultiplexError    ConnectionException
    | DemultiplexContinue


--------------------------------------------------------------------------------
demultiplex :: SizeLimit
            -> DemultiplexState
            -> Frame
            -> (DemultiplexResult, DemultiplexState)

demultiplex :: SizeLimit
-> DemultiplexState
-> Frame
-> (DemultiplexResult, DemultiplexState)
demultiplex _ state :: DemultiplexState
state (Frame True False False False PingFrame pl :: ByteString
pl)
    | ByteString -> Int64
BL.length ByteString
pl Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
> 125 =
        (ConnectionException -> DemultiplexResult
DemultiplexError (ConnectionException -> DemultiplexResult)
-> ConnectionException -> DemultiplexResult
forall a b. (a -> b) -> a -> b
$ Word16 -> ByteString -> ConnectionException
CloseRequest 1002 "Protocol Error", DemultiplexState
emptyDemultiplexState)
    | Bool
otherwise =
        (Message -> DemultiplexResult
DemultiplexSuccess (Message -> DemultiplexResult) -> Message -> DemultiplexResult
forall a b. (a -> b) -> a -> b
$ ControlMessage -> Message
ControlMessage (ByteString -> ControlMessage
Ping ByteString
pl), DemultiplexState
state)

demultiplex _ state :: DemultiplexState
state (Frame True False False False PongFrame pl :: ByteString
pl) =
    (Message -> DemultiplexResult
DemultiplexSuccess (ControlMessage -> Message
ControlMessage (ByteString -> ControlMessage
Pong ByteString
pl)), DemultiplexState
state)

demultiplex _ _ (Frame True False False False CloseFrame pl :: ByteString
pl) =
    (Message -> DemultiplexResult
DemultiplexSuccess (ControlMessage -> Message
ControlMessage ((Word16 -> ByteString -> ControlMessage)
-> (Word16, ByteString) -> ControlMessage
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Word16 -> ByteString -> ControlMessage
Close (Word16, ByteString)
parsedClose)), DemultiplexState
emptyDemultiplexState)
  where
    -- The Close frame MAY contain a body (the "Application data" portion of the
    -- frame) that indicates a reason for closing, such as an endpoint shutting
    -- down, an endpoint having received a frame too large, or an endpoint
    -- having received a frame that does not conform to the format expected by
    -- the endpoint. If there is a body, the first two bytes of the body MUST
    -- be a 2-byte unsigned integer (in network byte order) representing a
    -- status code with value /code/ defined in Section 7.4.
    parsedClose :: (Word16, ByteString)
parsedClose
       | ByteString -> Int64
BL.length ByteString
pl Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= 2 = case Get Word16 -> ByteString -> Word16
forall a. Get a -> ByteString -> a
runGet Get Word16
getWord16be ByteString
pl of
              a :: Word16
a | Word16
a Word16 -> Word16 -> Bool
forall a. Ord a => a -> a -> Bool
< 1000 Bool -> Bool -> Bool
|| Word16
a Word16 -> [Word16] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [1004,1005,1006
                                       ,1014,1015,1016
                                       ,1100,2000,2999
                                       ,5000,65535] -> (1002, ByteString
BL.empty)
              a :: Word16
a -> (Word16
a, Int64 -> ByteString -> ByteString
BL.drop 2 ByteString
pl)
       | ByteString -> Int64
BL.length ByteString
pl Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== 1 = (1002, ByteString
BL.empty)
       | Bool
otherwise         = (1000, ByteString
BL.empty)

demultiplex sizeLimit :: SizeLimit
sizeLimit EmptyDemultiplexState (Frame fin :: Bool
fin rsv1 :: Bool
rsv1 rsv2 :: Bool
rsv2 rsv3 :: Bool
rsv3 tp :: FrameType
tp pl :: ByteString
pl) = case FrameType
tp of
    _ | Bool -> Bool
not (Int64 -> SizeLimit -> Bool
atMostSizeLimit Int64
size SizeLimit
sizeLimit) ->
        ( ConnectionException -> DemultiplexResult
DemultiplexError (ConnectionException -> DemultiplexResult)
-> ConnectionException -> DemultiplexResult
forall a b. (a -> b) -> a -> b
$ String -> ConnectionException
ParseException (String -> ConnectionException) -> String -> ConnectionException
forall a b. (a -> b) -> a -> b
$
            "Message of size " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int64 -> String
forall a. Show a => a -> String
show Int64
size String -> ShowS
forall a. [a] -> [a] -> [a]
++ " exceeded limit"
        , DemultiplexState
emptyDemultiplexState
        )

    TextFrame
        | Bool
fin       ->
            (Message -> DemultiplexResult
DemultiplexSuccess (ByteString -> Message
text ByteString
pl), DemultiplexState
emptyDemultiplexState)
        | Bool
otherwise ->
            (DemultiplexResult
DemultiplexContinue, Int64 -> Builder -> (Builder -> Message) -> DemultiplexState
DemultiplexState Int64
size Builder
plb (ByteString -> Message
text (ByteString -> Message)
-> (Builder -> ByteString) -> Builder -> Message
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
B.toLazyByteString))


    BinaryFrame
        | Bool
fin       -> (Message -> DemultiplexResult
DemultiplexSuccess (ByteString -> Message
binary ByteString
pl), DemultiplexState
emptyDemultiplexState)
        | Bool
otherwise -> (DemultiplexResult
DemultiplexContinue, Int64 -> Builder -> (Builder -> Message) -> DemultiplexState
DemultiplexState Int64
size Builder
plb (ByteString -> Message
binary (ByteString -> Message)
-> (Builder -> ByteString) -> Builder -> Message
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
B.toLazyByteString))

    _ -> (ConnectionException -> DemultiplexResult
DemultiplexError (ConnectionException -> DemultiplexResult)
-> ConnectionException -> DemultiplexResult
forall a b. (a -> b) -> a -> b
$ Word16 -> ByteString -> ConnectionException
CloseRequest 1002 "Protocol Error", DemultiplexState
emptyDemultiplexState)

  where
    size :: Int64
size     = ByteString -> Int64
BL.length ByteString
pl
    plb :: Builder
plb      = ByteString -> Builder
B.lazyByteString ByteString
pl
    text :: ByteString -> Message
text   x :: ByteString
x = Bool -> Bool -> Bool -> DataMessage -> Message
DataMessage Bool
rsv1 Bool
rsv2 Bool
rsv3 (ByteString -> Maybe Text -> DataMessage
Text ByteString
x Maybe Text
forall a. Maybe a
Nothing)
    binary :: ByteString -> Message
binary x :: ByteString
x = Bool -> Bool -> Bool -> DataMessage -> Message
DataMessage Bool
rsv1 Bool
rsv2 Bool
rsv3 (ByteString -> DataMessage
Binary ByteString
x)

demultiplex sizeLimit :: SizeLimit
sizeLimit (DemultiplexState size0 :: Int64
size0 b :: Builder
b f :: Builder -> Message
f) (Frame fin :: Bool
fin False False False ContinuationFrame pl :: ByteString
pl)
    | Bool -> Bool
not (Int64 -> SizeLimit -> Bool
atMostSizeLimit Int64
size1 SizeLimit
sizeLimit) =
        ( ConnectionException -> DemultiplexResult
DemultiplexError (ConnectionException -> DemultiplexResult)
-> ConnectionException -> DemultiplexResult
forall a b. (a -> b) -> a -> b
$ String -> ConnectionException
ParseException (String -> ConnectionException) -> String -> ConnectionException
forall a b. (a -> b) -> a -> b
$
            "Message of size " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int64 -> String
forall a. Show a => a -> String
show Int64
size1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ " exceeded limit"
        , DemultiplexState
emptyDemultiplexState
        )
    | Bool
fin         = (Message -> DemultiplexResult
DemultiplexSuccess (Builder -> Message
f Builder
b'), DemultiplexState
emptyDemultiplexState)
    | Bool
otherwise   = (DemultiplexResult
DemultiplexContinue, Int64 -> Builder -> (Builder -> Message) -> DemultiplexState
DemultiplexState Int64
size1 Builder
b' Builder -> Message
f)
  where
    size1 :: Int64
size1 = Int64
size0 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ ByteString -> Int64
BL.length ByteString
pl
    b' :: Builder
b'    = Builder
b Builder -> Builder -> Builder
forall a. Monoid a => a -> a -> a
`mappend` Builder
plb
    plb :: Builder
plb   = ByteString -> Builder
B.lazyByteString ByteString
pl

demultiplex _ _ _ =
    (ConnectionException -> DemultiplexResult
DemultiplexError (Word16 -> ByteString -> ConnectionException
CloseRequest 1002 "Protocol Error"), DemultiplexState
emptyDemultiplexState)