first commit
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
The :mod:`sklearn.tree` module includes decision tree-based models for
|
||||
classification and regression.
|
||||
"""
|
||||
|
||||
from ._classes import BaseDecisionTree
|
||||
from ._classes import DecisionTreeClassifier
|
||||
from ._classes import DecisionTreeRegressor
|
||||
from ._classes import ExtraTreeClassifier
|
||||
from ._classes import ExtraTreeRegressor
|
||||
from ._export import export_graphviz, plot_tree, export_text
|
||||
|
||||
__all__ = [
|
||||
"BaseDecisionTree",
|
||||
"DecisionTreeClassifier",
|
||||
"DecisionTreeRegressor",
|
||||
"ExtraTreeClassifier",
|
||||
"ExtraTreeRegressor",
|
||||
"export_graphviz",
|
||||
"plot_tree",
|
||||
"export_text",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Binary file not shown.
@@ -0,0 +1,79 @@
|
||||
# Authors: Gilles Louppe <g.louppe@gmail.com>
|
||||
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
|
||||
# Brian Holt <bdholt1@gmail.com>
|
||||
# Joel Nothman <joel.nothman@gmail.com>
|
||||
# Arnaud Joly <arnaud.v.joly@gmail.com>
|
||||
# Jacob Schreiber <jmschreiber91@gmail.com>
|
||||
#
|
||||
# License: BSD 3 clause
|
||||
|
||||
# See _criterion.pyx for implementation details.
|
||||
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
|
||||
from ._tree cimport DTYPE_t # Type of X
|
||||
from ._tree cimport DOUBLE_t # Type of y, sample_weight
|
||||
from ._tree cimport SIZE_t # Type for indices and counters
|
||||
from ._tree cimport INT32_t # Signed 32 bit integer
|
||||
from ._tree cimport UINT32_t # Unsigned 32 bit integer
|
||||
|
||||
cdef class Criterion:
|
||||
# The criterion computes the impurity of a node and the reduction of
|
||||
# impurity of a split on that node. It also computes the output statistics
|
||||
# such as the mean in regression and class probabilities in classification.
|
||||
|
||||
# Internal structures
|
||||
cdef const DOUBLE_t[:, ::1] y # Values of y
|
||||
cdef DOUBLE_t* sample_weight # Sample weights
|
||||
|
||||
cdef SIZE_t* samples # Sample indices in X, y
|
||||
cdef SIZE_t start # samples[start:pos] are the samples in the left node
|
||||
cdef SIZE_t pos # samples[pos:end] are the samples in the right node
|
||||
cdef SIZE_t end
|
||||
|
||||
cdef SIZE_t n_outputs # Number of outputs
|
||||
cdef SIZE_t n_samples # Number of samples
|
||||
cdef SIZE_t n_node_samples # Number of samples in the node (end-start)
|
||||
cdef double weighted_n_samples # Weighted number of samples (in total)
|
||||
cdef double weighted_n_node_samples # Weighted number of samples in the node
|
||||
cdef double weighted_n_left # Weighted number of samples in the left node
|
||||
cdef double weighted_n_right # Weighted number of samples in the right node
|
||||
|
||||
# The criterion object is maintained such that left and right collected
|
||||
# statistics correspond to samples[start:pos] and samples[pos:end].
|
||||
|
||||
# Methods
|
||||
cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* sample_weight,
|
||||
double weighted_n_samples, SIZE_t* samples, SIZE_t start,
|
||||
SIZE_t end) nogil except -1
|
||||
cdef int reset(self) nogil except -1
|
||||
cdef int reverse_reset(self) nogil except -1
|
||||
cdef int update(self, SIZE_t new_pos) nogil except -1
|
||||
cdef double node_impurity(self) nogil
|
||||
cdef void children_impurity(self, double* impurity_left,
|
||||
double* impurity_right) nogil
|
||||
cdef void node_value(self, double* dest) nogil
|
||||
cdef double impurity_improvement(self, double impurity_parent,
|
||||
double impurity_left,
|
||||
double impurity_right) nogil
|
||||
cdef double proxy_impurity_improvement(self) nogil
|
||||
|
||||
cdef class ClassificationCriterion(Criterion):
|
||||
"""Abstract criterion for classification."""
|
||||
|
||||
cdef SIZE_t[::1] n_classes
|
||||
cdef SIZE_t max_n_classes
|
||||
|
||||
cdef double[:, ::1] sum_total # The sum of the weighted count of each label.
|
||||
cdef double[:, ::1] sum_left # Same as above, but for the left side of the split
|
||||
cdef double[:, ::1] sum_right # Same as above, but for the right side of the split
|
||||
|
||||
cdef class RegressionCriterion(Criterion):
|
||||
"""Abstract regression criterion."""
|
||||
|
||||
cdef double sq_sum_total
|
||||
|
||||
cdef double[::1] sum_total # The sum of w*y.
|
||||
cdef double[::1] sum_left # Same as above, but for the left side of the split
|
||||
cdef double[::1] sum_right # Same as above, but for the right side of the split
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,188 @@
|
||||
# Authors: William Mill (bill@billmill.org)
|
||||
# License: BSD 3 clause
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DrawTree:
|
||||
def __init__(self, tree, parent=None, depth=0, number=1):
|
||||
self.x = -1.0
|
||||
self.y = depth
|
||||
self.tree = tree
|
||||
self.children = [
|
||||
DrawTree(c, self, depth + 1, i + 1) for i, c in enumerate(tree.children)
|
||||
]
|
||||
self.parent = parent
|
||||
self.thread = None
|
||||
self.mod = 0
|
||||
self.ancestor = self
|
||||
self.change = self.shift = 0
|
||||
self._lmost_sibling = None
|
||||
# this is the number of the node in its group of siblings 1..n
|
||||
self.number = number
|
||||
|
||||
def left(self):
|
||||
return self.thread or len(self.children) and self.children[0]
|
||||
|
||||
def right(self):
|
||||
return self.thread or len(self.children) and self.children[-1]
|
||||
|
||||
def lbrother(self):
|
||||
n = None
|
||||
if self.parent:
|
||||
for node in self.parent.children:
|
||||
if node == self:
|
||||
return n
|
||||
else:
|
||||
n = node
|
||||
return n
|
||||
|
||||
def get_lmost_sibling(self):
|
||||
if not self._lmost_sibling and self.parent and self != self.parent.children[0]:
|
||||
self._lmost_sibling = self.parent.children[0]
|
||||
return self._lmost_sibling
|
||||
|
||||
lmost_sibling = property(get_lmost_sibling)
|
||||
|
||||
def __str__(self):
|
||||
return "%s: x=%s mod=%s" % (self.tree, self.x, self.mod)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def max_extents(self):
|
||||
extents = [c.max_extents() for c in self.children]
|
||||
extents.append((self.x, self.y))
|
||||
return np.max(extents, axis=0)
|
||||
|
||||
|
||||
def buchheim(tree):
|
||||
dt = first_walk(DrawTree(tree))
|
||||
min = second_walk(dt)
|
||||
if min < 0:
|
||||
third_walk(dt, -min)
|
||||
return dt
|
||||
|
||||
|
||||
def third_walk(tree, n):
|
||||
tree.x += n
|
||||
for c in tree.children:
|
||||
third_walk(c, n)
|
||||
|
||||
|
||||
def first_walk(v, distance=1.0):
|
||||
if len(v.children) == 0:
|
||||
if v.lmost_sibling:
|
||||
v.x = v.lbrother().x + distance
|
||||
else:
|
||||
v.x = 0.0
|
||||
else:
|
||||
default_ancestor = v.children[0]
|
||||
for w in v.children:
|
||||
first_walk(w)
|
||||
default_ancestor = apportion(w, default_ancestor, distance)
|
||||
# print("finished v =", v.tree, "children")
|
||||
execute_shifts(v)
|
||||
|
||||
midpoint = (v.children[0].x + v.children[-1].x) / 2
|
||||
|
||||
w = v.lbrother()
|
||||
if w:
|
||||
v.x = w.x + distance
|
||||
v.mod = v.x - midpoint
|
||||
else:
|
||||
v.x = midpoint
|
||||
return v
|
||||
|
||||
|
||||
def apportion(v, default_ancestor, distance):
|
||||
w = v.lbrother()
|
||||
if w is not None:
|
||||
# in buchheim notation:
|
||||
# i == inner; o == outer; r == right; l == left; r = +; l = -
|
||||
vir = vor = v
|
||||
vil = w
|
||||
vol = v.lmost_sibling
|
||||
sir = sor = v.mod
|
||||
sil = vil.mod
|
||||
sol = vol.mod
|
||||
while vil.right() and vir.left():
|
||||
vil = vil.right()
|
||||
vir = vir.left()
|
||||
vol = vol.left()
|
||||
vor = vor.right()
|
||||
vor.ancestor = v
|
||||
shift = (vil.x + sil) - (vir.x + sir) + distance
|
||||
if shift > 0:
|
||||
move_subtree(ancestor(vil, v, default_ancestor), v, shift)
|
||||
sir = sir + shift
|
||||
sor = sor + shift
|
||||
sil += vil.mod
|
||||
sir += vir.mod
|
||||
sol += vol.mod
|
||||
sor += vor.mod
|
||||
if vil.right() and not vor.right():
|
||||
vor.thread = vil.right()
|
||||
vor.mod += sil - sor
|
||||
else:
|
||||
if vir.left() and not vol.left():
|
||||
vol.thread = vir.left()
|
||||
vol.mod += sir - sol
|
||||
default_ancestor = v
|
||||
return default_ancestor
|
||||
|
||||
|
||||
def move_subtree(wl, wr, shift):
|
||||
subtrees = wr.number - wl.number
|
||||
# print(wl.tree, "is conflicted with", wr.tree, 'moving', subtrees,
|
||||
# 'shift', shift)
|
||||
# print wl, wr, wr.number, wl.number, shift, subtrees, shift/subtrees
|
||||
wr.change -= shift / subtrees
|
||||
wr.shift += shift
|
||||
wl.change += shift / subtrees
|
||||
wr.x += shift
|
||||
wr.mod += shift
|
||||
|
||||
|
||||
def execute_shifts(v):
|
||||
shift = change = 0
|
||||
for w in v.children[::-1]:
|
||||
# print("shift:", w, shift, w.change)
|
||||
w.x += shift
|
||||
w.mod += shift
|
||||
change += w.change
|
||||
shift += w.shift + change
|
||||
|
||||
|
||||
def ancestor(vil, v, default_ancestor):
|
||||
# the relevant text is at the bottom of page 7 of
|
||||
# "Improving Walker's Algorithm to Run in Linear Time" by Buchheim et al,
|
||||
# (2002)
|
||||
# http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.16.8757&rep=rep1&type=pdf
|
||||
if vil.ancestor in v.parent.children:
|
||||
return vil.ancestor
|
||||
else:
|
||||
return default_ancestor
|
||||
|
||||
|
||||
def second_walk(v, m=0, depth=0, min=None):
|
||||
v.x += m
|
||||
v.y = depth
|
||||
|
||||
if min is None or v.x < min:
|
||||
min = v.x
|
||||
|
||||
for w in v.children:
|
||||
min = second_walk(w, m + v.mod, depth + 1, min)
|
||||
|
||||
return min
|
||||
|
||||
|
||||
class Tree:
|
||||
def __init__(self, label="", node_id=-1, *children):
|
||||
self.label = label
|
||||
self.node_id = node_id
|
||||
if children:
|
||||
self.children = children
|
||||
else:
|
||||
self.children = []
|
||||
Binary file not shown.
@@ -0,0 +1,93 @@
|
||||
# Authors: Gilles Louppe <g.louppe@gmail.com>
|
||||
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
|
||||
# Brian Holt <bdholt1@gmail.com>
|
||||
# Joel Nothman <joel.nothman@gmail.com>
|
||||
# Arnaud Joly <arnaud.v.joly@gmail.com>
|
||||
# Jacob Schreiber <jmschreiber91@gmail.com>
|
||||
#
|
||||
# License: BSD 3 clause
|
||||
|
||||
# See _splitter.pyx for details.
|
||||
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
|
||||
from ._criterion cimport Criterion
|
||||
|
||||
from ._tree cimport DTYPE_t # Type of X
|
||||
from ._tree cimport DOUBLE_t # Type of y, sample_weight
|
||||
from ._tree cimport SIZE_t # Type for indices and counters
|
||||
from ._tree cimport INT32_t # Signed 32 bit integer
|
||||
from ._tree cimport UINT32_t # Unsigned 32 bit integer
|
||||
|
||||
cdef struct SplitRecord:
|
||||
# Data to track sample split
|
||||
SIZE_t feature # Which feature to split on.
|
||||
SIZE_t pos # Split samples array at the given position,
|
||||
# i.e. count of samples below threshold for feature.
|
||||
# pos is >= end if the node is a leaf.
|
||||
double threshold # Threshold to split at.
|
||||
double improvement # Impurity improvement given parent node.
|
||||
double impurity_left # Impurity of the left split.
|
||||
double impurity_right # Impurity of the right split.
|
||||
|
||||
cdef class Splitter:
|
||||
# The splitter searches in the input space for a feature and a threshold
|
||||
# to split the samples samples[start:end].
|
||||
#
|
||||
# The impurity computations are delegated to a criterion object.
|
||||
|
||||
# Internal structures
|
||||
cdef public Criterion criterion # Impurity criterion
|
||||
cdef public SIZE_t max_features # Number of features to test
|
||||
cdef public SIZE_t min_samples_leaf # Min samples in a leaf
|
||||
cdef public double min_weight_leaf # Minimum weight in a leaf
|
||||
|
||||
cdef object random_state # Random state
|
||||
cdef UINT32_t rand_r_state # sklearn_rand_r random number state
|
||||
|
||||
cdef SIZE_t* samples # Sample indices in X, y
|
||||
cdef SIZE_t n_samples # X.shape[0]
|
||||
cdef double weighted_n_samples # Weighted number of samples
|
||||
cdef SIZE_t* features # Feature indices in X
|
||||
cdef SIZE_t* constant_features # Constant features indices
|
||||
cdef SIZE_t n_features # X.shape[1]
|
||||
cdef DTYPE_t* feature_values # temp. array holding feature values
|
||||
|
||||
cdef SIZE_t start # Start position for the current node
|
||||
cdef SIZE_t end # End position for the current node
|
||||
|
||||
cdef const DOUBLE_t[:, ::1] y
|
||||
cdef DOUBLE_t* sample_weight
|
||||
|
||||
# The samples vector `samples` is maintained by the Splitter object such
|
||||
# that the samples contained in a node are contiguous. With this setting,
|
||||
# `node_split` reorganizes the node samples `samples[start:end]` in two
|
||||
# subsets `samples[start:pos]` and `samples[pos:end]`.
|
||||
|
||||
# The 1-d `features` array of size n_features contains the features
|
||||
# indices and allows fast sampling without replacement of features.
|
||||
|
||||
# The 1-d `constant_features` array of size n_features holds in
|
||||
# `constant_features[:n_constant_features]` the feature ids with
|
||||
# constant values for all the samples that reached a specific node.
|
||||
# The value `n_constant_features` is given by the parent node to its
|
||||
# child nodes. The content of the range `[n_constant_features:]` is left
|
||||
# undefined, but preallocated for performance reasons
|
||||
# This allows optimization with depth-based tree building.
|
||||
|
||||
# Methods
|
||||
cdef int init(self, object X, const DOUBLE_t[:, ::1] y,
|
||||
DOUBLE_t* sample_weight) except -1
|
||||
|
||||
cdef int node_reset(self, SIZE_t start, SIZE_t end,
|
||||
double* weighted_n_node_samples) nogil except -1
|
||||
|
||||
cdef int node_split(self,
|
||||
double impurity, # Impurity of the node
|
||||
SplitRecord* split,
|
||||
SIZE_t* n_constant_features) nogil except -1
|
||||
|
||||
cdef void node_value(self, double* dest) nogil
|
||||
|
||||
cdef double node_impurity(self) nogil
|
||||
Binary file not shown.
@@ -0,0 +1,103 @@
|
||||
# Authors: Gilles Louppe <g.louppe@gmail.com>
|
||||
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
|
||||
# Brian Holt <bdholt1@gmail.com>
|
||||
# Joel Nothman <joel.nothman@gmail.com>
|
||||
# Arnaud Joly <arnaud.v.joly@gmail.com>
|
||||
# Jacob Schreiber <jmschreiber91@gmail.com>
|
||||
# Nelson Liu <nelson@nelsonliu.me>
|
||||
#
|
||||
# License: BSD 3 clause
|
||||
|
||||
# See _tree.pyx for details.
|
||||
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
|
||||
ctypedef np.npy_float32 DTYPE_t # Type of X
|
||||
ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight
|
||||
ctypedef np.npy_intp SIZE_t # Type for indices and counters
|
||||
ctypedef np.npy_int32 INT32_t # Signed 32 bit integer
|
||||
ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer
|
||||
|
||||
from ._splitter cimport Splitter
|
||||
from ._splitter cimport SplitRecord
|
||||
|
||||
cdef struct Node:
|
||||
# Base storage structure for the nodes in a Tree object
|
||||
|
||||
SIZE_t left_child # id of the left child of the node
|
||||
SIZE_t right_child # id of the right child of the node
|
||||
SIZE_t feature # Feature used for splitting the node
|
||||
DOUBLE_t threshold # Threshold value at the node
|
||||
DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion)
|
||||
SIZE_t n_node_samples # Number of samples at the node
|
||||
DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node
|
||||
|
||||
|
||||
cdef class Tree:
|
||||
# The Tree object is a binary tree structure constructed by the
|
||||
# TreeBuilder. The tree structure is used for predictions and
|
||||
# feature importances.
|
||||
|
||||
# Input/Output layout
|
||||
cdef public SIZE_t n_features # Number of features in X
|
||||
cdef SIZE_t* n_classes # Number of classes in y[:, k]
|
||||
cdef public SIZE_t n_outputs # Number of outputs in y
|
||||
cdef public SIZE_t max_n_classes # max(n_classes)
|
||||
|
||||
# Inner structures: values are stored separately from node structure,
|
||||
# since size is determined at runtime.
|
||||
cdef public SIZE_t max_depth # Max depth of the tree
|
||||
cdef public SIZE_t node_count # Counter for node IDs
|
||||
cdef public SIZE_t capacity # Capacity of tree, in terms of nodes
|
||||
cdef Node* nodes # Array of nodes
|
||||
cdef double* value # (capacity, n_outputs, max_n_classes) array of values
|
||||
cdef SIZE_t value_stride # = n_outputs * max_n_classes
|
||||
|
||||
# Methods
|
||||
cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf,
|
||||
SIZE_t feature, double threshold, double impurity,
|
||||
SIZE_t n_node_samples,
|
||||
double weighted_n_node_samples) nogil except -1
|
||||
cdef int _resize(self, SIZE_t capacity) nogil except -1
|
||||
cdef int _resize_c(self, SIZE_t capacity=*) nogil except -1
|
||||
|
||||
cdef np.ndarray _get_value_ndarray(self)
|
||||
cdef np.ndarray _get_node_ndarray(self)
|
||||
|
||||
cpdef np.ndarray predict(self, object X)
|
||||
|
||||
cpdef np.ndarray apply(self, object X)
|
||||
cdef np.ndarray _apply_dense(self, object X)
|
||||
cdef np.ndarray _apply_sparse_csr(self, object X)
|
||||
|
||||
cpdef object decision_path(self, object X)
|
||||
cdef object _decision_path_dense(self, object X)
|
||||
cdef object _decision_path_sparse_csr(self, object X)
|
||||
|
||||
cpdef compute_feature_importances(self, normalize=*)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tree builder
|
||||
# =============================================================================
|
||||
|
||||
cdef class TreeBuilder:
|
||||
# The TreeBuilder recursively builds a Tree object from training samples,
|
||||
# using a Splitter object for splitting internal nodes and assigning
|
||||
# values to leaves.
|
||||
#
|
||||
# This class controls the various stopping criteria and the node splitting
|
||||
# evaluation order, e.g. depth-first or best-first.
|
||||
|
||||
cdef Splitter splitter # Splitting algorithm
|
||||
|
||||
cdef SIZE_t min_samples_split # Minimum number of samples in an internal node
|
||||
cdef SIZE_t min_samples_leaf # Minimum number of samples in a leaf
|
||||
cdef double min_weight_leaf # Minimum weight in a leaf
|
||||
cdef SIZE_t max_depth # Maximal tree depth
|
||||
cdef double min_impurity_decrease # Impurity threshold for early stopping
|
||||
|
||||
cpdef build(self, Tree tree, object X, np.ndarray y,
|
||||
np.ndarray sample_weight=*)
|
||||
cdef _check_input(self, object X, np.ndarray y, np.ndarray sample_weight)
|
||||
Binary file not shown.
@@ -0,0 +1,111 @@
|
||||
# Authors: Gilles Louppe <g.louppe@gmail.com>
|
||||
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
|
||||
# Arnaud Joly <arnaud.v.joly@gmail.com>
|
||||
# Jacob Schreiber <jmschreiber91@gmail.com>
|
||||
# Nelson Liu <nelson@nelsonliu.me>
|
||||
#
|
||||
# License: BSD 3 clause
|
||||
|
||||
# See _utils.pyx for details.
|
||||
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
from ._tree cimport Node
|
||||
from ..neighbors._quad_tree cimport Cell
|
||||
|
||||
ctypedef np.npy_float32 DTYPE_t # Type of X
|
||||
ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight
|
||||
ctypedef np.npy_intp SIZE_t # Type for indices and counters
|
||||
ctypedef np.npy_int32 INT32_t # Signed 32 bit integer
|
||||
ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer
|
||||
|
||||
|
||||
cdef enum:
|
||||
# Max value for our rand_r replacement (near the bottom).
|
||||
# We don't use RAND_MAX because it's different across platforms and
|
||||
# particularly tiny on Windows/MSVC.
|
||||
RAND_R_MAX = 0x7FFFFFFF
|
||||
|
||||
|
||||
# safe_realloc(&p, n) resizes the allocation of p to n * sizeof(*p) bytes or
|
||||
# raises a MemoryError. It never calls free, since that's __dealloc__'s job.
|
||||
# cdef DTYPE_t *p = NULL
|
||||
# safe_realloc(&p, n)
|
||||
# is equivalent to p = malloc(n * sizeof(*p)) with error checking.
|
||||
ctypedef fused realloc_ptr:
|
||||
# Add pointer types here as needed.
|
||||
(DTYPE_t*)
|
||||
(SIZE_t*)
|
||||
(unsigned char*)
|
||||
(WeightedPQueueRecord*)
|
||||
(DOUBLE_t*)
|
||||
(DOUBLE_t**)
|
||||
(Node*)
|
||||
(Cell*)
|
||||
(Node**)
|
||||
|
||||
cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) nogil except *
|
||||
|
||||
|
||||
cdef np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size)
|
||||
|
||||
|
||||
cdef SIZE_t rand_int(SIZE_t low, SIZE_t high,
|
||||
UINT32_t* random_state) nogil
|
||||
|
||||
|
||||
cdef double rand_uniform(double low, double high,
|
||||
UINT32_t* random_state) nogil
|
||||
|
||||
|
||||
cdef double log(double x) nogil
|
||||
|
||||
# =============================================================================
|
||||
# WeightedPQueue data structure
|
||||
# =============================================================================
|
||||
|
||||
# A record stored in the WeightedPQueue
|
||||
cdef struct WeightedPQueueRecord:
|
||||
DOUBLE_t data
|
||||
DOUBLE_t weight
|
||||
|
||||
cdef class WeightedPQueue:
|
||||
cdef SIZE_t capacity
|
||||
cdef SIZE_t array_ptr
|
||||
cdef WeightedPQueueRecord* array_
|
||||
|
||||
cdef bint is_empty(self) nogil
|
||||
cdef int reset(self) nogil except -1
|
||||
cdef SIZE_t size(self) nogil
|
||||
cdef int push(self, DOUBLE_t data, DOUBLE_t weight) nogil except -1
|
||||
cdef int remove(self, DOUBLE_t data, DOUBLE_t weight) nogil
|
||||
cdef int pop(self, DOUBLE_t* data, DOUBLE_t* weight) nogil
|
||||
cdef int peek(self, DOUBLE_t* data, DOUBLE_t* weight) nogil
|
||||
cdef DOUBLE_t get_weight_from_index(self, SIZE_t index) nogil
|
||||
cdef DOUBLE_t get_value_from_index(self, SIZE_t index) nogil
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# WeightedMedianCalculator data structure
|
||||
# =============================================================================
|
||||
|
||||
cdef class WeightedMedianCalculator:
|
||||
cdef SIZE_t initial_capacity
|
||||
cdef WeightedPQueue samples
|
||||
cdef DOUBLE_t total_weight
|
||||
cdef SIZE_t k
|
||||
cdef DOUBLE_t sum_w_0_k # represents sum(weights[0:k])
|
||||
# = w[0] + w[1] + ... + w[k-1]
|
||||
|
||||
cdef SIZE_t size(self) nogil
|
||||
cdef int push(self, DOUBLE_t data, DOUBLE_t weight) nogil except -1
|
||||
cdef int reset(self) nogil except -1
|
||||
cdef int update_median_parameters_post_push(
|
||||
self, DOUBLE_t data, DOUBLE_t weight,
|
||||
DOUBLE_t original_median) nogil
|
||||
cdef int remove(self, DOUBLE_t data, DOUBLE_t weight) nogil
|
||||
cdef int pop(self, DOUBLE_t* data, DOUBLE_t* weight) nogil
|
||||
cdef int update_median_parameters_post_remove(
|
||||
self, DOUBLE_t data, DOUBLE_t weight,
|
||||
DOUBLE_t original_median) nogil
|
||||
cdef DOUBLE_t get_median(self) nogil
|
||||
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
|
||||
import numpy
|
||||
from numpy.distutils.misc_util import Configuration
|
||||
|
||||
|
||||
def configuration(parent_package="", top_path=None):
|
||||
config = Configuration("tree", parent_package, top_path)
|
||||
libraries = []
|
||||
if os.name == "posix":
|
||||
libraries.append("m")
|
||||
config.add_extension(
|
||||
"_tree",
|
||||
sources=["_tree.pyx"],
|
||||
include_dirs=[numpy.get_include()],
|
||||
libraries=libraries,
|
||||
language="c++",
|
||||
extra_compile_args=["-O3"],
|
||||
)
|
||||
config.add_extension(
|
||||
"_splitter",
|
||||
sources=["_splitter.pyx"],
|
||||
include_dirs=[numpy.get_include()],
|
||||
libraries=libraries,
|
||||
extra_compile_args=["-O3"],
|
||||
)
|
||||
config.add_extension(
|
||||
"_criterion",
|
||||
sources=["_criterion.pyx"],
|
||||
include_dirs=[numpy.get_include()],
|
||||
libraries=libraries,
|
||||
extra_compile_args=["-O3"],
|
||||
)
|
||||
config.add_extension(
|
||||
"_utils",
|
||||
sources=["_utils.pyx"],
|
||||
include_dirs=[numpy.get_include()],
|
||||
libraries=libraries,
|
||||
extra_compile_args=["-O3"],
|
||||
)
|
||||
|
||||
config.add_subpackage("tests")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from numpy.distutils.core import setup
|
||||
|
||||
setup(**configuration().todict())
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,512 @@
|
||||
"""
|
||||
Testing for export functions of decision trees (sklearn.tree.export).
|
||||
"""
|
||||
from re import finditer, search
|
||||
from textwrap import dedent
|
||||
|
||||
from numpy.random import RandomState
|
||||
import pytest
|
||||
|
||||
from sklearn.base import is_classifier
|
||||
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
|
||||
from sklearn.ensemble import GradientBoostingClassifier
|
||||
from sklearn.tree import export_graphviz, plot_tree, export_text
|
||||
from io import StringIO
|
||||
from sklearn.exceptions import NotFittedError
|
||||
|
||||
# toy sample
|
||||
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
|
||||
y = [-1, -1, -1, 1, 1, 1]
|
||||
y2 = [[-1, 1], [-1, 1], [-1, 1], [1, 2], [1, 2], [1, 3]]
|
||||
w = [1, 1, 1, 0.5, 0.5, 0.5]
|
||||
y_degraded = [1, 1, 1, 1, 1, 1]
|
||||
|
||||
|
||||
def test_graphviz_toy():
|
||||
# Check correctness of export_graphviz
|
||||
clf = DecisionTreeClassifier(
|
||||
max_depth=3, min_samples_split=2, criterion="gini", random_state=2
|
||||
)
|
||||
clf.fit(X, y)
|
||||
|
||||
# Test export code
|
||||
contents1 = export_graphviz(clf, out_file=None)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, fontname="helvetica"] ;\n'
|
||||
'edge [fontname="helvetica"] ;\n'
|
||||
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
|
||||
'value = [3, 3]"] ;\n'
|
||||
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
|
||||
"0 -> 1 [labeldistance=2.5, labelangle=45, "
|
||||
'headlabel="True"] ;\n'
|
||||
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
|
||||
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
|
||||
'headlabel="False"] ;\n'
|
||||
"}"
|
||||
)
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test with feature_names
|
||||
contents1 = export_graphviz(
|
||||
clf, feature_names=["feature0", "feature1"], out_file=None
|
||||
)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, fontname="helvetica"] ;\n'
|
||||
'edge [fontname="helvetica"] ;\n'
|
||||
'0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
|
||||
'value = [3, 3]"] ;\n'
|
||||
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
|
||||
"0 -> 1 [labeldistance=2.5, labelangle=45, "
|
||||
'headlabel="True"] ;\n'
|
||||
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
|
||||
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
|
||||
'headlabel="False"] ;\n'
|
||||
"}"
|
||||
)
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test with class_names
|
||||
contents1 = export_graphviz(clf, class_names=["yes", "no"], out_file=None)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, fontname="helvetica"] ;\n'
|
||||
'edge [fontname="helvetica"] ;\n'
|
||||
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
|
||||
'value = [3, 3]\\nclass = yes"] ;\n'
|
||||
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n'
|
||||
'class = yes"] ;\n'
|
||||
"0 -> 1 [labeldistance=2.5, labelangle=45, "
|
||||
'headlabel="True"] ;\n'
|
||||
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n'
|
||||
'class = no"] ;\n'
|
||||
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
|
||||
'headlabel="False"] ;\n'
|
||||
"}"
|
||||
)
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test plot_options
|
||||
contents1 = export_graphviz(
|
||||
clf,
|
||||
filled=True,
|
||||
impurity=False,
|
||||
proportion=True,
|
||||
special_characters=True,
|
||||
rounded=True,
|
||||
out_file=None,
|
||||
fontname="sans",
|
||||
)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, style="filled, rounded", color="black", '
|
||||
'fontname="sans"] ;\n'
|
||||
'edge [fontname="sans"] ;\n'
|
||||
"0 [label=<X<SUB>0</SUB> ≤ 0.0<br/>samples = 100.0%<br/>"
|
||||
'value = [0.5, 0.5]>, fillcolor="#ffffff"] ;\n'
|
||||
"1 [label=<samples = 50.0%<br/>value = [1.0, 0.0]>, "
|
||||
'fillcolor="#e58139"] ;\n'
|
||||
"0 -> 1 [labeldistance=2.5, labelangle=45, "
|
||||
'headlabel="True"] ;\n'
|
||||
"2 [label=<samples = 50.0%<br/>value = [0.0, 1.0]>, "
|
||||
'fillcolor="#399de5"] ;\n'
|
||||
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
|
||||
'headlabel="False"] ;\n'
|
||||
"}"
|
||||
)
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test max_depth
|
||||
contents1 = export_graphviz(clf, max_depth=0, class_names=True, out_file=None)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, fontname="helvetica"] ;\n'
|
||||
'edge [fontname="helvetica"] ;\n'
|
||||
'0 [label="X[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
|
||||
'value = [3, 3]\\nclass = y[0]"] ;\n'
|
||||
'1 [label="(...)"] ;\n'
|
||||
"0 -> 1 ;\n"
|
||||
'2 [label="(...)"] ;\n'
|
||||
"0 -> 2 ;\n"
|
||||
"}"
|
||||
)
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test max_depth with plot_options
|
||||
contents1 = export_graphviz(
|
||||
clf, max_depth=0, filled=True, out_file=None, node_ids=True
|
||||
)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, style="filled", color="black", '
|
||||
'fontname="helvetica"] ;\n'
|
||||
'edge [fontname="helvetica"] ;\n'
|
||||
'0 [label="node #0\\nX[0] <= 0.0\\ngini = 0.5\\n'
|
||||
'samples = 6\\nvalue = [3, 3]", fillcolor="#ffffff"] ;\n'
|
||||
'1 [label="(...)", fillcolor="#C0C0C0"] ;\n'
|
||||
"0 -> 1 ;\n"
|
||||
'2 [label="(...)", fillcolor="#C0C0C0"] ;\n'
|
||||
"0 -> 2 ;\n"
|
||||
"}"
|
||||
)
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test multi-output with weighted samples
|
||||
clf = DecisionTreeClassifier(
|
||||
max_depth=2, min_samples_split=2, criterion="gini", random_state=2
|
||||
)
|
||||
clf = clf.fit(X, y2, sample_weight=w)
|
||||
|
||||
contents1 = export_graphviz(clf, filled=True, impurity=False, out_file=None)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, style="filled", color="black", '
|
||||
'fontname="helvetica"] ;\n'
|
||||
'edge [fontname="helvetica"] ;\n'
|
||||
'0 [label="X[0] <= 0.0\\nsamples = 6\\n'
|
||||
"value = [[3.0, 1.5, 0.0]\\n"
|
||||
'[3.0, 1.0, 0.5]]", fillcolor="#ffffff"] ;\n'
|
||||
'1 [label="samples = 3\\nvalue = [[3, 0, 0]\\n'
|
||||
'[3, 0, 0]]", fillcolor="#e58139"] ;\n'
|
||||
"0 -> 1 [labeldistance=2.5, labelangle=45, "
|
||||
'headlabel="True"] ;\n'
|
||||
'2 [label="X[0] <= 1.5\\nsamples = 3\\n'
|
||||
"value = [[0.0, 1.5, 0.0]\\n"
|
||||
'[0.0, 1.0, 0.5]]", fillcolor="#f1bd97"] ;\n'
|
||||
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
|
||||
'headlabel="False"] ;\n'
|
||||
'3 [label="samples = 2\\nvalue = [[0, 1, 0]\\n'
|
||||
'[0, 1, 0]]", fillcolor="#e58139"] ;\n'
|
||||
"2 -> 3 ;\n"
|
||||
'4 [label="samples = 1\\nvalue = [[0.0, 0.5, 0.0]\\n'
|
||||
'[0.0, 0.0, 0.5]]", fillcolor="#e58139"] ;\n'
|
||||
"2 -> 4 ;\n"
|
||||
"}"
|
||||
)
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test regression output with plot_options
|
||||
clf = DecisionTreeRegressor(
|
||||
max_depth=3, min_samples_split=2, criterion="squared_error", random_state=2
|
||||
)
|
||||
clf.fit(X, y)
|
||||
|
||||
contents1 = export_graphviz(
|
||||
clf,
|
||||
filled=True,
|
||||
leaves_parallel=True,
|
||||
out_file=None,
|
||||
rotate=True,
|
||||
rounded=True,
|
||||
fontname="sans",
|
||||
)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, style="filled, rounded", color="black", '
|
||||
'fontname="sans"] ;\n'
|
||||
"graph [ranksep=equally, splines=polyline] ;\n"
|
||||
'edge [fontname="sans"] ;\n'
|
||||
"rankdir=LR ;\n"
|
||||
'0 [label="X[0] <= 0.0\\nsquared_error = 1.0\\nsamples = 6\\n'
|
||||
'value = 0.0", fillcolor="#f2c09c"] ;\n'
|
||||
'1 [label="squared_error = 0.0\\nsamples = 3\\'
|
||||
'nvalue = -1.0", '
|
||||
'fillcolor="#ffffff"] ;\n'
|
||||
"0 -> 1 [labeldistance=2.5, labelangle=-45, "
|
||||
'headlabel="True"] ;\n'
|
||||
'2 [label="squared_error = 0.0\\nsamples = 3\\nvalue = 1.0", '
|
||||
'fillcolor="#e58139"] ;\n'
|
||||
"0 -> 2 [labeldistance=2.5, labelangle=45, "
|
||||
'headlabel="False"] ;\n'
|
||||
"{rank=same ; 0} ;\n"
|
||||
"{rank=same ; 1; 2} ;\n"
|
||||
"}"
|
||||
)
|
||||
|
||||
assert contents1 == contents2
|
||||
|
||||
# Test classifier with degraded learning set
|
||||
clf = DecisionTreeClassifier(max_depth=3)
|
||||
clf.fit(X, y_degraded)
|
||||
|
||||
contents1 = export_graphviz(clf, filled=True, out_file=None)
|
||||
contents2 = (
|
||||
"digraph Tree {\n"
|
||||
'node [shape=box, style="filled", color="black", '
|
||||
'fontname="helvetica"] ;\n'
|
||||
'edge [fontname="helvetica"] ;\n'
|
||||
'0 [label="gini = 0.0\\nsamples = 6\\nvalue = 6.0", '
|
||||
'fillcolor="#ffffff"] ;\n'
|
||||
"}"
|
||||
)
|
||||
|
||||
|
||||
def test_graphviz_errors():
|
||||
# Check for errors of export_graphviz
|
||||
clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2)
|
||||
|
||||
# Check not-fitted decision tree error
|
||||
out = StringIO()
|
||||
with pytest.raises(NotFittedError):
|
||||
export_graphviz(clf, out)
|
||||
|
||||
clf.fit(X, y)
|
||||
|
||||
# Check if it errors when length of feature_names
|
||||
# mismatches with number of features
|
||||
message = "Length of feature_names, 1 does not match number of features, 2"
|
||||
with pytest.raises(ValueError, match=message):
|
||||
export_graphviz(clf, None, feature_names=["a"])
|
||||
|
||||
message = "Length of feature_names, 3 does not match number of features, 2"
|
||||
with pytest.raises(ValueError, match=message):
|
||||
export_graphviz(clf, None, feature_names=["a", "b", "c"])
|
||||
|
||||
# Check error when argument is not an estimator
|
||||
message = "is not an estimator instance"
|
||||
with pytest.raises(TypeError, match=message):
|
||||
export_graphviz(clf.fit(X, y).tree_)
|
||||
|
||||
# Check class_names error
|
||||
out = StringIO()
|
||||
with pytest.raises(IndexError):
|
||||
export_graphviz(clf, out, class_names=[])
|
||||
|
||||
# Check precision error
|
||||
out = StringIO()
|
||||
with pytest.raises(ValueError, match="should be greater or equal"):
|
||||
export_graphviz(clf, out, precision=-1)
|
||||
with pytest.raises(ValueError, match="should be an integer"):
|
||||
export_graphviz(clf, out, precision="1")
|
||||
|
||||
|
||||
def test_friedman_mse_in_graphviz():
|
||||
clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0)
|
||||
clf.fit(X, y)
|
||||
dot_data = StringIO()
|
||||
export_graphviz(clf, out_file=dot_data)
|
||||
|
||||
clf = GradientBoostingClassifier(n_estimators=2, random_state=0)
|
||||
clf.fit(X, y)
|
||||
for estimator in clf.estimators_:
|
||||
export_graphviz(estimator[0], out_file=dot_data)
|
||||
|
||||
for finding in finditer(r"\[.*?samples.*?\]", dot_data.getvalue()):
|
||||
assert "friedman_mse" in finding.group()
|
||||
|
||||
|
||||
def test_precision():
|
||||
|
||||
rng_reg = RandomState(2)
|
||||
rng_clf = RandomState(8)
|
||||
for X, y, clf in zip(
|
||||
(rng_reg.random_sample((5, 2)), rng_clf.random_sample((1000, 4))),
|
||||
(rng_reg.random_sample((5,)), rng_clf.randint(2, size=(1000,))),
|
||||
(
|
||||
DecisionTreeRegressor(
|
||||
criterion="friedman_mse", random_state=0, max_depth=1
|
||||
),
|
||||
DecisionTreeClassifier(max_depth=1, random_state=0),
|
||||
),
|
||||
):
|
||||
|
||||
clf.fit(X, y)
|
||||
for precision in (4, 3):
|
||||
dot_data = export_graphviz(
|
||||
clf, out_file=None, precision=precision, proportion=True
|
||||
)
|
||||
|
||||
# With the current random state, the impurity and the threshold
|
||||
# will have the number of precision set in the export_graphviz
|
||||
# function. We will check the number of precision with a strict
|
||||
# equality. The value reported will have only 2 precision and
|
||||
# therefore, only a less equal comparison will be done.
|
||||
|
||||
# check value
|
||||
for finding in finditer(r"value = \d+\.\d+", dot_data):
|
||||
assert len(search(r"\.\d+", finding.group()).group()) <= precision + 1
|
||||
# check impurity
|
||||
if is_classifier(clf):
|
||||
pattern = r"gini = \d+\.\d+"
|
||||
else:
|
||||
pattern = r"friedman_mse = \d+\.\d+"
|
||||
|
||||
# check impurity
|
||||
for finding in finditer(pattern, dot_data):
|
||||
assert len(search(r"\.\d+", finding.group()).group()) == precision + 1
|
||||
# check threshold
|
||||
for finding in finditer(r"<= \d+\.\d+", dot_data):
|
||||
assert len(search(r"\.\d+", finding.group()).group()) == precision + 1
|
||||
|
||||
|
||||
def test_export_text_errors():
|
||||
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
|
||||
clf.fit(X, y)
|
||||
|
||||
err_msg = "max_depth bust be >= 0, given -1"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
export_text(clf, max_depth=-1)
|
||||
err_msg = "feature_names must contain 2 elements, got 1"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
export_text(clf, feature_names=["a"])
|
||||
err_msg = "decimals must be >= 0, given -1"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
export_text(clf, decimals=-1)
|
||||
err_msg = "spacing must be > 0, given 0"
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
export_text(clf, spacing=0)
|
||||
|
||||
|
||||
def test_export_text():
|
||||
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
|
||||
clf.fit(X, y)
|
||||
|
||||
expected_report = dedent(
|
||||
"""
|
||||
|--- feature_1 <= 0.00
|
||||
| |--- class: -1
|
||||
|--- feature_1 > 0.00
|
||||
| |--- class: 1
|
||||
"""
|
||||
).lstrip()
|
||||
|
||||
assert export_text(clf) == expected_report
|
||||
# testing that leaves at level 1 are not truncated
|
||||
assert export_text(clf, max_depth=0) == expected_report
|
||||
# testing that the rest of the tree is truncated
|
||||
assert export_text(clf, max_depth=10) == expected_report
|
||||
|
||||
expected_report = dedent(
|
||||
"""
|
||||
|--- b <= 0.00
|
||||
| |--- class: -1
|
||||
|--- b > 0.00
|
||||
| |--- class: 1
|
||||
"""
|
||||
).lstrip()
|
||||
assert export_text(clf, feature_names=["a", "b"]) == expected_report
|
||||
|
||||
expected_report = dedent(
|
||||
"""
|
||||
|--- feature_1 <= 0.00
|
||||
| |--- weights: [3.00, 0.00] class: -1
|
||||
|--- feature_1 > 0.00
|
||||
| |--- weights: [0.00, 3.00] class: 1
|
||||
"""
|
||||
).lstrip()
|
||||
assert export_text(clf, show_weights=True) == expected_report
|
||||
|
||||
expected_report = dedent(
|
||||
"""
|
||||
|- feature_1 <= 0.00
|
||||
| |- class: -1
|
||||
|- feature_1 > 0.00
|
||||
| |- class: 1
|
||||
"""
|
||||
).lstrip()
|
||||
assert export_text(clf, spacing=1) == expected_report
|
||||
|
||||
X_l = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [-1, 1]]
|
||||
y_l = [-1, -1, -1, 1, 1, 1, 2]
|
||||
clf = DecisionTreeClassifier(max_depth=4, random_state=0)
|
||||
clf.fit(X_l, y_l)
|
||||
expected_report = dedent(
|
||||
"""
|
||||
|--- feature_1 <= 0.00
|
||||
| |--- class: -1
|
||||
|--- feature_1 > 0.00
|
||||
| |--- truncated branch of depth 2
|
||||
"""
|
||||
).lstrip()
|
||||
assert export_text(clf, max_depth=0) == expected_report
|
||||
|
||||
X_mo = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
|
||||
y_mo = [[-1, -1], [-1, -1], [-1, -1], [1, 1], [1, 1], [1, 1]]
|
||||
|
||||
reg = DecisionTreeRegressor(max_depth=2, random_state=0)
|
||||
reg.fit(X_mo, y_mo)
|
||||
|
||||
expected_report = dedent(
|
||||
"""
|
||||
|--- feature_1 <= 0.0
|
||||
| |--- value: [-1.0, -1.0]
|
||||
|--- feature_1 > 0.0
|
||||
| |--- value: [1.0, 1.0]
|
||||
"""
|
||||
).lstrip()
|
||||
assert export_text(reg, decimals=1) == expected_report
|
||||
assert export_text(reg, decimals=1, show_weights=True) == expected_report
|
||||
|
||||
X_single = [[-2], [-1], [-1], [1], [1], [2]]
|
||||
reg = DecisionTreeRegressor(max_depth=2, random_state=0)
|
||||
reg.fit(X_single, y_mo)
|
||||
|
||||
expected_report = dedent(
|
||||
"""
|
||||
|--- first <= 0.0
|
||||
| |--- value: [-1.0, -1.0]
|
||||
|--- first > 0.0
|
||||
| |--- value: [1.0, 1.0]
|
||||
"""
|
||||
).lstrip()
|
||||
assert export_text(reg, decimals=1, feature_names=["first"]) == expected_report
|
||||
assert (
|
||||
export_text(reg, decimals=1, show_weights=True, feature_names=["first"])
|
||||
== expected_report
|
||||
)
|
||||
|
||||
|
||||
def test_plot_tree_entropy(pyplot):
|
||||
# mostly smoke tests
|
||||
# Check correctness of export_graphviz for criterion = entropy
|
||||
clf = DecisionTreeClassifier(
|
||||
max_depth=3, min_samples_split=2, criterion="entropy", random_state=2
|
||||
)
|
||||
clf.fit(X, y)
|
||||
|
||||
# Test export code
|
||||
feature_names = ["first feat", "sepal_width"]
|
||||
nodes = plot_tree(clf, feature_names=feature_names)
|
||||
assert len(nodes) == 3
|
||||
assert (
|
||||
nodes[0].get_text()
|
||||
== "first feat <= 0.0\nentropy = 1.0\nsamples = 6\nvalue = [3, 3]"
|
||||
)
|
||||
assert nodes[1].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [3, 0]"
|
||||
assert nodes[2].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [0, 3]"
|
||||
|
||||
|
||||
def test_plot_tree_gini(pyplot):
|
||||
# mostly smoke tests
|
||||
# Check correctness of export_graphviz for criterion = gini
|
||||
clf = DecisionTreeClassifier(
|
||||
max_depth=3, min_samples_split=2, criterion="gini", random_state=2
|
||||
)
|
||||
clf.fit(X, y)
|
||||
|
||||
# Test export code
|
||||
feature_names = ["first feat", "sepal_width"]
|
||||
nodes = plot_tree(clf, feature_names=feature_names)
|
||||
assert len(nodes) == 3
|
||||
assert (
|
||||
nodes[0].get_text()
|
||||
== "first feat <= 0.0\ngini = 0.5\nsamples = 6\nvalue = [3, 3]"
|
||||
)
|
||||
assert nodes[1].get_text() == "gini = 0.0\nsamples = 3\nvalue = [3, 0]"
|
||||
assert nodes[2].get_text() == "gini = 0.0\nsamples = 3\nvalue = [0, 3]"
|
||||
|
||||
|
||||
def test_not_fitted_tree(pyplot):
|
||||
|
||||
# Testing if not fitted tree throws the correct error
|
||||
clf = DecisionTreeRegressor()
|
||||
with pytest.raises(NotFittedError):
|
||||
plot_tree(clf)
|
||||
@@ -0,0 +1,48 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from sklearn.tree._reingold_tilford import buchheim, Tree
|
||||
|
||||
simple_tree = Tree("", 0, Tree("", 1), Tree("", 2))
|
||||
|
||||
bigger_tree = Tree(
|
||||
"",
|
||||
0,
|
||||
Tree(
|
||||
"",
|
||||
1,
|
||||
Tree("", 3),
|
||||
Tree("", 4, Tree("", 7), Tree("", 8)),
|
||||
),
|
||||
Tree("", 2, Tree("", 5), Tree("", 6)),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tree, n_nodes", [(simple_tree, 3), (bigger_tree, 9)])
|
||||
def test_buchheim(tree, n_nodes):
|
||||
def walk_tree(draw_tree):
|
||||
res = [(draw_tree.x, draw_tree.y)]
|
||||
for child in draw_tree.children:
|
||||
# parents higher than children:
|
||||
assert child.y == draw_tree.y + 1
|
||||
res.extend(walk_tree(child))
|
||||
if len(draw_tree.children):
|
||||
# these trees are always binary
|
||||
# parents are centered above children
|
||||
assert (
|
||||
draw_tree.x == (draw_tree.children[0].x + draw_tree.children[1].x) / 2
|
||||
)
|
||||
return res
|
||||
|
||||
layout = buchheim(tree)
|
||||
coordinates = walk_tree(layout)
|
||||
assert len(coordinates) == n_nodes
|
||||
# test that x values are unique per depth / level
|
||||
# we could also do it quicker using defaultdicts..
|
||||
depth = 0
|
||||
while True:
|
||||
x_at_this_depth = [node[0] for node in coordinates if node[1] == depth]
|
||||
if not x_at_this_depth:
|
||||
# reached all leafs
|
||||
break
|
||||
assert len(np.unique(x_at_this_depth)) == len(x_at_this_depth)
|
||||
depth += 1
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user