import os
import shutil
import tensorflow as tf
from ae.utils.flags_ppi import FLAGS, home_out
from ae.utils.start_tensorboard_ppi import start
from ae.utils.data_ppi import read_data_sets
from ae.utils.data_ppi import simulation_data_sets
import numpy as np

_data_dir = FLAGS.data_dir
_summary_dir = FLAGS.summary_dir
_chkpt_dir = FLAGS.chkpt_dir

class AutoEncoderRec(object):
  """Generic deep autoencoder.

  Autoencoder used for full training cycle, including
  unsupervised pretraining layers and final fine tuning.
  The user specifies the structure of the neural net
  by specifying number of inputs, the number of hidden
  units for each layer and the number of final output
  logits.
  """
  _weights_str = "weights{0}"
  _biases_str = "biases{0}"

  def __init__(self, shape, sess, vars):
    """Autoencoder initializer

    Args:
      shape: list of ints specifying
              num input, hidden1 units,...hidden_n units, num logits
      sess: tensorflow session object to use
    """
    self.__shape = shape  # [input_dim,hidden1_dim,...,hidden_n_dim,output_dim]
    self.__num_hidden_layers = len(self.__shape) - 2

    self.__variables = {}
    self.__sess = sess
    self._recover_variables(vars)
    #self._setup_variables(vars)

  @property
  def shape(self):
    return self.__shape

  @property
  def num_hidden_layers(self):
    return self.__num_hidden_layers

  @property
  def session(self):
    return self.__sess

  @property
  def variables(self):
      return self.__variables

  def __getitem__(self, item):
    """Get autoencoder tf variable

    Returns the specified variable created by this object.
    Names are weights#, biases#, biases#_out, weights#_fixed,
    biases#_fixed.

    Args:
     item: string, variables internal name
    Returns:
     Tensorflow variable
    """
    return self.__variables[item]

  def __setitem__(self, key, value):
    """Store a tensorflow variable

    NOTE: Don't call this explicity. It should
    be used only internally when setting up
    variables.

    Args:
      key: string, name of variable
      value: tensorflow variable
    """
    self.__variables[key] = value
  
  def _recover_variables(self, vars):
    with tf.name_scope("autoencoder_vars"):
      for i in range(self.__num_hidden_layers + 1):
        # Train weights
        name_w = self._weights_str.format(i + 1)
        name_b = self._biases_str.format(i + 1)
        if i < self.__num_hidden_layers:
          name_b_out = self._biases_str.format(i + 1) + "_out"
        print(name_w, name_b)
        for v in vars:
          firstname, lastname = v.name.split('/')
          midname, lastname = lastname.split(':')
          if midname == name_w:
            print(name_w)
            w_init = self.session.run(v.name)
            print(w_init.shape)
          if midname == name_b:
            print(name_b)
            b_init = self.session.run(v.name)
            print(b_init.shape)
          if i < self.__num_hidden_layers:
              if midname == name_b_out:
                print(name_b_out)
                b_init_out = self.session.run(v.name)
                print(b_init_out.shape)

        #w_shape = (self.__shape[i], self.__shape[i + 1])
        self[name_w] = w_init
        # Train biases
        #b_shape = (self.__shape[i + 1],)
        self[name_b] = b_init

          # Pretraining output training biases
        if i < self.__num_hidden_layers:
          #b_shape = (self.__shape[i],)
          self[name_b_out] = b_init_out

  def _w(self, n, suffix=""):
    return self[self._weights_str.format(n) + suffix]

  def _b(self, n, suffix=""):
    return self[self._biases_str.format(n) + suffix]


  @staticmethod
  def _activate(x, w, b, transpose_w=False, name='sigmoid'):
    linear = tf.nn.bias_add(tf.matmul(x, w, transpose_b=transpose_w), b)
    if name == 'sigmoid':
        return tf.nn.sigmoid(linear, name='encoded')
    elif name == 'softmax':
        return tf.nn.softmax(linear, name='encoded')
    elif name == 'linear':
        return linear
    elif name == 'tanh':
        return tf.nn.tanh(linear, name='encoded')
    elif name == 'relu':
        return tf.nn.relu(linear, name='encoded')

  def supervised_net(self, input_pl):
    """Get the supervised fine tuning net

    Args:
      input_pl: tf placeholder for ae input data
    Returns:
      Tensor giving full ae net
    """
    last_output = input_pl

    for i in range(self.__num_hidden_layers + 1):
      # Fine tuning will be done on these variables
      w = self._w(i + 1)
      b = self._b(i + 1)

      last_output = self._activate(last_output, w, b)

    return last_output

def _check_and_clean_dir(d):
  if os.path.exists(d):
    shutil.rmtree(d)
    os.mkdir(d)
    
def final_output(ae, data, node_feature_file = './encoded_matrix.csv'):
  try:
    os.remove(node_feature_file)
  except OSError:
    pass

  with ae.session.graph.as_default():
    sess = ae.session
    input_pl = tf.placeholder(tf.float32, name='input_pl')
    feature_v = ae.supervised_net(input_pl)
    steps = data.num_examples
    for i in range(steps):
      vec = sess.run(feature_v, feed_dict={input_pl:[data.matrices[i]]})
      #print(vec.shape)
      tmp = 1
      tmps = ""
      cnt = 0
      for x in np.nditer(vec[0]):
        cnt += 1
        if tmp:
          tmps = tmps + str(x)
          tmp = 0
        else: tmps = tmps + ',' + str(x)
      tmps = tmps + '\n'
      print(i, vec.shape, cnt)
      f = open(node_feature_file, 'a')
      f.write(tmps)
      f.close()
      '''
      fv = '\t'.join(map(str, ))
      fv = fv.strip('[]')
      fv = fv.replace('  ', '\t')
      f = open(node_feature_file, 'a')
      f.write(str(i)+'\t'+ fv.replace('\n', ' ')+'\n')
      f.close()
      '''

def main():
  data, pretrain_data = read_data_sets()
  sess = tf.Session()
  print("Importing graph......")
  rec_vars = tf.train.import_meta_graph('./supervised_model_ppiprediction.ckpt.meta')
  print("Importing variables......")
  rec_vars.restore(sess, tf.train.latest_checkpoint('./'))
  all_vars = tf.trainable_variables()
  num_hidden = FLAGS.num_hidden_layers
  ae_hidden_shapes = [getattr(FLAGS, "hidden{0}_units".format(j + 1)) for j in range(num_hidden)]
  ae_shape = [FLAGS.image_pixels] + ae_hidden_shapes + [FLAGS.num_classes]
  ae = AutoEncoderRec(ae_shape, sess, all_vars)
  print("Caculating encoded matrix......")
  final_output(ae, data, node_feature_file = './encoded_matrix.csv')

if __name__ == '__main__':
    main()
