/*
 * Decompiled with CFR 0.152.
 */
package cora.termination.dependency_pairs.processors;

import charlie.smt.BVar;
import charlie.smt.Constraint;
import charlie.smt.IVar;
import charlie.smt.IntegerExpression;
import charlie.smt.SmtFactory;
import charlie.smt.SmtProblem;
import charlie.smt.SmtSolver;
import charlie.smt.Valuation;
import charlie.substitution.MutableSubstitution;
import charlie.terms.FunctionSymbol;
import charlie.terms.Term;
import charlie.terms.TermFactory;
import charlie.terms.TheoryFactory;
import charlie.terms.Value;
import charlie.terms.Variable;
import charlie.theorytranslation.TermSmtTranslator;
import charlie.trs.TRS;
import charlie.trs.TrsProperties;
import charlie.types.Arrow;
import charlie.types.Type;
import charlie.types.TypeFactory;
import cora.config.Settings;
import cora.termination.dependency_pairs.DP;
import cora.termination.dependency_pairs.Problem;
import cora.termination.dependency_pairs.processors.IntegerMappingProof;
import cora.termination.dependency_pairs.processors.Processor;
import java.lang.runtime.SwitchBootstraps;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;

public class IntegerMappingProcessor
implements Processor {
    private SmtProblem _smt;
    private Map<FunctionSymbol, List<Variable>> _fnToFreshVar;
    private Map<FunctionSymbol, List<Term>> _candidates;

    public static String queryDisabledCode() {
        return "imap";
    }

    @Override
    public boolean isApplicable(Problem dp) {
        return !Settings.isDisabled(IntegerMappingProcessor.queryDisabledCode()) && dp.getOriginalTRS().verifyProperties(TrsProperties.Level.APPLICATIVE, TrsProperties.Constrained.YES, TrsProperties.TypeLevel.SIMPLE, TrsProperties.Lhs.PATTERN, TrsProperties.Root.THEORY, TrsProperties.FreshRight.CVARS, new TRS.RuleScheme[0]);
    }

    private Map<FunctionSymbol, List<Variable>> computeFreshVars(Problem dpp) {
        Set<FunctionSymbol> allSharps = dpp.getHeads();
        TreeMap<FunctionSymbol, List<Variable>> ret = new TreeMap<FunctionSymbol, List<Variable>>();
        allSharps.forEach(fSharp -> {
            Type ty = fSharp.queryType();
            ArrayList<Variable> newvars = new ArrayList<Variable>();
            int i = 1;
            while (true) {
                Type right;
                if (!(ty instanceof Arrow)) {
                    ret.put((FunctionSymbol)fSharp, (List<Variable>)newvars);
                    return;
                }
                Arrow $b$0 = (Arrow)ty;
                try {
                    Type patt2$temp;
                    Type patt1$temp;
                    Type left = patt1$temp = $b$0.left();
                    right = patt2$temp = $b$0.right();
                    newvars.add(TermFactory.createVar("arg_" + i, left));
                }
                catch (Throwable throwable) {
                    throw new MatchException(throwable.toString(), throwable);
                }
                ty = right;
                ++i;
            }
        });
        return ret;
    }

    private Term makePlus(Term a, Term b) {
        return TermFactory.createApp(TheoryFactory.plusSymbol, a, b);
    }

    private void initiateCandidates(Problem dpp) {
        Set<FunctionSymbol> allSharps = dpp.getHeads();
        this._candidates = new TreeMap<FunctionSymbol, List<Term>>();
        allSharps.forEach(fSharp -> this._candidates.put((FunctionSymbol)fSharp, new ArrayList()));
    }

    private void addSimpleCandidates() {
        this._candidates.forEach((fSharp, options) -> {
            for (Variable y : this._fnToFreshVar.get(fSharp).stream().filter(x -> x.queryType().equals(TypeFactory.intSort)).toList()) {
                options.add(y);
                options.add(this.makePlus(y, TheoryFactory.createValue(-1)));
                options.add(this.makePlus(y, TheoryFactory.createValue(1)));
            }
        });
    }

    private void addComplexCandidates(Problem dpp) {
        for (DP dp : dpp.getDPList()) {
            Term lhs = dp.lhs();
            FunctionSymbol root = lhs.queryRoot();
            MutableSubstitution subst = new MutableSubstitution();
            for (int i = 1; i <= lhs.numberArguments(); ++i) {
                Term argi = lhs.queryArgument(i);
                if (!argi.isVariable() || !argi.queryType().isTheoryType() || !argi.queryType().isBaseType()) continue;
                subst.extend(argi.queryVariable(), this._fnToFreshVar.get(root).get(i - 1));
            }
            ArrayList<Term> suggestions = new ArrayList<Term>();
            this.addAllComparisonFunctions(dp.constraint(), suggestions);
            for (Term s : suggestions) {
                boolean ok = true;
                for (Variable x : s.vars()) {
                    if (subst.get(x) != null) continue;
                    ok = false;
                    break;
                }
                if (!ok) continue;
                this._candidates.get(root).add(subst.substitute(s));
            }
        }
    }

    private void addAllComparisonFunctions(Term constraint, ArrayList<Term> ret) {
        Term minleft;
        Value minone;
        Term minright;
        if (!constraint.isFunctionalTerm()) {
            return;
        }
        FunctionSymbol f = constraint.queryRoot();
        if (f.equals(TheoryFactory.andSymbol) || f.equals(TheoryFactory.orSymbol)) {
            for (int i = 1; i <= constraint.numberArguments(); ++i) {
                this.addAllComparisonFunctions(constraint.queryArgument(i), ret);
            }
        }
        if (constraint.numberArguments() != 2) {
            return;
        }
        Term left = constraint.queryArgument(1);
        Term right = constraint.queryArgument(2);
        if (f.equals(TheoryFactory.geqSymbol) || f.equals(TheoryFactory.intEqualSymbol) || f.equals(TheoryFactory.greaterSymbol)) {
            if (right.equals(TheoryFactory.createValue(0))) {
                ret.add(left);
            } else {
                minright = TheoryFactory.minusSymbol.apply(right);
                ret.add(this.makePlus(left, minright));
            }
        }
        if (f.equals(TheoryFactory.greaterSymbol)) {
            minright = TheoryFactory.minusSymbol.apply(right);
            Term leftminright = this.makePlus(left, minright);
            if (right.equals(TheoryFactory.createValue(0))) {
                leftminright = left;
            }
            minone = TheoryFactory.createValue(-1);
            ret.add(this.makePlus(leftminright, minone));
        }
        if (f.equals(TheoryFactory.leqSymbol) || f.equals(TheoryFactory.intEqualSymbol) || f.equals(TheoryFactory.smallerSymbol)) {
            if (left.equals(TheoryFactory.createValue(0))) {
                ret.add(right);
            } else {
                minleft = TheoryFactory.minusSymbol.apply(left);
                ret.add(this.makePlus(right, minleft));
            }
        }
        if (f.equals(TheoryFactory.smallerSymbol)) {
            minleft = TheoryFactory.minusSymbol.apply(left);
            Term rightminleft = this.makePlus(right, minleft);
            if (left.equals(TheoryFactory.createValue(0))) {
                rightminleft = right;
            }
            minone = TheoryFactory.createValue(-1);
            ret.add(this.makePlus(rightminleft, minone));
        }
    }

    private boolean everyFunctionHasAtLeastOneCandidate() {
        for (List<Term> xs : this._candidates.values()) {
            if (xs.size() != 0) continue;
            return false;
        }
        return true;
    }

    private Term instantiateCandidate(Term candidate, Term term) {
        MutableSubstitution subst = new MutableSubstitution();
        FunctionSymbol f = term.queryRoot();
        for (int varL = 0; varL < f.queryArity(); ++varL) {
            subst.extend(this._fnToFreshVar.get(f).get(varL), term.queryArgument(varL + 1));
        }
        return subst.substitute(candidate);
    }

    private void filterCandidateList(Term term, Set<Variable> theoryVars) {
        ArrayList<Term> updatedCandidates = new ArrayList<Term>();
        FunctionSymbol fSharp = term.queryRoot();
        for (Term cand : this._candidates.get(fSharp)) {
            Term inst = this.instantiateCandidate(cand, term);
            if (!inst.isTheoryTerm()) continue;
            boolean badvar = false;
            for (Variable x : inst.vars()) {
                if (theoryVars.contains(x)) continue;
                badvar = true;
                break;
            }
            if (badvar) continue;
            updatedCandidates.add(cand);
        }
        this._candidates.replace(fSharp, updatedCandidates);
    }

    private void updateCandidates(Problem dpp) {
        for (DP dp : dpp.getDPList()) {
            Term lhs = dp.lhs();
            Term rhs = dp.rhs();
            Term ctr = dp.constraint();
            Set<Variable> V = dp.lvars();
            this.filterCandidateList(lhs, V);
            this.filterCandidateList(rhs, V);
        }
    }

    private Map<FunctionSymbol, IVar> generateIVars(Problem dpp) {
        Set<FunctionSymbol> allFns = dpp.getHeads();
        TreeMap<FunctionSymbol, IVar> retMap = new TreeMap<FunctionSymbol, IVar>();
        allFns.forEach(fSharp -> retMap.put((FunctionSymbol)fSharp, this._smt.createIntegerVariable()));
        return retMap;
    }

    private Map<DP, BVar> generateDpBVarMap(Problem dpp) {
        LinkedHashMap<DP, BVar> retMap = new LinkedHashMap<DP, BVar>(dpp.getDPList().size());
        dpp.getDPList().forEach(dp -> retMap.put((DP)dp, this._smt.createBooleanVariable()));
        return retMap;
    }

    private void requiresCtrs(Map<FunctionSymbol, IVar> intMap) {
        intMap.forEach((f, ivar) -> {
            int upperBound = this._candidates.get(f).size() - 1;
            this._smt.require(SmtFactory.createLeq(SmtFactory.createValue(0), ivar));
            this._smt.require(SmtFactory.createLeq(ivar, SmtFactory.createValue(upperBound)));
        });
    }

    private void requireAtLeastOneStrict(Map<DP, BVar> boolMap) {
        ArrayList<Constraint> disj = new ArrayList<Constraint>();
        for (BVar b : boolMap.values()) {
            disj.add(b);
        }
        this._smt.require(SmtFactory.createDisjunction(disj));
    }

    private void putDpRequirements(Map<FunctionSymbol, IVar> intMap, Map<DP, BVar> boolMap, Problem dpp) {
        for (DP dp : dpp.getDPList()) {
            Term lhs = dp.lhs();
            Term rhs = dp.rhs();
            Term ctr = dp.constraint();
            FunctionSymbol lhsHead = lhs.queryRoot();
            FunctionSymbol rhsHead = rhs.queryRoot();
            for (int i = 0; i < this._candidates.get(lhsHead).size(); ++i) {
                for (int j = 0; j < this._candidates.get(rhsHead).size(); ++j) {
                    if (lhsHead.equals(rhsHead) && i != j) continue;
                    Term instLi = this.instantiateCandidate(this._candidates.get(lhsHead).get(i), lhs);
                    Term instRj = this.instantiateCandidate(this._candidates.get(rhsHead).get(j), rhs);
                    SmtProblem validityProblem = new SmtProblem();
                    TermSmtTranslator tst = new TermSmtTranslator(validityProblem);
                    Constraint constraintTranslation = tst.translateConstraint(ctr);
                    IntegerExpression candLiExpr = tst.translateIntegerExpression(instLi);
                    IntegerExpression candRjExpr = tst.translateIntegerExpression(instRj);
                    Constraint fSharpDisjunction = SmtFactory.createDisjunction(SmtFactory.createUnequal(intMap.get(lhsHead), SmtFactory.createValue(i)), SmtFactory.createUnequal(intMap.get(rhsHead), SmtFactory.createValue(j)));
                    validityProblem.requireImplication(constraintTranslation, SmtFactory.createGeq(candLiExpr, candRjExpr));
                    if (!Settings.smtSolver.checkValidity(validityProblem)) {
                        this._smt.require(fSharpDisjunction);
                        continue;
                    }
                    validityProblem.clear();
                    validityProblem.requireImplication(constraintTranslation, SmtFactory.createConjunction(SmtFactory.createGeq(candLiExpr, SmtFactory.createValue(0)), SmtFactory.createGreater(candLiExpr, candRjExpr)));
                    if (Settings.smtSolver.checkValidity(validityProblem)) {
                        this._smt.require(SmtFactory.createDisjunction(fSharpDisjunction, (Constraint)boolMap.get(dp)));
                        continue;
                    }
                    this._smt.require(SmtFactory.createDisjunction(fSharpDisjunction, SmtFactory.createNegation(boolMap.get(dp))));
                }
            }
        }
    }

    /*
     * Loose catch block
     */
    @Override
    public IntegerMappingProof processDPP(Problem dpp) {
        Valuation result;
        this._smt = new SmtProblem();
        this._fnToFreshVar = this.computeFreshVars(dpp);
        this.initiateCandidates(dpp);
        this.addSimpleCandidates();
        this.addComplexCandidates(dpp);
        this.updateCandidates(dpp);
        if (!this.everyFunctionHasAtLeastOneCandidate()) {
            return new IntegerMappingProof(dpp);
        }
        Map<FunctionSymbol, IVar> intMap = this.generateIVars(dpp);
        this.requiresCtrs(intMap);
        Map<DP, BVar> boolMap = this.generateDpBVarMap(dpp);
        this.requireAtLeastOneStrict(boolMap);
        this.putDpRequirements(intMap, boolMap, dpp);
        SmtSolver.Answer answer = Settings.smtSolver.checkSatisfiability(this._smt);
        Objects.requireNonNull(answer);
        SmtSolver.Answer answer2 = answer;
        int n = 0;
        switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{SmtSolver.Answer.YES.class}, (Object)answer2, n)) {
            case 0: {
                Valuation valuation;
                Valuation val;
                SmtSolver.Answer.YES yES = (SmtSolver.Answer.YES)answer2;
                Valuation valuation2 = val = (valuation = yES.val());
                break;
            }
            default: {
                Valuation valuation2 = result = null;
            }
        }
        if (result == null) {
            return new IntegerMappingProof(dpp);
        }
        TreeSet<Integer> indexOfOrientedDPs = new TreeSet<Integer>();
        TreeMap<FunctionSymbol, Term> candFun = new TreeMap<FunctionSymbol, Term>();
        List<DP> originalDPs = dpp.getDPList();
        ArrayList<DP> remainingDPs = new ArrayList<DP>();
        intMap.forEach((f, ivar) -> candFun.put((FunctionSymbol)f, this._candidates.get(f).get(result.queryAssignment((IVar)ivar))));
        for (int index = 0; index < originalDPs.size(); ++index) {
            DP dp = originalDPs.get(index);
            BVar bvar = boolMap.get(dp);
            if (result.queryAssignment(bvar)) {
                indexOfOrientedDPs.add(index);
                continue;
            }
            remainingDPs.add(dp);
        }
        return new IntegerMappingProof(dpp, indexOfOrientedDPs, this._fnToFreshVar, candFun);
        catch (Throwable throwable) {
            throw new MatchException(throwable.toString(), throwable);
        }
    }
}

