/*
 * Decompiled with CFR 0.152.
 */
package com.mindspore.flclient;

import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.CipherClient;
import com.mindspore.flclient.Common;
import com.mindspore.flclient.EncryptLevel;
import com.mindspore.flclient.FLClientStatus;
import com.mindspore.flclient.FLParameter;
import com.mindspore.flclient.LocalFLParameter;
import com.mindspore.flclient.SecureProtocol;
import com.mindspore.flclient.common.FLLoggerGenerater;
import com.mindspore.flclient.compression.CompressWeight;
import com.mindspore.flclient.compression.EncodeExecutor;
import com.mindspore.flclient.model.Client;
import com.mindspore.flclient.model.ClientManager;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Date;
import java.util.logging.Logger;
import mindspore.fl.schema.CompressFeatureMap;
import mindspore.fl.schema.FeatureMap;
import mindspore.fl.schema.RequestUpdateModel;
import mindspore.fl.schema.ResponseUpdateModel;

public class UpdateModel {
    private static final Logger LOGGER = FLLoggerGenerater.getModelLogger(UpdateModel.class.toString());
    private static volatile UpdateModel updateModel;
    private FLParameter flParameter = FLParameter.getInstance();
    private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
    private FLClientStatus status;
    private int retCode = 400;

    private UpdateModel() {
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public static UpdateModel getInstance() {
        UpdateModel localRef = updateModel;
        if (localRef != null) return localRef;
        Class<UpdateModel> clazz = UpdateModel.class;
        synchronized (UpdateModel.class) {
            localRef = updateModel;
            if (localRef != null) return localRef;
            updateModel = localRef = new UpdateModel();
            // ** MonitorExit[var1_1] (shouldn't be in output)
            return localRef;
        }
    }

    public FLClientStatus getStatus() {
        return this.status;
    }

    public int getRetCode() {
        return this.retCode;
    }

    public byte[] getRequestUpdateFLJob(int iteration, SecureProtocol secureProtocol, int trainDataSize, float evaAcc) {
        float uploadLoss;
        RequestUpdateModelBuilder builder = new RequestUpdateModelBuilder(this.localFLParameter.getEncryptLevel());
        boolean isPkiVerify = this.flParameter.isPkiVerify();
        Client client = ClientManager.getClient(this.flParameter.getFlName());
        float f = uploadLoss = client == null ? 0.0f : client.getUploadLoss();
        if (isPkiVerify) {
            Date date = new Date();
            long timestamp = date.getTime();
            String dateTime = String.valueOf(timestamp);
            byte[] signature = CipherClient.signTimeAndIter(dateTime, iteration);
            return builder.flName(this.flParameter.getFlName()).time(dateTime).id(this.localFLParameter.getFlID()).featuresMap(secureProtocol, trainDataSize).iteration(iteration).signData(signature).uploadLoss(uploadLoss).evalAccuracy(evaAcc).build();
        }
        return builder.flName(this.flParameter.getFlName()).time("null").id(this.localFLParameter.getFlID()).featuresMap(secureProtocol, trainDataSize).iteration(iteration).uploadLoss(uploadLoss).evalAccuracy(evaAcc).build();
    }

    public FLClientStatus doResponse(ResponseUpdateModel response) {
        this.retCode = response.retcode();
        LOGGER.info("[updateModel] ==========the response message of updateModel is================");
        LOGGER.info("[updateModel] ==========retCode: " + this.retCode);
        LOGGER.info("[updateModel] ==========reason: " + response.reason());
        LOGGER.info("[updateModel] ==========next request time: " + response.nextReqTime());
        switch (response.retcode()) {
            case 200: {
                LOGGER.info("[updateModel] updateModel success");
                return FLClientStatus.SUCCESS;
            }
            case 300: {
                return FLClientStatus.RESTART;
            }
            case 400: 
            case 500: {
                LOGGER.warning("[updateModel] catch RequestError or SystemError");
                return FLClientStatus.FAILED;
            }
        }
        LOGGER.severe("[updateModel]the return <retCode> from server is invalid: " + response.retcode());
        return FLClientStatus.FAILED;
    }

    static {
        System.loadLibrary("mindspore-lite-jni");
    }

    class RequestUpdateModelBuilder {
        private RequestUpdateModel requestUM;
        private FlatBufferBuilder builder;
        private int fmOffset = 0;
        private int compFmOffset = 0;
        private int nameOffset = 0;
        private int idOffset = 0;
        private int timestampOffset = 0;
        private int signDataOffset = 0;
        private int sign = 0;
        private int indexArrayOffset = 0;
        private int iteration = 0;
        private byte uploadCompressType = 0;
        private float uploadSparseRate = 0.0f;
        private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT;
        private float uploadLossOffset = 0.0f;
        private float evalAccuracy = 0.0f;
        private int nameVecOffset = 0;

        private RequestUpdateModelBuilder(EncryptLevel encryptLevel) {
            this.builder = new FlatBufferBuilder();
            this.encryptLevel = encryptLevel;
        }

        private RequestUpdateModelBuilder flName(String name) {
            if (name == null || name.isEmpty()) {
                LOGGER.severe("[updateModel] the parameter of <name> is null or empty, please check!");
                throw new IllegalArgumentException();
            }
            this.nameOffset = this.builder.createString((CharSequence)name);
            return this;
        }

        private RequestUpdateModelBuilder time(String setTime) {
            if (setTime == null || setTime.isEmpty()) {
                LOGGER.severe("[updateModel] the parameter of <setTime> is null or empty, please check!");
                throw new IllegalArgumentException();
            }
            if (setTime.equals("null")) {
                Date date = new Date();
                long time = date.getTime();
                this.timestampOffset = this.builder.createString((CharSequence)String.valueOf(time));
            } else {
                this.timestampOffset = this.builder.createString((CharSequence)setTime);
            }
            return this;
        }

        private RequestUpdateModelBuilder iteration(int iteration) {
            this.iteration = iteration;
            return this;
        }

        private RequestUpdateModelBuilder id(String id) {
            if (id == null || id.isEmpty()) {
                LOGGER.severe("[updateModel] the parameter of <id> is null or empty, please check!");
                throw new IllegalArgumentException();
            }
            this.idOffset = this.builder.createString((CharSequence)id);
            return this;
        }

        private EncrypterBase getEncrypter(Client client, SecureProtocol secureProtocol, int trainDataSize) {
            switch (this.encryptLevel) {
                case PW_ENCRYPT: {
                    return new PwEncrypter(client, secureProtocol, trainDataSize);
                }
                case DP_ENCRYPT: {
                    return new DpEncrypter(client, secureProtocol, trainDataSize);
                }
            }
            return new NoEncrypter(client, secureProtocol, trainDataSize);
        }

        private RequestUpdateModelBuilder featuresMap(SecureProtocol secureProtocol, int trainDataSize) {
            ArrayList<String> updateFeatureName = secureProtocol.getUpdateFeatureName();
            if (this.encryptLevel == EncryptLevel.SIGNDS) {
                return this.signDSEncrypt(secureProtocol, updateFeatureName);
            }
            Client client = ClientManager.getClient(UpdateModel.this.flParameter.getFlName());
            EncrypterBase encrypterBase = this.getEncrypter(client, secureProtocol, trainDataSize);
            encrypterBase.init();
            this.uploadCompressType = UpdateModel.this.localFLParameter.getUploadCompressType();
            if (this.uploadCompressType == 0) {
                long startTime = System.currentTimeMillis();
                int index = 0;
                int[] fmOffsets = new int[updateFeatureName.size()];
                while (!encrypterBase.isEnd()) {
                    int featureMapOffset;
                    String featureName = encrypterBase.getNextFeature();
                    float[] encryptWeight = encrypterBase.geEncryptWeight(featureName);
                    LOGGER.fine("[updateModel build featuresMap] feature name: " + featureName + " feature " + "size: " + encryptWeight.length);
                    int featureNameOffset = this.builder.createString((CharSequence)featureName);
                    int weightOffset = FeatureMap.createDataVector(this.builder, encryptWeight);
                    fmOffsets[index] = featureMapOffset = FeatureMap.createFeatureMap(this.builder, featureNameOffset, weightOffset);
                    ++index;
                }
                this.fmOffset = RequestUpdateModel.createFeatureMapVector(this.builder, fmOffsets);
                long endTime = System.currentTimeMillis();
                LOGGER.info("No compression and encrypt type is:" + (Object)((Object)this.encryptLevel) + " cost " + (endTime - startTime) + "ms");
                return this;
            }
            long startTime = System.currentTimeMillis();
            int totalMaskLen = 0;
            for (String featureName : updateFeatureName) {
                totalMaskLen += client.getPreFeature(featureName).length;
            }
            boolean[] maskArray = EncodeExecutor.getInstance().constructMaskArray(totalMaskLen);
            int maskedLen = 0;
            int index = 0;
            int[] compFmOffsets = new int[updateFeatureName.size()];
            while (!encrypterBase.isEnd()) {
                String featureName = encrypterBase.getNextFeature();
                float[] encryptWeight = encrypterBase.geEncryptWeight(featureName);
                float[] preWeight = client.getPreFeature(featureName);
                CompressWeight compressWeight = EncodeExecutor.enDiffSparseQuantData(featureName, encryptWeight, preWeight, 8, trainDataSize, maskArray, maskedLen);
                byte[] data = compressWeight.getCompressData();
                float minVal = compressWeight.getMinValue();
                float maxVal = compressWeight.getMaxValue();
                LOGGER.fine("[updateModel build compressWeight] origin size: " + preWeight.length + ", after compress size: " + data.length);
                int featureNameOffset = this.builder.createString((CharSequence)featureName);
                int weightOffset = CompressFeatureMap.createCompressDataVector(this.builder, data);
                int featureOffset = CompressFeatureMap.createCompressFeatureMap(this.builder, featureNameOffset, weightOffset, minVal, maxVal);
                LOGGER.fine("[updateModel Compression] " + featureName + "min_val: " + minVal + ", max_val: " + maxVal);
                compFmOffsets[index] = featureOffset;
                ++index;
                maskedLen += preWeight.length;
            }
            this.compFmOffset = RequestUpdateModel.createCompressFeatureMapVector(this.builder, compFmOffsets);
            this.uploadSparseRate = UpdateModel.this.localFLParameter.getUploadSparseRatio();
            this.nameVecOffset = this.buildNameVecOffset(updateFeatureName);
            long endTime = System.currentTimeMillis();
            LOGGER.info("compression time is " + (endTime - startTime) + "ms, encrypt is " + (Object)((Object)this.encryptLevel));
            return this;
        }

        private RequestUpdateModelBuilder signDSEncrypt(SecureProtocol secureProtocol, ArrayList<String> updateFeatureName) {
            long startTime = System.currentTimeMillis();
            Client client = ClientManager.getClient(UpdateModel.this.flParameter.getFlName());
            SecureRandom secureRandom = Common.getSecureRandom();
            boolean signBool = secureRandom.nextBoolean();
            this.sign = signBool ? 1 : -1;
            int[] indexArray = secureProtocol.signDSModel(client, signBool);
            if (indexArray == null || indexArray.length == 0) {
                LOGGER.severe("[Encrypt] the return fmOffsetsSignDS from <secureProtocol.signDSModel> is null, please check");
                UpdateModel.this.retCode = 400;
                UpdateModel.this.status = FLClientStatus.FAILED;
                throw new IllegalArgumentException();
            }
            this.indexArrayOffset = RequestUpdateModel.createIndexArrayVector(this.builder, indexArray);
            int compFeatureSize = updateFeatureName.size();
            int[] fmOffsetsSignds = new int[compFeatureSize];
            for (int i = 0; i < compFeatureSize; ++i) {
                int featureMap;
                String key = updateFeatureName.get(i);
                float[] data = new float[]{};
                int featureName = this.builder.createString((CharSequence)key);
                int weight = FeatureMap.createDataVector(this.builder, data);
                fmOffsetsSignds[i] = featureMap = FeatureMap.createFeatureMap(this.builder, featureName, weight);
            }
            this.fmOffset = RequestUpdateModel.createFeatureMapVector(this.builder, fmOffsetsSignds);
            LOGGER.info("[Encrypt] SignDS mask model ok!");
            long endTime = System.currentTimeMillis();
            LOGGER.info("signds time is " + (endTime - startTime) + "ms");
            return this;
        }

        private int buildNameVecOffset(ArrayList<String> updateFeatureName) {
            int featureSize = updateFeatureName.size();
            int[] nameVecOffsets = new int[featureSize];
            for (int i = 0; i < featureSize; ++i) {
                String key = updateFeatureName.get(i);
                nameVecOffsets[i] = this.builder.createString((CharSequence)key);
            }
            return RequestUpdateModel.createNameVecVector(this.builder, nameVecOffsets);
        }

        private RequestUpdateModelBuilder signData(byte[] signData) {
            if (signData == null || signData.length == 0) {
                LOGGER.severe("[updateModel] the parameter of <signData> is null or empty, please check!");
                throw new IllegalArgumentException();
            }
            this.signDataOffset = RequestUpdateModel.createSignatureVector(this.builder, signData);
            return this;
        }

        private RequestUpdateModelBuilder uploadLoss(float uploadLoss) {
            this.uploadLossOffset = uploadLoss;
            return this;
        }

        private RequestUpdateModelBuilder evalAccuracy(float evalAccuracy) {
            this.evalAccuracy = evalAccuracy;
            return this;
        }

        private byte[] build() {
            RequestUpdateModel.startRequestUpdateModel(this.builder);
            RequestUpdateModel.addFlName(this.builder, this.nameOffset);
            RequestUpdateModel.addFlId(this.builder, this.idOffset);
            RequestUpdateModel.addTimestamp(this.builder, this.timestampOffset);
            RequestUpdateModel.addIteration(this.builder, this.iteration);
            RequestUpdateModel.addCompressFeatureMap(this.builder, this.compFmOffset);
            RequestUpdateModel.addUploadCompressType(this.builder, this.uploadCompressType);
            RequestUpdateModel.addUploadSparseRate(this.builder, this.uploadSparseRate);
            RequestUpdateModel.addNameVec(this.builder, this.nameVecOffset);
            RequestUpdateModel.addFeatureMap(this.builder, this.fmOffset);
            RequestUpdateModel.addSignature(this.builder, this.signDataOffset);
            RequestUpdateModel.addUploadLoss(this.builder, this.uploadLossOffset);
            RequestUpdateModel.addUploadAccuracy(this.builder, this.evalAccuracy);
            RequestUpdateModel.addSign(this.builder, this.sign);
            RequestUpdateModel.addIndexArray(this.builder, this.indexArrayOffset);
            int root = RequestUpdateModel.endRequestUpdateModel(this.builder);
            this.builder.finish(root);
            return this.builder.sizedByteArray();
        }

        class DpEncrypter
        extends EncrypterBase {
            private double clipFactor;
            private double gaussianSigma;

            public DpEncrypter(Client client, SecureProtocol secureProtocol, int trainDataSize) {
                super(client, secureProtocol, trainDataSize);
            }

            @Override
            public boolean init() {
                this.gaussianSigma = this.secureProtocol.calculateSigma();
                double dpNormClip = this.secureProtocol.getDpNormClip();
                double updateL2Norm = 0.0;
                for (int i = 0; i < this.featureNames.size(); ++i) {
                    String key = (String)this.featureNames.get(i);
                    float[] data = this.client.getFeature(key);
                    float[] dataBeforeTrain = this.client.getPreFeature(key);
                    if (data == null || dataBeforeTrain == null || data.length != dataBeforeTrain.length) {
                        throw new RuntimeException("data of feature size is not same, feature name:" + key);
                    }
                    for (int j = 0; j < data.length; ++j) {
                        float updateData = data[j] - dataBeforeTrain[j];
                        updateL2Norm += (double)(updateData * updateData);
                    }
                }
                if ((updateL2Norm = Math.sqrt(updateL2Norm)) == 0.0) {
                    LOGGER.severe("[Encrypt] updateL2Norm is 0, please check");
                    return false;
                }
                this.clipFactor = Math.min(1.0, dpNormClip / updateL2Norm);
                return true;
            }

            @Override
            public float[] geEncryptWeight(String weightName) {
                float[] weight = this.client.getFeature(weightName);
                float[] weightBeforeTrain = this.client.getPreFeature(weightName);
                if (weight == null || weightBeforeTrain == null || weight.length != weightBeforeTrain.length) {
                    throw new RuntimeException("data of feature size is not same, feature name:" + weightName);
                }
                SecureRandom secureRandom = Common.getSecureRandom();
                for (int j = 0; j < weight.length; ++j) {
                    float rawData = weight[j];
                    float rawDataBeforeTrain = weightBeforeTrain[j];
                    float updateData = rawData - rawDataBeforeTrain;
                    updateData = (float)((double)updateData * this.clipFactor);
                    double gaussianNoise = secureRandom.nextGaussian() * this.gaussianSigma;
                    updateData = (float)((double)updateData + gaussianNoise);
                    weight[j] = rawDataBeforeTrain + updateData;
                    weight[j] = weight[j] * (float)this.trainDataSize;
                }
                return weight;
            }
        }

        class PwEncrypter
        extends EncrypterBase {
            private int maskedLen;

            public PwEncrypter(Client client, SecureProtocol secureProtocol, int trainDataSize) {
                super(client, secureProtocol, trainDataSize);
                this.maskedLen = 0;
            }

            @Override
            public boolean init() {
                return true;
            }

            @Override
            public float[] geEncryptWeight(String weightName) {
                float[] weight = this.client.getFeature(weightName);
                float[] encryptWeight = this.secureProtocol.pwMaskWeight(this.trainDataSize, weight, this.maskedLen);
                this.maskedLen += this.maskedLen;
                return encryptWeight;
            }
        }

        class NoEncrypter
        extends EncrypterBase {
            public NoEncrypter(Client client, SecureProtocol secureProtocol, int trainDataSize) {
                super(client, secureProtocol, trainDataSize);
            }

            @Override
            public boolean init() {
                return true;
            }

            @Override
            public float[] geEncryptWeight(String weightName) {
                float[] weight = this.client.getFeature(weightName);
                for (int i = 0; i < weight.length; ++i) {
                    weight[i] = weight[i] * (float)this.trainDataSize;
                }
                return weight;
            }
        }

        abstract class EncrypterBase {
            protected Client client;
            protected SecureProtocol secureProtocol;
            protected ArrayList<String> featureNames;
            protected int trainDataSize;
            protected int curIter = 0;

            public EncrypterBase(Client client, SecureProtocol secureProtocol, int trainDataSize) {
                this.client = client;
                this.secureProtocol = secureProtocol;
                this.trainDataSize = trainDataSize;
                this.featureNames = secureProtocol.getUpdateFeatureName();
            }

            abstract boolean init();

            boolean isEnd() {
                return this.curIter >= this.featureNames.size();
            }

            public String getNextFeature() {
                String weightName = this.featureNames.get(this.curIter);
                ++this.curIter;
                return weightName;
            }

            public abstract float[] geEncryptWeight(String var1);
        }
    }
}

