import json
import sys
import re

def split_terms(filename):
    '''
    Split the entire split into individual terms
    Args:
        filename: a file path to a GO terms file
     '''
    terms = []
    try:
        with open(filename) as f:
            contents = f.read()
            split_terms = contents.split("[Term]")
            return split_terms[1:-1]
    except FileNotFoundError:
        print("File Not Found!")
    return terms


def map_protein_to_go(GAF):
    '''
    Builds a dictionary in order to map each individual protein, a key, to a list of its associated GO terms, the values.
    Args:
        GAF: a file path to a GO annotations (GAF) file
    '''
    dict1 = {}
    try:
        with open(GAF) as f:
            for line in f:
                if not line.startswith("!"):
                    columns = line.split("\t")
                    if (columns[6])!="IEA":
                        #print(columns[6])
                        protein = columns[1]
                        GOID = columns[4] if "GO:" in columns[4] else columns[5]
                        if protein in dict1:
                            if(GOID not in dict1[protein]):
                                dict1[protein].append(GOID)

                        else:
                            dict1[protein] = [GOID]
                        #########--------
                        ohther_identifier=columns[7]
                        listother=ohther_identifier.strip().split("|")
                        #print(listother)
                        for yl in listother:
                            if (str(yl).startswith('UniProtKB:')):
                                oth=str(yl).strip().split("UniProtKB:")
                                #print(oth[1])
                                if str(oth[1]) in dict1:
                                    if (GOID not in dict1[str(oth[1])]):
                                        dict1[str(oth[1])].append(GOID)

                                else:
                                    dict1[str(oth[1])] = [GOID]


                        #########-------------

    except FileNotFoundError:
        print("File Not Found!")
    return(dict1)

def parse_go_term(term):
    '''
    Grab only the GO IDs from the GO Term and return the associated ID and is_a values in a collection
    Args:
        term: a single GO term from the larger GO file
    '''
    is_a = []
    id = None
    if term:
        for line in term.split("\n"):
            linesplit = line.split()
            if "id: GO:" in line:
                id = linesplit[1]
            if "is_a" in line:
                is_a.append(linesplit[1])
    return {id: is_a}

def find_parent_terms(go_id, go_dict):
    '''
    Takes in two parameters and looks for all parents terms of a GO term and returns them as a collection.
    Args:
        go_id: a singular GO ID
        go_dict: a dictionary of GO terms
    '''
    parent_terms = []
    for goid in go_dict.get(go_id, []):
        parent_terms.append(goid)
        parsed_parents = find_parent_terms(goid, parse_go_term(go_dict.get(goid, None)))
        if parsed_parents:
            parent_terms.extend(parsed_parents)
    parent_terms = list(set(parent_terms))
    return parent_terms

def mapping_to_term(terms):
    '''
    Obtain a new dictionary for output
    '''
    mapping_the_terms = {}
    for items in terms:
        protein_parent_term = parse_go_term(items)
        mapping_the_terms[next(iter(protein_parent_term))] = items # next(iter(protein_parent_term))
    return mapping_the_terms

def main():
    '''
    main function for output
    Run of command line: Python3 *NameOfProgram.py* <GO file> <GAF file>
    '''
    terms = split_terms('go-basic.obo')
    entity_GOT_List={}
    with open("listGOterms.txt", 'w+') as frkg2:
        for x in terms:
            substr = str(x).strip().split('\n')
            if (str(substr[2]).strip().split('namespace:')):
                strnamespace = str(substr[2]).strip().split('namespace:')
                if (str(strnamespace[1]).strip() == "cellular_component"):
                    goid = str(substr[0]).strip().split('id:')
                    #print(goid[1])
                    frkg2.writelines("{}\n".format(str(goid[1])))
                    if str(goid[1]) not in entity_GOT_List:
                        entity_GOT_List[str(goid[1])] = set()
                    entity_GOT_List[str(goid[1])].add(str('Gene Ontology terms'))


    mapped_protein = map_protein_to_go('goa/goa_human.gaf')
    print("---------------- Reference GOLD standard--------------")
    ########################
    entity_map_gene_to_protein_idmap_List = {}
    fentity_idmap_file2 = 'data.tsv'
    with open(fentity_idmap_file2) as f:
        f.__next__()
        for line in f:
            it = line.strip().split("\t")
            protname = it[0]
            genelist=it[1].strip().split(" ")
            for g in genelist:
                if str(protname) not in entity_map_gene_to_protein_idmap_List:
                    entity_map_gene_to_protein_idmap_List[str(protname)] = set()
                entity_map_gene_to_protein_idmap_List[str(protname)].add(str(g))



    entity_yeast_complex_map_GOID_idmap_List = {}
    entity_yeast_complex_map_Proteins_idmap_List = {}
    fentity_idmap_file2 = 'CYC2008_complex/CYC2008_complex.tab'
    with open(fentity_idmap_file2) as f:
        f.__next__()
        for line in f:

            it = line.strip().split("\t")
            complexName = it[2]
            gene = it[1]
            try:
                got_name = it[5]
                if (str(got_name).strip() !="-"):
                    if str(complexName) not in entity_yeast_complex_map_GOID_idmap_List:
                        entity_yeast_complex_map_GOID_idmap_List[str(complexName)] = set()
                    entity_yeast_complex_map_GOID_idmap_List[str(complexName)].add(str(got_name))

            except:
                pass
            if str(complexName) not in entity_yeast_complex_map_Proteins_idmap_List:
                entity_yeast_complex_map_Proteins_idmap_List[str(complexName)] = set()
            for i, row in enumerate(entity_map_gene_to_protein_idmap_List):
                protein_d = row
                gotList_id2 = entity_map_gene_to_protein_idmap_List[protein_d]
                for x in gotList_id2:
                    if (x==gene):
                        entity_yeast_complex_map_Proteins_idmap_List[str(complexName)].add(str(protein_d))


    cluster_protein_List = {}
    Complex_Name = {}
    dict = {}
    dict['Cmplx'] = {}
    with open('humanComplexes/humanComplexes.json') as json_file:
        data = json.load(json_file)
        icmptdict = 0
        for p in data:
            dict['Cmplx'][int(icmptdict)] = {}
            # dict['Cmplx'][int(icmptdict)]['DiseaseComment'] = str(p['Disease comment'])
            dict['Cmplx'][int(icmptdict)]['GODescription'] = str(p['GO description'])
            # dict['Cmplx'][int(icmptdict)]['FunCatDescription'] = str(p['FunCat description'])
            dict['Cmplx'][int(icmptdict)]['GeneName'] = str(p['subunits(Gene name)'])
            dict['Cmplx'][int(icmptdict)]['Synonyms'] = str(p['Synonyms'])
            dict['Cmplx'][int(icmptdict)]['ComplexID'] = str(p['ComplexID'])
            g_name = str(p['subunits(Gene name)']).strip().split(';')
            dict['Cmplx'][int(icmptdict)]['ComplexName'] = str(p['ComplexName'])
            varcmplexname = str(p['ComplexName'])

            dict['Cmplx'][int(icmptdict)]['Organism'] = str(p['Organism'])
            if (str(p['Organism']) == "Human"):
                if str(p['ComplexID']) not in cluster_protein_List:
                    cluster_protein_List[str(p['ComplexID'])] = set()
                for gn in g_name:
                    cluster_protein_List[str(p['ComplexID'])].add(gn)

                if str(p['ComplexID']) not in Complex_Name:
                    Complex_Name[str(p['ComplexID'])] = set()
                for gn in g_name:
                    Complex_Name[str(p['ComplexID'])].add(varcmplexname)

            dict['Cmplx'][int(icmptdict)]['GeneNameSyn'] = str(p['subunits(Gene name syn)'])
            # dict['Cmplx'][int(icmptdict)]['ProtComplexPurificationMethod'] = str(p['Protein complex purification method'])
            # dict['Cmplx'][int(icmptdict)]['ComplexComment'] = str(p['Complex comment'])
            # dict['Cmplx'][int(icmptdict)]['SubunitsComment'] = str(p['Subunits comment'])
            # dict['Cmplx'][int(icmptdict)]['PubMedID'] = str(p['PubMed ID'])
            # dict['Cmplx'][int(icmptdict)]['CellLine'] = str(p['Cell line'])
            dict['Cmplx'][int(icmptdict)]['GOID'] = str(p['GO ID'])
            # dict['Cmplx'][int(icmptdict)]['FunCatID'] = str(p['FunCat ID'])
            dict['Cmplx'][int(icmptdict)]['subunitsProteinName'] = str(p['subunits(Protein name)'])
            dict['Cmplx'][int(icmptdict)]['subunitsUniProtIDs'] = str(p['subunits(UniProt IDs)'])
            dict['Cmplx'][int(icmptdict)]['SWISSPROTOrganism'] = str(p['SWISSPROT organism'])
            icmptdict = icmptdict + 1

    with open("listGOGO_GOterms.txt", 'w+') as frkg3:
        for i, row in enumerate(entity_yeast_complex_map_GOID_idmap_List):
            complex_id1 = row
            gotList_id2 = entity_yeast_complex_map_GOID_idmap_List[complex_id1]
            for p2 in gotList_id2:
                fp2 = str(p2)
            for i, row in enumerate(entity_GOT_List):
                gotid = str(row)
                frkg3.writelines("{}\t{}\n".format(str(fp2), str(gotid)))








    mapped_term = mapping_to_term(terms)
    with open("alignment_ground_yeast_human_basicv3.txt", 'w+') as frkg:
        for k in range(icmptdict):
            cmpt = 0
            listprotcomplex = dict['Cmplx'][int(k)]['subunitsUniProtIDs']
            protids = str(listprotcomplex).strip().split(";")
            for l in protids:
                try:
                    listgot = mapped_protein[l]
                    for z in listgot:
                        for i, row in enumerate(entity_yeast_complex_map_GOID_idmap_List):
                            complex_id1 = row
                            gotList_id2 = entity_yeast_complex_map_GOID_idmap_List[complex_id1]
                            for p2 in gotList_id2:
                                fp2 = str(p2)
                            if (z == fp2):
                                cmpt = cmpt + 1
                except:
                    print(l)
                    continue
            res = len(protids) / 2
            if (cmpt > len(protids) / 2):
                listprotyeast = entity_yeast_complex_map_Proteins_idmap_List[str(complex_id1)]
                listprohuman = protids
                for lm in listprotyeast:
                    for ls in listprohuman:
                        frkg.writelines("{}\t{}".format(str(lm), str(ls)))
                        frkg.writelines("\n")



    with open("gold_standard_reference_yeast_human_basicv3.txt", 'w+') as frkg:
        for k in range(icmptdict):
            cmpt = 0
            listprotcomplex = dict['Cmplx'][int(k)]['subunitsUniProtIDs']
            protids = str(listprotcomplex).strip().split(";")
            for l in protids:
                try:
                    listgot = mapped_protein[l]
                    for z in listgot:
                        for i, row in enumerate(entity_yeast_complex_map_GOID_idmap_List):
                            complex_id1 = row
                            gotList_id2 = entity_yeast_complex_map_GOID_idmap_List[complex_id1]
                            for p2 in gotList_id2:
                                fp2 = str(p2)
                            if (z == fp2):
                                cmpt = cmpt + 1
                except:
                    print(l)
                    continue
            res = len(protids) / 2
            if (cmpt > len(protids) / 2):
                listprotyeast = entity_yeast_complex_map_Proteins_idmap_List[str(complex_id1)]
                listprohuman = protids
                for lm in listprotyeast:
                    frkg.writelines("{}\t".format(str(lm)))
                for ls in listprohuman:
                    frkg.writelines("{}\t".format(str(ls)))
                frkg.writelines("\n")









    with open("gold_standard_reference_yeast_human_basicv2.txt", 'w+') as frkg:
        for i, row in enumerate(entity_yeast_complex_map_GOID_idmap_List):
            complex_id1 = row
            gotList_id2 = entity_yeast_complex_map_GOID_idmap_List[complex_id1]
            for p2 in gotList_id2:
                fp2 = str(p2)

            for k in range(icmptdict):
                cmpt = 0
                listprotcomplex = dict['Cmplx'][int(k)]['subunitsUniProtIDs']
                protids = str(listprotcomplex).strip().split(";")
                for l in protids:
                    try:
                        listgot = mapped_protein[l]
                        for z in listgot:
                            if (z == fp2):
                                cmpt = cmpt + 1
                    except:
                        print(l)
                        continue
                res=len(protids) / 2
                if (cmpt> len(protids) / 2 ):
                    listprotyeast=entity_yeast_complex_map_Proteins_idmap_List[str(complex_id1)]
                    listprohuman=protids
                    for lm in listprotyeast:
                        frkg.writelines("{}\t".format(str(lm)))
                    for ls in listprohuman:
                        frkg.writelines("{}\t".format(str(ls)))
                    frkg.writelines("\n")

                    #print(entity_yeast_complex_map_Proteins_idmap_List[str(complex_id1)], '|', protids)
                    #print("found true")






    # for k in range(i):
    #     listgoidcomplex = dict['Cmplx'][int(k)]['GOID']
    #     goids = str(listgoidcomplex).strip().split(";")
    #     for l in goids:
    #         for i, row in enumerate(entity_yeast_complex_map_GOID_idmap_List):
    #             complex_id1 = row
    #             gotList_id2 = entity_yeast_complex_map_GOID_idmap_List[complex_id1]
    #             for p2 in gotList_id2:
    #                 fp2 = str(p2)
    #                 if (l == fp2):
    #                     print(l, "complex n", k, "equal to", complex_id1)

    print("End")

    ########################################################################

    if len(sys.argv) > 2:
        input_terms = sys.argv[1]
        input_annotations = sys.argv[2]
        terms = split_terms(input_terms)
        mapped_protein = map_protein_to_go(input_annotations)

        mapped_term = mapping_to_term(terms)


        #testing = {'A0A075B6K4': mapped_protein['A0A075B6K4'], 'A8MW99': mapped_protein['A8MW99']}
        if len(sys.argv) == 3:
            outfile = "results_yeast.tsv"
        else:
            outfile = sys.argv[3]
        ppi_idmap_List = {}

        pid_idmap_file = 'HINTSaccharomycesCerevisiaeS288C/SaccharomycesCerevisiaeS288Cbinaryhq.txt'
        with open(pid_idmap_file) as f:
            for line in f:
                it = line.strip().split("\t")
                p1id = it[0]
                p2id = it[1]
                if (p1id != p2id):
                    if str(p1id) not in ppi_idmap_List:
                        ppi_idmap_List[str(p1id)] = set()
                    ppi_idmap_List[str(p1id)].add(str(p2id))

        #
        # testing = {'A0A075B6K4': mapped_protein['A0A075B6K4'], 'A8MW99': mapped_protein['A8MW99']}
        # with open(outfile, 'w+') as annotated_result:
        #     cprot=0
        #     for prot, updated_goids in testing.items():
        #         annotated_result.write(prot + " ")
        #         c1 = 0
        #         for new_goid in updated_goids:
        #             combine = find_parent_terms(new_goid, parse_go_term(mapped_term.get(new_goid)))
        #             proteinstring = " ".join(combine)
        #             print("{0} {1}".format(new_goid, proteinstring))
        #             c1 = c1 + 1
        #             if (len(updated_goids) == c1 and cprot == 0):
        #                 if (proteinstring != ""):
        #                     annotated_result.write("{0} {1}".format(new_goid, proteinstring))
        #                 else:
        #                     annotated_result.write(new_goid)
        #                 annotated_result.write(";")
        #             else:
        #                 if (proteinstring != ""):
        #                     annotated_result.write("{0} {1}".format(new_goid, proteinstring))
        #                     annotated_result.write(" ")
        #                 else:
        #                     annotated_result.write(new_goid + " ")
        #         cprot = cprot + 1





        with open(outfile, 'w+') as annotated_result:
            for i, row in enumerate(ppi_idmap_List):
                prot_id1 = row
                protList_id2 = ppi_idmap_List[prot_id1]
                for p2 in protList_id2:
                    fprot_id1 = str(prot_id1)
                    fp2 = str(p2)
                    testing = {}
                    try:
                        testing[fprot_id1] = mapped_protein[fprot_id1]
                        testing[fp2] = mapped_protein[fp2]
                    except:
                        continue
                    cprot = 0
                    for prot, updated_goids in testing.items():
                        if (cprot == 0):
                            annotated_result.write("\n" + prot + " ")
                        else:
                            annotated_result.write(prot + " ")
                        c1 = 0
                        for new_goid in updated_goids:
                            combine = find_parent_terms(new_goid, parse_go_term(mapped_term.get(new_goid)))
                            proteinstring = " ".join(combine)
                            #print("{0} {1}".format(new_goid, proteinstring))
                            c1 = c1 + 1
                            if (len(updated_goids) == c1 and cprot == 0):
                                if (proteinstring != ""):
                                    annotated_result.write("{0} {1}".format(new_goid, proteinstring))
                                else:
                                    annotated_result.write(new_goid)
                                annotated_result.write("; ")
                            else:
                                if (proteinstring != ""):
                                    annotated_result.write("{0} {1}".format(new_goid, proteinstring))
                                    annotated_result.write(" ")
                                else:
                                    annotated_result.write(new_goid + " ")
                        cprot = cprot + 1







        # print(
        #     find_parent_terms(
        #         "GO:0002250", parse_go_term(mapped_term['GO:0002250'])))
    else:
        print("Error! Not enough command line arguments")


if __name__=="__main__":
    main()
    #map_protein_to_go('goa_human_subset.gaf')
