#include "HapMapReference.h"
#include "AlleleTools.h"
#include "IntArray.h"

#ifndef  max
#define  max(a,b)  ((a)>(b)?(a):(b))
#endif

#ifndef  min
#define  min(a,b)  ((a)<(b)?(a):(b))
#endif


HapMapReference::HapMapReference()
{
   overlap    = 0;
   genotypes  = NULL;
   alleles[0] = alleles[1] = NULL;

   different_labels = same_labels = flips = perfect = 0;

   comparisonByScore = NULL;

   tmatches = thomhom = thethom = thomhet = 0;
}

HapMapReference::~HapMapReference()
{
   if (genotypes  != NULL) delete [] genotypes;
   if (alleles[0] != NULL) delete [] alleles[0];
   if (alleles[1] != NULL) delete [] alleles[1];
   if (comparisonByScore != NULL) delete [] comparisonByScore;
}

char * HapMapReference::RetrieveGenotypes(String & markerName)
   {
   int offset = markerLookup.Integer(markerName);

   if (offset >= 0)
      return genotypes + offset * overlap;
   else
      return NULL;
   }

void HapMapReference::LoadGenotypes(FILE * input, AssayInfo & assayInfo, StringArray & sampleIds)
   {
   String      buffer;
   StringArray header, tokens;

   while (!feof(input) && header.Length() == 0)
      header.AddColumns(buffer.ReadLine(input), ' ');

   int columnQC = header.SlowFind("QC_code");
   int columnStrand = header.SlowFind("strand");
   int columnAlleles = header.SlowFind("SNPalleles");

   if (columnQC < 0 || columnStrand < 0 || columnAlleles < 0)
      {
      printf("I cannot parse the header of your HapMap genotype file\n");
      return;
      }

   for (int i = columnQC + 1; i < header.Length(); i++)
      sampleKey.Push(sampleIds.SlowFind(header[i]));

   overlap = sampleKey.CountIfGreaterOrEqual(0);

   if (overlap == 0)
      {
      printf("None of genotype file sample ids could be matched to the HapMap\n");
      return;
      }

   sampleComparisons.Dimension(sampleKey.Length());
   sampleDifferences.Dimension(sampleKey.Length());
   sampleComparisons.Zero();
   sampleDifferences.Zero();

   histogram.Dimension(overlap + 1);
   histogram.Zero();

   comparisonByScore = new IntArray[10004];

   for (int i = 0; i < 10004; i++)
      {
      comparisonByScore[i].Dimension(9);
      comparisonByScore[i].Zero();
      }

   printf("%d samples overlap with user supplied HapMap files ...\n\n", overlap);

   genotypes  = new char [assayInfo.snpArray.Entries() * overlap];
   alleles[0] = new char [assayInfo.snpArray.Entries()];
   alleles[1] = new char [assayInfo.snpArray.Entries()];

   int output = 0;

   while (!feof(input))
      {
      buffer.ReadLine(input);
      buffer.Trim();
      tokens.ReplaceColumns(buffer, ' ');

      if (tokens.Length() != header.Length()) continue;
      if (tokens[columnQC] != "QC+") continue;

      SNPInfo * snpInfo = assayInfo.getRSIDInfo(tokens[0]);

      if (snpInfo == NULL) continue;
      if (tokens[columnAlleles].Length() != 3 || tokens[columnAlleles][1] != '/') continue;
      if (markerLookup.Find(tokens[0]) >= 0) continue;

      alleles[0][output] = tokens[columnAlleles][0];
      alleles[1][output] = tokens[columnAlleles][2];

      char * error = NULL;
      for (int i = columnQC + 1, j = output * overlap; i < header.Length(); i++)
         if (sampleKey[i - columnQC - 1] >= 0)
            genotypes[j++] = TranslateGenotype(tokens[i][0], tokens[i][1], alleles[0][output], alleles[1][output], error);

      if (tokens[columnStrand] == "-")
         {
         alleles[0][output] = FlipAllele(alleles[0][output]);
         alleles[1][output] = FlipAllele(alleles[1][output]);
         }

      markerLookup.SetInteger(tokens[0], output);

      output++;
      }

   printf("HapMap genotypes for %d markers loaded ...\n\n", output);
   }

#define   HAPHOM_COMPARE_MATCH     0
#define   HAPHOM_COMPARE_MISMATCH  1
#define   HAPHOM_COMPARE_MISS      2

#define   HAPHET_COMPARE_MATCH     3
#define   HAPHET_COMPARE_MISMATCH  4
#define   HAPHET_COMPARE_MISS      5

#define   HAPMISS_COMPARE_HOM      6
#define   HAPMISS_COMPARE_HET      7
#define   HAPMISS_COMPARE_MISS     8

int HapMapReference::CustomCompare(
   IntArray & genos, int offset, int * map, IntArray & detailedComparison)
   {
   detailedComparison.Zero();

   // Run the comparison assuming no flips ...
   for (int i = 0, j = offset * overlap; i < sampleKey.Length(); i++)
      if (sampleKey[i] >= 0)
         {
         int k = sampleKey[i];

         if (genotypes[j] == 0)
            {
            detailedComparison[HAPMISS_COMPARE_HOM]  += genos[k] == 1 || genos[k] == 2;
            detailedComparison[HAPMISS_COMPARE_HET]  += genos[k] == 3;
            detailedComparison[HAPMISS_COMPARE_MISS] += genos[k] == 0;
            }
         else if (genotypes[j] == 1 || genotypes[j] == 2)
            {
            int reference = map[genotypes[j]];

            detailedComparison[HAPHOM_COMPARE_MISMATCH]  += genos[k] == 3 || genos[k] == 3 - reference;
            detailedComparison[HAPHOM_COMPARE_MATCH]     += genos[k] == reference;
            detailedComparison[HAPHOM_COMPARE_MISS]      += genos[k] == 0;
            }
         else if (genotypes[j] == 3)
            {
            detailedComparison[HAPHET_COMPARE_MISMATCH]  += genos[k] == 1 || genos[k] == 2;
            detailedComparison[HAPHET_COMPARE_MATCH]     += genos[k] == 3;
            detailedComparison[HAPHET_COMPARE_MISS]      += genos[k] == 0;
            }

         j++;
         }

   return detailedComparison[HAPHOM_COMPARE_MISMATCH] +
          detailedComparison[HAPHET_COMPARE_MISMATCH];
   }

void HapMapReference::CustomCompareByScore(
   IntArray & genos, IntArray & qualityScores,
   int offset, int * map, IntArray * detailedComparison)
   {
   // Run the comparison assuming no flips ...
   for (int i = 0, j = offset * overlap; i < sampleKey.Length(); i++)
      if (sampleKey[i] >= 0)
         {
         int k = sampleKey[i];

         int score = qualityScores[k];

         if (genotypes[j] == 0)
            {
            detailedComparison[score][HAPMISS_COMPARE_HOM]  += genos[k] == 1 || genos[k] == 2;
            detailedComparison[score][HAPMISS_COMPARE_HET]  += genos[k] == 3;
            detailedComparison[score][HAPMISS_COMPARE_MISS] += genos[k] == 0;
            }
         else if (genotypes[j] == 1 || genotypes[j] == 2)
            {
            int reference = map[genotypes[j]];

            detailedComparison[score][HAPHOM_COMPARE_MISMATCH]  += genos[k] == 3 || genos[k] == 3 - reference;
            detailedComparison[score][HAPHOM_COMPARE_MATCH]     += genos[k] == reference;
            detailedComparison[score][HAPHOM_COMPARE_MISS]      += genos[k] == 0;
            }
         else if (genotypes[j] == 3)
            {
            detailedComparison[score][HAPHET_COMPARE_MISMATCH]  += genos[k] == 1 || genos[k] == 2;
            detailedComparison[score][HAPHET_COMPARE_MATCH]     += genos[k] == 3;
            detailedComparison[score][HAPHET_COMPARE_MISS]      += genos[k] == 0;
            }

         j++;
         }
   }

void HapMapReference::CustomCompareBySample(
   IntArray & genos, IntArray & qualityScores,
   int offset, int * map)
   {
   // Run the comparison assuming no flips ...
   for (int i = 0, j = offset * overlap; i < sampleKey.Length(); i++)
      if (sampleKey[i] >= 0)
         {
         int k = sampleKey[i];

         int score = qualityScores[k];

         if (genotypes[j] != 0 && genos[k] != 0 && score > QC_Settings::QUALITY_THRESHOLD)
            {
            if (genotypes[j] == 3)
               sampleDifferences[i] += genos[k] != 3;
            else
               sampleDifferences[i] += genos[k] != map[genotypes[j]];

            sampleComparisons[i]++;
            }

         j++;
         }
   }

void HapMapReference::CompareGenotypes(String & markerName,
                      IntArray & genos, char al1, char al2, char strand,
                      IntArray & qualityScores)
   {
   int offset = markerLookup.Integer(markerName);

   if (offset < 0)
      return;

   if (strand == '-')
      {
      al1 = FlipAllele(al1);
      al2 = FlipAllele(al2);
      }

   int map[] = { 0, 1, 2, 3};

   if (al1 == alleles[1][offset] && al2 == alleles[0][offset])
      map[1] = 2, map[2] = 1;

   bool wrong_labels = false;

   if ( (al1 != alleles[1][offset] || al2 != alleles[0][offset]) &&
        (al1 != alleles[0][offset] || al2 != alleles[1][offset]))
      different_labels++, wrong_labels = true;
   else
      same_labels++;

   if (wrong_labels)
      badlyLabeled.Add(markerName);

   int mismatches = CustomCompare(genos, offset, map, comparisonByScore[10002]);

   // Run the comparison assuming flips ...
   map[1] ^= 3; map[2] ^= 3;
   int fmismatches = CustomCompare(genos, offset, map, comparisonByScore[10003]);
   map[1] ^= 3; map[2] ^= 3;

   bool flipped = (mismatches > fmismatches) && (fmismatches < 5);

   if (flipped)
      {
      map[1] ^= 3;
      map[2] ^= 3;

      flips  += !wrong_labels;

      if (!wrong_labels) flippedMarkers.Add(markerName);

      perfect += fmismatches == 0;
      histogram[fmismatches]++;

      comparisonByScore[10001] += comparisonByScore[10003];
      }
   else
      {
      perfect += mismatches == 0;

      histogram[mismatches]++;

      comparisonByScore[10001] += comparisonByScore[10002];
      }

   CustomCompareBySample(genos, qualityScores, offset, map);

   // Update statistics by quality score
   if (qualityScores.Length() == 0)
      comparisonByScore[0] += comparisonByScore[flipped ? 10003 : 10002];
   else
      CustomCompareByScore(genos, qualityScores, offset, map, comparisonByScore);
   }

void HapMapReference::PrintComparisonHeader()
   {
   printf("SUMMARY BY MARKER\n\n"
          "MARKERS    BADLABELS   FLIPPED    PERFECT\n"
          "========   =========   ========   =======\n"
          "%8d %11d %10d %9d\n\n",
          different_labels + same_labels, different_labels, flips, perfect);

   printf("SUMMARY BY GENOTYPE\n\n"
          "          HAPMAP HOMOZYGOTES       HAPMAP HETEROZYGOTES        HAPMAP MISSING\n"
          "SCORE    MATCH  MISMATCH  MISS     MATCH  MISMATCH  MISS     HOM     HET    MISS\n"
          "=====  ======== ======== ======  ======== ======== ======  ======= ======= ======\n");
   }

void HapMapReference::PrintComparisonSummary()
   {
   int rows = 0, from = -1, to = -1;

   for (int i = 0; i <= 10000; i++)
      if (comparisonByScore[i].Sum())
         {
         if (from == -1) from = i;
         to = i;
         rows++;
         }

   // If there are too many rows to display, do some grouping
   if (rows > 200)
      for (int i = 0, step = (from - to) / 100; i <= 10000; i++)
         if ((i - from) % step != 0)
            {
            comparisonByScore[from + ((i - from) / step) * step] += comparisonByScore[i];
            comparisonByScore[i].Zero();
            }

   for (int i = 0; i <= 10000; i++)
      if (comparisonByScore[i].Sum())
         printf("%5d %9d %8d %6d %9d %8d %6d %8d %8d %6d\n",
                 (int) i, comparisonByScore[i][0], comparisonByScore[i][1], comparisonByScore[i][2],
                    comparisonByScore[i][3], comparisonByScore[i][4], comparisonByScore[i][5],
                    comparisonByScore[i][6], comparisonByScore[i][7], comparisonByScore[i][8]);

   int ALL = 10001; // the row with the totals
   printf("%5s %9d %8d %6d %9d %8d %6d %8d %8d %6d\n",
          "ANY", comparisonByScore[ALL][0], comparisonByScore[ALL][1], comparisonByScore[ALL][2],
                 comparisonByScore[ALL][3], comparisonByScore[ALL][4], comparisonByScore[ALL][5],
                 comparisonByScore[ALL][6], comparisonByScore[ALL][7], comparisonByScore[ALL][8]);

   comparisonByScore[ALL+1].Zero();
   for (int i = (int) max(QC_Settings::QUALITY_THRESHOLD,0); i <= 10000; i++)
      comparisonByScore[ALL+1] += comparisonByScore[i];

   printf(">%4d %9d %8d %6d %9d %8d %6d %8d %8d %6d\n\n",
          (int) QC_Settings::QUALITY_THRESHOLD,
          comparisonByScore[ALL+1][0], comparisonByScore[ALL+1][1], comparisonByScore[ALL+1][2],
          comparisonByScore[ALL+1][3], comparisonByScore[ALL+1][4], comparisonByScore[ALL+1][5],
          comparisonByScore[ALL+1][6], comparisonByScore[ALL+1][7], comparisonByScore[ALL+1][8]);

   double total_genotypes = (comparisonByScore[ALL].Sum() + 1e-30) / 100.;

   printf("SUMMARY BY GENOTYPE\n\n"
          "          HAPMAP HOMOZYGOTES       HAPMAP HETEROZYGOTES        HAPMAP MISSING\n"
          "        (N = %8d, %6.2f%%)   (N = %8d, %6.2f%%)   (N = %6d, %6.2f%%)\n"
          "SCORE   MATCH  MISMATCH   MISS    MATCH  MISMATCH   MISS     HOM     HET    MISS\n"
          "=====  ======= ======== =======  ======= ======== =======  ======= ======= =======\n",
          comparisonByScore[ALL].Sum(0,2), comparisonByScore[ALL].Sum(0,2) / total_genotypes,
          comparisonByScore[ALL].Sum(3,5), comparisonByScore[ALL].Sum(3,5) / total_genotypes,
          comparisonByScore[ALL].Sum(6,8), comparisonByScore[ALL].Sum(6,8) / total_genotypes);

   for (int i = 0; i <= 10000; i++)
      if (comparisonByScore[i].Sum())
         {
         double hom = (comparisonByScore[i].Sum(0,2) + 1e-30) / 100.;
         double het = (comparisonByScore[i].Sum(3,5) + 1e-30) / 100.;
         double miss = (comparisonByScore[i].Sum(6,8) + 1e-30) / 100.;

         printf("%5d %7.2f%% %7.2f%% %6.2f%% %7.2f%% %7.2f%% %6.2f%% %7.2f%% %7.2f%% %6.2f%%\n",
                 i, comparisonByScore[i][0] / hom, comparisonByScore[i][1] / hom, comparisonByScore[i][2] / hom,
                    comparisonByScore[i][3] / het, comparisonByScore[i][4] / het, comparisonByScore[i][5] / het,
                    comparisonByScore[i][6] / miss, comparisonByScore[i][7] / miss, comparisonByScore[i][8] / miss);
         }

   double hom = (comparisonByScore[ALL].Sum(0,2) + 1e-30) / 100.;
   double het = (comparisonByScore[ALL].Sum(3,5) + 1e-30) / 100.;
   double miss = (comparisonByScore[ALL].Sum(6,8) + 1e-30) / 100.;

   printf("%5s %7.2f%% %7.2f%% %6.2f%% %7.2f%% %7.2f%% %6.2f%% %7.2f%% %7.2f%% %6.2f%%\n",
          "ANY", comparisonByScore[ALL][0] / hom, comparisonByScore[ALL][1] / hom, comparisonByScore[ALL][2] / hom,
                 comparisonByScore[ALL][3] / het, comparisonByScore[ALL][4] / het, comparisonByScore[ALL][5] / het,
                 comparisonByScore[ALL][6] / miss, comparisonByScore[ALL][7] / miss, comparisonByScore[ALL][8] / miss);

   hom = (comparisonByScore[ALL+1].Sum(0,2) + 1e-30) / 100.;
   het = (comparisonByScore[ALL+1].Sum(3,5) + 1e-30) / 100.;
   miss = (comparisonByScore[ALL+1].Sum(6,8) + 1e-30) / 100.;

   printf(">%4d %7.2f%% %7.2f%% %6.2f%% %7.2f%% %7.2f%% %6.2f%% %7.2f%% %7.2f%% %6.2f%%\n",
          (int) QC_Settings::QUALITY_THRESHOLD,
          comparisonByScore[ALL+1][0] / hom, comparisonByScore[ALL+1][1] / hom, comparisonByScore[ALL+1][2] / hom,
          comparisonByScore[ALL+1][3] / het, comparisonByScore[ALL+1][4] / het, comparisonByScore[ALL+1][5] / het,
          comparisonByScore[ALL+1][6] / miss, comparisonByScore[ALL+1][7] / miss, comparisonByScore[ALL+1][8] / miss);

   printf(">%4d    (CALL RATE: %7.2f%%)     (CALL RATE: %7.2f%%)    (CALL RATE: %7.2f%%)\n\n",
           (int) QC_Settings::QUALITY_THRESHOLD,
           hom * 10000. / (comparisonByScore[ALL].Sum(0,2) + 1e-30),
           het * 10000. / (comparisonByScore[ALL].Sum(3,5) + 1e-30),
           miss * 10000. / (comparisonByScore[ALL].Sum(6,8) + 1e-30));
   }

void HapMapReference::PrintHistogram()
   {
   printf("DISTRIBUTION OF DISCREPANCIES, BY MARKER\n");
   printf("========================================\n\n");

   printf("Diffs\tMarkers\n");

   for (int i = 0; i < min(histogram.Length(),5); i++)
      printf("%d\t%d\n", i, histogram[i]);

   if (histogram.Length() > 5)
      printf("6+\t%d\n", histogram.Sum(6));
   printf("\n\n");
   }

void HapMapReference::LogFlips(const char * filename)
   {
   FILE * f = fopen(filename, "wt");

   if (f == NULL) return;

   printf("Flipped markers logged to %s\n", filename);

   for (int i = 0; i < flippedMarkers.Capacity(); i++)
      if (flippedMarkers.SlotInUse(i))
         fprintf(f, "FLIPPED: %s\n", (const char *) flippedMarkers[i]);

   for (int i = 0; i < badlyLabeled.Capacity(); i++)
      if (badlyLabeled.SlotInUse(i))
         fprintf(f, "BAD_LABELS: %s\n", (const char *) badlyLabeled[i]);


   fclose(f);
   }

void HapMapReference::LogSampleComparisons(const char * filename, StringArray & sampleIds)
   {
   FILE * f = fopen(filename, "wt");

   if (f == NULL) return;

   printf("Sample-by-sample comparison summary logged to %s\n", filename);

   fprintf(f, "SAMPLE\tDIFFS\tOVERLAP\tDIFFRATE\n");

   for (int i = 0; i < sampleKey.Length(); i++)
      if (sampleComparisons[i] > 0)
         fprintf(f, "%s\t%d\t%d\t%.3f\n",
                 (const char *) sampleIds[sampleKey[i]],
                 sampleDifferences[i], sampleComparisons[i],
                 sampleDifferences[i] / (sampleComparisons[i] + 1e-20));

   fclose(f);
   }


