/*
 * Decompiled with CFR 0.152.
 */
package cora.termination.dependency_pairs.processors;

import charlie.substitution.Matcher;
import charlie.substitution.MutableSubstitution;
import charlie.terms.Environment;
import charlie.terms.FunctionSymbol;
import charlie.terms.Term;
import charlie.terms.TermFactory;
import charlie.terms.TheoryFactory;
import charlie.terms.Variable;
import charlie.terms.replaceable.MutableRenaming;
import charlie.terms.replaceable.Renaming;
import charlie.terms.replaceable.Replaceable;
import charlie.trs.TRS;
import charlie.util.Pair;
import cora.config.Settings;
import cora.io.OutputModule;
import cora.termination.dependency_pairs.DP;
import cora.termination.dependency_pairs.Problem;
import cora.termination.dependency_pairs.processors.Processor;
import cora.termination.dependency_pairs.processors.ProcessorProofObject;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;

public class ChainingProcessor
implements Processor {
    private final boolean _allowSelfChaining;

    public ChainingProcessor(boolean allowSelfChaining) {
        this._allowSelfChaining = allowSelfChaining;
    }

    public static String queryDisabledCode() {
        return "chaining";
    }

    @Override
    public boolean isApplicable(Problem dpp) {
        if (Settings.isDisabled(ChainingProcessor.queryDisabledCode())) {
            return false;
        }
        if (!dpp.isInnermost()) {
            return false;
        }
        return dpp.getDPList().size() >= (this._allowSelfChaining ? 1 : 2);
    }

    @Override
    public ProcessorProofObject processDPP(Problem dpp) {
        LinkedHashMap<FunctionSymbol, Set<DP>> rootToDP1s = new LinkedHashMap<FunctionSymbol, Set<DP>>();
        LinkedHashMap<FunctionSymbol, Set<DP>> rootToDP2s = new LinkedHashMap<FunctionSymbol, Set<DP>>();
        LinkedHashSet<FunctionSymbol> forbiddenDP2Roots = new LinkedHashSet<FunctionSymbol>();
        ChainingProcessor.classifyDPs(dpp, rootToDP1s, rootToDP2s, forbiddenDP2Roots);
        SortedSet<Pair<FunctionSymbol, Integer>> headsToTry = this.chooseHeadCandidates(rootToDP1s, rootToDP2s, forbiddenDP2Roots);
        block0: for (Pair pair : headsToTry) {
            FunctionSymbol chosenHead = (FunctionSymbol)pair.fst();
            Set dp1s = (Set)rootToDP1s.get(chosenHead);
            Set dp2s = (Set)rootToDP2s.get(chosenHead);
            LinkedHashMap<DP, Pair<DP, DP>> chainedToOldDPs = new LinkedHashMap<DP, Pair<DP, DP>>();
            for (DP dp1 : dp1s) {
                for (DP dp2 : dp2s) {
                    Optional<DP> dpNewOrEmpty = ChainingProcessor.chainDPs(dp1, dp2);
                    if (dpNewOrEmpty.isEmpty()) continue block0;
                    DP dpChained = dpNewOrEmpty.get();
                    chainedToOldDPs.put(dpChained, new Pair<DP, DP>(dp1, dp2));
                }
            }
            LinkedHashSet<DP> deletedDPs = new LinkedHashSet<DP>(dp1s);
            deletedDPs.addAll(dp2s);
            LinkedHashSet<DP> dpsResult = new LinkedHashSet<DP>(dpp.getDPList());
            dpsResult.removeAll(deletedDPs);
            dpsResult.addAll(chainedToOldDPs.keySet());
            Problem result = new Problem(new ArrayList<DP>(dpsResult), dpp.getRuleList(), null, dpp.getOriginalTRS(), dpp.isInnermost(), dpp.hasExtraRules(), dpp.queryTerminationStatus());
            return new ChainingProof(dpp, result, chainedToOldDPs, deletedDPs);
        }
        return new ChainingProof(dpp);
    }

    private static void classifyDPs(Problem dpp, Map<FunctionSymbol, Set<DP>> rootToDP1s, Map<FunctionSymbol, Set<DP>> rootToDP2s, Set<FunctionSymbol> forbiddenDP2Roots) {
        Set<FunctionSymbol> heads = dpp.getHeads();
        for (FunctionSymbol fSharp : heads) {
            rootToDP1s.put(fSharp, new LinkedHashSet());
            rootToDP2s.put(fSharp, new LinkedHashSet());
        }
        TRS trs = dpp.getOriginalTRS();
        TreeSet<FunctionSymbol> definedSymbols = trs.definedSymbols();
        List<DP> allDPs = dpp.getDPList();
        for (DP dp : allDPs) {
            Term lhs = dp.lhs();
            Term rhs = dp.rhs();
            FunctionSymbol lhsRoot = lhs.queryRoot();
            FunctionSymbol rhsRoot = rhs.queryRoot();
            LinkedHashSet<FunctionSymbol> rhsSymbols = new LinkedHashSet<FunctionSymbol>();
            rhs.storeFunctionSymbols(rhsSymbols);
            if (!Collections.disjoint(definedSymbols, rhsSymbols)) {
                forbiddenDP2Roots.add(rhsRoot);
            }
            Set<DP> dp1s = rootToDP1s.get(rhsRoot);
            dp1s.add(dp);
            Set<DP> dp2s = rootToDP2s.get(lhsRoot);
            dp2s.add(dp);
        }
    }

    private SortedSet<Pair<FunctionSymbol, Integer>> chooseHeadCandidates(Map<FunctionSymbol, Set<DP>> rootToDP1s, Map<FunctionSymbol, Set<DP>> rootToDP2s, Set<FunctionSymbol> forbiddenDP2Roots) {
        Comparator cmp = (p1, p2) -> {
            int p2snd;
            int p1snd = (Integer)p1.snd();
            return p1snd != (p2snd = ((Integer)p2.snd()).intValue()) ? p1snd - p2snd : ((FunctionSymbol)p1.fst()).compareTo((FunctionSymbol)p2.fst());
        };
        TreeSet<Pair<FunctionSymbol, Integer>> result = new TreeSet<Pair<FunctionSymbol, Integer>>(cmp);
        for (Map.Entry<FunctionSymbol, Set<DP>> rootToDP1Entry : rootToDP1s.entrySet()) {
            int size2;
            int size1;
            int sizeNew;
            FunctionSymbol root = rootToDP1Entry.getKey();
            if (forbiddenDP2Roots.contains(root)) continue;
            Set<DP> dp1s = rootToDP1Entry.getValue();
            Set<DP> dp2s = rootToDP2s.get(root);
            if (!this._allowSelfChaining && !Collections.disjoint(dp1s, dp2s) || (sizeNew = (size1 = dp1s.size()) * (size2 = dp2s.size())) == 0) continue;
            int sizeOld = size1 + size2;
            int sizeDiff = sizeNew - sizeOld;
            result.add(new Pair<FunctionSymbol, Integer>(root, sizeDiff));
        }
        return result;
    }

    private static Optional<DP> chainDPs(DP dp1, DP dp2) {
        dp2 = dp2.getRenamed();
        Term dp1Rhs = dp1.rhs();
        Term dp2Lhs = dp2.lhs();
        MutableSubstitution matcher = Matcher.match(dp2Lhs, dp1Rhs);
        if (matcher == null) {
            return Optional.empty();
        }
        LinkedHashSet<Replaceable> replaceables = new LinkedHashSet<Replaceable>(matcher.domain());
        replaceables.retainAll(dp2.lvars());
        for (Replaceable var : replaceables) {
            Term replacement = matcher.get(var);
            if (replacement.isTheoryTerm()) continue;
            return Optional.empty();
        }
        Term resultRhs = matcher.substitute(dp2.rhs());
        Term dp2ConstraintSubst = matcher.substitute(dp2.constraint());
        Term resultConstraint = TermFactory.createApp(TheoryFactory.andSymbol, dp1.constraint(), dp2ConstraintSubst);
        LinkedHashSet<Variable> resultTheoryVars = new LinkedHashSet<Variable>(dp1.lvars());
        for (Variable x : dp2.lvars()) {
            Term xReplacement = matcher.get(x);
            Environment<Variable> env = xReplacement.vars();
            for (Variable v : env) {
                resultTheoryVars.add(v);
            }
        }
        DP result = new DP(dp1.lhs(), resultRhs, resultConstraint, resultTheoryVars);
        return Optional.of(result);
    }

    private static class ChainingProof
    extends ProcessorProofObject {
        private final Map<DP, Pair<DP, DP>> _chainedToOriginalDPs;
        private final Set<DP> _deletedDPs;

        public ChainingProof(Problem input) {
            super(input);
            this._chainedToOriginalDPs = null;
            this._deletedDPs = null;
        }

        public ChainingProof(Problem input, Problem output, Map<DP, Pair<DP, DP>> chainedToOriginalDPs, Set<DP> deletedDPs) {
            super(input, output);
            this._chainedToOriginalDPs = chainedToOriginalDPs;
            this._deletedDPs = deletedDPs;
        }

        private Renaming getRenaming(OutputModule module, DP dp1, DP dp2, DP dp3) {
            LinkedHashSet<Variable> allvars = dp1.getAllVariables();
            Term[] vars = new Term[allvars.size()];
            int i = 0;
            for (Variable x : dp1.getAllVariables()) {
                vars[i++] = x;
            }
            MutableRenaming ret = module.generateUniqueNaming(vars);
            for (Variable x : dp2.getAllVariables()) {
                this.extend(ret, x);
            }
            for (Variable x : dp3.getAllVariables()) {
                this.extend(ret, x);
            }
            return ret;
        }

        private void extend(MutableRenaming renaming, Variable x) {
            if (renaming.getName(x) != null) {
                return;
            }
            Object name = x.queryName();
            while (!renaming.isAvailable((String)name)) {
                name = (String)name + "'";
            }
            renaming.setName(x, (String)name);
        }

        @Override
        public void justify(OutputModule module) {
            if (this._output == null) {
                module.println("No suitable chaining could be found.", new Object[0]);
                return;
            }
            module.println("We chain DPs according to the following mapping:", new Object[0]);
            module.println();
            module.startTable();
            this._chainedToOriginalDPs.forEach((c, p) -> {
                Renaming renaming = this.getRenaming(module, (DP)c, (DP)p.fst(), (DP)p.snd());
                module.nextColumn("%a", new Pair<DP, Renaming>((DP)c, renaming));
                module.nextColumn(" is obtained by chaining ", new Object[0]);
                module.nextColumn("%a", new Pair<DP, Renaming>((DP)p.fst(), renaming));
                module.nextColumn("and", new Object[0]);
                module.nextColumn("%a", new Pair<DP, Renaming>((DP)p.snd(), renaming));
                module.println();
            });
            module.endTable();
            module.println();
            module.println("The following DPs were deleted:", new Object[0]);
            this._deletedDPs.forEach(dp -> module.println("%a", dp));
            module.println();
            module.println("By chaining, we added " + this._chainedToOriginalDPs.size() + " DPs and removed " + this._deletedDPs.size() + " DPs.", new Object[0]);
        }

        @Override
        public String queryProcessorName() {
            return "Chaining Processor";
        }
    }
}

