#include <cstdio>
#include <cassert>

#include <cmath>
#include <algorithm>

#include "ray.h"


Ray make_ray( const Point& o, const Point& e )
{
    return Ray(o, e);
}


// cf https://www.bramz.net/data/writings/reflection_transmission.pdf

Ray reflect( const Ray& ray, const Point& p, const Vector& n )
{
    Vector nn= n;
    // la normale est orientee vers l'exterieur de l'objet et le rayon est oriente vers l'objet
    if(dot(ray.direction, n) > 0)
        // le rayon est a l'interieur de l'objet, retourner la normale
        nn= -n;
    
    Ray r;
    r.origin= p + 0.001f * nn;
    r.direction= ray.direction - 2 * dot(ray.direction, nn) * nn;
    return r;
}

Ray refract( const Ray& ray, const Point& p, const Vector& n, const float ir )
{
    float n1= 1;
    float n2= ir;
    Vector nn= n;
    
    float cos_theta= -dot(ray.direction, n);
    if(cos_theta < 0)
    {
        // sortie de l'objet
        n1= ir;
        n2= 1;
        nn= -n;
        cos_theta= -cos_theta;
    }

    float i= n1 / n2;
    float sin2_theta= i*i * (1.f - cos_theta*cos_theta);

    Ray t;
    t.origin= p - 0.001f * nn;
    if(sin2_theta >= 1.f) 
        // reflexion totale interne, pas de refraction
        return t;

    t.direction= i * (ray.direction + cos_theta * nn) - std::sqrt(1.f - sin2_theta) * nn;
    return t;
}

bool fresnel_refract( const Ray& ray, const Point& p, const Vector& n, const float ir )
{
    float n1= 1;
    float n2= ir;
    
    float cos_theta= -dot(ray.direction, n);
    if(cos_theta < 0)
    {
        // sortie de l'objet
        n1= ir;
        n2= 1;
    }

    float i= n1 / n2;
    float sin2_theta= i*i * (1.f - cos_theta*cos_theta);
    return (sin2_theta < 1.f);
}

float fresnel( const Ray& ray, const Point& p, const Vector& n, const float ir )
{
    float n1= 1;
    float n2= ir;
    
    float cos_theta= -dot(ray.direction, n);
    if(cos_theta < 0)
    {
        // sortie de l'objet
        n1= ir;
        n2= 1;
        
        // verifier que la refraction existe
        float i= n1 / n2;
        float sin2_theta= i*i * (1.f - cos_theta*cos_theta);
        if(sin2_theta > 1.f)
            // reflexion totale interne, pas de refraction
            return 1.f;
        
        // utiliser la direction refractee
        cos_theta= std::sqrt(1.f - sin2_theta);
    }
    
    float r0= (n1 - n2) / (n1 + n2);
    r0= r0 * r0;
    
    float c= 1.f - cos_theta;
    return r0 + (1.f - r0) * c*c*c*c*c;
}
