Skip to content

init

cosy

T = TypeVar('T', bound=Hashable) module-attribute

__all__ = ['DSL', 'Literal', 'Var', 'Subtypes', 'Type', 'Omega', 'Constructor', 'Arrow', 'Intersection', 'Synthesizer', 'SolutionSpace'] module-attribute

Arrow dataclass

Source code in src/cosy/types.py
@dataclass(frozen=True)
class Arrow(Type):
    source: Type = field(init=True)
    target: Type = field(init=True)
    is_omega: bool = field(init=False, compare=False)
    size: int = field(init=False, compare=False)
    organized: set[Type] = field(init=False, compare=False)
    free_vars: set[str] = field(init=False, compare=False)

    def __post_init__(self) -> None:
        super().__init__(
            is_omega=self._is_omega(),
            size=self._size(),
            organized=self._organized(),
            free_vars=self._free_vars(),
        )

    def _is_omega(self) -> bool:
        return self.target.is_omega

    def _size(self) -> int:
        return 1 + self.source.size + self.target.size

    def _organized(self) -> set[Type]:
        if len(self.target.organized) == 0:
            return set()
        if len(self.target.organized) == 1:
            return {self}
        return {Arrow(self.source, tp) for tp in self.target.organized}

    def _free_vars(self) -> set[str]:
        return set.union(self.source.free_vars, self.target.free_vars)

    def __str__(self) -> str:
        return f"{self.source} -> {self.target}"

    def subst(self, groups: Mapping[str, str], substitution: dict[str, Any]) -> Type:
        if not any(var in substitution for var in self.free_vars):
            return self
        return Arrow(
            self.source.subst(groups, substitution),
            self.target.subst(groups, substitution),
        )

free_vars: set[str] = field(init=False, compare=False) class-attribute instance-attribute

is_omega: bool = field(init=False, compare=False) class-attribute instance-attribute

organized: set[Type] = field(init=False, compare=False) class-attribute instance-attribute

size: int = field(init=False, compare=False) class-attribute instance-attribute

source: Type = field(init=True) class-attribute instance-attribute

target: Type = field(init=True) class-attribute instance-attribute

__init__(source: Type, target: Type, *, is_omega: bool, size: int, organized: set[Type], free_vars: set[str]) -> None

__post_init__() -> None

Source code in src/cosy/types.py
def __post_init__(self) -> None:
    super().__init__(
        is_omega=self._is_omega(),
        size=self._size(),
        organized=self._organized(),
        free_vars=self._free_vars(),
    )

__str__() -> str

Source code in src/cosy/types.py
def __str__(self) -> str:
    return f"{self.source} -> {self.target}"

subst(groups: Mapping[str, str], substitution: dict[str, Any]) -> Type

Source code in src/cosy/types.py
def subst(self, groups: Mapping[str, str], substitution: dict[str, Any]) -> Type:
    if not any(var in substitution for var in self.free_vars):
        return self
    return Arrow(
        self.source.subst(groups, substitution),
        self.target.subst(groups, substitution),
    )

CoSy

Source code in src/cosy/__init__.py
class CoSy(Generic[T]):
    component_specifications: Mapping[T, Specification]
    parameter_space: ParameterSpace | None = None
    taxonomy: Taxonomy | None = None
    _synthesizer: Synthesizer

    def __init__(
        self,
        component_specifications: Mapping[T, Specification],
        parameter_space: ParameterSpace | None = None,
        taxonomy: Taxonomy | None = None,
    ) -> None:
        self.component_specifications = component_specifications
        self.parameter_space = parameter_space
        self.taxonomy = taxonomy if taxonomy is not None else {}
        self._synthesizer = Synthesizer(component_specifications, parameter_space, self.taxonomy)

    def solve(self, query: Type, max_count: int = 100) -> Iterable[Any]:
        """
        Solves the given query by constructing a solution space and enumerating and interpreting the resulting trees.

        :param query: The query to solve.
        :param max_count: The maximum number of trees to enumerate.
        :return: An iterable of interpreted trees.
        """
        if not isinstance(query, Type):
            msg = "Query must be of type Type"
            raise TypeError(msg)
        solution_space = self._synthesizer.construct_solution_space(query).prune()

        trees = solution_space.enumerate_trees(query, max_count=max_count)
        for tree in trees:
            yield tree.interpret()

component_specifications: Mapping[T, Specification] = component_specifications instance-attribute

parameter_space: ParameterSpace | None = parameter_space class-attribute instance-attribute

taxonomy: Taxonomy | None = taxonomy if taxonomy is not None else {} class-attribute instance-attribute

__init__(component_specifications: Mapping[T, Specification], parameter_space: ParameterSpace | None = None, taxonomy: Taxonomy | None = None) -> None

Source code in src/cosy/__init__.py
def __init__(
    self,
    component_specifications: Mapping[T, Specification],
    parameter_space: ParameterSpace | None = None,
    taxonomy: Taxonomy | None = None,
) -> None:
    self.component_specifications = component_specifications
    self.parameter_space = parameter_space
    self.taxonomy = taxonomy if taxonomy is not None else {}
    self._synthesizer = Synthesizer(component_specifications, parameter_space, self.taxonomy)

solve(query: Type, max_count: int = 100) -> Iterable[Any]

Solves the given query by constructing a solution space and enumerating and interpreting the resulting trees.

:param query: The query to solve. :param max_count: The maximum number of trees to enumerate. :return: An iterable of interpreted trees.

Source code in src/cosy/__init__.py
def solve(self, query: Type, max_count: int = 100) -> Iterable[Any]:
    """
    Solves the given query by constructing a solution space and enumerating and interpreting the resulting trees.

    :param query: The query to solve.
    :param max_count: The maximum number of trees to enumerate.
    :return: An iterable of interpreted trees.
    """
    if not isinstance(query, Type):
        msg = "Query must be of type Type"
        raise TypeError(msg)
    solution_space = self._synthesizer.construct_solution_space(query).prune()

    trees = solution_space.enumerate_trees(query, max_count=max_count)
    for tree in trees:
        yield tree.interpret()

Constructor dataclass

Source code in src/cosy/types.py
@dataclass(frozen=True)
class Constructor(Type):
    name: str = field(init=True)
    arg: Type = field(default=Omega(), init=True)
    is_omega: bool = field(init=False, compare=False)
    size: int = field(init=False, compare=False)
    organized: set[Type] = field(init=False, compare=False)
    free_vars: set[str] = field(init=False, compare=False)

    def __post_init__(self) -> None:
        super().__init__(
            is_omega=self._is_omega(),
            size=self._size(),
            organized=self._organized(),
            free_vars=self._free_vars(),
        )

    def _is_omega(self) -> bool:
        return False

    def _size(self) -> int:
        return 1 + self.arg.size

    def _organized(self) -> set[Type]:
        if len(self.arg.organized) <= 1:
            return {self}
        return {Constructor(self.name, ap) for ap in self.arg.organized}

    def _free_vars(self) -> set[str]:
        return self.arg.free_vars

    def __str__(self) -> str:
        if self.arg == Omega():
            return str(self.name)
        return f"{self.name!s}({self.arg!s})"

    def subst(self, groups: Mapping[str, str], substitution: dict[str, Any]) -> Type:
        if not any(var in substitution for var in self.free_vars):
            return self
        return Constructor(self.name, self.arg.subst(groups, substitution))

arg: Type = field(default=Omega(), init=True) class-attribute instance-attribute

free_vars: set[str] = field(init=False, compare=False) class-attribute instance-attribute

is_omega: bool = field(init=False, compare=False) class-attribute instance-attribute

name: str = field(init=True) class-attribute instance-attribute

organized: set[Type] = field(init=False, compare=False) class-attribute instance-attribute

size: int = field(init=False, compare=False) class-attribute instance-attribute

__init__(name: str, arg: Type = Omega(), *, is_omega: bool, size: int, organized: set[Type], free_vars: set[str]) -> None

__post_init__() -> None

Source code in src/cosy/types.py
def __post_init__(self) -> None:
    super().__init__(
        is_omega=self._is_omega(),
        size=self._size(),
        organized=self._organized(),
        free_vars=self._free_vars(),
    )

__str__() -> str

Source code in src/cosy/types.py
def __str__(self) -> str:
    if self.arg == Omega():
        return str(self.name)
    return f"{self.name!s}({self.arg!s})"

subst(groups: Mapping[str, str], substitution: dict[str, Any]) -> Type

Source code in src/cosy/types.py
def subst(self, groups: Mapping[str, str], substitution: dict[str, Any]) -> Type:
    if not any(var in substitution for var in self.free_vars):
        return self
    return Constructor(self.name, self.arg.subst(groups, substitution))

DSL

A domain-specific language (DSL) to define component specifications.

This class provides a interface for defining specifications in a declarative manner. It allows users to specify the name and group of each parameter, as well as filter.

Examples:

DSL() .Parameter("x", int) .Parameter("y", int, lambda vars: vars["x"] + 1) .Parameter("z", str) .ParameterConstraint(lambda vars: len(vars["z"]) == vars["x"] + vars["y"]) .Suffix()

constructs a specification for a function with three parameters: - x: an integer - y: an integer, the value of which is x + 1 - z: a string, whose length is equal to x + y The Suffix method specifies the function, which uses the variables x, y, and z.

Source code in src/cosy/dsl.py
class DSL:
    """
    A domain-specific language (DSL) to define component specifications.

    This class provides a interface for defining specifications in a declarative manner. It allows
    users to specify the name and group of each parameter, as well as filter.

    Examples:
        DSL()
            .Parameter("x", int)
            .Parameter("y", int, lambda vars: vars["x"] + 1)
            .Parameter("z", str)
            .ParameterConstraint(lambda vars: len(vars["z"]) == vars["x"] + vars["y"])
            .Suffix(<Type using Var("x"), Var("y") and Var("z")>)

        constructs a specification for a function with three parameters:
        - `x`: an integer
        - `y`: an integer, the value of which is `x` + 1
        - `z`: a string, whose length is equal to `x` + `y`
        The `Suffix` method specifies the function, which uses the variables `x`, `y`, and `z`.
    """

    def __init__(self) -> None:
        """
        Initialize the DSL object
        """

        self._result: Callable[[Specification], Specification] = lambda suffix: suffix

    def parameter(  #
        self,
        name: str,
        group: str,
        candidates: Callable[[dict[str, Any]], Sequence[Any]] | None = None,
    ) -> DSL:
        """
        Introduce a new parameter variable.

        `group` is a string, and an instance of this specification will be generated
        for each valid literal in the corresponding literal group.
        You can use this variable as Var(name) in all `Type`s, after the introduction
        and in all predicates.
        Optionally, you can specify a sequence of candidate values, that will be used to generate
        the literals. This sequence is parameterized by the values of previously
        defined literal variables. This is useful, if you want to restrict the values of a variable
        to a subset of the values in the corresponding literal group.

        :param name: The name of the new variable.
        :type name: str
        :param group: The group of the variable.
        :type group: str
        :param candidates: Parameterized sequence of candidate values, that will be used to generate the literals.
        :type candidates: Callable[[dict[str, Any]], Sequence[Any]] | None
        :return: The DSL object.
        :rtype: DSL
        """

        def new_result(suffix: Specification, result=self._result) -> Specification:
            return result(Abstraction(LiteralParameter(name, group, candidates), suffix))

        self._result = new_result
        return self

    def argument(self, name: str, specification: Type) -> DSL:
        """
        Introduce a new variable.

        `group` is a `Type`, and an instance will be generated for each tree, satisfying
        the specification given by the type. Since this can only be done in the enumeration step,
        you can only use these variables in predicates, that themselves belong to variables whose `group` is a `Type`.

        :param name: The name of the new variable.
        :type name: str
        :param specification: The type of the variable.
        :type specification: Type
        :return: The DSL object.
        :rtype: DSL
        """

        def new_result(suffix: Specification, result=self._result) -> Specification:
            return result(Abstraction(TermParameter(name, specification), suffix))

        self._result = new_result
        return self

    def parameter_constraint(self, constraint: Callable[[Mapping[str, Any]], bool]) -> DSL:
        """
        Constraint on the previously defined parameter variables.

        :param constraint: A constraint deciding, if the currently chosen parameter values are valid.
            The values of variables are passed by a dictionary, where the keys are the names of the
            parameter variables and the values are the corresponding values.
        :type constraint: Callable[[Mapping[str, Any]], bool]
        :return: The DSL object.
        :rtype: DSL
        """

        def new_result(suffix: Specification, result=self._result) -> Specification:
            return result(Implication(Predicate(constraint, True), suffix))

        self._result = new_result
        return self

    def constraint(self, constraint: Callable[[Mapping[str, Any]], bool]) -> DSL:
        """
        Constraint on the previously defined parameter variables and argument variables.

        :param constraint: A constraint deciding, if the currently chosen values are valid.
            The values of variables are passed by a dictionary, where the keys are the names of the
            variables and the values are the corresponding values.
        :type constraint: Callable[[Mapping[str, Any]], bool]
        :return: The DSL object.
        :rtype: DSL
        """

        def new_result(suffix: Specification, result=self._result) -> Specification:
            return result(Implication(Predicate(constraint, False), suffix))

        self._result = new_result
        return self

    def suffix(self, suffix: Type) -> Specification:
        """
        Constructs the final specification wrapping the given `Type` `suffix`.

        :param suffix: The wrapped type.
        :type suffix: Type
        :return: The constructed specification.
        :rtype: Abstraction | Type
        """
        return self._result(suffix)

__init__() -> None

Initialize the DSL object

Source code in src/cosy/dsl.py
def __init__(self) -> None:
    """
    Initialize the DSL object
    """

    self._result: Callable[[Specification], Specification] = lambda suffix: suffix

argument(name: str, specification: Type) -> DSL

Introduce a new variable.

group is a Type, and an instance will be generated for each tree, satisfying the specification given by the type. Since this can only be done in the enumeration step, you can only use these variables in predicates, that themselves belong to variables whose group is a Type.

:param name: The name of the new variable. :type name: str :param specification: The type of the variable. :type specification: Type :return: The DSL object. :rtype: DSL

Source code in src/cosy/dsl.py
def argument(self, name: str, specification: Type) -> DSL:
    """
    Introduce a new variable.

    `group` is a `Type`, and an instance will be generated for each tree, satisfying
    the specification given by the type. Since this can only be done in the enumeration step,
    you can only use these variables in predicates, that themselves belong to variables whose `group` is a `Type`.

    :param name: The name of the new variable.
    :type name: str
    :param specification: The type of the variable.
    :type specification: Type
    :return: The DSL object.
    :rtype: DSL
    """

    def new_result(suffix: Specification, result=self._result) -> Specification:
        return result(Abstraction(TermParameter(name, specification), suffix))

    self._result = new_result
    return self

constraint(constraint: Callable[[Mapping[str, Any]], bool]) -> DSL

Constraint on the previously defined parameter variables and argument variables.

:param constraint: A constraint deciding, if the currently chosen values are valid. The values of variables are passed by a dictionary, where the keys are the names of the variables and the values are the corresponding values. :type constraint: Callable[[Mapping[str, Any]], bool] :return: The DSL object. :rtype: DSL

Source code in src/cosy/dsl.py
def constraint(self, constraint: Callable[[Mapping[str, Any]], bool]) -> DSL:
    """
    Constraint on the previously defined parameter variables and argument variables.

    :param constraint: A constraint deciding, if the currently chosen values are valid.
        The values of variables are passed by a dictionary, where the keys are the names of the
        variables and the values are the corresponding values.
    :type constraint: Callable[[Mapping[str, Any]], bool]
    :return: The DSL object.
    :rtype: DSL
    """

    def new_result(suffix: Specification, result=self._result) -> Specification:
        return result(Implication(Predicate(constraint, False), suffix))

    self._result = new_result
    return self

parameter(name: str, group: str, candidates: Callable[[dict[str, Any]], Sequence[Any]] | None = None) -> DSL

Introduce a new parameter variable.

group is a string, and an instance of this specification will be generated for each valid literal in the corresponding literal group. You can use this variable as Var(name) in all Types, after the introduction and in all predicates. Optionally, you can specify a sequence of candidate values, that will be used to generate the literals. This sequence is parameterized by the values of previously defined literal variables. This is useful, if you want to restrict the values of a variable to a subset of the values in the corresponding literal group.

:param name: The name of the new variable. :type name: str :param group: The group of the variable. :type group: str :param candidates: Parameterized sequence of candidate values, that will be used to generate the literals. :type candidates: Callable[[dict[str, Any]], Sequence[Any]] | None :return: The DSL object. :rtype: DSL

Source code in src/cosy/dsl.py
def parameter(  #
    self,
    name: str,
    group: str,
    candidates: Callable[[dict[str, Any]], Sequence[Any]] | None = None,
) -> DSL:
    """
    Introduce a new parameter variable.

    `group` is a string, and an instance of this specification will be generated
    for each valid literal in the corresponding literal group.
    You can use this variable as Var(name) in all `Type`s, after the introduction
    and in all predicates.
    Optionally, you can specify a sequence of candidate values, that will be used to generate
    the literals. This sequence is parameterized by the values of previously
    defined literal variables. This is useful, if you want to restrict the values of a variable
    to a subset of the values in the corresponding literal group.

    :param name: The name of the new variable.
    :type name: str
    :param group: The group of the variable.
    :type group: str
    :param candidates: Parameterized sequence of candidate values, that will be used to generate the literals.
    :type candidates: Callable[[dict[str, Any]], Sequence[Any]] | None
    :return: The DSL object.
    :rtype: DSL
    """

    def new_result(suffix: Specification, result=self._result) -> Specification:
        return result(Abstraction(LiteralParameter(name, group, candidates), suffix))

    self._result = new_result
    return self

parameter_constraint(constraint: Callable[[Mapping[str, Any]], bool]) -> DSL

Constraint on the previously defined parameter variables.

:param constraint: A constraint deciding, if the currently chosen parameter values are valid. The values of variables are passed by a dictionary, where the keys are the names of the parameter variables and the values are the corresponding values. :type constraint: Callable[[Mapping[str, Any]], bool] :return: The DSL object. :rtype: DSL

Source code in src/cosy/dsl.py
def parameter_constraint(self, constraint: Callable[[Mapping[str, Any]], bool]) -> DSL:
    """
    Constraint on the previously defined parameter variables.

    :param constraint: A constraint deciding, if the currently chosen parameter values are valid.
        The values of variables are passed by a dictionary, where the keys are the names of the
        parameter variables and the values are the corresponding values.
    :type constraint: Callable[[Mapping[str, Any]], bool]
    :return: The DSL object.
    :rtype: DSL
    """

    def new_result(suffix: Specification, result=self._result) -> Specification:
        return result(Implication(Predicate(constraint, True), suffix))

    self._result = new_result
    return self

suffix(suffix: Type) -> Specification

Constructs the final specification wrapping the given Type suffix.

:param suffix: The wrapped type. :type suffix: Type :return: The constructed specification. :rtype: Abstraction | Type

Source code in src/cosy/dsl.py
def suffix(self, suffix: Type) -> Specification:
    """
    Constructs the final specification wrapping the given `Type` `suffix`.

    :param suffix: The wrapped type.
    :type suffix: Type
    :return: The constructed specification.
    :rtype: Abstraction | Type
    """
    return self._result(suffix)

Intersection dataclass

Source code in src/cosy/types.py
@dataclass(frozen=True)
class Intersection(Type):
    left: Type = field(init=True)
    right: Type = field(init=True)
    is_omega: bool = field(init=False, compare=False)
    size: int = field(init=False, compare=False)
    organized: set[Type] = field(init=False, compare=False)
    free_vars: set[str] = field(init=False, compare=False)

    def __post_init__(self) -> None:
        super().__init__(
            is_omega=self._is_omega(),
            size=self._size(),
            organized=self._organized(),
            free_vars=self._free_vars(),
        )

    def _is_omega(self) -> bool:
        return self.left.is_omega and self.right.is_omega

    def _size(self) -> int:
        return 1 + self.left.size + self.right.size

    def _organized(self) -> set[Type]:
        return set.union(self.left.organized, self.right.organized)

    def _free_vars(self) -> set[str]:
        return set.union(self.left.free_vars, self.right.free_vars)

    def __str__(self) -> str:
        return f"{self.left} & {self.right}"

    def subst(self, groups: Mapping[str, str], substitution: dict[str, Any]) -> Type:
        if not any(var in substitution for var in self.free_vars):
            return self
        return Intersection(
            self.left.subst(groups, substitution),
            self.right.subst(groups, substitution),
        )

free_vars: set[str] = field(init=False, compare=False) class-attribute instance-attribute

is_omega: bool = field(init=False, compare=False) class-attribute instance-attribute

left: Type = field(init=True) class-attribute instance-attribute

organized: set[Type] = field(init=False, compare=False) class-attribute instance-attribute

right: Type = field(init=True) class-attribute instance-attribute

size: int = field(init=False, compare=False) class-attribute instance-attribute

__init__(left: Type, right: Type, *, is_omega: bool, size: int, organized: set[Type], free_vars: set[str]) -> None

__post_init__() -> None

Source code in src/cosy/types.py
def __post_init__(self) -> None:
    super().__init__(
        is_omega=self._is_omega(),
        size=self._size(),
        organized=self._organized(),
        free_vars=self._free_vars(),
    )

__str__() -> str

Source code in src/cosy/types.py
def __str__(self) -> str:
    return f"{self.left} & {self.right}"

subst(groups: Mapping[str, str], substitution: dict[str, Any]) -> Type

Source code in src/cosy/types.py
def subst(self, groups: Mapping[str, str], substitution: dict[str, Any]) -> Type:
    if not any(var in substitution for var in self.free_vars):
        return self
    return Intersection(
        self.left.subst(groups, substitution),
        self.right.subst(groups, substitution),
    )

Literal dataclass

Source code in src/cosy/types.py
@dataclass(frozen=True)
class Literal(Type):
    value: Any  # has to be Hashable
    group: str
    is_omega: bool = field(init=False, compare=False)
    size: int = field(init=False, compare=False)
    organized: set[Type] = field(init=False, compare=False)
    free_vars: set[str] = field(init=False, compare=False)

    def __post_init__(self) -> None:
        super().__init__(
            is_omega=self._is_omega(),
            size=self._size(),
            organized=self._organized(),
            free_vars=self._free_vars(),
        )

    def _is_omega(self) -> bool:
        return False

    def _size(self) -> int:
        return 1

    def _organized(self) -> set[Type]:
        return {self}

    def _free_vars(self) -> set[str]:
        return set()

    def __str__(self) -> str:
        return f"[{self.value!s}, {self.group}]"

    def subst(self, _groups: Mapping[str, str], _substitution: dict[str, Any]) -> Type:
        return self

free_vars: set[str] = field(init=False, compare=False) class-attribute instance-attribute

group: str instance-attribute

is_omega: bool = field(init=False, compare=False) class-attribute instance-attribute

organized: set[Type] = field(init=False, compare=False) class-attribute instance-attribute

size: int = field(init=False, compare=False) class-attribute instance-attribute

value: Any instance-attribute

__init__(value: Any, group: str, *, is_omega: bool, size: int, organized: set[Type], free_vars: set[str]) -> None

__post_init__() -> None

Source code in src/cosy/types.py
def __post_init__(self) -> None:
    super().__init__(
        is_omega=self._is_omega(),
        size=self._size(),
        organized=self._organized(),
        free_vars=self._free_vars(),
    )

__str__() -> str

Source code in src/cosy/types.py
def __str__(self) -> str:
    return f"[{self.value!s}, {self.group}]"

subst(_groups: Mapping[str, str], _substitution: dict[str, Any]) -> Type

Source code in src/cosy/types.py
def subst(self, _groups: Mapping[str, str], _substitution: dict[str, Any]) -> Type:
    return self

Omega dataclass

Source code in src/cosy/types.py
@dataclass(frozen=True)
class Omega(Type):
    is_omega: bool = field(init=False, compare=False)
    size: int = field(init=False, compare=False)
    organized: set[Type] = field(init=False, compare=False)
    free_vars: set[str] = field(init=False, compare=False)

    def __post_init__(self) -> None:
        super().__init__(
            is_omega=self._is_omega(),
            size=self._size(),
            organized=self._organized(),
            free_vars=self._free_vars(),
        )

    def _is_omega(self) -> bool:
        return True

    def _size(self) -> int:
        return 1

    def _organized(self) -> set[Type]:
        return set()

    def __str__(self) -> str:
        return "omega"

    def _free_vars(self) -> set[str]:
        return set()

    def subst(self, _groups: Mapping[str, str], _substitution: dict[str, Any]) -> Type:
        return self

free_vars: set[str] = field(init=False, compare=False) class-attribute instance-attribute

is_omega: bool = field(init=False, compare=False) class-attribute instance-attribute

organized: set[Type] = field(init=False, compare=False) class-attribute instance-attribute

size: int = field(init=False, compare=False) class-attribute instance-attribute

__init__(*, is_omega: bool, size: int, organized: set[Type], free_vars: set[str]) -> None

__post_init__() -> None

Source code in src/cosy/types.py
def __post_init__(self) -> None:
    super().__init__(
        is_omega=self._is_omega(),
        size=self._size(),
        organized=self._organized(),
        free_vars=self._free_vars(),
    )

__str__() -> str

Source code in src/cosy/types.py
def __str__(self) -> str:
    return "omega"

subst(_groups: Mapping[str, str], _substitution: dict[str, Any]) -> Type

Source code in src/cosy/types.py
def subst(self, _groups: Mapping[str, str], _substitution: dict[str, Any]) -> Type:
    return self

SolutionSpace

Source code in src/cosy/solution_space.py
class SolutionSpace(Generic[NT, T, G]):
    _rules: defaultdict[NT, deque[RHSRule[NT, T, G]]]

    def __init__(self, rules: dict[NT, deque[RHSRule[NT, T, G]]] | None = None) -> None:
        if rules is None:
            rules = defaultdict(deque)
        self._rules = defaultdict(deque, rules)

    def get(self, nonterminal: NT) -> deque[RHSRule[NT, T, G]] | None:
        return self._rules.get(nonterminal)

    def __getitem__(self, nonterminal: NT) -> deque[RHSRule[NT, T, G]]:
        return self._rules[nonterminal]

    def nonterminals(self) -> Iterable[NT]:
        return self._rules.keys()

    def as_tuples(self) -> Iterable[tuple[NT, deque[RHSRule[NT, T, G]]]]:
        return self._rules.items()

    def add_rule(
        self,
        nonterminal: NT,
        terminal: T,
        arguments: tuple[Argument, ...],
        predicates: tuple[Callable[[dict[str, Any]], bool], ...],
    ) -> None:
        self._rules[nonterminal].append(RHSRule(arguments, predicates, terminal))

    def show(self) -> str:
        return "\n".join(
            f"{nt!s} ~> {' | '.join([str(subrule) for subrule in rule])}" for nt, rule in self._rules.items()
        )

    def prune(self) -> SolutionSpace[NT, T, G]:
        """Keep only productive rules."""

        ground_types: set[NT] = set()
        queue: set[NT] = set()
        inverse_grammar: dict[NT, set[tuple[NT, frozenset[NT]]]] = defaultdict(set)

        for n, exprs in self._rules.items():
            for expr in exprs:
                non_terminals = expr.non_terminals
                for m in non_terminals:
                    inverse_grammar[m].add((n, non_terminals))
                if not non_terminals:
                    queue.add(n)

        while queue:
            n = queue.pop()
            if n not in ground_types:
                ground_types.add(n)
                for m, non_terminals in inverse_grammar[n]:
                    if m not in ground_types and all(t in ground_types for t in non_terminals):
                        queue.add(m)

        return SolutionSpace[NT, T, G](
            defaultdict(
                deque,
                {
                    target: deque(
                        possibility
                        for possibility in self._rules[target]
                        if all(t in ground_types for t in possibility.non_terminals)
                    )
                    for target in ground_types
                },
            )
        )

    def _enumerate_tree_vectors(
        self,
        non_terminals: Sequence[NT | None],
        existing_terms: Mapping[NT, set[Tree[T]]],
        nt_term: tuple[NT, Tree[T]] | None = None,
    ) -> Iterable[tuple[Tree[T] | None, ...]]:
        """Enumerate possible term vectors for a given list of non-terminals and existing terms. Use nt_term at least once (if given)."""
        if nt_term is None:
            yield from product(*([n] if n is None else existing_terms[n] for n in non_terminals))
        else:
            nt, term = nt_term
            for i, n in enumerate(non_terminals):
                if n == nt:
                    arg_lists: Iterable[Iterable[Tree[T] | None]] = (
                        [None] if m is None else [term] if i == j else existing_terms[m]
                        for j, m in enumerate(non_terminals)
                    )
                    yield from product(*arg_lists)

    def _generate_new_trees(
        self,
        rule: RHSRule[NT, T, G],
        existing_terms: Mapping[NT, set[Tree[T]]],
        max_count: int | None = None,
        nt_old_term: tuple[NT, Tree[T]] | None = None,
    ) -> set[Tree[T]]:
        # Genererate new terms for rule `rule` from existing terms up to `max_count`
        # the term `old_term` should be a subterm of all resulting terms, at a position, that corresponds to `nt`

        output_set: set[Tree[T]] = set()
        if max_count == 0:
            return output_set

        named_non_terminals = [
            a.origin if isinstance(a, NonTerminalArgument) and a.name is not None else None for a in rule.arguments
        ]
        unnamed_non_terminals = [
            a.origin if isinstance(a, NonTerminalArgument) and a.name is None else None for a in rule.arguments
        ]
        literal_arguments = [Tree(a.value, ()) if isinstance(a, ConstantArgument) else None for a in rule.arguments]

        def interleave(
            parameters: Sequence[Tree[T] | None],
            literal_arguments: Sequence[Tree[T] | None],
            arguments: Sequence[Tree[T] | None],
        ) -> Iterable[Tree[T]]:
            """Interleave parameters, literal arguments and arguments."""
            for parameter, literal_argument, argument in zip(parameters, literal_arguments, arguments, strict=True):
                if parameter is not None:
                    yield parameter
                elif literal_argument is not None:
                    yield literal_argument
                elif argument is not None:
                    yield argument
                else:
                    msg = "All arguments of interleave are None"
                    raise ValueError(msg)

        def construct_tree(
            rule: RHSRule[NT, T, G],
            parameters: Sequence[Tree[T] | None],
            literal_arguments: Sequence[Tree[T] | None],
            arguments: Sequence[Tree[T] | None],
        ) -> Tree[T]:
            """Construct a new tree from the rule and the given specific arguments."""
            return Tree(
                rule.terminal,
                tuple(interleave(parameters, literal_arguments, arguments)),
            )

        def specific_substitution(parameters):
            return {
                a.name: p
                for p, a in zip(parameters, rule.arguments, strict=True)
                if isinstance(a, NonTerminalArgument) and a.name is not None
            } | rule.literal_substitution

        def valid_parameters(
            nt_term: tuple[NT, Tree[T]] | None,
        ) -> Iterable[tuple[Tree[T] | None, ...]]:
            """Enumerate all valid parameters for the rule."""
            for parameters in self._enumerate_tree_vectors(named_non_terminals, existing_terms, nt_term):
                substitution = specific_substitution(parameters)
                if all(predicate(substitution) for predicate in rule.predicates):
                    yield parameters

        for parameters in valid_parameters(nt_old_term):
            for arguments in self._enumerate_tree_vectors(unnamed_non_terminals, existing_terms):
                output_set.add(construct_tree(rule, parameters, literal_arguments, arguments))
                if max_count is not None and len(output_set) >= max_count:
                    return output_set

        if nt_old_term is not None:
            all_parameters: deque[tuple[Tree[T] | None, ...]] | None = None
            for arguments in self._enumerate_tree_vectors(unnamed_non_terminals, existing_terms):
                all_parameters = all_parameters if all_parameters is not None else deque(valid_parameters(None))
                for parameters in all_parameters:
                    output_set.add(construct_tree(rule, parameters, literal_arguments, arguments))
                    if max_count is not None and len(output_set) >= max_count:
                        return output_set
        return output_set

    def enumerate_trees(
        self,
        start: NT,
        max_count: int | None = None,
        max_bucket_size: int | None = None,
    ) -> Iterable[Tree[T]]:
        """
        Enumerate terms as an iterator efficiently - all terms are enumerated, no guaranteed term order.
        """
        if start not in self.nonterminals():
            return

        queues: dict[NT, PriorityQueue[Tree[T]]] = {n: PriorityQueue() for n in self.nonterminals()}
        existing_terms: dict[NT, set[Tree[T]]] = {n: set() for n in self.nonterminals()}
        inverse_grammar: dict[NT, deque[tuple[NT, RHSRule[NT, T, G]]]] = {n: deque() for n in self.nonterminals()}
        all_results: set[Tree[T]] = set()

        for n, exprs in self._rules.items():
            for expr in exprs:
                if all(m in self.nonterminals() for m in expr.non_terminals):
                    for m in expr.non_terminals:
                        inverse_grammar[m].append((n, expr))
                    for new_term in self._generate_new_trees(expr, existing_terms):
                        queues[n].put(new_term)
                        if n == start and new_term not in all_results:
                            if max_count is not None and len(all_results) >= max_count:
                                return
                            yield new_term
                            all_results.add(new_term)

        current_bucket_size = 1

        while (max_bucket_size is None or current_bucket_size <= max_bucket_size) and any(
            not queue.empty() for queue in queues.values()
        ):
            non_terminals = {n for n in self.nonterminals() if not queues[n].empty()}

            while non_terminals:
                n = non_terminals.pop()
                results = existing_terms[n]
                while len(results) < current_bucket_size and not queues[n].empty():
                    term = queues[n].get()
                    if term in results:
                        continue
                    results.add(term)
                    for m, expr in inverse_grammar[n]:
                        if len(existing_terms[m]) < current_bucket_size:
                            non_terminals.add(m)
                        if m == start:
                            for new_term in self._generate_new_trees(expr, existing_terms, max_count, (n, term)):
                                if new_term not in all_results:
                                    if max_count is not None and len(all_results) >= max_count:
                                        return
                                    yield new_term
                                    all_results.add(new_term)
                                    queues[start].put(new_term)
                        else:
                            for new_term in self._generate_new_trees(expr, existing_terms, max_bucket_size, (n, term)):
                                queues[m].put(new_term)
            current_bucket_size += 1
        return

    def contains_tree(self, start: NT, tree: Tree[T]) -> bool:
        """Check if the solution space contains a given `tree` derivable from `start`."""
        if start not in self.nonterminals():
            return False

        stack: deque[tuple | Callable] = deque([(start, tree)])
        results: deque[bool] = deque()

        def get_inputs(count: int) -> Generator[bool]:
            for _ in range(count):
                yield results.pop()
            return

        while stack:
            task = stack.pop()
            if isinstance(task, tuple):
                nt, tree = task
                relevant_rhss = [
                    rhs
                    for rhs in self._rules[nt]
                    if len(rhs.arguments) == len(tree.children)
                    and rhs.terminal == tree.root
                    and all(
                        argument.value == child.root and len(child.children) == 0
                        for argument, child in zip(rhs.arguments, tree.children, strict=True)
                        if isinstance(argument, ConstantArgument)
                    )
                ]

                # if there is a relevant rule containing only TerminalArgument which are equal to the children of the tree
                if any(
                    all(isinstance(argument, ConstantArgument) for argument in rhs.arguments) for rhs in relevant_rhss
                ):
                    results.append(True)
                    continue

                # disjunction of the results for individual rules
                def or_inputs(count: int = len(relevant_rhss)) -> None:
                    results.append(any(get_inputs(count)))

                stack.append(or_inputs)

                for rhs in relevant_rhss:
                    substitution = {
                        argument.name: child.root if isinstance(argument, ConstantArgument) else child
                        for argument, child in zip(rhs.arguments, tree.children, strict=True)
                        if argument.name is not None
                    }

                    # conjunction of the results for individual arguments in the rule
                    def and_inputs(
                        count: int = sum(1 for argument in rhs.arguments if isinstance(argument, NonTerminalArgument)),
                        substitution: dict[str, Any] = substitution,
                        predicates=rhs.predicates,
                    ) -> None:
                        results.append(
                            all(get_inputs(count)) and all(predicate(substitution) for predicate in predicates)
                        )

                    stack.append(and_inputs)
                    for argument, child in zip(rhs.arguments, tree.children, strict=True):
                        if isinstance(argument, NonTerminalArgument):
                            stack.append((argument.origin, child))
            elif isinstance(task, FunctionType):
                # task is a function to execute
                task()
        return results.pop()

__getitem__(nonterminal: NT) -> deque[RHSRule[NT, T, G]]

Source code in src/cosy/solution_space.py
def __getitem__(self, nonterminal: NT) -> deque[RHSRule[NT, T, G]]:
    return self._rules[nonterminal]

__init__(rules: dict[NT, deque[RHSRule[NT, T, G]]] | None = None) -> None

Source code in src/cosy/solution_space.py
def __init__(self, rules: dict[NT, deque[RHSRule[NT, T, G]]] | None = None) -> None:
    if rules is None:
        rules = defaultdict(deque)
    self._rules = defaultdict(deque, rules)

add_rule(nonterminal: NT, terminal: T, arguments: tuple[Argument, ...], predicates: tuple[Callable[[dict[str, Any]], bool], ...]) -> None

Source code in src/cosy/solution_space.py
def add_rule(
    self,
    nonterminal: NT,
    terminal: T,
    arguments: tuple[Argument, ...],
    predicates: tuple[Callable[[dict[str, Any]], bool], ...],
) -> None:
    self._rules[nonterminal].append(RHSRule(arguments, predicates, terminal))

as_tuples() -> Iterable[tuple[NT, deque[RHSRule[NT, T, G]]]]

Source code in src/cosy/solution_space.py
def as_tuples(self) -> Iterable[tuple[NT, deque[RHSRule[NT, T, G]]]]:
    return self._rules.items()

contains_tree(start: NT, tree: Tree[T]) -> bool

Check if the solution space contains a given tree derivable from start.

Source code in src/cosy/solution_space.py
def contains_tree(self, start: NT, tree: Tree[T]) -> bool:
    """Check if the solution space contains a given `tree` derivable from `start`."""
    if start not in self.nonterminals():
        return False

    stack: deque[tuple | Callable] = deque([(start, tree)])
    results: deque[bool] = deque()

    def get_inputs(count: int) -> Generator[bool]:
        for _ in range(count):
            yield results.pop()
        return

    while stack:
        task = stack.pop()
        if isinstance(task, tuple):
            nt, tree = task
            relevant_rhss = [
                rhs
                for rhs in self._rules[nt]
                if len(rhs.arguments) == len(tree.children)
                and rhs.terminal == tree.root
                and all(
                    argument.value == child.root and len(child.children) == 0
                    for argument, child in zip(rhs.arguments, tree.children, strict=True)
                    if isinstance(argument, ConstantArgument)
                )
            ]

            # if there is a relevant rule containing only TerminalArgument which are equal to the children of the tree
            if any(
                all(isinstance(argument, ConstantArgument) for argument in rhs.arguments) for rhs in relevant_rhss
            ):
                results.append(True)
                continue

            # disjunction of the results for individual rules
            def or_inputs(count: int = len(relevant_rhss)) -> None:
                results.append(any(get_inputs(count)))

            stack.append(or_inputs)

            for rhs in relevant_rhss:
                substitution = {
                    argument.name: child.root if isinstance(argument, ConstantArgument) else child
                    for argument, child in zip(rhs.arguments, tree.children, strict=True)
                    if argument.name is not None
                }

                # conjunction of the results for individual arguments in the rule
                def and_inputs(
                    count: int = sum(1 for argument in rhs.arguments if isinstance(argument, NonTerminalArgument)),
                    substitution: dict[str, Any] = substitution,
                    predicates=rhs.predicates,
                ) -> None:
                    results.append(
                        all(get_inputs(count)) and all(predicate(substitution) for predicate in predicates)
                    )

                stack.append(and_inputs)
                for argument, child in zip(rhs.arguments, tree.children, strict=True):
                    if isinstance(argument, NonTerminalArgument):
                        stack.append((argument.origin, child))
        elif isinstance(task, FunctionType):
            # task is a function to execute
            task()
    return results.pop()

enumerate_trees(start: NT, max_count: int | None = None, max_bucket_size: int | None = None) -> Iterable[Tree[T]]

Enumerate terms as an iterator efficiently - all terms are enumerated, no guaranteed term order.

Source code in src/cosy/solution_space.py
def enumerate_trees(
    self,
    start: NT,
    max_count: int | None = None,
    max_bucket_size: int | None = None,
) -> Iterable[Tree[T]]:
    """
    Enumerate terms as an iterator efficiently - all terms are enumerated, no guaranteed term order.
    """
    if start not in self.nonterminals():
        return

    queues: dict[NT, PriorityQueue[Tree[T]]] = {n: PriorityQueue() for n in self.nonterminals()}
    existing_terms: dict[NT, set[Tree[T]]] = {n: set() for n in self.nonterminals()}
    inverse_grammar: dict[NT, deque[tuple[NT, RHSRule[NT, T, G]]]] = {n: deque() for n in self.nonterminals()}
    all_results: set[Tree[T]] = set()

    for n, exprs in self._rules.items():
        for expr in exprs:
            if all(m in self.nonterminals() for m in expr.non_terminals):
                for m in expr.non_terminals:
                    inverse_grammar[m].append((n, expr))
                for new_term in self._generate_new_trees(expr, existing_terms):
                    queues[n].put(new_term)
                    if n == start and new_term not in all_results:
                        if max_count is not None and len(all_results) >= max_count:
                            return
                        yield new_term
                        all_results.add(new_term)

    current_bucket_size = 1

    while (max_bucket_size is None or current_bucket_size <= max_bucket_size) and any(
        not queue.empty() for queue in queues.values()
    ):
        non_terminals = {n for n in self.nonterminals() if not queues[n].empty()}

        while non_terminals:
            n = non_terminals.pop()
            results = existing_terms[n]
            while len(results) < current_bucket_size and not queues[n].empty():
                term = queues[n].get()
                if term in results:
                    continue
                results.add(term)
                for m, expr in inverse_grammar[n]:
                    if len(existing_terms[m]) < current_bucket_size:
                        non_terminals.add(m)
                    if m == start:
                        for new_term in self._generate_new_trees(expr, existing_terms, max_count, (n, term)):
                            if new_term not in all_results:
                                if max_count is not None and len(all_results) >= max_count:
                                    return
                                yield new_term
                                all_results.add(new_term)
                                queues[start].put(new_term)
                    else:
                        for new_term in self._generate_new_trees(expr, existing_terms, max_bucket_size, (n, term)):
                            queues[m].put(new_term)
        current_bucket_size += 1
    return

get(nonterminal: NT) -> deque[RHSRule[NT, T, G]] | None

Source code in src/cosy/solution_space.py
def get(self, nonterminal: NT) -> deque[RHSRule[NT, T, G]] | None:
    return self._rules.get(nonterminal)

nonterminals() -> Iterable[NT]

Source code in src/cosy/solution_space.py
def nonterminals(self) -> Iterable[NT]:
    return self._rules.keys()

prune() -> SolutionSpace[NT, T, G]

Keep only productive rules.

Source code in src/cosy/solution_space.py
def prune(self) -> SolutionSpace[NT, T, G]:
    """Keep only productive rules."""

    ground_types: set[NT] = set()
    queue: set[NT] = set()
    inverse_grammar: dict[NT, set[tuple[NT, frozenset[NT]]]] = defaultdict(set)

    for n, exprs in self._rules.items():
        for expr in exprs:
            non_terminals = expr.non_terminals
            for m in non_terminals:
                inverse_grammar[m].add((n, non_terminals))
            if not non_terminals:
                queue.add(n)

    while queue:
        n = queue.pop()
        if n not in ground_types:
            ground_types.add(n)
            for m, non_terminals in inverse_grammar[n]:
                if m not in ground_types and all(t in ground_types for t in non_terminals):
                    queue.add(m)

    return SolutionSpace[NT, T, G](
        defaultdict(
            deque,
            {
                target: deque(
                    possibility
                    for possibility in self._rules[target]
                    if all(t in ground_types for t in possibility.non_terminals)
                )
                for target in ground_types
            },
        )
    )

show() -> str

Source code in src/cosy/solution_space.py
def show(self) -> str:
    return "\n".join(
        f"{nt!s} ~> {' | '.join([str(subrule) for subrule in rule])}" for nt, rule in self._rules.items()
    )

Subtypes

Source code in src/cosy/subtypes.py
class Subtypes:
    def __init__(self, taxonomy: Taxonomy):
        self.taxonomy = self._transitive_closure(self._reflexive_closure(taxonomy))

    def _check_subtype_rec(
        self,
        subtypes: deque[Type],
        supertype: Type,
        groups: Mapping[str, str],
        substitutions: Mapping[str, Literal],
    ) -> bool:
        if supertype.is_omega:
            return True
        match supertype:
            case Literal(value2, group2):
                while subtypes:
                    match subtypes.pop():
                        case Literal(value1, group1):
                            if value2 == value1 and group1 == group2:
                                return True
                        case Var(name1):
                            if groups[name1] == supertype.group and substitutions[name1] == supertype.value:
                                return True
                        case Intersection(l, r):
                            subtypes.extend((l, r))
                return False
            case Constructor(name2, arg2):
                casted_constr: deque[Type] = deque()
                while subtypes:
                    match subtypes.pop():
                        case Constructor(name1, arg1):
                            if name2 == name1 or name2 in self.taxonomy.get(name1, {}):
                                casted_constr.append(arg1)
                        case Intersection(l, r):
                            subtypes.extend((l, r))
                return len(casted_constr) != 0 and self._check_subtype_rec(casted_constr, arg2, groups, substitutions)
            case Arrow(src2, tgt2):
                casted_arr: deque[Type] = deque()
                while subtypes:
                    match subtypes.pop():
                        case Arrow(src1, tgt1):
                            if self._check_subtype_rec(deque((src2,)), src1, groups, substitutions):
                                casted_arr.append(tgt1)
                        case Intersection(l, r):
                            subtypes.extend((l, r))
                return len(casted_arr) != 0 and self._check_subtype_rec(casted_arr, tgt2, groups, substitutions)
            case Intersection(l, r):
                return self._check_subtype_rec(subtypes.copy(), l, groups, substitutions) and self._check_subtype_rec(
                    subtypes, r, groups, substitutions
                )
            case Var(name):
                while subtypes:
                    match subtypes.pop():
                        case Literal(value, group):
                            if groups[name] == group and substitutions[name] == value:
                                return True
                        case Intersection(l, r):
                            subtypes.extend((l, r))
                return False
            case _:
                msg = f"Unsupported type in check_subtype: {supertype}"
                raise TypeError(msg)

    def check_subtype(
        self,
        subtype: Type,
        supertype: Type,
        groups: Mapping[str, str],
        substitutions: Mapping[str, Literal],
    ) -> bool:
        """Decides whether subtype <= supertype with respect to intersection type subtyping."""

        return self._check_subtype_rec(deque((subtype,)), supertype, groups, substitutions)

    def infer_substitution(self, subtype: Type, path: Type, groups: Mapping[str, str]) -> dict[str, Any] | None:
        """Infers a unique substitution S such that S(subtype) <= path where path is closed. Returns None or Ambiguous is no solution exists or multiple solutions exist respectively."""

        if subtype.is_omega:
            return None

        match subtype:
            case Literal(value1, group1):
                match path:
                    case Literal(value2, group2):
                        if value1 == value2 and group1 == group2:
                            return {}
            case Constructor(name1, arg1):
                match path:
                    case Constructor(name2, arg2):
                        if name2 == name1 or name2 in self.taxonomy.get(name1, {}):
                            if arg2.is_omega:
                                return {}
                            return self.infer_substitution(arg1, arg2, groups)
            case Arrow(src1, tgt1):
                match path:
                    case Arrow(src2, tgt2):
                        substitution = self.infer_substitution(tgt1, tgt2, groups)
                        if substitution is None:
                            return None
                        if all(name in substitution for name in src1.free_vars):
                            if self.check_subtype(src2, src1, groups, substitution):
                                return substitution
                            return None
                        return {}  # there are actual non-Ambiguous cases (relevant in practice?)
            case Intersection(l, r):
                substitution1 = self.infer_substitution(l, path, groups)
                substitution2 = self.infer_substitution(r, path, groups)
                if substitution1 is None:
                    return substitution2
                if substitution2 is None:
                    return substitution1
                if all(
                    (name in substitution2 and substitution2[name] == value for name, value in substitution1.items())
                ):
                    return substitution1  # substitution1 included in substitution2
                if all(
                    (name in substitution1 and substitution1[name] == value for name, value in substitution2.items())
                ):
                    return substitution2  # substitution2 included in substitution1
                return {}
            case Var(name):
                match path:
                    case Literal(value2, group2):
                        if groups[name] == group2:
                            return {name: value2}
            case _:
                msg = f"Unsupported type in infer_substitution: {subtype}"
                raise TypeError(msg)
        return None

    @staticmethod
    def _reflexive_closure(env: Mapping[str, set[str]]) -> dict[str, set[str]]:
        all_types: set[str] = set(env.keys())
        for v in env.values():
            all_types.update(v)
        result: dict[str, set[str]] = {subtype: {subtype}.union(env.get(subtype, set())) for subtype in all_types}
        return result

    @staticmethod
    def _transitive_closure(env: Mapping[str, set[str]]) -> dict[str, set[str]]:
        result: dict[str, set[str]] = {subtype: supertypes.copy() for (subtype, supertypes) in env.items()}
        has_changed = True

        while has_changed:
            has_changed = False
            for known_supertypes in result.values():
                for supertype in known_supertypes.copy():
                    to_add: set[str] = {
                        new_supertype for new_supertype in result[supertype] if new_supertype not in known_supertypes
                    }
                    if to_add:
                        has_changed = True
                    known_supertypes.update(to_add)

        return result

taxonomy = self._transitive_closure(self._reflexive_closure(taxonomy)) instance-attribute

__init__(taxonomy: Taxonomy)

Source code in src/cosy/subtypes.py
def __init__(self, taxonomy: Taxonomy):
    self.taxonomy = self._transitive_closure(self._reflexive_closure(taxonomy))

check_subtype(subtype: Type, supertype: Type, groups: Mapping[str, str], substitutions: Mapping[str, Literal]) -> bool

Decides whether subtype <= supertype with respect to intersection type subtyping.

Source code in src/cosy/subtypes.py
def check_subtype(
    self,
    subtype: Type,
    supertype: Type,
    groups: Mapping[str, str],
    substitutions: Mapping[str, Literal],
) -> bool:
    """Decides whether subtype <= supertype with respect to intersection type subtyping."""

    return self._check_subtype_rec(deque((subtype,)), supertype, groups, substitutions)

infer_substitution(subtype: Type, path: Type, groups: Mapping[str, str]) -> dict[str, Any] | None

Infers a unique substitution S such that S(subtype) <= path where path is closed. Returns None or Ambiguous is no solution exists or multiple solutions exist respectively.

Source code in src/cosy/subtypes.py
def infer_substitution(self, subtype: Type, path: Type, groups: Mapping[str, str]) -> dict[str, Any] | None:
    """Infers a unique substitution S such that S(subtype) <= path where path is closed. Returns None or Ambiguous is no solution exists or multiple solutions exist respectively."""

    if subtype.is_omega:
        return None

    match subtype:
        case Literal(value1, group1):
            match path:
                case Literal(value2, group2):
                    if value1 == value2 and group1 == group2:
                        return {}
        case Constructor(name1, arg1):
            match path:
                case Constructor(name2, arg2):
                    if name2 == name1 or name2 in self.taxonomy.get(name1, {}):
                        if arg2.is_omega:
                            return {}
                        return self.infer_substitution(arg1, arg2, groups)
        case Arrow(src1, tgt1):
            match path:
                case Arrow(src2, tgt2):
                    substitution = self.infer_substitution(tgt1, tgt2, groups)
                    if substitution is None:
                        return None
                    if all(name in substitution for name in src1.free_vars):
                        if self.check_subtype(src2, src1, groups, substitution):
                            return substitution
                        return None
                    return {}  # there are actual non-Ambiguous cases (relevant in practice?)
        case Intersection(l, r):
            substitution1 = self.infer_substitution(l, path, groups)
            substitution2 = self.infer_substitution(r, path, groups)
            if substitution1 is None:
                return substitution2
            if substitution2 is None:
                return substitution1
            if all(
                (name in substitution2 and substitution2[name] == value for name, value in substitution1.items())
            ):
                return substitution1  # substitution1 included in substitution2
            if all(
                (name in substitution1 and substitution1[name] == value for name, value in substitution2.items())
            ):
                return substitution2  # substitution2 included in substitution1
            return {}
        case Var(name):
            match path:
                case Literal(value2, group2):
                    if groups[name] == group2:
                        return {name: value2}
        case _:
            msg = f"Unsupported type in infer_substitution: {subtype}"
            raise TypeError(msg)
    return None

Synthesizer

Source code in src/cosy/synthesizer.py
class Synthesizer(Generic[C]):
    def __init__(
        self,
        component_specifications: Mapping[C, Specification],
        parameter_space: ParameterSpace | None = None,
        taxonomy: Taxonomy | None = None,
    ):
        self.literals: ParameterSpace = {} if parameter_space is None else dict(parameter_space.items())
        self.repository: tuple[tuple[C, CombinatorInfo], ...] = tuple(
            (c, Synthesizer._function_types(self.literals, ty)) for c, ty in component_specifications.items()
        )
        self.subtypes = Subtypes(taxonomy if taxonomy is not None else {})

    @staticmethod
    def _function_types(
        literals: ParameterSpace,
        parameterized_type: Specification,
    ) -> CombinatorInfo:
        """Presents a type as a list of 0-ary, 1-ary, ..., n-ary function types."""

        def unary_function_types(ty: Type) -> Iterable[tuple[Type, Type]]:
            tys: deque[Type] = deque((ty,))
            while tys:
                match tys.pop():
                    case Arrow(src, tgt) if not tgt.is_omega:
                        yield (src, tgt)
                    case Intersection(sigma, tau):
                        tys.extend((sigma, tau))

        prefix: list[LiteralParameter | TermParameter | Predicate] = []
        variables: set[str] = set()
        groups: dict[str, str] = {}
        while not isinstance(parameterized_type, Type):
            if isinstance(parameterized_type, Abstraction):
                param = parameterized_type.parameter
                if param.name in variables:
                    # check if parameter names are unique
                    msg = f"Duplicate name: {param.name}"
                    raise ValueError(msg)
                variables.add(param.name)
                if isinstance(param, LiteralParameter):
                    prefix.append(param)
                    groups[param.name] = param.group
                    # check if group is defined in the parameter space
                    if param.group not in literals:
                        msg = f"Group {param.group} is not defined in the parameter space."
                        raise ValueError(msg)
                elif isinstance(param, TermParameter):
                    prefix.append(param)
                    for free_var in param.group.free_vars:
                        if free_var not in groups:
                            # check if each parameter variable is abstracted
                            msg = f"Parameter {free_var} is not abstracted."
                            raise ValueError(msg)
                parameterized_type = parameterized_type.body
            elif isinstance(parameterized_type, Implication):
                prefix.append(parameterized_type.predicate)
                parameterized_type = parameterized_type.body

        for free_var in parameterized_type.free_vars:
            if free_var not in groups:
                # check if each parameter variable is abstracted
                msg = f"Parameter {free_var} is not abstracted."
                raise ValueError(msg)

        current: list[MultiArrow] = [MultiArrow((), parameterized_type)]

        multiarrows = []
        while len(current) != 0:
            multiarrows.append(current)
            current = [
                MultiArrow((*c.args, new_arg), new_tgt)
                for c in current
                for (new_arg, new_tgt) in unary_function_types(c.target)
            ]

        term_predicates: tuple[Callable[[dict[str, Any]], bool], ...] = tuple(
            p.constraint for p in prefix if isinstance(p, Predicate) and not p.only_literals
        )
        return CombinatorInfo(prefix, groups, term_predicates, None, multiarrows)

    def _enumerate_substitutions(
        self,
        prefix: list[LiteralParameter | TermParameter | Predicate],
        substitution: dict[str, Any],
    ) -> Iterable[dict[str, Any]]:
        """Enumerate all substitutions for the given parameters fairly.
        Take initial_substitution with inferred literals into account."""

        stack: deque[tuple[dict[str, Any], int, Iterator[Any] | None]] = deque([(substitution, 0, None)])

        while stack:
            substitution, index, generator = stack.pop()
            if index >= len(prefix):
                # no more parameters to process
                yield substitution
                continue
            parameter = prefix[index]
            if isinstance(parameter, LiteralParameter):
                if generator is None:
                    if parameter.name in substitution:
                        value = substitution[parameter.name]
                        if parameter.values is not None and value not in parameter.values(substitution):
                            # the inferred value is not in the set of values
                            continue
                        if value not in self.literals[parameter.group]:
                            # the inferred value is not in the group
                            continue
                        stack.appendleft((substitution, index + 1, None))
                    elif parameter.values is not None:
                        stack.appendleft((substitution, index, iter(parameter.values(substitution))))
                    else:
                        concrete_values = self.literals[parameter.group]
                        if not isinstance(concrete_values, Iterable):
                            msg = f"The value of {parameter.name} could not be inferred."
                            raise RuntimeError(msg)
                        else:
                            stack.appendleft((substitution, index, iter(concrete_values)))
                else:
                    try:
                        value = next(generator)
                    except StopIteration:
                        continue
                    if value in self.literals[parameter.group]:
                        stack.appendleft(({**substitution, parameter.name: value}, index + 1, None))
                    stack.appendleft((substitution, index, generator))

            elif isinstance(parameter, Predicate) and parameter.only_literals:
                if parameter.constraint(substitution):
                    # the predicate is satisfied
                    stack.appendleft((substitution, index + 1, None))
            else:
                stack.appendleft((substitution, index + 1, None))

    def _subqueries(
        self,
        nary_types: list[MultiArrow],
        paths: Iterable[Type],
        groups: dict[str, str],
        substitution: dict[str, Any],
    ) -> Sequence[list[Type]]:
        # does the target of a multi-arrow contain a given type?
        def target_contains(m: MultiArrow, t: Type) -> bool:
            return self.subtypes.check_subtype(m.target, t, groups, substitution)

        # cover target using targets of multi-arrows in nary_types
        covers = minimal_covers(nary_types, paths, target_contains)
        if len(covers) == 0:
            return []

        # intersect corresponding arguments of multi-arrows in each cover
        def intersect_args(args1: Iterable[Type], args2: Iterable[Type]) -> tuple[Type, ...]:
            return tuple(Intersection(a, b) for a, b in zip(args1, args2, strict=False))

        intersected_args: Generator[list[Type]] = (list(reduce(intersect_args, (m.args for m in ms))) for ms in covers)

        # consider only maximal argument vectors
        def compare_args(args1, args2) -> bool:
            return all(
                map(
                    lambda a, b: self.subtypes.check_subtype(a, b, groups, substitution),
                    args1,
                    args2,
                )
            )

        return maximal_elements(intersected_args, compare_args)

    def _necessary_substitution(
        self,
        paths: Iterable[Type],
        combinator_type: list[list[MultiArrow]],
        groups: dict[str, str],
    ) -> dict[str, Any] | None:
        """
        Computes a substitution that needs to be part of every substitution S such that
        S(combinator_type) <= paths.

        If no substitution can make this valid, None is returned.
        """

        result: dict[str, Any] = {}

        for path in paths:
            unique_substitution: dict[str, Any] | None = None
            is_unique = True

            for nary_types in combinator_type:
                for ty in nary_types:
                    substitution = self.subtypes.infer_substitution(ty.target, path, groups)
                    if substitution is None:
                        continue
                    if unique_substitution is None:
                        unique_substitution = substitution
                    else:
                        is_unique = False
                        break
                if not is_unique:
                    break

            if unique_substitution is None:
                return None  # no substitution for this path
            if not is_unique:
                continue  # substitution not unique substitution — skip

            # merge consistent substitution
            for k, v in unique_substitution.items():
                if k in result:
                    if result[k] != v:
                        return None  # conflict in necessary substitution
                else:
                    result[k] = v

        return result

    def construct_solution_space_rules(self, *targets: Type) -> Generator[tuple[Type, RHSRule]]:
        """Generate logic program rules for the given target types."""

        # current target types
        stack: deque[tuple[Type, tuple[C, CombinatorInfo, Iterator] | None]] = deque(
            (target, None) for target in targets
        )
        seen: set[Type] = set()

        while stack:
            current_target, current_target_info = stack.pop()
            # if the target is omega, then the result is junk
            if current_target.is_omega:
                msg = f"Target type {current_target} is omega."
                raise ValueError(msg)

            # target type was not initialized before
            if current_target not in seen or current_target_info is not None:
                if current_target_info is None:
                    seen.add(current_target)
                    # try each combinator
                    for combinator, combinator_info in self.repository:
                        # Compute necessary substitutions
                        substitution = self._necessary_substitution(
                            current_target.organized,
                            combinator_info.type,
                            combinator_info.groups,
                        )

                        # If there cannot be a suitable substitution, ignore this combinator
                        if substitution is None:
                            continue

                        # Keep necessary substitutions and enumerate the rest
                        selected_instantiations = self._enumerate_substitutions(combinator_info.prefix, substitution)
                        stack.appendleft(
                            (
                                current_target,
                                (
                                    combinator,
                                    combinator_info,
                                    iter(selected_instantiations),
                                ),
                            )
                        )
                else:
                    combinator, combinator_info, selected_instantiations = current_target_info
                    instantiation = next(selected_instantiations, None)
                    if instantiation is not None:
                        stack.appendleft((current_target, current_target_info))
                        named_arguments: tuple[Argument, ...] | None = None

                        # and every arity of the combinator type
                        for nary_types in combinator_info.type:
                            for subquery in self._subqueries(
                                nary_types,
                                current_target.organized,
                                combinator_info.groups,
                                instantiation,
                            ):
                                if named_arguments is None:  # do this only once for each instantiation
                                    named_arguments = tuple(
                                        ConstantArgument(
                                            param.name,
                                            instantiation[param.name],
                                            combinator_info.groups[param.name],
                                        )
                                        if isinstance(param, LiteralParameter)
                                        else NonTerminalArgument(
                                            param.name,
                                            param.group.subst(
                                                combinator_info.groups,
                                                instantiation,
                                            ),
                                        )
                                        for param in combinator_info.prefix
                                        if isinstance(param, Parameter)
                                    )
                                    stack.extendleft(
                                        (argument.origin, None)
                                        for argument in named_arguments
                                        if isinstance(argument, NonTerminalArgument)
                                    )

                                anonymous_arguments: tuple[Argument, ...] = tuple(
                                    NonTerminalArgument(
                                        None,
                                        ty.subst(combinator_info.groups, instantiation),
                                    )
                                    for ty in subquery
                                )
                                yield (
                                    current_target,
                                    RHSRule[Type, Any, str](
                                        (*named_arguments, *anonymous_arguments),
                                        combinator_info.term_predicates,
                                        combinator,
                                    ),
                                )
                                stack.extendleft((q.origin, None) for q in anonymous_arguments)

    def construct_solution_space(self, *targets: Type) -> SolutionSpace[Type, C, str]:
        """Constructs a logic program in the current environment for the given target types."""

        solution_space: SolutionSpace[Type, C, str] = SolutionSpace()
        for nt, rule in self.construct_solution_space_rules(*targets):
            solution_space.add_rule(nt, rule.terminal, rule.arguments, rule.predicates)

        return solution_space

literals: ParameterSpace = {} if parameter_space is None else dict(parameter_space.items()) instance-attribute

repository: tuple[tuple[C, CombinatorInfo], ...] = tuple((c, Synthesizer._function_types(self.literals, ty)) for (c, ty) in component_specifications.items()) instance-attribute

subtypes = Subtypes(taxonomy if taxonomy is not None else {}) instance-attribute

__init__(component_specifications: Mapping[C, Specification], parameter_space: ParameterSpace | None = None, taxonomy: Taxonomy | None = None)

Source code in src/cosy/synthesizer.py
def __init__(
    self,
    component_specifications: Mapping[C, Specification],
    parameter_space: ParameterSpace | None = None,
    taxonomy: Taxonomy | None = None,
):
    self.literals: ParameterSpace = {} if parameter_space is None else dict(parameter_space.items())
    self.repository: tuple[tuple[C, CombinatorInfo], ...] = tuple(
        (c, Synthesizer._function_types(self.literals, ty)) for c, ty in component_specifications.items()
    )
    self.subtypes = Subtypes(taxonomy if taxonomy is not None else {})

construct_solution_space(*targets: Type) -> SolutionSpace[Type, C, str]

Constructs a logic program in the current environment for the given target types.

Source code in src/cosy/synthesizer.py
def construct_solution_space(self, *targets: Type) -> SolutionSpace[Type, C, str]:
    """Constructs a logic program in the current environment for the given target types."""

    solution_space: SolutionSpace[Type, C, str] = SolutionSpace()
    for nt, rule in self.construct_solution_space_rules(*targets):
        solution_space.add_rule(nt, rule.terminal, rule.arguments, rule.predicates)

    return solution_space

construct_solution_space_rules(*targets: Type) -> Generator[tuple[Type, RHSRule]]

Generate logic program rules for the given target types.

Source code in src/cosy/synthesizer.py
def construct_solution_space_rules(self, *targets: Type) -> Generator[tuple[Type, RHSRule]]:
    """Generate logic program rules for the given target types."""

    # current target types
    stack: deque[tuple[Type, tuple[C, CombinatorInfo, Iterator] | None]] = deque(
        (target, None) for target in targets
    )
    seen: set[Type] = set()

    while stack:
        current_target, current_target_info = stack.pop()
        # if the target is omega, then the result is junk
        if current_target.is_omega:
            msg = f"Target type {current_target} is omega."
            raise ValueError(msg)

        # target type was not initialized before
        if current_target not in seen or current_target_info is not None:
            if current_target_info is None:
                seen.add(current_target)
                # try each combinator
                for combinator, combinator_info in self.repository:
                    # Compute necessary substitutions
                    substitution = self._necessary_substitution(
                        current_target.organized,
                        combinator_info.type,
                        combinator_info.groups,
                    )

                    # If there cannot be a suitable substitution, ignore this combinator
                    if substitution is None:
                        continue

                    # Keep necessary substitutions and enumerate the rest
                    selected_instantiations = self._enumerate_substitutions(combinator_info.prefix, substitution)
                    stack.appendleft(
                        (
                            current_target,
                            (
                                combinator,
                                combinator_info,
                                iter(selected_instantiations),
                            ),
                        )
                    )
            else:
                combinator, combinator_info, selected_instantiations = current_target_info
                instantiation = next(selected_instantiations, None)
                if instantiation is not None:
                    stack.appendleft((current_target, current_target_info))
                    named_arguments: tuple[Argument, ...] | None = None

                    # and every arity of the combinator type
                    for nary_types in combinator_info.type:
                        for subquery in self._subqueries(
                            nary_types,
                            current_target.organized,
                            combinator_info.groups,
                            instantiation,
                        ):
                            if named_arguments is None:  # do this only once for each instantiation
                                named_arguments = tuple(
                                    ConstantArgument(
                                        param.name,
                                        instantiation[param.name],
                                        combinator_info.groups[param.name],
                                    )
                                    if isinstance(param, LiteralParameter)
                                    else NonTerminalArgument(
                                        param.name,
                                        param.group.subst(
                                            combinator_info.groups,
                                            instantiation,
                                        ),
                                    )
                                    for param in combinator_info.prefix
                                    if isinstance(param, Parameter)
                                )
                                stack.extendleft(
                                    (argument.origin, None)
                                    for argument in named_arguments
                                    if isinstance(argument, NonTerminalArgument)
                                )

                            anonymous_arguments: tuple[Argument, ...] = tuple(
                                NonTerminalArgument(
                                    None,
                                    ty.subst(combinator_info.groups, instantiation),
                                )
                                for ty in subquery
                            )
                            yield (
                                current_target,
                                RHSRule[Type, Any, str](
                                    (*named_arguments, *anonymous_arguments),
                                    combinator_info.term_predicates,
                                    combinator,
                                ),
                            )
                            stack.extendleft((q.origin, None) for q in anonymous_arguments)

Type dataclass

Source code in src/cosy/types.py
@dataclass(frozen=True)
class Type(ABC):
    is_omega: bool = field(init=True, kw_only=True, compare=False)
    size: int = field(init=True, kw_only=True, compare=False)
    organized: set[Type] = field(init=True, kw_only=True, compare=False)
    free_vars: set[str] = field(init=True, kw_only=True, compare=False)

    @abstractmethod
    def __str__(self) -> str:
        pass

    @abstractmethod
    def _organized(self) -> set[Type]:
        pass

    @abstractmethod
    def _size(self) -> int:
        pass

    @abstractmethod
    def _is_omega(self) -> bool:
        pass

    @abstractmethod
    def _free_vars(self) -> set[str]:
        pass

    @abstractmethod
    def subst(self, groups: Mapping[str, str], substitution: dict[str, Any]) -> Type:
        pass

    @staticmethod
    def intersect(types: Sequence[Type]) -> Type:
        if len(types) > 0:
            rtypes = reversed(types)
            result: Type = next(rtypes)
            for ty in rtypes:
                result = Intersection(ty, result)
            return result
        return Omega()

    def __getstate__(self) -> dict[str, Any]:
        state = self.__dict__.copy()
        del state["is_omega"]
        del state["size"]
        del state["organized"]
        return state

    def __setstate__(self, state: dict[str, Any]) -> None:
        self.__dict__.update(state)
        self.__dict__["is_omega"] = self._is_omega()
        self.__dict__["size"] = self._size()
        self.__dict__["organized"] = self._organized()

    def __pow__(self, other: Type) -> Type:
        return Arrow(self, other)

    def __and__(self, other: Type) -> Type:
        return Intersection(self, other)

    def __rmatmul__(self, name: str) -> Type:
        return Constructor(name, self)

free_vars: set[str] = field(init=True, kw_only=True, compare=False) class-attribute instance-attribute

is_omega: bool = field(init=True, kw_only=True, compare=False) class-attribute instance-attribute

organized: set[Type] = field(init=True, kw_only=True, compare=False) class-attribute instance-attribute

size: int = field(init=True, kw_only=True, compare=False) class-attribute instance-attribute

__and__(other: Type) -> Type

Source code in src/cosy/types.py
def __and__(self, other: Type) -> Type:
    return Intersection(self, other)

__getstate__() -> dict[str, Any]

Source code in src/cosy/types.py
def __getstate__(self) -> dict[str, Any]:
    state = self.__dict__.copy()
    del state["is_omega"]
    del state["size"]
    del state["organized"]
    return state

__init__(*, is_omega: bool, size: int, organized: set[Type], free_vars: set[str]) -> None

__pow__(other: Type) -> Type

Source code in src/cosy/types.py
def __pow__(self, other: Type) -> Type:
    return Arrow(self, other)

__rmatmul__(name: str) -> Type

Source code in src/cosy/types.py
def __rmatmul__(self, name: str) -> Type:
    return Constructor(name, self)

__setstate__(state: dict[str, Any]) -> None

Source code in src/cosy/types.py
def __setstate__(self, state: dict[str, Any]) -> None:
    self.__dict__.update(state)
    self.__dict__["is_omega"] = self._is_omega()
    self.__dict__["size"] = self._size()
    self.__dict__["organized"] = self._organized()

__str__() -> str abstractmethod

Source code in src/cosy/types.py
@abstractmethod
def __str__(self) -> str:
    pass

intersect(types: Sequence[Type]) -> Type staticmethod

Source code in src/cosy/types.py
@staticmethod
def intersect(types: Sequence[Type]) -> Type:
    if len(types) > 0:
        rtypes = reversed(types)
        result: Type = next(rtypes)
        for ty in rtypes:
            result = Intersection(ty, result)
        return result
    return Omega()

subst(groups: Mapping[str, str], substitution: dict[str, Any]) -> Type abstractmethod

Source code in src/cosy/types.py
@abstractmethod
def subst(self, groups: Mapping[str, str], substitution: dict[str, Any]) -> Type:
    pass

Var dataclass

Source code in src/cosy/types.py
@dataclass(frozen=True)
class Var(Type):
    name: str
    is_omega: bool = field(init=False, compare=False)
    size: int = field(init=False, compare=False)
    organized: set[Type] = field(init=False, compare=False)
    free_vars: set[str] = field(init=False, compare=False)

    def __post_init__(self) -> None:
        super().__init__(
            is_omega=self._is_omega(),
            size=self._size(),
            organized=self._organized(),
            free_vars=self._free_vars(),
        )

    def _is_omega(self) -> bool:
        return False

    def _size(self) -> int:
        return 1

    def _organized(self) -> set[Type]:
        return {self}

    def _free_vars(self) -> set[str]:
        return {self.name}

    def __str__(self) -> str:
        return f"<{self.name!s}>"

    def subst(self, groups: Mapping[str, str], substitution: dict[str, Any]) -> Type:
        if self.name in substitution:
            return Literal(substitution[self.name], groups[self.name])
        return self

free_vars: set[str] = field(init=False, compare=False) class-attribute instance-attribute

is_omega: bool = field(init=False, compare=False) class-attribute instance-attribute

name: str instance-attribute

organized: set[Type] = field(init=False, compare=False) class-attribute instance-attribute

size: int = field(init=False, compare=False) class-attribute instance-attribute

__init__(name: str, *, is_omega: bool, size: int, organized: set[Type], free_vars: set[str]) -> None

__post_init__() -> None

Source code in src/cosy/types.py
def __post_init__(self) -> None:
    super().__init__(
        is_omega=self._is_omega(),
        size=self._size(),
        organized=self._organized(),
        free_vars=self._free_vars(),
    )

__str__() -> str

Source code in src/cosy/types.py
def __str__(self) -> str:
    return f"<{self.name!s}>"

subst(groups: Mapping[str, str], substitution: dict[str, Any]) -> Type

Source code in src/cosy/types.py
def subst(self, groups: Mapping[str, str], substitution: dict[str, Any]) -> Type:
    if self.name in substitution:
        return Literal(substitution[self.name], groups[self.name])
    return self