/*
 * Decompiled with CFR 0.152.
 */
package com.mindspore.flclient.model;

import com.mindspore.Graph;
import com.mindspore.MSTensor;
import com.mindspore.Model;
import com.mindspore.config.MSContext;
import com.mindspore.config.TrainCfg;
import com.mindspore.flclient.LocalFLParameter;
import com.mindspore.flclient.common.FLLoggerGenerater;
import com.mindspore.flclient.model.Callback;
import com.mindspore.flclient.model.Client;
import com.mindspore.flclient.model.DataSet;
import com.mindspore.flclient.model.LossCallback;
import com.mindspore.flclient.model.RunType;
import com.mindspore.flclient.model.Status;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import mindspore.fl.schema.FeatureMap;

public class ModelProxy {
    private static final Logger logger = FLLoggerGenerater.getModelLogger(Client.class.toString());
    private Model model;
    private Map<RunType, DataSet> dataSets = new HashMap<RunType, DataSet>();
    private final List<ByteBuffer> inputsBuffer = new ArrayList<ByteBuffer>();
    private List<MSTensor> inputs;
    private HashMap<String, MSTensor> featureMap = new HashMap();
    private float uploadLoss = 0.0f;

    public Model getModel() {
        return this.model;
    }

    public float getUploadLoss() {
        return this.uploadLoss;
    }

    public void setUploadLoss(float uploadLoss) {
        this.uploadLoss = uploadLoss;
    }

    public void free() {
        if (this.model != null) {
            this.inputs.forEach(MSTensor::free);
            this.featureMap.forEach((t, v) -> v.free());
            this.model.free();
            this.model = null;
        }
    }

    private MSContext getMsContext() {
        int deviceType = LocalFLParameter.getInstance().getDeviceType();
        int threadNum = LocalFLParameter.getInstance().getThreadNum();
        int cpuBindMode = LocalFLParameter.getInstance().getCpuBindMode();
        boolean enableFp16 = LocalFLParameter.getInstance().isEnableFp16();
        MSContext msContext = new MSContext();
        if (!msContext.init(threadNum, cpuBindMode)) {
            logger.severe("Call msContext.init failed, threadNum " + threadNum + ", cpuBindMode " + cpuBindMode);
            msContext.free();
            return null;
        }
        if (!msContext.addDeviceInfo(deviceType, enableFp16, 0)) {
            logger.severe("Call msContext.addDeviceInfo failed, deviceType " + deviceType + ", enableFp16 " + enableFp16);
            msContext.free();
            return null;
        }
        return msContext;
    }

    private boolean initModelWithoutShape(String modelPath, MSContext msContext) {
        TrainCfg trainCfg = new TrainCfg();
        if (!trainCfg.init()) {
            logger.severe("Call trainCfg.init failed ...");
            msContext.free();
            trainCfg.free();
            return false;
        }
        Graph graph = new Graph();
        if (!graph.load(modelPath)) {
            logger.severe("Call graph.load failed, modelPath: " + modelPath);
            graph.free();
            trainCfg.free();
            msContext.free();
            return false;
        }
        this.model = new Model();
        if (!this.model.build(graph, msContext, trainCfg)) {
            logger.severe("Call model.build failed ... ");
            graph.free();
            this.model.free();
            return false;
        }
        graph.free();
        this.inputs = this.model.getInputs();
        for (MSTensor input : this.inputs) {
            ByteBuffer inputBuffer = ByteBuffer.allocateDirect((int)input.size());
            inputBuffer.order(ByteOrder.nativeOrder());
            this.inputsBuffer.add(inputBuffer);
        }
        List<MSTensor> features = this.model.getFeatureMaps();
        for (MSTensor item : features) {
            this.featureMap.put(item.tensorName(), item);
        }
        return true;
    }

    private boolean initModelWithShape(String modelPath, MSContext msContext, int[][] inputShapes) {
        this.model = new Model();
        if (!this.model.build(modelPath, 0, msContext)) {
            logger.severe("Call model.build failed ... ");
            this.model.free();
            return false;
        }
        this.inputs = this.model.getInputs();
        boolean isSuccess = this.model.resize(this.inputs, inputShapes);
        if (!isSuccess) {
            this.model.free();
            logger.severe("session resize failed");
            return false;
        }
        for (int[] shapes : inputShapes) {
            int size = IntStream.of(shapes).reduce((a, b) -> a * b).getAsInt() * 4;
            ByteBuffer inputBuffer = ByteBuffer.allocateDirect(size);
            inputBuffer.order(ByteOrder.nativeOrder());
            this.inputsBuffer.add(inputBuffer);
        }
        List<MSTensor> features = this.model.getFeatureMaps();
        for (MSTensor item : features) {
            this.featureMap.put(item.tensorName(), item);
        }
        return true;
    }

    public Status initModel(String modelPath, int[][] inputShapes) {
        if (modelPath == null) {
            logger.severe("session init failed");
            return Status.FAILED;
        }
        MSContext msContext = this.getMsContext();
        if (msContext == null) {
            return Status.FAILED;
        }
        boolean initModelRet = inputShapes == null ? this.initModelWithoutShape(modelPath, msContext) : this.initModelWithShape(modelPath, msContext, inputShapes);
        return initModelRet ? Status.SUCCESS : Status.FAILED;
    }

    private void fillModelInput(DataSet dataSet, int batchIdx) {
        dataSet.fillInputBuffer(this.inputsBuffer, batchIdx);
        for (int i = 0; i < this.inputs.size(); ++i) {
            this.inputs.get(i).setData(this.inputsBuffer.get(i));
        }
    }

    public Status runModel(int epochs, List<Callback> callbacks, DataSet dataSet) {
        LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
        long startTime = System.currentTimeMillis();
        for (int i = 0; i < epochs; ++i) {
            for (int j = 0; j < dataSet.batchNum; ++j) {
                if (localFLParameter.isStopJobFlag()) {
                    logger.info("the stopJObFlag is set to true, the job will be stop");
                    return Status.FAILED;
                }
                this.fillModelInput(dataSet, j);
                boolean isSuccess = this.model.runStep();
                if (!isSuccess) {
                    logger.severe("run graph failed");
                    return Status.FAILED;
                }
                for (Callback callBack : callbacks) {
                    callBack.stepEnd();
                }
            }
            for (Callback callBack : callbacks) {
                callBack.epochEnd();
                if (!(callBack instanceof LossCallback) || i != epochs - 1) continue;
                LossCallback lossCallback = (LossCallback)callBack;
                this.setUploadLoss(lossCallback.getUploadLoss());
            }
        }
        long endTime = System.currentTimeMillis();
        logger.info("total run time:" + (endTime - startTime) + "ms");
        return Status.SUCCESS;
    }

    public Map<String, float[]> getFeatureMap() {
        HashMap<String, float[]> features = new HashMap<String, float[]>(this.featureMap.size());
        for (Map.Entry<String, MSTensor> entry : this.featureMap.entrySet()) {
            features.put(entry.getKey(), entry.getValue().getFloatData());
        }
        return features;
    }

    public float[] getFeature(String weightName) {
        if (this.featureMap.containsKey(weightName)) {
            return this.featureMap.get(weightName).getFloatData();
        }
        return null;
    }

    public Status updateFeatures(String modelName, List<FeatureMap> featureMaps) {
        if (this.model == null || featureMaps == null || modelName == null || modelName.isEmpty()) {
            logger.severe("trainSession,featureMaps modelName cannot be null");
            return Status.NULLPTR;
        }
        ArrayList tensors = new ArrayList(featureMaps.size());
        for (FeatureMap newFeature : featureMaps) {
            if (newFeature == null) {
                logger.severe("newFeature cannot be null");
                return Status.NULLPTR;
            }
            if (newFeature.weightFullname().isEmpty() || !this.featureMap.containsKey(newFeature.weightFullname())) {
                logger.severe("Can't get feature for name:" + newFeature.weightFullname());
                return Status.NULLPTR;
            }
            MSTensor tensor = this.featureMap.get(newFeature.weightFullname());
            ByteBuffer by = newFeature.dataAsByteBuffer();
            ByteBuffer newData = ByteBuffer.allocateDirect(by.remaining());
            newData.order(ByteOrder.nativeOrder());
            newData.put(by);
            if (tensor.setData(newData)) continue;
            logger.severe("Set tensor value failed, name:" + tensor.tensorName());
            return Status.FAILED;
        }
        this.model.export(modelName, 0, false, null);
        return Status.SUCCESS;
    }

    public Status updateFeature(FeatureMap newFeature) {
        if (newFeature == null) {
            logger.severe("newFeature cannot be null");
            return Status.NULLPTR;
        }
        if (newFeature.weightFullname().isEmpty() || !this.featureMap.containsKey(newFeature.weightFullname())) {
            logger.severe("Can't get feature for name:" + newFeature.weightFullname());
            return Status.NULLPTR;
        }
        MSTensor tensor = this.featureMap.get(newFeature.weightFullname());
        ByteBuffer by = newFeature.dataAsByteBuffer();
        ByteBuffer newData = ByteBuffer.allocateDirect(by.remaining());
        newData.order(ByteOrder.nativeOrder());
        newData.put(by);
        if (!tensor.setData(newData)) {
            logger.severe("Set tensor value failed, name:" + tensor.tensorName());
            return Status.FAILED;
        }
        return Status.SUCCESS;
    }
}

