#STUDENT NAME: Ana Loureiro
#STUDENT NUMBER: 104063


from tree_search import *
from strips import *
from blocksworld2 import *

class MyNode(SearchNode):

    def __init__(self, state, parent, depth=0, cost=0, heuristic=0, action=None):
        super().__init__(state,parent)

        self.depth = depth
        self.cost = cost
        self.heuristic = heuristic
        self.action = action



class MyTree(SearchTree):

    def __init__(self,problem, strategy='breadth'):
        super().__init__(problem,strategy)

        #root replaced with a MyNode instace
        self.root = MyNode(self.problem.initial, None, 0, 0, self.problem.domain.heuristic(self.problem.initial, self.problem.goal))
        self.open_nodes = [self.root]
        
        self.terminal = 0
        self.non_terminal = 0 
        self.solution_cost = 0

        self.mode = "uniform"


    def hybrid_add_to_open(self,lnewnodes):
        if not lnewnodes:
            return
        
        #add new nodes
        self.open_nodes.extend(lnewnodes)

        #check if mode has to switch to greedy
        if self.mode == "uniform":
            for node in lnewnodes:
                if node.cost > (self.root.heuristic / 2):
                    self.mode = "greedy"
                    break


        #sort according to mode, with tie breakers
        if self.mode == "uniform":
            self.open_nodes.sort(key=lambda n: (n.cost, n.depth, n.state))
        else:
            self.open_nodes.sort(key=lambda n: (n.heuristic, n.depth, n.state))
    


    def search2(self):
        while self.open_nodes:
            node = self.open_nodes.pop(0)  #investigate the node

            #if the node is the solution, break the loop
            if self.problem.goal_test(node.state):  
                self.solution = node
                break

            #else, expand it
            expand_node(self, node)


        #add the unexplored nodes as terminals (no children)
        self.terminal += len(self.open_nodes)

        if self.solution is not None:
            self.solution_cost = node.cost
            self.terminal += 1     #solution node counts as terminal
            return self.get_path(node)
        
        return None
     


    def bipolar_search(self):

        #normal tree search
        fwd_tree = MyTree(self.problem, self.strategy)

        #backward problem (goal=start)
        reverse_problem = SearchProblem(self.problem.domain, self.problem.goal, self.problem.initial)
        
        #reverse tree search
        rev_tree = MyTree(reverse_problem, self.strategy)

        #list of visited nodes
        visited_nodes_fwd = []
        visited_nodes_rev = []

        #node path if found
        path_fwd = []
        path_rev = []

        solution_node = None

        #alternate between searches
        while True:
            #if a search has no more open nodes, return
            if (fwd_tree.open_nodes == []) or (rev_tree.open_nodes == []):
                return None
            
            #check if node already appeared in reverse search (visited or found), else extend one node foward
            node = fwd_tree.open_nodes.pop(0)
            visited_nodes_fwd.append(node)

            #blind search: check only visited nodes
            if self.strategy == "breadth" or self.strategy == "depth":
                if any(node.state == rnode.state for rnode in visited_nodes_rev):
                    solution_node = node
                    break
            else:
                #informed search: check visited and open nodes
                if any(node.state == rnode.state for rnode in (visited_nodes_rev + rev_tree.open_nodes)):
                    solution_node = node
                    break
            
            expand_node(fwd_tree, node)



            #check if node already appeared in foward search (visited or found/open), else extend one node reverse
            node = rev_tree.open_nodes.pop(0)
            visited_nodes_rev.append(node)

            #blind search: check only visited nodes
            if self.strategy == "breadth" or self.strategy == "depth":
                if any(node.state == fnode.state for fnode in visited_nodes_fwd):
                    solution_node = node
                    break
            else:
                #informed search: check visited and open nodes
                if any(node.state == fnode.state for fnode in (visited_nodes_fwd + fwd_tree.open_nodes)):
                    solution_node = node
                    break
            
            expand_node(rev_tree, node)


        #if no solution found
        if solution_node is None:
            expanded_fwd = {n.state for n in visited_nodes_fwd}
            expanded_rev = {n.state for n in visited_nodes_rev}
            self.non_terminal = len(expanded_fwd.union(expanded_rev))

            open_fwd = {n.state for n in fwd_tree.open_nodes}
            open_rev = {n.state for n in rev_tree.open_nodes}
            self.terminal = len(open_fwd.union(open_rev))

            return None
        
        
        #get all same states of solution node from search trees according to strategy
        sols_fwd = []
        sols_rev = []

        if self.strategy == "breadth" or self.strategy == "depth":
            sols_fwd = [n for n in visited_nodes_fwd if n.state == solution_node.state]
            sols_rev = [n for n in visited_nodes_rev if n.state == solution_node.state]
        else:
            sols_fwd = [n for n in visited_nodes_fwd + fwd_tree.open_nodes if n.state == solution_node.state]
            sols_rev = [n for n in visited_nodes_rev + rev_tree.open_nodes if n.state == solution_node.state]



        #get node according to strategy (blind or informed)
        if self.strategy == "breadth" or self.strategy == "depth":
            node_fwd = min(sols_fwd, key=lambda n: n.depth)
            node_rev = min(sols_rev, key=lambda n: n.depth)
        else:
            node_fwd = min(sols_fwd, key=lambda n: n.cost)
            node_rev = min(sols_rev, key=lambda n: n.cost)


        #build full path
        path_fwd = fwd_tree.get_path(node_fwd)
        path_rev = rev_tree.get_path(node_rev)
        path_rev.reverse()


        #update statistics
        self.solution_cost = node_fwd.cost + node_rev.cost


        expanded_fwd = {n.state for n in visited_nodes_fwd}
        expanded_rev = {n.state for n in visited_nodes_rev}
        #all unique expanded nodes, minus the meeting node
        self.non_terminal = len(expanded_fwd.union(expanded_rev)) - 1

   
        open_fwd = {n.state for n in fwd_tree.open_nodes}
        open_rev = {n.state for n in rev_tree.open_nodes}
        #unique open nodes to avoid double counting, plus the meeting node
        unique_open = open_fwd.union(open_rev)
        self.terminal = len(unique_open) + 1
            
        return path_fwd + path_rev[1:]


#auxiliary functions
def expand_node(tree, node):
    lnewnodes = []
    for action in tree.problem.domain.actions(node.state):
        newstate = tree.problem.domain.result(node.state, action)
              
        #loop prevention
        if newstate not in tree.get_path(node):

            #acumulated costs for new node
            action_cost = tree.problem.domain.cost(node.state, action)
            new_cost = node.cost + action_cost
            heuristic = tree.problem.domain.heuristic(newstate, tree.problem.goal)
            depth = node.depth + 1

            #create the new children node
            newnode = MyNode(newstate, node, depth, new_cost, heuristic, action)
            lnewnodes.append(newnode)
        
    #node was expanded
    tree.non_terminal += 1

    #update statistics based on strategy
    if tree.strategy == "hybrid":
        tree.hybrid_add_to_open(lnewnodes)
    else:
        tree.add_to_open(lnewnodes)

    return lnewnodes




class MySTRIPS(STRIPS):

    def get_instanciations(self,op,state):
        #no preconditions, return empty (always applicable)
        if not op.pc:
            return [dict()]
        
        possible_assignments = self.var_possible_assignments(op, state)
        
        if not possible_assignments:
            return []   #no possible assignments, return empty
        

        #build all combinations of assignments
        all_combos = [dict()]
        for var, values in possible_assignments.items():
            new_combos = []
            for combo in all_combos:
                for val in values:
                    new_dict = dict(combo)
                    new_dict[var] = val
                    new_combos.append(new_dict)
            all_combos = new_combos
        

        #check which combinations satisfy all preconditions
        valid_instanciations = []
        for combo in all_combos:
            if self.precondition_satisfied(op.pc, state, combo):
                args = [combo[a] for a in op.args]
                valid_instanciations.append(op.instanciate(args))

        
        return valid_instanciations



    def var_possible_assignments(self,op,state):
        var_assignments = {}

        for pre in op.pc:
            for states in state:
                #only compare if theyre the same type of state
                if type(pre) != type(states):
                    continue

                #check if arguments are variable or constant
                pre_args = pre.args if hasattr(pre, 'args') else []
                state_args = states.args if hasattr(states, 'args') else []

                #check if they have the same number of arguments
                if len(pre_args) != len(state_args):
                    continue

                for x, a in zip(pre_args, state_args):
                    if x in op.args:    #x is a variable, can become any constant
                        var_assignments.setdefault(x, set()).add(a)

                    elif x != a:    #x is a constant, check if its equal to a
                        break       #assignment isnt possible

        return var_assignments



    def precondition_satisfied(self, pc, state, assignment):
        for prec in pc:

            new_prec = prec.substitute(assignment)

            if new_prec not in state:   #check if precondition is satisfied
                return False

        return True
                            