from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.display import display
from matplotlib.figure import Figure
from sklearn.metrics import (
accuracy_score,
auc,
average_precision_score,
balanced_accuracy_score,
confusion_matrix,
fbeta_score,
precision_recall_curve,
precision_recall_fscore_support,
roc_auc_score,
roc_curve,
)
from slickml.utils import check_var
from slickml.visualization import plot_binary_classification_metrics
# TODO(amir): update docstrings types in attributes
# TODO(amir): currently `pos_label` in `roc` is defaulted to None
# the None options is ok for [0, 1] and [-1, 1] cases; plan to expose `pos_label`
# with default values to None? what would be the breaking changes ?
[docs]
@dataclass
class BinaryClassificationMetrics:
"""BinaryClassificationMetrics calculates binary classification metrics in one place.
Binary metrics are computed based on three methods for calculating the thresholds to binarize
the prediction probabilities. Threshold computations including:
1) Youden Index [youden-j-index]_.
2) Maximizing Precision-Recall.
3) Maximizing Sensitivity-Specificity.
Parameters
----------
y_true : Union[List[int], np.ndarray, pd.Series]
List of ground truth values such as [0, 1] for binary problems
y_pred_proba : Union[List[float], np.ndarray, pd.Series]
List of predicted probabilities for the positive class (class=1) in binary problems
or ``y_pred_proba[:, 1]`` in scikit-learn API
threshold : float, optional
Inclusive threshold value to binarize ``y_pred_prob`` to ``y_pred`` where any value that
satisfies ``y_pred_prob >= threshold`` will set to ``class=1 (positive class)``. Note that
for ``">="`` is used instead of ``">"``, by default 0.5
average_method : str, optional
Method to calculate the average of any metric. Possible values are ``"micro"``, ``"macro"``,
``"weighted"``, ``"binary"``, by default "binary"
precision_digits : int, optional
The number of precision digits to format the scores dataframe, by default 3
display_df : bool, optional
Whether to display the formatted scores' dataframe, by default True
Methods
-------
plot(figsize=(12, 12), save_path=None, display_plot=False, return_fig=False)
Plots classification metrics
get_metrics(dtype="dataframe")
Returns calculated classification metrics
Attributes
----------
y_pred_ : np.ndarray
Predicted class based on the ``threshold``. The threshold value inclusively binarizes
``y_pred_prob`` to ``y_pred`` where any value that satisfies ``y_pred_prob >= threshold``
will set to ``class=1 (positive class)``. Note that for ``">="`` is used instead of ``">"``
accuracy_ : float
Accuracy based on the initial ``threshold`` value with a possible value between 0.0 and 1.0
balanced_accuracy_ : float
Balanced accuracy based on the initial ``threshold`` value considering the prevalence of the
classes with a possible value between 0.0 and 1.0
fpr_list_ : np.ndarray
List of calculated false-positive-rates based on ``roc_thresholds_``
tpr_list_ : np.ndarray
List of calculated true-positive-rates based on ``roc_thresholds_``
roc_thresholds_ : np.ndarray
List of thresholds value to calculate ``fpr_list_`` and ``tpr_list_``
auc_roc_ : float
Area under ROC curve with a possible value between 0.0 and 1.0
precision_list_ : np.ndarray
List of calculated precision based on ``pr_thresholds_``
recall_list_ : np.ndarray
List of calculated recall based on ``pr_thresholds_``
pr_thresholds_ : numpy.ndarray
List of precision-recall thresholds value to calculate ``precision_list_`` and ``recall_list_``
auc_pr_ : float
Area under Precision-Recall curve with a possible value between 0.0 and 1.0
precision_ : float
Precision based on the ``threshold`` value with a possible value between 0.0 and 1.0
recall_ : float
Recall based on the ``threshold`` value with a possible value between 0.0 and 1.0
f1_ : float
F1-score based on the ``threshold`` value (beta=1.0) with a possible value between 0.0 and 1.0
f2_ : float
F2-score based on the ``threshold`` value (beta=2.0) with a possible value between 0.0 and 1.0
f05_ : float
F(1/2)-score based on the ``threshold`` value (beta=0.5) with a possible value between 0.0 and
1.0
average_precision_ : float
Avearge precision based on the ``threshold`` value and class prevalence with a possible value
between 0.0 and 1.0
tn_ : np.int64
True negative counts based on the ``threshold`` value
fp_ : np.int64
False positive counts based on the ``threshold`` valuee
fn_ : np.int64
False negative counts based on the ``threshold`` value
tp_ : np.int64
True positive counts based on the ``threshold`` value
threat_score_ : float
Threat score based on the ``threshold`` value with a possible value between 0.0 and 1.0
youden_index_ : np.int64
Index of the calculated Youden index threshold
youden_threshold_ : float
Threshold calculated based on Youden Index with a possible value between 0.0 and 1.0
sens_spec_threshold_ : float
Threshold calculated based on maximized sensitivity-specificity with a possible value
between 0.0 and 1.0
prec_rec_threshold_ : float
Threshold calculated based on maximized precision-recall with a possible value between 0.0
and 1.0
thresholds_dict_ : Dict[str, float]
Calculated thresholds based on different algorithms including Youden Index
``youden_threshold_``, maximizing the area under sensitivity-specificity curve
``sens_spec_threshold_``, and maximizing the area under precision-recall curver
``prec_rec_threshold_``
metrics_dict_ : Dict[str, float]
Rounded metrics based on the number of precision digits
metrics_df_ : pd.DataFrame
Pandas DataFrame of all calculated metrics with ``threshold`` set as index
average_methods_: List[str]
List of all possible average methods
plotting_dict_: Dict[str, Any]
Plotting properties
References
----------
.. [youden-j-index] https://en.wikipedia.org/wiki/Youden%27s_J_statistic
Examples
--------
>>> from slickml.metrics import BinaryClassificationMetrics
>>> cm = BinaryClassificationMetrics(
... y_true=[1, 1, 0, 0],
... y_pred_proba=[0.95, 0.3, 0.1, 0.9]
... )
>>> f = cm.plot()
>>> m = cm.get_metrics()
"""
y_true: Union[List[int], np.ndarray, pd.Series]
y_pred_proba: Union[List[float], np.ndarray, pd.Series]
threshold: Optional[float] = 0.5
average_method: Optional[str] = "binary"
precision_digits: Optional[int] = 3
display_df: Optional[bool] = True
[docs]
def __post_init__(self) -> None:
"""Post instantiation validations and assignments."""
check_var(
self.y_true,
var_name="y_true",
dtypes=(
np.ndarray,
pd.Series,
list,
),
)
check_var(
self.y_pred_proba,
var_name="y_pred_proba",
dtypes=(
np.ndarray,
pd.Series,
list,
),
)
check_var(
self.threshold,
var_name="threshold",
dtypes=float,
)
check_var(
self.average_method,
var_name="average_method",
dtypes=str,
values=(
"micro",
"macro",
"weighted",
"binary",
),
)
check_var(
self.precision_digits,
var_name="precision_digits",
dtypes=int,
)
check_var(
self.display_df,
var_name="display_df",
dtypes=bool,
)
# TODO(amir): add `values_between` option to `check_var()`
if self.threshold is not None and (self.threshold < 0.0 or self.threshold > 1.0):
raise ValueError(
"The input threshold must have a value between 0.0 and 1.0.",
)
# TODO(amir): how we can pull off special cases like this ?
if self.average_method == "binary" or not self.average_method:
self.average_method = None
# TODO(amir): add `list_to_array()` function into slickml.utils
# TODO(amir): how numpy works with pd.Series here? kinda fuzzy
if not isinstance(self.y_true, np.ndarray):
self.y_true = np.array(self.y_true)
if not isinstance(self.y_pred_proba, np.ndarray):
self.y_pred_proba = np.array(self.y_pred_proba)
self.y_pred_ = (self.y_pred_proba >= self.threshold).astype(int)
self.accuracy_ = self._accuracy()
self.balanced_accuracy_ = self._balanced_accuracy()
(
self.fpr_list_,
self.tpr_list_,
self.roc_thresholds_,
) = self._roc_curve()
self.auc_roc_ = self._auc_roc()
(
self.precision_list_,
self.recall_list_,
self.pr_thresholds_,
) = self._precision_recall_curve()
self.auc_pr_ = self._auc_pr()
(
self.precision_,
self.recall_,
self.f1_,
) = self._precision_recall_f1()
(
self.f2_,
self.f05_,
) = self._f2_f50()
self.average_precision_ = self._average_precision()
(
self.tn_,
self.fp_,
self.fn_,
self.tp_,
) = self._confusion_matrix()
self.threat_score_ = self._threat_score()
self.metrics_dict_ = self._metrics_dict()
self.metrics_df_ = self._metrics_df()
(
self.youden_index_,
self.youden_threshold_,
) = self._threshold_youden()
(
self.sens_spec_index_,
self.sens_spec_threshold_,
) = self._threshold_sens_spec()
(
self.prec_rec_index_,
self.prec_rec_threshold_,
) = self._threshold_prec_rec()
self.thresholds_dict_ = self._thresholds_dict()
self.plotting_dict_ = self._plotting_dict()
self.average_methods_ = self._average_methods()
[docs]
def plot(
self,
figsize: Optional[Tuple[float, float]] = (12, 12),
save_path: Optional[str] = None,
display_plot: Optional[bool] = False,
return_fig: Optional[bool] = False,
) -> Optional[Figure]:
"""Plots classification metrics.
Parameters
----------
figsize : Tuple[float, float], optional
Figure size, by default (12, 12)
save_path : str, optional
The full or relative path to save the plot including the image format such as
"myplot.png" or "../../myplot.pdf", by default None
display_plot : bool, optional
Whether to show the plot, by default False
return_fig : bool, optional
Whether to return figure object, by default False
Returns
-------
Figure
"""
return plot_binary_classification_metrics(
figsize=figsize,
save_path=save_path,
display_plot=display_plot,
return_fig=return_fig,
**self.plotting_dict_,
)
[docs]
def get_metrics(
self,
dtype: Optional[str] = "dataframe",
) -> Union[pd.DataFrame, Dict[str, Optional[float]]]:
"""Returns calculated metrics with desired dtypes.
Currently, available output types are "dataframe" and "dict".
Parameters
----------
dtype : str, optional
Results dtype, by default "dataframe"
Returns
-------
Union[pd.DataFrame, Dict[str, Optional[float]]]
"""
check_var(
dtype,
var_name="dtype",
dtypes=str,
values=("dataframe", "dict"),
)
if dtype == "dataframe":
return self.metrics_df_
else:
return self.metrics_dict_
def _accuracy(self) -> float:
"""Calculates accuracy score.
Returns
-------
float
"""
return accuracy_score(
y_true=self.y_true,
y_pred=self.y_pred_,
normalize=True,
)
def _balanced_accuracy(self) -> float:
"""Calculates balanced accuracy score.
Returns
-------
float
"""
return balanced_accuracy_score(
y_true=self.y_true,
y_pred=self.y_pred_,
adjusted=False,
)
# TODO(amir): check return types here between ndarray or list
def _roc_curve(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Calculates the roc curve elements: fpr, tpr, thresholds.
Returns
-------
Tuple[np.ndarray, np.ndarray, np.ndarray]
"""
fpr_list, tpr_list, roc_thresholds = roc_curve(
y_true=self.y_true,
y_score=self.y_pred_proba,
)
return (fpr_list, tpr_list, roc_thresholds)
# TODO(amir): check the API when `average_method="binary"` that does it pass None as the method
# or keep it as "binary"
def _auc_roc(self) -> float:
"""Calculates the area under ROC curve (auc_roc).
Returns
-------
float
"""
return roc_auc_score(
y_true=self.y_true,
y_score=self.y_pred_proba,
average=self.average_method,
)
def _precision_recall_curve(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Calculates precision recall curve elements: precision_list, recall_list, pr_thresholds.
Returns
-------
Tuple[np.ndarray, np.ndarray, np.ndarray]
"""
precision_list, recall_list, pr_thresholds = precision_recall_curve(
y_true=self.y_true,
y_score=self.y_pred_proba,
)
return precision_list, recall_list, pr_thresholds
def _auc_pr(self) -> float:
"""Calculates the area under Precision-Recal curve (auc_pr).
Returns
-------
float
"""
return auc(
self.recall_list_,
self.precision_list_,
)
def _precision_recall_f1(self) -> Tuple[float, float, float]:
"""Calculates precision, recall, and f1-score.
Returns
-------
Tuple[float, float, float]
"""
precision, recall, f1, _ = precision_recall_fscore_support(
y_true=self.y_true,
y_pred=self.y_pred_,
beta=1.0,
average=self.average_method,
)
# updating precision, recall, and f1 for binary average method
if not self.average_method:
precision = precision[1]
recall = recall[1]
f1 = f1[1]
return (precision, recall, f1)
def _f2_f50(self) -> Tuple[float, float]:
"""Calculates f2-score and f0.5-score.
Returns
-------
Tuple[float, float]
"""
f2 = fbeta_score(
y_true=self.y_true,
y_pred=self.y_pred_,
beta=2.0,
average=self.average_method,
)
f05 = fbeta_score(
y_true=self.y_true,
y_pred=self.y_pred_,
beta=0.5,
average=self.average_method,
)
# updating f2, f0.5 scores for binary average method
if not self.average_method:
f2 = f2[1]
f05 = f05[1]
return (f2, f05)
def _average_precision(self) -> float:
"""Calculates average precision.
Returns
-------
float
"""
return average_precision_score(
y_true=self.y_true,
y_score=self.y_pred_proba,
average=self.average_method,
)
def _confusion_matrix(self) -> Tuple[float, float, float, float]:
"""Calculates confusion matrix elements: tn, fp, fn, tp.
Returns
-------
Tuple[float, float, float, float]
"""
return confusion_matrix(
y_true=self.y_true,
y_pred=self.y_pred_,
).ravel()
def _threat_score(self) -> float:
"""Calculates threat score.
Returns
-------
float
"""
if self.average_method == "weighted":
w = self.tp_ + self.tn_
wp = self.tp_ / w
wn = self.tn_ / w
threat_score = wp * (self.tp_ / (self.tp_ + self.fp_ + self.fn_)) + wn * (
self.tn_ / (self.tn_ + self.fn_ + self.fp_)
)
elif self.average_method == "macro":
threat_score = 0.5 * (self.tp_ / (self.tp_ + self.fp_ + self.fn_)) + 0.5 * (
self.tn_ / (self.tn_ + self.fn_ + self.fp_)
)
else:
threat_score = self.tp_ / (self.tp_ + self.fp_ + self.fn_)
return threat_score
def _metrics_dict(self) -> Dict[str, float]:
"""Rounded calculated metrics based on the number of precision digits.
Returns
-------
Dict[str, float]
"""
return {
"Accuracy": round(
number=self.accuracy_,
ndigits=self.precision_digits,
),
"Balanced Accuracy": round(
number=self.balanced_accuracy_,
ndigits=self.precision_digits,
),
"ROC AUC": round(
number=self.auc_roc_,
ndigits=self.precision_digits,
),
"PR AUC": round(
number=self.auc_pr_,
ndigits=self.precision_digits,
),
"Precision": round(
number=self.precision_,
ndigits=self.precision_digits,
),
"Recall": round(
number=self.recall_,
ndigits=self.precision_digits,
),
"F-1 Score": round(
number=self.f1_,
ndigits=self.precision_digits,
),
"F-2 Score": round(
number=self.f2_,
ndigits=self.precision_digits,
),
"F-0.50 Score": round(
number=self.f05_,
ndigits=self.precision_digits,
),
"Threat Score": round(
number=self.threat_score_,
ndigits=self.precision_digits,
),
"Average Precision": round(
number=self.average_precision_,
ndigits=self.precision_digits,
),
"TP": self.tp_,
"TN": self.tn_,
"FP": self.fp_,
"FN": self.fn_,
}
def _metrics_df(self) -> pd.DataFrame:
"""Creates a pandas DataFrame of all calculated metrics with custom formatting.
The resulted dataframe contains all the metrics based on the precision digits and selected
average method.
Returns
-------
pd.DataFrame
"""
# update None average_method back to binary for printing
if not self.average_method:
self.average_method = "binary"
metrics_df = pd.DataFrame(
data=self.metrics_dict_,
index=[
f"""Threshold = {self.threshold:.{self.precision_digits}f} | Average =
{self.average_method.title()}""",
],
)
# TODO(amir): can we do df.reindex() ?
metrics_df = metrics_df.reindex(
columns=[
"Accuracy",
"Balanced Accuracy",
"ROC AUC",
"PR AUC",
"Precision",
"Recall",
"Average Precision",
"F-1 Score",
"F-2 Score",
"F-0.50 Score",
"Threat Score",
"TP",
"TN",
"FP",
"FN",
],
)
# TODO(amir): move this to a utility function under utils/format.py since it is repeated
# that would make it more general and scalable across API
# Set CSS properties
th_props = [
("font-size", "12px"),
("text-align", "left"),
("font-weight", "bold"),
]
td_props = [
("font-size", "12px"),
("text-align", "center"),
]
# Set table styles
styles = [
dict(selector="th", props=th_props),
dict(selector="td", props=td_props),
]
cm = sns.light_palette(
"blue",
as_cmap=True,
)
if self.display_df:
display(
metrics_df.style.background_gradient(
cmap=cm,
).set_table_styles(styles),
)
return metrics_df
def _threshold_youden(self) -> Tuple[int, float]:
"""Calculates the Youden index and Youden threshold.
Returns
-------
Tuple[int, float]
"""
youden_index = np.argmax(
np.abs(self.tpr_list_ - self.fpr_list_),
)
youden_threshold = self.roc_thresholds_[youden_index]
return (youden_index, youden_threshold)
def _threshold_sens_spec(self) -> Tuple[int, float]:
"""Calculates the threshold that maximizes sensitivity-specificity curve.
Returns
-------
Tuple[int, float]
"""
sens_spec_index = np.argmin(
abs(self.tpr_list_ + self.fpr_list_ - 1),
)
sens_spec_threshold = self.roc_thresholds_[sens_spec_index]
return (sens_spec_index, sens_spec_threshold)
def _threshold_prec_rec(self) -> Tuple[int, float]:
"""Calculates the threshold that maximizes precision-recall curve.
Returns
-------
Tuple[int, float]
"""
prec_rec_index = np.argmin(abs(self.precision_list_ - self.recall_list_))
prec_rec_threshold = self.pr_thresholds_[prec_rec_index]
return (prec_rec_index, prec_rec_threshold)
def _thresholds_dict(self) -> Dict[str, float]:
"""Returns the calculated thresholds as a dictionary.
Returns
-------
Dict[str, float]
"""
return {
"Youden": self.youden_threshold_,
"Sensitivity-Specificity": self.sens_spec_threshold_,
"Precision-Recall-F1": self.prec_rec_threshold_,
}
# TODO(amir): check Any here since it can be Union[np.ndarray, int, float] ?
def _plotting_dict(self) -> Dict[str, Any]:
"""Returns the plotting properties."""
return {
"roc_thresholds": self.roc_thresholds_,
"pr_thresholds": self.pr_thresholds_,
"precision_list": self.precision_list_,
"recall_list": self.recall_list_,
"y_pred_proba": self.y_pred_proba,
"y_true": self.y_true,
"fpr_list": self.fpr_list_,
"tpr_list": self.tpr_list_,
"auc_roc": self.auc_roc_,
"youden_index": self.youden_index_,
"youden_threshold": self.youden_threshold_,
"sens_spec_threshold": self.sens_spec_threshold_,
"prec_rec_threshold": self.prec_rec_threshold_,
"auc_pr": self.auc_pr_,
"prec_rec_index": self.prec_rec_index_,
}
def _average_methods(self) -> List[str]:
"""Returns the list of average methods.
Returns
-------
List[str]
"""
return [
"binary",
"weighted",
"macro",
"micro",
]