/*
 * Decompiled with CFR 0.152.
 */
package cora.rwinduction.engine.deduction;

import charlie.terms.FunctionSymbol;
import charlie.terms.Term;
import charlie.terms.TheoryFactory;
import charlie.terms.Variable;
import charlie.terms.replaceable.MutableRenaming;
import charlie.terms.replaceable.Renaming;
import charlie.theorytranslation.TermAnalyser;
import charlie.util.Pair;
import cora.io.OutputModule;
import cora.rwinduction.engine.DeductionStep;
import cora.rwinduction.engine.Equation;
import cora.rwinduction.engine.EquationContext;
import cora.rwinduction.engine.PartialProof;
import cora.rwinduction.engine.ProofContext;
import cora.rwinduction.engine.ProofState;
import cora.rwinduction.engine.VariableNamer;
import cora.rwinduction.engine.deduction.CalcHelper;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Optional;

public final class DeductionCalcAll
extends DeductionStep {
    private Equation _newEquation;
    private Renaming _newRenaming;
    private Side _side;
    private int _numReplacements;

    private DeductionCalcAll(ProofState state, ProofContext context, Equation eq, Renaming ren, int numreps, Side side) {
        super(state, context);
        this._newEquation = eq;
        this._newRenaming = ren.makeImmutable();
        this._numReplacements = numreps;
        this._side = side;
    }

    @Override
    public ProofState tryApply(Optional<OutputModule> module) {
        EquationContext ctxt = this._equ.replace(this._newEquation, this._newRenaming, this._state.getLastUsedIndex() + 1);
        return this._state.replaceTopEquation(ctxt);
    }

    @Override
    public boolean verify(Optional<OutputModule> module) {
        return true;
    }

    @Override
    public String commandDescription() {
        return switch (this._side.ordinal()) {
            default -> throw new MatchException(null, null);
            case 0 -> "calc left";
            case 1 -> "calc right";
            case 2 -> "calc";
        };
    }

    @Override
    public void explain(OutputModule module) {
        String location;
        switch (this._side.ordinal()) {
            default: {
                throw new MatchException(null, null);
            }
            case 0: {
                String string = " in the left-hand side of the equation";
                break;
            }
            case 1: {
                String string = " in the right-hand side of the equation";
                break;
            }
            case 2: {
                String string = location = "";
            }
        }
        if (this._numReplacements == 1) {
            module.println("We use CALC at the only position%a where it is possible.", location);
        } else {
            module.println("We use CALC at all %a positions%a where it is possible.", this._numReplacements, location);
        }
    }

    public static DeductionCalcAll createStep(PartialProof proof, Optional<OutputModule> m, Side side) {
        EquationContext ec = DeductionCalcAll.getTopEquation(proof.getProofState(), m);
        if (ec == null) {
            return null;
        }
        MutableRenaming renaming = ec.getRenaming().copy();
        HashMap<Term, Variable> definedVars = CalcHelper.breakupConstraint(ec.getConstraint());
        VariableNamer namer = proof.getContext().getVariableNamer();
        Term left = ec.getLhs();
        Term right = ec.getRhs();
        ReplacementInfo info = new ReplacementInfo();
        if (side == Side.Left || side == Side.Both) {
            left = DeductionCalcAll.doCalculations(left, definedVars, renaming, info, namer, null);
        }
        if (side == Side.Right || side == Side.Both) {
            right = DeductionCalcAll.doCalculations(right, definedVars, renaming, info, namer, null);
        }
        if (info.count == 0) {
            m.ifPresent(o -> o.println("There are no calculatable subterms.", new Object[0]));
            return null;
        }
        Term constraint = DeductionCalcAll.buildConstraint(ec.getConstraint(), info);
        Equation neweq = new Equation(left, right, constraint);
        return new DeductionCalcAll(proof.getProofState(), proof.getContext(), neweq, renaming, info.count, side);
    }

    private static Term doCalculations(Term term, HashMap<Term, Variable> map, MutableRenaming renaming, ReplacementInfo info, VariableNamer namer, Pair<FunctionSymbol, Integer> parent) {
        if (CalcHelper.calculatable(term)) {
            ++info.count;
            if (term.isGround()) {
                return TermAnalyser.evaluate(term);
            }
            Term replacement = map.get(term);
            if (replacement != null) {
                return replacement;
            }
            Variable x = namer.chooseDerivativeForTerm(term, renaming, term.queryType().toString().substring(0, 1).toLowerCase(), parent);
            info.freshVars.add(x);
            info.varReplacements.add(term);
            map.put(term, x);
            return x;
        }
        return DeductionCalcAll.doCalculcationsRecurse(term, map, renaming, info, namer);
    }

    private static Term doCalculcationsRecurse(Term term, HashMap<Term, Variable> map, MutableRenaming renaming, ReplacementInfo info, VariableNamer namer) {
        Term head = term.queryHead();
        ArrayList<Term> args = null;
        for (int i = 1; i <= term.numberArguments(); ++i) {
            Pair<FunctionSymbol, Integer> parent;
            Term replacement;
            Term arg = term.queryArgument(i);
            if (arg != (replacement = DeductionCalcAll.doCalculations(arg, map, renaming, info, namer, parent = head.isConstant() ? new Pair<FunctionSymbol, Integer>(head.queryRoot(), i) : null)) && args == null) {
                args = new ArrayList<Term>();
                for (int j = 1; j < i; ++j) {
                    args.add(term.queryArgument(j));
                }
            }
            if (args == null) continue;
            args.add(replacement);
        }
        if (args == null) {
            return term;
        }
        return head.apply(args);
    }

    private static Term buildConstraint(Term base, ReplacementInfo info) {
        for (int i = 0; i < info.freshVars.size(); ++i) {
            base = TheoryFactory.createConjunction(base, TheoryFactory.createEquality(info.freshVars.get(i), info.varReplacements.get(i)));
        }
        return base;
    }

    public static enum Side {
        Left,
        Right,
        Both;

    }

    private static class ReplacementInfo {
        ArrayList<Term> freshVars = new ArrayList();
        ArrayList<Term> varReplacements = new ArrayList();
        int count = 0;

        ReplacementInfo() {
        }
    }
}

