package xim.poc

import xim.math.Vector3f
import xim.resource.CollisionObject
import xim.resource.CollisionObjectGroup
import xim.resource.DatId
import xim.resource.TerrainType
import kotlin.math.abs
import kotlin.math.max
import kotlin.math.min

class CollisionDepthResult(val minDepth: CollisionDepth)

class CollisionDepth(var amount: Float = Float.MAX_VALUE, var axis: Vector3f = Vector3f.ZERO, var verticalEscape: Float? = null) {

    fun applyIfSmaller(other: CollisionDepth) {
        if (abs(verticalEscape ?: Float.POSITIVE_INFINITY) < abs(other.verticalEscape ?: Float.POSITIVE_INFINITY)) {
            other.verticalEscape = verticalEscape
        }

        if (abs(amount) >= abs(other.amount)) { return }
        other.amount = amount
        other.axis = axis
    }

}

data class ProjectionRange(val min: Float, val max: Float) {

    fun overlapAmount(other: ProjectionRange) : Float? {
        val overlap = overlap(min, max, other.min, other.max) ?: return null
        return if (abs(overlap) < 0.0001f) { null } else { overlap }
    }

    companion object {
        fun overlap(amin: Float, amax: Float, bmin: Float, bmax: Float): Float? {
            if (amax < bmin || bmax < amin) return null
            val depth = 0f

            //we've proven that they do intersect, so...
            if (amin <= bmin) {
                //the case where bs are inside as
                return if (amax > bmax) {
                    //if the distance between maxs is less...
                    if (abs(amax - bmax) < abs(amin - bmin)) {
                        bmin - amax
                    } else bmax - amin
                } else {
                    bmin - amax
                }
            }
            return if (bmin <= amin) {
                //the case where as are inside bs
                if (bmax > amax) {
                    //if the distance between maxs is less...
                    if (abs(amax - bmax) < abs(amin - bmin)) {
                        bmax - amin
                    } else bmin - amax
                } else {
                    bmax - amin
                }
            } else depth
        }

    }

}

data class CollisionProperty(val collisionObject: CollisionObject, val environmentId: DatId?, val cullingTableIndex: Int?, val lightIndices: List<Int>, val terrainType: TerrainType)

object Collider {

    private const val verticalEscapeStep = 0.5001f

    fun updatePosition(areas: List<Area>, actor: Actor, velocity: Vector3f, bbScale: Vector3f, stopOnCollision: Boolean): Map<Area, List<CollisionProperty>> {
        val magnitude = velocity.magnitude()
        val stepDirection = velocity.normalize()

        var accumulator = 0f
        var collisionResults: Map<Area, List<CollisionProperty>> = emptyMap()

        while (accumulator < magnitude) {
            val stepSize = min(magnitude - accumulator, 0.05f)
            accumulator += stepSize

            actor.position += (stepDirection * stepSize)
            collisionResults = updatePositionIteration(areas, actor, bbScale)

            if (collisionResults.values.any { it.isNotEmpty() } && stopOnCollision) {
                break
            }
        }

        return collisionResults
    }

    private fun updatePositionIteration(areas: List<Area>, actor: Actor, bbScale: Vector3f): Map<Area, List<CollisionProperty>> {
        return areas.associateWith {
            val areaCollisionMap = it.getZoneResource().zoneCollisionMap

            if (areaCollisionMap != null) {
                val collisionObjects = areaCollisionMap.getCollisionObjects(actor.position)
                collisionObjects.flatMap { objects -> updatePositionAgainstObject(actor.position, bbScale, objects) }
            } else {
                emptyList()
            }
        }
    }

    private fun updatePositionAgainstObject(position: Vector3f, scale: Vector3f, collisionObjectGroup: CollisionObjectGroup): List<CollisionProperty> {
        val properties = ArrayList<CollisionProperty>()

        for (obj in collisionObjectGroup.collisionObjects) {
            if (obj.transformInfo.subAreaLinkId != null && SceneManager.getCurrentScene().isCurrentSubArea(obj.transformInfo.subAreaLinkId)) {
                continue // defer to the loaded sub-area's collision instead
            }

            val collisionProperties = collideWithObj(position, scale, obj)
            if (collisionProperties.isNotEmpty()) {
                val firstTerrainType = collisionProperties[0]
                properties += CollisionProperty(
                    collisionObject = obj,
                    environmentId = obj.transformInfo.linkedDatId,
                    cullingTableIndex = obj.transformInfo.cullingTableIndex,
                    lightIndices = obj.transformInfo.lightIndices,
                    terrainType = firstTerrainType
                )
            }
        }

        return properties
    }

    private fun collideWithObj(position: Vector3f, scale: Vector3f, obj: CollisionObject) : List<TerrainType> {
        var boundingBox = BoundingBox.scaled(scale, position)
        var transformedBox = boundingBox.transform(obj.transformInfo.invTransform)
        val bsSphere = transformedBox.toBoundingSphere()

        if (!Sphere.intersects(bsSphere, obj.collisionMesh.boundingSphere)) {
            return emptyList()
        }

        val terrainTypes = ArrayList<TerrainType>()

        for (tri in obj.collisionMesh.tris) {
            val depthResult = SatCollider.boxTriSeparatingAxisTheorem(transformedBox, tri) ?: continue
            val depth = depthResult.minDepth

            val slope = abs(tri.normal.dot(Vector3f.NegY))
            val height = tri.height()

            if (slope < 0.75f && height > verticalEscapeStep) {
                depth.verticalEscape = null
            }

            val verticalEscape = depth.verticalEscape
            position += if (verticalEscape != null && verticalEscape > 0f) {
                Vector3f.NegY * verticalEscape
            } else {
                obj.transformInfo.transform.transformDirectionVector(depth.axis) * (depth.amount * 1.015f)
            }

            boundingBox = BoundingBox.scaled(scale, position)
            transformedBox = boundingBox.transform(obj.transformInfo.invTransform)

            terrainTypes.add(tri.type)
        }

        return terrainTypes
    }

    fun nearestFloor(position: Vector3f, areas: List<Area>): Vector3f? {
        return areas.firstNotNullOfOrNull { nearestFloor(position, it) }
    }

    private fun nearestFloor(position: Vector3f, area: Area): Vector3f? {
        val collisionMap = area.getZoneResource().zoneCollisionMap ?: return null
        val collisionGroups = collisionMap.getCollisionObjects(position)

        var nearestFloor: Pair<Vector3f, Float>? = null

        for (collisionObjectGroup in collisionGroups) {
            for (collisionObject in collisionObjectGroup.collisionObjects) {
                val intersection = getFloorDistance(position, collisionObject) ?: continue

                if (nearestFloor == null) {
                    nearestFloor = intersection
                } else if (nearestFloor.second > intersection.second) {
                    nearestFloor = intersection
                }
            }
        }

        return nearestFloor?.first
    }

    private fun getFloorDistance(position: Vector3f, collisionObject: CollisionObject): Pair<Vector3f, Float>? {
        val collisionSpaceRay = collisionObject.transformInfo.invTransform.transformDirectionVector(Vector3f.Y)
        val collisionSpacePosition = collisionObject.transformInfo.invTransform.transform(position)

        var nearestFloor: Pair<Vector3f, Float>? = null

        for (tri in collisionObject.collisionMesh.tris) {
            val intersection = RayTriangleCollider.intersect(tri, collisionSpacePosition, collisionSpaceRay) ?: continue

            if (nearestFloor != null && nearestFloor.second < intersection.second) { continue }

            val worldSpaceIntersection = collisionObject.transformInfo.transform.transform(intersection.first)
            nearestFloor = Pair(worldSpaceIntersection, intersection.second)
        }

        return nearestFloor
    }

}

object SatCollider {

    fun boxTriSeparatingAxisTheorem(b: BoundingBox, t: Triangle) : CollisionDepthResult? {
        // For SAT, prove there's overlap against all 13 axis
        val minDepth = CollisionDepth()

        // 3 are trivial from the AABB
        boxTriAxisCheck(b, t, Vector3f.X)?.applyIfSmaller(minDepth) ?: return null
        boxTriAxisCheck(b, t, Vector3f.Y)?.applyIfSmaller(minDepth) ?: return null
        boxTriAxisCheck(b, t, Vector3f.Z)?.applyIfSmaller(minDepth) ?: return null

        // 1 from the triangle's normal
        boxTriAxisCheck(b, t, t.normal)?.applyIfSmaller(minDepth) ?: return null

        // And 9 from the combination of edges
        for (unitAxis in listOf(Vector3f.X, Vector3f.Y, Vector3f.Z)) {
            for (edge in t.getEdges()) {
                val axis = unitAxis.cross(edge)
                if (axis.magnitudeSquare() < 0.001f) { continue }
                axis.normalizeInPlace()
                boxTriAxisCheck(b, t, axis)?.applyIfSmaller(minDepth) ?: return null
            }
        }

        return CollisionDepthResult(minDepth = minDepth)
    }

    private fun boxTriAxisCheck(b: BoundingBox, t: Triangle, axis: Vector3f) : CollisionDepth? {
        val boxProjection = projectVertices(b.vertices, axis)
        val triProjection = projectVertices(t.vertices, axis)
        val overlapAmount = boxProjection.overlapAmount(triProjection) ?: return null

        val verticalProjection = axis.dot(Vector3f.NegY)
        val verticalEscape = if (abs(verticalProjection) < 0.0001) { Float.POSITIVE_INFINITY } else { overlapAmount / verticalProjection }

        return CollisionDepth(overlapAmount, axis, verticalEscape)
    }

    fun boxBoxIntersection(a: Box, b: Box): Boolean {
        val distance = Vector3f.distanceSquared(a.getCenter(), b.getCenter())
        if (distance > a.getRadiusSq() + b.getRadiusSq()) { return false }

        if (a is AxisAlignedBoundingBox && b is AxisAlignedBoundingBox) {
            return AxisAlignedBoundingBox.intersects(a, b)
        }

        // For SAT, prove there's overlap against all 15 axis
        val minDepth = CollisionDepth()
        val aAxes = a.getAxes()
        val bAxes = b.getAxes()

        for (i in 0 until 3) {
            boxBoxAxisCheck(a, b, aAxes[i])?.applyIfSmaller(minDepth) ?: return false
            boxBoxAxisCheck(a, b, bAxes[i])?.applyIfSmaller(minDepth) ?: return false
        }

        for (i in 0 until 3) {
            for (j in 0 until 3) {
                val axis = aAxes[i].cross(bAxes[j])
                if (axis.magnitudeSquare() < 0.0001f) { continue }
                boxBoxAxisCheck(a, b, axis.normalizeInPlace())?.applyIfSmaller(minDepth) ?: return false
            }
        }

        return true
    }

    private fun boxBoxAxisCheck(a: Box, b: Box, axis: Vector3f) : CollisionDepth? {
        val aProjection = projectVertices(a.getVertices(), axis)
        val bProjection = projectVertices(b.getVertices(), axis)
        val overlapAmount = aProjection.overlapAmount(bProjection) ?: return null

        val verticalProjection = axis.dot(Vector3f.NegY)
        val verticalEscape = if (abs(verticalProjection) < 0.0001) { Float.POSITIVE_INFINITY } else { overlapAmount / verticalProjection }

        return CollisionDepth(overlapAmount, axis, verticalEscape)
    }

    private fun projectVertices(vertices: List<Vector3f>, axis: Vector3f) : ProjectionRange {
        var min = Float.POSITIVE_INFINITY
        var max = Float.NEGATIVE_INFINITY

        for (v in vertices) {
            val projection = v.dot(axis)
            min = min(min, projection)
            max = max(max, projection)
        }

        return ProjectionRange(min, max)
    }

    private fun projectVertices(vertices: Array<Vector3f>, axis: Vector3f) : ProjectionRange {
        var min = Float.POSITIVE_INFINITY
        var max = Float.NEGATIVE_INFINITY

        for (v in vertices) {
            val projection = v.dot(axis)
            min = min(min, projection)
            max = max(max, projection)
        }

        return ProjectionRange(min, max)
    }
}

object RayTriangleCollider {

    fun intersect(tri: Triangle, origin: Vector3f, direction: Vector3f): Pair<Vector3f, Float>? {
        val d = tri.normal.dot(direction)
        if (abs(d) <= 1e-5) { return null }

        val planeConstant = tri.normal.dot(tri.t0)
        val rayDistance = (planeConstant - tri.normal.dot(origin)) / d
        if (rayDistance < 0f) { return null }

        val q = origin + direction * rayDistance

        if ((tri.t0 - tri.t1).cross(q - tri.t0).dot(tri.normal) < 0f) { return null }
        if ((tri.t1 - tri.t2).cross(q - tri.t1).dot(tri.normal) < 0f) { return null }
        if ((tri.t2 - tri.t0).cross(q - tri.t2).dot(tri.normal) < 0f) { return null }

        return Pair(q, rayDistance)
    }

}

object RaySphereCollider {

    fun intersect(ray: Ray, sphere: Sphere): Boolean {
        val dC = sphere.center - ray.origin
        val proj = ray.direction.dot(dC)

        if (proj < 0) { return false }

        val rayPoint = ray.origin + ray.direction * proj
        val distance = Vector3f.distanceSquared(rayPoint, sphere.center)

        return distance <= sphere.radiusSq
    }

}