Source code for data_morph.shapes.factory

"""Factory class for generating shape objects."""

from itertools import zip_longest
from numbers import Number
from typing import ClassVar

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes

from ..data.dataset import Dataset
from ..plotting.style import plot_with_custom_style
from .bases.shape import Shape
from .circles import Bullseye, Circle, Rings
from .lines import (
    Diamond,
    HighLines,
    HorizontalLines,
    Rectangle,
    SlantDownLines,
    SlantUpLines,
    Star,
    VerticalLines,
    WideLines,
    XLines,
)
from .points import (
    Club,
    DotsGrid,
    DownParabola,
    Heart,
    LeftParabola,
    RightParabola,
    Scatter,
    Spade,
    UpParabola,
)


[docs] class ShapeFactory: """ Factory for generating shape objects based on data. .. plot:: :caption: Target shapes currently available. from data_morph.data.loader import DataLoader from data_morph.shapes.factory import ShapeFactory dataset = DataLoader.load_dataset('panda') _ = ShapeFactory(dataset).plot_available_shapes() Parameters ---------- dataset : Dataset The starting dataset to morph into other shapes. """ _SHAPE_CLASSES: tuple[type[Shape]] = ( Bullseye, Circle, Club, Diamond, DotsGrid, DownParabola, Heart, HighLines, HorizontalLines, LeftParabola, Rectangle, RightParabola, Rings, Scatter, SlantDownLines, SlantUpLines, Spade, Star, UpParabola, VerticalLines, WideLines, XLines, ) """New shape classes must be registered here.""" _SHAPE_MAPPING: ClassVar[dict[str, type[Shape]]] = { shape_cls.get_name(): shape_cls for shape_cls in _SHAPE_CLASSES } """Mapping of shape display names to classes.""" AVAILABLE_SHAPES: list[str] = sorted(_SHAPE_MAPPING.keys()) """The list of available shapes, which can be visualized with :meth:`.plot_available_shapes`.""" def __init__(self, dataset: Dataset) -> None: self._dataset: Dataset = dataset
[docs] def generate_shape(self, shape: str, **kwargs: Number) -> Shape: """ Generate the shape object based on the dataset. Parameters ---------- shape : str The desired shape. See :attr:`.AVAILABLE_SHAPES`. **kwargs Additional keyword arguments to pass down when creating the shape. Returns ------- Shape An shape object of the requested type. """ try: return self._SHAPE_MAPPING[shape](self._dataset, **kwargs) except KeyError as err: raise ValueError(f'No such shape as {shape}.') from err
[docs] @plot_with_custom_style def plot_available_shapes(self) -> Axes: """ Plot the available target shapes. Returns ------- matplotlib.axes.Axes The :class:`~matplotlib.axes.Axes` object containing the plot. See Also -------- AVAILABLE_SHAPES The list of available shapes. """ num_cols = 5 num_plots = len(self.AVAILABLE_SHAPES) num_rows = int(np.ceil(num_plots / num_cols)) fig, axs = plt.subplots( num_rows, num_cols, layout='constrained', figsize=(10, 2 * num_rows), ) fig.get_layout_engine().set(w_pad=0.2, h_pad=0.2) for shape, ax in zip_longest(self.AVAILABLE_SHAPES, axs.flatten()): if shape: ax.tick_params( axis='both', which='both', bottom=False, left=False, right=False, labelbottom=False, labelleft=False, ) shape_obj = self.generate_shape(shape) ax = shape_obj.plot(ax=ax).set( xlabel='', ylabel='', title=str(shape_obj) ) else: ax.remove() return axs