//////////////////////////////////////////////////////////////////////
// MultiBalance.cpp
// Author: Wei-Min Chen
// March 22, 2007

#include "MultiBalance.h"
#include "Kinship.h"
#include "QuickIndex.h"
#include "VCLinear.h"
#include <math.h>

POLY_Multivariate_Balance::POLY_Multivariate_Balance(Pedigree & pedigree):POLY_Multivariate(pedigree)
{
   pheno = NULL;
   Idx = NULL;
   mCovariate.Dimension(0);

}

POLY_Multivariate_Balance::~POLY_Multivariate_Balance()
{
   if(Idx) delete []Idx;
}

int POLY_Multivariate_Balance::constraint(void)
{
   int ret = 0;
   for(int i = 0; i < 2; i++)
      for(int r = 0; r < rCount; r++)
         if(variances[parCount/2*i+r] < 0){
            variances[parCount/2*i+r] = 0;
            ret = 1;
         }
   return ret;
}

void POLY_Multivariate_Balance::GetPhi(int f)
{
   count = size;
   size = count * rCount;
   Kinship kin;
   Phi.Dimension(count, count);
   kin.Setup(*ped.families[f]);
   for(int i = 0; i < count; i++)
      for(int j = i; j < count; j++)
         Phi[j][i] = Phi[i][j] = 2 * kin(ped[pheno[f][i]], ped[pheno[f][j]]);
}

void POLY_Multivariate_Balance::RefreshO(int f)
{
   Omega.Dimension(size, size);
   for(int d = 0; d < rCount; d++){
      for(int i = 0; i < count; i++)   // same trait same person
         Omega[count*d+i][count*d+i] = variances[d] + variances[parCount/2+d];
      for(int i = 0; i < count; i++)
         for(int j = i+1; j < count; j++) // same trait diff person
            Omega[count*d+i][count*d+j] = Omega[count*d+j][count*d+i]
               = Phi[i][j] * variances[parCount/2+d];
   }
   for(int u = 0; u < rCount; u++)
      for(int v = u+1; v < rCount; v++){
         for(int i = 0; i < count; i++)   // diff trait same person
            Omega[count*u+i][count*v+i] = Omega[count*v+i][count*u+i]
               = variances[Idx[u][v]] + variances[parCount/2 + Idx[u][v]];
         for(int i = 0; i < count; i++)
            for(int j = i+1; j < count; j++) // diff trait diff person
               Omega[count*u+i][count*v+j] = Omega[count*u+j][count*v+i]
               = Omega[count*v+j][count*u+i] = Omega[count*v+i][count*u+j]
               = Phi[i][j] * variances[parCount/2 + Idx[u][v]];
      }
}

void POLY_Multivariate_Balance::RefreshOD(int f)
{
   CholeskyOmega.Invert(); // most computationally intensive part
   for(int i = 0; i < parCount/2; i++) O[i].Dimension(count, count);
   for(int i = 0; i < count; i++)
      for(int j = 0; j < count; j++)
         for(int u = 0; u < rCount; u++)
            for(int v = u; v < rCount; v++)
               O[Idx[u][v]][i][j] = CholeskyOmega.inv[count*u+i][count*v+j];
   for(int u = 0; u < parCount/2; u++) // most computationally intensive part
      OP[u].Product(O[u], Phi);
   for(int i = 0; i < parCount; i++) OD[i].Zero();
   for(int i = 0; i < count; i++)
      for(int j = 0; j < count; j++)
         for(int k = 0; k < rCount; k++)
            for(int u = 0; u < rCount; u++){
               OD[u][count*k+i][count*u+j] = O[Idx[k][u]][i][j];
               OD[parCount/2+u][count*k+i][count*u+j] = OP[Idx[k][u]][i][j];
               for(int v = u+1; v < rCount; v++){
                  OD[Idx[u][v]][count*k+i][count*u+j] = O[Idx[k][v]][i][j];
                  OD[Idx[u][v]][count*k+i][count*v+j] = O[Idx[k][u]][i][j];
                  OD[parCount/2+Idx[u][v]][count*k+i][count*u+j] = OP[Idx[k][v]][i][j];
                  OD[parCount/2+Idx[u][v]][count*k+i][count*v+j] = OP[Idx[k][u]][i][j];
               }
            }
}

void POLY_Multivariate_Balance::InitCoef()
{
   rCount = mTrait.Length();
   parCount = rCount*(rCount+1);
   if(Idx==NULL) Idx = new IntArray[rCount];
   for(int r = 0; r < rCount; r++) Idx[r].Dimension(rCount);
   int par_length = 0;
   for(int i = 0; i < rCount; i++)
      Idx[i][i] = par_length++;
   for(int i = 1; i < rCount; i++)
      for(int j = 0; j < i; j++)
         Idx[i][j] = Idx[j][i] = par_length++;
   if(pheno==NULL) pheno = new IntArray[ped.familyCount];
   int samplesize = 0;
   for (int f = 0; f < ped.familyCount; f++) {
      pheno[f].Dimension(0);
      for (int i = ped.families[f]->first; i <= ped.families[f]->last; i++){
         int missing = 0;
         for(int k = 0; k < mCovariate.Length(); k++)
            if(!ped[i].isControlled(mCovariate[k]))
               missing=1;
         for(int r = 0; r < rCount; r++)
            if(!ped[i].isPhenotyped(mTrait[r])) missing = 1;
         if(!missing) pheno[f].Push(i);
      }
      samplesize += pheno[f].Length();
   }
   SampleSize.Dimension(rCount);
   SampleSize.Set(samplesize);
   if(traits == NULL) traits = new Vector[ped.familyCount];
   if(covariates == NULL) covariates = new Matrix[ped.familyCount];
   variances.Dimension(parCount);
   variances.Zero();
   coef.Dimension(0);
   if(ModelPreset==0) InitModel();
   for(int r = 0; r < rCount; r++){
      variances[r] = Model[r]->variances[0];
      variances[parCount/2+r] = Model[r]->variances[1];
      coef.Stack(Model[r]->coef);
   }
   coefCount = coef.Length();
   ValidFamilies = ValidPersons = 0;
   for (int f = 0; f < ped.familyCount; f++){
      int count = pheno[f].Length();
      if(count) ValidFamilies ++;
      else{
         traits[f].Dimension(0);
         covariates[f].Dimension(coefCount, 0);
         continue;
      }
      ValidPersons += count;
      traits[f].Dimension(count*rCount);
      covariates[f].Dimension(coefCount, count*rCount);
      covariates[f].Zero();
      for(int r = 0; r < rCount; r++)
         for(int i = 0; i < count; i++){
            traits[f][count*r+i] = ped[pheno[f][i]].traits[mTrait[r]];
            covariates[f][r*(mCovariate.Length()+1)][count*r+i] = 1;
            for(int j = 0; j < mCovariate.Length(); j++)
               covariates[f][r*(mCovariate.Length()+1)+j+1][count*r+i] =
                  ped[pheno[f][i]].covariates[mCovariate[j]];
         }
   }
   double maxV = -1;
   for(int r = 0; r < rCount; r++){
      double v = variances[r] + variances[parCount/2+r];
      if(v > maxV) maxV = v;
   }
   Epsilon = 1E-7 * maxV * maxV;
}

/*
   double R1, R2;
   Vector overall;
   for(int r1 = 0; r1 < rCount; r1++)
      for(int r2 = r1+1; r2 < rCount; r2++){
         int par = Idx[r1][r2]-rCount;
         overall.Dimension(0);
         for(int i = 0; i < ped.count; i++){
            R1 = Model[r1]->residual(i);
            if(R1 ==_NAN_) continue;
            R2 = Model[r2]->residual(i);
            if(R2 ==_NAN_) continue;
            overall.Push(R1*R2);
         }
         variances[par+rCount] = overall.Sum() / overall.Length();
      }

   Kinship kin;
   Vector num, den;
   for(int r1 = 0; r1 < rCount; r1++)
      for(int r2 = r1+1; r2 < rCount; r2++){
         int par = Idx[r1][r2]-rCount;
         overall.Dimension(0);
         num.Dimension(0);
         den.Dimension(0);
         for(int f = 0; f < ped.familyCount; f++){
            size = ped.families[f]->last - ped.families[f]->first + 1;
            Matrix Phi;
            Phi.Dimension(size, size);
            kin.Setup(*ped.families[f]);
            for(int i = ped.families[f]->first; i <= ped.families[f]->last; i++)
               for(int j = i; j <= ped.families[f]->last; j++)
                  Phi[i-ped.families[f]->first][j-ped.families[f]->first]
                     = Phi[j-ped.families[f]->first][i-ped.families[f]->first]
                     = 2 * kin(ped[i], ped[j]);
            for(int i = ped.families[f]->first; i <= ped.families[f]->last; i++)
               for(int j = ped.families[f]->first; j <= ped.families[f]->last; j++){
                  if(Phi[i-ped.families[f]->first][j-ped.families[f]->first]==0.0)
                     continue;
                  if(i==j) continue;
                  R1 = Model[r1]->residual(i);
                  if(R1 ==_NAN_) continue;
                  R2 = Model[r2]->residual(j);
                  if(R2 ==_NAN_) continue;
                  num.Push(R1*R2*Phi[i-ped.families[f]->first][j-ped.families[f]->first]);
                  den.Push(Phi[i-ped.families[f]->first][j-ped.families[f]->first]*
                     Phi[i-ped.families[f]->first][j-ped.families[f]->first]);
               }
//            overall.Push(vec.Sum()/vec.Length());
         }
//         variances[parCount/2+Idx[r1][r2]] = overall.Sum() / overall.Length();
         variances[parCount/2+Idx[r1][r2]] = num.Sum() / den.Sum();
         variances[Idx[r1][r2]] -= variances[parCount/2+Idx[r1][r2]];
      }
*/


