/*
 * Decompiled with CFR 0.152.
 */
package ai.grazie.rules.tree;

import ai.grazie.rules.tree.AccessedParameters;
import ai.grazie.rules.tree.Node;
import ai.grazie.rules.tree.NodeMatch;
import ai.grazie.rules.tree.NodePattern;
import ai.grazie.rules.tree.PatternHint;
import ai.grazie.rules.tree.Tree;
import ai.grazie.rules.tree.TreeCache;
import ai.grazie.rules.util.TransformingCharSequence;
import com.hankcs.algorithm.AhoCorasickDoubleArrayTrie;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;

public class PatternSet {
    private final Map<String, BitSet> byForm = new HashMap<String, BitSet>();
    private final Map<String, BitSet> bySomeForm = new HashMap<String, BitSet>();
    private final Map<String, BitSet> byLemma = new HashMap<String, BitSet>();
    private final Map<String, BitSet> bySomeLemma = new HashMap<String, BitSet>();
    private final Map<String, BitSet> byHeadRel = new HashMap<String, BitSet>();
    private final Map<String, BitSet> byDepRel = new HashMap<String, BitSet>();
    private final List<Integer> nonHinted = new ArrayList<Integer>();
    private final NodePattern[] patterns;
    private final List<String> ids;
    private final AhoCorasickDoubleArrayTrie<BitSet> substringTrie;
    private static final TreeCache<SentenceIndex> indexCache = new TreeCache<SentenceIndex>("textIndex", PatternSet::buildIndex){

        @Override
        protected boolean dependsOnlyOnStructure() {
            return true;
        }
    };

    public PatternSet(List<NodePattern> patterns, List<String> ids) {
        this.patterns = new NodePattern[patterns.size()];
        for (int i = 0; i < patterns.size(); ++i) {
            this.patterns[i] = patterns.get(i);
        }
        this.ids = ids;
        HashMap<String, BitSet> bySubstring = new HashMap<String, BitSet>();
        for (int i = 0; i < patterns.size(); ++i) {
            PatternHint hint = patterns.get((int)i).hint;
            if (hint.allowsEverything()) {
                this.nonHinted.add(i);
                continue;
            }
            List<PatternHint.Disjunct> disjuncts = hint.disjuncts();
            assert (!disjuncts.isEmpty());
            for (PatternHint.Disjunct disjunct : disjuncts) {
                if (disjunct.nodeForm() != null) {
                    PatternSet.addHints(this.byForm, disjunct.nodeForm(), i);
                    continue;
                }
                if (disjunct.nodeLemma() != null) {
                    PatternSet.addHints(this.byLemma, disjunct.nodeLemma(), i);
                    continue;
                }
                if (disjunct.nodeSubstring() != null) {
                    PatternSet.addHints(bySubstring, disjunct.nodeSubstring(), i);
                    continue;
                }
                if (disjunct.headRel() != null) {
                    PatternSet.addHints(this.byHeadRel, disjunct.headRel(), i);
                    continue;
                }
                if (disjunct.depRel() != null) {
                    PatternSet.addHints(this.byDepRel, disjunct.depRel(), i);
                    continue;
                }
                if (disjunct.someForm() != null) {
                    PatternSet.addHints(this.bySomeForm, disjunct.someForm(), i);
                    continue;
                }
                if (disjunct.someLemma() != null) {
                    PatternSet.addHints(this.bySomeLemma, disjunct.someLemma(), i);
                    continue;
                }
                throw new IllegalStateException("Every disjunct is expected to have some info");
            }
        }
        if (bySubstring.isEmpty()) {
            this.substringTrie = null;
        } else {
            this.substringTrie = new AhoCorasickDoubleArrayTrie();
            this.substringTrie.build(bySubstring);
        }
    }

    private static void addHints(Map<String, BitSet> map, String[] hints, int patternIndex) {
        for (String s : hints) {
            map.computeIfAbsent(s, __ -> new BitSet()).set(patternIndex);
        }
    }

    public List<IndexedMatch<NodeMatch>> match(List<Tree> trees) {
        ArrayList<IndexedMatch<NodeMatch>> result = new ArrayList<IndexedMatch<NodeMatch>>();
        for (Tree tree : trees) {
            result.addAll(this.matchTree(tree));
        }
        return result;
    }

    public List<IndexedMatch<NodeMatch>> matchTree(Tree tree) {
        List<Node> nodes = tree.nodes();
        int nodeCount = nodes.size();
        if (nodeCount == 0) {
            return List.of();
        }
        SentenceIndex index = tree.getCached(indexCache);
        BitSet matrix = new BitSet(this.patterns.length * nodeCount);
        PatternSet.fillMatrix(matrix, this.byForm, this.bySomeForm, nodeCount, index.wordNodes);
        PatternSet.fillMatrix(matrix, this.byLemma, this.bySomeLemma, nodeCount, index.lemmaNodes);
        PatternSet.fillMatrix(matrix, this.byHeadRel, Map.of(), nodeCount, index.headRelNodes);
        PatternSet.fillMatrix(matrix, this.byDepRel, Map.of(), nodeCount, index.depRelNodes);
        if (this.substringTrie != null) {
            this.substringTrie.parseText(PatternSet.normalizedText(tree), (begin, end, ruleIndices) -> {
                int nodeIndex = index.startToNodeIndex[begin];
                if (nodeIndex >= 0) {
                    int i = ruleIndices.nextSetBit(0);
                    while (i >= 0) {
                        matrix.set(i * nodeCount + nodeIndex);
                        i = ruleIndices.nextSetBit(i + 1);
                    }
                }
            });
        }
        for (int i : this.nonHinted) {
            matrix.set(i * nodeCount, i * nodeCount + nodeCount);
        }
        return this.matchByMatrix(nodes, nodeCount, matrix, tree);
    }

    private static CharSequence normalizedText(Tree tree) {
        String text = tree.text();
        CharSequence sb = null;
        for (Node node : tree.nodes()) {
            String form = node.form();
            String rawForm = node.rawForm();
            if (rawForm.equals(form) || rawForm.length() < form.length()) continue;
            if (sb == null) {
                sb = new StringBuilder(text);
            }
            ((StringBuilder)sb).replace(node.startOffset(), node.endOffset(), form + "\u0000".repeat(rawForm.length() - form.length()));
        }
        return TransformingCharSequence.lowerCase(sb != null ? sb : text);
    }

    private List<IndexedMatch<NodeMatch>> matchByMatrix(List<Node> nodes, int nodeCount, BitSet matrix, Tree tree) {
        AccessedParameters current = AccessedParameters.current();
        String prevId = current == null ? null : current.ruleId;
        int patternIndex = 0;
        try {
            ArrayList<IndexedMatch<NodeMatch>> results = new ArrayList<IndexedMatch<NodeMatch>>();
            int i = matrix.nextSetBit(0);
            while (i >= 0) {
                NodeMatch nodeMatch;
                patternIndex = i / nodeCount;
                Node node = nodes.get(i % nodeCount);
                NodePattern pattern = this.patterns[patternIndex];
                if (current != null) {
                    current.ruleId = this.ids.get(patternIndex);
                }
                if ((nodeMatch = pattern.matchWithPrevention(node, NodeMatch.EMPTY)) != null) {
                    results.add(new IndexedMatch<NodeMatch>(patternIndex, nodeMatch));
                }
                i = matrix.nextSetBit(i + 1);
            }
            ArrayList<IndexedMatch<NodeMatch>> arrayList = results;
            return arrayList;
        }
        catch (Throwable t) {
            throw new RuntimeException("While processing " + this.ids.get(patternIndex) + " in " + String.valueOf(tree.language()), t);
        }
        finally {
            if (current != null) {
                current.ruleId = prevId;
            }
        }
    }

    private static void fillMatrix(BitSet matrix, Map<String, BitSet> nodeMap, Map<String, BitSet> sentenceMap, int nodeCount, Map<String, BitSet> index) {
        for (Map.Entry<String, BitSet> entry : index.entrySet()) {
            int i;
            BitSet ruleIndices = nodeMap.get(entry.getKey());
            if (ruleIndices != null) {
                i = ruleIndices.nextSetBit(0);
                while (i >= 0) {
                    BitSet nodeIndices = entry.getValue();
                    int ni = nodeIndices.nextSetBit(0);
                    while (ni >= 0) {
                        matrix.set(i * nodeCount + ni);
                        ni = nodeIndices.nextSetBit(ni + 1);
                    }
                    i = ruleIndices.nextSetBit(i + 1);
                }
            }
            if ((ruleIndices = sentenceMap.get(entry.getKey())) == null) continue;
            i = ruleIndices.nextSetBit(0);
            while (i >= 0) {
                matrix.set(i * nodeCount, i * nodeCount + nodeCount);
                i = ruleIndices.nextSetBit(i + 1);
            }
        }
    }

    private static SentenceIndex buildIndex(Tree tree) {
        List<Node> nodes = tree.nodes();
        int nodeCount = nodes.size();
        int headRelCountEstimate = Math.min(20, nodeCount);
        HashMap<String, BitSet> wordNodes = new HashMap<String, BitSet>(nodeCount);
        HashMap<String, BitSet> lemmaNodes = new HashMap<String, BitSet>(nodeCount * 2);
        HashMap<String, BitSet> headRelNodes = new HashMap<String, BitSet>(headRelCountEstimate);
        HashMap<String, BitSet> depRelNodes = new HashMap<String, BitSet>(headRelCountEstimate);
        int[] startToNodeIndex = new int[tree.text().length()];
        Arrays.fill(startToNodeIndex, -1);
        Function<String, BitSet> newBitSet = __ -> new BitSet(nodeCount);
        for (int nodeIndex = 0; nodeIndex < nodeCount; ++nodeIndex) {
            Node node = nodes.get(nodeIndex);
            Arrays.fill(startToNodeIndex, node.startOffset(), node.endOffset(), nodeIndex);
            wordNodes.computeIfAbsent(node.lowForm(), newBitSet).set(nodeIndex);
            headRelNodes.computeIfAbsent(node.headRelation(), newBitSet).set(nodeIndex);
            int headIndex = node.headIndex;
            if (headIndex >= 0) {
                depRelNodes.computeIfAbsent(node.headRelation(), newBitSet).set(headIndex);
            }
            for (String lemma : node.lemmaReadings()) {
                lemmaNodes.computeIfAbsent(lemma.toLowerCase(Locale.ROOT), newBitSet).set(nodeIndex);
            }
        }
        return new SentenceIndex(wordNodes, lemmaNodes, headRelNodes, depRelNodes, startToNodeIndex);
    }

    private record SentenceIndex(Map<String, BitSet> wordNodes, Map<String, BitSet> lemmaNodes, Map<String, BitSet> headRelNodes, Map<String, BitSet> depRelNodes, int[] startToNodeIndex) {
    }

    public record IndexedMatch<T>(int index, T match) {
    }
}

