## This file is part of mlpy.
## Imputing.
    
## This code is written by Davide Albanese, <albanese@fbk.eu>.
## (C) 2009 Fondazione Bruno Kessler - Via Santa Croce 77, 38100 Trento, ITALY.

## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.

## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
## GNU General Public License for more details.

## You should have received a copy of the GNU General Public License
## along with this program.  If not, see <http://www.gnu.org/licenses/>.

__all__ = ['purify', 'knn_imputing']


import numpy as np


def purify(x, th0=0.1, th1=0.1):
    """
    Return the matrix x without rows and cols
    containing respectively more than
    th0 * x.shape[1] and th1 * x.shape[0] NaNs.

    :Returns:

      (xout, v0, v1) : (2d ndarray, 1d ndarray int, 1d ndarray int)
                     v0 are the valid index at dimension 0 and
                     v1 are the valid index at dimension 1

    Example:

    >>> import numpy as np
    >>> import mlpy
    >>> x = np.array([[1,      4,      4     ],
    ...               [2,      9,      np.NaN],
    ...               [2,      5,      8     ],
    ...               [8,      np.NaN, np.NaN],
    ...               [np.NaN, 4,      4     ]])
    >>> y = np.array([1, -1, 1, -1, -1])
    >>> x, v0, v1 = mlpy.purify(x, 0.4, 0.4)
    >>> x
    array([[  1.,   4.,   4.],
           [  2.,   9.,  NaN],
           [  2.,   5.,   8.],
           [ NaN,   4.,   4.]])
    >>> v0
    array([0, 1, 2, 4])
    >>> v1
    array([0, 1, 2])
    """

    missing = np.where(np.isnan(x))

    # dim0 purifying
    if missing[0].shape[0] != 0:
        nm0tmp =  np.bincount(missing[0]) / float(x.shape[1])
        nm0 = np.zeros(x.shape[0])
        nm0[0:nm0tmp.shape[0]] = nm0tmp
        valid0 = np.where(nm0 <= th0)[0]
    else:
        valid0 = np.arange(x.shape[0])

    # dim1 purifying
    if missing[0].shape[0] != 0:
        nm1tmp = np.bincount(missing[1]) / float(x.shape[0])
        nm1 = np.zeros(x.shape[1])
        nm1[0:nm1tmp.shape[0]] = nm1tmp
        valid1 = np.where(nm1 <= th1)[0]
    else:
        valid1 = np.arange(x.shape[1])

    # rebuild matrix
    xout = x[valid0][:, valid1].copy()

    return xout, valid0, valid1


def euclidean_distance(x1, x2):
    """
    Euclidean Distance.

    Compute the Euclidean distance between points
    x1=(x1_1, x1_2, ..., x1_n) and x2=(x2_1, x2_2, ..., x2_n)   
    """

    d  = x1 - x2
    du = d[np.logical_not(np.isnan(d))]

    if du.shape[0] != 0:
        return np.linalg.norm(du) 
    else:
        return np.inf


def euclidean_squared_distance(x1, x2):
    """
    Euclidean Distance.

    Compute the Euclidean squared distance between points
    x1=(x1_1, x1_2, ..., x1_n) and x2=(x2_1, x2_2, ..., x2_n)   
    """
    
    d  = x1 - x2
    du = d[np.logical_not(np.isnan(d))]

    if du.shape[0] != 0:
        return np.linalg.norm(du)**2
    else:
        return np.inf
   

def knn_core(x, k, dist='se', method='mean'): 
    
    if dist == 'se':
        distfunc = euclidean_distance
    elif dist == 'e':
        distfunc = euclidean_squared_distance
    else:
        raise ValueError("dist %s is not valid" % dist)

    if method == 'mean':
        methodfunc = np.mean
    elif method == 'median':
        methodfunc = np.median
    else:
        raise ValueError("method %s is not valid" % method)
    
    midx = np.where(np.isnan(x))
    distance = np.empty(x.shape[0], dtype=float)
    midx0u = np.unique(midx[0])
    mv = []
    
    for i in midx0u:
        
        midx1 = midx[1][midx[0] == i]
        
        for s in np.arange(x.shape[0]):
            distance[s] = distfunc(x[i], x[s])
            idxsort = np.argsort(distance)

        for j in midx1:
            idx = idxsort[np.logical_not(np.isnan(x[idxsort, j]))][0:k]
            mv.append(methodfunc(x[idx, j]))

    xout = x.copy()
    for m, (i, j) in enumerate(zip(midx[0], midx[1])):
        xout[i, j] = mv[m]

    return xout
    

def knn_imputing(x, k, dist='e', method='mean', y=None, ldep=False):
    """
    Knn imputing

    :Parameters:
      x : 2d ndarray float (samples x feats)
        data to impute
      k : integer
        number of nearest neighbor
      dist : string ('se' = SQUARED EUCLIDEAN, 'e' = EUCLIDEAN)
           adopted distance 
      method : string ('mean', 'median')
             method to compute the missing values
      y : 1d ndarray
        labels
      ldep : bool
           label depended (if y != None)

    :Returns:
      xout : 2d ndarray float (samples x feats)
           data imputed

    >>> import numpy as np
    >>> import mlpy
    >>> x = np.array([[1,      4,      4     ],
    ...               [2,      9,      np.NaN],
    ...               [2,      5,      8     ],
    ...               [8,      np.NaN, np.NaN],
    ...               [np.NaN, 4,      4     ]])
    >>> y = np.array([1, -1, 1, -1, -1])
    >>> x, v0, v1 = mlpy.purify(x, 0.4, 0.4)
    >>> x
    array([[  1.,   4.,   4.],
           [  2.,   9.,  NaN],
           [  2.,   5.,   8.],
           [ NaN,   4.,   4.]])
    >>> v0
    array([0, 1, 2, 4])
    >>> v1
    array([0, 1, 2])
    >>> y = y[v0]
    >>> x = mlpy.knn_imputing(x, 2, dist='e', method='median')
    >>> x
    array([[ 1. ,  4. ,  4. ],
           [ 2. ,  9. ,  6. ],
	   [ 2. ,  5. ,  8. ],
	   [ 1.5,  4. ,  4. ]])
    """

    xout = x.copy()

    if ldep and y != None:
        classes = np.unique(y)

        for c in classes:
            xtmp = knn_core(x=x[y == c], k=k, dist=dist, method=method)           
            xout[y == c, :] = xtmp
    else:
        xout = knn_core(x=x, k=k, dist=dist, method=method)

    return xout
