#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "board.h"
#include "entity.h"
#include "math.h"
#include "move.h"
#include "simul.h"
#include "mcts.h"

//FUSION TARDIVE DE rbTree et mcts pour des problèmes de headers récursifs


#define MCTS_NB_ITER 10
rb_node_t * root_mcts;

//----------------Définition des arbres rouge-noirs---------------------

// Create a red-black tree
struct rb_node_t * rbtCreateNode(mcts_node_t * data) {
    rb_node_t * node = malloc(sizeof(rb_node_t));
	node->data = data;
	node->color = RED;
	node->link[0] = NULL;
	node->link[1] = NULL;

	return node;
}

struct rb_node_t * rbtRotateLeft(struct rb_node_t * x) {
    struct rb_node_t * y = x->link[1];
    struct rb_node_t * z = y->link[0];

    y->link[0] = x;
    x->link[1] = z;

    return y;
}

struct rb_node_t * rbtRotateRight(struct rb_node_t * y) {
    struct rb_node_t * x = y->link[0];
    struct rb_node_t * z = x->link[1];

    x->link[1] = y;
    y->link[0] = z;

    return x;
}

struct rb_node_t * rbtRebalance(struct rb_node_t * root, struct rb_node_t * stack[], int dir[], int ht) {
	struct rb_node_t *ptr, *xPtr, *yPtr, *zPtr;

	// Cases 1, 2 and 3 : Already covered
	
	while ((ht >= 3) && (stack[ht - 1]->color == RED)) {
		// We first get some useful nodes
		ptr = stack[ht];							// Current node
		xPtr = stack[ht - 1]; 						// Father
		zPtr = stack[ht - 2];						// Grandfather
		yPtr = zPtr->link[(dir[ht - 1] + 1) % 2]; 	// Uncle

		if (yPtr != NULL && yPtr->color == RED) { // Case 4

			xPtr->color = BLACK;

			if (yPtr != NULL) {
				yPtr->color = BLACK;
			}

			zPtr->color = RED;

			ht -= 2; // We continue with the grandfather

		} else if (dir[ht - 1] != dir[ht - 2]) { // Case 5

			if (dir[ht - 1] == 0) {
				xPtr = rbtRotateRight(xPtr);
			} else {
				xPtr = rbtRotateLeft(xPtr);
			}

			zPtr->link[dir[ht - 2]] = xPtr; // We rechain the rotated sub-tree

			// Correcting stack and dir
			stack[ht - 1] = ptr;
			stack[ht] = xPtr;

			dir[ht - 1] = (dir[ht - 1] + 1) % 2; // We switch the direction

		} else { // Case 6

			zPtr->color = RED;
			xPtr->color = BLACK;

			if (dir[ht - 1] == 0) {
				zPtr = rbtRotateRight(zPtr);
			} else {
				zPtr = rbtRotateLeft(zPtr);
			}

			stack[ht - 3]->link[dir[ht - 3]] = zPtr; // We rechain the rotated sub-tree

			// Correcting stack
			stack[ht - 2] = zPtr;

			ht -= 2; // We continue with zPtr

		}
	}

	if (root != NULL) {
		root->color = BLACK;
	}
	
	return root;
}

int mcts_node_cmp(mcts_node_t * node1, mcts_node_t * node2) {
	return board_cmp(node1->view_board, node2->view_board) + 100000 * (node1->curr_player.id - node2->curr_player.id);
}

// Retourne la racine, rempli new_adress avec l'adresse, soit du nouveau noeud crée soit du noeud qui existe déjà 
struct rb_node_t * rbtInsert(struct rb_node_t * root, mcts_node_t * data) {
	struct rb_node_t *stack[100], *ptr, * newnode;
	int dir[100], ht = 0, index;

	// Finding where to insert the node

	// Case of the root
	ptr = root;
	if (!root) {
		root = rbtCreateNode(data);
		return root;
	}

	stack[ht] = root;
	dir[ht++] = 0;

	while (ptr != NULL) {

		if (mcts_node_cmp(ptr->data, data) == 0) {
			printf("Duplicates Not Allowed!!\n");
			return root;
		}

		index = !(mcts_node_cmp(ptr->data, data) > 0); // Left of right child ?

		stack[ht] = ptr; // We keep a trace of the old node
		dir[ht++] = index; // We also keep a trace of the direction we followed
		ptr = ptr->link[index]; // We go to the next node
	}

	// Creating the new node
	stack[ht - 1]->link[index] = newnode = rbtCreateNode(data);
	stack[ht] = newnode;

	// Correcting the potential violations
	root = rbtRebalance(root, stack, dir, ht);

	printf("newnode: %p | %p\n", newnode->link[0], newnode->link[1]);
	return root;
}

// Delete a node
struct rb_node_t * rbtDelete(struct rb_node_t * root, mcts_node_t * data) {
	struct rb_node_t *stack[100], *ptr, *xPtr, *yPtr, *pPtr;
	int dir[100], ht = 0, diff;
	enum nodeColor c;

	// stack : Stack of all the covered nodes
	// dir : The direction we followed in the RB tree

	// Case of the root

	if (!root) {
		printf("Tree not available\n");
		return root;
	}

	// Finding the node to delete

	ptr = root;
	while (ptr != NULL) {
		if (memcmp(ptr->data->view_board, data->view_board, sizeof(board_t)))
			break;

		diff = memcmp(ptr->data->view_board, data->view_board, sizeof(board_t)) > 0 ? 1 : 0; // Left of right child ?

		stack[ht] = ptr; // We keep a trace of the old node
		dir[ht++] = diff; // We also keep a trace of the direction we followed
		ptr = ptr->link[diff]; // We go to the next node
	}

	// Deleting the node

	if (ptr != NULL) { // We found the value to delete

		c = ptr->color;

		xPtr = ptr->link[0];
		yPtr = ptr->link[1];

		if (yPtr == NULL) {

				// We replace the node by its left child

				if (ptr == root) {
					root = xPtr;
				} else {
					stack[ht - 1]->link[dir[ht - 1]] =  xPtr;
				}

				if (xPtr != NULL) {
					xPtr->color = BLACK;
				}

				free(ptr);
				ptr = NULL;


		} else {

			if (yPtr->link[0] == NULL) {

					// We replace the node by its right child

					yPtr->link[0] = ptr->link[0];

					if (ptr == root) {
						root = yPtr;
					} else {
						stack[ht - 1]->link[dir[ht - 1]] = yPtr;
					}

					if (yPtr != NULL) {
						yPtr->color = BLACK;
					}

					free(ptr);
					ptr = NULL;

			} else {

					// We replace ptr with its successor in the BR-tree
					pPtr = yPtr;

					stack[ht] = pPtr;
					dir[ht++] = 1;

					while (pPtr->link[0] != NULL) {

						stack[ht] = pPtr; // We keep a trace of the old node
						dir[ht++] = 0; // We also keep a trace of the direction we followed
						pPtr = pPtr->link[0]; // We go to the next node

					}

					// We swap the data of the 2 nodes
					ptr->data = pPtr->data;
					pPtr->data = data;

					// Now, we can delete pPtr
					pPtr = rbtDelete(pPtr, data);
					stack[ht-1]->link[dir[ht-1]] = pPtr;

			}
		}

		if (ht < 1)
			return root;

		// Correcting the potential violations

		if (c == BLACK) {
			root = rbtRebalance(root, stack, dir, ht-1);
		}

	}

	return root;
}

// Print the inorder traversal of the tree
void rbtDisplay(struct rb_node_t * node) {
	if (node != NULL) {
		printf("%p(%d) [", node->data, node->color);
		rbtDisplay(node->link[0]);
		printf("] [");
		rbtDisplay(node->link[1]);
		printf("]");
	}
}

//-------------MCTS-----------------
float c = 1.4142;

// Retourne l'interet pour un noeud
float interest(node_t node, int nb_total_try) {
    return (node.nb_win/((float) node.nb_try)) + c * sqrt(log(nb_total_try)/ node.nb_try);
}


node_t * get_max_interest_node(list_t * list){
    int i;
    int nb_total_try = 0;
    node_t * node = list->head;
    node_t * node_act = list->head;
    node_t * node_max = list->head;

	// On récupère le nombre total de try
	for (i=0; i<list->length; i++){
		nb_total_try += node_act->nb_try;
		node_act = node_act->next;
	}

	float max_val = -1;
	node_act = node;
	for (i = 0; i < list->length ; i++) {
		if (node_act->nb_try == 0) {
			return node_act;
		}

		int act_val = interest(*node_act, nb_total_try);
		if (act_val > max_val) {
			node_max = node_act;
			max_val = act_val;
		}

		node_act = node_act->next;
	}
	printf("(%d, %d) -> (%d, %d) try: %d nb_win:%d\n", node_max->move[0].x, node_max->move[0].y, node_max->move[1].x, node_max->move[1].y, node_max->nb_try, node_max->nb_win);
	return node_max;
}

// Donne la vue de current_player, en fonctio de opposite player
void get_view_from_board(player_t opposite_player, board_t board, board_t view) {
    for (int i = 0; i < BOARD_SIZE; i++) {
        for (int j = 0; j < BOARD_SIZE; j++) {
            view[i][j] = board[i][j];
            if ((!board[i][j].has_attacked) && (board[i][j].owner_id == opposite_player.id)) {
                view[i][j].rank = piece_unknown;
            }
        }
    }
}

mcts_node_t * create_mcts_node_from_view(board_t view, player_t current_player) {
    mcts_node_t * node = malloc(sizeof(mcts_node_t));
    copy_board(view, node->view_board);
    node->curr_player = current_player;
    node->moves = possible_moves(view, &current_player);
    return node;
}

mcts_node_t * create_mcts_node_from_board(board_t board, player_t current_player, player_t opposite_player) {
    board_t * view = malloc(sizeof(board_t));
    get_view_from_board(opposite_player, board, *view);

    return create_mcts_node_from_view(*view, current_player);
}

// Remplace toutes les pièces du joueur qui sont en unknown en pièce probable en fonction de sa stash
// /!\ Modifie la stash ET le board

void replace_unknown_pieces_on_board_by_random(board_t board, player_t * player) {
    for (int i = 0; i < BOARD_SIZE ; i++) {
        for (int j = 0; j < BOARD_SIZE; j ++) {
            if ((board[i][j].owner_id == player->id) && (board[i][j].rank == piece_unknown)) {
                board[i][j].rank = pop_random_rank_from_stash(player);
            }
        }
    }
}

bool_t search_view(board_t * view, mcts_node_t ** new_address, player_t * current_player){
	bool_t is_in = 0;
	rb_node_t * ptr = root_mcts;
	mcts_node_t * node;
	int iter = 0;

	while (ptr != NULL && (!is_in)){
		if (board_cmp(ptr->data->view_board, *view) == 0) {
			is_in = 1;
			*new_address = ptr->data;
		}

		int index = !(board_cmp(ptr->data->view_board, *view) > 0);
		ptr = ptr->link[index];
		iter++;
	}
	if (!is_in){
		node = create_mcts_node_from_view(*view,*current_player);
        *new_address = node;
		root_mcts = rbtInsert(root_mcts,node);
	}
	return is_in;
}

player_id mcts(board_t * view, player_t * current_player, player_t * responding_player, board_t complementary){
    mcts_node_t * starting_node = NULL;
    node_t * node;
    move_t move;
    board_t * next_view = malloc(sizeof(board_t));
    player_id id;
    bool_t is_in_tree = search_view(view,&starting_node,current_player);
    if (is_in_tree){
        if (is_player_loser(*current_player, complementary)){
            return responding_player->id;
        }
        else{
            node = get_max_interest_node(starting_node->moves);
            copy_move(node->move,move);
            take_move(complementary, move, current_player, responding_player, 1, NULL);
            get_view_from_board(*responding_player,complementary,*next_view);
            id = mcts(next_view,responding_player,current_player,complementary);
        }
    }
    else{
        starting_node->moves = possible_moves(complementary,current_player);
        node = get_max_interest_node(starting_node->moves);
        copy_move(node->move,move);
        take_move(complementary,move,current_player,responding_player,1,NULL);
        id = simulate_from_board(complementary,*responding_player,*current_player);
    }
	
	node->nb_try += 1;
	if (id == current_player->id){
		node->nb_win += 1;
	}

    return id;
}

// à appeler avant mcts_take_move
// Root board > Position par défaut
void init_mcts(board_t start_board, player_t mcts_player, player_t human_player) {
    mcts_node_t * root_node = create_mcts_node_from_board(start_board, mcts_player, human_player);
    root_mcts = rbtCreateNode(root_node);
}


// Modifie le board après une action de mcts
void mcts_take_move(board_t board, player_t * mcts_player, player_t * human_player, move_t * move, int * result) {
	mcts_node_t * board_state_node = NULL;
	player_t copied_mcts;
	player_t copied_human;
	board_t copied_board;

    board_t view;
	get_blank_board(view);

    for (int i = 0; i < MCTS_NB_ITER; i++) {
        get_view_from_board(*human_player, board, view);
		copy_board(board, copied_board);
		copy_player(human_player, &copied_human);
		copy_player(mcts_player, &copied_mcts);
        
		mcts(&view, &copied_mcts, &copied_human, copied_board);
    }
	
	search_view(&view, &board_state_node, mcts_player);
    node_t * max_interest_node = get_max_interest_node(board_state_node->moves);
    take_move(board, max_interest_node->move, mcts_player, human_player, 1, result);
	copy_move(max_interest_node->move, *move);
}