
# Module: tree_search
# 
# This module provides a set o classes for automated
# problem solving through tree search:
#    SearchDomain  - problem domains
#    SearchProblem - concrete problems to be solved
#    SearchNode    - search tree nodes
#    SearchTree    - search tree with the necessary methods for searhing
#
#  (c) Luis Seabra Lopes
#  Introducao a Inteligencia Artificial, 2012-2020,
#  Inteligência Artificial, 2014-2026

#### atributos e metodos ####
## self.problem - O problema a resolver (uma instancia de SearchProblem)

## self.problem.domain - O domiınio (SearchDomain) em que se enquadra o problema
## self.problem.domain.actions(state) - Devolve uma lista com as acoess aplicaveis em state
## self.problem.domain.result(state,action) - Devolve o resultado de action em state
## self.problem.domain.cost(state,action) - Devolve o custo de action em state
## self.problem.domain.heuristic(state1,state2) - Devolve uma estimativa do custo de ir de state1 para state2
## self.problem.domain.satisfies(state,goal) - Verifica se um dado estado (state) satisfaz um dado objectivo (goal)

## self.problem.initial - O estado inicial
## self.problem.goal - O estado objectivo
## self.problem.goal_test(state) - Verifica se state e o objectivo

## self.strategy - A estrategia de pesquisa usada
## self.open_nodes - A fila dos nos abertos (folhas da arvore, a expandir), em que cada no é uma instancia de SearchNode
## self.search() - O metodo principal de pesquisa




from abc import ABC, abstractmethod

# Dominios de pesquisa
# Permitem calcular
# as accoes possiveis em cada estado, etc
class SearchDomain(ABC):

    # construtor
    @abstractmethod
    def __init__(self):
        pass

    # lista de accoes possiveis num estado
    @abstractmethod
    def actions(self, state):
        pass

    # resultado de uma accao num estado, ou seja, o estado seguinte
    @abstractmethod
    def result(self, state, action):
        pass

    # custo de uma accao num estado
    @abstractmethod
    def cost(self, state, action):
        pass

    # custo estimado de chegar de um estado a outro
    @abstractmethod
    def heuristic(self, state, goal):
        pass

    # test if the given "goal" is satisfied in "state"
    @abstractmethod
    def satisfies(self, state, goal):
        pass


# Problemas concretos a resolver
# dentro de um determinado dominio
class SearchProblem:
    def __init__(self, domain, initial, goal):
        self.domain = domain
        self.initial = initial
        self.goal = goal


    def goal_test(self, state):
        return self.domain.satisfies(state,self.goal)



# Nos de uma arvore de pesquisa
class SearchNode:
    def __init__(self,state,parent,cost=0, heuristic=0, action=None): 
        self.state = state
        self.parent = parent

        self.depth = 0 if parent is None else parent.depth + 1
        self.cost = cost
        self.heuristic = heuristic

        self.action = action # acao que levou a este no

    def __str__(self):
        return "no(" + str(self.state) + "," + str(self.parent) + ")"
    
    def __repr__(self):
        return str(self)



# Arvores de pesquisa
class SearchTree:

    # construtor
    def __init__(self,problem, strategy='breadth', limit=None): 
        self.problem = problem
        root = SearchNode(problem.initial, None)
        self.open_nodes = [root]
        self.strategy = strategy
        self.solution = None
        self.limit = limit

        self.terminals = 0
        self.non_terminals = 0

        self.highest_cost_nodes = []
        self.max_cost = 0

        self.total_depth = 0
        self.total_nodes = 1
        self.average_depth = 0

        self.plan = None

    @property # numero de transições
    def length(self):
        if self.solution is None:
            return None
        path = self.get_path(self.solution)
        return len(path) - 1
    

    @property # fator de ramificação média da árvore
    def avg_branching(self):
        if self.non_terminals == 0:
            return 0 
        total_nodes = self.terminals + self.non_terminals
        return (total_nodes - 1) / self.non_terminals


    @property #custo total (acumulado) da solução
    def cost(self):
        if self.solution is None:
            return None
        return self.solution.cost

    # obter o caminho (sequencia de estados) da raiz ate um no
    def get_path(self,node):
        if node.parent == None:
            return [node.state]
        path = self.get_path(node.parent)
        path += [node.state]
        return(path)

    # obter o plano (sequencia de acoes) da raiz ate um no
    def get_plan(self, node):
        if node.parent is None:
            return []
        plan = self.get_plan(node.parent)
        plan.append(node.action)
        return plan

    # verifica se o estado já aparece no caminho até à raiz
    def in_path(self, node, state):
        if node is None:
            return False
        if node.state == state:
            return True
        return self.in_path(node.parent, state)
    

    # procurar a solucao
    def search(self, limit=None):
        while self.open_nodes != []:
            node = self.open_nodes.pop(0)

            if self.problem.goal_test(node.state):
                self.solution = node
                break

            self.non_terminals += 1
            
            #so expande se nao ultrapassar o limite
            if limit is None or node.depth < limit:
                lnewnodes = []

                for a in self.problem.domain.actions(node.state):
                    newstate = self.problem.domain.result(node.state, a)
                    
                    if not self.in_path(node, newstate):

                        #custo acumulado ate ao novo no
                        step_cost = self.problem.domain.cost(node.state, a)
                        new_cost = node.cost + step_cost
                        h = self.problem.domain.heuristic(newstate, self.problem.goal)

                        #prevencao de ciclos
                        newnode = SearchNode(newstate, node,  new_cost, h, a)
                        lnewnodes.append(newnode)

                        #atualizar estatisticas
                        self.total_nodes += 1
                        self.total_depth += newnode.depth

                        #atualizar os nos de maior custo acumulado
                        if new_cost > self.max_cost:
                            self.max_cost = new_cost
                            self.highest_cost_nodes = [newnode]
                        elif new_cost == self.max_cost:
                            self.highest_cost_nodes.append(newnode)

                self.add_to_open(lnewnodes)
    
        self.terminals = len(self.open_nodes)

        if self.total_nodes > 0:
            self.average_depth = self.total_depth / self.total_nodes

        if self.solution is not None:
            self.terminals += 1
            self.plan = self.get_plan(node)
            return self.get_path(self.solution)
        else:
            return None

        



    # juntar novos nos a lista de nos abertos de acordo com a estrategia
    def add_to_open(self,lnewnodes):
        if self.strategy == 'breadth':
            self.open_nodes.extend(lnewnodes)
        elif self.strategy == 'depth':
            self.open_nodes[:0] = lnewnodes
        elif self.strategy == 'uniform':
            self.open_nodes.extend(lnewnodes)
            self.open_nodes.sort(key=lambda node: node.cost)
        elif self.strategy == 'greedy':
            self.open_nodes.extend(lnewnodes)
            self.open_nodes.sort(key=lambda node: node.heuristic)
        elif self.strategy == 'a*':
            self.open_nodes.extend(lnewnodes)
            self.open_nodes.sort(key=lambda node: node.cost + node.heuristic)
