import multiprocessing

import numpy as np
from cluster_scope import Cluster_Expansion
from collections import defaultdict 
from itertools import zip_longest
from copy import deepcopy
import deterministic_ipca
import coach_script
import mcode_script
import torch as th
from multiprocessing import Pool
DEBUG = False 

class LocalAlignment:
    def __init__(self, ppin1_filename, ppin2_filename, datanet1, datanet2,  result_seed_pair, entitymap, relationmap, entity_emb_file,relation_emb_file, entity_id_map, GAMMA=0.25, SCORE_THRESHOLD=0.02, EMB_SIM_THRESHOLD=0.7, alph=0.5, clstm="ipca"):
        print("isnide init")
        self.GAMMA = GAMMA
        self.SCORE_THRESHOLD = SCORE_THRESHOLD
        self.EMB_SIM_THRESHOLD = EMB_SIM_THRESHOLD
        self.alpha=alph
        self.res=[]
        items = [ppin1_filename,ppin2_filename]
        if (clstm=="ipca"):
            with Pool(multiprocessing.cpu_count()) as pool:
                self.res = pool.map(deterministic_ipca.ipca, items)
                print("ipca m")

        elif(clstm=="mcode"):
            with Pool(multiprocessing.cpu_count()) as pool:
                self.res = pool.map(mcode_script.mcode, items)
                print("mcode m")
        elif (clstm == "coach"):
            with Pool(multiprocessing.cpu_count()) as pool:
                self.res = pool.map(coach_script.coach, items)
                print("coach m")

        #print("self.res=",self.res)
        self.cluster_expans_net1 = Cluster_Expansion(datanet1)
        self.cluster_expans_net2 = Cluster_Expansion(datanet2)
        

        self.sim_all, temp = load_weighted_graph(result_seed_pair)
        

        self.sorted_sim = sorted(temp, key=lambda x:x[2],reverse=True)
        ##########################################################################
        self.entity_emb = np.load(entity_emb_file)
        self.relation_emb = np.load(relation_emb_file)
        self.param_entity_map= entitymap
        self.param_relation_map = relationmap
        treatment = ['ppi::DrosophilaMelanogaster', 'ppi::Mouse', 'ppi::DrosophilaMelanogaster::complex', 'ppi::Mouse::complex']
        self.treatment_rid = [self.param_relation_map[treat] for treat in treatment]
        self.treatment_rid = th.tensor(self.treatment_rid)
        self.treatment_embs = [th.tensor(self.relation_emb[rid]) for rid in self.treatment_rid]

        treatmentHM = ['ppi::DrosophilaMelanogaster::complex']
        self.treatment_rid = [self.param_relation_map[treat] for treat in treatmentHM]
        self.treatment_rid = th.tensor(self.treatment_rid)
        self.treatment_embsHM = [th.tensor(self.relation_emb[rid]) for rid in self.treatment_rid]

        treatmentSach = ['ppi::Mouse::complex']
        self.treatment_rid = [self.param_relation_map[treat] for treat in treatmentSach]
        self.treatment_rid = th.tensor(self.treatment_rid)
        self.treatment_embsSach = [th.tensor(self.relation_emb[rid]) for rid in self.treatment_rid]


        self.gamma = 12.0
        self.entity_id_map=entity_id_map
        ################################################################

    import torch as th

    def Rotate(self, head, rel, tail):
        # Convert tail to a PyTorch tensor if it is a NumPy array
        if isinstance(tail, np.ndarray):
            tail = th.from_numpy(tail)

        # Split complex embeddings into real and imaginary parts
        head_real, head_imag = th.chunk(head, 2, dim=-1)
        rel_real, rel_imag = th.chunk(rel, 2, dim=-1)
        tail_real, tail_imag = th.chunk(tail, 2, dim=-1)

        # Calculate rotation
        rotated_head_real = head_real * rel_real - head_imag * rel_imag
        rotated_head_imag = head_real * rel_imag + head_imag * rel_real

        # Calculate score
        score_real = rotated_head_real * tail_real + rotated_head_imag * tail_imag
        score_imag = rotated_head_real * tail_imag - rotated_head_imag * tail_real

        # Sum real and imaginary scores
        score = th.sum(th.stack([score_real, score_imag], dim=-1), dim=-1)

        return score

    def transE_l2(self, head, rel, tail):
        score = head + rel - tail
        return self.gamma - th.norm(score, p=2, dim=-1)
    def DistMult(self, head, rel, tail):
        score = head * rel * tail
        return th.sum(score, dim=-1)
        
    def get_score(self, c1, c2, DEBUG=False):
        sim_map = self.sim_all 
        sum_best_sim = 0
        for p1 in c1:
            best_sim = 0
            for p2 in c2:
                if p1 in sim_map and p2 in sim_map[p1]: 
                    best_sim = max(best_sim, sim_map[p1][p2])
            sum_best_sim += best_sim
            
        scoreB = sum_best_sim/len(c1)  if len(c1) > 0 else 0 
        
        score = scoreB
        
        return score

    def sigmoid(self, x):
        return 1 / (1 + np.exp(-x))
    def get_clusters(self):

        
        visited1, visited2 = set(), set()
        clusters1, clusters2 = [], [] 
        
        large_info = []
        
        num_visited = 0
        cached_init_clusters1, cached_init_clusters2 = dict(), dict()
        for g1, g2, w in self.sorted_sim:
            num_visited += 1

            # If both nodes are already in clusters, skip
            if g1 in visited1 and g2 in visited2: continue

            print("Number of nodes remaining: %6d" % (len(self.sorted_sim) - num_visited), end='\r')

            try:
                cluster1 = cached_init_clusters1.setdefault(g1, self.res[0][g1])
                #print("cluster1", cluster1[0])
                cluster2 = cached_init_clusters2.setdefault(g2, self.res[1][g2])
            except:
                continue


            cluster1, cluster2 = self.cluster_expans_net1.cluster_list_to_dict(
                cluster1), self.cluster_expans_net2.cluster_list_to_dict(cluster2)
            Cluster1_ListProteins = []
            Cluster_test_ListProteins = []
            for key in cluster1:
                Cluster1_ListProteins.append(self.param_entity_map[key])
                Cluster_test_ListProteins.append(self.param_entity_map[key])

            Cluster1_ListProteins = th.tensor(Cluster1_ListProteins).long()
            Cluster1List_emb = th.tensor(self.entity_emb[Cluster1_ListProteins])

            Cluster2_ListProteins = []
            for key2 in cluster2:
                Cluster2_ListProteins.append(self.param_entity_map[key2])
            Cluster2_ListProteins = th.tensor(Cluster2_ListProteins).long()
            Cluster2List_emb = th.tensor(self.entity_emb[Cluster2_ListProteins])

            for rid in range(len(self.treatment_embsHM)):
                self.treatment_emb_HM_in = self.treatment_embsHM[rid]
            treatment_emb_HM_in = self.treatment_emb_HM_in

            for rid in range(len(self.treatment_embsSach)):
                self.treatment_emb_Sach_in = self.treatment_embsSach[rid]
            treatment_emb_Sach_in = self.treatment_emb_Sach_in


            #print("treatment_emb_indiv", treatment_emb_HM_in)
            #print("treatment_embs", self.treatment_embs)

            expanded1, expanded2 = True, True
            while True:
                listoutboundary_all_embedding_List = {}
                listoutboundary_all_embedding_List_cl2 = {}
                initial_size1, initial_size2 = len(cluster1), len(cluster2)
                initial_cluster1, initial_cluster2 = deepcopy(cluster1), deepcopy(cluster2)

                initial_score1 = self.get_score(cluster1, cluster2) if expanded1 else initial_score1
                initial_score2 = self.get_score(cluster2, cluster1) if expanded2 else initial_score2

                if initial_score1 < self.SCORE_THRESHOLD or initial_score2 < self.SCORE_THRESHOLD:
                    break

                outer_boundary1, outer_boundary2 = self.cluster_expans_net1.get_protein_out_scope(
                    cluster1) if expanded1 else outer_boundary1, self.cluster_expans_net2.get_protein_out_scope(
                    cluster2) if expanded2 else outer_boundary2

                ContextCluster1 = []
                ContextCluster2 = []
                for pcl1 in outer_boundary1:
                    ContextCluster1.append(self.param_entity_map[pcl1])

                for pcl2 in outer_boundary2:
                    ContextCluster2.append(self.param_entity_map[pcl2])

                ContextCluster2 = th.tensor(ContextCluster2).long()

                scores_per_protoutbdrcl2 = []
                poutidscl2 = []
                if (len(outer_boundary2) > 1):
                    for rid in range(len(self.treatment_embs)):
                        treatment_emb_cl2 = self.treatment_embs[rid]
                        for outbcl1 in ContextCluster2:
                            prot = self.entity_emb[outbcl1]
                            scorecl2 = self.DistMult(Cluster1List_emb, treatment_emb_cl2, prot)
                            scores_per_protoutbdrcl2.append(scorecl2)
                            poutidscl2.append(ContextCluster2)

                    try:
                        scorescl2 = th.cat(scores_per_protoutbdrcl2)
                    except:
                        scorescl2 = th.stack(scores_per_protoutbdrcl2, dim=0)
                        #print("except")

                    poutidscl2 = th.cat(poutidscl2)

                    idx = th.flip(th.argsort(scorescl2), dims=[0])
                    scorescl2 = scorescl2[idx].numpy()
                    try:
                        poutidscl2 = poutidscl2[idx].numpy()
                    except:
                        pass

                    _, unique_indices = np.unique(poutidscl2, return_index=True)
                    topkcl2 = len(outer_boundary2)
                    topk_indicescl2 = np.sort(unique_indices)[:topkcl2]
                    proposed_protcl2 = poutidscl2[topk_indicescl2]
                    proposed_scorescl2 = scorescl2[topk_indicescl2]
                    #print("Analyze")
                    #print(len(proposed_protcl2))

                    for i in range(len(outer_boundary2)):
                        protcl2 = int(proposed_protcl2[i])
                        scorecl2 = proposed_scorescl2[i]
                        #print(str(self.entity_id_map[protcl2]), float(self.sigmoid(scorecl2)))
                        if str(self.entity_id_map[protcl2]) not in listoutboundary_all_embedding_List_cl2:
                            listoutboundary_all_embedding_List_cl2[str(self.entity_id_map[protcl2])] = float(
                                self.sigmoid(scorecl2))


                #print(len(cluster1))
                #print(cluster1)

                #print(len(outer_boundary2))
                #print(outer_boundary2)

                #print("-------------------")

                ContextCluster1 = th.tensor(ContextCluster1).long()

                scores_per_protoutbdr = []
                poutids = []
                if (len(outer_boundary1) > 1):
                    for rid in range(len(self.treatment_embs)):
                        treatment_emb = self.treatment_embs[rid]
                        for outbcl1 in ContextCluster1:
                            prot = self.entity_emb[outbcl1]
                            score = self.DistMult(Cluster1List_emb, treatment_emb, prot)
                            scores_per_protoutbdr.append(score)
                            poutids.append(ContextCluster1)

                    try:
                        scores = th.cat(scores_per_protoutbdr)
                    except:
                        scores = th.stack(scores_per_protoutbdr, dim=0)
                        #print("except")

                    poutids = th.cat(poutids)

                    idx = th.flip(th.argsort(scores), dims=[0])
                    scores = scores[idx].numpy()
                    try:
                        poutids = poutids[idx].numpy()
                    except:
                        pass

                    _, unique_indices = np.unique(poutids, return_index=True)
                    topk = len(outer_boundary1)
                    topk_indices = np.sort(unique_indices)[:topk]
                    proposed_prot = poutids[topk_indices]
                    proposed_scores = scores[topk_indices]

                    for i in range(len(outer_boundary1)):
                        prot = int(proposed_prot[i])
                        score = proposed_scores[i]
                        if str(self.entity_id_map[prot]) not in listoutboundary_all_embedding_List:
                            listoutboundary_all_embedding_List[str(self.entity_id_map[prot])] = float(
                                self.sigmoid(score))


                expanded1, expanded2 = False, False

                for v1, v2 in zip_longest(outer_boundary1, outer_boundary2):
                    if v1:
                        protsource_emb = self.entity_emb[self.param_entity_map[v1]]
                        cluster1[v1] = dict()
                        for n, w in self.cluster_expans_net1.graph[v1]:
                            if n in cluster1:
                                prottarget_emb = self.entity_emb[self.param_entity_map[n]]
                                scoreLocal = self.DistMult(th.tensor(protsource_emb), treatment_emb_HM_in,
                                                           th.tensor(prottarget_emb))

                                if (len(outer_boundary1) > 1):
                                    score_lcl = float(
                                        self.alpha * listoutboundary_all_embedding_List[str(v1)] + (
                                                    1 - self.alpha) * self.sigmoid(
                                            scoreLocal))

                                else:
                                    score_lcl = self.sigmoid(scoreLocal)

                                if (score_lcl > float(0.5)):
                                    cluster1[v1][n] = score_lcl
                                    cluster1[n][v1] = score_lcl


                        score1 = self.get_score(cluster1, initial_cluster2)

                        if score1 > initial_score1:
                            initial_cluster1 = deepcopy(cluster1)
                            initial_score1 = score1
                            expanded1 = True


                        else:  # Revert back
                            cluster1 = deepcopy(initial_cluster1)
                            score1 = initial_score1

                    if v2:
                        protsource_embcl2 = self.entity_emb[self.param_entity_map[v2]]
                        cluster2[v2] = dict()
                        for n, w in self.cluster_expans_net1.graph[v2]:
                            if n in cluster2:
                                prottarget_emb_cl2 = self.entity_emb[self.param_entity_map[n]]
                                scoreLocalcl2 = self.DistMult(th.tensor(protsource_embcl2), treatment_emb_Sach_in,
                                                              th.tensor(prottarget_emb_cl2))

                                if (len(outer_boundary2) > 1):
                                    score_lcl_cl2 = float(
                                        self.alpha * listoutboundary_all_embedding_List_cl2[str(v2)] + (
                                                1 - self.alpha) * self.sigmoid(
                                            scoreLocalcl2))

                                else:
                                    score_lcl_cl2 = self.sigmoid(scoreLocalcl2)

                                if (score_lcl_cl2 > float(0.5)):
                                    cluster2[v2][n] = score_lcl_cl2
                                    cluster2[n][v2] = score_lcl_cl2


                        score2 = self.get_score(cluster2, initial_cluster1)
                        if score2 > initial_score2:
                            initial_cluster2 = deepcopy(cluster2)
                            initial_score2 = score2
                            expanded2 = True
                        else:
                            cluster2 = deepcopy(initial_cluster2)
                            score2 = initial_score2

                if initial_size1 == len(cluster1) and initial_size2 == len(cluster2): break

                if len(cluster1) == 1 and len(cluster2) == 1:
                    break

            if len(cluster1) > 1 and len(cluster2) > 1:
                clusters1.append(cluster1)
                clusters2.append(cluster2)

                for g in cluster1: visited1.add(g)
                for g in cluster2: visited2.add(g)


        return clusters1, clusters2 

def write_cluster_to_file(clusters, filename):
    with open(filename, 'w') as fp:
        for c in clusters:
            if len(c) > 1:
                print(str(len(c)) + " " + " ".join(c), file=fp)
                
def load_weighted_graph(result_seed_pair):
    data=result_seed_pair
    graph = defaultdict(dict)
    temp = []
    for g1, g2, w in data:
        graph[g1][g2] = float(w)
        graph[g2][g1] = float(w)
        temp.append((g1,g2,float(w)))
    return (dict(graph), temp)
def comparison_context_two_cluster(self,v1, cluster1, treatment_emb_indiv, outer_boundary1, listoutboundary_all_embedding_List, initial_cluster2, initial_score1, initial_cluster1):
    protsource_emb = self.entity_emb[self.param_entity_map[v1]]

    cluster1[v1] = dict()
    for n, w in self.cluster_expans_net1.graph[v1]:
        if n in cluster1:
            prottarget_emb = self.entity_emb[self.param_entity_map[n]]
            scoreLocal = self.DistMult(th.tensor(protsource_emb), treatment_emb_indiv,
                                       th.tensor(prottarget_emb))
            if (len(outer_boundary1) > 1):
                score_lcl = float(
                    self.alpha * listoutboundary_all_embedding_List[str(v1)] + (1 - self.alpha) * self.sigmoid(
                        scoreLocal))
            else:
                score_lcl = self.sigmoid(scoreLocal)
            if (score_lcl > float(0.5)):
                cluster1[v1][n] = score_lcl
                cluster1[n][v1] = score_lcl


    score1 = self.get_score(cluster1, initial_cluster2)

    if score1 > initial_score1:
        initial_cluster1 = deepcopy(cluster1)
        initial_score1 = score1
        expanded1 = True


    else:
        cluster1 = deepcopy(initial_cluster1)
        score1 = initial_score1

   
