/*
 * Decompiled with CFR 0.152.
 */
package com.mindspore.flclient;

import com.mindspore.config.Version;
import com.mindspore.flclient.BindMode;
import com.mindspore.flclient.Common;
import com.mindspore.flclient.FLClientStatus;
import com.mindspore.flclient.FLLiteClient;
import com.mindspore.flclient.FLParameter;
import com.mindspore.flclient.IFLJobResultCallback;
import com.mindspore.flclient.LocalFLParameter;
import com.mindspore.flclient.ServerMod;
import com.mindspore.flclient.common.FLLoggerGenerater;
import com.mindspore.flclient.model.Client;
import com.mindspore.flclient.model.ClientManager;
import com.mindspore.flclient.model.RunType;
import com.mindspore.flclient.model.Status;
import com.mindspore.flclient.pki.PkiUtil;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;

public class SyncFLJob {
    private static Logger LOGGER = FLLoggerGenerater.getModelLogger(SyncFLJob.class.toString());
    private FLParameter flParameter = FLParameter.getInstance();
    private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
    private IFLJobResultCallback flJobResultCallback;
    private FLClientStatus curStatus;
    private int tryTimePerIter = 0;
    private int lastIteration = -1;
    private int waitTryTime = 0;
    private HashSet<String> msVersions = new HashSet<String>(Arrays.asList("MindSpore Lite 1.9.0", "MindSpore Lite 2.0.0"));

    private void initFlIDForPkiVerify() {
        if (this.flParameter.isPkiVerify()) {
            LOGGER.info("pkiVerify mode is open!");
            String equipCertHash = PkiUtil.genEquipCertHash(this.flParameter.getClientID());
            if (equipCertHash == null || equipCertHash.isEmpty()) {
                LOGGER.severe("equipCertHash is empty, please check your mobile phone, only Huawei phones are supported now.");
                throw new IllegalArgumentException();
            }
            LOGGER.info("flID for pki verify is: " + equipCertHash);
            this.localFLParameter.setFlID(equipCertHash);
        } else {
            LOGGER.info("pkiVerify mode is not open!");
            this.localFLParameter.setFlID(this.flParameter.getClientID());
        }
    }

    public SyncFLJob() {
        try {
            LOGGER.info("the flName: " + this.flParameter.getFlName());
            Class.forName(this.flParameter.getFlName());
            String msVersion = Version.version();
            if (!this.msVersions.contains(msVersion)) {
                String expInfo = "Expect mindspore lite version in " + this.msVersions.toString() + ", but got incompatible mindspore lite version:" + msVersion;
                throw new RuntimeException(expInfo);
            }
            LOGGER.info("Got compatible mindspore lite version:" + this.msVersions);
        }
        catch (ClassNotFoundException e) {
            LOGGER.severe("catch ClassNotFoundException error, the set flName does not exist, please check: " + e.getMessage());
            throw new IllegalArgumentException();
        }
    }

    public FLClientStatus flJobRun() {
        this.flJobResultCallback = this.flParameter.getIflJobResultCallback();
        if ("android".equals(this.flParameter.getDeployEnv())) {
            Common.setSecureRandom(Common.getFastSecureRandom());
        } else {
            Common.setSecureRandom(new SecureRandom());
        }
        this.initFlIDForPkiVerify();
        this.localFLParameter.setMsConfig(0, this.flParameter.getThreadNum(), this.flParameter.getCpuBindMode(), false);
        Client client = ClientManager.getClient(this.flParameter.getFlName());
        Status modelInitRet = client.initModel(this.flParameter);
        if (modelInitRet != Status.SUCCESS) {
            LOGGER.severe("initModel failed");
            client.free();
            return FLClientStatus.FAILED;
        }
        FLLiteClient flLiteClient = new FLLiteClient();
        LOGGER.info("recovery StopJobFlag to false in the start of fl job");
        this.localFLParameter.setStopJobFlag(false);
        this.InitialParameters();
        LOGGER.info("flJobRun start");
        this.flRunLoop(flLiteClient);
        if (this.curStatus == FLClientStatus.SUCCESS) {
            client.saveModel(this.flParameter, this.localFLParameter);
        }
        LOGGER.info("flJobRun finish");
        this.flJobResultCallback.onFlJobFinished(this.flParameter.getFlName(), flLiteClient.getIterations(), flLiteClient.getRetCode());
        client.free();
        return this.curStatus;
    }

    private void flRunLoop(FLLiteClient flLiteClient) {
        while (!this.tryTimeExceedsLimit().booleanValue() && !this.checkStopJobFlag()) {
            LOGGER.info("flName: " + this.flParameter.getFlName());
            boolean initFlg = flLiteClient.initDataSets();
            if (!initFlg) {
                this.curStatus = FLClientStatus.FAILED;
                this.failed("unsolved error code in <flLiteClient.setInput>: the return trainDataSize<=0, setInput", flLiteClient);
                break;
            }
            this.curStatus = flLiteClient.startFLJob();
            if (this.curStatus == FLClientStatus.RESTART) {
                ++this.tryTimePerIter;
                this.resetContext("[startFLJob]", flLiteClient.getNextRequestTime(), flLiteClient);
            } else {
                if (this.curStatus != FLClientStatus.SUCCESS) {
                    this.failed("[startFLJob]", flLiteClient);
                    break;
                }
                LOGGER.info("[startFLJob] startFLJob succeed, curIteration: " + flLiteClient.getIteration());
                this.updateTryTimePerIter(flLiteClient);
                if (!this.checkEvalPath()) {
                    LOGGER.info("[evaluate] the data map set by user do not contain evaluation dataset, don't evaluate the model after getting model from server");
                } else {
                    this.curStatus = flLiteClient.evaluateModel();
                    if (this.curStatus != FLClientStatus.SUCCESS) {
                        this.failed("[evaluate] evaluate", flLiteClient);
                        break;
                    }
                    LOGGER.info("[evaluate] evaluate succeed");
                }
                Client client = ClientManager.getClient(this.flParameter.getFlName());
                client.EnableTrain(true);
                this.curStatus = flLiteClient.getFeatureMask();
                if (this.curStatus == FLClientStatus.RESTART) {
                    this.resetContext("[Encrypt] creatMask", flLiteClient.getNextRequestTime(), flLiteClient);
                } else {
                    if (this.curStatus != FLClientStatus.SUCCESS) {
                        this.failed("[Encrypt] createMask", flLiteClient);
                        break;
                    }
                    this.curStatus = flLiteClient.localTrain();
                    if (this.curStatus != FLClientStatus.SUCCESS) {
                        this.failed("[train] train", flLiteClient);
                        break;
                    }
                    LOGGER.info("[train] train succeed");
                    this.curStatus = flLiteClient.updateModel();
                    if (this.curStatus == FLClientStatus.RESTART) {
                        this.resetContext("[updateModel]", flLiteClient.getNextRequestTime(), flLiteClient);
                    } else {
                        if (this.curStatus != FLClientStatus.SUCCESS) {
                            this.failed("[updateModel] updateModel", flLiteClient);
                            break;
                        }
                        this.curStatus = flLiteClient.unMasking();
                        if (this.curStatus == FLClientStatus.RESTART) {
                            this.resetContext("[Encrypt] unmasking", flLiteClient.getNextRequestTime(), flLiteClient);
                        } else {
                            if (this.curStatus != FLClientStatus.SUCCESS) {
                                this.failed("[Encrypt] unmasking", flLiteClient);
                                break;
                            }
                            this.curStatus = this.getModel(flLiteClient);
                            if (this.curStatus == FLClientStatus.RESTART) {
                                this.resetContext("[getModel]", flLiteClient.getNextRequestTime(), flLiteClient);
                            } else {
                                if (this.curStatus != FLClientStatus.SUCCESS) {
                                    this.failed("[getModel] getModel", flLiteClient);
                                    break;
                                }
                                flLiteClient.updateDpNormClip();
                                LOGGER.info("========================================================the total response of " + flLiteClient.getIteration() + ": " + (Object)((Object)this.curStatus) + "======================================================================");
                                this.flJobResultCallback.onFlJobIterationFinished(this.flParameter.getFlName(), flLiteClient.getIteration(), flLiteClient.getRetCode());
                                this.tryTimePerIter = 0;
                            }
                        }
                    }
                }
            }
            if (flLiteClient.getIteration() < flLiteClient.getIterations()) continue;
        }
    }

    private void InitialParameters() {
        this.tryTimePerIter = 0;
        this.lastIteration = -1;
        this.waitTryTime = 0;
    }

    private Boolean tryTimeExceedsLimit() {
        if (this.tryTimePerIter > 1) {
            LOGGER.severe("[tryTimeExceedsLimit] the repeated request time exceeds the limit, current repeated request time is: " + this.tryTimePerIter + " the limited time is: " + 1);
            this.curStatus = FLClientStatus.FAILED;
            return true;
        }
        return false;
    }

    private void updateTryTimePerIter(FLLiteClient flLiteClient) {
        if (this.lastIteration != -1 && this.lastIteration == flLiteClient.getIteration()) {
            ++this.tryTimePerIter;
        } else {
            this.tryTimePerIter = 1;
            this.lastIteration = flLiteClient.getIteration();
        }
    }

    private Boolean waitTryTimeExceedsLimit() {
        if (this.waitTryTime > 18) {
            LOGGER.severe("[waitTryTimeExceedsLimit] the waitTryTime exceeds the limit, current waitTryTime is: " + this.waitTryTime + " the limited time is: " + 18);
            this.curStatus = FLClientStatus.FAILED;
            return true;
        }
        return false;
    }

    private FLClientStatus getModel(FLLiteClient flLiteClient) {
        FLClientStatus curStatus = flLiteClient.getModel();
        this.waitTryTime = 0;
        while (curStatus == FLClientStatus.WAIT) {
            ++this.waitTryTime;
            if (this.waitTryTimeExceedsLimit().booleanValue()) {
                curStatus = FLClientStatus.FAILED;
                break;
            }
            if (this.checkStopJobFlag()) {
                curStatus = FLClientStatus.FAILED;
                break;
            }
            this.waitSomeTime();
            curStatus = flLiteClient.getModel();
        }
        return curStatus;
    }

    private boolean checkEvalPath() {
        boolean tag = true;
        if (!this.flParameter.getDataMap().containsKey((Object)RunType.EVALMODE)) {
            LOGGER.info("[evaluate] the data map set by user do not contain evaluation dataset, don't evaluate the model after getting model from server");
            tag = false;
            return tag;
        }
        return tag;
    }

    private boolean checkStopJobFlag() {
        if (this.localFLParameter.isStopJobFlag()) {
            LOGGER.info("the stopJObFlag is set to true, the job will be stop");
            this.curStatus = FLClientStatus.FAILED;
            return true;
        }
        return false;
    }

    public List<Object> modelInfer() {
        Map<RunType, Integer> dataSize;
        Client client = ClientManager.getClient(this.flParameter.getFlName());
        this.localFLParameter.setMsConfig(0, this.flParameter.getThreadNum(), this.flParameter.getCpuBindMode(), false);
        this.localFLParameter.setStopJobFlag(false);
        if (null != this.flParameter.getInputShape()) {
            LOGGER.info("[model inference] the inference model has dynamic input.");
        }
        if ((dataSize = client.initDataSets(this.flParameter.getDataMap())).isEmpty()) {
            LOGGER.severe("[model inference] initDataSets failed, please check");
            client.free();
            return null;
        }
        Status modelInitRet = client.initModel(this.flParameter);
        if (modelInitRet != Status.SUCCESS) {
            LOGGER.severe("initModel failed");
            return null;
        }
        if (!client.EnableTrain(false)) {
            LOGGER.severe("[model inference] call EnableTrain failed");
            client.free();
            return null;
        }
        client.setBatchSize(this.flParameter.getBatchSize());
        LOGGER.info("===========model inference=============");
        List<Object> labels = client.inferModel();
        if (labels == null || labels.size() == 0) {
            LOGGER.severe("[model inference] the returned label from client.inferModel() is null, please check");
            client.free();
            return null;
        }
        LOGGER.fine("[model inference] the predicted outputs: " + Arrays.deepToString(labels.toArray()));
        client.free();
        LOGGER.info("[model inference] inference finish");
        return labels;
    }

    public FLClientStatus getModel() {
        if ("android".equals(this.flParameter.getDeployEnv())) {
            Common.setSecureRandom(Common.getFastSecureRandom());
        } else {
            Common.setSecureRandom(new SecureRandom());
        }
        this.localFLParameter.setServerMod(this.flParameter.getServerMod().toString());
        this.localFLParameter.setMsConfig(0, 1, 0, false);
        Client client = ClientManager.getClient(this.flParameter.getFlName());
        Status modelInitRet = client.initModel(this.flParameter);
        if (modelInitRet != Status.SUCCESS) {
            LOGGER.severe("initModel failed");
            client.free();
            return null;
        }
        FLLiteClient flLiteClient = new FLLiteClient();
        FLClientStatus status = flLiteClient.getModel();
        if (status == FLClientStatus.SUCCESS) {
            client.saveModel(this.flParameter, this.localFLParameter);
        }
        client.free();
        return status;
    }

    public void stopFLJob() {
        LOGGER.info("will stop the flJob");
        this.localFLParameter.setStopJobFlag(true);
        Common.notifyObject();
    }

    private void waitSomeTime() {
        if (this.flParameter.getSleepTime() != 0) {
            Common.sleep(this.flParameter.getSleepTime());
        } else {
            Common.sleep(10000L);
        }
    }

    private void waitNextReqTime(String nextReqTime) {
        long waitTime = Common.getWaitTime(nextReqTime);
        Common.sleep(waitTime);
    }

    private void resetContext(String tag, String nextReqTime, FLLiteClient flLiteClient) {
        LOGGER.info(tag + " out of time: need wait and request startFLJob again");
        this.waitNextReqTime(nextReqTime);
        this.flJobResultCallback.onFlJobIterationFinished(this.flParameter.getFlName(), flLiteClient.getIteration(), flLiteClient.getRetCode());
    }

    private void failed(String tag, FLLiteClient flLiteClient) {
        LOGGER.info(tag + " failed");
        LOGGER.info("=========================================the total response of " + flLiteClient.getIteration() + ": " + (Object)((Object)this.curStatus) + "=========================================");
        this.flJobResultCallback.onFlJobIterationFinished(this.flParameter.getFlName(), flLiteClient.getIteration(), flLiteClient.getRetCode());
    }

    private static Map<RunType, List<String>> createDatasetMap(String trainDataPath, String evalDataPath, String inferDataPath, String pathRegex) {
        HashMap<RunType, List<String>> dataMap = new HashMap<RunType, List<String>>();
        if (trainDataPath == null || "null".equals(trainDataPath) || trainDataPath.isEmpty()) {
            LOGGER.info("the trainDataPath is null or empty, please check if you are in the case of only inference");
        } else {
            dataMap.put(RunType.TRAINMODE, Arrays.asList(trainDataPath.split(pathRegex)));
            LOGGER.info("the trainDataPath: " + Arrays.toString(trainDataPath.split(pathRegex)));
        }
        if (evalDataPath == null || "null".equals(evalDataPath) || evalDataPath.isEmpty()) {
            LOGGER.info("the evalDataPath is null or empty, please check if you are in the case of only training without evaluation");
        } else {
            dataMap.put(RunType.EVALMODE, Arrays.asList(evalDataPath.split(pathRegex)));
            LOGGER.info("the evalDataPath: " + Arrays.toString(evalDataPath.split(pathRegex)));
        }
        if (inferDataPath == null || "null".equals(inferDataPath) || inferDataPath.isEmpty()) {
            LOGGER.info("the inferDataPath is null or empty, please check if you are in the case of training without inference");
        } else {
            dataMap.put(RunType.INFERMODE, Arrays.asList(inferDataPath.split(pathRegex)));
            LOGGER.info("the inferDataPath: " + Arrays.toString(inferDataPath.split(pathRegex)));
        }
        return dataMap;
    }

    private static void createWeightNameList(String trainWeightName, String inferWeightName, String nameRegex, FLParameter flParameter) {
        if (trainWeightName == null || "null".equals(trainWeightName) || trainWeightName.isEmpty()) {
            LOGGER.info("the trainWeightName is null or empty");
        } else {
            flParameter.setHybridWeightName(Arrays.asList(trainWeightName.split(nameRegex)), RunType.TRAINMODE);
            LOGGER.info("the trainWeightName: " + Arrays.toString(trainWeightName.split(nameRegex)));
        }
        if (inferWeightName == null || "null".equals(inferWeightName) || inferWeightName.isEmpty()) {
            LOGGER.info("the inferWeightName is null or empty");
        } else {
            flParameter.setHybridWeightName(Arrays.asList(inferWeightName.split(nameRegex)), RunType.INFERMODE);
            LOGGER.info("the inferWeightName: " + Arrays.toString(inferWeightName.split(nameRegex)));
        }
    }

    private static int[][] getInputShapeArray(String inputShape) {
        String[] inputs = inputShape.split(";");
        int inputsSize = inputs.length;
        int[][] inputsArray = new int[inputsSize][];
        for (int i = 0; i < inputsSize; ++i) {
            String[] input = inputs[i].split(",");
            int[] inputArray = Arrays.stream(input).mapToInt(Integer::parseInt).toArray();
            inputsArray[i] = inputArray;
        }
        return inputsArray;
    }

    private static void task(String[] args) {
        String trainDataPath = args[0];
        String evalDataPath = args[1];
        String inferDataPath = args[2];
        String pathRegex = args[3];
        String flName = args[4];
        String trainModelPath = args[5];
        String inferModelPath = args[6];
        String sslProtocol = args[7];
        String deployEnv = args[8];
        String domainName = args[9];
        String certPath = args[10];
        boolean useElb = Boolean.parseBoolean(args[11]);
        int serverNum = Integer.parseInt(args[12]);
        String task = args[13];
        int threadNum = Integer.parseInt(args[14]);
        String cpuBindMode = args[15];
        String trainWeightName = args[16];
        String inferWeightName = args[17];
        String nameRegex = args[18];
        String serverMod = args[19];
        String inputShape = args[21];
        int batchSize = Integer.parseInt(args[20]);
        FLParameter flParameter = FLParameter.getInstance();
        if (!"null".equals(inputShape) && inputShape != null) {
            flParameter.setInputShape(SyncFLJob.getInputShapeArray(inputShape));
        }
        Map<RunType, List<String>> dataMap = SyncFLJob.createDatasetMap(trainDataPath, evalDataPath, inferDataPath, pathRegex);
        SyncFLJob.createWeightNameList(trainWeightName, inferWeightName, nameRegex, flParameter);
        flParameter.setFlName(flName);
        SyncFLJob syncFLJob = new SyncFLJob();
        switch (task) {
            case "train": {
                LOGGER.info("start syncFLJob.flJobRun()");
                flParameter.setDataMap(dataMap);
                flParameter.setTrainModelPath(trainModelPath);
                flParameter.setInferModelPath(inferModelPath);
                flParameter.setSslProtocol(sslProtocol);
                flParameter.setDeployEnv(deployEnv);
                flParameter.setDomainName(domainName);
                if (Common.isHttps()) {
                    flParameter.setCertPath(certPath);
                }
                flParameter.setUseElb(useElb);
                flParameter.setServerNum(serverNum);
                flParameter.setThreadNum(threadNum);
                flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode));
                flParameter.setBatchSize(batchSize);
                syncFLJob.flJobRun();
                break;
            }
            case "inference": {
                LOGGER.info("start syncFLJob.modelInference()");
                flParameter.setDataMap(dataMap);
                flParameter.setInferModelPath(inferModelPath);
                flParameter.setThreadNum(threadNum);
                flParameter.setCpuBindMode(BindMode.valueOf(cpuBindMode));
                flParameter.setBatchSize(batchSize);
                syncFLJob.modelInfer();
                break;
            }
            case "getModel": {
                LOGGER.info("start syncFLJob.getModel()");
                flParameter.setTrainModelPath(trainModelPath);
                flParameter.setInferModelPath(inferModelPath);
                flParameter.setSslProtocol(sslProtocol);
                flParameter.setDeployEnv(deployEnv);
                flParameter.setDomainName(domainName);
                if (Common.isHttps()) {
                    flParameter.setCertPath(certPath);
                }
                flParameter.setUseElb(useElb);
                flParameter.setServerNum(serverNum);
                flParameter.setServerMod(ServerMod.valueOf(serverMod));
                syncFLJob.getModel();
                break;
            }
            default: {
                LOGGER.info("do not do any thing!");
            }
        }
    }

    public static void main(String[] args) {
        if (args[4] == null || args[4].isEmpty()) {
            LOGGER.severe("the parameter of <args[4]> is null, please check");
            throw new IllegalArgumentException();
        }
        SyncFLJob.task(args);
    }
}

