/*
  This module provides the scanning methods. It contains the following segments:
  - compute the scores of all sites
  - select top K non-overlapping sites (as defined per the module variable overlapT)
    caveates: greedy, might have to settle for less than K
  - select non-overlapping sites >= T (also greedy)
*/


#include <stdlib.h>
#include <stdio.h>
#include <assert.h>
#include <float.h>
#include <math.h>
#include "my_types.h"
#include "misc_functions.h"
#include "data_interface.h"
#include "pwms.h"
#include "sitePrtctdBtstrp.h"
#include "markov.h"
#include "motif_scan.h"



static siteStructVec gen_seqSitesScores(chunksSeq seq, PWM_Struct pwm, int seqID);
static void trim_K(slctdSitesStruct *slctdSites, int K);
static unsigned char **perSeqInit(chunksSeqVec set, PWM_Struct pwm, dsSitesStruct *dsSites);
static int selectTopKsites(chunksSeqVec set, int width, int K, siteStructVec allSites, unsigned char **selected, slctdSitesStruct *slctdSites);
static int selectGTTsites(chunksSeqVec set, int width, double T, siteStructVec allSites, unsigned char **selected, slctdSitesStruct *slctdSites);
static Boolean noOverlap(siteStruct site,chunksSeq seq, int width, unsigned char *selected, double overlapT);
static void free_slctdSitesStruct(slctdSitesStruct *s);
static unsigned char **allocBoolSeqVec(chunksSeqVec set);
static void freeBoolSeqVec(int nSeqs, unsigned char **bool);



int scanDirection;     // Determines overall scanning direction. allow other modules access through extern
static double overlapT = 0;       // The fraction of overlap that is tolerated
MarkovModelType nullMarkovModel; // Null Markov model used for scoring here. Other modules can access it directly (extern)
static Boolean printTopK = FALSE;  // set to true if we should print the top k sites from each sequence
static Boolean printGTT = FALSE;   // set to true if we should print the GTT sites from each sequence
static int Kprint;                 // what "K" is when printing the top k sites from each sequence


// ----------- Finding the scores of all sites -----------------


void get_dsSitesScores(chunksSeqVec set, PWM_Struct pwm, dsSitesStruct *dsSites)
     /* 
	Updates dsSites if necessary to include scores of all the pwm sites in the given set.
	Per sequence sites are first calcluated then concatenated to generate all dataset sites scores.
     */
{
  int iseq, nSites=0, iSite=0, iSeqSite;

  if (dsSites->allSites.len > 0) // already done
    return;
  // Else, first find the scores in each sequence
  if (dsSites->perSeqSites.len == 0) { // need to allocate it the scores array
    dsSites->perSeqSites.len = set.len;
    assert( dsSites->perSeqSites.entry = (void *) calloc(set.len, sizeof(siteStructVec)) );
  }
  for (iseq = 0; iseq < set.len; iseq++) {
    dsSites->perSeqSites.entry[iseq] = gen_seqSitesScores(set.entry[iseq], pwm, iseq);
    nSites += dsSites->perSeqSites.entry[iseq].len; // keep track of the total number of sites
    qsort((void *) dsSites->perSeqSites.entry[iseq].entry, dsSites->perSeqSites.entry[iseq].len, sizeof(siteStruct), (void *) siteStructCompare); // descending order
  }

  // Next, concatenate all scores for the dataset-wide scores
  FREE_ATOMS_VEC(dsSites->allSites);
  dsSites->allSites.len = nSites;
  assert( dsSites->allSites.entry = (void *) calloc(nSites, sizeof(siteStruct)) );
  for (iseq = 0; iseq < set.len; iseq++)
    for (iSeqSite = 0; iSeqSite < dsSites->perSeqSites.entry[iseq].len; iSeqSite++)
      dsSites->allSites.entry[iSite++] = dsSites->perSeqSites.entry[iseq].entry[iSeqSite];
  qsort((void *) dsSites->allSites.entry, nSites, sizeof(siteStruct), (void *) siteStructCompare); // descending order

  return;
}


siteStructVec gen_seqSitesScores(chunksSeq seq, PWM_Struct pwm, int seqID)
     // Finds all sites in a sequence (internal probably)
{
  int factor=2, iChunk, nSites=0;
  siteStructVec seqSites;

  if (scanDirection < 2)
    factor = 1;
  assert( seqSites.entry = calloc(factor*seq.total_len, sizeof(siteStruct)) ); // note that this is longer than it need be but some chunks might be shorter than pwm.width!!!
  for (iChunk = 0; iChunk < seq.num_chunks; iChunk++)
    nSites += get_chunkSites(seq.chunks[iChunk].body, pwm, scanDirection, seqID, iChunk, seqSites.entry + nSites);

  seqSites.len = nSites;
  return seqSites;
}
 

int get_chunkSites(iLetterVec chunk, PWM_Struct pwm, int scanDirection, int seqID, int chunkID, siteStruct * chunkSites)
     /*
       Computes the scores of all sites of pwm in chunk.
       Implicit input includes: scanDirection and nullMarkovModel.
       The function returns the number of sites it processed.
       At some point this needs to account for gapped positions an possiblt for right
       flanking segment as well.
     */
{
  int i, nSites = 0;
  double nullScore;

  for (i = 0; i < chunk.len-pwm.width+1; i++) {
    nullScore = getMarkovLklhd(nullMarkovModel, chunk.entry, i, pwm.width);
    if (scanDirection == 0 || scanDirection == 2) { // forward site
      chunkSites[nSites].leftPos = i;
      chunkSites[nSites].rev = FALSE;
      chunkSites[nSites++].score = sitePwmScore(chunk.entry+i, pwm.logFreqs, pwm.width) - nullScore;
    }
    if (scanDirection == 1 || scanDirection == 2) { // reverse complement
      chunkSites[nSites].leftPos = i;
      chunkSites[nSites].rev = TRUE;
      chunkSites[nSites++].score = sitePwmScore(chunk.entry+i, pwm.logFreqsRC, pwm.width) - nullScore;
    }
  }
  for (i = 0; i < nSites; i++) {
    chunkSites[i].seqID = seqID;
    chunkSites[i].chunkID = chunkID;
  }
  return nSites;
}


int get_nChunkSitesGTT(iLetterVec chunk, PWM_Struct pwm, int scanDirection, double T)
     /*
       Returns the number of GTT sites in chunk. While this can be done by first calling
       get_chunkSites, the advantage here is that no space is needed
     */
{
  int i, nSites = 0;
  double nullScore;

  for (i = 0; i < chunk.len-pwm.width+1; i++) {
    nullScore = getMarkovLklhd(nullMarkovModel, chunk.entry, i, pwm.width);
    if ((scanDirection == 0 || scanDirection == 2) && 
	sitePwmScore(chunk.entry+i, pwm.logFreqs, pwm.width) - nullScore >= T)
      nSites++;
    if ((scanDirection == 1 || scanDirection == 2) && 
	sitePwmScore(chunk.entry+i, pwm.logFreqsRC, pwm.width) - nullScore >= T)
      nSites++;
  }
  return nSites;
}


// ----------- Finding the top K scoring sites -----------------


int findDsTopK(chunksSeqVec set, PWM_Struct pwm, int K, dsSitesStruct *dsSites)
     /*
       Updates if necessary the "top k sites" entry of dsSites per the given (dataset,PWM)
       Returns the actual number of sites found.
     */
{
  unsigned char **selected;  // modeled after the set, 1 if location is selected

  if (dsSites->slctdSites.K >= K)  // we already worked it out
    trim_K(&(dsSites->slctdSites), K);               // if > K we might need to drop a few sites
  // Else, starting from srcatch isn't too costly and cleaner
  else {
    get_dsSitesScores(set, pwm, dsSites); // make sure we have the scores of all sites first
    selected = allocBoolSeqVec(set);      // initialize all locations to 0
    selectTopKsites(set, pwm.width, K, dsSites->allSites, selected, &(dsSites->slctdSites));

    freeBoolSeqVec(set.len, selected);
  }
  return dsSites->slctdSites.topK.len;
}


void findPerSeqTopK(chunksSeqVec set, PWM_Struct pwm, int *Ks, dsSitesStruct *dsSites)
     /*
       Updates if necessary the "top K sites per sequence" entry of dsSites per the given (dataset,PWM)
       K is specified per sequence in Ks
     */
{
  unsigned char **selected;  // modeled after the set, =1 if location is selected
  int iSeq;

  selected = perSeqInit(set, pwm, dsSites);

  for (iSeq = 0; iSeq < set.len; iSeq++)
    if (dsSites->perSeqSlctdSites.entry[iSeq].K >= Ks[iSeq])
      trim_K(dsSites->perSeqSlctdSites.entry+iSeq, Ks[iSeq]);        // if > K we might need to drop a few sites
    else
      selectTopKsites(set, pwm.width, Ks[iSeq], dsSites->perSeqSites.entry[iSeq], selected, dsSites->perSeqSlctdSites.entry+iSeq);

  freeBoolSeqVec(set.len, selected);
  return;
}


void trim_K(slctdSitesStruct *slctdSites, int K)
     // if K < slctdSites.K then trim if needed
{
  int n;
  if (K < slctdSites->K) {
    slctdSites->K = K;
    n = MIN2(K, slctdSites->topK.len);
    slctdSites->topK.entry = (void *) my_realloc(slctdSites->topK.entry, n*sizeof(siteStruct));
    slctdSites->topK.len = n;
  }
  return;
}


unsigned char **perSeqInit(chunksSeqVec set, PWM_Struct pwm, dsSitesStruct *dsSites)
     // Some common tasks prior to the per-sequence filtering
{
  unsigned char **selected;  // modeled after the set, =1 if location is selected
  int iSeq;

  if (dsSites->perSeqSlctdSites.len == 0) { // need to allocate it
    dsSites->perSeqSlctdSites.len = set.len;
    assert( dsSites->perSeqSlctdSites.entry = (void *) calloc(set.len, sizeof(slctdSitesStruct)) );
    for (iSeq = 0; iSeq < set.len; iSeq++)
      dsSites->perSeqSlctdSites.entry[iSeq].T = DBL_MAX; // make sure T won't randomly match initial value
  }
  get_dsSitesScores(set, pwm, dsSites); // make sure we have the scores of all sites first
  selected = allocBoolSeqVec(set);      // initialize all locations to 0

  return selected;
}


int selectTopKsites(chunksSeqVec set, int width, int K, siteStructVec allSites, unsigned char **selected, slctdSitesStruct *slctdSites)
     /*
       Chooses the top K non-overlapping (as defined by the implicit overlapT) sites from allSites using a 
       greedy heuristic.
       allSites is assumed sorted in a decreasing order
       Returns the actual number of sites found.
     */
{
  int iSite, nSelSite=0;
  siteStruct site;

  FREE_ATOMS_VEC(slctdSites->topK);
  assert( slctdSites->topK.entry = (void *) calloc(K, sizeof(siteStruct)) );

  for (iSite = 0; (iSite < allSites.len) && (nSelSite < K); iSite++) {
    site = allSites.entry[iSite]; // just for readbility
    if ( noOverlap(site, set.entry[site.seqID], width, selected[site.seqID], overlapT) ) // add new site
      slctdSites->topK.entry[nSelSite++] = allSites.entry[iSite];
  }

  slctdSites->topK.len = nSelSite;
  slctdSites->K = K;
  return nSelSite;
}


// ----------- Finding sites scoring >= T -----------------


int findDsSitesGTT(chunksSeqVec set, PWM_Struct pwm, double T, dsSitesStruct *dsSites)
     /*
       Updates if necessary the "sites above T" entry of dsSites per the given dataset and pwm.
       Returns the number of sites found across the dataset.
     */
{
  unsigned char **selected;  // modeled after the set, =1 if location is selected

  if ( dsSites->slctdSites.T == T )  // we already worked it out
    return dsSites->slctdSites.aboveT.len;
  // Else, starting from srcatch isn't too costly and cleaner than building from bigger/smaller T
  get_dsSitesScores(set, pwm, dsSites); // make sure we have the scores of all sites first
  selected = allocBoolSeqVec(set);      // initialize all locations to 0
  selectGTTsites(set, pwm.width, T, dsSites->allSites, selected, &(dsSites->slctdSites));

  freeBoolSeqVec(set.len, selected);

  return dsSites->slctdSites.aboveT.len;
}


void findPerSeqSitesGTT(chunksSeqVec set, PWM_Struct pwm, double T, dsSitesStruct *dsSites)
     /*
       Updates if necessary the "per sequence sites above T" entry of dsSites per the given dataset and pwm.
       Currently T is shared across all sequences. 
       Also since we use a greedy approach we could change the test == T to <= T and drop the sites that fail
       the more stringent threshold.
     */
{
  unsigned char **selected;  // modeled after the set, =1 if location is selected
  int iSeq, nSelSite=0;

  if ( (dsSites->perSeqSlctdSites.len > 0) && (dsSites->perSeqSlctdSites.entry[0].T == T) )  
    return;    // we already worked it out: all sequences are assumed to share the same T!!!
  selected = perSeqInit(set, pwm, dsSites);

  for (iSeq = 0; iSeq < set.len; iSeq++)
    nSelSite += selectGTTsites(set, pwm.width, T, dsSites->perSeqSites.entry[iSeq], selected, dsSites->perSeqSlctdSites.entry+iSeq);

  freeBoolSeqVec(set.len, selected);
  return;
}


int selectGTTsites(chunksSeqVec set, int width, double T, siteStructVec allSites, unsigned char **selected, slctdSitesStruct *slctdSites)
     /*
       Chooses non-overlapping sites >= T from allSites
       Implicit input: overlapT
       allSites is assumed sorted in a decreasing order
       Returns the actual number of sites found.
     */
{
  int iSite, nSelSite=0;
  siteStruct site;

  FREE_ATOMS_VEC(slctdSites->aboveT);
  assert( slctdSites->aboveT.entry = (void *) calloc(allSites.len, sizeof(siteStruct)) );

  for (iSite = 0; iSite < allSites.len; iSite++) {
    site = allSites.entry[iSite]; // just for readbility
    if (site.score < T)           // sites scores are assumed to be decreasing
      break;
    if ( noOverlap(site, set.entry[site.seqID], width, selected[site.seqID], overlapT) ) // add new site
      slctdSites->aboveT.entry[nSelSite++] = allSites.entry[iSite];
  }

  slctdSites->aboveT.len = nSelSite;
  slctdSites->T = T;
  slctdSites->aboveT.entry = (void *) my_realloc(slctdSites->aboveT.entry, nSelSite*sizeof(siteStruct)); // drop the wasted space
  return nSelSite;
}


Boolean noOverlap(siteStruct site,chunksSeq seq, int width, unsigned char *selected, double overlapT)
     /*
       Adds the new site to the selected locations and returns TRUE iff no overlap is detected.
       Otherwise FALSE is returned and selected isn't modified
     */
{
  int i, overlap=0, siteLeftPos;

  siteLeftPos = seq.chunks[site.chunkID].startPos + site.leftPos;
  for (i = siteLeftPos; i < siteLeftPos+width; i++)
    overlap += selected[i];
  if (overlap <= floor(overlapT * width)) {
    for (i = siteLeftPos; i < siteLeftPos+width; i++)
      selected[i] = 1;
    return TRUE;
  }
  else
    return FALSE;
}


// ----------- Setting user input -----------------


void setNullTrainSeq(iLetterVec trainData, int order, double pseudoCount)
     // Generate the null Markov model nullMarkovModel once and for all (trainData isn't saved)
{
  nullMarkovModel = trainMarkovModel(trainData, order, pseudoCount);
}


void set_scanDirection(int value) {scanDirection = value;}


void set_overlapThreshold(double value) {overlapT = value; }


void set_printTopK(Boolean value) {printTopK = value; }


void set_printGTT(Boolean value) {printGTT = value; }



// --------------- others ----------------
int siteStructCompare(const siteStruct *e1, const siteStruct *e2)
     // compares the scores of two sites (used for descending qsort order)
{
  return (e1->score > e2->score) ? -1 : (e1->score == e2->score)? 0 : 1;
}


// ---------------  memory management ------------------


void free_dsSitesStruct(dsSitesStruct *dsSites)
		// Frees all allocated array
{
	int i;
	
	FREE_ATOMS_VEC(dsSites->allSites);
	free_slctdSitesStruct(&(dsSites->slctdSites));
	for (i = 0; i < dsSites->perSeqSites.len; i++)
		FREE_ATOMS_VEC(dsSites->perSeqSites.entry[i]);
	FREE_ATOMS_VEC(dsSites->perSeqSites);
	for (i = 0; i < dsSites->perSeqSlctdSites.len; ++i)
		free_slctdSitesStruct(dsSites->perSeqSlctdSites.entry+i);

	FREE_ATOMS_VEC(dsSites->perSeqSlctdSites);
}

void free_slctdSitesStruct(slctdSitesStruct *s)
{
	FREE_ATOMS_VEC(s->topK);
	FREE_ATOMS_VEC(s->aboveT);
	s->K = 0;
	s->T = DBL_MAX;
}


unsigned char **allocBoolSeqVec(chunksSeqVec set)
     /*
       Allocate and initialize to 0 an unsigned char array of vectors. 
       The number of vectors is the same as the number of sequences in set and each vector
       is as long at the raw sequence. This could have been more frugal but it will probably do.
     */
{
  int iSeq;
  unsigned char **bool;

  assert( bool = (void *) calloc(set.len, sizeof(unsigned char *)) );
  for (iSeq = 0; iSeq < set.len; iSeq++)
    assert( bool[iSeq] = calloc(set.entry[iSeq].chunks[set.entry[iSeq].num_chunks-1].startPos + 
				 set.entry[iSeq].chunks[set.entry[iSeq].num_chunks-1].body.len, sizeof(unsigned char)) );
  return bool;
}


void freeBoolSeqVec(int nSeqs, unsigned char **bool)
     // Frees the allocated array
{
  int iSeq;
  for (iSeq = 0; iSeq < nSeqs; iSeq++)
    free(bool[iSeq]);
  free(bool);
}




// ---------------  printing good sites ------------------

void printSites(chunksSeqVec *set, dsSitesStruct *dsSites, PWM_Struct *pwm, char *setName, FILE *output)
       // call the apprpriate functions to print the good sites
{
  if (dsSites->allSites.len > 0) { // already found all sites so Jeff's code works fine
    if(printTopK)
      printTopKSites(set, dsSites, pwm, setName, output);
    if(printGTT)
      printGTTSites(set, dsSites, pwm, setName, output);
  }
  else {
    if(printGTT)
      printPerSeqSites(set, dsSites, pwm, setName, output, TRUE);
    if(printTopK)
      printPerSeqSites(set, dsSites, pwm, setName, output, FALSE);
  }
}


void printTopKSites(chunksSeqVec *set, dsSitesStruct *dsSites, PWM_Struct *pwm, char *setName, FILE *output)
       // Print the Top K sites in the dsSites array
{
       // this will find the top K sites, and then call printSite on them
       int siteIt, seqIt;
       int *Ks;
	   int totalNumScores;
       siteStruct curSite;

       // just in case the top K sites haven't been found yet
       assert( Ks = (void *) malloc(set->len * sizeof(int)) );
	   for (seqIt = 0; seqIt < set->len; seqIt++)
			   Ks[seqIt] = Kprint; // set how many sites to print, in this case K as defined by set_printSitesTopK()
       for(seqIt = 0; seqIt < set->len; seqIt++)
       {
               findPerSeqTopK(*set, *pwm, Ks, dsSites);
       }

	   // find how many scores we are going to print
	   totalNumScores = 0;
	   for(seqIt = 0; seqIt < set->len; seqIt++)
       {
				totalNumScores += dsSites->perSeqSlctdSites.entry[seqIt].topK.len;
       }

       // print top K sites
       fprintf(output, "\n\n");
	   fprintf(output, "The following %d scores come from the %s\n", totalNumScores, setName);
       for(seqIt = 0; seqIt < set->len; seqIt++)
       {
               for(siteIt = 0; siteIt < dsSites->perSeqSlctdSites.entry[seqIt].topK.len; siteIt++)
               {
                       curSite = dsSites->perSeqSlctdSites.entry[seqIt].topK.entry[siteIt];
                       printSite(set, curSite.seqID, curSite.chunkID, curSite.leftPos, curSite.rev, curSite.score, pwm->width, output);
               }
       }
	   fprintf(output, "\n\n");

       free(Ks);
}


void printGTTSites(chunksSeqVec *set, dsSitesStruct *dsSites, PWM_Struct *pwm, char *setName, FILE *output)
       // Print the top sites above the threshold in the dsSites array
{
       // this will find the GTT sites, and then call printSite on them
       int siteIt;
       siteStruct curSite;

       // just in case the GTT sites haven't been found yet
       findDsSitesGTT(*set, *pwm, pwm->siteThreshold[0], dsSites);

       // print GTT sites
       fprintf(output, "\n\n");
	   fprintf(output, "The following %d scores come from the %s\n", dsSites->slctdSites.aboveT.len, setName);

       for(siteIt = 0; siteIt < dsSites->slctdSites.aboveT.len; siteIt++)
       {
               curSite = dsSites->slctdSites.aboveT.entry[siteIt];
               printSite(set, curSite.seqID, curSite.chunkID, curSite.leftPos, curSite.rev, curSite.score, pwm->width, output);
       }
	   fprintf(output, "\n\n");
}


void set_printSitesTopK(char *Kstr)
// defines "K" when printing the Top K sites from each sequence
{
       Kprint = my_atol(Kstr);
}


void printPerSeqSites(chunksSeqVec *set, dsSitesStruct *dsSites, PWM_Struct *pwm, char *setName, FILE *output, Boolean GTTflag)
     /*
       Print either the top K sites or the GTT sites in each sequence of the input set
       Much more memory efficient than the original print functions which keep the
       sites across the entire set.
     */
{
  unsigned char **selected;  // modeled after the set, =1 if location is selected
  siteStructVec seqSites, printSites;
  int iSeq, iSite;
  slctdSitesStruct slctdSites;
  siteStruct curSite;

  fprintf(output, "\n\n");
  fprintf(output, "The following scores come from the %s\n", setName);

  selected = allocBoolSeqVec(*set);      // initialize all locations to 0 (a bit wasteful)
  slctdSites.aboveT.len = 0;
  slctdSites.topK.len = 0;
  for (iSeq = 0; iSeq < set->len; iSeq++) {
    seqSites = gen_seqSitesScores(set->entry[iSeq], *pwm, iSeq);
    qsort((void *) seqSites.entry, seqSites.len, sizeof(siteStruct), (void *) siteStructCompare); // descending order
    if (GTTflag) {
      selectGTTsites(*set, pwm->width, *(pwm->siteThreshold), seqSites, selected, &slctdSites);
      printSites = slctdSites.aboveT;
    }
    else {
      selectTopKsites(*set, pwm->width, Kprint, seqSites, selected, &slctdSites);
      printSites = slctdSites.topK;
    }
    FREE_ATOMS_VEC(seqSites);
    for(iSite = 0; iSite < printSites.len; iSite++)
      {
	curSite = printSites.entry[iSite];
	printSite(set, curSite.seqID, curSite.chunkID, curSite.leftPos, curSite.rev, curSite.score, pwm->width, output);
      }
    if (printSites.len > 0)
      fprintf(output, "\n");
  }
  fprintf(output, "\n\n\n");
  freeBoolSeqVec(set->len, selected);
  FREE_ATOMS_VEC(slctdSites.aboveT);
  FREE_ATOMS_VEC(slctdSites.topK);
}
