Source code for gama.genetic_programming.components.primitive_node

from typing import List, Union
from .terminal import DATA_TERMINAL, Terminal
from .primitive import Primitive


[docs]class PrimitiveNode: """ An instantiation for a Primitive with specific Terminals. Parameters ---------- primitive: Primitive The Primitive type of this PrimitiveNode. data_node: PrimitiveNode The PrimitiveNode that specifies all preprocessing before this PrimitiveNode. terminals: List[Terminal] A list of terminals matching the `primitive`. """ def __init__( self, primitive: Primitive, data_node: Union["PrimitiveNode", str], terminals: List[Terminal], ): self._primitive = primitive self._data_node = data_node self._terminals = sorted(terminals, key=lambda t: str(t)) def __str__(self): """ Recursively stringify all primitive nodes (primitive and hyperparameters). Examples: - "GaussianNB(data)" - "BernoulliNB(data, alpha=1.0)" - "BernoulliNB(FastICA(data, tol=0.5), alpha=1.0)" """ if self._terminals: terminal_str = ", ".join([repr(terminal) for terminal in self._terminals]) return f"{self._primitive}({self._data_node}, {terminal_str})" else: return f"{self._primitive}({self._data_node})" @property def str_nonrecursive(self) -> str: """ Stringify all primitive node without data node (primitive and hyperparameters). Examples: - "GaussianNB()" - "BernoulliNB(alpha=1.0)" """ terminal_str = ", ".join([str(terminal) for terminal in self._terminals]) return f"{self._primitive}({terminal_str})" def copy(self): """ Copies the object. Shallow for terminals, deep for data_node. """ if self._data_node == DATA_TERMINAL: data_node_copy = DATA_TERMINAL else: data_node_copy = self._data_node.copy() return PrimitiveNode( primitive=self._primitive, data_node=data_node_copy, terminals=self._terminals.copy(), ) @classmethod def from_string(cls, string: str, primitive_set: dict): """ Create a PrimitiveNode from string formatted like PrimitiveNode.__str__ Parameters ---------- string: str A string formatted similar to PrimitiveNode.__str__ primitive_set: dict The dictionary defining all Terminals and Primitives. Returns ------- PrimitiveNode The PrimitiveNode as defined the string. """ # General form is # A(B(C(data[, C.param=value, ..])[, B.param=value, ..])[, A.param=value, ..]) # below assumes that left parenthesis is never part of a parameter name or value primitives = string.split("(")[:-1] terminal_start_index = string.index(DATA_TERMINAL) terminals_string = string[terminal_start_index + len(DATA_TERMINAL) :] terminal_sets = terminals_string.split(")")[:-1] last_node: Union[PrimitiveNode, str] = DATA_TERMINAL for primitive_string, terminal_set in zip(reversed(primitives), terminal_sets): primitive = find_primitive(primitive_set, primitive_string) if terminal_set == "": terminals = [] else: terminal_set = terminal_set[2:] # 2 is because string starts with ', ' terminals = [ find_terminal(primitive_set, terminal_string) for terminal_string in terminal_set.split(", ") ] missing = set(primitive.input) - set(map(lambda t: t.identifier, terminals)) if missing: raise ValueError(f"terminals {missing} for primitive {primitive}") last_node = cls(primitive, last_node, terminals) return last_node
def find_primitive(primitive_set: dict, primitive_string: str) -> Primitive: """ Find the Primitive that matches `primitive_string` in `primitive_set`. """ all_primitives = primitive_set[DATA_TERMINAL] + primitive_set["prediction"] for primitive in all_primitives: if repr(primitive) == primitive_string: return primitive raise IndexError(f"Could not find Primitive of type '{primitive_string}'.") def find_terminal(primitive_set: dict, terminal_string: str) -> Terminal: """ Find the Terminal that matches `terminal_string` in `primitive_set`. """ term_type, _ = terminal_string.split("=") for terminal in primitive_set[term_type]: if repr(terminal) == terminal_string: return terminal raise RuntimeError(f"Could not find Terminal of type '{terminal_string}'.")