first commit
This commit is contained in:
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
The :mod:`sklearn.metrics` module includes score functions, performance metrics
|
||||
and pairwise metrics and distance computations.
|
||||
"""
|
||||
|
||||
|
||||
from ._ranking import auc
|
||||
from ._ranking import average_precision_score
|
||||
from ._ranking import coverage_error
|
||||
from ._ranking import det_curve
|
||||
from ._ranking import dcg_score
|
||||
from ._ranking import label_ranking_average_precision_score
|
||||
from ._ranking import label_ranking_loss
|
||||
from ._ranking import ndcg_score
|
||||
from ._ranking import precision_recall_curve
|
||||
from ._ranking import roc_auc_score
|
||||
from ._ranking import roc_curve
|
||||
from ._ranking import top_k_accuracy_score
|
||||
|
||||
from ._classification import accuracy_score
|
||||
from ._classification import balanced_accuracy_score
|
||||
from ._classification import classification_report
|
||||
from ._classification import cohen_kappa_score
|
||||
from ._classification import confusion_matrix
|
||||
from ._classification import f1_score
|
||||
from ._classification import fbeta_score
|
||||
from ._classification import hamming_loss
|
||||
from ._classification import hinge_loss
|
||||
from ._classification import jaccard_score
|
||||
from ._classification import log_loss
|
||||
from ._classification import matthews_corrcoef
|
||||
from ._classification import precision_recall_fscore_support
|
||||
from ._classification import precision_score
|
||||
from ._classification import recall_score
|
||||
from ._classification import zero_one_loss
|
||||
from ._classification import brier_score_loss
|
||||
from ._classification import multilabel_confusion_matrix
|
||||
|
||||
from ._dist_metrics import DistanceMetric
|
||||
|
||||
from . import cluster
|
||||
from .cluster import adjusted_mutual_info_score
|
||||
from .cluster import adjusted_rand_score
|
||||
from .cluster import rand_score
|
||||
from .cluster import pair_confusion_matrix
|
||||
from .cluster import completeness_score
|
||||
from .cluster import consensus_score
|
||||
from .cluster import homogeneity_completeness_v_measure
|
||||
from .cluster import homogeneity_score
|
||||
from .cluster import mutual_info_score
|
||||
from .cluster import normalized_mutual_info_score
|
||||
from .cluster import fowlkes_mallows_score
|
||||
from .cluster import silhouette_samples
|
||||
from .cluster import silhouette_score
|
||||
from .cluster import calinski_harabasz_score
|
||||
from .cluster import v_measure_score
|
||||
from .cluster import davies_bouldin_score
|
||||
|
||||
from .pairwise import euclidean_distances
|
||||
from .pairwise import nan_euclidean_distances
|
||||
from .pairwise import pairwise_distances
|
||||
from .pairwise import pairwise_distances_argmin
|
||||
from .pairwise import pairwise_distances_argmin_min
|
||||
from .pairwise import pairwise_kernels
|
||||
from .pairwise import pairwise_distances_chunked
|
||||
|
||||
from ._regression import explained_variance_score
|
||||
from ._regression import max_error
|
||||
from ._regression import mean_absolute_error
|
||||
from ._regression import mean_squared_error
|
||||
from ._regression import mean_squared_log_error
|
||||
from ._regression import median_absolute_error
|
||||
from ._regression import mean_absolute_percentage_error
|
||||
from ._regression import mean_pinball_loss
|
||||
from ._regression import r2_score
|
||||
from ._regression import mean_tweedie_deviance
|
||||
from ._regression import mean_poisson_deviance
|
||||
from ._regression import mean_gamma_deviance
|
||||
from ._regression import d2_tweedie_score
|
||||
from ._regression import d2_pinball_score
|
||||
from ._regression import d2_absolute_error_score
|
||||
|
||||
|
||||
from ._scorer import check_scoring
|
||||
from ._scorer import make_scorer
|
||||
from ._scorer import SCORERS
|
||||
from ._scorer import get_scorer
|
||||
from ._scorer import get_scorer_names
|
||||
|
||||
|
||||
from ._plot.det_curve import plot_det_curve
|
||||
from ._plot.det_curve import DetCurveDisplay
|
||||
from ._plot.roc_curve import plot_roc_curve
|
||||
from ._plot.roc_curve import RocCurveDisplay
|
||||
from ._plot.precision_recall_curve import plot_precision_recall_curve
|
||||
from ._plot.precision_recall_curve import PrecisionRecallDisplay
|
||||
|
||||
from ._plot.confusion_matrix import plot_confusion_matrix
|
||||
from ._plot.confusion_matrix import ConfusionMatrixDisplay
|
||||
|
||||
|
||||
__all__ = [
|
||||
"accuracy_score",
|
||||
"adjusted_mutual_info_score",
|
||||
"adjusted_rand_score",
|
||||
"auc",
|
||||
"average_precision_score",
|
||||
"balanced_accuracy_score",
|
||||
"calinski_harabasz_score",
|
||||
"check_scoring",
|
||||
"classification_report",
|
||||
"cluster",
|
||||
"cohen_kappa_score",
|
||||
"completeness_score",
|
||||
"ConfusionMatrixDisplay",
|
||||
"confusion_matrix",
|
||||
"consensus_score",
|
||||
"coverage_error",
|
||||
"d2_tweedie_score",
|
||||
"d2_absolute_error_score",
|
||||
"d2_pinball_score",
|
||||
"dcg_score",
|
||||
"davies_bouldin_score",
|
||||
"DetCurveDisplay",
|
||||
"det_curve",
|
||||
"DistanceMetric",
|
||||
"euclidean_distances",
|
||||
"explained_variance_score",
|
||||
"f1_score",
|
||||
"fbeta_score",
|
||||
"fowlkes_mallows_score",
|
||||
"get_scorer",
|
||||
"hamming_loss",
|
||||
"hinge_loss",
|
||||
"homogeneity_completeness_v_measure",
|
||||
"homogeneity_score",
|
||||
"jaccard_score",
|
||||
"label_ranking_average_precision_score",
|
||||
"label_ranking_loss",
|
||||
"log_loss",
|
||||
"make_scorer",
|
||||
"nan_euclidean_distances",
|
||||
"matthews_corrcoef",
|
||||
"max_error",
|
||||
"mean_absolute_error",
|
||||
"mean_squared_error",
|
||||
"mean_squared_log_error",
|
||||
"mean_pinball_loss",
|
||||
"mean_poisson_deviance",
|
||||
"mean_gamma_deviance",
|
||||
"mean_tweedie_deviance",
|
||||
"median_absolute_error",
|
||||
"mean_absolute_percentage_error",
|
||||
"multilabel_confusion_matrix",
|
||||
"mutual_info_score",
|
||||
"ndcg_score",
|
||||
"normalized_mutual_info_score",
|
||||
"pair_confusion_matrix",
|
||||
"pairwise_distances",
|
||||
"pairwise_distances_argmin",
|
||||
"pairwise_distances_argmin_min",
|
||||
"pairwise_distances_chunked",
|
||||
"pairwise_kernels",
|
||||
"plot_confusion_matrix",
|
||||
"plot_det_curve",
|
||||
"plot_precision_recall_curve",
|
||||
"plot_roc_curve",
|
||||
"PrecisionRecallDisplay",
|
||||
"precision_recall_curve",
|
||||
"precision_recall_fscore_support",
|
||||
"precision_score",
|
||||
"r2_score",
|
||||
"rand_score",
|
||||
"recall_score",
|
||||
"RocCurveDisplay",
|
||||
"roc_auc_score",
|
||||
"roc_curve",
|
||||
"SCORERS",
|
||||
"get_scorer_names",
|
||||
"silhouette_samples",
|
||||
"silhouette_score",
|
||||
"top_k_accuracy_score",
|
||||
"v_measure_score",
|
||||
"zero_one_loss",
|
||||
"brier_score_loss",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
Common code for all metrics.
|
||||
|
||||
"""
|
||||
# Authors: Alexandre Gramfort <alexandre.gramfort@inria.fr>
|
||||
# Mathieu Blondel <mathieu@mblondel.org>
|
||||
# Olivier Grisel <olivier.grisel@ensta.org>
|
||||
# Arnaud Joly <a.joly@ulg.ac.be>
|
||||
# Jochen Wersdorfer <jochen@wersdoerfer.de>
|
||||
# Lars Buitinck
|
||||
# Joel Nothman <joel.nothman@gmail.com>
|
||||
# Noel Dawe <noel@dawe.me>
|
||||
# License: BSD 3 clause
|
||||
|
||||
from itertools import combinations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..utils import check_array, check_consistent_length
|
||||
from ..utils.multiclass import type_of_target
|
||||
|
||||
|
||||
def _average_binary_score(binary_metric, y_true, y_score, average, sample_weight=None):
|
||||
"""Average a binary metric for multilabel classification.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y_true : array, shape = [n_samples] or [n_samples, n_classes]
|
||||
True binary labels in binary label indicators.
|
||||
|
||||
y_score : array, shape = [n_samples] or [n_samples, n_classes]
|
||||
Target scores, can either be probability estimates of the positive
|
||||
class, confidence values, or binary decisions.
|
||||
|
||||
average : {None, 'micro', 'macro', 'samples', 'weighted'}, default='macro'
|
||||
If ``None``, the scores for each class are returned. Otherwise,
|
||||
this determines the type of averaging performed on the data:
|
||||
|
||||
``'micro'``:
|
||||
Calculate metrics globally by considering each element of the label
|
||||
indicator matrix as a label.
|
||||
``'macro'``:
|
||||
Calculate metrics for each label, and find their unweighted
|
||||
mean. This does not take label imbalance into account.
|
||||
``'weighted'``:
|
||||
Calculate metrics for each label, and find their average, weighted
|
||||
by support (the number of true instances for each label).
|
||||
``'samples'``:
|
||||
Calculate metrics for each instance, and find their average.
|
||||
|
||||
Will be ignored when ``y_true`` is binary.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
binary_metric : callable, returns shape [n_classes]
|
||||
The binary metric function to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
score : float or array of shape [n_classes]
|
||||
If not ``None``, average the score, else return the score for each
|
||||
classes.
|
||||
|
||||
"""
|
||||
average_options = (None, "micro", "macro", "weighted", "samples")
|
||||
if average not in average_options:
|
||||
raise ValueError("average has to be one of {0}".format(average_options))
|
||||
|
||||
y_type = type_of_target(y_true)
|
||||
if y_type not in ("binary", "multilabel-indicator"):
|
||||
raise ValueError("{0} format is not supported".format(y_type))
|
||||
|
||||
if y_type == "binary":
|
||||
return binary_metric(y_true, y_score, sample_weight=sample_weight)
|
||||
|
||||
check_consistent_length(y_true, y_score, sample_weight)
|
||||
y_true = check_array(y_true)
|
||||
y_score = check_array(y_score)
|
||||
|
||||
not_average_axis = 1
|
||||
score_weight = sample_weight
|
||||
average_weight = None
|
||||
|
||||
if average == "micro":
|
||||
if score_weight is not None:
|
||||
score_weight = np.repeat(score_weight, y_true.shape[1])
|
||||
y_true = y_true.ravel()
|
||||
y_score = y_score.ravel()
|
||||
|
||||
elif average == "weighted":
|
||||
if score_weight is not None:
|
||||
average_weight = np.sum(
|
||||
np.multiply(y_true, np.reshape(score_weight, (-1, 1))), axis=0
|
||||
)
|
||||
else:
|
||||
average_weight = np.sum(y_true, axis=0)
|
||||
if np.isclose(average_weight.sum(), 0.0):
|
||||
return 0
|
||||
|
||||
elif average == "samples":
|
||||
# swap average_weight <-> score_weight
|
||||
average_weight = score_weight
|
||||
score_weight = None
|
||||
not_average_axis = 0
|
||||
|
||||
if y_true.ndim == 1:
|
||||
y_true = y_true.reshape((-1, 1))
|
||||
|
||||
if y_score.ndim == 1:
|
||||
y_score = y_score.reshape((-1, 1))
|
||||
|
||||
n_classes = y_score.shape[not_average_axis]
|
||||
score = np.zeros((n_classes,))
|
||||
for c in range(n_classes):
|
||||
y_true_c = y_true.take([c], axis=not_average_axis).ravel()
|
||||
y_score_c = y_score.take([c], axis=not_average_axis).ravel()
|
||||
score[c] = binary_metric(y_true_c, y_score_c, sample_weight=score_weight)
|
||||
|
||||
# Average the results
|
||||
if average is not None:
|
||||
if average_weight is not None:
|
||||
# Scores with 0 weights are forced to be 0, preventing the average
|
||||
# score from being affected by 0-weighted NaN elements.
|
||||
average_weight = np.asarray(average_weight)
|
||||
score[average_weight == 0] = 0
|
||||
return np.average(score, weights=average_weight)
|
||||
else:
|
||||
return score
|
||||
|
||||
|
||||
def _average_multiclass_ovo_score(binary_metric, y_true, y_score, average="macro"):
|
||||
"""Average one-versus-one scores for multiclass classification.
|
||||
|
||||
Uses the binary metric for one-vs-one multiclass classification,
|
||||
where the score is computed according to the Hand & Till (2001) algorithm.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
binary_metric : callable
|
||||
The binary metric function to use that accepts the following as input:
|
||||
y_true_target : array, shape = [n_samples_target]
|
||||
Some sub-array of y_true for a pair of classes designated
|
||||
positive and negative in the one-vs-one scheme.
|
||||
y_score_target : array, shape = [n_samples_target]
|
||||
Scores corresponding to the probability estimates
|
||||
of a sample belonging to the designated positive class label
|
||||
|
||||
y_true : array-like of shape (n_samples,)
|
||||
True multiclass labels.
|
||||
|
||||
y_score : array-like of shape (n_samples, n_classes)
|
||||
Target scores corresponding to probability estimates of a sample
|
||||
belonging to a particular class.
|
||||
|
||||
average : {'macro', 'weighted'}, default='macro'
|
||||
Determines the type of averaging performed on the pairwise binary
|
||||
metric scores:
|
||||
``'macro'``:
|
||||
Calculate metrics for each label, and find their unweighted
|
||||
mean. This does not take label imbalance into account. Classes
|
||||
are assumed to be uniformly distributed.
|
||||
``'weighted'``:
|
||||
Calculate metrics for each label, taking into account the
|
||||
prevalence of the classes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
score : float
|
||||
Average of the pairwise binary metric scores.
|
||||
"""
|
||||
check_consistent_length(y_true, y_score)
|
||||
|
||||
y_true_unique = np.unique(y_true)
|
||||
n_classes = y_true_unique.shape[0]
|
||||
n_pairs = n_classes * (n_classes - 1) // 2
|
||||
pair_scores = np.empty(n_pairs)
|
||||
|
||||
is_weighted = average == "weighted"
|
||||
prevalence = np.empty(n_pairs) if is_weighted else None
|
||||
|
||||
# Compute scores treating a as positive class and b as negative class,
|
||||
# then b as positive class and a as negative class
|
||||
for ix, (a, b) in enumerate(combinations(y_true_unique, 2)):
|
||||
a_mask = y_true == a
|
||||
b_mask = y_true == b
|
||||
ab_mask = np.logical_or(a_mask, b_mask)
|
||||
|
||||
if is_weighted:
|
||||
prevalence[ix] = np.average(ab_mask)
|
||||
|
||||
a_true = a_mask[ab_mask]
|
||||
b_true = b_mask[ab_mask]
|
||||
|
||||
a_true_score = binary_metric(a_true, y_score[ab_mask, a])
|
||||
b_true_score = binary_metric(b_true, y_score[ab_mask, b])
|
||||
pair_scores[ix] = (a_true_score + b_true_score) / 2
|
||||
|
||||
return np.average(pair_scores, weights=prevalence)
|
||||
|
||||
|
||||
def _check_pos_label_consistency(pos_label, y_true):
|
||||
"""Check if `pos_label` need to be specified or not.
|
||||
|
||||
In binary classification, we fix `pos_label=1` if the labels are in the set
|
||||
{-1, 1} or {0, 1}. Otherwise, we raise an error asking to specify the
|
||||
`pos_label` parameters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos_label : int, str or None
|
||||
The positive label.
|
||||
y_true : ndarray of shape (n_samples,)
|
||||
The target vector.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pos_label : int
|
||||
If `pos_label` can be inferred, it will be returned.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
In the case that `y_true` does not have label in {-1, 1} or {0, 1},
|
||||
it will raise a `ValueError`.
|
||||
"""
|
||||
# ensure binary classification if pos_label is not specified
|
||||
# classes.dtype.kind in ('O', 'U', 'S') is required to avoid
|
||||
# triggering a FutureWarning by calling np.array_equal(a, b)
|
||||
# when elements in the two arrays are not comparable.
|
||||
classes = np.unique(y_true)
|
||||
if pos_label is None and (
|
||||
classes.dtype.kind in "OUS"
|
||||
or not (
|
||||
np.array_equal(classes, [0, 1])
|
||||
or np.array_equal(classes, [-1, 1])
|
||||
or np.array_equal(classes, [0])
|
||||
or np.array_equal(classes, [-1])
|
||||
or np.array_equal(classes, [1])
|
||||
)
|
||||
):
|
||||
classes_repr = ", ".join(repr(c) for c in classes)
|
||||
raise ValueError(
|
||||
f"y_true takes value in {{{classes_repr}}} and pos_label is not "
|
||||
"specified: either make y_true take value in {0, 1} or "
|
||||
"{-1, 1} or pass pos_label explicitly."
|
||||
)
|
||||
elif pos_label is None:
|
||||
pos_label = 1
|
||||
|
||||
return pos_label
|
||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
@@ -0,0 +1,87 @@
|
||||
cimport numpy as np
|
||||
from libc.math cimport sqrt, exp
|
||||
|
||||
from ..utils._typedefs cimport DTYPE_t, ITYPE_t
|
||||
|
||||
######################################################################
|
||||
# Inline distance functions
|
||||
#
|
||||
# We use these for the default (euclidean) case so that they can be
|
||||
# inlined. This leads to faster computation for the most common case
|
||||
cdef inline DTYPE_t euclidean_dist(const DTYPE_t* x1, const DTYPE_t* x2,
|
||||
ITYPE_t size) nogil except -1:
|
||||
cdef DTYPE_t tmp, d=0
|
||||
cdef np.intp_t j
|
||||
for j in range(size):
|
||||
tmp = x1[j] - x2[j]
|
||||
d += tmp * tmp
|
||||
return sqrt(d)
|
||||
|
||||
|
||||
cdef inline DTYPE_t euclidean_rdist(const DTYPE_t* x1, const DTYPE_t* x2,
|
||||
ITYPE_t size) nogil except -1:
|
||||
cdef DTYPE_t tmp, d=0
|
||||
cdef np.intp_t j
|
||||
for j in range(size):
|
||||
tmp = x1[j] - x2[j]
|
||||
d += tmp * tmp
|
||||
return d
|
||||
|
||||
|
||||
cdef inline DTYPE_t euclidean_dist_to_rdist(const DTYPE_t dist) nogil except -1:
|
||||
return dist * dist
|
||||
|
||||
|
||||
cdef inline DTYPE_t euclidean_rdist_to_dist(const DTYPE_t dist) nogil except -1:
|
||||
return sqrt(dist)
|
||||
|
||||
|
||||
######################################################################
|
||||
# DistanceMetric base class
|
||||
cdef class DistanceMetric:
|
||||
# The following attributes are required for a few of the subclasses.
|
||||
# we must define them here so that cython's limited polymorphism will work.
|
||||
# Because we don't expect to instantiate a lot of these objects, the
|
||||
# extra memory overhead of this setup should not be an issue.
|
||||
cdef DTYPE_t p
|
||||
cdef DTYPE_t[::1] vec
|
||||
cdef DTYPE_t[:, ::1] mat
|
||||
cdef ITYPE_t size
|
||||
cdef object func
|
||||
cdef object kwargs
|
||||
|
||||
cdef DTYPE_t dist(self, const DTYPE_t* x1, const DTYPE_t* x2,
|
||||
ITYPE_t size) nogil except -1
|
||||
|
||||
cdef DTYPE_t rdist(self, const DTYPE_t* x1, const DTYPE_t* x2,
|
||||
ITYPE_t size) nogil except -1
|
||||
|
||||
cdef int pdist(self, const DTYPE_t[:, ::1] X, DTYPE_t[:, ::1] D) except -1
|
||||
|
||||
cdef int cdist(self, const DTYPE_t[:, ::1] X, const DTYPE_t[:, ::1] Y,
|
||||
DTYPE_t[:, ::1] D) except -1
|
||||
|
||||
cdef DTYPE_t _rdist_to_dist(self, DTYPE_t rdist) nogil except -1
|
||||
|
||||
cdef DTYPE_t _dist_to_rdist(self, DTYPE_t dist) nogil except -1
|
||||
|
||||
|
||||
######################################################################
|
||||
# DatasetsPair base class
|
||||
cdef class DatasetsPair:
|
||||
cdef DistanceMetric distance_metric
|
||||
|
||||
cdef ITYPE_t n_samples_X(self) nogil
|
||||
|
||||
cdef ITYPE_t n_samples_Y(self) nogil
|
||||
|
||||
cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil
|
||||
|
||||
cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil
|
||||
|
||||
|
||||
cdef class DenseDenseDatasetsPair(DatasetsPair):
|
||||
cdef:
|
||||
const DTYPE_t[:, ::1] X
|
||||
const DTYPE_t[:, ::1] Y
|
||||
ITYPE_t d
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,116 @@
|
||||
from ...base import is_classifier
|
||||
|
||||
|
||||
def _check_classifier_response_method(estimator, response_method):
|
||||
"""Return prediction method from the response_method
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator: object
|
||||
Classifier to check
|
||||
|
||||
response_method: {'auto', 'predict_proba', 'decision_function'}
|
||||
Specifies whether to use :term:`predict_proba` or
|
||||
:term:`decision_function` as the target response. If set to 'auto',
|
||||
:term:`predict_proba` is tried first and if it does not exist
|
||||
:term:`decision_function` is tried next.
|
||||
|
||||
Returns
|
||||
-------
|
||||
prediction_method: callable
|
||||
prediction method of estimator
|
||||
"""
|
||||
|
||||
if response_method not in ("predict_proba", "decision_function", "auto"):
|
||||
raise ValueError(
|
||||
"response_method must be 'predict_proba', 'decision_function' or 'auto'"
|
||||
)
|
||||
|
||||
error_msg = "response method {} is not defined in {}"
|
||||
if response_method != "auto":
|
||||
prediction_method = getattr(estimator, response_method, None)
|
||||
if prediction_method is None:
|
||||
raise ValueError(
|
||||
error_msg.format(response_method, estimator.__class__.__name__)
|
||||
)
|
||||
else:
|
||||
predict_proba = getattr(estimator, "predict_proba", None)
|
||||
decision_function = getattr(estimator, "decision_function", None)
|
||||
prediction_method = predict_proba or decision_function
|
||||
if prediction_method is None:
|
||||
raise ValueError(
|
||||
error_msg.format(
|
||||
"decision_function or predict_proba", estimator.__class__.__name__
|
||||
)
|
||||
)
|
||||
|
||||
return prediction_method
|
||||
|
||||
|
||||
def _get_response(X, estimator, response_method, pos_label=None):
|
||||
"""Return response and positive label.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
Input values.
|
||||
|
||||
estimator : estimator instance
|
||||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
|
||||
in which the last estimator is a classifier.
|
||||
|
||||
response_method: {'auto', 'predict_proba', 'decision_function'}
|
||||
Specifies whether to use :term:`predict_proba` or
|
||||
:term:`decision_function` as the target response. If set to 'auto',
|
||||
:term:`predict_proba` is tried first and if it does not exist
|
||||
:term:`decision_function` is tried next.
|
||||
|
||||
pos_label : str or int, default=None
|
||||
The class considered as the positive class when computing
|
||||
the metrics. By default, `estimators.classes_[1]` is
|
||||
considered as the positive class.
|
||||
|
||||
Returns
|
||||
-------
|
||||
y_pred: ndarray of shape (n_samples,)
|
||||
Target scores calculated from the provided response_method
|
||||
and pos_label.
|
||||
|
||||
pos_label: str or int
|
||||
The class considered as the positive class when computing
|
||||
the metrics.
|
||||
"""
|
||||
classification_error = (
|
||||
"Expected 'estimator' to be a binary classifier, but got"
|
||||
f" {estimator.__class__.__name__}"
|
||||
)
|
||||
|
||||
if not is_classifier(estimator):
|
||||
raise ValueError(classification_error)
|
||||
|
||||
prediction_method = _check_classifier_response_method(estimator, response_method)
|
||||
y_pred = prediction_method(X)
|
||||
if pos_label is not None:
|
||||
try:
|
||||
class_idx = estimator.classes_.tolist().index(pos_label)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"The class provided by 'pos_label' is unknown. Got "
|
||||
f"{pos_label} instead of one of {set(estimator.classes_)}"
|
||||
) from e
|
||||
else:
|
||||
class_idx = 1
|
||||
pos_label = estimator.classes_[class_idx]
|
||||
|
||||
if y_pred.ndim != 1: # `predict_proba`
|
||||
y_pred_shape = y_pred.shape[1]
|
||||
if y_pred_shape != 2:
|
||||
raise ValueError(
|
||||
f"{classification_error} fit on multiclass ({y_pred_shape} classes)"
|
||||
" data"
|
||||
)
|
||||
y_pred = y_pred[:, class_idx]
|
||||
elif pos_label == estimator.classes_[0]: # `decision_function`
|
||||
y_pred *= -1
|
||||
|
||||
return y_pred, pos_label
|
||||
@@ -0,0 +1,603 @@
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import confusion_matrix
|
||||
from ...utils import check_matplotlib_support
|
||||
from ...utils import deprecated
|
||||
from ...utils.multiclass import unique_labels
|
||||
from ...base import is_classifier
|
||||
|
||||
|
||||
class ConfusionMatrixDisplay:
|
||||
"""Confusion Matrix visualization.
|
||||
|
||||
It is recommend to use
|
||||
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_estimator` or
|
||||
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_predictions` to
|
||||
create a :class:`ConfusionMatrixDisplay`. All parameters are stored as
|
||||
attributes.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
confusion_matrix : ndarray of shape (n_classes, n_classes)
|
||||
Confusion matrix.
|
||||
|
||||
display_labels : ndarray of shape (n_classes,), default=None
|
||||
Display labels for plot. If None, display labels are set from 0 to
|
||||
`n_classes - 1`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
im_ : matplotlib AxesImage
|
||||
Image representing the confusion matrix.
|
||||
|
||||
text_ : ndarray of shape (n_classes, n_classes), dtype=matplotlib Text, \
|
||||
or None
|
||||
Array of matplotlib axes. `None` if `include_values` is false.
|
||||
|
||||
ax_ : matplotlib Axes
|
||||
Axes with confusion matrix.
|
||||
|
||||
figure_ : matplotlib Figure
|
||||
Figure containing the confusion matrix.
|
||||
|
||||
See Also
|
||||
--------
|
||||
confusion_matrix : Compute Confusion Matrix to evaluate the accuracy of a
|
||||
classification.
|
||||
ConfusionMatrixDisplay.from_estimator : Plot the confusion matrix
|
||||
given an estimator, the data, and the label.
|
||||
ConfusionMatrixDisplay.from_predictions : Plot the confusion matrix
|
||||
given the true and predicted labels.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(X, y,
|
||||
... random_state=0)
|
||||
>>> clf = SVC(random_state=0)
|
||||
>>> clf.fit(X_train, y_train)
|
||||
SVC(random_state=0)
|
||||
>>> predictions = clf.predict(X_test)
|
||||
>>> cm = confusion_matrix(y_test, predictions, labels=clf.classes_)
|
||||
>>> disp = ConfusionMatrixDisplay(confusion_matrix=cm,
|
||||
... display_labels=clf.classes_)
|
||||
>>> disp.plot()
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
|
||||
def __init__(self, confusion_matrix, *, display_labels=None):
|
||||
self.confusion_matrix = confusion_matrix
|
||||
self.display_labels = display_labels
|
||||
|
||||
def plot(
|
||||
self,
|
||||
*,
|
||||
include_values=True,
|
||||
cmap="viridis",
|
||||
xticks_rotation="horizontal",
|
||||
values_format=None,
|
||||
ax=None,
|
||||
colorbar=True,
|
||||
im_kw=None,
|
||||
):
|
||||
"""Plot visualization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
include_values : bool, default=True
|
||||
Includes values in confusion matrix.
|
||||
|
||||
cmap : str or matplotlib Colormap, default='viridis'
|
||||
Colormap recognized by matplotlib.
|
||||
|
||||
xticks_rotation : {'vertical', 'horizontal'} or float, \
|
||||
default='horizontal'
|
||||
Rotation of xtick labels.
|
||||
|
||||
values_format : str, default=None
|
||||
Format specification for values in confusion matrix. If `None`,
|
||||
the format specification is 'd' or '.2g' whichever is shorter.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
colorbar : bool, default=True
|
||||
Whether or not to add a colorbar to the plot.
|
||||
|
||||
im_kw : dict, default=None
|
||||
Dict with keywords passed to `matplotlib.pyplot.imshow` call.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
|
||||
"""
|
||||
check_matplotlib_support("ConfusionMatrixDisplay.plot")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots()
|
||||
else:
|
||||
fig = ax.figure
|
||||
|
||||
cm = self.confusion_matrix
|
||||
n_classes = cm.shape[0]
|
||||
|
||||
default_im_kw = dict(interpolation="nearest", cmap=cmap)
|
||||
im_kw = im_kw or {}
|
||||
im_kw = {**default_im_kw, **im_kw}
|
||||
|
||||
self.im_ = ax.imshow(cm, **im_kw)
|
||||
self.text_ = None
|
||||
cmap_min, cmap_max = self.im_.cmap(0), self.im_.cmap(1.0)
|
||||
|
||||
if include_values:
|
||||
self.text_ = np.empty_like(cm, dtype=object)
|
||||
|
||||
# print text with appropriate color depending on background
|
||||
thresh = (cm.max() + cm.min()) / 2.0
|
||||
|
||||
for i, j in product(range(n_classes), range(n_classes)):
|
||||
color = cmap_max if cm[i, j] < thresh else cmap_min
|
||||
|
||||
if values_format is None:
|
||||
text_cm = format(cm[i, j], ".2g")
|
||||
if cm.dtype.kind != "f":
|
||||
text_d = format(cm[i, j], "d")
|
||||
if len(text_d) < len(text_cm):
|
||||
text_cm = text_d
|
||||
else:
|
||||
text_cm = format(cm[i, j], values_format)
|
||||
|
||||
self.text_[i, j] = ax.text(
|
||||
j, i, text_cm, ha="center", va="center", color=color
|
||||
)
|
||||
|
||||
if self.display_labels is None:
|
||||
display_labels = np.arange(n_classes)
|
||||
else:
|
||||
display_labels = self.display_labels
|
||||
if colorbar:
|
||||
fig.colorbar(self.im_, ax=ax)
|
||||
ax.set(
|
||||
xticks=np.arange(n_classes),
|
||||
yticks=np.arange(n_classes),
|
||||
xticklabels=display_labels,
|
||||
yticklabels=display_labels,
|
||||
ylabel="True label",
|
||||
xlabel="Predicted label",
|
||||
)
|
||||
|
||||
ax.set_ylim((n_classes - 0.5, -0.5))
|
||||
plt.setp(ax.get_xticklabels(), rotation=xticks_rotation)
|
||||
|
||||
self.figure_ = fig
|
||||
self.ax_ = ax
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_estimator(
|
||||
cls,
|
||||
estimator,
|
||||
X,
|
||||
y,
|
||||
*,
|
||||
labels=None,
|
||||
sample_weight=None,
|
||||
normalize=None,
|
||||
display_labels=None,
|
||||
include_values=True,
|
||||
xticks_rotation="horizontal",
|
||||
values_format=None,
|
||||
cmap="viridis",
|
||||
ax=None,
|
||||
colorbar=True,
|
||||
im_kw=None,
|
||||
):
|
||||
"""Plot Confusion Matrix given an estimator and some data.
|
||||
|
||||
Read more in the :ref:`User Guide <confusion_matrix>`.
|
||||
|
||||
.. versionadded:: 1.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator instance
|
||||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
|
||||
in which the last estimator is a classifier.
|
||||
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
Input values.
|
||||
|
||||
y : array-like of shape (n_samples,)
|
||||
Target values.
|
||||
|
||||
labels : array-like of shape (n_classes,), default=None
|
||||
List of labels to index the confusion matrix. This may be used to
|
||||
reorder or select a subset of labels. If `None` is given, those
|
||||
that appear at least once in `y_true` or `y_pred` are used in
|
||||
sorted order.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
normalize : {'true', 'pred', 'all'}, default=None
|
||||
Either to normalize the counts display in the matrix:
|
||||
|
||||
- if `'true'`, the confusion matrix is normalized over the true
|
||||
conditions (e.g. rows);
|
||||
- if `'pred'`, the confusion matrix is normalized over the
|
||||
predicted conditions (e.g. columns);
|
||||
- if `'all'`, the confusion matrix is normalized by the total
|
||||
number of samples;
|
||||
- if `None` (default), the confusion matrix will not be normalized.
|
||||
|
||||
display_labels : array-like of shape (n_classes,), default=None
|
||||
Target names used for plotting. By default, `labels` will be used
|
||||
if it is defined, otherwise the unique labels of `y_true` and
|
||||
`y_pred` will be used.
|
||||
|
||||
include_values : bool, default=True
|
||||
Includes values in confusion matrix.
|
||||
|
||||
xticks_rotation : {'vertical', 'horizontal'} or float, \
|
||||
default='horizontal'
|
||||
Rotation of xtick labels.
|
||||
|
||||
values_format : str, default=None
|
||||
Format specification for values in confusion matrix. If `None`, the
|
||||
format specification is 'd' or '.2g' whichever is shorter.
|
||||
|
||||
cmap : str or matplotlib Colormap, default='viridis'
|
||||
Colormap recognized by matplotlib.
|
||||
|
||||
ax : matplotlib Axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
colorbar : bool, default=True
|
||||
Whether or not to add a colorbar to the plot.
|
||||
|
||||
im_kw : dict, default=None
|
||||
Dict with keywords passed to `matplotlib.pyplot.imshow` call.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
|
||||
|
||||
See Also
|
||||
--------
|
||||
ConfusionMatrixDisplay.from_predictions : Plot the confusion matrix
|
||||
given the true and predicted labels.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import ConfusionMatrixDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, random_state=0)
|
||||
>>> clf = SVC(random_state=0)
|
||||
>>> clf.fit(X_train, y_train)
|
||||
SVC(random_state=0)
|
||||
>>> ConfusionMatrixDisplay.from_estimator(
|
||||
... clf, X_test, y_test)
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
method_name = f"{cls.__name__}.from_estimator"
|
||||
check_matplotlib_support(method_name)
|
||||
if not is_classifier(estimator):
|
||||
raise ValueError(f"{method_name} only supports classifiers")
|
||||
y_pred = estimator.predict(X)
|
||||
|
||||
return cls.from_predictions(
|
||||
y,
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
labels=labels,
|
||||
normalize=normalize,
|
||||
display_labels=display_labels,
|
||||
include_values=include_values,
|
||||
cmap=cmap,
|
||||
ax=ax,
|
||||
xticks_rotation=xticks_rotation,
|
||||
values_format=values_format,
|
||||
colorbar=colorbar,
|
||||
im_kw=im_kw,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_predictions(
|
||||
cls,
|
||||
y_true,
|
||||
y_pred,
|
||||
*,
|
||||
labels=None,
|
||||
sample_weight=None,
|
||||
normalize=None,
|
||||
display_labels=None,
|
||||
include_values=True,
|
||||
xticks_rotation="horizontal",
|
||||
values_format=None,
|
||||
cmap="viridis",
|
||||
ax=None,
|
||||
colorbar=True,
|
||||
im_kw=None,
|
||||
):
|
||||
"""Plot Confusion Matrix given true and predicted labels.
|
||||
|
||||
Read more in the :ref:`User Guide <confusion_matrix>`.
|
||||
|
||||
.. versionadded:: 1.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y_true : array-like of shape (n_samples,)
|
||||
True labels.
|
||||
|
||||
y_pred : array-like of shape (n_samples,)
|
||||
The predicted labels given by the method `predict` of an
|
||||
classifier.
|
||||
|
||||
labels : array-like of shape (n_classes,), default=None
|
||||
List of labels to index the confusion matrix. This may be used to
|
||||
reorder or select a subset of labels. If `None` is given, those
|
||||
that appear at least once in `y_true` or `y_pred` are used in
|
||||
sorted order.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
normalize : {'true', 'pred', 'all'}, default=None
|
||||
Either to normalize the counts display in the matrix:
|
||||
|
||||
- if `'true'`, the confusion matrix is normalized over the true
|
||||
conditions (e.g. rows);
|
||||
- if `'pred'`, the confusion matrix is normalized over the
|
||||
predicted conditions (e.g. columns);
|
||||
- if `'all'`, the confusion matrix is normalized by the total
|
||||
number of samples;
|
||||
- if `None` (default), the confusion matrix will not be normalized.
|
||||
|
||||
display_labels : array-like of shape (n_classes,), default=None
|
||||
Target names used for plotting. By default, `labels` will be used
|
||||
if it is defined, otherwise the unique labels of `y_true` and
|
||||
`y_pred` will be used.
|
||||
|
||||
include_values : bool, default=True
|
||||
Includes values in confusion matrix.
|
||||
|
||||
xticks_rotation : {'vertical', 'horizontal'} or float, \
|
||||
default='horizontal'
|
||||
Rotation of xtick labels.
|
||||
|
||||
values_format : str, default=None
|
||||
Format specification for values in confusion matrix. If `None`, the
|
||||
format specification is 'd' or '.2g' whichever is shorter.
|
||||
|
||||
cmap : str or matplotlib Colormap, default='viridis'
|
||||
Colormap recognized by matplotlib.
|
||||
|
||||
ax : matplotlib Axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
colorbar : bool, default=True
|
||||
Whether or not to add a colorbar to the plot.
|
||||
|
||||
im_kw : dict, default=None
|
||||
Dict with keywords passed to `matplotlib.pyplot.imshow` call.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
|
||||
|
||||
See Also
|
||||
--------
|
||||
ConfusionMatrixDisplay.from_estimator : Plot the confusion matrix
|
||||
given an estimator, the data, and the label.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import ConfusionMatrixDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, random_state=0)
|
||||
>>> clf = SVC(random_state=0)
|
||||
>>> clf.fit(X_train, y_train)
|
||||
SVC(random_state=0)
|
||||
>>> y_pred = clf.predict(X_test)
|
||||
>>> ConfusionMatrixDisplay.from_predictions(
|
||||
... y_test, y_pred)
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
check_matplotlib_support(f"{cls.__name__}.from_predictions")
|
||||
|
||||
if display_labels is None:
|
||||
if labels is None:
|
||||
display_labels = unique_labels(y_true, y_pred)
|
||||
else:
|
||||
display_labels = labels
|
||||
|
||||
cm = confusion_matrix(
|
||||
y_true,
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
labels=labels,
|
||||
normalize=normalize,
|
||||
)
|
||||
|
||||
disp = cls(confusion_matrix=cm, display_labels=display_labels)
|
||||
|
||||
return disp.plot(
|
||||
include_values=include_values,
|
||||
cmap=cmap,
|
||||
ax=ax,
|
||||
xticks_rotation=xticks_rotation,
|
||||
values_format=values_format,
|
||||
colorbar=colorbar,
|
||||
im_kw=im_kw,
|
||||
)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"Function `plot_confusion_matrix` is deprecated in 1.0 and will be "
|
||||
"removed in 1.2. Use one of the class methods: "
|
||||
"ConfusionMatrixDisplay.from_predictions or "
|
||||
"ConfusionMatrixDisplay.from_estimator."
|
||||
)
|
||||
def plot_confusion_matrix(
|
||||
estimator,
|
||||
X,
|
||||
y_true,
|
||||
*,
|
||||
labels=None,
|
||||
sample_weight=None,
|
||||
normalize=None,
|
||||
display_labels=None,
|
||||
include_values=True,
|
||||
xticks_rotation="horizontal",
|
||||
values_format=None,
|
||||
cmap="viridis",
|
||||
ax=None,
|
||||
colorbar=True,
|
||||
):
|
||||
"""Plot Confusion Matrix.
|
||||
|
||||
`plot_confusion_matrix` is deprecated in 1.0 and will be removed in
|
||||
1.2. Use one of the following class methods:
|
||||
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_predictions` or
|
||||
:func:`~sklearn.metrics.ConfusionMatrixDisplay.from_estimator`.
|
||||
|
||||
Read more in the :ref:`User Guide <confusion_matrix>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator instance
|
||||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
|
||||
in which the last estimator is a classifier.
|
||||
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
Input values.
|
||||
|
||||
y_true : array-like of shape (n_samples,)
|
||||
Target values.
|
||||
|
||||
labels : array-like of shape (n_classes,), default=None
|
||||
List of labels to index the matrix. This may be used to reorder or
|
||||
select a subset of labels. If `None` is given, those that appear at
|
||||
least once in `y_true` or `y_pred` are used in sorted order.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
normalize : {'true', 'pred', 'all'}, default=None
|
||||
Either to normalize the counts display in the matrix:
|
||||
|
||||
- if `'true'`, the confusion matrix is normalized over the true
|
||||
conditions (e.g. rows);
|
||||
- if `'pred'`, the confusion matrix is normalized over the
|
||||
predicted conditions (e.g. columns);
|
||||
- if `'all'`, the confusion matrix is normalized by the total
|
||||
number of samples;
|
||||
- if `None` (default), the confusion matrix will not be normalized.
|
||||
|
||||
display_labels : array-like of shape (n_classes,), default=None
|
||||
Target names used for plotting. By default, `labels` will be used if
|
||||
it is defined, otherwise the unique labels of `y_true` and `y_pred`
|
||||
will be used.
|
||||
|
||||
include_values : bool, default=True
|
||||
Includes values in confusion matrix.
|
||||
|
||||
xticks_rotation : {'vertical', 'horizontal'} or float, \
|
||||
default='horizontal'
|
||||
Rotation of xtick labels.
|
||||
|
||||
values_format : str, default=None
|
||||
Format specification for values in confusion matrix. If `None`,
|
||||
the format specification is 'd' or '.2g' whichever is shorter.
|
||||
|
||||
cmap : str or matplotlib Colormap, default='viridis'
|
||||
Colormap recognized by matplotlib.
|
||||
|
||||
ax : matplotlib Axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
colorbar : bool, default=True
|
||||
Whether or not to add a colorbar to the plot.
|
||||
|
||||
.. versionadded:: 0.24
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
|
||||
Object that stores computed values.
|
||||
|
||||
See Also
|
||||
--------
|
||||
confusion_matrix : Compute Confusion Matrix to evaluate the accuracy of a
|
||||
classification.
|
||||
ConfusionMatrixDisplay : Confusion Matrix visualization.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import plot_confusion_matrix
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, random_state=0)
|
||||
>>> clf = SVC(random_state=0)
|
||||
>>> clf.fit(X_train, y_train)
|
||||
SVC(random_state=0)
|
||||
>>> plot_confusion_matrix(clf, X_test, y_test) # doctest: +SKIP
|
||||
>>> plt.show()
|
||||
"""
|
||||
check_matplotlib_support("plot_confusion_matrix")
|
||||
|
||||
if not is_classifier(estimator):
|
||||
raise ValueError("plot_confusion_matrix only supports classifiers")
|
||||
|
||||
y_pred = estimator.predict(X)
|
||||
cm = confusion_matrix(
|
||||
y_true, y_pred, sample_weight=sample_weight, labels=labels, normalize=normalize
|
||||
)
|
||||
|
||||
if display_labels is None:
|
||||
if labels is None:
|
||||
display_labels = unique_labels(y_true, y_pred)
|
||||
else:
|
||||
display_labels = labels
|
||||
|
||||
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=display_labels)
|
||||
return disp.plot(
|
||||
include_values=include_values,
|
||||
cmap=cmap,
|
||||
ax=ax,
|
||||
xticks_rotation=xticks_rotation,
|
||||
values_format=values_format,
|
||||
colorbar=colorbar,
|
||||
)
|
||||
@@ -0,0 +1,472 @@
|
||||
import scipy as sp
|
||||
|
||||
from .base import _get_response
|
||||
|
||||
from .. import det_curve
|
||||
from .._base import _check_pos_label_consistency
|
||||
|
||||
from ...utils import check_matplotlib_support
|
||||
from ...utils import deprecated
|
||||
|
||||
|
||||
class DetCurveDisplay:
|
||||
"""DET curve visualization.
|
||||
|
||||
It is recommend to use :func:`~sklearn.metrics.DetCurveDisplay.from_estimator`
|
||||
or :func:`~sklearn.metrics.DetCurveDisplay.from_predictions` to create a
|
||||
visualizer. All parameters are stored as attributes.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
.. versionadded:: 0.24
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fpr : ndarray
|
||||
False positive rate.
|
||||
|
||||
fnr : ndarray
|
||||
False negative rate.
|
||||
|
||||
estimator_name : str, default=None
|
||||
Name of estimator. If None, the estimator name is not shown.
|
||||
|
||||
pos_label : str or int, default=None
|
||||
The label of the positive class.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
line_ : matplotlib Artist
|
||||
DET Curve.
|
||||
|
||||
ax_ : matplotlib Axes
|
||||
Axes with DET Curve.
|
||||
|
||||
figure_ : matplotlib Figure
|
||||
Figure containing the curve.
|
||||
|
||||
See Also
|
||||
--------
|
||||
det_curve : Compute error rates for different probability thresholds.
|
||||
DetCurveDisplay.from_estimator : Plot DET curve given an estimator and
|
||||
some data.
|
||||
DetCurveDisplay.from_predictions : Plot DET curve given the true and
|
||||
predicted labels.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import det_curve, DetCurveDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(n_samples=1000, random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, test_size=0.4, random_state=0)
|
||||
>>> clf = SVC(random_state=0).fit(X_train, y_train)
|
||||
>>> y_pred = clf.decision_function(X_test)
|
||||
>>> fpr, fnr, _ = det_curve(y_test, y_pred)
|
||||
>>> display = DetCurveDisplay(
|
||||
... fpr=fpr, fnr=fnr, estimator_name="SVC"
|
||||
... )
|
||||
>>> display.plot()
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
|
||||
def __init__(self, *, fpr, fnr, estimator_name=None, pos_label=None):
|
||||
self.fpr = fpr
|
||||
self.fnr = fnr
|
||||
self.estimator_name = estimator_name
|
||||
self.pos_label = pos_label
|
||||
|
||||
@classmethod
|
||||
def from_estimator(
|
||||
cls,
|
||||
estimator,
|
||||
X,
|
||||
y,
|
||||
*,
|
||||
sample_weight=None,
|
||||
response_method="auto",
|
||||
pos_label=None,
|
||||
name=None,
|
||||
ax=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Plot DET curve given an estimator and data.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
.. versionadded:: 1.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator instance
|
||||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
|
||||
in which the last estimator is a classifier.
|
||||
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
Input values.
|
||||
|
||||
y : array-like of shape (n_samples,)
|
||||
Target values.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
response_method : {'predict_proba', 'decision_function', 'auto'} \
|
||||
default='auto'
|
||||
Specifies whether to use :term:`predict_proba` or
|
||||
:term:`decision_function` as the predicted target response. If set
|
||||
to 'auto', :term:`predict_proba` is tried first and if it does not
|
||||
exist :term:`decision_function` is tried next.
|
||||
|
||||
pos_label : str or int, default=None
|
||||
The label of the positive class. When `pos_label=None`, if `y_true`
|
||||
is in {-1, 1} or {0, 1}, `pos_label` is set to 1, otherwise an
|
||||
error will be raised.
|
||||
|
||||
name : str, default=None
|
||||
Name of DET curve for labeling. If `None`, use the name of the
|
||||
estimator.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
**kwargs : dict
|
||||
Additional keywords arguments passed to matplotlib `plot` function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.DetCurveDisplay`
|
||||
Object that stores computed values.
|
||||
|
||||
See Also
|
||||
--------
|
||||
det_curve : Compute error rates for different probability thresholds.
|
||||
DetCurveDisplay.from_predictions : Plot DET curve given the true and
|
||||
predicted labels.
|
||||
plot_roc_curve : Plot Receiver operating characteristic (ROC) curve.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import DetCurveDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(n_samples=1000, random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, test_size=0.4, random_state=0)
|
||||
>>> clf = SVC(random_state=0).fit(X_train, y_train)
|
||||
>>> DetCurveDisplay.from_estimator(
|
||||
... clf, X_test, y_test)
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
check_matplotlib_support(f"{cls.__name__}.from_estimator")
|
||||
|
||||
name = estimator.__class__.__name__ if name is None else name
|
||||
|
||||
y_pred, pos_label = _get_response(
|
||||
X,
|
||||
estimator,
|
||||
response_method,
|
||||
pos_label=pos_label,
|
||||
)
|
||||
|
||||
return cls.from_predictions(
|
||||
y_true=y,
|
||||
y_pred=y_pred,
|
||||
sample_weight=sample_weight,
|
||||
name=name,
|
||||
ax=ax,
|
||||
pos_label=pos_label,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_predictions(
|
||||
cls,
|
||||
y_true,
|
||||
y_pred,
|
||||
*,
|
||||
sample_weight=None,
|
||||
pos_label=None,
|
||||
name=None,
|
||||
ax=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Plot DET curve given the true and
|
||||
predicted labels.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
.. versionadded:: 1.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y_true : array-like of shape (n_samples,)
|
||||
True labels.
|
||||
|
||||
y_pred : array-like of shape (n_samples,)
|
||||
Target scores, can either be probability estimates of the positive
|
||||
class, confidence values, or non-thresholded measure of decisions
|
||||
(as returned by `decision_function` on some classifiers).
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
pos_label : str or int, default=None
|
||||
The label of the positive class. When `pos_label=None`, if `y_true`
|
||||
is in {-1, 1} or {0, 1}, `pos_label` is set to 1, otherwise an
|
||||
error will be raised.
|
||||
|
||||
name : str, default=None
|
||||
Name of DET curve for labeling. If `None`, name will be set to
|
||||
`"Classifier"`.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
**kwargs : dict
|
||||
Additional keywords arguments passed to matplotlib `plot` function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.DetCurveDisplay`
|
||||
Object that stores computed values.
|
||||
|
||||
See Also
|
||||
--------
|
||||
det_curve : Compute error rates for different probability thresholds.
|
||||
DetCurveDisplay.from_estimator : Plot DET curve given an estimator and
|
||||
some data.
|
||||
plot_roc_curve : Plot Receiver operating characteristic (ROC) curve.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import DetCurveDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(n_samples=1000, random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, test_size=0.4, random_state=0)
|
||||
>>> clf = SVC(random_state=0).fit(X_train, y_train)
|
||||
>>> y_pred = clf.decision_function(X_test)
|
||||
>>> DetCurveDisplay.from_predictions(
|
||||
... y_test, y_pred)
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
check_matplotlib_support(f"{cls.__name__}.from_predictions")
|
||||
fpr, fnr, _ = det_curve(
|
||||
y_true,
|
||||
y_pred,
|
||||
pos_label=pos_label,
|
||||
sample_weight=sample_weight,
|
||||
)
|
||||
|
||||
pos_label = _check_pos_label_consistency(pos_label, y_true)
|
||||
name = "Classifier" if name is None else name
|
||||
|
||||
viz = DetCurveDisplay(
|
||||
fpr=fpr,
|
||||
fnr=fnr,
|
||||
estimator_name=name,
|
||||
pos_label=pos_label,
|
||||
)
|
||||
|
||||
return viz.plot(ax=ax, name=name, **kwargs)
|
||||
|
||||
def plot(self, ax=None, *, name=None, **kwargs):
|
||||
"""Plot visualization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
name : str, default=None
|
||||
Name of DET curve for labeling. If `None`, use `estimator_name` if
|
||||
it is not `None`, otherwise no labeling is shown.
|
||||
|
||||
**kwargs : dict
|
||||
Additional keywords arguments passed to matplotlib `plot` function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.plot.DetCurveDisplay`
|
||||
Object that stores computed values.
|
||||
"""
|
||||
check_matplotlib_support("DetCurveDisplay.plot")
|
||||
|
||||
name = self.estimator_name if name is None else name
|
||||
line_kwargs = {} if name is None else {"label": name}
|
||||
line_kwargs.update(**kwargs)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if ax is None:
|
||||
_, ax = plt.subplots()
|
||||
|
||||
(self.line_,) = ax.plot(
|
||||
sp.stats.norm.ppf(self.fpr),
|
||||
sp.stats.norm.ppf(self.fnr),
|
||||
**line_kwargs,
|
||||
)
|
||||
info_pos_label = (
|
||||
f" (Positive label: {self.pos_label})" if self.pos_label is not None else ""
|
||||
)
|
||||
|
||||
xlabel = "False Positive Rate" + info_pos_label
|
||||
ylabel = "False Negative Rate" + info_pos_label
|
||||
ax.set(xlabel=xlabel, ylabel=ylabel)
|
||||
|
||||
if "label" in line_kwargs:
|
||||
ax.legend(loc="lower right")
|
||||
|
||||
ticks = [0.001, 0.01, 0.05, 0.20, 0.5, 0.80, 0.95, 0.99, 0.999]
|
||||
tick_locations = sp.stats.norm.ppf(ticks)
|
||||
tick_labels = [
|
||||
"{:.0%}".format(s) if (100 * s).is_integer() else "{:.1%}".format(s)
|
||||
for s in ticks
|
||||
]
|
||||
ax.set_xticks(tick_locations)
|
||||
ax.set_xticklabels(tick_labels)
|
||||
ax.set_xlim(-3, 3)
|
||||
ax.set_yticks(tick_locations)
|
||||
ax.set_yticklabels(tick_labels)
|
||||
ax.set_ylim(-3, 3)
|
||||
|
||||
self.ax_ = ax
|
||||
self.figure_ = ax.figure
|
||||
return self
|
||||
|
||||
|
||||
@deprecated(
|
||||
"Function plot_det_curve is deprecated in 1.0 and will be "
|
||||
"removed in 1.2. Use one of the class methods: "
|
||||
"DetCurveDisplay.from_predictions or "
|
||||
"DetCurveDisplay.from_estimator."
|
||||
)
|
||||
def plot_det_curve(
|
||||
estimator,
|
||||
X,
|
||||
y,
|
||||
*,
|
||||
sample_weight=None,
|
||||
response_method="auto",
|
||||
name=None,
|
||||
ax=None,
|
||||
pos_label=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Plot detection error tradeoff (DET) curve.
|
||||
|
||||
Extra keyword arguments will be passed to matplotlib's `plot`.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
.. versionadded:: 0.24
|
||||
|
||||
.. deprecated:: 1.0
|
||||
`plot_det_curve` is deprecated in 1.0 and will be removed in
|
||||
1.2. Use one of the following class methods:
|
||||
:func:`~sklearn.metrics.DetCurveDisplay.from_predictions` or
|
||||
:func:`~sklearn.metrics.DetCurveDisplay.from_estimator`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator instance
|
||||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
|
||||
in which the last estimator is a classifier.
|
||||
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
Input values.
|
||||
|
||||
y : array-like of shape (n_samples,)
|
||||
Target values.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
response_method : {'predict_proba', 'decision_function', 'auto'} \
|
||||
default='auto'
|
||||
Specifies whether to use :term:`predict_proba` or
|
||||
:term:`decision_function` as the predicted target response. If set to
|
||||
'auto', :term:`predict_proba` is tried first and if it does not exist
|
||||
:term:`decision_function` is tried next.
|
||||
|
||||
name : str, default=None
|
||||
Name of DET curve for labeling. If `None`, use the name of the
|
||||
estimator.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is created.
|
||||
|
||||
pos_label : str or int, default=None
|
||||
The label of the positive class.
|
||||
When `pos_label=None`, if `y_true` is in {-1, 1} or {0, 1},
|
||||
`pos_label` is set to 1, otherwise an error will be raised.
|
||||
|
||||
**kwargs : dict
|
||||
Additional keywords arguments passed to matplotlib `plot` function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.DetCurveDisplay`
|
||||
Object that stores computed values.
|
||||
|
||||
See Also
|
||||
--------
|
||||
det_curve : Compute error rates for different probability thresholds.
|
||||
DetCurveDisplay : DET curve visualization.
|
||||
DetCurveDisplay.from_estimator : Plot DET curve given an estimator and
|
||||
some data.
|
||||
DetCurveDisplay.from_predictions : Plot DET curve given the true and
|
||||
predicted labels.
|
||||
RocCurveDisplay.from_estimator : Plot Receiver Operating Characteristic
|
||||
(ROC) curve given an estimator and some data.
|
||||
RocCurveDisplay.from_predictions : Plot Receiver Operating Characteristic
|
||||
(ROC) curve given the true and predicted values.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import plot_det_curve
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(n_samples=1000, random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, test_size=0.4, random_state=0)
|
||||
>>> clf = SVC(random_state=0).fit(X_train, y_train)
|
||||
>>> plot_det_curve(clf, X_test, y_test) # doctest: +SKIP
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
check_matplotlib_support("plot_det_curve")
|
||||
|
||||
y_pred, pos_label = _get_response(
|
||||
X, estimator, response_method, pos_label=pos_label
|
||||
)
|
||||
|
||||
fpr, fnr, _ = det_curve(
|
||||
y,
|
||||
y_pred,
|
||||
pos_label=pos_label,
|
||||
sample_weight=sample_weight,
|
||||
)
|
||||
|
||||
name = estimator.__class__.__name__ if name is None else name
|
||||
|
||||
viz = DetCurveDisplay(fpr=fpr, fnr=fnr, estimator_name=name, pos_label=pos_label)
|
||||
|
||||
return viz.plot(ax=ax, name=name, **kwargs)
|
||||
@@ -0,0 +1,499 @@
|
||||
from sklearn.base import is_classifier
|
||||
from .base import _get_response
|
||||
|
||||
from .. import average_precision_score
|
||||
from .. import precision_recall_curve
|
||||
from .._base import _check_pos_label_consistency
|
||||
from .._classification import check_consistent_length
|
||||
|
||||
from ...utils import check_matplotlib_support, deprecated
|
||||
|
||||
|
||||
class PrecisionRecallDisplay:
|
||||
"""Precision Recall visualization.
|
||||
|
||||
It is recommend to use
|
||||
:func:`~sklearn.metrics.PrecisionRecallDisplay.from_estimator` or
|
||||
:func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions` to create
|
||||
a :class:`~sklearn.metrics.PredictionRecallDisplay`. All parameters are
|
||||
stored as attributes.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
precision : ndarray
|
||||
Precision values.
|
||||
|
||||
recall : ndarray
|
||||
Recall values.
|
||||
|
||||
average_precision : float, default=None
|
||||
Average precision. If None, the average precision is not shown.
|
||||
|
||||
estimator_name : str, default=None
|
||||
Name of estimator. If None, then the estimator name is not shown.
|
||||
|
||||
pos_label : str or int, default=None
|
||||
The class considered as the positive class. If None, the class will not
|
||||
be shown in the legend.
|
||||
|
||||
.. versionadded:: 0.24
|
||||
|
||||
Attributes
|
||||
----------
|
||||
line_ : matplotlib Artist
|
||||
Precision recall curve.
|
||||
|
||||
ax_ : matplotlib Axes
|
||||
Axes with precision recall curve.
|
||||
|
||||
figure_ : matplotlib Figure
|
||||
Figure containing the curve.
|
||||
|
||||
See Also
|
||||
--------
|
||||
precision_recall_curve : Compute precision-recall pairs for different
|
||||
probability thresholds.
|
||||
PrecisionRecallDisplay.from_estimator : Plot Precision Recall Curve given
|
||||
a binary classifier.
|
||||
PrecisionRecallDisplay.from_predictions : Plot Precision Recall Curve
|
||||
using predictions from a binary classifier.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The average precision (cf. :func:`~sklearn.metrics.average_precision`) in
|
||||
scikit-learn is computed without any interpolation. To be consistent with
|
||||
this metric, the precision-recall curve is plotted without any
|
||||
interpolation as well (step-wise style).
|
||||
|
||||
You can change this style by passing the keyword argument
|
||||
`drawstyle="default"` in :meth:`plot`, :meth:`from_estimator`, or
|
||||
:meth:`from_predictions`. However, the curve will not be strictly
|
||||
consistent with the reported average precision.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import (precision_recall_curve,
|
||||
... PrecisionRecallDisplay)
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(X, y,
|
||||
... random_state=0)
|
||||
>>> clf = SVC(random_state=0)
|
||||
>>> clf.fit(X_train, y_train)
|
||||
SVC(random_state=0)
|
||||
>>> predictions = clf.predict(X_test)
|
||||
>>> precision, recall, _ = precision_recall_curve(y_test, predictions)
|
||||
>>> disp = PrecisionRecallDisplay(precision=precision, recall=recall)
|
||||
>>> disp.plot()
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
precision,
|
||||
recall,
|
||||
*,
|
||||
average_precision=None,
|
||||
estimator_name=None,
|
||||
pos_label=None,
|
||||
):
|
||||
self.estimator_name = estimator_name
|
||||
self.precision = precision
|
||||
self.recall = recall
|
||||
self.average_precision = average_precision
|
||||
self.pos_label = pos_label
|
||||
|
||||
def plot(self, ax=None, *, name=None, **kwargs):
|
||||
"""Plot visualization.
|
||||
|
||||
Extra keyword arguments will be passed to matplotlib's `plot`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ax : Matplotlib Axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
name : str, default=None
|
||||
Name of precision recall curve for labeling. If `None`, use
|
||||
`estimator_name` if not `None`, otherwise no labeling is shown.
|
||||
|
||||
**kwargs : dict
|
||||
Keyword arguments to be passed to matplotlib's `plot`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
|
||||
Object that stores computed values.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The average precision (cf. :func:`~sklearn.metrics.average_precision`)
|
||||
in scikit-learn is computed without any interpolation. To be consistent
|
||||
with this metric, the precision-recall curve is plotted without any
|
||||
interpolation as well (step-wise style).
|
||||
|
||||
You can change this style by passing the keyword argument
|
||||
`drawstyle="default"`. However, the curve will not be strictly
|
||||
consistent with the reported average precision.
|
||||
"""
|
||||
check_matplotlib_support("PrecisionRecallDisplay.plot")
|
||||
|
||||
name = self.estimator_name if name is None else name
|
||||
|
||||
line_kwargs = {"drawstyle": "steps-post"}
|
||||
if self.average_precision is not None and name is not None:
|
||||
line_kwargs["label"] = f"{name} (AP = {self.average_precision:0.2f})"
|
||||
elif self.average_precision is not None:
|
||||
line_kwargs["label"] = f"AP = {self.average_precision:0.2f}"
|
||||
elif name is not None:
|
||||
line_kwargs["label"] = name
|
||||
line_kwargs.update(**kwargs)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
(self.line_,) = ax.plot(self.recall, self.precision, **line_kwargs)
|
||||
info_pos_label = (
|
||||
f" (Positive label: {self.pos_label})" if self.pos_label is not None else ""
|
||||
)
|
||||
|
||||
xlabel = "Recall" + info_pos_label
|
||||
ylabel = "Precision" + info_pos_label
|
||||
ax.set(xlabel=xlabel, ylabel=ylabel)
|
||||
|
||||
if "label" in line_kwargs:
|
||||
ax.legend(loc="lower left")
|
||||
|
||||
self.ax_ = ax
|
||||
self.figure_ = ax.figure
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_estimator(
|
||||
cls,
|
||||
estimator,
|
||||
X,
|
||||
y,
|
||||
*,
|
||||
sample_weight=None,
|
||||
pos_label=None,
|
||||
response_method="auto",
|
||||
name=None,
|
||||
ax=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Plot precision-recall curve given an estimator and some data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator instance
|
||||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
|
||||
in which the last estimator is a classifier.
|
||||
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
Input values.
|
||||
|
||||
y : array-like of shape (n_samples,)
|
||||
Target values.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
pos_label : str or int, default=None
|
||||
The class considered as the positive class when computing the
|
||||
precision and recall metrics. By default, `estimators.classes_[1]`
|
||||
is considered as the positive class.
|
||||
|
||||
response_method : {'predict_proba', 'decision_function', 'auto'}, \
|
||||
default='auto'
|
||||
Specifies whether to use :term:`predict_proba` or
|
||||
:term:`decision_function` as the target response. If set to 'auto',
|
||||
:term:`predict_proba` is tried first and if it does not exist
|
||||
:term:`decision_function` is tried next.
|
||||
|
||||
name : str, default=None
|
||||
Name for labeling curve. If `None`, no name is used.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is created.
|
||||
|
||||
**kwargs : dict
|
||||
Keyword arguments to be passed to matplotlib's `plot`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
|
||||
|
||||
See Also
|
||||
--------
|
||||
PrecisionRecallDisplay.from_predictions : Plot precision-recall curve
|
||||
using estimated probabilities or output of decision function.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The average precision (cf. :func:`~sklearn.metrics.average_precision`)
|
||||
in scikit-learn is computed without any interpolation. To be consistent
|
||||
with this metric, the precision-recall curve is plotted without any
|
||||
interpolation as well (step-wise style).
|
||||
|
||||
You can change this style by passing the keyword argument
|
||||
`drawstyle="default"`. However, the curve will not be strictly
|
||||
consistent with the reported average precision.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import PrecisionRecallDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.linear_model import LogisticRegression
|
||||
>>> X, y = make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, random_state=0)
|
||||
>>> clf = LogisticRegression()
|
||||
>>> clf.fit(X_train, y_train)
|
||||
LogisticRegression()
|
||||
>>> PrecisionRecallDisplay.from_estimator(
|
||||
... clf, X_test, y_test)
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
method_name = f"{cls.__name__}.from_estimator"
|
||||
check_matplotlib_support(method_name)
|
||||
if not is_classifier(estimator):
|
||||
raise ValueError(f"{method_name} only supports classifiers")
|
||||
y_pred, pos_label = _get_response(
|
||||
X,
|
||||
estimator,
|
||||
response_method,
|
||||
pos_label=pos_label,
|
||||
)
|
||||
|
||||
name = name if name is not None else estimator.__class__.__name__
|
||||
|
||||
return cls.from_predictions(
|
||||
y,
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
name=name,
|
||||
pos_label=pos_label,
|
||||
ax=ax,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_predictions(
|
||||
cls,
|
||||
y_true,
|
||||
y_pred,
|
||||
*,
|
||||
sample_weight=None,
|
||||
pos_label=None,
|
||||
name=None,
|
||||
ax=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Plot precision-recall curve given binary class predictions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y_true : array-like of shape (n_samples,)
|
||||
True binary labels.
|
||||
|
||||
y_pred : array-like of shape (n_samples,)
|
||||
Estimated probabilities or output of decision function.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
pos_label : str or int, default=None
|
||||
The class considered as the positive class when computing the
|
||||
precision and recall metrics.
|
||||
|
||||
name : str, default=None
|
||||
Name for labeling curve. If `None`, name will be set to
|
||||
`"Classifier"`.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is created.
|
||||
|
||||
**kwargs : dict
|
||||
Keyword arguments to be passed to matplotlib's `plot`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
|
||||
|
||||
See Also
|
||||
--------
|
||||
PrecisionRecallDisplay.from_estimator : Plot precision-recall curve
|
||||
using an estimator.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The average precision (cf. :func:`~sklearn.metrics.average_precision`)
|
||||
in scikit-learn is computed without any interpolation. To be consistent
|
||||
with this metric, the precision-recall curve is plotted without any
|
||||
interpolation as well (step-wise style).
|
||||
|
||||
You can change this style by passing the keyword argument
|
||||
`drawstyle="default"`. However, the curve will not be strictly
|
||||
consistent with the reported average precision.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import PrecisionRecallDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.linear_model import LogisticRegression
|
||||
>>> X, y = make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, random_state=0)
|
||||
>>> clf = LogisticRegression()
|
||||
>>> clf.fit(X_train, y_train)
|
||||
LogisticRegression()
|
||||
>>> y_pred = clf.predict_proba(X_test)[:, 1]
|
||||
>>> PrecisionRecallDisplay.from_predictions(
|
||||
... y_test, y_pred)
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
check_matplotlib_support(f"{cls.__name__}.from_predictions")
|
||||
|
||||
check_consistent_length(y_true, y_pred, sample_weight)
|
||||
pos_label = _check_pos_label_consistency(pos_label, y_true)
|
||||
|
||||
precision, recall, _ = precision_recall_curve(
|
||||
y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight
|
||||
)
|
||||
average_precision = average_precision_score(
|
||||
y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight
|
||||
)
|
||||
|
||||
name = name if name is not None else "Classifier"
|
||||
|
||||
viz = PrecisionRecallDisplay(
|
||||
precision=precision,
|
||||
recall=recall,
|
||||
average_precision=average_precision,
|
||||
estimator_name=name,
|
||||
pos_label=pos_label,
|
||||
)
|
||||
|
||||
return viz.plot(ax=ax, name=name, **kwargs)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"Function `plot_precision_recall_curve` is deprecated in 1.0 and will be "
|
||||
"removed in 1.2. Use one of the class methods: "
|
||||
"PrecisionRecallDisplay.from_predictions or "
|
||||
"PrecisionRecallDisplay.from_estimator."
|
||||
)
|
||||
def plot_precision_recall_curve(
|
||||
estimator,
|
||||
X,
|
||||
y,
|
||||
*,
|
||||
sample_weight=None,
|
||||
response_method="auto",
|
||||
name=None,
|
||||
ax=None,
|
||||
pos_label=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Plot Precision Recall Curve for binary classifiers.
|
||||
|
||||
Extra keyword arguments will be passed to matplotlib's `plot`.
|
||||
|
||||
Read more in the :ref:`User Guide <precision_recall_f_measure_metrics>`.
|
||||
|
||||
.. deprecated:: 1.0
|
||||
`plot_precision_recall_curve` is deprecated in 1.0 and will be removed in
|
||||
1.2. Use one of the following class methods:
|
||||
:func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions` or
|
||||
:func:`~sklearn.metrics.PrecisionRecallDisplay.from_estimator`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator instance
|
||||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
|
||||
in which the last estimator is a classifier.
|
||||
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
Input values.
|
||||
|
||||
y : array-like of shape (n_samples,)
|
||||
Binary target values.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
response_method : {'predict_proba', 'decision_function', 'auto'}, \
|
||||
default='auto'
|
||||
Specifies whether to use :term:`predict_proba` or
|
||||
:term:`decision_function` as the target response. If set to 'auto',
|
||||
:term:`predict_proba` is tried first and if it does not exist
|
||||
:term:`decision_function` is tried next.
|
||||
|
||||
name : str, default=None
|
||||
Name for labeling curve. If `None`, the name of the
|
||||
estimator is used.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is created.
|
||||
|
||||
pos_label : str or int, default=None
|
||||
The class considered as the positive class when computing the precision
|
||||
and recall metrics. By default, `estimators.classes_[1]` is considered
|
||||
as the positive class.
|
||||
|
||||
.. versionadded:: 0.24
|
||||
|
||||
**kwargs : dict
|
||||
Keyword arguments to be passed to matplotlib's `plot`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
|
||||
Object that stores computed values.
|
||||
|
||||
See Also
|
||||
--------
|
||||
precision_recall_curve : Compute precision-recall pairs for different
|
||||
probability thresholds.
|
||||
PrecisionRecallDisplay : Precision Recall visualization.
|
||||
"""
|
||||
check_matplotlib_support("plot_precision_recall_curve")
|
||||
|
||||
y_pred, pos_label = _get_response(
|
||||
X, estimator, response_method, pos_label=pos_label
|
||||
)
|
||||
|
||||
precision, recall, _ = precision_recall_curve(
|
||||
y, y_pred, pos_label=pos_label, sample_weight=sample_weight
|
||||
)
|
||||
average_precision = average_precision_score(
|
||||
y, y_pred, pos_label=pos_label, sample_weight=sample_weight
|
||||
)
|
||||
|
||||
name = name if name is not None else estimator.__class__.__name__
|
||||
|
||||
viz = PrecisionRecallDisplay(
|
||||
precision=precision,
|
||||
recall=recall,
|
||||
average_precision=average_precision,
|
||||
estimator_name=name,
|
||||
pos_label=pos_label,
|
||||
)
|
||||
|
||||
return viz.plot(ax=ax, name=name, **kwargs)
|
||||
@@ -0,0 +1,470 @@
|
||||
from .base import _get_response
|
||||
|
||||
from .. import auc
|
||||
from .. import roc_curve
|
||||
from .._base import _check_pos_label_consistency
|
||||
|
||||
from ...utils import check_matplotlib_support, deprecated
|
||||
|
||||
|
||||
class RocCurveDisplay:
|
||||
"""ROC Curve visualization.
|
||||
|
||||
It is recommend to use
|
||||
:func:`~sklearn.metrics.RocCurveDisplay.from_estimator` or
|
||||
:func:`~sklearn.metrics.RocCurveDisplay.from_predictions` to create
|
||||
a :class:`~sklearn.metrics.RocCurveDisplay`. All parameters are
|
||||
stored as attributes.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fpr : ndarray
|
||||
False positive rate.
|
||||
|
||||
tpr : ndarray
|
||||
True positive rate.
|
||||
|
||||
roc_auc : float, default=None
|
||||
Area under ROC curve. If None, the roc_auc score is not shown.
|
||||
|
||||
estimator_name : str, default=None
|
||||
Name of estimator. If None, the estimator name is not shown.
|
||||
|
||||
pos_label : str or int, default=None
|
||||
The class considered as the positive class when computing the roc auc
|
||||
metrics. By default, `estimators.classes_[1]` is considered
|
||||
as the positive class.
|
||||
|
||||
.. versionadded:: 0.24
|
||||
|
||||
Attributes
|
||||
----------
|
||||
line_ : matplotlib Artist
|
||||
ROC Curve.
|
||||
|
||||
ax_ : matplotlib Axes
|
||||
Axes with ROC Curve.
|
||||
|
||||
figure_ : matplotlib Figure
|
||||
Figure containing the curve.
|
||||
|
||||
See Also
|
||||
--------
|
||||
roc_curve : Compute Receiver operating characteristic (ROC) curve.
|
||||
RocCurveDisplay.from_estimator : Plot Receiver Operating Characteristic
|
||||
(ROC) curve given an estimator and some data.
|
||||
RocCurveDisplay.from_predictions : Plot Receiver Operating Characteristic
|
||||
(ROC) curve given the true and predicted values.
|
||||
roc_auc_score : Compute the area under the ROC curve.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> import numpy as np
|
||||
>>> from sklearn import metrics
|
||||
>>> y = np.array([0, 0, 1, 1])
|
||||
>>> pred = np.array([0.1, 0.4, 0.35, 0.8])
|
||||
>>> fpr, tpr, thresholds = metrics.roc_curve(y, pred)
|
||||
>>> roc_auc = metrics.auc(fpr, tpr)
|
||||
>>> display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc,
|
||||
... estimator_name='example estimator')
|
||||
>>> display.plot()
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
|
||||
def __init__(self, *, fpr, tpr, roc_auc=None, estimator_name=None, pos_label=None):
|
||||
self.estimator_name = estimator_name
|
||||
self.fpr = fpr
|
||||
self.tpr = tpr
|
||||
self.roc_auc = roc_auc
|
||||
self.pos_label = pos_label
|
||||
|
||||
def plot(self, ax=None, *, name=None, **kwargs):
|
||||
"""Plot visualization
|
||||
|
||||
Extra keyword arguments will be passed to matplotlib's ``plot``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
name : str, default=None
|
||||
Name of ROC Curve for labeling. If `None`, use `estimator_name` if
|
||||
not `None`, otherwise no labeling is shown.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.plot.RocCurveDisplay`
|
||||
Object that stores computed values.
|
||||
"""
|
||||
check_matplotlib_support("RocCurveDisplay.plot")
|
||||
|
||||
name = self.estimator_name if name is None else name
|
||||
|
||||
line_kwargs = {}
|
||||
if self.roc_auc is not None and name is not None:
|
||||
line_kwargs["label"] = f"{name} (AUC = {self.roc_auc:0.2f})"
|
||||
elif self.roc_auc is not None:
|
||||
line_kwargs["label"] = f"AUC = {self.roc_auc:0.2f}"
|
||||
elif name is not None:
|
||||
line_kwargs["label"] = name
|
||||
|
||||
line_kwargs.update(**kwargs)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
(self.line_,) = ax.plot(self.fpr, self.tpr, **line_kwargs)
|
||||
info_pos_label = (
|
||||
f" (Positive label: {self.pos_label})" if self.pos_label is not None else ""
|
||||
)
|
||||
|
||||
xlabel = "False Positive Rate" + info_pos_label
|
||||
ylabel = "True Positive Rate" + info_pos_label
|
||||
ax.set(xlabel=xlabel, ylabel=ylabel)
|
||||
|
||||
if "label" in line_kwargs:
|
||||
ax.legend(loc="lower right")
|
||||
|
||||
self.ax_ = ax
|
||||
self.figure_ = ax.figure
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_estimator(
|
||||
cls,
|
||||
estimator,
|
||||
X,
|
||||
y,
|
||||
*,
|
||||
sample_weight=None,
|
||||
drop_intermediate=True,
|
||||
response_method="auto",
|
||||
pos_label=None,
|
||||
name=None,
|
||||
ax=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a ROC Curve display from an estimator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator instance
|
||||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
|
||||
in which the last estimator is a classifier.
|
||||
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
Input values.
|
||||
|
||||
y : array-like of shape (n_samples,)
|
||||
Target values.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
drop_intermediate : bool, default=True
|
||||
Whether to drop some suboptimal thresholds which would not appear
|
||||
on a plotted ROC curve. This is useful in order to create lighter
|
||||
ROC curves.
|
||||
|
||||
response_method : {'predict_proba', 'decision_function', 'auto'} \
|
||||
default='auto'
|
||||
Specifies whether to use :term:`predict_proba` or
|
||||
:term:`decision_function` as the target response. If set to 'auto',
|
||||
:term:`predict_proba` is tried first and if it does not exist
|
||||
:term:`decision_function` is tried next.
|
||||
|
||||
pos_label : str or int, default=None
|
||||
The class considered as the positive class when computing the roc auc
|
||||
metrics. By default, `estimators.classes_[1]` is considered
|
||||
as the positive class.
|
||||
|
||||
name : str, default=None
|
||||
Name of ROC Curve for labeling. If `None`, use the name of the
|
||||
estimator.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is created.
|
||||
|
||||
**kwargs : dict
|
||||
Keyword arguments to be passed to matplotlib's `plot`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.plot.RocCurveDisplay`
|
||||
The ROC Curve display.
|
||||
|
||||
See Also
|
||||
--------
|
||||
roc_curve : Compute Receiver operating characteristic (ROC) curve.
|
||||
RocCurveDisplay.from_predictions : ROC Curve visualization given the
|
||||
probabilities of scores of a classifier.
|
||||
roc_auc_score : Compute the area under the ROC curve.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import RocCurveDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, random_state=0)
|
||||
>>> clf = SVC(random_state=0).fit(X_train, y_train)
|
||||
>>> RocCurveDisplay.from_estimator(
|
||||
... clf, X_test, y_test)
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
check_matplotlib_support(f"{cls.__name__}.from_estimator")
|
||||
|
||||
name = estimator.__class__.__name__ if name is None else name
|
||||
|
||||
y_pred, pos_label = _get_response(
|
||||
X,
|
||||
estimator,
|
||||
response_method=response_method,
|
||||
pos_label=pos_label,
|
||||
)
|
||||
|
||||
return cls.from_predictions(
|
||||
y_true=y,
|
||||
y_pred=y_pred,
|
||||
sample_weight=sample_weight,
|
||||
drop_intermediate=drop_intermediate,
|
||||
name=name,
|
||||
ax=ax,
|
||||
pos_label=pos_label,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_predictions(
|
||||
cls,
|
||||
y_true,
|
||||
y_pred,
|
||||
*,
|
||||
sample_weight=None,
|
||||
drop_intermediate=True,
|
||||
pos_label=None,
|
||||
name=None,
|
||||
ax=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Plot ROC curve given the true and predicted values.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
.. versionadded:: 1.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y_true : array-like of shape (n_samples,)
|
||||
True labels.
|
||||
|
||||
y_pred : array-like of shape (n_samples,)
|
||||
Target scores, can either be probability estimates of the positive
|
||||
class, confidence values, or non-thresholded measure of decisions
|
||||
(as returned by “decision_function” on some classifiers).
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
drop_intermediate : bool, default=True
|
||||
Whether to drop some suboptimal thresholds which would not appear
|
||||
on a plotted ROC curve. This is useful in order to create lighter
|
||||
ROC curves.
|
||||
|
||||
pos_label : str or int, default=None
|
||||
The label of the positive class. When `pos_label=None`, if `y_true`
|
||||
is in {-1, 1} or {0, 1}, `pos_label` is set to 1, otherwise an
|
||||
error will be raised.
|
||||
|
||||
name : str, default=None
|
||||
Name of ROC curve for labeling. If `None`, name will be set to
|
||||
`"Classifier"`.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is
|
||||
created.
|
||||
|
||||
**kwargs : dict
|
||||
Additional keywords arguments passed to matplotlib `plot` function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.RocCurveDisplay`
|
||||
Object that stores computed values.
|
||||
|
||||
See Also
|
||||
--------
|
||||
roc_curve : Compute Receiver operating characteristic (ROC) curve.
|
||||
RocCurveDisplay.from_estimator : ROC Curve visualization given an
|
||||
estimator and some data.
|
||||
roc_auc_score : Compute the area under the ROC curve.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn.datasets import make_classification
|
||||
>>> from sklearn.metrics import RocCurveDisplay
|
||||
>>> from sklearn.model_selection import train_test_split
|
||||
>>> from sklearn.svm import SVC
|
||||
>>> X, y = make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = train_test_split(
|
||||
... X, y, random_state=0)
|
||||
>>> clf = SVC(random_state=0).fit(X_train, y_train)
|
||||
>>> y_pred = clf.decision_function(X_test)
|
||||
>>> RocCurveDisplay.from_predictions(
|
||||
... y_test, y_pred)
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
check_matplotlib_support(f"{cls.__name__}.from_predictions")
|
||||
|
||||
fpr, tpr, _ = roc_curve(
|
||||
y_true,
|
||||
y_pred,
|
||||
pos_label=pos_label,
|
||||
sample_weight=sample_weight,
|
||||
drop_intermediate=drop_intermediate,
|
||||
)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
|
||||
name = "Classifier" if name is None else name
|
||||
pos_label = _check_pos_label_consistency(pos_label, y_true)
|
||||
|
||||
viz = RocCurveDisplay(
|
||||
fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name=name, pos_label=pos_label
|
||||
)
|
||||
|
||||
return viz.plot(ax=ax, name=name, **kwargs)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"Function :func:`plot_roc_curve` is deprecated in 1.0 and will be "
|
||||
"removed in 1.2. Use one of the class methods: "
|
||||
":meth:`sklearn.metric.RocCurveDisplay.from_predictions` or "
|
||||
":meth:`sklearn.metric.RocCurveDisplay.from_estimator`."
|
||||
)
|
||||
def plot_roc_curve(
|
||||
estimator,
|
||||
X,
|
||||
y,
|
||||
*,
|
||||
sample_weight=None,
|
||||
drop_intermediate=True,
|
||||
response_method="auto",
|
||||
name=None,
|
||||
ax=None,
|
||||
pos_label=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Plot Receiver operating characteristic (ROC) curve.
|
||||
|
||||
Extra keyword arguments will be passed to matplotlib's `plot`.
|
||||
|
||||
Read more in the :ref:`User Guide <visualizations>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator instance
|
||||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
|
||||
in which the last estimator is a classifier.
|
||||
|
||||
X : {array-like, sparse matrix} of shape (n_samples, n_features)
|
||||
Input values.
|
||||
|
||||
y : array-like of shape (n_samples,)
|
||||
Target values.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
drop_intermediate : bool, default=True
|
||||
Whether to drop some suboptimal thresholds which would not appear
|
||||
on a plotted ROC curve. This is useful in order to create lighter
|
||||
ROC curves.
|
||||
|
||||
response_method : {'predict_proba', 'decision_function', 'auto'} \
|
||||
default='auto'
|
||||
Specifies whether to use :term:`predict_proba` or
|
||||
:term:`decision_function` as the target response. If set to 'auto',
|
||||
:term:`predict_proba` is tried first and if it does not exist
|
||||
:term:`decision_function` is tried next.
|
||||
|
||||
name : str, default=None
|
||||
Name of ROC Curve for labeling. If `None`, use the name of the
|
||||
estimator.
|
||||
|
||||
ax : matplotlib axes, default=None
|
||||
Axes object to plot on. If `None`, a new figure and axes is created.
|
||||
|
||||
pos_label : str or int, default=None
|
||||
The class considered as the positive class when computing the roc auc
|
||||
metrics. By default, `estimators.classes_[1]` is considered
|
||||
as the positive class.
|
||||
|
||||
**kwargs : dict
|
||||
Additional keywords arguments passed to matplotlib `plot` function.
|
||||
|
||||
.. versionadded:: 0.24
|
||||
|
||||
Returns
|
||||
-------
|
||||
display : :class:`~sklearn.metrics.RocCurveDisplay`
|
||||
Object that stores computed values.
|
||||
|
||||
See Also
|
||||
--------
|
||||
roc_curve : Compute Receiver operating characteristic (ROC) curve.
|
||||
RocCurveDisplay.from_estimator : ROC Curve visualization given an estimator
|
||||
and some data.
|
||||
RocCurveDisplay.from_predictions : ROC Curve visualisation given the
|
||||
true and predicted values.
|
||||
roc_auc_score : Compute the area under the ROC curve.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> from sklearn import datasets, metrics, model_selection, svm
|
||||
>>> X, y = datasets.make_classification(random_state=0)
|
||||
>>> X_train, X_test, y_train, y_test = model_selection.train_test_split(
|
||||
... X, y, random_state=0)
|
||||
>>> clf = svm.SVC(random_state=0)
|
||||
>>> clf.fit(X_train, y_train)
|
||||
SVC(random_state=0)
|
||||
>>> metrics.plot_roc_curve(clf, X_test, y_test) # doctest: +SKIP
|
||||
<...>
|
||||
>>> plt.show()
|
||||
"""
|
||||
check_matplotlib_support("plot_roc_curve")
|
||||
|
||||
y_pred, pos_label = _get_response(
|
||||
X, estimator, response_method, pos_label=pos_label
|
||||
)
|
||||
|
||||
fpr, tpr, _ = roc_curve(
|
||||
y,
|
||||
y_pred,
|
||||
pos_label=pos_label,
|
||||
sample_weight=sample_weight,
|
||||
drop_intermediate=drop_intermediate,
|
||||
)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
|
||||
name = estimator.__class__.__name__ if name is None else name
|
||||
|
||||
viz = RocCurveDisplay(
|
||||
fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name=name, pos_label=pos_label
|
||||
)
|
||||
|
||||
return viz.plot(ax=ax, name=name, **kwargs)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,75 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
|
||||
|
||||
from sklearn.metrics._plot.base import _get_response
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"estimator, err_msg, params",
|
||||
[
|
||||
(
|
||||
DecisionTreeRegressor(),
|
||||
"Expected 'estimator' to be a binary classifier",
|
||||
{"response_method": "auto"},
|
||||
),
|
||||
(
|
||||
DecisionTreeClassifier(),
|
||||
"The class provided by 'pos_label' is unknown.",
|
||||
{"response_method": "auto", "pos_label": "unknown"},
|
||||
),
|
||||
(
|
||||
DecisionTreeClassifier(),
|
||||
"fit on multiclass",
|
||||
{"response_method": "predict_proba"},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_response_error(estimator, err_msg, params):
|
||||
"""Check that we raise the proper error messages in `_get_response`."""
|
||||
X, y = load_iris(return_X_y=True)
|
||||
|
||||
estimator.fit(X, y)
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
_get_response(X, estimator, **params)
|
||||
|
||||
|
||||
def test_get_response_predict_proba():
|
||||
"""Check the behaviour of `_get_response` using `predict_proba`."""
|
||||
X, y = load_iris(return_X_y=True)
|
||||
X_binary, y_binary = X[:100], y[:100]
|
||||
|
||||
classifier = DecisionTreeClassifier().fit(X_binary, y_binary)
|
||||
y_proba, pos_label = _get_response(
|
||||
X_binary, classifier, response_method="predict_proba"
|
||||
)
|
||||
np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 1])
|
||||
assert pos_label == 1
|
||||
|
||||
y_proba, pos_label = _get_response(
|
||||
X_binary, classifier, response_method="predict_proba", pos_label=0
|
||||
)
|
||||
np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 0])
|
||||
assert pos_label == 0
|
||||
|
||||
|
||||
def test_get_response_decision_function():
|
||||
"""Check the behaviour of `get_response` using `decision_function`."""
|
||||
X, y = load_iris(return_X_y=True)
|
||||
X_binary, y_binary = X[:100], y[:100]
|
||||
|
||||
classifier = LogisticRegression().fit(X_binary, y_binary)
|
||||
y_score, pos_label = _get_response(
|
||||
X_binary, classifier, response_method="decision_function"
|
||||
)
|
||||
np.testing.assert_allclose(y_score, classifier.decision_function(X_binary))
|
||||
assert pos_label == 1
|
||||
|
||||
y_score, pos_label = _get_response(
|
||||
X_binary, classifier, response_method="decision_function", pos_label=0
|
||||
)
|
||||
np.testing.assert_allclose(y_score, classifier.decision_function(X_binary) * -1)
|
||||
assert pos_label == 0
|
||||
@@ -0,0 +1,152 @@
|
||||
import pytest
|
||||
|
||||
from sklearn.base import ClassifierMixin, clone
|
||||
from sklearn.compose import make_column_transformer
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
|
||||
from sklearn.metrics import (
|
||||
DetCurveDisplay,
|
||||
PrecisionRecallDisplay,
|
||||
RocCurveDisplay,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data():
|
||||
return load_iris(return_X_y=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data_binary(data):
|
||||
X, y = data
|
||||
return X[y < 2], y[y < 2]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
|
||||
)
|
||||
def test_display_curve_error_non_binary(pyplot, data, Display):
|
||||
"""Check that a proper error is raised when only binary classification is
|
||||
supported."""
|
||||
X, y = data
|
||||
clf = DecisionTreeClassifier().fit(X, y)
|
||||
|
||||
msg = (
|
||||
"Expected 'estimator' to be a binary classifier, but got DecisionTreeClassifier"
|
||||
)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
Display.from_estimator(clf, X, y)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response_method, msg",
|
||||
[
|
||||
(
|
||||
"predict_proba",
|
||||
"response method predict_proba is not defined in MyClassifier",
|
||||
),
|
||||
(
|
||||
"decision_function",
|
||||
"response method decision_function is not defined in MyClassifier",
|
||||
),
|
||||
(
|
||||
"auto",
|
||||
"response method decision_function or predict_proba is not "
|
||||
"defined in MyClassifier",
|
||||
),
|
||||
(
|
||||
"bad_method",
|
||||
"response_method must be 'predict_proba', 'decision_function' or 'auto'",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
|
||||
)
|
||||
def test_display_curve_error_no_response(
|
||||
pyplot,
|
||||
data_binary,
|
||||
response_method,
|
||||
msg,
|
||||
Display,
|
||||
):
|
||||
"""Check that a proper error is raised when the response method requested
|
||||
is not defined for the given trained classifier."""
|
||||
X, y = data_binary
|
||||
|
||||
class MyClassifier(ClassifierMixin):
|
||||
def fit(self, X, y):
|
||||
self.classes_ = [0, 1]
|
||||
return self
|
||||
|
||||
clf = MyClassifier().fit(X, y)
|
||||
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
Display.from_estimator(clf, X, y, response_method=response_method)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
|
||||
)
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
def test_display_curve_estimator_name_multiple_calls(
|
||||
pyplot,
|
||||
data_binary,
|
||||
Display,
|
||||
constructor_name,
|
||||
):
|
||||
"""Check that passing `name` when calling `plot` will overwrite the original name
|
||||
in the legend."""
|
||||
X, y = data_binary
|
||||
clf_name = "my hand-crafted name"
|
||||
clf = LogisticRegression().fit(X, y)
|
||||
y_pred = clf.predict_proba(X)[:, 1]
|
||||
|
||||
# safe guard for the binary if/else construction
|
||||
assert constructor_name in ("from_estimator", "from_predictions")
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
disp = Display.from_estimator(clf, X, y, name=clf_name)
|
||||
else:
|
||||
disp = Display.from_predictions(y, y_pred, name=clf_name)
|
||||
assert disp.estimator_name == clf_name
|
||||
pyplot.close("all")
|
||||
disp.plot()
|
||||
assert clf_name in disp.line_.get_label()
|
||||
pyplot.close("all")
|
||||
clf_name = "another_name"
|
||||
disp.plot(name=clf_name)
|
||||
assert clf_name in disp.line_.get_label()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clf",
|
||||
[
|
||||
LogisticRegression(),
|
||||
make_pipeline(StandardScaler(), LogisticRegression()),
|
||||
make_pipeline(
|
||||
make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression()
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay]
|
||||
)
|
||||
def test_display_curve_not_fitted_errors(pyplot, data_binary, clf, Display):
|
||||
"""Check that a proper error is raised when the classifier is not
|
||||
fitted."""
|
||||
X, y = data_binary
|
||||
# clone since we parametrize the test and the classifier will be fitted
|
||||
# when testing the second and subsequent plotting function
|
||||
model = clone(clf)
|
||||
with pytest.raises(NotFittedError):
|
||||
Display.from_estimator(model, X, y)
|
||||
model.fit(X, y)
|
||||
disp = Display.from_estimator(model, X, y)
|
||||
assert model.__class__.__name__ in disp.line_.get_label()
|
||||
assert disp.estimator_name == model.__class__.__name__
|
||||
@@ -0,0 +1,379 @@
|
||||
from numpy.testing import (
|
||||
assert_allclose,
|
||||
assert_array_equal,
|
||||
)
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from sklearn.datasets import make_classification
|
||||
from sklearn.compose import make_column_transformer
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.svm import SVC, SVR
|
||||
|
||||
from sklearn.metrics import ConfusionMatrixDisplay
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
||||
|
||||
# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved
|
||||
pytestmark = pytest.mark.filterwarnings(
|
||||
"ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:"
|
||||
"matplotlib.*"
|
||||
)
|
||||
|
||||
|
||||
def test_confusion_matrix_display_validation(pyplot):
|
||||
"""Check that we raise the proper error when validating parameters."""
|
||||
X, y = make_classification(
|
||||
n_samples=100, n_informative=5, n_classes=5, random_state=0
|
||||
)
|
||||
|
||||
with pytest.raises(NotFittedError):
|
||||
ConfusionMatrixDisplay.from_estimator(SVC(), X, y)
|
||||
|
||||
regressor = SVR().fit(X, y)
|
||||
y_pred_regressor = regressor.predict(X)
|
||||
y_pred_classifier = SVC().fit(X, y).predict(X)
|
||||
|
||||
err_msg = "ConfusionMatrixDisplay.from_estimator only supports classifiers"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
ConfusionMatrixDisplay.from_estimator(regressor, X, y)
|
||||
|
||||
err_msg = "Mix type of y not allowed, got types"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
# Force `y_true` to be seen as a regression problem
|
||||
ConfusionMatrixDisplay.from_predictions(y + 0.5, y_pred_classifier)
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
ConfusionMatrixDisplay.from_predictions(y, y_pred_regressor)
|
||||
|
||||
err_msg = "Found input variables with inconsistent numbers of samples"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
ConfusionMatrixDisplay.from_predictions(y, y_pred_classifier[::2])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
def test_confusion_matrix_display_invalid_option(pyplot, constructor_name):
|
||||
"""Check the error raise if an invalid parameter value is passed."""
|
||||
X, y = make_classification(
|
||||
n_samples=100, n_informative=5, n_classes=5, random_state=0
|
||||
)
|
||||
classifier = SVC().fit(X, y)
|
||||
y_pred = classifier.predict(X)
|
||||
|
||||
# safe guard for the binary if/else construction
|
||||
assert constructor_name in ("from_estimator", "from_predictions")
|
||||
extra_params = {"normalize": "invalid"}
|
||||
|
||||
err_msg = r"normalize must be one of \{'true', 'pred', 'all', None\}"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
if constructor_name == "from_estimator":
|
||||
ConfusionMatrixDisplay.from_estimator(classifier, X, y, **extra_params)
|
||||
else:
|
||||
ConfusionMatrixDisplay.from_predictions(y, y_pred, **extra_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
@pytest.mark.parametrize("with_labels", [True, False])
|
||||
@pytest.mark.parametrize("with_display_labels", [True, False])
|
||||
def test_confusion_matrix_display_custom_labels(
|
||||
pyplot, constructor_name, with_labels, with_display_labels
|
||||
):
|
||||
"""Check the resulting plot when labels are given."""
|
||||
n_classes = 5
|
||||
X, y = make_classification(
|
||||
n_samples=100, n_informative=5, n_classes=n_classes, random_state=0
|
||||
)
|
||||
classifier = SVC().fit(X, y)
|
||||
y_pred = classifier.predict(X)
|
||||
|
||||
# safe guard for the binary if/else construction
|
||||
assert constructor_name in ("from_estimator", "from_predictions")
|
||||
|
||||
ax = pyplot.gca()
|
||||
labels = [2, 1, 0, 3, 4] if with_labels else None
|
||||
display_labels = ["b", "d", "a", "e", "f"] if with_display_labels else None
|
||||
|
||||
cm = confusion_matrix(y, y_pred, labels=labels)
|
||||
common_kwargs = {
|
||||
"ax": ax,
|
||||
"display_labels": display_labels,
|
||||
"labels": labels,
|
||||
}
|
||||
if constructor_name == "from_estimator":
|
||||
disp = ConfusionMatrixDisplay.from_estimator(classifier, X, y, **common_kwargs)
|
||||
else:
|
||||
disp = ConfusionMatrixDisplay.from_predictions(y, y_pred, **common_kwargs)
|
||||
assert_allclose(disp.confusion_matrix, cm)
|
||||
|
||||
if with_display_labels:
|
||||
expected_display_labels = display_labels
|
||||
elif with_labels:
|
||||
expected_display_labels = labels
|
||||
else:
|
||||
expected_display_labels = list(range(n_classes))
|
||||
|
||||
expected_display_labels_str = [str(name) for name in expected_display_labels]
|
||||
|
||||
x_ticks = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
|
||||
y_ticks = [tick.get_text() for tick in disp.ax_.get_yticklabels()]
|
||||
|
||||
assert_array_equal(disp.display_labels, expected_display_labels)
|
||||
assert_array_equal(x_ticks, expected_display_labels_str)
|
||||
assert_array_equal(y_ticks, expected_display_labels_str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
@pytest.mark.parametrize("normalize", ["true", "pred", "all", None])
|
||||
@pytest.mark.parametrize("include_values", [True, False])
|
||||
def test_confusion_matrix_display_plotting(
|
||||
pyplot,
|
||||
constructor_name,
|
||||
normalize,
|
||||
include_values,
|
||||
):
|
||||
"""Check the overall plotting rendering."""
|
||||
n_classes = 5
|
||||
X, y = make_classification(
|
||||
n_samples=100, n_informative=5, n_classes=n_classes, random_state=0
|
||||
)
|
||||
classifier = SVC().fit(X, y)
|
||||
y_pred = classifier.predict(X)
|
||||
|
||||
# safe guard for the binary if/else construction
|
||||
assert constructor_name in ("from_estimator", "from_predictions")
|
||||
|
||||
ax = pyplot.gca()
|
||||
cmap = "plasma"
|
||||
|
||||
cm = confusion_matrix(y, y_pred)
|
||||
common_kwargs = {
|
||||
"normalize": normalize,
|
||||
"cmap": cmap,
|
||||
"ax": ax,
|
||||
"include_values": include_values,
|
||||
}
|
||||
if constructor_name == "from_estimator":
|
||||
disp = ConfusionMatrixDisplay.from_estimator(classifier, X, y, **common_kwargs)
|
||||
else:
|
||||
disp = ConfusionMatrixDisplay.from_predictions(y, y_pred, **common_kwargs)
|
||||
|
||||
assert disp.ax_ == ax
|
||||
|
||||
if normalize == "true":
|
||||
cm = cm / cm.sum(axis=1, keepdims=True)
|
||||
elif normalize == "pred":
|
||||
cm = cm / cm.sum(axis=0, keepdims=True)
|
||||
elif normalize == "all":
|
||||
cm = cm / cm.sum()
|
||||
|
||||
assert_allclose(disp.confusion_matrix, cm)
|
||||
import matplotlib as mpl
|
||||
|
||||
assert isinstance(disp.im_, mpl.image.AxesImage)
|
||||
assert disp.im_.get_cmap().name == cmap
|
||||
assert isinstance(disp.ax_, pyplot.Axes)
|
||||
assert isinstance(disp.figure_, pyplot.Figure)
|
||||
|
||||
assert disp.ax_.get_ylabel() == "True label"
|
||||
assert disp.ax_.get_xlabel() == "Predicted label"
|
||||
|
||||
x_ticks = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
|
||||
y_ticks = [tick.get_text() for tick in disp.ax_.get_yticklabels()]
|
||||
|
||||
expected_display_labels = list(range(n_classes))
|
||||
|
||||
expected_display_labels_str = [str(name) for name in expected_display_labels]
|
||||
|
||||
assert_array_equal(disp.display_labels, expected_display_labels)
|
||||
assert_array_equal(x_ticks, expected_display_labels_str)
|
||||
assert_array_equal(y_ticks, expected_display_labels_str)
|
||||
|
||||
image_data = disp.im_.get_array().data
|
||||
assert_allclose(image_data, cm)
|
||||
|
||||
if include_values:
|
||||
assert disp.text_.shape == (n_classes, n_classes)
|
||||
fmt = ".2g"
|
||||
expected_text = np.array([format(v, fmt) for v in cm.ravel(order="C")])
|
||||
text_text = np.array([t.get_text() for t in disp.text_.ravel(order="C")])
|
||||
assert_array_equal(expected_text, text_text)
|
||||
else:
|
||||
assert disp.text_ is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
def test_confusion_matrix_display(pyplot, constructor_name):
|
||||
"""Check the behaviour of the default constructor without using the class
|
||||
methods."""
|
||||
n_classes = 5
|
||||
X, y = make_classification(
|
||||
n_samples=100, n_informative=5, n_classes=n_classes, random_state=0
|
||||
)
|
||||
classifier = SVC().fit(X, y)
|
||||
y_pred = classifier.predict(X)
|
||||
|
||||
# safe guard for the binary if/else construction
|
||||
assert constructor_name in ("from_estimator", "from_predictions")
|
||||
|
||||
cm = confusion_matrix(y, y_pred)
|
||||
common_kwargs = {
|
||||
"normalize": None,
|
||||
"include_values": True,
|
||||
"cmap": "viridis",
|
||||
"xticks_rotation": 45.0,
|
||||
}
|
||||
if constructor_name == "from_estimator":
|
||||
disp = ConfusionMatrixDisplay.from_estimator(classifier, X, y, **common_kwargs)
|
||||
else:
|
||||
disp = ConfusionMatrixDisplay.from_predictions(y, y_pred, **common_kwargs)
|
||||
|
||||
assert_allclose(disp.confusion_matrix, cm)
|
||||
assert disp.text_.shape == (n_classes, n_classes)
|
||||
|
||||
rotations = [tick.get_rotation() for tick in disp.ax_.get_xticklabels()]
|
||||
assert_allclose(rotations, 45.0)
|
||||
|
||||
image_data = disp.im_.get_array().data
|
||||
assert_allclose(image_data, cm)
|
||||
|
||||
disp.plot(cmap="plasma")
|
||||
assert disp.im_.get_cmap().name == "plasma"
|
||||
|
||||
disp.plot(include_values=False)
|
||||
assert disp.text_ is None
|
||||
|
||||
disp.plot(xticks_rotation=90.0)
|
||||
rotations = [tick.get_rotation() for tick in disp.ax_.get_xticklabels()]
|
||||
assert_allclose(rotations, 90.0)
|
||||
|
||||
disp.plot(values_format="e")
|
||||
expected_text = np.array([format(v, "e") for v in cm.ravel(order="C")])
|
||||
text_text = np.array([t.get_text() for t in disp.text_.ravel(order="C")])
|
||||
assert_array_equal(expected_text, text_text)
|
||||
|
||||
|
||||
def test_confusion_matrix_contrast(pyplot):
|
||||
"""Check that the text color is appropriate depending on background."""
|
||||
|
||||
cm = np.eye(2) / 2
|
||||
disp = ConfusionMatrixDisplay(cm, display_labels=[0, 1])
|
||||
|
||||
disp.plot(cmap=pyplot.cm.gray)
|
||||
# diagonal text is black
|
||||
assert_allclose(disp.text_[0, 0].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||||
assert_allclose(disp.text_[1, 1].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||||
|
||||
# off-diagonal text is white
|
||||
assert_allclose(disp.text_[0, 1].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||||
assert_allclose(disp.text_[1, 0].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||||
|
||||
disp.plot(cmap=pyplot.cm.gray_r)
|
||||
# diagonal text is white
|
||||
assert_allclose(disp.text_[0, 1].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||||
assert_allclose(disp.text_[1, 0].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||||
|
||||
# off-diagonal text is black
|
||||
assert_allclose(disp.text_[0, 0].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||||
assert_allclose(disp.text_[1, 1].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||||
|
||||
# Regression test for #15920
|
||||
cm = np.array([[19, 34], [32, 58]])
|
||||
disp = ConfusionMatrixDisplay(cm, display_labels=[0, 1])
|
||||
|
||||
disp.plot(cmap=pyplot.cm.Blues)
|
||||
min_color = pyplot.cm.Blues(0)
|
||||
max_color = pyplot.cm.Blues(255)
|
||||
assert_allclose(disp.text_[0, 0].get_color(), max_color)
|
||||
assert_allclose(disp.text_[0, 1].get_color(), max_color)
|
||||
assert_allclose(disp.text_[1, 0].get_color(), max_color)
|
||||
assert_allclose(disp.text_[1, 1].get_color(), min_color)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clf",
|
||||
[
|
||||
LogisticRegression(),
|
||||
make_pipeline(StandardScaler(), LogisticRegression()),
|
||||
make_pipeline(
|
||||
make_column_transformer((StandardScaler(), [0, 1])),
|
||||
LogisticRegression(),
|
||||
),
|
||||
],
|
||||
ids=["clf", "pipeline-clf", "pipeline-column_transformer-clf"],
|
||||
)
|
||||
def test_confusion_matrix_pipeline(pyplot, clf):
|
||||
"""Check the behaviour of the plotting with more complex pipeline."""
|
||||
n_classes = 5
|
||||
X, y = make_classification(
|
||||
n_samples=100, n_informative=5, n_classes=n_classes, random_state=0
|
||||
)
|
||||
with pytest.raises(NotFittedError):
|
||||
ConfusionMatrixDisplay.from_estimator(clf, X, y)
|
||||
clf.fit(X, y)
|
||||
y_pred = clf.predict(X)
|
||||
|
||||
disp = ConfusionMatrixDisplay.from_estimator(clf, X, y)
|
||||
cm = confusion_matrix(y, y_pred)
|
||||
|
||||
assert_allclose(disp.confusion_matrix, cm)
|
||||
assert disp.text_.shape == (n_classes, n_classes)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
def test_confusion_matrix_with_unknown_labels(pyplot, constructor_name):
|
||||
"""Check that when labels=None, the unique values in `y_pred` and `y_true`
|
||||
will be used.
|
||||
Non-regression test for:
|
||||
https://github.com/scikit-learn/scikit-learn/pull/18405
|
||||
"""
|
||||
n_classes = 5
|
||||
X, y = make_classification(
|
||||
n_samples=100, n_informative=5, n_classes=n_classes, random_state=0
|
||||
)
|
||||
classifier = SVC().fit(X, y)
|
||||
y_pred = classifier.predict(X)
|
||||
# create unseen labels in `y_true` not seen during fitting and not present
|
||||
# in 'classifier.classes_'
|
||||
y = y + 1
|
||||
|
||||
# safe guard for the binary if/else construction
|
||||
assert constructor_name in ("from_estimator", "from_predictions")
|
||||
|
||||
common_kwargs = {"labels": None}
|
||||
if constructor_name == "from_estimator":
|
||||
disp = ConfusionMatrixDisplay.from_estimator(classifier, X, y, **common_kwargs)
|
||||
else:
|
||||
disp = ConfusionMatrixDisplay.from_predictions(y, y_pred, **common_kwargs)
|
||||
|
||||
display_labels = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
|
||||
expected_labels = [str(i) for i in range(n_classes + 1)]
|
||||
assert_array_equal(expected_labels, display_labels)
|
||||
|
||||
|
||||
def test_colormap_max(pyplot):
|
||||
"""Check that the max color is used for the color of the text."""
|
||||
|
||||
from matplotlib import cm
|
||||
|
||||
gray = cm.get_cmap("gray", 1024)
|
||||
confusion_matrix = np.array([[1.0, 0.0], [0.0, 1.0]])
|
||||
|
||||
disp = ConfusionMatrixDisplay(confusion_matrix)
|
||||
disp.plot(cmap=gray)
|
||||
|
||||
color = disp.text_[1, 0].get_color()
|
||||
assert_allclose(color, [1.0, 1.0, 1.0, 1.0])
|
||||
|
||||
|
||||
def test_im_kw_adjust_vmin_vmax(pyplot):
|
||||
"""Check that im_kw passes kwargs to imshow"""
|
||||
|
||||
confusion_matrix = np.array([[0.48, 0.04], [0.08, 0.4]])
|
||||
disp = ConfusionMatrixDisplay(confusion_matrix)
|
||||
disp.plot(im_kw=dict(vmin=0.0, vmax=0.8))
|
||||
|
||||
clim = disp.im_.get_clim()
|
||||
assert clim[0] == pytest.approx(0.0)
|
||||
assert clim[1] == pytest.approx(0.8)
|
||||
@@ -0,0 +1,108 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
||||
from sklearn.metrics import det_curve
|
||||
from sklearn.metrics import DetCurveDisplay
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
|
||||
@pytest.mark.parametrize("with_sample_weight", [True, False])
|
||||
@pytest.mark.parametrize("with_strings", [True, False])
|
||||
def test_det_curve_display(
|
||||
pyplot, constructor_name, response_method, with_sample_weight, with_strings
|
||||
):
|
||||
X, y = load_iris(return_X_y=True)
|
||||
# Binarize the data with only the two first classes
|
||||
X, y = X[y < 2], y[y < 2]
|
||||
|
||||
pos_label = None
|
||||
if with_strings:
|
||||
y = np.array(["c", "b"])[y]
|
||||
pos_label = "c"
|
||||
|
||||
if with_sample_weight:
|
||||
rng = np.random.RandomState(42)
|
||||
sample_weight = rng.randint(1, 4, size=(X.shape[0]))
|
||||
else:
|
||||
sample_weight = None
|
||||
|
||||
lr = LogisticRegression()
|
||||
lr.fit(X, y)
|
||||
y_pred = getattr(lr, response_method)(X)
|
||||
if y_pred.ndim == 2:
|
||||
y_pred = y_pred[:, 1]
|
||||
|
||||
# safe guard for the binary if/else construction
|
||||
assert constructor_name in ("from_estimator", "from_predictions")
|
||||
|
||||
common_kwargs = {
|
||||
"name": lr.__class__.__name__,
|
||||
"alpha": 0.8,
|
||||
"sample_weight": sample_weight,
|
||||
"pos_label": pos_label,
|
||||
}
|
||||
if constructor_name == "from_estimator":
|
||||
disp = DetCurveDisplay.from_estimator(lr, X, y, **common_kwargs)
|
||||
else:
|
||||
disp = DetCurveDisplay.from_predictions(y, y_pred, **common_kwargs)
|
||||
|
||||
fpr, fnr, _ = det_curve(
|
||||
y,
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
pos_label=pos_label,
|
||||
)
|
||||
|
||||
assert_allclose(disp.fpr, fpr)
|
||||
assert_allclose(disp.fnr, fnr)
|
||||
|
||||
assert disp.estimator_name == "LogisticRegression"
|
||||
|
||||
# cannot fail thanks to pyplot fixture
|
||||
import matplotlib as mpl # noqal
|
||||
|
||||
assert isinstance(disp.line_, mpl.lines.Line2D)
|
||||
assert disp.line_.get_alpha() == 0.8
|
||||
assert isinstance(disp.ax_, mpl.axes.Axes)
|
||||
assert isinstance(disp.figure_, mpl.figure.Figure)
|
||||
assert disp.line_.get_label() == "LogisticRegression"
|
||||
|
||||
expected_pos_label = 1 if pos_label is None else pos_label
|
||||
expected_ylabel = f"False Negative Rate (Positive label: {expected_pos_label})"
|
||||
expected_xlabel = f"False Positive Rate (Positive label: {expected_pos_label})"
|
||||
assert disp.ax_.get_ylabel() == expected_ylabel
|
||||
assert disp.ax_.get_xlabel() == expected_xlabel
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"constructor_name, expected_clf_name",
|
||||
[
|
||||
("from_estimator", "LogisticRegression"),
|
||||
("from_predictions", "Classifier"),
|
||||
],
|
||||
)
|
||||
def test_det_curve_display_default_name(
|
||||
pyplot,
|
||||
constructor_name,
|
||||
expected_clf_name,
|
||||
):
|
||||
# Check the default name display in the figure when `name` is not provided
|
||||
X, y = load_iris(return_X_y=True)
|
||||
# Binarize the data with only the two first classes
|
||||
X, y = X[y < 2], y[y < 2]
|
||||
|
||||
lr = LogisticRegression().fit(X, y)
|
||||
y_pred = lr.predict_proba(X)[:, 1]
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
disp = DetCurveDisplay.from_estimator(lr, X, y)
|
||||
else:
|
||||
disp = DetCurveDisplay.from_predictions(y, y_pred)
|
||||
|
||||
assert disp.estimator_name == expected_clf_name
|
||||
assert disp.line_.get_label() == expected_clf_name
|
||||
@@ -0,0 +1,365 @@
|
||||
# TODO: remove this file when plot_confusion_matrix will be deprecated in 1.2
|
||||
import pytest
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
from numpy.testing import assert_array_equal
|
||||
|
||||
from sklearn.compose import make_column_transformer
|
||||
from sklearn.datasets import make_classification
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.svm import SVC, SVR
|
||||
|
||||
from sklearn.metrics import confusion_matrix
|
||||
from sklearn.metrics import plot_confusion_matrix
|
||||
from sklearn.metrics import ConfusionMatrixDisplay
|
||||
|
||||
|
||||
# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved
|
||||
pytestmark = pytest.mark.filterwarnings(
|
||||
"ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:"
|
||||
"matplotlib.*"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def n_classes():
|
||||
return 5
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data(n_classes):
|
||||
X, y = make_classification(
|
||||
n_samples=100, n_informative=5, n_classes=n_classes, random_state=0
|
||||
)
|
||||
return X, y
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def fitted_clf(data):
|
||||
return SVC(kernel="linear", C=0.01).fit(*data)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def y_pred(data, fitted_clf):
|
||||
X, _ = data
|
||||
return fitted_clf.predict(X)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore: Function plot_confusion_matrix is deprecated")
|
||||
def test_error_on_regressor(pyplot, data):
|
||||
X, y = data
|
||||
est = SVR().fit(X, y)
|
||||
|
||||
msg = "plot_confusion_matrix only supports classifiers"
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
plot_confusion_matrix(est, X, y)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore: Function plot_confusion_matrix is deprecated")
|
||||
def test_error_on_invalid_option(pyplot, fitted_clf, data):
|
||||
X, y = data
|
||||
msg = r"normalize must be one of \{'true', 'pred', 'all', " r"None\}"
|
||||
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
plot_confusion_matrix(fitted_clf, X, y, normalize="invalid")
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore: Function plot_confusion_matrix is deprecated")
|
||||
@pytest.mark.parametrize("with_labels", [True, False])
|
||||
@pytest.mark.parametrize("with_display_labels", [True, False])
|
||||
def test_plot_confusion_matrix_custom_labels(
|
||||
pyplot, data, y_pred, fitted_clf, n_classes, with_labels, with_display_labels
|
||||
):
|
||||
X, y = data
|
||||
ax = pyplot.gca()
|
||||
labels = [2, 1, 0, 3, 4] if with_labels else None
|
||||
display_labels = ["b", "d", "a", "e", "f"] if with_display_labels else None
|
||||
|
||||
cm = confusion_matrix(y, y_pred, labels=labels)
|
||||
disp = plot_confusion_matrix(
|
||||
fitted_clf, X, y, ax=ax, display_labels=display_labels, labels=labels
|
||||
)
|
||||
|
||||
assert_allclose(disp.confusion_matrix, cm)
|
||||
|
||||
if with_display_labels:
|
||||
expected_display_labels = display_labels
|
||||
elif with_labels:
|
||||
expected_display_labels = labels
|
||||
else:
|
||||
expected_display_labels = list(range(n_classes))
|
||||
|
||||
expected_display_labels_str = [str(name) for name in expected_display_labels]
|
||||
|
||||
x_ticks = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
|
||||
y_ticks = [tick.get_text() for tick in disp.ax_.get_yticklabels()]
|
||||
|
||||
assert_array_equal(disp.display_labels, expected_display_labels)
|
||||
assert_array_equal(x_ticks, expected_display_labels_str)
|
||||
assert_array_equal(y_ticks, expected_display_labels_str)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore: Function plot_confusion_matrix is deprecated")
|
||||
@pytest.mark.parametrize("normalize", ["true", "pred", "all", None])
|
||||
@pytest.mark.parametrize("include_values", [True, False])
|
||||
def test_plot_confusion_matrix(
|
||||
pyplot, data, y_pred, n_classes, fitted_clf, normalize, include_values
|
||||
):
|
||||
X, y = data
|
||||
ax = pyplot.gca()
|
||||
cmap = "plasma"
|
||||
cm = confusion_matrix(y, y_pred)
|
||||
disp = plot_confusion_matrix(
|
||||
fitted_clf,
|
||||
X,
|
||||
y,
|
||||
normalize=normalize,
|
||||
cmap=cmap,
|
||||
ax=ax,
|
||||
include_values=include_values,
|
||||
)
|
||||
|
||||
assert disp.ax_ == ax
|
||||
|
||||
if normalize == "true":
|
||||
cm = cm / cm.sum(axis=1, keepdims=True)
|
||||
elif normalize == "pred":
|
||||
cm = cm / cm.sum(axis=0, keepdims=True)
|
||||
elif normalize == "all":
|
||||
cm = cm / cm.sum()
|
||||
|
||||
assert_allclose(disp.confusion_matrix, cm)
|
||||
import matplotlib as mpl
|
||||
|
||||
assert isinstance(disp.im_, mpl.image.AxesImage)
|
||||
assert disp.im_.get_cmap().name == cmap
|
||||
assert isinstance(disp.ax_, pyplot.Axes)
|
||||
assert isinstance(disp.figure_, pyplot.Figure)
|
||||
|
||||
assert disp.ax_.get_ylabel() == "True label"
|
||||
assert disp.ax_.get_xlabel() == "Predicted label"
|
||||
|
||||
x_ticks = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
|
||||
y_ticks = [tick.get_text() for tick in disp.ax_.get_yticklabels()]
|
||||
|
||||
expected_display_labels = list(range(n_classes))
|
||||
|
||||
expected_display_labels_str = [str(name) for name in expected_display_labels]
|
||||
|
||||
assert_array_equal(disp.display_labels, expected_display_labels)
|
||||
assert_array_equal(x_ticks, expected_display_labels_str)
|
||||
assert_array_equal(y_ticks, expected_display_labels_str)
|
||||
|
||||
image_data = disp.im_.get_array().data
|
||||
assert_allclose(image_data, cm)
|
||||
|
||||
if include_values:
|
||||
assert disp.text_.shape == (n_classes, n_classes)
|
||||
fmt = ".2g"
|
||||
expected_text = np.array([format(v, fmt) for v in cm.ravel(order="C")])
|
||||
text_text = np.array([t.get_text() for t in disp.text_.ravel(order="C")])
|
||||
assert_array_equal(expected_text, text_text)
|
||||
else:
|
||||
assert disp.text_ is None
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore: Function plot_confusion_matrix is deprecated")
|
||||
def test_confusion_matrix_display(pyplot, data, fitted_clf, y_pred, n_classes):
|
||||
X, y = data
|
||||
|
||||
cm = confusion_matrix(y, y_pred)
|
||||
disp = plot_confusion_matrix(
|
||||
fitted_clf,
|
||||
X,
|
||||
y,
|
||||
normalize=None,
|
||||
include_values=True,
|
||||
cmap="viridis",
|
||||
xticks_rotation=45.0,
|
||||
)
|
||||
|
||||
assert_allclose(disp.confusion_matrix, cm)
|
||||
assert disp.text_.shape == (n_classes, n_classes)
|
||||
|
||||
rotations = [tick.get_rotation() for tick in disp.ax_.get_xticklabels()]
|
||||
assert_allclose(rotations, 45.0)
|
||||
|
||||
image_data = disp.im_.get_array().data
|
||||
assert_allclose(image_data, cm)
|
||||
|
||||
disp.plot(cmap="plasma")
|
||||
assert disp.im_.get_cmap().name == "plasma"
|
||||
|
||||
disp.plot(include_values=False)
|
||||
assert disp.text_ is None
|
||||
|
||||
disp.plot(xticks_rotation=90.0)
|
||||
rotations = [tick.get_rotation() for tick in disp.ax_.get_xticklabels()]
|
||||
assert_allclose(rotations, 90.0)
|
||||
|
||||
disp.plot(values_format="e")
|
||||
expected_text = np.array([format(v, "e") for v in cm.ravel(order="C")])
|
||||
text_text = np.array([t.get_text() for t in disp.text_.ravel(order="C")])
|
||||
assert_array_equal(expected_text, text_text)
|
||||
|
||||
|
||||
def test_confusion_matrix_contrast(pyplot):
|
||||
# make sure text color is appropriate depending on background
|
||||
|
||||
cm = np.eye(2) / 2
|
||||
disp = ConfusionMatrixDisplay(cm, display_labels=[0, 1])
|
||||
|
||||
disp.plot(cmap=pyplot.cm.gray)
|
||||
# diagonal text is black
|
||||
assert_allclose(disp.text_[0, 0].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||||
assert_allclose(disp.text_[1, 1].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||||
|
||||
# off-diagonal text is white
|
||||
assert_allclose(disp.text_[0, 1].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||||
assert_allclose(disp.text_[1, 0].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||||
|
||||
disp.plot(cmap=pyplot.cm.gray_r)
|
||||
# diagonal text is white
|
||||
assert_allclose(disp.text_[0, 1].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||||
assert_allclose(disp.text_[1, 0].get_color(), [0.0, 0.0, 0.0, 1.0])
|
||||
|
||||
# off-diagonal text is black
|
||||
assert_allclose(disp.text_[0, 0].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||||
assert_allclose(disp.text_[1, 1].get_color(), [1.0, 1.0, 1.0, 1.0])
|
||||
|
||||
# Regression test for #15920
|
||||
cm = np.array([[19, 34], [32, 58]])
|
||||
disp = ConfusionMatrixDisplay(cm, display_labels=[0, 1])
|
||||
|
||||
disp.plot(cmap=pyplot.cm.Blues)
|
||||
min_color = pyplot.cm.Blues(0)
|
||||
max_color = pyplot.cm.Blues(255)
|
||||
assert_allclose(disp.text_[0, 0].get_color(), max_color)
|
||||
assert_allclose(disp.text_[0, 1].get_color(), max_color)
|
||||
assert_allclose(disp.text_[1, 0].get_color(), max_color)
|
||||
assert_allclose(disp.text_[1, 1].get_color(), min_color)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore: Function plot_confusion_matrix is deprecated")
|
||||
@pytest.mark.parametrize(
|
||||
"clf",
|
||||
[
|
||||
LogisticRegression(),
|
||||
make_pipeline(StandardScaler(), LogisticRegression()),
|
||||
make_pipeline(
|
||||
make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression()
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_confusion_matrix_pipeline(pyplot, clf, data, n_classes):
|
||||
X, y = data
|
||||
with pytest.raises(NotFittedError):
|
||||
plot_confusion_matrix(clf, X, y)
|
||||
clf.fit(X, y)
|
||||
y_pred = clf.predict(X)
|
||||
|
||||
disp = plot_confusion_matrix(clf, X, y)
|
||||
cm = confusion_matrix(y, y_pred)
|
||||
|
||||
assert_allclose(disp.confusion_matrix, cm)
|
||||
assert disp.text_.shape == (n_classes, n_classes)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore: Function plot_confusion_matrix is deprecated")
|
||||
@pytest.mark.parametrize("colorbar", [True, False])
|
||||
def test_plot_confusion_matrix_colorbar(pyplot, data, fitted_clf, colorbar):
|
||||
X, y = data
|
||||
|
||||
def _check_colorbar(disp, has_colorbar):
|
||||
if has_colorbar:
|
||||
assert disp.im_.colorbar is not None
|
||||
assert disp.im_.colorbar.__class__.__name__ == "Colorbar"
|
||||
else:
|
||||
assert disp.im_.colorbar is None
|
||||
|
||||
disp = plot_confusion_matrix(fitted_clf, X, y, colorbar=colorbar)
|
||||
_check_colorbar(disp, colorbar)
|
||||
# attempt a plot with the opposite effect of colorbar
|
||||
disp.plot(colorbar=not colorbar)
|
||||
_check_colorbar(disp, not colorbar)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore: Function plot_confusion_matrix is deprecated")
|
||||
@pytest.mark.parametrize("values_format", ["e", "n"])
|
||||
def test_confusion_matrix_text_format(
|
||||
pyplot, data, y_pred, n_classes, fitted_clf, values_format
|
||||
):
|
||||
# Make sure plot text is formatted with 'values_format'.
|
||||
X, y = data
|
||||
cm = confusion_matrix(y, y_pred)
|
||||
disp = plot_confusion_matrix(
|
||||
fitted_clf, X, y, include_values=True, values_format=values_format
|
||||
)
|
||||
|
||||
assert disp.text_.shape == (n_classes, n_classes)
|
||||
|
||||
expected_text = np.array([format(v, values_format) for v in cm.ravel()])
|
||||
text_text = np.array([t.get_text() for t in disp.text_.ravel()])
|
||||
assert_array_equal(expected_text, text_text)
|
||||
|
||||
|
||||
def test_confusion_matrix_standard_format(pyplot):
|
||||
cm = np.array([[10000000, 0], [123456, 12345678]])
|
||||
plotted_text = ConfusionMatrixDisplay(cm, display_labels=[False, True]).plot().text_
|
||||
# Values should be shown as whole numbers 'd',
|
||||
# except the first number which should be shown as 1e+07 (longer length)
|
||||
# and the last number will be shown as 1.2e+07 (longer length)
|
||||
test = [t.get_text() for t in plotted_text.ravel()]
|
||||
assert test == ["1e+07", "0", "123456", "1.2e+07"]
|
||||
|
||||
cm = np.array([[0.1, 10], [100, 0.525]])
|
||||
plotted_text = ConfusionMatrixDisplay(cm, display_labels=[False, True]).plot().text_
|
||||
# Values should now formatted as '.2g', since there's a float in
|
||||
# Values are have two dec places max, (e.g 100 becomes 1e+02)
|
||||
test = [t.get_text() for t in plotted_text.ravel()]
|
||||
assert test == ["0.1", "10", "1e+02", "0.53"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"display_labels, expected_labels",
|
||||
[
|
||||
(None, ["0", "1"]),
|
||||
(["cat", "dog"], ["cat", "dog"]),
|
||||
],
|
||||
)
|
||||
def test_default_labels(pyplot, display_labels, expected_labels):
|
||||
cm = np.array([[10, 0], [12, 120]])
|
||||
disp = ConfusionMatrixDisplay(cm, display_labels=display_labels).plot()
|
||||
|
||||
x_ticks = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
|
||||
y_ticks = [tick.get_text() for tick in disp.ax_.get_yticklabels()]
|
||||
|
||||
assert_array_equal(x_ticks, expected_labels)
|
||||
assert_array_equal(y_ticks, expected_labels)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore: Function plot_confusion_matrix is deprecated")
|
||||
def test_error_on_a_dataset_with_unseen_labels(pyplot, fitted_clf, data, n_classes):
|
||||
"""Check that when labels=None, the unique values in `y_pred` and `y_true`
|
||||
will be used.
|
||||
Non-regression test for:
|
||||
https://github.com/scikit-learn/scikit-learn/pull/18405
|
||||
"""
|
||||
X, y = data
|
||||
|
||||
# create unseen labels in `y_true` not seen during fitting and not present
|
||||
# in 'fitted_clf.classes_'
|
||||
y = y + 1
|
||||
disp = plot_confusion_matrix(fitted_clf, X, y)
|
||||
|
||||
display_labels = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
|
||||
expected_labels = [str(i) for i in range(n_classes + 1)]
|
||||
assert_array_equal(expected_labels, display_labels)
|
||||
|
||||
|
||||
def test_plot_confusion_matrix_deprecation_warning(pyplot, fitted_clf, data):
|
||||
with pytest.warns(FutureWarning):
|
||||
plot_confusion_matrix(fitted_clf, *data)
|
||||
@@ -0,0 +1,131 @@
|
||||
import pytest
|
||||
|
||||
from sklearn.base import ClassifierMixin
|
||||
from sklearn.base import clone
|
||||
from sklearn.compose import make_column_transformer
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
|
||||
from sklearn.metrics import plot_det_curve
|
||||
from sklearn.metrics import plot_roc_curve
|
||||
|
||||
pytestmark = pytest.mark.filterwarnings(
|
||||
"ignore:Function plot_roc_curve is deprecated",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data():
|
||||
return load_iris(return_X_y=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data_binary(data):
|
||||
X, y = data
|
||||
return X[y < 2], y[y < 2]
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore: Function plot_det_curve is deprecated")
|
||||
@pytest.mark.parametrize("plot_func", [plot_det_curve, plot_roc_curve])
|
||||
def test_plot_curve_error_non_binary(pyplot, data, plot_func):
|
||||
X, y = data
|
||||
clf = DecisionTreeClassifier()
|
||||
clf.fit(X, y)
|
||||
|
||||
msg = (
|
||||
"Expected 'estimator' to be a binary classifier, but got DecisionTreeClassifier"
|
||||
)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
plot_func(clf, X, y)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore: Function plot_det_curve is deprecated")
|
||||
@pytest.mark.parametrize(
|
||||
"response_method, msg",
|
||||
[
|
||||
(
|
||||
"predict_proba",
|
||||
"response method predict_proba is not defined in MyClassifier",
|
||||
),
|
||||
(
|
||||
"decision_function",
|
||||
"response method decision_function is not defined in MyClassifier",
|
||||
),
|
||||
(
|
||||
"auto",
|
||||
"response method decision_function or predict_proba is not "
|
||||
"defined in MyClassifier",
|
||||
),
|
||||
(
|
||||
"bad_method",
|
||||
"response_method must be 'predict_proba', 'decision_function' or 'auto'",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("plot_func", [plot_det_curve, plot_roc_curve])
|
||||
def test_plot_curve_error_no_response(
|
||||
pyplot,
|
||||
data_binary,
|
||||
response_method,
|
||||
msg,
|
||||
plot_func,
|
||||
):
|
||||
X, y = data_binary
|
||||
|
||||
class MyClassifier(ClassifierMixin):
|
||||
def fit(self, X, y):
|
||||
self.classes_ = [0, 1]
|
||||
return self
|
||||
|
||||
clf = MyClassifier().fit(X, y)
|
||||
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
plot_func(clf, X, y, response_method=response_method)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore: Function plot_det_curve is deprecated")
|
||||
@pytest.mark.parametrize("plot_func", [plot_det_curve, plot_roc_curve])
|
||||
def test_plot_curve_estimator_name_multiple_calls(pyplot, data_binary, plot_func):
|
||||
# non-regression test checking that the `name` used when calling
|
||||
# `plot_func` is used as well when calling `disp.plot()`
|
||||
X, y = data_binary
|
||||
clf_name = "my hand-crafted name"
|
||||
clf = LogisticRegression().fit(X, y)
|
||||
disp = plot_func(clf, X, y, name=clf_name)
|
||||
assert disp.estimator_name == clf_name
|
||||
pyplot.close("all")
|
||||
disp.plot()
|
||||
assert clf_name in disp.line_.get_label()
|
||||
pyplot.close("all")
|
||||
clf_name = "another_name"
|
||||
disp.plot(name=clf_name)
|
||||
assert clf_name in disp.line_.get_label()
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore: Function plot_det_curve is deprecated")
|
||||
@pytest.mark.parametrize(
|
||||
"clf",
|
||||
[
|
||||
LogisticRegression(),
|
||||
make_pipeline(StandardScaler(), LogisticRegression()),
|
||||
make_pipeline(
|
||||
make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression()
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("plot_func", [plot_det_curve, plot_roc_curve])
|
||||
def test_plot_det_curve_not_fitted_errors(pyplot, data_binary, clf, plot_func):
|
||||
X, y = data_binary
|
||||
# clone since we parametrize the test and the classifier will be fitted
|
||||
# when testing the second and subsequent plotting function
|
||||
model = clone(clf)
|
||||
with pytest.raises(NotFittedError):
|
||||
plot_func(model, X, y)
|
||||
model.fit(X, y)
|
||||
disp = plot_func(model, X, y)
|
||||
assert model.__class__.__name__ in disp.line_.get_label()
|
||||
assert disp.estimator_name == model.__class__.__name__
|
||||
@@ -0,0 +1,84 @@
|
||||
# TODO: remove this file when plot_det_curve will be deprecated in 1.2
|
||||
import pytest
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
||||
from sklearn.metrics import det_curve
|
||||
from sklearn.metrics import plot_det_curve
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data():
|
||||
return load_iris(return_X_y=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data_binary(data):
|
||||
X, y = data
|
||||
return X[y < 2], y[y < 2]
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore: Function plot_det_curve is deprecated")
|
||||
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
|
||||
@pytest.mark.parametrize("with_sample_weight", [True, False])
|
||||
@pytest.mark.parametrize("with_strings", [True, False])
|
||||
def test_plot_det_curve(
|
||||
pyplot, response_method, data_binary, with_sample_weight, with_strings
|
||||
):
|
||||
X, y = data_binary
|
||||
|
||||
pos_label = None
|
||||
if with_strings:
|
||||
y = np.array(["c", "b"])[y]
|
||||
pos_label = "c"
|
||||
|
||||
if with_sample_weight:
|
||||
rng = np.random.RandomState(42)
|
||||
sample_weight = rng.randint(1, 4, size=(X.shape[0]))
|
||||
else:
|
||||
sample_weight = None
|
||||
|
||||
lr = LogisticRegression()
|
||||
lr.fit(X, y)
|
||||
|
||||
viz = plot_det_curve(
|
||||
lr,
|
||||
X,
|
||||
y,
|
||||
alpha=0.8,
|
||||
sample_weight=sample_weight,
|
||||
)
|
||||
|
||||
y_pred = getattr(lr, response_method)(X)
|
||||
if y_pred.ndim == 2:
|
||||
y_pred = y_pred[:, 1]
|
||||
|
||||
fpr, fnr, _ = det_curve(
|
||||
y,
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
pos_label=pos_label,
|
||||
)
|
||||
|
||||
assert_allclose(viz.fpr, fpr)
|
||||
assert_allclose(viz.fnr, fnr)
|
||||
|
||||
assert viz.estimator_name == "LogisticRegression"
|
||||
|
||||
# cannot fail thanks to pyplot fixture
|
||||
import matplotlib as mpl # noqal
|
||||
|
||||
assert isinstance(viz.line_, mpl.lines.Line2D)
|
||||
assert viz.line_.get_alpha() == 0.8
|
||||
assert isinstance(viz.ax_, mpl.axes.Axes)
|
||||
assert isinstance(viz.figure_, mpl.figure.Figure)
|
||||
assert viz.line_.get_label() == "LogisticRegression"
|
||||
|
||||
expected_pos_label = 1 if pos_label is None else pos_label
|
||||
expected_ylabel = f"False Negative Rate (Positive label: {expected_pos_label})"
|
||||
expected_xlabel = f"False Positive Rate (Positive label: {expected_pos_label})"
|
||||
assert viz.ax_.get_ylabel() == expected_ylabel
|
||||
assert viz.ax_.get_xlabel() == expected_xlabel
|
||||
@@ -0,0 +1,248 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
from sklearn.metrics import plot_precision_recall_curve
|
||||
from sklearn.metrics import average_precision_score
|
||||
from sklearn.metrics import precision_recall_curve
|
||||
from sklearn.datasets import make_classification
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.utils import shuffle
|
||||
from sklearn.compose import make_column_transformer
|
||||
|
||||
pytestmark = pytest.mark.filterwarnings(
|
||||
# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved
|
||||
"ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:"
|
||||
"matplotlib.*",
|
||||
# TODO: Remove in 1.2 (as well as all the tests below)
|
||||
"ignore:Function plot_precision_recall_curve is deprecated",
|
||||
)
|
||||
|
||||
|
||||
def test_errors(pyplot):
|
||||
X, y_multiclass = make_classification(
|
||||
n_classes=3, n_samples=50, n_informative=3, random_state=0
|
||||
)
|
||||
y_binary = y_multiclass == 0
|
||||
|
||||
# Unfitted classifier
|
||||
binary_clf = DecisionTreeClassifier()
|
||||
with pytest.raises(NotFittedError):
|
||||
plot_precision_recall_curve(binary_clf, X, y_binary)
|
||||
binary_clf.fit(X, y_binary)
|
||||
|
||||
multi_clf = DecisionTreeClassifier().fit(X, y_multiclass)
|
||||
|
||||
# Fitted multiclass classifier with binary data
|
||||
msg = (
|
||||
"Expected 'estimator' to be a binary classifier, but got DecisionTreeClassifier"
|
||||
)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
plot_precision_recall_curve(multi_clf, X, y_binary)
|
||||
|
||||
reg = DecisionTreeRegressor().fit(X, y_multiclass)
|
||||
msg = (
|
||||
"Expected 'estimator' to be a binary classifier, but got DecisionTreeRegressor"
|
||||
)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
plot_precision_recall_curve(reg, X, y_binary)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response_method, msg",
|
||||
[
|
||||
(
|
||||
"predict_proba",
|
||||
"response method predict_proba is not defined in MyClassifier",
|
||||
),
|
||||
(
|
||||
"decision_function",
|
||||
"response method decision_function is not defined in MyClassifier",
|
||||
),
|
||||
(
|
||||
"auto",
|
||||
"response method decision_function or predict_proba is not "
|
||||
"defined in MyClassifier",
|
||||
),
|
||||
(
|
||||
"bad_method",
|
||||
"response_method must be 'predict_proba', 'decision_function' or 'auto'",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_error_bad_response(pyplot, response_method, msg):
|
||||
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
|
||||
|
||||
class MyClassifier(ClassifierMixin, BaseEstimator):
|
||||
def fit(self, X, y):
|
||||
self.fitted_ = True
|
||||
self.classes_ = [0, 1]
|
||||
return self
|
||||
|
||||
clf = MyClassifier().fit(X, y)
|
||||
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
plot_precision_recall_curve(clf, X, y, response_method=response_method)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
|
||||
@pytest.mark.parametrize("with_sample_weight", [True, False])
|
||||
def test_plot_precision_recall(pyplot, response_method, with_sample_weight):
|
||||
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
|
||||
|
||||
lr = LogisticRegression().fit(X, y)
|
||||
|
||||
if with_sample_weight:
|
||||
rng = np.random.RandomState(42)
|
||||
sample_weight = rng.randint(0, 4, size=X.shape[0])
|
||||
else:
|
||||
sample_weight = None
|
||||
|
||||
disp = plot_precision_recall_curve(
|
||||
lr,
|
||||
X,
|
||||
y,
|
||||
alpha=0.8,
|
||||
response_method=response_method,
|
||||
sample_weight=sample_weight,
|
||||
)
|
||||
|
||||
y_score = getattr(lr, response_method)(X)
|
||||
if response_method == "predict_proba":
|
||||
y_score = y_score[:, 1]
|
||||
|
||||
prec, recall, _ = precision_recall_curve(y, y_score, sample_weight=sample_weight)
|
||||
avg_prec = average_precision_score(y, y_score, sample_weight=sample_weight)
|
||||
|
||||
assert_allclose(disp.precision, prec)
|
||||
assert_allclose(disp.recall, recall)
|
||||
assert disp.average_precision == pytest.approx(avg_prec)
|
||||
|
||||
assert disp.estimator_name == "LogisticRegression"
|
||||
|
||||
# cannot fail thanks to pyplot fixture
|
||||
import matplotlib as mpl # noqa
|
||||
|
||||
assert isinstance(disp.line_, mpl.lines.Line2D)
|
||||
assert disp.line_.get_alpha() == 0.8
|
||||
assert isinstance(disp.ax_, mpl.axes.Axes)
|
||||
assert isinstance(disp.figure_, mpl.figure.Figure)
|
||||
|
||||
expected_label = "LogisticRegression (AP = {:0.2f})".format(avg_prec)
|
||||
assert disp.line_.get_label() == expected_label
|
||||
assert disp.ax_.get_xlabel() == "Recall (Positive label: 1)"
|
||||
assert disp.ax_.get_ylabel() == "Precision (Positive label: 1)"
|
||||
|
||||
# draw again with another label
|
||||
disp.plot(name="MySpecialEstimator")
|
||||
expected_label = "MySpecialEstimator (AP = {:0.2f})".format(avg_prec)
|
||||
assert disp.line_.get_label() == expected_label
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clf",
|
||||
[
|
||||
make_pipeline(StandardScaler(), LogisticRegression()),
|
||||
make_pipeline(
|
||||
make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression()
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_precision_recall_curve_pipeline(pyplot, clf):
|
||||
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
|
||||
with pytest.raises(NotFittedError):
|
||||
plot_precision_recall_curve(clf, X, y)
|
||||
clf.fit(X, y)
|
||||
disp = plot_precision_recall_curve(clf, X, y)
|
||||
assert disp.estimator_name == clf.__class__.__name__
|
||||
|
||||
|
||||
def test_precision_recall_curve_string_labels(pyplot):
|
||||
# regression test #15738
|
||||
cancer = load_breast_cancer()
|
||||
X = cancer.data
|
||||
y = cancer.target_names[cancer.target]
|
||||
|
||||
lr = make_pipeline(StandardScaler(), LogisticRegression())
|
||||
lr.fit(X, y)
|
||||
for klass in cancer.target_names:
|
||||
assert klass in lr.classes_
|
||||
disp = plot_precision_recall_curve(lr, X, y)
|
||||
|
||||
y_pred = lr.predict_proba(X)[:, 1]
|
||||
avg_prec = average_precision_score(y, y_pred, pos_label=lr.classes_[1])
|
||||
|
||||
assert disp.average_precision == pytest.approx(avg_prec)
|
||||
assert disp.estimator_name == lr.__class__.__name__
|
||||
|
||||
|
||||
def test_plot_precision_recall_curve_estimator_name_multiple_calls(pyplot):
|
||||
# non-regression test checking that the `name` used when calling
|
||||
# `plot_precision_recall_curve` is used as well when calling `disp.plot()`
|
||||
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
|
||||
clf_name = "my hand-crafted name"
|
||||
clf = LogisticRegression().fit(X, y)
|
||||
disp = plot_precision_recall_curve(clf, X, y, name=clf_name)
|
||||
assert disp.estimator_name == clf_name
|
||||
pyplot.close("all")
|
||||
disp.plot()
|
||||
assert clf_name in disp.line_.get_label()
|
||||
pyplot.close("all")
|
||||
clf_name = "another_name"
|
||||
disp.plot(name=clf_name)
|
||||
assert clf_name in disp.line_.get_label()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
|
||||
def test_plot_precision_recall_pos_label(pyplot, response_method):
|
||||
# check that we can provide the positive label and display the proper
|
||||
# statistics
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
# create an highly imbalanced version of the breast cancer dataset
|
||||
idx_positive = np.flatnonzero(y == 1)
|
||||
idx_negative = np.flatnonzero(y == 0)
|
||||
idx_selected = np.hstack([idx_negative, idx_positive[:25]])
|
||||
X, y = X[idx_selected], y[idx_selected]
|
||||
X, y = shuffle(X, y, random_state=42)
|
||||
# only use 2 features to make the problem even harder
|
||||
X = X[:, :2]
|
||||
y = np.array(["cancer" if c == 1 else "not cancer" for c in y], dtype=object)
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X,
|
||||
y,
|
||||
stratify=y,
|
||||
random_state=0,
|
||||
)
|
||||
|
||||
classifier = LogisticRegression()
|
||||
classifier.fit(X_train, y_train)
|
||||
|
||||
# sanity check to be sure the positive class is classes_[0] and that we
|
||||
# are betrayed by the class imbalance
|
||||
assert classifier.classes_.tolist() == ["cancer", "not cancer"]
|
||||
|
||||
disp = plot_precision_recall_curve(
|
||||
classifier, X_test, y_test, pos_label="cancer", response_method=response_method
|
||||
)
|
||||
# we should obtain the statistics of the "cancer" class
|
||||
avg_prec_limit = 0.65
|
||||
assert disp.average_precision < avg_prec_limit
|
||||
assert -np.trapz(disp.precision, disp.recall) < avg_prec_limit
|
||||
|
||||
# otherwise we should obtain the statistics of the "not cancer" class
|
||||
disp = plot_precision_recall_curve(
|
||||
classifier,
|
||||
X_test,
|
||||
y_test,
|
||||
response_method=response_method,
|
||||
)
|
||||
avg_prec_limit = 0.95
|
||||
assert disp.average_precision > avg_prec_limit
|
||||
assert -np.trapz(disp.precision, disp.recall) > avg_prec_limit
|
||||
@@ -0,0 +1,176 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from sklearn.metrics import plot_roc_curve
|
||||
from sklearn.metrics import roc_curve
|
||||
from sklearn.metrics import auc
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.utils import shuffle
|
||||
from sklearn.compose import make_column_transformer
|
||||
|
||||
# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved
|
||||
pytestmark = pytest.mark.filterwarnings(
|
||||
"ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:"
|
||||
"matplotlib.*",
|
||||
"ignore:Function plot_roc_curve is deprecated",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data():
|
||||
return load_iris(return_X_y=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data_binary(data):
|
||||
X, y = data
|
||||
return X[y < 2], y[y < 2]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
|
||||
@pytest.mark.parametrize("with_sample_weight", [True, False])
|
||||
@pytest.mark.parametrize("drop_intermediate", [True, False])
|
||||
@pytest.mark.parametrize("with_strings", [True, False])
|
||||
def test_plot_roc_curve(
|
||||
pyplot,
|
||||
response_method,
|
||||
data_binary,
|
||||
with_sample_weight,
|
||||
drop_intermediate,
|
||||
with_strings,
|
||||
):
|
||||
X, y = data_binary
|
||||
|
||||
pos_label = None
|
||||
if with_strings:
|
||||
y = np.array(["c", "b"])[y]
|
||||
pos_label = "c"
|
||||
|
||||
if with_sample_weight:
|
||||
rng = np.random.RandomState(42)
|
||||
sample_weight = rng.randint(1, 4, size=(X.shape[0]))
|
||||
else:
|
||||
sample_weight = None
|
||||
|
||||
lr = LogisticRegression()
|
||||
lr.fit(X, y)
|
||||
|
||||
viz = plot_roc_curve(
|
||||
lr,
|
||||
X,
|
||||
y,
|
||||
alpha=0.8,
|
||||
sample_weight=sample_weight,
|
||||
drop_intermediate=drop_intermediate,
|
||||
)
|
||||
|
||||
y_pred = getattr(lr, response_method)(X)
|
||||
if y_pred.ndim == 2:
|
||||
y_pred = y_pred[:, 1]
|
||||
|
||||
fpr, tpr, _ = roc_curve(
|
||||
y,
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
drop_intermediate=drop_intermediate,
|
||||
pos_label=pos_label,
|
||||
)
|
||||
|
||||
assert_allclose(viz.roc_auc, auc(fpr, tpr))
|
||||
assert_allclose(viz.fpr, fpr)
|
||||
assert_allclose(viz.tpr, tpr)
|
||||
|
||||
assert viz.estimator_name == "LogisticRegression"
|
||||
|
||||
# cannot fail thanks to pyplot fixture
|
||||
import matplotlib as mpl # noqal
|
||||
|
||||
assert isinstance(viz.line_, mpl.lines.Line2D)
|
||||
assert viz.line_.get_alpha() == 0.8
|
||||
assert isinstance(viz.ax_, mpl.axes.Axes)
|
||||
assert isinstance(viz.figure_, mpl.figure.Figure)
|
||||
|
||||
expected_label = "LogisticRegression (AUC = {:0.2f})".format(viz.roc_auc)
|
||||
assert viz.line_.get_label() == expected_label
|
||||
|
||||
expected_pos_label = 1 if pos_label is None else pos_label
|
||||
expected_ylabel = f"True Positive Rate (Positive label: {expected_pos_label})"
|
||||
expected_xlabel = f"False Positive Rate (Positive label: {expected_pos_label})"
|
||||
|
||||
assert viz.ax_.get_ylabel() == expected_ylabel
|
||||
assert viz.ax_.get_xlabel() == expected_xlabel
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clf",
|
||||
[
|
||||
LogisticRegression(),
|
||||
make_pipeline(StandardScaler(), LogisticRegression()),
|
||||
make_pipeline(
|
||||
make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression()
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_roc_curve_not_fitted_errors(pyplot, data_binary, clf):
|
||||
X, y = data_binary
|
||||
with pytest.raises(NotFittedError):
|
||||
plot_roc_curve(clf, X, y)
|
||||
clf.fit(X, y)
|
||||
disp = plot_roc_curve(clf, X, y)
|
||||
assert clf.__class__.__name__ in disp.line_.get_label()
|
||||
assert disp.estimator_name == clf.__class__.__name__
|
||||
|
||||
|
||||
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
|
||||
def test_plot_roc_curve_pos_label(pyplot, response_method):
|
||||
# check that we can provide the positive label and display the proper
|
||||
# statistics
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
# create an highly imbalanced
|
||||
idx_positive = np.flatnonzero(y == 1)
|
||||
idx_negative = np.flatnonzero(y == 0)
|
||||
idx_selected = np.hstack([idx_negative, idx_positive[:25]])
|
||||
X, y = X[idx_selected], y[idx_selected]
|
||||
X, y = shuffle(X, y, random_state=42)
|
||||
# only use 2 features to make the problem even harder
|
||||
X = X[:, :2]
|
||||
y = np.array(["cancer" if c == 1 else "not cancer" for c in y], dtype=object)
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X,
|
||||
y,
|
||||
stratify=y,
|
||||
random_state=0,
|
||||
)
|
||||
|
||||
classifier = LogisticRegression()
|
||||
classifier.fit(X_train, y_train)
|
||||
|
||||
# sanity check to be sure the positive class is classes_[0] and that we
|
||||
# are betrayed by the class imbalance
|
||||
assert classifier.classes_.tolist() == ["cancer", "not cancer"]
|
||||
|
||||
disp = plot_roc_curve(
|
||||
classifier, X_test, y_test, pos_label="cancer", response_method=response_method
|
||||
)
|
||||
|
||||
roc_auc_limit = 0.95679
|
||||
|
||||
assert disp.roc_auc == pytest.approx(roc_auc_limit)
|
||||
assert np.trapz(disp.tpr, disp.fpr) == pytest.approx(roc_auc_limit)
|
||||
|
||||
disp = plot_roc_curve(
|
||||
classifier,
|
||||
X_test,
|
||||
y_test,
|
||||
response_method=response_method,
|
||||
)
|
||||
|
||||
assert disp.roc_auc == pytest.approx(roc_auc_limit)
|
||||
assert np.trapz(disp.tpr, disp.fpr) == pytest.approx(roc_auc_limit)
|
||||
@@ -0,0 +1,304 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from sklearn.compose import make_column_transformer
|
||||
from sklearn.datasets import load_breast_cancer, make_classification
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import average_precision_score, precision_recall_curve
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.svm import SVC, SVR
|
||||
from sklearn.utils import shuffle
|
||||
|
||||
from sklearn.metrics import PrecisionRecallDisplay, plot_precision_recall_curve
|
||||
|
||||
# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved
|
||||
pytestmark = pytest.mark.filterwarnings(
|
||||
"ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:"
|
||||
"matplotlib.*"
|
||||
)
|
||||
|
||||
|
||||
def test_precision_recall_display_validation(pyplot):
|
||||
"""Check that we raise the proper error when validating parameters."""
|
||||
X, y = make_classification(
|
||||
n_samples=100, n_informative=5, n_classes=5, random_state=0
|
||||
)
|
||||
|
||||
with pytest.raises(NotFittedError):
|
||||
PrecisionRecallDisplay.from_estimator(SVC(), X, y)
|
||||
|
||||
regressor = SVR().fit(X, y)
|
||||
y_pred_regressor = regressor.predict(X)
|
||||
classifier = SVC(probability=True).fit(X, y)
|
||||
y_pred_classifier = classifier.predict_proba(X)[:, -1]
|
||||
|
||||
err_msg = "PrecisionRecallDisplay.from_estimator only supports classifiers"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
PrecisionRecallDisplay.from_estimator(regressor, X, y)
|
||||
|
||||
err_msg = "Expected 'estimator' to be a binary classifier, but got SVC"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
PrecisionRecallDisplay.from_estimator(classifier, X, y)
|
||||
|
||||
err_msg = "{} format is not supported"
|
||||
with pytest.raises(ValueError, match=err_msg.format("continuous")):
|
||||
# Force `y_true` to be seen as a regression problem
|
||||
PrecisionRecallDisplay.from_predictions(y + 0.5, y_pred_classifier, pos_label=1)
|
||||
with pytest.raises(ValueError, match=err_msg.format("multiclass")):
|
||||
PrecisionRecallDisplay.from_predictions(y, y_pred_regressor, pos_label=1)
|
||||
|
||||
err_msg = "Found input variables with inconsistent numbers of samples"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
PrecisionRecallDisplay.from_predictions(y, y_pred_classifier[::2])
|
||||
|
||||
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
|
||||
y += 10
|
||||
classifier.fit(X, y)
|
||||
y_pred_classifier = classifier.predict_proba(X)[:, -1]
|
||||
err_msg = r"y_true takes value in {10, 11} and pos_label is not specified"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
PrecisionRecallDisplay.from_predictions(y, y_pred_classifier)
|
||||
|
||||
|
||||
# FIXME: Remove in 1.2
|
||||
def test_plot_precision_recall_curve_deprecation(pyplot):
|
||||
"""Check that we raise a FutureWarning when calling
|
||||
`plot_precision_recall_curve`."""
|
||||
|
||||
X, y = make_classification(random_state=0)
|
||||
clf = LogisticRegression().fit(X, y)
|
||||
deprecation_warning = "Function plot_precision_recall_curve is deprecated"
|
||||
with pytest.warns(FutureWarning, match=deprecation_warning):
|
||||
plot_precision_recall_curve(clf, X, y)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
|
||||
def test_precision_recall_display_plotting(pyplot, constructor_name, response_method):
|
||||
"""Check the overall plotting rendering."""
|
||||
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
|
||||
pos_label = 1
|
||||
|
||||
classifier = LogisticRegression().fit(X, y)
|
||||
classifier.fit(X, y)
|
||||
|
||||
y_pred = getattr(classifier, response_method)(X)
|
||||
y_pred = y_pred if y_pred.ndim == 1 else y_pred[:, pos_label]
|
||||
|
||||
# safe guard for the binary if/else construction
|
||||
assert constructor_name in ("from_estimator", "from_predictions")
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
display = PrecisionRecallDisplay.from_estimator(
|
||||
classifier, X, y, response_method=response_method
|
||||
)
|
||||
else:
|
||||
display = PrecisionRecallDisplay.from_predictions(
|
||||
y, y_pred, pos_label=pos_label
|
||||
)
|
||||
|
||||
precision, recall, _ = precision_recall_curve(y, y_pred, pos_label=pos_label)
|
||||
average_precision = average_precision_score(y, y_pred, pos_label=pos_label)
|
||||
|
||||
np.testing.assert_allclose(display.precision, precision)
|
||||
np.testing.assert_allclose(display.recall, recall)
|
||||
assert display.average_precision == pytest.approx(average_precision)
|
||||
|
||||
import matplotlib as mpl
|
||||
|
||||
assert isinstance(display.line_, mpl.lines.Line2D)
|
||||
assert isinstance(display.ax_, mpl.axes.Axes)
|
||||
assert isinstance(display.figure_, mpl.figure.Figure)
|
||||
|
||||
assert display.ax_.get_xlabel() == "Recall (Positive label: 1)"
|
||||
assert display.ax_.get_ylabel() == "Precision (Positive label: 1)"
|
||||
|
||||
# plotting passing some new parameters
|
||||
display.plot(alpha=0.8, name="MySpecialEstimator")
|
||||
expected_label = f"MySpecialEstimator (AP = {average_precision:0.2f})"
|
||||
assert display.line_.get_label() == expected_label
|
||||
assert display.line_.get_alpha() == pytest.approx(0.8)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"constructor_name, default_label",
|
||||
[
|
||||
("from_estimator", "LogisticRegression (AP = {:.2f})"),
|
||||
("from_predictions", "Classifier (AP = {:.2f})"),
|
||||
],
|
||||
)
|
||||
def test_precision_recall_display_name(pyplot, constructor_name, default_label):
|
||||
"""Check the behaviour of the name parameters"""
|
||||
X, y = make_classification(n_classes=2, n_samples=100, random_state=0)
|
||||
pos_label = 1
|
||||
|
||||
classifier = LogisticRegression().fit(X, y)
|
||||
classifier.fit(X, y)
|
||||
|
||||
y_pred = classifier.predict_proba(X)[:, pos_label]
|
||||
|
||||
# safe guard for the binary if/else construction
|
||||
assert constructor_name in ("from_estimator", "from_predictions")
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
display = PrecisionRecallDisplay.from_estimator(classifier, X, y)
|
||||
else:
|
||||
display = PrecisionRecallDisplay.from_predictions(
|
||||
y, y_pred, pos_label=pos_label
|
||||
)
|
||||
|
||||
average_precision = average_precision_score(y, y_pred, pos_label=pos_label)
|
||||
|
||||
# check that the default name is used
|
||||
assert display.line_.get_label() == default_label.format(average_precision)
|
||||
|
||||
# check that the name can be set
|
||||
display.plot(name="MySpecialEstimator")
|
||||
assert (
|
||||
display.line_.get_label()
|
||||
== f"MySpecialEstimator (AP = {average_precision:.2f})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clf",
|
||||
[
|
||||
make_pipeline(StandardScaler(), LogisticRegression()),
|
||||
make_pipeline(
|
||||
make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression()
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_precision_recall_display_pipeline(pyplot, clf):
|
||||
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
|
||||
with pytest.raises(NotFittedError):
|
||||
PrecisionRecallDisplay.from_estimator(clf, X, y)
|
||||
clf.fit(X, y)
|
||||
display = PrecisionRecallDisplay.from_estimator(clf, X, y)
|
||||
assert display.estimator_name == clf.__class__.__name__
|
||||
|
||||
|
||||
def test_precision_recall_display_string_labels(pyplot):
|
||||
# regression test #15738
|
||||
cancer = load_breast_cancer()
|
||||
X, y = cancer.data, cancer.target_names[cancer.target]
|
||||
|
||||
lr = make_pipeline(StandardScaler(), LogisticRegression())
|
||||
lr.fit(X, y)
|
||||
for klass in cancer.target_names:
|
||||
assert klass in lr.classes_
|
||||
display = PrecisionRecallDisplay.from_estimator(lr, X, y)
|
||||
|
||||
y_pred = lr.predict_proba(X)[:, 1]
|
||||
avg_prec = average_precision_score(y, y_pred, pos_label=lr.classes_[1])
|
||||
|
||||
assert display.average_precision == pytest.approx(avg_prec)
|
||||
assert display.estimator_name == lr.__class__.__name__
|
||||
|
||||
err_msg = r"y_true takes value in {'benign', 'malignant'}"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
PrecisionRecallDisplay.from_predictions(y, y_pred)
|
||||
|
||||
display = PrecisionRecallDisplay.from_predictions(
|
||||
y, y_pred, pos_label=lr.classes_[1]
|
||||
)
|
||||
assert display.average_precision == pytest.approx(avg_prec)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"average_precision, estimator_name, expected_label",
|
||||
[
|
||||
(0.9, None, "AP = 0.90"),
|
||||
(None, "my_est", "my_est"),
|
||||
(0.8, "my_est2", "my_est2 (AP = 0.80)"),
|
||||
],
|
||||
)
|
||||
def test_default_labels(pyplot, average_precision, estimator_name, expected_label):
|
||||
"""Check the default labels used in the display."""
|
||||
precision = np.array([1, 0.5, 0])
|
||||
recall = np.array([0, 0.5, 1])
|
||||
display = PrecisionRecallDisplay(
|
||||
precision,
|
||||
recall,
|
||||
average_precision=average_precision,
|
||||
estimator_name=estimator_name,
|
||||
)
|
||||
display.plot()
|
||||
assert display.line_.get_label() == expected_label
|
||||
|
||||
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
|
||||
def test_plot_precision_recall_pos_label(pyplot, constructor_name, response_method):
|
||||
# check that we can provide the positive label and display the proper
|
||||
# statistics
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
# create an highly imbalanced version of the breast cancer dataset
|
||||
idx_positive = np.flatnonzero(y == 1)
|
||||
idx_negative = np.flatnonzero(y == 0)
|
||||
idx_selected = np.hstack([idx_negative, idx_positive[:25]])
|
||||
X, y = X[idx_selected], y[idx_selected]
|
||||
X, y = shuffle(X, y, random_state=42)
|
||||
# only use 2 features to make the problem even harder
|
||||
X = X[:, :2]
|
||||
y = np.array(["cancer" if c == 1 else "not cancer" for c in y], dtype=object)
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X,
|
||||
y,
|
||||
stratify=y,
|
||||
random_state=0,
|
||||
)
|
||||
|
||||
classifier = LogisticRegression()
|
||||
classifier.fit(X_train, y_train)
|
||||
|
||||
# sanity check to be sure the positive class is classes_[0] and that we
|
||||
# are betrayed by the class imbalance
|
||||
assert classifier.classes_.tolist() == ["cancer", "not cancer"]
|
||||
|
||||
y_pred = getattr(classifier, response_method)(X_test)
|
||||
# we select the corresponding probability columns or reverse the decision
|
||||
# function otherwise
|
||||
y_pred_cancer = -1 * y_pred if y_pred.ndim == 1 else y_pred[:, 0]
|
||||
y_pred_not_cancer = y_pred if y_pred.ndim == 1 else y_pred[:, 1]
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
display = PrecisionRecallDisplay.from_estimator(
|
||||
classifier,
|
||||
X_test,
|
||||
y_test,
|
||||
pos_label="cancer",
|
||||
response_method=response_method,
|
||||
)
|
||||
else:
|
||||
display = PrecisionRecallDisplay.from_predictions(
|
||||
y_test,
|
||||
y_pred_cancer,
|
||||
pos_label="cancer",
|
||||
)
|
||||
# we should obtain the statistics of the "cancer" class
|
||||
avg_prec_limit = 0.65
|
||||
assert display.average_precision < avg_prec_limit
|
||||
assert -np.trapz(display.precision, display.recall) < avg_prec_limit
|
||||
|
||||
# otherwise we should obtain the statistics of the "not cancer" class
|
||||
if constructor_name == "from_estimator":
|
||||
display = PrecisionRecallDisplay.from_estimator(
|
||||
classifier,
|
||||
X_test,
|
||||
y_test,
|
||||
response_method=response_method,
|
||||
pos_label="not cancer",
|
||||
)
|
||||
else:
|
||||
display = PrecisionRecallDisplay.from_predictions(
|
||||
y_test,
|
||||
y_pred_not_cancer,
|
||||
pos_label="not cancer",
|
||||
)
|
||||
avg_prec_limit = 0.95
|
||||
assert display.average_precision > avg_prec_limit
|
||||
assert -np.trapz(display.precision, display.recall) > avg_prec_limit
|
||||
@@ -0,0 +1,263 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
|
||||
from sklearn.compose import make_column_transformer
|
||||
from sklearn.datasets import load_iris
|
||||
|
||||
from sklearn.datasets import load_breast_cancer, make_classification
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import roc_curve
|
||||
from sklearn.metrics import auc
|
||||
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.utils import shuffle
|
||||
|
||||
|
||||
from sklearn.metrics import RocCurveDisplay, plot_roc_curve
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data():
|
||||
return load_iris(return_X_y=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def data_binary(data):
|
||||
X, y = data
|
||||
return X[y < 2], y[y < 2]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
|
||||
@pytest.mark.parametrize("with_sample_weight", [True, False])
|
||||
@pytest.mark.parametrize("drop_intermediate", [True, False])
|
||||
@pytest.mark.parametrize("with_strings", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"constructor_name, default_name",
|
||||
[
|
||||
("from_estimator", "LogisticRegression"),
|
||||
("from_predictions", "Classifier"),
|
||||
],
|
||||
)
|
||||
def test_roc_curve_display_plotting(
|
||||
pyplot,
|
||||
response_method,
|
||||
data_binary,
|
||||
with_sample_weight,
|
||||
drop_intermediate,
|
||||
with_strings,
|
||||
constructor_name,
|
||||
default_name,
|
||||
):
|
||||
"""Check the overall plotting behaviour."""
|
||||
X, y = data_binary
|
||||
|
||||
pos_label = None
|
||||
if with_strings:
|
||||
y = np.array(["c", "b"])[y]
|
||||
pos_label = "c"
|
||||
|
||||
if with_sample_weight:
|
||||
rng = np.random.RandomState(42)
|
||||
sample_weight = rng.randint(1, 4, size=(X.shape[0]))
|
||||
else:
|
||||
sample_weight = None
|
||||
|
||||
lr = LogisticRegression()
|
||||
lr.fit(X, y)
|
||||
|
||||
y_pred = getattr(lr, response_method)(X)
|
||||
y_pred = y_pred if y_pred.ndim == 1 else y_pred[:, 1]
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
display = RocCurveDisplay.from_estimator(
|
||||
lr,
|
||||
X,
|
||||
y,
|
||||
sample_weight=sample_weight,
|
||||
drop_intermediate=drop_intermediate,
|
||||
pos_label=pos_label,
|
||||
alpha=0.8,
|
||||
)
|
||||
else:
|
||||
display = RocCurveDisplay.from_predictions(
|
||||
y,
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
drop_intermediate=drop_intermediate,
|
||||
pos_label=pos_label,
|
||||
alpha=0.8,
|
||||
)
|
||||
|
||||
fpr, tpr, _ = roc_curve(
|
||||
y,
|
||||
y_pred,
|
||||
sample_weight=sample_weight,
|
||||
drop_intermediate=drop_intermediate,
|
||||
pos_label=pos_label,
|
||||
)
|
||||
|
||||
assert_allclose(display.roc_auc, auc(fpr, tpr))
|
||||
assert_allclose(display.fpr, fpr)
|
||||
assert_allclose(display.tpr, tpr)
|
||||
|
||||
assert display.estimator_name == default_name
|
||||
|
||||
import matplotlib as mpl # noqal
|
||||
|
||||
assert isinstance(display.line_, mpl.lines.Line2D)
|
||||
assert display.line_.get_alpha() == 0.8
|
||||
assert isinstance(display.ax_, mpl.axes.Axes)
|
||||
assert isinstance(display.figure_, mpl.figure.Figure)
|
||||
|
||||
expected_label = f"{default_name} (AUC = {display.roc_auc:.2f})"
|
||||
assert display.line_.get_label() == expected_label
|
||||
|
||||
expected_pos_label = 1 if pos_label is None else pos_label
|
||||
expected_ylabel = f"True Positive Rate (Positive label: {expected_pos_label})"
|
||||
expected_xlabel = f"False Positive Rate (Positive label: {expected_pos_label})"
|
||||
|
||||
assert display.ax_.get_ylabel() == expected_ylabel
|
||||
assert display.ax_.get_xlabel() == expected_xlabel
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clf",
|
||||
[
|
||||
LogisticRegression(),
|
||||
make_pipeline(StandardScaler(), LogisticRegression()),
|
||||
make_pipeline(
|
||||
make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression()
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
def test_roc_curve_display_complex_pipeline(pyplot, data_binary, clf, constructor_name):
|
||||
"""Check the behaviour with complex pipeline."""
|
||||
X, y = data_binary
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
with pytest.raises(NotFittedError):
|
||||
RocCurveDisplay.from_estimator(clf, X, y)
|
||||
|
||||
clf.fit(X, y)
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
display = RocCurveDisplay.from_estimator(clf, X, y)
|
||||
name = clf.__class__.__name__
|
||||
else:
|
||||
display = RocCurveDisplay.from_predictions(y, y)
|
||||
name = "Classifier"
|
||||
|
||||
assert name in display.line_.get_label()
|
||||
assert display.estimator_name == name
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"roc_auc, estimator_name, expected_label",
|
||||
[
|
||||
(0.9, None, "AUC = 0.90"),
|
||||
(None, "my_est", "my_est"),
|
||||
(0.8, "my_est2", "my_est2 (AUC = 0.80)"),
|
||||
],
|
||||
)
|
||||
def test_roc_curve_display_default_labels(
|
||||
pyplot, roc_auc, estimator_name, expected_label
|
||||
):
|
||||
"""Check the default labels used in the display."""
|
||||
fpr = np.array([0, 0.5, 1])
|
||||
tpr = np.array([0, 0.5, 1])
|
||||
disp = RocCurveDisplay(
|
||||
fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name=estimator_name
|
||||
).plot()
|
||||
assert disp.line_.get_label() == expected_label
|
||||
|
||||
|
||||
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
|
||||
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
|
||||
def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name):
|
||||
# check that we can provide the positive label and display the proper
|
||||
# statistics
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
# create an highly imbalanced
|
||||
idx_positive = np.flatnonzero(y == 1)
|
||||
idx_negative = np.flatnonzero(y == 0)
|
||||
idx_selected = np.hstack([idx_negative, idx_positive[:25]])
|
||||
X, y = X[idx_selected], y[idx_selected]
|
||||
X, y = shuffle(X, y, random_state=42)
|
||||
# only use 2 features to make the problem even harder
|
||||
X = X[:, :2]
|
||||
y = np.array(["cancer" if c == 1 else "not cancer" for c in y], dtype=object)
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X,
|
||||
y,
|
||||
stratify=y,
|
||||
random_state=0,
|
||||
)
|
||||
|
||||
classifier = LogisticRegression()
|
||||
classifier.fit(X_train, y_train)
|
||||
|
||||
# sanity check to be sure the positive class is classes_[0] and that we
|
||||
# are betrayed by the class imbalance
|
||||
assert classifier.classes_.tolist() == ["cancer", "not cancer"]
|
||||
|
||||
y_pred = getattr(classifier, response_method)(X_test)
|
||||
# we select the corresponding probability columns or reverse the decision
|
||||
# function otherwise
|
||||
y_pred_cancer = -1 * y_pred if y_pred.ndim == 1 else y_pred[:, 0]
|
||||
y_pred_not_cancer = y_pred if y_pred.ndim == 1 else y_pred[:, 1]
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
display = RocCurveDisplay.from_estimator(
|
||||
classifier,
|
||||
X_test,
|
||||
y_test,
|
||||
pos_label="cancer",
|
||||
response_method=response_method,
|
||||
)
|
||||
else:
|
||||
display = RocCurveDisplay.from_predictions(
|
||||
y_test,
|
||||
y_pred_cancer,
|
||||
pos_label="cancer",
|
||||
)
|
||||
|
||||
roc_auc_limit = 0.95679
|
||||
|
||||
assert display.roc_auc == pytest.approx(roc_auc_limit)
|
||||
assert np.trapz(display.tpr, display.fpr) == pytest.approx(roc_auc_limit)
|
||||
|
||||
if constructor_name == "from_estimator":
|
||||
display = RocCurveDisplay.from_estimator(
|
||||
classifier,
|
||||
X_test,
|
||||
y_test,
|
||||
response_method=response_method,
|
||||
pos_label="not cancer",
|
||||
)
|
||||
else:
|
||||
display = RocCurveDisplay.from_predictions(
|
||||
y_test,
|
||||
y_pred_not_cancer,
|
||||
pos_label="not cancer",
|
||||
)
|
||||
|
||||
assert display.roc_auc == pytest.approx(roc_auc_limit)
|
||||
assert np.trapz(display.tpr, display.fpr) == pytest.approx(roc_auc_limit)
|
||||
|
||||
|
||||
# FIXME: Remove in 1.2
|
||||
def test_plot_precision_recall_curve_deprecation(pyplot):
|
||||
"""Check that we raise a FutureWarning when calling
|
||||
`plot_roc_curve`."""
|
||||
|
||||
X, y = make_classification(random_state=0)
|
||||
clf = LogisticRegression().fit(X, y)
|
||||
deprecation_warning = "Function plot_roc_curve is deprecated"
|
||||
with pytest.warns(FutureWarning, match=deprecation_warning):
|
||||
plot_roc_curve(clf, X, y)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,836 @@
|
||||
"""
|
||||
The :mod:`sklearn.metrics.scorer` submodule implements a flexible
|
||||
interface for model selection and evaluation using
|
||||
arbitrary score functions.
|
||||
|
||||
A scorer object is a callable that can be passed to
|
||||
:class:`~sklearn.model_selection.GridSearchCV` or
|
||||
:func:`sklearn.model_selection.cross_val_score` as the ``scoring``
|
||||
parameter, to specify how a model should be evaluated.
|
||||
|
||||
The signature of the call is ``(estimator, X, y)`` where ``estimator``
|
||||
is the model to be evaluated, ``X`` is the test data and ``y`` is the
|
||||
ground truth labeling (or ``None`` in the case of unsupervised models).
|
||||
"""
|
||||
|
||||
# Authors: Andreas Mueller <amueller@ais.uni-bonn.de>
|
||||
# Lars Buitinck
|
||||
# Arnaud Joly <arnaud.v.joly@gmail.com>
|
||||
# License: Simplified BSD
|
||||
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from collections import Counter
|
||||
|
||||
import numpy as np
|
||||
import copy
|
||||
import warnings
|
||||
|
||||
from . import (
|
||||
r2_score,
|
||||
median_absolute_error,
|
||||
max_error,
|
||||
mean_absolute_error,
|
||||
mean_squared_error,
|
||||
mean_squared_log_error,
|
||||
mean_poisson_deviance,
|
||||
mean_gamma_deviance,
|
||||
accuracy_score,
|
||||
top_k_accuracy_score,
|
||||
f1_score,
|
||||
roc_auc_score,
|
||||
average_precision_score,
|
||||
precision_score,
|
||||
recall_score,
|
||||
log_loss,
|
||||
balanced_accuracy_score,
|
||||
explained_variance_score,
|
||||
brier_score_loss,
|
||||
jaccard_score,
|
||||
mean_absolute_percentage_error,
|
||||
matthews_corrcoef,
|
||||
)
|
||||
|
||||
from .cluster import adjusted_rand_score
|
||||
from .cluster import rand_score
|
||||
from .cluster import homogeneity_score
|
||||
from .cluster import completeness_score
|
||||
from .cluster import v_measure_score
|
||||
from .cluster import mutual_info_score
|
||||
from .cluster import adjusted_mutual_info_score
|
||||
from .cluster import normalized_mutual_info_score
|
||||
from .cluster import fowlkes_mallows_score
|
||||
|
||||
from ..utils.multiclass import type_of_target
|
||||
from ..base import is_regressor
|
||||
|
||||
|
||||
def _cached_call(cache, estimator, method, *args, **kwargs):
|
||||
"""Call estimator with method and args and kwargs."""
|
||||
if cache is None:
|
||||
return getattr(estimator, method)(*args, **kwargs)
|
||||
|
||||
try:
|
||||
return cache[method]
|
||||
except KeyError:
|
||||
result = getattr(estimator, method)(*args, **kwargs)
|
||||
cache[method] = result
|
||||
return result
|
||||
|
||||
|
||||
class _MultimetricScorer:
|
||||
"""Callable for multimetric scoring used to avoid repeated calls
|
||||
to `predict_proba`, `predict`, and `decision_function`.
|
||||
|
||||
`_MultimetricScorer` will return a dictionary of scores corresponding to
|
||||
the scorers in the dictionary. Note that `_MultimetricScorer` can be
|
||||
created with a dictionary with one key (i.e. only one actual scorer).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
scorers : dict
|
||||
Dictionary mapping names to callable scorers.
|
||||
"""
|
||||
|
||||
def __init__(self, **scorers):
|
||||
self._scorers = scorers
|
||||
|
||||
def __call__(self, estimator, *args, **kwargs):
|
||||
"""Evaluate predicted target values."""
|
||||
scores = {}
|
||||
cache = {} if self._use_cache(estimator) else None
|
||||
cached_call = partial(_cached_call, cache)
|
||||
|
||||
for name, scorer in self._scorers.items():
|
||||
if isinstance(scorer, _BaseScorer):
|
||||
score = scorer._score(cached_call, estimator, *args, **kwargs)
|
||||
else:
|
||||
score = scorer(estimator, *args, **kwargs)
|
||||
scores[name] = score
|
||||
return scores
|
||||
|
||||
def _use_cache(self, estimator):
|
||||
"""Return True if using a cache is beneficial.
|
||||
|
||||
Caching may be beneficial when one of these conditions holds:
|
||||
- `_ProbaScorer` will be called twice.
|
||||
- `_PredictScorer` will be called twice.
|
||||
- `_ThresholdScorer` will be called twice.
|
||||
- `_ThresholdScorer` and `_PredictScorer` are called and
|
||||
estimator is a regressor.
|
||||
- `_ThresholdScorer` and `_ProbaScorer` are called and
|
||||
estimator does not have a `decision_function` attribute.
|
||||
|
||||
"""
|
||||
if len(self._scorers) == 1: # Only one scorer
|
||||
return False
|
||||
|
||||
counter = Counter([type(v) for v in self._scorers.values()])
|
||||
|
||||
if any(
|
||||
counter[known_type] > 1
|
||||
for known_type in [_PredictScorer, _ProbaScorer, _ThresholdScorer]
|
||||
):
|
||||
return True
|
||||
|
||||
if counter[_ThresholdScorer]:
|
||||
if is_regressor(estimator) and counter[_PredictScorer]:
|
||||
return True
|
||||
elif counter[_ProbaScorer] and not hasattr(estimator, "decision_function"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class _BaseScorer:
|
||||
def __init__(self, score_func, sign, kwargs):
|
||||
self._kwargs = kwargs
|
||||
self._score_func = score_func
|
||||
self._sign = sign
|
||||
|
||||
@staticmethod
|
||||
def _check_pos_label(pos_label, classes):
|
||||
if pos_label not in list(classes):
|
||||
raise ValueError(f"pos_label={pos_label} is not a valid label: {classes}")
|
||||
|
||||
def _select_proba_binary(self, y_pred, classes):
|
||||
"""Select the column of the positive label in `y_pred` when
|
||||
probabilities are provided.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y_pred : ndarray of shape (n_samples, n_classes)
|
||||
The prediction given by `predict_proba`.
|
||||
|
||||
classes : ndarray of shape (n_classes,)
|
||||
The class labels for the estimator.
|
||||
|
||||
Returns
|
||||
-------
|
||||
y_pred : ndarray of shape (n_samples,)
|
||||
Probability predictions of the positive class.
|
||||
"""
|
||||
if y_pred.shape[1] == 2:
|
||||
pos_label = self._kwargs.get("pos_label", classes[1])
|
||||
self._check_pos_label(pos_label, classes)
|
||||
col_idx = np.flatnonzero(classes == pos_label)[0]
|
||||
return y_pred[:, col_idx]
|
||||
|
||||
err_msg = (
|
||||
f"Got predict_proba of shape {y_pred.shape}, but need "
|
||||
f"classifier with two classes for {self._score_func.__name__} "
|
||||
"scoring"
|
||||
)
|
||||
raise ValueError(err_msg)
|
||||
|
||||
def __repr__(self):
|
||||
kwargs_string = "".join(
|
||||
[", %s=%s" % (str(k), str(v)) for k, v in self._kwargs.items()]
|
||||
)
|
||||
return "make_scorer(%s%s%s%s)" % (
|
||||
self._score_func.__name__,
|
||||
"" if self._sign > 0 else ", greater_is_better=False",
|
||||
self._factory_args(),
|
||||
kwargs_string,
|
||||
)
|
||||
|
||||
def __call__(self, estimator, X, y_true, sample_weight=None):
|
||||
"""Evaluate predicted target values for X relative to y_true.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : object
|
||||
Trained estimator to use for scoring. Must have a predict_proba
|
||||
method; the output of that is used to compute the score.
|
||||
|
||||
X : {array-like, sparse matrix}
|
||||
Test data that will be fed to estimator.predict.
|
||||
|
||||
y_true : array-like
|
||||
Gold standard target values for X.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
Returns
|
||||
-------
|
||||
score : float
|
||||
Score function applied to prediction of estimator on X.
|
||||
"""
|
||||
return self._score(
|
||||
partial(_cached_call, None),
|
||||
estimator,
|
||||
X,
|
||||
y_true,
|
||||
sample_weight=sample_weight,
|
||||
)
|
||||
|
||||
def _factory_args(self):
|
||||
"""Return non-default make_scorer arguments for repr."""
|
||||
return ""
|
||||
|
||||
|
||||
class _PredictScorer(_BaseScorer):
|
||||
def _score(self, method_caller, estimator, X, y_true, sample_weight=None):
|
||||
"""Evaluate predicted target values for X relative to y_true.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
method_caller : callable
|
||||
Returns predictions given an estimator, method name, and other
|
||||
arguments, potentially caching results.
|
||||
|
||||
estimator : object
|
||||
Trained estimator to use for scoring. Must have a `predict`
|
||||
method; the output of that is used to compute the score.
|
||||
|
||||
X : {array-like, sparse matrix}
|
||||
Test data that will be fed to estimator.predict.
|
||||
|
||||
y_true : array-like
|
||||
Gold standard target values for X.
|
||||
|
||||
sample_weight : array-like of shape (n_samples,), default=None
|
||||
Sample weights.
|
||||
|
||||
Returns
|
||||
-------
|
||||
score : float
|
||||
Score function applied to prediction of estimator on X.
|
||||
"""
|
||||
|
||||
y_pred = method_caller(estimator, "predict", X)
|
||||
if sample_weight is not None:
|
||||
return self._sign * self._score_func(
|
||||
y_true, y_pred, sample_weight=sample_weight, **self._kwargs
|
||||
)
|
||||
else:
|
||||
return self._sign * self._score_func(y_true, y_pred, **self._kwargs)
|
||||
|
||||
|
||||
class _ProbaScorer(_BaseScorer):
|
||||
def _score(self, method_caller, clf, X, y, sample_weight=None):
|
||||
"""Evaluate predicted probabilities for X relative to y_true.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
method_caller : callable
|
||||
Returns predictions given an estimator, method name, and other
|
||||
arguments, potentially caching results.
|
||||
|
||||
clf : object
|
||||
Trained classifier to use for scoring. Must have a `predict_proba`
|
||||
method; the output of that is used to compute the score.
|
||||
|
||||
X : {array-like, sparse matrix}
|
||||
Test data that will be fed to clf.predict_proba.
|
||||
|
||||
y : array-like
|
||||
Gold standard target values for X. These must be class labels,
|
||||
not probabilities.
|
||||
|
||||
sample_weight : array-like, default=None
|
||||
Sample weights.
|
||||
|
||||
Returns
|
||||
-------
|
||||
score : float
|
||||
Score function applied to prediction of estimator on X.
|
||||
"""
|
||||
|
||||
y_type = type_of_target(y)
|
||||
y_pred = method_caller(clf, "predict_proba", X)
|
||||
if y_type == "binary" and y_pred.shape[1] <= 2:
|
||||
# `y_type` could be equal to "binary" even in a multi-class
|
||||
# problem: (when only 2 class are given to `y_true` during scoring)
|
||||
# Thus, we need to check for the shape of `y_pred`.
|
||||
y_pred = self._select_proba_binary(y_pred, clf.classes_)
|
||||
if sample_weight is not None:
|
||||
return self._sign * self._score_func(
|
||||
y, y_pred, sample_weight=sample_weight, **self._kwargs
|
||||
)
|
||||
else:
|
||||
return self._sign * self._score_func(y, y_pred, **self._kwargs)
|
||||
|
||||
def _factory_args(self):
|
||||
return ", needs_proba=True"
|
||||
|
||||
|
||||
class _ThresholdScorer(_BaseScorer):
|
||||
def _score(self, method_caller, clf, X, y, sample_weight=None):
|
||||
"""Evaluate decision function output for X relative to y_true.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
method_caller : callable
|
||||
Returns predictions given an estimator, method name, and other
|
||||
arguments, potentially caching results.
|
||||
|
||||
clf : object
|
||||
Trained classifier to use for scoring. Must have either a
|
||||
decision_function method or a predict_proba method; the output of
|
||||
that is used to compute the score.
|
||||
|
||||
X : {array-like, sparse matrix}
|
||||
Test data that will be fed to clf.decision_function or
|
||||
clf.predict_proba.
|
||||
|
||||
y : array-like
|
||||
Gold standard target values for X. These must be class labels,
|
||||
not decision function values.
|
||||
|
||||
sample_weight : array-like, default=None
|
||||
Sample weights.
|
||||
|
||||
Returns
|
||||
-------
|
||||
score : float
|
||||
Score function applied to prediction of estimator on X.
|
||||
"""
|
||||
|
||||
y_type = type_of_target(y)
|
||||
if y_type not in ("binary", "multilabel-indicator"):
|
||||
raise ValueError("{0} format is not supported".format(y_type))
|
||||
|
||||
if is_regressor(clf):
|
||||
y_pred = method_caller(clf, "predict", X)
|
||||
else:
|
||||
try:
|
||||
y_pred = method_caller(clf, "decision_function", X)
|
||||
|
||||
if isinstance(y_pred, list):
|
||||
# For multi-output multi-class estimator
|
||||
y_pred = np.vstack([p for p in y_pred]).T
|
||||
elif y_type == "binary" and "pos_label" in self._kwargs:
|
||||
self._check_pos_label(self._kwargs["pos_label"], clf.classes_)
|
||||
if self._kwargs["pos_label"] == clf.classes_[0]:
|
||||
# The implicit positive class of the binary classifier
|
||||
# does not match `pos_label`: we need to invert the
|
||||
# predictions
|
||||
y_pred *= -1
|
||||
|
||||
except (NotImplementedError, AttributeError):
|
||||
y_pred = method_caller(clf, "predict_proba", X)
|
||||
|
||||
if y_type == "binary":
|
||||
y_pred = self._select_proba_binary(y_pred, clf.classes_)
|
||||
elif isinstance(y_pred, list):
|
||||
y_pred = np.vstack([p[:, -1] for p in y_pred]).T
|
||||
|
||||
if sample_weight is not None:
|
||||
return self._sign * self._score_func(
|
||||
y, y_pred, sample_weight=sample_weight, **self._kwargs
|
||||
)
|
||||
else:
|
||||
return self._sign * self._score_func(y, y_pred, **self._kwargs)
|
||||
|
||||
def _factory_args(self):
|
||||
return ", needs_threshold=True"
|
||||
|
||||
|
||||
def get_scorer(scoring):
|
||||
"""Get a scorer from string.
|
||||
|
||||
Read more in the :ref:`User Guide <scoring_parameter>`.
|
||||
:func:`~sklearn.metrics.get_scorer_names` can be used to retrieve the names
|
||||
of all available scorers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
scoring : str or callable
|
||||
Scoring method as string. If callable it is returned as is.
|
||||
|
||||
Returns
|
||||
-------
|
||||
scorer : callable
|
||||
The scorer.
|
||||
|
||||
Notes
|
||||
-----
|
||||
When passed a string, this function always returns a copy of the scorer
|
||||
object. Calling `get_scorer` twice for the same scorer results in two
|
||||
separate scorer objects.
|
||||
"""
|
||||
if isinstance(scoring, str):
|
||||
try:
|
||||
scorer = copy.deepcopy(_SCORERS[scoring])
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"%r is not a valid scoring value. "
|
||||
"Use sklearn.metrics.get_scorer_names() "
|
||||
"to get valid options." % scoring
|
||||
)
|
||||
else:
|
||||
scorer = scoring
|
||||
return scorer
|
||||
|
||||
|
||||
def _passthrough_scorer(estimator, *args, **kwargs):
|
||||
"""Function that wraps estimator.score"""
|
||||
return estimator.score(*args, **kwargs)
|
||||
|
||||
|
||||
def check_scoring(estimator, scoring=None, *, allow_none=False):
|
||||
"""Determine scorer from user options.
|
||||
|
||||
A TypeError will be thrown if the estimator cannot be scored.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : estimator object implementing 'fit'
|
||||
The object to use to fit the data.
|
||||
|
||||
scoring : str or callable, default=None
|
||||
A string (see model evaluation documentation) or
|
||||
a scorer callable object / function with signature
|
||||
``scorer(estimator, X, y)``.
|
||||
If None, the provided estimator object's `score` method is used.
|
||||
|
||||
allow_none : bool, default=False
|
||||
If no scoring is specified and the estimator has no score function, we
|
||||
can either return None or raise an exception.
|
||||
|
||||
Returns
|
||||
-------
|
||||
scoring : callable
|
||||
A scorer callable object / function with signature
|
||||
``scorer(estimator, X, y)``.
|
||||
"""
|
||||
if not hasattr(estimator, "fit"):
|
||||
raise TypeError(
|
||||
"estimator should be an estimator implementing 'fit' method, %r was passed"
|
||||
% estimator
|
||||
)
|
||||
if isinstance(scoring, str):
|
||||
return get_scorer(scoring)
|
||||
elif callable(scoring):
|
||||
# Heuristic to ensure user has not passed a metric
|
||||
module = getattr(scoring, "__module__", None)
|
||||
if (
|
||||
hasattr(module, "startswith")
|
||||
and module.startswith("sklearn.metrics.")
|
||||
and not module.startswith("sklearn.metrics._scorer")
|
||||
and not module.startswith("sklearn.metrics.tests.")
|
||||
):
|
||||
raise ValueError(
|
||||
"scoring value %r looks like it is a metric "
|
||||
"function rather than a scorer. A scorer should "
|
||||
"require an estimator as its first parameter. "
|
||||
"Please use `make_scorer` to convert a metric "
|
||||
"to a scorer." % scoring
|
||||
)
|
||||
return get_scorer(scoring)
|
||||
elif scoring is None:
|
||||
if hasattr(estimator, "score"):
|
||||
return _passthrough_scorer
|
||||
elif allow_none:
|
||||
return None
|
||||
else:
|
||||
raise TypeError(
|
||||
"If no scoring is specified, the estimator passed should "
|
||||
"have a 'score' method. The estimator %r does not." % estimator
|
||||
)
|
||||
elif isinstance(scoring, Iterable):
|
||||
raise ValueError(
|
||||
"For evaluating multiple scores, use "
|
||||
"sklearn.model_selection.cross_validate instead. "
|
||||
"{0} was passed.".format(scoring)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"scoring value should either be a callable, string or None. %r was passed"
|
||||
% scoring
|
||||
)
|
||||
|
||||
|
||||
def _check_multimetric_scoring(estimator, scoring):
|
||||
"""Check the scoring parameter in cases when multiple metrics are allowed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : sklearn estimator instance
|
||||
The estimator for which the scoring will be applied.
|
||||
|
||||
scoring : list, tuple or dict
|
||||
Strategy to evaluate the performance of the cross-validated model on
|
||||
the test set.
|
||||
|
||||
The possibilities are:
|
||||
|
||||
- a list or tuple of unique strings;
|
||||
- a callable returning a dictionary where they keys are the metric
|
||||
names and the values are the metric scores;
|
||||
- a dictionary with metric names as keys and callables a values.
|
||||
|
||||
See :ref:`multimetric_grid_search` for an example.
|
||||
|
||||
Returns
|
||||
-------
|
||||
scorers_dict : dict
|
||||
A dict mapping each scorer name to its validated scorer.
|
||||
"""
|
||||
err_msg_generic = (
|
||||
f"scoring is invalid (got {scoring!r}). Refer to the "
|
||||
"scoring glossary for details: "
|
||||
"https://scikit-learn.org/stable/glossary.html#term-scoring"
|
||||
)
|
||||
|
||||
if isinstance(scoring, (list, tuple, set)):
|
||||
err_msg = (
|
||||
"The list/tuple elements must be unique strings of predefined scorers. "
|
||||
)
|
||||
try:
|
||||
keys = set(scoring)
|
||||
except TypeError as e:
|
||||
raise ValueError(err_msg) from e
|
||||
|
||||
if len(keys) != len(scoring):
|
||||
raise ValueError(
|
||||
f"{err_msg} Duplicate elements were found in"
|
||||
f" the given list. {scoring!r}"
|
||||
)
|
||||
elif len(keys) > 0:
|
||||
if not all(isinstance(k, str) for k in keys):
|
||||
if any(callable(k) for k in keys):
|
||||
raise ValueError(
|
||||
f"{err_msg} One or more of the elements "
|
||||
"were callables. Use a dict of score "
|
||||
"name mapped to the scorer callable. "
|
||||
f"Got {scoring!r}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{err_msg} Non-string types were found "
|
||||
f"in the given list. Got {scoring!r}"
|
||||
)
|
||||
scorers = {
|
||||
scorer: check_scoring(estimator, scoring=scorer) for scorer in scoring
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"{err_msg} Empty list was given. {scoring!r}")
|
||||
|
||||
elif isinstance(scoring, dict):
|
||||
keys = set(scoring)
|
||||
if not all(isinstance(k, str) for k in keys):
|
||||
raise ValueError(
|
||||
"Non-string types were found in the keys of "
|
||||
f"the given dict. scoring={scoring!r}"
|
||||
)
|
||||
if len(keys) == 0:
|
||||
raise ValueError(f"An empty dict was passed. {scoring!r}")
|
||||
scorers = {
|
||||
key: check_scoring(estimator, scoring=scorer)
|
||||
for key, scorer in scoring.items()
|
||||
}
|
||||
else:
|
||||
raise ValueError(err_msg_generic)
|
||||
return scorers
|
||||
|
||||
|
||||
def make_scorer(
|
||||
score_func,
|
||||
*,
|
||||
greater_is_better=True,
|
||||
needs_proba=False,
|
||||
needs_threshold=False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Make a scorer from a performance metric or loss function.
|
||||
|
||||
This factory function wraps scoring functions for use in
|
||||
:class:`~sklearn.model_selection.GridSearchCV` and
|
||||
:func:`~sklearn.model_selection.cross_val_score`.
|
||||
It takes a score function, such as :func:`~sklearn.metrics.accuracy_score`,
|
||||
:func:`~sklearn.metrics.mean_squared_error`,
|
||||
:func:`~sklearn.metrics.adjusted_rand_score` or
|
||||
:func:`~sklearn.metrics.average_precision_score`
|
||||
and returns a callable that scores an estimator's output.
|
||||
The signature of the call is `(estimator, X, y)` where `estimator`
|
||||
is the model to be evaluated, `X` is the data and `y` is the
|
||||
ground truth labeling (or `None` in the case of unsupervised models).
|
||||
|
||||
Read more in the :ref:`User Guide <scoring>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
score_func : callable
|
||||
Score function (or loss function) with signature
|
||||
`score_func(y, y_pred, **kwargs)`.
|
||||
|
||||
greater_is_better : bool, default=True
|
||||
Whether `score_func` is a score function (default), meaning high is
|
||||
good, or a loss function, meaning low is good. In the latter case, the
|
||||
scorer object will sign-flip the outcome of the `score_func`.
|
||||
|
||||
needs_proba : bool, default=False
|
||||
Whether `score_func` requires `predict_proba` to get probability
|
||||
estimates out of a classifier.
|
||||
|
||||
If True, for binary `y_true`, the score function is supposed to accept
|
||||
a 1D `y_pred` (i.e., probability of the positive class, shape
|
||||
`(n_samples,)`).
|
||||
|
||||
needs_threshold : bool, default=False
|
||||
Whether `score_func` takes a continuous decision certainty.
|
||||
This only works for binary classification using estimators that
|
||||
have either a `decision_function` or `predict_proba` method.
|
||||
|
||||
If True, for binary `y_true`, the score function is supposed to accept
|
||||
a 1D `y_pred` (i.e., probability of the positive class or the decision
|
||||
function, shape `(n_samples,)`).
|
||||
|
||||
For example `average_precision` or the area under the roc curve
|
||||
can not be computed using discrete predictions alone.
|
||||
|
||||
**kwargs : additional arguments
|
||||
Additional parameters to be passed to `score_func`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
scorer : callable
|
||||
Callable object that returns a scalar score; greater is better.
|
||||
|
||||
Notes
|
||||
-----
|
||||
If `needs_proba=False` and `needs_threshold=False`, the score
|
||||
function is supposed to accept the output of :term:`predict`. If
|
||||
`needs_proba=True`, the score function is supposed to accept the
|
||||
output of :term:`predict_proba` (For binary `y_true`, the score function is
|
||||
supposed to accept probability of the positive class). If
|
||||
`needs_threshold=True`, the score function is supposed to accept the
|
||||
output of :term:`decision_function` or :term:`predict_proba` when
|
||||
:term:`decision_function` is not present.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from sklearn.metrics import fbeta_score, make_scorer
|
||||
>>> ftwo_scorer = make_scorer(fbeta_score, beta=2)
|
||||
>>> ftwo_scorer
|
||||
make_scorer(fbeta_score, beta=2)
|
||||
>>> from sklearn.model_selection import GridSearchCV
|
||||
>>> from sklearn.svm import LinearSVC
|
||||
>>> grid = GridSearchCV(LinearSVC(), param_grid={'C': [1, 10]},
|
||||
... scoring=ftwo_scorer)
|
||||
"""
|
||||
sign = 1 if greater_is_better else -1
|
||||
if needs_proba and needs_threshold:
|
||||
raise ValueError(
|
||||
"Set either needs_proba or needs_threshold to True, but not both."
|
||||
)
|
||||
if needs_proba:
|
||||
cls = _ProbaScorer
|
||||
elif needs_threshold:
|
||||
cls = _ThresholdScorer
|
||||
else:
|
||||
cls = _PredictScorer
|
||||
return cls(score_func, sign, kwargs)
|
||||
|
||||
|
||||
# Standard regression scores
|
||||
explained_variance_scorer = make_scorer(explained_variance_score)
|
||||
r2_scorer = make_scorer(r2_score)
|
||||
max_error_scorer = make_scorer(max_error, greater_is_better=False)
|
||||
neg_mean_squared_error_scorer = make_scorer(mean_squared_error, greater_is_better=False)
|
||||
neg_mean_squared_log_error_scorer = make_scorer(
|
||||
mean_squared_log_error, greater_is_better=False
|
||||
)
|
||||
neg_mean_absolute_error_scorer = make_scorer(
|
||||
mean_absolute_error, greater_is_better=False
|
||||
)
|
||||
neg_mean_absolute_percentage_error_scorer = make_scorer(
|
||||
mean_absolute_percentage_error, greater_is_better=False
|
||||
)
|
||||
neg_median_absolute_error_scorer = make_scorer(
|
||||
median_absolute_error, greater_is_better=False
|
||||
)
|
||||
neg_root_mean_squared_error_scorer = make_scorer(
|
||||
mean_squared_error, greater_is_better=False, squared=False
|
||||
)
|
||||
neg_mean_poisson_deviance_scorer = make_scorer(
|
||||
mean_poisson_deviance, greater_is_better=False
|
||||
)
|
||||
|
||||
neg_mean_gamma_deviance_scorer = make_scorer(
|
||||
mean_gamma_deviance, greater_is_better=False
|
||||
)
|
||||
|
||||
# Standard Classification Scores
|
||||
accuracy_scorer = make_scorer(accuracy_score)
|
||||
balanced_accuracy_scorer = make_scorer(balanced_accuracy_score)
|
||||
matthews_corrcoef_scorer = make_scorer(matthews_corrcoef)
|
||||
|
||||
# Score functions that need decision values
|
||||
top_k_accuracy_scorer = make_scorer(
|
||||
top_k_accuracy_score, greater_is_better=True, needs_threshold=True
|
||||
)
|
||||
roc_auc_scorer = make_scorer(
|
||||
roc_auc_score, greater_is_better=True, needs_threshold=True
|
||||
)
|
||||
average_precision_scorer = make_scorer(average_precision_score, needs_threshold=True)
|
||||
roc_auc_ovo_scorer = make_scorer(roc_auc_score, needs_proba=True, multi_class="ovo")
|
||||
roc_auc_ovo_weighted_scorer = make_scorer(
|
||||
roc_auc_score, needs_proba=True, multi_class="ovo", average="weighted"
|
||||
)
|
||||
roc_auc_ovr_scorer = make_scorer(roc_auc_score, needs_proba=True, multi_class="ovr")
|
||||
roc_auc_ovr_weighted_scorer = make_scorer(
|
||||
roc_auc_score, needs_proba=True, multi_class="ovr", average="weighted"
|
||||
)
|
||||
|
||||
# Score function for probabilistic classification
|
||||
neg_log_loss_scorer = make_scorer(log_loss, greater_is_better=False, needs_proba=True)
|
||||
neg_brier_score_scorer = make_scorer(
|
||||
brier_score_loss, greater_is_better=False, needs_proba=True
|
||||
)
|
||||
brier_score_loss_scorer = make_scorer(
|
||||
brier_score_loss, greater_is_better=False, needs_proba=True
|
||||
)
|
||||
|
||||
|
||||
# Clustering scores
|
||||
adjusted_rand_scorer = make_scorer(adjusted_rand_score)
|
||||
rand_scorer = make_scorer(rand_score)
|
||||
homogeneity_scorer = make_scorer(homogeneity_score)
|
||||
completeness_scorer = make_scorer(completeness_score)
|
||||
v_measure_scorer = make_scorer(v_measure_score)
|
||||
mutual_info_scorer = make_scorer(mutual_info_score)
|
||||
adjusted_mutual_info_scorer = make_scorer(adjusted_mutual_info_score)
|
||||
normalized_mutual_info_scorer = make_scorer(normalized_mutual_info_score)
|
||||
fowlkes_mallows_scorer = make_scorer(fowlkes_mallows_score)
|
||||
|
||||
|
||||
# TODO(1.3) Remove
|
||||
class _DeprecatedScorers(dict):
|
||||
"""A temporary class to deprecate SCORERS."""
|
||||
|
||||
def __getitem__(self, item):
|
||||
warnings.warn(
|
||||
"sklearn.metrics.SCORERS is deprecated and will be removed in v1.3. "
|
||||
"Please use sklearn.metrics.get_scorer_names to get a list of available "
|
||||
"scorers and sklearn.metrics.get_metric to get scorer.",
|
||||
FutureWarning,
|
||||
)
|
||||
return super().__getitem__(item)
|
||||
|
||||
|
||||
_SCORERS = dict(
|
||||
explained_variance=explained_variance_scorer,
|
||||
r2=r2_scorer,
|
||||
max_error=max_error_scorer,
|
||||
matthews_corrcoef=matthews_corrcoef_scorer,
|
||||
neg_median_absolute_error=neg_median_absolute_error_scorer,
|
||||
neg_mean_absolute_error=neg_mean_absolute_error_scorer,
|
||||
neg_mean_absolute_percentage_error=neg_mean_absolute_percentage_error_scorer, # noqa
|
||||
neg_mean_squared_error=neg_mean_squared_error_scorer,
|
||||
neg_mean_squared_log_error=neg_mean_squared_log_error_scorer,
|
||||
neg_root_mean_squared_error=neg_root_mean_squared_error_scorer,
|
||||
neg_mean_poisson_deviance=neg_mean_poisson_deviance_scorer,
|
||||
neg_mean_gamma_deviance=neg_mean_gamma_deviance_scorer,
|
||||
accuracy=accuracy_scorer,
|
||||
top_k_accuracy=top_k_accuracy_scorer,
|
||||
roc_auc=roc_auc_scorer,
|
||||
roc_auc_ovr=roc_auc_ovr_scorer,
|
||||
roc_auc_ovo=roc_auc_ovo_scorer,
|
||||
roc_auc_ovr_weighted=roc_auc_ovr_weighted_scorer,
|
||||
roc_auc_ovo_weighted=roc_auc_ovo_weighted_scorer,
|
||||
balanced_accuracy=balanced_accuracy_scorer,
|
||||
average_precision=average_precision_scorer,
|
||||
neg_log_loss=neg_log_loss_scorer,
|
||||
neg_brier_score=neg_brier_score_scorer,
|
||||
# Cluster metrics that use supervised evaluation
|
||||
adjusted_rand_score=adjusted_rand_scorer,
|
||||
rand_score=rand_scorer,
|
||||
homogeneity_score=homogeneity_scorer,
|
||||
completeness_score=completeness_scorer,
|
||||
v_measure_score=v_measure_scorer,
|
||||
mutual_info_score=mutual_info_scorer,
|
||||
adjusted_mutual_info_score=adjusted_mutual_info_scorer,
|
||||
normalized_mutual_info_score=normalized_mutual_info_scorer,
|
||||
fowlkes_mallows_score=fowlkes_mallows_scorer,
|
||||
)
|
||||
|
||||
|
||||
def get_scorer_names():
|
||||
"""Get the names of all available scorers.
|
||||
|
||||
These names can be passed to :func:`~sklearn.metrics.get_scorer` to
|
||||
retrieve the scorer object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of str
|
||||
Names of all available scorers.
|
||||
"""
|
||||
return sorted(_SCORERS.keys())
|
||||
|
||||
|
||||
for name, metric in [
|
||||
("precision", precision_score),
|
||||
("recall", recall_score),
|
||||
("f1", f1_score),
|
||||
("jaccard", jaccard_score),
|
||||
]:
|
||||
_SCORERS[name] = make_scorer(metric, average="binary")
|
||||
for average in ["macro", "micro", "samples", "weighted"]:
|
||||
qualified_name = "{0}_{1}".format(name, average)
|
||||
_SCORERS[qualified_name] = make_scorer(metric, pos_label=None, average=average)
|
||||
|
||||
SCORERS = _DeprecatedScorers(_SCORERS)
|
||||
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
The :mod:`sklearn.metrics.cluster` submodule contains evaluation metrics for
|
||||
cluster analysis results. There are two forms of evaluation:
|
||||
|
||||
- supervised, which uses a ground truth class values for each sample.
|
||||
- unsupervised, which does not and measures the 'quality' of the model itself.
|
||||
"""
|
||||
from ._supervised import adjusted_mutual_info_score
|
||||
from ._supervised import normalized_mutual_info_score
|
||||
from ._supervised import adjusted_rand_score
|
||||
from ._supervised import rand_score
|
||||
from ._supervised import completeness_score
|
||||
from ._supervised import contingency_matrix
|
||||
from ._supervised import pair_confusion_matrix
|
||||
from ._supervised import expected_mutual_information
|
||||
from ._supervised import homogeneity_completeness_v_measure
|
||||
from ._supervised import homogeneity_score
|
||||
from ._supervised import mutual_info_score
|
||||
from ._supervised import v_measure_score
|
||||
from ._supervised import fowlkes_mallows_score
|
||||
from ._supervised import entropy
|
||||
from ._unsupervised import silhouette_samples
|
||||
from ._unsupervised import silhouette_score
|
||||
from ._unsupervised import calinski_harabasz_score
|
||||
from ._unsupervised import davies_bouldin_score
|
||||
from ._bicluster import consensus_score
|
||||
|
||||
__all__ = [
|
||||
"adjusted_mutual_info_score",
|
||||
"normalized_mutual_info_score",
|
||||
"adjusted_rand_score",
|
||||
"rand_score",
|
||||
"completeness_score",
|
||||
"pair_confusion_matrix",
|
||||
"contingency_matrix",
|
||||
"expected_mutual_information",
|
||||
"homogeneity_completeness_v_measure",
|
||||
"homogeneity_score",
|
||||
"mutual_info_score",
|
||||
"v_measure_score",
|
||||
"fowlkes_mallows_score",
|
||||
"entropy",
|
||||
"silhouette_samples",
|
||||
"silhouette_score",
|
||||
"calinski_harabasz_score",
|
||||
"davies_bouldin_score",
|
||||
"consensus_score",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,88 @@
|
||||
import numpy as np
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
from ...utils.validation import check_consistent_length, check_array
|
||||
|
||||
__all__ = ["consensus_score"]
|
||||
|
||||
|
||||
def _check_rows_and_columns(a, b):
|
||||
"""Unpacks the row and column arrays and checks their shape."""
|
||||
check_consistent_length(*a)
|
||||
check_consistent_length(*b)
|
||||
checks = lambda x: check_array(x, ensure_2d=False)
|
||||
a_rows, a_cols = map(checks, a)
|
||||
b_rows, b_cols = map(checks, b)
|
||||
return a_rows, a_cols, b_rows, b_cols
|
||||
|
||||
|
||||
def _jaccard(a_rows, a_cols, b_rows, b_cols):
|
||||
"""Jaccard coefficient on the elements of the two biclusters."""
|
||||
intersection = (a_rows * b_rows).sum() * (a_cols * b_cols).sum()
|
||||
|
||||
a_size = a_rows.sum() * a_cols.sum()
|
||||
b_size = b_rows.sum() * b_cols.sum()
|
||||
|
||||
return intersection / (a_size + b_size - intersection)
|
||||
|
||||
|
||||
def _pairwise_similarity(a, b, similarity):
|
||||
"""Computes pairwise similarity matrix.
|
||||
|
||||
result[i, j] is the Jaccard coefficient of a's bicluster i and b's
|
||||
bicluster j.
|
||||
|
||||
"""
|
||||
a_rows, a_cols, b_rows, b_cols = _check_rows_and_columns(a, b)
|
||||
n_a = a_rows.shape[0]
|
||||
n_b = b_rows.shape[0]
|
||||
result = np.array(
|
||||
list(
|
||||
list(
|
||||
similarity(a_rows[i], a_cols[i], b_rows[j], b_cols[j])
|
||||
for j in range(n_b)
|
||||
)
|
||||
for i in range(n_a)
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def consensus_score(a, b, *, similarity="jaccard"):
|
||||
"""The similarity of two sets of biclusters.
|
||||
|
||||
Similarity between individual biclusters is computed. Then the
|
||||
best matching between sets is found using the Hungarian algorithm.
|
||||
The final score is the sum of similarities divided by the size of
|
||||
the larger set.
|
||||
|
||||
Read more in the :ref:`User Guide <biclustering>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : (rows, columns)
|
||||
Tuple of row and column indicators for a set of biclusters.
|
||||
|
||||
b : (rows, columns)
|
||||
Another set of biclusters like ``a``.
|
||||
|
||||
similarity : 'jaccard' or callable, default='jaccard'
|
||||
May be the string "jaccard" to use the Jaccard coefficient, or
|
||||
any function that takes four arguments, each of which is a 1d
|
||||
indicator vector: (a_rows, a_columns, b_rows, b_columns).
|
||||
|
||||
References
|
||||
----------
|
||||
|
||||
* Hochreiter, Bodenhofer, et. al., 2010. `FABIA: factor analysis
|
||||
for bicluster acquisition
|
||||
<https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2881408/>`__.
|
||||
|
||||
"""
|
||||
if similarity == "jaccard":
|
||||
similarity = _jaccard
|
||||
matrix = _pairwise_similarity(a, b, similarity)
|
||||
row_indices, col_indices = linear_sum_assignment(1.0 - matrix)
|
||||
n_a = len(a[0])
|
||||
n_b = len(b[0])
|
||||
return matrix[row_indices, col_indices].sum() / max(n_a, n_b)
|
||||
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,366 @@
|
||||
"""Unsupervised evaluation metrics."""
|
||||
|
||||
# Authors: Robert Layton <robertlayton@gmail.com>
|
||||
# Arnaud Fouchet <foucheta@gmail.com>
|
||||
# Thierry Guillemot <thierry.guillemot.work@gmail.com>
|
||||
# License: BSD 3 clause
|
||||
|
||||
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...utils import check_random_state
|
||||
from ...utils import check_X_y
|
||||
from ...utils import _safe_indexing
|
||||
from ..pairwise import pairwise_distances_chunked
|
||||
from ..pairwise import pairwise_distances
|
||||
from ...preprocessing import LabelEncoder
|
||||
|
||||
|
||||
def check_number_of_labels(n_labels, n_samples):
|
||||
"""Check that number of labels are valid.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_labels : int
|
||||
Number of labels.
|
||||
|
||||
n_samples : int
|
||||
Number of samples.
|
||||
"""
|
||||
if not 1 < n_labels < n_samples:
|
||||
raise ValueError(
|
||||
"Number of labels is %d. Valid values are 2 to n_samples - 1 (inclusive)"
|
||||
% n_labels
|
||||
)
|
||||
|
||||
|
||||
def silhouette_score(
|
||||
X, labels, *, metric="euclidean", sample_size=None, random_state=None, **kwds
|
||||
):
|
||||
"""Compute the mean Silhouette Coefficient of all samples.
|
||||
|
||||
The Silhouette Coefficient is calculated using the mean intra-cluster
|
||||
distance (``a``) and the mean nearest-cluster distance (``b``) for each
|
||||
sample. The Silhouette Coefficient for a sample is ``(b - a) / max(a,
|
||||
b)``. To clarify, ``b`` is the distance between a sample and the nearest
|
||||
cluster that the sample is not a part of.
|
||||
Note that Silhouette Coefficient is only defined if number of labels
|
||||
is ``2 <= n_labels <= n_samples - 1``.
|
||||
|
||||
This function returns the mean Silhouette Coefficient over all samples.
|
||||
To obtain the values for each sample, use :func:`silhouette_samples`.
|
||||
|
||||
The best value is 1 and the worst value is -1. Values near 0 indicate
|
||||
overlapping clusters. Negative values generally indicate that a sample has
|
||||
been assigned to the wrong cluster, as a different cluster is more similar.
|
||||
|
||||
Read more in the :ref:`User Guide <silhouette_coefficient>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array-like of shape (n_samples_a, n_samples_a) if metric == \
|
||||
"precomputed" or (n_samples_a, n_features) otherwise
|
||||
An array of pairwise distances between samples, or a feature array.
|
||||
|
||||
labels : array-like of shape (n_samples,)
|
||||
Predicted labels for each sample.
|
||||
|
||||
metric : str or callable, default='euclidean'
|
||||
The metric to use when calculating distance between instances in a
|
||||
feature array. If metric is a string, it must be one of the options
|
||||
allowed by :func:`metrics.pairwise.pairwise_distances
|
||||
<sklearn.metrics.pairwise.pairwise_distances>`. If ``X`` is
|
||||
the distance array itself, use ``metric="precomputed"``.
|
||||
|
||||
sample_size : int, default=None
|
||||
The size of the sample to use when computing the Silhouette Coefficient
|
||||
on a random subset of the data.
|
||||
If ``sample_size is None``, no sampling is used.
|
||||
|
||||
random_state : int, RandomState instance or None, default=None
|
||||
Determines random number generation for selecting a subset of samples.
|
||||
Used when ``sample_size is not None``.
|
||||
Pass an int for reproducible results across multiple function calls.
|
||||
See :term:`Glossary <random_state>`.
|
||||
|
||||
**kwds : optional keyword parameters
|
||||
Any further parameters are passed directly to the distance function.
|
||||
If using a scipy.spatial.distance metric, the parameters are still
|
||||
metric dependent. See the scipy docs for usage examples.
|
||||
|
||||
Returns
|
||||
-------
|
||||
silhouette : float
|
||||
Mean Silhouette Coefficient for all samples.
|
||||
|
||||
References
|
||||
----------
|
||||
|
||||
.. [1] `Peter J. Rousseeuw (1987). "Silhouettes: a Graphical Aid to the
|
||||
Interpretation and Validation of Cluster Analysis". Computational
|
||||
and Applied Mathematics 20: 53-65.
|
||||
<https://www.sciencedirect.com/science/article/pii/0377042787901257>`_
|
||||
|
||||
.. [2] `Wikipedia entry on the Silhouette Coefficient
|
||||
<https://en.wikipedia.org/wiki/Silhouette_(clustering)>`_
|
||||
"""
|
||||
if sample_size is not None:
|
||||
X, labels = check_X_y(X, labels, accept_sparse=["csc", "csr"])
|
||||
random_state = check_random_state(random_state)
|
||||
indices = random_state.permutation(X.shape[0])[:sample_size]
|
||||
if metric == "precomputed":
|
||||
X, labels = X[indices].T[indices].T, labels[indices]
|
||||
else:
|
||||
X, labels = X[indices], labels[indices]
|
||||
return np.mean(silhouette_samples(X, labels, metric=metric, **kwds))
|
||||
|
||||
|
||||
def _silhouette_reduce(D_chunk, start, labels, label_freqs):
|
||||
"""Accumulate silhouette statistics for vertical chunk of X.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
D_chunk : array-like of shape (n_chunk_samples, n_samples)
|
||||
Precomputed distances for a chunk.
|
||||
start : int
|
||||
First index in the chunk.
|
||||
labels : array-like of shape (n_samples,)
|
||||
Corresponding cluster labels, encoded as {0, ..., n_clusters-1}.
|
||||
label_freqs : array-like
|
||||
Distribution of cluster labels in ``labels``.
|
||||
"""
|
||||
# accumulate distances from each sample to each cluster
|
||||
clust_dists = np.zeros((len(D_chunk), len(label_freqs)), dtype=D_chunk.dtype)
|
||||
for i in range(len(D_chunk)):
|
||||
clust_dists[i] += np.bincount(
|
||||
labels, weights=D_chunk[i], minlength=len(label_freqs)
|
||||
)
|
||||
|
||||
# intra_index selects intra-cluster distances within clust_dists
|
||||
intra_index = (np.arange(len(D_chunk)), labels[start : start + len(D_chunk)])
|
||||
# intra_clust_dists are averaged over cluster size outside this function
|
||||
intra_clust_dists = clust_dists[intra_index]
|
||||
# of the remaining distances we normalise and extract the minimum
|
||||
clust_dists[intra_index] = np.inf
|
||||
clust_dists /= label_freqs
|
||||
inter_clust_dists = clust_dists.min(axis=1)
|
||||
return intra_clust_dists, inter_clust_dists
|
||||
|
||||
|
||||
def silhouette_samples(X, labels, *, metric="euclidean", **kwds):
|
||||
"""Compute the Silhouette Coefficient for each sample.
|
||||
|
||||
The Silhouette Coefficient is a measure of how well samples are clustered
|
||||
with samples that are similar to themselves. Clustering models with a high
|
||||
Silhouette Coefficient are said to be dense, where samples in the same
|
||||
cluster are similar to each other, and well separated, where samples in
|
||||
different clusters are not very similar to each other.
|
||||
|
||||
The Silhouette Coefficient is calculated using the mean intra-cluster
|
||||
distance (``a``) and the mean nearest-cluster distance (``b``) for each
|
||||
sample. The Silhouette Coefficient for a sample is ``(b - a) / max(a,
|
||||
b)``.
|
||||
Note that Silhouette Coefficient is only defined if number of labels
|
||||
is 2 ``<= n_labels <= n_samples - 1``.
|
||||
|
||||
This function returns the Silhouette Coefficient for each sample.
|
||||
|
||||
The best value is 1 and the worst value is -1. Values near 0 indicate
|
||||
overlapping clusters.
|
||||
|
||||
Read more in the :ref:`User Guide <silhouette_coefficient>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array-like of shape (n_samples_a, n_samples_a) if metric == \
|
||||
"precomputed" or (n_samples_a, n_features) otherwise
|
||||
An array of pairwise distances between samples, or a feature array.
|
||||
|
||||
labels : array-like of shape (n_samples,)
|
||||
Label values for each sample.
|
||||
|
||||
metric : str or callable, default='euclidean'
|
||||
The metric to use when calculating distance between instances in a
|
||||
feature array. If metric is a string, it must be one of the options
|
||||
allowed by :func:`sklearn.metrics.pairwise.pairwise_distances`.
|
||||
If ``X`` is the distance array itself, use "precomputed" as the metric.
|
||||
Precomputed distance matrices must have 0 along the diagonal.
|
||||
|
||||
**kwds : optional keyword parameters
|
||||
Any further parameters are passed directly to the distance function.
|
||||
If using a ``scipy.spatial.distance`` metric, the parameters are still
|
||||
metric dependent. See the scipy docs for usage examples.
|
||||
|
||||
Returns
|
||||
-------
|
||||
silhouette : array-like of shape (n_samples,)
|
||||
Silhouette Coefficients for each sample.
|
||||
|
||||
References
|
||||
----------
|
||||
|
||||
.. [1] `Peter J. Rousseeuw (1987). "Silhouettes: a Graphical Aid to the
|
||||
Interpretation and Validation of Cluster Analysis". Computational
|
||||
and Applied Mathematics 20: 53-65.
|
||||
<https://www.sciencedirect.com/science/article/pii/0377042787901257>`_
|
||||
|
||||
.. [2] `Wikipedia entry on the Silhouette Coefficient
|
||||
<https://en.wikipedia.org/wiki/Silhouette_(clustering)>`_
|
||||
"""
|
||||
X, labels = check_X_y(X, labels, accept_sparse=["csc", "csr"])
|
||||
|
||||
# Check for non-zero diagonal entries in precomputed distance matrix
|
||||
if metric == "precomputed":
|
||||
error_msg = ValueError(
|
||||
"The precomputed distance matrix contains non-zero "
|
||||
"elements on the diagonal. Use np.fill_diagonal(X, 0)."
|
||||
)
|
||||
if X.dtype.kind == "f":
|
||||
atol = np.finfo(X.dtype).eps * 100
|
||||
if np.any(np.abs(np.diagonal(X)) > atol):
|
||||
raise ValueError(error_msg)
|
||||
elif np.any(np.diagonal(X) != 0): # integral dtype
|
||||
raise ValueError(error_msg)
|
||||
|
||||
le = LabelEncoder()
|
||||
labels = le.fit_transform(labels)
|
||||
n_samples = len(labels)
|
||||
label_freqs = np.bincount(labels)
|
||||
check_number_of_labels(len(le.classes_), n_samples)
|
||||
|
||||
kwds["metric"] = metric
|
||||
reduce_func = functools.partial(
|
||||
_silhouette_reduce, labels=labels, label_freqs=label_freqs
|
||||
)
|
||||
results = zip(*pairwise_distances_chunked(X, reduce_func=reduce_func, **kwds))
|
||||
intra_clust_dists, inter_clust_dists = results
|
||||
intra_clust_dists = np.concatenate(intra_clust_dists)
|
||||
inter_clust_dists = np.concatenate(inter_clust_dists)
|
||||
|
||||
denom = (label_freqs - 1).take(labels, mode="clip")
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
intra_clust_dists /= denom
|
||||
|
||||
sil_samples = inter_clust_dists - intra_clust_dists
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
sil_samples /= np.maximum(intra_clust_dists, inter_clust_dists)
|
||||
# nan values are for clusters of size 1, and should be 0
|
||||
return np.nan_to_num(sil_samples)
|
||||
|
||||
|
||||
def calinski_harabasz_score(X, labels):
|
||||
"""Compute the Calinski and Harabasz score.
|
||||
|
||||
It is also known as the Variance Ratio Criterion.
|
||||
|
||||
The score is defined as ratio of the sum of between-cluster dispersion and
|
||||
of within-cluster dispersion.
|
||||
|
||||
Read more in the :ref:`User Guide <calinski_harabasz_index>`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array-like of shape (n_samples, n_features)
|
||||
A list of ``n_features``-dimensional data points. Each row corresponds
|
||||
to a single data point.
|
||||
|
||||
labels : array-like of shape (n_samples,)
|
||||
Predicted labels for each sample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
score : float
|
||||
The resulting Calinski-Harabasz score.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] `T. Calinski and J. Harabasz, 1974. "A dendrite method for cluster
|
||||
analysis". Communications in Statistics
|
||||
<https://www.tandfonline.com/doi/abs/10.1080/03610927408827101>`_
|
||||
"""
|
||||
X, labels = check_X_y(X, labels)
|
||||
le = LabelEncoder()
|
||||
labels = le.fit_transform(labels)
|
||||
|
||||
n_samples, _ = X.shape
|
||||
n_labels = len(le.classes_)
|
||||
|
||||
check_number_of_labels(n_labels, n_samples)
|
||||
|
||||
extra_disp, intra_disp = 0.0, 0.0
|
||||
mean = np.mean(X, axis=0)
|
||||
for k in range(n_labels):
|
||||
cluster_k = X[labels == k]
|
||||
mean_k = np.mean(cluster_k, axis=0)
|
||||
extra_disp += len(cluster_k) * np.sum((mean_k - mean) ** 2)
|
||||
intra_disp += np.sum((cluster_k - mean_k) ** 2)
|
||||
|
||||
return (
|
||||
1.0
|
||||
if intra_disp == 0.0
|
||||
else extra_disp * (n_samples - n_labels) / (intra_disp * (n_labels - 1.0))
|
||||
)
|
||||
|
||||
|
||||
def davies_bouldin_score(X, labels):
|
||||
"""Compute the Davies-Bouldin score.
|
||||
|
||||
The score is defined as the average similarity measure of each cluster with
|
||||
its most similar cluster, where similarity is the ratio of within-cluster
|
||||
distances to between-cluster distances. Thus, clusters which are farther
|
||||
apart and less dispersed will result in a better score.
|
||||
|
||||
The minimum score is zero, with lower values indicating better clustering.
|
||||
|
||||
Read more in the :ref:`User Guide <davies-bouldin_index>`.
|
||||
|
||||
.. versionadded:: 0.20
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array-like of shape (n_samples, n_features)
|
||||
A list of ``n_features``-dimensional data points. Each row corresponds
|
||||
to a single data point.
|
||||
|
||||
labels : array-like of shape (n_samples,)
|
||||
Predicted labels for each sample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
score: float
|
||||
The resulting Davies-Bouldin score.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Davies, David L.; Bouldin, Donald W. (1979).
|
||||
`"A Cluster Separation Measure"
|
||||
<https://ieeexplore.ieee.org/document/4766909>`__.
|
||||
IEEE Transactions on Pattern Analysis and Machine Intelligence.
|
||||
PAMI-1 (2): 224-227
|
||||
"""
|
||||
X, labels = check_X_y(X, labels)
|
||||
le = LabelEncoder()
|
||||
labels = le.fit_transform(labels)
|
||||
n_samples, _ = X.shape
|
||||
n_labels = len(le.classes_)
|
||||
check_number_of_labels(n_labels, n_samples)
|
||||
|
||||
intra_dists = np.zeros(n_labels)
|
||||
centroids = np.zeros((n_labels, len(X[0])), dtype=float)
|
||||
for k in range(n_labels):
|
||||
cluster_k = _safe_indexing(X, labels == k)
|
||||
centroid = cluster_k.mean(axis=0)
|
||||
centroids[k] = centroid
|
||||
intra_dists[k] = np.average(pairwise_distances(cluster_k, [centroid]))
|
||||
|
||||
centroid_distances = pairwise_distances(centroids)
|
||||
|
||||
if np.allclose(intra_dists, 0) or np.allclose(centroid_distances, 0):
|
||||
return 0.0
|
||||
|
||||
centroid_distances[centroid_distances == 0] = np.inf
|
||||
combined_intra_dists = intra_dists[:, None] + intra_dists
|
||||
scores = np.max(combined_intra_dists / centroid_distances, axis=1)
|
||||
return np.mean(scores)
|
||||
@@ -0,0 +1,27 @@
|
||||
import os
|
||||
|
||||
import numpy
|
||||
from numpy.distutils.misc_util import Configuration
|
||||
|
||||
|
||||
def configuration(parent_package="", top_path=None):
|
||||
config = Configuration("cluster", parent_package, top_path)
|
||||
libraries = []
|
||||
if os.name == "posix":
|
||||
libraries.append("m")
|
||||
config.add_extension(
|
||||
"_expected_mutual_info_fast",
|
||||
sources=["_expected_mutual_info_fast.pyx"],
|
||||
include_dirs=[numpy.get_include()],
|
||||
libraries=libraries,
|
||||
)
|
||||
|
||||
config.add_subpackage("tests")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from numpy.distutils.core import setup
|
||||
|
||||
setup(**configuration().todict())
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,57 @@
|
||||
"""Testing for bicluster metrics module"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sklearn.utils._testing import assert_almost_equal
|
||||
|
||||
from sklearn.metrics.cluster._bicluster import _jaccard
|
||||
from sklearn.metrics import consensus_score
|
||||
|
||||
|
||||
def test_jaccard():
|
||||
a1 = np.array([True, True, False, False])
|
||||
a2 = np.array([True, True, True, True])
|
||||
a3 = np.array([False, True, True, False])
|
||||
a4 = np.array([False, False, True, True])
|
||||
|
||||
assert _jaccard(a1, a1, a1, a1) == 1
|
||||
assert _jaccard(a1, a1, a2, a2) == 0.25
|
||||
assert _jaccard(a1, a1, a3, a3) == 1.0 / 7
|
||||
assert _jaccard(a1, a1, a4, a4) == 0
|
||||
|
||||
|
||||
def test_consensus_score():
|
||||
a = [[True, True, False, False], [False, False, True, True]]
|
||||
b = a[::-1]
|
||||
|
||||
assert consensus_score((a, a), (a, a)) == 1
|
||||
assert consensus_score((a, a), (b, b)) == 1
|
||||
assert consensus_score((a, b), (a, b)) == 1
|
||||
assert consensus_score((a, b), (b, a)) == 1
|
||||
|
||||
assert consensus_score((a, a), (b, a)) == 0
|
||||
assert consensus_score((a, a), (a, b)) == 0
|
||||
assert consensus_score((b, b), (a, b)) == 0
|
||||
assert consensus_score((b, b), (b, a)) == 0
|
||||
|
||||
|
||||
def test_consensus_score_issue2445():
|
||||
"""Different number of biclusters in A and B"""
|
||||
a_rows = np.array(
|
||||
[
|
||||
[True, True, False, False],
|
||||
[False, False, True, True],
|
||||
[False, False, False, True],
|
||||
]
|
||||
)
|
||||
a_cols = np.array(
|
||||
[
|
||||
[True, True, False, False],
|
||||
[False, False, True, True],
|
||||
[False, False, False, True],
|
||||
]
|
||||
)
|
||||
idx = [0, 2]
|
||||
s = consensus_score((a_rows, a_cols), (a_rows[idx], a_cols[idx]))
|
||||
# B contains 2 of the 3 biclusters in A, so score should be 2/3
|
||||
assert_almost_equal(s, 2.0 / 3.0)
|
||||
@@ -0,0 +1,219 @@
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from sklearn.metrics.cluster import adjusted_mutual_info_score
|
||||
from sklearn.metrics.cluster import adjusted_rand_score
|
||||
from sklearn.metrics.cluster import rand_score
|
||||
from sklearn.metrics.cluster import completeness_score
|
||||
from sklearn.metrics.cluster import fowlkes_mallows_score
|
||||
from sklearn.metrics.cluster import homogeneity_score
|
||||
from sklearn.metrics.cluster import mutual_info_score
|
||||
from sklearn.metrics.cluster import normalized_mutual_info_score
|
||||
from sklearn.metrics.cluster import v_measure_score
|
||||
from sklearn.metrics.cluster import silhouette_score
|
||||
from sklearn.metrics.cluster import calinski_harabasz_score
|
||||
from sklearn.metrics.cluster import davies_bouldin_score
|
||||
|
||||
from sklearn.utils._testing import assert_allclose
|
||||
|
||||
|
||||
# Dictionaries of metrics
|
||||
# ------------------------
|
||||
# The goal of having those dictionaries is to have an easy way to call a
|
||||
# particular metric and associate a name to each function:
|
||||
# - SUPERVISED_METRICS: all supervised cluster metrics - (when given a
|
||||
# ground truth value)
|
||||
# - UNSUPERVISED_METRICS: all unsupervised cluster metrics
|
||||
#
|
||||
# Those dictionaries will be used to test systematically some invariance
|
||||
# properties, e.g. invariance toward several input layout.
|
||||
#
|
||||
|
||||
SUPERVISED_METRICS = {
|
||||
"adjusted_mutual_info_score": adjusted_mutual_info_score,
|
||||
"adjusted_rand_score": adjusted_rand_score,
|
||||
"rand_score": rand_score,
|
||||
"completeness_score": completeness_score,
|
||||
"homogeneity_score": homogeneity_score,
|
||||
"mutual_info_score": mutual_info_score,
|
||||
"normalized_mutual_info_score": normalized_mutual_info_score,
|
||||
"v_measure_score": v_measure_score,
|
||||
"fowlkes_mallows_score": fowlkes_mallows_score,
|
||||
}
|
||||
|
||||
UNSUPERVISED_METRICS = {
|
||||
"silhouette_score": silhouette_score,
|
||||
"silhouette_manhattan": partial(silhouette_score, metric="manhattan"),
|
||||
"calinski_harabasz_score": calinski_harabasz_score,
|
||||
"davies_bouldin_score": davies_bouldin_score,
|
||||
}
|
||||
|
||||
# Lists of metrics with common properties
|
||||
# ---------------------------------------
|
||||
# Lists of metrics with common properties are used to test systematically some
|
||||
# functionalities and invariance, e.g. SYMMETRIC_METRICS lists all metrics
|
||||
# that are symmetric with respect to their input argument y_true and y_pred.
|
||||
#
|
||||
# --------------------------------------------------------------------
|
||||
# Symmetric with respect to their input arguments y_true and y_pred.
|
||||
# Symmetric metrics only apply to supervised clusters.
|
||||
SYMMETRIC_METRICS = [
|
||||
"adjusted_rand_score",
|
||||
"rand_score",
|
||||
"v_measure_score",
|
||||
"mutual_info_score",
|
||||
"adjusted_mutual_info_score",
|
||||
"normalized_mutual_info_score",
|
||||
"fowlkes_mallows_score",
|
||||
]
|
||||
|
||||
NON_SYMMETRIC_METRICS = ["homogeneity_score", "completeness_score"]
|
||||
|
||||
# Metrics whose upper bound is 1
|
||||
NORMALIZED_METRICS = [
|
||||
"adjusted_rand_score",
|
||||
"rand_score",
|
||||
"homogeneity_score",
|
||||
"completeness_score",
|
||||
"v_measure_score",
|
||||
"adjusted_mutual_info_score",
|
||||
"fowlkes_mallows_score",
|
||||
"normalized_mutual_info_score",
|
||||
]
|
||||
|
||||
|
||||
rng = np.random.RandomState(0)
|
||||
y1 = rng.randint(3, size=30)
|
||||
y2 = rng.randint(3, size=30)
|
||||
|
||||
|
||||
def test_symmetric_non_symmetric_union():
|
||||
assert sorted(SYMMETRIC_METRICS + NON_SYMMETRIC_METRICS) == sorted(
|
||||
SUPERVISED_METRICS
|
||||
)
|
||||
|
||||
|
||||
# 0.22 AMI and NMI changes
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
@pytest.mark.parametrize(
|
||||
"metric_name, y1, y2", [(name, y1, y2) for name in SYMMETRIC_METRICS]
|
||||
)
|
||||
def test_symmetry(metric_name, y1, y2):
|
||||
metric = SUPERVISED_METRICS[metric_name]
|
||||
assert metric(y1, y2) == pytest.approx(metric(y2, y1))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"metric_name, y1, y2", [(name, y1, y2) for name in NON_SYMMETRIC_METRICS]
|
||||
)
|
||||
def test_non_symmetry(metric_name, y1, y2):
|
||||
metric = SUPERVISED_METRICS[metric_name]
|
||||
assert metric(y1, y2) != pytest.approx(metric(y2, y1))
|
||||
|
||||
|
||||
# 0.22 AMI and NMI changes
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
@pytest.mark.parametrize("metric_name", NORMALIZED_METRICS)
|
||||
def test_normalized_output(metric_name):
|
||||
upper_bound_1 = [0, 0, 0, 1, 1, 1]
|
||||
upper_bound_2 = [0, 0, 0, 1, 1, 1]
|
||||
metric = SUPERVISED_METRICS[metric_name]
|
||||
assert metric([0, 0, 0, 1, 1], [0, 0, 0, 1, 2]) > 0.0
|
||||
assert metric([0, 0, 1, 1, 2], [0, 0, 1, 1, 1]) > 0.0
|
||||
assert metric([0, 0, 0, 1, 2], [0, 1, 1, 1, 1]) < 1.0
|
||||
assert metric([0, 0, 0, 1, 2], [0, 1, 1, 1, 1]) < 1.0
|
||||
assert metric(upper_bound_1, upper_bound_2) == pytest.approx(1.0)
|
||||
|
||||
lower_bound_1 = [0, 0, 0, 0, 0, 0]
|
||||
lower_bound_2 = [0, 1, 2, 3, 4, 5]
|
||||
score = np.array(
|
||||
[metric(lower_bound_1, lower_bound_2), metric(lower_bound_2, lower_bound_1)]
|
||||
)
|
||||
assert not (score < 0).any()
|
||||
|
||||
|
||||
# 0.22 AMI and NMI changes
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
@pytest.mark.parametrize("metric_name", chain(SUPERVISED_METRICS, UNSUPERVISED_METRICS))
|
||||
def test_permute_labels(metric_name):
|
||||
# All clustering metrics do not change score due to permutations of labels
|
||||
# that is when 0 and 1 exchanged.
|
||||
y_label = np.array([0, 0, 0, 1, 1, 0, 1])
|
||||
y_pred = np.array([1, 0, 1, 0, 1, 1, 0])
|
||||
if metric_name in SUPERVISED_METRICS:
|
||||
metric = SUPERVISED_METRICS[metric_name]
|
||||
score_1 = metric(y_pred, y_label)
|
||||
assert_allclose(score_1, metric(1 - y_pred, y_label))
|
||||
assert_allclose(score_1, metric(1 - y_pred, 1 - y_label))
|
||||
assert_allclose(score_1, metric(y_pred, 1 - y_label))
|
||||
else:
|
||||
metric = UNSUPERVISED_METRICS[metric_name]
|
||||
X = np.random.randint(10, size=(7, 10))
|
||||
score_1 = metric(X, y_pred)
|
||||
assert_allclose(score_1, metric(X, 1 - y_pred))
|
||||
|
||||
|
||||
# 0.22 AMI and NMI changes
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
@pytest.mark.parametrize("metric_name", chain(SUPERVISED_METRICS, UNSUPERVISED_METRICS))
|
||||
# For all clustering metrics Input parameters can be both
|
||||
# in the form of arrays lists, positive, negative or string
|
||||
def test_format_invariance(metric_name):
|
||||
y_true = [0, 0, 0, 0, 1, 1, 1, 1]
|
||||
y_pred = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
|
||||
def generate_formats(y):
|
||||
y = np.array(y)
|
||||
yield y, "array of ints"
|
||||
yield y.tolist(), "list of ints"
|
||||
yield [str(x) + "-a" for x in y.tolist()], "list of strs"
|
||||
yield (
|
||||
np.array([str(x) + "-a" for x in y.tolist()], dtype=object),
|
||||
"array of strs",
|
||||
)
|
||||
yield y - 1, "including negative ints"
|
||||
yield y + 1, "strictly positive ints"
|
||||
|
||||
if metric_name in SUPERVISED_METRICS:
|
||||
metric = SUPERVISED_METRICS[metric_name]
|
||||
score_1 = metric(y_true, y_pred)
|
||||
y_true_gen = generate_formats(y_true)
|
||||
y_pred_gen = generate_formats(y_pred)
|
||||
for (y_true_fmt, fmt_name), (y_pred_fmt, _) in zip(y_true_gen, y_pred_gen):
|
||||
assert score_1 == metric(y_true_fmt, y_pred_fmt)
|
||||
else:
|
||||
metric = UNSUPERVISED_METRICS[metric_name]
|
||||
X = np.random.randint(10, size=(8, 10))
|
||||
score_1 = metric(X, y_true)
|
||||
assert score_1 == metric(X.astype(float), y_true)
|
||||
y_true_gen = generate_formats(y_true)
|
||||
for y_true_fmt, fmt_name in y_true_gen:
|
||||
assert score_1 == metric(X, y_true_fmt)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("metric", SUPERVISED_METRICS.values())
|
||||
def test_single_sample(metric):
|
||||
# only the supervised metrics support single sample
|
||||
for i, j in [(0, 0), (0, 1), (1, 0), (1, 1)]:
|
||||
metric([i], [j])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"metric_name, metric_func", dict(SUPERVISED_METRICS, **UNSUPERVISED_METRICS).items()
|
||||
)
|
||||
def test_inf_nan_input(metric_name, metric_func):
|
||||
if metric_name in SUPERVISED_METRICS:
|
||||
invalids = [
|
||||
([0, 1], [np.inf, np.inf]),
|
||||
([0, 1], [np.nan, np.nan]),
|
||||
([0, 1], [np.nan, np.inf]),
|
||||
]
|
||||
else:
|
||||
X = np.random.randint(10, size=(2, 10))
|
||||
invalids = [(X, [np.inf, np.inf]), (X, [np.nan, np.nan]), (X, [np.nan, np.inf])]
|
||||
with pytest.raises(ValueError, match=r"contains (NaN|infinity)"):
|
||||
for args in invalids:
|
||||
metric_func(*args)
|
||||
@@ -0,0 +1,483 @@
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from sklearn.metrics.cluster import adjusted_mutual_info_score
|
||||
from sklearn.metrics.cluster import adjusted_rand_score
|
||||
from sklearn.metrics.cluster import rand_score
|
||||
from sklearn.metrics.cluster import completeness_score
|
||||
from sklearn.metrics.cluster import contingency_matrix
|
||||
from sklearn.metrics.cluster import pair_confusion_matrix
|
||||
from sklearn.metrics.cluster import entropy
|
||||
from sklearn.metrics.cluster import expected_mutual_information
|
||||
from sklearn.metrics.cluster import fowlkes_mallows_score
|
||||
from sklearn.metrics.cluster import homogeneity_completeness_v_measure
|
||||
from sklearn.metrics.cluster import homogeneity_score
|
||||
from sklearn.metrics.cluster import mutual_info_score
|
||||
from sklearn.metrics.cluster import normalized_mutual_info_score
|
||||
from sklearn.metrics.cluster import v_measure_score
|
||||
from sklearn.metrics.cluster._supervised import _generalized_average
|
||||
from sklearn.metrics.cluster._supervised import check_clusterings
|
||||
|
||||
from sklearn.utils import assert_all_finite
|
||||
from sklearn.utils._testing import assert_almost_equal
|
||||
from numpy.testing import assert_array_equal, assert_array_almost_equal, assert_allclose
|
||||
|
||||
|
||||
score_funcs = [
|
||||
adjusted_rand_score,
|
||||
rand_score,
|
||||
homogeneity_score,
|
||||
completeness_score,
|
||||
v_measure_score,
|
||||
adjusted_mutual_info_score,
|
||||
normalized_mutual_info_score,
|
||||
]
|
||||
|
||||
|
||||
def test_error_messages_on_wrong_input():
|
||||
for score_func in score_funcs:
|
||||
expected = (
|
||||
r"Found input variables with inconsistent numbers " r"of samples: \[2, 3\]"
|
||||
)
|
||||
with pytest.raises(ValueError, match=expected):
|
||||
score_func([0, 1], [1, 1, 1])
|
||||
|
||||
expected = r"labels_true must be 1D: shape is \(2"
|
||||
with pytest.raises(ValueError, match=expected):
|
||||
score_func([[0, 1], [1, 0]], [1, 1, 1])
|
||||
|
||||
expected = r"labels_pred must be 1D: shape is \(2"
|
||||
with pytest.raises(ValueError, match=expected):
|
||||
score_func([0, 1, 0], [[1, 1], [0, 0]])
|
||||
|
||||
|
||||
def test_generalized_average():
|
||||
a, b = 1, 2
|
||||
methods = ["min", "geometric", "arithmetic", "max"]
|
||||
means = [_generalized_average(a, b, method) for method in methods]
|
||||
assert means[0] <= means[1] <= means[2] <= means[3]
|
||||
c, d = 12, 12
|
||||
means = [_generalized_average(c, d, method) for method in methods]
|
||||
assert means[0] == means[1] == means[2] == means[3]
|
||||
|
||||
|
||||
def test_perfect_matches():
|
||||
for score_func in score_funcs:
|
||||
assert score_func([], []) == pytest.approx(1.0)
|
||||
assert score_func([0], [1]) == pytest.approx(1.0)
|
||||
assert score_func([0, 0, 0], [0, 0, 0]) == pytest.approx(1.0)
|
||||
assert score_func([0, 1, 0], [42, 7, 42]) == pytest.approx(1.0)
|
||||
assert score_func([0.0, 1.0, 0.0], [42.0, 7.0, 42.0]) == pytest.approx(1.0)
|
||||
assert score_func([0.0, 1.0, 2.0], [42.0, 7.0, 2.0]) == pytest.approx(1.0)
|
||||
assert score_func([0, 1, 2], [42, 7, 2]) == pytest.approx(1.0)
|
||||
score_funcs_with_changing_means = [
|
||||
normalized_mutual_info_score,
|
||||
adjusted_mutual_info_score,
|
||||
]
|
||||
means = {"min", "geometric", "arithmetic", "max"}
|
||||
for score_func in score_funcs_with_changing_means:
|
||||
for mean in means:
|
||||
assert score_func([], [], average_method=mean) == pytest.approx(1.0)
|
||||
assert score_func([0], [1], average_method=mean) == pytest.approx(1.0)
|
||||
assert score_func(
|
||||
[0, 0, 0], [0, 0, 0], average_method=mean
|
||||
) == pytest.approx(1.0)
|
||||
assert score_func(
|
||||
[0, 1, 0], [42, 7, 42], average_method=mean
|
||||
) == pytest.approx(1.0)
|
||||
assert score_func(
|
||||
[0.0, 1.0, 0.0], [42.0, 7.0, 42.0], average_method=mean
|
||||
) == pytest.approx(1.0)
|
||||
assert score_func(
|
||||
[0.0, 1.0, 2.0], [42.0, 7.0, 2.0], average_method=mean
|
||||
) == pytest.approx(1.0)
|
||||
assert score_func(
|
||||
[0, 1, 2], [42, 7, 2], average_method=mean
|
||||
) == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_homogeneous_but_not_complete_labeling():
|
||||
# homogeneous but not complete clustering
|
||||
h, c, v = homogeneity_completeness_v_measure([0, 0, 0, 1, 1, 1], [0, 0, 0, 1, 2, 2])
|
||||
assert_almost_equal(h, 1.00, 2)
|
||||
assert_almost_equal(c, 0.69, 2)
|
||||
assert_almost_equal(v, 0.81, 2)
|
||||
|
||||
|
||||
def test_complete_but_not_homogeneous_labeling():
|
||||
# complete but not homogeneous clustering
|
||||
h, c, v = homogeneity_completeness_v_measure([0, 0, 1, 1, 2, 2], [0, 0, 1, 1, 1, 1])
|
||||
assert_almost_equal(h, 0.58, 2)
|
||||
assert_almost_equal(c, 1.00, 2)
|
||||
assert_almost_equal(v, 0.73, 2)
|
||||
|
||||
|
||||
def test_not_complete_and_not_homogeneous_labeling():
|
||||
# neither complete nor homogeneous but not so bad either
|
||||
h, c, v = homogeneity_completeness_v_measure([0, 0, 0, 1, 1, 1], [0, 1, 0, 1, 2, 2])
|
||||
assert_almost_equal(h, 0.67, 2)
|
||||
assert_almost_equal(c, 0.42, 2)
|
||||
assert_almost_equal(v, 0.52, 2)
|
||||
|
||||
|
||||
def test_beta_parameter():
|
||||
# test for when beta passed to
|
||||
# homogeneity_completeness_v_measure
|
||||
# and v_measure_score
|
||||
beta_test = 0.2
|
||||
h_test = 0.67
|
||||
c_test = 0.42
|
||||
v_test = (1 + beta_test) * h_test * c_test / (beta_test * h_test + c_test)
|
||||
|
||||
h, c, v = homogeneity_completeness_v_measure(
|
||||
[0, 0, 0, 1, 1, 1], [0, 1, 0, 1, 2, 2], beta=beta_test
|
||||
)
|
||||
assert_almost_equal(h, h_test, 2)
|
||||
assert_almost_equal(c, c_test, 2)
|
||||
assert_almost_equal(v, v_test, 2)
|
||||
|
||||
v = v_measure_score([0, 0, 0, 1, 1, 1], [0, 1, 0, 1, 2, 2], beta=beta_test)
|
||||
assert_almost_equal(v, v_test, 2)
|
||||
|
||||
|
||||
def test_non_consecutive_labels():
|
||||
# regression tests for labels with gaps
|
||||
h, c, v = homogeneity_completeness_v_measure([0, 0, 0, 2, 2, 2], [0, 1, 0, 1, 2, 2])
|
||||
assert_almost_equal(h, 0.67, 2)
|
||||
assert_almost_equal(c, 0.42, 2)
|
||||
assert_almost_equal(v, 0.52, 2)
|
||||
|
||||
h, c, v = homogeneity_completeness_v_measure([0, 0, 0, 1, 1, 1], [0, 4, 0, 4, 2, 2])
|
||||
assert_almost_equal(h, 0.67, 2)
|
||||
assert_almost_equal(c, 0.42, 2)
|
||||
assert_almost_equal(v, 0.52, 2)
|
||||
|
||||
ari_1 = adjusted_rand_score([0, 0, 0, 1, 1, 1], [0, 1, 0, 1, 2, 2])
|
||||
ari_2 = adjusted_rand_score([0, 0, 0, 1, 1, 1], [0, 4, 0, 4, 2, 2])
|
||||
assert_almost_equal(ari_1, 0.24, 2)
|
||||
assert_almost_equal(ari_2, 0.24, 2)
|
||||
|
||||
ri_1 = rand_score([0, 0, 0, 1, 1, 1], [0, 1, 0, 1, 2, 2])
|
||||
ri_2 = rand_score([0, 0, 0, 1, 1, 1], [0, 4, 0, 4, 2, 2])
|
||||
assert_almost_equal(ri_1, 0.66, 2)
|
||||
assert_almost_equal(ri_2, 0.66, 2)
|
||||
|
||||
|
||||
def uniform_labelings_scores(score_func, n_samples, k_range, n_runs=10, seed=42):
|
||||
# Compute score for random uniform cluster labelings
|
||||
random_labels = np.random.RandomState(seed).randint
|
||||
scores = np.zeros((len(k_range), n_runs))
|
||||
for i, k in enumerate(k_range):
|
||||
for j in range(n_runs):
|
||||
labels_a = random_labels(low=0, high=k, size=n_samples)
|
||||
labels_b = random_labels(low=0, high=k, size=n_samples)
|
||||
scores[i, j] = score_func(labels_a, labels_b)
|
||||
return scores
|
||||
|
||||
|
||||
def test_adjustment_for_chance():
|
||||
# Check that adjusted scores are almost zero on random labels
|
||||
n_clusters_range = [2, 10, 50, 90]
|
||||
n_samples = 100
|
||||
n_runs = 10
|
||||
|
||||
scores = uniform_labelings_scores(
|
||||
adjusted_rand_score, n_samples, n_clusters_range, n_runs
|
||||
)
|
||||
|
||||
max_abs_scores = np.abs(scores).max(axis=1)
|
||||
assert_array_almost_equal(max_abs_scores, [0.02, 0.03, 0.03, 0.02], 2)
|
||||
|
||||
|
||||
def test_adjusted_mutual_info_score():
|
||||
# Compute the Adjusted Mutual Information and test against known values
|
||||
labels_a = np.array([1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3])
|
||||
labels_b = np.array([1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 3, 1, 3, 3, 3, 2, 2])
|
||||
# Mutual information
|
||||
mi = mutual_info_score(labels_a, labels_b)
|
||||
assert_almost_equal(mi, 0.41022, 5)
|
||||
# with provided sparse contingency
|
||||
C = contingency_matrix(labels_a, labels_b, sparse=True)
|
||||
mi = mutual_info_score(labels_a, labels_b, contingency=C)
|
||||
assert_almost_equal(mi, 0.41022, 5)
|
||||
# with provided dense contingency
|
||||
C = contingency_matrix(labels_a, labels_b)
|
||||
mi = mutual_info_score(labels_a, labels_b, contingency=C)
|
||||
assert_almost_equal(mi, 0.41022, 5)
|
||||
# Expected mutual information
|
||||
n_samples = C.sum()
|
||||
emi = expected_mutual_information(C, n_samples)
|
||||
assert_almost_equal(emi, 0.15042, 5)
|
||||
# Adjusted mutual information
|
||||
ami = adjusted_mutual_info_score(labels_a, labels_b)
|
||||
assert_almost_equal(ami, 0.27821, 5)
|
||||
ami = adjusted_mutual_info_score([1, 1, 2, 2], [2, 2, 3, 3])
|
||||
assert ami == pytest.approx(1.0)
|
||||
# Test with a very large array
|
||||
a110 = np.array([list(labels_a) * 110]).flatten()
|
||||
b110 = np.array([list(labels_b) * 110]).flatten()
|
||||
ami = adjusted_mutual_info_score(a110, b110)
|
||||
assert_almost_equal(ami, 0.38, 2)
|
||||
|
||||
|
||||
def test_expected_mutual_info_overflow():
|
||||
# Test for regression where contingency cell exceeds 2**16
|
||||
# leading to overflow in np.outer, resulting in EMI > 1
|
||||
assert expected_mutual_information(np.array([[70000]]), 70000) <= 1
|
||||
|
||||
|
||||
def test_int_overflow_mutual_info_fowlkes_mallows_score():
|
||||
# Test overflow in mutual_info_classif and fowlkes_mallows_score
|
||||
x = np.array(
|
||||
[1] * (52632 + 2529)
|
||||
+ [2] * (14660 + 793)
|
||||
+ [3] * (3271 + 204)
|
||||
+ [4] * (814 + 39)
|
||||
+ [5] * (316 + 20)
|
||||
)
|
||||
y = np.array(
|
||||
[0] * 52632
|
||||
+ [1] * 2529
|
||||
+ [0] * 14660
|
||||
+ [1] * 793
|
||||
+ [0] * 3271
|
||||
+ [1] * 204
|
||||
+ [0] * 814
|
||||
+ [1] * 39
|
||||
+ [0] * 316
|
||||
+ [1] * 20
|
||||
)
|
||||
|
||||
assert_all_finite(mutual_info_score(x, y))
|
||||
assert_all_finite(fowlkes_mallows_score(x, y))
|
||||
|
||||
|
||||
def test_entropy():
|
||||
ent = entropy([0, 0, 42.0])
|
||||
assert_almost_equal(ent, 0.6365141, 5)
|
||||
assert_almost_equal(entropy([]), 1)
|
||||
assert entropy([1, 1, 1, 1]) == 0
|
||||
|
||||
|
||||
def test_contingency_matrix():
|
||||
labels_a = np.array([1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3])
|
||||
labels_b = np.array([1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 3, 1, 3, 3, 3, 2, 2])
|
||||
C = contingency_matrix(labels_a, labels_b)
|
||||
C2 = np.histogram2d(labels_a, labels_b, bins=(np.arange(1, 5), np.arange(1, 5)))[0]
|
||||
assert_array_almost_equal(C, C2)
|
||||
C = contingency_matrix(labels_a, labels_b, eps=0.1)
|
||||
assert_array_almost_equal(C, C2 + 0.1)
|
||||
|
||||
|
||||
def test_contingency_matrix_sparse():
|
||||
labels_a = np.array([1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3])
|
||||
labels_b = np.array([1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 3, 1, 3, 3, 3, 2, 2])
|
||||
C = contingency_matrix(labels_a, labels_b)
|
||||
C_sparse = contingency_matrix(labels_a, labels_b, sparse=True).toarray()
|
||||
assert_array_almost_equal(C, C_sparse)
|
||||
with pytest.raises(ValueError, match="Cannot set 'eps' when sparse=True"):
|
||||
contingency_matrix(labels_a, labels_b, eps=1e-10, sparse=True)
|
||||
|
||||
|
||||
def test_exactly_zero_info_score():
|
||||
# Check numerical stability when information is exactly zero
|
||||
for i in np.logspace(1, 4, 4).astype(int):
|
||||
labels_a, labels_b = (np.ones(i, dtype=int), np.arange(i, dtype=int))
|
||||
assert normalized_mutual_info_score(labels_a, labels_b) == pytest.approx(0.0)
|
||||
assert v_measure_score(labels_a, labels_b) == pytest.approx(0.0)
|
||||
assert adjusted_mutual_info_score(labels_a, labels_b) == pytest.approx(0.0)
|
||||
assert normalized_mutual_info_score(labels_a, labels_b) == pytest.approx(0.0)
|
||||
for method in ["min", "geometric", "arithmetic", "max"]:
|
||||
assert adjusted_mutual_info_score(
|
||||
labels_a, labels_b, average_method=method
|
||||
) == pytest.approx(0.0)
|
||||
assert normalized_mutual_info_score(
|
||||
labels_a, labels_b, average_method=method
|
||||
) == pytest.approx(0.0)
|
||||
|
||||
|
||||
def test_v_measure_and_mutual_information(seed=36):
|
||||
# Check relation between v_measure, entropy and mutual information
|
||||
for i in np.logspace(1, 4, 4).astype(int):
|
||||
random_state = np.random.RandomState(seed)
|
||||
labels_a, labels_b = (
|
||||
random_state.randint(0, 10, i),
|
||||
random_state.randint(0, 10, i),
|
||||
)
|
||||
assert_almost_equal(
|
||||
v_measure_score(labels_a, labels_b),
|
||||
2.0
|
||||
* mutual_info_score(labels_a, labels_b)
|
||||
/ (entropy(labels_a) + entropy(labels_b)),
|
||||
0,
|
||||
)
|
||||
avg = "arithmetic"
|
||||
assert_almost_equal(
|
||||
v_measure_score(labels_a, labels_b),
|
||||
normalized_mutual_info_score(labels_a, labels_b, average_method=avg),
|
||||
)
|
||||
|
||||
|
||||
def test_fowlkes_mallows_score():
|
||||
# General case
|
||||
score = fowlkes_mallows_score([0, 0, 0, 1, 1, 1], [0, 0, 1, 1, 2, 2])
|
||||
assert_almost_equal(score, 4.0 / np.sqrt(12.0 * 6.0))
|
||||
|
||||
# Perfect match but where the label names changed
|
||||
perfect_score = fowlkes_mallows_score([0, 0, 0, 1, 1, 1], [1, 1, 1, 0, 0, 0])
|
||||
assert_almost_equal(perfect_score, 1.0)
|
||||
|
||||
# Worst case
|
||||
worst_score = fowlkes_mallows_score([0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 4, 5])
|
||||
assert_almost_equal(worst_score, 0.0)
|
||||
|
||||
|
||||
def test_fowlkes_mallows_score_properties():
|
||||
# handcrafted example
|
||||
labels_a = np.array([0, 0, 0, 1, 1, 2])
|
||||
labels_b = np.array([1, 1, 2, 2, 0, 0])
|
||||
expected = 1.0 / np.sqrt((1.0 + 3.0) * (1.0 + 2.0))
|
||||
# FMI = TP / sqrt((TP + FP) * (TP + FN))
|
||||
|
||||
score_original = fowlkes_mallows_score(labels_a, labels_b)
|
||||
assert_almost_equal(score_original, expected)
|
||||
|
||||
# symmetric property
|
||||
score_symmetric = fowlkes_mallows_score(labels_b, labels_a)
|
||||
assert_almost_equal(score_symmetric, expected)
|
||||
|
||||
# permutation property
|
||||
score_permuted = fowlkes_mallows_score((labels_a + 1) % 3, labels_b)
|
||||
assert_almost_equal(score_permuted, expected)
|
||||
|
||||
# symmetric and permutation(both together)
|
||||
score_both = fowlkes_mallows_score(labels_b, (labels_a + 2) % 3)
|
||||
assert_almost_equal(score_both, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"labels_true, labels_pred",
|
||||
[
|
||||
(["a"] * 6, [1, 1, 0, 0, 1, 1]),
|
||||
([1] * 6, [1, 1, 0, 0, 1, 1]),
|
||||
([1, 1, 0, 0, 1, 1], ["a"] * 6),
|
||||
([1, 1, 0, 0, 1, 1], [1] * 6),
|
||||
(["a"] * 6, ["a"] * 6),
|
||||
],
|
||||
)
|
||||
def test_mutual_info_score_positive_constant_label(labels_true, labels_pred):
|
||||
# Check that MI = 0 when one or both labelling are constant
|
||||
# non-regression test for #16355
|
||||
assert mutual_info_score(labels_true, labels_pred) == 0
|
||||
|
||||
|
||||
def test_check_clustering_error():
|
||||
# Test warning message for continuous values
|
||||
rng = np.random.RandomState(42)
|
||||
noise = rng.rand(500)
|
||||
wavelength = np.linspace(0.01, 1, 500) * 1e-6
|
||||
msg = (
|
||||
"Clustering metrics expects discrete values but received "
|
||||
"continuous values for label, and continuous values for "
|
||||
"target"
|
||||
)
|
||||
|
||||
with pytest.warns(UserWarning, match=msg):
|
||||
check_clusterings(wavelength, noise)
|
||||
|
||||
|
||||
def test_pair_confusion_matrix_fully_dispersed():
|
||||
# edge case: every element is its own cluster
|
||||
N = 100
|
||||
clustering1 = list(range(N))
|
||||
clustering2 = clustering1
|
||||
expected = np.array([[N * (N - 1), 0], [0, 0]])
|
||||
assert_array_equal(pair_confusion_matrix(clustering1, clustering2), expected)
|
||||
|
||||
|
||||
def test_pair_confusion_matrix_single_cluster():
|
||||
# edge case: only one cluster
|
||||
N = 100
|
||||
clustering1 = np.zeros((N,))
|
||||
clustering2 = clustering1
|
||||
expected = np.array([[0, 0], [0, N * (N - 1)]])
|
||||
assert_array_equal(pair_confusion_matrix(clustering1, clustering2), expected)
|
||||
|
||||
|
||||
def test_pair_confusion_matrix():
|
||||
# regular case: different non-trivial clusterings
|
||||
n = 10
|
||||
N = n**2
|
||||
clustering1 = np.hstack([[i + 1] * n for i in range(n)])
|
||||
clustering2 = np.hstack([[i + 1] * (n + 1) for i in range(n)])[:N]
|
||||
# basic quadratic implementation
|
||||
expected = np.zeros(shape=(2, 2), dtype=np.int64)
|
||||
for i in range(len(clustering1)):
|
||||
for j in range(len(clustering2)):
|
||||
if i != j:
|
||||
same_cluster_1 = int(clustering1[i] == clustering1[j])
|
||||
same_cluster_2 = int(clustering2[i] == clustering2[j])
|
||||
expected[same_cluster_1, same_cluster_2] += 1
|
||||
assert_array_equal(pair_confusion_matrix(clustering1, clustering2), expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clustering1, clustering2",
|
||||
[(list(range(100)), list(range(100))), (np.zeros((100,)), np.zeros((100,)))],
|
||||
)
|
||||
def test_rand_score_edge_cases(clustering1, clustering2):
|
||||
# edge case 1: every element is its own cluster
|
||||
# edge case 2: only one cluster
|
||||
assert_allclose(rand_score(clustering1, clustering2), 1.0)
|
||||
|
||||
|
||||
def test_rand_score():
|
||||
# regular case: different non-trivial clusterings
|
||||
clustering1 = [0, 0, 0, 1, 1, 1]
|
||||
clustering2 = [0, 1, 0, 1, 2, 2]
|
||||
# pair confusion matrix
|
||||
D11 = 2 * 2 # ordered pairs (1, 3), (5, 6)
|
||||
D10 = 2 * 4 # ordered pairs (1, 2), (2, 3), (4, 5), (4, 6)
|
||||
D01 = 2 * 1 # ordered pair (2, 4)
|
||||
D00 = 5 * 6 - D11 - D01 - D10 # the remaining pairs
|
||||
# rand score
|
||||
expected_numerator = D00 + D11
|
||||
expected_denominator = D00 + D01 + D10 + D11
|
||||
expected = expected_numerator / expected_denominator
|
||||
assert_allclose(rand_score(clustering1, clustering2), expected)
|
||||
|
||||
|
||||
def test_adjusted_rand_score_overflow():
|
||||
"""Check that large amount of data will not lead to overflow in
|
||||
`adjusted_rand_score`.
|
||||
Non-regression test for:
|
||||
https://github.com/scikit-learn/scikit-learn/issues/20305
|
||||
"""
|
||||
rng = np.random.RandomState(0)
|
||||
y_true = rng.randint(0, 2, 100_000, dtype=np.int8)
|
||||
y_pred = rng.randint(0, 2, 100_000, dtype=np.int8)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error", RuntimeWarning)
|
||||
adjusted_rand_score(y_true, y_pred)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("average_method", ["min", "arithmetic", "geometric", "max"])
|
||||
def test_normalized_mutual_info_score_bounded(average_method):
|
||||
"""Check that nmi returns a score between 0 (included) and 1 (excluded
|
||||
for non-perfect match)
|
||||
|
||||
Non-regression test for issue #13836
|
||||
"""
|
||||
labels1 = [0] * 469
|
||||
labels2 = [1] + labels1[1:]
|
||||
labels3 = [0, 1] + labels1[2:]
|
||||
|
||||
# labels1 is constant. The mutual info between labels1 and any other labelling is 0.
|
||||
nmi = normalized_mutual_info_score(labels1, labels2, average_method=average_method)
|
||||
assert nmi == 0
|
||||
|
||||
# non constant, non perfect matching labels
|
||||
nmi = normalized_mutual_info_score(labels2, labels3, average_method=average_method)
|
||||
assert 0 <= nmi < 1
|
||||
@@ -0,0 +1,370 @@
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
import pytest
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
from sklearn import datasets
|
||||
from sklearn.utils._testing import assert_array_equal
|
||||
from sklearn.metrics.cluster import silhouette_score
|
||||
from sklearn.metrics.cluster import silhouette_samples
|
||||
from sklearn.metrics import pairwise_distances
|
||||
from sklearn.metrics.cluster import calinski_harabasz_score
|
||||
from sklearn.metrics.cluster import davies_bouldin_score
|
||||
|
||||
|
||||
def test_silhouette():
|
||||
# Tests the Silhouette Coefficient.
|
||||
dataset = datasets.load_iris()
|
||||
X_dense = dataset.data
|
||||
X_csr = csr_matrix(X_dense)
|
||||
X_dok = sp.dok_matrix(X_dense)
|
||||
X_lil = sp.lil_matrix(X_dense)
|
||||
y = dataset.target
|
||||
|
||||
for X in [X_dense, X_csr, X_dok, X_lil]:
|
||||
D = pairwise_distances(X, metric="euclidean")
|
||||
# Given that the actual labels are used, we can assume that S would be
|
||||
# positive.
|
||||
score_precomputed = silhouette_score(D, y, metric="precomputed")
|
||||
assert score_precomputed > 0
|
||||
# Test without calculating D
|
||||
score_euclidean = silhouette_score(X, y, metric="euclidean")
|
||||
pytest.approx(score_precomputed, score_euclidean)
|
||||
|
||||
if X is X_dense:
|
||||
score_dense_without_sampling = score_precomputed
|
||||
else:
|
||||
pytest.approx(score_euclidean, score_dense_without_sampling)
|
||||
|
||||
# Test with sampling
|
||||
score_precomputed = silhouette_score(
|
||||
D, y, metric="precomputed", sample_size=int(X.shape[0] / 2), random_state=0
|
||||
)
|
||||
score_euclidean = silhouette_score(
|
||||
X, y, metric="euclidean", sample_size=int(X.shape[0] / 2), random_state=0
|
||||
)
|
||||
assert score_precomputed > 0
|
||||
assert score_euclidean > 0
|
||||
pytest.approx(score_euclidean, score_precomputed)
|
||||
|
||||
if X is X_dense:
|
||||
score_dense_with_sampling = score_precomputed
|
||||
else:
|
||||
pytest.approx(score_euclidean, score_dense_with_sampling)
|
||||
|
||||
|
||||
def test_cluster_size_1():
|
||||
# Assert Silhouette Coefficient == 0 when there is 1 sample in a cluster
|
||||
# (cluster 0). We also test the case where there are identical samples
|
||||
# as the only members of a cluster (cluster 2). To our knowledge, this case
|
||||
# is not discussed in reference material, and we choose for it a sample
|
||||
# score of 1.
|
||||
X = [[0.0], [1.0], [1.0], [2.0], [3.0], [3.0]]
|
||||
labels = np.array([0, 1, 1, 1, 2, 2])
|
||||
|
||||
# Cluster 0: 1 sample -> score of 0 by Rousseeuw's convention
|
||||
# Cluster 1: intra-cluster = [.5, .5, 1]
|
||||
# inter-cluster = [1, 1, 1]
|
||||
# silhouette = [.5, .5, 0]
|
||||
# Cluster 2: intra-cluster = [0, 0]
|
||||
# inter-cluster = [arbitrary, arbitrary]
|
||||
# silhouette = [1., 1.]
|
||||
|
||||
silhouette = silhouette_score(X, labels)
|
||||
assert not np.isnan(silhouette)
|
||||
ss = silhouette_samples(X, labels)
|
||||
assert_array_equal(ss, [0, 0.5, 0.5, 0, 1, 1])
|
||||
|
||||
|
||||
def test_silhouette_paper_example():
|
||||
# Explicitly check per-sample results against Rousseeuw (1987)
|
||||
# Data from Table 1
|
||||
lower = [
|
||||
5.58,
|
||||
7.00,
|
||||
6.50,
|
||||
7.08,
|
||||
7.00,
|
||||
3.83,
|
||||
4.83,
|
||||
5.08,
|
||||
8.17,
|
||||
5.83,
|
||||
2.17,
|
||||
5.75,
|
||||
6.67,
|
||||
6.92,
|
||||
4.92,
|
||||
6.42,
|
||||
5.00,
|
||||
5.58,
|
||||
6.00,
|
||||
4.67,
|
||||
6.42,
|
||||
3.42,
|
||||
5.50,
|
||||
6.42,
|
||||
6.42,
|
||||
5.00,
|
||||
3.92,
|
||||
6.17,
|
||||
2.50,
|
||||
4.92,
|
||||
6.25,
|
||||
7.33,
|
||||
4.50,
|
||||
2.25,
|
||||
6.33,
|
||||
2.75,
|
||||
6.08,
|
||||
6.67,
|
||||
4.25,
|
||||
2.67,
|
||||
6.00,
|
||||
6.17,
|
||||
6.17,
|
||||
6.92,
|
||||
6.17,
|
||||
5.25,
|
||||
6.83,
|
||||
4.50,
|
||||
3.75,
|
||||
5.75,
|
||||
5.42,
|
||||
6.08,
|
||||
5.83,
|
||||
6.67,
|
||||
3.67,
|
||||
4.75,
|
||||
3.00,
|
||||
6.08,
|
||||
6.67,
|
||||
5.00,
|
||||
5.58,
|
||||
4.83,
|
||||
6.17,
|
||||
5.67,
|
||||
6.50,
|
||||
6.92,
|
||||
]
|
||||
D = np.zeros((12, 12))
|
||||
D[np.tril_indices(12, -1)] = lower
|
||||
D += D.T
|
||||
|
||||
names = [
|
||||
"BEL",
|
||||
"BRA",
|
||||
"CHI",
|
||||
"CUB",
|
||||
"EGY",
|
||||
"FRA",
|
||||
"IND",
|
||||
"ISR",
|
||||
"USA",
|
||||
"USS",
|
||||
"YUG",
|
||||
"ZAI",
|
||||
]
|
||||
|
||||
# Data from Figure 2
|
||||
labels1 = [1, 1, 2, 2, 1, 1, 2, 1, 1, 2, 2, 1]
|
||||
expected1 = {
|
||||
"USA": 0.43,
|
||||
"BEL": 0.39,
|
||||
"FRA": 0.35,
|
||||
"ISR": 0.30,
|
||||
"BRA": 0.22,
|
||||
"EGY": 0.20,
|
||||
"ZAI": 0.19,
|
||||
"CUB": 0.40,
|
||||
"USS": 0.34,
|
||||
"CHI": 0.33,
|
||||
"YUG": 0.26,
|
||||
"IND": -0.04,
|
||||
}
|
||||
score1 = 0.28
|
||||
|
||||
# Data from Figure 3
|
||||
labels2 = [1, 2, 3, 3, 1, 1, 2, 1, 1, 3, 3, 2]
|
||||
expected2 = {
|
||||
"USA": 0.47,
|
||||
"FRA": 0.44,
|
||||
"BEL": 0.42,
|
||||
"ISR": 0.37,
|
||||
"EGY": 0.02,
|
||||
"ZAI": 0.28,
|
||||
"BRA": 0.25,
|
||||
"IND": 0.17,
|
||||
"CUB": 0.48,
|
||||
"USS": 0.44,
|
||||
"YUG": 0.31,
|
||||
"CHI": 0.31,
|
||||
}
|
||||
score2 = 0.33
|
||||
|
||||
for labels, expected, score in [
|
||||
(labels1, expected1, score1),
|
||||
(labels2, expected2, score2),
|
||||
]:
|
||||
expected = [expected[name] for name in names]
|
||||
# we check to 2dp because that's what's in the paper
|
||||
pytest.approx(
|
||||
expected,
|
||||
silhouette_samples(D, np.array(labels), metric="precomputed"),
|
||||
abs=1e-2,
|
||||
)
|
||||
pytest.approx(
|
||||
score, silhouette_score(D, np.array(labels), metric="precomputed"), abs=1e-2
|
||||
)
|
||||
|
||||
|
||||
def test_correct_labelsize():
|
||||
# Assert 1 < n_labels < n_samples
|
||||
dataset = datasets.load_iris()
|
||||
X = dataset.data
|
||||
|
||||
# n_labels = n_samples
|
||||
y = np.arange(X.shape[0])
|
||||
err_msg = (
|
||||
r"Number of labels is %d\. Valid values are 2 "
|
||||
r"to n_samples - 1 \(inclusive\)" % len(np.unique(y))
|
||||
)
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
silhouette_score(X, y)
|
||||
|
||||
# n_labels = 1
|
||||
y = np.zeros(X.shape[0])
|
||||
err_msg = (
|
||||
r"Number of labels is %d\. Valid values are 2 "
|
||||
r"to n_samples - 1 \(inclusive\)" % len(np.unique(y))
|
||||
)
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
silhouette_score(X, y)
|
||||
|
||||
|
||||
def test_non_encoded_labels():
|
||||
dataset = datasets.load_iris()
|
||||
X = dataset.data
|
||||
labels = dataset.target
|
||||
assert silhouette_score(X, labels * 2 + 10) == silhouette_score(X, labels)
|
||||
assert_array_equal(
|
||||
silhouette_samples(X, labels * 2 + 10), silhouette_samples(X, labels)
|
||||
)
|
||||
|
||||
|
||||
def test_non_numpy_labels():
|
||||
dataset = datasets.load_iris()
|
||||
X = dataset.data
|
||||
y = dataset.target
|
||||
assert silhouette_score(list(X), list(y)) == silhouette_score(X, y)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", (np.float32, np.float64))
|
||||
def test_silhouette_nonzero_diag(dtype):
|
||||
# Make sure silhouette_samples requires diagonal to be zero.
|
||||
# Non-regression test for #12178
|
||||
|
||||
# Construct a zero-diagonal matrix
|
||||
dists = pairwise_distances(
|
||||
np.array([[0.2, 0.1, 0.12, 1.34, 1.11, 1.6]], dtype=dtype).T
|
||||
)
|
||||
labels = [0, 0, 0, 1, 1, 1]
|
||||
|
||||
# small values on the diagonal are OK
|
||||
dists[2][2] = np.finfo(dists.dtype).eps * 10
|
||||
silhouette_samples(dists, labels, metric="precomputed")
|
||||
|
||||
# values bigger than eps * 100 are not
|
||||
dists[2][2] = np.finfo(dists.dtype).eps * 1000
|
||||
with pytest.raises(ValueError, match="contains non-zero"):
|
||||
silhouette_samples(dists, labels, metric="precomputed")
|
||||
|
||||
|
||||
def assert_raises_on_only_one_label(func):
|
||||
"""Assert message when there is only one label"""
|
||||
rng = np.random.RandomState(seed=0)
|
||||
with pytest.raises(ValueError, match="Number of labels is"):
|
||||
func(rng.rand(10, 2), np.zeros(10))
|
||||
|
||||
|
||||
def assert_raises_on_all_points_same_cluster(func):
|
||||
"""Assert message when all point are in different clusters"""
|
||||
rng = np.random.RandomState(seed=0)
|
||||
with pytest.raises(ValueError, match="Number of labels is"):
|
||||
func(rng.rand(10, 2), np.arange(10))
|
||||
|
||||
|
||||
def test_calinski_harabasz_score():
|
||||
assert_raises_on_only_one_label(calinski_harabasz_score)
|
||||
|
||||
assert_raises_on_all_points_same_cluster(calinski_harabasz_score)
|
||||
|
||||
# Assert the value is 1. when all samples are equals
|
||||
assert 1.0 == calinski_harabasz_score(np.ones((10, 2)), [0] * 5 + [1] * 5)
|
||||
|
||||
# Assert the value is 0. when all the mean cluster are equal
|
||||
assert 0.0 == calinski_harabasz_score([[-1, -1], [1, 1]] * 10, [0] * 10 + [1] * 10)
|
||||
|
||||
# General case (with non numpy arrays)
|
||||
X = (
|
||||
[[0, 0], [1, 1]] * 5
|
||||
+ [[3, 3], [4, 4]] * 5
|
||||
+ [[0, 4], [1, 3]] * 5
|
||||
+ [[3, 1], [4, 0]] * 5
|
||||
)
|
||||
labels = [0] * 10 + [1] * 10 + [2] * 10 + [3] * 10
|
||||
pytest.approx(calinski_harabasz_score(X, labels), 45 * (40 - 4) / (5 * (4 - 1)))
|
||||
|
||||
|
||||
def test_davies_bouldin_score():
|
||||
assert_raises_on_only_one_label(davies_bouldin_score)
|
||||
assert_raises_on_all_points_same_cluster(davies_bouldin_score)
|
||||
|
||||
# Assert the value is 0. when all samples are equals
|
||||
assert davies_bouldin_score(np.ones((10, 2)), [0] * 5 + [1] * 5) == pytest.approx(
|
||||
0.0
|
||||
)
|
||||
|
||||
# Assert the value is 0. when all the mean cluster are equal
|
||||
assert davies_bouldin_score(
|
||||
[[-1, -1], [1, 1]] * 10, [0] * 10 + [1] * 10
|
||||
) == pytest.approx(0.0)
|
||||
|
||||
# General case (with non numpy arrays)
|
||||
X = (
|
||||
[[0, 0], [1, 1]] * 5
|
||||
+ [[3, 3], [4, 4]] * 5
|
||||
+ [[0, 4], [1, 3]] * 5
|
||||
+ [[3, 1], [4, 0]] * 5
|
||||
)
|
||||
labels = [0] * 10 + [1] * 10 + [2] * 10 + [3] * 10
|
||||
pytest.approx(davies_bouldin_score(X, labels), 2 * np.sqrt(0.5) / 3)
|
||||
|
||||
# Ensure divide by zero warning is not raised in general case
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error", RuntimeWarning)
|
||||
davies_bouldin_score(X, labels)
|
||||
|
||||
# General case - cluster have one sample
|
||||
X = [[0, 0], [2, 2], [3, 3], [5, 5]]
|
||||
labels = [0, 0, 1, 2]
|
||||
pytest.approx(davies_bouldin_score(X, labels), (5.0 / 4) / 3)
|
||||
|
||||
|
||||
def test_silhouette_score_integer_precomputed():
|
||||
"""Check that silhouette_score works for precomputed metrics that are integers.
|
||||
|
||||
Non-regression test for #22107.
|
||||
"""
|
||||
result = silhouette_score(
|
||||
[[0, 1, 2], [1, 0, 1], [2, 1, 0]], [0, 0, 1], metric="precomputed"
|
||||
)
|
||||
assert result == pytest.approx(1 / 6)
|
||||
|
||||
# non-zero on diagonal for ints raises an error
|
||||
with pytest.raises(ValueError, match="contains non-zero"):
|
||||
silhouette_score(
|
||||
[[1, 1, 2], [1, 0, 1], [2, 1, 0]], [0, 0, 1], metric="precomputed"
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from numpy.distutils.misc_util import Configuration
|
||||
|
||||
|
||||
def configuration(parent_package="", top_path=None):
|
||||
config = Configuration("metrics", parent_package, top_path)
|
||||
|
||||
libraries = []
|
||||
if os.name == "posix":
|
||||
libraries.append("m")
|
||||
|
||||
config.add_subpackage("_plot")
|
||||
config.add_subpackage("_plot.tests")
|
||||
config.add_subpackage("cluster")
|
||||
|
||||
config.add_extension(
|
||||
"_pairwise_fast", sources=["_pairwise_fast.pyx"], libraries=libraries
|
||||
)
|
||||
|
||||
config.add_extension(
|
||||
"_dist_metrics",
|
||||
sources=["_dist_metrics.pyx"],
|
||||
include_dirs=[np.get_include(), os.path.join(np.get_include(), "numpy")],
|
||||
libraries=libraries,
|
||||
)
|
||||
|
||||
config.add_extension(
|
||||
"_pairwise_distances_reduction",
|
||||
sources=["_pairwise_distances_reduction.pyx"],
|
||||
include_dirs=[np.get_include(), os.path.join(np.get_include(), "numpy")],
|
||||
language="c++",
|
||||
libraries=libraries,
|
||||
extra_compile_args=["-std=c++11"],
|
||||
)
|
||||
|
||||
config.add_subpackage("tests")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from numpy.distutils.core import setup
|
||||
|
||||
setup(**configuration().todict())
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,333 @@
|
||||
import itertools
|
||||
import pickle
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_array_almost_equal
|
||||
|
||||
import pytest
|
||||
|
||||
import scipy.sparse as sp
|
||||
from scipy.spatial.distance import cdist
|
||||
from sklearn.metrics import DistanceMetric
|
||||
from sklearn.metrics._dist_metrics import BOOL_METRICS
|
||||
from sklearn.utils import check_random_state
|
||||
from sklearn.utils._testing import create_memmap_backed_data
|
||||
from sklearn.utils.fixes import sp_version, parse_version
|
||||
|
||||
|
||||
def dist_func(x1, x2, p):
|
||||
return np.sum((x1 - x2) ** p) ** (1.0 / p)
|
||||
|
||||
|
||||
rng = check_random_state(0)
|
||||
d = 4
|
||||
n1 = 20
|
||||
n2 = 25
|
||||
X1 = rng.random_sample((n1, d)).astype("float64", copy=False)
|
||||
X2 = rng.random_sample((n2, d)).astype("float64", copy=False)
|
||||
|
||||
[X1_mmap, X2_mmap] = create_memmap_backed_data([X1, X2])
|
||||
|
||||
# make boolean arrays: ones and zeros
|
||||
X1_bool = X1.round(0)
|
||||
X2_bool = X2.round(0)
|
||||
|
||||
[X1_bool_mmap, X2_bool_mmap] = create_memmap_backed_data([X1_bool, X2_bool])
|
||||
|
||||
|
||||
V = rng.random_sample((d, d))
|
||||
VI = np.dot(V, V.T)
|
||||
|
||||
|
||||
METRICS_DEFAULT_PARAMS = [
|
||||
("euclidean", {}),
|
||||
("cityblock", {}),
|
||||
("minkowski", dict(p=(1, 1.5, 2, 3))),
|
||||
("chebyshev", {}),
|
||||
("seuclidean", dict(V=(rng.random_sample(d),))),
|
||||
("mahalanobis", dict(VI=(VI,))),
|
||||
("hamming", {}),
|
||||
("canberra", {}),
|
||||
("braycurtis", {}),
|
||||
]
|
||||
if sp_version >= parse_version("1.8.0.dev0"):
|
||||
# Starting from scipy 1.8.0.dev0, minkowski now accepts w, the weighting
|
||||
# parameter directly and using it is preferred over using wminkowski.
|
||||
METRICS_DEFAULT_PARAMS.append(
|
||||
("minkowski", dict(p=(1, 1.5, 3), w=(rng.random_sample(d),))),
|
||||
)
|
||||
else:
|
||||
# For previous versions of scipy, this was possible through a dedicated
|
||||
# metric (deprecated in 1.6 and removed in 1.8).
|
||||
METRICS_DEFAULT_PARAMS.append(
|
||||
("wminkowski", dict(p=(1, 1.5, 3), w=(rng.random_sample(d),))),
|
||||
)
|
||||
|
||||
|
||||
def check_cdist(metric, kwargs, X1, X2):
|
||||
if metric == "wminkowski":
|
||||
# wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0
|
||||
WarningToExpect = None
|
||||
if sp_version >= parse_version("1.6.0"):
|
||||
WarningToExpect = DeprecationWarning
|
||||
with pytest.warns(WarningToExpect):
|
||||
D_scipy_cdist = cdist(X1, X2, metric, **kwargs)
|
||||
else:
|
||||
D_scipy_cdist = cdist(X1, X2, metric, **kwargs)
|
||||
|
||||
dm = DistanceMetric.get_metric(metric, **kwargs)
|
||||
D_sklearn = dm.pairwise(X1, X2)
|
||||
assert_array_almost_equal(D_sklearn, D_scipy_cdist)
|
||||
|
||||
|
||||
# TODO: Remove filterwarnings in 1.3 when wminkowski is removed
|
||||
@pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn")
|
||||
@pytest.mark.parametrize("metric_param_grid", METRICS_DEFAULT_PARAMS)
|
||||
@pytest.mark.parametrize("X1, X2", [(X1, X2), (X1_mmap, X2_mmap)])
|
||||
def test_cdist(metric_param_grid, X1, X2):
|
||||
metric, param_grid = metric_param_grid
|
||||
keys = param_grid.keys()
|
||||
for vals in itertools.product(*param_grid.values()):
|
||||
kwargs = dict(zip(keys, vals))
|
||||
if metric == "mahalanobis":
|
||||
# See: https://github.com/scipy/scipy/issues/13861
|
||||
# Possibly caused by: https://github.com/joblib/joblib/issues/563
|
||||
pytest.xfail(
|
||||
"scipy#13861: cdist with 'mahalanobis' fails on joblib memmap data"
|
||||
)
|
||||
check_cdist(metric, kwargs, X1, X2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("metric", BOOL_METRICS)
|
||||
@pytest.mark.parametrize(
|
||||
"X1_bool, X2_bool", [(X1_bool, X2_bool), (X1_bool_mmap, X2_bool_mmap)]
|
||||
)
|
||||
def test_cdist_bool_metric(metric, X1_bool, X2_bool):
|
||||
D_true = cdist(X1_bool, X2_bool, metric)
|
||||
check_cdist_bool(metric, D_true)
|
||||
|
||||
|
||||
def check_cdist_bool(metric, D_true):
|
||||
dm = DistanceMetric.get_metric(metric)
|
||||
D12 = dm.pairwise(X1_bool, X2_bool)
|
||||
assert_array_almost_equal(D12, D_true)
|
||||
|
||||
|
||||
# TODO: Remove filterwarnings in 1.3 when wminkowski is removed
|
||||
@pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn")
|
||||
@pytest.mark.parametrize("metric_param_grid", METRICS_DEFAULT_PARAMS)
|
||||
@pytest.mark.parametrize("X1, X2", [(X1, X2), (X1_mmap, X2_mmap)])
|
||||
def test_pdist(metric_param_grid, X1, X2):
|
||||
metric, param_grid = metric_param_grid
|
||||
keys = param_grid.keys()
|
||||
for vals in itertools.product(*param_grid.values()):
|
||||
kwargs = dict(zip(keys, vals))
|
||||
if metric == "mahalanobis":
|
||||
# See: https://github.com/scipy/scipy/issues/13861
|
||||
pytest.xfail("scipy#13861: pdist with 'mahalanobis' fails onmemmap data")
|
||||
elif metric == "wminkowski":
|
||||
if sp_version >= parse_version("1.8.0"):
|
||||
pytest.skip("wminkowski will be removed in SciPy 1.8.0")
|
||||
|
||||
# wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0
|
||||
ExceptionToAssert = None
|
||||
if sp_version >= parse_version("1.6.0"):
|
||||
ExceptionToAssert = DeprecationWarning
|
||||
with pytest.warns(ExceptionToAssert):
|
||||
D_true = cdist(X1, X1, metric, **kwargs)
|
||||
else:
|
||||
D_true = cdist(X1, X1, metric, **kwargs)
|
||||
|
||||
check_pdist(metric, kwargs, D_true)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("metric", BOOL_METRICS)
|
||||
@pytest.mark.parametrize("X1_bool", [X1_bool, X1_bool_mmap])
|
||||
def test_pdist_bool_metrics(metric, X1_bool):
|
||||
D_true = cdist(X1_bool, X1_bool, metric)
|
||||
check_pdist_bool(metric, D_true)
|
||||
|
||||
|
||||
def check_pdist(metric, kwargs, D_true):
|
||||
dm = DistanceMetric.get_metric(metric, **kwargs)
|
||||
D12 = dm.pairwise(X1)
|
||||
assert_array_almost_equal(D12, D_true)
|
||||
|
||||
|
||||
def check_pdist_bool(metric, D_true):
|
||||
dm = DistanceMetric.get_metric(metric)
|
||||
D12 = dm.pairwise(X1_bool)
|
||||
# Based on https://github.com/scipy/scipy/pull/7373
|
||||
# When comparing two all-zero vectors, scipy>=1.2.0 jaccard metric
|
||||
# was changed to return 0, instead of nan.
|
||||
if metric == "jaccard" and sp_version < parse_version("1.2.0"):
|
||||
D_true[np.isnan(D_true)] = 0
|
||||
assert_array_almost_equal(D12, D_true)
|
||||
|
||||
|
||||
# TODO: Remove filterwarnings in 1.3 when wminkowski is removed
|
||||
@pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn")
|
||||
@pytest.mark.parametrize("writable_kwargs", [True, False])
|
||||
@pytest.mark.parametrize("metric_param_grid", METRICS_DEFAULT_PARAMS)
|
||||
def test_pickle(writable_kwargs, metric_param_grid):
|
||||
metric, param_grid = metric_param_grid
|
||||
keys = param_grid.keys()
|
||||
for vals in itertools.product(*param_grid.values()):
|
||||
if any(isinstance(val, np.ndarray) for val in vals):
|
||||
vals = copy.deepcopy(vals)
|
||||
for val in vals:
|
||||
if isinstance(val, np.ndarray):
|
||||
val.setflags(write=writable_kwargs)
|
||||
kwargs = dict(zip(keys, vals))
|
||||
check_pickle(metric, kwargs)
|
||||
|
||||
|
||||
# TODO: Remove filterwarnings in 1.3 when wminkowski is removed
|
||||
@pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn")
|
||||
@pytest.mark.parametrize("metric", BOOL_METRICS)
|
||||
@pytest.mark.parametrize("X1_bool", [X1_bool, X1_bool_mmap])
|
||||
def test_pickle_bool_metrics(metric, X1_bool):
|
||||
dm = DistanceMetric.get_metric(metric)
|
||||
D1 = dm.pairwise(X1_bool)
|
||||
dm2 = pickle.loads(pickle.dumps(dm))
|
||||
D2 = dm2.pairwise(X1_bool)
|
||||
assert_array_almost_equal(D1, D2)
|
||||
|
||||
|
||||
def check_pickle(metric, kwargs):
|
||||
dm = DistanceMetric.get_metric(metric, **kwargs)
|
||||
D1 = dm.pairwise(X1)
|
||||
dm2 = pickle.loads(pickle.dumps(dm))
|
||||
D2 = dm2.pairwise(X1)
|
||||
assert_array_almost_equal(D1, D2)
|
||||
|
||||
|
||||
def test_haversine_metric():
|
||||
def haversine_slow(x1, x2):
|
||||
return 2 * np.arcsin(
|
||||
np.sqrt(
|
||||
np.sin(0.5 * (x1[0] - x2[0])) ** 2
|
||||
+ np.cos(x1[0]) * np.cos(x2[0]) * np.sin(0.5 * (x1[1] - x2[1])) ** 2
|
||||
)
|
||||
)
|
||||
|
||||
X = np.random.random((10, 2))
|
||||
|
||||
haversine = DistanceMetric.get_metric("haversine")
|
||||
|
||||
D1 = haversine.pairwise(X)
|
||||
D2 = np.zeros_like(D1)
|
||||
for i, x1 in enumerate(X):
|
||||
for j, x2 in enumerate(X):
|
||||
D2[i, j] = haversine_slow(x1, x2)
|
||||
|
||||
assert_array_almost_equal(D1, D2)
|
||||
assert_array_almost_equal(haversine.dist_to_rdist(D1), np.sin(0.5 * D2) ** 2)
|
||||
|
||||
|
||||
def test_pyfunc_metric():
|
||||
X = np.random.random((10, 3))
|
||||
|
||||
euclidean = DistanceMetric.get_metric("euclidean")
|
||||
pyfunc = DistanceMetric.get_metric("pyfunc", func=dist_func, p=2)
|
||||
|
||||
# Check if both callable metric and predefined metric initialized
|
||||
# DistanceMetric object is picklable
|
||||
euclidean_pkl = pickle.loads(pickle.dumps(euclidean))
|
||||
pyfunc_pkl = pickle.loads(pickle.dumps(pyfunc))
|
||||
|
||||
D1 = euclidean.pairwise(X)
|
||||
D2 = pyfunc.pairwise(X)
|
||||
|
||||
D1_pkl = euclidean_pkl.pairwise(X)
|
||||
D2_pkl = pyfunc_pkl.pairwise(X)
|
||||
|
||||
assert_array_almost_equal(D1, D2)
|
||||
assert_array_almost_equal(D1_pkl, D2_pkl)
|
||||
|
||||
|
||||
def test_input_data_size():
|
||||
# Regression test for #6288
|
||||
# Previously, a metric requiring a particular input dimension would fail
|
||||
def custom_metric(x, y):
|
||||
assert x.shape[0] == 3
|
||||
return np.sum((x - y) ** 2)
|
||||
|
||||
rng = check_random_state(0)
|
||||
X = rng.rand(10, 3)
|
||||
|
||||
pyfunc = DistanceMetric.get_metric("pyfunc", func=custom_metric)
|
||||
eucl = DistanceMetric.get_metric("euclidean")
|
||||
assert_array_almost_equal(pyfunc.pairwise(X), eucl.pairwise(X) ** 2)
|
||||
|
||||
|
||||
# TODO: Remove filterwarnings in 1.3 when wminkowski is removed
|
||||
@pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn")
|
||||
def test_readonly_kwargs():
|
||||
# Non-regression test for:
|
||||
# https://github.com/scikit-learn/scikit-learn/issues/21685
|
||||
|
||||
rng = check_random_state(0)
|
||||
|
||||
weights = rng.rand(100)
|
||||
VI = rng.rand(10, 10)
|
||||
weights.setflags(write=False)
|
||||
VI.setflags(write=False)
|
||||
|
||||
# Those distances metrics have to support readonly buffers.
|
||||
DistanceMetric.get_metric("seuclidean", V=weights)
|
||||
DistanceMetric.get_metric("wminkowski", p=1, w=weights)
|
||||
DistanceMetric.get_metric("mahalanobis", VI=VI)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"w, err_type, err_msg",
|
||||
[
|
||||
(np.array([1, 1.5, -13]), ValueError, "w cannot contain negative weights"),
|
||||
(np.array([1, 1.5, np.nan]), ValueError, "w contains NaN"),
|
||||
(
|
||||
sp.csr_matrix([1, 1.5, 1]),
|
||||
TypeError,
|
||||
"A sparse matrix was passed, but dense data is required",
|
||||
),
|
||||
(np.array(["a", "b", "c"]), ValueError, "could not convert string to float"),
|
||||
(np.array([]), ValueError, "a minimum of 1 is required"),
|
||||
],
|
||||
)
|
||||
def test_minkowski_metric_validate_weights_values(w, err_type, err_msg):
|
||||
with pytest.raises(err_type, match=err_msg):
|
||||
DistanceMetric.get_metric("minkowski", p=3, w=w)
|
||||
|
||||
|
||||
def test_minkowski_metric_validate_weights_size():
|
||||
w2 = rng.random_sample(d + 1)
|
||||
dm = DistanceMetric.get_metric("minkowski", p=3, w=w2)
|
||||
msg = (
|
||||
"MinkowskiDistance: the size of w must match "
|
||||
f"the number of features \\({X1.shape[1]}\\). "
|
||||
f"Currently len\\(w\\)={w2.shape[0]}."
|
||||
)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
dm.pairwise(X1, X2)
|
||||
|
||||
|
||||
# TODO: Remove in 1.3 when wminkowski is removed
|
||||
def test_wminkowski_deprecated():
|
||||
w = rng.random_sample(d)
|
||||
msg = "WMinkowskiDistance is deprecated in version 1.1"
|
||||
with pytest.warns(FutureWarning, match=msg):
|
||||
DistanceMetric.get_metric("wminkowski", p=3, w=w)
|
||||
|
||||
|
||||
# TODO: Remove in 1.3 when wminkowski is removed
|
||||
@pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn")
|
||||
@pytest.mark.parametrize("p", [1, 1.5, 3])
|
||||
def test_wminkowski_minkowski_equivalence(p):
|
||||
w = rng.random_sample(d)
|
||||
# Weights are rescaled for consistency w.r.t scipy 1.8 refactoring of 'minkowski'
|
||||
dm_wmks = DistanceMetric.get_metric("wminkowski", p=p, w=(w) ** (1 / p))
|
||||
dm_mks = DistanceMetric.get_metric("minkowski", p=p, w=w)
|
||||
D_wmks = dm_wmks.pairwise(X1, X2)
|
||||
D_mks = dm_mks.pairwise(X1, X2)
|
||||
assert_array_almost_equal(D_wmks, D_mks)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,551 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import threadpoolctl
|
||||
from numpy.testing import assert_array_equal, assert_allclose
|
||||
from scipy.sparse import csr_matrix
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from sklearn.metrics._pairwise_distances_reduction import (
|
||||
PairwiseDistancesReduction,
|
||||
PairwiseDistancesArgKmin,
|
||||
PairwiseDistancesRadiusNeighborhood,
|
||||
_sqeuclidean_row_norms,
|
||||
)
|
||||
|
||||
from sklearn.metrics import euclidean_distances
|
||||
from sklearn.utils.fixes import sp_version, parse_version
|
||||
|
||||
# Common supported metric between scipy.spatial.distance.cdist
|
||||
# and PairwiseDistancesReduction.
|
||||
# This allows constructing tests to check consistency of results
|
||||
# of concrete PairwiseDistancesReduction on some metrics using APIs
|
||||
# from scipy and numpy.
|
||||
CDIST_PAIRWISE_DISTANCES_REDUCTION_COMMON_METRICS = [
|
||||
"braycurtis",
|
||||
"canberra",
|
||||
"chebyshev",
|
||||
"cityblock",
|
||||
"euclidean",
|
||||
"minkowski",
|
||||
"seuclidean",
|
||||
]
|
||||
|
||||
|
||||
def _get_metric_params_list(metric: str, n_features: int, seed: int = 1):
|
||||
"""Return list of dummy DistanceMetric kwargs for tests."""
|
||||
|
||||
# Distinguishing on cases not to compute unneeded datastructures.
|
||||
rng = np.random.RandomState(seed)
|
||||
|
||||
if metric == "minkowski":
|
||||
minkowski_kwargs = [dict(p=1.5), dict(p=2), dict(p=3), dict(p=np.inf)]
|
||||
if sp_version >= parse_version("1.8.0.dev0"):
|
||||
# TODO: remove the test once we no longer support scipy < 1.8.0.
|
||||
# Recent scipy versions accept weights in the Minkowski metric directly:
|
||||
# type: ignore
|
||||
minkowski_kwargs.append(dict(p=3, w=rng.rand(n_features)))
|
||||
|
||||
return minkowski_kwargs
|
||||
|
||||
# TODO: remove this case for "wminkowski" once we no longer support scipy < 1.8.0.
|
||||
if metric == "wminkowski":
|
||||
weights = rng.random_sample(n_features)
|
||||
weights /= weights.sum()
|
||||
wminkowski_kwargs = [dict(p=1.5, w=weights)]
|
||||
if sp_version < parse_version("1.8.0.dev0"):
|
||||
# wminkowski was removed in scipy 1.8.0 but should work for previous
|
||||
# versions.
|
||||
wminkowski_kwargs.append(dict(p=3, w=rng.rand(n_features)))
|
||||
return wminkowski_kwargs
|
||||
|
||||
if metric == "seuclidean":
|
||||
return [dict(V=rng.rand(n_features))]
|
||||
|
||||
# Case of: "euclidean", "manhattan", "chebyshev", "haversine" or any other metric.
|
||||
# In those cases, no kwargs is needed.
|
||||
return [{}]
|
||||
|
||||
|
||||
def assert_argkmin_results_equality(ref_dist, dist, ref_indices, indices):
|
||||
assert_array_equal(
|
||||
ref_indices,
|
||||
indices,
|
||||
err_msg="Query vectors have different neighbors' indices",
|
||||
)
|
||||
assert_allclose(
|
||||
ref_dist,
|
||||
dist,
|
||||
err_msg="Query vectors have different neighbors' distances",
|
||||
rtol=1e-7,
|
||||
)
|
||||
|
||||
|
||||
def assert_radius_neighborhood_results_equality(ref_dist, dist, ref_indices, indices):
|
||||
# We get arrays of arrays and we need to check for individual pairs
|
||||
for i in range(ref_dist.shape[0]):
|
||||
assert_array_equal(
|
||||
ref_indices[i],
|
||||
indices[i],
|
||||
err_msg=f"Query vector #{i} has different neighbors' indices",
|
||||
)
|
||||
assert_allclose(
|
||||
ref_dist[i],
|
||||
dist[i],
|
||||
err_msg=f"Query vector #{i} has different neighbors' distances",
|
||||
rtol=1e-7,
|
||||
)
|
||||
|
||||
|
||||
ASSERT_RESULT = {
|
||||
PairwiseDistancesArgKmin: assert_argkmin_results_equality,
|
||||
PairwiseDistancesRadiusNeighborhood: assert_radius_neighborhood_results_equality,
|
||||
}
|
||||
|
||||
|
||||
def test_pairwise_distances_reduction_is_usable_for():
|
||||
rng = np.random.RandomState(0)
|
||||
X = rng.rand(100, 10)
|
||||
Y = rng.rand(100, 10)
|
||||
metric = "euclidean"
|
||||
assert PairwiseDistancesReduction.is_usable_for(X, Y, metric)
|
||||
assert not PairwiseDistancesReduction.is_usable_for(
|
||||
X.astype(np.int64), Y.astype(np.int64), metric
|
||||
)
|
||||
|
||||
assert not PairwiseDistancesReduction.is_usable_for(X, Y, metric="pyfunc")
|
||||
# TODO: remove once 32 bits datasets are supported
|
||||
assert not PairwiseDistancesReduction.is_usable_for(X.astype(np.float32), Y, metric)
|
||||
assert not PairwiseDistancesReduction.is_usable_for(X, Y.astype(np.int32), metric)
|
||||
|
||||
# TODO: remove once sparse matrices are supported
|
||||
assert not PairwiseDistancesReduction.is_usable_for(csr_matrix(X), Y, metric)
|
||||
assert not PairwiseDistancesReduction.is_usable_for(X, csr_matrix(Y), metric)
|
||||
|
||||
|
||||
def test_argkmin_factory_method_wrong_usages():
|
||||
rng = np.random.RandomState(1)
|
||||
X = rng.rand(100, 10)
|
||||
Y = rng.rand(100, 10)
|
||||
k = 5
|
||||
metric = "euclidean"
|
||||
|
||||
msg = (
|
||||
"Only 64bit float datasets are supported at this time, "
|
||||
"got: X.dtype=float32 and Y.dtype=float64"
|
||||
)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
PairwiseDistancesArgKmin.compute(
|
||||
X=X.astype(np.float32), Y=Y, k=k, metric=metric
|
||||
)
|
||||
|
||||
msg = (
|
||||
"Only 64bit float datasets are supported at this time, "
|
||||
"got: X.dtype=float64 and Y.dtype=int32"
|
||||
)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
PairwiseDistancesArgKmin.compute(X=X, Y=Y.astype(np.int32), k=k, metric=metric)
|
||||
|
||||
with pytest.raises(ValueError, match="k == -1, must be >= 1."):
|
||||
PairwiseDistancesArgKmin.compute(X=X, Y=Y, k=-1, metric=metric)
|
||||
|
||||
with pytest.raises(ValueError, match="k == 0, must be >= 1."):
|
||||
PairwiseDistancesArgKmin.compute(X=X, Y=Y, k=0, metric=metric)
|
||||
|
||||
with pytest.raises(ValueError, match="Unrecognized metric"):
|
||||
PairwiseDistancesArgKmin.compute(X=X, Y=Y, k=k, metric="wrong metric")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r"Buffer has wrong number of dimensions \(expected 2, got 1\)"
|
||||
):
|
||||
PairwiseDistancesArgKmin.compute(
|
||||
X=np.array([1.0, 2.0]), Y=Y, k=k, metric=metric
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="ndarray is not C-contiguous"):
|
||||
PairwiseDistancesArgKmin.compute(
|
||||
X=np.asfortranarray(X), Y=Y, k=k, metric=metric
|
||||
)
|
||||
|
||||
unused_metric_kwargs = {"p": 3}
|
||||
|
||||
message = (
|
||||
r"Some metric_kwargs have been passed \({'p': 3}\) but aren't usable for this"
|
||||
r" case \("
|
||||
r"FastEuclideanPairwiseDistancesArgKmin\) and will be ignored."
|
||||
)
|
||||
|
||||
with pytest.warns(UserWarning, match=message):
|
||||
PairwiseDistancesArgKmin.compute(
|
||||
X=X, Y=Y, k=k, metric=metric, metric_kwargs=unused_metric_kwargs
|
||||
)
|
||||
|
||||
|
||||
def test_radius_neighborhood_factory_method_wrong_usages():
|
||||
rng = np.random.RandomState(1)
|
||||
X = rng.rand(100, 10)
|
||||
Y = rng.rand(100, 10)
|
||||
radius = 5
|
||||
metric = "euclidean"
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Only 64bit float datasets are supported at this time, "
|
||||
"got: X.dtype=float32 and Y.dtype=float64"
|
||||
),
|
||||
):
|
||||
PairwiseDistancesRadiusNeighborhood.compute(
|
||||
X=X.astype(np.float32), Y=Y, radius=radius, metric=metric
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Only 64bit float datasets are supported at this time, "
|
||||
"got: X.dtype=float64 and Y.dtype=int32"
|
||||
),
|
||||
):
|
||||
PairwiseDistancesRadiusNeighborhood.compute(
|
||||
X=X, Y=Y.astype(np.int32), radius=radius, metric=metric
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="radius == -1.0, must be >= 0."):
|
||||
PairwiseDistancesRadiusNeighborhood.compute(X=X, Y=Y, radius=-1, metric=metric)
|
||||
|
||||
with pytest.raises(ValueError, match="Unrecognized metric"):
|
||||
PairwiseDistancesRadiusNeighborhood.compute(
|
||||
X=X, Y=Y, radius=radius, metric="wrong metric"
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r"Buffer has wrong number of dimensions \(expected 2, got 1\)"
|
||||
):
|
||||
PairwiseDistancesRadiusNeighborhood.compute(
|
||||
X=np.array([1.0, 2.0]), Y=Y, radius=radius, metric=metric
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="ndarray is not C-contiguous"):
|
||||
PairwiseDistancesRadiusNeighborhood.compute(
|
||||
X=np.asfortranarray(X), Y=Y, radius=radius, metric=metric
|
||||
)
|
||||
|
||||
unused_metric_kwargs = {"p": 3}
|
||||
|
||||
message = (
|
||||
r"Some metric_kwargs have been passed \({'p': 3}\) but aren't usable for this"
|
||||
r" case \(FastEuclideanPairwiseDistancesRadiusNeighborhood\) and will be"
|
||||
r" ignored."
|
||||
)
|
||||
|
||||
with pytest.warns(UserWarning, match=message):
|
||||
PairwiseDistancesRadiusNeighborhood.compute(
|
||||
X=X, Y=Y, radius=radius, metric=metric, metric_kwargs=unused_metric_kwargs
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_samples", [100, 1000])
|
||||
@pytest.mark.parametrize("chunk_size", [50, 512, 1024])
|
||||
@pytest.mark.parametrize(
|
||||
"PairwiseDistancesReduction",
|
||||
[PairwiseDistancesArgKmin, PairwiseDistancesRadiusNeighborhood],
|
||||
)
|
||||
def test_chunk_size_agnosticism(
|
||||
global_random_seed,
|
||||
PairwiseDistancesReduction,
|
||||
n_samples,
|
||||
chunk_size,
|
||||
n_features=100,
|
||||
dtype=np.float64,
|
||||
):
|
||||
# Results should not depend on the chunk size
|
||||
rng = np.random.RandomState(global_random_seed)
|
||||
spread = 100
|
||||
X = rng.rand(n_samples, n_features).astype(dtype) * spread
|
||||
Y = rng.rand(n_samples, n_features).astype(dtype) * spread
|
||||
|
||||
parameter = (
|
||||
10
|
||||
if PairwiseDistancesReduction is PairwiseDistancesArgKmin
|
||||
# Scaling the radius slightly with the numbers of dimensions
|
||||
else 10 ** np.log(n_features)
|
||||
)
|
||||
|
||||
ref_dist, ref_indices = PairwiseDistancesReduction.compute(
|
||||
X,
|
||||
Y,
|
||||
parameter,
|
||||
return_distance=True,
|
||||
)
|
||||
|
||||
dist, indices = PairwiseDistancesReduction.compute(
|
||||
X,
|
||||
Y,
|
||||
parameter,
|
||||
chunk_size=chunk_size,
|
||||
return_distance=True,
|
||||
)
|
||||
|
||||
ASSERT_RESULT[PairwiseDistancesReduction](ref_dist, dist, ref_indices, indices)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_samples", [100, 1000])
|
||||
@pytest.mark.parametrize("chunk_size", [50, 512, 1024])
|
||||
@pytest.mark.parametrize(
|
||||
"PairwiseDistancesReduction",
|
||||
[PairwiseDistancesArgKmin, PairwiseDistancesRadiusNeighborhood],
|
||||
)
|
||||
def test_n_threads_agnosticism(
|
||||
global_random_seed,
|
||||
PairwiseDistancesReduction,
|
||||
n_samples,
|
||||
chunk_size,
|
||||
n_features=100,
|
||||
dtype=np.float64,
|
||||
):
|
||||
# Results should not depend on the number of threads
|
||||
rng = np.random.RandomState(global_random_seed)
|
||||
spread = 100
|
||||
X = rng.rand(n_samples, n_features).astype(dtype) * spread
|
||||
Y = rng.rand(n_samples, n_features).astype(dtype) * spread
|
||||
|
||||
parameter = (
|
||||
10
|
||||
if PairwiseDistancesReduction is PairwiseDistancesArgKmin
|
||||
# Scaling the radius slightly with the numbers of dimensions
|
||||
else 10 ** np.log(n_features)
|
||||
)
|
||||
|
||||
ref_dist, ref_indices = PairwiseDistancesReduction.compute(
|
||||
X,
|
||||
Y,
|
||||
parameter,
|
||||
return_distance=True,
|
||||
)
|
||||
|
||||
with threadpoolctl.threadpool_limits(limits=1, user_api="openmp"):
|
||||
dist, indices = PairwiseDistancesReduction.compute(
|
||||
X, Y, parameter, return_distance=True
|
||||
)
|
||||
|
||||
ASSERT_RESULT[PairwiseDistancesReduction](ref_dist, dist, ref_indices, indices)
|
||||
|
||||
|
||||
# TODO: Remove filterwarnings in 1.3 when wminkowski is removed
|
||||
@pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn")
|
||||
@pytest.mark.parametrize("n_samples", [100, 1000])
|
||||
@pytest.mark.parametrize("metric", PairwiseDistancesReduction.valid_metrics())
|
||||
@pytest.mark.parametrize(
|
||||
"PairwiseDistancesReduction",
|
||||
[PairwiseDistancesArgKmin, PairwiseDistancesRadiusNeighborhood],
|
||||
)
|
||||
def test_strategies_consistency(
|
||||
global_random_seed,
|
||||
PairwiseDistancesReduction,
|
||||
metric,
|
||||
n_samples,
|
||||
n_features=10,
|
||||
dtype=np.float64,
|
||||
):
|
||||
|
||||
rng = np.random.RandomState(global_random_seed)
|
||||
spread = 100
|
||||
X = rng.rand(n_samples, n_features).astype(dtype) * spread
|
||||
Y = rng.rand(n_samples, n_features).astype(dtype) * spread
|
||||
|
||||
# Haversine distance only accepts 2D data
|
||||
if metric == "haversine":
|
||||
X = np.ascontiguousarray(X[:, :2])
|
||||
Y = np.ascontiguousarray(Y[:, :2])
|
||||
|
||||
parameter = (
|
||||
10
|
||||
if PairwiseDistancesReduction is PairwiseDistancesArgKmin
|
||||
# Scaling the radius slightly with the numbers of dimensions
|
||||
else 10 ** np.log(n_features)
|
||||
)
|
||||
|
||||
dist_par_X, indices_par_X = PairwiseDistancesReduction.compute(
|
||||
X,
|
||||
Y,
|
||||
parameter,
|
||||
metric=metric,
|
||||
# Taking the first
|
||||
metric_kwargs=_get_metric_params_list(
|
||||
metric, n_features, seed=global_random_seed
|
||||
)[0],
|
||||
# To be sure to use parallelization
|
||||
chunk_size=n_samples // 4,
|
||||
strategy="parallel_on_X",
|
||||
return_distance=True,
|
||||
)
|
||||
|
||||
dist_par_Y, indices_par_Y = PairwiseDistancesReduction.compute(
|
||||
X,
|
||||
Y,
|
||||
parameter,
|
||||
metric=metric,
|
||||
# Taking the first
|
||||
metric_kwargs=_get_metric_params_list(
|
||||
metric, n_features, seed=global_random_seed
|
||||
)[0],
|
||||
# To be sure to use parallelization
|
||||
chunk_size=n_samples // 4,
|
||||
strategy="parallel_on_Y",
|
||||
return_distance=True,
|
||||
)
|
||||
|
||||
ASSERT_RESULT[PairwiseDistancesReduction](
|
||||
dist_par_X,
|
||||
dist_par_Y,
|
||||
indices_par_X,
|
||||
indices_par_Y,
|
||||
)
|
||||
|
||||
|
||||
# "Concrete PairwiseDistancesReductions"-specific tests
|
||||
|
||||
# TODO: Remove filterwarnings in 1.3 when wminkowski is removed
|
||||
@pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn")
|
||||
@pytest.mark.parametrize("n_features", [50, 500])
|
||||
@pytest.mark.parametrize("translation", [0, 1e6])
|
||||
@pytest.mark.parametrize("metric", CDIST_PAIRWISE_DISTANCES_REDUCTION_COMMON_METRICS)
|
||||
@pytest.mark.parametrize("strategy", ("parallel_on_X", "parallel_on_Y"))
|
||||
def test_pairwise_distances_argkmin(
|
||||
global_random_seed,
|
||||
n_features,
|
||||
translation,
|
||||
metric,
|
||||
strategy,
|
||||
n_samples=100,
|
||||
k=10,
|
||||
dtype=np.float64,
|
||||
):
|
||||
rng = np.random.RandomState(global_random_seed)
|
||||
spread = 1000
|
||||
X = translation + rng.rand(n_samples, n_features).astype(dtype) * spread
|
||||
Y = translation + rng.rand(n_samples, n_features).astype(dtype) * spread
|
||||
|
||||
# Haversine distance only accepts 2D data
|
||||
if metric == "haversine":
|
||||
X = np.ascontiguousarray(X[:, :2])
|
||||
Y = np.ascontiguousarray(Y[:, :2])
|
||||
|
||||
metric_kwargs = _get_metric_params_list(metric, n_features)[0]
|
||||
|
||||
# Reference for argkmin results
|
||||
if metric == "euclidean":
|
||||
# Compare to scikit-learn GEMM optimized implementation
|
||||
dist_matrix = euclidean_distances(X, Y)
|
||||
else:
|
||||
dist_matrix = cdist(X, Y, metric=metric, **metric_kwargs)
|
||||
# Taking argkmin (indices of the k smallest values)
|
||||
argkmin_indices_ref = np.argsort(dist_matrix, axis=1)[:, :k]
|
||||
# Getting the associated distances
|
||||
argkmin_distances_ref = np.zeros(argkmin_indices_ref.shape, dtype=np.float64)
|
||||
for row_idx in range(argkmin_indices_ref.shape[0]):
|
||||
argkmin_distances_ref[row_idx] = dist_matrix[
|
||||
row_idx, argkmin_indices_ref[row_idx]
|
||||
]
|
||||
|
||||
argkmin_distances, argkmin_indices = PairwiseDistancesArgKmin.compute(
|
||||
X,
|
||||
Y,
|
||||
k,
|
||||
metric=metric,
|
||||
metric_kwargs=metric_kwargs,
|
||||
return_distance=True,
|
||||
# So as to have more than a chunk, forcing parallelism.
|
||||
chunk_size=n_samples // 4,
|
||||
strategy=strategy,
|
||||
)
|
||||
|
||||
ASSERT_RESULT[PairwiseDistancesArgKmin](
|
||||
argkmin_distances, argkmin_distances_ref, argkmin_indices, argkmin_indices_ref
|
||||
)
|
||||
|
||||
|
||||
# TODO: Remove filterwarnings in 1.3 when wminkowski is removed
|
||||
@pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn")
|
||||
@pytest.mark.parametrize("n_features", [50, 500])
|
||||
@pytest.mark.parametrize("translation", [0, 1e6])
|
||||
@pytest.mark.parametrize("metric", CDIST_PAIRWISE_DISTANCES_REDUCTION_COMMON_METRICS)
|
||||
@pytest.mark.parametrize("strategy", ("parallel_on_X", "parallel_on_Y"))
|
||||
def test_pairwise_distances_radius_neighbors(
|
||||
global_random_seed,
|
||||
n_features,
|
||||
translation,
|
||||
metric,
|
||||
strategy,
|
||||
n_samples=100,
|
||||
dtype=np.float64,
|
||||
):
|
||||
rng = np.random.RandomState(global_random_seed)
|
||||
spread = 1000
|
||||
radius = spread * np.log(n_features)
|
||||
X = translation + rng.rand(n_samples, n_features).astype(dtype) * spread
|
||||
Y = translation + rng.rand(n_samples, n_features).astype(dtype) * spread
|
||||
|
||||
metric_kwargs = _get_metric_params_list(
|
||||
metric, n_features, seed=global_random_seed
|
||||
)[0]
|
||||
|
||||
# Reference for argkmin results
|
||||
if metric == "euclidean":
|
||||
# Compare to scikit-learn GEMM optimized implementation
|
||||
dist_matrix = euclidean_distances(X, Y)
|
||||
else:
|
||||
dist_matrix = cdist(X, Y, metric=metric, **metric_kwargs)
|
||||
|
||||
# Getting the neighbors for a given radius
|
||||
neigh_indices_ref = []
|
||||
neigh_distances_ref = []
|
||||
|
||||
for row in dist_matrix:
|
||||
ind = np.arange(row.shape[0])[row <= radius]
|
||||
dist = row[ind]
|
||||
|
||||
sort = np.argsort(dist)
|
||||
ind, dist = ind[sort], dist[sort]
|
||||
|
||||
neigh_indices_ref.append(ind)
|
||||
neigh_distances_ref.append(dist)
|
||||
|
||||
neigh_indices_ref = np.array(neigh_indices_ref)
|
||||
neigh_distances_ref = np.array(neigh_distances_ref)
|
||||
|
||||
neigh_distances, neigh_indices = PairwiseDistancesRadiusNeighborhood.compute(
|
||||
X,
|
||||
Y,
|
||||
radius,
|
||||
metric=metric,
|
||||
metric_kwargs=metric_kwargs,
|
||||
return_distance=True,
|
||||
# So as to have more than a chunk, forcing parallelism.
|
||||
chunk_size=n_samples // 4,
|
||||
strategy=strategy,
|
||||
sort_results=True,
|
||||
)
|
||||
|
||||
ASSERT_RESULT[PairwiseDistancesRadiusNeighborhood](
|
||||
neigh_distances, neigh_distances_ref, neigh_indices, neigh_indices_ref
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_samples", [100, 1000])
|
||||
@pytest.mark.parametrize("n_features", [5, 10, 100])
|
||||
@pytest.mark.parametrize("num_threads", [1, 2, 8])
|
||||
def test_sqeuclidean_row_norms(
|
||||
global_random_seed,
|
||||
n_samples,
|
||||
n_features,
|
||||
num_threads,
|
||||
dtype=np.float64,
|
||||
):
|
||||
rng = np.random.RandomState(global_random_seed)
|
||||
spread = 100
|
||||
X = rng.rand(n_samples, n_features).astype(dtype) * spread
|
||||
|
||||
sq_row_norm_reference = np.linalg.norm(X, axis=1) ** 2
|
||||
sq_row_norm = np.asarray(_sqeuclidean_row_norms(X, num_threads=num_threads))
|
||||
|
||||
assert_allclose(sq_row_norm_reference, sq_row_norm)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,615 @@
|
||||
import numpy as np
|
||||
from scipy import optimize
|
||||
from numpy.testing import assert_allclose
|
||||
from scipy.special import factorial, xlogy
|
||||
from itertools import product
|
||||
import pytest
|
||||
|
||||
from sklearn.utils._testing import assert_almost_equal
|
||||
from sklearn.utils._testing import assert_array_equal
|
||||
from sklearn.utils._testing import assert_array_almost_equal
|
||||
from sklearn.dummy import DummyRegressor
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
|
||||
from sklearn.metrics import explained_variance_score
|
||||
from sklearn.metrics import mean_absolute_error
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from sklearn.metrics import mean_squared_log_error
|
||||
from sklearn.metrics import median_absolute_error
|
||||
from sklearn.metrics import mean_absolute_percentage_error
|
||||
from sklearn.metrics import max_error
|
||||
from sklearn.metrics import mean_pinball_loss
|
||||
from sklearn.metrics import r2_score
|
||||
from sklearn.metrics import mean_tweedie_deviance
|
||||
from sklearn.metrics import d2_tweedie_score
|
||||
from sklearn.metrics import d2_pinball_score
|
||||
from sklearn.metrics import d2_absolute_error_score
|
||||
from sklearn.metrics import make_scorer
|
||||
|
||||
from sklearn.metrics._regression import _check_reg_targets
|
||||
|
||||
from sklearn.exceptions import UndefinedMetricWarning
|
||||
|
||||
|
||||
def test_regression_metrics(n_samples=50):
|
||||
y_true = np.arange(n_samples)
|
||||
y_pred = y_true + 1
|
||||
y_pred_2 = y_true - 1
|
||||
|
||||
assert_almost_equal(mean_squared_error(y_true, y_pred), 1.0)
|
||||
assert_almost_equal(
|
||||
mean_squared_log_error(y_true, y_pred),
|
||||
mean_squared_error(np.log(1 + y_true), np.log(1 + y_pred)),
|
||||
)
|
||||
assert_almost_equal(mean_absolute_error(y_true, y_pred), 1.0)
|
||||
assert_almost_equal(mean_pinball_loss(y_true, y_pred), 0.5)
|
||||
assert_almost_equal(mean_pinball_loss(y_true, y_pred_2), 0.5)
|
||||
assert_almost_equal(mean_pinball_loss(y_true, y_pred, alpha=0.4), 0.6)
|
||||
assert_almost_equal(mean_pinball_loss(y_true, y_pred_2, alpha=0.4), 0.4)
|
||||
assert_almost_equal(median_absolute_error(y_true, y_pred), 1.0)
|
||||
mape = mean_absolute_percentage_error(y_true, y_pred)
|
||||
assert np.isfinite(mape)
|
||||
assert mape > 1e6
|
||||
assert_almost_equal(max_error(y_true, y_pred), 1.0)
|
||||
assert_almost_equal(r2_score(y_true, y_pred), 0.995, 2)
|
||||
assert_almost_equal(r2_score(y_true, y_pred, force_finite=False), 0.995, 2)
|
||||
assert_almost_equal(explained_variance_score(y_true, y_pred), 1.0)
|
||||
assert_almost_equal(
|
||||
explained_variance_score(y_true, y_pred, force_finite=False), 1.0
|
||||
)
|
||||
assert_almost_equal(
|
||||
mean_tweedie_deviance(y_true, y_pred, power=0),
|
||||
mean_squared_error(y_true, y_pred),
|
||||
)
|
||||
assert_almost_equal(
|
||||
d2_tweedie_score(y_true, y_pred, power=0), r2_score(y_true, y_pred)
|
||||
)
|
||||
dev_median = np.abs(y_true - np.median(y_true)).sum()
|
||||
assert_array_almost_equal(
|
||||
d2_absolute_error_score(y_true, y_pred),
|
||||
1 - np.abs(y_true - y_pred).sum() / dev_median,
|
||||
)
|
||||
alpha = 0.2
|
||||
pinball_loss = lambda y_true, y_pred, alpha: alpha * np.maximum(
|
||||
y_true - y_pred, 0
|
||||
) + (1 - alpha) * np.maximum(y_pred - y_true, 0)
|
||||
y_quantile = np.percentile(y_true, q=alpha * 100)
|
||||
assert_almost_equal(
|
||||
d2_pinball_score(y_true, y_pred, alpha=alpha),
|
||||
1
|
||||
- pinball_loss(y_true, y_pred, alpha).sum()
|
||||
/ pinball_loss(y_true, y_quantile, alpha).sum(),
|
||||
)
|
||||
assert_almost_equal(
|
||||
d2_absolute_error_score(y_true, y_pred),
|
||||
d2_pinball_score(y_true, y_pred, alpha=0.5),
|
||||
)
|
||||
|
||||
# Tweedie deviance needs positive y_pred, except for p=0,
|
||||
# p>=2 needs positive y_true
|
||||
# results evaluated by sympy
|
||||
y_true = np.arange(1, 1 + n_samples)
|
||||
y_pred = 2 * y_true
|
||||
n = n_samples
|
||||
assert_almost_equal(
|
||||
mean_tweedie_deviance(y_true, y_pred, power=-1),
|
||||
5 / 12 * n * (n**2 + 2 * n + 1),
|
||||
)
|
||||
assert_almost_equal(
|
||||
mean_tweedie_deviance(y_true, y_pred, power=1), (n + 1) * (1 - np.log(2))
|
||||
)
|
||||
assert_almost_equal(
|
||||
mean_tweedie_deviance(y_true, y_pred, power=2), 2 * np.log(2) - 1
|
||||
)
|
||||
assert_almost_equal(
|
||||
mean_tweedie_deviance(y_true, y_pred, power=3 / 2),
|
||||
((6 * np.sqrt(2) - 8) / n) * np.sqrt(y_true).sum(),
|
||||
)
|
||||
assert_almost_equal(
|
||||
mean_tweedie_deviance(y_true, y_pred, power=3), np.sum(1 / y_true) / (4 * n)
|
||||
)
|
||||
|
||||
dev_mean = 2 * np.mean(xlogy(y_true, 2 * y_true / (n + 1)))
|
||||
assert_almost_equal(
|
||||
d2_tweedie_score(y_true, y_pred, power=1),
|
||||
1 - (n + 1) * (1 - np.log(2)) / dev_mean,
|
||||
)
|
||||
|
||||
dev_mean = 2 * np.log((n + 1) / 2) - 2 / n * np.log(factorial(n))
|
||||
assert_almost_equal(
|
||||
d2_tweedie_score(y_true, y_pred, power=2), 1 - (2 * np.log(2) - 1) / dev_mean
|
||||
)
|
||||
|
||||
|
||||
def test_mean_squared_error_multioutput_raw_value_squared():
|
||||
# non-regression test for
|
||||
# https://github.com/scikit-learn/scikit-learn/pull/16323
|
||||
mse1 = mean_squared_error([[1]], [[10]], multioutput="raw_values", squared=True)
|
||||
mse2 = mean_squared_error([[1]], [[10]], multioutput="raw_values", squared=False)
|
||||
assert np.sqrt(mse1) == pytest.approx(mse2)
|
||||
|
||||
|
||||
def test_multioutput_regression():
|
||||
y_true = np.array([[1, 0, 0, 1], [0, 1, 1, 1], [1, 1, 0, 1]])
|
||||
y_pred = np.array([[0, 0, 0, 1], [1, 0, 1, 1], [0, 0, 0, 1]])
|
||||
|
||||
error = mean_squared_error(y_true, y_pred)
|
||||
assert_almost_equal(error, (1.0 / 3 + 2.0 / 3 + 2.0 / 3) / 4.0)
|
||||
|
||||
error = mean_squared_error(y_true, y_pred, squared=False)
|
||||
assert_almost_equal(error, 0.454, decimal=2)
|
||||
|
||||
error = mean_squared_log_error(y_true, y_pred)
|
||||
assert_almost_equal(error, 0.200, decimal=2)
|
||||
|
||||
# mean_absolute_error and mean_squared_error are equal because
|
||||
# it is a binary problem.
|
||||
error = mean_absolute_error(y_true, y_pred)
|
||||
assert_almost_equal(error, (1.0 + 2.0 / 3) / 4.0)
|
||||
|
||||
error = mean_pinball_loss(y_true, y_pred)
|
||||
assert_almost_equal(error, (1.0 + 2.0 / 3) / 8.0)
|
||||
|
||||
error = np.around(mean_absolute_percentage_error(y_true, y_pred), decimals=2)
|
||||
assert np.isfinite(error)
|
||||
assert error > 1e6
|
||||
error = median_absolute_error(y_true, y_pred)
|
||||
assert_almost_equal(error, (1.0 + 1.0) / 4.0)
|
||||
|
||||
error = r2_score(y_true, y_pred, multioutput="variance_weighted")
|
||||
assert_almost_equal(error, 1.0 - 5.0 / 2)
|
||||
error = r2_score(y_true, y_pred, multioutput="uniform_average")
|
||||
assert_almost_equal(error, -0.875)
|
||||
|
||||
score = d2_pinball_score(y_true, y_pred, alpha=0.5, multioutput="raw_values")
|
||||
raw_expected_score = [
|
||||
1
|
||||
- np.abs(y_true[:, i] - y_pred[:, i]).sum()
|
||||
/ np.abs(y_true[:, i] - np.median(y_true[:, i])).sum()
|
||||
for i in range(y_true.shape[1])
|
||||
]
|
||||
# in the last case, the denominator vanishes and hence we get nan,
|
||||
# but since the numerator vanishes as well the expected score is 1.0
|
||||
raw_expected_score = np.where(np.isnan(raw_expected_score), 1, raw_expected_score)
|
||||
assert_array_almost_equal(score, raw_expected_score)
|
||||
|
||||
score = d2_pinball_score(y_true, y_pred, alpha=0.5, multioutput="uniform_average")
|
||||
assert_almost_equal(score, raw_expected_score.mean())
|
||||
# constant `y_true` with force_finite=True leads to 1. or 0.
|
||||
yc = [5.0, 5.0]
|
||||
error = r2_score(yc, [5.0, 5.0], multioutput="variance_weighted")
|
||||
assert_almost_equal(error, 1.0)
|
||||
error = r2_score(yc, [5.0, 5.1], multioutput="variance_weighted")
|
||||
assert_almost_equal(error, 0.0)
|
||||
|
||||
# Setting force_finite=False results in the nan for 4th output propagating
|
||||
error = r2_score(
|
||||
y_true, y_pred, multioutput="variance_weighted", force_finite=False
|
||||
)
|
||||
assert_almost_equal(error, np.nan)
|
||||
error = r2_score(y_true, y_pred, multioutput="uniform_average", force_finite=False)
|
||||
assert_almost_equal(error, np.nan)
|
||||
|
||||
# Dropping the 4th output to check `force_finite=False` for nominal
|
||||
y_true = y_true[:, :-1]
|
||||
y_pred = y_pred[:, :-1]
|
||||
error = r2_score(y_true, y_pred, multioutput="variance_weighted")
|
||||
error2 = r2_score(
|
||||
y_true, y_pred, multioutput="variance_weighted", force_finite=False
|
||||
)
|
||||
assert_almost_equal(error, error2)
|
||||
error = r2_score(y_true, y_pred, multioutput="uniform_average")
|
||||
error2 = r2_score(y_true, y_pred, multioutput="uniform_average", force_finite=False)
|
||||
assert_almost_equal(error, error2)
|
||||
|
||||
# constant `y_true` with force_finite=False leads to NaN or -Inf.
|
||||
error = r2_score(
|
||||
yc, [5.0, 5.0], multioutput="variance_weighted", force_finite=False
|
||||
)
|
||||
assert_almost_equal(error, np.nan)
|
||||
error = r2_score(
|
||||
yc, [5.0, 6.0], multioutput="variance_weighted", force_finite=False
|
||||
)
|
||||
assert_almost_equal(error, -np.inf)
|
||||
|
||||
|
||||
def test_regression_metrics_at_limits():
|
||||
# Single-sample case
|
||||
# Note: for r2 and d2_tweedie see also test_regression_single_sample
|
||||
assert_almost_equal(mean_squared_error([0.0], [0.0]), 0.0)
|
||||
assert_almost_equal(mean_squared_error([0.0], [0.0], squared=False), 0.0)
|
||||
assert_almost_equal(mean_squared_log_error([0.0], [0.0]), 0.0)
|
||||
assert_almost_equal(mean_absolute_error([0.0], [0.0]), 0.0)
|
||||
assert_almost_equal(mean_pinball_loss([0.0], [0.0]), 0.0)
|
||||
assert_almost_equal(mean_absolute_percentage_error([0.0], [0.0]), 0.0)
|
||||
assert_almost_equal(median_absolute_error([0.0], [0.0]), 0.0)
|
||||
assert_almost_equal(max_error([0.0], [0.0]), 0.0)
|
||||
assert_almost_equal(explained_variance_score([0.0], [0.0]), 1.0)
|
||||
|
||||
# Perfect cases
|
||||
assert_almost_equal(r2_score([0.0, 1], [0.0, 1]), 1.0)
|
||||
assert_almost_equal(d2_pinball_score([0.0, 1], [0.0, 1]), 1.0)
|
||||
|
||||
# Non-finite cases
|
||||
# R² and explained variance have a fix by default for non-finite cases
|
||||
for s in (r2_score, explained_variance_score):
|
||||
assert_almost_equal(s([0, 0], [1, -1]), 0.0)
|
||||
assert_almost_equal(s([0, 0], [1, -1], force_finite=False), -np.inf)
|
||||
assert_almost_equal(s([1, 1], [1, 1]), 1.0)
|
||||
assert_almost_equal(s([1, 1], [1, 1], force_finite=False), np.nan)
|
||||
msg = (
|
||||
"Mean Squared Logarithmic Error cannot be used when targets "
|
||||
"contain negative values."
|
||||
)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
mean_squared_log_error([-1.0], [-1.0])
|
||||
msg = (
|
||||
"Mean Squared Logarithmic Error cannot be used when targets "
|
||||
"contain negative values."
|
||||
)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
mean_squared_log_error([1.0, 2.0, 3.0], [1.0, -2.0, 3.0])
|
||||
msg = (
|
||||
"Mean Squared Logarithmic Error cannot be used when targets "
|
||||
"contain negative values."
|
||||
)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
mean_squared_log_error([1.0, -2.0, 3.0], [1.0, 2.0, 3.0])
|
||||
|
||||
# Tweedie deviance error
|
||||
power = -1.2
|
||||
assert_allclose(
|
||||
mean_tweedie_deviance([0], [1.0], power=power), 2 / (2 - power), rtol=1e-3
|
||||
)
|
||||
msg = "can only be used on strictly positive y_pred."
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
mean_tweedie_deviance([0.0], [0.0], power=power)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
d2_tweedie_score([0.0] * 2, [0.0] * 2, power=power)
|
||||
|
||||
assert_almost_equal(mean_tweedie_deviance([0.0], [0.0], power=0), 0.0, 2)
|
||||
|
||||
power = 1.0
|
||||
msg = "only be used on non-negative y and strictly positive y_pred."
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
mean_tweedie_deviance([0.0], [0.0], power=power)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
d2_tweedie_score([0.0] * 2, [0.0] * 2, power=power)
|
||||
|
||||
power = 1.5
|
||||
assert_allclose(mean_tweedie_deviance([0.0], [1.0], power=power), 2 / (2 - power))
|
||||
msg = "only be used on non-negative y and strictly positive y_pred."
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
mean_tweedie_deviance([0.0], [0.0], power=power)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
d2_tweedie_score([0.0] * 2, [0.0] * 2, power=power)
|
||||
|
||||
power = 2.0
|
||||
assert_allclose(mean_tweedie_deviance([1.0], [1.0], power=power), 0.00, atol=1e-8)
|
||||
msg = "can only be used on strictly positive y and y_pred."
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
mean_tweedie_deviance([0.0], [0.0], power=power)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
d2_tweedie_score([0.0] * 2, [0.0] * 2, power=power)
|
||||
|
||||
power = 3.0
|
||||
assert_allclose(mean_tweedie_deviance([1.0], [1.0], power=power), 0.00, atol=1e-8)
|
||||
msg = "can only be used on strictly positive y and y_pred."
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
mean_tweedie_deviance([0.0], [0.0], power=power)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
d2_tweedie_score([0.0] * 2, [0.0] * 2, power=power)
|
||||
|
||||
power = 0.5
|
||||
with pytest.raises(ValueError, match="is only defined for power<=0 and power>=1"):
|
||||
mean_tweedie_deviance([0.0], [0.0], power=power)
|
||||
with pytest.raises(ValueError, match="is only defined for power<=0 and power>=1"):
|
||||
d2_tweedie_score([0.0] * 2, [0.0] * 2, power=power)
|
||||
|
||||
|
||||
def test__check_reg_targets():
|
||||
# All of length 3
|
||||
EXAMPLES = [
|
||||
("continuous", [1, 2, 3], 1),
|
||||
("continuous", [[1], [2], [3]], 1),
|
||||
("continuous-multioutput", [[1, 1], [2, 2], [3, 1]], 2),
|
||||
("continuous-multioutput", [[5, 1], [4, 2], [3, 1]], 2),
|
||||
("continuous-multioutput", [[1, 3, 4], [2, 2, 2], [3, 1, 1]], 3),
|
||||
]
|
||||
|
||||
for (type1, y1, n_out1), (type2, y2, n_out2) in product(EXAMPLES, repeat=2):
|
||||
|
||||
if type1 == type2 and n_out1 == n_out2:
|
||||
y_type, y_check1, y_check2, multioutput = _check_reg_targets(y1, y2, None)
|
||||
assert type1 == y_type
|
||||
if type1 == "continuous":
|
||||
assert_array_equal(y_check1, np.reshape(y1, (-1, 1)))
|
||||
assert_array_equal(y_check2, np.reshape(y2, (-1, 1)))
|
||||
else:
|
||||
assert_array_equal(y_check1, y1)
|
||||
assert_array_equal(y_check2, y2)
|
||||
else:
|
||||
with pytest.raises(ValueError):
|
||||
_check_reg_targets(y1, y2, None)
|
||||
|
||||
|
||||
def test__check_reg_targets_exception():
|
||||
invalid_multioutput = "this_value_is_not_valid"
|
||||
expected_message = (
|
||||
"Allowed 'multioutput' string values are.+You provided multioutput={!r}".format(
|
||||
invalid_multioutput
|
||||
)
|
||||
)
|
||||
with pytest.raises(ValueError, match=expected_message):
|
||||
_check_reg_targets([1, 2, 3], [[1], [2], [3]], invalid_multioutput)
|
||||
|
||||
|
||||
def test_regression_multioutput_array():
|
||||
y_true = [[1, 2], [2.5, -1], [4.5, 3], [5, 7]]
|
||||
y_pred = [[1, 1], [2, -1], [5, 4], [5, 6.5]]
|
||||
|
||||
mse = mean_squared_error(y_true, y_pred, multioutput="raw_values")
|
||||
mae = mean_absolute_error(y_true, y_pred, multioutput="raw_values")
|
||||
err_msg = (
|
||||
"multioutput is expected to be 'raw_values' "
|
||||
"or 'uniform_average' but we got 'variance_weighted' instead."
|
||||
)
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
mean_pinball_loss(y_true, y_pred, multioutput="variance_weighted")
|
||||
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
d2_pinball_score(y_true, y_pred, multioutput="variance_weighted")
|
||||
|
||||
pbl = mean_pinball_loss(y_true, y_pred, multioutput="raw_values")
|
||||
mape = mean_absolute_percentage_error(y_true, y_pred, multioutput="raw_values")
|
||||
r = r2_score(y_true, y_pred, multioutput="raw_values")
|
||||
evs = explained_variance_score(y_true, y_pred, multioutput="raw_values")
|
||||
d2ps = d2_pinball_score(y_true, y_pred, alpha=0.5, multioutput="raw_values")
|
||||
evs2 = explained_variance_score(
|
||||
y_true, y_pred, multioutput="raw_values", force_finite=False
|
||||
)
|
||||
|
||||
assert_array_almost_equal(mse, [0.125, 0.5625], decimal=2)
|
||||
assert_array_almost_equal(mae, [0.25, 0.625], decimal=2)
|
||||
assert_array_almost_equal(pbl, [0.25 / 2, 0.625 / 2], decimal=2)
|
||||
assert_array_almost_equal(mape, [0.0778, 0.2262], decimal=2)
|
||||
assert_array_almost_equal(r, [0.95, 0.93], decimal=2)
|
||||
assert_array_almost_equal(evs, [0.95, 0.93], decimal=2)
|
||||
assert_array_almost_equal(d2ps, [0.833, 0.722], decimal=2)
|
||||
assert_array_almost_equal(evs2, [0.95, 0.93], decimal=2)
|
||||
|
||||
# mean_absolute_error and mean_squared_error are equal because
|
||||
# it is a binary problem.
|
||||
y_true = [[0, 0]] * 4
|
||||
y_pred = [[1, 1]] * 4
|
||||
mse = mean_squared_error(y_true, y_pred, multioutput="raw_values")
|
||||
mae = mean_absolute_error(y_true, y_pred, multioutput="raw_values")
|
||||
pbl = mean_pinball_loss(y_true, y_pred, multioutput="raw_values")
|
||||
r = r2_score(y_true, y_pred, multioutput="raw_values")
|
||||
d2ps = d2_pinball_score(y_true, y_pred, multioutput="raw_values")
|
||||
assert_array_almost_equal(mse, [1.0, 1.0], decimal=2)
|
||||
assert_array_almost_equal(mae, [1.0, 1.0], decimal=2)
|
||||
assert_array_almost_equal(pbl, [0.5, 0.5], decimal=2)
|
||||
assert_array_almost_equal(r, [0.0, 0.0], decimal=2)
|
||||
assert_array_almost_equal(d2ps, [0.0, 0.0], decimal=2)
|
||||
|
||||
r = r2_score([[0, -1], [0, 1]], [[2, 2], [1, 1]], multioutput="raw_values")
|
||||
assert_array_almost_equal(r, [0, -3.5], decimal=2)
|
||||
assert np.mean(r) == r2_score(
|
||||
[[0, -1], [0, 1]], [[2, 2], [1, 1]], multioutput="uniform_average"
|
||||
)
|
||||
evs = explained_variance_score(
|
||||
[[0, -1], [0, 1]], [[2, 2], [1, 1]], multioutput="raw_values"
|
||||
)
|
||||
assert_array_almost_equal(evs, [0, -1.25], decimal=2)
|
||||
evs2 = explained_variance_score(
|
||||
[[0, -1], [0, 1]],
|
||||
[[2, 2], [1, 1]],
|
||||
multioutput="raw_values",
|
||||
force_finite=False,
|
||||
)
|
||||
assert_array_almost_equal(evs2, [-np.inf, -1.25], decimal=2)
|
||||
|
||||
# Checking for the condition in which both numerator and denominator is
|
||||
# zero.
|
||||
y_true = [[1, 3], [1, 2]]
|
||||
y_pred = [[1, 4], [1, 1]]
|
||||
r2 = r2_score(y_true, y_pred, multioutput="raw_values")
|
||||
assert_array_almost_equal(r2, [1.0, -3.0], decimal=2)
|
||||
assert np.mean(r2) == r2_score(y_true, y_pred, multioutput="uniform_average")
|
||||
r22 = r2_score(y_true, y_pred, multioutput="raw_values", force_finite=False)
|
||||
assert_array_almost_equal(r22, [np.nan, -3.0], decimal=2)
|
||||
assert_almost_equal(
|
||||
np.mean(r22),
|
||||
r2_score(y_true, y_pred, multioutput="uniform_average", force_finite=False),
|
||||
)
|
||||
|
||||
evs = explained_variance_score(y_true, y_pred, multioutput="raw_values")
|
||||
assert_array_almost_equal(evs, [1.0, -3.0], decimal=2)
|
||||
assert np.mean(evs) == explained_variance_score(y_true, y_pred)
|
||||
d2ps = d2_pinball_score(y_true, y_pred, alpha=0.5, multioutput="raw_values")
|
||||
assert_array_almost_equal(d2ps, [1.0, -1.0], decimal=2)
|
||||
evs2 = explained_variance_score(
|
||||
y_true, y_pred, multioutput="raw_values", force_finite=False
|
||||
)
|
||||
assert_array_almost_equal(evs2, [np.nan, -3.0], decimal=2)
|
||||
assert_almost_equal(
|
||||
np.mean(evs2), explained_variance_score(y_true, y_pred, force_finite=False)
|
||||
)
|
||||
|
||||
# Handling msle separately as it does not accept negative inputs.
|
||||
y_true = np.array([[0.5, 1], [1, 2], [7, 6]])
|
||||
y_pred = np.array([[0.5, 2], [1, 2.5], [8, 8]])
|
||||
msle = mean_squared_log_error(y_true, y_pred, multioutput="raw_values")
|
||||
msle2 = mean_squared_error(
|
||||
np.log(1 + y_true), np.log(1 + y_pred), multioutput="raw_values"
|
||||
)
|
||||
assert_array_almost_equal(msle, msle2, decimal=2)
|
||||
|
||||
|
||||
def test_regression_custom_weights():
|
||||
y_true = [[1, 2], [2.5, -1], [4.5, 3], [5, 7]]
|
||||
y_pred = [[1, 1], [2, -1], [5, 4], [5, 6.5]]
|
||||
|
||||
msew = mean_squared_error(y_true, y_pred, multioutput=[0.4, 0.6])
|
||||
rmsew = mean_squared_error(y_true, y_pred, multioutput=[0.4, 0.6], squared=False)
|
||||
maew = mean_absolute_error(y_true, y_pred, multioutput=[0.4, 0.6])
|
||||
mapew = mean_absolute_percentage_error(y_true, y_pred, multioutput=[0.4, 0.6])
|
||||
rw = r2_score(y_true, y_pred, multioutput=[0.4, 0.6])
|
||||
evsw = explained_variance_score(y_true, y_pred, multioutput=[0.4, 0.6])
|
||||
d2psw = d2_pinball_score(y_true, y_pred, alpha=0.5, multioutput=[0.4, 0.6])
|
||||
evsw2 = explained_variance_score(
|
||||
y_true, y_pred, multioutput=[0.4, 0.6], force_finite=False
|
||||
)
|
||||
|
||||
assert_almost_equal(msew, 0.39, decimal=2)
|
||||
assert_almost_equal(rmsew, 0.59, decimal=2)
|
||||
assert_almost_equal(maew, 0.475, decimal=3)
|
||||
assert_almost_equal(mapew, 0.1668, decimal=2)
|
||||
assert_almost_equal(rw, 0.94, decimal=2)
|
||||
assert_almost_equal(evsw, 0.94, decimal=2)
|
||||
assert_almost_equal(d2psw, 0.766, decimal=2)
|
||||
assert_almost_equal(evsw2, 0.94, decimal=2)
|
||||
|
||||
# Handling msle separately as it does not accept negative inputs.
|
||||
y_true = np.array([[0.5, 1], [1, 2], [7, 6]])
|
||||
y_pred = np.array([[0.5, 2], [1, 2.5], [8, 8]])
|
||||
msle = mean_squared_log_error(y_true, y_pred, multioutput=[0.3, 0.7])
|
||||
msle2 = mean_squared_error(
|
||||
np.log(1 + y_true), np.log(1 + y_pred), multioutput=[0.3, 0.7]
|
||||
)
|
||||
assert_almost_equal(msle, msle2, decimal=2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("metric", [r2_score, d2_tweedie_score, d2_pinball_score])
|
||||
def test_regression_single_sample(metric):
|
||||
y_true = [0]
|
||||
y_pred = [1]
|
||||
warning_msg = "not well-defined with less than two samples."
|
||||
|
||||
# Trigger the warning
|
||||
with pytest.warns(UndefinedMetricWarning, match=warning_msg):
|
||||
score = metric(y_true, y_pred)
|
||||
assert np.isnan(score)
|
||||
|
||||
|
||||
def test_tweedie_deviance_continuity():
|
||||
n_samples = 100
|
||||
|
||||
y_true = np.random.RandomState(0).rand(n_samples) + 0.1
|
||||
y_pred = np.random.RandomState(1).rand(n_samples) + 0.1
|
||||
|
||||
assert_allclose(
|
||||
mean_tweedie_deviance(y_true, y_pred, power=0 - 1e-10),
|
||||
mean_tweedie_deviance(y_true, y_pred, power=0),
|
||||
)
|
||||
|
||||
# Ws we get closer to the limit, with 1e-12 difference the absolute
|
||||
# tolerance to pass the below check increases. There are likely
|
||||
# numerical precision issues on the edges of different definition
|
||||
# regions.
|
||||
assert_allclose(
|
||||
mean_tweedie_deviance(y_true, y_pred, power=1 + 1e-10),
|
||||
mean_tweedie_deviance(y_true, y_pred, power=1),
|
||||
atol=1e-6,
|
||||
)
|
||||
|
||||
assert_allclose(
|
||||
mean_tweedie_deviance(y_true, y_pred, power=2 - 1e-10),
|
||||
mean_tweedie_deviance(y_true, y_pred, power=2),
|
||||
atol=1e-6,
|
||||
)
|
||||
|
||||
assert_allclose(
|
||||
mean_tweedie_deviance(y_true, y_pred, power=2 + 1e-10),
|
||||
mean_tweedie_deviance(y_true, y_pred, power=2),
|
||||
atol=1e-6,
|
||||
)
|
||||
|
||||
|
||||
def test_mean_absolute_percentage_error():
|
||||
random_number_generator = np.random.RandomState(42)
|
||||
y_true = random_number_generator.exponential(size=100)
|
||||
y_pred = 1.2 * y_true
|
||||
assert mean_absolute_percentage_error(y_true, y_pred) == pytest.approx(0.2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"distribution", ["normal", "lognormal", "exponential", "uniform"]
|
||||
)
|
||||
@pytest.mark.parametrize("target_quantile", [0.05, 0.5, 0.75])
|
||||
def test_mean_pinball_loss_on_constant_predictions(distribution, target_quantile):
|
||||
if not hasattr(np, "quantile"):
|
||||
pytest.skip(
|
||||
"This test requires a more recent version of numpy "
|
||||
"with support for np.quantile."
|
||||
)
|
||||
|
||||
# Check that the pinball loss is minimized by the empirical quantile.
|
||||
n_samples = 3000
|
||||
rng = np.random.RandomState(42)
|
||||
data = getattr(rng, distribution)(size=n_samples)
|
||||
|
||||
# Compute the best possible pinball loss for any constant predictor:
|
||||
best_pred = np.quantile(data, target_quantile)
|
||||
best_constant_pred = np.full(n_samples, fill_value=best_pred)
|
||||
best_pbl = mean_pinball_loss(data, best_constant_pred, alpha=target_quantile)
|
||||
|
||||
# Evaluate the loss on a grid of quantiles
|
||||
candidate_predictions = np.quantile(data, np.linspace(0, 1, 100))
|
||||
for pred in candidate_predictions:
|
||||
# Compute the pinball loss of a constant predictor:
|
||||
constant_pred = np.full(n_samples, fill_value=pred)
|
||||
pbl = mean_pinball_loss(data, constant_pred, alpha=target_quantile)
|
||||
|
||||
# Check that the loss of this constant predictor is greater or equal
|
||||
# than the loss of using the optimal quantile (up to machine
|
||||
# precision):
|
||||
assert pbl >= best_pbl - np.finfo(best_pbl.dtype).eps
|
||||
|
||||
# Check that the value of the pinball loss matches the analytical
|
||||
# formula.
|
||||
expected_pbl = (pred - data[data < pred]).sum() * (1 - target_quantile) + (
|
||||
data[data >= pred] - pred
|
||||
).sum() * target_quantile
|
||||
expected_pbl /= n_samples
|
||||
assert_almost_equal(expected_pbl, pbl)
|
||||
|
||||
# Check that we can actually recover the target_quantile by minimizing the
|
||||
# pinball loss w.r.t. the constant prediction quantile.
|
||||
def objective_func(x):
|
||||
constant_pred = np.full(n_samples, fill_value=x)
|
||||
return mean_pinball_loss(data, constant_pred, alpha=target_quantile)
|
||||
|
||||
result = optimize.minimize(objective_func, data.mean(), method="Nelder-Mead")
|
||||
assert result.success
|
||||
# The minimum is not unique with limited data, hence the large tolerance.
|
||||
assert result.x == pytest.approx(best_pred, rel=1e-2)
|
||||
assert result.fun == pytest.approx(best_pbl)
|
||||
|
||||
|
||||
def test_dummy_quantile_parameter_tuning():
|
||||
# Integration test to check that it is possible to use the pinball loss to
|
||||
# tune the hyperparameter of a quantile regressor. This is conceptually
|
||||
# similar to the previous test but using the scikit-learn estimator and
|
||||
# scoring API instead.
|
||||
n_samples = 1000
|
||||
rng = np.random.RandomState(0)
|
||||
X = rng.normal(size=(n_samples, 5)) # Ignored
|
||||
y = rng.exponential(size=n_samples)
|
||||
|
||||
all_quantiles = [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]
|
||||
for alpha in all_quantiles:
|
||||
neg_mean_pinball_loss = make_scorer(
|
||||
mean_pinball_loss,
|
||||
alpha=alpha,
|
||||
greater_is_better=False,
|
||||
)
|
||||
regressor = DummyRegressor(strategy="quantile", quantile=0.25)
|
||||
grid_search = GridSearchCV(
|
||||
regressor,
|
||||
param_grid=dict(quantile=all_quantiles),
|
||||
scoring=neg_mean_pinball_loss,
|
||||
).fit(X, y)
|
||||
|
||||
assert grid_search.best_params_["quantile"] == pytest.approx(alpha)
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user