package be.ac.ulb.mlg.utils;

/*
 * The MIT License (MIT)
 * 
 * Copyright (c) 2013 Jean-Sebastien Lerat (Jean-Sebastien.Lerat@ulb.ac.be)
 * 
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 * 
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 * 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

/**
 * 
 * @author Jean-Sebastien Lerat (Jean-Sebastien.Lerat@ulb.ac.be)
 * @version 1.00, 24/06/2013
 */

import java.nio.DoubleBuffer;
import java.nio.IntBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;


/**
 * DefaultRenormalizer supports custom or columns summing renormalization with features (row to be ignorated)
 * and can handle non-overlapping groups (division of input vectors into subsets where elements of each subset
 * are normalized only with others from the same group)
 */
public class DefaultRenormalizer implements Renormalizer{
	/**
	 * Sentinel value meaning that the renormalizer use the sum function on columns
	 */
	public static final double SUM_NORMILIZER = 0.0;
	private int[] groups;
	private double[][] sum;
	
	private int groupCount;
	private boolean shufflePair;
	private Map<String,Integer> groupNames;
	private List<double[]> specifiRenorms;
	private List<Integer> features;
	
	/**
	 * Construct a new DefaultRenormalizer without groups, without features and with sum normalization
	 * shuffles are disable
	 */
	public DefaultRenormalizer(){
		features		= new ArrayList<Integer>();
		specifiRenorms	= new ArrayList<double[]>();
		shufflePair		= false;
		noGroups();
		noFeatures();
		resetNormalizer();
	}
	public void addFeature(int dataIndex){
		this.features.add(new Integer(dataIndex));
	}
	/**
	 * Know if the configuration of the DefaultRenormalizer use the group renormalization
	 * @return true if the DefaultRenormalizer use the group renormalization
	 */
	public boolean isUsingGroups(){
		return groups != null;
	}
	/**
	 * Specifies to no longer use the group renormalization
	 * but preserve renormalization values by group
	 * to rester them, use resetNormalizer()
	 */
	public void noGroups(){
		if(groupNames != null) groupNames.clear();
		groupNames	= null;
		groupCount	= 1;
		groups		= null;
		final double[] backup = specifiRenorms.size()>0 ? specifiRenorms.get(0) : new double[0];
		specifiRenorms.clear();
		specifiRenorms.add(backup);
	}
	/**
	 * Reset the custom renormalizer to use the default method (sum by columns)
	 */
	public void resetNormalizer(){
		specifiRenorms.clear();
		for(int i=0;i<groupCount;i++)
			specifiRenorms.add(new double[0]);
	}
	/**
	 * Know if the configuration of the DefaultRenormalizer use the feature renormalization
	 * @return true if the DefaultRenormalizer use the feature renormalization
	 */
	public boolean isUsingFeatures(){
		Collections.sort(this.features);
		return !features.isEmpty();
	}
	/**
	 * Know if the configuration of the DefaultRenormalizer uses custom renormalization values
	 * @return true if the DefaultRenormalizer uses custom renormalization values
	 */
	public boolean isUsingCustomNormalizer(){
		for(int i=0;i<specifiRenorms.size();i++){
			final double[] ref = specifiRenorms.get(i);
			if(ref != null){
				for(final double normlizer:ref)
					if(normlizer != SUM_NORMILIZER)
						return true;
			}
		}
		return false;
	}
	/**
	 * Specifies to no longer use the features renormalization
	 */
	public void noFeatures(){
		features.clear();
	}
	/**
	 * Define the group for one input data vector
	 * @param row The index of the input data vector (used with {@link Measurer})
	 * @param groupName The name of the group of the input data vector 
	 */
	public void setGroup(int row,final String groupName){
		if(groupName != null)
			setGroup(row, groupNames.get(groupName).intValue());
		//else null => feature => no group
	}
	/**
	 * Define the group for one input data vector
	 * @param row The index of the input data vector (used with {@link Measurer})
	 * @param group The index of the group of the input data vector 
	 */
	public void setGroup(final int row,final int group){
		if(group >= 0)
			this.groups[row] = group;
		//else '-1' => feature => no group
	}
	/**
	 * Define groups and associate them a label.
	 * @param rows The number of input data vectors to be processed
	 * @param names The name of all existing groups
	 */
	public void setGroupsNames(final int rows,final String[] names){
		groupNames = new HashMap<String,Integer>();
		for(int group=0;group<names.length;group++)
			groupNames.put(names[group],new Integer(group));
		groupCount	= names.length;
		groups		= new int[rows];//we can add Arrays.fill(groups,0) but Java autofill array with 0
		while(specifiRenorms.size()<groupCount)
			specifiRenorms.add(new double[0]);
	}
	/**
	 * Define the group of all input data vectors by their indexes
	 * @param groups The array mapping each for index i, the input data vector i to the group groups[i]
	 * @param groupCount The number of groups
	 */
	public void setGroups(final int[] groups,int groupCount){
		this.groupCount	= groupCount;
		this.groupNames	= new HashMap<String,Integer>();
		this.groups		= Arrays.copyOf(groups,groups.length);
		for(groupCount=0;groupCount<this.groupCount;groupCount++){
			Integer object = new Integer(groupCount); 
			groupNames.put(object.toString(),object);
		}
		while(specifiRenorms.size()<groupCount)
			specifiRenorms.add(new double[0]);
	}
	/**
	 * Define a specific value to use instead of the sum by column for groups
	 * @param groupName The name of the group
	 * @param column  The column index (for an input data vector v, it's i of v[i])
	 * @param divider The specific value to use as normalizer
	 */
	public void setNormalizer(final String groupName,final int column,final double divider){
		setNormalizer(groupNames.get(groupName).intValue(),column,divider);
	}
	/**
	 * Define a specific value to use instead of the sum by column for groups
	 * @param group The index of the group
	 * @param column  The column index (for an input data vector v, it's i of v[i])
	 * @param divider The specific value to use as normalizer
	 */
	public void setNormalizer(final int group,final int column,final double divider){
		double[] data = specifiRenorms.get(group);
		if(column >= data.length)
			data = Arrays.copyOf(data, column+1);//Auto expended with 0.0 <=> SUM_NORMILIZER
		data[column] = divider;
		specifiRenorms.set(group, data);
	}
	/**
	 * Define a specific value to use instead of the sum by column
	 * @param column  The column index (for an input data vector v, it's i of v[i])
	 * @param divider The specific value to use as normalizer
	 */
	public void setNormalizer(int column,double divider){
		setNormalizer(0,column,divider);
	}
	/**
	 * Count the number of features
	 * @return The number of features
	 */
	public int countFeatures(){
		return features.size();
	}
	/**
	 * Count the number of groups
	 * @return The number of groups
	 */
	public int countGroups(){
		return groupCount;
	}
	/**
	 * Get the group of a row
	 * @param row The row index
	 * @return The group identifier
	 */
	public int groupOf(int row){
		return groups[row];
	}
	protected void mergeRenormilizer(){
		int col;
		for(int group=0;group<sum.length;group++){
			final double[] specific = specifiRenorms.get(group);
			for(col=0;col<specific.length;col++){
				if(specific[col] != SUM_NORMILIZER)
					sum[group][col] = specific[col];
			}
		}
	}
	//Useful for the native implementation
	@SuppressWarnings("unused")
	private IntBuffer jni_groupsHelper(){
		return JNIUtils.createIntBuffer(groups.length).put(groups);
	}
	//Useful for the native implementation
	@SuppressWarnings("unused")
	private DoubleBuffer[] jni_specificHelper(){
		final DoubleBuffer[] array = new DoubleBuffer[specifiRenorms.size()];
		
		for(int i=0;i<array.length;i++){
			final double[] tmp = specifiRenorms.get(i);
			array[i] = JNIUtils.createDoubleBuffer(tmp.length+1).put((double)tmp.length).put(tmp);
		}
		
		return array;
	}
	//Useful for the native implementation
	@SuppressWarnings("unused")
	private IntBuffer jni_featuresHelper(){
		//features.add(Integer.MAX_VALUE);
		final IntBuffer buff = JNIUtils.createIntBuffer(features.size()+1);
		for(Integer i:features)
			buff.put(i.intValue());
		return buff.put(Integer.MAX_VALUE);
	}
	@Override
	public double[][] normalizeInput(double[][] input){
		int row,col;
		// compute the sums
		sum = new double[groupCount][input[0].length];
		//we can add Arrays.fill(sum,0) but Java autofill array with 0
		
		if(isUsingGroups()){
			if(isUsingFeatures()){
				features.add(Integer.MAX_VALUE);
				int pos = 0;
				for(row=0;row<input.length;row++)
					if(row<features.get(pos).intValue())
						sum(input[row],sum[groups[row]]);
					else pos++;
				
				mergeRenormilizer();
				
				// normalize
				pos=0;
				for(row=0;row<input.length;row++)
					if(row<features.get(pos).intValue()){
						for(col=0;col<input[row].length;col++)
							input[row][col]/=sum[groups[row]][col];
					}else pos++;
				features.remove(features.size()-1);
			}else{
				for(row=0;row<input.length;row++)
					sum(input[row],sum[groups[row]]);
		
				mergeRenormilizer();

				// normalize
				for(row=0;row<input.length;row++)
					for(col=0;col<input[row].length;col++)
						input[row][col]/=sum[groups[row]][col];
			}
		}else{ 
			if(isUsingFeatures()){
				features.add(Integer.MAX_VALUE);
				int pos = 0;
				for(row=0;row<input.length;row++)
					if(row<features.get(pos).intValue())
						sum(input[row],sum[0]);
					else pos++;
				
				
				mergeRenormilizer();
				
				// normalize
				pos = 0;
				for(row=0;row<input.length;row++)
					if(row<features.get(pos).intValue()){
						for(col=0;col<input[row].length;col++)
							input[row][col]/=sum[0][col];
					}else pos++;
				features.remove(features.size()-1);
			}else{
				for(row=0;row<input.length;row++)
					sum(input[row],sum[0]);
				
				mergeRenormilizer();
				
				// normalize
				for(row=0;row<input.length;row++)
					for(col=0;col<input[row].length;col++)
						input[row][col]/=sum[0][col];
			}
		}
		//if(shufflePair()){
			for(row=0;row<sum.length;row++)
				for(col=0;col<sum[row].length;col++)
					sum[row][col] = 1.0;
			
			mergeRenormilizer();
		//}
		return input;
	}
	
	@Override
	public void updateRenormalization(double[] vect,final int row,boolean ignore){
		if(!shufflePair())return;
		if(isUsingFeatures() && features.contains(new Integer(row))){
			// no need to take into account because this element must be ignored
			return;
		}
		if(isUsingGroups()){
			if(ignore) difference(vect,sum[groups[row]]);
			else sum(vect,sum[groups[row]]);
		}else{ 
			if(ignore) difference(vect,sum[0]);
			else sum(vect,sum[0]);
		}
		mergeRenormilizer();
	}
	protected void prepareRenormalization(double[] vect,final int row,boolean ignore){
		if(isUsingFeatures() && features.contains(new Integer(row))){
			// no need to take into account because this element must be ignored
			return;
		}
		if(isUsingGroups()){
			if(ignore) difference(vect,sum[groups[row]]);
			else sum(vect,sum[groups[row]]);
		}else{ 
			if(ignore) difference(vect,sum[0]);
			else sum(vect,sum[0]);
		}
	}
	@Override
	public void simulateRenormalization(double[] shuffledFirst,double[] shuffledSecond,final int first,final int second){
		if(!shufflePair())return;
		final double[][] ptr;
		final int[] groupPtr;
		{
			final boolean processFirst	= !(isUsingFeatures() && features.contains(new Integer(first)));
			final boolean processSecond	= !(isUsingFeatures() && features.contains(new Integer(second)));
			if(processFirst){
				if(processSecond){
					ptr			= new double[][]{shuffledFirst,shuffledSecond};
					groupPtr	= isUsingGroups()?new int[]{groups[first],groups[second]}:new int[]{0,0};
				}else{
					ptr		= new double[][]{shuffledFirst};
					groupPtr= isUsingGroups()?new int[]{groups[first]}:new int[]{0};
				}
			}else if(processSecond){
				ptr		= new double[][]{shuffledFirst};
				groupPtr= isUsingGroups()?new int[]{groups[second]}:new int[]{0};
			}else{
				return;//do not process
			}
		}
		
		updateRenormalization(shuffledFirst,first,false);
		updateRenormalization(shuffledSecond,second,false);
		
		int col,vect;
		double[] tmp = new double[ptr.length];
		
		for(col=0;col<shuffledFirst.length;col++){
			for(vect=0;vect<tmp.length;vect++){
				tmp[vect]		=  ptr[vect][col];
				ptr[vect][col]	/= sum[groupPtr[vect]][col];
			}
			for(vect=0;vect<tmp.length;vect++){
				if(! Double.isNaN(tmp[vect]))
					sum[groupPtr[vect]][col] -= tmp[vect];
			}
		}
		mergeRenormilizer();
	}
	@Override
	public double[][] normalizeOutput(final double[][] input,final double[][] output,Measure measure){
		//nothnig to do
		return output;
	}
	private static void sum(final double[] row,final double sum[]){
		for(int col=0;col<sum.length;col++)
			if(! Double.isNaN(row[col]))
				sum[col] += row[col];
	}
	private static void difference(final double[] row,final double sum[]){
		for(int col=0;col<sum.length;col++)
			if(! Double.isNaN(row[col]))
				sum[col] -= row[col];
	}
	@Override
	public boolean hasNativeImplementation() {
		return true;
	}
	@Override
	public boolean processInput() {
		return true;
	}
	@Override
	public boolean processOutput() {
		return false;
	}
	/**
	 * Enable or disable pairwise shuffle
	 * @param shuffle true to enable, false to disable
	 */
	public void usePairShuffle(final boolean shuffle){
		this.shufflePair=shuffle;
	}
	@Override
	public boolean shufflePair() {
		return shufflePair;
	}
}
