/*
 * Decompiled with CFR 0.152.
 */
package com.mindspore.flclient;

import com.mindspore.flclient.CipherClient;
import com.mindspore.flclient.Common;
import com.mindspore.flclient.FLClientStatus;
import com.mindspore.flclient.FLParameter;
import com.mindspore.flclient.LocalFLParameter;
import com.mindspore.flclient.common.FLLoggerGenerater;
import com.mindspore.flclient.model.Client;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;

public class SecureProtocol {
    private static final Logger LOGGER = FLLoggerGenerater.getModelLogger(SecureProtocol.class.toString());
    private static double deltaError = 1.0E-6;
    private static Map<String, float[]> modelMap;
    private FLParameter flParameter = FLParameter.getInstance();
    private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
    private int iteration;
    private CipherClient cipherClient;
    private FLClientStatus status;
    private float[] featureMask = new float[0];
    private double dpEps;
    private double dpDelta;
    private double dpNormClip;
    private ArrayList<String> updateFeatureName = new ArrayList();
    private int retCode;
    private float signK;
    private float signEps;
    private float signThrRatio;
    private float signGlobalLr;
    private int signDimOut;

    public FLClientStatus getStatus() {
        return this.status;
    }

    public int getRetCode() {
        return this.retCode;
    }

    public void setPWParameter(int iter, int minSecretNum, byte[] prime, int featureSize) {
        if (prime == null || prime.length == 0) {
            LOGGER.severe("[PairWiseMask] the input argument <prime> is null, please check!");
            throw new IllegalArgumentException();
        }
        this.iteration = iter;
        this.cipherClient = new CipherClient(this.iteration, minSecretNum, prime, featureSize);
    }

    public FLClientStatus setDPParameter(int iter, double diffEps, double diffDelta, double diffNorm) {
        this.iteration = iter;
        this.dpEps = diffEps;
        this.dpDelta = diffDelta;
        this.dpNormClip = diffNorm;
        return FLClientStatus.SUCCESS;
    }

    public FLClientStatus setDSParameter(float signK, float signEps, float signThrRatio, float signGlobalLr, int signDimOut) {
        this.signK = signK;
        this.signEps = signEps;
        this.signThrRatio = signThrRatio;
        this.signGlobalLr = signGlobalLr;
        this.signDimOut = signDimOut;
        return FLClientStatus.SUCCESS;
    }

    public ArrayList<String> getUpdateFeatureName() {
        return this.updateFeatureName;
    }

    public void setUpdateFeatureName(ArrayList<String> updateFeatureName) {
        this.updateFeatureName = updateFeatureName;
    }

    public String getNextRequestTime() {
        return this.cipherClient.getNextRequestTime();
    }

    public double getDpNormClip() {
        return this.dpNormClip;
    }

    public FLClientStatus pwCreateMask() {
        LOGGER.info(String.format("[PairWiseMask] ==============request flID: %s ==============", this.localFLParameter.getFlID()));
        if (this.localFLParameter.isStopJobFlag()) {
            LOGGER.info("the stopJObFlag is set to true, the job will be stop");
            return this.status;
        }
        this.status = this.cipherClient.exchangeKeys();
        this.retCode = this.cipherClient.getRetCode();
        LOGGER.info(String.format("[PairWiseMask] ============= RequestExchangeKeys+GetExchangeKeys response: %s ", new Object[]{"============", this.status}));
        if (this.status != FLClientStatus.SUCCESS) {
            return this.status;
        }
        if (this.localFLParameter.isStopJobFlag()) {
            LOGGER.info("the stopJObFlag is set to true, the job will be stop");
            return this.status;
        }
        this.status = this.cipherClient.shareSecrets();
        this.retCode = this.cipherClient.getRetCode();
        LOGGER.info(String.format("[Encrypt] =============RequestShareSecrets+GetShareSecrets response: %s ", new Object[]{"=============", this.status}));
        if (this.status != FLClientStatus.SUCCESS) {
            return this.status;
        }
        if (this.localFLParameter.isStopJobFlag()) {
            LOGGER.info("the stopJObFlag is set to true, the job will be stop");
            return this.status;
        }
        this.featureMask = this.cipherClient.doubleMaskingWeight();
        if (this.featureMask == null || this.featureMask.length <= 0) {
            LOGGER.severe("[Encrypt] the returned featureMask from cipherClient.doubleMaskingWeight is null, please check!");
            return FLClientStatus.FAILED;
        }
        this.retCode = this.cipherClient.getRetCode();
        LOGGER.info("[Encrypt] =============Create double feature mask: SUCCESS=============");
        return this.status;
    }

    public float[] pwMaskWeight(int trainDataSize, float[] feature, int maskIndex) {
        HashMap featureMaps = new HashMap();
        if (this.featureMask == null || this.featureMask.length == 0) {
            throw new RuntimeException("[pwMaskWeight] feature mask is null, please check");
        }
        if (this.featureMask.length < maskIndex + feature.length) {
            throw new RuntimeException("[pwMaskWeight] the data length is out of range for array featureMask, featureMask length:" + this.featureMask.length + " data length:" + feature.length);
        }
        LOGGER.info(String.format("[pwMaskWeight] feature mask size: %s", this.featureMask.length));
        float[] maskedData = new float[feature.length];
        LOGGER.info(String.format("[pwMaskWeight] feature  size: %s", feature.length));
        for (int j = 0; j < feature.length; ++j) {
            maskedData[j] = feature[j] * (float)trainDataSize + this.featureMask[maskIndex + j];
        }
        return maskedData;
    }

    public FLClientStatus pwUnmasking() {
        this.status = this.cipherClient.reconstructSecrets();
        this.retCode = this.cipherClient.getRetCode();
        LOGGER.info(String.format("[Encrypt] =============GetClientList+SendReconstructSecret: %s =============", new Object[]{this.status}));
        return this.status;
    }

    private static float calculateErf(double erfInput) {
        double result = 0.0;
        int segmentNum = 10000;
        double deltaX = erfInput / (double)segmentNum;
        result += 1.0;
        for (int i = 1; i < segmentNum; ++i) {
            result += 2.0 * Math.exp(-Math.pow(deltaX * (double)i, 2.0));
        }
        return (float)((result += Math.exp(-Math.pow(deltaX * (double)segmentNum, 2.0))) * deltaX / Math.pow(Math.PI, 0.5));
    }

    private static double calculatePhi(double phiInput) {
        return 0.5 * (1.0 + (double)SecureProtocol.calculateErf(phiInput / Math.sqrt(2.0)));
    }

    private static double calculateBPositive(double eps, double calInput) {
        return SecureProtocol.calculatePhi(Math.sqrt(eps * calInput)) - Math.exp(eps) * SecureProtocol.calculatePhi(-Math.sqrt(eps * (calInput + 2.0)));
    }

    private static double calculateBNegative(double eps, double calInput) {
        return SecureProtocol.calculatePhi(-Math.sqrt(eps * calInput)) - Math.exp(eps) * SecureProtocol.calculatePhi(-Math.sqrt(eps * (calInput + 2.0)));
    }

    private static double calculateSPositive(double eps, double targetDelta, double initSInf, double initSSup) {
        double deltaSup = SecureProtocol.calculateBPositive(eps, initSSup);
        double sInf = initSInf;
        double sSup = initSSup;
        while (deltaSup <= targetDelta) {
            sInf = sSup;
            sSup = 2.0 * sInf;
            deltaSup = SecureProtocol.calculateBPositive(eps, sSup);
        }
        double sMid = sInf + (sSup - sInf) / 2.0;
        int iterMax = 1000;
        int iters = 0;
        do {
            double bPositive;
            if ((bPositive = SecureProtocol.calculateBPositive(eps, sMid)) <= targetDelta) {
                if (targetDelta - bPositive <= deltaError) break;
                sInf = sMid;
            } else {
                sSup = sMid;
            }
            sMid = sInf + (sSup - sInf) / 2.0;
        } while (++iters <= iterMax);
        return sMid;
    }

    private static double calculateSNegative(double eps, double targetDelta, double initSInf, double initSSup) {
        double deltaSup = SecureProtocol.calculateBNegative(eps, initSSup);
        double sInf = initSInf;
        double sSup = initSSup;
        while (deltaSup > targetDelta) {
            sInf = sSup;
            sSup = 2.0 * sInf;
            deltaSup = SecureProtocol.calculateBNegative(eps, sSup);
        }
        double sMid = sInf + (sSup - sInf) / 2.0;
        int iterMax = 1000;
        int iters = 0;
        do {
            double bNegative;
            if ((bNegative = SecureProtocol.calculateBNegative(eps, sMid)) <= targetDelta) {
                if (targetDelta - bNegative <= deltaError) break;
                sSup = sMid;
            } else {
                sInf = sMid;
            }
            sMid = sInf + (sSup - sInf) / 2.0;
        } while (++iters <= iterMax);
        return sMid;
    }

    public double calculateSigma() {
        double deltaZero = SecureProtocol.calculateBPositive(this.dpEps, 0.0);
        double alpha = 1.0;
        if (this.dpDelta > deltaZero) {
            double sPositive = SecureProtocol.calculateSPositive(this.dpEps, this.dpDelta, 0.0, 1.0);
            alpha = Math.sqrt(1.0 + sPositive / 2.0) - Math.sqrt(sPositive / 2.0);
        } else if (this.dpDelta < deltaZero) {
            double sNegative = SecureProtocol.calculateSNegative(this.dpEps, this.dpDelta, 0.0, 1.0);
            alpha = Math.sqrt(1.0 + sNegative / 2.0) + Math.sqrt(sNegative / 2.0);
        } else {
            LOGGER.info("[Encrypt] targetDelta = deltaZero");
        }
        return alpha * this.dpNormClip / Math.sqrt(2.0 * this.dpEps);
    }

    private static double comb(double n, double k) {
        boolean cond = k <= n && n >= 0.0 && k >= 0.0;
        double m = n + 1.0;
        if (!cond) {
            return 0.0;
        }
        double nTerm = Math.min(k, n - k);
        double res = 1.0;
        int i = 1;
        while ((double)i <= nTerm) {
            res *= m - (double)i;
            res /= (double)i;
            ++i;
        }
        return res;
    }

    private static double countCombs(int numInter, int topkDim, int inputDim, int outputDim) {
        return SecureProtocol.comb(topkDim, numInter) * SecureProtocol.comb(inputDim - topkDim, outputDim - numInter);
    }

    private static List<Double> calcPmf(int thr, int topkDim, int inputDim, int outputDim, float eps) {
        int i;
        ArrayList<Double> pmf = new ArrayList<Double>();
        for (int v = 0; v <= outputDim; ++v) {
            double newPmf = v < thr ? SecureProtocol.countCombs(v, topkDim, inputDim, outputDim) : SecureProtocol.countCombs(v, topkDim, inputDim, outputDim) * Math.exp(eps);
            pmf.add(newPmf);
        }
        double pmfSum = 0.0;
        for (i = 0; i < pmf.size(); ++i) {
            pmfSum += ((Double)pmf.get(i)).doubleValue();
        }
        if (pmfSum == 0.0) {
            LOGGER.severe("[SignDS] probability mass function is 0, please check");
            return new ArrayList<Double>();
        }
        for (i = 0; i < pmf.size(); ++i) {
            pmf.set(i, (Double)pmf.get(i) / pmfSum);
        }
        return pmf;
    }

    private static double calcExpectation(List<Double> pmf) {
        double sumExpectation = 0.0;
        for (int i = 0; i < pmf.size(); ++i) {
            sumExpectation += (double)i * pmf.get(i);
        }
        return sumExpectation;
    }

    private static int calcOptThr(int topkDim, int inputDim, int outputDim, float eps) {
        double newExpect;
        double optExpect = 0.0;
        double optT = 0.0;
        for (int t = 1; t <= outputDim && (newExpect = SecureProtocol.calcExpectation(SecureProtocol.calcPmf(t, topkDim, inputDim, outputDim, eps))) > optExpect; ++t) {
            optExpect = newExpect;
            optT = t;
        }
        return (int)Math.max(optT, 1.0);
    }

    private static int findOptOutputDim(float thrInterRatio, int topkDim, int inputDim, float eps) {
        int thr;
        double expectedRatio;
        int outputDim = 1;
        while (!((expectedRatio = SecureProtocol.calcExpectation(SecureProtocol.calcPmf(thr = SecureProtocol.calcOptThr(topkDim, inputDim, outputDim, eps), topkDim, inputDim, outputDim, eps)) / (double)outputDim) < (double)thrInterRatio) && !Double.isNaN(expectedRatio)) {
            ++outputDim;
        }
        return Math.max(1, outputDim - 1);
    }

    private static int countInters(int thrDim, double denominator, int topkDim, int inputDim, int outputDim, float eps) {
        SecureRandom secureRandom = new SecureRandom();
        double randomProb = secureRandom.nextDouble();
        int numInter = 0;
        double prob = SecureProtocol.countCombs(numInter, topkDim, inputDim, outputDim) / denominator;
        while (prob < randomProb) {
            if (++numInter < thrDim) {
                prob += SecureProtocol.countCombs(numInter, topkDim, inputDim, outputDim) / denominator;
                continue;
            }
            prob += Math.exp(eps) * SecureProtocol.countCombs(numInter, topkDim, inputDim, outputDim) / denominator;
        }
        return numInter;
    }

    private static void randomSelect(SecureRandom secureRandom, int[] inputList, int selectNums, int[] outputList, int outStartPos) {
        if (selectNums <= 0) {
            LOGGER.severe("[SignDS] The number to be selected is set incorrectly!");
            return;
        }
        if (inputList.length < selectNums) {
            LOGGER.severe("[SignDS] The size of inputList is small than num!");
            return;
        }
        for (int i = inputList.length; i > inputList.length - selectNums; --i) {
            int randomIndex = secureRandom.nextInt(i);
            int randomSelectTopkIndex = inputList[randomIndex];
            inputList[randomIndex] = inputList[i - 1];
            inputList[i - 1] = randomSelectTopkIndex;
            outputList[outStartPos + inputList.length - i] = randomSelectTopkIndex;
        }
    }

    public int[] signDSModel(Client client, boolean sign) {
        int i;
        int i2;
        int layerNum = this.updateFeatureName.size();
        int inputDim = 0;
        for (int i3 = 0; i3 < layerNum; ++i3) {
            String key = this.updateFeatureName.get(i3);
            float[] dataBeforeTrain = client.getPreFeature(key);
            inputDim += dataBeforeTrain.length;
        }
        int topkDim = (int)(this.signK * (float)inputDim);
        if (this.signDimOut == 0) {
            this.signDimOut = SecureProtocol.findOptOutputDim(this.signThrRatio, topkDim, inputDim, this.signEps);
        }
        int thrDim = SecureProtocol.calcOptThr(topkDim, inputDim, this.signDimOut, this.signEps);
        double combLessInter = 0.0;
        double combMoreInter = 0.0;
        for (i2 = 0; i2 < thrDim; ++i2) {
            combLessInter += SecureProtocol.countCombs(i2, topkDim, inputDim, this.signDimOut);
        }
        for (i2 = thrDim; i2 <= this.signDimOut; ++i2) {
            combMoreInter += SecureProtocol.countCombs(i2, topkDim, inputDim, this.signDimOut);
        }
        double denominator = combLessInter + Math.exp(this.signEps) * combMoreInter;
        if (denominator == 0.0) {
            LOGGER.severe("[SignDS] denominator is 0, please check");
            return new int[0];
        }
        int numInter = SecureProtocol.countInters(thrDim, denominator, topkDim, inputDim, this.signDimOut, this.signEps);
        int numOuter = this.signDimOut - numInter;
        if (topkDim < numInter || this.signDimOut <= 0) {
            LOGGER.severe("[SignDS] topkDim or signDimOut is ERROR! please check");
            return new int[0];
        }
        float[] originData = new float[inputDim];
        Integer[] originIndex = new Integer[inputDim];
        int index = 0;
        for (int i4 = 0; i4 < layerNum; ++i4) {
            String key = this.updateFeatureName.get(i4);
            float[] dataAfterTrain = client.getFeature(key);
            float[] dataBeforeTrain = client.getPreFeature(key);
            for (int j = 0; j < dataAfterTrain.length; ++j) {
                float updateData;
                originData[index] = updateData = dataAfterTrain[j] - dataBeforeTrain[j];
                originIndex[index] = index;
                ++index;
            }
        }
        if (sign) {
            Arrays.sort(originIndex, (l, r) -> Float.compare(originData[r], originData[l]));
        } else {
            Arrays.sort(originIndex, (l, r) -> Float.compare(originData[l], originData[r]));
        }
        int[] nonTopkKeyList = new int[inputDim - topkDim];
        int[] topkKeyList = new int[topkDim];
        for (i = 0; i < topkDim; ++i) {
            topkKeyList[i] = originIndex[i];
        }
        for (i = topkDim; i < inputDim; ++i) {
            nonTopkKeyList[i - topkDim] = originIndex[i];
        }
        int[] outputDimensionIndexList = new int[numInter + numOuter];
        SecureRandom secureRandom = Common.getSecureRandom();
        SecureProtocol.randomSelect(secureRandom, topkKeyList, numInter, outputDimensionIndexList, 0);
        SecureProtocol.randomSelect(secureRandom, nonTopkKeyList, numOuter, outputDimensionIndexList, numInter);
        Arrays.sort(outputDimensionIndexList);
        LOGGER.info("[SignDS] outputDimension size is " + outputDimensionIndexList.length);
        return outputDimensionIndexList;
    }
}

