
#include <vector>
#include <algorithm>
#include <cassert>

#include "bvh.h"


//! intersection rayon / triangle.
//! cf "fast, minimum storage ray-triangle intersection" http://www.graphics.cornell.edu/pubs/1997/MT97.pdf
bool intersect( const Triangle *triangle, const Ray& ray, const float htmax, float &t, float &tu, float &tv )
{
    /* begin calculating determinant - also used to calculate U parameter */
    Vector ac= Vector(triangle->a, triangle->c);
    const Vector pvec= cross(ray.direction, ac);
    
    /* if determinant is near zero, ray lies in plane of triangle */
    Vector ab= Vector(triangle->a, triangle->b);
    const float det= dot(ab, pvec);
    if (det > -0.0001f && det < 0.0001f) return false;
    
    const float inv_det= 1.0f / det;
    
    /* calculate distance from vert0 to ray origin */
    const Vector tvec= Vector(triangle->a, ray.origin);
    
    /* calculate U parameter and test bounds */
    const float u= dot(tvec, pvec) * inv_det;
    if(u < 0.0f || u > 1.0f) return false;
    
    /* prepare to test V parameter */
    const Vector qvec= cross(tvec, ab);
    
    /* calculate V parameter and test bounds */
    const float v= dot(ray.direction, qvec) * inv_det;
    if(v < 0.0f || u + v > 1.0f) return false;
    
    /* calculate t, ray intersects triangle */
    float hit= dot(ac, qvec) * inv_det;
    
    // ne renvoie vrai que si l'intersection est valide (comprise entre tmin / 0 et tmax du rayon)
    t= hit;
    tu= u;
    tv= v;
    return (hit > 0 && hit < htmax);
}

/*! intersection rayon / bbox.
renvoie faux si une intersection existe mais n'est pas dans l'intervalle [0 htmax]. \n
renvoie vrai + la position du point d'entree (rtmin) et celle du point de sortie (rtmax). \n
cf "An Efficient and Robust Ray-Box Intersection Algorithm" http://cag.csail.mit.edu/%7Eamy/papers/box-jgt.pdf

le parametre htmax permet de trouver tres facilement l'intersection la plus proche de l'origine du rayon.
\code
    float t= FLT_MAX;
    // rechercher la bbox la plus proche de l'origine du rayon
    for(int i= 0; i < n; i++)
    {
        float tmin, tmax;
        if(intersect(bbox[i], ray, t, tmin, tmax))
            // t= tmin; ne suffit pas si l'intervalle [tmin tmax] est en partie negatif, tmin < 0
            t= std::max(0.f, tmin);
    }
\endcode
*/
bool intersect( const BBox& box, const Ray &ray, const float htmax )
{
    // xslab
    float tmin= (box.min.x - ray.origin.x) / ray.direction.x;
    float tmax= (box.max.x - ray.origin.x) / ray.direction.x;
    if(tmax < tmin) std::swap(tmin, tmax);

    // y slab
    float tymin= (box.min.y - ray.origin.y) / ray.direction.y;
    float tymax= (box.max.y - ray.origin.y) / ray.direction.y;
    if(tymax < tymin) std::swap(tymin, tymax);
    
    if((tmin > tymax) || (tymin > tmax)) return false;
    if(tymin > tmin) tmin= tymin;
    if(tymax < tmax) tmax= tymax;

    // z slab
    float tzmin= (box.min.z - ray.origin.z) / ray.direction.z;
    float tzmax= (box.max.z - ray.origin.z) / ray.direction.z;
    if(tzmax < tzmin) std::swap(tzmin, tzmax);

    if((tmin > tzmax) || (tzmin > tmax)) return false;
    if(tzmin > tmin) tmin= tzmin;
    if(tzmax < tmax) tmax= tzmax;

    // ne renvoie vrai que si l'intersection est valide
    return (tmin < tmax && tmax > 0 && tmin < htmax);
}


BVHNode *create_node( const BBox& box, BVHNode *left, BVHNode *right )
{
    BVHNode *node= new BVHNode;
    node->left= left;
    node->right= right;
    node->triangle= nullptr;
    node->box= box;
    return node;
}

BVHNode *create_leaf( const BBox& box, Triangle *triangle )
{
    BVHNode *leaf= new BVHNode;
    leaf->left= nullptr;
    leaf->right= nullptr;
    leaf->triangle= triangle;
    leaf->box= box;
    return leaf;
}


struct CutLess
{
    CutLess( const int a, const float c ) : axis(a), cut(c) {}
    
    int axis;
    float cut;
    
    bool operator() ( const Triangle& t )
    {
        // construit la bbox du triangle
        BBox box;
        bbox_insert(box, t.a);
        bbox_insert(box, t.b);
        bbox_insert(box, t.c);
        
        if(axis == 0 && box.min.x < cut) return true;
        else if(axis == 1 && box.min.y < cut) return true;
        else if(axis == 2 && box.min.z < cut) return true;
        return false;
    }
};

BVHNode *build_node( std::vector<Triangle>& triangles, const unsigned int begin, const unsigned int end )
{
    if(begin == end) 
        return nullptr; // plus de triangles...
    
    // calcule la bbox de la sequence de triangles
    BBox box;
    for(unsigned int i= begin; i < end; i++)
    {
        bbox_insert(box, triangles[i].a);
        bbox_insert(box, triangles[i].b);
        bbox_insert(box, triangles[i].c);
    }
    
    // construit une feuille, s'il ne reste qu'un seul triangle
    if(begin + 1 == end)
        return create_leaf(box, &triangles.front() + begin);
    
    // trouve l'axe le plus etire de la bbox
    // et coupe la boite en 2 sur l'axe le plus etire
    int axis;
    float cut;
    Vector d= Vector(box.min, box.max);
    if(d.x > d.y && d.x > d.z)      { axis= 0; cut= (box.min.x + box.max.x) / 2; }
    else if(d.y > d.x && d.y > d.z) { axis= 1; cut= (box.min.y + box.max.y) / 2; }
    else                            { axis= 2; cut= (box.min.z + box.max.z) / 2; }
    
    // reparti les triangles en fonction de leur position par rapport a la coupe
    Triangle *pmid= std::partition(&triangles.front() + begin, &triangles.front() + end, CutLess(axis, cut));
    unsigned int mid= std::distance(&triangles.front(), pmid);
    
    if(mid == begin || mid == end)
        // force un decoupage en 2 groupes, si la partition a echouee
        mid= (begin + end) / 2;
    
    return create_node(box, 
        build_node(triangles, begin, mid), 
        build_node(triangles, mid, end));
}


BVH make_bvh( const std::vector<vec3>& positions )
{
    BVH bvh;
    
    // construit la liste de triangles
    bvh.triangles.reserve(positions.size() / 3);
    for(unsigned int i= 0; i +2 < positions.size(); i= i + 3)
        bvh.triangles.push_back( Triangle( Point(positions[i]), Point(positions[i +1]), Point(positions[i +2])) );

    printf("building bvh %d triangles...\n", (int) bvh.triangles.size());
    
    // construit l'arbre
    bvh.root= build_node(bvh.triangles, 0, bvh.triangles.size());
    bvh.box= bvh.root->box;

    return bvh;
}

unsigned long int box_n= 0;
unsigned long int tri_n= 0;


void node_intersect( const BVHNode *node, const Ray& ray, Hit& hit )
{
    if(node == nullptr) return;
    
    if(node->triangle)
    {
        //~ tri_n++;
        // feuille, intersection avec le triangle
        float t, u, v;
        if(intersect(node->triangle, ray, hit.t, t, u, v))
        {
            // mise a jour de l'intersection courante
            hit.t= t;
            hit.p= ray.origin + t * ray.direction;
            
        #if 1
            // calcule la normale geometrique du triangle
            Vector ab= Vector(node->triangle->a, node->triangle->b);
            Vector ac= Vector(node->triangle->a, node->triangle->c);
            hit.n= normalize(cross(ab, ac));
        #else
            // interpole la normale au point d'intersection, necessite des normales par sommet
        #endif
            
            unsigned int id= (unsigned long int) node->triangle & 255;
            hit.color= Color(1.f - (id % 16) / 15.f, (id % 16) / 15.f, id / 255.f);
            hit.hit= true;
        }
    }
    else
    {
        //~ box_n++;
        // noeud interne
        if(intersect(node->box, ray, hit.t))
        {
            // visite les 2 fils
            node_intersect(node->left, ray, hit);
            node_intersect(node->right, ray, hit);
        }
    }
}

bool intersect( const BVH& bvh, const Ray& ray, const float tmax, Hit& hit )
{
    hit.color= Black();
    hit.t= tmax;
    hit.hit= false;
    node_intersect(bvh.root, ray, hit);
    return hit.hit;
}

