package be.ac.ulb.mlg.utils.measure;


/*
 * The MIT License (MIT)
 * 
 * Copyright (c) 2013 Jean-Sebastien Lerat (Jean-Sebastien.Lerat@ulb.ac.be)
 * 
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 * 
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 * 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

/**
 * 
 * @author Jean-Sebastien Lerat (Jean-Sebastien.Lerat@ulb.ac.be)
 * @version 1.00, 30/08/2013
 */

import be.ac.ulb.mlg.utils.Measure;
import be.ac.ulb.mlg.utils.MeasureUtils;

/**
 * Jump up to: a b c Szkely, Rizzo and Bakirov (2007)
 * Jump up to: a b c d Szkely & Rizzo (2009)
 * http://en.wikipedia.org/wiki/Distance_correlation#Distance_correlation
 * 
 * 
 * The used strategy to handle missing value is to evaluate values with all available value (estimate means)
 * and try to infer the covariance 
 */
public class BrownCorrelation implements Measure {

	@Override
	public double measure(final double[] a,final double[] b,final boolean[] computable){
		return inefficientImplementation(a, b, computable);
	}
	
	private static double inefficientImplementation(double[] a, double[] b, boolean[] computable) {
		if(computable == null){
			computable = new boolean[a.length];
			java.util.Arrays.fill(computable,true);
		}
		//First, compute all pairwise distances
		final double[][] da = new double[a.length][a.length];
		final double[][] db = new double[b.length][b.length];
		int i,j;
		for(i=0;i<a.length;i++){
			for(j=0;j<a.length;j++){
				if(Double.isNaN(a[i]) || Double.isNaN(a[j])){
					da[i][j] = Double.NaN;
				}else{
					da[i][j] = a[i]-a[j];
					if(da[i][j] < 0.0) da[i][j] = -da[i][j];
				}
				if(Double.isNaN(b[i]) || Double.isNaN(b[j])){
					db[i][j] = Double.NaN;
				}else{
					db[i][j] = b[i]-b[j];
					if(db[i][j] < 0.0) db[i][j] = -db[i][j];
				}
			}
		}
		//Then take all doubly centered distances
		final double[][] A = new double[a.length][a.length];
		final double[][] B = new double[b.length][b.length];
		double	daMean = 0.0;
		double dbMean = 0.0;
		int countDA = a.length*a.length;
		int countDB = b.length*b.length;
		for(i=0;i<a.length;i++){
			for(j=0;j<a.length;j++){
				if(Double.isNaN(da[i][j])) countDA--;
				else daMean += da[i][j];
				if(Double.isNaN(db[i][j])) countDB--;
				else dbMean += db[i][j];
			}
		}
		daMean = daMean/countDA;
		dbMean = dbMean/countDB;
		
		//A/B[i][j] = da/db[i][j] + mean(da/db) - mean of cols and rows
		for(i=0;i<a.length;i++){
			for(j=0;j<a.length;j++){
				A[i][j] = da[i][j] + daMean;
				B[i][j] = db[i][j] + dbMean;
			}
		}
		
		// we still need to subtract means of rows
		double daRowMean,dbRowMean;
		for(i=0;i<a.length;i++){
			countDA = a.length;
			countDB = b.length;
			daRowMean = 0.0;
			dbRowMean = 0.0;
			for(j=0;j<a.length;j++){
				if(Double.isNaN(da[i][j])) countDA--;
				else daRowMean += da[i][j];
				if(Double.isNaN(db[i][j])) countDB--;
				else dbRowMean += db[i][j];
			}
			daRowMean /= countDA;
			dbRowMean /= countDB;
			for(j=0;j<a.length;j++){
				A[i][j] -= daRowMean;
				B[i][j] -= dbRowMean;
			}
		}
		
		// we still need to subtract means of columns
		double daColumnMean,dbColumnMean;
		for(j=0;j<a.length;j++){
			countDA = a.length;
			countDB = b.length;
			daColumnMean = dbColumnMean = 0.0;
			for(i=0;i<a.length;i++){
				if(Double.isNaN(da[i][j])) countDA--;
				else daColumnMean += da[i][j];
				if(Double.isNaN(db[i][j])) countDB--;
				else dbColumnMean += db[i][j];
			}
			daColumnMean /= countDA;
			dbColumnMean /= countDB;
			for(i=0;i<a.length;i++){
				A[i][j] -= daColumnMean;
				B[i][j] -= dbColumnMean;
			}
		}
		//The squared sample distance covariance
		double dCov = 0.0, dVarA = 0.0, dVarB = 0.0;
		for(i=0;i<a.length;i++){
			for(j=0;j<a.length;j++){
				if(!Double.isNaN(A[i][j]))
					dVarA	+= A[i][j] * A[i][j];
				if(!Double.isNaN(B[i][j]))
					dVarB	+= B[i][j] * B[i][j];
				
				if(!Double.isNaN(A[i][j]) && !Double.isNaN(B[i][j])){
					dCov	+= A[i][j] * B[i][j];
				}
			}
		}
		
		//The distance correlation of two random variables is obtained
		//by dividing their distance covariance by the product of
		//their distance standard deviations
		
		dCov	= MeasureUtils.sqrt(dCov);
		dVarA	= MeasureUtils.sqrt(dVarA);
		dVarB	= MeasureUtils.sqrt(dVarB);
		
		System.out.println("\t\tIneff[dCov]:  "+dCov);
		System.out.println("\t\tIneff[dVarA]: "+dVarA);
		System.out.println("\t\tIneff[dVarB]: "+dVarB);
		
		final double denom = MeasureUtils.sqrt(dVarA * dVarB);
		if(denom > 0)
			return dCov/denom;
		return 0.0;
	}
	
	public static void main(String[] args){
		final double[] x = {5.463489, 14.579548, 1.934457,  53.518047, 3.400016};
		final double[] y = {1.225586, 1.225586, 18.829184, 12.167940, 11.560835};
		final boolean[] computable = new boolean[x.length];
		double ineff,eff;
		java.util.Arrays.fill(computable, true);
		System.out.println("Without NAN");
		System.out.println("\tInefficient: "+(ineff=inefficientImplementation(x, y, computable)));
		System.out.println("\tEfficient: "+(eff=efficientImplementation(x, y, computable)));
		System.out.println("\tError: "+Math.abs(ineff-eff));
		
		System.out.println();
		x[0]=Double.NaN;
		computable[0]=false;
		
		System.out.println("With NAN in x");
		System.out.println("\tInefficient: "+(ineff=inefficientImplementation(x, y, computable)));
		System.out.println("\tEfficient: "+(eff=efficientImplementation(x, y, computable)));
		System.out.println("\tError: "+Math.abs(ineff-eff));
		
		System.out.println();
		x[0]=5.463489;
		y[0]=Double.NaN;
		
		System.out.println("With NAN in y");
		System.out.println("\tInefficient: "+(ineff=inefficientImplementation(x, y, computable)));
		System.out.println("\tEfficient: "+(eff=efficientImplementation(x, y, computable)));
		System.out.println("\tError: "+Math.abs(ineff-eff));
		
		System.out.println();
		x[0]=Double.NaN;
		
		System.out.println("With NAN in x and y at the same position");
		System.out.println("\tInefficient: "+(ineff=inefficientImplementation(x, y, computable)));
		System.out.println("\tEfficient: "+(eff=efficientImplementation(x, y, computable)));
		System.out.println("\tError: "+Math.abs(ineff-eff));
		
		System.out.println();
		x[0]=5.463489;
		x[1]=Double.NaN;
		computable[1]=false;
		
		System.out.println("With NAN in x and y at different positions");
		System.out.println("\tInefficient: "+(ineff=inefficientImplementation(x, y, computable)));
		System.out.println("\tEfficient: "+(eff=efficientImplementation(x, y, computable)));
		System.out.println("\tError: "+Math.abs(ineff-eff));
		
	}
	
	private static double efficientImplementation(double[] a, double[] b, boolean[] computable) {
		int i,j;
		int countAB = a.length;
		final int countA,countB;
		{
			if(computable == null) countA = countB = countAB;
			else{
				int count = a.length;
				for(double x:a)
					if(Double.isNaN(x))
						count--;
				countA = count;
				count = b.length;
				for(double x:b)
					if(Double.isNaN(x))
						count--;
				countB = count;
			}
		}
		
		double daa = 0.0,dbb = 0.0,dab = 0.0;
		double ma2 = 0.0,mb2 = 0.0,mab = 0.0;
		double daMean=0.0,dbMean=0.0;
		final double[] ma = new double[a.length];
		final double[] mb = new double[ma.length];
		
		double da,db;
		if(computable != null){
			for(i=0;i<a.length;i++){
				if(computable[i]){//can use a[i] and b[i]
					for(j=i+1;j<a.length;j++){
						if(computable[j]){//can use a[j] and b[j]
							da		= a[i]-a[j];
							if(da < 0.0) da = -da;
							ma[i]	+= da;
							ma[j]	+= da;
							daa		+= da*da;
							db		=  b[i]-b[j];
							if(db < 0.0) db = -db;
							mb[i]	+= db;
							mb[j]	+= db;
							dbb		+= db*db;
							dab		+= da*db;
						}else if(!Double.isNaN(a[j])){//can only use a[j]
							da		= a[i]-a[j];
							if(da < 0.0) da = -da;
							ma[i]	+= da;
							ma[j]	+= da;
							daa		+= da*da;
						}else if(!Double.isNaN(b[j])){//can only use b[j]
							db		= b[i]-b[j];
							if(db < 0.0) db = -db;
							mb[i]	+= db;
							mb[j]	+= db;
							dbb		+= db*db;
						}
					}
					ma[i]	/= countA;
					mb[i]	/= countB;
					daMean	+= ma[i];
					dbMean	+= mb[i];
					ma2		+= ma[i]*ma[i];
					mb2		+= mb[i]*mb[i];
					mab 	+= ma[i]*mb[i];
				}else{
					--countAB;
					if(!Double.isNaN(a[i])){//can only use a[i]
						for(j=i+1;j<a.length;j++){
							if(!Double.isNaN(a[j])){//can only use a[j]
								da		= a[i]-a[j];
								if(da < 0.0) da = -da;
								ma[i]	+= da;
								ma[j]	+= da;
								daa		+= da*da;
							}
						}
						ma[i]	/= countA;
						daMean	+= ma[i];
						ma2		+= ma[i]*ma[i];
					}else if(!Double.isNaN(b[i])){//can only use b[i]
						for(j=i+1;j<b.length;j++){
							if(!Double.isNaN(b[j])){//can only use b[j]
								db		= b[i]-b[j];
								if(db < 0.0) db = -db;
								mb[i]	+= db;
								mb[j]	+= db;
								dbb		+= db*db;
							}
						}
						mb[i]	/= countB;
						dbMean	+= mb[i];
						mb2		+= mb[i]*mb[i];
					}
				}
			}
		}else{
			for(i=0;i<a.length;i++){
				for(j=i+1;j<a.length;j++){
					da		= a[i]-a[j];
					if(da < 0.0) da = -da;
					ma[i]	+= da;
					ma[j]	+= da;
					daa		+= da*da;
					db		=  b[i]-b[j];
					if(db < 0.0) db = -db;
					mb[i]	+= db;
					mb[j]	+= db;
					dbb		+= db*db;
					dab		+= da*db;
				}
				ma[i]	/= countA;
				mb[i]	/= countB;
				daMean	+= ma[i];
				dbMean	+= mb[i];
				ma2		+= ma[i]*ma[i];
				mb2		+= mb[i]*mb[i];
				mab 	+= ma[i]*mb[i];
			}
		}
		ma2 *= 2.0;
		mb2 *= 2.0;
		mab *= 2.0;
	
		daa *= 2;
		dbb *= 2;
		dab *= 2;
		
		daMean /= countA;
		dbMean /= countB;
		
		{//recompute mab///TODO remove
			//mab = 0.0;
			
			double tmp_da_dbMean = 0.0;
			double tmp_da_mbi = 0.0;
			double tmp_daMean_db = 0.0;
			double tmp_daMean_mbi = 0.0;
			double tmp_mai_db = 0.0;
			double tmp_mai_dbMean = 0.0;
			double tmp_mai_mbi = 0.0;
			double tmp_mai_mbj = 0.0;

			for(i=0;i<a.length;i++){
				if(computable[i]){
					for(j=0;j<a.length;j++){
						if(computable[j]){
							//da = (a[i]-a[j])<0.0?-(a[i]-a[j]):(a[i]-a[j]);
							//db = (b[i]-b[j])<0.0?-(b[i]-b[j]):(b[i]-b[j]);
						
						//	mab = ma[i]*mb[i];
							/*
							tmp_da_dbMean	+= da;
							tmp_daMean_db	+= db;
							
							tmp_mai_dbMean	+= ma[i];
							tmp_daMean_mbi	+= mb[i];
							
							tmp_da_mbi		+= da*mb[i];
							tmp_mai_db		+= ma[i]*db;
							tmp_mai_mbi		+= ma[i]*mb[i];
							tmp_mai_mbj		+= ma[i]*mb[j];
							*/
						}
					}
				}
			}
			/*
			tmp_da_dbMean	*= dbMean;//		tmp_da_dbMean	= countAB*countAB*dbMean*daMean;
			tmp_daMean_db	*= daMean;//		tmp_daMean_db	= countAB*countAB*dbMean*daMean;
			tmp_mai_dbMean	*= dbMean;//		tmp_mai_dbMean	= countAB*countAB*dbMean*daMean;
			tmp_daMean_mbi	*= daMean;//		tmp_daMean_mbi	= countAB*countAB*dbMean*daMean;
			*/
			//System.out.println(tmp_da_dbMean);
			//System.out.println(tmp_daMean_db);
			//System.out.println(tmp_mai_dbMean);
			//System.out.println(tmp_daMean_mbi);
			
			//System.out.println(tmp_da_mbi		- (daMean*dbMean*countAB*countAB));
			//System.out.println(tmp_mai_db		- (daMean*dbMean*countAB*countAB));
			//System.out.println(tmp_mai_mbi		- (daMean*dbMean*countAB*countAB));
			//System.out.println(tmp_mai_mbj		- (daMean*dbMean*countAB*countAB));
			
			
			
			//mab  = //tmp_da_dbMean + tmp_daMean_db + 2 * (tmp_mai_mbi + tmp_mai_mbj - tmp_da_mbi - tmp_daMean_mbi - tmp_mai_db - tmp_mai_dbMean);
					//dbMean*daMean+dbMean*daMean+ 2*(dbMean*daMean+dbMean*daMean-dbMean*daMean-dbMean*daMean-dbMean*daMean-dbMean*daMean);
			//mab *= countAB*countAB;
			
			//mab += dab + daMean*dbMean * countAB * countAB;
		}
			
		final double dCov	= MeasureUtils.sqrt(dab + countAB  * (daMean * dbMean * countAB - mab));
		/*MeasureUtils.sqrt(
												dab  + countAB * daMean * dbMean * countAB - countAB * mab//*(countAB*countAB)/(countA * countB)
				);*/
		final double dVarA	= MeasureUtils.sqrt(daa + countA  * (daMean * daMean * countA  - ma2));
		final double dVarB	= MeasureUtils.sqrt(dbb + countB  * (dbMean * dbMean * countB  - mb2));
		
		
		System.out.println("\t\tEff[dCov]:  "+dCov);
		System.out.println("\t\tEff[dVarA]: "+dVarA);
		System.out.println("\t\tEff[dVarB]: "+dVarB);
		
		final double denom = MeasureUtils.sqrt(dVarA * dVarB);
		if(denom > 0)
			return dCov/denom;
		return 0.0;
	}
	
	@Override
	public boolean hasNativeImplementation() {
		return false;//TODO implement it
	}

	@Override
	public boolean requireDataTransformation() {
		return false;
	}

	@Override
	public void transform(double[][] data) {
		// nothing to do
	}

}
