/*
 * Decompiled with CFR 0.152.
 */
package com.mindspore.flclient;

import com.mindspore.flclient.Common;
import com.mindspore.flclient.EncryptLevel;
import com.mindspore.flclient.FLClientStatus;
import com.mindspore.flclient.FLCommunication;
import com.mindspore.flclient.FLParameter;
import com.mindspore.flclient.GetModel;
import com.mindspore.flclient.LocalFLParameter;
import com.mindspore.flclient.SecureProtocol;
import com.mindspore.flclient.ServerMod;
import com.mindspore.flclient.StartFLJob;
import com.mindspore.flclient.UpdateModel;
import com.mindspore.flclient.common.FLLoggerGenerater;
import com.mindspore.flclient.model.Client;
import com.mindspore.flclient.model.ClientManager;
import com.mindspore.flclient.model.RunType;
import com.mindspore.flclient.model.Status;
import com.mindspore.flclient.pki.PkiBean;
import com.mindspore.flclient.pki.PkiUtil;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Date;
import java.util.Map;
import java.util.logging.Logger;
import mindspore.fl.schema.CipherPublicParams;
import mindspore.fl.schema.FLPlan;
import mindspore.fl.schema.ResponseFLJob;
import mindspore.fl.schema.ResponseGetModel;
import mindspore.fl.schema.ResponseUpdateModel;

public class FLLiteClient {
    private static final Logger LOGGER = FLLoggerGenerater.getModelLogger(FLLiteClient.class.toString());
    private static int iteration = 0;
    private double dpNormClipFactor = 1.0;
    private double dpNormClipAdapt = 0.05;
    private FLCommunication flCommunication;
    private FLClientStatus status;
    private int retCode = 400;
    private int iterations = 1;
    private int epochs = 1;
    private int batchSize = 16;
    private int minSecretNum;
    private byte[] prime;
    private int featureSize;
    private int trainDataSize = 0;
    private int evaDataSize = 0;
    private double dpEps = 100.0;
    private double dpDelta = 0.01;
    private FLParameter flParameter = FLParameter.getInstance();
    private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
    private SecureProtocol secureProtocol = new SecureProtocol();
    private String nextRequestTime;
    private Client client;
    private float signK = 0.01f;
    private float signEps = 100.0f;
    private float signThrRatio = 0.6f;
    private float signGlobalLr = 1.0f;
    private int signDimOut = 0;
    private float evaAcc = 0.0f;

    public float getEvaAcc() {
        return this.evaAcc;
    }

    public FLLiteClient() {
        this.flCommunication = FLCommunication.getInstance();
        this.client = ClientManager.getClient(this.flParameter.getFlName());
    }

    private int setGlobalParameters(ResponseFLJob flJob) {
        FLPlan flPlan = flJob.flPlanConfig();
        if (flPlan == null) {
            LOGGER.severe("[startFLJob] the FLPlan get from server is null");
            return -1;
        }
        this.iterations = flPlan.iterations();
        this.epochs = flPlan.epochs();
        this.batchSize = flPlan.miniBatch();
        String serverMod = flPlan.serverMode();
        this.localFLParameter.setServerMod(serverMod);
        byte uploadCompressType = flJob.uploadCompressType();
        LOGGER.info("[startFLJob] [compression] uploadCompressType: " + uploadCompressType);
        this.localFLParameter.setUploadCompressType(uploadCompressType);
        float uploadSparseRate = flJob.uploadSparseRate();
        LOGGER.info("[startFLJob] [compression] uploadSparseRate: " + uploadSparseRate);
        this.localFLParameter.setUploadSparseRatio(uploadSparseRate);
        int seed = flJob.iteration();
        LOGGER.info("[startFLJob] [compression] seed: " + seed);
        this.localFLParameter.setSeed(seed);
        LOGGER.info("[startFLJob] the GlobalParameter <serverMod> from server: " + serverMod);
        LOGGER.info("[startFLJob] the GlobalParameter <iterations> from server: " + this.iterations);
        LOGGER.info("[startFLJob] the GlobalParameter <epochs> from server: " + this.epochs);
        LOGGER.info("[startFLJob] the GlobalParameter <batchSize> from server: " + this.batchSize);
        CipherPublicParams cipherPublicParams = flPlan.cipher();
        if (cipherPublicParams == null) {
            LOGGER.severe("[startFLJob] the cipherPublicParams returned from server is null");
            return -1;
        }
        String encryptLevel = cipherPublicParams.encryptType();
        if (encryptLevel == null || encryptLevel.isEmpty()) {
            LOGGER.severe("[startFLJob] GlobalParameters <encryptLevel> from server is null, set the encryptLevel to NOT_ENCRYPT ");
            this.localFLParameter.setEncryptLevel(EncryptLevel.NOT_ENCRYPT.toString());
        } else {
            this.localFLParameter.setEncryptLevel(encryptLevel);
            LOGGER.info("[startFLJob] GlobalParameters <encryptLevel> from server: " + encryptLevel);
        }
        switch (this.localFLParameter.getEncryptLevel()) {
            case PW_ENCRYPT: {
                this.minSecretNum = cipherPublicParams.pwParams().t();
                int primeLength = cipherPublicParams.pwParams().primeLength();
                this.prime = new byte[primeLength];
                for (int i = 0; i < primeLength; ++i) {
                    this.prime[i] = (byte)cipherPublicParams.pwParams().prime(i);
                }
                LOGGER.info("[startFLJob] GlobalParameters <minSecretNum> from server: " + this.minSecretNum);
                if (this.minSecretNum > 0) break;
                LOGGER.info("[startFLJob] GlobalParameters <minSecretNum> from server is not valid:  <=0");
                return -1;
            }
            case DP_ENCRYPT: {
                this.dpEps = cipherPublicParams.dpParams().dpEps();
                this.dpDelta = cipherPublicParams.dpParams().dpDelta();
                this.dpNormClipFactor = cipherPublicParams.dpParams().dpNormClip();
                LOGGER.info("[startFLJob] GlobalParameters <dpEps> from server: " + this.dpEps);
                LOGGER.info("[startFLJob] GlobalParameters <dpDelta> from server: " + this.dpDelta);
                LOGGER.info("[startFLJob] GlobalParameters <dpNormClipFactor> from server: " + this.dpNormClipFactor);
                break;
            }
            case SIGNDS: {
                this.signK = cipherPublicParams.dsParams().signK();
                this.signEps = cipherPublicParams.dsParams().signEps();
                this.signThrRatio = cipherPublicParams.dsParams().signThrRatio();
                this.signGlobalLr = cipherPublicParams.dsParams().signGlobalLr();
                this.signDimOut = cipherPublicParams.dsParams().signDimOut();
                LOGGER.info("[startFLJob] GlobalParameters <signK> from server: " + this.signK);
                LOGGER.info("[startFLJob] GlobalParameters <signEps> from server: " + this.signEps);
                LOGGER.info("[startFLJob] GlobalParameters <signThrRatio> from server: " + this.signThrRatio);
                LOGGER.info("[startFLJob] GlobalParameters <signGlobalLr> from server: " + this.signGlobalLr);
                LOGGER.info("[startFLJob] GlobalParameters <SignDimOut> from server: " + this.signDimOut);
                break;
            }
            default: {
                LOGGER.info("[startFLJob] NOT_ENCRYPT, do not set parameter for Encrypt");
            }
        }
        return 0;
    }

    public int getRetCode() {
        return this.retCode;
    }

    public int getIteration() {
        return iteration;
    }

    public int getIterations() {
        return this.iterations;
    }

    public String getNextRequestTime() {
        return this.nextRequestTime;
    }

    public double getDpNormClipFactor() {
        return this.dpNormClipFactor;
    }

    public double getDpNormClipAdapt() {
        return this.dpNormClipAdapt;
    }

    public void setDpNormClipAdapt(double dpNormClipAdapt) {
        this.dpNormClipAdapt = dpNormClipAdapt;
    }

    public FLClientStatus startFLJob() {
        LOGGER.info("[startFLJob] ====================================Verify server====================================");
        String url = Common.generateUrl(this.flParameter.isUseElb(), this.flParameter.getServerNum(), this.flParameter.getDomainName());
        StartFLJob startFLJob = StartFLJob.getInstance();
        Date date = new Date();
        long time = date.getTime();
        PkiBean pkiBean = null;
        if (this.flParameter.isPkiVerify()) {
            pkiBean = PkiUtil.genPkiBean(this.flParameter.getClientID(), time);
        }
        byte[] msg = startFLJob.getRequestStartFLJob(this.trainDataSize, this.evaDataSize, iteration, time, pkiBean);
        try {
            long start = Common.startTime("single startFLJob");
            LOGGER.info("[startFLJob] the request message length: " + msg.length);
            byte[] message = this.flCommunication.syncRequest(url + "/startFLJob", msg);
            if (!Common.isSeverReady(message)) {
                LOGGER.info("[startFLJob] the server is not ready now, need wait some time and request again");
                this.status = FLClientStatus.RESTART;
                this.nextRequestTime = Common.getNextReqTime();
                this.retCode = 300;
                return this.status;
            }
            if (Common.isSeverJobFinished(message)) {
                return this.serverJobFinished("startFLJob");
            }
            LOGGER.info("[startFLJob] the response message length: " + message.length);
            Common.endTime(start, "single startFLJob");
            ByteBuffer buffer = ByteBuffer.wrap(message);
            ResponseFLJob responseDataBuf = ResponseFLJob.getRootAsResponseFLJob(buffer);
            this.status = this.judgeStartFLJob(startFLJob, responseDataBuf);
        }
        catch (IOException e) {
            this.failed("[startFLJob] unsolved error code in StartFLJob: catch IOException: " + e.getMessage(), 400);
        }
        return this.status;
    }

    private FLClientStatus judgeStartFLJob(StartFLJob startFLJob, ResponseFLJob responseDataBuf) {
        iteration = responseDataBuf.iteration();
        FLClientStatus response = startFLJob.doResponse(responseDataBuf);
        this.retCode = startFLJob.getRetCode();
        this.status = response;
        switch (response) {
            case SUCCESS: {
                LOGGER.info("[startFLJob] startFLJob success");
                this.featureSize = startFLJob.getFeatureSize();
                this.secureProtocol.setUpdateFeatureName(startFLJob.getUpdateFeatureName());
                LOGGER.info("[startFLJob] ***the feature size get in ResponseFLJob***: " + this.featureSize);
                int tag = this.setGlobalParameters(responseDataBuf);
                if (tag != -1) break;
                LOGGER.severe("[startFLJob] setGlobalParameters failed");
                this.status = FLClientStatus.FAILED;
                break;
            }
            case RESTART: {
                FLPlan flPlan = responseDataBuf.flPlanConfig();
                if (flPlan == null) {
                    LOGGER.severe("[startFLJob] the flPlan returned from server is null");
                    return FLClientStatus.FAILED;
                }
                this.iterations = flPlan.iterations();
                LOGGER.info("[startFLJob] GlobalParameters <iterations> from server: " + this.iterations);
                this.nextRequestTime = responseDataBuf.nextReqTime();
                break;
            }
            case FAILED: {
                LOGGER.severe("[startFLJob] startFLJob failed");
                break;
            }
            default: {
                LOGGER.severe("[startFLJob] failed: the response of startFLJob is out of range <SUCCESS, WAIT, FAILED, Restart>");
                this.status = FLClientStatus.FAILED;
            }
        }
        return this.status;
    }

    private FLClientStatus trainLoop() {
        Client client = ClientManager.getClient(this.flParameter.getFlName());
        if (!client.EnableTrain(true)) {
            this.retCode = 400;
            return FLClientStatus.FAILED;
        }
        this.retCode = 200;
        LOGGER.info("[train] train in " + this.flParameter.getFlName());
        LOGGER.info("[train] lr for client is: " + this.localFLParameter.getLr());
        Status tag = client.setLearningRate(this.localFLParameter.getLr());
        if (!Status.SUCCESS.equals((Object)tag)) {
            LOGGER.severe("[train] setLearningRate failed, return -1, please check");
            this.retCode = 400;
            return FLClientStatus.FAILED;
        }
        tag = client.trainModel(this.epochs);
        if (Float.isNaN(client.getUploadLoss()) || Float.isInfinite(client.getUploadLoss())) {
            client.restoreModelFile(this.flParameter.getTrainModelPath());
            this.failed("[train] train failed, train loss is:" + client.getUploadLoss(), 400);
        } else if (!Status.SUCCESS.equals((Object)tag)) {
            this.failed("[train] unsolved error code in <client.trainModel>", 400);
        }
        return this.status;
    }

    public FLClientStatus localTrain() {
        LOGGER.info("[train] ====================================global train epoch " + iteration + "====================================");
        this.status = this.trainLoop();
        return this.status;
    }

    public FLClientStatus updateModel() {
        String url = Common.generateUrl(this.flParameter.isUseElb(), this.flParameter.getServerNum(), this.flParameter.getDomainName());
        UpdateModel updateModelBuf = UpdateModel.getInstance();
        byte[] updateModelBuffer = updateModelBuf.getRequestUpdateFLJob(iteration, this.secureProtocol, this.trainDataSize, this.evaAcc);
        if (updateModelBuf.getStatus() == FLClientStatus.FAILED) {
            LOGGER.info("[updateModel] catch error in build RequestUpdateFLJob");
            return FLClientStatus.FAILED;
        }
        try {
            long start = Common.startTime("single updateModel");
            LOGGER.info("[updateModel] the request message length: " + updateModelBuffer.length);
            byte[] message = this.flCommunication.syncRequest(url + "/updateModel", updateModelBuffer);
            if (!Common.isSeverReady(message)) {
                LOGGER.info("[updateModel] the server is not ready now, need wait some time and request again");
                this.status = FLClientStatus.RESTART;
                this.nextRequestTime = Common.getNextReqTime();
                this.retCode = 300;
                return this.status;
            }
            if (Common.isSeverJobFinished(message)) {
                return this.serverJobFinished("updateModel");
            }
            LOGGER.info("[updateModel] the response message length: " + message.length);
            Common.endTime(start, "single updateModel");
            ByteBuffer debugBuffer = ByteBuffer.wrap(message);
            ResponseUpdateModel responseDataBuf = ResponseUpdateModel.getRootAsResponseUpdateModel(debugBuffer);
            this.status = updateModelBuf.doResponse(responseDataBuf);
            this.retCode = updateModelBuf.getRetCode();
            if (this.status == FLClientStatus.RESTART) {
                this.nextRequestTime = responseDataBuf.nextReqTime();
            }
            LOGGER.info("[updateModel] get response from server ok!");
        }
        catch (IOException e) {
            this.failed("[updateModel] unsolved error code in updateModel: catch IOException: " + e.getMessage(), 400);
        }
        return this.status;
    }

    public FLClientStatus getModel() {
        String url = Common.generateUrl(this.flParameter.isUseElb(), this.flParameter.getServerNum(), this.flParameter.getDomainName());
        GetModel getModelBuf = GetModel.getInstance();
        byte[] buffer = getModelBuf.getRequestGetModel(this.flParameter.getFlName(), iteration);
        try {
            long start = Common.startTime("single getModel");
            LOGGER.info("[getModel] the request message length: " + buffer.length);
            byte[] message = this.flCommunication.syncRequest(url + "/getModel", buffer);
            if (!Common.isSeverReady(message)) {
                LOGGER.info("[getModel] the server is not ready now, need wait some time and request again");
                this.status = FLClientStatus.WAIT;
                this.retCode = 201;
                return this.status;
            }
            if (Common.isSeverJobFinished(message)) {
                return this.serverJobFinished("getModel");
            }
            LOGGER.info("[getModel] the response message length: " + message.length);
            Common.endTime(start, "single getModel");
            LOGGER.info("[getModel] get model request success");
            ByteBuffer debugBuffer = ByteBuffer.wrap(message);
            ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer);
            this.status = getModelBuf.doResponse(responseDataBuf);
            this.retCode = getModelBuf.getRetCode();
            if (this.status == FLClientStatus.RESTART) {
                this.nextRequestTime = responseDataBuf.timestamp();
            }
            LOGGER.info("[getModel] get response from server ok!");
        }
        catch (IOException e) {
            this.failed("[getModel] unsolved error code: catch IOException: " + e.getMessage(), 400);
        }
        return this.status;
    }

    public void updateDpNormClip() {
        EncryptLevel encryptLevel = this.localFLParameter.getEncryptLevel();
        if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
            this.client.EnableTrain(true);
            float fedWeightUpdateNorm = this.client.getDpWeightNorm(this.secureProtocol.getUpdateFeatureName());
            LOGGER.info("[DP] L2-norm of weights' average update is: " + fedWeightUpdateNorm);
            float newNormCLip = (float)this.getDpNormClipFactor() * fedWeightUpdateNorm;
            if (iteration == 1) {
                this.setDpNormClipAdapt(newNormCLip);
                LOGGER.info("[DP] dpNormClip has been updated.");
            } else if ((double)newNormCLip < this.getDpNormClipAdapt()) {
                this.setDpNormClipAdapt(newNormCLip);
                LOGGER.info("[DP] dpNormClip has been updated.");
            }
            LOGGER.info("[DP] Adaptive dpNormClip is: " + this.getDpNormClipAdapt());
        }
    }

    public FLClientStatus getFeatureMask() {
        switch (this.localFLParameter.getEncryptLevel()) {
            case PW_ENCRYPT: {
                LOGGER.info("[Encrypt] creating feature mask of <" + this.localFLParameter.getEncryptLevel().toString() + ">");
                this.secureProtocol.setPWParameter(iteration, this.minSecretNum, this.prime, this.featureSize);
                FLClientStatus curStatus = this.secureProtocol.pwCreateMask();
                if (curStatus == FLClientStatus.RESTART) {
                    this.nextRequestTime = this.secureProtocol.getNextRequestTime();
                }
                this.retCode = this.secureProtocol.getRetCode();
                LOGGER.info("[Encrypt] the response of create mask for <" + this.localFLParameter.getEncryptLevel().toString() + "> : " + curStatus);
                return curStatus;
            }
            case DP_ENCRYPT: {
                FLClientStatus curStatus = this.secureProtocol.setDPParameter(iteration, this.dpEps, this.dpDelta, this.dpNormClipAdapt);
                this.retCode = 200;
                if (curStatus != FLClientStatus.SUCCESS) {
                    LOGGER.severe("---Differential privacy init failed---");
                    this.retCode = 400;
                    return FLClientStatus.FAILED;
                }
                LOGGER.info("[Encrypt] set parameters for DP_ENCRYPT!");
                return FLClientStatus.SUCCESS;
            }
            case SIGNDS: {
                FLClientStatus curStatus = this.secureProtocol.setDSParameter(this.signK, this.signEps, this.signThrRatio, this.signGlobalLr, this.signDimOut);
                this.retCode = 200;
                if (curStatus != FLClientStatus.SUCCESS) {
                    LOGGER.severe("---SignDS init failed---");
                    this.retCode = 400;
                    return FLClientStatus.FAILED;
                }
                LOGGER.info("[Encrypt] set parameters for SignDS!");
                return FLClientStatus.SUCCESS;
            }
            case NOT_ENCRYPT: {
                this.retCode = 200;
                LOGGER.info("[Encrypt] don't mask model");
                return FLClientStatus.SUCCESS;
            }
        }
        this.retCode = 200;
        LOGGER.severe("[Encrypt] The encrypt level is error, not encrypt by default");
        return FLClientStatus.SUCCESS;
    }

    public FLClientStatus unMasking() {
        switch (this.localFLParameter.getEncryptLevel()) {
            case PW_ENCRYPT: {
                FLClientStatus curStatus = this.secureProtocol.pwUnmasking();
                this.retCode = this.secureProtocol.getRetCode();
                LOGGER.info("[Encrypt] the response of unmasking : " + curStatus);
                if (curStatus == FLClientStatus.RESTART) {
                    this.nextRequestTime = this.secureProtocol.getNextRequestTime();
                }
                return curStatus;
            }
            case DP_ENCRYPT: {
                LOGGER.info("[Encrypt] DP_ENCRYPT do not need unmasking");
                this.retCode = 200;
                return FLClientStatus.SUCCESS;
            }
            case NOT_ENCRYPT: {
                LOGGER.info("[Encrypt] haven't mask model");
                this.retCode = 200;
                return FLClientStatus.SUCCESS;
            }
            case SIGNDS: {
                LOGGER.info("[Encrypt] SIGNDS do not need unmasking");
                this.retCode = 200;
                return FLClientStatus.SUCCESS;
            }
        }
        LOGGER.severe("[Encrypt] The encrypt level is error, not encrypt by default");
        this.retCode = 200;
        return FLClientStatus.SUCCESS;
    }

    private FLClientStatus evaluateLoop() {
        this.status = FLClientStatus.SUCCESS;
        this.retCode = 200;
        this.evaAcc = 0.0f;
        if (this.localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
            LOGGER.info("[evaluate] evaluateModel by " + this.localFLParameter.getServerMod());
            this.client.EnableTrain(false);
            LOGGER.info("[evaluate] modelPath: " + this.flParameter.getInferModelPath());
            this.evaAcc = this.client.evalModel();
        } else {
            LOGGER.info("[evaluate] evaluateModel by " + this.localFLParameter.getServerMod());
            this.client.EnableTrain(true);
            LOGGER.info("[evaluate] modelPath: " + this.flParameter.getTrainModelPath());
            this.evaAcc = this.client.evalModel();
        }
        if (Float.isNaN(this.evaAcc)) {
            this.failed("[evaluate] unsolved error code in <evalModel>: the return acc is NAN", 400);
            return this.status;
        }
        LOGGER.info("[evaluate] evaluate acc: " + this.evaAcc);
        return this.status;
    }

    private void failed(String log, int retCode) {
        LOGGER.severe(log);
        this.status = FLClientStatus.FAILED;
        this.retCode = retCode;
    }

    public FLClientStatus evaluateModel() {
        LOGGER.info("===================================evaluate model after getting model from server===================================");
        this.status = this.evaluateLoop();
        return this.status;
    }

    public boolean initDataSets() {
        this.retCode = 200;
        LOGGER.info("==========set input===========");
        Map<RunType, Integer> dataInfo = this.client.initDataSets(this.flParameter.getDataMap());
        this.trainDataSize = dataInfo.get((Object)RunType.TRAINMODE);
        if (this.trainDataSize <= 0) {
            this.retCode = 400;
            return false;
        }
        this.evaDataSize = dataInfo.getOrDefault((Object)RunType.EVALMODE, 0);
        return true;
    }

    private FLClientStatus serverJobFinished(String logTag) {
        LOGGER.info("[" + logTag + "] The server's training job is disabled or finished. will stop the task and exist.");
        this.retCode = 500;
        return FLClientStatus.FAILED;
    }
}

