package xim.poc.gl

import xim.poc.gl.ShaderConstants.diffuseLightCalcFn
import xim.poc.gl.ShaderConstants.getUniformDiffuseLight
import xim.poc.gl.ShaderConstants.getUniformPointLight
import xim.poc.gl.ShaderConstants.pointLightCalcFn
import web.gl.WebGL2RenderingContext
import web.gl.WebGLUniformLocation
import xim.poc.gl.ShaderConstants.fogCalcFn
import xim.poc.gl.ShaderConstants.getUniformFog

class XimLocations(val ximProgram: GLProgram, val context: WebGL2RenderingContext) {

    val diffuseTexture: WebGLUniformLocation

    val uProjMatrix: WebGLUniformLocation
    val uModelMatrix: WebGLUniformLocation
    val uViewMatrix: WebGLUniformLocation
    val discardThreshold: WebGLUniformLocation
    val positionBlendWeight: WebGLUniformLocation

    val ambientLightColor: WebGLUniformLocation

    val pointLight: Array<UniformPointLight>

    val diffuseLight0: UniformDiffuseLight
    val diffuseLight1: UniformDiffuseLight

    val fog: UniformFog

    init {
        diffuseTexture = getUniformLoc("diffuseTexture")
        uProjMatrix = getUniformLoc("uProjMatrix")
        uModelMatrix = getUniformLoc("uModelMatrix")
        uViewMatrix = getUniformLoc("uViewMatrix")
        discardThreshold = getUniformLoc("discardThreshold")
        positionBlendWeight = getUniformLoc("positionBlendWeight")

        ambientLightColor = getUniformLoc("ambientLightColor")

        pointLight = arrayOf(
            getUniformPointLight(ximProgram, "pointLights[0]"),
            getUniformPointLight(ximProgram, "pointLights[1]"),
            getUniformPointLight(ximProgram, "pointLights[2]"),
            getUniformPointLight(ximProgram, "pointLights[3]"),
        )

        diffuseLight0 = getUniformDiffuseLight(ximProgram, "diffuseLights[0]")
        diffuseLight1 = getUniformDiffuseLight(ximProgram, "diffuseLights[1]")

        fog = getUniformFog(ximProgram, "fog")
    }

    private fun getUniformLoc(name: String) : web.gl.WebGLUniformLocation {
        return context.getUniformLocation(ximProgram.programId, name) ?: throw IllegalStateException("$name is not defined in XimProgram")
    }

}

object XimShader {

    private lateinit var locations: XimLocations

    fun getLocations(ximProgram: GLProgram, context: WebGL2RenderingContext) : XimLocations {
        loadLocations(ximProgram, context)
        return locations
    }

    private fun loadLocations(ximProgram: GLProgram, context: WebGL2RenderingContext) {
        if (!this::locations.isInitialized) {
            locations = XimLocations(ximProgram, context)
        }
    }


    const val vertShader = """${ShaderConstants.version}

${ShaderConstants.pointLightStruct}

${ShaderConstants.diffuseLightStruct}

uniform sampler2D diffuseTexture;

uniform mat4 uProjMatrix;
uniform mat4 uModelMatrix;
uniform mat4 uViewMatrix;

uniform vec4 ambientLightColor;

uniform float positionBlendWeight;

layout(location=0) in vec3 position0;
layout(location=1) in vec3 position1;

layout(location=2) in vec3 normal0;

layout(location=4) in vec2 textureCoords;
layout(location=8) in vec4 vertexColor;

out vec2 frag_textureCoords;
out vec4 frag_color;
out vec4 frag_cameraPos;

$diffuseLightCalcFn

$pointLightCalcFn

void main(){
	frag_textureCoords = textureCoords;
    
    mat3 invTransposeModel = transpose(inverse(mat3(uModelMatrix)));
    vec3 transformedNormal = normalize(invTransposeModel * normal0);
    
    vec3 position = position0 + positionBlendWeight * position1;
    vec4 worldPos = uModelMatrix * vec4(position, 1.0);
    vec4 cameraPos = uViewMatrix * worldPos;
        
    // Lighting Calc
    vec4 finalAmbientColor = vertexColor * ambientLightColor;

    vec4 df0 = diffuseLightCalc(transformedNormal, vertexColor, diffuseLights[0]);
    vec4 df1 = diffuseLightCalc(transformedNormal, vertexColor, diffuseLights[1]);

    vec4 pl;
    for (int i = 0; i < 4; i++) {
        pl += pointLightCalc(worldPos, transformedNormal, vertexColor, pointLights[i]);
    }

    frag_color = clamp(vec4((finalAmbientColor + pl + df0 + df1).rgb, vertexColor.a), 0.0, 1.0);
    frag_cameraPos = cameraPos;
	gl_Position = uProjMatrix * cameraPos;
}
"""


    const val fragShader = """${ShaderConstants.version}
precision highp float;

uniform float discardThreshold;
uniform sampler2D diffuseTexture;

in vec2 frag_textureCoords;
in vec4 frag_color;
in vec4 frag_cameraPos;

out vec4 outColor;

${ShaderConstants.fogStruct}
$fogCalcFn

void main()
{
    vec4 basePixel = texture(diffuseTexture, frag_textureCoords); 
    vec4 coloredPixel = vec4(2.0 * frag_color.rgb * basePixel.rgb, 4.0 * frag_color.a * basePixel.a);
    
    if (coloredPixel.a < discardThreshold) { discard; }
    
	outColor = fogCalc(frag_cameraPos.xyz, coloredPixel);
}        
"""

}