first commit

This commit is contained in:
Carla Floricel
2022-08-02 09:52:52 -04:00
parent 417ea8660b
commit 05e52aa52b
10444 changed files with 2300232 additions and 0 deletions

View File

@@ -0,0 +1,22 @@
"""
The :mod:`sklearn.tree` module includes decision tree-based models for
classification and regression.
"""
from ._classes import BaseDecisionTree
from ._classes import DecisionTreeClassifier
from ._classes import DecisionTreeRegressor
from ._classes import ExtraTreeClassifier
from ._classes import ExtraTreeRegressor
from ._export import export_graphviz, plot_tree, export_text
__all__ = [
"BaseDecisionTree",
"DecisionTreeClassifier",
"DecisionTreeRegressor",
"ExtraTreeClassifier",
"ExtraTreeRegressor",
"export_graphviz",
"plot_tree",
"export_text",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,79 @@
# Authors: Gilles Louppe <g.louppe@gmail.com>
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
# Brian Holt <bdholt1@gmail.com>
# Joel Nothman <joel.nothman@gmail.com>
# Arnaud Joly <arnaud.v.joly@gmail.com>
# Jacob Schreiber <jmschreiber91@gmail.com>
#
# License: BSD 3 clause
# See _criterion.pyx for implementation details.
import numpy as np
cimport numpy as np
from ._tree cimport DTYPE_t # Type of X
from ._tree cimport DOUBLE_t # Type of y, sample_weight
from ._tree cimport SIZE_t # Type for indices and counters
from ._tree cimport INT32_t # Signed 32 bit integer
from ._tree cimport UINT32_t # Unsigned 32 bit integer
cdef class Criterion:
# The criterion computes the impurity of a node and the reduction of
# impurity of a split on that node. It also computes the output statistics
# such as the mean in regression and class probabilities in classification.
# Internal structures
cdef const DOUBLE_t[:, ::1] y # Values of y
cdef DOUBLE_t* sample_weight # Sample weights
cdef SIZE_t* samples # Sample indices in X, y
cdef SIZE_t start # samples[start:pos] are the samples in the left node
cdef SIZE_t pos # samples[pos:end] are the samples in the right node
cdef SIZE_t end
cdef SIZE_t n_outputs # Number of outputs
cdef SIZE_t n_samples # Number of samples
cdef SIZE_t n_node_samples # Number of samples in the node (end-start)
cdef double weighted_n_samples # Weighted number of samples (in total)
cdef double weighted_n_node_samples # Weighted number of samples in the node
cdef double weighted_n_left # Weighted number of samples in the left node
cdef double weighted_n_right # Weighted number of samples in the right node
# The criterion object is maintained such that left and right collected
# statistics correspond to samples[start:pos] and samples[pos:end].
# Methods
cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* sample_weight,
double weighted_n_samples, SIZE_t* samples, SIZE_t start,
SIZE_t end) nogil except -1
cdef int reset(self) nogil except -1
cdef int reverse_reset(self) nogil except -1
cdef int update(self, SIZE_t new_pos) nogil except -1
cdef double node_impurity(self) nogil
cdef void children_impurity(self, double* impurity_left,
double* impurity_right) nogil
cdef void node_value(self, double* dest) nogil
cdef double impurity_improvement(self, double impurity_parent,
double impurity_left,
double impurity_right) nogil
cdef double proxy_impurity_improvement(self) nogil
cdef class ClassificationCriterion(Criterion):
"""Abstract criterion for classification."""
cdef SIZE_t[::1] n_classes
cdef SIZE_t max_n_classes
cdef double[:, ::1] sum_total # The sum of the weighted count of each label.
cdef double[:, ::1] sum_left # Same as above, but for the left side of the split
cdef double[:, ::1] sum_right # Same as above, but for the right side of the split
cdef class RegressionCriterion(Criterion):
"""Abstract regression criterion."""
cdef double sq_sum_total
cdef double[::1] sum_total # The sum of w*y.
cdef double[::1] sum_left # Same as above, but for the left side of the split
cdef double[::1] sum_right # Same as above, but for the right side of the split

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,188 @@
# Authors: William Mill (bill@billmill.org)
# License: BSD 3 clause
import numpy as np
class DrawTree:
def __init__(self, tree, parent=None, depth=0, number=1):
self.x = -1.0
self.y = depth
self.tree = tree
self.children = [
DrawTree(c, self, depth + 1, i + 1) for i, c in enumerate(tree.children)
]
self.parent = parent
self.thread = None
self.mod = 0
self.ancestor = self
self.change = self.shift = 0
self._lmost_sibling = None
# this is the number of the node in its group of siblings 1..n
self.number = number
def left(self):
return self.thread or len(self.children) and self.children[0]
def right(self):
return self.thread or len(self.children) and self.children[-1]
def lbrother(self):
n = None
if self.parent:
for node in self.parent.children:
if node == self:
return n
else:
n = node
return n
def get_lmost_sibling(self):
if not self._lmost_sibling and self.parent and self != self.parent.children[0]:
self._lmost_sibling = self.parent.children[0]
return self._lmost_sibling
lmost_sibling = property(get_lmost_sibling)
def __str__(self):
return "%s: x=%s mod=%s" % (self.tree, self.x, self.mod)
def __repr__(self):
return self.__str__()
def max_extents(self):
extents = [c.max_extents() for c in self.children]
extents.append((self.x, self.y))
return np.max(extents, axis=0)
def buchheim(tree):
dt = first_walk(DrawTree(tree))
min = second_walk(dt)
if min < 0:
third_walk(dt, -min)
return dt
def third_walk(tree, n):
tree.x += n
for c in tree.children:
third_walk(c, n)
def first_walk(v, distance=1.0):
if len(v.children) == 0:
if v.lmost_sibling:
v.x = v.lbrother().x + distance
else:
v.x = 0.0
else:
default_ancestor = v.children[0]
for w in v.children:
first_walk(w)
default_ancestor = apportion(w, default_ancestor, distance)
# print("finished v =", v.tree, "children")
execute_shifts(v)
midpoint = (v.children[0].x + v.children[-1].x) / 2
w = v.lbrother()
if w:
v.x = w.x + distance
v.mod = v.x - midpoint
else:
v.x = midpoint
return v
def apportion(v, default_ancestor, distance):
w = v.lbrother()
if w is not None:
# in buchheim notation:
# i == inner; o == outer; r == right; l == left; r = +; l = -
vir = vor = v
vil = w
vol = v.lmost_sibling
sir = sor = v.mod
sil = vil.mod
sol = vol.mod
while vil.right() and vir.left():
vil = vil.right()
vir = vir.left()
vol = vol.left()
vor = vor.right()
vor.ancestor = v
shift = (vil.x + sil) - (vir.x + sir) + distance
if shift > 0:
move_subtree(ancestor(vil, v, default_ancestor), v, shift)
sir = sir + shift
sor = sor + shift
sil += vil.mod
sir += vir.mod
sol += vol.mod
sor += vor.mod
if vil.right() and not vor.right():
vor.thread = vil.right()
vor.mod += sil - sor
else:
if vir.left() and not vol.left():
vol.thread = vir.left()
vol.mod += sir - sol
default_ancestor = v
return default_ancestor
def move_subtree(wl, wr, shift):
subtrees = wr.number - wl.number
# print(wl.tree, "is conflicted with", wr.tree, 'moving', subtrees,
# 'shift', shift)
# print wl, wr, wr.number, wl.number, shift, subtrees, shift/subtrees
wr.change -= shift / subtrees
wr.shift += shift
wl.change += shift / subtrees
wr.x += shift
wr.mod += shift
def execute_shifts(v):
shift = change = 0
for w in v.children[::-1]:
# print("shift:", w, shift, w.change)
w.x += shift
w.mod += shift
change += w.change
shift += w.shift + change
def ancestor(vil, v, default_ancestor):
# the relevant text is at the bottom of page 7 of
# "Improving Walker's Algorithm to Run in Linear Time" by Buchheim et al,
# (2002)
# http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.16.8757&rep=rep1&type=pdf
if vil.ancestor in v.parent.children:
return vil.ancestor
else:
return default_ancestor
def second_walk(v, m=0, depth=0, min=None):
v.x += m
v.y = depth
if min is None or v.x < min:
min = v.x
for w in v.children:
min = second_walk(w, m + v.mod, depth + 1, min)
return min
class Tree:
def __init__(self, label="", node_id=-1, *children):
self.label = label
self.node_id = node_id
if children:
self.children = children
else:
self.children = []

View File

@@ -0,0 +1,93 @@
# Authors: Gilles Louppe <g.louppe@gmail.com>
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
# Brian Holt <bdholt1@gmail.com>
# Joel Nothman <joel.nothman@gmail.com>
# Arnaud Joly <arnaud.v.joly@gmail.com>
# Jacob Schreiber <jmschreiber91@gmail.com>
#
# License: BSD 3 clause
# See _splitter.pyx for details.
import numpy as np
cimport numpy as np
from ._criterion cimport Criterion
from ._tree cimport DTYPE_t # Type of X
from ._tree cimport DOUBLE_t # Type of y, sample_weight
from ._tree cimport SIZE_t # Type for indices and counters
from ._tree cimport INT32_t # Signed 32 bit integer
from ._tree cimport UINT32_t # Unsigned 32 bit integer
cdef struct SplitRecord:
# Data to track sample split
SIZE_t feature # Which feature to split on.
SIZE_t pos # Split samples array at the given position,
# i.e. count of samples below threshold for feature.
# pos is >= end if the node is a leaf.
double threshold # Threshold to split at.
double improvement # Impurity improvement given parent node.
double impurity_left # Impurity of the left split.
double impurity_right # Impurity of the right split.
cdef class Splitter:
# The splitter searches in the input space for a feature and a threshold
# to split the samples samples[start:end].
#
# The impurity computations are delegated to a criterion object.
# Internal structures
cdef public Criterion criterion # Impurity criterion
cdef public SIZE_t max_features # Number of features to test
cdef public SIZE_t min_samples_leaf # Min samples in a leaf
cdef public double min_weight_leaf # Minimum weight in a leaf
cdef object random_state # Random state
cdef UINT32_t rand_r_state # sklearn_rand_r random number state
cdef SIZE_t* samples # Sample indices in X, y
cdef SIZE_t n_samples # X.shape[0]
cdef double weighted_n_samples # Weighted number of samples
cdef SIZE_t* features # Feature indices in X
cdef SIZE_t* constant_features # Constant features indices
cdef SIZE_t n_features # X.shape[1]
cdef DTYPE_t* feature_values # temp. array holding feature values
cdef SIZE_t start # Start position for the current node
cdef SIZE_t end # End position for the current node
cdef const DOUBLE_t[:, ::1] y
cdef DOUBLE_t* sample_weight
# The samples vector `samples` is maintained by the Splitter object such
# that the samples contained in a node are contiguous. With this setting,
# `node_split` reorganizes the node samples `samples[start:end]` in two
# subsets `samples[start:pos]` and `samples[pos:end]`.
# The 1-d `features` array of size n_features contains the features
# indices and allows fast sampling without replacement of features.
# The 1-d `constant_features` array of size n_features holds in
# `constant_features[:n_constant_features]` the feature ids with
# constant values for all the samples that reached a specific node.
# The value `n_constant_features` is given by the parent node to its
# child nodes. The content of the range `[n_constant_features:]` is left
# undefined, but preallocated for performance reasons
# This allows optimization with depth-based tree building.
# Methods
cdef int init(self, object X, const DOUBLE_t[:, ::1] y,
DOUBLE_t* sample_weight) except -1
cdef int node_reset(self, SIZE_t start, SIZE_t end,
double* weighted_n_node_samples) nogil except -1
cdef int node_split(self,
double impurity, # Impurity of the node
SplitRecord* split,
SIZE_t* n_constant_features) nogil except -1
cdef void node_value(self, double* dest) nogil
cdef double node_impurity(self) nogil

View File

@@ -0,0 +1,103 @@
# Authors: Gilles Louppe <g.louppe@gmail.com>
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
# Brian Holt <bdholt1@gmail.com>
# Joel Nothman <joel.nothman@gmail.com>
# Arnaud Joly <arnaud.v.joly@gmail.com>
# Jacob Schreiber <jmschreiber91@gmail.com>
# Nelson Liu <nelson@nelsonliu.me>
#
# License: BSD 3 clause
# See _tree.pyx for details.
import numpy as np
cimport numpy as np
ctypedef np.npy_float32 DTYPE_t # Type of X
ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight
ctypedef np.npy_intp SIZE_t # Type for indices and counters
ctypedef np.npy_int32 INT32_t # Signed 32 bit integer
ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer
from ._splitter cimport Splitter
from ._splitter cimport SplitRecord
cdef struct Node:
# Base storage structure for the nodes in a Tree object
SIZE_t left_child # id of the left child of the node
SIZE_t right_child # id of the right child of the node
SIZE_t feature # Feature used for splitting the node
DOUBLE_t threshold # Threshold value at the node
DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion)
SIZE_t n_node_samples # Number of samples at the node
DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node
cdef class Tree:
# The Tree object is a binary tree structure constructed by the
# TreeBuilder. The tree structure is used for predictions and
# feature importances.
# Input/Output layout
cdef public SIZE_t n_features # Number of features in X
cdef SIZE_t* n_classes # Number of classes in y[:, k]
cdef public SIZE_t n_outputs # Number of outputs in y
cdef public SIZE_t max_n_classes # max(n_classes)
# Inner structures: values are stored separately from node structure,
# since size is determined at runtime.
cdef public SIZE_t max_depth # Max depth of the tree
cdef public SIZE_t node_count # Counter for node IDs
cdef public SIZE_t capacity # Capacity of tree, in terms of nodes
cdef Node* nodes # Array of nodes
cdef double* value # (capacity, n_outputs, max_n_classes) array of values
cdef SIZE_t value_stride # = n_outputs * max_n_classes
# Methods
cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf,
SIZE_t feature, double threshold, double impurity,
SIZE_t n_node_samples,
double weighted_n_node_samples) nogil except -1
cdef int _resize(self, SIZE_t capacity) nogil except -1
cdef int _resize_c(self, SIZE_t capacity=*) nogil except -1
cdef np.ndarray _get_value_ndarray(self)
cdef np.ndarray _get_node_ndarray(self)
cpdef np.ndarray predict(self, object X)
cpdef np.ndarray apply(self, object X)
cdef np.ndarray _apply_dense(self, object X)
cdef np.ndarray _apply_sparse_csr(self, object X)
cpdef object decision_path(self, object X)
cdef object _decision_path_dense(self, object X)
cdef object _decision_path_sparse_csr(self, object X)
cpdef compute_feature_importances(self, normalize=*)
# =============================================================================
# Tree builder
# =============================================================================
cdef class TreeBuilder:
# The TreeBuilder recursively builds a Tree object from training samples,
# using a Splitter object for splitting internal nodes and assigning
# values to leaves.
#
# This class controls the various stopping criteria and the node splitting
# evaluation order, e.g. depth-first or best-first.
cdef Splitter splitter # Splitting algorithm
cdef SIZE_t min_samples_split # Minimum number of samples in an internal node
cdef SIZE_t min_samples_leaf # Minimum number of samples in a leaf
cdef double min_weight_leaf # Minimum weight in a leaf
cdef SIZE_t max_depth # Maximal tree depth
cdef double min_impurity_decrease # Impurity threshold for early stopping
cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray sample_weight=*)
cdef _check_input(self, object X, np.ndarray y, np.ndarray sample_weight)

View File

@@ -0,0 +1,111 @@
# Authors: Gilles Louppe <g.louppe@gmail.com>
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
# Arnaud Joly <arnaud.v.joly@gmail.com>
# Jacob Schreiber <jmschreiber91@gmail.com>
# Nelson Liu <nelson@nelsonliu.me>
#
# License: BSD 3 clause
# See _utils.pyx for details.
import numpy as np
cimport numpy as np
from ._tree cimport Node
from ..neighbors._quad_tree cimport Cell
ctypedef np.npy_float32 DTYPE_t # Type of X
ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight
ctypedef np.npy_intp SIZE_t # Type for indices and counters
ctypedef np.npy_int32 INT32_t # Signed 32 bit integer
ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer
cdef enum:
# Max value for our rand_r replacement (near the bottom).
# We don't use RAND_MAX because it's different across platforms and
# particularly tiny on Windows/MSVC.
RAND_R_MAX = 0x7FFFFFFF
# safe_realloc(&p, n) resizes the allocation of p to n * sizeof(*p) bytes or
# raises a MemoryError. It never calls free, since that's __dealloc__'s job.
# cdef DTYPE_t *p = NULL
# safe_realloc(&p, n)
# is equivalent to p = malloc(n * sizeof(*p)) with error checking.
ctypedef fused realloc_ptr:
# Add pointer types here as needed.
(DTYPE_t*)
(SIZE_t*)
(unsigned char*)
(WeightedPQueueRecord*)
(DOUBLE_t*)
(DOUBLE_t**)
(Node*)
(Cell*)
(Node**)
cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) nogil except *
cdef np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size)
cdef SIZE_t rand_int(SIZE_t low, SIZE_t high,
UINT32_t* random_state) nogil
cdef double rand_uniform(double low, double high,
UINT32_t* random_state) nogil
cdef double log(double x) nogil
# =============================================================================
# WeightedPQueue data structure
# =============================================================================
# A record stored in the WeightedPQueue
cdef struct WeightedPQueueRecord:
DOUBLE_t data
DOUBLE_t weight
cdef class WeightedPQueue:
cdef SIZE_t capacity
cdef SIZE_t array_ptr
cdef WeightedPQueueRecord* array_
cdef bint is_empty(self) nogil
cdef int reset(self) nogil except -1
cdef SIZE_t size(self) nogil
cdef int push(self, DOUBLE_t data, DOUBLE_t weight) nogil except -1
cdef int remove(self, DOUBLE_t data, DOUBLE_t weight) nogil
cdef int pop(self, DOUBLE_t* data, DOUBLE_t* weight) nogil
cdef int peek(self, DOUBLE_t* data, DOUBLE_t* weight) nogil
cdef DOUBLE_t get_weight_from_index(self, SIZE_t index) nogil
cdef DOUBLE_t get_value_from_index(self, SIZE_t index) nogil
# =============================================================================
# WeightedMedianCalculator data structure
# =============================================================================
cdef class WeightedMedianCalculator:
cdef SIZE_t initial_capacity
cdef WeightedPQueue samples
cdef DOUBLE_t total_weight
cdef SIZE_t k
cdef DOUBLE_t sum_w_0_k # represents sum(weights[0:k])
# = w[0] + w[1] + ... + w[k-1]
cdef SIZE_t size(self) nogil
cdef int push(self, DOUBLE_t data, DOUBLE_t weight) nogil except -1
cdef int reset(self) nogil except -1
cdef int update_median_parameters_post_push(
self, DOUBLE_t data, DOUBLE_t weight,
DOUBLE_t original_median) nogil
cdef int remove(self, DOUBLE_t data, DOUBLE_t weight) nogil
cdef int pop(self, DOUBLE_t* data, DOUBLE_t* weight) nogil
cdef int update_median_parameters_post_remove(
self, DOUBLE_t data, DOUBLE_t weight,
DOUBLE_t original_median) nogil
cdef DOUBLE_t get_median(self) nogil

View File

@@ -0,0 +1,50 @@
import os
import numpy
from numpy.distutils.misc_util import Configuration
def configuration(parent_package="", top_path=None):
config = Configuration("tree", parent_package, top_path)
libraries = []
if os.name == "posix":
libraries.append("m")
config.add_extension(
"_tree",
sources=["_tree.pyx"],
include_dirs=[numpy.get_include()],
libraries=libraries,
language="c++",
extra_compile_args=["-O3"],
)
config.add_extension(
"_splitter",
sources=["_splitter.pyx"],
include_dirs=[numpy.get_include()],
libraries=libraries,
extra_compile_args=["-O3"],
)
config.add_extension(
"_criterion",
sources=["_criterion.pyx"],
include_dirs=[numpy.get_include()],
libraries=libraries,
extra_compile_args=["-O3"],
)
config.add_extension(
"_utils",
sources=["_utils.pyx"],
include_dirs=[numpy.get_include()],
libraries=libraries,
extra_compile_args=["-O3"],
)
config.add_subpackage("tests")
return config
if __name__ == "__main__":
from numpy.distutils.core import setup
setup(**configuration().todict())

View File

@@ -0,0 +1,512 @@
"""
Testing for export functions of decision trees (sklearn.tree.export).
"""
from re import finditer, search
from textwrap import dedent
from numpy.random import RandomState
import pytest
from sklearn.base import is_classifier
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.tree import export_graphviz, plot_tree, export_text
from io import StringIO
from sklearn.exceptions import NotFittedError
# toy sample
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
y = [-1, -1, -1, 1, 1, 1]
y2 = [[-1, 1], [-1, 1], [-1, 1], [1, 2], [1, 2], [1, 3]]
w = [1, 1, 1, 0.5, 0.5, 0.5]
y_degraded = [1, 1, 1, 1, 1, 1]
def test_graphviz_toy():
# Check correctness of export_graphviz
clf = DecisionTreeClassifier(
max_depth=3, min_samples_split=2, criterion="gini", random_state=2
)
clf.fit(X, y)
# Test export code
contents1 = export_graphviz(clf, out_file=None)
contents2 = (
"digraph Tree {\n"
'node [shape=box, fontname="helvetica"] ;\n'
'edge [fontname="helvetica"] ;\n'
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
'value = [3, 3]"] ;\n'
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
"0 -> 1 [labeldistance=2.5, labelangle=45, "
'headlabel="True"] ;\n'
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
'headlabel="False"] ;\n'
"}"
)
assert contents1 == contents2
# Test with feature_names
contents1 = export_graphviz(
clf, feature_names=["feature0", "feature1"], out_file=None
)
contents2 = (
"digraph Tree {\n"
'node [shape=box, fontname="helvetica"] ;\n'
'edge [fontname="helvetica"] ;\n'
'0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
'value = [3, 3]"] ;\n'
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
"0 -> 1 [labeldistance=2.5, labelangle=45, "
'headlabel="True"] ;\n'
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
'headlabel="False"] ;\n'
"}"
)
assert contents1 == contents2
# Test with class_names
contents1 = export_graphviz(clf, class_names=["yes", "no"], out_file=None)
contents2 = (
"digraph Tree {\n"
'node [shape=box, fontname="helvetica"] ;\n'
'edge [fontname="helvetica"] ;\n'
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
'value = [3, 3]\\nclass = yes"] ;\n'
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n'
'class = yes"] ;\n'
"0 -> 1 [labeldistance=2.5, labelangle=45, "
'headlabel="True"] ;\n'
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n'
'class = no"] ;\n'
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
'headlabel="False"] ;\n'
"}"
)
assert contents1 == contents2
# Test plot_options
contents1 = export_graphviz(
clf,
filled=True,
impurity=False,
proportion=True,
special_characters=True,
rounded=True,
out_file=None,
fontname="sans",
)
contents2 = (
"digraph Tree {\n"
'node [shape=box, style="filled, rounded", color="black", '
'fontname="sans"] ;\n'
'edge [fontname="sans"] ;\n'
"0 [label=<X<SUB>0</SUB> &le; 0.0<br/>samples = 100.0%<br/>"
'value = [0.5, 0.5]>, fillcolor="#ffffff"] ;\n'
"1 [label=<samples = 50.0%<br/>value = [1.0, 0.0]>, "
'fillcolor="#e58139"] ;\n'
"0 -> 1 [labeldistance=2.5, labelangle=45, "
'headlabel="True"] ;\n'
"2 [label=<samples = 50.0%<br/>value = [0.0, 1.0]>, "
'fillcolor="#399de5"] ;\n'
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
'headlabel="False"] ;\n'
"}"
)
assert contents1 == contents2
# Test max_depth
contents1 = export_graphviz(clf, max_depth=0, class_names=True, out_file=None)
contents2 = (
"digraph Tree {\n"
'node [shape=box, fontname="helvetica"] ;\n'
'edge [fontname="helvetica"] ;\n'
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
'value = [3, 3]\\nclass = y[0]"] ;\n'
'1 [label="(...)"] ;\n'
"0 -> 1 ;\n"
'2 [label="(...)"] ;\n'
"0 -> 2 ;\n"
"}"
)
assert contents1 == contents2
# Test max_depth with plot_options
contents1 = export_graphviz(
clf, max_depth=0, filled=True, out_file=None, node_ids=True
)
contents2 = (
"digraph Tree {\n"
'node [shape=box, style="filled", color="black", '
'fontname="helvetica"] ;\n'
'edge [fontname="helvetica"] ;\n'
'0 [label="node #0\\nX[0] <= 0.0\\ngini = 0.5\\n'
'samples = 6\\nvalue = [3, 3]", fillcolor="#ffffff"] ;\n'
'1 [label="(...)", fillcolor="#C0C0C0"] ;\n'
"0 -> 1 ;\n"
'2 [label="(...)", fillcolor="#C0C0C0"] ;\n'
"0 -> 2 ;\n"
"}"
)
assert contents1 == contents2
# Test multi-output with weighted samples
clf = DecisionTreeClassifier(
max_depth=2, min_samples_split=2, criterion="gini", random_state=2
)
clf = clf.fit(X, y2, sample_weight=w)
contents1 = export_graphviz(clf, filled=True, impurity=False, out_file=None)
contents2 = (
"digraph Tree {\n"
'node [shape=box, style="filled", color="black", '
'fontname="helvetica"] ;\n'
'edge [fontname="helvetica"] ;\n'
'0 [label="X[0] <= 0.0\\nsamples = 6\\n'
"value = [[3.0, 1.5, 0.0]\\n"
'[3.0, 1.0, 0.5]]", fillcolor="#ffffff"] ;\n'
'1 [label="samples = 3\\nvalue = [[3, 0, 0]\\n'
'[3, 0, 0]]", fillcolor="#e58139"] ;\n'
"0 -> 1 [labeldistance=2.5, labelangle=45, "
'headlabel="True"] ;\n'
'2 [label="X[0] <= 1.5\\nsamples = 3\\n'
"value = [[0.0, 1.5, 0.0]\\n"
'[0.0, 1.0, 0.5]]", fillcolor="#f1bd97"] ;\n'
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
'headlabel="False"] ;\n'
'3 [label="samples = 2\\nvalue = [[0, 1, 0]\\n'
'[0, 1, 0]]", fillcolor="#e58139"] ;\n'
"2 -> 3 ;\n"
'4 [label="samples = 1\\nvalue = [[0.0, 0.5, 0.0]\\n'
'[0.0, 0.0, 0.5]]", fillcolor="#e58139"] ;\n'
"2 -> 4 ;\n"
"}"
)
assert contents1 == contents2
# Test regression output with plot_options
clf = DecisionTreeRegressor(
max_depth=3, min_samples_split=2, criterion="squared_error", random_state=2
)
clf.fit(X, y)
contents1 = export_graphviz(
clf,
filled=True,
leaves_parallel=True,
out_file=None,
rotate=True,
rounded=True,
fontname="sans",
)
contents2 = (
"digraph Tree {\n"
'node [shape=box, style="filled, rounded", color="black", '
'fontname="sans"] ;\n'
"graph [ranksep=equally, splines=polyline] ;\n"
'edge [fontname="sans"] ;\n'
"rankdir=LR ;\n"
'0 [label="X[0] <= 0.0\\nsquared_error = 1.0\\nsamples = 6\\n'
'value = 0.0", fillcolor="#f2c09c"] ;\n'
'1 [label="squared_error = 0.0\\nsamples = 3\\'
'nvalue = -1.0", '
'fillcolor="#ffffff"] ;\n'
"0 -> 1 [labeldistance=2.5, labelangle=-45, "
'headlabel="True"] ;\n'
'2 [label="squared_error = 0.0\\nsamples = 3\\nvalue = 1.0", '
'fillcolor="#e58139"] ;\n'
"0 -> 2 [labeldistance=2.5, labelangle=45, "
'headlabel="False"] ;\n'
"{rank=same ; 0} ;\n"
"{rank=same ; 1; 2} ;\n"
"}"
)
assert contents1 == contents2
# Test classifier with degraded learning set
clf = DecisionTreeClassifier(max_depth=3)
clf.fit(X, y_degraded)
contents1 = export_graphviz(clf, filled=True, out_file=None)
contents2 = (
"digraph Tree {\n"
'node [shape=box, style="filled", color="black", '
'fontname="helvetica"] ;\n'
'edge [fontname="helvetica"] ;\n'
'0 [label="gini = 0.0\\nsamples = 6\\nvalue = 6.0", '
'fillcolor="#ffffff"] ;\n'
"}"
)
def test_graphviz_errors():
# Check for errors of export_graphviz
clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2)
# Check not-fitted decision tree error
out = StringIO()
with pytest.raises(NotFittedError):
export_graphviz(clf, out)
clf.fit(X, y)
# Check if it errors when length of feature_names
# mismatches with number of features
message = "Length of feature_names, 1 does not match number of features, 2"
with pytest.raises(ValueError, match=message):
export_graphviz(clf, None, feature_names=["a"])
message = "Length of feature_names, 3 does not match number of features, 2"
with pytest.raises(ValueError, match=message):
export_graphviz(clf, None, feature_names=["a", "b", "c"])
# Check error when argument is not an estimator
message = "is not an estimator instance"
with pytest.raises(TypeError, match=message):
export_graphviz(clf.fit(X, y).tree_)
# Check class_names error
out = StringIO()
with pytest.raises(IndexError):
export_graphviz(clf, out, class_names=[])
# Check precision error
out = StringIO()
with pytest.raises(ValueError, match="should be greater or equal"):
export_graphviz(clf, out, precision=-1)
with pytest.raises(ValueError, match="should be an integer"):
export_graphviz(clf, out, precision="1")
def test_friedman_mse_in_graphviz():
clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0)
clf.fit(X, y)
dot_data = StringIO()
export_graphviz(clf, out_file=dot_data)
clf = GradientBoostingClassifier(n_estimators=2, random_state=0)
clf.fit(X, y)
for estimator in clf.estimators_:
export_graphviz(estimator[0], out_file=dot_data)
for finding in finditer(r"\[.*?samples.*?\]", dot_data.getvalue()):
assert "friedman_mse" in finding.group()
def test_precision():
rng_reg = RandomState(2)
rng_clf = RandomState(8)
for X, y, clf in zip(
(rng_reg.random_sample((5, 2)), rng_clf.random_sample((1000, 4))),
(rng_reg.random_sample((5,)), rng_clf.randint(2, size=(1000,))),
(
DecisionTreeRegressor(
criterion="friedman_mse", random_state=0, max_depth=1
),
DecisionTreeClassifier(max_depth=1, random_state=0),
),
):
clf.fit(X, y)
for precision in (4, 3):
dot_data = export_graphviz(
clf, out_file=None, precision=precision, proportion=True
)
# With the current random state, the impurity and the threshold
# will have the number of precision set in the export_graphviz
# function. We will check the number of precision with a strict
# equality. The value reported will have only 2 precision and
# therefore, only a less equal comparison will be done.
# check value
for finding in finditer(r"value = \d+\.\d+", dot_data):
assert len(search(r"\.\d+", finding.group()).group()) <= precision + 1
# check impurity
if is_classifier(clf):
pattern = r"gini = \d+\.\d+"
else:
pattern = r"friedman_mse = \d+\.\d+"
# check impurity
for finding in finditer(pattern, dot_data):
assert len(search(r"\.\d+", finding.group()).group()) == precision + 1
# check threshold
for finding in finditer(r"<= \d+\.\d+", dot_data):
assert len(search(r"\.\d+", finding.group()).group()) == precision + 1
def test_export_text_errors():
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
clf.fit(X, y)
err_msg = "max_depth bust be >= 0, given -1"
with pytest.raises(ValueError, match=err_msg):
export_text(clf, max_depth=-1)
err_msg = "feature_names must contain 2 elements, got 1"
with pytest.raises(ValueError, match=err_msg):
export_text(clf, feature_names=["a"])
err_msg = "decimals must be >= 0, given -1"
with pytest.raises(ValueError, match=err_msg):
export_text(clf, decimals=-1)
err_msg = "spacing must be > 0, given 0"
with pytest.raises(ValueError, match=err_msg):
export_text(clf, spacing=0)
def test_export_text():
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
clf.fit(X, y)
expected_report = dedent(
"""
|--- feature_1 <= 0.00
| |--- class: -1
|--- feature_1 > 0.00
| |--- class: 1
"""
).lstrip()
assert export_text(clf) == expected_report
# testing that leaves at level 1 are not truncated
assert export_text(clf, max_depth=0) == expected_report
# testing that the rest of the tree is truncated
assert export_text(clf, max_depth=10) == expected_report
expected_report = dedent(
"""
|--- b <= 0.00
| |--- class: -1
|--- b > 0.00
| |--- class: 1
"""
).lstrip()
assert export_text(clf, feature_names=["a", "b"]) == expected_report
expected_report = dedent(
"""
|--- feature_1 <= 0.00
| |--- weights: [3.00, 0.00] class: -1
|--- feature_1 > 0.00
| |--- weights: [0.00, 3.00] class: 1
"""
).lstrip()
assert export_text(clf, show_weights=True) == expected_report
expected_report = dedent(
"""
|- feature_1 <= 0.00
| |- class: -1
|- feature_1 > 0.00
| |- class: 1
"""
).lstrip()
assert export_text(clf, spacing=1) == expected_report
X_l = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [-1, 1]]
y_l = [-1, -1, -1, 1, 1, 1, 2]
clf = DecisionTreeClassifier(max_depth=4, random_state=0)
clf.fit(X_l, y_l)
expected_report = dedent(
"""
|--- feature_1 <= 0.00
| |--- class: -1
|--- feature_1 > 0.00
| |--- truncated branch of depth 2
"""
).lstrip()
assert export_text(clf, max_depth=0) == expected_report
X_mo = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
y_mo = [[-1, -1], [-1, -1], [-1, -1], [1, 1], [1, 1], [1, 1]]
reg = DecisionTreeRegressor(max_depth=2, random_state=0)
reg.fit(X_mo, y_mo)
expected_report = dedent(
"""
|--- feature_1 <= 0.0
| |--- value: [-1.0, -1.0]
|--- feature_1 > 0.0
| |--- value: [1.0, 1.0]
"""
).lstrip()
assert export_text(reg, decimals=1) == expected_report
assert export_text(reg, decimals=1, show_weights=True) == expected_report
X_single = [[-2], [-1], [-1], [1], [1], [2]]
reg = DecisionTreeRegressor(max_depth=2, random_state=0)
reg.fit(X_single, y_mo)
expected_report = dedent(
"""
|--- first <= 0.0
| |--- value: [-1.0, -1.0]
|--- first > 0.0
| |--- value: [1.0, 1.0]
"""
).lstrip()
assert export_text(reg, decimals=1, feature_names=["first"]) == expected_report
assert (
export_text(reg, decimals=1, show_weights=True, feature_names=["first"])
== expected_report
)
def test_plot_tree_entropy(pyplot):
# mostly smoke tests
# Check correctness of export_graphviz for criterion = entropy
clf = DecisionTreeClassifier(
max_depth=3, min_samples_split=2, criterion="entropy", random_state=2
)
clf.fit(X, y)
# Test export code
feature_names = ["first feat", "sepal_width"]
nodes = plot_tree(clf, feature_names=feature_names)
assert len(nodes) == 3
assert (
nodes[0].get_text()
== "first feat <= 0.0\nentropy = 1.0\nsamples = 6\nvalue = [3, 3]"
)
assert nodes[1].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [3, 0]"
assert nodes[2].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [0, 3]"
def test_plot_tree_gini(pyplot):
# mostly smoke tests
# Check correctness of export_graphviz for criterion = gini
clf = DecisionTreeClassifier(
max_depth=3, min_samples_split=2, criterion="gini", random_state=2
)
clf.fit(X, y)
# Test export code
feature_names = ["first feat", "sepal_width"]
nodes = plot_tree(clf, feature_names=feature_names)
assert len(nodes) == 3
assert (
nodes[0].get_text()
== "first feat <= 0.0\ngini = 0.5\nsamples = 6\nvalue = [3, 3]"
)
assert nodes[1].get_text() == "gini = 0.0\nsamples = 3\nvalue = [3, 0]"
assert nodes[2].get_text() == "gini = 0.0\nsamples = 3\nvalue = [0, 3]"
def test_not_fitted_tree(pyplot):
# Testing if not fitted tree throws the correct error
clf = DecisionTreeRegressor()
with pytest.raises(NotFittedError):
plot_tree(clf)

View File

@@ -0,0 +1,48 @@
import numpy as np
import pytest
from sklearn.tree._reingold_tilford import buchheim, Tree
simple_tree = Tree("", 0, Tree("", 1), Tree("", 2))
bigger_tree = Tree(
"",
0,
Tree(
"",
1,
Tree("", 3),
Tree("", 4, Tree("", 7), Tree("", 8)),
),
Tree("", 2, Tree("", 5), Tree("", 6)),
)
@pytest.mark.parametrize("tree, n_nodes", [(simple_tree, 3), (bigger_tree, 9)])
def test_buchheim(tree, n_nodes):
def walk_tree(draw_tree):
res = [(draw_tree.x, draw_tree.y)]
for child in draw_tree.children:
# parents higher than children:
assert child.y == draw_tree.y + 1
res.extend(walk_tree(child))
if len(draw_tree.children):
# these trees are always binary
# parents are centered above children
assert (
draw_tree.x == (draw_tree.children[0].x + draw_tree.children[1].x) / 2
)
return res
layout = buchheim(tree)
coordinates = walk_tree(layout)
assert len(coordinates) == n_nodes
# test that x values are unique per depth / level
# we could also do it quicker using defaultdicts..
depth = 0
while True:
x_at_this_depth = [node[0] for node in coordinates if node[1] == depth]
if not x_at_this_depth:
# reached all leafs
break
assert len(np.unique(x_at_this_depth)) == len(x_at_this_depth)
depth += 1

File diff suppressed because it is too large Load Diff