{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DeriveDataTypeable #-}
module Network.DNS.Transport (
Resolver(..)
, resolve
) where
import Control.Concurrent.Async (async, waitAnyCancel)
import Control.Exception as E
import qualified Data.ByteString.Char8 as BS
import qualified Data.List.NonEmpty as NE
import Network.Socket (AddrInfo(..), SockAddr(..), Family(AF_INET, AF_INET6), Socket, SocketType(Stream), close, socket, connect, defaultProtocol)
import System.IO.Error (annotateIOError)
import System.Timeout (timeout)
import Network.DNS.IO
import Network.DNS.Imports
import Network.DNS.Types
import Network.DNS.Types.Internal
checkResp :: [Question] -> Identifier -> DNSMessage -> Bool
checkResp q seqno resp =
(identifier (header resp) == seqno) && (q == (question resp))
data TCPFallback = TCPFallback deriving (Show, Typeable)
instance Exception TCPFallback
type Rslv0 = Bool -> (Socket -> IO DNSMessage)
-> IO (Either DNSError DNSMessage)
type Rslv1 = [Question]
-> [ResourceRecord]
-> Int
-> Int
-> Rslv0
type TcpRslv = Identifier -> AddrInfo -> [Question] -> Int
-> Bool -> IO DNSMessage
type UdpRslv = [ResourceRecord] -> Int
-> (Socket -> IO DNSMessage) -> TcpRslv
resolve :: Domain -> TYPE -> Resolver -> Rslv0
resolve dom typ rlv ad rcv
| isIllegal dom = return $ Left IllegalDomain
| onlyOne = resolveOne (head nss) (head gens) q edns tm retry ad rcv
| concurrent = resolveConcurrent nss gens q edns tm retry ad rcv
| otherwise = resolveSequential nss gens q edns tm retry ad rcv
where
q = case BS.last dom of
'.' -> [Question dom typ]
_ -> [Question (dom <> ".") typ]
gens = NE.toList $ genIds rlv
seed = resolvseed rlv
nss = NE.toList $ nameservers seed
onlyOne = length nss == 1
conf = resolvconf seed
concurrent = resolvConcurrent conf
tm = resolvTimeout conf
retry = resolvRetry conf
edns = resolvEDNS conf
resolveSequential :: [AddrInfo] -> [IO Identifier] -> Rslv1
resolveSequential nss gs q edns tm retry ad rcv = loop nss gs
where
loop [ai] [gen] = resolveOne ai gen q edns tm retry ad rcv
loop (ai:ais) (gen:gens) = do
eres <- resolveOne ai gen q edns tm retry ad rcv
case eres of
Left _ -> loop ais gens
res -> return res
loop _ _ = error "resolveSequential:loop"
resolveConcurrent :: [AddrInfo] -> [IO Identifier] -> Rslv1
resolveConcurrent nss gens q edns tm retry ad rcv = do
asyncs <- mapM mkAsync $ zip nss gens
snd <$> waitAnyCancel asyncs
where
mkAsync (ai,gen) = async $ resolveOne ai gen q edns tm retry ad rcv
resolveOne :: AddrInfo -> IO Identifier -> Rslv1
resolveOne ai gen q edns tm retry ad rcv = do
ident <- gen
E.try $ udpTcpLookup edns retry rcv ident ai q tm ad
udpTcpLookup :: UdpRslv
udpTcpLookup edns retry rcv ident ai q tm ad =
udpLookup edns retry rcv ident ai q tm ad `E.catch` \TCPFallback ->
tcpLookup ident ai q tm ad
ioErrorToDNSError :: AddrInfo -> String -> IOError -> IO DNSMessage
ioErrorToDNSError ai tag ioe = throwIO $ NetworkFailure aioe
where
aioe = annotateIOError ioe (show ai) Nothing $ Just tag
udpOpen :: AddrInfo -> IO Socket
udpOpen ai = do
sock <- socket (addrFamily ai) (addrSocketType ai) (addrProtocol ai)
connect sock (addrAddress ai)
return sock
udpLookup :: UdpRslv
udpLookup edns retry rcv ident ai q tm ad = do
let qry = encodeQuestions ident q edns ad
ednsRetry = not $ null edns
E.handle (ioErrorToDNSError ai "UDP") $
bracket (udpOpen ai) close (loop qry ednsRetry 0 RetryLimitExceeded)
where
loop qry ednsRetry cnt err sock
| cnt == retry = E.throwIO err
| otherwise = do
mres <- timeout tm (send sock qry >> getAns sock)
case mres of
Nothing -> loop qry ednsRetry (cnt + 1) RetryLimitExceeded sock
Just res -> do
let flgs = flags$ header res
truncated = trunCation flgs
rc = rcode flgs
if truncated then
E.throwIO TCPFallback
else if ednsRetry && rc == FormatErr then
let nonednsQuery = encodeQuestions ident q [] ad
in loop nonednsQuery False cnt RetryLimitExceeded sock
else
return res
getAns sock = do
mres <- rcv sock
if checkResp q ident mres
then return mres
else getAns sock
tcpOpen :: SockAddr -> IO Socket
tcpOpen peer = case peer of
SockAddrInet{} -> socket AF_INET Stream defaultProtocol
SockAddrInet6{} -> socket AF_INET6 Stream defaultProtocol
_ -> E.throwIO ServerFailure
tcpLookup :: TcpRslv
tcpLookup ident ai q tm ad =
E.handle (ioErrorToDNSError ai "TCP") $ bracket (tcpOpen addr) close perform
where
addr = addrAddress ai
perform vc = do
let qry = encodeQuestions ident q [] ad
mres <- timeout tm $ do
connect vc addr
sendVC vc qry
receiveVC vc
case mres of
Nothing -> E.throwIO TimeoutExpired
Just res
| checkResp q ident res -> return res
| otherwise -> E.throwIO SequenceNumberMismatch
badLength :: Domain -> Bool
badLength dom
| BS.null dom = True
| BS.last dom == '.' = BS.length dom > 254
| otherwise = BS.length dom > 253
isIllegal :: Domain -> Bool
isIllegal dom
| badLength dom = True
| '.' `BS.notElem` dom = True
| ':' `BS.elem` dom = True
| '/' `BS.elem` dom = True
| any (\x -> BS.length x > 63)
(BS.split '.' dom) = True
| otherwise = False