#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
//#include <iostream.h>
//#include <strstream.h>
//#include <fstream.h>
#include <string.h>
//#include <iomanip.h>
#include <math.h>
#include "utils.h"
#include "extended_exponent.h"
#include "main.h"

double* mw_lattice(int m, int n, int k, int c, int* size) ;

double * mw_latticeC(int m, int n, int k, int c, int *size){
			return mw_lattice(m,n,k,c,size);
}


double* mw_lattice(int m, int n, int k, int c, int* size) {
/*
 DP in the tied case
 m, n - size of the two samples (m <-> X, n <-> Y)
 c - # of tie classes
 size - sizes of tie classes
 k - 2 * sum 1_{X_i>Y_j} + sum 1_{X_i=Y_j}
 Returns the log of the pmf
 The actual computation corresponds to the Wilcoxon rank sum statistic
*/

  ee_t** choose = new ee_t*[m+n+1];
  for(int i = 0; i < m+n+1; i++) {

    choose[i] = new ee_t[m+1];
    choose[i][0].log_set(0);
    for(int j = 1; j < MIN(i+1, m+1); j++)
      choose[i][j] = choose[i][j-1]*ee_t(log((i-j+1)/(j*1.0)));
  }
  
  int Q = 2*m*n+1 + m*(m+1);
  //int Q = k+1 + m*(m+1);
  ee_t** counts = new ee_t*[m+1];
  for(int i = 0; i < m+1; i++) 
    counts[i] = new ee_t[Q];

  int rank = 0;
  for(int d = 0; d < MIN(m+1, size[0]+1); d++)
    counts[d][d*(2*rank+size[0]+1)] = choose[size[0]][d];

  for(int a = 1; a < c; a++) {

    rank += size[a-1];  
    for(int b = MIN(m, rank+size[a]); b >= 0;  b--)
      for(int c = Q-1; c >= 0; c--)
	for(int d = 1; d < MIN(b+1, size[a]+1); d++)
	  if(c >= d*(2*rank+size[a]+1) && counts[b-d][c-d*(2*rank+size[a]+1)] != 0)
	    counts[b][c] += counts[b-d][c-d*(2*rank+size[a]+1)]*choose[size[a]][d];
  }
      
  double* log_pmf = new double[Q-m*(m+1)];
  double log_denominator = choose[m+n][m].log_get();
  for(int i = m*(m+1); i < Q; i++)
    log_pmf[i-m*(m+1)] = counts[m][i].log_get()-log_denominator;

  return log_pmf;
}


/*
int main(int argc, char** argv) {

  int argno = 1;
  int m = atoi(argv[argno++]);
  int n = atoi(argv[argno++]);
  int k = atoi(argv[argno++]);
  int c; int* size;
  int N=0;

  if(argc > 4) {
    
    c = atoi(argv[argno++]);
    size = new int[c];
    for(int i = 0; i < c; i++) {
      size[i] = atoi(argv[argno++]);
      N += size[i];
    }
    if (N != m+n) {
      printf("\n ERROR: sum of tied class = %d but m=%d & n=%d \n", N, m, n);
      return 0;
    }
  }

  printf("\nm=%d (X), n=%d (Y), # of tie classes=%d\n%d = 2 * sum 1_{X_i>Y_j} + sum 1_{X_i=Y_j}\n",
	 m, n, c, k);
  printf("The tie classes are: ");
  for(int i = 0; i < c; i++)
    printf("%d ", size[i]);
  printf("\n");
  
  int i = c;
  int msum = N;
  while (msum >= m)
    msum -= size[--i];
  int i0 = i;
  int m0 = m - msum;
  //  printf("i0=%d, m0=%d, msum=%d\n", i0, m0, msum);
  if (k < m0*(size[i0]-m0)) {
    printf("\nk is too small, minimum possible value is %d\n", m0*(size[i0]-m0));
    return 0;
  }

  i = c;
  msum = N;
  while (msum >= n)
    msum -= size[--i];
  i0 = i;
  int n0 = n - msum;
  //printf("i0=%d, n0=%d, msum=%d\n", i0, n0, msum);
  if (k > 2*m*n - n0*(size[i0]-n0)) {
    printf("\nk is too big, maximum possible value is %d\n", 2*m*n - n0*(size[i0]-n0));
    return 0;
  }
  
  double* log_pmf; double log_pval=LOGZERO;
  double start = clock(); 

  log_pmf = mw_lattice(m, n, k, c, size);
  //    for(int i = 0; i <= k; i++)
//	printf("%g ", log_pmf[i]);
  
  for (int i=0; i<=k; i++)
    log_pval = log_sum(log_pval, log_pmf[i]);

  double finish = clock();
  cout << "Time = " << (finish*1.0/CLOCKS_PER_SEC - start*1.0/CLOCKS_PER_SEC) << endl;
  cout << "P = " << setprecision(20) << log_pval << " (" << exp(log_pval) << ")" << endl;
  cout << "EP format: " << setprecision(20) << exp(log_pval-int(log_pval)) << " "
       << int(log_pval)/log(10.0) << endl;

  return 0;
}
*/
