// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.codeInspection.dataFlow.memory;

import com.intellij.codeInspection.dataFlow.value.DfaVariableValue;
import com.intellij.codeInspection.dataFlow.value.RelationType;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongIterator;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.AbstractSet;
import java.util.Iterator;
import java.util.List;

public final class DistinctPairSet extends AbstractSet<DistinctPairSet.DistinctPair> {
  private final DfaMemoryStateImpl myState;
  private final LongOpenHashSet myData;

  public DistinctPairSet(DfaMemoryStateImpl state) {
    myState = state;
    myData = new LongOpenHashSet();
  }

  public DistinctPairSet(DfaMemoryStateImpl state, DistinctPairSet other) {
    myData = new LongOpenHashSet(other.myData);
    myState = state;
  }

  public boolean addOrdered(int firstIndex, int secondIndex) {
    LongOpenHashSet toAdd = new LongOpenHashSet();
    LongOpenHashSet toRemove = new LongOpenHashSet();
    toAdd.add(createPair(firstIndex, secondIndex, true));
    toRemove.add(createPair(firstIndex, secondIndex, false));
    for(DistinctPair pair : this) {
      if (!pair.isOrdered()) continue;
      if (pair.myFirst == secondIndex) {
        if (pair.mySecond == firstIndex || myData.contains(createPair(pair.mySecond, firstIndex, true))) return false;
        toAdd.add(createPair(firstIndex, pair.mySecond, true));
        toRemove.add(createPair(firstIndex, pair.mySecond, false));
      } else if (pair.mySecond == firstIndex) {
        if (myData.contains(createPair(secondIndex, pair.myFirst, true))) return false;
        toAdd.add(createPair(pair.myFirst, secondIndex, true));
        toRemove.add(createPair(pair.myFirst, secondIndex, false));
      }
    }
    myData.removeAll(toRemove);
    myData.addAll(toAdd);
    return true;
  }

  public void addUnordered(int firstIndex, int secondIndex) {
    if (!myData.contains(createPair(firstIndex, secondIndex, true)) &&
        !myData.contains(createPair(secondIndex, firstIndex, true))) {
      myData.add(createPair(firstIndex, secondIndex, false));
    }
  }

  @Override
  public boolean remove(Object o) {
    if (o instanceof DistinctPair dp) {
      return myData.remove(createPair(dp.myFirst, dp.mySecond, dp.myOrdered));
    }
    return false;
  }

  @Override
  public boolean contains(Object o) {
    if (!(o instanceof DistinctPair dp)) return false;
    EqClass first = dp.getFirst();
    EqClass second = dp.getSecond();
    if (first.isEmpty() || second.isEmpty()) return false;
    DfaVariableValue firstVal = first.getVariable(0);
    DfaVariableValue secondVal = second.getVariable(0);
    int firstIndex = myState.getEqClassIndex(firstVal);
    if (firstIndex == -1) return false;
    int secondIndex = myState.getEqClassIndex(secondVal);
    if (secondIndex == -1) return false;
    long pair = createPair(firstIndex, secondIndex, dp.isOrdered());
    return myData.contains(pair) && decode(pair).equals(dp);
  }

  @Override
  public Iterator<DistinctPair> iterator() {
    return new Iterator<>() {
      final LongIterator iterator = myData.iterator();

      @Override
      public boolean hasNext() {
        return iterator.hasNext();
      }

      @Override
      public DistinctPair next() {
        return decode(iterator.nextLong());
      }

      @Override
      public void remove() {
        iterator.remove();
      }
    };
  }

  @Override
  public int size() {
    return myData.size();
  }

  /**
   * Merge c2Index class into c1Index
   *
   * @param c1Index index of resulting class
   * @param c2Index index of class which becomes equivalent to c1Index
   * @return true if merge is successful, false if classes were distinct
   */
  public boolean unite(int c1Index, int c2Index) {
    LongArrayList c2Pairs = new LongArrayList();
    long[] distincts = myData.toLongArray();
    for (long distinct : distincts) {
      int pc1 = low(distinct);
      int pc2 = high(distinct);
      boolean addedToC1 = false;

      if (pc1 == c1Index || pc2 == c1Index) {
        addedToC1 = true;
        if (distinct < 0) {
          if (pc1 == c1Index && myData.contains(createPair(pc2, c2Index, true)) ||
              pc2 == c1Index && myData.contains(createPair(c2Index, pc1, true))) {
            return false;
          }
        }
      }

      if (pc1 == c2Index || pc2 == c2Index) {
        if (addedToC1) return false;
        c2Pairs.add(distinct);
      }
    }

    for (int i = 0; i < c2Pairs.size(); i++) {
      long c = c2Pairs.getLong(i);
      myData.remove(c);
      if (c >= 0) {
        myData.add(createPair(c1Index, low(c) == c2Index ? high(c) : low(c), false));
      }
      else if (low(c) == c2Index) {
        myData.add(createPair(c1Index, high(c), true));
      }
      else {
        myData.add(createPair(low(c), c1Index, true));
      }
    }
    return true;
  }

  public void splitClass(int index, int[] splitIndices) {
    LongArrayList toAdd = new LongArrayList();
    for(LongIterator iterator = myData.iterator(); iterator.hasNext(); ) {
      DistinctPair pair = decode(iterator.nextLong());
      if (pair.myFirst == index) {
        for (int splitIndex : splitIndices) {
          toAdd.add(createPair(splitIndex, pair.mySecond, pair.isOrdered()));
        }
        iterator.remove();
      } else if (pair.mySecond == index) {
        for (int splitIndex : splitIndices) {
          toAdd.add(createPair(pair.myFirst, splitIndex, pair.isOrdered()));
        }
        iterator.remove();
      }
    }
    myData.addAll(toAdd);
  }

  public boolean areDistinctUnordered(int c1Index, int c2Index) {
    return myData.contains(createPair(c1Index, c2Index, false));
  }

  public @Nullable RelationType getRelation(int c1Index, int c2Index) {
    if (areDistinctUnordered(c1Index, c2Index)) {
      return RelationType.NE;
    }
    if (myData.contains(createPair(c1Index, c2Index, true))) {
      return RelationType.LT;
    }
    if (myData.contains(createPair(c2Index, c1Index, true))) {
      return RelationType.GT;
    }
    return null;
  }

  private DistinctPair decode(long encoded) {
    boolean ordered = encoded < 0;
    encoded = Math.abs(encoded);
    return new DistinctPair(low(encoded), high(encoded), ordered, myState.myEqClasses);
  }

  public void dropOrder(DistinctPair pair) {
    if (remove(pair)) {
      addUnordered(pair.myFirst, pair.mySecond);
    }
  }

  private static long createPair(int low, int high, boolean ordered) {
    if (ordered) {
      return -(((long)high << 32) + low);
    }
    return low < high ? ((long)low << 32) + high : ((long)high << 32) + low;
  }

  private static int low(long l) {
    return (int)(Math.abs(l));
  }

  private static int high(long l) {
    return (int)((Math.abs(l) & 0xFFFFFFFF00000000L) >> 32);
  }

  public static final class DistinctPair {
    private final int myFirst;
    private final int mySecond;
    private final boolean myOrdered;
    private final List<EqClassImpl> myList;

    private DistinctPair(int first, int second, boolean ordered, List<EqClassImpl> list) {
      myFirst = first;
      mySecond = second;
      myOrdered = ordered;
      myList = list;
    }

    public @NotNull EqClass getFirst() {
      return myList.get(myFirst);
    }

    public int getFirstIndex() {
      return myFirst;
    }

    public @NotNull EqClass getSecond() {
      return myList.get(mySecond);
    }

    public int getSecondIndex() {
      return mySecond;
    }

    public void check() {
      if (myList.get(myFirst) == null) {
        throw new IllegalStateException(this + ": EqClass " + myFirst + " is missing");
      }
      if (myList.get(mySecond) == null) {
        throw new IllegalStateException(this + ": EqClass " + mySecond + " is missing");
      }
    }

    public boolean isOrdered() {
      return myOrdered;
    }

    public @Nullable EqClass getOtherClass(int eqClassIndex) {
      if (myFirst == eqClassIndex) {
        return getSecond();
      }
      if (mySecond == eqClassIndex) {
        return getFirst();
      }
      return null;
    }

    @Override
    public boolean equals(Object obj) {
      if (obj == this) return true;
      if (!(obj instanceof DistinctPair that)) return false;
      if (that.myOrdered != this.myOrdered) return false;
      return that.getFirst().equals(this.getFirst()) && that.getSecond().equals(this.getSecond()) ||
             (!myOrdered && that.getSecond().equals(this.getFirst()) && that.getFirst().equals(this.getSecond()));
    }

    @Override
    public int hashCode() {
      return getFirst().hashCode() * (myOrdered ? 31 : 1) + getSecond().hashCode();
    }

    @Override
    public String toString() {
      return "{" + myList.get(myFirst) + (myOrdered ? "<" : "!=") + myList.get(mySecond) + "}";
    }
  }
}
