//----------------------------------------------------------------------------
// William Baxter III's Ray Tracer
//
//     Project for Comp 238, Raster Graphics
//     University of North Carolina at Chapel Hill
//     
// $Id:$
//----------------------------------------------------------------------------

#include "wb3TracingShader.hpp"
#include "wb3Light.hpp"
#include "wb3Sampling.hpp"

//----------------------------------------------------------------------------
wb3TracingShader::wb3TracingShader()
{
  m_fCutoffAtten = 3e-6f;
}

//----------------------------------------------------------------------------
wb3TracingShader::~wb3TracingShader()
{
}


//----------------------------------------------------------------------------
inline float CLAMP(float f)
{
  // clamp between [0,1]
  if (f>1.0f) return 1.0f;
  if (f<0.0f) return 0.0f;
  return f;
}

//----------------------------------------------------------------------------
void wb3TracingShader::Shade(
    const wb3Scene* scene,
    const wb3Artifact* from,
    const wb3Artifact* to,
    const Ray3f& I, const Vec3f &hitPt, 
    Vec3f& color, Vec3f& attenuation, int hitsLeft)  const
{
  using namespace wb3Sampling;

  // Get normal at point of intersection
  Vec3f N;
  if (to != scene) // HACK!
    to->GetNormal(I, hitPt, N);
  else
    from->GetNormal(I, hitPt, N);    

  // Perturb intersection in dirn of normal to prevent
  // hitting again on the way out due to numerical imprecision.
  float IdotN(I.V() * N);  

  // Get lights
  const wb3LightArray& lights = scene->GetLights();

  // ---- LOCAL ILLUMINATION
  for (unsigned int i=0; i<lights.GetSize(); i++)
  {
    const wb3Light *light = lights.GetAt(i);
    Vec3f lColor(Vec3f::ZERO);
    light->Contribute(scene, to, I, hitPt, N, lColor);
    lColor &= attenuation;
    color += lColor;
  }

  // ---- RECURSIVE ILLUMINATION
  if (hitsLeft > 0) {
    Vec3f toC, fromC;

    // ---- REFLECTION

    // HACK: use component-wise product of reflectivities
    // of from and to materials for interface reflectivity.
    // Works well because usually one or the other is air.

    Vec3f rAttenuation(attenuation);
    rAttenuation &= to->GetReflect(toC,hitPt);
    rAttenuation &= from->GetReflect(fromC,hitPt);

    if (rAttenuation.x + rAttenuation.y + rAttenuation.z > m_fCutoffAtten)
    {
      Vec3f Rv(I.V() + 2.0f * -IdotN * N); // reflected direction
      if (to->GetGloss() < 1.0f)
      {
        // Not perfectly mirrorlike.  Perturb the reflected ray.
        // Rayshade docs say to sample a cone of height 1, centered
        // around the mirror reflection direction, with a base of 
        // width 1.0-gloss.  In otherwords, an angle of atan(1-gloss).
        Vec3f original(Rv);
        PerturbSampleRadially(Rv, Rv, 1.0f - to->GetGloss());
        Rv.Normalize();
      }

      Ray3f R(hitPt, Rv); // reflected ray
      Vec3f rColor(Vec3f::ZERO);

      // What to use for from and to in the recursive call is kinda tricky.
      // Pass the same thing we started with since scene->Shade() will use
      // the 'to' for filtering out bogus hits from numerical imprecision.
      scene->Shade(scene, from, to, R, Vec3f::ZERO, 
                   rColor, rAttenuation, hitsLeft-1);
      color += rColor;
    }


    // HACK AGAIN: use component-wise product of refractivities
    // of from and to materials for interface refractivity.
    // Works well because usually one or the other is air.

    Vec3f tAttenuation(attenuation);
    tAttenuation &= to->GetRefract(toC, hitPt);
    tAttenuation &= from->GetRefract(fromC, hitPt);
    
    if (tAttenuation.x + tAttenuation.y + tAttenuation.z > m_fCutoffAtten)
    {
      // ---- REFRACTION
      float n = from->GetMaterial()->GetIndex() / to->GetMaterial()->GetIndex();
      float underRad = 1.0f + n*n*(IdotN*IdotN - 1.0f);
      if (underRad > 0) // not total internal reflection
      {
        // The eq for the refracted ray expects the normal
        // to be pointing a particular direction w.r.t the incomming ray
        // flip sign accordingly (equivalent to flipping normal).
        float fSgn = (IdotN > 0) ? 1.0f : -1.0f;

        Vec3f Tv      // Transmitted direction
          (n*I.V() + (-n*IdotN + fSgn*float(sqrt(underRad)) ) * N);
        Ray3f T(hitPt, Tv); // transmitted ray
        Vec3f tColor(Vec3f::ZERO);
        scene->Shade(scene, to, 0, T, Vec3f::ZERO, 
                     tColor, tAttenuation, hitsLeft-1);
        color += tColor;
      }
    }
  }
}
//----------------------------------------------------------------------------


