/*
  This module provides the scanning methods. It contains the following segments:
  - compute the scores of all sites
  - select top K non-overlapping sites (as defined per the module variable overlapT)
    caveates: greedy, might have to settle for less than K
  - select non-overlapping sites >= T (also greedy)
*/


#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <assert.h>
#include <float.h>
#include <math.h>
#include "my_types.h"
#include "data_interface.h"
#include "misc_functions.h"
#include "pwms.h"
#include "sitePrtctdBtstrp.h"
#include "markov.h"
#include "motif_scan.h"


#define numStableBlockSize 10000		// recompute word likelihood from scratch every so many characaters

const double INF = 1e10;

static siteStructVec gen_seqSitesScores(const gappedSeq seq, const motif_Struct motif, int seqID);
int get_gappedSeqMotifScores(const gappedSeq seq, const motif_Struct motif, const int scanDirection, const int seqID, 
		siteStruct *seqSites, const Boolean onlyCountGTT, const double T);
static void get_gappedSeqMotifScoresHelper (int seqID, const gappedSeq seq, const motif_Struct motif, const realVec seqSiteScores[], 
		const Boolean isRC, const Boolean onlyCountGTT, const double T, siteStruct *seqSites, int* nSites);
static void get_motifSeqScoreRecur(const motif_Struct motif, const realVec seqSiteScores[], int gaps[], int pwmIdx, int currentMotif, int currentMotifSpan, 
			double bestSeqSiteScores[], int bestSeqSiteMotifs[], int bestSeqSiteMotifSpans[]);
static void get_chunkPWMScores (const iLetterVec chunk, const PWM_Struct pwm, const realVec wordNullScores, realVec siteScores);

static void trim_K(slctdSitesStruct *slctdSites, int K);
static unsigned char **perSeqInit(gappedSeqVec set, motif_Struct motif, dsSitesStruct *dsSites);
static int selectTopKsites(gappedSeqVec set, int K, siteStructVec allSites, unsigned char **selected, slctdSitesStruct *slctdSites);
static int selectGTTsites(gappedSeqVec set, double T, siteStructVec allSites, unsigned char **selected, slctdSitesStruct *slctdSites);
static Boolean noOverlap(siteStruct site, gappedSeq seq, unsigned char *selected, double overlapT);
static void free_slctdSitesStruct(slctdSitesStruct *s);
static unsigned char **allocBoolSeqVec(gappedSeqVec set);
static void freeBoolSeqVec(int nSeqs, unsigned char **bool);
static void set_seqSite(int seqID, int chunkID, int leftChunkOffset, double siteScore, int motifConfig, int motifSpan, 
						Boolean isRC, Boolean onlyCountGTT, double T, siteStruct *seqSites, int *nSites);
static void get_nullWordsScoreFromPosScores(realVec posNullScores, int w, realVec wordNullScores);


static Boolean avg2StrandsNull;	 // Is the null score of a site its markov LLR of the scanned strand (0) or average of both
							 	 // strands (1). The latter is needed to ensure consistency between the PWM and its RC.
int scanDirection;     // Determines overall scanning direction. allow other modules access through extern
static double overlapT = 0;       // The fraction of overlap that is tolerated
MarkovModelType nullMarkovModel; // Null Markov model used for scoring here. Other modules can access it directly (extern)
static Boolean printTopK = FALSE;  // set to true if we should print the top k sites from each sequence
static Boolean printGTT = FALSE;   // set to true if we should print the GTT sites from each sequence
static int Kprint;                 // what "K" is when printing the top k sites from each sequence


// ----------- Finding the scores of all sites -----------------


void get_dsSitesScores(gappedSeqVec set, motif_Struct motif, dsSitesStruct *dsSites)
     /* 
	Updates dsSites if necessary to include scores of all the motif sites in the given set.
	Per sequence sites are first calcluated then concatenated to generate all dataset sites scores.
     */
{
  int iseq, nSites=0, iSite=0, iSeqSite;

  if (dsSites->allSites.len > 0) // already done
    return;
  // Else, first find the scores in each sequence
  if (dsSites->perSeqSites.len == 0) { // need to allocate it the scores array
    dsSites->perSeqSites.len = set.len;
    assert( dsSites->perSeqSites.entry = (void *) calloc(set.len, sizeof(siteStructVec)) );
  }
  for (iseq = 0; iseq < set.len; iseq++) {
    dsSites->perSeqSites.entry[iseq] = gen_seqSitesScores(set.entry[iseq], motif, iseq);
    nSites += dsSites->perSeqSites.entry[iseq].len; // keep track of the total number of sites
    qsort((void *) dsSites->perSeqSites.entry[iseq].entry, dsSites->perSeqSites.entry[iseq].len, sizeof(siteStruct), (void *) siteStructCompare); // descending order
  }

  // Next, concatenate all scores for the dataset-wide scores
  FREE_ATOMS_VEC(dsSites->allSites);
  dsSites->allSites.len = nSites;
  assert( dsSites->allSites.entry = (void *) calloc(nSites, sizeof(siteStruct)) );
  for (iseq = 0; iseq < set.len; iseq++)
    for (iSeqSite = 0; iSeqSite < dsSites->perSeqSites.entry[iseq].len; iSeqSite++)
      dsSites->allSites.entry[iSite++] = dsSites->perSeqSites.entry[iseq].entry[iSeqSite];
  qsort((void *) dsSites->allSites.entry, nSites, sizeof(siteStruct), (void *) siteStructCompare); // descending order

  return;
}


/* Begin new functions by Anand */

/*
 * Find all sites in the sequence seq
 */
siteStructVec gen_seqSitesScores(const gappedSeq seq, const motif_Struct motif, int seqID)	{
	siteStructVec seqSites;
	int factor = 2;
	if (scanDirection < 2)	factor = 1;
	seqSites.entry = calloc(factor * seq.body.len, sizeof(siteStruct));
	int nSites = get_gappedSeqMotifScores(seq, motif, scanDirection, seqID, seqSites.entry, FALSE, 0);
	seqSites.len = nSites;
	//printf("In gen_seqSitesScores, length = %d\n", nSites);
	return seqSites;
}


/*
 * wrapper around the function of the same name which expects a gappedSeq type of sequence.
 */
int get_chunkSites(const iLetterVec seq, const motif_Struct motif, const int scanDirection, const int seqID, const int chunkID, 
		siteStruct *seqSites, const Boolean onlyCountGTT, const double T)	{
	//make a gappedSeq struct out of seq 
	gappedSeq gSeq;
	gSeq.body = seq;
	gSeq.numChunks = 1;
	gSeq.chunkOffsets = calloc(1, sizeof (int));
	gSeq.chunkLens = calloc(1, sizeof (int));
	gSeq.chunkOffsets[0] = 0;
	gSeq.chunkLens[0] = seq.len;
	
	int nSites = get_gappedSeqMotifScores(gSeq, motif, scanDirection, seqID, seqSites, onlyCountGTT, T);
	
	free(gSeq.chunkLens);
	free(gSeq.chunkOffsets);
	
	return nSites;
}

/* 
 * Takes in a gappedSeq sequence (which will have chunks possibly separated by gaps), and a motif (which 
 * may have pwms separated by possible gap restrictions), and return the scores and number of sites above
 * the threshold. If countGTT is true, only the number of sites is returned.
 */
int get_gappedSeqMotifScores(const gappedSeq seq, const motif_Struct motif, const int scanDirection, const int seqID, 
							siteStruct *seqSites, const Boolean onlyCountGTT, const double T)	{ 
	int numChunks = seq.numChunks, nPWMs = motif.nPWMs;
	int seqLen = seq.body.len;
	int pwmIdx, chunkIdx, pos, i;
	
	int nSites = 0;
	
	realVec seqSiteScores[nPWMs], rcSeqSiteScores[nPWMs];
	
	int doForward = (scanDirection == 0 || scanDirection == 2);
	int doReverse = (scanDirection == 1 || scanDirection == 2);
	
	gappedSeq rcSeq;
	if (doReverse || avg2StrandsNull)	{	//make reverse complement of seq
		rcSeq = getSeqReverseComplement(seq);
		assert(rcSeq.numChunks == seq.numChunks);
		assert(rcSeq.body.len == seq.body.len);
	}
	
	for (pwmIdx = 0; pwmIdx < nPWMs; pwmIdx++)	{
		if (doForward)	{
			seqSiteScores[pwmIdx] = alloc_realVec(seqLen);
			for (pos = 0; pos < seqLen; pos++)	seqSiteScores[pwmIdx].entry[pos] = -INF;
		}
		if (doReverse)	{
			rcSeqSiteScores[pwmIdx] = alloc_realVec(seqLen);
			for (pos = 0; pos < seqLen; pos++)	rcSeqSiteScores[pwmIdx].entry[pos] = -INF;
		}
	}
	
	//get null scores for each position of each chunk
	for (chunkIdx = 0; chunkIdx < numChunks; chunkIdx++)	{
		iLetterVec chunk = { seq.chunkLens[chunkIdx], seq.body.entry + seq.chunkOffsets[chunkIdx] };
		iLetterVec rcChunk;
		
		realVec posNullScores, rcPosNullScores;
		realVec wordNullScores, rcWordNullScores;
		
		if (doForward || avg2StrandsNull)	{	//get forward null scores for each position of each chunk
			posNullScores = getPosMarkovLklhds(nullMarkovModel, chunk);
			wordNullScores = alloc_realVec(chunk.len);
		}
		
		if (doReverse || avg2StrandsNull)	{	//get reverse null scores for each position of each chunk
			rcChunk.len = rcSeq.chunkLens[numChunks - 1 - chunkIdx];
			rcChunk.entry = rcSeq.body.entry + rcSeq.chunkOffsets[numChunks - 1 - chunkIdx];
			
			assert(chunk.len == rcChunk.len);	//sanity check that rcChunk is the RC of chunk
			
			rcPosNullScores = getPosMarkovLklhds(nullMarkovModel, rcChunk);
			rcWordNullScores = alloc_realVec(chunk.len);
		}
				
		//get null score and site score for each pwm of this chunk
		for (pwmIdx = 0; pwmIdx < nPWMs; pwmIdx++)	{
			PWM_Struct pwm = motif.pwms[pwmIdx];
			if (doForward || avg2StrandsNull)	{	// forward strand null word scores
		      	get_nullWordsScoreFromPosScores(posNullScores, pwm.width, wordNullScores);
			}

		    if (doReverse || avg2StrandsNull)	{	// reverse strand null word scores
				get_nullWordsScoreFromPosScores(rcPosNullScores, pwm.width, rcWordNullScores);
			}

		    if (avg2StrandsNull)	{	// average the null score over both strands
				for (i = 0; i < chunk.len - pwm.width + 1; i++) {
					wordNullScores.entry[i] = (wordNullScores.entry[i] + rcWordNullScores.entry[chunk.len - i - pwm.width]) / 2;
					rcWordNullScores.entry[chunk.len - i - pwm.width] = wordNullScores.entry[i];
				}
			}
			
			
			//score the sites
			if (doForward)	{
				realVec siteScores = { chunk.len, seqSiteScores[pwmIdx].entry + seq.chunkOffsets[chunkIdx] };
				get_chunkPWMScores(chunk, motif.pwms[pwmIdx], wordNullScores, siteScores);
			}
			if (doReverse)	{
				realVec rcSiteScores = { rcChunk.len, rcSeqSiteScores[pwmIdx].entry + rcSeq.chunkOffsets[numChunks - 1 - chunkIdx] };
				get_chunkPWMScores(rcChunk, motif.pwms[pwmIdx], rcWordNullScores, rcSiteScores);
			}
		}
		
		if (doForward || avg2StrandsNull)	{
			FREE_ATOMS_VEC(posNullScores);
			FREE_ATOMS_VEC(wordNullScores);
		}
		
		if (doReverse || avg2StrandsNull)	{
			FREE_ATOMS_VEC(rcPosNullScores);
			FREE_ATOMS_VEC(rcWordNullScores);
		}
	}
	
	if (doForward)	{
		get_gappedSeqMotifScoresHelper (seqID, seq, motif, seqSiteScores, FALSE, onlyCountGTT, T, seqSites, &nSites);
	}
	if (doReverse)	{
		get_gappedSeqMotifScoresHelper (seqID, rcSeq, motif, rcSeqSiteScores, TRUE, onlyCountGTT, T, seqSites, &nSites);
	}
	
	//free the vectors which recorded the seq and chunk site scores
	for (pwmIdx = 0; pwmIdx < nPWMs; pwmIdx++)	{
		if (doForward)	FREE_ATOMS_VEC(seqSiteScores[pwmIdx]);
		if (doReverse)	FREE_ATOMS_VEC(rcSeqSiteScores[pwmIdx]);
	}
	
	if (doReverse || avg2StrandsNull)	{	//	free rcSeq
		FREE_ATOMS_VEC(rcSeq.body);
		free(rcSeq.chunkLens);
		free(rcSeq.chunkOffsets);
	}
	
	return nSites;
}


/* Helper function for get_gappedSeqMotifScores. Scores one "strand" of the sequence seq and records it. It uses the
 * scores for each pwm of the motif for the whole sequence that is precomputed and stored in seqSiteScores
 */
void get_gappedSeqMotifScoresHelper (int seqID, const gappedSeq seq, const motif_Struct motif, const realVec seqSiteScores[], 
									const Boolean isRC, const Boolean onlyCountGTT, const double T,
									siteStruct *seqSites, int *nSites)	{
	int numChunks = seq.numChunks, nPWMs = motif.nPWMs;
	int seqLen = seq.body.len;
	int chunkIdx, pos;
	
	double* bestSeqSiteScores = calloc(seqLen, sizeof(double));
	int* bestSeqSiteMotifs = calloc(seqLen, sizeof(int));
	int* bestSeqSiteMotifSpans = calloc(seqLen, sizeof(int));
	
	for (pos = 0; pos < seqLen; pos++)	bestSeqSiteScores[pos] = -INF;	//set scores at each position to -INF
	memset(bestSeqSiteMotifs, 255, seqLen * sizeof(int));	//set to -1
	memset(bestSeqSiteMotifSpans, 255, seqLen * sizeof(int));	//set to -1
	
	int gaps[nPWMs];	//needed for the recursive function call below
	gaps[nPWMs - 1] = 0;	//sentinel
	get_motifSeqScoreRecur(motif, seqSiteScores, gaps, 0, 0, 0, 
						   bestSeqSiteScores, bestSeqSiteMotifs, bestSeqSiteMotifSpans);

	for (chunkIdx = 0; chunkIdx < numChunks; chunkIdx++)	{
		int chunkOffset = seq.chunkOffsets[chunkIdx];
		int absPos = chunkOffset;
		for (pos = 0; pos < seq.chunkLens[chunkIdx]; pos++, absPos++)	if (bestSeqSiteMotifSpans[absPos] > 0)	{
			if (!isRC)	{
				//printf("forward hit, tot seq len = %d, numChunks = %d, absPos = %d, chunkIdx = %d, span = %d, score = %lf\n", 
				//		seq.body.len, seq.numChunks, absPos, chunkIdx, bestSeqSiteMotifSpans[absPos], bestSeqSiteScores[absPos]);
				
				set_seqSite(seqID, chunkIdx, pos, bestSeqSiteScores[absPos], bestSeqSiteMotifs[absPos], bestSeqSiteMotifSpans[absPos], isRC, onlyCountGTT, T, seqSites, nSites);
			}
			else	{
				//find which chunk the end of the motif falls in the original sequence
				int endPos = absPos + bestSeqSiteMotifSpans[absPos];
				//printf("reverse hit, tot seq len = %d, numChunks = %d, absPos = %d, endPos = %d, chunkIdx = %d, span = %d, score = %lf\n", 
				//		seq.body.len, seq.numChunks, absPos, endPos, chunkIdx, bestSeqSiteMotifSpans[absPos], bestSeqSiteScores[absPos]);

				int trueChunkIdx = chunkIdx;
				while (endPos > seq.chunkOffsets[trueChunkIdx] + seq.chunkLens[trueChunkIdx])	trueChunkIdx++;
				//endPos falls in chunk trueChunkIdx
				assert(seq.chunkOffsets[trueChunkIdx] < endPos && endPos <= (seq.chunkOffsets[trueChunkIdx] + seq.chunkLens[trueChunkIdx]));
				
				int truePos = seq.chunkLens[trueChunkIdx] - (endPos - seq.chunkOffsets[trueChunkIdx]);
				set_seqSite(seqID, numChunks - 1 - trueChunkIdx, truePos, bestSeqSiteScores[absPos], bestSeqSiteMotifs[absPos], bestSeqSiteMotifSpans[absPos], isRC, onlyCountGTT, T, seqSites, nSites);
			}
		}
	}					
						
	//free allocated arrays
	free(bestSeqSiteScores);
	free(bestSeqSiteMotifs);
	free(bestSeqSiteMotifSpans);		
}

/* Recursively tries all possible gaps between the pwms of the motif, and scores each position 
 * of the sequence based on that
 * Puts the best found motif score, configuration and span in the last 3 argument vectors
 * The set of gaps being tried is encoded in the integer currentMotif.
 * If n = #PWM's, g_0, g_1, ..., g_{n-2} are the gaps chosen after each of the pwms, and 
 * m_0, m_1, ..., m_{n-2} are the min allowed gaps after each of the pwms, and M_0, M_1, ..., M_{n-2} 
 * are the max allowed gaps after each of the PWM's, then the encoding for this set of gaps is given by:
 * (((g_0 - m_0) * (M_1 - m_1 + 1) + (g_1 - m_1))*(M_2 - m_2 + 1) + ...)*(M_{n-2} - m_{n-2} + 1) + (g_{n-2} - m_{n-2})
 * This is an optimal representation, and the code ranges over all integers from 
 * 0 to (M_0 - m_0 + 1)*(M_1 - m_1 + 1)...*(M_{n-2} - m_{n-2} + 1)
 */
void get_motifSeqScoreRecur(const motif_Struct motif, const realVec seqSiteScores[], int gaps[], int pwmIdx, int currentMotif, int currentMotifSpan, 
							double bestSeqSiteScores[], int bestSeqSiteMotifs[], int bestSeqSiteMotifSpans[])	
{
	if (pwmIdx == motif.nPWMs - 1)	{
		currentMotifSpan += motif.pwms[pwmIdx].width;
		int seqLen = seqSiteScores[pwmIdx].len;
		
		int pos;
		for (pos = 0; pos < seqLen; pos++)	{	//score the motif on pos
			double score = 0.0;
			int offset = 0;
			for (pwmIdx = 0; pwmIdx < motif.nPWMs && pos + offset < seqLen; pwmIdx++)	{
				if (seqSiteScores[pwmIdx].entry[pos + offset] <= -INF)	{
					break;
				}
				score += seqSiteScores[pwmIdx].entry[pos + offset];
				offset += motif.pwms[pwmIdx].width + gaps[pwmIdx];
			}
			//printf("pos = %d, pwmIdx = %d, motif.nPWMs = %d, pos+offset = %d, seqLen = %d, score = %lf\n", pos, pwmIdx, motif.nPWMs, (pos+offset), seqLen, score);
			if (pwmIdx >= motif.nPWMs && pos + offset <= seqLen && score > bestSeqSiteScores[pos])	{
				bestSeqSiteScores[pos] = score;
				bestSeqSiteMotifs[pos] = currentMotif;
				bestSeqSiteMotifSpans[pos] = currentMotifSpan;
			}
		}
		return;
	}
	
	currentMotif *= (motif.pwms[pwmIdx].maxNextGap - motif.pwms[pwmIdx].minNextGap + 1);
	
	int g;
	for (g = motif.pwms[pwmIdx].minNextGap; g <= motif.pwms[pwmIdx].maxNextGap; g++)	{
		gaps[pwmIdx] = g;
		get_motifSeqScoreRecur (motif, seqSiteScores, gaps, pwmIdx+1, currentMotif, currentMotifSpan + g + motif.pwms[pwmIdx].width, 
								bestSeqSiteScores, bestSeqSiteMotifs, bestSeqSiteMotifSpans);
		currentMotif++;
	}
}

/** Score a non-gapped chunk with a pwm, using precomputed null word scores where the word lengths are the pwm width **/
void get_chunkPWMScores (const iLetterVec chunk, const PWM_Struct pwm, const realVec wordNullScores, realVec siteScores)	{
	int i;
	for (i = 0; i < chunk.len - pwm.width + 1; i++) {	// now compute the PWM score and add it to the overall site score
		siteScores.entry[i] = sitePwmScore(chunk.entry + i, pwm.logFreqs, pwm.width) - wordNullScores.entry[i];
	}
	// set the remaining entries to -INF since the PWM cannot fit here
	for (i = MAX2(0, chunk.len - pwm.width + 1); i < chunk.len; i++)	{
		siteScores.entry[i] = -INF;
	}
}


//leftChunkOffset is the offset within the chunk numbered chunkID in sequence seqID
void set_seqSite(int seqID, int chunkID, int leftChunkOffset, double siteScore, int motifConfig, int motifSpan, 
				 Boolean isRC, Boolean onlyCountGTT, double T, siteStruct *seqSites, int *nSites)	{
  if (onlyCountGTT) {
    if (siteScore >= T)
      (*nSites)++;
  }
  else {
    seqSites[*nSites].seqID = seqID;
    seqSites[*nSites].chunkID = chunkID;
    seqSites[*nSites].leftChunkOffset = leftChunkOffset;
    seqSites[*nSites].rev = isRC;
	seqSites[*nSites].motifSpan = motifSpan;
	seqSites[*nSites].motifConfig = motifConfig;
    seqSites[*nSites].score = siteScore;
	(*nSites)++;
  }
}

/** End new functions by Anand **/


void get_nullWordsScoreFromPosScores(realVec posNullScores, int w, realVec wordNullScores)
// figure out the word null scores from the position null scores
{
  int iw, i;

  for (iw = 0; iw < posNullScores.len-w+1; iw++) {
    if ((iw % numStableBlockSize) == 0) {	// time to compute the likelihood from scratch to ensure numerical stability
      wordNullScores.entry[iw] = 0;
      for (i = 0; i < w; i++)			// null score of first word in the new block
        wordNullScores.entry[iw] += posNullScores.entry[iw+i];
    }
    else
      wordNullScores.entry[iw] = wordNullScores.entry[iw-1] - posNullScores.entry[iw-1] + posNullScores.entry[iw+w-1];
  }
  return;
}


// ----------- Finding the top K scoring sites -----------------


int findDsTopK(gappedSeqVec set, motif_Struct motif, int K, dsSitesStruct *dsSites)
     /*
       Updates if necessary the "top k sites" entry of dsSites per the given (dataset,PWM)
       Returns the actual number of sites found.
     */
{
  unsigned char **selected;  // modeled after the set, 1 if location is selected

  if (dsSites->slctdSites.K >= K)  // we already worked it out
    trim_K(&(dsSites->slctdSites), K);               // if > K we might need to drop a few sites
  // Else, starting from srcatch isn't too costly and cleaner
  else {
    get_dsSitesScores(set, motif, dsSites); // make sure we have the scores of all sites first
    selected = allocBoolSeqVec(set);      // initialize all locations to 0
    selectTopKsites(set, K, dsSites->allSites, selected, &(dsSites->slctdSites));

    freeBoolSeqVec(set.len, selected);
  }
  return dsSites->slctdSites.topK.len;
}


void findPerSeqTopK(gappedSeqVec set, motif_Struct motif, int *Ks, dsSitesStruct *dsSites)
     /*
       Updates if necessary the "top K sites per sequence" entry of dsSites per the given (dataset,PWM)
       K is specified per sequence in Ks
     */
{
  unsigned char **selected;  // modeled after the set, =1 if location is selected
  int iSeq;

  selected = perSeqInit(set, motif, dsSites);

  for (iSeq = 0; iSeq < set.len; iSeq++)
    if (dsSites->perSeqSlctdSites.entry[iSeq].K >= Ks[iSeq])
      trim_K(dsSites->perSeqSlctdSites.entry+iSeq, Ks[iSeq]);        // if > K we might need to drop a few sites
    else
      selectTopKsites(set, Ks[iSeq], dsSites->perSeqSites.entry[iSeq], selected, dsSites->perSeqSlctdSites.entry+iSeq);

  freeBoolSeqVec(set.len, selected);
  return;
}


void trim_K(slctdSitesStruct *slctdSites, int K)
     // if K < slctdSites.K then trim if needed
{
  int n;
  if (K < slctdSites->K) {
    slctdSites->K = K;
    n = MIN2(K, slctdSites->topK.len);
    slctdSites->topK.entry = (void *) my_realloc(slctdSites->topK.entry, n*sizeof(siteStruct));
    slctdSites->topK.len = n;
  }
  return;
}


unsigned char **perSeqInit(gappedSeqVec set, motif_Struct motif, dsSitesStruct *dsSites)
     // Some common tasks prior to the per-sequence filtering
{
  unsigned char **selected;  // modeled after the set, =1 if location is selected
  int iSeq;

  if (dsSites->perSeqSlctdSites.len == 0) { // need to allocate it
    dsSites->perSeqSlctdSites.len = set.len;
    assert( dsSites->perSeqSlctdSites.entry = (void *) calloc(set.len, sizeof(slctdSitesStruct)) );
    for (iSeq = 0; iSeq < set.len; iSeq++)
      dsSites->perSeqSlctdSites.entry[iSeq].T = DBL_MAX; // make sure T won't randomly match initial value
  }
  get_dsSitesScores(set, motif, dsSites); // make sure we have the scores of all sites first
  selected = allocBoolSeqVec(set);      // initialize all locations to 0

  return selected;
}


int selectTopKsites(gappedSeqVec set, int K, siteStructVec allSites, unsigned char **selected, slctdSitesStruct *slctdSites)
     /*
       Chooses the top K non-overlapping (as defined by the implicit overlapT) sites from allSites using a 
       greedy heuristic.
       allSites is assumed sorted in a decreasing order
       Returns the actual number of sites found.
     */
{
  int iSite, nSelSite=0;
  siteStruct site;

  FREE_ATOMS_VEC(slctdSites->topK);
  assert( slctdSites->topK.entry = (void *) calloc(K, sizeof(siteStruct)) );

  for (iSite = 0; (iSite < allSites.len) && (nSelSite < K); iSite++) {
    site = allSites.entry[iSite]; // just for readbility
    if ( noOverlap(site, set.entry[site.seqID], selected[site.seqID], overlapT) ) // add new site
      slctdSites->topK.entry[nSelSite++] = allSites.entry[iSite];
  }

  slctdSites->topK.len = nSelSite;
  slctdSites->K = K;
  return nSelSite;
}


// ----------- Finding sites scoring >= T -----------------


int findDsSitesGTT(gappedSeqVec set, motif_Struct motif, double T, dsSitesStruct *dsSites)
     /*
       Updates if necessary the "sites above T" entry of dsSites per the given dataset and motif.
       Returns the number of sites found across the dataset.
     */
{
  unsigned char **selected;  // modeled after the set, =1 if location is selected

  if ( dsSites->slctdSites.T == T )  // we already worked it out
    return dsSites->slctdSites.aboveT.len;
  // Else, starting from srcatch isn't too costly and cleaner than building from bigger/smaller T
  get_dsSitesScores(set, motif, dsSites); // make sure we have the scores of all sites first
  selected = allocBoolSeqVec(set);      // initialize all locations to 0
  selectGTTsites(set, T, dsSites->allSites, selected, &(dsSites->slctdSites));

  freeBoolSeqVec(set.len, selected);

  return dsSites->slctdSites.aboveT.len;
}


void findPerSeqSitesGTT(gappedSeqVec set, motif_Struct motif, double T, dsSitesStruct *dsSites)
     /*
       Updates if necessary the "per sequence sites above T" entry of dsSites per the given dataset and motif.
       Currently T is shared across all sequences. 
       Also since we use a greedy approach we could change the test == T to <= T and drop the sites that fail
       the more stringent threshold.
     */
{
  unsigned char **selected;  // modeled after the set, =1 if location is selected
  int iSeq, nSelSite=0;

  if ( (dsSites->perSeqSlctdSites.len > 0) && (dsSites->perSeqSlctdSites.entry[0].T == T) )  
    return;    // we already worked it out: all sequences are assumed to share the same T!!!
  selected = perSeqInit(set, motif, dsSites);

  for (iSeq = 0; iSeq < set.len; iSeq++)
    nSelSite += selectGTTsites(set, T, dsSites->perSeqSites.entry[iSeq], selected, dsSites->perSeqSlctdSites.entry+iSeq);

  freeBoolSeqVec(set.len, selected);
  return;
}


int selectGTTsites(gappedSeqVec set, double T, siteStructVec allSites, unsigned char **selected, slctdSitesStruct *slctdSites)
     /*
       Chooses non-overlapping sites >= T from allSites
       Implicit input: overlapT
       allSites is assumed sorted in a decreasing order
       Returns the actual number of sites found.
     */
{
  int iSite, nSelSite=0;
  siteStruct site;

  FREE_ATOMS_VEC(slctdSites->aboveT);
  assert( slctdSites->aboveT.entry = (void *) calloc(allSites.len, sizeof(siteStruct)) );

  for (iSite = 0; iSite < allSites.len; iSite++) {
    site = allSites.entry[iSite]; // just for readbility
    if (site.score < T)           // sites scores are assumed to be decreasing
      break;
    if ( noOverlap(site, set.entry[site.seqID], selected[site.seqID], overlapT) ) // add new site
      slctdSites->aboveT.entry[nSelSite++] = allSites.entry[iSite];
  }

  slctdSites->aboveT.len = nSelSite;
  slctdSites->T = T;
  slctdSites->aboveT.entry = (void *) my_realloc(slctdSites->aboveT.entry, nSelSite*sizeof(siteStruct)); // drop the wasted space
  return nSelSite;
}


Boolean noOverlap(siteStruct site, gappedSeq seq, unsigned char *selected, double overlapT)
     /*
       Adds the new site to the selected locations and returns TRUE iff no overlap is detected.
       Otherwise FALSE is returned and selected isn't modified
		FIXME: How to quantify the overlap threshold if the motif has variable span???
     */
{
  int i, overlap=0, siteLeftPos;

  siteLeftPos = seq.chunkOffsets[site.chunkID] + site.leftChunkOffset;
  for (i = siteLeftPos; i < siteLeftPos + site.motifSpan; i++)
    overlap += selected[i];
  if (overlap <= floor(overlapT * site.motifSpan)) {
    for (i = siteLeftPos; i < siteLeftPos + site.motifSpan; i++)
      selected[i] = 1;
    return TRUE;
  }
  else
    return FALSE;
}


// ----------- Setting user input -----------------

void setNullTrainSeqFromFile(char *fname)	{
	nullMarkovModel = readMarkovModelFromFile(fname);
}

void setNullTrainSeq(iLetterVec trainData, int order, double pseudoCount, int scanMode)
     // Generate the null Markov model nullMarkovModel once and for all (trainData isn't saved)
{
  nullMarkovModel = trainMarkovModel(trainData, order, pseudoCount, scanMode);
}


void set_avg2StrandsNull(Boolean value) {avg2StrandsNull = value;}


void set_scanDirection(int value) {scanDirection = value;}


void set_overlapThreshold(double value) {overlapT = value; }


void set_printTopK(Boolean value) {printTopK = value; }


void set_printGTT(Boolean value) {printGTT = value; }



// --------------- others ----------------
int siteStructCompare(const siteStruct *e1, const siteStruct *e2)
     // compares the scores of two sites (used for descending qsort order)
{
  return (e1->score > e2->score) ? -1 : (e1->score == e2->score)? 0 : 1;
}


// ---------------  memory management ------------------


void free_dsSitesStruct(dsSitesStruct *dsSites)
		// Frees all allocated array
{
	int i;
	
	FREE_ATOMS_VEC(dsSites->allSites);
	free_slctdSitesStruct(&(dsSites->slctdSites));
	for (i = 0; i < dsSites->perSeqSites.len; i++)
		FREE_ATOMS_VEC(dsSites->perSeqSites.entry[i]);
	FREE_ATOMS_VEC(dsSites->perSeqSites);
	for (i = 0; i < dsSites->perSeqSlctdSites.len; ++i)
		free_slctdSitesStruct(dsSites->perSeqSlctdSites.entry+i);

	FREE_ATOMS_VEC(dsSites->perSeqSlctdSites);
}

void free_slctdSitesStruct(slctdSitesStruct *s)
{
	FREE_ATOMS_VEC(s->topK);
	FREE_ATOMS_VEC(s->aboveT);
	s->K = 0;
	s->T = DBL_MAX;
}


unsigned char **allocBoolSeqVec(gappedSeqVec set)
     /*
       Allocate and initialize to 0 an unsigned char array of vectors. 
       The number of vectors is the same as the number of sequences in set and each vector
       is as long at the raw sequence. This could have been more frugal but it will probably do.
     */
{
  int iSeq;
  unsigned char **bool;

  assert( bool = (void *) calloc(set.len, sizeof(unsigned char *)) );
  for (iSeq = 0; iSeq < set.len; iSeq++)
    assert( bool[iSeq] = calloc(set.entry[iSeq].body.len, sizeof(unsigned char)) );
  return bool;
}


void freeBoolSeqVec(int nSeqs, unsigned char **bool)
     // Frees the allocated array
{
  int iSeq;
  for (iSeq = 0; iSeq < nSeqs; iSeq++)
    free(bool[iSeq]);
  free(bool);
}




// ---------------  printing good sites ------------------

void printSites(gappedSeqVec *set, dsSitesStruct *dsSites, motif_Struct *motif, char *setName, FILE *output)
       // call the apprpriate functions to print the good sites
{
  if (printGTT)
    fprintf(output, "                                             threshold = %f\n", motif->siteThreshold[0]);
  if (dsSites->allSites.len > 0) { // already found all sites so Jeff's code works fine
    if(printTopK)
      printTopKSites(set, dsSites, motif, setName, output);
    if(printGTT)
      printGTTSites(set, dsSites, motif, setName, output);
  }
  else {
    if(printGTT)
      printPerSeqSites(set, dsSites, motif, setName, output, TRUE);
    if(printTopK)
      printPerSeqSites(set, dsSites, motif, setName, output, FALSE);
  }
}


void printTopKSites(gappedSeqVec *set, dsSitesStruct *dsSites, motif_Struct *motif, char *setName, FILE *output)
       // Print the Top K sites in the dsSites array
{
       // this will find the top K sites, and then call printSite on them
       int siteIt, seqIt;
       int *Ks;
	   int totalNumScores;
       siteStruct curSite;

       // just in case the top K sites haven't been found yet
       assert( Ks = (void *) malloc(set->len * sizeof(int)) );
	   for (seqIt = 0; seqIt < set->len; seqIt++)
			   Ks[seqIt] = Kprint; // set how many sites to print, in this case K as defined by set_printSitesTopK()
       for(seqIt = 0; seqIt < set->len; seqIt++)
       {
               findPerSeqTopK(*set, *motif, Ks, dsSites);
       }

	   // find how many scores we are going to print
	   totalNumScores = 0;
	   for(seqIt = 0; seqIt < set->len; seqIt++)
       {
				totalNumScores += dsSites->perSeqSlctdSites.entry[seqIt].topK.len;
       }

       // print top K sites
       fprintf(output, "\n");
	   fprintf(output, "The following %d scores come from the %s\n", totalNumScores, setName);
       for(seqIt = 0; seqIt < set->len; seqIt++)
       {
               for(siteIt = 0; siteIt < dsSites->perSeqSlctdSites.entry[seqIt].topK.len; siteIt++)
               {
                       curSite = dsSites->perSeqSlctdSites.entry[seqIt].topK.entry[siteIt];
                       printSite(set, curSite.seqID, curSite.chunkID, curSite.leftChunkOffset, curSite.rev, curSite.score, *motif, curSite.motifConfig, curSite.motifSpan, output);
               }
	   fprintf(output, "\n");
       }
	   fprintf(output, "\n\n");

       free(Ks);
}


void printGTTSites(gappedSeqVec *set, dsSitesStruct *dsSites, motif_Struct *motif, char *setName, FILE *output)
       // Print the top sites above the threshold in the dsSites array
{
       // this will find the GTT sites, and then call printSite on them
       int siteIt;
       siteStruct curSite;

       // just in case the GTT sites haven't been found yet
       findDsSitesGTT(*set, *motif, motif->siteThreshold[0], dsSites);

       // print GTT sites
       fprintf(output, "\n\n");
	   fprintf(output, "The following %d scores come from the %s\n", dsSites->slctdSites.aboveT.len, setName);

       for(siteIt = 0; siteIt < dsSites->slctdSites.aboveT.len; siteIt++)
       {
               curSite = dsSites->slctdSites.aboveT.entry[siteIt];
               printSite(set, curSite.seqID, curSite.chunkID, curSite.leftChunkOffset, curSite.rev, curSite.score, *motif, curSite.motifConfig, curSite.motifSpan, output);
       }
	   fprintf(output, "\n\n");
}


void set_printSitesTopK(char *Kstr)
// defines "K" when printing the Top K sites from each sequence
{
       Kprint = my_atol(Kstr);
}


void printPerSeqSites(gappedSeqVec *set, dsSitesStruct *dsSites, motif_Struct *motif, char *setName, FILE *output, Boolean GTTflag)
     /*
       Print either the top K sites or the GTT sites in each sequence of the input set
       Much more memory efficient than the original print functions which keep the
       sites across the entire set.
     */
{
  unsigned char **selected;  // modeled after the set, =1 if location is selected
  siteStructVec seqSites, printSites;
  int iSeq, iSite;
  slctdSitesStruct slctdSites;
  siteStruct curSite;

  fprintf(output, "\n\n");
  fprintf(output, "The following scores come from the %s\n", setName);

  selected = allocBoolSeqVec(*set);      // initialize all locations to 0 (a bit wasteful)
  slctdSites.aboveT.len = 0;
  slctdSites.topK.len = 0;
  for (iSeq = 0; iSeq < set->len; iSeq++) {	
    seqSites = gen_seqSitesScores(set->entry[iSeq], *motif, iSeq);
	qsort((void *) seqSites.entry, seqSites.len, sizeof(siteStruct), (void *) siteStructCompare); // descending order
    if (GTTflag) {
      selectGTTsites(*set, *(motif->siteThreshold), seqSites, selected, &slctdSites);
      printSites = slctdSites.aboveT;
    }
    else {
      selectTopKsites(*set, Kprint, seqSites, selected, &slctdSites);
      printSites = slctdSites.topK;
    }
    FREE_ATOMS_VEC(seqSites);
	//printf("Chosen %d sites for printing\n", printSites.len);
	for(iSite = 0; iSite < printSites.len; iSite++)
      {
	curSite = printSites.entry[iSite];
	printSite(set, curSite.seqID, curSite.chunkID, curSite.leftChunkOffset, curSite.rev, curSite.score, *motif, curSite.motifConfig, curSite.motifSpan, output);
      }
    if (printSites.len > 0)
      fprintf(output, "\n");
  }
  fprintf(output, "\n\n\n");
  freeBoolSeqVec(set->len, selected);
  FREE_ATOMS_VEC(slctdSites.aboveT);
  FREE_ATOMS_VEC(slctdSites.topK);
}
