info_func={...
    @matrix_info_fro_embed_general;%1
    @matrix_info_cosine_embed_general;%2
    @matrix_info_hsic_embed_general;%3
    @matrix_info_cka_embed_general;%4
    @matrix_info_chernoff_embed;%5
    @matrix_info_quantum_hellinger_embed_general;%5
    @bures_lb_embed_info;%6
    @bures_super_lb_matrix_info_embed;%7
    @(X) matrix_info_bartlett(X'*X);%8
    @(X) matrix_info_lawley(X'*X);%9
    };
func_names={     'Euclid.',...
    'Cosine',...
    'HSIC',...
    'CKA',...
    'Chernoff',...
    'QH',...
    'Bures',...
    'Sub-Bures',...
    'Bartlett',...
    'Lawley'};

if exist('../results/kpca_all_measure_select_ksize.mat','file')
    load('../results/kpca_all_measure_select_ksize','allresults','datasets')
else
    prop_train=1/2;
    nmonte = 20;    
    datasets={'datasets-uci-wisconsin_breast_cancer_wbdc',...
        'sonar_uci',...
        'datasets-uci-ionosphere',...
        'datasets-uci-parkinson',...
        'datasets-uci-iris',...
        'datasets-uci-glass',...
        'datasets-uci-ecoli'};
        getData('uci');    
    allresults=cell(numel(datasets),1);
    
    for datasetii=1:numel(datasets)
        setname=datasets{datasetii};
        load(['../data/uci/',setname],'all_labels','pX')
        inputX=pX;
        inputX=bsxfun(@rdivide,inputX,std(inputX));
        
        [all_labels,resort]=sort(all_labels);
        un_lab=unique(all_labels);%unique class label
        inputX=inputX(resort,:);
        [n,d]=size(inputX);
        %assume input is n by d
        %that is, columns are dimensions
        x=inputX;
        x2=sum(x.^2,2);
        D2=bsxfun(@plus,bsxfun(@plus,-2*(x*x'),x2),x2');
        D2=D2.*(D2>0);% ensure no small negative values
        %ensure D2 is symmetric
        D2=(D2+D2')/2;
        
        %two stage, first select a bandwidth for Gaussian kernel then pick dimension
        Dtemp=sort(D2);
        sig_max = sqrt(max(Dtemp(end,:)));
        Dtemp(Dtemp==0)=inf;
        dt=sqrt(min(reshape(Dtemp(2:end,:),[],1)));
        sig_min = 2*dt;
        clear Dtemp
        
        N_sigma=15;
        sigmas=linspace(sig_min,sig_max,N_sigma)';
        kernel_size_results=zeros(N_sigma,2);
        idx=ceil(N_sigma*.25);
        optw=dt+.15*(sig_max-dt);
        Kfull=exp(-full(D2)/(2*optw^2));
        
        %find embedding
        tol=sqrt(sum(Kfull(:).^2))*eps;
        [V,E]=eig(Kfull);
        p=diag(E);
        [~,resort_idx]=sort(p,'descend');
        p=p(resort_idx);
        V=V(:,resort_idx);
        V=V(:,p>tol);
        p=p(p>tol);
        Zfull=real(diag(sqrt(p))*V');
        
        [~,resort_idx]=sort(p.*sum(V).^2','descend');
        p=p(resort_idx);
        V=V(:,resort_idx);
        Zp=real(diag(sqrt(p))*V');
        
        
        %now do it with centered version
        kbar=mean(Kfull);
        Kc=bsxfun(@minus,bsxfun(@minus,Kfull,kbar),kbar')+mean(kbar);
        tol=0;
        [V,E]=eig(Kc);
        [p,resort_idx]=sort(diag(E),'descend');
        V=V(:,resort_idx);
        V=V(:,p>tol);
        p=p(p>tol);
        Zc=real(diag(sqrt(p))*V');
        
        embedding_dimensions=unique(round(linspace(1,size(Zfull,1),100)));
        embedding_dimensions=1:size(Zfull,1);
        N_diff_dims=numel(embedding_dimensions);
        
        n_methods=size(info_func,1);
        eval_results=zeros(N_diff_dims,n_methods*2);
        Zs=cell(N_diff_dims,1);
        Zps=Zs;
        Zcs=Zs;
        
        kk=1;
        
        for kk=1:N_diff_dims
            Ztemp=Zfull(1:embedding_dimensions(kk),:);
            dim_z=size(Ztemp,1);
            
            Zs{kk}=Ztemp;
            Zsphr=bsxfun(@times,Ztemp,1./sqrt(sum(Ztemp.^2,1)));%oblique manifold
            Ztemp=Zc(1:min(size(Zc,1),dim_z),:);
            Zcs{kk}=Ztemp;
            Zc_sphr=bsxfun(@times,Ztemp,1./sqrt(sum(Ztemp.^2,1)));%oblique manifold
            
            for ii=1:n_methods
                info_val=info_func{ii,1}(Zsphr);
                info_val2=info_func{ii,1}(Zc_sphr);
                eval_results(kk,[1+(ii-1)*2 ii*2])=[info_val info_val2];
            end
        end
        [best_results,best_result_indices]=max(eval_results);
        rez = cell(nmonte,1);
        rez2=rez;
        
        rng(880);
        
        tic
        for mmm = 1:nmonte
            
            sortii=cell(3,1);
            N_keep=round(numel(all_labels)*(prop_train));
            outoforder=randperm(numel(all_labels));
            sortii{1} = sort(outoforder(1:N_keep));
            sortii{2} = sort(outoforder(1+N_keep:end));
            
            labels=all_labels(sortii{1});
            test_labels=all_labels(sortii{2});
            valid_labels=all_labels(sortii{3});
            
            N=numel(sortii{1});%number of training set samples
            
            Y=1*bsxfun(@eq,labels,un_lab');% embedding of training labels
            Yc=bsxfun(@minus,Y,mean(Y));%centered
            Yc_2=Yc/sqrt(sqrt(sum(reshape(Yc'*Yc,[],1).^2))); %  ||y*y'||_F=1
            
            eval_results=zeros(numel(Zs),2);
            for kk=1:numel(Zs)
                Z_tr=Zs{kk}(:,sortii{1});
                Zc_tr=bsxfun(@minus,Z_tr,mean(Z_tr,2));% centered
                Zc_tr2=Zc_tr/sqrt(sqrt(sum(reshape(Zc_tr*Zc_tr',[],1).^2))); %  Schatten-2 normalized
                centered_alignment=norm(Zc_tr2*Yc_2,'fro');
                eval_results(kk,1)=centered_alignment;
                
                Z_tr=Zcs{kk}(:,sortii{1});
                Zc_tr=bsxfun(@minus,Z_tr,mean(Z_tr,2));% centered
                Zc_tr2=Zc_tr/sqrt(sqrt(sum(reshape(Zc_tr*Zc_tr',[],1).^2))); %  Schatten-2 normalized
                centered_alignment=norm(Zc_tr2*Yc_2,'fro');
                eval_results(kk,2)=centered_alignment;
                
            end
            [best_results,best_result_indices2]=max(eval_results);
            all_result_indices=cat(2,best_result_indices,best_result_indices2);
            
            rez{mmm}=[];
            rez2{mmm}=[];
            for method_ii=1:numel(all_result_indices)+1
                if method_ii<=numel(all_result_indices)
                    kk=all_result_indices(method_ii);
                    if mod(method_ii,2)==1
                        Z=Zs{kk};
                    else
                        Z=Zcs{kk};
                    end
                    K=Z'*Z;
                    embed_dim=size(Z,1);
                else
                    K=Kfull;
                    embed_dim=size(K,2);
                end
                %Euclidean distance of embedding
                K1=diag(K)*ones(1,size(K,1));
                Dmat = sqrt(-2*K+K1+K1');
                
                [~,iii]=sort(Dmat(sortii{1},sortii{2}));
                n_1nn=mean(labels(iii(1,:))==test_labels);%1NN
                rez{mmm}=cat(3,rez{mmm},n_1nn);
                rez2{mmm}=cat(3,rez2{mmm},embed_dim);
            end
            
        end
        new_results=cell2mat(rez);
        new_dims=cell2mat(rez2);
        orig_dims=d*ones(size(new_dims));
        num_samples=n*ones(size(new_dims));
        allresults{datasetii}={new_results new_dims orig_dims,num_samples};
    end
    save('../results/kpca_all_measure_select_ksize','allresults','datasets')
end
%%


acc_ave=cell2mat(cellfun(@(X) squeeze(mean(X{1})),allresults','uni',0))';
dim_sel=cell2mat(cellfun(@(X) squeeze(mean(X{2})),allresults','uni',0))';
%dim_ori=cell2mat(cellfun(@(X) squeeze(mean(X{3})),allresults','uni',0))';
%n_samples=cell2mat(cellfun(@(X) squeeze(mean(X{4})),allresults','uni',0))';

S=[];
for do_it_twice=1:2
    if do_it_twice==1
        method_order=[size(acc_ave,2) size(acc_ave,2)-2 size(acc_ave,2)-1 [1 3 11 13 ]];
    else
        method_order=[5  7 9  15 17 19 ];
    end
    M1= dim_sel(:,method_order);
    M2=acc_ave(:,method_order)*100;
    marg_ave=cell2mat(cellfun(@(X) squeeze(mean(bsxfun(@minus,X{1}(:,:,size(acc_ave,2)),X{1}))),allresults','uni',0))';
    M3=marg_ave(:,method_order)*100;
    marg_ave=cell2mat(cellfun(@(X) squeeze(mean(bsxfun(@minus,max(X{1},[],3),X{1}))),allresults','uni',0))';
    M4=marg_ave(:,method_order)*100;
    
    
    method_names=cat(2,...
        reshape([func_names(:),cellfun(@(x) [x ' centered'],func_names(:),'uni',0)]',[],1)',...
        'CKA^1','CKA^2','Original');
    full_names=method_names(method_order);
    
    C=numel(full_names);
    S=cat(2,S,sprintf('\\begin{tabular}{%s}\n',repmat('l ',1,C+1)));
    S=cat(2,S,sprintf('\\hline\n'));
    str='  &';
    for col=1:C
        str=cat(2,str,sprintf('%15s &',full_names{col}));
    end
    fprintf('%s\n',strrep(str(1:end-1),'&','|'));
    S=cat(2,S,sprintf('%s\\\\\n',str(1:end-1)));
    
    S=cat(2,S,sprintf('\\hline\n'));
    fprintf('%s\n',repmat('-',1,100));
    
    row_headers=datasets';
    for row=1:numel(row_headers)
        str=row_headers{row};
        idx=find(str=='-',1,'last');
        if ~isempty(idx)
            str=str(idx+1:end);
        end
        idx=find(str=='_',1,'last');
        if ~isempty(idx)
            str=str(1:idx-1);
        end
        idx=find(str=='_',1,'first');
        if ~isempty(idx)
            str=str(idx+1:end);
        end
        str=strrep(str,'_',' ');
        row_headers{row}=str;
    end
    for row=1:numel(row_headers)
        str=sprintf('%15s',[row_headers{row},'&']);
        for col=1:C
            val1=M1(row,col);
            val2=M2(row,col);
            if do_it_twice==1
                switch col
                    case 1
                        str=cat(2,str,sprintf('%15s &',sprintf('%i (%i)',round(val2),val1)));
                    case {2,3}
                        str=cat(2,str,sprintf('%15s &',sprintf('%i (%.1f)',round(val2),val1)));
                    otherwise
                        str=cat(2,str,sprintf('%15s &',sprintf('%i (%i)',round(val2),round(val1))));
                end
            else
                str=cat(2,str,sprintf('%15s &',sprintf('%i (%i)',round(val2),round(val1))));
            end
        end
        fprintf('%s\n',strrep(str(1:end-1),'&','|'));
        str=strrep(str,'%','\%');
        S=cat(2,S,sprintf('%s\\\\\n',str(1:end-1)));
    end
    S=cat(2,S,sprintf('\\hline\n'));
    fprintf('%s\n',repmat('-',1,100));
    
    str=sprintf('%15s',['vs. original','&']);
    for col=1:C
        val1=mean(M3(:,col));
        val2=std(M3(:,col));
        if round(val2)<=1
            str=cat(2,str,sprintf('%15s &',sprintf('%6.1f%.1f',val1,val2)));
        else
            str=cat(2,str,sprintf('%15s &',sprintf('%6.1f%i',val1,round(val2))));
        end
    end
    fprintf('%s\n',strrep(str(1:end-1),'&','|'));
    str=strrep(str,'%','\%');
    str=strrep(str,'','$\pm$');
    S=cat(2,S,sprintf('%s\\\\\n',str(1:end-1)));
    
    str=sprintf('%15s',['Loss margin','&']);
    for col=1:C
        val1=mean(M4(:,col));
        val2=std(M4(:,col));
        if round(val2)<=1
            str=cat(2,str,sprintf('%15s &',sprintf('%6.1f%.1f',val1,val2)));
        else
            str=cat(2,str,sprintf('%15s &',sprintf('%6.1f%i',val1,round(val2))));
        end
    end
    fprintf('%s\n',strrep(str(1:end-1),'&','|'));
    str=strrep(str,'%','\%');
    str=strrep(str,'','$\pm$');
    S=cat(2,S,sprintf('%s\\\\\n',str(1:end-1)));
    
    
    
    S=cat(2,S,sprintf('\\hline\n'));
    fprintf('%s\n',repmat('-',1,100));
    S=cat(2,S,sprintf('\\end{tabular}\n'));
end

fid=fopen('../results/Table_6.txt','w');
fwrite(fid,S)
fclose(fid);

