'''' This file needs interactome.tsv file, DataS2_disease_genes.tsv and the file with data of low s_AB and high comobidity 
There are three place where value needs to be changed accordingly. it is diseasepath, rocpath and r. These are the paths where respective files will be stored. Specifically line 605, 609 and 633

The end file will be stored wth name ROC.txt in the same folder from where the code will run. and Hisogram plot will be shown as well.''' 

#! /usr/bin/env python
import csv
import os
import glob 
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
from collections import Counter
import networkx as nx
import numpy as np
import random


def read_network():
    """
    Reads a network from an external file.

    * The edgelist must be provided as a tab-separated table. The
    first two columns of the table will be interpreted as an
    interaction gene1 <==> gene2

    * Lines that start with '#' will be ignored
    """

    G = nx.Graph()
    for line in open("DataS1_interactome.tsv",'r'):
        # lines starting with '#' will be ignored
        if line[0]=='#':
            continue
        # The first two columns in the line will be interpreted as an
        # interaction gene1 <=> gene2
        line_data   = line.strip().split('\t')
        node1 = line_data[0]
        node2 = line_data[1]
        G.add_edge(node1,node2)

    #print "\n> done loading network:"
    #print "> network contains %s nodes and %s links" %(G.number_of_nodes(),G.number_of_edges())
    
    return G


# =============================================================================
def read_gene_list(gene_file,diseasepath):
    """
    Reads a list genes from an external file.

    * The genes must be provided as a table. If the table has more
    than one column, they must be tab-separated. The first column will
    be used only.

    * Lines that start with '#' will be ignored
    """

    genes_set = set()
    for line in open(diseasepath+"/"+gene_file+".txt",'r'):
        # lines starting with '#' will be ignored
        if line[0]=='#':
            continue
        # the first column in the line will be interpreted as a seed
        # gene:
        line_data = line.strip().split('\t')
        gene      = line_data[0]
        genes_set.add(gene)

    #print "\n> done reading genes:"
    #print "> %s genes found in %s" %(len(genes_set),gene_file)

    return genes_set


# =============================================================================
def remove_self_links(G):

    sl = G.selfloop_edges()
    G.remove_edges_from(sl)



# =============================================================================
def get_pathlengths_for_single_set(G,given_gene_set):
    
    """
    calculate the shortest paths of a given set of genes in a
    given network. The results are stored in a dictionary of
    dictionaries:
    all_path_lenghts[gene1][gene2] = l
    with gene1 < gene2, so each pair is stored only once!

    PARAMETERS:
    -----------
        - G: network
        - gene_set: gene set for which paths should be computed

    RETURNS:
    --------
        - all_path_lenghts[gene1][gene2] = l for all pairs of genes
          with gene1 < gene2

    """ 

    # remove all nodes that are not in the network
    all_genes_in_network = set(G.nodes())
    gene_set = given_gene_set & all_genes_in_network

    all_path_lenghts = {}
    
    # calculate the distance of all possible pairs
    for gene1 in gene_set:
        if not all_path_lenghts.has_key(gene1):
            all_path_lenghts[gene1] = {}
        for gene2 in gene_set:
            if gene1 < gene2:
                try:
                    l = nx.shortest_path_length(G, source=gene1, target=gene2)
                    all_path_lenghts[gene1][gene2] = l
                except:
                    continue

    return all_path_lenghts



# =============================================================================
def get_pathlengths_for_two_sets(G,given_gene_set1,given_gene_set2):
    
    """
    calculate the shortest paths between two given set of genes in a
    given network. The results are stored in a dictionary of
    dictionaries: all_path_lenghts[gene1][gene2] = l with gene1 <
    gene2, so each pair is stored only once!

    PARAMETERS:
    -----------
        - G: network
        - gene_set1/2: gene sets for which paths should be computed

    RETURNS:
    --------
        - all_path_lenghts[gene1][gene2] = l for all pairs of genes
          with gene1 < gene2

    """ 

    # remove all nodes that are not in the network
    all_genes_in_network = set(G.nodes())
    gene_set1 = given_gene_set1 & all_genes_in_network
    gene_set2 = given_gene_set2 & all_genes_in_network

    all_path_lenghts = {}
    
    # calculate the distance of all possible pairs
    for gene1 in gene_set1:
        if not all_path_lenghts.has_key(gene1):
            all_path_lenghts[gene1] = {}
        for gene2 in gene_set2:
            if gene1 != gene2:
                try:
                    l = nx.shortest_path_length(G, source=gene1, target=gene2)
                    if gene1 < gene2:
                        all_path_lenghts[gene1][gene2] = l
                    else:
                        if not all_path_lenghts.has_key(gene2):
                            all_path_lenghts[gene2] = {}
                        all_path_lenghts[gene2][gene1] = l
                except:
                    continue

    return all_path_lenghts


# =============================================================================
def calc_single_set_distance(G,given_gene_set):

    """
    Calculates the mean shortest distance for a set of genes on a
    given network    
    

    PARAMETERS:
    -----------
        - G: network
        - gene_set: gene set for which distance will be computed 

    RETURNS:
    --------
         - mean shortest distance 

    """


    # remove all nodes that are not in the network, just to be safe
    all_genes_in_network = set(G.nodes())
    gene_set = given_gene_set & all_genes_in_network

    # get the network distances for all gene pairs:
    all_path_lenghts = get_pathlengths_for_single_set(G,gene_set)

    all_distances = []

    # going through all gene pairs
    for geneA in gene_set:

        all_distances_A = []
        for geneB in gene_set:

            # I have to check which gene is 'smaller' in order to know
            # where to look up the distance of that pair in the
            # all_path_lengths dict
            if geneA < geneB:
                if all_path_lenghts[geneA].has_key(geneB):
                    all_distances_A.append(all_path_lenghts[geneA][geneB])
            else:
                if all_path_lenghts[geneB].has_key(geneA):
                    all_distances_A.append(all_path_lenghts[geneB][geneA])

        if len(all_distances_A) > 0:
            l_min = min(all_distances_A)
            all_distances.append(l_min)

    # calculate mean shortest distance
    mean_shortest_distance = np.mean(all_distances)
    #mean_shortest_distance = np.mean(all_distances_A)

    return mean_shortest_distance


# =============================================================================
def calc_set_pair_distances(G,given_gene_set1,given_gene_set2):

    """
    Calculates the mean shortest distance between two sets of genes on
    a given network
    
    PARAMETERS:
    -----------
        - G: network
        - gene_set1/2: gene sets for which distance will be computed 

    RETURNS:
    --------
         - mean shortest distance 

    """

    # remove all nodes that are not in the network
    all_genes_in_network = set(G.nodes())
    gene_set1 = given_gene_set1 & all_genes_in_network
    gene_set2 = given_gene_set2 & all_genes_in_network

    # get the network distances for all gene pairs:
    all_path_lenghts = get_pathlengths_for_two_sets(G,gene_set1,gene_set2)

    all_distances = []

    # going through all pairs starting from set 1 
    for geneA in gene_set1:

        all_distances_A = []
        for geneB in gene_set2:

            # the genes are the same, so their distance is 0
            if geneA == geneB:
                all_distances_A.append(0)
                
            # I have to check which gene is 'smaller' in order to know
            # where to look up the distance of that pair in the
            # all_path_lengths dict
            else:
                if geneA < geneB:
                    try:
                        all_distances_A.append(all_path_lenghts[geneA][geneB])
                    except:
                        pass

                else:
                    try:
                        all_distances_A.append(all_path_lenghts[geneB][geneA])
                    except:
                        pass


        if len(all_distances_A) > 0:
            l_min = min(all_distances_A)
            all_distances.append(l_min)

    # going through all pairs starting from disease B
    for geneA in gene_set2:

        all_distances_A = []
        for geneB in gene_set1:

            # the genes are the same, so their distance is 0
            if geneA == geneB:
                all_distances_A.append(0)

            # I have to check which gene is 'smaller' in order to know
            # where to look up the distance of that pair in the
            # all_path_lengths dict
            else:
                if geneA < geneB:
                    try:
                        all_distances_A.append(all_path_lenghts[geneA][geneB])
                    except:
                        pass
                        
                else:
                    try:
                        all_distances_A.append(all_path_lenghts[geneB][geneA])
                    except:
                        pass

        if len(all_distances_A) > 0:
            l_min = min(all_distances_A)
            all_distances.append(l_min)


    # calculate mean shortest distance
    mean_shortest_distance = np.mean(all_distances)
    #mean_shortest_distance = np.mean(all_distances_A)

    return mean_shortest_distance
#-----------------------------------------------------------
def find_genes (gene_file_1,gene_file_2,diseasepath):
	''''subroutine to separate common and uncommon genes from disease pair'''	
	f1=open(diseasepath+"/"+gene_file_1+".txt","r")
	f2=open(diseasepath+"/"+gene_file_2+".txt","r")
	genelist=[]
	for x in f1:
		x=x.rstrip()
		x1=x.split()
		genelist.append(x1[0])
	for x in f2:
		x=x.rstrip()
		x1=x.split()
		genelist.append(x1[0])
	common_genes=set([x for x in genelist if genelist.count(x) > 1])
	uncommon_genes= [item for item in genelist if item not in common_genes]
	f11=open("common_genes.txt","w")
	f12=open("uncommon_genes.txt","w")
	for i in common_genes:
		f11.write(str(i) +"\n")
	for i in uncommon_genes:
		f12.write(str(i)+"\n")
	return (len(common_genes), len(uncommon_genes))  
		
		 	
#--------------------------calculate s_AB for uncommon genes
def s_AB_uncommon_genes (gene_file_1,gene_file_2,rocpath,diseasepath):
	fw=open(rocpath+"/"+gene_file_1+"&"+gene_file_2+".txt","w")
	fx=open("uncommon_genes.txt","r")
	a1=[]
	w3=0
	for line in fx:
		line=line.strip()
		x=line.split()
		a1.append(x[0])
	a=[]
	c=0
	if len(a1) <= 10:
 		a=a1
	else:
		while c <> 10:
			w1=random.choice(a1)
			if not w1 in a:
				a.append(w1)			
				c=c+1
	#print len(a)
	for i in range(len(a)):
		# read network
    		G  = read_network()
    		# get all genes and remove self links
    		all_genes_in_network = set(G.nodes())
    		remove_self_links(G)
    		# read gene set 1
    		genes_A_full = read_gene_list(gene_file_1,diseasepath)
    		# removing genes that are not in the network:
		genes_A = genes_A_full & all_genes_in_network
		# making uncommon gene a common gene for gene A		
		genes_A |= set(a[i])
    		#if len(genes_A_full) != len(genes_A):
        		#print "> ignoring %s genes that are not in the network" %(len(genes_A_full - all_genes_in_network))
       		#print "> remaining number of genes: %s" %(len(genes_A))
		#if a[i] in genes_A:
			#print "genes_A"
		# read gene set 2
		genes_B_full = read_gene_list(gene_file_2,diseasepath)
    		# removing genes that are not in the network:
    		genes_B = genes_B_full & all_genes_in_network
		# making uncommon gene a common gene for gene B	
		genes_B |= set(a[i])
    		#if len(genes_B_full) != len(genes_B):
       	 		#print "> ignoring %s genes that are not in the network" %(len(genes_B_full - all_genes_in_network))
       	 	#print "> remaining number of genes: %s" %(len(genes_B))
		#if a[i] in genes_B:
			#print "genes_B"
		
    		# --------------------------------------------------------
	        #
    		# CALCULATE NETWORK QUANTITIES
    		#
    		# --------------------------------------------------------
  		#print a[i]
		# distances WITHIN the two gene sets:
		d_A = calc_single_set_distance(G,genes_A)
		d_B = calc_single_set_distance(G,genes_B)
		# distances BETWEEN the two gene sets:
		d_AB = calc_set_pair_distances(G,genes_A,genes_B)
		# calculate separation
    		s_AB = d_AB - (d_A + d_B)/2.
		
		# print and save results:
		fw.write(str(a[i])+"\t"+ str(s_AB)+"\t"+str("0")+"\n")
   		
#--------------------------------calculate s_AB for common genes
def s_AB_common_genes (gene_file_1,gene_file_2,rocpath,diseasepath):
	fw=open(rocpath+"/"+gene_file_1+"&"+gene_file_2+".txt","a+")
	fx=open("common_genes.txt","r")
    	a1=[]
    	for line in fx:
		line=line.strip()
		x=line.split()
		a1.append(x[0])
	a=[]
	c=0
	if len(a1) <= 5:
 		a=a1
	else:	
		while c <> 5:
			w1=random.choice(a1)
			if not w1 in a:
				a.append(w1)			
				c=c+1
	#print len(a)
    	for i in range(len(a)):
		# read network
    		G  = read_network()
    		# get all genes ad remove self links
    		all_genes_in_network = set(G.nodes())
    		remove_self_links(G)
    		# read gene set 1
    		genes_A_full = read_gene_list(gene_file_1,diseasepath)
    		# removing genes that are not in the network:
    		genes_A = genes_A_full & all_genes_in_network
    		#if len(genes_A_full) != len(genes_A):
        	#	print "> ignoring %s genes that are not in the network" %(len(genes_A_full - all_genes_in_network))
       		#print "> remaining number of genes: %s" %(len(genes_A))
		# read gene set 2
   		genes_B_full = read_gene_list(gene_file_2,diseasepath)
    		# removing genes that are not in the network:
    		genes_B = genes_B_full & all_genes_in_network
    		#if len(genes_B_full) != len(genes_B):
       	 	#	print "> ignoring %s genes that are not in the network" %(len(genes_B_full - all_genes_in_network))
       	 	#print "> remaining number of genes: %s" %(len(genes_B))
		# --------------------------------------------------------
    		#
    		# CALCULATE NETWORK QUANTITIES
    		#
    		# --------------------------------------------------------
  		if a[i] in genes_A:
			#print a[i]
			genes_A.remove(a[i])
			#print genes_A
			#print len(genes_A)
			#raw_input()	    
			# distances WITHIN the two gene sets:
	       		d_A = calc_single_set_distance(G,genes_A)
	    		d_B = calc_single_set_distance(G,genes_B)
			# distances BETWEEN the two gene sets:
	    		d_AB = calc_set_pair_distances(G,genes_A,genes_B)
			# calculate separation
	    		s_AB = d_AB - (d_A + d_B)/2.
			#print and save results:
			fw.write(str(a[i])+"\t"+ str(s_AB)+"\t"+str("1")+"\n")
		else:
			#fw.write(str(a[i])+"\t"+ "-" +"\n")
			continue
	fx=open("common_genes.txt","r")
    	a1=[]
    	for line in fx:
		line=line.strip()
		x=line.split()
		a1.append(x[0])
	a=[]
	c=0
	if len(a1) <= 5:
 		a=a1
	else:
		while c <> 5:
			w1=random.choice(a1)
			if not w1 in a:
				a.append(w1)			
				c=c+1
	#print len(a)
	for i in range(len(a)):
    		# read network
    		G  = read_network()
    		# get all genes ad remove self links
    		all_genes_in_network = set(G.nodes())
    		remove_self_links(G)
    		# read gene set 1
    		genes_A_full = read_gene_list(gene_file_1,diseasepath)
    		# removing genes that are not in the network:
    		genes_A = genes_A_full & all_genes_in_network
    		#if len(genes_A_full) != len(genes_A):
        	#	print "> ignoring %s genes that are not in the network" %(len(genes_A_full - all_genes_in_network))
       		#print "> remaining number of genes: %s" %(len(genes_A))
		# read gene set 2
   		genes_B_full = read_gene_list(gene_file_2,diseasepath)
    		# removing genes that are not in the network:
    		genes_B = genes_B_full & all_genes_in_network
    		#if len(genes_B_full) != len(genes_B):
       	 	#	print "> ignoring %s genes that are not in the network" %(len(genes_B_full - all_genes_in_network))
       	 	#print "> remaining number of genes: %s" %(len(genes_B))
		# --------------------------------------------------------
    		#
    		# CALCULATE NETWORK QUANTITIES
    		#
    		# --------------------------------------------------------
  		if a[i] in genes_B:
			#print a[i]
			genes_B.remove(a[i])
			#print genes_B
			#print len(genes_B)
			#raw_input()	    
			# distances WITHIN the two gene sets:
	       		d_A = calc_single_set_distance(G,genes_A)
	    		d_B = calc_single_set_distance(G,genes_B)
			# distances BETWEEN the two gene sets:
	    		d_AB = calc_set_pair_distances(G,genes_A,genes_B)
			# calculate separation
	    		s_AB = d_AB - (d_A + d_B)/2.
			# print and save results:
			fw.write(str(a[i])+"\t"+ str(s_AB)+"\t"+str("1")+"\n")
		else:
			#fw.write(str(a[i])+"\t"+ "-" +"\n")
			continue   
		 

#---------------------generate disease gene files
def disease_file (newpath):
	a=[]
	for line in open("DataS2_disease_genes.tsv",'r'):
		# lines starting with '#' will be ignored
        	if line[0]=='#':
        	    continue
        	line_data =line.strip().split('\t')
		f=open(newpath+"/"+line_data[0]+".txt","w")
		if len(line_data) == 5:
			for i in line_data[4].split(";"):
				a.append(i)
		else:
			for i in line_data[4].split(";"):
				a.append(i)
			for i in line_data[5].split(";"):
				a.append(i)
	for tw in a:
		f.write(tw+"\n")

	return newpath

#-------------------------------------ROC Files
def roc_file (rocpath,r):
	path = rocpath+'/*.txt' 
	files=glob.glob(path) 
	for file in files:
		b=max(loc for loc, val in enumerate(file) if val == '/')
		#a=file[29:-4]
		a=file[b+1:-4]
		index=a.index('&')
		a1=a[:index]
		a2=a[index+1:]
		#print a1
		#print a2
		f2 = list(csv.reader(open('negsab.csv', 'rb'), delimiter='\t')) 
		for f1 in f2:
			if str(f1[0]) == str(a1) and str(f1[1]) == str(a2):
				#print f1
				f=open(file, 'r')
				fw=open(r+"/"+a+".txt","w")
				#print r+"/"+a+".txt"
				for line in f:
					#print line
					line=line.rstrip()
					x=line.split()
					fw.write(str(x[0]) +"\t"+str(x[1])+"\t" + str(f1[2]) +"\t"+str(f1[3])+"\t"+str(float(x[1])-float(f1[3]))+"\t"+str(x[2])+"\n")


#----------------------roc score calculation
def roc_score (r):
	path = r+'/*.txt'   
	files=glob.glob(path) 
	fw=open("roc.txt","w") 
	#fw.write("#Disease Pair"+"\t"+"ROC Score"+"\t"+"Zero Count"+"\t"+"One Count"+"\t"+"Comorbidity Value\n")
	for file in files:
		b=max(loc for loc, val in enumerate(file) if val == '/')
		#a=file[27:-4]
		a=file[b+1:-4]
		f=open(file, 'r')
		data=[]
		res=[]
		for line in f:
			line=line.rstrip()
			x=line.split()
			data.append(float(x[4]))
			res.append(int(x[5]))
		ab= set(res)
		if len(ab) == 2:
			z=Counter(res)
			false_positive_rate, true_positive_rate, thresholds = roc_curve(res, data)
			roc_auc = auc(false_positive_rate, true_positive_rate)	
			ax=roc_auc_score(res, data)
			fw.write(str(a)+ "\t"+str(roc_auc)+"\t"+str(z[0])+"\t"+str(z[1])+"\t"+x[2]+"\n")
		else:
			continue

#---------------------histogram plot
def hist_plot ():
	f = list(csv.reader(open('roc.txt', 'rb'), delimiter='\t'))
	x=[]
	for line in range(len(f)):
		#print f[line][1]
		x.append(float(f[line][1]))
	#print x
	plt.hist(x)
	plt.title("ROC Score Histogram")
	plt.xlabel("Value")
	plt.ylabel("Frequency")
	plt.show()


"""
# =============================================================================

           E N D    O F    D E F I N I T I O N S 

# =============================================================================
"""
diseasepath = '/home/paki/Desktop/OriginalData/Disease_genes' 
#if not os.path.exists(diseasepath):
#	os.makedirs(diseasepath)
#disease_file(diseasepath)
rocpath = '/home/paki/Desktop/OriginalData/ROC' 
if not os.path.exists(rocpath):
	os.makedirs(rocpath)

complete_data = list(csv.reader(open('negsab.csv', 'rb'), delimiter='\t'))
for line in range(len(complete_data)):
	print ('begin: '+ str(line))
	gene_file_1  = complete_data[line][0]
	gene_file_2  = complete_data[line][1]
	#print ("file name")		
	print gene_file_1
	print gene_file_2
	#separate common and uncommon genes from disease pair	
	a,b=find_genes(gene_file_1,gene_file_2,diseasepath)
	print a 
	print b
	#print str(a) + "\t"+str(b)
	if a != 0 and b != 0:
		#print "gene separated"
		# calculate s_AB for uncommon genes
		s_AB_uncommon_genes(gene_file_1,gene_file_2,rocpath,diseasepath)
		# caluculate s_AB for common genes
		s_AB_common_genes(gene_file_1,gene_file_2,rocpath,diseasepath)
		print('finished: '+str(line))    	
	else:
		#print "inside else"
		continue

r = '/home/paki/Desktop/OriginalData/R' 
if not os.path.exists(r):
	os.makedirs(r)
roc_file(rocpath,r)
roc_score(r)
hist_plot()  
        

