/*
 * Decompiled with CFR 0.152.
 */
package cora.termination.transformation;

import charlie.substitution.Substitution;
import charlie.terms.FunctionSymbol;
import charlie.terms.Term;
import charlie.terms.TermFactory;
import charlie.terms.Variable;
import charlie.terms.position.ArgumentPos;
import charlie.terms.position.FinalPos;
import charlie.terms.position.Position;
import charlie.terms.replaceable.ReplaceableList;
import charlie.trs.Alphabet;
import charlie.trs.Rule;
import charlie.trs.TRS;
import charlie.trs.TrsFactory;
import charlie.trs.TrsProperties;
import charlie.types.Type;
import charlie.util.Pair;
import cora.io.OutputModule;
import cora.io.ProofObject;
import java.lang.invoke.CallSite;
import java.util.ArrayList;
import java.util.List;
import java.util.TreeMap;
import java.util.TreeSet;

public class HelperFunctionTransformer {
    private boolean _applicable;
    private TreeMap<FunctionSymbol, Integer> _ruleArity;
    private TRS _trs;

    public HelperFunctionTransformer(TRS trs) {
        this._trs = trs;
        this._applicable = trs.isLeftLinear() && trs.verifyProperties(TrsProperties.Level.APPLICATIVE, TrsProperties.Constrained.YES, TrsProperties.TypeLevel.SIMPLE, TrsProperties.Lhs.PATTERN, TrsProperties.Root.THEORY, TrsProperties.FreshRight.CVARS, new TRS.RuleScheme[0]);
        this.getSymbolArities();
    }

    private void getSymbolArities() {
        this._ruleArity = new TreeMap();
        for (int i = 0; i < this._trs.queryRuleCount(); ++i) {
            Rule rule = this._trs.queryRule(i);
            Term l = rule.queryLeftSide();
            if (!l.isFunctionalTerm()) {
                this._applicable = false;
                return;
            }
            FunctionSymbol f = l.queryRoot();
            int num = l.numberArguments();
            if (!this._ruleArity.containsKey(f)) {
                this._ruleArity.put(f, num);
                continue;
            }
            int k = this._ruleArity.get(f);
            if (k <= num) continue;
            this._ruleArity.put(f, num);
        }
    }

    Integer queryRuleArity(FunctionSymbol f) {
        return this._ruleArity.get(f);
    }

    private void add(ArrayList<Candidate> lst, FunctionSymbol g, int i, FunctionSymbol f, int n) {
        Candidate c = new Candidate(g, i, f, n);
        for (Candidate a : lst) {
            if (!a.equals(c)) continue;
            return;
        }
        lst.add(c);
    }

    private boolean isHeadOf(Term v, Term u) {
        if (!v.isFunctionalTerm() || !v.queryRoot().equals(u.queryRoot())) {
            return false;
        }
        if (u.numberArguments() <= v.numberArguments()) {
            return false;
        }
        return u.queryImmediateHeadSubterm(v.numberArguments()).equals(v);
    }

    ArrayList<Candidate> getReplacementCandidates() {
        ArrayList<Candidate> ret = new ArrayList<Candidate>();
        for (int j = 0; j < this._trs.queryRuleCount(); ++j) {
            Rule rule = this._trs.queryRule(j);
            Term left = rule.queryLeftSide();
            FunctionSymbol f = left.queryRoot();
            int k = left.numberArguments();
            rule.queryRightSide().visitSubterms((s, p) -> {
                if (!s.isFunctionalTerm()) {
                    return;
                }
                FunctionSymbol g = s.queryRoot();
                for (int i = 1; i <= s.numberArguments(); ++i) {
                    Term arg = s.queryArgument(i);
                    int n = arg.numberArguments();
                    if (!this.isHeadOf(arg, left) || n >= this._ruleArity.get(f)) continue;
                    this.add(ret, g, i, f, n);
                }
            });
        }
        return ret;
    }

    boolean checkCandidateSuitability(Candidate cand) {
        FunctionSymbol g = cand.below();
        FunctionSymbol f = cand.main();
        int i = cand.argument();
        for (int j = 0; j < this._trs.queryRuleCount(); ++j) {
            Term left = this._trs.queryRule(j).queryLeftSide();
            Term constraint = this._trs.queryRule(j).queryConstraint();
            if (null == left.findSubterm((s, p) -> {
                if (!s.isFunctionalTerm() || !s.queryRoot().equals(g)) {
                    return false;
                }
                if (s.numberArguments() < i) {
                    return true;
                }
                Term arg = s.queryArgument(i);
                if (arg.isVariable() || arg.isAbstraction()) {
                    return false;
                }
                if (arg.isFunctionalTerm()) {
                    return arg.queryRoot().equals(f);
                }
                return true;
            })) continue;
            return false;
        }
        return true;
    }

    TreeMap<FunctionSymbol, FunctionSymbol> createCopies(List<Candidate> cands) {
        TreeSet<FunctionSymbol> symbols = new TreeSet<FunctionSymbol>();
        TreeSet<CallSite> newnames = new TreeSet<CallSite>();
        TreeMap<FunctionSymbol, FunctionSymbol> ret = new TreeMap<FunctionSymbol, FunctionSymbol>();
        for (Candidate cand : cands) {
            symbols.add(cand.main());
        }
        for (FunctionSymbol f : symbols) {
            String newname = f.queryName() + "'";
            int i = 0;
            while (this._trs.lookupSymbol(newname) != null || newnames.contains(newname)) {
                newname = f.queryName() + "'" + i;
                ++i;
            }
            newnames.add((CallSite)((Object)newname));
            FunctionSymbol helper = TermFactory.createConstant(newname, f.queryType());
            ret.put(f, helper);
        }
        return ret;
    }

    private TreeSet<String> determinePrivate(TreeMap<FunctionSymbol, FunctionSymbol> copies) {
        TreeSet<String> ret = new TreeSet<String>(this._trs.queryPrivateSymbols());
        for (FunctionSymbol f : copies.values()) {
            ret.add(f.queryName());
        }
        return ret;
    }

    private ArrayList<Substitution> getReplacementSubstitutions(Term term, ReplaceableList okay, List<Candidate> cands) {
        ArrayList<Substitution> ret = new ArrayList<Substitution>();
        for (Pair<Term, Position> sub : term.querySubterms()) {
            Term subterm = sub.fst();
            if (!subterm.isFunctionalTerm()) continue;
            for (int j = 0; j < cands.size(); ++j) {
                Variable arg;
                Candidate cand = cands.get(j);
                if (!subterm.queryRoot().equals(cand.below()) || !okay.contains(arg = subterm.queryArgument(cand.argument()).queryVariable())) continue;
                Term replacement = cand.main();
                for (int i = 0; i < cand.numArgs(); ++i) {
                    String varname = "arg." + (j + 1) + "." + (i + 1);
                    Variable x = TermFactory.createVar(varname, replacement.queryType().subtype(1));
                    replacement = replacement.apply(x);
                }
                Substitution subst = Substitution.of(arg, replacement);
                ret.add(subst);
            }
        }
        return ret;
    }

    private void applyAllUpdates(ArrayList<Rule> rules, ArrayList<Substitution> substitutions) {
        for (Substitution subst : substitutions) {
            int n = rules.size();
            for (int i = 0; i < n; ++i) {
                Rule rule = rules.get(i);
                Term lhs = rule.queryLeftSide();
                Term rhs = rule.queryRightSide();
                Term lhssubst = subst.substitute(lhs);
                if (lhssubst.equals(lhs)) continue;
                Term rhssubst = subst.substitute(rhs);
                rules.add(TrsFactory.createRule(lhssubst, rhssubst, rule.queryConstraint()));
            }
        }
    }

    Term renameSymbolsInsideCandidates(Term term, List<Candidate> candidates, TreeMap<FunctionSymbol, FunctionSymbol> renamings) {
        for (Pair<Term, Position> p : term.querySubterms()) {
            Term subterm = p.fst();
            if (!subterm.isFunctionalTerm()) continue;
            FunctionSymbol g = subterm.queryRoot();
            for (Candidate cand : candidates) {
                Term arg;
                if (!g.equals(cand.below())) continue;
                int index = cand.argument();
                if (subterm.numberArguments() < index || !(arg = subterm.queryArgument(index)).isFunctionalTerm() || !arg.queryRoot().equals(cand.main())) continue;
                Position pos = p.snd().append(new ArgumentPos(index, new FinalPos(cand.numArgs())));
                term = term.replaceSubterm(pos, renamings.get(cand.main()));
            }
        }
        return term;
    }

    private void renameAll(ArrayList<Rule> rules, List<Candidate> candidates, TreeMap<FunctionSymbol, FunctionSymbol> renamings) {
        for (int i = 0; i < rules.size(); ++i) {
            Rule rule = rules.get(i);
            Term lhs = this.renameSymbolsInsideCandidates(rule.queryLeftSide(), candidates, renamings);
            Term rhs = this.renameSymbolsInsideCandidates(rule.queryRightSide(), candidates, renamings);
            Term constraint = rule.queryConstraint();
            rules.set(i, TrsFactory.createRule(lhs, rhs, constraint));
        }
    }

    ArrayList<Rule> getInstantiatedCopies(Rule rule, List<Candidate> candidates, TreeMap<FunctionSymbol, FunctionSymbol> renamings) {
        ArrayList<Substitution> updates = this.getReplacementSubstitutions(rule.queryLeftSide(), rule.queryRightSide().freeReplaceables(), candidates);
        ArrayList<Rule> rulecopies = new ArrayList<Rule>();
        rulecopies.add(rule);
        this.applyAllUpdates(rulecopies, updates);
        this.renameAll(rulecopies, candidates, renamings);
        return rulecopies;
    }

    public TRS computeReplacementTRS(List<Candidate> candidates, TreeMap<FunctionSymbol, FunctionSymbol> renamings) {
        Alphabet newalf = this._trs.queryAlphabet().add(renamings.values());
        TreeSet<String> newpriv = this.determinePrivate(renamings);
        ArrayList<Rule> newrules = new ArrayList<Rule>();
        for (int i = 0; i < this._trs.queryRuleCount(); ++i) {
            newrules.addAll(this.getInstantiatedCopies(this._trs.queryRule(i), candidates, renamings));
        }
        return TrsFactory.createTrs(newalf, newrules, newpriv, false, this._trs.theoriesIncluded() ? TrsFactory.LCSTRS : TrsFactory.LCTRS);
    }

    public TransformerProofObject transform() {
        if (!this._applicable) {
            return new TransformerProofObject(this, null, null, this._trs);
        }
        List<Candidate> candidates = this.getReplacementCandidates().stream().filter(c -> this.checkCandidateSuitability((Candidate)c)).toList();
        if (candidates.size() == 0) {
            return new TransformerProofObject(this, candidates, null, this._trs);
        }
        TreeMap<FunctionSymbol, FunctionSymbol> copies = this.createCopies(candidates);
        TRS t = this.computeReplacementTRS(candidates, copies);
        return new TransformerProofObject(this, candidates, copies, t);
    }

    record Candidate(FunctionSymbol below, int argument, FunctionSymbol main, int numArgs) {
        public boolean equals(Candidate other) {
            return this.below.equals(other.below) && this.argument == other.argument && this.main.equals(other.main) && this.numArgs == other.numArgs;
        }
    }

    public class TransformerProofObject
    implements ProofObject {
        private TRS _result;
        private List<Candidate> _candidates;
        private TreeMap<FunctionSymbol, FunctionSymbol> _renamings;

        private TransformerProofObject(HelperFunctionTransformer this$0, List<Candidate> candidates, TreeMap<FunctionSymbol, FunctionSymbol> renamings, TRS result) {
            this._candidates = candidates;
            this._renamings = renamings;
            this._result = result;
        }

        @Override
        public ProofObject.Answer queryAnswer() {
            if (this._candidates == null) {
                return ProofObject.Answer.NO;
            }
            if (this._candidates.size() == 0) {
                return ProofObject.Answer.MAYBE;
            }
            return ProofObject.Answer.YES;
        }

        public TRS queryResultingTRS() {
            return this._result;
        }

        @Override
        public void justify(OutputModule module) {
            if (this._candidates == null) {
                module.println("The TRS does not satisfy the conditions to apply the helper function transformation.", new Object[0]);
                return;
            }
            if (this._candidates.size() == 0) {
                module.println("The helper function transformation was not applied: I could not find any candidate positions to replace.", new Object[0]);
                return;
            }
            module.println("We observe that the TRS can be modified, without affecting termination, in the following way:", new Object[0]);
            module.startTable();
            for (Candidate cand : this._candidates) {
                FunctionSymbol base = cand.below();
                Type t = base.queryType();
                for (int i = 1; i < cand.argument(); ++i) {
                    base.apply(TermFactory.createVar("x{" + i + "}", t.subtype(1)));
                    t = t.subtype(2);
                }
                ArrayList<Term> subargs = new ArrayList<Term>();
                Type y = cand.main().queryType();
                for (int i = 1; i <= cand.numArgs(); ++i) {
                    subargs.add(TermFactory.createVar("y{" + i + "}", y.subtype(1)));
                    y = y.subtype(2);
                }
                Term original = base.apply(cand.main().apply(subargs));
                Term replacement = base.apply(this._renamings.get(cand.main()).apply(subargs));
                module.println("We replace all occurrences of %a by %a.", original, replacement);
            }
            module.endTable();
            module.print("This yields a(n) ", new Object[0]);
            module.printTrs(this._result);
        }
    }
}

