/*
  Manage the PWMs
  This module is not well written. It should be rewritten at some point.
*/

#include <stdlib.h>
#include <stdio.h>
#include <ctype.h>
#include <assert.h>
#include <string.h>
#include <float.h>
#include <math.h>
#include "my_types.h"
#include "data_interface.h"
#include "misc_functions.h"
#include "markov.h"
#include "pwms.h"
#include "motif_scan.h"

static unsigned int motif_ID=0; // a global motif identifier used to associate other data with the motif

#define N_MOTIF_INC 50				/* Incremental number of motifs for which memory is allocated */
#define MAX_HEADER_LENGTH 80	/* How much of the header will be kept  */
#define N_POS_INC 10			/* Incremental number of rows of weight matrixs for which memory is allocated */
#define MAXLINE 100
#define NCLNUM 4
#define PWM_IND(pos,let) (pos * NCLNUM + let)


static double pwmPC; // total pseudocount
          // thresholds related variables
static double sitePercentThresh = -DBL_MAX;
static unsigned int siteThreshTrainSize;
static int siteThresholdScanDir;
static Boolean nullMarkovBlock;
extern MarkovModelType nullMarkovModel;
extern iLetterVec nullTrainData;

         // internal methods
char *move_to_next_motif_start(FILE *fp);
int read_wm_into_buffer(FILE *fp, PWM_Struct *pwm, char *line);
void pseudocount(double **matr,int row,int col);
void norm_freq(double **matr,int row,int col);
int word_cnt(char *s);
void word_split(char *s,double *freq);
double *TwoDarray2LogVector(double **array, int width);



motifsVec read_motifs(char * wmfile)
     // read the weight matrix from a file
{
	motifsVec motifs={0,NULL};
	FILE *matfile;
	char line[MAXLINE], dots[MAXLINE];

	assert(matfile=fopen(wmfile,"r"));
	int n_motifs=0, n_max_motifs=0;
	motif_Struct *motif;
	PWM_Struct *pwm;
	
	while(!feof(matfile))
	{
		if (n_max_motifs <= n_motifs)
		  assert(motifs.entry = realloc(motifs.entry, (n_max_motifs+=N_MOTIF_INC)*sizeof(motif_Struct)) );
		motif = motifs.entry + n_motifs;     // readability
		memset(motif, 0, sizeof(motif_Struct));
		if ( (motif->header = move_to_next_motif_start(matfile)) == NULL)
		  break;
		motif->nPWMs = 0; motif->pwms = NULL;

		do {
			assert( motif->pwms = realloc(motif->pwms, sizeof(PWM_Struct)*( ++ motif->nPWMs)) );
			pwm =  motif->pwms + ((motif->nPWMs)-1);
			if (read_wm_into_buffer(matfile, pwm, line) == 0) {
			  printf( "Weight matrix %d of motif %s has 0 length!!! Goodbye!\n", motif->nPWMs+1, motif->header);
			  exit(0);
			}
			//motif -> width += pwm->width;
			if (line[0] == '<')
				break;
			if (sscanf(line, "%s %d %d", dots, &(pwm->minNextGap), &(pwm->maxNextGap)) != 3 || strcmp(dots, "...") != 0 ) {
			  printf( "Gap info after matrix %d of motif %s should be: '... min max'. Please fix it!\n", motif->nPWMs+1, motif->header);
			  exit(0);
			}
			if (pwm->minNextGap > pwm->maxNextGap ) {
			  printf( "Gap info after matrix %d of motif %s: Require min <= max. Please fix it!\n", motif->nPWMs+1, motif->header);
			  exit(0);
			}
			//(motif->width) += (pwm->maxNextGap);
		} while (TRUE);

		motif->id = motif_ID++;
		n_motifs++;
	}

	fclose(matfile);
	motifs.len = n_motifs;
	assert(motifs.entry = realloc(motifs.entry, n_motifs*sizeof(motif_Struct)) );
	return motifs;
}


char *move_to_next_motif_start(FILE *fp)
{
	int h_len=0;
	char c, buffer[MAX_HEADER_LENGTH], *motifHeader;

	while (!feof(fp)  &&  ((c=getc(fp)) != '>'));
	if (feof(fp))	
	  return NULL;

	buffer[h_len++] = '>';
	while (!feof(fp)  &&  ((c = getc(fp)) != '\n'))
	  if (h_len < MAX_HEADER_LENGTH-1)
	    buffer[h_len++] = c;
	buffer[h_len] = '\0';
	motifHeader = allocAndStrCpy(buffer);

	return motifHeader;
}


int read_wm_into_buffer(FILE *fp, PWM_Struct *pwm, char *line)		
{
	int n_max_pos=0, n_pos=0,i,cnt;

	pwm->entry = NULL;
	while (!feof(fp)) 
	{
		if (n_pos>=n_max_pos)
		{
			assert(pwm->entry = (double **)realloc(pwm->entry,(n_max_pos+=N_POS_INC) * sizeof(double *)));
			for(i=n_pos;i<n_max_pos;i++)
				assert(pwm->entry[i]=(double *)malloc(NCLNUM*sizeof(double)));
		}
		fgets(line,MAXLINE,fp);
		if(line[0]=='<' || strncmp(line,"...", 3) == 0)	break;
		
		cnt=word_cnt(line);
		assert(cnt==NCLNUM);
		word_split(line,pwm->entry[n_pos]);		//read the weight matrix to wmatr structure
			
		n_pos++;
	}
	pseudocount(pwm->entry,n_pos, NCLNUM);        // fix any zero entries in the matrix with a pseudocount
	norm_freq(pwm->entry,n_pos, NCLNUM);			//normalize the rows of the weight matrix
	pwm->logFreqs = TwoDarray2LogVector(pwm->entry, n_pos);		// create log entries
	
	return (pwm->width = n_pos);
}


/* PWM_Struct alloc_pwm(int width, charVec header) */
/*      // Allocate the PWM structure while initializing some values */
/* { */
/*   PWM_Struct pwm; */

/*   pwm.width = width; */
/*   pwm.siteThreshold = NULL; */
/*   assert( pwm.entry = calloc(4*width, sizeof(double)) );   */
/*   assert( pwm.reverseCompEntry = calloc(4*width, sizeof(double)) ); */
/*   pwm.id = PWM_ID++; */
/*   pwm.header = header; */

/*   return pwm; */
/* } */


void pseudocount(double **matr,int row,int col)
/*
	This function checks to see if there are any 0 entries in a row. if there are, it sums the values
		of that row, and adds 10% of that sum to each entry.  This ensures that log(0) = -infinity
		won't ruin any log-likelihood-ratio calculations.
	input: matr = weight matrix, row = number of rows, col = number of columns
	output: matr is modified so it no longer has any 0 entries as described above

	Uri 6/22: The pseudocount as written originally was 40%! Moreover it wasn't consistently applied.
*/
{
	int i,j;
	double sum,diff;

	for(i=0;i<row;i++)
	{
		sum=0;
		for(j=0;j<col;j++)
			sum+=matr[i][j];

		diff = (sum / col) * pwmPC;
			for(j=0;j<col;j++)
				matr[i][j]+= diff;
	}
}




void norm_freq(double **matr,int row,int col)
/*
 Function norm_freq is used to normalize the rows of the weight matrix "matr", "row" is the number of rows in this matrix 
 "col" is the number of columns of this matrix
*/
{
	int i,j;
	double sum;
	for(i=0;i<row;i++)
	{
		sum=0;
		for(j=0;j<col;j++)
		sum+=matr[i][j];

		for(j=0;j<col;j++)
		matr[i][j]=matr[i][j]/sum;
	}
}

int word_cnt(char *s)
/*
 Function word_cnt is to count the number of words in a string "s"
*/
{
	int cnt=0;
	while(*s !='\0')
	{	while(isspace(*s)) ++s;					//skip white space
		if(*s!='\0')							//found a word
		{	++cnt;
			while(!isspace(*s)&&*s!='\0')		//skip word
			++s;
		}
	}
	return cnt;
}


void word_split(char *s,double *freq)
/*
 Function word_split is to split a string "s" into float numbers and store these float numbers into array "freq"
8/28/07: Uri removed unused num parameter
*/
{
	int cnt=0,i;
	char *p;
	assert(p=(char *)malloc(MAXLINE*sizeof(char)));
	while(*s !='\0')
	{	while(isspace(*s)) ++s;					//skip white space
		if(*s!='\0')							//found a word
		{	i=0;
			while(!isspace(*s)&&*s!='\0')		//translate the word into float number
			{	p[i++]=*s++;}
			p[i]='\0';
			freq[cnt++]=(double) atof(p);
		}
	}
	free(p);
}


void permute_motif (motif_Struct *motif)	{
	int i, j, k, idx, base, totWidth;
	
	totWidth = 0;
	for (k = 0; k < motif->nPWMs; k++)	{
		totWidth += motif->pwms[k].width;
	}
	
	int* perm = malloc(totWidth * sizeof(int));
	perm = rand_perm_prealloc(perm, totWidth);
	
	double*** tmpList = malloc(motif->nPWMs * sizeof(double**));
	int ip = 0;
	for (k = 0; k < motif->nPWMs; k++)	{
		int motifWidth = motif->pwms[k].width;
		double** tmp = malloc(motifWidth * sizeof(double*));
		for (i = 0; i < motifWidth; i++)	{
			tmp[i] = malloc(NCLNUM * sizeof(double));
			idx = perm[ip++] - 1;
			j = 0;
			while (j < motif->nPWMs && idx >= motif->pwms[j].width)	{
				idx -= motif->pwms[j].width;
				j++;
			}
			assert(j < motif->nPWMs && idx < motif->pwms[j].width);
			for (base = 0; base < NCLNUM; base++)	{
				tmp[i][base] = motif->pwms[j].entry[idx][base];
			}
		}
		tmpList[k] = tmp;
	}
	
	for (k = 0; k < motif->nPWMs; k++)	{
		PWM_Struct *motifPWM = motif->pwms + k;
		for (i = 0; i < motifPWM->width; i++)	{
			free(motifPWM->entry[i]);
		}
		free(motifPWM->entry);
		motifPWM->entry = tmpList[k];
		
		free(motifPWM->logFreqs);
		motifPWM->logFreqs = TwoDarray2LogVector(motifPWM->entry, motifPWM->width);	
	}
	
	free(tmpList);
	
	free(perm);
}


// void permute_wm(motif_Struct *motifs)
// // This function isn't working as it stands ****
//      // *** This needs urgent rewriting: a new structure with a new PWM_ID needs to be created ***
// // randomly permutes the weight matrix
// //  first creates an array of permuted indices to the weight matrix, then uses it
// //  to shuffle the values of the weight matrix
// // Uri: 6/22/07 fixed a bug whereby the reverse of the permuted matrix was off 
// //      and cleaned the code along the way.
// {
// 	int i,j;
// 	double **tempwm;
// 
// 	int *perm, width;
// 
// 	width = pwm->width;
// 	assert(perm = (int *)malloc(width*sizeof(int)));
// 	perm = rand_perm_prealloc(perm, width);
// 
// 	for (i=0; i<width; i++)
// 	  printf("%d ",perm[i]);
// 	printf("\n");
// 
// 	assert(tempwm=(double **)malloc(width*sizeof(double *)));
// 	for (i=0;i<width;i++)
// 		assert(tempwm[i]=(double *)malloc(NCLNUM*sizeof(double)));
// 	
// 	
// 	for (i=0;i<width;i++)
// 	{
// 		for (j=0;j<NCLNUM;j++)
// 			tempwm[i][j]=pwm->entry[i][j];
// 	}
// 	for (i=0;i<width;i++)
// 	{
// 		for (j=0;j<NCLNUM;j++)
// 		pwm->entry[i][j]=tempwm[perm[i]-1][j];
// 	}
// 
// 	for (i=0; i<width; i++)
// 	{
// 		for(j=0; j<4; j++)
// 			printf("%f ",pwm->entry[i][j]);
// 		printf("\n");
// 	}		
// 	free(perm);
// 	for (i=0; i<width; i++)
// 	  free(tempwm[i]);
// 	free(tempwm);
// 
// 	free(pwm->logFreqs); // patch
// 	pwm->logFreqs = TwoDarray2LogVector(pwm->entry, pwm->width);
// 	pwm -> id = PWM_ID++;
// }


double *TwoDarray2LogVector(double **array, int width)
     // Returns log(array ) in vector format
{
  int i, j;
  double *vec;

  assert( vec = (void *) malloc(width * NCLNUM * sizeof(double)) );
  for (i = 0; i < width; i++)
    for(j = 0; j < NCLNUM; j++)
      vec[PWM_IND(i,j)] = log(array[i][j]);
  return vec;
}


double sitePwmScore(iLetter *site, double *pwm, int width)
     // Return the PWM score of a site
{
  int i;
  double score=0;

  for (i = 0; i < width; i++)
    score += pwm[PWM_IND(i, site[i])];
  return score;
}
// 
// 
// double get_sitePercentThresh(PWM_Struct pwm)
//      // Retrieve or find a nominal % given the threshold
// {
//   siteStructVec trainSites;
//   double nScoresGTT=0;
//   int i;
// 
//   if (sitePercentThresh > -DBL_MAX) // the % was explicitly specified and should match the threshold
//     return sitePercentThresh;
//   assert( trainSites.entry = (void *) malloc((1+(siteThresholdScanDir==2))*nullTrainData.len * sizeof(siteStruct)) );
//            // get scores of all sites (forward direction only)
//   trainSites.len = get_chunkSites(nullTrainData, pwm, siteThresholdScanDir, 0, 0, trainSites.entry, FALSE, 0);
//   for (i = 0; i < trainSites.len; i++) // count the # of sites above threshold
//     nScoresGTT += (trainSites.entry[i].score > *(pwm.siteThreshold));
// 
//   free(trainSites.entry);
//   return nScoresGTT / trainSites.len;
// }


void set_siteThreshold(motif_Struct *motif)
     // find threshold from nominal % one + training file
{
  siteStructVec trainSites;
  double nPercSites;
  iLetterVec trainData;
  int indPerc;

  if (motif->siteThreshold == NULL) { // threshold wasn't set yet
    if (nullMarkovBlock)
      trainData = genMarkovBlock(nullMarkovModel, siteThreshTrainSize);
    else
      trainData = nullTrainData;
    assert( trainSites.entry = (void *) malloc((1+(siteThresholdScanDir==2))*trainData.len * sizeof(siteStruct)) );
           // get scores of all sites (forward direction only)
    trainSites.len = get_chunkSites(trainData, *motif, siteThresholdScanDir, 0, 0, trainSites.entry, FALSE, 0);
	//trainSites.len = get_gappedSeqMotifScores(trainData, *motif, siteThresholdScanDir, 0, 0, trainSites.entry, FALSE, 0.0);
    qsort((void *) trainSites.entry, trainSites.len, sizeof(siteStruct), (void *) siteStructCompare);
          // set threshold T so that no more than sitePercentThresh * trainSites.len sites are above T
    nPercSites = sitePercentThresh * trainSites.len;
    indPerc = floor(nPercSites);
    if (indPerc < trainSites.len-1) // interpolate the threshold
      motif->siteThreshold = allocAndSetD(trainSites.entry[indPerc].score + 
	             (trainSites.entry[indPerc+1].score - trainSites.entry[indPerc].score) * (nPercSites-indPerc));
    else
      motif->siteThreshold = allocAndSetD(trainSites.entry[indPerc].score);
    free(trainSites.entry);
    if (nullMarkovBlock)
      free(trainData.entry);
  }
}

void set_threshGenerationData(double sitePercentThreshValue, unsigned int siteThreshTrainSizeValue, int siteThresholdBlockType, int siteThresholdScanDirValue)
     // Set the parameters according to which the site thresholds are learned on demand
{
  sitePercentThresh = sitePercentThreshValue;
  siteThreshTrainSize = siteThreshTrainSizeValue;
  nullMarkovBlock = 1 - siteThresholdBlockType;
  siteThresholdScanDir = siteThresholdScanDirValue;
}


void set_siteThresholdsFromFile(char *thresholdsFileName, motifsVec motifs)
     // Reads the thresholds from a file and sets the appropriate slots of motifs
{
	FILE *tfile;
	char *line;
	int numread, num=motifs.len;
	assert(tfile=fopen(thresholdsFileName,"r"));

	assert(line = (char *)malloc(100*sizeof(char)));
	numread = 0;
	// skip white spaces before the first threshold
	while (!feof(tfile) && isspace(line[0] = fgetc(tfile)) ) { }
	// read the first num thresholds
	while(!feof(tfile) && numread < num){
		fgets(line+1,100,tfile);
		// truncate line feeds and carriage returns, if they're included (DOS-style)
		if(line[strlen(line)-1] == 10)
			line[strlen(line)-1] = 0;
		if(line[strlen(line)-1] == 13)
			line[strlen(line)-1] = 0;
		if(strlen(line) >= 1){
			motifs.entry[numread++].siteThreshold = allocAndSetD(my_atod(line));
			line[0] = 0;
		}
		// skip white spaces to get to the next threshold
		while (!feof(tfile) && isspace(line[0] = fgetc(tfile)) ) { }
	}

	if(numread < num)
		// if the file contained less thresholds than we needed, that is BAD, and we should throw an error
		ERROR(("There were %d thresholds (we needed %d) in %s\n"
				      "Please ensure there is at most one threshold per line.",
				      numread, num, thresholdsFileName) )
	else if(!feof(tfile))
		// if we aren't at the end of the file yet, that means there were more thresholds than we wanted
		//  that is BAD, and we should throw an error
		ERROR(("There were too many thresholds in %s\n"
				      "Please include only %d thresholds.",
				      thresholdsFileName, num) )
	fclose(tfile);
	free(line);
}


void set_pwmPC(double pc)
     // Set the total pseudocount
{
  pwmPC = pc;
}
