import { Proj4Projection } from '@math.gl/proj4';
import type { Mesh, Scene, TransformNode } from '@babylonjs/core';
import { Matrix, Vector3 } from '@babylonjs/core';
import { lusolve, multiply, transpose } from 'mathjs';
import type { Tile2D } from './nodes/tiles2D';

export type Point = [number, number, number];

export interface GpsLocation {
  latitude: number;
  longitude: number;
  altitude: number;
}

// Lower case WGS84 projection
Proj4Projection.defineProjectionAliases({
  'wgs84': '+title=WGS 84 (long/lat) +proj=longlat +ellps=WGS84 +datum=WGS84 +units=degrees',
});

/* Degree to Radian */
const radians = (deg: number): number => {
  return deg * (Math.PI / 180.0);
};

export class SlippyTransformer {
  /*
    Class for converting between a local coordinate system and a slippy map.
    https://en.wikipedia.org/wiki/Tiled_web_map
   */

  constructor(private viewLocation: GpsLocation) {}

  public wgs84ToTile([lon, lat, alt]: Point, zoom: number): Point {
    const x = (128.0 / Math.PI) * Math.pow(2, zoom) * (radians(lon) + Math.PI);
    const y =
      (128.0 / Math.PI) * Math.pow(2, zoom) * (Math.PI - Math.log(Math.tan(Math.PI / 4.0 + radians(lat) / 2.0)));
    return [x, y, alt];
  }

  public tileToWgs84([x, y, z]: Point, zoom: number): Point {
    const lat = (4.0 * Math.atan(Math.exp(Math.PI - (y * Math.PI) / (128.0 * Math.pow(2, zoom)))) - Math.PI) / 2.0;
    const lng = (x * Math.PI) / (128.0 * Math.pow(2, zoom)) - Math.PI;
    return [lng * (180.0 / Math.PI), lat * (180.0 / Math.PI), z];
  }

  public getTilesInLocalFrame(geometry: GpsLocation[], zoom: number): Tile2D[] {
    const corner1Wgs84 = [
      Math.min(...geometry.map((g) => g.longitude)),
      Math.max(...geometry.map((g) => g.latitude)),
      0,
    ] as Point;
    const corner2Wgs84 = [
      Math.max(...geometry.map((g) => g.longitude)),
      Math.min(...geometry.map((g) => g.latitude)),
      0,
    ] as Point;

    const corner1Tile = this.wgs84ToTile(corner1Wgs84, zoom);
    const corner2Tile = this.wgs84ToTile(corner2Wgs84, zoom);

    // Get min and max tile as tile indices
    const minTileIndex = [Math.floor(corner1Tile[0] / 256), Math.floor(corner1Tile[1] / 256)];
    const maxTileIndex = [Math.floor(corner2Tile[0] / 256), Math.floor(corner2Tile[1] / 256)];

    const tiles = [];
    for (let x = minTileIndex[0]; x <= maxTileIndex[0]; x++) {
      for (let y = minTileIndex[1]; y <= maxTileIndex[1]; y++) {
        tiles.push([x, y]);
      }
    }

    return tiles.map((tile) => {
      const wgs841 = this.tileToWgs84([tile[0] * 256, tile[1] * 256, 0], zoom);
      const wgs842 = this.tileToWgs84([tile[0] * 256 + 256, tile[1] * 256 + 256, 0], zoom);
      const local1 = crsToWorld(this.viewLocation, 'WGS84', new Vector3(...wgs841));
      const local2 = crsToWorld(this.viewLocation, 'WGS84', new Vector3(...wgs842));
      return {
        index: tile,
        positions: [local1.x, local1.y, 0, local2.x, local1.y, 0, local2.x, local2.y, 0, local1.x, local2.y, 0],
        indices: [0, 1, 2, 0, 2, 3],
        uvs: [0, 0, 1, 0, 1, 1, 0, 1],
        center: [(local1.x + local2.x) / 2, (local1.y + local2.y) / 2, 0],
        width: local2.x - local1.x,
        height: local2.y - local1.y,
      };
    });
  }
}

export const orthoProjString = (gpsLocation: GpsLocation): string => {
  // eslint-disable-next-line max-len
  return `+proj=ortho +lat_0=${gpsLocation.latitude} +lon_0=${gpsLocation.longitude} +x_0=0 +y_0=0 +ellps=WGS84 +datum=WGS84 +units=m +no_defs`;
};

export const worldToCrs = (gpsLocation: GpsLocation, crs: string, point: Vector3): Vector3 => {
  const projection = new Proj4Projection({ from: orthoProjString(gpsLocation), to: crs });
  const projected = projection.project(point.asArray());
  projected[2] += gpsLocation.altitude;
  return new Vector3(...projected);
};

export const crsToWorld = (gpsLocation: GpsLocation, crs: string, point: Vector3): Vector3 => {
  const projection = new Proj4Projection({ from: crs, to: orthoProjString(gpsLocation) });
  const projected = projection.project(point.asArray());
  projected[2] -= gpsLocation.altitude;
  return new Vector3(...projected);
};

export const pointToFrame = (
  viewLocation: GpsLocation,
  frame: Mesh | TransformNode | string,
  point: Vector3,
): Vector3 => {
  if (typeof frame === 'string') {
    return worldToCrs(viewLocation, frame, point);
  }
  frame.computeWorldMatrix(true);
  return Vector3.TransformCoordinates(point, frame.getWorldMatrix().invert());
};

export const pointFromFrame = (
  viewLocation: GpsLocation,
  frame: Mesh | TransformNode | string,
  point: Vector3,
): Vector3 => {
  if (typeof frame === 'string') {
    return crsToWorld(viewLocation, frame, point);
  }
  frame.computeWorldMatrix(true);
  return Vector3.TransformCoordinates(point, frame.getWorldMatrix());
};

export const updateGeoTransformNode = (transformNode: TransformNode, viewLocation: GpsLocation, to: string): void => {
  const from = orthoProjString(viewLocation);

  const projection = new Proj4Projection({ from, to });

  // Create points around the origin
  const pointsFrom = [];
  for (let x = -1; x < 2; x++) {
    for (let y = -1; y < 2; y++) {
      for (let z = -1; z < 2; z++) {
        pointsFrom.push(new Vector3(x * 1000, y * 1000, z * 100));
      }
    }
  }
  const pointsTo = pointsFrom.map((point) => projection.project(point.asArray())).map((point) => new Vector3(...point));

  // Correct the Z since an orthographic projection is 2D
  pointsTo.forEach((point) => (point.z += viewLocation.altitude));

  const matrix = leastSquares(pointsFrom, pointsTo);
  const rmse = computeRmse(pointsFrom, pointsTo, matrix);
  transformNode.setPivotMatrix(matrix.invert(), false);
  transformNode.metadata = { from, to, rmse };
};

export const updateViewLocation = (scene: Scene, viewLocation: GpsLocation): void => {
  const transformNodes = scene.transformNodes.filter((node) => node.metadata?.to);
  transformNodes.forEach((node) => {
    updateGeoTransformNode(node, viewLocation, node.metadata.to);
  });
};

export const updateLeastSquaresTransformNode = (transformNode: TransformNode, pointsPairArray: number[][]): void => {
  const pointsFrom = pointsPairArray.map((point) => new Vector3(...point.slice(0, 3)));
  const pointsTo = pointsPairArray.map((point) => new Vector3(...point.slice(3)));
  const matrix = leastSquares(pointsFrom, pointsTo);
  const rmse = computeRmse(pointsFrom, pointsTo, matrix);
  transformNode.setPivotMatrix(matrix.invert(), false);
  transformNode.metadata = { rmse };
};

const leastSquares = (pointsFrom: Vector3[], pointsTo: Vector3[]): Matrix => {
  if (pointsFrom.length < 4 || pointsTo.length < 4 || pointsFrom.length !== pointsTo.length) {
    console.error('Invalid input: At least 4 points are required and source/target point counts must match.');
    return null;
  }

  // Construct the coefficient matrix A and the target vector b
  const A = [];
  const b = [];
  for (let i = 0; i < pointsFrom.length; i++) {
    const [x, y, z] = pointsFrom[i].asArray();
    const [u, v, w] = pointsTo[i].asArray();
    A.push([x, y, z, 1, 0, 0, 0, 0, 0, 0, 0, 0]);
    A.push([0, 0, 0, 0, x, y, z, 1, 0, 0, 0, 0]);
    A.push([0, 0, 0, 0, 0, 0, 0, 0, x, y, z, 1]);
    b.push(u);
    b.push(v);
    b.push(w);
  }

  // Solve for the transformation matrix using least-squares
  const AtA = multiply(transpose(A), A);
  const AtB = multiply(transpose(A), b);
  const solution = lusolve(AtA, AtB);

  // We have a 4x3 matrix, we need a 4x4
  const transformationMatrix: number[][] = [
    [solution[0][0], solution[1][0], solution[2][0], solution[3][0]],
    [solution[4][0], solution[5][0], solution[6][0], solution[7][0]],
    [solution[8][0], solution[9][0], solution[10][0], solution[11][0]],
    [0, 0, 0, 1],
  ];

  return Matrix.FromArray(transformationMatrix.flat()).transpose();
};

const computeRmse = (pointsFrom: Vector3[], pointsTo: Vector3[], matrix: Matrix): number => {
  const squaredErrors = pointsFrom.map((pointFrom, i) => {
    const transformedPoint = Vector3.TransformCoordinates(pointFrom, matrix);
    return Vector3.DistanceSquared(transformedPoint, pointsTo[i]);
  });

  const meanSquaredError =
    squaredErrors.reduce((sum: number, squaredError: number) => sum + squaredError, 0) / pointsFrom.length;

  return Math.sqrt(meanSquaredError);
};
