#define Matrix rlMatrix
#define Vector3 rlVector3
#define Quaternion rlQuaternion
#include <raylib.h>
#include <rlgl.h>
#include <external/glad.h>
#define RAYGUI_IMPLEMENTATION
#include <raygui_dl.h>
#undef Quaternion
#undef Vector3
#undef Matrix
#include "jet.hpp"
#include <array>
#include <Eigen/Dense>
#include <Eigen/Geometry>
#include <iostream>
#include <cmath>
#ifdef __EMSCRIPTEN__
#include <emscripten.h>
#endif
using namespace Eigen;
template<typename T>
using Vector6 = Eigen::Vector<T, 6>;
template<typename T>
using Matrix6 = Eigen::Matrix<T, 6, 6>;

template<typename scalar>
    requires(scalar::n_vars >= 6)
struct transform{
    Vector3<scalar> translation;
    Vector3<scalar> euler_rotation_angles;
    transform() : translation(0,0,0), euler_rotation_angles(0,0,0){
        for(size_t i = 0;i < 3;i++){
            translation[i].deriv(i) = 1;
            euler_rotation_angles[i].deriv(i + 3) = 1;
        }
    }
    Matrix3<scalar> matrix()const noexcept{
        return quat().toRotationMatrix();
    }
    Quaternion<scalar> quat()const noexcept{
        Eigen::AngleAxis<scalar> rollAngle(euler_rotation_angles.x(), Eigen::Vector3<scalar>::UnitX());
        Eigen::AngleAxis<scalar> yawAngle(euler_rotation_angles.y(), Eigen::Vector3<scalar>::UnitY());
        Eigen::AngleAxis<scalar> pitchAngle(euler_rotation_angles.z(), Eigen::Vector3<scalar>::UnitZ());
        return rollAngle * yawAngle * pitchAngle;
    }
    template<int N>
    Matrix<scalar, 3, N> apply(const Matrix<scalar, 3, N>& input){
        Matrix<scalar, 3, N> trfed = matrix() * input;
        trfed.colwise() += translation;
        return trfed;
    }
};

template<typename T>
Matrix<T, 3, 6> generate_hexagon(T scale, Vector3<T> translation = Vector3<T>(0, 0, 0)){
    T scale3 = scale * std::sqrt(T(3));
    T scale32 = scale3 / T(2);
    translation.x() -= scale * 0.5;
    translation.y() -= scale32;
    Matrix<T, 3, 6> vertices;
    vertices.col(0) = translation + Vector3<T>(0, 0, 0);
    vertices.col(1) = translation + Vector3<T>(scale, 0, 0);
    vertices.col(2) = translation + Vector3<T>(scale * 1.5, scale32, 0);
    vertices.col(3) = translation + Vector3<T>(scale, scale3, 0);
    vertices.col(4) = translation + Vector3<T>(0, scale3, 0);
    vertices.col(5) = translation + Vector3<T>(-scale * 0.5, scale32, 0);
    return vertices;
}
template<typename T>
Matrix<T, 3, 6> generate_scronched_hexagon(T scale, Vector3<T> translation = Vector3<T>(0, 0, 0)){
    T scale3 = scale * std::sqrt(T(3));
    T scale32 = scale3 / T(2);
    translation.x() -= scale * 0.5;
    translation.y() -= scale32;
    Matrix<T, 3, 6> vertices;
    vertices.col(0) = translation + Vector3<T>(0, 0, 0);
    vertices.col(1) = translation + Vector3<T>(0, 0, 0);
    vertices.col(2) = translation + Vector3<T>(scale * 1.5, scale32, 0);
    vertices.col(3) = translation + Vector3<T>(scale * 1.5, scale32, 0);
    vertices.col(4) = translation + Vector3<T>(0, scale3, 0);
    vertices.col(5) = translation + Vector3<T>(0, scale3, 0);
    return vertices;
}
int GuiSlider(Rectangle bounds, const char *textLeft, const char *textRight, double *value, float minValue, float maxValue){
    float v = *value;
    int res = GuiSlider(bounds, textLeft, textRight, &v, minValue, maxValue);
    *value = v;
    return res;
}

template<typename scalar>
using jet_type = jet<scalar, 12>;
template<typename scalar>
using hexagon = Matrix<jet_type<scalar>, 3, 6>;

constexpr double distance = 1.5;
template<typename scalar>
std::pair<Vector6<scalar>, Matrix6<scalar>> evaluate_with_jacobian(const transform<jet_type<scalar>>& trf, const hexagon<scalar>& points, const hexagon<scalar>& anchors){
    hexagon<scalar> trfed = trf.matrix() * points;
    trfed.colwise() += trf.translation;
    Vector6<jet_type<scalar>> distances;
    for(size_t i = 0;i < 6;i++){
        distances(i) = (trfed.col(i) - anchors.col(i)).norm();
    }
    std::pair<Vector6<scalar>, Matrix6<scalar>> ret;
    for(size_t i = 0;i < 6;i++){
        ret.first(i) = distances(i).value - distance;
        ret.second.array().row(i) = distances(i).deriv.template head<6>();
    }
    return ret;
}
template<typename scalar>
std::pair<Vector6<scalar>, Matrix6<scalar>> evaluate_with_jacobian_wrt_actuators(const transform<jet_type<scalar>>& trf, const hexagon<scalar>& points, const hexagon<scalar>& anchors){
    hexagon<scalar> trfed = trf.matrix() * points;
    trfed.colwise() += trf.translation;
    Vector6<jet_type<scalar>> distances;
    for(size_t i = 0;i < 6;i++){
        distances(i) = (trfed.col(i) - anchors.col(i)).norm();
    }
    std::pair<Vector6<scalar>, Matrix6<scalar>> ret;
    for(size_t i = 0;i < 6;i++){
        ret.first(i) = distances(i).value - distance;
        ret.second.array().row(i) = distances(i).deriv.template tail<6>();
    }
    return ret;
}
template<typename T, int N>
std::array<rlVector3, N> convert_to_raylib(const Matrix<T, 3, N>& mat){
    std::array<rlVector3, N> ret;
    for(int i = 0;i < N;i++){
        ret[i] = rlVector3(mat.col(i)[0], mat.col(i)[1], mat.col(i)[2]);
    }
    return ret;
}
rlVector3 erl(const auto& v){
    rlVector3 ret(v(0), v(1), v(2));
    return ret;
}
RenderTexture vp;
float camangle = 0.0f;
float campitch = 1.0f;
bool camdragging = false;
Vector6<float> actuator_angles = Vector6<float>::Zero();
transform<jet_type<double>> current_transform; // Identity

void gameLoop(){
    using scalar = double;
    rlVector3 targ{0.5f,std::sqrt(3.0f) / 2.0f,0};
    campitch = std::clamp(campitch, -1.2f, 1.2f);
    Camera3D cam{.position = rlVector3{std::sin(camangle) * 4.0f + targ.x,std::cos(camangle) * 4.0f + targ.y, std::tan(campitch)}, .target = targ, .up = {0,0,1},.fovy =  60.0f, .projection =  CAMERA_PERSPECTIVE};
    BeginDrawing();
    if(!IsMouseButtonDown(MOUSE_LEFT_BUTTON)){
        camdragging = false;
    }
    if(IsMouseButtonPressed(MOUSE_LEFT_BUTTON)){
        if(GetMouseX() <= 1000){
            //std::cout << "oof\n";
            camdragging = true;
        }
    }else if(IsMouseButtonDown(MOUSE_BUTTON_LEFT) && camdragging){
        camangle += GetMouseDelta().x * 0.01f;
        campitch += GetMouseDelta().y * 0.01f;
    }
    
    ClearBackground(Color(30,30,10,255));
    BeginTextureMode(vp);
    ClearBackground(Color(30,30,10,255));
    BeginMode3D(cam);
    auto unit_hexagon = generate_scronched_hexagon(jet_type<scalar>(1.0));
    auto anchors = generate_hexagon(jet_type<scalar>(1.6), Vector3<jet_type<scalar>>(0, 0, -1.0));
    std::array<std::pair<Vector3<double>, Vector3<double>>, 6> rotary_planes;
    double rad = 0.4;
    for(int i = 0;i < 6;i++){
        Vector3<double> diag = anchors.col((i + 3) % 6).cast<double>() - anchors.col(i).cast<double>();
        rotary_planes[i].first = Vector3<double>(-diag.y(), diag.x(), 0).normalized();
        rotary_planes[i].second = Vector3<double>(0,0,1);
        
        for(int j = 0;j < 128;j++){
            double angle = 2.0 * M_PI * double(j) / 127;
            double anglen = 2.0 * M_PI * double(j + 1) / 127;
            Vector3<double> from = anchors.col(i).cast<double>() + rotary_planes[i].first * std::cos(angle) * rad + rotary_planes[i].second * std::sin(angle)  * rad;
            Vector3<double> to = anchors.col(i).cast<double>() + rotary_planes[i].first * std::cos(anglen)  * rad + rotary_planes[i].second * std::sin(anglen) * rad;
            rlVector3 rlf = erl(from);
            rlVector3 rlt = erl(to);
            DrawLine3D(rlf, rlt, ORANGE);
        }
    }
    for(int i = 0;i < 6;i++){
        jet_type<double> actuator_angle_jet(actuator_angles[i]);
        actuator_angle_jet.deriv(i + 6) = 1;
        anchors.col(i) += (jet_type<double>(rad) * cos(actuator_angle_jet) * rotary_planes[i].first .cast<jet_type<double>>())
                       +  (jet_type<double>(rad) * sin(actuator_angle_jet) * rotary_planes[i].second.cast<jet_type<double>>());
    }
    //anchors.col(0)(2) += 0.1 * std::sin(5.0 * GetTime());
    auto trf_backup = current_transform;
    current_transform = transform<jet_type<double>>();
    for(int i = 0;i < 20;i++){
        auto[v, J] = evaluate_with_jacobian<double>(current_transform, unit_hexagon, anchors);
        FullPivHouseholderQR<Matrix6<double>> qr(J);
        
        Vector6<double> subtract = qr.solve(v);
        current_transform.translation -= subtract.head<3>().cast<jet_type<double>>();
        current_transform.euler_rotation_angles -= subtract.segment(3, 3).cast<jet_type<double>>();
        //std::cout << v << "\n";
        //std::cout << "Rank: " << J.fullPivHouseholderQr().rank() << "\n";
        //std::cout << "Rank: " << J.fullPivHouseholderQr().solve(v) << "\n";
        auto[vi, Ji] = evaluate_with_jacobian_wrt_actuators<double>(current_transform, unit_hexagon, anchors);
        if(vi.norm() > 0.1){
            current_transform = trf_backup;
        }
    }
    
    Matrix<double, 3, 6> hexagon_vertices = current_transform.apply(unit_hexagon).cast<double>();
    std::array<rlVector3, 6> hvrl = convert_to_raylib(hexagon_vertices);
    std::array<rlVector3, 6> arl = convert_to_raylib(anchors);
    for(int i = 0;i < 6;i++){
        DrawLine3D(hvrl[i], arl[i], WHITE);
    }
    DrawTriangle3D(hvrl[0], hvrl[2], hvrl[4], Color{255, 255, 0, 120});
    EndMode3D();
    EndTextureMode();
    SetTextureFilter(vp.texture, TEXTURE_FILTER_BILINEAR);
    GenTextureMipmaps(&vp.texture);
    DrawTexturePro(vp.texture, Rectangle(0,0,vp.texture.width, -float(vp.texture.height)), Rectangle(0,0,1000,1000), ::Vector2(0,0), 0.0f, WHITE);
    DrawText(TextFormat("FPS: %d", GetFPS()), 0,0 , 40, GREEN);
    for(int i = 0;i < 6;i++){
        
        GuiSlider(Rectangle(1100, i * 50 + 100, 200, 40), TextFormat("Actuator %d", i + 1), "", actuator_angles.data() + i, 0.0f, 4.0f * M_PI);
    }
    constexpr const char* names[] = {"X", "Y", "Z", "Roll", "Pitch", "Yaw"};
    bool changed = false;
    
    auto aa_backup = actuator_angles;
    for(int i = 0;i < 3;i++){
        //float x = 0;
        //std::cout << current_transform.translation(i).value << ", ";
        double pp = current_transform.translation(i).value;
        double pv = current_transform.euler_rotation_angles(i).value;
        GuiSlider(Rectangle(1100, i * 50 + 500, 200, 40), names[i + 0], "", &current_transform.translation(i).value, -1.0f, 1.0f);
        GuiSlider(Rectangle(1100, i * 50 + 700, 200, 40), names[i + 3], "", &current_transform.euler_rotation_angles(i).value, -0.5f, 0.5f);
        if(std::abs(current_transform.translation(i).value - pp) > 1e-5){
            changed = true;
        }
        if(std::abs(current_transform.euler_rotation_angles(i).value - pv) > 1e-5){
            changed = true;
        }
    }
    if (GuiButton(Rectangle(1100, 55, 200, 40), "Reset")){
        changed = false;
        actuator_angles.fill(0);
    }
    if(changed){
        //std::cout << "Changed\n";
        
        for(int i = 0;i < 20;i++){
            auto[v, J] = evaluate_with_jacobian_wrt_actuators<double>(current_transform, unit_hexagon, anchors);
            FullPivHouseholderQR<Matrix6<double>> qr(J);

            Vector6<double> subtract = qr.solve(v);
            
            actuator_angles -= subtract.cast<float>();
            anchors = generate_hexagon(jet_type<scalar>(1.6), Vector3<jet_type<scalar>>(0, 0, -1.0));
            for(int j = 0;j < 6;j++){
                jet_type<double> actuator_angle_jet(actuator_angles[j]);
                actuator_angle_jet.deriv(j + 6) = 1;
                anchors.col(j) += (jet_type<double>(rad) * cos(actuator_angle_jet) * rotary_planes[j].first .cast<jet_type<double>>())
                               +  (jet_type<double>(rad) * sin(actuator_angle_jet) * rotary_planes[j].second.cast<jet_type<double>>());
            }
            //std::cout << v << "\n";
            if(i == 19){
                auto[vi, Ji] = evaluate_with_jacobian_wrt_actuators<double>(current_transform, unit_hexagon, anchors);
                if(vi.norm() > 0.1){
                    actuator_angles = aa_backup;
                }
            }
            //std::cout << anchors << "\n";

        }
        for(int j = 0;j < 6;j++){
            actuator_angles(j) = std::fmod(actuator_angles(j) + 2.0 * M_PI, 2.0 * M_PI);
        }
    }
    EndDrawing();
}
int main(){
    InitWindow(1500, 1000, "Simulator");
    rlSetLineWidth(5.0f);
    
    vp = LoadRenderTexture(4000, 4000);
    
    
    GuiSetStyle(DEFAULT, TEXT_SIZE, 30);
    GuiSetStyle(DEFAULT, TEXT_COLOR_NORMAL, ColorToInt(Color{100,255,100,255}));
    GuiSetStyle(DEFAULT, TEXT_SPACING, 4.0f);
    SetTargetFPS(60);
    rlDisableBackfaceCulling();
    #ifdef __EMSCRIPTEN__
    emscripten_set_main_loop(gameLoop, 0, 0);
    #else
    while(!WindowShouldClose()){
        gameLoop();
    }
    #endif
    
    
    
    //std::cout << current_guess.translation << "\n";
    //std::cout << current_guess.euler_rotation_angles << "\n\n";
    //for(int i = 0;i < 6;i++)
    //    std::cout << (current_guess.apply(unit_hexagon).col(i).cast<double>() - anchors.col(i).cast<double>()).norm() << "\n";
    //auto[v, J] = evaluate_with_jacobian<double>(current_guess, unit_hexagon, anchors);
}