package xim.resource

import xim.math.Matrix4f
import xim.math.Vector2f
import xim.math.Vector3f
import xim.poc.*
import xim.poc.camera.CameraReference
import xim.poc.gl.ByteColor
import xim.util.interpolate

abstract class InterpolatedEffect {

    var runTime = 0f
    var progress = 0f

    private var ranOnce = false

    fun update(elapsedFrames: Float) {
        val completed = isComplete()
        if (ranOnce && completed) { return }

        runTime += elapsedFrames
        runTime = runTime.coerceAtMost(duration())
        progress = if (duration() == 0f) { 1.0f } else { runTime / duration() }

        update()
        ranOnce = true

        if (isComplete()) { onComplete() }
    }

    open fun update() {}

    abstract fun duration(): Float

    fun isComplete(): Boolean {
        return runTime >= duration()
    }

    open fun onComplete() {}

}

class SkeletonAnimationInstance(val actor: Actor, val skeletonAnimationRoutine: SkeletonAnimationRoutine, val localDir: DirectoryResource, val modelSlotVisibilityState: ModelSlotVisibilityState): InterpolatedEffect() {

    init {
        val animationDirs = listOf(localDir) + actor.getAllAnimationDirectories()
        val loopParams = LoopParams(loopDuration = skeletonAnimationRoutine.duration.toFloat() / 2f, numLoops = skeletonAnimationRoutine.maxLoops)
        val transitionParams = TransitionParams(transitionInTime = skeletonAnimationRoutine.transitionInTime.toFloat() / 2f, transitionOutTime = skeletonAnimationRoutine.transitionOutTime.toFloat() / 2f)

        actor.actorModel?.setSkeletonAnimation(skeletonAnimationRoutine.id, animationDirs, loopParams, transitionParams = transitionParams, modelSlotVisibilityState = modelSlotVisibilityState)
    }

    override fun duration(): Float {
        return skeletonAnimationRoutine.duration.toFloat()
    }

}

class FlinchAnimationInstance(val actor: Actor, val flinchRoutine: FlinchRoutine): InterpolatedEffect() {

    private val transitionParams: TransitionParams

    init {
        transitionParams = TransitionParams(transitionInTime = flinchRoutine.animationDuration/2f, transitionOutTime = flinchRoutine.animationDuration/2f)
        setModelRoutine()
    }

    override fun update() {
        if (runTime >= flinchRoutine.animationDuration / 2f) {
            transitionParams.eagerTransitionOut = true
        }
    }

    override fun duration(): Float {
        return flinchRoutine.animationDuration / 2f
    }

    private fun setModelRoutine() {
        val actorModel = actor.actorModel ?: return
        if (actorModel.isAnimationLocked()) { return }

        // Flinching should only ever overwrite idle animations
        val currentlyIdle = actorModel.skeletonAnimationCoordinator.animations.all { it == null || it.currentAnimation?.loopParams?.lowPriority == true }
        if (!currentlyIdle) { return }

        val animationDirs = actor.getAllAnimationDirectories()
        val loopParams = LoopParams(loopDuration = null, numLoops = 1)

        // TODO - using [dfi?] arbitrarily here. There's also [dbi?], which is used for getting hit from the back.
        // TODO - where is [dfi?] defined for pc models? [dfm?] is probably for extended knock-back effects
        val flinchId = if (actorModel.model is PcModel) { DatId("dfm?") } else { DatId("dfi?") }
        actorModel.setSkeletonAnimation(flinchId, animationDirs, loopParams, transitionParams = transitionParams)
    }

}

class ModelTransformInstance(
    val modelId: DatId,
    val effect: ModelTransformEffect,
    val area: Area,
    val initialValueSupplier: (ModelTransform) -> Vector3f,
    val updater: (ModelTransform, Vector3f) -> Unit) : InterpolatedEffect() {

    private val initialValue = Vector3f()

    init {
        val initialTransforms = area.getModelTransform(modelId)
        if (initialTransforms != null) {
            val transform = initialTransforms.transforms[effect.index]
            if (transform != null) { initialValue.copyFrom(initialValueSupplier.invoke(transform)) }
        }
    }

    override fun update() {
        val newValue = Vector3f.lerp(initialValue, effect.finalValue, progress)
        area.updateModelTransform(modelId, effect.index) { updater.invoke(it, newValue) }
    }

    override fun duration(): Float {
        return effect.duration.toFloat()
    }

}

class ActorColorTransform(val effect: ActorFadeRoutine, val actor: Actor) : InterpolatedEffect() {

    private val initialValue = actor.renderState.effectColor.copy()

    override fun update() {
        actor.renderState.effectColor = ByteColor.interpolate(initialValue, effect.endColor, progress)
    }

    override fun duration(): Float {
        return effect.duration.toFloat()
    }

}

class ActorWrapEffect (
    var textureLink: DatLink<TextureResource> = DatLink(DatId.zero),
    var color: ByteColor = ByteColor.zero.copy(),
    var uvTranslation: Vector2f = Vector2f()
) {

    init {
        reset()
    }

    fun reset() {
        textureLink = DatLink(DatId.zero)
        color = ByteColor.zero.copy()
        uvTranslation = Vector2f()
    }
}

class ActorWrapColorTransform(val effect: ActorWrapColor, val actor: Actor) : InterpolatedEffect() {

    private val initialValue = actor.renderState.wrapEffect.color.copy()

    override fun update() {
        actor.renderState.wrapEffect.color = ByteColor.interpolate(initialValue, effect.endValue, progress)
    }

    override fun duration(): Float {
        return effect.duration.toFloat()
    }
}

class ActorWrapUvTransform(val effect: ActorWrapUvTranslation, val actor: Actor) : InterpolatedEffect() {

    private val initialValue = actor.renderState.wrapEffect.uvTranslation.copy()

    override fun update() {
        actor.renderState.wrapEffect.uvTranslation[effect.uv] = initialValue[effect.uv].interpolate(effect.endValue, progress)
    }

    override fun duration(): Float {
        return effect.duration.toFloat()
    }
}

class ActorWrapTextureEffect(val effect: ActorWrapTexture, val actor: Actor): InterpolatedEffect() {
    override fun duration(): Float {
        return effect.duration.toFloat()
    }

    init {
        actor.renderState.wrapEffect.textureLink = effect.textureLink
    }

    override fun onComplete() {
        actor.renderState.wrapEffect.textureLink = DatLink(DatId.zero)
    }

}

class ActorJumpTransform(val effect: ActorJumpRoutine, val actor: Actor, val start: Vector3f, val end: Vector3f) : InterpolatedEffect() {

    init {
        actor.displayPositionOffset.copyFrom(end - start)
        if (actor.isPlayer()) { CameraReference.getInstance().lock(enable = true, position = start) }
    }

    override fun onComplete() {
        actor.displayPositionOffset.copyFrom(Vector3f.ZERO)
        if (actor.isPlayer()) { CameraReference.getInstance().lock(enable = false) }
    }

    override fun duration(): Float {
        return effect.duration.toFloat()
    }

}

class ActorForwardDisplacement(val effect: ForwardDisplacementEffect, val actor: Actor): InterpolatedEffect() {

    val end = Matrix4f().rotateYInPlace(actor.displayFacingDir)
            .translateInPlace(Vector3f(effect.displacement, 0f, 0f))
            .getTranslationVector()

    override fun update() {
        actor.displayPositionOffset.copyFrom(Vector3f.lerp(Vector3f.ZERO, end, progress))
    }

    override fun onComplete() {
        actor.displayPositionOffset.copyFrom(Vector3f.ZERO)
    }

    override fun duration(): Float {
        return effect.duration.toFloat()
    }

}

class PointLightMultiplierModifier(val effect: PointLightInterpolationEffect, val particleGenerator: ParticleGenerator): InterpolatedEffect() {

    private val startingValues = HashMap<Long, Float>()

    override fun duration(): Float {
        return effect.duration.toFloat()
    }

    override fun update() {
        val particles = particleGenerator.getActiveParticles()

        particles.forEach {
            val start = startingValues.getOrPut(it.internalId) { getValue(it) }
            setValue(it, start.interpolate(effect.endValue, progress))
        }
    }

    private fun getValue(particle: Particle): Float {
        return if (effect.theta) { particle.pointLightParams.thetaMultiplier } else { particle.pointLightParams.rangeMultiplier }
    }

    private fun setValue(particle: Particle, value: Float) {
        if (effect.theta) { particle.pointLightParams.thetaMultiplier = value } else { particle.pointLightParams.rangeMultiplier = value }
    }

}
