// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.compiler.notNullVerification;

import com.intellij.compiler.instrumentation.FailSafeMethodVisitor;
import org.jetbrains.org.objectweb.asm.AnnotationVisitor;
import org.jetbrains.org.objectweb.asm.ClassReader;
import org.jetbrains.org.objectweb.asm.ClassVisitor;
import org.jetbrains.org.objectweb.asm.Handle;
import org.jetbrains.org.objectweb.asm.Label;
import org.jetbrains.org.objectweb.asm.MethodVisitor;
import org.jetbrains.org.objectweb.asm.Type;
import org.jetbrains.org.objectweb.asm.TypePath;
import org.jetbrains.org.objectweb.asm.TypeReference;

import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;

import static org.jetbrains.org.objectweb.asm.Opcodes.ACC_BRIDGE;
import static org.jetbrains.org.objectweb.asm.Opcodes.ACC_ENUM;
import static org.jetbrains.org.objectweb.asm.Opcodes.ACC_FINAL;
import static org.jetbrains.org.objectweb.asm.Opcodes.ACC_PRIVATE;
import static org.jetbrains.org.objectweb.asm.Opcodes.ACC_STATIC;
import static org.jetbrains.org.objectweb.asm.Opcodes.ALOAD;
import static org.jetbrains.org.objectweb.asm.Opcodes.ANEWARRAY;
import static org.jetbrains.org.objectweb.asm.Opcodes.API_VERSION;
import static org.jetbrains.org.objectweb.asm.Opcodes.ARETURN;
import static org.jetbrains.org.objectweb.asm.Opcodes.CHECKCAST;
import static org.jetbrains.org.objectweb.asm.Opcodes.DUP;
import static org.jetbrains.org.objectweb.asm.Opcodes.DUP2;
import static org.jetbrains.org.objectweb.asm.Opcodes.DUP2_X1;
import static org.jetbrains.org.objectweb.asm.Opcodes.DUP2_X2;
import static org.jetbrains.org.objectweb.asm.Opcodes.DUP_X1;
import static org.jetbrains.org.objectweb.asm.Opcodes.DUP_X2;
import static org.jetbrains.org.objectweb.asm.Opcodes.GOTO;
import static org.jetbrains.org.objectweb.asm.Opcodes.IFNONNULL;
import static org.jetbrains.org.objectweb.asm.Opcodes.IINC;
import static org.jetbrains.org.objectweb.asm.Opcodes.INVOKEDYNAMIC;
import static org.jetbrains.org.objectweb.asm.Opcodes.INVOKESPECIAL;
import static org.jetbrains.org.objectweb.asm.Opcodes.INVOKESTATIC;
import static org.jetbrains.org.objectweb.asm.Opcodes.INVOKEVIRTUAL;
import static org.jetbrains.org.objectweb.asm.Opcodes.JSR;
import static org.jetbrains.org.objectweb.asm.Opcodes.LDC;
import static org.jetbrains.org.objectweb.asm.Opcodes.LOOKUPSWITCH;
import static org.jetbrains.org.objectweb.asm.Opcodes.MULTIANEWARRAY;
import static org.jetbrains.org.objectweb.asm.Opcodes.NEW;
import static org.jetbrains.org.objectweb.asm.Opcodes.NEWARRAY;
import static org.jetbrains.org.objectweb.asm.Opcodes.NOP;
import static org.jetbrains.org.objectweb.asm.Opcodes.RET;
import static org.jetbrains.org.objectweb.asm.Opcodes.TABLESWITCH;

public final class NotNullVerifyingInstrumenter extends ClassVisitor {
  private static final String IAE_CLASS_NAME = "java/lang/IllegalArgumentException";
  private static final String ISE_CLASS_NAME = "java/lang/IllegalStateException";
  private static final String KOTLIN_METADATA_ANNOTATION_CLASS_DESCRIPTOR = "Lkotlin/Metadata;";

  private static final String ANNOTATION_DEFAULT_METHOD = "value";

  @SuppressWarnings("SSBasedInspection")
  private static final String[] EMPTY_STRING_ARRAY = new String[0];

  private final MethodData myMethodData;
  private boolean myIsModification = false;
  private RuntimeException myPostponedError;
  private final AuxiliaryMethodGenerator myAuxGenerator;

  private NotNullVerifyingInstrumenter(ClassVisitor classVisitor, ClassReader reader, String[] notNullAnnotations) {
    super(API_VERSION, classVisitor);
    Set<String> annoSet = new HashSet<>();
    for (String annotation : notNullAnnotations) {
      annoSet.add('L' + annotation.replace('.', '/') + ';');
    }
    myMethodData = collectMethodData(reader, annoSet);
    myAuxGenerator = new AuxiliaryMethodGenerator(reader);
  }

  public static boolean processClassFile(ClassReader reader, ClassVisitor writer, String[] notNullAnnotations) {
    NotNullVerifyingInstrumenter instrumenter = new NotNullVerifyingInstrumenter(writer, reader, notNullAnnotations);
    if (instrumenter.myMethodData.myIsKotlinBytecode) {
      // skip Kotlin-generated bytecode, as nullability assertions are handled on compiler level by kotlinc
      return false;
    }
    reader.accept(instrumenter, 0);
    return instrumenter.myIsModification;
  }

  private static class MethodInfo {
    final NotNullState nullability = new NotNullState();
    final Map<Integer, String> paramNames = new HashMap<>();
    final Map<Integer, NotNullState> paramNullability = new LinkedHashMap<>();
    boolean isStable;
    int paramAnnotationOffset;

    NotNullState obtainParameterNullability(int index) {
      NotNullState state = paramNullability.get(index);
      if (state == null) {
        state = new NotNullState();
        paramNullability.put(index, state);
      }
      return state;
    }
  }

  private static final class MethodData {
    private String myClassName;
    private boolean myIsKotlinBytecode;
    private final Map<String, MethodInfo> myMethodInfos = new HashMap<>();

    static String key(String methodName, String desc) {
      return methodName + desc;
    }

    String lookupParamName(String methodName, String desc, Integer num) {
      MethodInfo info = myMethodInfos.get(key(methodName, desc));
      Map<Integer, String> names = info == null ? null : info.paramNames;
      return names != null ? names.get(num) : null;
    }

    boolean isAlwaysNotNull(String className, String methodName, String desc) {
      if (myClassName.equals(className)) {
        MethodInfo info = myMethodInfos.get(key(methodName, desc));
        return info != null && info.isStable && info.nullability.isNotNull();
      }
      return false;
    }
  }

  private static MethodData collectMethodData(ClassReader reader, final Set<String> notNullAnnotations) {
    final MethodData result = new MethodData();
    reader.accept(new ClassVisitor(API_VERSION) {
      private boolean myEnum, myInner;

      @Override
      public AnnotationVisitor visitAnnotation(String desc, boolean visible) {
        if (KOTLIN_METADATA_ANNOTATION_CLASS_DESCRIPTOR.equals(desc)) {
          result.myIsKotlinBytecode = true;
        }
        return super.visitAnnotation(desc, visible);
      }

      @Override
      public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
        super.visit(version, access, name, signature, superName, interfaces);
        result.myClassName = name;
        myEnum = (access & ACC_ENUM) != 0;
      }

      @Override
      public void visitInnerClass(String name, String outerName, String innerName, int access) {
        super.visitInnerClass(name, outerName, innerName, access);
        if (result.myClassName.equals(name)) {
          myInner = (access & ACC_STATIC) == 0;
        }
      }

      @Override
      public MethodVisitor visitMethod(int access, final String name, final String desc, String signature, String[] exceptions) {
        final Type[] args = Type.getArgumentTypes(desc);
        final boolean methodCanHaveNullability = isReferenceType(Type.getReturnType(desc));

        final Map<Integer, Integer> paramSlots = new LinkedHashMap<>(); // map: localVariableSlot -> methodParameterIndex
        int slotIndex = isStatic(access) ? 0 : 1;
        for (int paramIndex = 0; paramIndex < args.length; paramIndex++) {
          Type arg = args[paramIndex];
          paramSlots.put(slotIndex, paramIndex);
          slotIndex += arg.getSize();
        }

        final MethodInfo methodInfo = new MethodInfo();
        methodInfo.isStable = (access & (ACC_FINAL | ACC_STATIC | ACC_PRIVATE)) != 0;
        methodInfo.paramAnnotationOffset = !"<init>".equals(name) ? 0 : myEnum ? 2 : myInner ? 1 : 0;
        result.myMethodInfos.put(MethodData.key(name, desc), methodInfo);

        return new MethodVisitor(api) {
          private int myParamAnnotationOffset = methodInfo.paramAnnotationOffset;

          @Override
          public void visitAnnotableParameterCount(int parameterCount, boolean visible) {
            if (myParamAnnotationOffset != 0 && parameterCount == args.length) {
              myParamAnnotationOffset = 0;
            }
            super.visitAnnotableParameterCount(parameterCount, visible);
          }

          @Override
          public AnnotationVisitor visitParameterAnnotation(int parameter, String anno, boolean visible) {
            AnnotationVisitor base = super.visitParameterAnnotation(parameter, anno, visible);
            return checkParameterNullability(parameter + myParamAnnotationOffset, anno, base, false);
          }

          @Override
          public AnnotationVisitor visitAnnotation(String anno, boolean isRuntime) {
            AnnotationVisitor base = super.visitAnnotation(anno, isRuntime);
            if (methodCanHaveNullability && notNullAnnotations.contains(anno)) {
              return collectNotNullArgs(base, methodInfo.nullability.withNotNull(anno, ISE_CLASS_NAME));
            }
            return base;
          }

          @Override
          public AnnotationVisitor visitTypeAnnotation(int typeRef, TypePath typePath, String anno, boolean visible) {
            AnnotationVisitor base = super.visitTypeAnnotation(typeRef, typePath, anno, visible);
            if (typePath != null) return base;

            TypeReference ref = new TypeReference(typeRef);
            if (methodCanHaveNullability && ref.getSort() == TypeReference.METHOD_RETURN) {
              if (notNullAnnotations.contains(anno)) {
                return collectNotNullArgs(base, methodInfo.nullability.withNotNull(anno, ISE_CLASS_NAME));
              }
              else if (seemsNullable(anno)) {
                methodInfo.nullability.hasTypeUseNullable = true;
              }
            }
            else if (ref.getSort() == TypeReference.METHOD_FORMAL_PARAMETER) {
              return checkParameterNullability(ref.getFormalParameterIndex() + methodInfo.paramAnnotationOffset, anno, base, true);
            }

            return base;
          }

          private boolean seemsNullable(String anno) {
            String shortName = getAnnoShortName(anno);
            // use hardcoded short names until it causes trouble
            // this is to avoid cumbersome passing of configured nullable names from the IDE
            return shortName.contains("Nullable") || shortName.equals("CheckForNull");
          }

          private AnnotationVisitor collectNotNullArgs(AnnotationVisitor base, final NotNullState state) {
            return new AnnotationVisitor(API_VERSION, base) {
              @Override
              public void visit(String methodName, Object o) {
                if (ANNOTATION_DEFAULT_METHOD.equals(methodName) && !((String) o).isEmpty()) {
                  state.message = (String) o;
                }
                else if ("exception".equals(methodName) && o instanceof Type && !((Type)o).getClassName().equals(Exception.class.getName())) {
                  state.exceptionType = ((Type)o).getInternalName();
                }
                super.visit(methodName, o);
              }
            };
          }

          private AnnotationVisitor checkParameterNullability(int parameter, String anno, AnnotationVisitor av, boolean typeUse) {
            if (parameter >= 0 && parameter < args.length && isReferenceType(args[parameter])) {
              if (notNullAnnotations.contains(anno)) {
                return collectNotNullArgs(av, methodInfo.obtainParameterNullability(parameter).withNotNull(anno, IAE_CLASS_NAME));
              }
              else if (typeUse && seemsNullable(anno)) {
                methodInfo.obtainParameterNullability(parameter).hasTypeUseNullable = true;
              }
            }

            return av;
          }

          @Override
          public void visitLocalVariable(String name2, String desc, String signature, Label start, Label end, int slotIndex) {
            Integer paramIndex = paramSlots.get(slotIndex);
            if (paramIndex != null) {
              methodInfo.paramNames.put(paramIndex, name2);
            }
          }
        };
      }
    }, ClassReader.SKIP_FRAMES);
    return result;
  }

  private static class NotNullState {
    String message;
    String exceptionType;
    String notNullAnno;
    boolean hasTypeUseNullable;

    NotNullState withNotNull(String notNullAnno, String exceptionType) {
      this.notNullAnno = notNullAnno;
      this.exceptionType = exceptionType;
      return this;
    }

    boolean isNotNull() {
      return notNullAnno != null && !hasTypeUseNullable;
    }

    String getNullParamMessage(String paramName) {
      if (message != null) return message;
      String shortName = getAnnoShortName(notNullAnno);
      if (paramName != null) return "Argument for @" + shortName + " parameter '%s' of %s.%s must not be null";
      return "Argument %s for @" + shortName + " parameter of %s.%s must not be null";
    }

    String getNullResultMessage() {
      if (message != null) return message;
      String shortName = getAnnoShortName(notNullAnno);
      return "@" + shortName + " method %s.%s must not return null";
    }
  }

  private static String getAnnoShortName(String anno) {
    String fullName = anno.substring(1, anno.length() - 1); // "Lpk/name;" -> "pk/name"
    return fullName.substring(fullName.lastIndexOf('/') + 1);
  }

  @Override
  public MethodVisitor visitMethod(int access, final String name, final String desc, String signature, String[] exceptions) {
    final MethodInfo info = myMethodData.myMethodInfos.get(MethodData.key(name, desc));
    if ((access & ACC_BRIDGE) != 0 || info == null) {
      return new FailSafeMethodVisitor(API_VERSION, super.visitMethod(access, name, desc, signature, exceptions));
    }

    final boolean isStatic = isStatic(access);
    final Type[] args = Type.getArgumentTypes(desc);
    final NotNullInstructionTracker instrTracker = new NotNullInstructionTracker(cv.visitMethod(access, name, desc, signature, exceptions));
    return new FailSafeMethodVisitor(API_VERSION, instrTracker) {
      private Label myStartGeneratedCodeLabel;

      @Override
      public void visitCode() {
        for (Iterator<NotNullState> iterator = info.paramNullability.values().iterator(); iterator.hasNext(); ) {
          if (!iterator.next().isNotNull()) {
            iterator.remove();
          }
        }
        if (!info.paramNullability.isEmpty()) {
          myStartGeneratedCodeLabel = new Label();
          mv.visitLabel(myStartGeneratedCodeLabel);
        }
        for (Map.Entry<Integer, NotNullState> entry : info.paramNullability.entrySet()) {
          Integer param = entry.getKey();
          int var = isStatic ? 0 : 1;
          for (int i = 0; i < param; ++i) {
            var += args[i].getSize();
          }
          mv.visitVarInsn(ALOAD, var);

          Label end = new Label();
          mv.visitJumpInsn(IFNONNULL, end);

          NotNullState state = entry.getValue();
          String paramName = myMethodData.lookupParamName(name, desc, param);
          String descrPattern = state.getNullParamMessage(paramName);
          String[] args = state.message != null
                          ? EMPTY_STRING_ARRAY
                          : new String[]{paramName != null ? paramName : String.valueOf(param - info.paramAnnotationOffset), myMethodData.myClassName, name};
          reportError(state.exceptionType, end, descrPattern, args);
        }
      }

      @Override
      public void visitLocalVariable(String name, String desc, String signature, Label start, Label end, int index) {
        boolean isParameterOrThisRef = isStatic ? index < args.length : index <= args.length;
        Label label = (isParameterOrThisRef && myStartGeneratedCodeLabel != null) ? myStartGeneratedCodeLabel : start;
        mv.visitLocalVariable(name, desc, signature, label, end, index);
      }

      @Override
      public void visitInsn(int opcode) {
        if (opcode == ARETURN && instrTracker.canBeNull() && info.nullability.isNotNull()) {
          mv.visitInsn(DUP);
          Label skipLabel = new Label();
          mv.visitJumpInsn(IFNONNULL, skipLabel);
          String descrPattern = info.nullability.getNullResultMessage();
          String[] args = info.nullability.message != null ? EMPTY_STRING_ARRAY : new String[]{myMethodData.myClassName, name};
          reportError(info.nullability.exceptionType, skipLabel, descrPattern, args);
        }

        mv.visitInsn(opcode);
      }

      private void reportError(String exceptionClass, Label end, String descrPattern, String[] args) {
        myAuxGenerator.reportError(mv, myMethodData.myClassName, exceptionClass, descrPattern, args);
        mv.visitLabel(end);
        myIsModification = true;
        processPostponedErrors();
      }

      @Override
      @SuppressWarnings("SpellCheckingInspection")
      public void visitMaxs(int maxStack, int maxLocals) {
        try {
          super.visitMaxs(maxStack, maxLocals);
        }
        catch (Throwable e) {
          registerError(name, "visitMaxs", e);
        }
      }
    };
  }

  @Override
  public void visitEnd() {
    myAuxGenerator.generateReportingMethod(cv);
    super.visitEnd();
  }

  private static boolean isStatic(int access) {
    return (access & ACC_STATIC) != 0;
  }

  private static boolean isReferenceType(Type type) {
    return type.getSort() == Type.OBJECT || type.getSort() == Type.ARRAY;
  }

  private void registerError(String methodName, @SuppressWarnings("SameParameterValue") String operationName, Throwable t) {
    if (myPostponedError == null) {
      // throw the first error that occurred
      Throwable cause = t.getCause();
      if (cause != null) t = cause;

      String message = t.getMessage();

      StringWriter writer = new StringWriter();
      t.printStackTrace(new PrintWriter(writer));

      StringBuilder text = new StringBuilder();
      text.append("Operation '").append(operationName).append("' failed for ").append(myMethodData.myClassName).append(".").append(methodName).append("(): ");
      if (message != null) text.append(message);
      text.append('\n').append(writer.getBuffer());
      myPostponedError = new RuntimeException(text.toString(), cause);
    }
    if (myIsModification) {
      processPostponedErrors();
    }
  }

  private void processPostponedErrors() {
    RuntimeException error = myPostponedError;
    if (error != null) {
      throw error;
    }
  }

  private final class NotNullInstructionTracker extends MethodVisitor {
    private boolean myCanBeNull = true; // initially assume the value can be null

    NotNullInstructionTracker(MethodVisitor delegate) {
      super(API_VERSION, delegate);
    }

    public boolean canBeNull() {
      return myCanBeNull;
    }

    @Override
    public void visitIntInsn(int opcode, int operand) {
      myCanBeNull = nextCanBeNullValue(opcode);
      super.visitIntInsn(opcode, operand);
    }

    @Override
    public void visitVarInsn(int opcode, int var) {
      myCanBeNull = nextCanBeNullValue(opcode);
      super.visitVarInsn(opcode, var);
    }

    @Override
    public void visitTypeInsn(int opcode, String type) {
      myCanBeNull = nextCanBeNullValue(opcode);
      super.visitTypeInsn(opcode, type);
    }

    @Override
    public void visitFieldInsn(int opcode, String owner, String name, String descriptor) {
      myCanBeNull = nextCanBeNullValue(opcode);
      super.visitFieldInsn(opcode, owner, name, descriptor);
    }

    @Override
    public void visitMethodInsn(int opcode, String owner, String name, String descriptor, boolean isInterface) {
      myCanBeNull = nextCanBeNullValue(opcode, owner, name, descriptor); /*is not a constructor call*/
      super.visitMethodInsn(opcode, owner, name, descriptor, isInterface);
    }

    @Override
    public void visitInvokeDynamicInsn(String name, String descriptor, Handle bootstrapMethodHandle, Object... bootstrapMethodArguments) {
      myCanBeNull = nextCanBeNullValue(INVOKEDYNAMIC);
      super.visitInvokeDynamicInsn(name, descriptor, bootstrapMethodHandle, bootstrapMethodArguments);
    }

    @Override
    public void visitJumpInsn(int opcode, Label label) {
      myCanBeNull = nextCanBeNullValue(opcode);
      super.visitJumpInsn(opcode, label);
    }

    @Override
    public void visitLdcInsn(Object value) {
      myCanBeNull = nextCanBeNullValue(LDC);
      super.visitLdcInsn(value);
    }

    @Override
    public void visitIincInsn(int var, int increment) {
      myCanBeNull = nextCanBeNullValue(IINC);
      super.visitIincInsn(var, increment);
    }

    @Override
    public void visitTableSwitchInsn(int min, int max, Label defaultLabel, Label... labels) {
      myCanBeNull = nextCanBeNullValue(TABLESWITCH);
      super.visitTableSwitchInsn(min, max, defaultLabel, labels);
    }

    @Override
    public void visitLookupSwitchInsn(Label defaultLabel, int[] keys, Label[] labels) {
      myCanBeNull = nextCanBeNullValue(LOOKUPSWITCH);
      super.visitLookupSwitchInsn(defaultLabel, keys, labels);
    }

    @Override
    public void visitMultiANewArrayInsn(String descriptor, int numDimensions) {
      myCanBeNull = nextCanBeNullValue(MULTIANEWARRAY);
      super.visitMultiANewArrayInsn(descriptor, numDimensions);
    }

    @Override
    public void visitInsn(int opcode) {
      myCanBeNull = nextCanBeNullValue(opcode);
      super.visitInsn(opcode);
    }

    private boolean nextCanBeNullValue(int nextMethodCallOpcode, String owner, String name, String descriptor) {
      if (nextMethodCallOpcode == INVOKESPECIAL && ("<init>".equals(name) || myMethodData.isAlwaysNotNull(owner, name, descriptor))) {
        // a constructor call or a NotNull marked own method
        return false;
      }
      if ((nextMethodCallOpcode == INVOKESTATIC || nextMethodCallOpcode == INVOKEVIRTUAL) &&
          myMethodData.isAlwaysNotNull(owner, name, descriptor)) {
        return false;
      }
      return true;
    }

    private boolean nextCanBeNullValue(int nextOpcode) {
      // if instruction guaranteed produces non-null stack value
      if (nextOpcode == LDC || nextOpcode == NEW || nextOpcode == ANEWARRAY || nextOpcode == NEWARRAY || nextOpcode == MULTIANEWARRAY) {
        return false;
      }
      // for some instructions, it is safe not to change the previously calculated flag value
      if (nextOpcode == DUP || nextOpcode == DUP_X1 || nextOpcode == DUP_X2 ||
          nextOpcode == DUP2 || nextOpcode == DUP2_X1 || nextOpcode == DUP2_X2 ||
          nextOpcode == JSR || nextOpcode == GOTO || nextOpcode == NOP ||
          nextOpcode == RET || nextOpcode == CHECKCAST) {
        return myCanBeNull;
      }
      // by default assume nullable
      return true;
    }
  }
}
