import json
import matplotlib.pyplot as plt
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D

class Hit:
    def __init__(self, hit_id, particle_id=None, x=0, y=0, z=0, volume_id=None, layer_id=None, module_id=None,
                 selected=None, index=None):
        self.hit_id = hit_id
        self.x = x
        self.y = y
        self.z = z
        self.particle_id = particle_id
        self.volume_id = volume_id
        self.layer_id = layer_id
        self.module_id = module_id
        self.selected = selected
        self.track = None
        self.index = index

def read_hits(path):
    df = pd.read_csv(path)
    list_df = [row.tolist() for index, row in df.iterrows()]
    hits_by_layers = dict()

    for i in range(len(list_df)):
        hit = Hit(
            hit_id=list_df[i][0],
            x=list_df[i][1],
            y=list_df[i][2],
            z=list_df[i][3],
            volume_id=list_df[i][4],
            layer_id=list_df[i][5] / 2,
            module_id=list_df[i][6],
            particle_id=list_df[i][7],
            index=i
        )
        layer_id = int(hit.layer_id)
        if layer_id not in hits_by_layers:
            hits_by_layers[layer_id] = [hit]
        else:
            hits_by_layers[layer_id].append(hit)
    hits = []
    index = 0
    for hs in hits_by_layers.values():
        for h in hs:
            h.index = index
            index += 1
        hits += hs
    return hits_by_layers, hits





def visualize_solution_file(file):
    with open(file, 'r') as f:
        data = json.load(f)
    tracks = data['solution']
    file_data_name = data['file_data_path']
    _, hits = read_hits(file_data_name)
    title = data['model']
    visualize_solution(hits, tracks, title, do_show=True, out=None)

def visualize_solution(hits, tracks, title, do_show=False, out=None):
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')

    xs = []
    ys = []
    zs = []
    for h in hits:
        xs.append(h.x)
        ys.append(h.y)
        zs.append(h.z)
    ax.scatter(xs, ys, zs, marker='o', color='red')

    for track in tracks:
        for var in track:
            x_i_j = var.split('_')
            i = int(x_i_j[1])
            j = int(x_i_j[2])
            h1 = hits[i]
            h2 = hits[j]
            ax.plot(xs=[h1.x, h2.x], ys=[h1.y, h2.y], zs=[h1.z, h2.z], color='blue')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')

    plt.title(title)

    if do_show:
        plt.show()
    if out is not None:
        plt.savefig(out)


def clean_visualize():
    plt.close('all')

if __name__ == '__main__':

    file = '../results/V_9_N_100_E_100_P_1_B_250_D_15/A_QUBM/solution.json'
    visualize_solution_file(file)
