package com.intellij.remoteDev.tests.impl

import com.intellij.codeWithMe.ClientId
import com.intellij.codeWithMe.ClientId.Companion.isLocal
import com.intellij.codeWithMe.ClientIdContextElement
import com.intellij.codeWithMe.clientId
import com.intellij.diagnostic.LoadingState
import com.intellij.diagnostic.dumpCoroutines
import com.intellij.diagnostic.enableCoroutineDump
import com.intellij.ide.plugins.PluginManagerCore
import com.intellij.ide.plugins.PluginModuleDescriptor
import com.intellij.ide.plugins.PluginModuleId
import com.intellij.idea.AppMode
import com.intellij.notification.Notification
import com.intellij.notification.NotificationType
import com.intellij.openapi.application.*
import com.intellij.openapi.application.impl.LaterInvocator
import com.intellij.openapi.client.ClientKind
import com.intellij.openapi.client.ClientSessionsManager
import com.intellij.openapi.diagnostic.Logger
import com.intellij.openapi.project.Project
import com.intellij.openapi.project.ex.ProjectManagerEx
import com.intellij.openapi.rd.util.setSuspend
import com.intellij.openapi.util.SystemInfoRt
import com.intellij.remoteDev.tests.*
import com.intellij.remoteDev.tests.impl.utils.SerializedLambdaLoader
import com.intellij.remoteDev.tests.impl.utils.getArtifactsFileName
import com.intellij.remoteDev.tests.impl.utils.runLogged
import com.intellij.remoteDev.tests.impl.utils.waitSuspendingNotNull
import com.intellij.remoteDev.tests.modelGenerated.LambdaRdIdeType
import com.intellij.remoteDev.tests.modelGenerated.LambdaRdKeyValueEntry
import com.intellij.remoteDev.tests.modelGenerated.lambdaTestModel
import com.intellij.ui.WinFocusStealer
import com.intellij.util.ui.ImageUtil
import com.jetbrains.rd.framework.*
import com.jetbrains.rd.util.lifetime.EternalLifetime
import com.jetbrains.rd.util.reactive.viewNotNull
import kotlinx.coroutines.*
import org.jetbrains.annotations.ApiStatus
import org.jetbrains.annotations.TestOnly
import java.awt.Component
import java.awt.Window
import java.awt.image.BufferedImage
import java.io.File
import java.io.Serializable
import java.net.InetAddress
import java.net.URLClassLoader
import java.time.LocalTime
import javax.imageio.ImageIO
import kotlin.coroutines.EmptyCoroutineContext
import kotlin.reflect.KClass
import kotlin.reflect.full.companionObject
import kotlin.reflect.full.isSubclassOf
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.minutes
import kotlin.time.Duration.Companion.seconds

@TestOnly
@ApiStatus.Internal
open class LambdaTestHost(coroutineScope: CoroutineScope) {
  companion object {
    // it is easier to sort out logs from just testFramework
    private val LOG
      get() = Logger.getInstance(RdctTestFrameworkLoggerCategory.category + "Host")

    fun getLambdaTestPort(): Int? =
      System.getProperty(LambdaTestsConstants.protocolPortPropertyName)?.toIntOrNull()

    val sourcesRootFolder: File by lazy {
      System.getProperty(LambdaTestsConstants.sourcePathProperty, PathManager.getHomePath()).let(::File)
    }

    /**
     * ID of the plugin which contains test code.
     * Currently, only test code of the client part is put to a separate plugin.
     */
    const val TEST_MODULE_ID_PROPERTY_NAME: String = "lambda.test.module.id"

    // TODO: plugin: PluginModuleDescriptor might be passed as a context parameter and not via constructor
    abstract class NamedLambda<T : LambdaIdeContext>(protected val lambdaIdeContext: T, protected val plugin: PluginModuleDescriptor) {
      fun name(): String = this::class.qualifiedName ?: error("Can't get qualified name of lambda $this")
      abstract suspend fun T.lambda(args: List<LambdaRdKeyValueEntry>): Any?
      suspend fun runLambda(args: List<LambdaRdKeyValueEntry>) {
        with(lambdaIdeContext) {
          lambda(args = args)
        }
      }
    }
  }

  init {
    val hostAddress =
      System.getProperty(LambdaTestsConstants.protocolHostPropertyName)?.let {
        LOG.info("${LambdaTestsConstants.protocolHostPropertyName} system property is set=$it, will try to get address from it.")
        // this won't work when we do custom network setups as the default gateway will be overridden
        // val hostEntries = File("/etc/hosts").readText().lines()
        // val dockerInterfaceEntry = hostEntries.last { it.isNotBlank() }
        // val ipAddress = dockerInterfaceEntry.split("\\s".toRegex()).first()
        //  host.docker.internal is not available on linux yet (20.04+)
        InetAddress.getByName(it)
      } ?: InetAddress.getLoopbackAddress()

    val port = getLambdaTestPort()
    if (port != null) {
      LOG.info("Queue creating protocol on $hostAddress:$port")
      coroutineScope.launch {
        val coroutineDumperOnTimeout = launch {
          delay(20.seconds)
          LOG.warn("LoadingState.COMPONENTS_LOADED has not occurred in 20 seconds: ${dumpCoroutines()}")
        }
        while (!LoadingState.COMPONENTS_LOADED.isOccurred) {
          delay(10.milliseconds)
        }
        coroutineDumperOnTimeout.cancel()
        withContext(Dispatchers.EDT + ModalityState.any().asContextElement()) {
          createProtocol(hostAddress, port)
        }
      }
    }
  }

  private fun findLambdaClasses(lambdaReference: String, testModuleDescriptor: PluginModuleDescriptor, ideContext: LambdaIdeContext): List<NamedLambda<*>> {
    val className = if (lambdaReference.contains(".Companion")) {
      lambdaReference.substringBeforeLast(".").removeSuffix(".Companion")
    }
    else lambdaReference

    val testClass = Class.forName(className, true, testModuleDescriptor.pluginClassLoader).kotlin

    val companionClasses: Collection<KClass<*>> = testClass.companionObject?.nestedClasses ?: listOf()
    val nestedClasses: Collection<KClass<*>> = testClass.nestedClasses

    val namedLambdas = (companionClasses + nestedClasses + testClass)
      .filter { it.isSubclassOf(NamedLambda::class) }
      .mapNotNull {
        runCatching {
          it.constructors.single().call(ideContext, testModuleDescriptor) as NamedLambda<*> //todo maybe we can filter out constuctor in a more clever way
        }.getOrNull()
      }

    LOG.info("Found ${namedLambdas.size} lambda classes: ${namedLambdas.joinToString(", ") { it.name() }}")

    check(namedLambdas.isNotEmpty()) { "Can't find any named lambda in the test class '${testClass.qualifiedName}'" }

    return namedLambdas
  }

  private fun createProtocol(hostAddress: InetAddress, port: Int) {
    enableCoroutineDump()

    // EternalLifetime.createNested() is used intentionally to make sure logger session's lifetime is not terminated before the actual application stop.
    val lifetime = EternalLifetime.createNested()
    val protocolName = LambdaTestsConstants.protocolName
    LOG.info("Creating protocol '$protocolName' ...")

    val wire = SocketWire.Client(lifetime, LambdaTestIdeScheduler, port, protocolName, hostAddress)
    val protocol = Protocol(name = protocolName,
                            serializers = Serializers(),
                            identity = Identities(IdKind.Client),
                            scheduler = LambdaTestIdeScheduler,
                            wire = wire,
                            lifetime = lifetime)
    val model = protocol.lambdaTestModel

    LOG.info("Advise for session. Current state: ${model.session.value}...")
    model.session.viewNotNull(lifetime) { _, session ->

      try {
        @OptIn(ExperimentalCoroutinesApi::class)
        val sessionBgtDispatcher = Dispatchers.Default.limitedParallelism(1, "Lambda test session dispatcher")

        // Needed to enable proper focus behaviour
        if (SystemInfoRt.isWindows) {
          WinFocusStealer.setFocusStealingEnabled(true)
        }


        val testModuleDescriptor = run {
          val testModuleId = System.getProperty(TEST_MODULE_ID_PROPERTY_NAME)
                             ?: return@run null
          val tmd = PluginManagerCore.getPluginSet().findEnabledModule(PluginModuleId(testModuleId, PluginModuleId.JETBRAINS_NAMESPACE))
                    ?: error("Test plugin with test module '$testModuleId' is not found")

          assert(tmd.pluginClassLoader != null) {
            "Test plugin with test module '${testModuleId}' is not loaded." +
            "Probably due to missing dependencies, see `com.intellij.ide.plugins.ClassLoaderConfigurator#configureContentModule`."
          }
          return@run tmd
        }

        LOG.info("All test code will be loaded using '${testModuleDescriptor?.pluginClassLoader}'")

        fun getLambdaIdeContext(): LambdaIdeContext {
          val currentTestCoroutineScope = CoroutineScope(Dispatchers.Default + CoroutineName("Lambda test session scope") + SupervisorJob())

          currentTestCoroutineScope.coroutineContext.job.invokeOnCompletion {
            LOG.info("Test coroutine scope is completed")
          }
          return when (session.rdIdeType) {
            LambdaRdIdeType.BACKEND -> LambdaBackendContextClass(currentTestCoroutineScope.coroutineContext)
            LambdaRdIdeType.FRONTEND -> LambdaFrontendContextClass(currentTestCoroutineScope.coroutineContext)
            LambdaRdIdeType.MONOLITH -> LambdaMonolithContextClass(currentTestCoroutineScope.coroutineContext)
          }
        }

        var ideContext = getLambdaIdeContext()

        session.cleanUp.setSuspend(sessionBgtDispatcher) { _, _ ->
          LOG.info("Resetting scopes")
          ideContext.coroutineContext.job.cancelAndJoin()
          ideContext = getLambdaIdeContext()
        }

        // Advice for processing events
        session.runLambda.setSuspend(sessionBgtDispatcher) { _, parameters ->
          LOG.info("'${parameters.reference}': received lambda execution request")

          assert(testModuleDescriptor != null) {
            "Test module descriptor is not set, can't find test class '${parameters.reference}'"
          }
          try {
            val lambdaReference = parameters.reference
            val namedLambdas = findLambdaClasses(lambdaReference = lambdaReference, testModuleDescriptor = testModuleDescriptor!!, ideContext = ideContext)

            val ideAction = namedLambdas.singleOrNull { it.name() == lambdaReference } ?: run {
              val text = "There is no Action with reference '${lambdaReference}', something went terribly wrong, " +
                         "all referenced actions: ${namedLambdas.map { it.name() }}"
              LOG.error(text)
              error(text)
            }

            assert(ClientId.current.isLocal) { "ClientId '${ClientId.current}' should be local before test method starts" }
            LOG.info("'$parameters': received action execution request")

            val providedCoroutineContext = Dispatchers.Default + CoroutineName("Lambda task: ${ideAction.name()}")
            val clientId = providedCoroutineContext.clientId() ?: ClientId.current

            withContext(providedCoroutineContext) {
              assert(ClientId.current == clientId) { "ClientId '${ClientId.current}' should equal $clientId one when test method starts" }

              runLogged(parameters.reference, 1.minutes) {
                ideAction.runLambda(parameters.parameters ?: listOf())
              }
            }
          }
          catch (ex: Throwable) {
            LOG.warn("${session.rdIdeType}: ${parameters.let { "'$it' " }}hasn't finished successfully", ex)
            throw ex
          }
        }

        // Advice for processing events
        session.runSerializedLambda.setSuspend(sessionBgtDispatcher) { _, serializedLambda ->
          suspend fun clientIdContextToRunLambda() = if (session.rdIdeType == LambdaRdIdeType.BACKEND && AppMode.isRemoteDevHost()) {
            waitSuspendingNotNull("Got remote client id", 10.seconds) {
              ClientSessionsManager.getAppSessions(ClientKind.REMOTE).singleOrNull()?.clientId
            }.let { ClientIdContextElement(it) }
          }
          else {
            EmptyCoroutineContext
          }

          try {
            assert(ClientId.current.isLocal) { "ClientId '${ClientId.current}' should be local before test method starts" }
            LOG.info("'$serializedLambda': received serialized lambda execution request")
            return@setSuspend withContext(Dispatchers.Default + CoroutineName("Lambda task: ${serializedLambda.stepName}") + clientIdContextToRunLambda()) {
              runLogged(serializedLambda.stepName, 10.minutes) {
                val urls = serializedLambda.classPath.map { File(it).toURI().toURL() }
                URLClassLoader(urls.toTypedArray(), testModuleDescriptor?.pluginClassLoader ?: this::class.java.classLoader).use { cl ->
                  withContext(ideContext.coroutineContext) {
                    val params: List<Serializable> = serializedLambda.parametersBase64.map { SerializedLambdaLoader().loadObject(it, classLoader = cl) }
                    val result = SerializedLambdaLoader().load<LambdaIdeContext>(serializedLambda.serializedDataBase64, classLoader = cl).accept(ideContext, params)
                    SerializedLambdaLoader().save(serializedLambda.stepName, result)
                  }
                }
              }
            }
          }
          catch (ex: Throwable) {
            LOG.warn("${session.rdIdeType}: '${serializedLambda.stepName}' hasn't finished successfully", ex)
            throw ex
          }
        }

        session.isResponding.setSuspend(sessionBgtDispatcher + NonCancellable) { _, _ ->
          LOG.info("Answering for session is responding...")
          true
        }

        session.projectsNames.setSuspend(sessionBgtDispatcher) { _, _ ->
          ProjectManagerEx.getOpenProjects().map { it.name }.also {
            LOG.info("Projects: ${it.joinToString(", ", "[", "]")}")
          }
        }

        suspend fun waitProjectInitialisedOrDisposed(project: Project) {
          runLogged("Wait project '${project.name}' is initialised or disposed", 10.seconds) {
            while (!(project.isInitialized || project.isDisposed)) {
              delay(1.seconds)
            }
          }
        }

        suspend fun leaveAllModals(throwErrorIfModal: Boolean) {
          withContext(Dispatchers.EDT + ModalityState.any().asContextElement() + NonCancellable) {
            repeat(10) {
              if (ModalityState.current() == ModalityState.nonModal()) {
                return@withContext
              }
              delay(1.seconds)
            }
            if (throwErrorIfModal) {
              LOG.error("Unexpected modality: " + ModalityState.current())
            }
            LaterInvocator.forceLeaveAllModals("LambdaTestHost - leaveAllModals")
            repeat(10) {
              if (ModalityState.current() == ModalityState.nonModal()) {
                return@withContext
              }
              delay(1.seconds)
            }
            LOG.error("Failed to close modal dialog: " + ModalityState.current())
          }
        }

        session.closeAllOpenedProjects.setSuspend(sessionBgtDispatcher) { _, _ ->
          try {
            leaveAllModals(throwErrorIfModal = true)

            ProjectManagerEx.getOpenProjects().forEach { waitProjectInitialisedOrDisposed(it) }
            withContext(Dispatchers.EDT + NonCancellable) {
              writeIntentReadAction {
                ProjectManagerEx.getInstanceEx().closeAndDisposeAllProjects(checkCanClose = false)
              }
            }
          }
          catch (ce: CancellationException) {
            throw ce
          }

        }

        session.makeScreenshot.setSuspend(sessionBgtDispatcher) { _, fileName ->
          makeScreenshot(fileName)
        }

        session.projectsAreInitialised.setSuspend(sessionBgtDispatcher) { _, _ ->
          ProjectManagerEx.getOpenProjects().map { it.isInitialized }.all { true }
        }

        LOG.info("Test session ready!")
        session.ready.value = true
      }
      catch (ex: Throwable) {
        LOG.warn("Test session initialization hasn't finished successfully", ex)
        session.ready.value = false
      }
    }
  }

  private fun screenshotFile(actionName: String, suffix: String, timeStamp: LocalTime): File {
    val fileName = getArtifactsFileName(actionName, suffix, "png", timeStamp)

    return File(PathManager.getLogPath()).resolve(fileName)
  }

  private suspend fun makeScreenshotOfComponent(screenshotFile: File, component: Component) {
    runLogged("Making screenshot of ${component}") {
      val img = ImageUtil.createImage(component.width, component.height, BufferedImage.TYPE_INT_ARGB)
      component.printAll(img.createGraphics())
      withContext(Dispatchers.IO + NonCancellable) {
        try {
          ImageIO.write(img, "png", screenshotFile)
          LOG.info("Screenshot is saved at: $screenshotFile")
        }
        catch (t: Throwable) {
          LOG.warn("Exception while writing screenshot image to file", t)
        }
      }
    }
  }

  private suspend fun makeScreenshot(actionName: String): Boolean {
    if (ApplicationManager.getApplication().isHeadlessEnvironment) {
      LOG.warn("Can't make screenshot on application in headless mode.")
      return false
    }

    return runLogged("'$actionName': Making screenshot") {
      withContext(Dispatchers.EDT + ModalityState.any().asContextElement() + NonCancellable) { // even if there is a modal window opened
        val timeStamp = LocalTime.now()

        return@withContext try {
          val windows = Window.getWindows().filter { it.height != 0 && it.width != 0 }.filter { it.isShowing }
          windows.forEachIndexed { index, window ->
            val screenshotFile = if (window.isFocused) {
              screenshotFile(actionName, "_${index}_focusedWindow", timeStamp)
            }
            else {
              screenshotFile(actionName, "_$index", timeStamp)
            }
            makeScreenshotOfComponent(screenshotFile, window)
          }
          true
        }
        catch (e: Throwable) {
          LOG.warn("Test action 'makeScreenshot' hasn't finished successfully", e)
          false
        }
      }
    }
  }
}

@Suppress("HardCodedStringLiteral", "DialogTitleCapitalization")
private fun showNotification(text: String?) {
  if (ApplicationManager.getApplication().isHeadlessEnvironment || text.isNullOrBlank()) {
    return
  }

  Notification("TestFramework", "Test Framework", text, NotificationType.INFORMATION).notify(null)
}