|
"""Implement lattice interface.""" |
|
|
|
from overrides import overrides |
|
import numpy as np |
|
from abc import ABCMeta, abstractmethod |
|
|
|
|
|
class Lattice(metaclass=ABCMeta): |
|
|
|
"""Abstract base class representing a complemented lattice.""" |
|
|
|
@classmethod |
|
@abstractmethod |
|
def join(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray: |
|
return NotImplemented |
|
|
|
@classmethod |
|
@abstractmethod |
|
def meet(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray: |
|
return NotImplemented |
|
|
|
@classmethod |
|
@abstractmethod |
|
def join_reduce(cls, probs: np.ndarray) -> np.ndarray: |
|
return NotImplemented |
|
|
|
@classmethod |
|
@abstractmethod |
|
def meet_reduce(cls, probs: np.ndarray) -> np.ndarray: |
|
return NotImplemented |
|
|
|
|
|
class Product(Lattice): |
|
"""Lattice where meet=prod and sum is defined accordingly. |
|
|
|
Equivalent to assuming independence, more or less. |
|
""" |
|
|
|
eps = 1e-9 |
|
|
|
@classmethod |
|
@overrides |
|
def join(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray: |
|
return probs1 + probs2 - cls.meet(probs1, probs2) |
|
|
|
@classmethod |
|
@overrides |
|
def meet(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray: |
|
return probs1 * probs2 |
|
|
|
@classmethod |
|
@overrides |
|
def join_reduce(cls, probs: np.ndarray) -> np.ndarray: |
|
"""Assumes disjoint events.""" |
|
|
|
return np.sum(probs, axis=-1) |
|
|
|
@classmethod |
|
@overrides |
|
def meet_reduce(cls, probs: np.ndarray) -> np.ndarray: |
|
return np.prod(probs, axis=-1) |
|
|
|
@classmethod |
|
def comp(cls, probs): |
|
return 1 - probs |
|
|
|
@classmethod |
|
def normalize(cls, probs): |
|
"""Normalize a distribution by dividing by the total mass.""" |
|
return probs / np.sum(probs + cls.eps, axis=-1) |
|
|