__author__ = 'leihuang'
import numpy as np
from sklearn import metrics
from operator import sub
from functools import reduce
from math import sqrt
from sklearn import preprocessing
#import matplotlib.pyplot as plt
PATH = '/home/leihuang/PPINcluster/DIP_Yeast_20150101/longepoch/'
N = 5093

def read_edges(file_name):
    edge_set = set()
    for line in open(PATH+file_name):
        a, b = line.strip().split('\t')
        edge_set.add((int(a)-1, int(b)-1))
    return edge_set

def all_edges():
    for i in range(N):
        for j in range(i, N):
            yield (i, j)

def read_node_feature(feature_file):
    n_feature = []
    for line in open(PATH + feature_file, 'r'):
        parts = line.strip().split('\t')
        node, v = parts[0], parts[1:]
        tmpX = [float(x) for x in v if x != '']
        n_feature.append(tmpX)
    return n_feature

if __name__ == '__main__':
    training_edges = read_edges('yeast_training.txt')
    testing_edges = read_edges('yeast_testing.txt')
    features = read_node_feature('featurefile.txt')
    #print(training_edges)
    min_max_scaler = preprocessing.MinMaxScaler()
    X, scores, y = [], [], []
    for a, b in all_edges():
        if (a, b) not in testing_edges or (b, a) not in testing_edges:
            X.append((a, b))
            scores.append(sqrt(sum(map((lambda a: a**2), map(sub, features[a], features[b])))))
            if (a, b) in training_edges or (b, a) in training_edges:
                y.append(0)
                print(scores[-1], y[-1])
            else:
                y.append(1)
        else:
            print('==============')

    y = np.array(y)
    scores = np.array(scores)
    scores = min_max_scaler.fit_transform(scores)
    fw = open('prediction.txt', 'w')
    for i in range(len(scores)):
        fw.write(str(scores[i]) + '\t' + str(y[i])+'\n')
    fw.close()
    auc = metrics.roc_auc_score(y, scores)
    '''
    fpr, tpr, thresholds = metrics.roc_curve(y, scores, pos_label=0)
    auc = metrics.roc_auc_score(y, scores)
    plt.figure()
    lw = 2
    plt.plot(fpr, tpr, color='darkorange',
             lw=lw, label='ROC curve (area = %0.2f)' % auc)
    plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic example')
    plt.legend(loc="lower right")
    plt.show()
    '''