addpath decompose/
addpath learning/
addpath utilities/
addpath FastICA_25/
%%
load('data/data_mat.mat','freq','X','subs');
subjects=subs;
n_secs=4;
max_blocks=25;
n_learn=25;%blocks to train with, can be less than or equal to 25
nmonte=1;%only run from one initialization

%learning param
Nlearn_methods=3;
%Model parameters
coeff_per_filter=1.25*[200 100 80 40];%number of parameters per filter
subsample_rates=[1 1 1 1];%decimation rates
durations=coeff_per_filter.*subsample_rates/freq;
fprintf('The durations will be %0.2g s\n',durations);
nfilt=4;% number of filters per decimation rate
Filters_per_scale=nfilt*ones(size(coeff_per_filter));
do_nonneg=1;
freq_thresh=2/60;
coverage=1;
shared_model_params={'model_sizes',cat(2,coeff_per_filter(:),Filters_per_scale(:)),...
    'subsample_rates',subsample_rates(:),'freq',freq,...
    'coverage',coverage,...
    'nonneg_flag',do_nonneg,...
    'approximation_passes',4};


allresults=cell(numel(subjects),1);
timing=zeros(numel(subjects),Nlearn_methods,nmonte);

for sub_ii=1:numel(subjects)
    %learned dictionaries storage
    results=cell(Nlearn_methods,nmonte,2);
    subject=subjects{sub_ii};
    XX=X{sub_ii};
    section=1;%training section
    x=reshape(XX(:,(section-1)*max_blocks+(1:n_learn)),[],1);
    block_start_indices=1+size(XX,1):size(XX,1):numel(x);% 1 is always a block start
    xlearn=x;
    signal_params=struct('block_starts',block_start_indices,'freq',freq);
    for monteii=1:nmonte
        for method_ii=1:Nlearn_methods
            t2=tic;
            switch method_ii
                case 1
                    method_name='MP-SVD';
                    freq_thresh=3/numel(xlearn);
                    c=cat(2,shared_model_params,{'minimum_freq',freq_thresh});
                    model_params=cell2struct(c(2:2:end),c(1:2:end),2);
                    [MSDict,err_ratio,atoms ] = multistageWaveformLearning(xlearn,signal_params,model_params);
                case 2 % single channel ICA
                    method_name='SC-ICA';
                    c=cat(2,shared_model_params,{'max_ica_iterations',50,'fraction_ica_comp',.4});
                    model_params=cell2struct(c(2:2:end),c(1:2:end),2);
                    [MSDict,err_ratio,atoms ] = multistageICASelection(xlearn,signal_params,model_params);
                case 3 %group-sparse Gabor dictionary
                    method_name='Gabor-subset';
                    c=cat(2,shared_model_params,{'Ncandidate',400});
                    model_params=cell2struct(c(2:2:end),c(1:2:end),2);
                    MSDict = multistageGaborSelection(xlearn,signal_params,model_params);
            end
            results{method_ii,monteii,1}=MSDict;
            results{method_ii,monteii,2}=method_name;
            timing(sub_ii,method_ii,monte_ii)=toc;
        end
    end
    allresults{sub_ii}=results;
end
%% Save results
save('single_channel_filters.mat','allresults') %will overwrite a previous run
%% Load results and plot waveforms learned for the different datasets
%  also calculates similarity to Gabor wavelets
load('single_channel_filters.mat','allresults')
load('data/data_mat.mat','freq');
Nlearn_methods=size(allresults{1},1);
subjects={'A','B','C','D','E'};
nmonte=1;
monte_ii=1;
for method_ii=1:Nlearn_methods
    %%
    method_name=allresults{1}{method_ii,1,2};
    scales=numel(allresults{1}{method_ii,1,1});
    X=cell(scales,1);
    Y=cell(scales,1);
    for sub_ii=1:numel(subjects)
        subject=subjects{sub_ii};
        results=allresults{sub_ii};
        for scale=1:scales
            A=results{method_ii,monte_ii,1}{scale};
            X{scale}=cat(2,X{scale},A);
            Y{scale}=cat(1,Y{scale},cat(2,monte_ii*ones(size(A,2),1),...
                (1:size(A,2))',repmat(sub_ii,size(A,2),1)));
            % Y is arranged: monte carlo, filter index, subject
        end
    end
    % create set of Discrete prolate spheroidal (Slepian) sequences
    % for estimating frequency
    NFFT=400;
    h_dpss=permute(dpss(NFFT,2),[1 3 2]);
    fx=linspace(0,freq/2,NFFT/2+1);
    
    figure
    for scale=1:scales
        n=size(X{scale},1);
        % create set of Gabor wavelets for matching
        Ncandidate=400;
        TE=kron(ones(1,30),logspace(log10(5),log10(n),20));
        FC=kron(logspace(log10(1),log10(freq/4),30),ones(1,20));
        keep_atoms=(1./FC)<TE/freq;
        TE=TE(keep_atoms);
        FC=FC(keep_atoms);
        GF=[];
        phases=[0 pi/4 pi/2];
        for phase_ii=phases;
            gf=cos(phase_ii+2*pi*bsxfun(@times,FC,(0:n-1)')/freq);
            gf=gf.*exp(-bsxfun(@rdivide,linspace(-n/2,n/2,n)'.^2,2*TE.^2));
            GF=cat(2,GF,gf);
        end
        TE=repmat(TE,1,numel(phases));
        FC=repmat(FC,1,numel(phases));
        
        % plot each waveform, changing color if it matches Gabor wavelet
        ii=0;
        for sub_ii=1:numel(subjects)
            
            x=X{scale}(:,Y{scale}(:,3)==sub_ii );
            n=size(x,1);
            
            xh=bsxfun(@times,cat(1,zeros(NFFT-n,size(x,2)),x),h_dpss);
            Sc=mean(abs(fft(xh,NFFT)).^2,3);
            [flow,fhigh]=calcWidth(Sc(1:NFFT/2+1,:),fx(:));
            [~,fc_id]=sort(fhigh+flow);
            x=x(:,fc_id);
            Sc=Sc(1:NFFT/2+1,fc_id);
            x=bsxfun(@rdivide,.5*x,max(abs(x)));
            [~,fid]=max(Sc);
            fx=linspace(0,freq/2,NFFT/2+1);
            peak_freqs=fx(fid);
            [XC,IDX]=max(1-fasterXcorr2(GF,x,'abs'));
            x=cat(1,x,nan(round(size(x,1)*.1),size(x,2)));
            
            subplot(1,scales,scale)
            h=plot((0:size(x,1)-1)/freq*1000,ii+bsxfun(@minus,x,0:size(x,2)-1),'-k','linewidth',1.2);
            hold all
            threshhold=.8;
            set(h(XC>=threshhold),'color',[.5 .5 .3]);%,'linewidth',2);
            fc=FC(IDX);
            
            for jj=find(XC>=0.5 & fc>1)
                if round(fc(jj))<10
                    text(.95*size(x,1)/freq*1000,ii-jj+1,sprintf('%0.01fHz',fc(jj)),'fontsize',8)
                else
                    text(.95*size(x,1)/freq*1000,ii-jj+1,sprintf('%0.0fHz',fc(jj)),'fontsize',8)
                end
            end
            if scale==1
                text(-1.1*size(x,1)/freq*1000,ii,['Dataset ' subjects{sub_ii}],'fontsize',12)
            end
            ii=ii-size(x,2);
        end
        ii=ii-1;
        
        set(gca,'ytick',[]);%fliplr(0:-5:-20+1),'yticklabel',flipud(subjects(:)))
        xt=5^floor(log10(n/freq)/log10(5))*1000;
        
        plot([0*[1 1 1]  [1 1 1 ]*xt]',1+.5+[-.2 .2 0 0  .2 -.2]','-k','linewidth',1.5),
        hold all;
        text(1,2.5,[num2str(xt),' ms'],'fontsize',12)
        text(-.1*size(x,1)/freq*1000,3.5,sprintf('Scale %i',scale),'fontsize',12)
        axis tight
        set(gca,'xlim',[-size(x,1)*.05 1.5*size(x,1)]/freq*1000);
        set(gca,'ylim',[-23.5 2])
        set(gca,'xtick',[],'visible','off')
    end
    
    set(gcf,'PaperPositionMode','auto')
    saveas(gcf,sprintf('sc_%s_waves.eps',method_name),'epsc2')
end

