//----------------------------------------------------------------------------
// William Baxter III's Ray Tracer
//
//     Project for Comp 238, Raster Graphics
//     University of North Carolina at Chapel Hill
//     
// $Id:$
//----------------------------------------------------------------------------

#include "wb3TracingShader.hpp"
#include "wb3Light.hpp"

//----------------------------------------------------------------------------
wb3TracingShader::wb3TracingShader()
{
  m_fCutoffAtten = 3e-6f;
  m_iMaxDepth = 10;
}

//----------------------------------------------------------------------------
wb3TracingShader::~wb3TracingShader()
{
}


//----------------------------------------------------------------------------
inline float CLAMP(float f)
{
  // clamp between [0,1]
  if (f>1.0f) return 1.0f;
  if (f<0.0f) return 0.0f;
  return f;
}

//----------------------------------------------------------------------------
void wb3TracingShader::Shade(const wb3Scene* scene,
                             const wb3Artifact* from,
                             const wb3Artifact* to,
                             const Ray3f& I, const Vec3f &hitPt, 
                             Vec3f& color, Vec3f& attenuation, int hits) const
{
  // Get normal at point of intersection
  Vec3f N;
  if (to != scene) // HACK!
    to->GetNormal(I, hitPt, N);
  else
    from->GetNormal(I, hitPt, N);    

  // Perturb intersection in dirn of normal to prevent
  // hitting again on the way out due to numerical imprecision.
  float IdotN(I.V() * N);  

  // Get lights
  const wb3LightArray& lights = scene->GetLights();

  // ---- LOCAL ILLUMINATION
  for (unsigned int i=0; i<lights.GetSize(); i++)
  {
    const wb3Light *light = lights.GetAt(i);
    Vec3f lColor(Vec3f::ZERO);
    light->Contribute(scene, to, I, hitPt, N, lColor);
    lColor = lColor.CompMult(attenuation);
    color += lColor;
  }

  // ---- RECURSIVE ILLUMINATION
  if (hits < m_iMaxDepth) {
    const wb3Material& toMat = *to->GetMaterial();
    const wb3Material& fromMat = *from->GetMaterial();

    // ---- REFLECTION

    // HACK: use component-wise product of reflectivities
    // of from and to materials for interface reflectivity.
    // Works well because usually one or the other is air.

    Vec3f rAttenuation(
      attenuation.CompMult(toMat.GetReflect().CompMult(fromMat.GetReflect())));
    if (rAttenuation.x + rAttenuation.y + rAttenuation.z > m_fCutoffAtten)
    {
      Vec3f Rv(I.V() + 2.0f * -IdotN * N); // reflected direction
      Ray3f R(hitPt, Rv); // reflected ray
      Vec3f rColor(Vec3f::ZERO);

      // What to use for from and to in the recursive call is kinda tricky.
      // Pass the same thing we started with since scene->Shade() will use
      // the 'to' for filtering out bogus hits from numerical imprecision.
      scene->Shade(scene, from, to, R, Vec3f::ZERO, 
                   rColor, rAttenuation, hits+1);
      color += rColor;
    }

#if 1
    // HACK AGAIN: use component-wise product of refractivities
    // of from and to materials for interface refractivity.
    // Works well because usually one or the other is air.

    Vec3f tAttenuation(
      attenuation.CompMult(toMat.GetRefract().CompMult(fromMat.GetRefract())));
    
    if (tAttenuation.x + tAttenuation.y + tAttenuation.z > m_fCutoffAtten)
    {
      // ---- REFRACTION
      float n = from->GetIndex() / to->GetIndex();
      float underRad = 1.0f + n*n*(IdotN*IdotN - 1.0f);
      if (underRad > 0) // not total internal reflection
      {
        // The eq for the refracted ray expects the normal
        // to be pointing a particular direction w.r.t the incomming ray
        // flip sign accordingly (equivalent to flipping normal).
        float fSgn = (IdotN > 0) ? 1.0f : -1.0f;

        Vec3f Tv      // Transmitted direction
          (n*I.V() + (-n*IdotN + fSgn*float(sqrt(underRad)) ) * N);
        Ray3f T(hitPt, Tv); // transmitted ray
        Vec3f tColor(Vec3f::ZERO);
        scene->Shade(scene, to, 0, T, Vec3f::ZERO, 
                     tColor, tAttenuation, hits+1);
        color += tColor;
      }
    }
#endif 
  }
}
//----------------------------------------------------------------------------


