function [Y,bandwidth,G] = info_minfunc(Y,metric_name)
[n,d_x]=size(Y);

options.optTol=1e-9;
options.Method='cg';
options.Display='off';

objfungrad=@(gamma) gamma_fungrad(Y(:),gamma,eye(d_x),metric_name);
% derivativeCheck(objfungrad,0,1,1)
% derivativeCheck(objfungrad,1,1,1)
% derivativeCheck(objfungrad,-1,1,1)
% derivativeCheck(objfungrad,-2,1,1)

gamma=fminbnd(objfungrad,-5,5);

objfungrad=@(Y) my_objfungrad(Y,gamma,eye(d_x),metric_name);
%  fastDerivativeCheck(objfungrad,Y(:),1,2)
%  derivativeCheck(objfungrad,Y(:),1,1)
%  derivativeCheck(objfungrad,Y(:),1,2)

[~,g0]=objfungrad(Y(:));
G=reshape(g0,n,d_x);
[yvec,~]=minFunc(objfungrad,Y(:),options);
Y=reshape(yvec,n,d_x);
bandwidth=sqrt(1/2/10^gamma);

end

function [f, grad_gamma]= gamma_fungrad(yvec,gamma,A,metric_name)

[f, ~,grad_gamma]= my_objfungrad(yvec,gamma,A,metric_name);
end

function [f, gradf,grad_gamma,grad_A,grad_D]= my_objfungrad(xvec,gamma,A,metric_name)
d_x=size(A,1);
X=reshape(xvec,[],d_x);
Y=X*A;
n=size(Y,1);
YY=Y*Y';
Y2=diag(YY);
D2=bsxfun(@minus,bsxfun(@minus,2*(Y*Y'),Y2),Y2');
K=exp(10^gamma*D2);
onevec=ones(n,1);

if any(isnan(K(:)))
    f=NaN;
    gradf=0*logeta;
    fprintf('whoa!')
else
    % we assume K is strictly positive definite
    K=(1-1e-6)*K+1e-6*eye(n); % regularization
    
    kbar=mean(K);
    kbarbar=mean(kbar);
    H=eye(n)-1/n;
    Kc=H*K*H;
    switch lower(metric_name)
        case 'euclidean'
            nK=K(:).'*K(:);
            f=1-nK/n^2+(n*kbarbar^2-2*kbarbar+1)/(n-1);
            Grad=-2*K/n^2+(2*kbarbar/n-2/n^2)/(n-1);
        case 'cosine'
            nK=K(:).'*K(:);
            f=((n*kbarbar^2-2*kbarbar+1)/nK);
            Grad=(nK*(2*kbarbar/n-2/n^2)-2*(n*kbarbar^2-2*kbarbar+1)*K)/nK^2;
        case 'cka'
            trKH=trace(bsxfun(@minus,K,kbar));
            nHKH=sqrt(Kc(:).'*Kc(:));
            f=(trKH/nHKH)^2/sqrt(n-1)-log(1-kbarbar);
            Grad=2*trKH*(H-Kc*trKH/nHKH^2)/nHKH^2/sqrt(n-1)+ones(n)/n^2/(1-kbarbar);
        case 'hsic'
            f=1+1/n^2*(1/(n-1)*(n-n*kbarbar)^2-Kc(:).'*Kc(:));
            Grad=1/n^2*(-2/(n-1)*(1-kbarbar)-2*Kc);
        case 'bures'
            huber_val=1e-9;
            h=@(x,h) (x>=h).*sqrt(x) + (x<h).*((h^-1.5)/4*x.^2+3/4*sqrt(h));
            hp=@(x,h) (x>=h)./sqrt(x)/2 + (x<h).*((h^-1.5)/2*x);
            
            [U,S]=eig((Kc+Kc.')/2);
            s=diag(S);
            [~,remove_idx]=max(abs(sum(U,1)));
            U(:,remove_idx)=[];
            s(remove_idx)=[];
            
            tr_sqrt_HKH_h=sum(h(s,huber_val));
            dKc_h=(U*diag(2*hp(s,huber_val))*U');
            tr_sqrt_HKH=sum(sqrt(s));
            dKc=(U*diag(s.^-.5)*U');
            
            if huber_val>0
                tr_sqrt_HKH=tr_sqrt_HKH_h;
                dKc=dKc_h;
            end
            dKc=(dKc+dKc.')/2;
            f=(kbarbar+1/(n-1)/n*tr_sqrt_HKH^2);
            
            dKc=H*dKc*H;
            Grad=(1/n^2+tr_sqrt_HKH/n/(n-1)*dKc);
        case 'qh'
            [U,S]=eig((K+K.')/2);
            s=diag(S);
            U=U(:,s>eps);
            s=s(s>eps);
            K12=U*diag(sqrt(s))*U';
            K12J=sum(K12(:));
            K12H=K12*H;
            tr_K12H=trace(K12H);
            
            f=(K12J^2/n^3+1/(n-1)/n*tr_K12H^2);
            
            dKJ=sylvester(K12,K12,ones(n));
            dKH=sylvester(K12,K12,H);
            
            dKJ=reshape(dKJ,size(K));
            dKH=reshape(dKH,size(K));
            
            Grad=2*(1/n^3*dKJ*K12J+1/(n-1)/n*tr_K12H*dKH);
        case 'bartlett'
            rbar=(kbarbar*n-1)/(n-1);
            lambda = 1-rbar;
            f = 1/lambda^2*(...
                1/2*(K(:).'*K(:)-n^2/(n-1)*(kbarbar^2*n-2*kbarbar+1))...
                ...
                -1/n*(n^2*(kbar*kbar.')-kbarbar^2*n^3)...
                );
            G1 = K-(n*kbarbar-1)/(n-1)-1/n*(bsxfun(@plus,n*kbar,n*kbar.')-2*kbarbar*n);
            G2 = -2/(n-1)^2*(1-kbarbar);
            f1 = lambda^2*f;
            Grad = (lambda^2*G1 - G2*f1)/lambda^4;
    end
    B = Grad.*K;
    B=(B+B')/2;
    gradY=(B-diag(B*onevec))*X*(A*A.');
    gradf=10^gamma*4*real(gradY(:));
    if nargout>2
        grad_gamma=log(10)*10^gamma*D2(:).'*B(:);
    end
    if nargout>3
        grad_A=10^gamma*4*reshape(real( Y'*(B-diag(B*onevec))*(Y*A)),[],1);
    end
    if nargout>4
        grad_D=10^gamma*4*diag(real( Y'*(B-diag(B*onevec))*(Y*A)));
    end
end
end







