package xim.poc.gl

import js.buffer.ArrayBuffer
import js.typedarrays.Uint8Array
import js.typedarrays.Float32Array
import web.gl.WebGL2RenderingContext
import web.gl.WebGL2RenderingContext.Companion.ARRAY_BUFFER
import web.gl.WebGL2RenderingContext.Companion.DYNAMIC_DRAW
import web.gl.WebGLBuffer
import web.gl.WebGLVertexArrayObject
import xim.resource.UiElementComponent
import xim.resource.UiVertex

object QuadVertexBuffer {

    private val lookupTree = HashMap<UiVertex, HashMap<UiVertex, HashMap<UiVertex, HashMap<UiVertex, Int>>>>()
    private val uvs = floatArrayOf(
        -0.5f,-0.5f,
         0.5f,-0.5f,
        -0.5f, 0.5f,
         0.5f, 0.5f
    ).toTypedArray()

    private val maxAllocations = 20000
    private var numAllocated = 0

    private val sizeOfElementFloats = (4*3) + (4*2) + 4 // 4 positions + 4 UVs + 4 colors
    private val sizeOfElementBytes = 4 * sizeOfElementFloats

    private lateinit var vertexArray: WebGLVertexArrayObject
    private lateinit var vertexBuffer: WebGLBuffer

    fun get(element: UiElementComponent): Int {
        return lookupTree.getOrPut(element.vertices[0]) { HashMap() }
            .getOrPut(element.vertices[1]) { HashMap() }
            .getOrPut(element.vertices[2]) { HashMap() }
            .getOrPut(element.vertices[3]) { allocate(element) }
    }

    fun getVao(vertexAttributeSetter: () -> Unit): WebGLVertexArrayObject {
        val webgl = GlDisplay.getContext()
        if (!this::vertexBuffer.isInitialized) { initializeBuffer(webgl, vertexAttributeSetter) }
        return vertexArray
    }

    private fun allocate(element: UiElementComponent): Int {
        val webgl = GlDisplay.getContext()

        if (numAllocated == maxAllocations) { throw IllegalStateException("Ran out of room in the buffer :(") }
        val offset = numAllocated
        numAllocated += 1

        val buffer = ArrayBuffer(sizeOfElementBytes)
        val floatView = Float32Array(buffer)
        val bytesView = Uint8Array(buffer)

        var floatPos = 0

        for (i in 0 until 4) {
            val vertex = element.vertices[i]

            floatView[floatPos++] = vertex.point.x
            floatView[floatPos++] = vertex.point.y
            floatView[floatPos++] = 0f

            floatView[floatPos++] = uvs[i*2]
            floatView[floatPos++] = uvs[i*2+1]

            bytesView[floatPos*4 + 0] = vertex.color.r.toByte()
            bytesView[floatPos*4 + 1] = vertex.color.g.toByte()
            bytesView[floatPos*4 + 2] = vertex.color.b.toByte()
            bytesView[floatPos*4 + 3] = vertex.color.a.toByte()

            floatPos += 1
        }

        webgl.bindBuffer(ARRAY_BUFFER, vertexBuffer)
        webgl.bufferSubData(ARRAY_BUFFER, offset * sizeOfElementBytes, buffer)

        return offset
    }

    private fun initializeBuffer(webgl: WebGL2RenderingContext, vertexAttributeSetter: () -> Unit) {
        vertexArray = webgl.createVertexArray()!!
        webgl.bindVertexArray(vertexArray)

        vertexBuffer = webgl.createBuffer()!!
        webgl.bindBuffer(ARRAY_BUFFER, vertexBuffer)

        val bufferData = Float32Array(maxAllocations * sizeOfElementFloats)
        webgl.bufferData(ARRAY_BUFFER, bufferData, DYNAMIC_DRAW)

        vertexAttributeSetter.invoke()
    }

}