#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Dec  8 14:36:35 2020

@author: gilleschardon
"""

import numpy as np
import scipy.linalg as la

# some fast matrix products

def proddamastranspose(D, Data):
    x = np.real( np.sum( (D.conj().T @ Data) * D.T, 1))
    
    return x

def proddamas(D, x):
    
    z = D @ (x * D.conj()).T
    
    return proddamastranspose(D, z)
    
def proddamasdr(D, x):
    
    z = D @ (x * D.conj()).T
    z = z - np.diag(np.diag(z))
    
    return proddamastranspose(D, z)
    
# unconstrained least-squares problem
def solve_ls(D, Data):
    Gram = np.abs(D.conj().T @ D) ** 2
    return la.solve(Gram, proddamastranspose(D, Data), assume_a='pos')

# unconstrained least-squares problem for diagonal removal
def solve_ls_dr(D, Data):
    aD2 = np.abs(D ** 2)
    Gram = np.abs(D.conj().T @ D) ** 2 - aD2.T @ aD2
    return la.solve(Gram, proddamastranspose(D, Data), assume_a='pos')


# D dictionary of sources
# Data SCM
# epsilon tolerance
def cmf_nnls(D, Data, epsilon):
    
    n = D.shape[1]
    R = np.ones([n], dtype=bool) # active set
    N = np.arange(n);
    x = np.zeros([n])
    
    # residual
    Ay = proddamastranspose(D, Data)
    w = Ay - proddamas(D, x)
    
    it = 0
    
    while np.any(R) and (np.max(w[R]) > epsilon):
        
        print(f"iter {it} tol {np.max(w[R])}")
        it = it + 1

        # update of the active set        
        idx = np.argmax(w[R])
        Ridx = N[R]
        idx = Ridx[idx]        
        R[idx] = 0
        
        # least-square in the passive set        
        s = np.zeros(x.shape)      
        s[~R] = solve_ls(D[:, ~R], Data)
        
        # removal of negative coefficients
        while np.min(s[~R]) <= 0:
            
            Q = (s <= 0) & (~R)
            alpha = np.min(x[Q] / (x[Q] - s[Q]))
            x = x + alpha * (s - x)
            R = (( x <= epsilon) & ~R) | R
            
            s = np.zeros(x.shape)
            s[~R] = solve_ls(D[:, ~R], Data)
            
        # update of the solution
        x = s
        # update of the residual
        w = Ay - proddamas(D, x)

    return x


# same thing, sightly different, for diagonal removal
def cmf_nnls_dr(D, Data, epsilon):
    
    Data = Data - np.diag(np.diag(Data))
    n = D.shape[1]    
    R = np.ones([n], dtype=bool)  
    N = np.arange(n);    
    x = np.zeros([n])    
    Ay = proddamastranspose(D, Data)    
    w = Ay - proddamasdr(D, x)
    
    while np.any(R) and (np.max(w[R]) > epsilon):
        idx = np.argmax(w[R])
        Ridx = N[R]
        idx = Ridx[idx]        
        R[idx] = 0
        s = np.zeros(x.shape)
        s[~R] = solve_ls_dr(D[:, ~R], Data)    
        while np.min(s[~R]) <= 0:            
            Q = (s <= 0) & (~R)
            alpha = np.min(x[Q] / (x[Q] - s[Q]))
            x = x + alpha * (s - x)
            R = (( x <= epsilon) & ~R) | R
            s = np.zeros(x.shape)
            s[~R] = solve_ls_dr(D[:, ~R], Data)
        x = s
        w = Ay - proddamasdr(D, x)
    return x