import numpy as np
import sklearn.metrics.pairwise
from scipy.sparse import csr_matrix, coo_matrix
from sklearn.neighbors import KDTree

def get_embedding_similarities(embed, embed2 = None, sim_measure = "euclidean", num_top = None):
	print("Start embedding similarities using cosine measure")
	n_nodes, dim = embed.shape
	if embed2 is None:
		embed2 = embed

	if num_top is not None:
		kd_sim = kd_align(embed, embed2, distance_metric = sim_measure, num_top = num_top)
		return kd_sim

	#All pairwise distance computation
	if sim_measure == "cosine":
		similarity_matrix = sklearn.metrics.pairwise.cosine_similarity(embed, embed2)
	else:
		similarity_matrix = sklearn.metrics.pairwise.euclidean_distances(embed, embed2)
		similarity_matrix = np.exp(-similarity_matrix)

	print("Finish embedding similarities")

	return similarity_matrix


def kd_align(emb1, emb2, normalize=False, distance_metric = "cosinus", num_top = 50):
	kd_tree = KDTree(emb2, metric = distance_metric)	
		
	row = np.array([])
	col = np.array([])
	data = np.array([])
	
	dist, ind = kd_tree.query(emb1, k = num_top)
	row = np.array([])
	for i in range(emb1.shape[0]):
		row = np.concatenate((row, np.ones(num_top)*i))
	col = ind.flatten()
	data = np.exp(-dist).flatten()
	sparse_align_matrix = coo_matrix((data, (row, col)), shape=(emb1.shape[0], emb2.shape[0]))
	return sparse_align_matrix.tocsr()