8000 [WIP] Support `dataframe.dtype_backend` globally by jrbourbeau · Pull Request #9883 · dask/dask · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[WIP] Support dataframe.dtype_backend globally #9883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions dask/dataframe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from tlz import first, merge, partition_all, remove, unique

import dask
import dask.array as da
from dask import core
from dask.array.core import Array, normalize_arg
Expand All @@ -37,6 +38,7 @@
PANDAS_GT_140,
PANDAS_GT_150,
PANDAS_GT_200,
PANDAS_VERSION,
check_numeric_only_deprecation,
)
from dask.dataframe.accessor import CachedAccessor, DatetimeAccessor, StringAccessor
Expand Down Expand Up @@ -318,6 +320,42 @@ def _scalar_binary(op, self, other, inv=False):
return Scalar(graph, name, meta)


def _is_pyarrow_dtype(dtype):
return isinstance(dtype, pd.ArrowDtype) or dtype == pd.StringDtype("pyarrow")


def _convert_to_pyarrow_dtype(dtype):
import pyarrow as pa
from pandas.core.arrays.arrow.array import to_pyarrow_type
from pandas.core.dtypes.dtypes import BaseMaskedDtype, PandasExtensionDtype

# Already a pyarrow-backed dtype
if _is_pyarrow_dtype(dtype):
return dtype

if isinstance(dtype, PandasExtensionDtype):
base_dtype = dtype.base
elif isinstance(dtype, BaseMaskedDtype):
base_dtype = dtype.numpy_dtype
elif isinstance(dtype, pd.StringDtype):
base_dtype = np.dtype(str)
else:
base_dtype = dtype

if base_dtype == object:
# Convert objects to strings
pa_type = pa.string()
else:
pa_type = to_pyarrow_type(base_dtype)
if pa_type is None:
raise TypeError(f"Encountered {dtype} which is not compatible with pyarrow")

if pa_type == pa.string():
return pd.StringDtype("pyarrow")
else:
return pd.ArrowDtype(pa_type)


class _Frame(DaskMethodsMixin, OperatorMethodMixin):
"""Superclass for DataFrame and Series

Expand Down Expand Up @@ -349,6 +387,49 @@ def __init__(self, dsk, name, meta, divisions):
self._meta = meta
self.divisions = tuple(divisions)

# Optionally cast to `pyarrow`-backed dtypes based on the
# `dataframe.dtype_backend` config option
if dask.config.get("dataframe.dtype_backend") == "pyarrow":
try:
import pyarrow # noqa: F401
except ImportError:
raise RuntimeError(
"Using dask's `dataframe.dtype_backend = 'pyarrow'` configuration "
"option requires `pyarrow` to be installed."
)
if not PANDAS_GT_150:
raise RuntimeError(
"Using dask's `dataframe.dtype_backend = 'pyarrow'` configuration "
"option requires pandas>=1.5.0 to be installed. "
f"pandas={str(PANDAS_VERSION)} is currently using used."
)

# Check whether or not all dtypes are already pyarrow-compatible.
# This avoids infinite recursions.
if (
(
is_dataframe_like(meta)
and not all(_is_pyarrow_dtype(dt) for dt in meta.dtypes)
)
or (is_series_like(meta) and not _is_pyarrow_dtype(meta.dtype))
or (is_index_like(meta) and not _is_pyarrow_dtype(meta.dtype))
):

def to_pyarrow_dtypes(df):
if is_dataframe_like(df):
dtypes = {
col: _convert_to_pyarrow_dtype(df[col].dtype) for col in df
}
else:
dtypes = _convert_to_pyarrow_dtype(df.dtype)
return df.astype(dtypes)

result = self.map_partitions(to_pyarrow_dtypes)
self.dask = result.dask
self._name = result._name
self._meta = result._meta
self.divisions = result.divisions

def __dask_graph__(self):
return self.dask

Expand Down
5 changes: 3 additions & 2 deletions dask/dataframe/io/parquet/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,9 @@ def read_parquet(
pyarrow.parquet.ParquetDataset
"""

if use_nullable_dtypes:
use_nullable_dtypes = dask.config.get("dataframe.dtype_backend")
dtype_backend = dask.config.get("dataframe.dtype_backend")
if dtype_backend == "pyarrow" or use_nullable_dtypes:
use_nullable_dtypes = dtype_backend

# "Pre-deprecation" warning for `chunksize`
if chunksize:
Expand Down
40 changes: 40 additions & 0 deletions dask/dataframe/io/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd
import pytest

import dask
import dask.array as da
import dask.dataframe as dd
from dask import config
Expand Down Expand Up @@ -274,6 +275,45 @@ def test_from_pandas_npartitions_duplicates(index):
assert ddf.divisions == ("A", "B", "C", "C")


def test_from_pandas_dtype_backend_config():
pytest.importorskip(
"pandas",
minversion="1.5.0",
reason="Requires support for pyarrow-backed dtypes",
)
pytest.importorskip("pyarrow", reason="Requires pyarrow")

# `dataframe.dtype_backend` defaults to normal `numpy`-backed dtypes.
# This matches what `pandas` does by default.
s = pd.Series([1, 2, 3, 4])
df = pd.DataFrame(
{
"x": [1, 2, 3, 4],
"y": [5.0, 6.0, 7.0, 8.0],
"z": ["foo", "bar", "ricky", "bobby"],
}
)

ds = dd.from_pandas(s, npartitions=2)
ddf = dd.from_pandas(df, npartitions=2)

assert_eq(s, ds)
assert_eq(df, ddf)

# When `dataframe.dtype_backend = "pyarrow"`, dask should automatically
# cast to `pyarrow`-backed dtypes
with dask.config.set({"dataframe.dtype_backend": "pyarrow"}):
ds = dd.from_pandas(s, npartitions=2)
ddf = dd.from_pandas(df, npartitions=2)

s_pyarrow = s.astype("int64[pyarrow]")
df_pyarrow = df.astype(
{"x": "int64[pyarrow]", "y": "float64[pyarrow]", "z": "string[pyarrow]"}
)
assert_eq(s_pyarrow, ds)
assert_eq(df_pyarrow, ddf)


@pytest.mark.gpu
def test_gpu_from_pandas_npartitions_duplicates():
cudf = pytest.importorskip("cudf")
Expand Down
3 changes: 3 additions & 0 deletions dask/dataframe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,9 @@ def valid_divisions(divisions):
if not isinstance(divisions, (tuple, list)):
return False

if pd.isna(divisions).any():
return False

for i, x in enumerate(divisions[:-2]):
if x >= divisions[i + 1]:
return False
Expand Down
0