import csv
import torch as th
try: import cPickle as pickle
except ImportError:
	import pickle
from MNAmaster.alignments import *

class MNAEmbedding:
    def __init__(self, ppin1_filename, ppin2_filename, entity_emb_file, relation_emb_file, ent_idmap_file, rel_idmap_file):
        # Example usage
        protein_delet_ids = []
        piddlet_idmap_file = 'DATA/entities.tsv'
        with open(piddlet_idmap_file) as f:
            for line in f:
                it = line.strip().split("\t")
                p1id = it[1]
                protein_delet_ids.append(p1id)
        #print(protein_delet_ids)
        print('Reading PPI network 1: ', ppin1_filename)

        self.datasaetnet1=[]
        ppiSource_id_DGLKE_List = {}
        pidSource_idmap_file = ppin1_filename
        with open(pidSource_idmap_file) as f:
            for line in f:
                it = line.strip().split("\t")
                p1id = it[0]
                p2id = it[1]
                if (p1id in protein_delet_ids) and (p2id in protein_delet_ids):
                    self.datasaetnet1.append([p1id, p2id])
                    if str(p1id) not in ppiSource_id_DGLKE_List:
                        ppiSource_id_DGLKE_List[str(p1id)] = set()
                    ppiSource_id_DGLKE_List[str(p1id)].add(str("specie1"))
                    if str(p2id) not in ppiSource_id_DGLKE_List:
                        ppiSource_id_DGLKE_List[str(p2id)] = set()
                    ppiSource_id_DGLKE_List[str(p2id)].add(str("specie1"))


        print('Reading PPI network 2: ', ppin2_filename)
        self.datasaetnet2 = []

        ppitarget_id_DGLKE_List = {}
        pidTarget_idmap_file = ppin2_filename
        with open(pidTarget_idmap_file) as f:
            for line in f:
                try:
                    it = line.strip().split("\t")
                    p1id = it[0]
                    p2id = it[1]
                    if (p1id in protein_delet_ids) and (p2id in protein_delet_ids):
                        self.datasaetnet2.append([p1id, p2id])
                        if str(p1id) not in ppitarget_id_DGLKE_List:
                            ppitarget_id_DGLKE_List[str(p1id)] = set()
                        ppitarget_id_DGLKE_List[str(p1id)].add(str("specie2"))
                        if str(p2id) not in ppitarget_id_DGLKE_List:
                            ppitarget_id_DGLKE_List[str(p2id)] = set()
                        ppitarget_id_DGLKE_List[str(p2id)].add(str("specie2"))

                except:
                    pass

        entity_emb = np.load(entity_emb_file)
        self.relation_emb = np.load(relation_emb_file)
        self.seedpair=[]
        entity_idmap_file = ent_idmap_file
        relation_idmap_file = rel_idmap_file
        self.entity_map = {}
        self.entity_id_map = {}
        self.relation_map = {}
        with open(entity_idmap_file, newline='', encoding='utf-8') as csvfile:
            reader = csv.DictReader(csvfile, delimiter='\t', fieldnames=['id', 'name'])
            for row_val in reader:
                self.entity_map[row_val['name']] = int(row_val['id'])
                self.entity_id_map[int(row_val['id'])] = row_val['name']

        with open(relation_idmap_file, newline='', encoding='utf-8') as csvfile:
            reader = csv.DictReader(csvfile, delimiter='\t', fieldnames=['id', 'name'])
            for row_val in reader:
                self.relation_map[row_val['name']] = int(row_val['id'])

        humanproteins_Initi_ids = []
        Mouseproteins_Initi_ids = []

        for protHuman in ppiSource_id_DGLKE_List:
            try:
                humanproteins_Initi_ids.append(self.entity_map[protHuman])
            except:
                pass
        for protMouse in ppitarget_id_DGLKE_List:
            try:
                Mouseproteins_Initi_ids.append(self.entity_map[protMouse])
            except:
                pass

        self.humanproteins_ids = th.tensor(humanproteins_Initi_ids).long()
        self.Mouseproteins_ids = th.tensor(Mouseproteins_Initi_ids).long()
        self.human_emb = th.tensor(entity_emb[self.humanproteins_ids])

        self.Mouse_emb = th.tensor(entity_emb[self.Mouseproteins_ids])
        treatment = ['ppi::Human']
        treatment_rid = [self.relation_map[treat] for treat in treatment]
        treatment_rid = th.tensor(treatment_rid)
        treatment_embs = [th.tensor(self.relation_emb[rid]) for rid in treatment_rid]

    def get_embed_species(self):
        return self.human_emb, self.Mouse_emb
    def get_entity_relation_map_embrdding(self):
        return  self.entity_map,  self.relation_map,  self.entity_id_map

    def get_embed_similartities(self, specie1, specie2):
        print(np.shape(specie1), np.shape(specie2))
        self.alignment_matrix = get_embedding_similarities(specie1, specie2, num_top=None,
                                                           sim_measure="cosine")
        return self.alignment_matrix
    def get_networks(self):
        return self.datasaetnet1, self.datasaetnet2

    def get_alignment_emb(self):
       return self.alignment_matrix
    def get_seed_pair(self):
       return self.seedpair



    def write_to_file_mna_emb(self, resultsemb, SEED_THRESHOLD=0.1):
        z = np.nonzero(resultsemb > float(SEED_THRESHOLD))
        arr = np.asarray(z)
        for column in arr.T:
            if (resultsemb[column[0]][column[1]] > float(SEED_THRESHOLD)):
                self.seedpair.append([str(self.entity_id_map[int(self.humanproteins_ids[column[0]])]),
                                      str(self.entity_id_map[int(self.Mouseproteins_ids[column[1]])]),
                                      float(resultsemb[column[0]][column[1]])])
