/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.math.optimisers;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import org.tribuo.math.Parameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.Tensor;

public class AdaDelta
implements StochasticGradientOptimiser {
    @Config(description="Momentum value.")
    private double rho = 0.95;
    @Config(description="Epsilon for numerical stability.")
    private double epsilon = 1.0E-6;
    private Tensor[] gradsSquared;
    private Tensor[] velocitySquared;

    public AdaDelta(double rho, double epsilon) {
        this.rho = rho;
        this.epsilon = epsilon;
    }

    public AdaDelta(double epsilon) {
        this(0.95, epsilon);
    }

    public AdaDelta() {
        this(0.95, 1.0E-6);
    }

    @Override
    public void initialise(Parameters parameters) {
        this.gradsSquared = parameters.getEmptyCopy();
        this.velocitySquared = parameters.getEmptyCopy();
    }

    @Override
    public Tensor[] step(Tensor[] updates, double weight) {
        for (int i = 0; i < updates.length; ++i) {
            this.gradsSquared[i].scaleInPlace(this.rho);
            this.gradsSquared[i].intersectAndAddInPlace(updates[i], a -> a * a * (1.0 - this.rho));
            updates[i].hadamardProductInPlace(this.velocitySquared[i], a -> Math.sqrt(a + this.epsilon));
            updates[i].hadamardProductInPlace(this.gradsSquared[i], a -> 1.0 / Math.sqrt(a + this.epsilon));
            this.velocitySquared[i].scaleInPlace(this.rho);
            this.velocitySquared[i].intersectAndAddInPlace(updates[i], a -> a * a * (1.0 - this.rho));
        }
        return updates;
    }

    public String toString() {
        return "AdaDelta(rho=" + this.rho + ",epsilon=" + this.epsilon + ")";
    }

    @Override
    public void reset() {
        this.gradsSquared = null;
        this.velocitySquared = null;
    }

    @Override
    public AdaDelta copy() {
        return new AdaDelta(this.rho, this.epsilon);
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "StochasticGradientOptimiser");
    }
}

