#include "timeseries.h"

double general_MLE(data_kernel dk, double *C, options op, double *fit, double *cov_fit, double *sigma, int scale) {

int i, j, k, info, nrhs, lwork, minrc;
int ncol, nrow;
int *ipiv;

double alpha, beta, a, P, sig, N, out;
double *W, *W2, *Q, *R, *t1, *new_A, *new_d, *work;

clock_t time1, time2;

/*********/
/* START */
/*********/

nrow = dk.n_data;
ncol = dk.n_par;

lwork = (nrow * (ncol+1));

W     = (double *) calloc((size_t) (nrow*nrow), sizeof(double));
W2    = (double *) calloc((size_t) (nrow*nrow), sizeof(double));
work  = (double *) calloc((size_t) lwork,       sizeof(double));

ipiv = (int *) calloc((size_t) nrow, sizeof(int));

dlacpy_("Full",&nrow,&nrow,C,&nrow,W,&nrow); 

/***************/
/* C = chol(C) */
/***************/
 
time1 = clock();

dpotrf_("L",&nrow,W,&nrow,&info);

for (a = 0.0, j = 0; j < nrow; j++) {
	a += log(W[j + j * nrow]);
	/* fprintf(op.fpout, "W[%d,%d] = %f\n", j, j, W[j + j * nrow]); */
	for (k = j+1; k < nrow; k++) W[j + k * nrow] = 0.0;
}
a = a * 2.0;

dlacpy_("Full",&nrow,&nrow,W,&nrow,W2,&nrow); 

time2 = clock();

if (op.verbose > 1) {
	fprintf(op.fpout, " Time taken to create cholesky of covariance matrix : ");
	fprintf(op.fpout, "%f seconds\n", ((double) (time2-time1)) / CLOCKS_PER_SEC);
}

nrhs = ncol + 1;
for (j = 0; j < nrow; j++) {
	for (k = 0; k < ncol; k++) work[j + k * nrow] = dk.A[j + k * nrow];
	work[j + ncol * nrow] = dk.d[j];
}


/**************************/
/* New_d = C \ data       */
/* New_A = C \ A          */
/* both computed together */
/**************************/

dgesv_(&nrow, &nrhs, W, &nrow, ipiv, work, &nrow, &info);

new_A = (double *) calloc((size_t) (nrow*ncol), sizeof(double));
new_d = (double *) calloc((size_t) nrow,        sizeof(double));

for (j = 0; j < nrow; j++) {
	for (k = 0; k < ncol; k++) new_A[j + k * nrow] = work[j + k * nrow];
	new_d[j] = work[j + ncol * nrow];
}

/*********************/
/* QR factorization  */
/* [Q,R] = qr(new_A) */
/*********************/

minrc = nrow < ncol ? nrow : ncol;

Q = (double *) calloc((size_t) (nrow*minrc), sizeof(double));
R = (double *) calloc((size_t) (minrc*ncol), sizeof(double));

qr(new_A, nrow, ncol, Q, R);
free(new_A);

/********************/
/* Calculate inv(R) */
/********************/

free(ipiv);
ipiv = (int *) calloc((size_t) ncol, sizeof(int));
free(work);
lwork = ncol;
work = (double *) calloc((size_t) lwork, sizeof(double));

dgetrf_(&ncol, &ncol, R, &ncol, ipiv, &info);
dgetri_(&ncol, R, &ncol, ipiv, work, &lwork, &info);
free(ipiv);
free(work);

/********************/
/* t1 = inv(R) * Q' */
/********************/

alpha = 1.0;
beta  = 0.0;

t1 = (double *) calloc( (size_t) (ncol*nrow), sizeof(double));
dgemm_("N","T",&ncol,&nrow,&ncol,&alpha,R,&ncol,Q,&nrow,&beta,t1,&ncol);
free(Q);

/**********************/
/* cov_fit = t1 * t1' */
/**********************/

dgemm_("N","T",&ncol, &ncol,&nrow,&alpha,t1,&ncol,t1,&ncol,&beta,cov_fit,&ncol);

/********************/
/* fit = t1 * new_d */
/********************/

i = 1;
dgemm_("N","N",&ncol,&i,&nrow,&alpha,t1,&ncol,new_d,&nrow,&beta,fit,&ncol);

/**********************/
/* data_hat = A * fit */
/**********************/

work  = (double *) calloc((size_t) nrow, sizeof(double));
ipiv  = (int *)    calloc((size_t) nrow, sizeof(int));

dgemm_("N","N",&nrow,&i,&ncol,&alpha,dk.A,&nrow,fit,&ncol,&beta,work,&nrow);

/**************************************/
/* Done : Now calculate the final mle */
/**************************************/

time1 = clock();

nrhs = 1;
for (j = 0; j < nrow; j++) work[j] = dk.d[j] - work[j];

dgesv_(&nrow, &nrhs, W2, &nrow, ipiv, work, &nrow, &info);

for (P = 0.0, j = 0; j < nrow; j++) P += work[j]*work[j];

free(ipiv);
free(work);

N = (double) nrow;

sig = sqrt(P / N);

for (j = 0; j < ncol; j++) {
	for (k = 0; k < ncol; k++) {
		cov_fit[j + k * ncol] *= (sig*sig);
	}
}

if (scale) {
	out = -N * log(2.0 * M_PI) / 2.0 - a / 2.0 - N * log(sig*sig) / 2.0 - P / 2.0 / sig / sig;
} else {
	out = -N * log(2.0 * M_PI) / 2.0 - a / 2.0 - P / 2.0;
}

(*sigma) = sig;

if (op.verbose > 1) fprintf(op.fpout, " a = %f  sig = %f P = %f mle = %f\n", a, sig, P, out);

time2 = clock();

if (op.verbose > 1) {
	fprintf(op.fpout, " Time taken to do final mle part : ");
	fprintf(op.fpout, "%f seconds\n", ((double) (time2-time1)) / CLOCKS_PER_SEC);
}

free(new_d);
free(R);
free(t1);
free(W);
free(W2);

return(out);

}
