#include "camera.hpp"
#include <Eigen/Dense>
#include "warping.hpp"
using namespace Eigen;
struct hit{
    Vector3f pos;
    Vector3f normal;
    Vector3f indir;
    Vector3f outdir;
};
std::vector<Eigen::Vector3f> camera::get_image(const scene& sc, size_t n, sampler _rng)const{
    using namespace Eigen;
    std::vector<Eigen::Vector3f> img(n * n, Eigen::Vector3f(0,0,0));
    const Eigen::Vector3f lookray = (look_at - pos).normalized();
	const Eigen::Vector3f left = lookray.cross(up).normalized();
	const Eigen::Vector3f realup = left.cross(lookray);
    const int n2 = n / 2;
    size_t raycasts = 0;
    #pragma omp parallel
    {
    sampler rng;
    std::uniform_real_distribution<float> pert_dis(-0.5,0.5);
    #pragma omp for collapse(2) schedule(guided) reduction(+:raycasts)
    for(int j = 0;j < n;j++){
		for(int i = 0;i < n;i++){
            
            Eigen::Vector3d accum(0,0,0);
            
            for(size_t smp = 0;smp < 128;smp++){
                Eigen::Vector3f di = lookray + (left * ((pert_dis(rng.m_rng) + float(i - n2)) / n2 / 1.5)) + realup * ((pert_dis(rng.m_rng) + float(j - n2)) / n2 / 1.5);
                di.normalize();
                struct RTCRayHit rayhit;
                rayhit.ray.org_x = pos.x();
                rayhit.ray.org_y = pos.y();
                rayhit.ray.org_z = pos.z();
                rayhit.ray.dir_x = di.x();
                rayhit.ray.dir_y = di.y();
                rayhit.ray.dir_z = di.z();
                Vector3f hitpos = pos;
                Vector3f transport(1,1,1);
                std::vector<hit> hits;
                hits.reserve(6);
                for(size_t bounces = 0;bounces < 6;bounces++){
                    struct RTCIntersectContext context;
                    rtcInitIntersectContext(&context);
                    rayhit.ray.tnear = 0.001f;
                    rayhit.ray.tfar = std::numeric_limits<float>::infinity();
                    rayhit.ray.mask = -1;
                    rayhit.ray.flags = 0;
                    rayhit.hit.geomID = RTC_INVALID_GEOMETRY_ID;
                    rayhit.hit.instID[0] = RTC_INVALID_GEOMETRY_ID;
                    rtcIntersect1(sc.m_scene, &context, &rayhit);
                    raycasts++;
                    if(rayhit.hit.geomID != RTC_INVALID_GEOMETRY_ID){
                        if(sc.emitters.contains(rayhit.hit.geomID)){
                            Eigen::Vector3d transport(0,1,1);
                            for(auto it = hits.rbegin();it != hits.rend();it++){
                                transport.array() *= it->outdir.dot(it->normal);
                                for(const auto& pl : sc.pointlight){
                                    struct RTCIntersectContext sh_context;
                                    rtcInitIntersectContext(&sh_context);
                                    Vector3f sh_ray(pl.pos - it->pos);
                                    if(sh_ray.dot(it->normal) < 0)continue;
                                    float sh_ray_length = sh_ray.norm();
                                    sh_ray.normalize();
                                    RTCRay rsh_ray;
                                    rsh_ray.org_x = it->pos.x();
                                    rsh_ray.org_y = it->pos.y();
                                    rsh_ray.org_z = it->pos.z();
                                    rsh_ray.dir_x = sh_ray.x();
                                    rsh_ray.dir_y = sh_ray.y();
                                    rsh_ray.dir_z = sh_ray.z();
                                    rsh_ray.tnear = 0.001f;
                                    rsh_ray.tfar = sh_ray_length;
                                    rsh_ray.mask = -1;
                                    rsh_ray.flags = 0;
                                    raycasts++;
                                    rtcOccluded1(sc.m_scene, &sh_context, &rsh_ray);
                                    if(rsh_ray.tfar > 0){
                                        transport += pl.color.cast<double>() * std::abs(sh_ray.dot(it->normal)) * (1.0 / double(sh_ray_length * sh_ray_length));
                                    }
                                }
                            }
                            accum += transport;
                            goto inner;
                        }
                        rayhit.ray.org_x += rayhit.ray.dir_x * rayhit.ray.tfar;
                        rayhit.ray.org_y += rayhit.ray.dir_y * rayhit.ray.tfar;
                        rayhit.ray.org_z += rayhit.ray.dir_z * rayhit.ray.tfar;
                        Vector3f hitnormal = Vector3f(rayhit.hit.Ng_x, rayhit.hit.Ng_y, rayhit.hit.Ng_z).normalized();
                        
                        Vector3f newdir = uniform_sphere(rng.next2D());
                        if(newdir.dot(hitnormal) <= 0){
                            newdir *= -1.0f;
                        }
                        hits.push_back(hit{
                            Vector3f(rayhit.ray.org_x, rayhit.ray.org_y, rayhit.ray.org_z),
                            hitnormal,
                            Vector3f(rayhit.ray.dir_x, rayhit.ray.dir_y, rayhit.ray.dir_z),
                            newdir
                        });
                        transport *= newdir.dot(hitnormal);
                        rayhit.ray.dir_x = newdir.x();
                        rayhit.ray.dir_y = newdir.y();
                        rayhit.ray.dir_z = newdir.z();
                    }
                    else{
                        Eigen::Vector3d transport(0,0,0);
                        for(auto it = hits.rbegin();it != hits.rend();it++){
                            transport.array() *= it->outdir.dot(it->normal);
                            for(const auto& pl : sc.pointlight){
                                struct RTCIntersectContext sh_context;
                                rtcInitIntersectContext(&sh_context);
                                Vector3f sh_ray(pl.pos - it->pos);
                                if(sh_ray.dot(it->normal) < 0)continue;
                                float sh_ray_length = sh_ray.norm();
                                sh_ray.normalize();
                                RTCRay rsh_ray;
                                rsh_ray.org_x = it->pos.x();
                                rsh_ray.org_y = it->pos.y();
                                rsh_ray.org_z = it->pos.z();
                                rsh_ray.dir_x = sh_ray.x();
                                rsh_ray.dir_y = sh_ray.y();
                                rsh_ray.dir_z = sh_ray.z();
                                rsh_ray.tnear = 0.001f;
                                rsh_ray.tfar = sh_ray_length;
                                rsh_ray.mask = -1;
                                rsh_ray.flags = 0;
                                raycasts++;
                                rtcOccluded1(sc.m_scene, &sh_context, &rsh_ray);
                                if(rsh_ray.tfar > 0){
                                    transport += pl.color.cast<double>() * std::abs(sh_ray.dot(it->normal)) * 1 / (sh_ray_length * sh_ray_length);
                                }
                            }
                        }
                        accum += transport;
                        goto inner;
                    }
                }
                inner:
                (void)0;
            }
            img[j * n + i] = (accum.cast<float>() / 32.0f).array().pow(0.5f).matrix();
        }
    }
    }
    std::cout << std::to_string(raycasts) + " Raycasts\n";
    return img;
}

camera::camera(const Eigen::Vector3f& loc, const Eigen::Vector3f& look, const Eigen::Vector3f& u) : pos(loc), look_at(look), up(u){
	
}