/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.core.protocol.handler;

import de.rub.nds.modifiablevariable.util.ArrayConverter;
import de.rub.nds.tlsattacker.core.constants.AlgorithmResolver;
import de.rub.nds.tlsattacker.core.constants.CipherSuite;
import de.rub.nds.tlsattacker.core.constants.CompressionMethod;
import de.rub.nds.tlsattacker.core.constants.DigestAlgorithm;
import de.rub.nds.tlsattacker.core.constants.HKDFAlgorithm;
import de.rub.nds.tlsattacker.core.constants.HandshakeMessageType;
import de.rub.nds.tlsattacker.core.constants.ProtocolVersion;
import de.rub.nds.tlsattacker.core.constants.Tls13KeySetType;
import de.rub.nds.tlsattacker.core.crypto.HKDFunction;
import de.rub.nds.tlsattacker.core.crypto.ec.CurveFactory;
import de.rub.nds.tlsattacker.core.crypto.ec.EllipticCurve;
import de.rub.nds.tlsattacker.core.crypto.ec.Point;
import de.rub.nds.tlsattacker.core.crypto.ec.PointFormatter;
import de.rub.nds.tlsattacker.core.crypto.ec.RFC7748Curve;
import de.rub.nds.tlsattacker.core.exceptions.AdjustmentException;
import de.rub.nds.tlsattacker.core.exceptions.CryptoException;
import de.rub.nds.tlsattacker.core.protocol.handler.HandshakeMessageHandler;
import de.rub.nds.tlsattacker.core.protocol.message.ServerHelloMessage;
import de.rub.nds.tlsattacker.core.protocol.message.computations.PWDComputations;
import de.rub.nds.tlsattacker.core.protocol.message.extension.keyshare.DragonFlyKeyShareEntry;
import de.rub.nds.tlsattacker.core.protocol.message.extension.keyshare.KeyShareStoreEntry;
import de.rub.nds.tlsattacker.core.protocol.parser.ServerHelloParser;
import de.rub.nds.tlsattacker.core.protocol.parser.extension.keyshare.DragonFlyKeyShareEntryParser;
import de.rub.nds.tlsattacker.core.protocol.preparator.ServerHelloPreparator;
import de.rub.nds.tlsattacker.core.protocol.serializer.ServerHelloSerializer;
import de.rub.nds.tlsattacker.core.record.cipher.RecordCipher;
import de.rub.nds.tlsattacker.core.record.cipher.RecordCipherFactory;
import de.rub.nds.tlsattacker.core.record.cipher.cryptohelper.KeySet;
import de.rub.nds.tlsattacker.core.record.cipher.cryptohelper.KeySetGenerator;
import de.rub.nds.tlsattacker.core.state.Session;
import de.rub.nds.tlsattacker.core.state.TlsContext;
import de.rub.nds.tlsattacker.core.workflow.chooser.Chooser;
import de.rub.nds.tlsattacker.transport.ConnectionEndType;
import java.math.BigInteger;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import javax.crypto.Mac;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class ServerHelloHandler
extends HandshakeMessageHandler<ServerHelloMessage> {
    private static final Logger LOGGER = LogManager.getLogger();

    public ServerHelloHandler(TlsContext tlsContext) {
        super(tlsContext);
    }

    public ServerHelloPreparator getPreparator(ServerHelloMessage message) {
        return new ServerHelloPreparator(this.tlsContext.getChooser(), message);
    }

    public ServerHelloSerializer getSerializer(ServerHelloMessage message) {
        return new ServerHelloSerializer(message, this.tlsContext.getChooser().getSelectedProtocolVersion());
    }

    @Override
    public ServerHelloParser getParser(byte[] message, int pointer) {
        return new ServerHelloParser(pointer, message, this.tlsContext.getChooser().getLastRecordVersion(), this.tlsContext.getConfig());
    }

    @Override
    public void adjustTLSContext(ServerHelloMessage message) {
        this.adjustSelectedProtocolVersion(message);
        this.adjustSelectedCompression(message);
        this.adjustSelectedSessionID(message);
        this.adjustSelectedCiphersuite(message);
        this.adjustServerRandom(message);
        this.adjustExtensions(message, HandshakeMessageType.SERVER_HELLO);
        if (!message.isTls13HelloRetryRequest().booleanValue()) {
            if (this.tlsContext.getChooser().getSelectedProtocolVersion().isTLS13()) {
                this.adjustHandshakeTrafficSecrets();
                if (this.tlsContext.getTalkingConnectionEndType() != this.tlsContext.getChooser().getConnectionEndType()) {
                    this.setServerRecordCipher();
                }
            }
            this.adjustPRF(message);
            if (this.tlsContext.hasSession(this.tlsContext.getChooser().getServerSessionId())) {
                LOGGER.info("Resuming Session");
                LOGGER.debug("Loading Mastersecret");
                Session session = this.tlsContext.getSession(this.tlsContext.getChooser().getServerSessionId());
                this.tlsContext.setMasterSecret(session.getMasterSecret());
                this.setRecordCipher();
            }
        }
    }

    private void adjustSelectedCiphersuite(ServerHelloMessage message) {
        CipherSuite suite = null;
        if (message.getSelectedCipherSuite() != null) {
            suite = CipherSuite.getCipherSuite((byte[])message.getSelectedCipherSuite().getValue());
        }
        if (suite != null) {
            this.tlsContext.setSelectedCipherSuite(suite);
            LOGGER.debug("Set SelectedCipherSuite in Context to " + suite.name());
        } else {
            LOGGER.warn("Unknown CipherSuite, did not adjust Context");
        }
    }

    private void adjustServerRandom(ServerHelloMessage message) {
        this.tlsContext.setServerRandom((byte[])message.getRandom().getValue());
        LOGGER.debug("Set ServerRandom in Context to " + ArrayConverter.bytesToHexString((byte[])this.tlsContext.getServerRandom()));
    }

    private void adjustSelectedCompression(ServerHelloMessage message) {
        CompressionMethod method = null;
        if (message.getSelectedCompressionMethod() != null) {
            method = CompressionMethod.getCompressionMethod((Byte)message.getSelectedCompressionMethod().getValue());
        }
        if (method != null) {
            this.tlsContext.setSelectedCompressionMethod(method);
            LOGGER.debug("Set SelectedCompressionMethod in Context to " + method.name());
        } else {
            LOGGER.warn("Not adjusting CompressionMethod - Method is null!");
        }
    }

    private void adjustSelectedSessionID(ServerHelloMessage message) {
        byte[] sessionID = (byte[])message.getSessionId().getValue();
        this.tlsContext.setServerSessionId(sessionID);
        LOGGER.debug("Set SessionID in Context to " + ArrayConverter.bytesToHexString((byte[])sessionID, (boolean)false));
    }

    private void adjustSelectedProtocolVersion(ServerHelloMessage message) {
        ProtocolVersion version = null;
        if (message.getProtocolVersion() != null) {
            version = ProtocolVersion.getProtocolVersion((byte[])message.getProtocolVersion().getValue());
        }
        if (version != null) {
            this.tlsContext.setSelectedProtocolVersion(version);
            LOGGER.debug("Set SelectedProtocolVersion in Context to " + version.name());
        } else {
            LOGGER.warn("Did not Adjust ProtocolVersion since version is undefined " + ArrayConverter.bytesToHexString((byte[])((byte[])message.getProtocolVersion().getValue())));
        }
    }

    private void adjustPRF(ServerHelloMessage message) {
        Chooser chooser = this.tlsContext.getChooser();
        if (!chooser.getSelectedProtocolVersion().isSSL()) {
            this.tlsContext.setPrfAlgorithm(AlgorithmResolver.getPRFAlgorithm(chooser.getSelectedProtocolVersion(), chooser.getSelectedCipherSuite()));
        }
    }

    private void setRecordCipher() {
        KeySet keySet = this.getKeySet(this.tlsContext, Tls13KeySetType.NONE);
        LOGGER.debug("Setting new Cipher in RecordLayer");
        RecordCipher recordCipher = RecordCipherFactory.getRecordCipher(this.tlsContext, keySet);
        this.tlsContext.getRecordLayer().setRecordCipher(recordCipher);
    }

    private void setServerRecordCipher() {
        this.tlsContext.setActiveServerKeySetType(Tls13KeySetType.HANDSHAKE_TRAFFIC_SECRETS);
        LOGGER.debug("Setting cipher for server to use handshake secrets");
        KeySet serverKeySet = this.getKeySet(this.tlsContext, this.tlsContext.getActiveServerKeySetType());
        RecordCipher recordCipherServer = RecordCipherFactory.getRecordCipher(this.tlsContext, serverKeySet, this.tlsContext.getChooser().getSelectedCipherSuite());
        this.tlsContext.getRecordLayer().setRecordCipher(recordCipherServer);
        if (this.tlsContext.getChooser().getConnectionEndType() == ConnectionEndType.CLIENT) {
            this.tlsContext.setReadSequenceNumber(0L);
            this.tlsContext.getRecordLayer().updateDecryptionCipher();
        } else {
            this.tlsContext.setWriteSequenceNumber(0L);
            this.tlsContext.getRecordLayer().updateEncryptionCipher();
        }
    }

    private KeySet getKeySet(TlsContext context, Tls13KeySetType keySetType) {
        try {
            LOGGER.debug("Generating new KeySet");
            return KeySetGenerator.generateKeySet(context, this.tlsContext.getChooser().getSelectedProtocolVersion(), keySetType);
        }
        catch (CryptoException | NoSuchAlgorithmException ex) {
            throw new UnsupportedOperationException("The specified Algorithm is not supported", ex);
        }
    }

    @Override
    public void adjustTlsContextAfterSerialize(ServerHelloMessage message) {
        if (this.tlsContext.getChooser().getSelectedProtocolVersion().isTLS13()) {
            this.setServerRecordCipher();
        }
    }

    private void adjustHandshakeTrafficSecrets() {
        HKDFAlgorithm hkdfAlgortihm = AlgorithmResolver.getHKDFAlgorithm(this.tlsContext.getChooser().getSelectedCipherSuite());
        DigestAlgorithm digestAlgo = AlgorithmResolver.getDigestAlgorithm(this.tlsContext.getChooser().getSelectedProtocolVersion(), this.tlsContext.getChooser().getSelectedCipherSuite());
        try {
            int macLength = Mac.getInstance(hkdfAlgortihm.getMacAlgorithm().getJavaName()).getMacLength();
            byte[] psk = this.tlsContext.getConfig().isUsePsk() != false || this.tlsContext.getPsk() != null ? this.tlsContext.getChooser().getPsk() : new byte[macLength];
            byte[] earlySecret = HKDFunction.extract(hkdfAlgortihm, new byte[0], psk);
            byte[] saltHandshakeSecret = HKDFunction.deriveSecret(hkdfAlgortihm, digestAlgo.getJavaName(), earlySecret, "derived", ArrayConverter.hexStringToByteArray((String)""));
            byte[] sharedSecret = new byte[macLength];
            if (this.tlsContext.getChooser().getConnectionEndType() == ConnectionEndType.CLIENT) {
                if (this.tlsContext.getSelectedCipherSuite().isPWD()) {
                    sharedSecret = this.computeSharedPWDSecret(this.tlsContext.getChooser().getServerKeyShare());
                } else {
                    sharedSecret = this.computeSharedSecret(this.tlsContext.getChooser().getServerKeyShare());
                    if (this.tlsContext.getConfig().getDefaultPreMasterSecret().length > 0) {
                        LOGGER.debug("Using specified PMS instead of computed PMS");
                        sharedSecret = this.tlsContext.getConfig().getDefaultPreMasterSecret();
                    }
                }
            } else {
                Integer pos = null;
                for (KeyShareStoreEntry entry : this.tlsContext.getChooser().getClientKeyShares()) {
                    if (!Arrays.equals(entry.getGroup().getValue(), this.tlsContext.getChooser().getServerKeyShare().getGroup().getValue())) continue;
                    pos = this.tlsContext.getChooser().getClientKeyShares().indexOf(entry);
                }
                if (pos == null) {
                    LOGGER.warn("Client did not send the KeyShareType we expected. Choosing first in his List");
                    pos = 0;
                }
                sharedSecret = this.tlsContext.getSelectedCipherSuite().isPWD() ? this.computeSharedPWDSecret(this.tlsContext.getChooser().getClientKeyShares().get(pos)) : this.computeSharedSecret(this.tlsContext.getChooser().getClientKeyShares().get(pos));
            }
            byte[] handshakeSecret = HKDFunction.extract(hkdfAlgortihm, saltHandshakeSecret, sharedSecret);
            this.tlsContext.setHandshakeSecret(handshakeSecret);
            LOGGER.debug("Set handshakeSecret in Context to " + ArrayConverter.bytesToHexString((byte[])handshakeSecret));
            byte[] clientHandshakeTrafficSecret = HKDFunction.deriveSecret(hkdfAlgortihm, digestAlgo.getJavaName(), handshakeSecret, "c hs traffic", this.tlsContext.getDigest().getRawBytes());
            this.tlsContext.setClientHandshakeTrafficSecret(clientHandshakeTrafficSecret);
            LOGGER.debug("Set clientHandshakeTrafficSecret in Context to " + ArrayConverter.bytesToHexString((byte[])clientHandshakeTrafficSecret));
            byte[] serverHandshakeTrafficSecret = HKDFunction.deriveSecret(hkdfAlgortihm, digestAlgo.getJavaName(), handshakeSecret, "s hs traffic", this.tlsContext.getDigest().getRawBytes());
            this.tlsContext.setServerHandshakeTrafficSecret(serverHandshakeTrafficSecret);
            LOGGER.debug("Set serverHandshakeTrafficSecret in Context to " + ArrayConverter.bytesToHexString((byte[])serverHandshakeTrafficSecret));
        }
        catch (CryptoException | NoSuchAlgorithmException ex) {
            throw new AdjustmentException(ex);
        }
    }

    private byte[] computeSharedSecret(KeyShareStoreEntry keyShare) {
        EllipticCurve curve = CurveFactory.getCurve(keyShare.getGroup());
        Point publicPoint = PointFormatter.formatFromByteArray(keyShare.getGroup(), keyShare.getPublicKey());
        this.tlsContext.setServerEcPublicKey(publicPoint);
        this.tlsContext.setSelectedGroup(keyShare.getGroup());
        BigInteger privateKey = this.tlsContext.getConfig().getKeySharePrivate();
        switch (keyShare.getGroup()) {
            case ECDH_X25519: 
            case ECDH_X448: {
                RFC7748Curve rfcCurve = (RFC7748Curve)curve;
                return rfcCurve.computeSharedSecretDecodedPoint(privateKey, publicPoint);
            }
            case SECP160K1: 
            case SECP160R1: 
            case SECP160R2: 
            case SECP192K1: 
            case SECP192R1: 
            case SECP224K1: 
            case SECP224R1: 
            case SECP256K1: 
            case SECP256R1: 
            case SECP384R1: 
            case SECP521R1: 
            case SECT163K1: 
            case SECT163R1: 
            case SECT163R2: 
            case SECT193R1: 
            case SECT193R2: 
            case SECT233K1: 
            case SECT233R1: 
            case SECT239K1: 
            case SECT283K1: 
            case SECT283R1: 
            case SECT409K1: 
            case SECT409R1: 
            case SECT571K1: 
            case SECT571R1: {
                Point sharedPoint = curve.mult(privateKey, publicPoint);
                int elementLenght = ArrayConverter.bigIntegerToByteArray((BigInteger)sharedPoint.getX().getModulus()).length;
                return ArrayConverter.bigIntegerToNullPaddedByteArray((BigInteger)sharedPoint.getX().getData(), (int)elementLenght);
            }
        }
        throw new UnsupportedOperationException("KeyShare type " + (Object)((Object)keyShare.getGroup()) + " is unsupported");
    }

    private byte[] computeSharedPWDSecret(KeyShareStoreEntry keyShare) throws CryptoException {
        Chooser chooser = this.tlsContext.getChooser();
        EllipticCurve curve = CurveFactory.getCurve(keyShare.getGroup());
        DragonFlyKeyShareEntryParser parser = new DragonFlyKeyShareEntryParser(keyShare.getPublicKey(), keyShare.getGroup());
        DragonFlyKeyShareEntry dragonFlyKeyShareEntry = parser.parse();
        int curveSize = curve.getModulus().bitLength();
        Point keySharePoint = PointFormatter.fromRawFormat(keyShare.getGroup(), dragonFlyKeyShareEntry.getRawPublicKey());
        BigInteger scalar = dragonFlyKeyShareEntry.getScalar();
        Point passwordElement = PWDComputations.computePasswordElement(this.tlsContext.getChooser(), curve);
        BigInteger privateKeyScalar = chooser.getConnectionEndType() == ConnectionEndType.CLIENT ? new BigInteger(1, chooser.getConfig().getDefaultClientPWDPrivate()).mod(curve.getBasePointOrder()) : new BigInteger(1, chooser.getConfig().getDefaultServerPWDPrivate()).mod(curve.getBasePointOrder());
        LOGGER.debug("Element: " + ArrayConverter.bytesToHexString((byte[])PointFormatter.toRawFormat(keySharePoint)));
        LOGGER.debug("Scalar: " + ArrayConverter.bytesToHexString((byte[])ArrayConverter.bigIntegerToByteArray((BigInteger)scalar)));
        Point sharedSecret = curve.mult(privateKeyScalar, curve.add(curve.mult(scalar, passwordElement), keySharePoint));
        return ArrayConverter.bigIntegerToByteArray((BigInteger)sharedSecret.getX().getData(), (int)(curveSize / 8), (boolean)true);
    }
}

