#encoding: utf8
#Ana Loureiro 104063

from semanticnetwork import *
from constraintsearch import *
from bayes_net import *


class MySN(SemanticNetwork):
    def __init__(self):
        SemanticNetwork.__init__(self)

    def new_query_local(self,e1,relname=None,e2=None):  
        #returns list of (relname, e2) where (e1,rel,e2)
        #local declarations and inverse relations, no inheritance

        res = set()

        #direct: (e1,rel,x)
        for d in self.declarations:
            rel = d.relation
            if rel.entity1 == e1:
                if (relname is None or rel.name == relname) and (e2 is None or rel.entity2 == e2):
                    pair = (rel.name, rel.entity2)
                    res.add(pair)

        #inverse: (x, invName, e1) becomes (e1, relname, x)
        for d in self.declarations:
            rel = d.relation
            if rel.entity2 == e1 and rel.name in self.inverse:
                if (relname is None or self.inverse[rel.name] == relname) and (e2 is None or rel.entity1 == e2):
                    pair = (self.inverse[rel.name], rel.entity1)
                    res.add(pair)

        return list(res)


    def new_query(self,entity,relname):
        #returns list of values where (entity,rel), with inherited values

        #Member/Subtype: local declarations only
        #AssocOne: inheritance with cancelling, only most common value
        #AssocSome: inheritance without cancelling
            # - Opposite association cancels inheritance in AssocOne and AssocSome
            # - If association exists in AssocOne and AssocSome, stays where theres majority
        #Inverse relations are considered - new_query_local()
        

        #Member / Subtype
        if relname in ("member", "subtype"):
            vals = [d.relation.entity2 for d in self.query_local(e1=entity, relname=relname)]
            return list(set(vals))
        

        #Decide if AssocOne or AssocSome
        invRel = self.inverse.get(relname)
        oneCount = 0
        someCount = 0

        for d in self.declarations:
            rel = d.relation
            if rel.name == relname or (invRel is not None and rel.name == invRel):
                if isinstance(rel, AssocOne):
                    oneCount += 1
                elif isinstance(rel, AssocSome):
                    someCount += 1

        #no assoc, return empty
        if (oneCount == 0) and (someCount == 0):
            return []
        
        #type of assoc
        if someCount >= oneCount:
            relAssoc = AssocSome
        else:
            relAssoc = AssocOne

        #get all local values (no matter assoc)
        local_candidates = [v for (_, v) in self.new_query_local(entity, relname)]

        #type filtering
        def count_val(e, relname, val, assoc):
            #count values in assoc (with (e, rel, x))
            count = 0
            invName = self.inverse.get(relname)

            for d in self.declarations:
                r = d.relation

                if not isinstance(r, assoc):
                    continue

                #direct: (ent,relname, val)
                if r.name == relname and r.entity1 == e and r.entity2 == val:
                    count += 1

                # inverse: (val, invName, e) = (e, relname, val)
                if invName is not None and r.name == invName and r.entity1 == val and r.entity2 == e:
                    count += 1

            return count

        

        #For AssocOne: most common value, return if local exists
        if relAssoc == AssocOne:
            local_counts = {}

            for val in local_candidates:
                k = count_val(entity, relname, val, relAssoc)
                if k > 0: 
                    local_counts[val] = k
            
            if local_counts:
                best_val = max(local_counts, key=local_counts.get)
                return [best_val]
        
        else: 
            local_set = set()
            local_candidates = [v for (_, v) in self.new_query_local(entity, relname)]
            for v in local_candidates:
                if count_val(entity, relname, v, AssocSome) > 0:
                    local_set.add(v)

        #Opposite cancelling
        cancel_vals = set()
        oppo = self.opposite.get(relname)
        if oppo is not None:
            oppoInv = self.inverse.get(oppo)

            #assoc type for opposite (majority)
            oOneCount = 0
            oSomeCount = 0
            for d in self.declarations:
                rel = d.relation
                if rel.name == oppo or (oppoInv is not None and rel.name == oppoInv):
                    if isinstance(rel, AssocOne):
                        oOneCount += 1
                    elif isinstance(rel, AssocSome):
                        oSomeCount += 1

            #type of assoc
            if oSomeCount >= oOneCount:
                oppoAssoc = AssocSome
            else:
                oppoAssoc = AssocOne

            oppo_candidates = [v for (_, v) in self.new_query_local(entity, relname=oppo)]
            for v in oppo_candidates:
                if count_val(entity, oppo, v, oppoAssoc) > 0:
                    cancel_vals.add(v)



        def parents_of(e):
        #parents via Member / Subtype
            parents = []
            for d in self.query_local(e1=e):
                r = d.relation
                if isinstance(r, (Member, Subtype)):
                    parents.append(r.entity2)
            return parents
    

        parents = parents_of(entity)


        #Inheritance
        if relAssoc == AssocOne:

            def inherit_counts(e, visited):
                if e in visited:
                    return {}
                visited.add(e)

                #most common value, return if local exists
                local_counts = {}
                cand = [v for (_, v) in self.new_query_local(e, relname)]
                for val in cand:
                    k = count_val(e, relname, val, AssocOne)
                    if k > 0:
                        local_counts[val] = k

                if local_counts:
                    return local_counts

                total = {}
                for p in parents_of(e):
                    pc = inherit_counts(p, visited)
                    for v, k in pc.items():
                        total[v] = total.get(v, 0) + k
                return total

            inherited_counts = {}
            for p in parents:
                pc = inherit_counts(p, set())
                for v, k in pc.items():
                    inherited_counts[v] = inherited_counts.get(v, 0) + k

            #opposite cancels inherited values
            for v in list(inherited_counts.keys()):
                if v in cancel_vals:
                    del inherited_counts[v]

            if not inherited_counts:
                return []

            best_val = max(inherited_counts, key=inherited_counts.get)
            return [best_val]

        else:
            def inherit_set(e, visited):
                if e in visited:
                    return set()
                visited.add(e)

                vals = set()
                cand = [v for (_, v) in self.new_query_local(e, relname)]
                for val in cand:
                    if count_val(e, relname, val, AssocSome) > 0:
                        vals.add(val)

                for p in parents_of(e):
                    vals = vals | inherit_set(p, visited)

                return vals

            inherited_set = set()
            for p in parents:
                inherited_set = inherited_set | inherit_set(p, set())

            inherited_set -= cancel_vals

            #union of sets
            return list(local_set | inherited_set)





class MyCS(ConstraintSearch):

    def __init__(self,domains,constraints):
        ConstraintSearch.__init__(self,domains,constraints)

        #edges that enter each variable (vj,vi), timesave for propagation
        self.in_edges = {}
        for (vj, vi) in self.constraints:
            if vi not in self.in_edges:
                self.in_edges[vi] = []
            self.in_edges[vi].append((vj, vi))


    def search_all(self, domains=None):
        #returns list with all solutions
        self.calls += 1

        if domains is None:
            domains = self.domains

        #fail: some variable has empty domain
        if any(domains[v] == [] for v in domains):
            return []

        #success: all variables fixed
        if all(len(domains[v]) == 1 for v in domains):
            return [{v: domains[v][0] for v in domains}]

        #choose the most constrained variable (smallest domain), tie = alphabetical
        candidates = [v for v in domains if len(domains[v]) > 1]
        var = min(candidates, key=lambda v: (len(domains[v]), v))

        sols = []

        #try each value of var
        for val in domains[var]:
            newdomains = dict(domains)
            newdomains[var] = [val]

            #propagate through edges that point to var
            edges = list(self.in_edges.get(var, []))
            while edges:
                (vj, vi) = edges.pop()
                constraint = self.constraints[(vj, vi)]

                #reduce domain of vj
                newdomain = [xj for xj in newdomains[vj]
                             if any(constraint(vj, xj, vi, xi) for xi in newdomains[vi])]

                #if domain changed, propagate again from vj
                if len(newdomain) < len(newdomains[vj]):
                    newdomains[vj] = newdomain
                    edges += self.in_edges.get(vj, [])

            sols += self.search_all(newdomains)

        return sols

class MyBN(BayesNet):

    def __init__(self):
        BayesNet.__init__(self)


    def independence_bag(self, v1, v2): 
        #returns a set for v1 and v2
        # - get all vars on the paths from v1 and v2 to their closest common ancestors, excluding ancest of ancest
        # - add the mothers of all vars collected

        #get mothers of each variable
        mothers = {}
        for var, table in self.dependencies.items():
            mset = set()
            for conj in table.keys():  #conj = frozenset((parent,bool),)
                for (m, _ ) in conj:
                    mset.add(m)
            mothers[var] = mset


        #mothers also exist in dict
        for var in list(mothers.keys()):
            for m in mothers[var]:
                if m not in mothers:
                    mothers[m] = set()


        #ancestors of x
        def ancestors(x):
            res = set()
            for m in mothers.get(x, set()):
                res.add(m)
                res = res | ancestors(m)
            return res


        #common ancestors of v1 and v2 + themselves
        A1 = ancestors(v1) | {v1}
        A2 = ancestors(v2) | {v2}
        common = A1 & A2

        #closest common ancestors = remove ancestors of ancestors
        closest = set(common)
        for a in common:
            for b in common:
                if a != b and a in ancestors(b):
                    if a in closest:
                        closest.remove(a)
                    break

        #vars on a path going up from start until goal
        def nodes_on_path(start, goal):
            if start == goal:
                return {start}

            nodes = set()
            for m in mothers.get(start, set()):
                sub = nodes_on_path(m, goal)
                if sub:
                    nodes = nodes | sub #add, no dupes

            if nodes:
                nodes.add(start)

            return nodes

        #collect path vars from v1/v2 to closest common ancestors
        pathVars = set()
        if not closest:
            pathVars = {v1, v2}
        else:
            for comAn in closest:
                #union, no dupes
                pathVars = pathVars | nodes_on_path(v1, comAn)
                pathVars = pathVars | nodes_on_path(v2, comAn)


        #add mothers of everything found
        bag = set(pathVars)
        for var in pathVars:
            bag = bag | mothers.get(var, set())

        return bag

