from functools  import cmp_to_key
from collections import defaultdict
from itertools import combinations 
DEBUG = False

class Cluster_Expansion:

    def __init__(self, data):

        dataset=data
        self.graph = defaultdict(set)               # Dictionary to represent the graph 
        self.CC = defaultdict(lambda:-1)            # Stores clustering coefficient of each vertex in the graph
        self.DEG = defaultdict(lambda:-1)                # Store sum of weights of connecting edges 
        self.current_entropy = 0                    # Track the current entropy in the interation
        
        self.graph_dict = defaultdict(dict)


        for g1, g2 in dataset:
            self.graph[g1].add((g2,float(1.0)))
            self.graph[g2].add((g1,float(1.0)))
            
            self.graph_dict[g1][g2] = float(1.0)
            self.graph_dict[g2][g1] = float(1.0)
            
        self.graph_dict = dict(self.graph_dict) 

        
    '''
    Returns the clustering coefficient of vertex v. Stores the result in CC so that it won't be computed again  
    Parameters: 
        - v : Vertex 
    '''
    def clustering_coefficient(self, v):
        if self.CC[v] > -1: return self.CC[v]           # Clustering Coefficient already calculated 
        
        deg = len(self.graph[v])
        if deg == 1: return 0
        e = 0
        for n1, _ in self.graph[v]:
            neighbor_list = dict(self.graph[n1]).keys()
            for n2, _ in self.graph[v]:
                if n2 in neighbor_list: e += 1
        
        self.CC[v] = e/(deg * (deg - 1))
        return self.CC[v]


    '''
    Compare function for sorting vertices by degree and then by clustering coefficient if there is a tie 
    Parameters:
        v1: First vertex 
        v2: Second vertex
    Returns:
        < 0 if v1 < v2 
        = 0 if v1 == v2 
        > 0 if v1 > v2 
    '''
    def compare_vertices(self, v1, v2):
        s1, s2 = self.get_degree(v1), self.get_degree(v2) 
        
        if s1 == s2: 
            return self.clustering_coefficient(v1) - self.clustering_coefficient(v2)
        return s1 - s2 
    
    
    def cluster_list_to_dict(self, cluster_list):
        if len(cluster_list) == 1: return {list(cluster_list)[0] : dict()}
        cluster = defaultdict(dict)
        for v1, v2 in combinations(cluster_list,2):
            if v1 in self.graph_dict and v2 in self.graph_dict:
                wt = self.graph_dict[v1].get(v2,None) 
                if wt is None: continue
                cluster[v1][v2] = self.graph_dict[v1][v2] 
                cluster[v2][v1] = self.graph_dict[v2][v1] 
        return dict(cluster)
        
    def get_protein_out_scope(self, cluster):
        #Generate neighbors of boundary nodes that are not in the cluster as new nodes
        new_nodes = set()
        for v in cluster:
            for n, _ in self.graph[v]:
                if n not in cluster: 
                    new_nodes.add(n)
                    
        return sorted(list(new_nodes), key = cmp_to_key(self.compare_vertices), reverse = True)



    '''
    Returns degreee (sum of connecting edge weights) of a vertex v 
    '''
    def get_degree(self, v):
        if self.DEG[v] > -1: return self.DEG[v]
    
        self.DEG[v] = sum([wt for _, wt in self.graph[v]])
        return self.DEG[v]

    
            