From 281d4665237f66ed1e162c6774acef08abbd2cb9 Mon Sep 17 00:00:00 2001 From: Eric Vandenberg Date: Fri, 13 Sep 2024 14:00:39 -0700 Subject: [PATCH] SNOW-1660802 Implement dataframe groupby fillna --- CHANGELOG.md | 2 +- .../modin/supported/groupby_supported.rst | 3 +- .../snowpark/modin/pandas/groupby.py | 25 +- .../snowpark/modin/plugin/_internal/frame.py | 10 +- .../compiler/snowflake_query_compiler.py | 279 +++++++++++++++- .../modin/plugin/docstrings/groupby.py | 107 ++++++- .../modin/groupby/test_groupby_fillna.py | 297 ++++++++++++++++++ .../modin/groupby/test_groupby_first_last.py | 2 +- 8 files changed, 713 insertions(+), 12 deletions(-) create mode 100644 tests/integ/modin/groupby/test_groupby_fillna.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e0589d4a35..cc0603075c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ #### New Features - Added support for `TimedeltaIndex.mean` method. - +- Added support for `DataFrameGroupBy.fillna`. ## 1.22.1 (2024-09-11) This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for detailed release content. diff --git a/docs/source/modin/supported/groupby_supported.rst b/docs/source/modin/supported/groupby_supported.rst index 3bcf353821..1345d1b10b 100644 --- a/docs/source/modin/supported/groupby_supported.rst +++ b/docs/source/modin/supported/groupby_supported.rst @@ -106,7 +106,8 @@ Computations/descriptive stats +-----------------------------+---------------------------------+----------------------------------------------------+ | ``ffill`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``fillna`` | N | | +| ``fillna`` | P | GroupBy axis = 0 is supported. | +| | | Does not support ``downcast`` parameter | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``first`` | P | Does not support ``min_count`` parameter | +-----------------------------+---------------------------------+----------------------------------------------------+ diff --git a/src/snowflake/snowpark/modin/pandas/groupby.py b/src/snowflake/snowpark/modin/pandas/groupby.py index 72ae6a2b00..4b2f8e0ff1 100644 --- a/src/snowflake/snowpark/modin/pandas/groupby.py +++ b/src/snowflake/snowpark/modin/pandas/groupby.py @@ -30,7 +30,7 @@ import pandas.core.groupby from modin.pandas import Series from pandas._libs.lib import NoDefault, no_default -from pandas._typing import AggFuncType, Axis, IndexLabel +from pandas._typing import AggFuncType, Axis, FillnaOptions, IndexLabel from pandas.core.dtypes.common import is_dict_like, is_list_like, is_numeric_dtype from pandas.errors import SpecificationError from pandas.io.formats.printing import PrettyDict @@ -992,9 +992,28 @@ def corr(self, **kwargs): # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions ErrorMessage.method_not_implemented_error(name="corr", class_="GroupBy") - def fillna(self, *args, **kwargs): + def fillna( + self, + value: Any = None, + method: FillnaOptions | None = None, + axis: Axis | None = None, + inplace: bool = False, + limit: int | None = None, + downcast: dict | None = None, + ): # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions - ErrorMessage.method_not_implemented_error(name="fillna", class_="GroupBy") + query_compiler = self._query_compiler.groupby_fillna( + self._by, + self._axis, + self._kwargs, + value, + method, + axis, + inplace, + limit, + downcast, + ) + return pd.DataFrame(query_compiler=query_compiler) def count(self): # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions diff --git a/src/snowflake/snowpark/modin/plugin/_internal/frame.py b/src/snowflake/snowpark/modin/plugin/_internal/frame.py index 25ca2fb8d2..3ea232f2df 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/frame.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/frame.py @@ -554,6 +554,7 @@ def get_snowflake_quoted_identifiers_group_by_pandas_labels( self, pandas_labels: list[Hashable], include_index: bool = True, + include_data: bool = True, ) -> list[tuple[str, ...]]: """ Map given pandas labels to names in underlying snowpark dataframe. Given labels can be data or index labels. @@ -562,7 +563,8 @@ def get_snowflake_quoted_identifiers_group_by_pandas_labels( Args: pandas_labels: A list of pandas labels. - include_index: Include the index columns in addition to data columns, default is True. + include_index: Include the index columns in addition to potentially data columns, default is True. + include_data: Include the data columns in addition to potentially index columns, default is True. Returns: A list of tuples for matched identifiers. Each element of list is a tuple of str containing matched @@ -576,7 +578,11 @@ def get_snowflake_quoted_identifiers_group_by_pandas_labels( filter( lambda col: to_pandas_label(col.label) == label, self.label_to_snowflake_quoted_identifier[ - (0 if include_index else self.num_index_columns) : + (0 if include_index else self.num_index_columns) : ( + len(self.label_to_snowflake_quoted_identifier) + if include_data + else self.num_index_columns + ) ], ) ) diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 2f6ff69be6..7fc4713a43 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -12,6 +12,7 @@ import uuid from collections.abc import Hashable, Iterable, Mapping, Sequence from datetime import timedelta, tzinfo +from functools import reduce from typing import Any, Callable, List, Literal, Optional, Union, get_args import numpy as np @@ -4100,7 +4101,7 @@ def groupby_apply( ) def _fill_null_values_in_groupby( - self, method: str, by_list: list[str] + self, method: str, by_list: list[str], limit: Optional[int] = None ) -> dict[str, ColumnOrName]: """ Fill null values in each column using method within each group. @@ -4110,6 +4111,8 @@ def _fill_null_values_in_groupby( The method to use to fill null values. by_list: list[str] The list of columns to partition by during the fillna. + limit : int, optional + The limit of values in a run to fill. Returns: dict: A mapping between column name and the Snowpark Column object with @@ -4119,12 +4122,12 @@ def _fill_null_values_in_groupby( method_is_ffill = method is FillNAMethod.FFILL_METHOD if method_is_ffill: func = last_value - window_start = Window.UNBOUNDED_PRECEDING + window_start = Window.UNBOUNDED_PRECEDING if limit is None else -1 * limit window_end = Window.CURRENT_ROW else: func = first_value window_start = Window.CURRENT_ROW - window_end = Window.UNBOUNDED_FOLLOWING + window_end = Window.UNBOUNDED_FOLLOWING if limit is None else limit return { snowflake_quoted_id: coalesce( @@ -5327,6 +5330,276 @@ def groupby_value_counts( ignore_index=not as_index, # When as_index=False, take the default positional index ) + def groupby_fillna( + self, + by: Any, + axis: int, + groupby_kwargs: dict[str, Any], + value: Optional[ + Union[Scalar, Hashable, Mapping, "pd.DataFrame", "pd.Series"] + ] = None, + method: Optional[FillnaOptions] = None, + fill_axis: Optional[int] = None, + inplace: bool = False, + limit: Optional[int] = None, + downcast: Optional[dict] = None, + ) -> "SnowflakeQueryCompiler": + """ + Replace NaN values using provided method or value. + + Args: + by: Used to determine the groups for the groupby. + axis: Group by axis, 0 (index) or 1 (columns), only axis=0 is supported currently. + groupby_kwargs: Dict[str, Any] + keyword arguments passed for the groupby. + value: Optional fill value. + method: Optional (if no value specified) method of `ffill` or `bfill`. + fill_axis : Fill axis, 0 (index) or 1 (columns) + inplace: Not supported + limit: Maximum number of consecutive NA values to fill. + downcast: Not supported + + Returns: + SnowflakeQueryCompiler: with a NaN values using method or value. + """ + level = groupby_kwargs.get("level", None) + + is_supported = check_is_groupby_supported_by_snowflake( + by=by, level=level, axis=axis + ) + if not is_supported: + ErrorMessage.not_implemented( + f"Snowpark pandas GroupBy.fillna {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}" + ) + + if by is not None and not is_list_like(by): + by = [by] + + if value is not None and method is not None: + raise ValueError("Cannot specify both 'value' and 'method'.") + + if value is None and method is None: + raise ValueError("Must specify a fill 'value' or 'method'.") + + if method is not None and method not in ["ffill", "bfill"]: + raise ValueError( + f"Invalid fill method. Expecting pad (ffill) or backfill (bfill). Got {method}" + ) + + if downcast: + ErrorMessage.not_implemented( + "Snowpark pandas fillna API doesn't yet support 'downcast' parameter" + ) + + if fill_axis is None: + fill_axis = 0 + + if level is not None: + by = extract_groupby_column_pandas_labels(self, by, level) + + frame = self._modin_frame + + data_column_group_keys = [ + pandas_label + for pandas_label in frame.data_column_pandas_labels + if pandas_label in by + ] + + data_column_group_keys_mask = [ + pandas_label in data_column_group_keys + for pandas_label in frame.data_column_pandas_labels + ] + + by_list_snowflake_quoted_identifiers: list[str] + + # If any of the groupby values are None, then there is no fill so check through an expression. + def groupby_null_expr( + col_expr: SnowparkColumn, + col_snowflake_quoted_identifier: str, + by_list_snowflake_quoted_identifiers: List[str], + ) -> SnowparkColumn: + return iff( + reduce( + lambda b1, b2: b1 & b2, + [ + (col(snowflake_quoted_identifier).is_not_null()) + for snowflake_quoted_identifier in by_list_snowflake_quoted_identifiers + ], + ), + col_expr, + pandas_lit(None), + ).as_(col_snowflake_quoted_identifier) + + # If no method, then we will use the value instead. + if method is None: + # If there's no method, then the fill is same as dataframe.fillna. + qc = self._fillna_with_masking( + value=value, + self_is_series=False, + method=None, + axis=axis, + limit=limit, + downcast=downcast, + columns_mask=data_column_group_keys_mask, + ) + + frame = qc._modin_frame + new_snowflake_quoted_identifiers = ( + frame.data_column_snowflake_quoted_identifiers + ) + + # Group by snowflake quoted identifiers + by_list_snowflake_quoted_identifiers = [ + snowflake_quoted_identifier[0] + for snowflake_quoted_identifier in frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( + by, include_index=True, include_data=True + ) + if len(snowflake_quoted_identifier) > 0 + ] + + # Generate new snowflake quoted identifiers for output columns so they don't conflict with existing. + new_snowflake_quoted_identifiers = ( + frame.ordered_dataframe.generate_snowflake_quoted_identifiers( + pandas_labels=frame.data_column_pandas_labels + ) + ) + + select_list = frame.index_column_snowflake_quoted_identifiers + [ + groupby_null_expr( + col(snowflake_quoted_identifier), + new_snowflake_quoted_identifier, + by_list_snowflake_quoted_identifiers, + ) + for new_snowflake_quoted_identifier, snowflake_quoted_identifier in zip( + new_snowflake_quoted_identifiers, + frame.data_column_snowflake_quoted_identifiers, + ) + ] + else: + # Group by snowflake quoted identifiers + by_list_snowflake_quoted_identifiers = [ + snowflake_quoted_identifier[0] + for snowflake_quoted_identifier in frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( + by, include_index=True, include_data=True + ) + if len(snowflake_quoted_identifier) > 0 + ] + + # Generate new snowflake quoted identifiers for output columns so they don't conflict with existing. + new_snowflake_quoted_identifiers = ( + frame.ordered_dataframe.generate_snowflake_quoted_identifiers( + pandas_labels=frame.data_column_pandas_labels + ) + ) + + if fill_axis == 0: + columns_to_fillna_expr = self._fill_null_values_in_groupby( + method, by_list_snowflake_quoted_identifiers, limit + ) + + select_list = frame.index_column_snowflake_quoted_identifiers + [ + groupby_null_expr( + columns_to_fillna_expr[snowflake_quoted_identifier], + new_snowflake_quoted_identifier, + by_list_snowflake_quoted_identifiers, + ) + for new_snowflake_quoted_identifier, snowflake_quoted_identifier in zip( + new_snowflake_quoted_identifiers, + frame.data_column_snowflake_quoted_identifiers, + ) + ] + + elif fill_axis == 1: + coalesce_column_list: list[SnowparkColumn] = [] + select_list = [] + + data_column_pairs = list( + zip( + frame.data_column_snowflake_quoted_identifiers, + new_snowflake_quoted_identifiers, + data_column_group_keys_mask, + ) + ) + + if method == "bfill": + data_column_pairs.reverse() + + for ( + snowflake_quoted_identifier, + new_snowflake_quoted_identifier, + is_data_column_group_key, + ) in data_column_pairs: + if is_data_column_group_key: + select_list.append( + col(snowflake_quoted_identifier).as_( + new_snowflake_quoted_identifier + ) + ) + continue + + if len(coalesce_column_list) == 0: + select_item = col(snowflake_quoted_identifier) + else: + coalesce_expr = [ + snowflake_quoted_identifier + ] + coalesce_column_list + select_item = coalesce(*coalesce_expr) + + select_item = groupby_null_expr( + select_item, + new_snowflake_quoted_identifier, + by_list_snowflake_quoted_identifiers, + ) + + select_list.append(select_item) + coalesce_column_list.insert(0, col(snowflake_quoted_identifier)) + if limit is not None and len(coalesce_column_list) > limit: + del coalesce_column_list[-1] + + if method == "bfill": + select_list.reverse() + + select_list = ( + frame.index_column_snowflake_quoted_identifiers + select_list + ) + + new_ordered_dataframe = frame.ordered_dataframe.select(select_list) + + # If any group-by keys were original data (not index) columns, then we drop them in the final result. + # + # The methods ffill, bfill, pad and backfill of DataFrameGroupBy previously included the group labels in + # the return value, which was inconsistent with other groupby transforms. Now only the filled values + # are returned. (GH 21521) + if len(data_column_group_keys) > 0: + data_column_pandas_labels, data_column_snowflake_quoted_identifiers = zip( + *[ + (pandas_label, snowflake_quoted_identifier) + for pandas_label, snowflake_quoted_identifier in zip( + frame.data_column_pandas_labels, + new_snowflake_quoted_identifiers, + ) + if pandas_label not in data_column_group_keys + ] + ) + else: + data_column_pandas_labels, data_column_snowflake_quoted_identifiers = ( + frame.data_column_pandas_labels, + new_snowflake_quoted_identifiers, + ) + + new_frame = InternalFrame.create( + ordered_dataframe=new_ordered_dataframe, + index_column_pandas_labels=frame.index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers, + data_column_pandas_labels=data_column_pandas_labels, + data_column_snowflake_quoted_identifiers=data_column_snowflake_quoted_identifiers, + data_column_pandas_index_names=frame.data_column_pandas_index_names, + data_column_types=None, + index_column_types=None, + ) + + return SnowflakeQueryCompiler(new_frame) + def _get_dummies_helper( self, column: Hashable, diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/groupby.py b/src/snowflake/snowpark/modin/plugin/docstrings/groupby.py index a660546423..4331941429 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/groupby.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/groupby.py @@ -1885,7 +1885,112 @@ def corr(): pass def fillna(): - pass + """ + Fill NA/NaN values using the specified method within groups. + + Parameters + ---------- + value : scalar, dict, Series, or DataFrame + value to use to fill holes (e.g. 0), alternately a dict/Series/DataFrame of values + specifying which value to use for each index (for a Series) or column (for a + DataFrame). Values not in the dict/Series/DataFrame will not be filled. This + value cannot be a list. + + method : {{‘bfill’, ‘ffill’, None}}, default None + Method to use for filling holes. 'ffill' will propagate the last valid observation + forward within a group. 'bfill' will use next valid observation to fill the gap. + + axis : {0 or ‘index’, 1 or ‘columns’} + Axis along which to fill missing values. When the DataFrameGroupBy axis + argument is 0, using axis=1 here will produce the same results as + DataFrame.fillna(). When the DataFrameGroupBy axis argument is 1, using + axis=0 or axis=1 here will produce the same results. + + inplace : bool, default False + Ignored. + + limit : int, default None + If method is specified, this is the maximum number of consecutive NaN values to + forward/backward fill within a group. In other words, if there is a gap with more than + this number of consecutive NaNs, it will only be partially filled. If method is not + specified, this is the maximum number of entries along the entire axis where NaNs + will be filled. Must be greater than 0 if not None. + + downcast : dict, default is None + A dict of item->dtype of what to downcast if possible, or the string ‘infer’ which will + try to downcast to an appropriate equal type (e.g. float64 to int64 if possible). + + This parameter is not yet supported in Snowpark pandas. + + Returns + ------- + :class:`~modin.pandas.DataFrame` + Object with missing values filled. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "key": [0, 0, 1, 1, 1], + ... "A": [np.nan, 2, np.nan, 3, np.nan], + ... "B": [2, 3, np.nan, np.nan, np.nan], + ... "C": [np.nan, np.nan, 2, np.nan, np.nan], + ... } + ... ) + >>> df + key A B C + 0 0 NaN 2.0 NaN + 1 0 2.0 3.0 NaN + 2 1 NaN NaN 2.0 + 3 1 3.0 NaN NaN + 4 1 NaN NaN NaN + + Propagate non-null values forward or backward within each group along columns. + + >>> df.groupby("key").fillna(method="ffill") + A B C + 0 NaN 2.0 NaN + 1 2.0 3.0 NaN + 2 NaN NaN 2.0 + 3 3.0 NaN 2.0 + 4 3.0 NaN 2.0 + + >>> df.groupby("key").fillna(method="bfill") + A B C + 0 2.0 2.0 NaN + 1 2.0 3.0 NaN + 2 3.0 NaN 2.0 + 3 3.0 NaN NaN + 4 NaN NaN NaN + + Propagate non-null values forward or backward within each group along rows. + + >>> df.T.groupby(np.array([0, 0, 1, 1])).fillna(method="ffill").T + key A B C + 0 0.0 0.0 2.0 2.0 + 1 0.0 2.0 3.0 3.0 + 2 1.0 1.0 NaN 2.0 + 3 1.0 3.0 NaN NaN + 4 1.0 1.0 NaN NaN + + >>> df.T.groupby(np.array([0, 0, 1, 1])).fillna(method="bfill").T + key A B C + 0 0.0 NaN 2.0 NaN + 1 0.0 2.0 3.0 NaN + 2 1.0 NaN 2.0 2.0 + 3 1.0 3.0 NaN NaN + 4 1.0 NaN NaN NaN + + Only replace the first NaN element within a group along rows. + + >>> df.groupby("key").fillna(method="ffill", limit=1) + A B C + 0 NaN 2.0 NaN + 1 2.0 3.0 NaN + 2 NaN NaN 2.0 + 3 3.0 NaN 2.0 + 4 3.0 NaN NaN + """ def count(): """ diff --git a/tests/integ/modin/groupby/test_groupby_fillna.py b/tests/integ/modin/groupby/test_groupby_fillna.py new file mode 100644 index 0000000000..7b9ebeba0d --- /dev/null +++ b/tests/integ/modin/groupby/test_groupby_fillna.py @@ -0,0 +1,297 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# +import modin.pandas as pd +import numpy as np +import pandas as native_pd +import pytest + +import snowflake.snowpark.modin.plugin # noqa: F401 +from snowflake.snowpark.exceptions import SnowparkSQLException +from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker +from tests.integ.modin.utils import eval_snowpark_pandas_result + +METHOD_OR_VALUES = [ + ("ffill", None), + ("bfill", None), + (None, 123), +] + +TEST_DF_DATA = { + "A": [None, 99, None, None, None, 98, 98, 98, None, 97], + "B": [88, None, None, None, 87, 88, 89, None, 86, None], + "C": [None, None, 1.99, 1.98, 1.97, 1.96, None, None, None, None], +} + +TEST_DF_INDEX_1 = native_pd.Index([0, 0, 0, 1, 1, 1, 1, 1, 2, 3], name="I") +TEST_DF_COLUMNS_1 = native_pd.Index(["A", "B", "C"], name="X") + +TEST_DF_DATA_2 = [ + [2, None, None, 99], + [2, 10, None, 98], + [2, None, 1.1, None], + [2, 15, None, 97], + [2, None, 1.1, None], + [1, None, 2.2, None], + [1, None, None, 96], + [1, None, 2.3], + [1, 20, 3.3, 95], + [2, None, None, 94], + [2, 30, None, None], + [2, None, 300, None], +] + +TEST_DF_INDEX_2 = pd.MultiIndex.from_tuples( + [ + (1, "a"), + (1, "a"), + (1, "a"), + (1, "a"), + (1, "a"), + (1, "b"), + (1, "b"), + (0, "a"), + (0, "a"), + (0, "a"), + (0, "a"), + (0, "a"), + ], + names=["I", "J"], +) +TEST_DF_COLUMNS_2 = pd.MultiIndex.from_tuples( + [(5, "A"), (5, "B"), (6, "C"), (6, "D")], names=["X", "Y"] +) + +TEST_DF_DATA_3 = ( + [[None, 100 + 10 * i, 200 + 10 * i] for i in range(6)] + + [[300 + 10 * i, None, 500 + 10 * i] for i in range(6)] + + [[400 + 10 * i, 600 + 10 * i, None] for i in range(6)] +) + +TEST_DF_INDEX_3 = native_pd.Index([50] * len(TEST_DF_DATA_3), name="I") +TEST_DF_COLUMNS_3 = native_pd.Index(["A", "B", "C"], name="X") + + +@pytest.mark.parametrize("method_or_value", METHOD_OR_VALUES) +@pytest.mark.parametrize("groupby_list", ["X", "key", ["X", "key"]]) +def test_groupby_fillna_basic(groupby_list, method_or_value): + method, value = method_or_value + native_df = native_pd.DataFrame( + { + "key": [0, 0, 1, 1, 1], + "A": [np.nan, 2, np.nan, 3, np.nan], + "B": [2, 3, np.nan, np.nan, np.nan], + "C": [np.nan, np.nan, 2, np.nan, np.nan], + }, + index=native_pd.Index(["A", "B", "C", "D", "E"], name="X"), + ) + snow_df = pd.DataFrame(native_df) + + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.groupby(groupby_list).fillna(method=method, value=value), + ) + + +@pytest.mark.parametrize("method_or_value", METHOD_OR_VALUES) +@pytest.mark.parametrize("fillna_axis", [0, 1]) +@pytest.mark.parametrize("by_list", ["I", 0, "A"]) +@sql_count_checker(query_count=1) +def test_groupby_fillna_single_index_ffill_bfill(method_or_value, by_list, fillna_axis): + method, value = method_or_value + native_df = native_pd.DataFrame( + TEST_DF_DATA, index=TEST_DF_INDEX_1, columns=TEST_DF_COLUMNS_1 + ) + snow_df = pd.DataFrame(native_df) + + if isinstance(by_list, int): + level = by_list + by_list = None + else: + level = None + + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.groupby(by_list, level=level).fillna( + method=method, value=value, axis=fillna_axis + ), + ) + + +@pytest.mark.parametrize("method", ["ffill", "bfill"]) +@pytest.mark.parametrize("by_list", ["I", 0]) +@pytest.mark.parametrize("limit", [None, 1, 3, 5, 10]) +@sql_count_checker(query_count=1) +def test_groupby_fillna_ffill_bfill_with_limit_axis_0(method, by_list, limit): + native_df = native_pd.DataFrame( + TEST_DF_DATA_3, index=TEST_DF_INDEX_3, columns=TEST_DF_COLUMNS_3 + ) + snow_df = pd.DataFrame(native_df) + + if isinstance(by_list, int): + level = by_list + by_list = None + else: + level = None + + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.groupby(by_list, level=level).fillna( + method=method, axis=0, limit=limit + ), + ) + + +@pytest.mark.parametrize("method", ["ffill", "bfill"]) +@pytest.mark.parametrize("by_list", ["X", 0]) +@pytest.mark.parametrize("limit", [None, 1, 3, 5, 10]) +@sql_count_checker(query_count=1) +def test_groupby_fillna_ffill_bfill_with_limit_axis_1(method, by_list, limit): + native_df = native_pd.DataFrame( + TEST_DF_DATA_3, index=TEST_DF_INDEX_3, columns=TEST_DF_COLUMNS_3 + ).T + snow_df = pd.DataFrame(native_df) + + if isinstance(by_list, int): + level = by_list + by_list = None + else: + level = None + + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.groupby(by_list, level=level).fillna( + method=method, axis=1, limit=limit + ), + ) + + +@pytest.mark.parametrize("method_or_value", METHOD_OR_VALUES) +@pytest.mark.parametrize("fillna_axis", [0, 1]) +@pytest.mark.parametrize("by", ["I", ["I", "J"]]) +@sql_count_checker(query_count=1) +def test_groupby_fillna_multiindex_ffill_bfill(method_or_value, fillna_axis, by): + method, value = method_or_value + native_df = native_pd.DataFrame( + TEST_DF_DATA_2, index=TEST_DF_INDEX_2, columns=TEST_DF_COLUMNS_2 + ) + snow_df = pd.DataFrame(native_df) + + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.groupby(by=by, level=None, axis=0).fillna( + method=method, value=value, axis=fillna_axis + ), + ) + + +@pytest.mark.parametrize("method_or_value", METHOD_OR_VALUES) +@pytest.mark.parametrize("fillna_axis", [0, 1]) +@pytest.mark.parametrize("level", [0, 1, [0, 1]]) +@sql_count_checker(query_count=1) +def test_groupby_fillna_multiindex_ffill_bfill_with_level( + method_or_value, fillna_axis, level +): + method, value = method_or_value + native_df = native_pd.DataFrame( + TEST_DF_DATA_2, index=TEST_DF_INDEX_2, columns=TEST_DF_COLUMNS_2 + ) + snow_df = pd.DataFrame(native_df) + + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.groupby(by=None, level=level, axis=0).fillna( + method=method, value=value, axis=fillna_axis + ), + ) + + +@pytest.mark.parametrize("method_or_value", METHOD_OR_VALUES) +@pytest.mark.parametrize("fillna_axis", [None, 1]) +@pytest.mark.parametrize( + "by_info", [(["I", "A"], 1), (["A"], 0), (["A", "B"], 1), (10, 0)] +) +def test_groupby_fillna_multiindex_ffill_bfill_negative( + method_or_value, fillna_axis, by_info +): + method, value = method_or_value + by_list, expected_query_count = by_info + native_df = native_pd.DataFrame( + TEST_DF_DATA_2, index=TEST_DF_INDEX_2, columns=TEST_DF_COLUMNS_2 + ) + snow_df = pd.DataFrame(native_df) + + if isinstance(by_list, int): + level = by_list + by_list = None + else: + level = None + + with SqlCounter(query_count=expected_query_count): + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.groupby(by_list, level=level, axis=0).fillna( + method=method, value=value, axis=fillna_axis + ), + expect_exception=True, + expect_exception_type=IndexError if level is not None else KeyError, + ) + + +@pytest.mark.parametrize( + "method_or_value", [("buzz", None), (None, None), ("ffill", 123)] +) +@sql_count_checker(query_count=0) +def test_groupby_fillna_invalid_method_negative(method_or_value): + method, value = method_or_value + native_df = native_pd.DataFrame( + TEST_DF_DATA_2, index=TEST_DF_INDEX_2, columns=TEST_DF_COLUMNS_2 + ) + snow_df = pd.DataFrame(native_df) + + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.groupby(["I"], axis=0).fillna(method=method, value=value, axis=0), + expect_exception=True, + expect_exception_type=ValueError, + ) + + +@sql_count_checker(query_count=0) +def test_groupby_fillna_value_not_type_compatible_negative(): + native_df = native_pd.DataFrame( + TEST_DF_DATA, index=TEST_DF_INDEX_1, columns=TEST_DF_COLUMNS_1 + ) + snow_df = pd.DataFrame(native_df) + + message = "Numeric value 'str' is not recognized" + # native pandas is able to upcast the column to object type if the type for the fillna + # value is different compare with the column data type. However, in Snowpark pandas, we stay + # consistent with the Snowflake type system, and a SnowparkSQLException is raised if the type + # for the fillna value is not compatible with the column type. + with pytest.raises(SnowparkSQLException, match=message): + # call to_pandas to trigger the evaluation of the operation + snow_df.groupby("I").fillna(value="str").to_pandas() + + +@sql_count_checker(query_count=0) +def test_groupby_fillna_downcast_not_supported_negative(): + native_df = native_pd.DataFrame( + TEST_DF_DATA, index=TEST_DF_INDEX_1, columns=TEST_DF_COLUMNS_1 + ) + snow_df = pd.DataFrame(native_df) + + with pytest.raises( + NotImplementedError, + match="Snowpark pandas fillna API doesn't yet support 'downcast' parameter", + ): + # call to_pandas to trigger the evaluation of the operation + snow_df.groupby("I").fillna(method="ffill", downcast={"A": "str"}).to_pandas() diff --git a/tests/integ/modin/groupby/test_groupby_first_last.py b/tests/integ/modin/groupby/test_groupby_first_last.py index 5da35806dd..c580461b19 100644 --- a/tests/integ/modin/groupby/test_groupby_first_last.py +++ b/tests/integ/modin/groupby/test_groupby_first_last.py @@ -72,7 +72,7 @@ def test_groupby_first_last(by, as_index, skipna, method): # TODO: Add sort when SNOW-1481281 is resolved snowpark_pandas_df = pd.DataFrame(data_dictionary) pandas_df = snowpark_pandas_df.to_pandas() - with SqlCounter(query_count=1): + with SqlCounter(query_counƒt=1): eval_snowpark_pandas_result( snowpark_pandas_df, pandas_df,