addpath decompose/
addpath learning/
addpath utilities/
addpath FastICA_25/
addpath mptk_interface/
%%
do_mptk=false;% Change to one to showcase MPTK gabor decompostion
Nlearn_methods=3;
Nmethods=Nlearn_methods+1*do_mptk;

%create signal
slow_wave1=gausswin(128).*cos(2*pi*((0:128-1))'/50);
spindle1=gausswin(32).*cos(2*pi*((0:32-1))'/10);
spindle2=gausswin(128).*cos(2*pi*((0:128-1))'/15);
true_waves={cat(2,-(100:-1:1)',(1:100)')/100,spindle1,spindle2};

signal_length=80000;
%
group_dim=cellfun(@(x) size(x,2),true_waves);
Nwaves=sum(group_dim);

Nevents=400;
flat_atoms=cat(2,randi(signal_length,Nevents,1),randi(Nwaves,Nevents,1),ones(Nevents,1));
atoms=flatAtomsToGroupedAtoms(flat_atoms,group_dim);
Nkeep=sum(cellfun(@(x) size(x,1),atoms));
marked_pp=atomsToSources(atoms,group_dim,signal_length,Nkeep);
[~,components]=sourcesToSignalComponents(true_waves,marked_pp);
xall=sum(cell2mat(components(:)'),2);
noise=randn(size(xall));
xall=xall+.2*noise;
%% divide training, and testing

n_secs=4;
nstarts=ceil(linspace(1,signal_length,n_secs));
use_time=nstarts(1):nstarts(2);
xlearn=xall(use_time);
use_time=nstarts(end-1):nstarts(end);
x=xall(use_time);

test_components=cellfun(@(x) x(use_time),components,'uni',0);
block_start_indices=[];


%%

%Model parameters
freq=1;
coeff_per_filter=200;%number of parameters per filter
subsample_rates=1;%decimation rates
durations=coeff_per_filter.*subsample_rates/freq;
nfilt=4;
Filters_per_scale=nfilt*ones(size(coeff_per_filter));
%learning parameters
coverage=1;
learning_approximation_passes=4;
do_nonneg=1;
%learned dictionaries storage
learned_waveforms=cell(Nlearn_methods,2);
signal_params=struct('block_starts',block_start_indices,'freq',freq);
shared_model_params={'model_sizes',cat(2,coeff_per_filter(:),Filters_per_scale(:)),...
    'subsample_rates',subsample_rates,'freq',freq,...
    'coverage',coverage,...
    'nonneg_flag',do_nonneg,...
    'approximation_passes',learning_approximation_passes};
for method_ii=1:Nlearn_methods
    %%
    switch method_ii
        case 1
            method_name='MP-SVD';
            freq_thresh=3/numel(xlearn);
            c=cat(2,shared_model_params,{'minimum_freq',freq_thresh});
            model_params=cell2struct(c(2:2:end),c(1:2:end),2);
            [MSDict,err_ratio,atoms ] = multistageWaveformLearning(xlearn,signal_params,model_params);
        case 2 % single channel ICA
            method_name='SC-ICA';
            c=cat(2,shared_model_params,{'max_ica_iterations',50,'fraction_ica_comp',.4});
            model_params=cell2struct(c(2:2:end),c(1:2:end),2);
            [MSDict,err_ratio,atoms ] = multistageICASelection(xlearn,signal_params,model_params);
        case 3 %group-sparse Gabor dictionary
            method_name='Gabor-subset';
            c=cat(2,shared_model_params,{'Ncandidate',400});
            model_params=cell2struct(c(2:2:end),c(1:2:end),2);
            MSDict = multistageGaborSelection(xlearn,signal_params,model_params);
    end
    learned_waveforms{method_ii,1}=MSDict;
    learned_waveforms{method_ii,2}=method_name;
end
%%
c={...
    'coverage',1,...
    'nonneg_flag',do_nonneg,...
    'approximation_passes',1};
model_params=cell2struct(c(2:2:end),c(1:2:end),2);

decomp_results=cell(Nmethods,5);
for method_ii=1:Nmethods
    if method_ii<=Nlearn_methods
        method_name=learned_waveforms{method_ii,2};
        AA=learned_waveforms{method_ii,1};
        group_dim=cellfun(@(x) size(x,2),AA);
        tic
        atoms_all = multistageMP(x,AA,signal_params,model_params);
        Natoms=sum(cellfun(@(x) size(x,1),atoms_all));
        estimate_sources=atomsToSources(atoms_all,group_dim,numel(x),Natoms);
        [VV,VV_filters]=sourcesToSignalComponents(AA,estimate_sources);
        toc
        fprintf('%i atoms\n',Natoms);
    else
        %use same number of atoms as previous
        method_name='Gabor';
        Maxlength=max(coeff_per_filter);
        [VV,VV_filters,AA] = mptkDecompose_gabor(x,Maxlength,Natoms);
    end
    approx=sum(cell2mat(VV_filters(:)'),2);
    error=norm(x-approx);
    decomp_results(method_ii,:)={VV_filters,Natoms,AA,method_name,error};
end
decomp_results

%%
total_test=sum(cell2mat(test_components(:)'),2);

section=numel(nstarts);
for method_ii=1:Nmethods
    VV_filters=decomp_results{method_ii,1};
    AA=decomp_results{method_ii,3};
    method_name=decomp_results{method_ii,4};
    decomp_results{method_ii,5}
    num_scale=numel(AA);
    Cbase=repmat([27,158,119
        117,112,179
        102,166,30
        231,41,138
        166,118,29]/255,ceil(num_scale/5),1);
    
    Cmat=[];
    for ii=1:num_scale
        Cmat=cat(1,Cmat,repmat(Cbase(ii,:),size(AA{ii},2),1));
    end
    colors=cat(2,{[0 0 0],[.5 .5 .5]},mat2cell(Cmat,ones(size(Cmat,1),1),3)',{[217,95,2]/255});
    y=total_test;
    
    time_start=(nstarts(section)-1);
    time=time_start+linspace(0,(numel(use_time)-1),numel(use_time));
    xlim=time([1 8000]);
    Scale_factor=10^round(log10(1/max(abs(x))));
    figure
    
    allV=cat(2,y(:),x(:),...
        sum(cell2mat(VV_filters(:)'),2),cell2mat(VV_filters(:)'));
    all_pve=1-mean(bsxfun(@minus,allV(:,2:end),allV(:,1)).^2)/mean(allV(:,1).^2);
    %
    window_indicator=time>=xlim(1) & time<=xlim(2);
    filter_energy_in_window=cellfun(@(x) full(sum(abs(x(window_indicator)))),VV_filters);
    VV_ene=VV_filters(filter_energy_in_window>0);
    
    these_colors=colors([1 1 2 2+find(filter_energy_in_window>0)' end]);
    V_ene=cat(2,y,x,...
        sum(cell2mat(VV_ene(:)'),2),...
        cell2mat(VV_ene(:)'));
    
    
    pve=1-mean(bsxfun(@minus,V_ene(window_indicator,2:end),V_ene(window_indicator,1)).^2)/mean(V_ene(window_indicator,1).^2);
    
    
    
    approx_strs=cell(1,numel(VV_ene)+2);
    approx_strs{1}=cat(2,'Approx.');
    
    for ii=1:numel(VV_ene)
        approx_strs{ii+1}=sprintf('%i',ii);
    end
    yoffsets=[0 1 2 2+[1:numel(VV_ene)]];
    
    h=plot(time,bsxfun(@plus,bsxfun(@rdivide,.85*V_ene,max(abs(V_ene(time>=xlim(1) & time<=xlim(2) ,:)))),yoffsets));
    for ii=1:numel(h)
        set(h(ii),'color',these_colors{ii},'linewidth',1)
    end
    
    set(gca,'ytick',yoffsets,'yticklabel',['Signal','Observed',approx_strs])
    %title('Min-max normalized')
    set(gca,'fontsize',12)
    %xlabel('time (s)','fontsize',12)
    set(gca,'xlim',xlim);
    set(gca,'ylim',[yoffsets(1)-2.75 yoffsets(end)+1]);
    set(gca,'xtick',[],'box','off');
    
    str=sprintf('%s variance explained: %i%c',method_name,round(100*all_pve(2)),'%');
    title(str,'fontsize',12);
    if method_ii<=Nlearn_methods
        set(gcf,'position',[   360   338   560    175])
    else
    end
    if do_nonneg==1
        str='';
    else
        str='neg';
    end
    set(gcf,'PaperPositionMode','auto')
    saveas(gcf,sprintf('toy_%s%s',method_name,str),'epsc2')
end