import numpy as np
import numpy.random as random
import scipy.stats as scs
import matplotlib.pyplot as plt
import networkx as nx


class RootedTree:
    def __init__(self, max_size=10**6):
        self.depth = -np.ones(max_size, dtype=np.int64)
        self.row_position = np.zeros(max_size, dtype=np.int64)
        self.parent = np.empty(max_size, dtype=object)
        self.children = np.empty(max_size, dtype=object)
        self.depth[0], self.row_position[0], self.children[0] = 0, 1, []
        self.limit, self.profile, self.nodes_number = max_size, [1], 1

    def __repr__(self):
        return f"Rooted tree with {self.nodes_number} vertices"

    def find_index(self, d, k):
        ind = (self.depth == d)[: self.nodes_number]
        ind *= (self.row_position == k)[: self.nodes_number]
        return int(np.nonzero(ind)[0][0])

    def add_children(self, parent, c):
        ip = self.find_index(*parent)
        if len(self.profile) - 1 == self.depth[ip]:
            self.profile.append(0)
        for i in range(self.nodes_number, self.nodes_number + c):
            self.depth[i] = self.depth[ip] + 1
            self.profile[self.depth[i]] += 1
            self.row_position[i] = self.profile[self.depth[i]]
            self.parent[i], self.children[i] = ip, []
        self.children[ip] += list(
            range(self.nodes_number, self.nodes_number + c)
        )
        self.nodes_number += c

    def networkx(self):
        T = nx.DiGraph({i: self.children[i] for i in range(self.nodes_number)})
        for i in range(self.nodes_number):
            T.nodes[i]["label"] = (
                int(self.depth[i]),
                int(self.row_position[i]),
            )
        return T

    def layout(self):
        return {
            i: np.array(
                [
                    -(self.profile[self.depth[i]] + 1) / 2
                    + self.row_position[i],
                    self.depth[i],
                ]
            )
            for i in range(self.nodes_number)
        }

    def draw_on_ax(self, ax0, with_labels=False):
        ax0.set_axis_off()
        T = self.networkx()
        if with_labels:
            nx.draw_networkx(
                T,
                ax=ax0,
                pos=self.layout(),
                node_shape="s",
                node_size=800,
                node_color=np.array([[0.7, 0.7, 0.9]]),
                labels={
                    i: T.nodes[i]["label"] for i in range(self.nodes_number)
                },
            )
        else:
            nx.draw_networkx(
                T,
                ax=ax0,
                pos=self.layout(),
                node_shape="s",
                node_size=300,
                node_color=np.array([[0.7, 0.7, 0.9]]),
                with_labels=False,
            )
