package xim.resource

import xim.math.Matrix4f
import xim.math.Quaternion
import xim.math.Vector3f
import xim.poc.*
import xim.util.PI_f

class JointInstance(val index: Int, val parent: JointInstance?, val definition: Joint, val currentTransform: Matrix4f)

class SkeletonInstance(val resource: SkeletonResource) {

    val joints = ArrayList<JointInstance>()

    init {
        for (i in resource.joints.indices) {
            val jointDef = resource.joints[i]
            val parent = if (jointDef.parentIndex == -1) { null } else { joints[jointDef.parentIndex] }
            joints.add(JointInstance(i, parent, resource.joints[i], Matrix4f()))
        }
    }

    fun getStandardJoint(standardPosition: StandardPosition) : JointReference {
        return getStandardJoint(standardPosition.referenceIndex)
    }

    fun getStandardJoint(referenceIndex: Int) : JointReference {
        return resource.jointReference[referenceIndex]
    }

    fun getStandardJointExtended(referenceIndex: Int, fromActor: Actor, toActor: Actor): Int {
        if (referenceIndex < 49 || referenceIndex > 51) { return referenceIndex }
        // joint 49 refer to a set of 8 joints ([13,20]) that form a circle around the actor.
        // The nearest of these should be chosen.
        // TODO - how do 50 and 51 differ from 49?

        var reference = referenceIndex
        var distance = Float.MAX_VALUE

        for (i in 13 .. 20) {
            val jointPosition = toActor.getWorldSpaceJointPosition(i)
            val jointDistance = Vector3f.distanceSquared(fromActor.displayPosition, jointPosition)

            if (jointDistance < distance) {
                reference = i
                distance = jointDistance
            }
        }

        return reference
    }

    fun getJoint(jointReference: JointReference) : JointInstance {
        return joints[jointReference.index]
    }

    fun getJoint(standardPosition: StandardPosition) : JointInstance {
        return joints[getStandardJoint(standardPosition).index]
    }

    fun isLeftFootTouchingGround(): Boolean {
        return isJointTouchingGround(getStandardJoint(StandardPosition.LeftFoot))
    }

    fun isRightFootTouchingGround(): Boolean {
        return isJointTouchingGround(getStandardJoint(StandardPosition.RightFoot))
    }

    private fun isJointTouchingGround(jointReference: JointReference): Boolean {
        val jointInstance = getJoint(jointReference)
        val jointPosition = jointInstance.currentTransform.transform(jointReference.positionOffset)
        return jointPosition.y > -0.02f
    }

    private fun identity() {
        joints.forEach { it.currentTransform.identity() }
    }

    fun tPose() {
        identity()

        for (i in joints.indices) {
            val joint = joints[i]
            joint.currentTransform.translateInPlace(joint.definition.translation)
            joint.currentTransform.multiplyInPlace(joint.definition.rotation.toMat4())
            if (joint.parent != null) {
                joint.parent.currentTransform.multiply(joint.currentTransform, joint.currentTransform)
            }
        }
    }

    fun getStandardJointPosition(index: Int): Vector3f {
        return getJointPosition(getStandardJoint(index))
    }

    fun getStandardJointPosition(standardPosition: StandardPosition): Vector3f {
        return getJointPosition(getStandardJoint(standardPosition))
    }

    private fun getJointPosition(jointReference: JointReference): Vector3f {
        val jointInstance = getJoint(jointReference)
        return jointInstance.currentTransform.transform(jointReference.positionOffset)
    }

    fun animate(actor: Actor, actorModel: ActorModel, actorMount: Mount?) {
        identity()
        joints[0].currentTransform.rotateYInPlace(actor.displayFacingDir)

        // When a PC Model is engaged, the joints that correspond to weapon-handles get re-parented to the right & left hand.
        // The re-parenting effect has odd scaling properties - it seems that the scale of the new parent should be ignored.
        val jointParentOverrides = computeJointParentOverrides(actor, actorModel)
        val parents = jointParentOverrides.values.map { it.index }.toSet()

        // The re-parented joint might have children, and that the new parent might have a greater index that its child (meaning it would be processed later).
        // Any joint that is re-parented, or any child of such a joint, is deferred and handled in a second pass.
        val deferred = LinkedHashSet<Int>()

        for (joint in joints) {
            if (jointParentOverrides.containsKey(joint.index) || deferred.contains(joint.parent?.index)) {
                deferred += joint.index; continue
            }

            updateCurrentJointTransform(actor, actorModel, actorMount, joint, parents)
        }

        for (jointIndex in deferred) {
            val joint = joints[jointIndex]
            val parent = jointParentOverrides[jointIndex]

            if (parent != null) {
                val scale = Vector3f(1f, 1f, 1f)

                val animationTransform = actorModel.skeletonAnimationCoordinator.getJointTransform(joint.index)
                if (animationTransform != null) { scale *= animationTransform.scale }

                joint.currentTransform.copyFrom(parent.currentTransform).scaleInPlace(scale)
            } else {
                updateCurrentJointTransform(actor, actorModel, actorMount, joint, parents)
            }
        }

        updateBoundingBoxes(actor)
    }

    private fun updateCurrentJointTransform(actor: Actor, actorModel: ActorModel, actorMount: Mount?, joint: JointInstance, parents: Set<Int>) {
        val translation = Vector3f(joint.definition.translation)
        val rotation = Quaternion(joint.definition.rotation)
        val scale = Vector3f(1f, 1f, 1f)

        val animationTransform = actorModel.skeletonAnimationCoordinator.getJointTransform(joint.index)
        if (animationTransform != null) {
            translation += animationTransform.translation
            Quaternion.multiplyAndStore(animationTransform.rotation, rotation, rotation)
            scale *= animationTransform.scale
        }

        // For the root-joint, the translation doesn't seem to be in "skeleton-space"
        if (joint.index == 0) { translation.rotate270() }

        if (parents.contains(joint.index)) { scale.copyFrom(Vector3f.ONE) }

        val transform = rotation.toMat4().translateDirect(translation).scaleInPlace(scale)
        if (joint.index == 0) {
            joint.currentTransform.multiplyInPlace(transform)
        } else {
            joint.currentTransform.copyFrom(transform)
        }

        if (joint.index == 2 && actorMount != null) {
            val mountTransform = getMountAttachTransform(actor, actorMount)
            if (mountTransform != null) { joint.currentTransform.copyFrom(mountTransform) }
        } else if (joint.parent != null) {
            joint.parent.currentTransform.multiply(joint.currentTransform, joint.currentTransform)
        }
    }

    private fun computeJointParentOverrides(actor: Actor, actorModel: ActorModel): Map<Int, JointInstance> {
        if (!actor.isDisplayEngaged()) { return emptyMap() }
        if (actorModel.model !is PcModel) { return emptyMap() }

        val overrideMap = HashMap<Int, JointInstance>()

        val mainInfo = actorModel.model.getMainWeaponInfo() ?: return overrideMap
        if (mainInfo.standardJointIndex != null) {
            val mainJointReference = getStandardJoint(mainInfo.standardJointIndex)
            val rightHandJoint = getJoint(StandardPosition.RightHand)
            overrideMap[mainJointReference.index] = rightHandJoint
        }

        val subInfo = actorModel.model.getSubWeaponInfo() ?: return overrideMap
        if (subInfo.standardJointIndex != null) {
            val subJointReference = getStandardJoint(subInfo.standardJointIndex)
            val leftHandJoint = getJoint(StandardPosition.LeftHand)
            overrideMap[subJointReference.index] = leftHandJoint
        }

        return overrideMap
    }

    private fun getMountAttachTransform(actor: Actor, actorMount: Mount): Matrix4f? {
        val riderModel = actor.actorModel?.model ?: return null
        val riderTypeIndex = if (riderModel is PcModel) { riderModel.raceGenderConfig.index - 1 } else { 0 }

        val mount = ActorManager[actorMount.id] ?: return null
        val mountSkeleton = mount.actorModel?.getSkeleton() ?: return null

        val jointRef = mountSkeleton.getStandardJoint(48 + riderTypeIndex)
        val jointInstance = mountSkeleton.getJoint(jointRef)

        return Matrix4f().translateInPlace(jointInstance.currentTransform.transform(jointRef.positionOffset))
            .rotateYInPlace(actor.displayFacingDir - PI_f /2f + actorMount.getRiderRotation())
            .translateInPlace(Vector3f(0f, -0.1f, 0f)) // TODO is there a real offset to use here?
    }


    private fun updateBoundingBoxes(actor: Actor) {
        val boxes = ArrayList<BoundingBox>(resource.boundingBoxes.size)

        val stdRef = getStandardJoint(0)
        val joint = getJoint(stdRef)
        val jointPos = actor.getWorldSpaceJointPosition(0)

        val transform = Matrix4f().translateDirect(jointPos).copyUpperLeft(joint.currentTransform)

        for (baseBox in resource.boundingBoxes) {
            boxes += baseBox.transform(transform)
        }

        actor.updateSkeletonBoundingBoxes(boxes)
    }

}