Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make meta calculcation for merge more efficient #284

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions dask_expr/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Merge(Expr):
"suffixes",
"indicator",
"shuffle_backend",
"_meta",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes me nervous. I think that it's a good principle to only have necessary operands. Any derived state should be computed.

For example, if an optimization were to change some parameter here, like suffixes or something, I wouldn't want to worry about also modifying meta at the same time. It's nice to be able to rely on this invariant across the project.

If we want to include some other state in a custom constructor I would be more ok with that (although still nervous). In that case I'd want to make sure that the constructor always passed type(self)(*self.operands) == self

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The downside with this is that we are stuck with the repeated computation, caching objects and so on won't help here, since meta will change, e.g. caching is not useful. We genuinely change the object when we re-create it, which means that we will always trigger a fresh computation of meta. Which by itself isn't bad, but non-empty meta computations are relatively expensive (empty meta won't work here).

For example, if an optimization were to change some parameter here, like suffixes or something, I wouldn't want to worry about also modifying meta at the same time. It's nice to be able to rely on this invariant across the project

We can simply not pass meta in this case which would trigger a fresh computation, this is only a fast path.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We genuinely change the object when we re-create it, which means that we will always trigger a fresh computation of meta

If the inputs to the pd.merge call are going to change then it seems like we need to recompute meta anyway. If the inputs aren't changing but we're recomputing then maybe that is a sign that we're caching on the wrong things. Maybe we should have a staticfunction or something instead.

We can simply not pass meta in this case which would trigger a fresh computation, this is only a fast path

Imagine an optimization which did something like the following:

def subs(obj, old, new):
    operands = [
        new if operand == old else operand  # substitute old for new in all operands
        for operand in obj.operands
    ]
    return type(obj)(*operands)

This fails if we store derived state in operands, because _meta is in there and really we should choose to not include it any more.

I'm ok keeping derived state on the class. I just don't think that it should be in operands. This probably requires custom constructors.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are optimisations where meta changes but we compute the necessary information anyway, so we can simply adjust meta if we want to, Projections are a good example for this.

Lower is an example where meta won't change in case of merge, but you can't properly cache it either. We might introduce a shuffle which means caching will fail for most cases.

I'll add a custom constructor here that should work as well. Then we can see how that looks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lower is an example where meta won't change in case of merge, but you can't properly cache it either. We might introduce a shuffle which means caching will fail for most cases.

What about caching not on the object, but somewhere else? That way anyone that asks "what is the meta for these inputs?" will get the cached result, regardless of what object they're calling from? (this was my staticmethod suggestion from above)

Copy link
Collaborator Author

@phofl phofl Aug 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, my comment wasn't clear enough probably. That's how I understood what you were saying, but the problem is as follows:

  • That won't work if we introduce a Projection (which is something we could accept, although that wouldn't make me happy in this particular case)
  • Meta won't change when we lower the graph, but we will introduce a shuffle, so the inputs of the merge computation will change while lowering, which means that the cache wouldn't work anymore

Caching won't help in either of these two cases, which is where most of the time is spent unfortunately

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like we're maybe talking past each other here. Happy to chat live if you're free later.

]
_defaults = {
"how": "inner",
Expand All @@ -47,6 +48,7 @@ class Merge(Expr):
"suffixes": ("_x", "_y"),
"indicator": False,
"shuffle_backend": None,
"_meta": None,
}

def __str__(self):
Expand All @@ -69,6 +71,8 @@ def kwargs(self):

@functools.cached_property
def _meta(self):
if self.operand("_meta") is not None:
return self.operand("_meta")
left = meta_nonempty(self.left._meta)
right = meta_nonempty(self.right._meta)
return make_meta(left.merge(right, **self.kwargs))
Expand Down Expand Up @@ -104,7 +108,7 @@ def _lower(self):
or right.npartitions == 1
and how in ("left", "inner")
):
return BlockwiseMerge(left, right, **self.kwargs)
return BlockwiseMerge(left, right, **self.kwargs, _meta=self._meta)

# Check if we are merging on indices with known divisions
merge_indexed_left = (
Expand Down Expand Up @@ -165,6 +169,7 @@ def _lower(self):
indicator=self.indicator,
left_index=left_index,
right_index=right_index,
_meta=self._meta,
)

if shuffle_left_on:
Expand All @@ -186,7 +191,7 @@ def _lower(self):
)

# Blockwise merge
return BlockwiseMerge(left, right, **self.kwargs)
return BlockwiseMerge(left, right, **self.kwargs, _meta=self._meta)

def _simplify_up(self, parent):
if isinstance(parent, (Projection, Index)):
Expand All @@ -203,13 +208,20 @@ def _simplify_up(self, parent):
projection = [projection]

left, right = self.left, self.right
left_on, right_on = self.left_on, self.right_on
if isinstance(self.left_on, list):
left_on = self.left_on
else:
left_on = [self.left_on] if self.left_on is not None else []
if isinstance(self.right_on, list):
right_on = self.right_on
else:
right_on = [self.right_on] if self.right_on is not None else []
left_suffix, right_suffix = self.suffixes[0], self.suffixes[1]
project_left, project_right = [], []

# Find columns to project on the left
for col in left.columns:
if left_on is not None and col in left_on or col in projection:
if col in left_on or col in projection:
project_left.append(col)
elif f"{col}{left_suffix}" in projection:
project_left.append(col)
Expand All @@ -220,7 +232,7 @@ def _simplify_up(self, parent):

# Find columns to project on the right
for col in right.columns:
if right_on is not None and col in right_on or col in projection:
if col in right_on or col in projection:
project_right.append(col)
elif f"{col}{right_suffix}" in projection:
project_right.append(col)
Expand All @@ -232,8 +244,13 @@ def _simplify_up(self, parent):
if set(project_left) < set(left.columns) or set(project_right) < set(
right.columns
):
columns = left_on + right_on + projection
meta_cols = [col for col in self.columns if col in columns]
result = type(self)(
left[project_left], right[project_right], *self.operands[2:]
left[project_left],
right[project_right],
*self.operands[2:-1],
_meta=self._meta[meta_cols],
)
if parent_columns is None:
return type(parent)(result)
Expand All @@ -252,6 +269,7 @@ class HashJoinP2P(Merge, PartitionsFiltered):
"suffixes",
"indicator",
"_partitions",
"_meta",
]
_defaults = {
"how": "inner",
Expand All @@ -262,24 +280,15 @@ class HashJoinP2P(Merge, PartitionsFiltered):
"suffixes": ("_x", "_y"),
"indicator": False,
"_partitions": None,
"_meta": None,
}

def _lower(self):
return None

@functools.cached_property
def _meta(self):
left = self.left._meta.drop(columns=_HASH_COLUMN_NAME)
right = self.right._meta.drop(columns=_HASH_COLUMN_NAME)
return left.merge(
right,
left_on=self.left_on,
right_on=self.right_on,
indicator=self.indicator,
suffixes=self.suffixes,
left_index=self.left_index,
right_index=self.right_index,
)
return self.operand("_meta")

def _layer(self) -> dict:
dsk = {}
Expand Down
2 changes: 1 addition & 1 deletion dask_expr/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def _dataset_info(self):

return dataset_info

@property
@cached_property
def _meta(self):
meta = self._dataset_info["meta"]
if self._series:
Expand Down
12 changes: 12 additions & 0 deletions dask_expr/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,15 @@ def test_merge_len():
query = df.merge(df2).index.optimize(fuse=False)
expected = df[["x"]].merge(df2[["x"]]).index.optimize(fuse=False)
assert query._name == expected._name


def test_merge_optimize_subset_strings():
pdf = lib.DataFrame({"a": [1, 2], "aaa": 1})
pdf2 = lib.DataFrame({"b": [1, 2], "aaa": 1})
df = from_pandas(pdf)
df2 = from_pandas(pdf2)

query = df.merge(df2, on="aaa")[["aaa"]].optimize(fuse=False)
exp = df[["aaa"]].merge(df2[["aaa"]], on="aaa").optimize(fuse=False)
assert query._name == exp._name
assert_eq(query, pdf.merge(pdf2, on="aaa")[["aaa"]])
Loading