/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.ndarray;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.util.Preconditions;
import java.util.Arrays;

public final class NDArrays {
    private NDArrays() {
    }

    private static void checkInputs(NDArray[] arrays) {
        if (arrays == null || arrays.length < 2) {
            throw new IllegalArgumentException("Passed in arrays must have at least one element");
        }
        if (arrays.length > 2 && Arrays.stream(arrays).skip(1L).anyMatch(array -> !arrays[0].shapeEquals((NDArray)array))) {
            throw new IllegalArgumentException("The shape of all inputs must be the same");
        }
    }

    public static boolean contentEquals(NDArray a, Number n) {
        if (a == null) {
            return false;
        }
        return a.contentEquals(n);
    }

    public static boolean contentEquals(NDArray a, NDArray b) {
        return a.contentEquals(b);
    }

    public static boolean shapeEquals(NDArray a, NDArray b) {
        return a.shapeEquals(b);
    }

    public static boolean allClose(NDArray a, NDArray b) {
        return a.allClose(b);
    }

    public static boolean allClose(NDArray a, NDArray b, double rtol, double atol, boolean equalNan) {
        return a.allClose(b, rtol, atol, equalNan);
    }

    public static NDArray eq(NDArray a, Number n) {
        return a.eq(n);
    }

    public static NDArray eq(Number n, NDArray a) {
        return NDArrays.eq(a, n);
    }

    public static NDArray eq(NDArray a, NDArray b) {
        return a.eq(b);
    }

    public static NDArray neq(NDArray a, Number n) {
        return a.neq(n);
    }

    public static NDArray neq(Number n, NDArray a) {
        return NDArrays.neq(a, n);
    }

    public static NDArray neq(NDArray a, NDArray b) {
        return a.neq(b);
    }

    public static NDArray gt(NDArray a, Number n) {
        return a.gt(n);
    }

    public static NDArray gt(Number n, NDArray a) {
        return a.lt(n);
    }

    public static NDArray gt(NDArray a, NDArray b) {
        return a.gt(b);
    }

    public static NDArray gte(NDArray a, Number n) {
        return a.gte(n);
    }

    public static NDArray gte(Number n, NDArray a) {
        return a.lte(n);
    }

    public static NDArray gte(NDArray a, NDArray b) {
        return a.gte(b);
    }

    public static NDArray lt(NDArray a, Number n) {
        return a.lt(n);
    }

    public static NDArray lt(Number n, NDArray a) {
        return a.gt(n);
    }

    public static NDArray lt(NDArray a, NDArray b) {
        return a.lt(b);
    }

    public static NDArray lte(NDArray a, Number n) {
        return a.lte(n);
    }

    public static NDArray lte(Number n, NDArray a) {
        return a.gte(n);
    }

    public static NDArray lte(NDArray a, NDArray b) {
        return a.lte(b);
    }

    public static NDArray where(NDArray condition, NDArray a, NDArray b) {
        return a.getNDArrayInternal().where(condition, b);
    }

    public static NDArray maximum(NDArray a, Number n) {
        return a.maximum(n);
    }

    public static NDArray maximum(Number n, NDArray a) {
        return NDArrays.maximum(a, n);
    }

    public static NDArray maximum(NDArray a, NDArray b) {
        return a.maximum(b);
    }

    public static NDArray minimum(NDArray a, Number n) {
        return a.minimum(n);
    }

    public static NDArray minimum(Number n, NDArray a) {
        return NDArrays.minimum(a, n);
    }

    public static NDArray minimum(NDArray a, NDArray b) {
        return a.minimum(b);
    }

    public static NDArray booleanMask(NDArray data, NDArray index) {
        return NDArrays.booleanMask(data, index, 0);
    }

    public static NDArray booleanMask(NDArray data, NDArray index, int axis) {
        return data.booleanMask(index, axis);
    }

    public static NDArray sequenceMask(NDArray data, NDArray sequenceLength, float value) {
        return data.sequenceMask(sequenceLength, value);
    }

    public static NDArray sequenceMask(NDArray data, NDArray sequenceLength) {
        return data.sequenceMask(sequenceLength);
    }

    public static NDArray add(NDArray a, Number n) {
        return a.add(n);
    }

    public static NDArray add(Number n, NDArray a) {
        return a.add(n);
    }

    public static NDArray add(NDArray ... arrays) {
        NDArrays.checkInputs(arrays);
        if (arrays.length == 2) {
            return arrays[0].add(arrays[1]);
        }
        try (NDArray array = NDArrays.stack(new NDList(arrays));){
            NDArray nDArray = array.sum(new int[]{0});
            return nDArray;
        }
    }

    public static NDArray sub(NDArray a, Number n) {
        return a.sub(n);
    }

    public static NDArray sub(Number n, NDArray a) {
        return a.getNDArrayInternal().rsub(n);
    }

    public static NDArray sub(NDArray a, NDArray b) {
        return a.sub(b);
    }

    public static NDArray mul(NDArray a, Number n) {
        return a.mul(n);
    }

    public static NDArray mul(Number n, NDArray a) {
        return a.mul(n);
    }

    public static NDArray mul(NDArray ... arrays) {
        NDArrays.checkInputs(arrays);
        if (arrays.length == 2) {
            return arrays[0].mul(arrays[1]);
        }
        try (NDArray array = NDArrays.stack(new NDList(arrays));){
            NDArray nDArray = array.prod(new int[]{0});
            return nDArray;
        }
    }

    public static NDArray div(NDArray a, Number n) {
        return a.div(n);
    }

    public static NDArray div(Number n, NDArray a) {
        return a.getNDArrayInternal().rdiv(n);
    }

    public static NDArray div(NDArray a, NDArray b) {
        return a.div(b);
    }

    public static NDArray mod(NDArray a, Number n) {
        return a.mod(n);
    }

    public static NDArray mod(Number n, NDArray a) {
        return a.getNDArrayInternal().rmod(n);
    }

    public static NDArray mod(NDArray a, NDArray b) {
        return a.mod(b);
    }

    public static NDArray pow(NDArray a, Number n) {
        return a.pow(n);
    }

    public static NDArray pow(Number n, NDArray a) {
        return a.getNDArrayInternal().rpow(n);
    }

    public static NDArray pow(NDArray a, NDArray b) {
        return a.pow(b);
    }

    public static NDArray addi(NDArray a, Number n) {
        return a.addi(n);
    }

    public static NDArray addi(Number n, NDArray a) {
        return a.addi(n);
    }

    public static NDArray addi(NDArray ... arrays) {
        NDArrays.checkInputs(arrays);
        Arrays.stream(arrays).skip(1L).forEachOrdered(array -> arrays[0].addi((NDArray)array));
        return arrays[0];
    }

    public static NDArray subi(NDArray a, Number n) {
        return a.subi(n);
    }

    public static NDArray subi(Number n, NDArray a) {
        return a.getNDArrayInternal().rsubi(n);
    }

    public static NDArray subi(NDArray a, NDArray b) {
        return a.subi(b);
    }

    public static NDArray muli(NDArray a, Number n) {
        return a.muli(n);
    }

    public static NDArray muli(Number n, NDArray a) {
        return a.muli(n);
    }

    public static NDArray muli(NDArray ... arrays) {
        NDArrays.checkInputs(arrays);
        Arrays.stream(arrays).skip(1L).forEachOrdered(array -> arrays[0].muli((NDArray)array));
        return arrays[0];
    }

    public static NDArray divi(NDArray a, Number n) {
        return a.divi(n);
    }

    public static NDArray divi(Number n, NDArray a) {
        return a.getNDArrayInternal().rdivi(n);
    }

    public static NDArray divi(NDArray a, NDArray b) {
        return a.divi(b);
    }

    public static NDArray modi(NDArray a, Number n) {
        return a.modi(n);
    }

    public static NDArray modi(Number n, NDArray a) {
        return a.getNDArrayInternal().rmodi(n);
    }

    public static NDArray modi(NDArray a, NDArray b) {
        return a.modi(b);
    }

    public static NDArray powi(NDArray a, Number n) {
        return a.powi(n);
    }

    public static NDArray powi(Number n, NDArray a) {
        return a.getNDArrayInternal().rpowi(n);
    }

    public static NDArray powi(NDArray a, NDArray b) {
        return a.powi(b);
    }

    public static NDArray dot(NDArray a, NDArray b) {
        return a.dot(b);
    }

    public static NDArray matMul(NDArray a, NDArray b) {
        return a.matMul(b);
    }

    public static NDArray stack(NDList arrays) {
        return NDArrays.stack(arrays, 0);
    }

    public static NDArray stack(NDList arrays, int axis) {
        Preconditions.checkArgument(arrays.size() > 0, "need at least one array to stack");
        NDArray array = arrays.head();
        return array.getNDArrayInternal().stack(arrays.subNDList(1), axis);
    }

    public static NDArray concat(NDList arrays) {
        return NDArrays.concat(arrays, 0);
    }

    public static NDArray concat(NDList arrays, int axis) {
        Preconditions.checkArgument(arrays.size() > 0, "need at least one array to concatenate");
        if (arrays.size() == 1) {
            return arrays.singletonOrThrow().duplicate();
        }
        NDArray array = arrays.head();
        return array.getNDArrayInternal().concat(arrays.subNDList(1), axis);
    }

    public static NDArray logicalAnd(NDArray a, NDArray b) {
        return a.logicalAnd(b);
    }

    public static NDArray logicalOr(NDArray a, NDArray b) {
        return a.logicalOr(b);
    }

    public static NDArray logicalXor(NDArray a, NDArray b) {
        return a.logicalXor(b);
    }

    public static NDArray erfinv(NDArray input) {
        return input.erfinv();
    }

    public static NDArray erf(NDArray input) {
        return input.erf();
    }
}

