#include "FirstPass.h"

#include <math.h>

FirstPass::FirstPass()
   {
   markers = 0;
   }

FirstPass::~FirstPass()
   {
   }

void FirstPass::Prepare(int individuals)
   {
   individuals++;

   genotypeCounts.Dimension(individuals);
   genotypeCounts.Zero();

   heterozygoteCounts.Dimension(individuals);
   heterozygoteCounts.Zero();

   sampleMendelErrors.Dimension(individuals);
   sampleMendelErrors.Zero();
   
   sampleQualityScores.Dimension(individuals);
   sampleQualityScores.Zero();

   mask.Dimension(individuals);
   mask.Zero();

   logSexOdds.Dimension(individuals);
   logSexOdds.Zero();

   logLikelihood.Dimension(individuals);
   logLikelihood.Zero();

   tags.Dimension(individuals);

   }

void FirstPass::ProcessMarker(IntArray & genotypes)
   {
   for (int i = 1; i < genotypeCounts.Length(); i++)
     if (genotypes[i])
       {
       genotypeCounts[i]++;

       if (genotypes[i] == 3)
         heterozygoteCounts[i]++;
       }

     markers++;
   }

void FirstPass::BuildMask(IntArray & sex)
{
   Vector proportions;
   int failures = 0;

   printf("Masking poor quality samples ...\n");
   printf("================================\n\n");

   printf("Checking total number of genotype calls\n");

   int minGenos = int(markers * QC_Settings::SAMPLE_CALLS_MIN + 1);
   int minGenosFail = 0;

   for (int i = 1; i < genotypeCounts.Length(); i++)
     if (genotypeCounts[i] < minGenos || genotypeCounts[i] == 0)
       {
       tags[i] = "TOO_FEW_GENOTYPES[ABSOLUTE]";
       minGenosFail++;
       mask[i]++;
       }
     else
       proportions.Push(genotypeCounts[i]);

   printf("   Flagged %d samples with < %.3f genotyping proportion (<%d/%d markers)\n\n",
        minGenosFail, QC_Settings::SAMPLE_CALLS_MIN, minGenos, markers);
   failures += minGenosFail;

   double average = proportions.Average(0.0);
   double sd = sqrt(proportions.Var(0.0));

   printf("Average sample has %.1f genotypes (+/- %.1f stdev)\n",
        average, sd);

   int newMinGenos = int(average - sd * QC_Settings::SAMPLE_CALLS_ZSCORE + 1);
   int newMinGenosFail = 0;

   if (newMinGenos < minGenos && proportions.CountIfGreater(newMinGenos) != proportions.Length())
     for (int i = 1; i < genotypeCounts.Length(); i++)
       if (genotypeCounts[i] < newMinGenos && !mask[i])
         {
         tags[i] = "TOO_FEW_GENOTYPES[RELATIVE]";
         newMinGenosFail++;
         mask[i]++;
         }

   printf("   Flagged %d samples with outlier genotyping proportion (<%d/%d markers)\n\n",
        newMinGenosFail, newMinGenos, markers);
   failures += newMinGenosFail;

   printf("Checking absolute heterozygosity\n");

   proportions.Clear();
   int hetFail = 0;

   for (int i = 1; i < genotypeCounts.Length(); i++)
     if (!mask[i])
       {
       double het = (double) heterozygoteCounts[i] / (double) genotypeCounts[i];

       if (het < QC_Settings::SAMPLE_HET_MIN)
         {
         tags[i] = "LOW_HETEROZYGOSITY[ABSOLUTE]";
         hetFail++;
         mask[i]++;
         continue;
         }

       if (het > QC_Settings::SAMPLE_HET_MAX)
         {
         tags[i] = "HIGH_HETEROZYGOSITY[ABSOLUTE]";
         hetFail++;
         mask[i]++;
         continue;
         }

       proportions.Push(het);
       }

   printf("   Flagged %d samples with low (<%.3f) or high (>%.3f) heterozygosity\n\n",
        hetFail, QC_Settings::SAMPLE_HET_MIN, QC_Settings::SAMPLE_HET_MAX);
   failures += hetFail;

   average = proportions.Average(0.0);
   sd = sqrt(proportions.Var(0.0));

   double newMinHet = average - QC_Settings::SAMPLE_HET_ZSCORE * sd;
   double newMaxHet = average + QC_Settings::SAMPLE_HET_ZSCORE * sd;
   hetFail = 0;

   printf("Average sample has %.3f heterozygosity (+/- %.3f stdev)\n",
        average, sd);

   for (int i = 1; i < genotypeCounts.Length(); i++)
     if (!mask[i])
       {
       double het = (double) heterozygoteCounts[i] / (double) genotypeCounts[i];

       if (het < QC_Settings::SAMPLE_HET_MIN)
         {
         tags[i] = "LOW_HETEROZYGOSITY[RELATIVE]";
         hetFail++;
         mask[i]++;
         continue;
         }

       if (het > QC_Settings::SAMPLE_HET_MAX)
         {
         tags[i] = "HIGH_HETEROZYGOSITY[RELATIVE]";
         hetFail++;
         mask[i]++;
         continue;
         }
       }

   printf("   Flagged %d samples with outlier low (<%.3f) or high (>%.3f) heterozygosity\n\n",
        hetFail, newMinHet, newMaxHet);
   failures += hetFail;

   printf("Checking mendel errors\n");

   int mendelFail = 0;
   for (int i = 1; i < genotypeCounts.Length(); i++)
      if (((sampleMendelErrors[i]/(markers + 1e-20)) > QC_Settings::SAMPLE_MENDEL_MAX) && !mask[i])
      {
         tags[i] = "TOO_MANY_MENDEL_ERRORS";
         mendelFail++;
         mask[i]++;
      }
   printf("   Flagged %d samples with mendel errors at >%.3f%% of markers\n\n", mendelFail,
        QC_Settings::SAMPLE_MENDEL_MAX * 100.0);
   failures += mendelFail;

   printf("Checking sex odds errors\n");
   int sexOddsFail = 0;
   for (int i = 1; i < genotypeCounts.Length(); i++)
      if (!mask[i])
        if (((sex[i] == MALE) && (logSexOdds[i] < -1*fabs(QC_Settings::SAMPLE_LSEXODDS))) ||
           ((sex[i] == FEMALE) && (logSexOdds[i] > fabs(QC_Settings::SAMPLE_LSEXODDS))))
        {
          tags[i] = "SEX_ODDS_MISMATCH";
          sexOddsFail++;
          mask[i]++;
        }

   printf("   Flagged %d samples with log sex odds outside [-%.2f, %.2f]\n\n", sexOddsFail,
        QC_Settings::SAMPLE_LSEXODDS, QC_Settings::SAMPLE_LSEXODDS);
   failures += sexOddsFail;

   printf("   %d samples passed quality control\n\n", (genotypeCounts.Length() - failures - 1));
}


void FirstPass::UpdateSexOdds (IntArray & genotypes, IntArray & sex, double errorRate)
{
  double firstFreq = 0;
  int totalAlleles = 0;
  int numSamples = genotypes.Length();
  for (int i=1; i < numSamples; i++)
  {
    if ((mask[i] != OKAY) || (sex[i] == MISSING) || (genotypes[i] == MISSING))
      continue;

    if (genotypes[i] == FIRST_HOM)
      firstFreq += (sex[i] == MALE) ? 1 : 2;
    else if (genotypes[i] == HET)
      firstFreq += (sex[i] == MALE) ? 0.5 : 1;
    totalAlleles += (sex[i] == MALE) ? 1 : 2;
  }
  firstFreq /= totalAlleles;

  double LGeno_X_Females[3] = { 2*log(firstFreq),
                                2*log(1 - firstFreq),
                                log(2*firstFreq*(1 - firstFreq)) };
  double LGeno_X_Males[3] = { log((1-errorRate)*firstFreq + errorRate*exp(LGeno_X_Females[0])),
                       log((1-errorRate)*(1 - firstFreq) + errorRate*exp(LGeno_X_Females[1])),
                              log(errorRate*exp(LGeno_X_Females[2])) };

  for (int i=1; i < numSamples; i++)
  {
    if (genotypes[i] != MISSING)
    {
      logSexOdds[i] += LGeno_X_Males[genotypes[i] - 1];
      logSexOdds[i] -= LGeno_X_Females[genotypes[i] - 1];
    }
  }
}

void FirstPass::UpdateLikelihood(IntArray & genotypes, IntArray & sex, bool isSexLinked, double errorRate)
{
  int numGenos = genotypes.Length();
  double firstFreq = 0;
  int totalAlleles = 0;
  if (isSexLinked)
  {
    for (int i=1; i < numGenos; i++)
    {
      if ((mask[i] != OKAY) || (sex[i] == MISSING) || (genotypes[i] == MISSING))
        continue;

      if (genotypes[i] == FIRST_HOM)
        firstFreq += (sex[i] == MALE) ? 1 : 2;
      else if (genotypes[i] == HET)
        firstFreq += (sex[i] == MALE) ? 0.5 : 1;
      totalAlleles += (sex[i] == MALE) ? 1 : 2;
    }
   firstFreq /= totalAlleles;
   if (firstFreq < 1e-7)
      return;
   double LGeno_X_Females[3] = { 2*log(firstFreq),
               2*log(1 - firstFreq),
              log(2*firstFreq*(1 - firstFreq)) };
   double LGeno_X_Males[3] = { log((1-errorRate)*firstFreq + errorRate*firstFreq*firstFreq),
            log((1-errorRate)*(1 - firstFreq) + errorRate*(1-firstFreq)*(1-firstFreq)),
            log(errorRate*2*firstFreq*(1-firstFreq)) };

   for (int i=1; i < numGenos; i++)
   {
     if (genotypes[i] != MISSING)
      logLikelihood[i] += ((sex[i] == MALE) ? LGeno_X_Males[genotypes[i] - 1] :
                                    LGeno_X_Females[genotypes[i] - 1]);
   }
  }
  else
  {
     for (int i=1; i < numGenos; i++)
     {
   if ((mask[i] != OKAY) || (genotypes[i] == MISSING)) //(sex[i] == MISSING) ||
      continue;

        if (genotypes[i] == FIRST_HOM)
          firstFreq += 2;
        else if (genotypes[i] == HET)
          firstFreq += 1;
        totalAlleles += 2;
     }
     firstFreq /= totalAlleles;
     if (firstFreq < 1e-7)
        return;
     double LGeno_Auto[3] = { 2*log(firstFreq),
                              2*log(1 - firstFreq),
                              log(2*firstFreq*(1 - firstFreq)) };

    for (int i=1; i < numGenos; i++)
    {
      if (genotypes[i] != MISSING)
    logLikelihood[i] += LGeno_Auto[genotypes[i] - 1];
    }
  }
}

void FirstPass::UpdateMendelErrors(IntArray & markerMendelErrors)
{
   for (int i = 1; i < markerMendelErrors.Length(); i++)
      sampleMendelErrors[i] += markerMendelErrors[i];
}

void FirstPass::OutputSampleStatistics(FILE * sampleFile, StringArray & columnLabels)
{
   fprintf(sampleFile, "SampleId\tCompleteness\tHeterozygosity\tMendelErrors\tSexOdds\tLogL\tAvgQualityScore\tFlagged\tComments\n");
   for (int i = 1; i < genotypeCounts.Length(); i++)
   {
      fprintf(sampleFile, "%s\t%.4f\t%.4f\t%.4f\t%f\t%f\t",
          (const char *)columnLabels[i],
            (genotypeCounts[i]*1.0)/(markers + 1e-20),
          (heterozygoteCounts[i]*1.0)/(genotypeCounts[i] + 1e-20),
            (sampleMendelErrors[i]*1.0)/(markers + 1e-20),
          logSexOdds[i], logLikelihood[i]);
      if (sampleQualityScores[i] == -1)
         fprintf(sampleFile, "N/A\t");
      else
         fprintf(sampleFile, "%.4f\t", (sampleQualityScores[i]/(markers + 1e-20)));
      if (mask[i] == 0)
         fprintf(sampleFile, "PASSED\t-\n");
      else
         fprintf(sampleFile, "FAILED\t%s\n", (const char *)tags[i]);
   }
}

void FirstPass::ReleaseMemory()
{
   genotypeCounts.Clear();
   heterozygoteCounts.Clear();
   sampleMendelErrors.Clear();
   logSexOdds.Clear();
   logLikelihood.Clear();
}

void FirstPass::UpdateSampleScores(IntArray & qualityScores)
{
   for (int i = 1; i < qualityScores.Length(); i++)
       sampleQualityScores[i] += (qualityScores[i] == -1) ? 0 : qualityScores[i];
}

