package xim.poc

import xim.math.Vector3f
import xim.resource.CollisionObject
import xim.resource.CollisionObjectGroup
import xim.resource.DatId
import xim.resource.TerrainType
import kotlin.math.abs
import kotlin.math.max
import kotlin.math.min
import kotlin.math.sqrt

class CollisionDepthResult(val minDepth: CollisionDepth)

class CollisionDepth(var amount: Float = Float.POSITIVE_INFINITY, var axis: Vector3f = Vector3f.ZERO, var verticalEscape: Float? = null) {

    fun applyIfSmaller(other: CollisionDepth) {
        val myEscape = verticalEscape
        if (myEscape != null && myEscape > 0f && myEscape < (other.verticalEscape ?: Float.POSITIVE_INFINITY)) {
            other.verticalEscape = myEscape
        }

        if (abs(amount) >= abs(other.amount)) { return }
        other.amount = amount
        other.axis = axis
    }

}

data class ProjectionRange(val min: Float, val max: Float) {

    fun overlapAmount(other: ProjectionRange) : Float? {
        val overlap = overlap(min, max, other.min, other.max) ?: return null
        return if (abs(overlap) < 0.0001f) { null } else { overlap }
    }

    companion object {
        fun overlap(amin: Float, amax: Float, bmin: Float, bmax: Float): Float? {
            if (amax < bmin || bmax < amin) return null
            val depth = 0f

            //we've proven that they do intersect, so...
            if (amin <= bmin) {
                //the case where bs are inside as
                return if (amax > bmax) {
                    //if the distance between maxs is less...
                    if (abs(amax - bmax) < abs(amin - bmin)) {
                        bmin - amax
                    } else bmax - amin
                } else {
                    bmin - amax
                }
            }
            return if (bmin <= amin) {
                //the case where as are inside bs
                if (bmax > amax) {
                    //if the distance between maxs is less...
                    if (abs(amax - bmax) < abs(amin - bmin)) {
                        bmax - amin
                    } else bmin - amax
                } else {
                    bmax - amin
                }
            } else depth
        }

    }

}

data class CollisionProperty(val collisionObject: CollisionObject, val environmentId: DatId?, val cullingTableIndex: Int?, val lightIndices: List<Int>, val terrainType: TerrainType)

object Collider {

    private const val maxVerticalEscapeStep = 0.4001f

    fun updatePosition(areas: List<Area>, actor: Actor, velocity: Vector3f, bbScale: Vector3f): Map<Area, List<CollisionProperty>> {
        val magnitude = velocity.magnitude()
        val stepDirection = velocity.normalize()

        var accumulator = 0f
        var collisionResults: Map<Area, List<CollisionProperty>> = emptyMap()

        while (accumulator < magnitude) {
            val stepSize = min(magnitude - accumulator, 0.05f)
            accumulator += stepSize

            val startingPosition = Vector3f(actor.position)
            actor.position += (stepDirection * stepSize)
            collisionResults = updatePositionIteration(areas, actor, bbScale)

            if (Vector3f.distanceSquared(startingPosition, actor.position) < 0.001f) { break }
        }

        return collisionResults
    }

    private fun updatePositionIteration(areas: List<Area>, actor: Actor, bbScale: Vector3f): Map<Area, List<CollisionProperty>> {
        return areas.associateWith {
            val areaCollisionMap = it.getZoneResource().zoneCollisionMap
            val collisionProperties = ArrayList<CollisionProperty>()

            if (areaCollisionMap != null) {
                val collisionObjects = areaCollisionMap.getCollisionObjects(actor.position)

                for (collisionObject in collisionObjects) {
                    val collisions = updatePositionAgainstObject(actor, bbScale, collisionObject)
                    if (collisions.isNotEmpty()) { collisionProperties += collisions }
                }

                collisionProperties
            } else {
                emptyList()
            }
        }
    }

    private fun updatePositionAgainstObject(actor: Actor, scale: Vector3f, collisionObjectGroup: CollisionObjectGroup): List<CollisionProperty> {
        val properties = ArrayList<CollisionProperty>()

        for (obj in collisionObjectGroup.collisionObjects) {
            if (obj.transformInfo.subAreaLinkId != null && SceneManager.getCurrentScene().isCurrentSubArea(obj.transformInfo.subAreaLinkId)) {
                continue // defer to the loaded sub-area's collision instead
            }

            val collisionProperties = collideSphereWithObj(actor, scale, obj)
            if (collisionProperties.isNotEmpty()) {
                val firstTerrainType = collisionProperties[0]
                properties += CollisionProperty(
                    collisionObject = obj,
                    environmentId = obj.transformInfo.linkedDatId,
                    cullingTableIndex = obj.transformInfo.cullingTableIndex,
                    lightIndices = obj.transformInfo.lightIndices,
                    terrainType = firstTerrainType
                )
            }
        }

        return properties
    }

    private fun collideSphereWithObj(actor: Actor, scale: Vector3f, obj: CollisionObject) : List<TerrainType> {
        val offsetPosition = Vector3f(actor.position).also { it.y -= scale.x }
        obj.transformInfo.invTransform.transformInPlace(offsetPosition)
        val transformedSphere = Sphere(offsetPosition, scale.x)

        if (!Sphere.intersects(transformedSphere, obj.collisionMesh.boundingSphere)) {
            return emptyList()
        }

        val terrainTypes = ArrayList<TerrainType>()

        for (tri in obj.collisionMesh.tris) {
            val depthResult = TriangleSphereCollider.intersect(obj, tri, transformedSphere) ?: continue
            resolveCollision(actor, obj, tri, depthResult)

            offsetPosition.copyFrom(actor.position).also { it.y -= scale.x }
            obj.transformInfo.invTransform.transformInPlace(offsetPosition)

            terrainTypes.add(tri.type)
        }

        return terrainTypes
    }

    private fun collideBoxWithObj(actor: Actor, scale: Vector3f, obj: CollisionObject) : List<TerrainType> {
        var boundingBox = BoundingBox.scaled(scale, actor.position)
        var transformedBox = boundingBox.transform(obj.transformInfo.invTransform)
        val bsSphere = transformedBox.toBoundingSphere()

        if (!Sphere.intersects(bsSphere, obj.collisionMesh.boundingSphere)) {
            return emptyList()
        }

        val terrainTypes = ArrayList<TerrainType>()

        for (tri in obj.collisionMesh.tris) {
            val depthResult = SatCollider.boxTriSeparatingAxisTheorem(transformedBox, tri) ?: continue
            resolveCollision(actor, obj, tri, depthResult)

            boundingBox = BoundingBox.scaled(scale, actor.position)
            transformedBox = boundingBox.transform(obj.transformInfo.invTransform)

            terrainTypes.add(tri.type)
        }

        return terrainTypes
    }

    private fun resolveCollision(actor: Actor, obj: CollisionObject, tri: Triangle, depthResult: CollisionDepthResult) {
        val depth = depthResult.minDepth

        val slope = abs(tri.normal.dot(Vector3f.NegY))
        val height = tri.height()

        if (slope < 0.80f && height > maxVerticalEscapeStep) {
            depth.verticalEscape = null
        }

        // Don't snap in free-fall. It allows ledge-walking bugs in high frame-rates (less gravity per frame)
        if (!tri.type.verticalEscape || actor.lastCollisionResult.isInFreeFall()) {
            depth.verticalEscape = null
        }

        val verticalEscape = depth.verticalEscape
        actor.position += if (verticalEscape != null && verticalEscape > 0f) {
            Vector3f.NegY * verticalEscape
        } else {
            obj.transformInfo.transform.transformDirectionVector(depth.axis) * (depth.amount * 1.015f)
        }
    }

    fun nearestFloor(position: Vector3f, areas: List<Area>): Vector3f? {
        return areas.firstNotNullOfOrNull { nearestFloor(position, it) }
    }

    private fun nearestFloor(position: Vector3f, area: Area): Vector3f? {
        val collisionMap = area.getZoneResource().zoneCollisionMap ?: return null
        val collisionGroups = collisionMap.getCollisionObjects(position)

        var nearestFloor: Pair<Vector3f, Float>? = null

        for (collisionObjectGroup in collisionGroups) {
            for (collisionObject in collisionObjectGroup.collisionObjects) {
                val intersection = getFloorDistance(position, collisionObject) ?: continue

                if (nearestFloor == null) {
                    nearestFloor = intersection
                } else if (nearestFloor.second > intersection.second) {
                    nearestFloor = intersection
                }
            }
        }

        return nearestFloor?.first
    }

    private fun getFloorDistance(position: Vector3f, collisionObject: CollisionObject): Pair<Vector3f, Float>? {
        val collisionSpaceRay = collisionObject.transformInfo.invTransform.transformDirectionVector(Vector3f.Y)
        val collisionSpacePosition = collisionObject.transformInfo.invTransform.transform(position)

        var nearestFloor: Pair<Vector3f, Float>? = null

        for (tri in collisionObject.collisionMesh.tris) {
            val intersection = RayTriangleCollider.intersect(tri, collisionSpacePosition, collisionSpaceRay) ?: continue

            if (nearestFloor != null && nearestFloor.second < intersection.second) { continue }

            val worldSpaceIntersection = collisionObject.transformInfo.transform.transform(intersection.first)
            nearestFloor = Pair(worldSpaceIntersection, intersection.second)
        }

        return nearestFloor
    }

}

object SatCollider {

    fun boxTriSeparatingAxisTheorem(b: BoundingBox, t: Triangle) : CollisionDepthResult? {
        // For SAT, prove there's overlap against all 13 axis
        val minDepth = CollisionDepth()

        // 3 are trivial from the AABB
        boxTriAxisCheck(b, t, Vector3f.X)?.applyIfSmaller(minDepth) ?: return null
        boxTriAxisCheck(b, t, Vector3f.Y)?.applyIfSmaller(minDepth) ?: return null
        boxTriAxisCheck(b, t, Vector3f.Z)?.applyIfSmaller(minDepth) ?: return null

        // 1 from the triangle's normal
        boxTriAxisCheck(b, t, t.normal)?.applyIfSmaller(minDepth) ?: return null

        // And 9 from the combination of edges
        for (unitAxis in listOf(Vector3f.X, Vector3f.Y, Vector3f.Z)) {
            for (edge in t.getEdges()) {
                val axis = unitAxis.cross(edge)
                if (axis.magnitudeSquare() < 0.001f) { continue }
                axis.normalizeInPlace()
                boxTriAxisCheck(b, t, axis)?.applyIfSmaller(minDepth) ?: return null
            }
        }

        return CollisionDepthResult(minDepth = minDepth)
    }

    private fun boxTriAxisCheck(b: BoundingBox, t: Triangle, axis: Vector3f) : CollisionDepth? {
        val boxProjection = projectVertices(b.vertices, axis)
        val triProjection = projectVertices(t.vertices, axis)
        val overlapAmount = boxProjection.overlapAmount(triProjection) ?: return null

        val verticalProjection = axis.dot(Vector3f.NegY)
        val verticalEscape = if (abs(verticalProjection) < 0.0001) { Float.POSITIVE_INFINITY } else { overlapAmount / verticalProjection }

        return CollisionDepth(overlapAmount, axis, verticalEscape)
    }

    fun boxBoxIntersection(a: Box, b: Box): Boolean {
        val distance = Vector3f.distanceSquared(a.getCenter(), b.getCenter())
        if (distance > a.getRadiusSq() + b.getRadiusSq()) { return false }

        if (a is AxisAlignedBoundingBox && b is AxisAlignedBoundingBox) {
            return AxisAlignedBoundingBox.intersects(a, b)
        }

        // For SAT, prove there's overlap against all 15 axis
        val minDepth = CollisionDepth()
        val aAxes = a.getAxes()
        val bAxes = b.getAxes()

        for (i in 0 until 3) {
            boxBoxAxisCheck(a, b, aAxes[i])?.applyIfSmaller(minDepth) ?: return false
            boxBoxAxisCheck(a, b, bAxes[i])?.applyIfSmaller(minDepth) ?: return false
        }

        for (i in 0 until 3) {
            for (j in 0 until 3) {
                val axis = aAxes[i].cross(bAxes[j])
                if (axis.magnitudeSquare() < 0.0001f) { continue }
                boxBoxAxisCheck(a, b, axis.normalizeInPlace())?.applyIfSmaller(minDepth) ?: return false
            }
        }

        return true
    }

    private fun boxBoxAxisCheck(a: Box, b: Box, axis: Vector3f) : CollisionDepth? {
        val aProjection = projectVertices(a.getVertices(), axis)
        val bProjection = projectVertices(b.getVertices(), axis)
        val overlapAmount = aProjection.overlapAmount(bProjection) ?: return null

        val verticalProjection = axis.dot(Vector3f.NegY)
        val verticalEscape = if (abs(verticalProjection) < 0.0001) { Float.POSITIVE_INFINITY } else { overlapAmount / verticalProjection }

        return CollisionDepth(overlapAmount, axis, verticalEscape)
    }

    private fun projectVertices(vertices: List<Vector3f>, axis: Vector3f) : ProjectionRange {
        var min = Float.POSITIVE_INFINITY
        var max = Float.NEGATIVE_INFINITY

        for (v in vertices) {
            val projection = v.dot(axis)
            min = min(min, projection)
            max = max(max, projection)
        }

        return ProjectionRange(min, max)
    }

    private fun projectVertices(vertices: Array<Vector3f>, axis: Vector3f) : ProjectionRange {
        var min = Float.POSITIVE_INFINITY
        var max = Float.NEGATIVE_INFINITY

        for (v in vertices) {
            val projection = v.dot(axis)
            min = min(min, projection)
            max = max(max, projection)
        }

        return ProjectionRange(min, max)
    }
}

object RayTriangleCollider {

    fun intersect(tri: Triangle, origin: Vector3f, direction: Vector3f, rayLength: Float? = null): Pair<Vector3f, Float>? {
        val d = tri.normal.dot(direction)
        if (abs(d) <= 1e-5) { return null }

        val planeConstant = tri.normal.dot(tri.t0)
        val rayDistance = (planeConstant - tri.normal.dot(origin)) / d

        if (rayDistance < 0f) { return null }
        if (rayLength != null && rayDistance > rayLength) { return null }

        val q = origin + direction * rayDistance

        if ((tri.t0 - tri.t1).cross(q - tri.t0).dot(tri.normal) < 0f) { return null }
        if ((tri.t1 - tri.t2).cross(q - tri.t1).dot(tri.normal) < 0f) { return null }
        if ((tri.t2 - tri.t0).cross(q - tri.t2).dot(tri.normal) < 0f) { return null }

        return Pair(q, rayDistance)
    }

}

object RaySphereCollider {

    fun intersect(ray: Ray, sphere: Sphere): Boolean {
        val dC = sphere.center - ray.origin
        val proj = ray.direction.dot(dC)

        if (proj < 0) { return false }

        val rayPoint = ray.origin + ray.direction * proj
        val distance = Vector3f.distanceSquared(rayPoint, sphere.center)

        return distance <= sphere.radiusSq
    }

}

object TriangleSphereCollider {

    fun intersect(obj: CollisionObject, triangle: Triangle, sphere: Sphere): CollisionDepthResult? {
        val minDepth = CollisionDepth()

        val planeIntersection = RayTriangleCollider.intersect(tri = triangle, origin = sphere.center, direction = triangle.normal * -1f, rayLength = sphere.radius)

        if (planeIntersection != null) {
            resolvePointCollision(obj, planeIntersection.first, sphere)?.applyIfSmaller(minDepth)
        } else {
            val closestPointOnEdge0 = closestPointOnEdge(triangle.t0, triangle.t1, sphere)
            resolvePointCollision(obj, closestPointOnEdge0, sphere)?.applyIfSmaller(minDepth)

            val closestPointOnEdge1 = closestPointOnEdge(triangle.t1, triangle.t2, sphere)
            resolvePointCollision(obj, closestPointOnEdge1, sphere)?.applyIfSmaller(minDepth)

            val closestPointOnEdge2 = closestPointOnEdge(triangle.t2, triangle.t0, sphere)
            resolvePointCollision(obj, closestPointOnEdge2, sphere)?.applyIfSmaller(minDepth)
        }

        if (minDepth.amount.isInfinite()) { return null }
        return CollisionDepthResult(minDepth)
    }

    private fun closestPointOnEdge(edgeStart: Vector3f, edgeEnd: Vector3f, sphere: Sphere): Vector3f {
        val edgeVector = edgeEnd - edgeStart
        val t  = edgeVector.dot(sphere.center - edgeStart) / edgeVector.dot(edgeVector)
        return edgeStart + (edgeVector * t.coerceIn(0f, 1f))
    }

    private fun resolvePointCollision(obj: CollisionObject, point: Vector3f, sphere: Sphere): CollisionDepth? {
        // Penetration-normal escape
        val sphereToPointVector = sphere.center - point
        val sphereToPointDistance = sphereToPointVector.magnitude()
        if (sphereToPointDistance > sphere.radius) { return null }
        val amount = sphere.radius - sphereToPointDistance

        // Vertical escape
        val horizontal = sphere.center - point.withY(sphere.center.y)
        val verticalLength = sqrt(sphere.radiusSq - horizontal.magnitudeSquare())
        val verticalPoint = sphere.center.y + obj.transformInfo.invTransform.transformYAxisVector(y = verticalLength, w = 0f) // Some objects have flipped collision-space...
        val verticalEscape = verticalPoint - point.y

        return CollisionDepth(amount = amount, axis = sphereToPointVector.normalize(), verticalEscape = verticalEscape)
    }

}