first commit
This commit is contained in:
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
Hypothesis data generator helpers.
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
from hypothesis import strategies as st
|
||||
from hypothesis.extra.dateutil import timezones as dateutil_timezones
|
||||
from hypothesis.extra.pytz import timezones as pytz_timezones
|
||||
|
||||
from pandas.compat import is_platform_windows
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from pandas.tseries.offsets import (
|
||||
BMonthBegin,
|
||||
BMonthEnd,
|
||||
BQuarterBegin,
|
||||
BQuarterEnd,
|
||||
BYearBegin,
|
||||
BYearEnd,
|
||||
MonthBegin,
|
||||
MonthEnd,
|
||||
QuarterBegin,
|
||||
QuarterEnd,
|
||||
YearBegin,
|
||||
YearEnd,
|
||||
)
|
||||
|
||||
OPTIONAL_INTS = st.lists(st.one_of(st.integers(), st.none()), max_size=10, min_size=3)
|
||||
|
||||
OPTIONAL_FLOATS = st.lists(st.one_of(st.floats(), st.none()), max_size=10, min_size=3)
|
||||
|
||||
OPTIONAL_TEXT = st.lists(st.one_of(st.none(), st.text()), max_size=10, min_size=3)
|
||||
|
||||
OPTIONAL_DICTS = st.lists(
|
||||
st.one_of(st.none(), st.dictionaries(st.text(), st.integers())),
|
||||
max_size=10,
|
||||
min_size=3,
|
||||
)
|
||||
|
||||
OPTIONAL_LISTS = st.lists(
|
||||
st.one_of(st.none(), st.lists(st.text(), max_size=10, min_size=3)),
|
||||
max_size=10,
|
||||
min_size=3,
|
||||
)
|
||||
|
||||
OPTIONAL_ONE_OF_ALL = st.one_of(
|
||||
OPTIONAL_DICTS, OPTIONAL_FLOATS, OPTIONAL_INTS, OPTIONAL_LISTS, OPTIONAL_TEXT
|
||||
)
|
||||
|
||||
if is_platform_windows():
|
||||
DATETIME_NO_TZ = st.datetimes(min_value=datetime(1900, 1, 1))
|
||||
else:
|
||||
DATETIME_NO_TZ = st.datetimes()
|
||||
|
||||
DATETIME_JAN_1_1900_OPTIONAL_TZ = st.datetimes(
|
||||
min_value=pd.Timestamp(1900, 1, 1).to_pydatetime(),
|
||||
max_value=pd.Timestamp(1900, 1, 1).to_pydatetime(),
|
||||
timezones=st.one_of(st.none(), dateutil_timezones(), pytz_timezones()),
|
||||
)
|
||||
|
||||
DATETIME_IN_PD_TIMESTAMP_RANGE_NO_TZ = st.datetimes(
|
||||
min_value=pd.Timestamp.min.to_pydatetime(warn=False),
|
||||
max_value=pd.Timestamp.max.to_pydatetime(warn=False),
|
||||
)
|
||||
|
||||
INT_NEG_999_TO_POS_999 = st.integers(-999, 999)
|
||||
|
||||
# The strategy for each type is registered in conftest.py, as they don't carry
|
||||
# enough runtime information (e.g. type hints) to infer how to build them.
|
||||
YQM_OFFSET = st.one_of(
|
||||
*map(
|
||||
st.from_type,
|
||||
[
|
||||
MonthBegin,
|
||||
MonthEnd,
|
||||
BMonthBegin,
|
||||
BMonthEnd,
|
||||
QuarterBegin,
|
||||
QuarterEnd,
|
||||
BQuarterBegin,
|
||||
BQuarterEnd,
|
||||
YearBegin,
|
||||
YearEnd,
|
||||
BYearBegin,
|
||||
BYearEnd,
|
||||
],
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,430 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import bz2
|
||||
from functools import wraps
|
||||
import gzip
|
||||
import socket
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
)
|
||||
import zipfile
|
||||
|
||||
from pandas._typing import (
|
||||
FilePath,
|
||||
ReadPickleBuffer,
|
||||
)
|
||||
from pandas.compat import get_lzma_file
|
||||
from pandas.compat._optional import import_optional_dependency
|
||||
|
||||
import pandas as pd
|
||||
from pandas._testing._random import rands
|
||||
from pandas._testing.contexts import ensure_clean
|
||||
|
||||
from pandas.io.common import urlopen
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pandas import (
|
||||
DataFrame,
|
||||
Series,
|
||||
)
|
||||
|
||||
# skip tests on exceptions with these messages
|
||||
_network_error_messages = (
|
||||
# 'urlopen error timed out',
|
||||
# 'timeout: timed out',
|
||||
# 'socket.timeout: timed out',
|
||||
"timed out",
|
||||
"Server Hangup",
|
||||
"HTTP Error 503: Service Unavailable",
|
||||
"502: Proxy Error",
|
||||
"HTTP Error 502: internal error",
|
||||
"HTTP Error 502",
|
||||
"HTTP Error 503",
|
||||
"HTTP Error 403",
|
||||
"HTTP Error 400",
|
||||
"Temporary failure in name resolution",
|
||||
"Name or service not known",
|
||||
"Connection refused",
|
||||
"certificate verify",
|
||||
)
|
||||
|
||||
# or this e.errno/e.reason.errno
|
||||
_network_errno_vals = (
|
||||
101, # Network is unreachable
|
||||
111, # Connection refused
|
||||
110, # Connection timed out
|
||||
104, # Connection reset Error
|
||||
54, # Connection reset by peer
|
||||
60, # urllib.error.URLError: [Errno 60] Connection timed out
|
||||
)
|
||||
|
||||
# Both of the above shouldn't mask real issues such as 404's
|
||||
# or refused connections (changed DNS).
|
||||
# But some tests (test_data yahoo) contact incredibly flakey
|
||||
# servers.
|
||||
|
||||
# and conditionally raise on exception types in _get_default_network_errors
|
||||
|
||||
|
||||
def _get_default_network_errors():
|
||||
# Lazy import for http.client & urllib.error
|
||||
# because it imports many things from the stdlib
|
||||
import http.client
|
||||
import urllib.error
|
||||
|
||||
return (
|
||||
OSError,
|
||||
http.client.HTTPException,
|
||||
TimeoutError,
|
||||
urllib.error.URLError,
|
||||
socket.timeout,
|
||||
)
|
||||
|
||||
|
||||
def optional_args(decorator):
|
||||
"""
|
||||
allows a decorator to take optional positional and keyword arguments.
|
||||
Assumes that taking a single, callable, positional argument means that
|
||||
it is decorating a function, i.e. something like this::
|
||||
|
||||
@my_decorator
|
||||
def function(): pass
|
||||
|
||||
Calls decorator with decorator(f, *args, **kwargs)
|
||||
"""
|
||||
|
||||
@wraps(decorator)
|
||||
def wrapper(*args, **kwargs):
|
||||
def dec(f):
|
||||
return decorator(f, *args, **kwargs)
|
||||
|
||||
is_decorating = not kwargs and len(args) == 1 and callable(args[0])
|
||||
if is_decorating:
|
||||
f = args[0]
|
||||
args = ()
|
||||
return dec(f)
|
||||
else:
|
||||
return dec
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@optional_args
|
||||
def network(
|
||||
t,
|
||||
url="https://www.google.com",
|
||||
raise_on_error=False,
|
||||
check_before_test=False,
|
||||
error_classes=None,
|
||||
skip_errnos=_network_errno_vals,
|
||||
_skip_on_messages=_network_error_messages,
|
||||
):
|
||||
"""
|
||||
Label a test as requiring network connection and, if an error is
|
||||
encountered, only raise if it does not find a network connection.
|
||||
|
||||
In comparison to ``network``, this assumes an added contract to your test:
|
||||
you must assert that, under normal conditions, your test will ONLY fail if
|
||||
it does not have network connectivity.
|
||||
|
||||
You can call this in 3 ways: as a standard decorator, with keyword
|
||||
arguments, or with a positional argument that is the url to check.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
t : callable
|
||||
The test requiring network connectivity.
|
||||
url : path
|
||||
The url to test via ``pandas.io.common.urlopen`` to check
|
||||
for connectivity. Defaults to 'https://www.google.com'.
|
||||
raise_on_error : bool
|
||||
If True, never catches errors.
|
||||
check_before_test : bool
|
||||
If True, checks connectivity before running the test case.
|
||||
error_classes : tuple or Exception
|
||||
error classes to ignore. If not in ``error_classes``, raises the error.
|
||||
defaults to OSError. Be careful about changing the error classes here.
|
||||
skip_errnos : iterable of int
|
||||
Any exception that has .errno or .reason.erno set to one
|
||||
of these values will be skipped with an appropriate
|
||||
message.
|
||||
_skip_on_messages: iterable of string
|
||||
any exception e for which one of the strings is
|
||||
a substring of str(e) will be skipped with an appropriate
|
||||
message. Intended to suppress errors where an errno isn't available.
|
||||
|
||||
Notes
|
||||
-----
|
||||
* ``raise_on_error`` supersedes ``check_before_test``
|
||||
|
||||
Returns
|
||||
-------
|
||||
t : callable
|
||||
The decorated test ``t``, with checks for connectivity errors.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
Tests decorated with @network will fail if it's possible to make a network
|
||||
connection to another URL (defaults to google.com)::
|
||||
|
||||
>>> from pandas import _testing as tm
|
||||
>>> @tm.network
|
||||
... def test_network():
|
||||
... with pd.io.common.urlopen("rabbit://bonanza.com"):
|
||||
... pass
|
||||
>>> test_network() # doctest: +SKIP
|
||||
Traceback
|
||||
...
|
||||
URLError: <urlopen error unknown url type: rabbit>
|
||||
|
||||
You can specify alternative URLs::
|
||||
|
||||
>>> @tm.network("https://www.yahoo.com")
|
||||
... def test_something_with_yahoo():
|
||||
... raise OSError("Failure Message")
|
||||
>>> test_something_with_yahoo() # doctest: +SKIP
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
OSError: Failure Message
|
||||
|
||||
If you set check_before_test, it will check the url first and not run the
|
||||
test on failure::
|
||||
|
||||
>>> @tm.network("failing://url.blaher", check_before_test=True)
|
||||
... def test_something():
|
||||
... print("I ran!")
|
||||
... raise ValueError("Failure")
|
||||
>>> test_something() # doctest: +SKIP
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
|
||||
Errors not related to networking will always be raised.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
if error_classes is None:
|
||||
error_classes = _get_default_network_errors()
|
||||
|
||||
t.network = True
|
||||
|
||||
@wraps(t)
|
||||
def wrapper(*args, **kwargs):
|
||||
if (
|
||||
check_before_test
|
||||
and not raise_on_error
|
||||
and not can_connect(url, error_classes)
|
||||
):
|
||||
pytest.skip(
|
||||
f"May not have network connectivity because cannot connect to {url}"
|
||||
)
|
||||
try:
|
||||
return t(*args, **kwargs)
|
||||
except Exception as err:
|
||||
errno = getattr(err, "errno", None)
|
||||
if not errno and hasattr(errno, "reason"):
|
||||
# error: "Exception" has no attribute "reason"
|
||||
errno = getattr(err.reason, "errno", None) # type: ignore[attr-defined]
|
||||
|
||||
if errno in skip_errnos:
|
||||
pytest.skip(f"Skipping test due to known errno and error {err}")
|
||||
|
||||
e_str = str(err)
|
||||
|
||||
if any(m.lower() in e_str.lower() for m in _skip_on_messages):
|
||||
pytest.skip(
|
||||
f"Skipping test because exception message is known and error {err}"
|
||||
)
|
||||
|
||||
if not isinstance(err, error_classes) or raise_on_error:
|
||||
raise
|
||||
else:
|
||||
pytest.skip(
|
||||
f"Skipping test due to lack of connectivity and error {err}"
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
with_connectivity_check = network
|
||||
|
||||
|
||||
def can_connect(url, error_classes=None):
|
||||
"""
|
||||
Try to connect to the given url. True if succeeds, False if OSError
|
||||
raised
|
||||
|
||||
Parameters
|
||||
----------
|
||||
url : basestring
|
||||
The URL to try to connect to
|
||||
|
||||
Returns
|
||||
-------
|
||||
connectable : bool
|
||||
Return True if no OSError (unable to connect) or URLError (bad url) was
|
||||
raised
|
||||
"""
|
||||
if error_classes is None:
|
||||
error_classes = _get_default_network_errors()
|
||||
|
||||
try:
|
||||
with urlopen(url, timeout=20) as response:
|
||||
# Timeout just in case rate-limiting is applied
|
||||
if response.status != 200:
|
||||
return False
|
||||
except error_classes:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# File-IO
|
||||
|
||||
|
||||
def round_trip_pickle(
|
||||
obj: Any, path: FilePath | ReadPickleBuffer | None = None
|
||||
) -> DataFrame | Series:
|
||||
"""
|
||||
Pickle an object and then read it again.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obj : any object
|
||||
The object to pickle and then re-read.
|
||||
path : str, path object or file-like object, default None
|
||||
The path where the pickled object is written and then read.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pandas object
|
||||
The original object that was pickled and then re-read.
|
||||
"""
|
||||
_path = path
|
||||
if _path is None:
|
||||
_path = f"__{rands(10)}__.pickle"
|
||||
with ensure_clean(_path) as temp_path:
|
||||
pd.to_pickle(obj, temp_path)
|
||||
return pd.read_pickle(temp_path)
|
||||
|
||||
|
||||
def round_trip_pathlib(writer, reader, path: str | None = None):
|
||||
"""
|
||||
Write an object to file specified by a pathlib.Path and read it back
|
||||
|
||||
Parameters
|
||||
----------
|
||||
writer : callable bound to pandas object
|
||||
IO writing function (e.g. DataFrame.to_csv )
|
||||
reader : callable
|
||||
IO reading function (e.g. pd.read_csv )
|
||||
path : str, default None
|
||||
The path where the object is written and then read.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pandas object
|
||||
The original object that was serialized and then re-read.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
Path = pytest.importorskip("pathlib").Path
|
||||
if path is None:
|
||||
path = "___pathlib___"
|
||||
with ensure_clean(path) as path:
|
||||
writer(Path(path))
|
||||
obj = reader(Path(path))
|
||||
return obj
|
||||
|
||||
|
||||
def round_trip_localpath(writer, reader, path: str | None = None):
|
||||
"""
|
||||
Write an object to file specified by a py.path LocalPath and read it back.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
writer : callable bound to pandas object
|
||||
IO writing function (e.g. DataFrame.to_csv )
|
||||
reader : callable
|
||||
IO reading function (e.g. pd.read_csv )
|
||||
path : str, default None
|
||||
The path where the object is written and then read.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pandas object
|
||||
The original object that was serialized and then re-read.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
LocalPath = pytest.importorskip("py.path").local
|
||||
if path is None:
|
||||
path = "___localpath___"
|
||||
with ensure_clean(path) as path:
|
||||
writer(LocalPath(path))
|
||||
obj = reader(LocalPath(path))
|
||||
return obj
|
||||
|
||||
|
||||
def write_to_compressed(compression, path, data, dest="test"):
|
||||
"""
|
||||
Write data to a compressed file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
compression : {'gzip', 'bz2', 'zip', 'xz', 'zstd'}
|
||||
The compression type to use.
|
||||
path : str
|
||||
The file path to write the data.
|
||||
data : str
|
||||
The data to write.
|
||||
dest : str, default "test"
|
||||
The destination file (for ZIP only)
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError : An invalid compression value was passed in.
|
||||
"""
|
||||
args: tuple[Any, ...] = (data,)
|
||||
mode = "wb"
|
||||
method = "write"
|
||||
compress_method: Callable
|
||||
|
||||
if compression == "zip":
|
||||
compress_method = zipfile.ZipFile
|
||||
mode = "w"
|
||||
args = (dest, data)
|
||||
method = "writestr"
|
||||
elif compression == "gzip":
|
||||
compress_method = gzip.GzipFile
|
||||
elif compression == "bz2":
|
||||
compress_method = bz2.BZ2File
|
||||
elif compression == "zstd":
|
||||
compress_method = import_optional_dependency("zstandard").open
|
||||
elif compression == "xz":
|
||||
compress_method = get_lzma_file()
|
||||
else:
|
||||
raise ValueError(f"Unrecognized compression type: {compression}")
|
||||
|
||||
with compress_method(path, mode=mode) as f:
|
||||
getattr(f, method)(*args)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Plotting
|
||||
|
||||
|
||||
def close(fignum=None):
|
||||
from matplotlib.pyplot import (
|
||||
close as _close,
|
||||
get_fignums,
|
||||
)
|
||||
|
||||
if fignum is None:
|
||||
for fignum in get_fignums():
|
||||
_close(fignum)
|
||||
else:
|
||||
_close(fignum)
|
||||
@@ -0,0 +1,48 @@
|
||||
import string
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def randbool(size=(), p: float = 0.5):
|
||||
return np.random.rand(*size) <= p
|
||||
|
||||
|
||||
RANDS_CHARS = np.array(list(string.ascii_letters + string.digits), dtype=(np.str_, 1))
|
||||
RANDU_CHARS = np.array(
|
||||
list("".join(map(chr, range(1488, 1488 + 26))) + string.digits),
|
||||
dtype=(np.unicode_, 1),
|
||||
)
|
||||
|
||||
|
||||
def rands_array(nchars, size, dtype="O"):
|
||||
"""
|
||||
Generate an array of byte strings.
|
||||
"""
|
||||
retval = (
|
||||
np.random.choice(RANDS_CHARS, size=nchars * np.prod(size))
|
||||
.view((np.str_, nchars))
|
||||
.reshape(size)
|
||||
)
|
||||
return retval.astype(dtype)
|
||||
|
||||
|
||||
def randu_array(nchars, size, dtype="O"):
|
||||
"""
|
||||
Generate an array of unicode strings.
|
||||
"""
|
||||
retval = (
|
||||
np.random.choice(RANDU_CHARS, size=nchars * np.prod(size))
|
||||
.view((np.unicode_, nchars))
|
||||
.reshape(size)
|
||||
)
|
||||
return retval.astype(dtype)
|
||||
|
||||
|
||||
def rands(nchars):
|
||||
"""
|
||||
Generate one random byte string.
|
||||
|
||||
See `rands_array` if you want to create an array of random strings.
|
||||
|
||||
"""
|
||||
return "".join(np.random.choice(RANDS_CHARS, nchars))
|
||||
@@ -0,0 +1,205 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
import re
|
||||
import sys
|
||||
from typing import (
|
||||
Sequence,
|
||||
Type,
|
||||
cast,
|
||||
)
|
||||
import warnings
|
||||
|
||||
|
||||
@contextmanager
|
||||
def assert_produces_warning(
|
||||
expected_warning: type[Warning] | bool | None = Warning,
|
||||
filter_level="always",
|
||||
check_stacklevel: bool = True,
|
||||
raise_on_extra_warnings: bool = True,
|
||||
match: str | None = None,
|
||||
):
|
||||
"""
|
||||
Context manager for running code expected to either raise a specific
|
||||
warning, or not raise any warnings. Verifies that the code raises the
|
||||
expected warning, and that it does not raise any other unexpected
|
||||
warnings. It is basically a wrapper around ``warnings.catch_warnings``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expected_warning : {Warning, False, None}, default Warning
|
||||
The type of Exception raised. ``exception.Warning`` is the base
|
||||
class for all warnings. To check that no warning is returned,
|
||||
specify ``False`` or ``None``.
|
||||
filter_level : str or None, default "always"
|
||||
Specifies whether warnings are ignored, displayed, or turned
|
||||
into errors.
|
||||
Valid values are:
|
||||
|
||||
* "error" - turns matching warnings into exceptions
|
||||
* "ignore" - discard the warning
|
||||
* "always" - always emit a warning
|
||||
* "default" - print the warning the first time it is generated
|
||||
from each location
|
||||
* "module" - print the warning the first time it is generated
|
||||
from each module
|
||||
* "once" - print the warning the first time it is generated
|
||||
|
||||
check_stacklevel : bool, default True
|
||||
If True, displays the line that called the function containing
|
||||
the warning to show were the function is called. Otherwise, the
|
||||
line that implements the function is displayed.
|
||||
raise_on_extra_warnings : bool, default True
|
||||
Whether extra warnings not of the type `expected_warning` should
|
||||
cause the test to fail.
|
||||
match : str, optional
|
||||
Match warning message.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import warnings
|
||||
>>> with assert_produces_warning():
|
||||
... warnings.warn(UserWarning())
|
||||
...
|
||||
>>> with assert_produces_warning(False):
|
||||
... warnings.warn(RuntimeWarning())
|
||||
...
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError: Caused unexpected warning(s): ['RuntimeWarning'].
|
||||
>>> with assert_produces_warning(UserWarning):
|
||||
... warnings.warn(RuntimeWarning())
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError: Did not see expected warning of class 'UserWarning'.
|
||||
|
||||
..warn:: This is *not* thread-safe.
|
||||
"""
|
||||
__tracebackhide__ = True
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter(filter_level)
|
||||
yield w
|
||||
|
||||
if expected_warning:
|
||||
expected_warning = cast(Type[Warning], expected_warning)
|
||||
_assert_caught_expected_warning(
|
||||
caught_warnings=w,
|
||||
expected_warning=expected_warning,
|
||||
match=match,
|
||||
check_stacklevel=check_stacklevel,
|
||||
)
|
||||
|
||||
if raise_on_extra_warnings:
|
||||
_assert_caught_no_extra_warnings(
|
||||
caught_warnings=w,
|
||||
expected_warning=expected_warning,
|
||||
)
|
||||
|
||||
|
||||
def _assert_caught_expected_warning(
|
||||
*,
|
||||
caught_warnings: Sequence[warnings.WarningMessage],
|
||||
expected_warning: type[Warning],
|
||||
match: str | None,
|
||||
check_stacklevel: bool,
|
||||
) -> None:
|
||||
"""Assert that there was the expected warning among the caught warnings."""
|
||||
saw_warning = False
|
||||
matched_message = False
|
||||
unmatched_messages = []
|
||||
|
||||
for actual_warning in caught_warnings:
|
||||
if issubclass(actual_warning.category, expected_warning):
|
||||
saw_warning = True
|
||||
|
||||
if check_stacklevel and issubclass(
|
||||
actual_warning.category, (FutureWarning, DeprecationWarning)
|
||||
):
|
||||
_assert_raised_with_correct_stacklevel(actual_warning)
|
||||
|
||||
if match is not None:
|
||||
if re.search(match, str(actual_warning.message)):
|
||||
matched_message = True
|
||||
else:
|
||||
unmatched_messages.append(actual_warning.message)
|
||||
|
||||
if not saw_warning:
|
||||
raise AssertionError(
|
||||
f"Did not see expected warning of class "
|
||||
f"{repr(expected_warning.__name__)}"
|
||||
)
|
||||
|
||||
if match and not matched_message:
|
||||
raise AssertionError(
|
||||
f"Did not see warning {repr(expected_warning.__name__)} "
|
||||
f"matching '{match}'. The emitted warning messages are "
|
||||
f"{unmatched_messages}"
|
||||
)
|
||||
|
||||
|
||||
def _assert_caught_no_extra_warnings(
|
||||
*,
|
||||
caught_warnings: Sequence[warnings.WarningMessage],
|
||||
expected_warning: type[Warning] | bool | None,
|
||||
) -> None:
|
||||
"""Assert that no extra warnings apart from the expected ones are caught."""
|
||||
extra_warnings = []
|
||||
|
||||
for actual_warning in caught_warnings:
|
||||
if _is_unexpected_warning(actual_warning, expected_warning):
|
||||
# GH#38630 pytest.filterwarnings does not suppress these.
|
||||
if actual_warning.category == ResourceWarning:
|
||||
# GH 44732: Don't make the CI flaky by filtering SSL-related
|
||||
# ResourceWarning from dependencies
|
||||
unclosed_ssl = (
|
||||
"unclosed transport <asyncio.sslproto._SSLProtocolTransport",
|
||||
"unclosed <ssl.SSLSocket",
|
||||
)
|
||||
if any(msg in str(actual_warning.message) for msg in unclosed_ssl):
|
||||
continue
|
||||
# GH 44844: Matplotlib leaves font files open during the entire process
|
||||
# upon import. Don't make CI flaky if ResourceWarning raised
|
||||
# due to these open files.
|
||||
if any("matplotlib" in mod for mod in sys.modules):
|
||||
continue
|
||||
|
||||
extra_warnings.append(
|
||||
(
|
||||
actual_warning.category.__name__,
|
||||
actual_warning.message,
|
||||
actual_warning.filename,
|
||||
actual_warning.lineno,
|
||||
)
|
||||
)
|
||||
|
||||
if extra_warnings:
|
||||
raise AssertionError(f"Caused unexpected warning(s): {repr(extra_warnings)}")
|
||||
|
||||
|
||||
def _is_unexpected_warning(
|
||||
actual_warning: warnings.WarningMessage,
|
||||
expected_warning: type[Warning] | bool | None,
|
||||
) -> bool:
|
||||
"""Check if the actual warning issued is unexpected."""
|
||||
if actual_warning and not expected_warning:
|
||||
return True
|
||||
expected_warning = cast(Type[Warning], expected_warning)
|
||||
return bool(not issubclass(actual_warning.category, expected_warning))
|
||||
|
||||
|
||||
def _assert_raised_with_correct_stacklevel(
|
||||
actual_warning: warnings.WarningMessage,
|
||||
) -> None:
|
||||
from inspect import (
|
||||
getframeinfo,
|
||||
stack,
|
||||
)
|
||||
|
||||
caller = getframeinfo(stack()[4][0])
|
||||
msg = (
|
||||
"Warning not set with correct stacklevel. "
|
||||
f"File where warning is raised: {actual_warning.filename} != "
|
||||
f"{caller.filename}. Warning message: {actual_warning.message}"
|
||||
)
|
||||
assert actual_warning.filename == caller.filename, msg
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
Helpers for sharing tests between DataFrame/Series
|
||||
"""
|
||||
|
||||
from pandas import DataFrame
|
||||
|
||||
|
||||
def get_dtype(obj):
|
||||
if isinstance(obj, DataFrame):
|
||||
# Note: we are assuming only one column
|
||||
return obj.dtypes.iat[0]
|
||||
else:
|
||||
return obj.dtype
|
||||
|
||||
|
||||
def get_obj(df: DataFrame, klass):
|
||||
"""
|
||||
For sharing tests using frame_or_series, either return the DataFrame
|
||||
unchanged or return it's first column as a Series.
|
||||
"""
|
||||
if klass is DataFrame:
|
||||
return df
|
||||
return df._ixs(0, axis=1)
|
||||
@@ -0,0 +1,244 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
import os
|
||||
from pathlib import Path
|
||||
import random
|
||||
from shutil import rmtree
|
||||
import string
|
||||
import tempfile
|
||||
from typing import (
|
||||
IO,
|
||||
Any,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
from pandas import set_option
|
||||
|
||||
from pandas.io.common import get_handle
|
||||
|
||||
|
||||
@contextmanager
|
||||
def decompress_file(path, compression):
|
||||
"""
|
||||
Open a compressed file and return a file object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
The path where the file is read from.
|
||||
|
||||
compression : {'gzip', 'bz2', 'zip', 'xz', 'zstd', None}
|
||||
Name of the decompression to use
|
||||
|
||||
Returns
|
||||
-------
|
||||
file object
|
||||
"""
|
||||
with get_handle(path, "rb", compression=compression, is_text=False) as handle:
|
||||
yield handle.handle
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_timezone(tz: str):
|
||||
"""
|
||||
Context manager for temporarily setting a timezone.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tz : str
|
||||
A string representing a valid timezone.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from datetime import datetime
|
||||
>>> from dateutil.tz import tzlocal
|
||||
>>> tzlocal().tzname(datetime(2021, 1, 1)) # doctest: +SKIP
|
||||
'IST'
|
||||
|
||||
>>> with set_timezone('US/Eastern'):
|
||||
... tzlocal().tzname(datetime(2021, 1, 1))
|
||||
...
|
||||
'EST'
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
|
||||
def setTZ(tz):
|
||||
if tz is None:
|
||||
try:
|
||||
del os.environ["TZ"]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
os.environ["TZ"] = tz
|
||||
time.tzset()
|
||||
|
||||
orig_tz = os.environ.get("TZ")
|
||||
setTZ(tz)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
setTZ(orig_tz)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ensure_clean(filename=None, return_filelike: bool = False, **kwargs: Any):
|
||||
"""
|
||||
Gets a temporary path and agrees to remove on close.
|
||||
|
||||
This implementation does not use tempfile.mkstemp to avoid having a file handle.
|
||||
If the code using the returned path wants to delete the file itself, windows
|
||||
requires that no program has a file handle to it.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename : str (optional)
|
||||
suffix of the created file.
|
||||
return_filelike : bool (default False)
|
||||
if True, returns a file-like which is *always* cleaned. Necessary for
|
||||
savefig and other functions which want to append extensions.
|
||||
**kwargs
|
||||
Additional keywords are passed to open().
|
||||
|
||||
"""
|
||||
folder = Path(tempfile.gettempdir())
|
||||
|
||||
if filename is None:
|
||||
filename = ""
|
||||
filename = (
|
||||
"".join(random.choices(string.ascii_letters + string.digits, k=30)) + filename
|
||||
)
|
||||
path = folder / filename
|
||||
|
||||
path.touch()
|
||||
|
||||
handle_or_str: str | IO = str(path)
|
||||
if return_filelike:
|
||||
kwargs.setdefault("mode", "w+b")
|
||||
handle_or_str = open(path, **kwargs)
|
||||
|
||||
try:
|
||||
yield handle_or_str
|
||||
finally:
|
||||
if not isinstance(handle_or_str, str):
|
||||
handle_or_str.close()
|
||||
if path.is_file():
|
||||
path.unlink()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ensure_clean_dir():
|
||||
"""
|
||||
Get a temporary directory path and agrees to remove on close.
|
||||
|
||||
Yields
|
||||
------
|
||||
Temporary directory path
|
||||
"""
|
||||
directory_name = tempfile.mkdtemp(suffix="")
|
||||
try:
|
||||
yield directory_name
|
||||
finally:
|
||||
try:
|
||||
rmtree(directory_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ensure_safe_environment_variables():
|
||||
"""
|
||||
Get a context manager to safely set environment variables
|
||||
|
||||
All changes will be undone on close, hence environment variables set
|
||||
within this contextmanager will neither persist nor change global state.
|
||||
"""
|
||||
saved_environ = dict(os.environ)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.environ.clear()
|
||||
os.environ.update(saved_environ)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def with_csv_dialect(name, **kwargs):
|
||||
"""
|
||||
Context manager to temporarily register a CSV dialect for parsing CSV.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the dialect.
|
||||
kwargs : mapping
|
||||
The parameters for the dialect.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError : the name of the dialect conflicts with a builtin one.
|
||||
|
||||
See Also
|
||||
--------
|
||||
csv : Python's CSV library.
|
||||
"""
|
||||
import csv
|
||||
|
||||
_BUILTIN_DIALECTS = {"excel", "excel-tab", "unix"}
|
||||
|
||||
if name in _BUILTIN_DIALECTS:
|
||||
raise ValueError("Cannot override builtin dialect.")
|
||||
|
||||
csv.register_dialect(name, **kwargs)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
csv.unregister_dialect(name)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_numexpr(use, min_elements=None):
|
||||
from pandas.core.computation import expressions as expr
|
||||
|
||||
if min_elements is None:
|
||||
min_elements = expr._MIN_ELEMENTS
|
||||
|
||||
olduse = expr.USE_NUMEXPR
|
||||
oldmin = expr._MIN_ELEMENTS
|
||||
set_option("compute.use_numexpr", use)
|
||||
expr._MIN_ELEMENTS = min_elements
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
expr._MIN_ELEMENTS = oldmin
|
||||
set_option("compute.use_numexpr", olduse)
|
||||
|
||||
|
||||
class RNGContext:
|
||||
"""
|
||||
Context manager to set the numpy random number generator speed. Returns
|
||||
to the original value upon exiting the context manager.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
seed : int
|
||||
Seed for numpy.random.seed
|
||||
|
||||
Examples
|
||||
--------
|
||||
with RNGContext(42):
|
||||
np.random.randn()
|
||||
"""
|
||||
|
||||
def __init__(self, seed):
|
||||
self.seed = seed
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
self.start_state = np.random.get_state()
|
||||
np.random.seed(self.seed)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
|
||||
np.random.set_state(self.start_state)
|
||||
Reference in New Issue
Block a user