{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module Network.TLS.Handshake.Server.TLS13 (
recvClientSecondFlight13,
postHandshakeAuthServerWith,
) where
import Control.Monad.State.Strict
import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.Extension
import Network.TLS.Handshake.Common hiding (expectFinished)
import Network.TLS.Handshake.Common13
import Network.TLS.Handshake.Key
import Network.TLS.Handshake.Process
import Network.TLS.Handshake.Server.Common
import Network.TLS.Handshake.Signature
import Network.TLS.Handshake.State
import Network.TLS.Handshake.State13
import Network.TLS.IO
import Network.TLS.Imports
import Network.TLS.Parameters
import Network.TLS.Session
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.Types
import Network.TLS.X509
recvClientSecondFlight13
:: ServerParams
-> Context
-> ( SecretTriple ApplicationSecret
, ClientTrafficSecret HandshakeSecret
, Bool
, Bool
)
-> CH
-> IO ()
recvClientSecondFlight13 :: ServerParams
-> Context
-> (SecretTriple ApplicationSecret,
ClientTrafficSecret HandshakeSecret, Bool, Bool)
-> CH
-> IO ()
recvClientSecondFlight13 ServerParams
sparams Context
ctx (SecretTriple ApplicationSecret
appKey, ClientTrafficSecret HandshakeSecret
clientHandshakeSecret, Bool
authenticated, Bool
rtt0OK) CH{[CipherId]
[ExtensionRaw]
Session
chSession :: Session
chCiphers :: [CipherId]
chExtensions :: [ExtensionRaw]
chExtensions :: CH -> [ExtensionRaw]
chCiphers :: CH -> [CipherId]
chSession :: CH -> Session
..} = do
sfSentTime <- IO Millisecond
getCurrentTimeFromBase
let expectFinished' =
ServerParams
-> Context
-> [ExtensionRaw]
-> SecretTriple ApplicationSecret
-> ClientTrafficSecret HandshakeSecret
-> Millisecond
-> ByteString
-> Handshake13
-> RecvHandshake13M IO ()
forall (m :: * -> *).
MonadIO m =>
ServerParams
-> Context
-> [ExtensionRaw]
-> SecretTriple ApplicationSecret
-> ClientTrafficSecret HandshakeSecret
-> Millisecond
-> ByteString
-> Handshake13
-> m ()
expectFinished ServerParams
sparams Context
ctx [ExtensionRaw]
chExtensions SecretTriple ApplicationSecret
appKey ClientTrafficSecret HandshakeSecret
clientHandshakeSecret Millisecond
sfSentTime
if not authenticated && serverWantClientCert sparams
then runRecvHandshake13 $ do
recvHandshake13 ctx $ expectCertificate sparams ctx
recvHandshake13hash ctx (expectCertVerify sparams ctx)
recvHandshake13hash ctx expectFinished'
ensureRecvComplete ctx
else
if rtt0OK && not (ctxQUICMode ctx)
then
setPendingRecvActions
ctx
[ PendingRecvAction True $ expectEndOfEarlyData ctx clientHandshakeSecret
, PendingRecvActionHash True $
expectFinished sparams ctx chExtensions appKey clientHandshakeSecret sfSentTime
]
else runRecvHandshake13 $ do
recvHandshake13hash ctx expectFinished'
ensureRecvComplete ctx
expectFinished
:: MonadIO m
=> ServerParams
-> Context
-> [ExtensionRaw]
-> SecretTriple ApplicationSecret
-> ClientTrafficSecret HandshakeSecret
-> Word64
-> ByteString
-> Handshake13
-> m ()
expectFinished :: forall (m :: * -> *).
MonadIO m =>
ServerParams
-> Context
-> [ExtensionRaw]
-> SecretTriple ApplicationSecret
-> ClientTrafficSecret HandshakeSecret
-> Millisecond
-> ByteString
-> Handshake13
-> m ()
expectFinished ServerParams
sparams Context
ctx [ExtensionRaw]
exts SecretTriple ApplicationSecret
appKey ClientTrafficSecret HandshakeSecret
clientHandshakeSecret Millisecond
sfSentTime ByteString
hChBeforeCf (Finished13 VerifyData
verifyData) = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ do
Context -> (TLS13State -> TLS13State) -> IO ()
modifyTLS13State Context
ctx ((TLS13State -> TLS13State) -> IO ())
-> (TLS13State -> TLS13State) -> IO ()
forall a b. (a -> b) -> a -> b
$ \TLS13State
st -> TLS13State
st{tls13stRecvCF = True}
(usedHash, usedCipher, _, _) <- Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getRxRecordState Context
ctx
let ClientTrafficSecret chs = clientHandshakeSecret
checkFinished ctx usedHash chs hChBeforeCf verifyData
handshakeDone13 ctx
setRxRecordState ctx usedHash usedCipher clientApplicationSecret0
sendNewSessionTicket sparams ctx usedCipher exts applicationSecret sfSentTime
where
applicationSecret :: BaseSecret ApplicationSecret
applicationSecret = SecretTriple ApplicationSecret -> BaseSecret ApplicationSecret
forall a. SecretTriple a -> BaseSecret a
triBase SecretTriple ApplicationSecret
appKey
clientApplicationSecret0 :: ClientTrafficSecret ApplicationSecret
clientApplicationSecret0 = SecretTriple ApplicationSecret
-> ClientTrafficSecret ApplicationSecret
forall a. SecretTriple a -> ClientTrafficSecret a
triClient SecretTriple ApplicationSecret
appKey
expectFinished ServerParams
_ Context
_ [ExtensionRaw]
_ SecretTriple ApplicationSecret
_ ClientTrafficSecret HandshakeSecret
_ Millisecond
_ ByteString
_ Handshake13
hs = String -> Maybe String -> m ()
forall (m :: * -> *) a. MonadIO m => String -> Maybe String -> m a
unexpected (Handshake13 -> String
forall a. Show a => a -> String
show Handshake13
hs) (String -> Maybe String
forall a. a -> Maybe a
Just String
"finished 13")
expectEndOfEarlyData
:: Context -> ClientTrafficSecret HandshakeSecret -> Handshake13 -> IO ()
expectEndOfEarlyData :: Context
-> ClientTrafficSecret HandshakeSecret -> Handshake13 -> IO ()
expectEndOfEarlyData Context
ctx ClientTrafficSecret HandshakeSecret
clientHandshakeSecret Handshake13
EndOfEarlyData13 = do
(usedHash, usedCipher, _, _) <- Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getRxRecordState Context
ctx
setRxRecordState ctx usedHash usedCipher clientHandshakeSecret
expectEndOfEarlyData Context
_ ClientTrafficSecret HandshakeSecret
_ Handshake13
hs = String -> Maybe String -> IO ()
forall (m :: * -> *) a. MonadIO m => String -> Maybe String -> m a
unexpected (Handshake13 -> String
forall a. Show a => a -> String
show Handshake13
hs) (String -> Maybe String
forall a. a -> Maybe a
Just String
"end of early data")
expectCertificate
:: MonadIO m => ServerParams -> Context -> Handshake13 -> m ()
expectCertificate :: forall (m :: * -> *).
MonadIO m =>
ServerParams -> Context -> Handshake13 -> m ()
expectCertificate ServerParams
sparams Context
ctx (Certificate13 ByteString
certCtx (TLSCertificateChain CertificateChain
certs) [[ExtensionRaw]]
_ext) = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ do
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString
certCtx ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString
"") (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
String -> AlertDescription -> TLSError
Error_Protocol String
"certificate request context MUST be empty" AlertDescription
IllegalParameter
ServerParams -> Context -> CertificateChain -> IO ()
clientCertificate ServerParams
sparams Context
ctx CertificateChain
certs
expectCertificate ServerParams
sparams Context
ctx (CompressedCertificate13 ByteString
certCtx (TLSCertificateChain CertificateChain
certs) [[ExtensionRaw]]
_ext) = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ do
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString
certCtx ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString
"") (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
String -> AlertDescription -> TLSError
Error_Protocol String
"certificate request context MUST be empty" AlertDescription
IllegalParameter
ServerParams -> Context -> CertificateChain -> IO ()
clientCertificate ServerParams
sparams Context
ctx CertificateChain
certs
expectCertificate ServerParams
_ Context
_ Handshake13
hs = String -> Maybe String -> m ()
forall (m :: * -> *) a. MonadIO m => String -> Maybe String -> m a
unexpected (Handshake13 -> String
forall a. Show a => a -> String
show Handshake13
hs) (String -> Maybe String
forall a. a -> Maybe a
Just String
"certificate 13")
sendNewSessionTicket
:: ServerParams
-> Context
-> Cipher
-> [ExtensionRaw]
-> BaseSecret ApplicationSecret
-> Word64
-> IO ()
sendNewSessionTicket :: ServerParams
-> Context
-> Cipher
-> [ExtensionRaw]
-> BaseSecret ApplicationSecret
-> Millisecond
-> IO ()
sendNewSessionTicket ServerParams
sparams Context
ctx Cipher
usedCipher [ExtensionRaw]
exts BaseSecret ApplicationSecret
applicationSecret Millisecond
sfSentTime = Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
sendNST (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
cfRecvTime <- IO Millisecond
getCurrentTimeFromBase
let rtt = Millisecond
cfRecvTime Millisecond -> Millisecond -> Millisecond
forall a. Num a => a -> a -> a
- Millisecond
sfSentTime
nonce <- getStateRNG ctx 32
resumptionSecret <- calculateResumptionSecret ctx choice applicationSecret
let life = Int -> Second
forall {a} {a}. (Num a, Integral a) => a -> a
adjustLifetime (Int -> Second) -> Int -> Second
forall a b. (a -> b) -> a -> b
$ ServerParams -> Int
serverTicketLifetime ServerParams
sparams
psk = CipherChoice
-> BaseSecret ResumptionSecret -> ByteString -> ByteString
derivePSK CipherChoice
choice BaseSecret ResumptionSecret
resumptionSecret ByteString
nonce
(identity, add) <- generateSession life psk rtt0max rtt
let nst = Second -> Second -> ByteString -> ByteString -> Int -> Handshake13
forall {p}.
Integral p =>
Second -> Second -> ByteString -> ByteString -> p -> Handshake13
createNewSessionTicket Second
life Second
add ByteString
nonce ByteString
identity Int
rtt0max
sendPacket13 ctx $ Handshake13 [nst]
where
choice :: CipherChoice
choice = Version -> Cipher -> CipherChoice
makeCipherChoice Version
TLS13 Cipher
usedCipher
rtt0max :: Int
rtt0max = Int -> Int
forall a. (Num a, Ord a, FiniteBits a) => a -> a
safeNonNegative32 (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ ServerParams -> Int
serverEarlyDataSize ServerParams
sparams
sendNST :: Bool
sendNST = PskKexMode
PSK_DHE_KE PskKexMode -> [PskKexMode] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [PskKexMode]
dhModes
dhModes :: [PskKexMode]
dhModes = case ExtensionID -> [ExtensionRaw] -> Maybe ByteString
extensionLookup ExtensionID
EID_PskKeyExchangeModes [ExtensionRaw]
exts
Maybe ByteString
-> (ByteString -> Maybe PskKeyExchangeModes)
-> Maybe PskKeyExchangeModes
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MessageType -> ByteString -> Maybe PskKeyExchangeModes
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
MsgTClientHello of
Just (PskKeyExchangeModes [PskKexMode]
ms) -> [PskKexMode]
ms
Maybe PskKeyExchangeModes
Nothing -> []
generateSession :: Second
-> ByteString -> Int -> Millisecond -> IO (ByteString, Second)
generateSession Second
life ByteString
psk Int
maxSize Millisecond
rtt = do
Session (Just sessionId) <- Context -> IO Session
newSession Context
ctx
tinfo <- createTLS13TicketInfo life (Left ctx) (Just rtt)
sdata <- getSessionData13 ctx usedCipher tinfo maxSize psk
let mgr = Shared -> SessionManager
sharedSessionManager (Shared -> SessionManager) -> Shared -> SessionManager
forall a b. (a -> b) -> a -> b
$ ServerParams -> Shared
serverShared ServerParams
sparams
mticket <- sessionEstablish mgr sessionId sdata
let identity = ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
sessionId Maybe ByteString
mticket
return (identity, ageAdd tinfo)
createNewSessionTicket :: Second -> Second -> ByteString -> ByteString -> p -> Handshake13
createNewSessionTicket Second
life Second
add ByteString
nonce ByteString
identity p
maxSize =
Second
-> Second
-> ByteString
-> ByteString
-> [ExtensionRaw]
-> Handshake13
NewSessionTicket13 Second
life Second
add ByteString
nonce ByteString
identity [ExtensionRaw]
extensions
where
earlyDataExt :: ExtensionRaw
earlyDataExt = EarlyDataIndication -> ExtensionRaw
forall e. Extension e => e -> ExtensionRaw
toExtensionRaw (EarlyDataIndication -> ExtensionRaw)
-> EarlyDataIndication -> ExtensionRaw
forall a b. (a -> b) -> a -> b
$ Maybe Second -> EarlyDataIndication
EarlyDataIndication (Maybe Second -> EarlyDataIndication)
-> Maybe Second -> EarlyDataIndication
forall a b. (a -> b) -> a -> b
$ Second -> Maybe Second
forall a. a -> Maybe a
Just (Second -> Maybe Second) -> Second -> Maybe Second
forall a b. (a -> b) -> a -> b
$ p -> Second
forall a b. (Integral a, Num b) => a -> b
fromIntegral p
maxSize
extensions :: [ExtensionRaw]
extensions = [ExtensionRaw
earlyDataExt]
adjustLifetime :: a -> a
adjustLifetime a
i
| a
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 = a
0
| a
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
604800 = a
604800
| Bool
otherwise = a -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
i
expectCertVerify
:: MonadIO m => ServerParams -> Context -> ByteString -> Handshake13 -> m ()
expectCertVerify :: forall (m :: * -> *).
MonadIO m =>
ServerParams -> Context -> ByteString -> Handshake13 -> m ()
expectCertVerify ServerParams
sparams Context
ctx ByteString
hChCc (CertVerify13 (DigitallySigned HashAndSignatureAlgorithm
sigAlg ByteString
sig)) = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ do
certs@(CertificateChain cc) <-
Context -> String -> IO CertificateChain
forall (m :: * -> *).
MonadIO m =>
Context -> String -> m CertificateChain
checkValidClientCertChain Context
ctx String
"invalid client certificate chain"
pubkey <- case cc of
[] -> TLSError -> IO PubKey
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO PubKey) -> TLSError -> IO PubKey
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"client certificate missing" AlertDescription
HandshakeFailure
SignedExact Certificate
c : [SignedExact Certificate]
_ -> PubKey -> IO PubKey
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (PubKey -> IO PubKey) -> PubKey -> IO PubKey
forall a b. (a -> b) -> a -> b
$ Certificate -> PubKey
certPubKey (Certificate -> PubKey) -> Certificate -> PubKey
forall a b. (a -> b) -> a -> b
$ SignedExact Certificate -> Certificate
getCertificate SignedExact Certificate
c
ver <- usingState_ ctx getVersion
checkDigitalSignatureKey ver pubkey
usingHState ctx $ setPublicKey pubkey
verif <- checkCertVerify ctx pubkey sigAlg sig hChCc
clientCertVerify sparams ctx certs verif
expectCertVerify ServerParams
_ Context
_ ByteString
_ Handshake13
hs = String -> Maybe String -> m ()
forall (m :: * -> *) a. MonadIO m => String -> Maybe String -> m a
unexpected (Handshake13 -> String
forall a. Show a => a -> String
show Handshake13
hs) (String -> Maybe String
forall a. a -> Maybe a
Just String
"certificate verify 13")
clientCertVerify :: ServerParams -> Context -> CertificateChain -> Bool -> IO ()
clientCertVerify :: ServerParams -> Context -> CertificateChain -> Bool -> IO ()
clientCertVerify ServerParams
sparams Context
ctx CertificateChain
certs Bool
verif = do
if Bool
verif
then do
Context -> TLSSt () -> IO ()
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt () -> IO ()) -> TLSSt () -> IO ()
forall a b. (a -> b) -> a -> b
$ CertificateChain -> TLSSt ()
setClientCertificateChain CertificateChain
certs
() -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
else do
res <- IO Bool -> IO Bool
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> IO Bool) -> IO Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ ServerHooks -> IO Bool
onUnverifiedClientCert (ServerParams -> ServerHooks
serverHooks ServerParams
sparams)
if res
then do
usingState_ ctx $ setClientCertificateChain certs
else decryptError "verification failed"
postHandshakeAuthServerWith :: ServerParams -> Context -> Handshake13 -> IO ()
postHandshakeAuthServerWith :: ServerParams -> Context -> Handshake13 -> IO ()
postHandshakeAuthServerWith ServerParams
sparams Context
ctx h :: Handshake13
h@(Certificate13 ByteString
certCtx (TLSCertificateChain CertificateChain
certs) [[ExtensionRaw]]
_ext) = ServerParams
-> Context
-> ByteString
-> CertificateChain
-> Handshake13
-> IO ()
processHandshakeAuthServerWith ServerParams
sparams Context
ctx ByteString
certCtx CertificateChain
certs Handshake13
h
postHandshakeAuthServerWith ServerParams
sparams Context
ctx h :: Handshake13
h@(CompressedCertificate13 ByteString
certCtx (TLSCertificateChain CertificateChain
certs) [[ExtensionRaw]]
_ext) = ServerParams
-> Context
-> ByteString
-> CertificateChain
-> Handshake13
-> IO ()
processHandshakeAuthServerWith ServerParams
sparams Context
ctx ByteString
certCtx CertificateChain
certs Handshake13
h
postHandshakeAuthServerWith ServerParams
_ Context
_ Handshake13
_ =
TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
String -> AlertDescription -> TLSError
Error_Protocol
String
"unexpected handshake message received in postHandshakeAuthServerWith"
AlertDescription
UnexpectedMessage
processHandshakeAuthServerWith
:: ServerParams
-> Context
-> CertReqContext
-> CertificateChain
-> Handshake13
-> IO ()
processHandshakeAuthServerWith :: ServerParams
-> Context
-> ByteString
-> CertificateChain
-> Handshake13
-> IO ()
processHandshakeAuthServerWith ServerParams
sparams Context
ctx ByteString
certCtx CertificateChain
certs Handshake13
h = do
mCertReq <- Context -> ByteString -> IO (Maybe Handshake13)
getCertRequest13 Context
ctx ByteString
certCtx
when (isNothing mCertReq) $
throwCore $
Error_Protocol "unknown certificate request context" DecodeError
let certReq = Maybe Handshake13 -> Handshake13
forall a. HasCallStack => Maybe a -> a
fromJust Maybe Handshake13
mCertReq
clientCertificate sparams ctx certs
baseHState <- saveHState ctx
processHandshake13 ctx certReq
processHandshake13 ctx h
(usedHash, _, level, applicationSecretN) <- getRxRecordState ctx
unless (level == CryptApplicationSecret) $
throwCore $
Error_Protocol
"tried post-handshake authentication without application traffic secret"
InternalError
let expectFinished' ByteString
hChBeforeCf (Finished13 VerifyData
verifyData) = do
Context -> Hash -> ByteString -> ByteString -> VerifyData -> IO ()
forall (m :: * -> *).
MonadIO m =>
Context -> Hash -> ByteString -> ByteString -> VerifyData -> m ()
checkFinished Context
ctx Hash
usedHash ByteString
applicationSecretN ByteString
hChBeforeCf VerifyData
verifyData
IO (Saved (Maybe HandshakeState)) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Saved (Maybe HandshakeState)) -> IO ())
-> IO (Saved (Maybe HandshakeState)) -> IO ()
forall a b. (a -> b) -> a -> b
$ Context
-> Saved (Maybe HandshakeState)
-> IO (Saved (Maybe HandshakeState))
restoreHState Context
ctx Saved (Maybe HandshakeState)
baseHState
expectFinished' ByteString
_ Handshake13
hs = String -> Maybe String -> IO ()
forall (m :: * -> *) a. MonadIO m => String -> Maybe String -> m a
unexpected (Handshake13 -> String
forall a. Show a => a -> String
show Handshake13
hs) (String -> Maybe String
forall a. a -> Maybe a
Just String
"finished 13")
if isNullCertificateChain certs
then setPendingRecvActions ctx [PendingRecvActionHash False expectFinished']
else
setPendingRecvActions
ctx
[ PendingRecvActionHash False (expectCertVerify sparams ctx)
, PendingRecvActionHash False expectFinished'
]