#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Dec  8 14:21:40 2020

@author: gilleschardon
"""

import numpy as np
import matplotlib.pyplot as plt

from scipy.io import loadmat
import damas


mat = loadmat("damasdemo.mat")

# dictionary of sources
def dictionary(PX, PS, k):
    dx = PX[:, 0:1] - PS[:, 0:1].T
    dy = PX[:, 1:2] - PS[:, 1:2].T
    dz = PX[:, 2:3] - PS[:, 2:3].T

    d = np.sqrt(dx**2 + dy**2 + dz**2);
    
    D = np.exp( -1j * k * d) / d
    
    return D

Data = mat['Data']

# we select the 64 inner microphones (no fundamental reason, just to have
# an array of reasonable size)
N = 64;
Pmic = mat['Pmic']
Norms = Pmic[:, 0]**2 + Pmic[:, 1]**2
idx = np.argsort(Norms)[:N]

Pmic = Pmic[idx, :]

Data = Data[:, idx]
Data = Data[idx, :]

# source grid
Lx = 90
Ly = 30
xx = np.linspace(-2, 1, Lx+1)
yy = np.linspace(-0.75, -0.25, Ly+1)

xx = xx[:-1]
yy = yy[:-1]

# plotting grid
xxp = np.linspace(-2, 1, Lx+1)
yyp = np.linspace(-0.75, -0.25, Ly+1)

Xg, Yg = np.meshgrid(xx, yy);
Xp, Yp = np.meshgrid(xxp, yyp);

Z = 4.3;

# dictionary
D = dictionary(Pmic,
               np.vstack([Xg.ravel(), Yg.ravel(), np.ones(Lx*Ly)*Z]).T,
               mat['k'])


# beamforming
bf_map = np.real(  np.sum((D.conj().T @ Data) * D.T, 1))

tol = 100
# CMF-NNLS
cmf_map = damas.cmf_nnls(D, Data, tol)
# CMF-NNLS with diagonal removal
cmf_map_dr = damas.cmf_nnls_dr(D, Data, tol)
# CMF-NNLS with noise estimation
Dnoise = np.hstack([D, np.eye(D.shape[0])])
cmf_map_noise = damas.cmf_nnls(Dnoise, Data, tol)


#%%
plt.figure()
plt.pcolor(Xp, Yp, np.reshape(bf_map, [30, 90]), cmap='hot')
plt.axis('image')
plt.title('Beamforming')

plt.figure()
plt.pcolor(Xp, Yp, np.reshape(cmf_map, [30, 90]), cmap='hot')
plt.axis('image')
plt.title('CMF-NNLS')

plt.figure()
plt.pcolor(Xp, Yp, np.reshape(cmf_map_dr, [30, 90]), cmap='hot')
plt.axis('image')
plt.title('CMF-NNLS diagonal removal')

plt.figure()
plt.pcolor(Xp, Yp, np.reshape(cmf_map_noise[:Xg.size], [30, 90]), cmap='hot')
plt.axis('image')
plt.title('CMF-NNLS noise estimation')

plt.figure()
plt.stem(cmf_map_noise[Xg.size:])
plt.title('CMF-NNLS noise estimation - noise')
