import argparse
import multiprocessing
import os
import logging
import sys, getopt
import time
from multiprocessing import Pool
from local_algn_emb_sim import LocalAlignment, write_cluster_to_file
from MNAmaster.mna_emb import MNAEmbedding

def parse_args(args=None):
    parser = argparse.ArgumentParser(
        description='Local network alignment using Knowledge Graph Embedding Models',
        usage='KOGAL.py [<args>] [-h | --help]'
    )
    parser.add_argument('--policy', help='Paths for BLAST files and networks', type=str, default='policy.txt')
    parser.add_argument('--strategy',  help='(strategy == 1) the alignment process starts by computes the cosine similarity between embedding vectors derived from knowledge graph models. (strategy == 2) the alignment process starts by calculating the centrality degree of nodes within each network, highlighting the importance of each protein in the network structure.', default=2, type=int)
    parser.add_argument('--SEED_DC',
                        help='Applying a threshold to filter the pertinent seed node pairs of (strategy ==1) in order to detect the pairs of initial clusters',
                        default=0.5, type=float)
    parser.add_argument('--SEED_THRESHOLD',
                        help='Applying a threshold to filter the pertinent seed node pairs in order to detect the pairs of initial clusters',
                        default=0.5, type=float)
    parser.add_argument('--entity_emb_path', action='store_true', help='Entity embedding file',
                        default='DATA/ckpts/TransE_l2_LNA_12_epoch80mil_800hdimention/LNA_TransE_l2_entity.npy')
    parser.add_argument('--relation_emb_path', action='store_true', help='Relation embedding file',
                        default='DATA/ckpts/TransE_l2_LNA_12_epoch80mil_800hdimention/LNA_TransE_l2_relation.npy')
    parser.add_argument('--entity_idmap_path', action='store_true', help='Entity mapping file',
                        default='DATA/ckpts/TransE_l2_LNA_12_epoch80mil_800hdimention/entities.tsv')
    parser.add_argument('--relation_idmap_path', action='store_true', help='Relation mapping file',
                        default='DATA/ckpts/TransE_l2_LNA_12_epoch80mil_800hdimention/relations.tsv')

    parser.add_argument('--gamma', default=0.25, type=float)
    parser.add_argument('--SCORE_THRESHOLD', help='Minimum score needed for cluster detection. Any cluster whose score is less than a given threshold is abandoned', default=0.02, type=float)

    parser.add_argument('--alpha', default=0.5, help='Tuning the contribution between the local and global edge score computed from the knowledge graph embedding',type=float)
    parser.add_argument('-save', '--save_path', default='results', help='',type=str)

    parser.add_argument('--clstm',  type=str, help='Choosing graph clustering techniques (i.e. ipca, mcode or coach)',
                        default='mcode')

    return parser.parse_args(args)



def main(args):


    if  args.save_path is None:
        raise ValueError('Where do you want to save your the alignment result?')

    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    st = time.time()
    '''
    KOGAL will begin the alignment process by concentrating on compiling the alignment matrix from the knowledge graph embedding
    If not, KOGAL will initiate the alignment process by concentrating on computing the alignment matrix using the knowledge graph embedding

    '''
    import math

    import networkx as nx
    import pandas as pd
    from operator import itemgetter  # Import itemgetter

    # Function to load PPI data from a file into a NetworkX graph
    def load_ppi(file_path):
        # Read the file assuming two columns without headers
        df = pd.read_csv(file_path, sep='\t', header=None, names=['Protein_A', 'Protein_B'])
        G = nx.Graph()

        # Add edges to the graph
        for idx, row in df.iterrows():
            G.add_edge(row['Protein_A'], row['Protein_B'])

        return G

    # Function to read BLAST bit scores from a file and store them in a dictionary
    def read_blast_scores(filename):
        bit_score_dict = {}
        with open(filename, 'r') as f:
            for line in f:
                protein1, protein2, bit_score = line.strip().split()
                bit_score_dict[(protein1, protein2)] = float(bit_score)
                bit_score_dict[(protein2, protein1)] = float(bit_score)  # Symmetric scores
        return bit_score_dict

    # Path to the policy.txt file
    policy_file_path = args.policy  # Replace with the actual path

    def read_policy_file(policy_path):

        with open(policy_path, 'r') as file:
            lines = [line.strip() for line in file.readlines()]

        if len(lines) < 5:
            raise ValueError("policy.txt must contain exactly 5 lines: 2 for networks and 3 for BLAST files.")

        return {
            'network1': lines[0],
            'network2': lines[1],
            'blast_network1_network1': lines[2],
            'blast_network1_network2': lines[3],
            'blast_network2_network2': lines[4],
        }

    # Parse the policy file
    paths = read_policy_file(policy_file_path)

    # Load the PPI networks
    specie1 = load_ppi(paths['network1'])
    specie2 = load_ppi(paths['network2'])


    if args.strategy ==1:
        ##############################seed Generation#############################


        # Load the BLAST bit scores using the paths from the policy.txt
        specie1blast = read_blast_scores(paths['blast_network1_network1'])
        specie2blast = read_blast_scores(paths['blast_network1_network2'])
        specie1_specie2_blast = read_blast_scores(paths['blast_network2_network2'])


        # Step 1: Compute degree centrality for both species
        centrality_human = nx.degree_centrality(specie1)
        centrality_species = nx.degree_centrality(specie2)

        # Step 2: Seed generation based on high centrality (Top N nodes)
        N = 800  # Number of seeds to select

        # Sort the nodes by centrality and pick the top N
        top_human_seeds = sorted(centrality_human.items(), key=itemgetter(1), reverse=True)[:N]
        top_species_seeds = sorted(centrality_species.items(), key=itemgetter(1), reverse=True)[:N]

        # Extract just the node names (without the centrality values)
        seed_human = [node for node, centrality in top_human_seeds]
        seed_species = [node for node, centrality in top_species_seeds]

        # print(f"Top {N} seeds in Human: {seed_human}")
        # print(f"Top {N} seeds in Species: {seed_species}")

        # Step 3: Extract the local neighborhood of the seeds
        neighborhood_human = set()
        neighborhood_species = set()

        for seed in seed_human:
            neighborhood_human |= set(specie1.neighbors(seed)) | {seed}

        for seed in seed_species:
            neighborhood_species |= set(specie2.neighbors(seed)) | {seed}

        # print(f"Neighborhood of Human seeds: {neighborhood_human}")
        # print(f"Neighborhood of Species seeds: {neighborhood_species}")

        # Step 4: Subgraph matching - compare local neighborhoods
        subgraph_human = specie1.subgraph(neighborhood_human)
        subgraph_species = specie2.subgraph(neighborhood_species)

        # Step 5: Calculate sequence similarity only, discarding topological similarity
        def calculate_sequence_similarity(node_human, node_species, specie1blast, specie2blast,
                                          specie1_specie2_blast):

            # Sequence similarity using normalized BLAST bit scores
            # Normalize the BLAST bit scores to range between 0 and 1 (assuming max score known)
            # max_blast_score_human = max(human_human_blast.values()) if human_human_blast else 1
            # max_blast_score_species = max(yeast_yeast_blast.values()) if yeast_yeast_blast else 1

            # Intra-species sequence similarity (normalized BLAST scores)
            # seq_similarity_human = max(
            #     [human_human_blast.get((node_human, n), 0) / max_blast_score_human for n in G_human.neighbors(node_human)],
            #     default=0)
            # seq_similarity_species = max([yeast_yeast_blast.get((node_species, n), 0) / max_blast_score_species for n in
            #                               G_species.neighbors(node_species)], default=0)

            # Cross-species sequence similarity (normalized BLAST score between human and yeast proteins)
            seq_similarity_cross = log_human_yeast_sim.get((node_human, node_species), 0)
            if max_cross_species_score > 0:
                normalized_seq_similarity_cross = seq_similarity_cross / max_cross_species_score
            else:
                normalized_seq_similarity_cross = 0

            # if (normalized_seq_similarity_cross > 0):
            #     print(normalized_seq_similarity_cross)
            # Weighted combination of sequence similarity (intra- and cross-species)
            beta = 0.5  # weight for intra-species sequence similarity
            gamma = 0.5  # weight for cross-species sequence similarity

            return seq_similarity_cross

        # Log transformation function
        def log_transform(score):
            return math.log(1 + score)

        # Step 6: Alignment expansion based on combined similarity
        def expand_alignment(subgraph_human, subgraph_species, specie1blast, specie2blast, specie1_specie2_blast, seedthresholds):
            alignment = []
            for node_human in subgraph_human.nodes():
                for node_species in subgraph_species.nodes():
                    similarity = calculate_sequence_similarity(node_human, node_species, specie1blast,
                                                               specie2blast,
                                                               specie1_specie2_blast)
                    if similarity > seedthresholds:  # Threshold for accepting alignment (tune this)
                        alignment.append([node_human, node_species, similarity])
            return alignment

        # Apply log transformation to BLAST bit scores
        log_human_human_blast = {k: log_transform(v) for k, v in specie1blast.items()}
        log_yeast_yeast_blast = {k: log_transform(v) for k, v in specie2blast.items()}
        log_human_yeast_sim = {k: log_transform(v) for k, v in specie1_specie2_blast.items()}
        # Max score observed in cross-species BLAST data (adjust according to your dataset)
        max_cross_species_score = max(log_human_yeast_sim.values())
        # Run the alignment expansion
        alignment = expand_alignment(subgraph_human, subgraph_species, log_human_human_blast, log_yeast_yeast_blast,
                                     log_human_yeast_sim, args.DC)
        print("Aligned nodes based on combined sequence and topology similarity:", len(alignment))

        #####################################################################"

        cluster_source_specie = args.save_path + '/' + paths['network1'].split("/")[-1]
        cluster_target_specie = args.save_path + '/' + paths['network2'].split("/")[-1]
        print("\nBegin Knowledge Graph Embedding similarities")
        mna_emb_two_species = MNAEmbedding(paths['network1'], paths['network2'], args.entity_emb_path, args.relation_emb_path,
                                           args.entity_idmap_path, args.relation_idmap_path)
        datanet1, datanet2 = mna_emb_two_species.get_networks()
        specie1Emb, specie2Emb = mna_emb_two_species.get_embed_species()
        resultsemb_same_specie = mna_emb_two_species.get_embed_similartities(specie1Emb, specie2Emb)

        mna_emb_two_species.write_to_file_mna_emb(resultsemb_same_specie, args.SEED_THRESHOLD)

        entity_map, relation_map, entity_id_map = mna_emb_two_species.get_entity_relation_map_embrdding()
        la = LocalAlignment(paths['network1'], paths['network2'], datanet1, datanet2, alignment, entity_map, relation_map,
                            args.entity_emb_path, args.relation_emb_path, entity_id_map, args.gamma,
                            args.SCORE_THRESHOLD, float(0), args.alpha, args.clstm)
        cluster1, cluster2 = la.get_clusters()
    elif(args.strategy==2):

        cluster_source_specie = args.save_path + '/' + paths['network1'].split("/")[-1]
        cluster_target_specie = args.save_path + '/' + paths['network2'].split("/")[-1]
        print("\nBegin Knowledge Graph Embedding similarities")
        mna_emb_two_species = MNAEmbedding(paths['network1'], paths['network2'], args.entity_emb_path, args.relation_emb_path,
                                           args.entity_idmap_path, args.relation_idmap_path)
        datanet1, datanet2 = mna_emb_two_species.get_networks()
        specie1Emb, specie2Emb = mna_emb_two_species.get_embed_species()
        resultsemb_same_specie = mna_emb_two_species.get_embed_similartities(specie1Emb, specie2Emb)

        mna_emb_two_species.write_to_file_mna_emb(resultsemb_same_specie, args.SEED_THRESHOLD)
        result_seed_pair = mna_emb_two_species.get_seed_pair()

        entity_map, relation_map, entity_id_map = mna_emb_two_species.get_entity_relation_map_embrdding()
        la = LocalAlignment(paths['network1'], paths['network2'], datanet1, datanet2, result_seed_pair, entity_map, relation_map,
                            args.entity_emb_path, args.relation_emb_path, entity_id_map, args.gamma,
                            args.SCORE_THRESHOLD, float(0), args.alpha, args.clstm)
        cluster1, cluster2 = la.get_clusters()







    with Pool(multiprocessing.cpu_count()) as pool:
        resultwritecluster = pool.starmap(write_cluster_to_file, [(cluster1, cluster_source_specie), (cluster2, cluster_target_specie)])
    print("Number of alignments discovered: %d" % (len(cluster1)))

    print("Alignment results are in files '%s' and '%s'" % (cluster_source_specie, cluster_target_specie))

    ppi_specie2_id_DGLKE_List = {}
    pidHuman_idmap_file = cluster_source_specie
    with open(pidHuman_idmap_file) as f:
        cluster1 = 0
        for line in f:
            it = line.strip().split(" ")
            cluster1 = cluster1 + 1
            p1id = it[0]
            p2id = it[1]
            it.remove(it[0])
            for elem in it:
                if int(cluster1) not in ppi_specie2_id_DGLKE_List:
                    ppi_specie2_id_DGLKE_List[int(cluster1)] = set()
                ppi_specie2_id_DGLKE_List[int(cluster1)].add(str(elem))

    ppi_specie1_id_DGLKE_List = {}
    pidHuman_idmap_file = cluster_target_specie
    with open(pidHuman_idmap_file) as f:
        cluster1 = 0
        for line in f:
            it = line.strip().split(" ")
            cluster1 = cluster1 + 1
            p1id = it[0]
            p2id = it[1]
            it.remove(it[0])
            for elem in it:
                if int(cluster1) not in ppi_specie1_id_DGLKE_List:
                    ppi_specie1_id_DGLKE_List[int(cluster1)] = set()
                ppi_specie1_id_DGLKE_List[int(cluster1)].add(str(elem))

    with open(args.save_path + '/'+"alignment_spec1_spec2_protein_ppi.txt",
              'w+') as fedrgId:
        for i, row in enumerate(ppi_specie1_id_DGLKE_List):
            cluster_id1 = row
            p1idspec1 = ppi_specie1_id_DGLKE_List[cluster_id1]
            for p1 in p1idspec1:
                fedrgId.writelines("{}\t".format(str(p1)))
            p1idspec2 = ppi_specie2_id_DGLKE_List[cluster_id1]
            for p2 in p1idspec2:
                fedrgId.writelines("{}\t".format(str(p2)))
            fedrgId.writelines("\n")
    with open(
            args.save_path + '/'+"alignment_spec1_spec2_pairwise_protein_PPI.txt",
            'w+') as fedrgId:
        for i, row in enumerate(ppi_specie1_id_DGLKE_List):
            cluster_id1 = row
            p1idspec1 = ppi_specie1_id_DGLKE_List[cluster_id1]
            for p1 in p1idspec1:
                p1idspec2 = ppi_specie2_id_DGLKE_List[cluster_id1]
                for p2 in p1idspec2:
                    fedrgId.writelines("{}\t".format(str(p1)))
                    fedrgId.writelines("{}\t".format(str(p2)))
                    fedrgId.writelines("\n")

    et = time.time()

    # get the execution time
    elapsed_time = et - st

    print('Execution time:', time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))

if __name__ == '__main__':
    main(parse_args())
