-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SNOW-1649172]: Fix loc
set when setting DataFrame row with Series value
#2213
base: main
Are you sure you want to change the base?
Changes from 13 commits
3151ed7
2bd792f
9e2a26d
66f01bc
c3b9582
a9aceb9
c18ae1f
c159e3a
2289960
89401a8
25ccbb9
8f75bec
f8797d8
c252eb5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1889,6 +1889,7 @@ def _set_2d_labels_helper_for_frame_item( | |
assert len(index.data_column_snowflake_quoted_identifiers) == len( | ||
item.index_column_snowflake_quoted_identifiers | ||
), "TODO: SNOW-966427 handle it well in multiindex case" | ||
|
||
if not matching_item_rows_by_label: | ||
index = index.ensure_row_position_column() | ||
left_on = [index.row_position_snowflake_quoted_identifier] | ||
|
@@ -2125,6 +2126,7 @@ def set_frame_2d_labels( | |
matching_item_rows_by_label: bool, | ||
index_is_bool_indexer: bool, | ||
deduplicate_columns: bool, | ||
frame_is_df_and_item_is_series: bool, | ||
) -> InternalFrame: | ||
""" | ||
Helper function to handle the general loc set functionality. The general idea here is to join the key from ``index`` | ||
|
@@ -2151,6 +2153,7 @@ def set_frame_2d_labels( | |
index_is_bool_indexer: if True, the index is a boolean indexer. Note we only handle boolean indexer with | ||
item is a SnowflakeQueryCompiler here. | ||
deduplicate_columns: if True, deduplicate columns from ``columns``. | ||
frame_is_df_and_item_is_series: Whether item is from a Series object and is being assigned to a DataFrame object | ||
Returns: | ||
New frame where values have been set | ||
""" | ||
|
@@ -2213,6 +2216,36 @@ def set_frame_2d_labels( | |
index_is_frame = isinstance(index, InternalFrame) | ||
item_is_frame = isinstance(item, InternalFrame) | ||
item_is_scalar = is_scalar(item) | ||
original_index = index | ||
# If `item` is from a Series (rather than a Dataframe), flip the series item values to apply them | ||
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# across columns rather than rows. | ||
if frame_is_df_and_item_is_series and (columns == slice(None) or len(columns) > 1): # type: ignore[arg-type] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you wrap it into a function and use function name to brief what this method does? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does this mean There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be done in |
||
# If columns is slice(None), we are setting all columns in the InternalFrame. | ||
matching_item_columns_by_label = True | ||
col_len = ( | ||
len(internal_frame.data_column_snowflake_quoted_identifiers) | ||
if columns == slice(None) | ||
else len(columns) # type: ignore[arg-type] | ||
) | ||
item = get_item_series_as_single_row_frame( | ||
item, col_len, move_index_to_cols=True | ||
) | ||
|
||
if is_scalar(index): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens if index is not scalar? |
||
new_item = item.append_column("__index__", pandas_lit(index)) | ||
item = InternalFrame.create( | ||
ordered_dataframe=new_item.ordered_dataframe, | ||
data_column_pandas_labels=item.data_column_pandas_labels, | ||
data_column_snowflake_quoted_identifiers=item.data_column_snowflake_quoted_identifiers, | ||
data_column_pandas_index_names=item.data_column_pandas_index_names, | ||
index_column_pandas_labels=item.index_column_pandas_labels, | ||
index_column_snowflake_quoted_identifiers=[ | ||
new_item.data_column_snowflake_quoted_identifiers[-1] | ||
], | ||
data_column_types=item.cached_data_column_snowpark_pandas_types, | ||
index_column_types=[item.cached_data_column_snowpark_pandas_types[-1]], | ||
) | ||
index = pd.Series([index])._query_compiler._modin_frame | ||
|
||
assert not isinstance(index, slice) or index == slice( | ||
None | ||
|
@@ -2409,7 +2442,7 @@ def generate_updated_expr_for_existing_col( | |
|
||
if index_is_scalar: | ||
col_obj = iff( | ||
result_frame_index_col.equal_null(pandas_lit(index)), | ||
result_frame_index_col.equal_null(pandas_lit(original_index)), | ||
col_obj, | ||
original_col, | ||
) | ||
|
@@ -2468,7 +2501,7 @@ def generate_updated_expr_for_new_col( | |
return SnowparkPandasColumn(pandas_lit(None), None) | ||
if index_is_scalar: | ||
new_column = iff( | ||
result_frame_index_col.equal_null(pandas_lit(index)), | ||
result_frame_index_col.equal_null(pandas_lit(original_index)), | ||
new_column, | ||
pandas_lit(None), | ||
) | ||
|
@@ -2604,7 +2637,6 @@ def set_frame_2d_positional( | |
index = _get_adjusted_key_frame_by_row_pos_int_frame(internal_frame, index) | ||
|
||
assert isinstance(index_data_type, (_IntegralType, BooleanType)) | ||
|
||
if isinstance(item, InternalFrame): | ||
# If item is Series (rather than a Dataframe), then we need to flip the series item values so they apply across | ||
# columns rather than rows. | ||
|
@@ -2918,7 +2950,9 @@ def get_kv_frame_from_index_and_item_frames( | |
|
||
|
||
def get_item_series_as_single_row_frame( | ||
item: InternalFrame, num_columns: int | ||
item: InternalFrame, | ||
num_columns: int, | ||
move_index_to_cols: Optional[bool] = False, | ||
) -> InternalFrame: | ||
""" | ||
Get an internal frame that transpose single data column into frame with single row. For example, if the | ||
|
@@ -2940,13 +2974,18 @@ def get_item_series_as_single_row_frame( | |
---------- | ||
num_columns: Number of columns in the return frame | ||
item: Item frame that contains a single column of values. | ||
move_index_to_cols: Whether to use the index as the column names. | ||
|
||
Returns | ||
------- | ||
Frame containing single row with columns for each row. | ||
""" | ||
item = item.ensure_row_position_column() | ||
item_series_pandas_labels = list(range(num_columns)) | ||
item_series_pandas_labels = ( | ||
list(range(num_columns)) | ||
if not move_index_to_cols | ||
else item.index_columns_pandas_index().values | ||
) | ||
|
||
# This is a 2 step process. | ||
# | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1830,6 +1830,15 @@ def loc(): | |
viper 0 0 | ||
sidewinder 0 0 | ||
|
||
Setting the values with a Series item. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sfc-gh-helmeleegy this is the example I added |
||
|
||
>>> df.loc["viper"] = pd.Series([99, 99], index=["max_speed", "shield"]) | ||
>>> df | ||
max_speed shield | ||
cobra 30 10 | ||
viper 99 99 | ||
sidewinder 0 0 | ||
|
||
**Getting values on a DataFrame with an index that has integer labels** | ||
|
||
Another example using integers for the index | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3205,3 +3205,68 @@ def test_raise_set_cell_with_list_like_value_error(): | |
s.iloc[0] = [0, 0] | ||
with pytest.raises(NotImplementedError): | ||
s.to_frame().iloc[0, 0] = [0, 0] | ||
|
||
|
||
@sql_count_checker(query_count=1, join_count=3) | ||
@pytest.mark.parametrize("index", [list("ABC"), [0, 1, 2]]) | ||
def test_df_iloc_set_row_from_series(index): | ||
native_df = native_pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=list("ABC")) | ||
snow_df = pd.DataFrame(native_df) | ||
|
||
def ilocset(df): | ||
series = ( | ||
pd.Series([1, 4, 9], index=index) | ||
if isinstance(df, pd.DataFrame) | ||
else native_pd.Series([1, 4, 9], index=index) | ||
) | ||
df.iloc[1] = series | ||
return df | ||
|
||
eval_snowpark_pandas_result( | ||
snow_df, | ||
native_df, | ||
ilocset, | ||
) | ||
|
||
|
||
@sql_count_checker(query_count=1, join_count=3) | ||
@pytest.mark.parametrize("index", [[3, 4, 5], [0, 1, 2]]) | ||
def test_df_iloc_full_set_row_from_series(index): | ||
native_df = native_pd.DataFrame([[1, 2, 3], [4, 5, 6]]) | ||
snow_df = pd.DataFrame(native_df) | ||
|
||
def ilocset(df): | ||
series = ( | ||
pd.Series([1, 4, 9], index=index) | ||
if isinstance(df, pd.DataFrame) | ||
else native_pd.Series([1, 4, 9], index=index) | ||
) | ||
df.iloc[:] = series | ||
return df | ||
|
||
eval_snowpark_pandas_result( | ||
snow_df, | ||
native_df, | ||
ilocset, | ||
) | ||
|
||
|
||
@sql_count_checker(query_count=1, join_count=3) | ||
def test_df_iloc_full_set_row_from_series_int_and_string_indexes(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can combine this one into the previous one. |
||
native_df = native_pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=list("ABC")) | ||
snow_df = pd.DataFrame(native_df) | ||
|
||
def ilocset(df): | ||
series = ( | ||
pd.Series([1, 4, 9], index=list("ABC")) | ||
if isinstance(df, pd.DataFrame) | ||
else native_pd.Series([1, 4, 9], index=list("ABC")) | ||
) | ||
df.iloc[:] = series | ||
return df | ||
|
||
eval_snowpark_pandas_result( | ||
snow_df, | ||
native_df, | ||
ilocset, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4068,6 +4068,74 @@ def test_df_loc_get_with_timedelta_and_none_key(): | |
assert_frame_equal(snow_df.loc[None], expected_df, check_column_type=False) | ||
|
||
|
||
@sql_count_checker(query_count=2, join_count=4) | ||
@pytest.mark.parametrize("index", [list("ABC"), [0, 1, 2]]) | ||
def test_df_loc_set_row_from_series(index): | ||
native_df = native_pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=list("ABC")) | ||
snow_df = pd.DataFrame(native_df) | ||
|
||
def locset(df): | ||
series = ( | ||
pd.Series([1, 4, 9], index=index) | ||
if isinstance(df, pd.DataFrame) | ||
else native_pd.Series([1, 4, 9], index=index) | ||
) | ||
df.loc[1] = series | ||
return df | ||
|
||
eval_snowpark_pandas_result( | ||
snow_df, | ||
native_df, | ||
locset, | ||
) | ||
|
||
|
||
@sql_count_checker(query_count=2, join_count=1) | ||
@pytest.mark.parametrize("index", [[3, 4, 5], [0, 1, 2]]) | ||
def test_df_loc_full_set_row_from_series_pandas_errors(index): | ||
native_df = native_pd.DataFrame([[1, 2, 3], [4, 5, 6]]) | ||
snow_df = pd.DataFrame(native_df) | ||
|
||
with pytest.raises(ValueError, match="setting an array element with a sequence."): | ||
native_df.loc[:] = native_pd.Series([1, 4, 9], index=index) | ||
|
||
def locset(df): | ||
series = ( | ||
pd.Series([1, 4, 9], index=index) | ||
if isinstance(df, pd.DataFrame) | ||
else native_pd.Series([1, 4, 9], index=index) | ||
) | ||
if isinstance(df, pd.DataFrame): | ||
df.loc[:] = series | ||
else: | ||
if index == [0, 1, 2]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you'd better just compare the result with the expected result. These steps here can be confusing. |
||
df.loc[0] = series | ||
df.loc[1] = None | ||
else: | ||
df.loc[[0, 1]] = None | ||
return df | ||
|
||
eval_snowpark_pandas_result( | ||
snow_df, | ||
native_df, | ||
locset, | ||
) | ||
|
||
|
||
@sql_count_checker(query_count=1) | ||
def test_df_loc_full_set_row_from_series_errors(): | ||
# We error here because our join columns are an int (item.index) | ||
# and a string (value.index) column respectively, and we do not | ||
# support joins between those. | ||
snow_df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=list("ABC")) | ||
|
||
with pytest.raises( | ||
SnowparkSQLException, match="Numeric value 'A' is not recognized" | ||
): | ||
snow_df.loc[:] = pd.Series([1, 4, 9], index=list("ABC")) | ||
snow_df.to_pandas() # Force materialization. | ||
|
||
|
||
@sql_count_checker(query_count=0) | ||
def test_df_loc_invalid_key(): | ||
# Bug fix: SNOW-1320674 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens for item is a Index or list? Are they matching with pandas? Can you verify too?