__author__ = 'leihuang'
import numpy as np
from sklearn import metrics
from operator import sub
from functools import reduce
from math import sqrt
from sklearn import preprocessing
#import matplotlib.pyplot as plt
from sklearn import svm
import random
PATH = '/home/leihuang/PPINcluster/DIP_Yeast_20150101/longepoch/'
N = 5093

def read_edges(file_name):
    edge_set = set()
    for line in open(PATH+file_name):
        a, b = line.strip().split('\t')
        edge_set.add((int(a)-1, int(b)-1))
    return edge_set

def all_edges():
    for i in range(N):
        for j in range(i, N):
            yield (i, j)

def read_node_feature(feature_file):
    n_feature = []
    for line in open(PATH + feature_file, 'r'):
        parts = line.strip().split('\t')
        node, v = parts[0], parts[1:]
        tmpX = [float(x) for x in v if x != '']
        n_feature.append(tmpX)
    return n_feature

if __name__ == '__main__':
    training_edges = read_edges('yeast_training.txt')
    testing_edges = read_edges('yeast_testing.txt')
    features = read_node_feature('featurefile.txt')
    #print(training_edges)
    min_max_scaler = preprocessing.MinMaxScaler()
    Xtrain_pos, ytrain_pos = [],[]
    Xtest_pos, ytest_pos = [], []
    X_neg, y_neg = [], []
    for a, b in all_edges():
        if (a, b) in training_edges or (b, a) in training_edges:
            Xtrain_pos.append(features[a]+features[b])
            ytrain_pos.append(1)
            print(1)
        elif (a, b) in testing_edges or (b, a) in testing_edges:
            Xtest_pos.append(features[a]+features[b])
            ytest_pos.append(1)
            print(11)
        else:
            X_neg.append(features[a]+features[b])
            y_neg.append(0)

    sample = random.sample(list(range(len(X_neg))), len(Xtrain_pos)+len(Xtest_pos))
    train_sample = sample[:len(Xtrain_pos)]
    test_sample = sample[len(Xtrain_pos):]

    print("Building training and testing datasets......")
    Xtrain = Xtrain_pos + [X_neg[i] for i in train_sample]
    ytrain = ytrain_pos + [y_neg[i] for i in train_sample]
    Xtest = Xtest_pos + [X_neg[i] for i in test_sample]
    ytest = ytest_pos + [y_neg[i] for i in test_sample]

    C = 1.0
    print("Fitting the training data......")
    rbf_svc = svm.SVC(kernel='rbf', gamma=0.7, C=C, probability=True).fit(Xtrain, ytrain)
    print("Predicting......")
    y_score = rbf_svc.decision_function(Xtest)
    print(y_score)

    #fpr, tpr, thresholds = metrics.roc_curve(ytest, y_score)
    print("Caculating AUC score......")
    auc = metrics.roc_auc_score(ytest, y_score)
    print(auc)
    '''
    plt.figure()
    lw = 2
    plt.plot(fpr, tpr, color='darkorange',
             lw=lw, label='ROC curve (area = %0.2f)' % auc)
    plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic example')
    plt.legend(loc="lower right")
    plt.show()
    '''
