import type { AbstractMesh, Scene } from '@babylonjs/core';
import { Mesh, MeshBuilder, Quaternion, SceneLoader } from '@babylonjs/core';
import type { RocosSdkClientService } from '@shared/services';
import type { Subscription } from 'rxjs';
import { type PrimitiveMetaData } from '../primitives/primitives';
import { defaultPanels } from '@shared-modules/properties-editor-panel/pipes/editor-panel/editor-panel.pipe';

export const spotMetaData: PrimitiveMetaData = {
  key: 'spot-robot',
  label: 'Spot',
  icon: 'ri-robots',
  editorPanels: {
    ...defaultPanels,
    bindPosition: false,
    bindRotationEuler: false,
    bindRotationQuaternion: false,
  },
};

const spotMeshFilename = 'spotJoints.gltf';
const spotMeshUrl = 'https://content.rocos.io/3dmodels/';

const jointsSource = '/spot/robot_state/kinematic_state?int=200ms';
const positionSource = '/spot/graphnav/ros/anchor?int=200ms';

enum SpotJoint {
  FrontLeftHipX = 'fl.hx',
  FrontLeftHipY = 'fl.hy',
  FrontRightHipX = 'fr.hx',
  FrontRightHipY = 'fr.hy',
  HindLeftHipX = 'hl.hx',
  HindLeftHipY = 'hl.hy',
  HindRightHipX = 'hr.hx',
  HindRightHipY = 'hr.hy',
  FrontLeftKnee = 'fl.kn',
  FrontRightKnee = 'fr.kn',
  HindLeftKnee = 'hl.kn',
  HindRightKnee = 'hr.kn',
}

enum SpotMesh {
  FrontLeftUpperLeg = 'SpotLeg_FLU',
  FrontLeftLowerLeg = 'SpotLeg_FLL',
  FrontRightUpperLeg = 'SpotLeg_FRU',
  FrontRightLowerLeg = 'SpotLeg_FRL',
  HindLeftUpperLeg = 'SpotLeg_HLU',
  HindLeftLowerLeg = 'SpotLeg_HLL',
  HindRightUpperLeg = 'SpotLeg_HRU',
  HindRightLowerLeg = 'SpotLeg_HRL',
  Body = 'SpotBody',
}

const SPOT_BODY_OFFSET = 1.3;

const getSpotMeshFromString = (meshId: string): SpotMesh | undefined => {
  return Object.values(SpotMesh).find((value) => value === meshId);
};

const createJoint = (name: string, parent: AbstractMesh, visibility: 0 | 1) => {
  const joint = MeshBuilder.CreateBox(name, { size: 0.2 });
  joint.parent = parent;
  joint.visibility = visibility;
  return joint;
};

export class Spot extends Mesh {
  private projectId = '';
  private callsign = '';

  private jointsSub: Subscription;
  private positionSub: Subscription;

  private spotMeshes: Map<SpotMesh, AbstractMesh> = new Map();
  private joints: Map<SpotJoint, Joint> = new Map();
  private bodyJointMesh: AbstractMesh;

  constructor(name: string, projectId: string, callsign: string, scene: Scene, private sdk: RocosSdkClientService) {
    super(name, scene);

    this.projectId = projectId;
    this.callsign = callsign;

    this.loadSpot()
      .then(() => {
        this.addJoints();
        this.jointsSub = this.subscribeJoints();
        this.positionSub = this.subscribePosition();
      })
      .catch((err) => console.warn(`Failed to load spot mesh`, err));
  }

  override dispose(doNotRecurse?: boolean, disposeMaterialAndTextures?: boolean): void {
    this.jointsSub?.unsubscribe();
    this.positionSub?.unsubscribe();
    super.dispose(doNotRecurse, disposeMaterialAndTextures);
  }

  private async loadSpot() {
    const result = await SceneLoader.ImportMeshAsync(null, spotMeshUrl, spotMeshFilename, this.getScene());
    const rootMesh = result.meshes.find((mesh) => mesh.id === '__root__') as Mesh;
    if (!rootMesh) return;
    rootMesh.rotation.z = Math.PI / 2;
    rootMesh.rotation.y = Math.PI / 2;

    result.meshes.forEach((mesh) => {
      const spotMesh = getSpotMeshFromString(mesh.id);
      if (spotMesh) this.spotMeshes.set(spotMesh, mesh);
      mesh.id = this.getMeshId(mesh.id);
    });
  }

  private addJoints() {
    const visibility = 0;

    const jointBody = MeshBuilder.CreateBox('jointBody', { size: 0.2 });
    jointBody.visibility = visibility;
    jointBody.position.z = SPOT_BODY_OFFSET;
    jointBody.addChild(this.spotMeshes.get(SpotMesh.Body));
    this.bodyJointMesh = jointBody;

    const jointFrontLeftHip = createJoint('jointFrontLeftHip', jointBody, visibility);
    jointFrontLeftHip.position.set(0.9, 0.4, 0.0);
    jointFrontLeftHip.addChild(this.spotMeshes.get(SpotMesh.FrontLeftUpperLeg));
    this.addJoint(SpotJoint.FrontLeftHipX, jointFrontLeftHip, 'x', 0, 1);
    this.addJoint(SpotJoint.FrontLeftHipY, jointFrontLeftHip, 'y', -Math.PI / 4, 1);

    const jointFrontRightHip = createJoint('jointFrontRightHip', jointBody, visibility);
    jointFrontRightHip.position.set(0.9, -0.4, 0.0);
    jointFrontRightHip.addChild(this.spotMeshes.get(SpotMesh.FrontRightUpperLeg));
    this.addJoint(SpotJoint.FrontRightHipX, jointFrontRightHip, 'x', 0, 1);
    this.addJoint(SpotJoint.FrontRightHipY, jointFrontRightHip, 'y', -Math.PI / 4, 1);

    const jointHindLeftHip = createJoint('jointHindLeftHip', jointBody, visibility);
    jointHindLeftHip.position.set(-0.9, 0.4, 0.0);
    jointHindLeftHip.addChild(this.spotMeshes.get(SpotMesh.HindLeftUpperLeg));
    this.addJoint(SpotJoint.HindLeftHipX, jointHindLeftHip, 'x', 0, 1);
    this.addJoint(SpotJoint.HindLeftHipY, jointHindLeftHip, 'y', -Math.PI / 4, 1);

    const jointHindRightHip = createJoint('jointHindRightHip', jointBody, visibility);
    jointHindRightHip.position.set(-0.9, -0.4, 0.0);
    jointHindRightHip.addChild(this.spotMeshes.get(SpotMesh.HindRightUpperLeg));
    this.addJoint(SpotJoint.HindRightHipX, jointHindRightHip, 'x', 0, 1);
    this.addJoint(SpotJoint.HindRightHipY, jointHindRightHip, 'y', -Math.PI / 4, 1);

    // Knee joints
    const jointFrontLeftKnee = createJoint('jointFrontLeftKnee', jointFrontLeftHip, visibility);
    jointFrontLeftKnee.position.set(-0.55, 0, -0.8);
    jointFrontLeftKnee.setParent(this.spotMeshes.get(SpotMesh.FrontLeftUpperLeg));
    jointFrontLeftKnee.addChild(this.spotMeshes.get(SpotMesh.FrontLeftLowerLeg));
    this.addJoint(SpotJoint.FrontLeftKnee, jointFrontLeftKnee, 'x', -Math.PI / 2, -1);

    const jointFrontRightKnee = createJoint('jointFrontRightKnee', jointFrontRightHip, visibility);
    jointFrontRightKnee.position.set(-0.55, 0, -0.8);
    jointFrontRightKnee.setParent(this.spotMeshes.get(SpotMesh.FrontRightUpperLeg));
    jointFrontRightKnee.addChild(this.spotMeshes.get(SpotMesh.FrontRightLowerLeg));
    this.addJoint(SpotJoint.FrontRightKnee, jointFrontRightKnee, 'x', -Math.PI / 2, -1);

    const jointHindLeftKnee = createJoint('jointHindLeftKnee', jointHindLeftHip, visibility);
    jointHindLeftKnee.position.set(-0.55, 0, -0.8);
    jointHindLeftKnee.setParent(this.spotMeshes.get(SpotMesh.HindLeftUpperLeg));
    jointHindLeftKnee.addChild(this.spotMeshes.get(SpotMesh.HindLeftLowerLeg));
    this.addJoint(SpotJoint.HindLeftKnee, jointHindLeftKnee, 'x', -Math.PI / 2, -1);

    const jointHindRightKnee = createJoint('jointHindRightKnee', jointHindRightHip, visibility);
    jointHindRightKnee.position.set(-0.55, 0, -0.8);
    jointHindRightKnee.setParent(this.spotMeshes.get(SpotMesh.HindRightUpperLeg));
    jointHindRightKnee.addChild(this.spotMeshes.get(SpotMesh.HindRightLowerLeg));
    this.addJoint(SpotJoint.HindRightKnee, jointHindRightKnee, 'x', -Math.PI / 2, -1);

    // Scale to real size
    jointBody.scaling.set(0.4, 0.4, 0.4);
    jointBody.position.z = 0;
    jointBody.parent = this;
  }

  private addJoint(jointName: SpotJoint, babylonMesh: Mesh, axis: 'x' | 'y' | 'z', offset: number, direction: 1 | -1) {
    this.joints.set(jointName, new Joint(jointName, babylonMesh, axis, offset, direction));
  }

  private subscribeJoints() {
    return this.sdk.client
      .getTelemetryService()
      .subscribe<{ 'transforms_snapshot': any; 'joint_states': any[] }>({
        projectId: this.projectId,
        callsigns: [this.callsign],
        sources: [jointsSource],
      })
      .subscribe((res) => {
        if (!res?.payload) return;

        res.payload.joint_states.forEach((jointState) => {
          const joint = this.joints.get(jointState.name);
          if (!joint) return;

          joint.setJointPosition(jointState.position.value);
        });
      });
  }

  private subscribePosition() {
    return this.sdk.client
      .getTelemetryService()
      .subscribe<{
        pose: {
          pose: {
            position: {
              x: number;
              y: number;
              z: number;
            };
            orientation: {
              x: number;
              y: number;
              z: number;
              w: number;
            };
          };
        };
      }>({
        projectId: this.projectId,
        callsigns: [this.callsign],
        sources: [positionSource],
      })
      .subscribe((res) => {
        const pose = res?.payload?.pose?.pose;
        if (!pose) return;

        this.bodyJointMesh.position.x = pose.position.x;
        this.bodyJointMesh.position.y = pose.position.y;
        this.bodyJointMesh.position.z = pose.position.z;

        const rotation = new Quaternion(pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w);
        const euler = rotation.toEulerAngles();
        this.bodyJointMesh.rotation.x = euler.x;
        this.bodyJointMesh.rotation.y = euler.y;
        this.bodyJointMesh.rotation.z = euler.z;
      });
  }

  private getMeshId(meshId: string) {
    return `${this.name}-spot-${meshId}`;
  }
}

class Joint {
  constructor(
    public jointName: SpotJoint,
    private babylonMesh: Mesh,
    private axis: 'x' | 'y' | 'z',
    private offset: number,
    private direction: 1 | -1,
  ) {}

  public setJointPosition(position: number) {
    this.babylonMesh.rotation[this.axis] = this.direction * position + this.offset;
  }
}
