#include <assert.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <math.h>
#include "my_types.h"
#include "misc_functions.h"
#include "tests.h"
#include "exactmw/main.h"
#include "exactmw/utils.h"
#include "dcdflib/cdflib.h"
#include "gsl/hypergeometric.h"


static long MW_tiedU(realVec sample1, realVec sample2);
static double* MWharding(int m, int n, int k);
static testStruct MWnormal(long U, intVec multiplicities, double m, double n);
static testStruct MWexact(long U, intVec multiplicities, double m, double n);
static double uniSumCDF(int num, double sum);
static double factorial(int num);
static double logProdGamma2pvals(double p, double q);



testStruct hypergeometric_test(int nSampleRed, int nRed, int nBlack, int nSample)
     // Executes the hypergeometric test
{
  testStruct hg;

  hg.stat = nSampleRed;
  hg.test = "hypergeometric test";
  if (nSample == 0) { // the functions below do not handle this gracefully
    hg.pGT = myNaN();
    hg.pLT = myNaN();
  } 
  else {
    hg.pGT = gsl_cdf_hypergeometric_Q(nSampleRed>0 ? nSampleRed-1 : 0, nRed, nBlack, nSample);
    hg.pLT = gsl_cdf_hypergeometric_P(nSampleRed, nRed, nBlack, nSample);
  }
  hg.p2s = -1; // not implemented
  return hg;
}


testStruct t_test(realVec sample1, realVec sample2)
     // Executes the 2 sample t-test 
{
  double mean1, mean2, std1, std2, num1, num2;
  testStruct t;
  double deg_free, bound, confidence, inv_confidence;
  int which, status;

  mean1 = compMean(sample1);
  mean2 = compMean(sample2);
  std1 = compStd(sample1);
  std2 = compStd(sample2);
  t.test = "t-test";
  num1 = sample1.len;
  num2 = sample2.len;
  t.stat = (mean1-mean2) / sqrt( (num1*std1+num2*std2) / (num1+num2-2) * (1/num1+1/num2) );
  deg_free = num1+num2-2;
  which = 1;
  cdft(&which, &confidence, &inv_confidence, &t.stat, &deg_free, &status,&bound);
  t.pGT = inv_confidence;
  t.pLT = confidence;
  t.p2s = 2 * MIN(t.pGT, t.pLT);

  return t;
}


testStruct MW_test(realVec sample1, realVec sample2, const char *rqstdMethod)
     /*
       Executes the Mann Whitney test on the two input samples.
       rqstdMethod can be: "normal", "exact" or NULL 
          NULL is: if min(m,n) >= MWnormalThresh then normal else exact.
       The function automatically switches to the Harding method if no ties and exact.
     */
{
  int i;
  intVec multiplicities;
  realVec combined;
  long U;
  testStruct mw;
  const int MWnormalThresh=15; // if in default mode min(m,n)>= MWnormalThresh then normal is chosen
    
  if(sample1.len == 0 || sample2.len == 0) {              // can't make sense of such a case
    mw.pLT = mw.pGT = mw.p2s = mw.stat =  myNaN();
    mw.test = "0 sample";
    return mw;
  }
  combined = alloc_realVec(sample1.len + sample2.len);
  for(i = 0; i < sample1.len; i++)
    combined.entry[i] = sample1.entry[i];
  for(i = 0; i < sample2.len; i++)
    combined.entry[i+sample1.len] = sample2.entry[i];

  multiplicities = multiset_multiplicities(combined);                 // find the multiplicities

  U = MW_tiedU(sample1, sample2);

  if ( ((rqstdMethod == NULL) && MIN(sample1.len, sample2.len) >= MWnormalThresh) || // no method was specified but sample is "large"
      ((rqstdMethod != NULL) && strcmp(rqstdMethod, "normal") == 0) ) {       // or normal was specifically requested
    mw = MWnormal(U, multiplicities, sample1.len, sample2.len);
    if (MIN(sample1.len, sample2.len) < MWnormalThresh/2)
      printf("One of the samples has only %d point so normal approximation is dubious\n", MIN(sample1.len, sample2.len));
  }
  else if ((rqstdMethod == NULL) ||                                   // no method but small sample
	   (strcmp(rqstdMethod, "exact") == 0))                       // or exact was specifically requested
    mw = MWexact(U, multiplicities, sample1.len, sample2.len);
  else {
    printf("Requested method for MW_test, %s, is unknown", rqstdMethod);
    exit(0);
  }
  free(multiplicities.entry);
  free(combined.entry);

  return mw;
}


long MW_tiedU(realVec sample1, realVec sample2) {
  /*
    Returns the MW tiedU statistic defined here using +1, 0, -1 for each pair of sample points
  */
  long U=0;
  int i, j;

  for (i = 0; i < sample1.len; i++){
    for (j = 0; j < sample2.len; j++){
      if (sample1.entry[i] > sample2.entry[j])
	U++;
      else if (sample1.entry[i] < sample2.entry[j])
	U--;
    }
  }
  return U;
}


testStruct MWnormal(long U, intVec multiplicities, double m, double n)
     /*
       Returns the testStruct mw using the normal approximation
     */
{
  testStruct mw;
  double sd, t=0, N=m+n;
  int i;
  int which,status;
  double bound,p,q,mean,test_stat;

  for (i = 0; i < multiplicities.len; i++)
    t += pow(multiplicities.entry[i],3)-multiplicities.entry[i]; // this will be 0 if ( == 1)
  sd = sqrt( 4*m*n*(N+1)/12  -  (t*m*n) / (12*N*(N-1)) );
  which = 1;
  mean = 0;
  test_stat = U;
  cdfnor(&which, &p, &q, &test_stat, &mean, &sd, &status, &bound);
  if(status != 0){
    fprintf(stdout, "Mann-Whitney Normalized Approximation failed!\n");
    exit(1);
  }
  else{
    mw.pGT = q;
    mw.pLT = p;
    mw.p2s = 2*p;
  }
  mw.test = "MW (normal)";
  mw.stat = U;

  return mw;
}


testStruct MWexact(long U, intVec multiplicities, double m, double n)
     /*
       Returns the testStruct mw using an exact computation
     */
{
  testStruct mw;
  int noTies, i;
  double *exactMWResults, logP;
  long U2, stat;
  const double warningT=1e8; // threshold for generating complexity warning
  const double small_pGT=1e-6; // threshold for triggering a more costly computation of the GT alternative

  if (noTies = (multiplicities.len == m+n))
    mw.test = "MW (no-ties exact)";
  else
    mw.test = "MW (exact)";

  if ((m*m*n*n > warningT && !noTies) || (pow(MIN2(m,n),2)*MAX2(m,n) > warningT && noTies)) {
      fprintf(stdout, "There are a large number of scores.  The exact MW test will\n");
      fprintf(stdout, "   take a long time, and may not have enough memory.\n");
  }

  if (noTies) {
    U2 = (U+m*n)/2;
    exactMWResults = MWharding(m, n, U2);  // harding is faster
  }
  else {
    U2 = U+m*n;
    exactMWResults =( double *) mw_latticeC(m, n, U2, multiplicities.len, multiplicities.entry);
  }
  logP = LOGZERO;
  for (i = 0; i <= U2; i++)
    logP = log_sum(logP, exactMWResults[i]);
  mw.pLT = exp(logP);
         
  if (mw.pLT > 1 - small_pGT) { // we can speed up the computation of mw.pGT if it is not "too small "
    free(exactMWResults);       // mw.pGT is small so it's not a good idea to take 1-mw.pLT+correction
    if (noTies) {
      stat = m*n - U2;
      exactMWResults = MWharding(n, m, stat);  // harding is faster
    }
    else {
      stat = m*n - U;
      exactMWResults = (double *) mw_latticeC(n, m, m*n-U, multiplicities.len, multiplicities.entry);
    }
    logP = LOGZERO;
    for (i = 0; i <= stat; i++)
      logP = log_sum(logP, exactMWResults[i]);
    mw.pGT = exp(logP);
  }
  else                       // mw.pGT is not too small so we can save time here by doing trivial math
    mw.pGT = 1 - mw.pLT + exp(exactMWResults[U2]);

  free(exactMWResults);
  mw.p2s = -1;              // *** not implemented here yet ***
  mw.stat = U2;

  return mw;
}


double* MWharding(int m, int n, int k)
     /* 
	harding returns the log of the exact mann-whitney pmf evaluated on [0:k]
	m is the number of items in the first set
	n is the number of items in the second set
	k is the value of the test statistic: 0, 1 for each pair
     */
{
int t,u,s,j;
//double* f = new double[k+1];
double *f;
assert(f = (double *)malloc((k+1)*sizeof(double)));
for(t = 0; t < k+1; t++)
     f[t] = 0;
     f[0] = 1;

		if(n < k) {

for(t = n+1; t <= MIN(m+n, k); t++)
     for(u = k; u >= t; u--)
     f[u] -= f[u-t];
  }

  for(s = 1; s <= MIN(m, k); s++)
    for(u = s; u <= k; u++)
      f[u] += f[u-s];

  double log_CNm = 0;
  for(j = 1; j <= m; j++)
    log_CNm += log(n+j)-log(j);

  //double* result = new double[k+1];
  double * result;
  assert(result = (double *)malloc((k+1)*sizeof(double)));
  for(u = 0; u <= k; u++)
    result[u] = log(f[u])-log_CNm;

  free(f);
  return result;
}


/* testSTruct uniformSum(realVec sample) */
/*      /\* */
/*        Applies the uniform sum test to the input sample: the null hypothesis is that */
/*        the sample was generated from a U[0,1] distribution and the test statistic is */
/*        the sample sum. */
/*      *\/ */
/* { */
/*   double sum; */
/*   int k; */
/*   testStruct test; */

/*   sum = 0; */
/*   for(k=0; k < sample.len; k++) */
/*     sum += sample.entry[k]; */
/*   test.pLT = uniSumCDF(sample.len, sum); */
/*   test.pGT = 1 - test.pLT; */
/*   test.p2s = 2 * MIN(test.pLT, test.pGT); */

/*   return test; */
/* } */


testStruct unifSum2tests(testStruct t1, testStruct t2)
     /*
       Compbines the p-values from two tests using the uniform sum statistics.
       Under the null model the two pairs of one-sided p-values are assumed to be IID U[0,1].
       The sum statistic is therefore distributed as a sum of a pair of such RVs.
     */
{
  double sumP;
  testStruct uniSum;

  sumP = t1.pGT + t2.pGT;
  uniSum.pGT = uniSumCDF(2, sumP);
  uniSum.stat = sumP;
  uniSum.test = "uniform sum test";
     // The second alternative is a little trickier than usually
  sumP = t1.pLT + t2.pLT;
  uniSum.pLT = uniSumCDF(2, sumP);

  return uniSum;
}


double uniSumCDF(int num, double sum)
     /* 
	uniSumCDF computes the CDF of a variable that is a sum of num IID U[0,1] RVs.
     */
{
  double retval,tempval,tempsum,nchoosek;
  double mean, std, p, q, bound, test_val;
  int which,status;
  int k,j;

//fprintf(stdout, "num = %d, sum = %f\n", num, sum);

	retval = 0;
	if(num <= 20){
		// the equation to calculate the CDF is based on the integral of the density
		// the site used to get this information was:
		//      http://mathworld.wolfram.com/UniformSumDistribution.html
		// this site contains a discrete summation in equation (3), and that is nearly
		//  identical to what I used as the integral.  equation (3) is very similar to
		//  the computation this algorithm does
		for(k=0; k <= num; k++){
			if(k > 0 && k < num)
				nchoosek = nchoosek*(num-k+1)/((double)k);
			else
				nchoosek = 1;
			tempval = k%2==1?-1:1;
			if(sum > k){
				if(num%2==1){
					tempsum = 0;
					for(j = 0; j < num; j++)
						tempsum += pow(sum-k,j)*pow((double)k,num-1-j);
					tempval *= nchoosek*(sum-2*k)*tempsum/((double)num);
				}
				else{
					tempval *= nchoosek*(pow(sum-k,num)/((double)num) + pow(-1.0*k,num)/((double)num));
				}
			}
			else{
				tempsum = 0;
				for(j = 0; j < num; j++)
					tempsum += pow(-1.0*k,j)*pow(sum-k,num-1-j);
				tempval *= nchoosek*(-1*sum)*tempsum/((double)num);
			}
			retval += tempval;
		}
		retval /= 2.0*factorial(num-1);
	}
	else{
	  // if there are more than 20 variables to be summed, then due to accumulating roundoff
	  // errors the above method is too inaccurate, so use a normal approximation instead
		mean = ((double)num)/2.0;
		std = sqrt(((double)num)/12.0);
		test_val = (double) sum;
		which = 1;
		cdfnor(&which, &p, &q, &test_val, &mean, &std, &status, &bound);
		if(status != 0){
			fprintf(stdout, "Error in normalized approximation to the uniform sum distribution!\n");
			exit(1);
		}
		retval = p;
	}
	return retval;
}

/* a simple factorial function.  It accepts an integer as input
	and returns the factorial of that integer 
	Inputs: num, the number to find the factorial of
	Outputs: the factorial of num
*/
double factorial(int num){
	double retval;
	int i;
	retval = 1;
	for(i=num; i > 1; i--)
		retval *= (double)i;
	return retval;
}




testStruct logProdGamma2tests(testStruct t1, testStruct t2)
     /*
       Tests the null hypothesis that each of the two pairs of one-sided p-values are IID U[0,1]
       using their respective log product statistics.
       Under the null these log products are (-) Gamma (shape=2,scale=1) RV.
     */
{
  testStruct logProdTest;

  logProdTest.pGT = logProdGamma2pvals(t1.pGT, t2.pGT);
  logProdTest.stat = log(t1.pGT) + log(t2.pGT); // this is not really KOSHER as the second alternative is different!
  logProdTest.test = "Gamma log product test";
     // The second alternative is a little trickier than usually
  logProdTest.pLT = logProdGamma2pvals(t1.pLT, t2.pLT);

  return logProdTest;
}


double logProdGamma2pvals(double p, double q)
     /*
       Combines the p-values p & q using the log prod statistics and returns the p-value
       using a Gamma (shape=2,scale=1) distribution
     */
{
  int which=1, status;
  double bound, confidence, inv_confidence, shape=2, scale=1;
  double logProd;

  if (p == 0 || q == 0)
    return 0;

  logProd = -log(p) - log(q);
  cdfgam(&which, &confidence, &inv_confidence, &logProd, &shape, &scale, &status, &bound);
  if(status != 0)
    fprintf(stdout, "ERROR in logProdGamma2pvals\n");

  return inv_confidence;
}
