#include "constants.h"
#include "RelationCheck.h"
#include "QC_Settings.h"

RelationCheck::RelationCheck(int numSamples)
{
  int size = (numSamples*(numSamples-1))/2;
  putativeRelations.Dimension(size);
  estimatedRelations.Dimension(size);
  relationError.Dimension(size);
  exp_IBS0_IBD0.Dimension(size);
  exp_IBS1_IBD0.Dimension(size);
  exp_IBS1_IBD1.Dimension(size);
  observed_IBS0.Dimension(size);
  observed_IBS1.Dimension(size);
  observed_IBS2.Dimension(size);
  estimatedKinship.Dimension(size);
  putativeRelations.Zero();
  estimatedRelations.Set(RELATED);
  relationError.Zero();
  exp_IBS0_IBD0.Zero();
  exp_IBS1_IBD0.Zero();
  exp_IBS1_IBD1.Zero();
  observed_IBS0.Zero();
  observed_IBS1.Zero();
  observed_IBS2.Zero();
  estimatedKinship.Zero();
  
  // private vars
  meanKinship.Dimension(5);
  sdKinship.Dimension(5);
  zscore.Dimension(size);
  zscore.Zero();
  meanKinship.Zero();
  sdKinship.Zero();
  
  relationNames.Dimension(6);
  relationNames[UNRELATED - 1]        = String("unrelated");
  relationNames[HALF_SIBS - 1]        = String("half sibling");
  relationNames[SIBLINGS - 1]         = String("sibling");
  relationNames[PARENT_OFFSPRING - 1] = String("parent-offspring");
  relationNames[DUPLICATES - 1]       = String("duplicate");
  relationNames[RELATED - 1]          = String("related");

  expectedKinship[UNRELATED - 1]        = 0.0;
  expectedKinship[HALF_SIBS - 1]        = 0.25;
  expectedKinship[SIBLINGS - 1]         = 0.5;
  expectedKinship[PARENT_OFFSPRING - 1] = 0.5;
  expectedKinship[DUPLICATES - 1]       = 1.0;
  
  expectedIBDProbs[UNRELATED - 1][0] = 1.0;
  expectedIBDProbs[UNRELATED - 1][1] = 0.0;
  expectedIBDProbs[UNRELATED - 1][2] = 0.0;

  expectedIBDProbs[HALF_SIBS - 1][0] = 0.5;
  expectedIBDProbs[HALF_SIBS - 1][1] = 0.5;
  expectedIBDProbs[HALF_SIBS - 1][2] = 0.0;

  expectedIBDProbs[SIBLINGS - 1][0] = 0.25;
  expectedIBDProbs[SIBLINGS - 1][1] = 0.50;
  expectedIBDProbs[SIBLINGS - 1][2] = 0.25;

  expectedIBDProbs[PARENT_OFFSPRING - 1][0] = 0.0;
  expectedIBDProbs[PARENT_OFFSPRING - 1][1] = 1.0;
  expectedIBDProbs[PARENT_OFFSPRING - 1][2] = 0.0;

  expectedIBDProbs[DUPLICATES - 1][0] = 0.0;
  expectedIBDProbs[DUPLICATES - 1][1] = 0.0;
  expectedIBDProbs[DUPLICATES - 1][2] = 1.0;
  
  maxDelta = 0.1;
}

RelationCheck::~RelationCheck()
{
}

void RelationCheck::initRelations(IntArray &fatherColumn, IntArray &motherColumn, IntArray &columnToSampleId, int numSamples)
{
   int cnt = 0;
   for (int i = 1; i < numSamples; i++)
   {
      for (int j = i+1; j <= numSamples; j++, cnt++)
      {
         if (columnToSampleId[i] == columnToSampleId[j]) //check duplicates
            putativeRelations[cnt] = DUPLICATES;
         else if (fatherColumn[i] == j || fatherColumn[j] == i || motherColumn[i] == j || motherColumn[j] == i)
            putativeRelations[cnt] = PARENT_OFFSPRING;
         else if (fatherColumn[i] == fatherColumn[j] && motherColumn[i] == motherColumn[j] && fatherColumn[i] != -1 && motherColumn[i] != -1)
            putativeRelations[cnt] = SIBLINGS;
         else if ((fatherColumn[i] == fatherColumn[j] && fatherColumn[i] != -1) || (motherColumn[i] == motherColumn[j] && motherColumn[i] != -1))
            putativeRelations[cnt] = HALF_SIBS;
         else putativeRelations[cnt] = UNRELATED;
      }
   }
}

void RelationCheck::updateRelMatrices(IntArray & genotypes, IntArray & mask)
{
  int numSamples = genotypes.Length();
  int pairIndex = 0;
  // calculate the frequency
  // assume freq is freq.
  double freq = 0;
  int totalAlleles = 0;
  for (int i=1; i < numSamples; i++)
  {
   if ((mask[i] != OKAY) || (genotypes[i] == MISSING))
     continue;
   if (genotypes[i] == FIRST_HOM)
     freq += 2;
   else if (genotypes[i] == HET)
     freq += 1;
   totalAlleles += 2;
  }
  freq /= (totalAlleles + 1e-20);

  if (freq < QC_Settings::RELATION_MAF_MIN)
    return; // use only markers frequent enough

  // calculate the prob matrix
  float prob_0_0, prob_1_0, prob_1_1;

  prob_0_0 = 2*freq*freq*(1-freq)*(1-freq);
  prob_1_0 = 4*(pow(freq,3)*(1-freq) + pow((1-freq),3)*freq);
  prob_1_1 = 2*freq*(1-freq);

  for (int i=1; i < (numSamples - 1); i++)
  {
   for(int j=(i+1); j < numSamples; j++, pairIndex++)
   {
     if ((genotypes[i] == MISSING) || (genotypes[j] == MISSING))
      continue;

     exp_IBS0_IBD0[pairIndex] += prob_0_0;
     exp_IBS1_IBD0[pairIndex] += prob_1_0;
     exp_IBS1_IBD1[pairIndex] += prob_1_1;
     if (genotypes[i] == genotypes[j])
      observed_IBS2[pairIndex]++;
     else if ((genotypes[i] == HET) || genotypes[j] == HET)
      observed_IBS1[pairIndex]++;
     else
      observed_IBS0[pairIndex]++;
   }
  }
}

void RelationCheck::findIBDProbs(int pairIndex)
{
   if ((observed_IBS0[pairIndex] + observed_IBS1[pairIndex] + observed_IBS2[pairIndex]) == 0)
      return;
   
   RelationSolver relSolver(exp_IBS0_IBD0[pairIndex], exp_IBS1_IBD0[pairIndex], exp_IBS1_IBD1[pairIndex], observed_IBS0[pairIndex],
                            observed_IBS1[pairIndex], observed_IBS2[pairIndex]);
   relSolver.RunEMIterations();
   observed_IBS0[pairIndex] = relSolver.ibd[0];
   observed_IBS1[pairIndex] = relSolver.ibd[1];
   observed_IBS2[pairIndex] = relSolver.ibd[2];

   estimatedKinship[pairIndex] = relSolver.ibd[1]*0.5 + relSolver.ibd[2];
}

void RelationCheck::checkRelationships()
{
   Vector unrelatedKinships;
   Vector sibKinships;
   Vector halfSibKinships;
   Vector parentOffspringKinships;
   Vector duplicateKinships;
   
   for (int i = 0; i < estimatedKinship.Length(); i++)
   {
      switch(putativeRelations[i])
      {
         case UNRELATED:        unrelatedKinships.Push(estimatedKinship[i]); break;
	 case HALF_SIBS:        halfSibKinships.Push(estimatedKinship[i]); break;
	 case SIBLINGS:         sibKinships.Push(estimatedKinship[i]); break;
	 case PARENT_OFFSPRING: parentOffspringKinships.Push(estimatedKinship[i]); break;
	 case DUPLICATES:       duplicateKinships.Push(estimatedKinship[i]); break;
      }
   }
   
   // Here calculate the mean and std deviation of the vectors and check the relationships
   if (unrelatedKinships.Length() > 0)
   {
      meanKinship[0] = unrelatedKinships.Average();
      sdKinship[0] = sqrt(unrelatedKinships.Var());
   }
   if (halfSibKinships.Length() > 0)
   {
      meanKinship[1] = halfSibKinships.Average();
      sdKinship[1] = sqrt(halfSibKinships.Var());
   }
   if (sibKinships.Length() > 0)
   {
      meanKinship[2] = sibKinships.Average();
      sdKinship[2] = sqrt(sibKinships.Var());
   }
   if (parentOffspringKinships.Length() > 0)
   {
      meanKinship[3] = parentOffspringKinships.Average();
      sdKinship[3] = sqrt(parentOffspringKinships.Var());
   }
   if (duplicateKinships.Length() > 0)
   {
      meanKinship[4] = duplicateKinships.Average();
      sdKinship[4] = sqrt(duplicateKinships.Var());
   }
   
   for (int i = 0; i < estimatedKinship.Length(); i++)
   {
      if (sdKinship[putativeRelations[i] - 1] > 0.0)
      {
         zscore[i] = (estimatedKinship[i] - meanKinship[putativeRelations[i] - 1])/sdKinship[putativeRelations[i] - 1];
         if (fabs(zscore[i]) > QC_Settings::RELATION_ZSCORE)
            relationError[i] = 1;
      }
      else
      {
         if (estimatedKinship[i] - expectedKinship[putativeRelations[i] - 1] > QC_Settings::RELATION_MAX_MEAN_DIFF)
            relationError[i] = 1;
      }
      if (relationError[i] == 1)
      {
         int tempRelation = calculateEstimatedRelationship(observed_IBS0[i], observed_IBS1[i], observed_IBS2[i]);
         if ((tempRelation == 1 || tempRelation == 2)  && fabs(estimatedKinship[i] - expectedKinship[tempRelation - 1]) > maxDelta)
            estimatedRelations[i] = 5;
         else
            estimatedRelations[i] = tempRelation;
      }
      else
         estimatedRelations[i] = putativeRelations[i];
   }
}

void RelationCheck::outputRelationInfo(FILE *relFile, StringArray &sampleIDs)
{
   for (int i = 0; i < 5; i++)
      if (sdKinship[i] > 0.0 && (meanKinship[i] - expectedKinship[i] > QC_Settings::RELATION_MAX_MEAN_DIFF))
         fprintf(relFile, "WARNING: The mean of the %s relation group is %lf away from the expected mean %lf.\n", (const char *)relationNames[i],
	                                                                                (meanKinship[i]-expectedKinship[i]), expectedKinship[i]);
   int count = 0;
   for (int i = 1; i < sampleIDs.Length() - 1; i++)
   {
      for (int j = i + 1; j < sampleIDs.Length(); j++, count++)
      {
         if (relationError[count] == 1)
         {
            fprintf(relFile, "%s and %s have misspecified relationship. Put: %s\tEst: %s\tEst kinship: %lf\tEst IBD0=%lf IBD1=%lf IBD2=%lf\n", 
	    (const char *)sampleIDs[i], (const char *)sampleIDs[j], (const char *) relationNames[putativeRelations[count] - 1], 
	    (const char *) relationNames[estimatedRelations[count] - 1], estimatedKinship[count], observed_IBS0[count], observed_IBS1[count], observed_IBS2[count]);
         }
      }
   }
}

void RelationCheck::ReleaseMemory()
{
   putativeRelations.Clear();
   relationError.Clear();
   exp_IBS0_IBD0.Clear();
   exp_IBS1_IBD0.Clear();
   exp_IBS1_IBD1.Clear();
   observed_IBS0.Clear();
   observed_IBS1.Clear();
   observed_IBS2.Clear();
   estimatedKinship.Clear();
}

void RelationCheck::outputRelFile(FILE* rel, int pairIndex)
{
   fprintf(rel, "%lf\t%lf\t%lf\t%lf\n", observed_IBS0[pairIndex],
           observed_IBS1[pairIndex], observed_IBS2[pairIndex], estimatedKinship[pairIndex]); 
}

RelationSolver::RelationSolver(double exp_IBS0_IBD0, double exp_IBS1_IBD0,
                               double exp_IBS1_IBD1, double obs_IBS0,
                double obs_IBS1, double obs_IBS2)
{
  N      = obs_IBS0 + obs_IBS1 + obs_IBS2;
  observed_IBS0 = obs_IBS0/N;
  observed_IBS1 = obs_IBS1/N;
  observed_IBS2 = obs_IBS2/N;

  expected_IBS0_IBD0 = exp_IBS0_IBD0/N;
  expected_IBS1_IBD0 = exp_IBS1_IBD0/N;
  expected_IBS1_IBD1 = exp_IBS1_IBD1/N;
  expected_IBS2_IBD0 = 1.0 - expected_IBS0_IBD0 - expected_IBS1_IBD0;
  expected_IBS2_IBD1 = 1.0 - expected_IBS1_IBD1;
  expected_IBS2_IBD2 = 1.0;

  // starting point for the EM iterations
  tolerance = 0.0001;
  ibd[0] = 0.25;
  ibd[1] = 0.50;
  ibd[2] = 0.25;
}

void RelationSolver::RunEMIterations()
{
// Run through a few E-M iterations, better would be
// to check for convergence and then stop when difference
// between ibd_new and ibd is small.
//
   while(true) 
   {
      double ibd_new[3];

      double sum1 = ibd[0] * expected_IBS1_IBD0 +
                    ibd[1] * expected_IBS1_IBD1;

      double sum2 = ibd[0] * expected_IBS2_IBD0 +
                    ibd[1] * expected_IBS2_IBD1 +
                    ibd[2] * expected_IBS2_IBD2;

      ibd_new[0] = observed_IBS0 + observed_IBS1 * ibd[0] * expected_IBS1_IBD0 / sum1 +
                   observed_IBS2 * ibd[0] * expected_IBS2_IBD0 / sum2;

      ibd_new[1] = observed_IBS1 * ibd[1] * expected_IBS1_IBD1 / sum1 +
                   observed_IBS2 * ibd[1] * expected_IBS2_IBD1 / sum2;

      ibd_new[2] = observed_IBS2 * ibd[2] / sum2;

      
      if (fabs(ibd[0] - ibd_new[0]) < tolerance &&
          fabs(ibd[1] - ibd_new[1]) < tolerance &&
     fabs(ibd[2] - ibd_new[2]) < tolerance)
      {
         ibd[0] = ibd_new[0];
         ibd[1] = ibd_new[1];
         ibd[2] = ibd_new[2];
         break;
      }
      else
      {
         ibd[0] = ibd_new[0];
         ibd[1] = ibd_new[1];
         ibd[2] = ibd_new[2];
      }
   }
}


int RelationCheck::calculateEstimatedRelationship(double ibd0, double ibd1, double ibd2)
{
   double minDist = 5.0;
   int estRelation = RELATED;
   double euclidDist = 0.0;
   for (int i = 0; i < 5; i++)
   {
      euclidDist = pow((ibd0 - expectedIBDProbs[i][0]), 2) + 
                   pow((ibd1 - expectedIBDProbs[i][1]), 2) +
                   pow((ibd2 - expectedIBDProbs[i][2]), 2);
      euclidDist = sqrt(euclidDist);
      if (fabs(euclidDist) < fabs(minDist))
      {
         minDist = euclidDist;
	 estRelation = i + 1;
      }      
   }
   return estRelation;
}
