/*
  Manage the PWMs
  This module is not well written. It should be rewritten at some point.
*/

#include <stdlib.h>
#include <stdio.h>
#include <ctype.h>
#include <assert.h>
#include <string.h>
#include <float.h>
#include <math.h>
#include "my_types.h"
#include "misc_functions.h"
#include "data_interface.h"
#include "markov.h"
#include "pwms.h"
#include "motif_scan.h"

static unsigned int PWM_ID=0; // a global PWM identifier used to associate other data with the PWM

#define N_WM_INC 50				/* Incremental number of weight matrixs for which memory is allocated */
#define MAX_HEADER_LENGTH 80	/* How much of the header will be kept  */
#define N_POS_INC 10			/* Incremental number of rows of weight matrixs for which memory is allocated */
#define MAXLINE 100
#define NCLNUM 4
#define PWM_IND(pos,let) (pos * NCLNUM + let)


static double pwmPC; // total pseudocount
          // thresholds related variables
static double sitePercentThresh = -DBL_MAX;
static unsigned int siteThreshTrainSize;
static Boolean nullMarkovBlock;
extern MarkovModelType nullMarkovModel;
extern iLetterVec nullTrainData;

         // internal methods
char *move_to_next_wm_start(FILE *fp);
int read_wm_into_buffer(FILE *fp, PWM_Struct *pwm);
void pseudocount(double **matr,int row,int col);
void norm_freq(double **matr,int row,int col);
int word_cnt(char *s);
void word_split(char *s,double *freq);
void rev_wm(double ** wm, double **rev,int row,int col);
double *TwoDarray2LogVector(double **array, int width);



PWMsVec read_pwms(char * wmfile)
     // read the weight matrix from a file
{
	PWMsVec pwms={0,NULL};

	FILE *matfile;
	assert(matfile=fopen(wmfile,"r"));
	int n_wms=0,n_max_wms=0;
	PWM_Struct *pwm;
	
	while(!feof(matfile))
	{
		if (n_max_wms <= n_wms)
		  assert(pwms.entry = realloc(pwms.entry, (n_max_wms+=N_WM_INC)*sizeof(PWM_Struct)) );
		pwm = pwms.entry + n_wms;     // readability
		memset(pwm, 0, sizeof(PWM_Struct));
		if ( (pwm->header = move_to_next_wm_start(matfile)) == NULL)
		  break;
		if (read_wm_into_buffer(matfile, pwms.entry+n_wms) == 0) {
		  printf( "Weight matrix %d has 0 length!!! Goodbye!\n", n_wms+1);
		  exit(0);
		}
		pwm->id = PWM_ID++;
		  // create log entries
		pwm->logFreqs = TwoDarray2LogVector(pwm->entry, pwm->width);
		pwm->logFreqsRC = TwoDarray2LogVector(pwm->reverseCompEntry, pwm->width);
		n_wms++;
	}
	
	fclose(matfile);
	pwms.len = n_wms;
	assert(pwms.entry = realloc(pwms.entry, n_wms*sizeof(PWM_Struct)) );
	return pwms;
}


char *move_to_next_wm_start(FILE *fp)
{
	int i, h_len=0;
	char c, buffer[MAX_HEADER_LENGTH], *pwmHeader;

	while (!feof(fp)  &&  ((c=getc(fp)) != '>'));
	if (feof(fp))	
	  return NULL;

	buffer[h_len++] = '>';
	while (!feof(fp)  &&  ((c = getc(fp)) != '\n'))
	  if (h_len < MAX_HEADER_LENGTH-1)
	    buffer[h_len++] = c;
	buffer[h_len] = '\0';
	pwmHeader = allocAndStrCpy(buffer);

	return pwmHeader;
}


int read_wm_into_buffer(FILE *fp, PWM_Struct *pwm)		
{
	int n_max_pos=0, n_pos=0,i,cnt;
	char *line;
	assert(line=(char *)malloc(MAXLINE*sizeof(char)));

	pwm->entry = NULL;
	while (!feof(fp)) 
	{
		if (n_pos>=n_max_pos)
		{
			assert(pwm->entry = (double **)realloc(pwm->entry,(n_max_pos+=N_POS_INC) * sizeof(double *)));
			for(i=n_pos;i<n_max_pos;i++)
				assert(pwm->entry[i]=(double *)malloc(NCLNUM*sizeof(double)));
		}
		fgets(line,MAXLINE,fp);
		if(line[0]=='<')	break;
		
		cnt=word_cnt(line);
		assert(cnt==NCLNUM);
		word_split(line,pwm->entry[n_pos]);		//read the weight matrix to wmatr structure
			
		n_pos++;
	}
	pseudocount(pwm->entry,n_pos, NCLNUM);        // fix any zero entries in the matrix with a pseudocount
	norm_freq(pwm->entry,n_pos, NCLNUM);			//normalize the rows of the weight matrix
	
	assert(pwm->reverseCompEntry = (double **)malloc(n_pos * sizeof(double *)));	//allocate the space for the reversed weight matrix
	for(i=0;i<n_pos;i++)
		assert(pwm->reverseCompEntry[i]=(double *)malloc(NCLNUM*sizeof(double)));
	rev_wm(pwm->entry,pwm->reverseCompEntry,n_pos, NCLNUM);			//generate the reversed weight matrix according to the normal weight matrix
	
	free(line);
	return (pwm->width = n_pos);
}


/* PWM_Struct alloc_pwm(int width, charVec header) */
/*      // Allocate the PWM structure while initializing some values */
/* { */
/*   PWM_Struct pwm; */

/*   pwm.width = width; */
/*   pwm.siteThreshold = NULL; */
/*   assert( pwm.entry = calloc(4*width, sizeof(double)) );   */
/*   assert( pwm.reverseCompEntry = calloc(4*width, sizeof(double)) ); */
/*   pwm.id = PWM_ID++; */
/*   pwm.header = header; */

/*   return pwm; */
/* } */


void pseudocount(double **matr,int row,int col)
/*
	This function checks to see if there are any 0 entries in a row. if there are, it sums the values
		of that row, and adds 10% of that sum to each entry.  This ensures that log(0) = -infinity
		won't ruin any log-likelihood-ratio calculations.
	input: matr = weight matrix, row = number of rows, col = number of columns
	output: matr is modified so it no longer has any 0 entries as described above

	Uri 6/22: The pseudocount as written originally was 40%! Moreover it wasn't consistently applied.
*/
{
	int i,j;
	double sum,diff;

	for(i=0;i<row;i++)
	{
		sum=0;
		for(j=0;j<col;j++)
			sum+=matr[i][j];

		diff = (sum / col) * pwmPC;
			for(j=0;j<col;j++)
				matr[i][j]+= diff;
	}
}




void norm_freq(double **matr,int row,int col)
/*
 Function norm_freq is used to normalize the rows of the weight matrix "matr", "row" is the number of rows in this matrix 
 "col" is the number of columns of this matrix
*/
{
	int i,j;
	double sum;
	for(i=0;i<row;i++)
	{
		sum=0;
		for(j=0;j<col;j++)
		sum+=matr[i][j];

		for(j=0;j<col;j++)
		matr[i][j]=matr[i][j]/sum;
	}
}

int word_cnt(char *s)
/*
 Function word_cnt is to count the number of words in a string "s"
*/
{
	int cnt=0;
	while(*s !='\0')
	{	while(isspace(*s)) ++s;					//skip white space
		if(*s!='\0')							//found a word
		{	++cnt;
			while(!isspace(*s)&&*s!='\0')		//skip word
			++s;
		}
	}
	return cnt;
}


void word_split(char *s,double *freq)
/*
 Function word_split is to split a string "s" into float numbers and store these float numbers into array "freq"
8/28/07: Uri removed unused num parameter
*/
{
	int cnt=0,i;
	char *p;
	assert(p=(char *)malloc(MAXLINE*sizeof(char)));
	while(*s !='\0')
	{	while(isspace(*s)) ++s;					//skip white space
		if(*s!='\0')							//found a word
		{	i=0;
			while(!isspace(*s)&&*s!='\0')		//translate the word into float number
			{	p[i++]=*s++;}
			p[i]='\0';
			freq[cnt++]=(double) atof(p);
		}
	}
	free(p);
}

void rev_wm(double ** wm, double **rev,int row,int col)
/*
 * Function rev_wm is used to reverse the weight matrix to find the complement match of the weight matrix
 * Thus in this function, the order of columns and rows should both be reversed, i.e., exchange the first row (column) with the last row (column)...  
 * Input argument:  wm is the normal weight matrix needed to be reversed
 *					rev is the matrix to store the reversed weight matrix
 *					row is the number of rows of the weight matrix
 *					col is the number of columns of the weight matrix
 * Output is just the reversed weight matrix "rev"
 */
{
	double **temp;
	int i,j;
	
	assert(temp= (double **)malloc(row*sizeof(double *)));	//allocate the space for the reversed weight matrix
	for(i=0;i<row;i++)
		assert(temp[i]=(double *)malloc(col*sizeof(double)));
	
	for(i=0;i<row;i++)			//reverse the rows of the weight matrix
	{
		for(j=0;j<col;j++)		//i.e. exchange the first row with the last row,....
			temp[row-1-i][j]=wm[i][j];
	}
	
	  // Uri: Reverse complement assumes: A C G T (no need to pretend otherwise, i.e. col = 4)
	for(i=0;i<col;i++)			//reverse the columns of the weight matrix
	{
		for(j=0;j<row;j++)		//i.e. exchange the first column with the last column,....
			rev[j][col-1-i]=temp[j][i];
	}
		
	free(temp);
}	


void permute_wm(PWM_Struct *pwm)
     // *** This needs urgent rewriting: a new structure with a new PWM_ID needs to be created ***
// randomly permutes the weight matrix
//  first creates an array of permuted indices to the weight matrix, then uses it
//  to shuffle the values of the weight matrix
// Uri: 6/22/07 fixed a bug whereby the reverse of the permuted matrix was off 
//      and cleaned the code along the way.
{
	int i,j;
	double **tempwm;

	int *perm, width;

	width = pwm->width;
	assert(perm = (int *)malloc(width*sizeof(int)));
	perm = rand_perm_prealloc(perm, width);

	for (i=0; i<width; i++)
	  printf("%d ",perm[i]);
	printf("\n");

	assert(tempwm=(double **)malloc(width*sizeof(double *)));
	for (i=0;i<width;i++)
		assert(tempwm[i]=(double *)malloc(NCLNUM*sizeof(double)));
	
	
	for (i=0;i<width;i++)
	{
		for (j=0;j<NCLNUM;j++)
			tempwm[i][j]=pwm->entry[i][j];
	}
	for (i=0;i<width;i++)
	{
		for (j=0;j<NCLNUM;j++)
		pwm->entry[i][j]=tempwm[perm[i]-1][j];
	}

	rev_wm(pwm->entry, pwm->reverseCompEntry, width, NCLNUM);

	for (i=0; i<width; i++)
	{
		for(j=0; j<4; j++)
			printf("%f ",pwm->entry[i][j]);
		printf("\n");
	}		
	free(perm);
	for (i=0; i<width; i++)
	  free(tempwm[i]);
	free(tempwm);

	free(pwm->logFreqs); // patch
	free(pwm->logFreqsRC);
	pwm->logFreqs = TwoDarray2LogVector(pwm->entry, pwm->width);
	pwm->logFreqsRC = TwoDarray2LogVector(pwm->reverseCompEntry, pwm->width);
	pwm -> id = PWM_ID++;
}


double *TwoDarray2LogVector(double **array, int width)
     // Returns log(array ) in vector format
{
  int i, j;
  double *vec;

  assert( vec = (void *) malloc(width * NCLNUM * sizeof(double)) );
  for (i = 0; i < width; i++)
    for(j = 0; j < NCLNUM; j++)
      vec[PWM_IND(i,j)] = log(array[i][j]);
  return vec;
}


double sitePwmScore(iLetter *site, double *pwm, int width)
     // Return the PWM score of a site
{
  int i;
  double score=0;

  for (i = 0; i < width; i++)
    score += pwm[PWM_IND(i, site[i])];
  return score;
}


double get_sitePercentThresh(PWM_Struct pwm)
     // Retrieve or find a nominal % given the threshold
{
  siteStructVec trainSites;
  double nScoresGTT=0;
  int i;

  if (sitePercentThresh > -DBL_MAX) // the % was explicitly specified and should match the threshold
    return sitePercentThresh;
  assert( trainSites.entry = (void *) malloc(nullTrainData.len * sizeof(siteStruct)) );
           // get scores of all sites (forward direction only)
  trainSites.len = get_chunkSites(nullTrainData, pwm, 0, 0, 0, trainSites.entry);
  for (i = 0; i < trainSites.len; i++) // count the # of sites above threshold
    nScoresGTT += (trainSites.entry[i].score > *(pwm.siteThreshold));

  free(trainSites.entry);
  return nScoresGTT / trainSites.len;
}


void set_siteThreshold(PWM_Struct *pwm)
     // find threshold from nominal % one + training file
{
  siteStructVec trainSites;
  double nPercSites, T;
  iLetterVec trainData;
  int indPerc;

  if (pwm->siteThreshold == NULL) { // threshold wasn't set yet
    if (nullMarkovBlock)
      trainData = genMarkovBlock(nullMarkovModel, siteThreshTrainSize);
    else
      trainData = nullTrainData;
    assert( trainSites.entry = (void *) malloc(trainData.len * sizeof(siteStruct)) );
           // get scores of all sites (forward direction only)
    trainSites.len = get_chunkSites(trainData, *pwm, 0, 0, 0, trainSites.entry);
    qsort((void *) trainSites.entry, trainSites.len, sizeof(siteStruct), (void *) siteStructCompare);
          // set threshold T so that no more than sitePercentThresh * trainSites.len sites are above T
    nPercSites = sitePercentThresh * trainSites.len;
    indPerc = floor(nPercSites);
    if (indPerc < trainSites.len-1) // interpolate the threshold
      pwm->siteThreshold = allocAndSetD(trainSites.entry[indPerc].score + 
	             (trainSites.entry[indPerc+1].score - trainSites.entry[indPerc].score) * (nPercSites-indPerc));
    else
      pwm->siteThreshold = allocAndSetD(trainSites.entry[indPerc].score);
    free(trainSites.entry);
    if (nullMarkovBlock)
      free(trainData.entry);
  }
}

void set_threshGenerationData(double sitePercentThreshValue, unsigned int siteThreshTrainSizeValue, int siteThresholdBlockType)
     // Set the parameters according to which the site thresholds are learned on demand
{
  sitePercentThresh = sitePercentThreshValue;
  siteThreshTrainSize = siteThreshTrainSizeValue;
  nullMarkovBlock = 1 - siteThresholdBlockType;
}


void set_siteThresholdsFromFile(char *thresholdsFileName, PWMsVec pwms)
     // Reads the thresholds from a file and sets the appropriate slots of pwms
{
	FILE *tfile;
	char *line, c;
	int numread, num=pwms.len;
	assert(tfile=fopen(thresholdsFileName,"r"));

	assert(line = (char *)malloc(100*sizeof(char)));
	numread = 0;
	// skip white spaces before the first threshold
	while (!feof(tfile) && isspace(line[0] = fgetc(tfile)) ) { }
	// read the first num thresholds
	while(!feof(tfile) && numread < num){
		fgets(line+1,100,tfile);
		// truncate line feeds and carriage returns, if they're included (DOS-style)
		if(line[strlen(line)-1] == 10)
			line[strlen(line)-1] = 0;
		if(line[strlen(line)-1] == 13)
			line[strlen(line)-1] = 0;
		if(strlen(line) >= 1){
			pwms.entry[numread++].siteThreshold = allocAndSetD(my_atod(line));
			line[0] = 0;
		}
		// skip white spaces to get to the next threshold
		while (!feof(tfile) && isspace(line[0] = fgetc(tfile)) ) { }
	}

	if(numread < num)
		// if the file contained less thresholds than we needed, that is BAD, and we should throw an error
		ERROR(("There were %d thresholds (we needed %d) in %s\n"
				      "Please ensure there is at most one threshold per line.",
				      numread, num, thresholdsFileName) )
	else if(!feof(tfile))
		// if we aren't at the end of the file yet, that means there were more thresholds than we wanted
		//  that is BAD, and we should throw an error
		ERROR(("There were too many thresholds in %s\n"
				      "Please include only %d thresholds.",
				      thresholdsFileName, num) )
	fclose(tfile);
	free(line);
}


void set_pwmPC(double pc)
     // Set the total pseudocount
{
  pwmPC = pc;
}
