import type { BoundingInfo, Scene } from '@babylonjs/core';
import { ShaderMaterial } from '@babylonjs/core';
import { Vector3, VertexBuffer } from '@babylonjs/core';
import { Mesh, MeshBuilder } from '@babylonjs/core';

/* eslint-disable no-bitwise */
interface Voxel {
  x: number;
  y: number;
  z: number;
  r: number;
  g: number;
  b: number;
}

export const VoxelColorMethods = [
  { label: 'Color', key: 0 },
  { label: 'Rainbow', key: 1 },
  { label: 'Height-Rainbow', key: 2 },
  { label: 'Height-Mono', key: 3 },
];

export const VoxelsDefaultDeadline = '5s';
export const VoxelsDefaultTTL = '5s';

export class VoxelsTiler extends Mesh {
  public override getBoundingInfo(): BoundingInfo {
    const bounds = this.getHierarchyBoundingVectors();
    this.buildBoundingInfo(bounds.min, bounds.max);
    return super.getBoundingInfo();
  }
}

export class Voxels extends Mesh {
  public colorMethod = VoxelColorMethods[0];
  private voxels: Mesh;
  private maxOctreeDepth = 0;
  private currentFrameID = BigInt(0);
  private expectedChunks = 0;
  private receivedChunks = 0;
  private dataArrays = null;
  private readonly shaderMaterial: ShaderMaterial;

  constructor(name: string, scene: Scene) {
    super(name, scene);
    this.shaderMaterial = this.createShaderMaterial();
  }

  public setColourMethod(mode: string) {
    this.colorMethod = VoxelColorMethods.find((m) => m.label === mode);
    if (!this.colorMethod) this.colorMethod = VoxelColorMethods[0];
    this.updateMaterial();
  }

  public getColorMethod(): string {
    return this.colorMethod.label;
  }

  public getVoxelCount(): number {
    if (!this.voxels) return 0;
    return this.dataArrays.matricesData.length / 16;
  }

  public renderBoundUpdate(payload: string) {
    if (payload === '') {
      this.voxels?.dispose();
      return;
    }

    // Decode payload
    const decoded = decodePayload(payload);
    const header = extractHeaderFromPayload(decoded);
    this.maxOctreeDepth = header.maxDepth;

    // Create data arrays
    const matricesData = new Float32Array(16 * header.numberOfVoxels);
    const colorData = new Float32Array(4 * header.numberOfVoxels);
    const countData = new Float32Array(header.numberOfVoxels);
    const depthData = new Float32Array(header.numberOfVoxels);
    const dataArrays = {
      matricesData,
      colorData,
      countData,
      depthData,
    };

    // Populate data arrays
    const voxelByteLength = 20;
    const arrayBuf = new ArrayBuffer(voxelByteLength);
    const view = new DataView(arrayBuf);
    let offset = header.headerLength;
    for (let i = 0; i < header.numberOfVoxels; i++) {
      let voxel = {} as Voxel;
      decoded.slice(offset, offset + voxelByteLength).forEach((b, n) => {
        view.setUint8(n, b);
      });
      const voxelHeaderBytes = view.getUint16(0, false);
      const voxelHeader = extractVoxelFromPayload(voxelHeaderBytes);

      voxel.x = view.getUint16(2, false);
      voxel.y = view.getUint16(4, false);
      voxel.z = view.getUint16(6, false);

      voxel = inverseTransform(voxel, header.transform);

      offset += 8;
      if (voxelHeader.hasColor) {
        offset = offset + 3;
        voxel.r = view.getUint8(8);
        voxel.g = view.getUint8(9);
        voxel.b = view.getUint8(10);
      }
      matricesData[i * 16] = 1;
      matricesData[i * 16 + 5] = 1;
      matricesData[i * 16 + 10] = 1;
      matricesData[i * 16 + 15] = 1;
      matricesData[i * 16 + 12] = voxel.x;
      matricesData[i * 16 + 13] = voxel.y;
      matricesData[i * 16 + 14] = voxel.z;
      countData[i] = 1.0;
      colorData[i * 4] = voxel.r / 255;
      colorData[i * 4 + 1] = voxel.g / 255;
      colorData[i * 4 + 2] = voxel.b / 255;
      colorData[i * 4 + 3] = 1.0;
      depthData[i] = this.maxOctreeDepth;
    }

    // Accumulate chunks
    if (header.messageID !== this.currentFrameID) {
      if (header.messageID >= this.currentFrameID) {
        this.dataArrays = {
          matricesData: matricesData.slice(),
          colorData: colorData.slice(),
          countData: countData.slice(),
          depthData: depthData.slice(),
        };
        this.currentFrameID = header.messageID;
        this.expectedChunks = header.numberOfChunks;
        this.receivedChunks = 1;
      }
    } else {
      Object.keys(this.dataArrays).forEach((key) => {
        const temp = new Float32Array(this.dataArrays[key].length + dataArrays[key].length);
        temp.set(this.dataArrays[key], 0);
        temp.set(dataArrays[key], this.dataArrays[key].length);
        this.dataArrays[key] = temp;
      });
      this.receivedChunks++;
    }
    if (this.receivedChunks !== this.expectedChunks) return;

    this.updateVoxels(header.resolution);
  }

  private updateVoxels(resolution: number) {
    this.voxels?.dispose();
    this.voxels = MeshBuilder.CreateBox(`${this.name}-voxels`, { size: resolution });
    this.updateMaterial();
    this.voxels.parent = this;

    this.voxels.thinInstanceSetBuffer('matrix', this.dataArrays.matricesData, 16, true);
    this.voxels.setVerticesBuffer(
      new VertexBuffer(this.getEngine(), this.dataArrays.colorData, 'pointcolor', false, false, 4, true),
    );
    this.voxels.setVerticesBuffer(
      new VertexBuffer(this.getEngine(), this.dataArrays.countData, 'counts', false, false, 1, true),
    );
    this.voxels.setVerticesBuffer(
      new VertexBuffer(this.getEngine(), this.dataArrays.depthData, 'octreeDepth', false, false, 1, true),
    );
  }

  private updateMaterial() {
    this.shaderMaterial.setFloat('colorType', this.colorMethod.key);
    this.shaderMaterial.setFloat('maxDepth', this.maxOctreeDepth);
    if (this.voxels) this.voxels.material = this.shaderMaterial;
  }

  private createShaderMaterial() {
    const material = new ShaderMaterial(this.name, this.getScene(), './assets/shaders/simpleVoxel', {
      needAlphaBlending: false,
      needAlphaTesting: false,
      attributes: ['position', 'color', 'pointcolor', 'counts', 'octreeDepth'],
      defines: [],
      samplers: [],
      uniforms: [
        'billboard',
        'camPos',
        'camTarget',
        'colorMix',
        'confidenceThreshold',
        'maxDepth',
        'paraboloidPoint',
        'paraboloidWeight',
        'pointScale',
        'projection',
        'resolution',
        'roundCorners',
        'time',
        'view',
        'viewProjection',
        'visType',
        'world',
        'worldView',
        'worldViewProjection',
      ],
    });
    material.setFloat('billboard', 0.0);
    material.setFloat('colorMix', 1.0);
    material.setFloat('colorType', 0.0);
    material.setFloat('confidenceThreshold', 0.0);
    material.setFloat('maxDepth', 1.0);
    material.setFloat('paraboloidPoint', 0.0);
    material.setFloat('paraboloidWeight', 0.2);
    material.setFloat('pointScale', 1.0);
    material.setFloat('roundCorners', 0.0);
    material.setFloat('visType', 1.0);
    material.setVector3('boundsMax', new Vector3(1.0, 1.0, 1.0));
    material.setVector3('boundsMin', new Vector3(-1.0, -1.0, -1.0));
    return material;
  }
}

const decodePayload = (payload: string) =>
  new Uint8Array(
    atob(payload)
      .split('')
      .map((c) => c.charCodeAt(0)),
  );

const extractHeaderFromPayload = (decoded: Uint8Array) => {
  const headerLength = 74 + 8 + 4;
  const headerBuffer = new ArrayBuffer(headerLength);
  const headerView = new DataView(headerBuffer);
  decoded.slice(0, headerLength).forEach((b, i) => {
    headerView.setUint8(i, b);
  });

  const messageID = headerView.getBigUint64(1, false);
  const numberOfChunks = headerView.getUint32(9, false);
  const numberOfVoxels = headerView.getUint32(13, false);
  const resolution = headerView.getFloat32(17, false);
  const maxDepth = headerView.getUint8(21);

  let offset = 22;
  const transform = [];
  for (let i = 0; i < 4; i++) {
    transform[i] ??= [];
    for (let k = 0; k < 4; k++) {
      transform[i][k] = headerView.getFloat32(offset, false);
      offset = offset + 4;
    }
  }

  return {
    headerLength,
    messageID,
    numberOfChunks,
    numberOfVoxels,
    resolution,
    maxDepth,
    transform,
  };
};

const extractVoxelFromPayload = (header: number) => ({
  version: header & 0xf,
  hasColor: (header >> 4) & 0x1,
  hasProbability: (header >> 5) & 0x1,
  hasClassification: (header >> 6) & 0x1,
  hasReturn: (header >> 7) & 0x1,
  hasIntensity: (header >> 8) & 0x1,
});

const inverseTransform = (v: Voxel, mat: Array<Array<number>>) => {
  const x = mat[0][0] * v.x + mat[1][0] * v.y + mat[2][0] * v.z + mat[3][0];
  const y = mat[0][1] * v.x + mat[1][1] * v.y + mat[2][1] * v.z + mat[3][1];
  const z = mat[0][2] * v.x + mat[1][2] * v.y + mat[2][2] * v.z + mat[3][2];
  const w = mat[0][3] * v.x + mat[1][3] * v.y + mat[2][3] * v.z + mat[3][3];
  const oow = 1 / w;
  v.x = x * oow;
  v.y = y * oow;
  v.z = z * oow;
  return v;
};
