package xim.poc

import xim.resource.DatId
import xim.resource.SkeletonAnimation
import xim.resource.SkeletonAnimationKeyFrameTransform
import xim.resource.SkeletonAnimationResource

class AnimationSnapshot {

    private val jointSnapshots = HashMap<Int, SkeletonAnimationKeyFrameTransform>()

    constructor(previous: SkeletonAnimationContext) {
        val joints = previous.animation.keyFrameSets.keys
        for (joint in joints) {
            val snapshot = previous.animation.getJointTransform(joint, previous.currentFrame) ?: continue
            jointSnapshots[joint] = snapshot
        }
    }

    constructor(transition: AnimationTransition) {
        val allJoints = transition.previous.jointSnapshots.keys + transition.next.animation.keyFrameSets.keys
        for (joint in allJoints) {
            val snapshot = transition.getJointTransform(joint) ?: continue
            jointSnapshots[joint] = snapshot
        }
    }

    fun getJointTransform(jointIndex: Int): SkeletonAnimationKeyFrameTransform? {
        return jointSnapshots[jointIndex]
    }

}

class AnimationTransition(val previous: AnimationSnapshot, val next: SkeletonAnimationContext, val transitionDuration: Float, val inBetween: SkeletonAnimationResource? = null) {

    private var progress = 0f

    fun update(elapsedFrames: Float): Boolean {
        progress += elapsedFrames
        return isComplete()
    }

    fun isComplete(): Boolean {
        return progress >= transitionDuration
    }

    fun getJointTransform(jointIndex: Int): SkeletonAnimationKeyFrameTransform? {
        val t = progress/transitionDuration

        return if (inBetween == null) {
            val keyFramePrev = previous.getJointTransform(jointIndex) ?: return null
            val keyFrameNext = next.getJointTransform(jointIndex) ?: return null
            SkeletonAnimationKeyFrameTransform.interpolate(keyFramePrev, keyFrameNext, t)
        } else if (t < 0.5f) {
            val keyFramePrev = previous.getJointTransform(jointIndex) ?: return null
            val keyFrameNext = inBetween.skeletonAnimation.getJointTransform(jointIndex, 0f) ?: return null
            SkeletonAnimationKeyFrameTransform.interpolate(keyFramePrev, keyFrameNext, t * 2f)
        } else {
            val keyFramePrev = inBetween.skeletonAnimation.getJointTransform(jointIndex, 0f) ?: return null
            val keyFrameNext = next.getJointTransform(jointIndex) ?: return null
            SkeletonAnimationKeyFrameTransform.interpolate(keyFramePrev, keyFrameNext, (t - 0.5f) * 2f)
        }
    }
}

data class LoopParams (
    val loopDuration: Float?,
    val numLoops: Int?,
) {
    companion object {
        fun lowPriorityLoop(): LoopParams = LoopParams(loopDuration = null, numLoops = null)
    }
}

class TransitionParams (
    val transitionInTime: Float = 7.5f,
    val transitionOutTime: Float = 7.5f,
    val inBetween: DatId? = null,
) {
    var resolvedInBetween: Map<Int, SkeletonAnimationResource>? = null
}

class SkeletonAnimationContext(
    val animation: SkeletonAnimation,
    val loopParams: LoopParams,
    val transitionParams: TransitionParams?,
) {

    var currentFrame = 0f

    private var completed = false
    private var loopCounter = 0

    fun advance(elapsedFrames: Float) {
        if (loopParams.loopDuration == 0f) {
            currentFrame = 0f
            completed = true
            return
        }

        val loopDuration = loopParams.loopDuration ?: animation.getLengthInFrames()
        val scalingFactor = animation.getLengthInFrames() / loopDuration

        currentFrame += (elapsedFrames * scalingFactor)
        currentFrame = applyLoopBounds()
    }

    fun getJointTransform(jointIndex: Int): SkeletonAnimationKeyFrameTransform? {
        return animation.getJointTransform(jointIndex, currentFrame)
    }

    fun isDoneLooping(): Boolean {
        return loopParams.numLoops == null || completed
    }

    private fun applyLoopBounds(): Float {
        val maxLoops = loopParams.numLoops ?: 0

        while (currentFrame > animation.getLengthInFrames()) {
            loopCounter += 1
            currentFrame -= animation.getLengthInFrames()
        }

        if (maxLoops != 0 && loopCounter >= maxLoops) {
            completed = true
            return animation.getLengthInFrames()
        }

        return currentFrame
    }

}

class SkeletonAnimator(private val animationSlot: Int) {

    var currentAnimation: SkeletonAnimationContext? = null
    var transition: AnimationTransition? = null

    fun update(elapsedFrames: Float) {
        if (transition?.update(elapsedFrames) == true) {
            transition = null
        } else {
            currentAnimation?.advance(elapsedFrames)
        }
    }

    fun setNextAnimation(skeletonAnimationContext: SkeletonAnimationContext, transitionParams: TransitionParams?) {
        val current = currentAnimation

        if (current == null || transitionParams?.transitionInTime == 0f) {
            currentAnimation = skeletonAnimationContext
            return
        }

        if (current.animation.id == skeletonAnimationContext.animation.id) {
            return
        }

        val transitionDuration = if (transitionParams != null) {
            transitionParams.transitionInTime
        } else if (current.transitionParams != null && current.transitionParams.transitionOutTime > 0f) {
            current.transitionParams.transitionOutTime
        } else {
            7.5f
        }

        val snapshot = transition?.let { AnimationSnapshot(it) } ?: AnimationSnapshot(current)
        val maybeInBetweenFrame = transitionParams?.resolvedInBetween?.get(animationSlot)

        transition = AnimationTransition(snapshot, skeletonAnimationContext, transitionDuration, maybeInBetweenFrame)
        currentAnimation = skeletonAnimationContext
    }

    fun getJointTransform(jointIndex: Int): SkeletonAnimationKeyFrameTransform? {
        return transition?.getJointTransform(jointIndex) ?: currentAnimation?.getJointTransform(jointIndex)
    }
}

class SkeletonAnimationCoordinator {

    val animations = HashMap<Int, SkeletonAnimator>()

    fun update(elapsedFrames: Float) {
        animations.values.forEach { it.update(elapsedFrames) }
    }

    fun registerAnimation(skeletonAnimationResources: List<SkeletonAnimationResource>, loopParams: LoopParams, transitionParams: TransitionParams? = null, overrideCondition: (SkeletonAnimator) -> Boolean = { true }) {
        for (skeletonAnimationResource in skeletonAnimationResources) {
            val animationType = skeletonAnimationResource.id.finalDigit() ?: 0
            val animator = animations.getOrPut(animationType) { SkeletonAnimator(animationType) }
            val context = SkeletonAnimationContext(skeletonAnimationResource.skeletonAnimation, loopParams, transitionParams)

            if (overrideCondition(animator)) {
                animator.setNextAnimation(context, transitionParams)
            }
        }
    }

    // See: [Skeleton Animation Parameterization]
    fun registerAnimationDirect(skeletonAnimationResource: SkeletonAnimationResource, loopParams: LoopParams, transitionParams: TransitionParams?) {
        var slot = 0

        for (i in 0 until 8) {
            slot = i
            val current = animations[slot]?.currentAnimation ?: break
            if (current.currentFrame > 0f) { break }
        }

        val animator = animations.getOrPut(slot) { SkeletonAnimator(slot) }
        val context = SkeletonAnimationContext(skeletonAnimationResource.skeletonAnimation, loopParams, transitionParams)
        animator.setNextAnimation(context, transitionParams)
    }

    fun hasCompleteTransitionOutAnimations(): Boolean {
        return animations.values.any { readyForTransitionOut(it, requireTransitionOut = true) }
    }

    fun registerIdleAnimation(skeletonAnimationResources: List<SkeletonAnimationResource>, requireTransitionOut: Boolean) {
        registerAnimation(skeletonAnimationResources, loopParams = LoopParams.lowPriorityLoop()) {
            readyForTransitionOut(it, requireTransitionOut)
        }
    }

    private fun readyForTransitionOut(animator: SkeletonAnimator, requireTransitionOut: Boolean): Boolean {
        val current = animator.currentAnimation

        val transitionOutReqs = if (!requireTransitionOut) { true } else {
            val outTime = current?.transitionParams?.transitionOutTime
            val noOutTimeCase = (outTime == null || outTime == 0f)
            val outTimeCase = outTime != null && outTime > 0f
            noOutTimeCase || outTimeCase
        }

        val doneLooping = current?.loopParams == null || current.isDoneLooping()

        return transitionOutReqs && doneLooping
    }

    fun getJointTransform(jointIndex: Int): List<SkeletonAnimationKeyFrameTransform> {
        return animations.values.mapNotNull { it.getJointTransform(jointIndex) }
    }

    fun isTransitioning(): Boolean {
        return animations.values.any { it.transition != null }
    }

}