Source code for data_morph.data.loader

"""Load data for morphing."""

from __future__ import annotations

from importlib.resources import files
from itertools import zip_longest
from numbers import Number
from pathlib import Path
from typing import ClassVar

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.axes import Axes

from .. import MAIN_DIR
from ..plotting.style import plot_with_custom_style
from .dataset import Dataset


[docs] class DataLoader: """ Class for loading datasets for morphing. .. plot:: :caption: Datasets currently included in Data Morph. The ``dino`` shape is Alberto Cairo's `Datasaurus`_. All logos are used with permission (see below). Note that CSV files are also supported by the :meth:`.load_dataset` method. from data_morph.data.loader import DataLoader DataLoader.plot_available_datasets() .. _Datasaurus: https://web.archive.org/web/20240620205540/http://www.thefunctionalart.com/2016/08/download-datasaurus-never-trust-summary.html .. important:: All logos are used with permission. * The Python logo is a `trademark of the Python Software Foundation (PSF) <https://www.python.org/psf/trademarks/>`_, used with permission from the Foundation. * The `SuperDataScience (SDS) <https://www.superdatascience.com/podcast>`_ logo is used with permission. """ _DATA_PATH: str = 'data/starter_shapes/' _DATASETS: ClassVar[dict[str, str]] = { 'bunny': 'bunny.csv', 'cat': 'cat.csv', 'dino': 'dino.csv', 'dog': 'dog.csv', 'gorilla': 'gorilla.csv', 'music': 'music.csv', 'panda': 'panda.csv', 'pi': 'pi.csv', 'Python': 'python.csv', 'SDS': 'superdatascience.csv', 'sheep': 'sheep.csv', 'soccer': 'soccer_ball.csv', } AVAILABLE_DATASETS = sorted(_DATASETS.keys(), key=str.lower) """list[str]: List of available built-in starter datasets, which can be visualized with :meth:`plot_available_datasets`.""" def __init__(self) -> None: raise NotImplementedError
[docs] @classmethod def load_dataset( cls, dataset: str, scale: Number | None = None, ) -> Dataset: """ Load dataset. Parameters ---------- dataset : str Either one of :attr:`AVAILABLE_DATASETS` or a path to a CSV file containing two columns: x and y. scale : numbers.Number, optional The factor to scale the data by (can be used to speed up morphing). Values in the data's x and y columns will be divided by this value. Returns ------- Dataset The starting dataset for morphing. Notes ----- If you are looking to create a :class:`.Dataset` from a :class:`~pandas.DataFrame` object, use the :class:`.Dataset` class directly. """ if dataset == 'python': dataset = dataset.capitalize() elif dataset == 'sds': dataset = dataset.upper() try: filepath = files(MAIN_DIR).joinpath( Path(cls._DATA_PATH) / cls._DATASETS[dataset] ) name = dataset df = pd.read_csv(filepath) except KeyError: try: name = Path(dataset).stem df = pd.read_csv(dataset) except FileNotFoundError as err: raise ValueError( f'Unknown dataset "{dataset}". ' 'Provide a valid path to a CSV dataset or use one of ' f'the included datasets: {", ".join(cls.AVAILABLE_DATASETS)}.' ) from err return Dataset(name=name, df=df, scale=scale)
[docs] @classmethod @plot_with_custom_style def plot_available_datasets(cls) -> Axes: """ Plot the built-in datasets. Returns ------- matplotlib.axes.Axes The :class:`~matplotlib.axes.Axes` object containing the plot. See Also -------- AVAILABLE_DATASETS The list of available datasets built into Data Morph. """ num_plots = len(cls.AVAILABLE_DATASETS) num_cols = 3 num_rows = int(np.ceil(num_plots / num_cols)) fig, axs = plt.subplots( num_rows, num_cols, layout='constrained', figsize=(12, 4 * num_rows), subplot_kw={'aspect': 'equal'}, ) fig.get_layout_engine().set(w_pad=0.2, h_pad=0.2) for dataset, ax in zip_longest(cls.AVAILABLE_DATASETS, axs.flatten()): if dataset: ax.tick_params( axis='both', which='both', bottom=False, left=False, right=False, labelbottom=False, labelleft=False, ) points = cls.load_dataset(dataset) if dataset == 'Python': dataset += ' logo\u2122' elif dataset == 'SDS': dataset += ' logo' ax.scatter(points.df.x, points.df.y, s=4, color='black') # tight plot bounds for the grid of datasets in the docs bounds = points.data_bounds.clone() x_offset, y_offset = (offset * 0.1 for offset in bounds.range) bounds.adjust_bounds(x=x_offset, y=y_offset) bounds.align_aspect_ratio() ax.set( title=f'{dataset} ({points.df.shape[0]:,d} points)', xlim=bounds.x_bounds, ylim=bounds.y_bounds, xlabel='', ylabel='', ) else: ax.remove() return axs