Source code for slickml.visualization._metrics

from typing import Any, Dict, Optional, Tuple

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as scp
import seaborn as sns
from matplotlib.figure import Figure
from sklearn.metrics import accuracy_score

from slickml.utils import check_var

# TODO(amir): this options should be set globally too
sns.set_style("ticks")
mpl.rcParams["axes.linewidth"] = 2
mpl.rcParams["lines.linewidth"] = 2


# TODO(amir): globally adjust options, fonts, ...
[docs] def plot_binary_classification_metrics( figsize: Optional[Tuple[float, float]] = (12, 12), save_path: Optional[str] = None, display_plot: Optional[bool] = False, return_fig: Optional[bool] = False, **kwargs: Dict[str, Any], ) -> Optional[Figure]: """Visualizes binary classification metrics using ``plotting_dict_`` attribute of ``BinaryClassificationMetrics``. Parameters ---------- figsize : Tuple[float, float], optional Figure size, by default (12, 12) save_path : str, optional The full or relative path to save the plot including the image format such as "myplot.png" or "../../myplot.pdf", by default None display_plot : bool, optional Whether to show the plot, by default False return_fig : bool, optional Whether to return figure object, by default False **kwargs : Dict[str, Any] Key-value pairs of regression metrics plot Returns ------- Figure, optional """ check_var( figsize, var_name="figsize", dtypes=tuple, ) check_var( display_plot, var_name="display_plot", dtypes=bool, ) check_var( return_fig, var_name="return_fig", dtypes=bool, ) if save_path: check_var( save_path, var_name="save_path", dtypes=str, ) # TODO(amir): move this to a function ? # prepare thresholds for plotting thr_set1 = np.linspace( start=min(kwargs["roc_thresholds"]), stop=max(kwargs["roc_thresholds"]), num=1000, ) thr_set2 = np.linspace( start=min(kwargs["pr_thresholds"]), stop=max(kwargs["pr_thresholds"]), num=1000, ) f1_list = [ 2 * (kwargs["precision_list"][i] * kwargs["recall_list"][i]) # type: ignore / (kwargs["precision_list"][i] + kwargs["recall_list"][i]) # type: ignore for i in range(len(kwargs["precision_list"])) ] queue_rate_list = [(kwargs["y_pred_proba"] >= thr).mean() for thr in thr_set2] accuracy_list = [ accuracy_score(kwargs["y_true"], (kwargs["y_pred_proba"] >= thr).astype(int)) for thr in thr_set1 ] fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots( nrows=2, ncols=2, figsize=figsize, ) # ----------------------------------- # subplot 1: roc curve ax1.plot( kwargs["fpr_list"], kwargs["tpr_list"], color="red", label=f"AUC = {kwargs['auc_roc']:.3f}", ) ax1.plot( kwargs["fpr_list"][kwargs["youden_index"]], # type: ignore kwargs["tpr_list"][kwargs["youden_index"]], # type: ignore marker="o", color="navy", ms=10, ) ax1.set( xlim=[-0.01, 1.01], ylim=[-0.01, 1.01], xlabel="1 - Specificity", ylabel="Sensitivity", title="ROC Curve", ) ax1.tick_params( axis="both", which="major", labelsize=12, ) ax1.legend( prop={"size": 12}, loc=0, framealpha=0.0, ) ax1.annotate( f"Threshold = {kwargs['youden_threshold']:.3f}", xy=( kwargs["fpr_list"][kwargs["youden_index"]], # type: ignore kwargs["tpr_list"][kwargs["youden_index"]], # type: ignore ), xycoords="data", xytext=( kwargs["fpr_list"][kwargs["youden_index"]] + 0.4, # type: ignore kwargs["tpr_list"][kwargs["youden_index"]] - 0.4, # type: ignore ), arrowprops=dict( facecolor="black", shrink=0.05, ), horizontalalignment="right", verticalalignment="bottom", ) # ----------------------------------- # subplot 2: preferred scores vs thresholds ax2.plot( kwargs["roc_thresholds"], 1 - kwargs["fpr_list"], # type: ignore label="Specificity", ) ax2.plot( kwargs["roc_thresholds"], kwargs["tpr_list"], label="Sensitivity", ) ax2.plot( thr_set1, accuracy_list, label="Accuracy", ) ax2.set( xlim=[-0.01, 1.01], ylim=[-0.01, 1.01], xlabel="Threshold", ylabel="Score", title="Preferred Scores vs Thresholds", ) ax2.tick_params( axis="both", which="major", labelsize=12, ) ax2.legend( bbox_to_anchor=(1.2, 0.5), loc="center", ncol=1, framealpha=0.0, ) ax2.axvline( kwargs["sens_spec_threshold"], color="k", ls="--", ) if isinstance(kwargs["sens_spec_threshold"], float) and kwargs["sens_spec_threshold"] <= 0.5: ax2.annotate( f"Threshold = {kwargs['sens_spec_threshold']:.3f}", xy=(kwargs["sens_spec_threshold"], 0.05), xycoords="data", xytext=(kwargs["sens_spec_threshold"] + 0.1, 0.05), # type: ignore arrowprops=dict(facecolor="black", shrink=0.05), horizontalalignment="left", verticalalignment="bottom", ) else: ax2.annotate( f"Threshold = {kwargs['sens_spec_threshold']:.3f}", xy=(kwargs["sens_spec_threshold"], 0.05), xycoords="data", xytext=(kwargs["sens_spec_threshold"] - 0.4, 0.05), # type: ignore arrowprops=dict(facecolor="black", shrink=0.05), horizontalalignment="left", verticalalignment="bottom", ) # ----------------------------------- # subplot 3: precision-recall curve ax3.plot( kwargs["recall_list"], kwargs["precision_list"], color="red", label=f"PR AUC ={kwargs['auc_pr']:.3f}", ) ax3.plot( kwargs["recall_list"][kwargs["prec_rec_index"]], # type: ignore kwargs["precision_list"][kwargs["prec_rec_index"]], # type: ignore marker="o", color="navy", ms=10, ) ax3.axvline( x=kwargs["recall_list"][kwargs["prec_rec_index"]], # type: ignore ymin=kwargs["recall_list"][kwargs["prec_rec_index"]], # type: ignore ymax=kwargs["precision_list"][kwargs["prec_rec_index"]], # type: ignore color="navy", ls="--", ) ax3.set( xlim=[-0.01, 1.01], ylim=[-0.01, 1.01], xlabel="Recall", ylabel="Precision", title="Precision-Recall Curve", ) ax3.legend( prop={"size": 12}, loc=0, framealpha=0.0, ) ax3.tick_params( axis="both", which="major", labelsize=12, ) ax3.annotate( f"Threshold = {kwargs['prec_rec_threshold']:.3f}", xy=( kwargs["recall_list"][kwargs["prec_rec_index"]], # type: ignore kwargs["precision_list"][kwargs["prec_rec_index"]], # type: ignore ), xycoords="data", xytext=( kwargs["recall_list"][kwargs["prec_rec_index"]] - 0.4, # type: ignore kwargs["precision_list"][kwargs["prec_rec_index"]] - 0.4, # type: ignore ), arrowprops=dict(facecolor="black", shrink=0.05), horizontalalignment="left", verticalalignment="bottom", ) # ----------------------------------- # subplot 4: preferred Scores vs Thresholds ax4.plot( kwargs["pr_thresholds"], kwargs["precision_list"][1:], # type: ignore label="Precision", ) ax4.plot( kwargs["pr_thresholds"], kwargs["recall_list"][1:], # type: ignore label="Recall", ) ax4.plot( kwargs["pr_thresholds"], f1_list[1:], label="F1-Score", ) ax4.plot( thr_set2, queue_rate_list, label="Queue Rate", ) ax4.set( xlim=[-0.01, 1.01], ylim=[-0.01, 1.01], xlabel="Threshold", ylabel="Score", title="Preferred Scores vs Thresholds", ) ax4.tick_params( axis="both", which="major", labelsize=12, ) ax4.legend( bbox_to_anchor=(1.2, 0.5), loc="center", ncol=1, framealpha=0.0, ) ax4.axvline( kwargs["prec_rec_threshold"], color="k", ls="--", ) if isinstance(kwargs["prec_rec_threshold"], float) and kwargs["prec_rec_threshold"] <= 0.5: ax4.annotate( f"Threshold = {kwargs['prec_rec_threshold']:.3f}", xy=(kwargs["prec_rec_threshold"], 0.03), xycoords="data", xytext=(kwargs["prec_rec_threshold"] + 0.1, 0.03), # type: ignore arrowprops=dict(facecolor="black", shrink=0.05), horizontalalignment="left", verticalalignment="bottom", ) else: ax4.annotate( f"Threshold = {kwargs['prec_rec_threshold']:.3f}", xy=(kwargs["prec_rec_threshold"], 0.03), xycoords="data", xytext=(kwargs["prec_rec_threshold"] - 0.4, 0.03), # type: ignore arrowprops=dict(facecolor="black", shrink=0.05), horizontalalignment="left", verticalalignment="bottom", ) if save_path: plt.savefig( save_path, bbox_inches="tight", dpi=200, ) if display_plot: plt.show() if return_fig: return fig return None
# TODO(amir): standardize all `plotting` params and options # currently, lots of options are duplicated; this can be done globally # TODO(amir): we are exposing this function in public API, therefore we need to probably expose # `kwargs` values; unless this function wont be useful; we can also make this private # TODO(amir): I think it would make sense to return `fig` so it can be tested # TODO(amir): there are warnings on seaborn regarding passing positional args instead of kwargs
[docs] def plot_regression_metrics( figsize: Optional[Tuple[float, float]] = (12, 16), save_path: Optional[str] = None, display_plot: Optional[bool] = False, return_fig: Optional[bool] = False, **kwargs: Dict[str, Any], ) -> Figure: """Visualizes regression metrics using ``plotting_dict_`` attribute of ``RegressionMetrics``. Parameters ---------- figsize : Tuple[float, float], optional Figure size, by default (12, 16) save_path : str, optional The full or relative path to save the plot including the image format such as "myplot.png" or "../../myplot.pdf", by default None display_plot : bool, optional Whether to show the plot, by default False return_fig : bool, optional Whether to return figure object, by default False **kwargs : Dict[str, Any] Key-value pairs of regression metrics plot Returns ------- Figure, optional """ check_var( figsize, var_name="figsize", dtypes=tuple, ) check_var( display_plot, var_name="display_plot", dtypes=bool, ) check_var( return_fig, var_name="return_fig", dtypes=bool, ) if save_path: check_var( save_path, var_name="save_path", dtypes=str, ) fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = plt.subplots( nrows=3, ncols=2, figsize=figsize, ) # subplot 1: actual vs predicted sns.regplot( x=kwargs["y_true"], y=kwargs["y_pred"], marker="o", scatter_kws={ "edgecolors": "navy", }, color="#B3C3F3", fit_reg=False, ax=ax1, ) ax1.plot( [ kwargs["y_true"].min(), # type: ignore kwargs["y_true"].max(), # type: ignore ], [ kwargs["y_true"].min(), # type: ignore kwargs["y_true"].max(), # type: ignore ], "r--", lw=3, ) ax1.set( xlabel="Actual Values", ylabel="Predicted Values", title="Actual-Predicted", ) ax1.tick_params( axis="both", which="major", labelsize=12, ) ax1.text( 0.05, 0.93, f"MAPE = {kwargs['mape']:.3f}", fontsize=12, transform=ax1.transAxes, ) ax1.text( 0.05, 0.86, f"$R^2$ = {kwargs['r2']:.3f}", fontsize=12, transform=ax1.transAxes, ) # ----------------------------------- # subplot 2: Q-Q Normal Plot scp.stats.probplot( kwargs["y_residual"], fit=True, dist="norm", plot=ax2, ) ax2.get_lines()[0].set_marker("o") ax2.get_lines()[0].set_markerfacecolor("#B3C3F3") ax2.get_lines()[0].set_markeredgecolor("navy") ax2.get_lines()[0].set_markersize(6.0) ax2.get_lines()[1].set_linewidth(3.0) ax2.get_lines()[1].set_linestyle("--") ax2.set( xlabel="Quantiles", ylabel="Residuals", title="Q-Q", ) ax2.tick_params( axis="both", which="major", labelsize=12, ) # ----------------------------------- # subplot 3: Residuals vs Fitted sns.residplot( x=kwargs["y_pred"], y=kwargs["y_true"], lowess=True, order=1, line_kws={ "color": "red", "lw": 3, "ls": "--", "alpha": 1, }, scatter_kws={ "edgecolors": "navy", }, color="#B3C3F3", robust=True, ax=ax3, ) ax3.set( xlabel="Predicted Values", ylabel="Residuals", title="Residuals-Predicted", ) ax3.tick_params( axis="both", which="major", labelsize=12, ) # ----------------------------------- # subplot 4: Sqrt Standard Residuals vs Fitted sns.regplot( x=kwargs["y_pred"], y=kwargs["y_residual_normsq"], lowess=True, line_kws={ "color": "red", "lw": 3, "ls": "--", "alpha": 1, }, scatter_kws={"edgecolors": "navy"}, color="#B3C3F3", ax=ax4, ) ax4.set( xlabel="Predicted Values", ylabel="Standardized Residuals Norm", title="Scale-Location", ) ax4.tick_params( axis="both", which="major", labelsize=12, ) # ----------------------------------- # subplot 5: Histogram of Coeff. of Variations freqs, _, _ = ax5.hist( kwargs["y_ratio"], histtype="bar", bins=np.arange(0.75, 1.25, 0.01), alpha=1.0, color="#B3C3F3", edgecolor="navy", ) ax5.set_xticks( [ 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, ], ) ax5.set_xticklabels( [ "Less", 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, "More", ], rotation=30, ) ax5.set( ylabel="Frequency", title="Prediction Variation", ) ax5.tick_params( axis="both", which="major", labelsize=12, ) ax5_ylim = max(freqs) ax5.text( 0.65, ax5_ylim, rf"""$\mu$ = {kwargs['mean_y_ratio']:.3f}""", fontsize=12, ) ax5.text( 0.65, 0.93 * ax5_ylim, f"CV = {kwargs['cv_y_ratio']:.3f}", fontsize=12, ) # ----------------------------------- # subplot 6: REC ax6.plot( kwargs["deviation"], kwargs["accuracy"], color="red", label=f"AUC = {kwargs['auc_rec']:.3f}", ) ax6.set( xlim=[-0.01, 1.01], ylim=[-0.01, 1.01], xlabel="Deviation", ylabel="Accuracy", title="REC Curve", ) ax6.tick_params( axis="both", which="major", labelsize=12, ) ax6.legend( prop={"size": 12}, loc=4, framealpha=0.0, ) if save_path: plt.savefig( save_path, bbox_inches="tight", dpi=200, ) if display_plot: plt.show() if return_fig: return fig # TODO(amir): investigate the options to return axes as well if needed # TODO(amir): investigate a better option for `plt.show()`; currently, no matter what the # figure is being shown when returning `fig`; what would be a global pattern here that can be # tested as well? Should we sacrifice the testing part here? return None