clear all; close all; clc;
rng('default');
%% Initial with correct topology or not. 
% uncomment if want correct connected intial guess.
ConnectionMatrix = [1,0,1,0; 0,1,0,1; 0,1,1,0; 1,0,0,1 ];                

% uncomment if want fully connected intial guess.
% ConnectionMatrix = ones(4);

% number of folds
K = 10;
%% 
symSpaceTM = ['i','o','M','m'];
symSpaceaa =  ['A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y'];

[ID,Seq] = fastaread('TMMOD_data_SeqAll.txt');
[~,Lbl] = fastaread('TMMOD_data_LblAll.txt');
[re_lbl] = TMRelabel(Lbl);

Seq = sym2number(Seq,symSpaceaa);
re_lbl = sym2number(re_lbl,symSpaceTM);

r_space = 0.05:0.05:0.95;
Results = zeros(5,length(r_space));

for r_ind = 1:length(r_space)

% r_ind

r = r_space(r_ind);

Indices = crossvalind('Kfold',length(Seq),K);

acc1= 0;     acc2 = 0;    acc3 = 0;    acc4 = 0;     acc5 = 0;    

tot = 0;

for i = 1:K
    
    E_guess = rand(4,20);   E_guess = matrixNorm(E_guess);
    T_guess = rand(4,4);    T_guess = T_guess.*ConnectionMatrix;
    T_guess = matrixNorm(T_guess);
    
    TrainInd = find(Indices~=i);             TestInd = find(Indices==i);
    TrainSeq = Seq(TrainInd);                TrainLbl = re_lbl(TrainInd);
    TestSeq = Seq(TestInd);                  TestLbl = re_lbl(TestInd);
    % make label partially avaiable.
    for j = 1:length(TrainSeq)
        randind = randperm(length(TrainLbl{j}));
        randind = randind(1:round(r*length(TrainLbl{j})));
        TrainLbl{j}(randind) = 0;
    end
    
%   TR_other
    [TR_other,E_other] = HMMPL_Other(TrainSeq,TrainLbl,T_guess,E_guess);
    for j = 1:length(TestSeq)
         try
            lbl_other = hmmviterbi(TestSeq{j},TR_other,E_other);
         catch
            E_other = E_other + ones(size(E_other))*0.0001;            E_other = matrixNorm(E_other);
            lbl_vi = hmmviterbi(testSeq{j},TR_other,E_other);
         end
         lbl_other(lbl_other == 4) = 3;         TestLbl{j}(TestLbl{j} == 4) = 3;
         lbl_other = lbl_other - TestLbl{j};    ind = lbl_other(lbl_other == 0);    
         tot = tot + length(TestLbl{j});
         acc1 = acc1 + length(ind);
    end
     
    [TR_our1,E_our1,TR_our2,E_our2,TR_our3,E_our3] = HMMBW_PartialLabel_EarlyStop(TrainSeq,TrainLbl,T_guess,E_guess);
     for j = 1:length(TestSeq)
         try
            lbl_bw = hmmviterbi(TestSeq{j},TR_our1,E_our1);
         catch
            E_our1 = E_our1 + ones(size(E_our1))*0.0001;              E_our1 = matrixNorm(E_our1);
            lbl_bw = hmmviterbi(TestSeq{j},TR_our1,E_our1);
         end
         lbl_bw(lbl_bw == 4) = 3;
         lbl_bw = lbl_bw - TestLbl{j};    ind = lbl_bw(lbl_bw == 0);    
         acc2 = acc2 + length(ind);
     end
     
     for j = 1:length(TestSeq)
         try
            lbl_bw = hmmviterbi(TestSeq{j},TR_our2,E_our2);
         catch
            E_our2 = E_our2 + ones(size(E_our2))*0.0001;              E_our2 = matrixNorm(E_our2);
            lbl_bw = hmmviterbi(TestSeq{j},TR_our2,E_our2);
         end
         lbl_bw(lbl_bw == 4) = 3;
         lbl_bw = lbl_bw - TestLbl{j};    ind = lbl_bw(lbl_bw == 0);    
         acc3 = acc3 + length(ind);
     end

     for j = 1:length(TestSeq)
         try
            lbl_bw = hmmviterbi(TestSeq{j},TR_our3,E_our3);
         catch
            E_our3 = E_our3 + ones(size(E_our3))*0.0001;              E_our3 = matrixNorm(E_our3);
            lbl_bw = hmmviterbi(TestSeq{j},TR_our3,E_our3);
         end
         lbl_bw(lbl_bw == 4) = 3;
         lbl_bw = lbl_bw - TestLbl{j};    ind = lbl_bw(lbl_bw == 0);    
         acc4 = acc4 + length(ind);
     end

     [TR_ML,E_ML] = hmm_ML(Seq(TrainInd),re_lbl(TrainInd),4,20);
     for j = 1:length(TestSeq)
         try
            lbl_ml = hmmviterbi(TestSeq{j},TR_ML,E_ML);
         catch
            E_ML = E_ML + ones(size(E_ML))*0.0001;              E_ML = matrixNorm(E_ML);
            lbl_ml = hmmviterbi(TestSeq{j},TR_ML,E_ML);
         end
         lbl_ml(lbl_ml == 4) = 3;
         lbl_ml = lbl_ml - TestLbl{j};    ind = lbl_ml(lbl_ml == 0);    
         acc5 = acc5 + length(ind);
     end
     
end
Results(1,r_ind) = acc1/tot;
Results(2,r_ind) = acc2/tot;
Results(3,r_ind) = acc3/tot;
Results(4,r_ind) = acc4/tot;
Results(5,r_ind) = acc5/tot;

end
cmp = Results(1,:);
avgimprove = zeros(1,3);
for i = 2:4
    avgimprove(i-1) = 100*mean((Results(i,:)- cmp )./cmp);
end

%% Show results grapically
figure;
plot(r_space,Results(1,:),'r'); hold on;
plot(r_space,Results(3,:),'b'); hold on;
plot(r_space,Results(4,:),'m'); hold on;
plot(r_space,Results(5,:),'k'); hold off;

title([ 'Decoding Accuracy for methods of Scheffer et al, cBW with model selection, cBW, and ML method, with correct initial transition']);
% title([ 'Decoding Accuracy for methods of Scheffer et al, cBW with model selection, cBW, and ML method, with fully connected initial transition']);

xlabel('% of unlabelled data ');
ylabel('Accuracy');
legend('Other method','Our method ','Ground True');
legend('Scheffer et al','cBW with model selection','cBW', 'ML method');

function [re_lbl] =TMRelabel(lbl)
    re_lbl = lbl;
    
    for i = 1:length(lbl)
       templbl = lbl{i};
       m_tag = 0;
       for j = 1:length(templbl)
           if templbl(j) == 'o'
               m_tag = 1;
           end
           if templbl(j) == 'i'
               m_tag = 2;
           end
           if templbl(j) == 'M' && m_tag == 1
                templbl(j) = 'm';
           end
       end
       re_lbl{i} = templbl;
    end
end