"""Utility functions for static plotting."""fromfunctoolsimportpartialfromnumbersimportNumberfrompathlibimportPathfromtypingimportIterable,Unionimportmatplotlib.pyplotaspltimportnumpyasnpimportpandasaspdfrommatplotlib.axesimportAxesfrommatplotlib.tickerimportEngFormatterfrom..data.statsimportget_valuesfrom.styleimportplot_with_custom_style
[docs]@plot_with_custom_styledefplot(df:pd.DataFrame,x_bounds:Iterable[Number],y_bounds:Iterable[Number],save_to:Union[str,Path],decimals:int,**save_kwds,)->Union[Axes,None]:""" Plot the dataset and summary statistics. Parameters ---------- df : pandas.DataFrame The dataset to plot. x_bounds, y_bounds : Iterable[numbers.Number] The plotting limits. save_to : str or pathlib.Path Path to save the plot frame to. decimals : int The number of integers to highlight as preserved. **save_kwds Additional keyword arguments that will be passed down to :meth:`matplotlib.figure.Figure.savefig`. Returns ------- matplotlib.axes.Axes or None When ``save_to`` is falsey, an :class:`~matplotlib.axes.Axes` object is returned. """fig,ax=plt.subplots(figsize=(7,3),layout='constrained',subplot_kw={'aspect':'equal'})fig.get_layout_engine().set(w_pad=1.4,h_pad=0.2,wspace=0)ax.scatter(df.x,df.y,s=1,alpha=0.7,color='black')ax.set(xlim=x_bounds,ylim=y_bounds)tick_formatter=EngFormatter()ax.xaxis.set_major_formatter(tick_formatter)ax.yaxis.set_major_formatter(tick_formatter)res=get_values(df)labels=('X Mean','Y Mean','X SD','Y SD','Corr.')locs=np.linspace(0.8,0.2,num=len(labels))max_label_length=max([len(label)forlabelinlabels])max_stat=int(np.log10(np.max(np.abs(res))))+1# If `max_label_length = 10`, this string will be "{:<10}: {:0.7f}", then we# can pull the `.format` method for that string to reduce typing it# repeatedlyvisible_decimals=7offset=2ifres.x_mean<0orres.y_mean<0else1formatter='{{:<{pad}}}: {{:{stat_pad}.{decimals}f}}'.format(pad=max_label_length,stat_pad=max_stat+visible_decimals+offset,decimals=visible_decimals,).formatcorr_formatter='{{:<{pad}}}: {{:+{corr_pad}.{decimals}f}}'.format(pad=max_label_length,corr_pad=max_stat+visible_decimals+offset,decimals=visible_decimals,).formatstat_clip=visible_decimals-decimalsadd_stat_text=partial(ax.text,1.05,fontsize=15,transform=ax.transAxes,va='center',)forlabel,loc,statinzip(labels[:-1],locs,res):add_stat_text(loc,formatter(label,stat),alpha=0.3)add_stat_text(loc,formatter(label,stat)[:-stat_clip])correlation_str=corr_formatter(labels[-1],res.correlation)foralpha,textinzip([0.3,1],[correlation_str,correlation_str[:-stat_clip]]):add_stat_text(locs[-1],text,alpha=alpha,)ifnotsave_to:fig.tight_layout()returnaxsave_to=Path(save_to)dirname=save_to.parentifnotdirname.is_dir():dirname.mkdir(parents=True,exist_ok=True)fig.savefig(save_to,bbox_inches='tight',**save_kwds)plt.close(fig)