/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.utils;

import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.util.mapping.IntPair;
import org.apache.flink.shaded.guava33.com.google.common.collect.Sets;
import org.apache.flink.table.catalog.Index;
import org.apache.flink.table.catalog.ResolvedSchema;
import org.apache.flink.table.connector.ChangelogMode;
import org.apache.flink.table.connector.source.DynamicTableSource;
import org.apache.flink.table.connector.source.LookupTableSource;
import org.apache.flink.table.functions.AsyncTableFunction;
import org.apache.flink.table.functions.UserDefinedFunction;
import org.apache.flink.table.planner.plan.nodes.exec.spec.DeltaJoinSpec;
import org.apache.flink.table.planner.plan.nodes.exec.spec.TemporalTableSourceSpec;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalExchange;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalIntermediateTableScan;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalJoin;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalRel;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalTableSourceScan;
import org.apache.flink.table.planner.plan.schema.IntermediateRelTable;
import org.apache.flink.table.planner.plan.schema.TableSourceTable;
import org.apache.flink.table.planner.plan.trait.DuplicateChanges;
import org.apache.flink.table.planner.plan.utils.ChangelogPlanUtils;
import org.apache.flink.table.planner.plan.utils.DuplicateChangesUtils;
import org.apache.flink.table.planner.plan.utils.FunctionCallUtils;
import org.apache.flink.table.planner.plan.utils.JoinTypeUtil;
import org.apache.flink.table.planner.plan.utils.LookupJoinUtil;
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
import org.apache.flink.table.runtime.functions.table.lookup.CachingAsyncLookupFunction;
import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
import org.apache.flink.table.runtime.operators.join.lookup.RetryableAsyncLookupFunctionDelegator;
import org.apache.flink.types.RowKind;
import org.apache.flink.util.Preconditions;

public class DeltaJoinUtil {
    private static final Set<Class<?>> ALL_SUPPORTED_DELTA_JOIN_UPSTREAM_NODES = Sets.newHashSet((Object[])new Class[]{StreamPhysicalTableSourceScan.class, StreamPhysicalExchange.class});

    private DeltaJoinUtil() {
    }

    public static boolean canConvertToDeltaJoin(StreamPhysicalJoin join) {
        FlinkJoinType flinkJoinType = JoinTypeUtil.getFlinkJoinType(join.getJoinType());
        if (!DeltaJoinUtil.isJoinTypeSupported(flinkJoinType)) {
            return false;
        }
        if (!DeltaJoinUtil.areJoinConditionsSupported(join)) {
            return false;
        }
        if (!DeltaJoinUtil.canJoinOutputDuplicateChanges(join)) {
            return false;
        }
        if (!DeltaJoinUtil.areAllInputsInsertOnly(join)) {
            return false;
        }
        if (!DeltaJoinUtil.areAllJoinInputsInWhiteList(join)) {
            return false;
        }
        return DeltaJoinUtil.areAllJoinTableScansSupported(join);
    }

    public static RelOptTable getTableScanRelOptTable(RelNode node) {
        return DeltaJoinUtil.getTableScan(node).getTable();
    }

    public static DeltaJoinSpec getDeltaJoinSpec(StreamPhysicalJoin join, boolean treatRightAsLookupSide) {
        RelOptTable lookupRelOptTable;
        JoinInfo joinInfo = join.analyzeCondition();
        RexBuilder rexBuilder = join.getCluster().getRexBuilder();
        RexNode condition = RexUtil.composeConjunction(rexBuilder, joinInfo.nonEquiConditions);
        Optional<Object> remainingCondition = condition.isAlwaysTrue() ? Optional.empty() : Optional.of(condition);
        List<IntPair> streamToLookupJoinKeys = joinInfo.pairs();
        if (treatRightAsLookupSide) {
            lookupRelOptTable = DeltaJoinUtil.getTableScanRelOptTable(join.getRight());
        } else {
            streamToLookupJoinKeys = DeltaJoinUtil.reverseIntPairs(streamToLookupJoinKeys);
            lookupRelOptTable = DeltaJoinUtil.getTableScanRelOptTable(join.getLeft());
        }
        Preconditions.checkState((boolean)(lookupRelOptTable instanceof TableSourceTable));
        TableSourceTable lookupTable = (TableSourceTable)lookupRelOptTable;
        Map<Integer, FunctionCallUtils.FunctionParam> allLookupKeys = DeltaJoinUtil.analyzerDeltaJoinLookupKeys(streamToLookupJoinKeys);
        return new DeltaJoinSpec(new TemporalTableSourceSpec(lookupTable), allLookupKeys, remainingCondition.orElse(null));
    }

    public static AsyncTableFunction<?> getUnwrappedAsyncLookupFunction(RelOptTable temporalTable, Collection<Integer> lookupKeys, ClassLoader classLoader) {
        UserDefinedFunction lookupFunction = LookupJoinUtil.getLookupFunction(temporalTable, lookupKeys, classLoader, true, null, false);
        boolean changed = true;
        while (changed) {
            if (lookupFunction instanceof CachingAsyncLookupFunction) {
                lookupFunction = ((CachingAsyncLookupFunction)temporalTable).getDelegate();
                continue;
            }
            if (lookupFunction instanceof RetryableAsyncLookupFunctionDelegator) {
                lookupFunction = ((RetryableAsyncLookupFunctionDelegator)temporalTable).getUserLookupFunction();
                continue;
            }
            changed = false;
        }
        if (!(lookupFunction instanceof AsyncTableFunction)) {
            throw new IllegalStateException(String.format("Table [%s] does not support async lookup. If the table supports the option of async lookup joins, add it to the with parameters of the DDL.", String.join((CharSequence)".", temporalTable.getQualifiedName())));
        }
        return (AsyncTableFunction)lookupFunction;
    }

    public static boolean isJoinTypeSupported(FlinkJoinType flinkJoinType) {
        return FlinkJoinType.INNER == flinkJoinType;
    }

    private static Map<Integer, FunctionCallUtils.FunctionParam> analyzerDeltaJoinLookupKeys(List<IntPair> streamToLookupJoinKeys) {
        LinkedHashMap<Integer, FunctionCallUtils.FunctionParam> allFieldRefLookupKeys = new LinkedHashMap<Integer, FunctionCallUtils.FunctionParam>();
        for (IntPair intPair : streamToLookupJoinKeys) {
            allFieldRefLookupKeys.put(intPair.target, new FunctionCallUtils.FieldRef(intPair.source));
        }
        return allFieldRefLookupKeys;
    }

    private static List<IntPair> reverseIntPairs(List<IntPair> intPairs) {
        return intPairs.stream().map(pair -> new IntPair(pair.target, pair.source)).collect(Collectors.toList());
    }

    private static int[][] getColumnIndicesOfAllTableIndexes(TableSourceTable tableSourceTable) {
        List<List<String>> columnsOfIndexes = DeltaJoinUtil.getAllIndexesColumnsOfTable(tableSourceTable);
        int[][] results = new int[columnsOfIndexes.size()][];
        for (int i = 0; i < columnsOfIndexes.size(); ++i) {
            List<String> fieldNames = tableSourceTable.getRowType().getFieldNames();
            results[i] = columnsOfIndexes.get(i).stream().mapToInt(fieldNames::indexOf).toArray();
        }
        return results;
    }

    private static List<List<String>> getAllIndexesColumnsOfTable(TableSourceTable tableSourceTable) {
        ResolvedSchema schema = tableSourceTable.contextResolvedTable().getResolvedSchema();
        List indexes = schema.getIndexes();
        return indexes.stream().map(Index::getColumns).collect(Collectors.toList());
    }

    private static boolean areJoinConditionsSupported(StreamPhysicalJoin join) {
        JoinInfo joinInfo = join.analyzeCondition();
        return !joinInfo.pairs().isEmpty();
    }

    private static boolean areAllJoinTableScansSupported(StreamPhysicalJoin join) {
        return DeltaJoinUtil.isTableScanSupported(DeltaJoinUtil.getTableScan(join.getLeft()), join.joinSpec().getLeftKeys()) && DeltaJoinUtil.isTableScanSupported(DeltaJoinUtil.getTableScan(join.getRight()), join.joinSpec().getRightKeys());
    }

    private static boolean isTableScanSupported(TableScan tableScan, int[] lookupKeys) {
        if (!(tableScan instanceof StreamPhysicalTableSourceScan)) {
            return false;
        }
        TableSourceTable tableSourceTable = ((StreamPhysicalTableSourceScan)tableScan).tableSourceTable();
        if (tableSourceTable.abilitySpecs().length != 0) {
            return false;
        }
        DynamicTableSource source = tableSourceTable.tableSource();
        if (!(source instanceof LookupTableSource)) {
            return false;
        }
        int[][] idxsOfAllIndexes = DeltaJoinUtil.getColumnIndicesOfAllTableIndexes(tableSourceTable);
        if (idxsOfAllIndexes.length == 0) {
            return false;
        }
        Set<Integer> lookupKeysSet = Arrays.stream(lookupKeys).boxed().collect(Collectors.toSet());
        for (int[] idxsOfIndex : idxsOfAllIndexes) {
            Preconditions.checkState((idxsOfIndex.length > 0 ? 1 : 0) != 0);
            boolean containsIndex = Arrays.stream(idxsOfIndex).allMatch(lookupKeysSet::contains);
            if (containsIndex) continue;
            return false;
        }
        return LookupJoinUtil.isAsyncLookup(tableSourceTable, lookupKeysSet, null, false, false);
    }

    private static TableScan getTableScan(RelNode node) {
        if ((node = DeltaJoinUtil.unwrapNode(node, true)) instanceof StreamPhysicalExchange) {
            return DeltaJoinUtil.getTableScan(((StreamPhysicalExchange)node).getInput());
        }
        Preconditions.checkState((boolean)(node instanceof TableScan));
        return (TableScan)node;
    }

    private static boolean areAllJoinInputsInWhiteList(RelNode node) {
        for (RelNode input : node.getInputs()) {
            if (!DeltaJoinUtil.isTheNodeInWhiteList(input = DeltaJoinUtil.unwrapNode(input, true))) {
                return false;
            }
            if (DeltaJoinUtil.areAllJoinInputsInWhiteList(input)) continue;
            return false;
        }
        return true;
    }

    private static boolean isTheNodeInWhiteList(RelNode node) {
        Class<?> nodeClazz = node.getClass();
        return ALL_SUPPORTED_DELTA_JOIN_UPSTREAM_NODES.contains(nodeClazz);
    }

    private static boolean canJoinOutputDuplicateChanges(StreamPhysicalJoin join) {
        DuplicateChanges duplicateChanges = DuplicateChangesUtils.getDuplicateChanges(join).orElseThrow(() -> new IllegalStateException(String.format("Unable to derive changelog mode from node %s. This is a bug.", join)));
        return DuplicateChanges.ALLOW.equals((Object)duplicateChanges);
    }

    private static boolean areAllInputsInsertOnly(StreamPhysicalJoin join) {
        for (RelNode input : join.getInputs()) {
            if (DeltaJoinUtil.isInsertOnly(DeltaJoinUtil.unwrapNode(input, false))) continue;
            return false;
        }
        return true;
    }

    private static boolean isInsertOnly(StreamPhysicalRel node) {
        ChangelogMode changelogMode = JavaScalaConversionUtil.toJava(ChangelogPlanUtils.getChangelogMode(node)).orElseThrow(() -> new IllegalStateException(String.format("Unable to derive changelog mode from node %s. This is a bug.", node)));
        return changelogMode.containsOnly(RowKind.INSERT);
    }

    private static StreamPhysicalRel unwrapNode(RelNode node, boolean transposeToChildBlock) {
        if (node instanceof HepRelVertex) {
            node = ((HepRelVertex)node).getCurrentRel();
        }
        if (node instanceof StreamPhysicalIntermediateTableScan && transposeToChildBlock) {
            IntermediateRelTable inputBlockOptimizedTree = (IntermediateRelTable)node.getTable();
            Preconditions.checkState((inputBlockOptimizedTree != null ? 1 : 0) != 0);
            node = inputBlockOptimizedTree.relNode();
        }
        return (StreamPhysicalRel)node;
    }
}

