/*
 * Decompiled with CFR 0.152.
 */
package com.mindspore.flclient.model;

import com.mindspore.Model;
import com.mindspore.flclient.common.FLLoggerGenerater;
import com.mindspore.flclient.model.Callback;
import com.mindspore.flclient.model.Status;
import java.util.Map;
import java.util.logging.Logger;

public class LossCallback
extends Callback {
    private static final Logger logger = FLLoggerGenerater.getModelLogger(LossCallback.class.toString());
    private float lossSum = 0.0f;
    private float uploadLoss = 0.0f;

    public LossCallback(Model model) {
        super(model);
    }

    @Override
    public Status stepBegin() {
        return Status.SUCCESS;
    }

    @Override
    public Status stepEnd() {
        Map<String, float[]> outputs = this.getOutputsBySize(1);
        if (outputs.isEmpty()) {
            logger.severe("cannot find loss tensor");
            return Status.NULLPTR;
        }
        Map.Entry<String, float[]> first = outputs.entrySet().iterator().next();
        if (first.getValue().length < 1 || Float.isNaN(first.getValue()[0])) {
            logger.severe("loss is nan");
            return Status.FAILED;
        }
        float loss = first.getValue()[0];
        logger.info("batch:" + this.steps + ",loss:" + loss);
        this.lossSum += loss;
        ++this.steps;
        return Status.SUCCESS;
    }

    @Override
    public Status epochBegin() {
        return Status.SUCCESS;
    }

    @Override
    public Status epochEnd() {
        logger.info("----------epoch:" + this.epochs + ",average loss:" + this.lossSum / (float)this.steps + "----------");
        this.setUploadLoss(this.lossSum / (float)this.steps);
        this.steps = 0;
        ++this.epochs;
        this.lossSum = 0.0f;
        return Status.SUCCESS;
    }

    public float getUploadLoss() {
        return this.uploadLoss;
    }

    public void setUploadLoss(float uploadLoss) {
        this.uploadLoss = uploadLoss;
    }
}

