/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.functions.sql.ml;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlCharStringLiteral;
import org.apache.calcite.sql.SqlDataTypeSpec;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlModelCall;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.SqlTableFunction;
import org.apache.calcite.sql.type.MapSqlType;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlOperandMetadata;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.validate.SqlNameMatcher;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorScope;
import org.apache.calcite.util.NlsString;
import org.apache.calcite.util.Util;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.api.config.MLPredictRuntimeConfigOptions;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.LogicalTypeFamily;
import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;
import org.apache.flink.types.Either;

public abstract class SqlMLTableFunction
extends SqlFunction
implements SqlTableFunction {
    private static final String TABLE_INPUT_ERROR = "SqlMLTableFunction must have only one table as first operand.";
    protected static final String PARAM_INPUT = "INPUT";
    protected static final String PARAM_MODEL = "MODEL";
    protected static final String PARAM_COLUMN = "ARGS";
    protected static final String PARAM_CONFIG = "CONFIG";

    public SqlMLTableFunction(String name, SqlOperandMetadata operandMetadata) {
        super(name, SqlKind.OTHER_FUNCTION, ReturnTypes.CURSOR, null, operandMetadata, SqlFunctionCategory.SYSTEM);
    }

    @Override
    public void validateCall(SqlCall call, SqlValidator validator, SqlValidatorScope scope, SqlValidatorScope operandScope) {
        assert (call.getOperator() == this);
        List<SqlNode> operandList = call.getOperandList();
        boolean foundSelect = false;
        for (SqlNode operand : operandList) {
            if (operand.getKind().equals((Object)SqlKind.DESCRIPTOR)) continue;
            if (operand.getKind().equals((Object)SqlKind.SET_SEMANTICS_TABLE)) {
                operand = ((SqlCall)operand).getOperandList().get(0);
                if (foundSelect) {
                    throw new ValidationException(TABLE_INPUT_ERROR);
                }
                foundSelect = true;
            }
            if (operand.getKind().equals((Object)SqlKind.SELECT)) {
                if (foundSelect) {
                    throw new ValidationException(TABLE_INPUT_ERROR);
                }
                foundSelect = true;
            }
            operand.validate(validator, scope);
        }
    }

    @Override
    public SqlReturnTypeInference getRowTypeInference() {
        return this::inferRowType;
    }

    protected abstract RelDataType inferRowType(SqlOperatorBinding var1);

    protected static Optional<RuntimeException> checkModelSignature(SqlCallBinding callBinding, int inputDescriptorIndex) {
        SqlValidator validator = callBinding.getValidator();
        if (!(callBinding.operand(1) instanceof SqlModelCall)) {
            return Optional.of(new ValidationException("Second operand must be a model identifier."));
        }
        SqlCall descriptorCall = (SqlCall)callBinding.operand(inputDescriptorIndex);
        List<SqlNode> descriptCols = descriptorCall.getOperandList();
        SqlModelCall modelCall = (SqlModelCall)callBinding.operand(1);
        RelDataType modelInputType = modelCall.getInputType(validator);
        if (descriptCols.size() != modelInputType.getFieldCount()) {
            return Optional.of(new ValidationException(String.format("Number of input descriptor columns (%d) does not match model input size (%d).", descriptCols.size(), modelInputType.getFieldCount())));
        }
        RelDataType tableType = validator.getValidatedNodeType(callBinding.operand(0));
        SqlNameMatcher matcher = validator.getCatalogReader().nameMatcher();
        for (int i = 0; i < descriptCols.size(); ++i) {
            Tuple3<Boolean, LogicalType, LogicalType> result = SqlMLTableFunction.checkModelDescriptorType(tableType, modelInputType.getFieldList().get(i).getType(), descriptCols.get(i), matcher);
            if (((Boolean)result.f0).booleanValue()) continue;
            return Optional.of(new ValidationException(String.format("Input descriptor column type %s cannot be assigned to model input type %s at position %d.", result.f1, result.f2, i)));
        }
        return Optional.empty();
    }

    protected static Tuple3<Boolean, LogicalType, LogicalType> checkModelDescriptorType(RelDataType tableType, RelDataType modelType, SqlNode descriptorNode, SqlNameMatcher matcher) {
        SqlIdentifier columnName = (SqlIdentifier)descriptorNode;
        String descriptorColName = columnName.isSimple() ? columnName.getSimple() : Util.last(columnName.names);
        int index = matcher.indexOf(tableType.getFieldNames(), descriptorColName);
        RelDataType sourceType = tableType.getFieldList().get(index).getType();
        LogicalType sourceLogicalType = FlinkTypeFactory.toLogicalType(sourceType);
        LogicalType targetLogicalType = FlinkTypeFactory.toLogicalType(modelType);
        return Tuple3.of((Object)LogicalTypeCasts.supportsImplicitCast((LogicalType)sourceLogicalType, (LogicalType)targetLogicalType), (Object)sourceLogicalType, (Object)targetLogicalType);
    }

    protected static Optional<RuntimeException> checkConfig(SqlCallBinding callBinding, SqlNode configNode) {
        if (!configNode.getKind().equals((Object)SqlKind.MAP_VALUE_CONSTRUCTOR)) {
            return Optional.of(new ValidationException("Config param should be a MAP."));
        }
        RelDataType mapType = callBinding.getValidator().getValidatedNodeType(configNode);
        assert (mapType instanceof MapSqlType);
        LogicalType keyType = FlinkTypeFactory.toLogicalType(mapType.getKeyType());
        LogicalType valueType = FlinkTypeFactory.toLogicalType(mapType.getValueType());
        if (!keyType.is(LogicalTypeFamily.CHARACTER_STRING) || !valueType.is(LogicalTypeFamily.CHARACTER_STRING)) {
            return Optional.of(new ValidationException(String.format("Config param can only be a MAP of string literals but node's type is %s at position %s.", mapType, configNode.getParserPosition())));
        }
        List<SqlNode> operands = ((SqlCall)configNode).getOperandList();
        HashMap<String, String> runtimeConfig = new HashMap<String, String>();
        for (int i = 0; i < operands.size(); i += 2) {
            Either<String, RuntimeException> key = SqlMLTableFunction.reduceLiteral(operands.get(i), callBinding.getValidator());
            Either<String, RuntimeException> value = SqlMLTableFunction.reduceLiteral(operands.get(i + 1), callBinding.getValidator());
            if (key.isRight()) {
                return Optional.of((RuntimeException)key.right());
            }
            if (value.isRight()) {
                return Optional.of((RuntimeException)value.right());
            }
            runtimeConfig.put((String)key.left(), (String)value.left());
        }
        return SqlMLTableFunction.checkConfigValue(runtimeConfig);
    }

    private static Optional<RuntimeException> checkConfigValue(Map<String, String> runtimeConfig) {
        Integer maxConcurrentOperations;
        Configuration config = Configuration.fromMap(runtimeConfig);
        try {
            MLPredictRuntimeConfigOptions.getSupportedOptions().forEach(arg_0 -> ((Configuration)config).get(arg_0));
        }
        catch (Throwable t) {
            return Optional.of(new ValidationException("Failed to parse the config.", t));
        }
        Boolean async = (Boolean)config.get(MLPredictRuntimeConfigOptions.ASYNC);
        if (Boolean.TRUE.equals(async) && (maxConcurrentOperations = (Integer)config.get(MLPredictRuntimeConfigOptions.ASYNC_MAX_CONCURRENT_OPERATIONS)) != null && maxConcurrentOperations <= 0) {
            return Optional.of(new ValidationException(String.format("Invalid runtime config option '%s'. Its value should be positive integer but was %s.", MLPredictRuntimeConfigOptions.ASYNC_MAX_CONCURRENT_OPERATIONS.key(), maxConcurrentOperations)));
        }
        return Optional.empty();
    }

    private static Either<String, RuntimeException> reduceLiteral(SqlNode operand, SqlValidator validator) {
        if (operand instanceof SqlCharStringLiteral) {
            return Either.Left((Object)((SqlCharStringLiteral)operand).getValueAs(NlsString.class).getValue());
        }
        if (operand.getKind() == SqlKind.CAST) {
            SqlCall call = (SqlCall)operand;
            SqlDataTypeSpec dataType = (SqlDataTypeSpec)call.operand(1);
            if (!FlinkTypeFactory.toLogicalType(dataType.deriveType(validator)).is(LogicalTypeFamily.CHARACTER_STRING)) {
                return Either.Right((Object)((Object)new ValidationException("Don't support to cast value to non-string type.")));
            }
            return SqlMLTableFunction.reduceLiteral(call.operand(0), validator);
        }
        return Either.Right((Object)((Object)new ValidationException(String.format("Unsupported expression %s is in runtime config at position %s. Currently, runtime config should be be a MAP of string literals.", operand, operand.getParserPosition()))));
    }
}

