"""Functions for preprocess data ."""
from __future__ import division
from __future__ import print_function
import pandas as pd
import gzip
import numpy as np
import scipy.sparse as sparse
from sklearn import preprocessing
from sklearn.preprocessing import scale
from .flags_ppi import FLAGS
import itertools
#from six.moves import urllib
import urllib
import random
#from six.moves import xrange  # pylint: disable=redefined-builtin
#from .flags import FLAGS

import os

FILE_PATH = '/home/data/PPINcluster/DIP_Yeast_20150101/shortepoch/'
SHAPE = 5093
#SHAPE = 9617
#SHAPE = 100
def ppi_training_testing(filename):
  edge_dict = {}
  trainingfile = FILE_PATH+'yeast_ppi_025rand_training_r1.txt'
  testingfile = FILE_PATH+'yeast_ppi_025rand_testing_r1.txt'
  cnt = 0
  ratio = 0.25
  for line in open(FILE_PATH+filename, 'r'):
    tmp = line.strip().split('\t')
    a, b = tmp[0], tmp[1]
    if a in edge_dict:
      edge_dict[a].append(b)
    else:
      edge_dict[a] = [b]
    cnt += 1
  print("The total number of edges is {0}".format(cnt))
  training_edge_dict = {}
  testing_edge_dict = {}
  for k in random.shuffle(edge_dict.keys()):
    random.shuffle(edge_dict[k])
    #t = max(1, int(len(edge_dict[k]) * ratio))
    t = int(len(edge_dict[k]) * ratio)
    training_edge_dict[k] = edge_dict[k][:t]
    testing_edge_dict[k] = edge_dict[k][t:]
  fw = open(trainingfile, 'w')
  for a in training_edge_dict:
    for b in training_edge_dict[a]:
      fw.write(a+'\t'+b+'\n')
  fw.close()
  fw = open(testingfile, 'w')
  for a in testing_edge_dict:
    for b in testing_edge_dict[a]:
        if (a not in training_edge_dict or (b not in training_edge_dict[a])) \
                and (b not in training_edge_dict or (a not in training_edge_dict[b])):
            fw.write(a+'\t'+b+'\n')
  fw.close()

def read_data_as_matrix(filename, label = False):
  arr = []
  for line in open(FILE_PATH+filename, 'r'):
    tmp = line.strip().split('\t')
    if label:
      arr.append([int(tmp[0])-1, int(tmp[1])-1, 1])
      arr.append([int(tmp[1])-1, int(tmp[0])-1, 1])
    else:
      arr.append([int(tmp[0])-1, int(tmp[1])-1, float(tmp[2])])
      arr.append([int(tmp[1])-1, int(tmp[0])-1, float(tmp[2])])
  arr = np.array(arr)
    #shape = tuple(arr.max(axis=0)[:2]+1)
  coo = sparse.coo_matrix((arr[:, 2], (arr[:, 0], arr[:, 1])), shape=(SHAPE, SHAPE), dtype=arr.dtype)
  return coo

def read_label_as_matrix(filename):
  arr = []
  for line in open(FILE_PATH+filename, 'r'):
    tmp = line.strip().split('\t')
    arr.append([int(tmp[0])-1, int(tmp[1])-1, 1])
    arr.append([int(tmp[1])-1, int(tmp[0])-1, 1])
  arr.append([0,0,0])
  arr.append([5092, 5092, 0])
  arr = np.array(arr)
  #shape = tuple(arr.max(axis=0)[:2]+1)
  coo = sparse.coo_matrix((arr[:, 2], (arr[:, 0], arr[:, 1])), shape=(SHAPE, SHAPE), dtype=arr.dtype)
  return coo

class DataSet(object):

  def __init__(self, k1, labels):
    #matrices = np.concatenate((k1, k2), axis=1)
    #matrices = np.concatenate((matrices, k3), axis=1)
    #matrices = np.concatenate((matrices, k4), axis=1)
    self._matrices = k1
    self._labels = labels
    self._epochs_completed = 0
    self._index_in_epoch = 0
    self._num_examples = SHAPE
    #print("The matrices shape is {0}".format(matrices.shape))
    print("The labels shape is {0}".format(labels.shape))

  @property
  def matrices(self):
    return self._matrices

  @property
  def labels(self):
    return self._labels

  @property
  def num_examples(self):
    return self._num_examples

  @property
  def epochs_completed(self):
    return self._epochs_completed

  def next_batch(self, batch_size):
    """Return the next `batch_size` examples from this data set."""
    start = self._index_in_epoch
    self._index_in_epoch += batch_size
    if self._index_in_epoch > self._num_examples:
      # Finished epoch
      self._epochs_completed += 1
      # Shuffle the data
      perm = np.arange(self._num_examples)
      np.random.shuffle(perm)
      self._matrices = self._matrices[perm]
      self._labels = self._labels[perm]
      # Start next epoch
      start = 0
      self._index_in_epoch = batch_size
      assert batch_size <= self._num_examples
    end = self._index_in_epoch
    return self._matrices[start:end], self._labels[start:end]

class DataSetPreTraining(object):

  def __init__(self, k1):
    #matrices = np.concatenate((k1, k2), axis=1)
    #matrices = np.concatenate((matrices, k3), axis=1)
    #matrices = np.concatenate((matrices, k4), axis=1)

    self._matrices = k1
    self._epochs_completed = 0
    self._index_in_epoch = 0
    self._num_examples = SHAPE
    #print("The matrices shape is {0}".format(matrices.shape))

  @property
  def matrices(self):
    return self._matrices

  @property
  def num_examples(self):
    return self._num_examples

  @property
  def epochs_completed(self):
    return self._epochs_completed

  def next_batch(self, batch_size):
    """Return the next `batch_size` examples from this data set."""
    start = self._index_in_epoch
    self._index_in_epoch += batch_size
    if self._index_in_epoch > self._num_examples:
      # Finished epoch
      self._epochs_completed += 1
      # Shuffle the data
      perm = np.arange(self._num_examples)
      np.random.shuffle(perm)
      self._matrices = self._matrices[perm]
      # Start next epoch
      start = 0
      self._index_in_epoch = batch_size
      assert batch_size <= self._num_examples
    end = self._index_in_epoch

    return self._matrices[start:end], self._matrices[start:end]

def module_matrix(arr):
  twom = np.sum(arr)
  m, n = arr.shape
  module_matrix = np.zeros((m, n))
  deg_arr = np.sum(arr, axis = 1)
  for i in range(m):
    for j in range(n):
      if arr[i, j]: module_matrix[i, j] = arr[i, j] - (deg_arr[i]*deg_arr[j])/twom
  return module_matrix

def read_data_sets():
  print("begin to read data...")
  print("reading k1...")
  k1 = read_data_as_matrix('DIP_Yeast_BlastID.txt')
  print("reading k2...")
  k2 = read_data_as_matrix('DIP_Yeast_ExprID.txt')
  print("reading k3...")
  k3 = read_data_as_matrix('DIP_Yeast_HmmID.txt')
  print("reading k4...")
  #k4 = read_data_as_matrix('DIP_Yeast_SWID.txt')
  #k4 = read_data_as_matrix('DIP_Yeast_BPID.txt', True)
  #print("reading labels...")
  ppi_training_testing('DIP_Yeast_BPID.txt')
  labels = read_label_as_matrix('yeast_ppi_025rand_training_r1.txt')
  
  k1 = module_matrix(k1.todense())
  k2 = module_matrix(k2.todense())
  k3 = module_matrix(k3.todense())
  #k4 = module_matrix(k4.todense())
  labels = labels.todense()
  min_max_scaler = preprocessing.MinMaxScaler()
  print("rescaling data...")
  k1 = min_max_scaler.fit_transform(k1)
  k2 = min_max_scaler.fit_transform(k2)
  k3 = min_max_scaler.fit_transform(k3)
  #k4 = min_max_scaler.fit_transform(k4)
  #k4 = scale(k4.todense())
  #labels = scale(labels)
  print("generating datasets...")
  print("checking nan for k1...")
  print(np.isnan(k1).any())
  print(k1.shape)
  print("checking nan for k2...")
  print(np.isnan(k2).any())
  print(k2.shape)
  print("checking nan for k3...")
  print(np.isnan(k3).any())
  print(k3.shape)
  #print("checking nan for k4...")
  #print(np.isnan(k4).any())
  #print(k4.shape)
  #print("checking nan for labels...")
  print(np.isnan(labels).any())
  print(labels.shape)
  
  data_sets = DataSet(k1, k2, k3, labels)
  pretrain_data_sets = DataSetPreTraining(k1, k2, k3)

  return data_sets, pretrain_data_sets

def simulation_data_sets():
    #ppi_training_testing('HPRD_IDI.txt')
    ppi_training_testing('DIP_Yeast_BPID.txt')
    labels = read_label_as_matrix('yeast_ppi_025rand_training_r1.txt')
    labels = labels.todense()
    k1 = np.zeros((5093,5093))
    #k1 = np.zeros((9617,9617))
    data_sets = DataSet(k1, labels)
    pretrain_data_sets = DataSetPreTraining(k1)
    return data_sets, pretrain_data_sets

def _add_noise(x, rate):
  x_cp = np.copy(x)
  pix_to_drop = np.random.rand(x_cp.shape[0],
                                  x_cp.shape[1]) < rate
  x_cp[pix_to_drop] = FLAGS.zero_bound
  return x_cp


def fill_feed_dict_ae(data_set, input_pl, target_pl, noise=None):
    input_feed, target_feed = data_set.next_batch(FLAGS.batch_size)
    if noise:
      input_feed = _add_noise(input_feed, noise)
    feed_dict = {
        input_pl: input_feed,
        target_pl: target_feed
    }
    return feed_dict


def fill_feed_dict(data_set, matrices_pl, labels_pl, noise=False):
  """Fills the feed_dict for training the given step.
  A feed_dict takes the form of:
  feed_dict = {
      <placeholder>: <tensor of values to be passed for placeholder>,
      ....
  }
  Args:
    data_set: The set of images and labels, from input_data.read_data_sets()
    images_pl: The images placeholder, from placeholder_inputs().
    labels_pl: The labels placeholder, from placeholder_inputs().
  Returns:
    feed_dict: The feed dictionary mapping from placeholders to values.
  """
  # Create the feed_dict for the placeholders filled with the next
  # `batch size ` examples.
  matrices_feed, labels_feed = data_set.next_batch(FLAGS.batch_size)
  if noise:
      matrices_feed = _add_noise(matrices_feed, FLAGS.drop_out_rate)
  feed_dict = {
      matrices_pl: matrices_feed,
      labels_pl: labels_feed,
  }
  return feed_dict