/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.tracker;

import ai.djl.TrainingDivergedException;
import ai.djl.training.tracker.Tracker;

public final class WarmUpTracker
implements Tracker {
    Tracker mainTracker;
    int warmUpSteps;
    float warmUpBeginValue;
    float warmUpFinalValue;
    Mode warmUpMode;

    WarmUpTracker(Builder builder) {
        this.mainTracker = builder.mainTracker;
        this.warmUpSteps = builder.warmUpSteps;
        this.warmUpBeginValue = builder.warmUpBeginValue;
        this.warmUpMode = builder.warmUpMode;
        this.warmUpFinalValue = this.mainTracker.getNewValue(0);
    }

    public static Builder builder() {
        return new Builder();
    }

    float getWarmUpValue(int numUpdate) {
        float value = this.warmUpBeginValue;
        if (this.warmUpMode == Mode.LINEAR) {
            value = this.warmUpBeginValue + (this.warmUpFinalValue - this.warmUpBeginValue) * (float)numUpdate / (float)this.warmUpSteps;
        }
        this.checkValue(value);
        return value;
    }

    @Override
    public float getNewValue(int numUpdate) {
        if (numUpdate < this.warmUpSteps) {
            return this.getWarmUpValue(numUpdate);
        }
        return this.mainTracker.getNewValue(numUpdate - this.warmUpSteps);
    }

    void checkValue(float value) {
        if (Float.isNaN(value)) {
            throw new TrainingDivergedException("Value is Nan.");
        }
    }

    public static final class Builder {
        Tracker mainTracker;
        int warmUpSteps;
        float warmUpBeginValue;
        Mode warmUpMode = Mode.LINEAR;

        private Builder() {
        }

        public Builder setMainTracker(Tracker mainTracker) {
            this.mainTracker = mainTracker;
            return this;
        }

        public Builder optWarmUpSteps(int warmUpSteps) {
            this.warmUpSteps = warmUpSteps;
            return this;
        }

        public Builder optWarmUpBeginValue(float warmUpBeginValue) {
            this.warmUpBeginValue = warmUpBeginValue;
            return this;
        }

        public Builder optWarmUpMode(Mode warmUpMode) {
            this.warmUpMode = warmUpMode;
            return this;
        }

        public WarmUpTracker build() {
            return new WarmUpTracker(this);
        }
    }

    public static enum Mode {
        LINEAR,
        CONSTANT;

    }
}

