Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Buffer aliasing behavior in the presence of control flow #16793

Open
gspschmid opened this issue Sep 4, 2024 · 8 comments
Open

Buffer aliasing behavior in the presence of control flow #16793

gspschmid opened this issue Sep 4, 2024 · 8 comments

Comments

@gspschmid
Copy link
Contributor

Async operations require a mechanism to keep buffers alive from op-start until the corresponding op-done. A natural way to express the intent to keep, say, the input buffer x to send-start(x) alive is to make send-start "pass the input buffer through" by instructing XLA to alias its first input with its first output, and then passing the aliased output to send-done. Async HLO Instructions in fact leverages this very idea. In pseudo code,

x' = send-start(x)  # send-start aliases input 0 with output 0
# other ops
y = send-done(x')  # original x should be alive until now, allowing send to keep using the buffer

Can we actually rely on this pattern working even in the presence of arbitrary aliasing operations being applied to x' before it eventually flows into foo-done?

I've been experimenting with custom calls that mimic this pattern (but are not known to be asynchronous to XLA) and encountered a situation where this assumption falls down. Consider the following example

def show(x):
  # a primitive that aliases x and prints the address of x
  (...)

@jax.jit
def f_good(x):
  x = show(x)
  y = jax.lax.cond(x != 999, (lambda: x), (lambda: x))
  return show(y)

@jax.jit
def f_bad(x):
  x = show(x)
  y = jax.lax.cond(x != 999, (lambda: x), (lambda: jnp.zeros_like(x)))
  return show(y)

def example():
  print('f_good:')
  f_good(123)

  print('\nf_bad:')
  f_bad(123)

which produces

f_good:
SHOW in=0x7f294a000100
SHOW in=0x7f294a000100

f_bad:
SHOW in=0x7f294a000280
SHOW in=0x7f294a000100

Notably, in the case of f_bad we deal with two distinct buffers, i.e. the input is implicitly copied at some point, breaking our initial assumption.

Is this expected behavior or a bug in bufferization/copying? I haven't investigated, but perhaps XLA-native async ops rely on special treatment during live range computation (

// If the instruction is in an asynchronous context, extend the live range
)?

Reproducer: https://gist.github.com/gspschmid/372bba804b48c4abbf5c94f19b2b32cd
HLO f_good (before optimizations)
HLO f_bad (before optimizations)

@gspschmid
Copy link
Contributor Author

cc @ezhulenev @frgossen @mattjj @nouiz

@frgossen
Copy link
Member

frgossen commented Sep 4, 2024

Is this expected behavior or a bug in bufferization/copying? I haven't investigated, but perhaps XLA-native async ops rely on special treatment during live range computation (

Today, I would say this is expected behaviour in XLA. Async ops do not get special treatment yet but they do require that the buffer passed from async start to async done does not have any other uses. That way, you can rely on there not being a copy.

The pass you'd want to look at is copy insertion. Especially around while and cond ops it inserts copies and then tries to remove them, which it will only do if it can proof no overlapping live time.

@gspschmid
Copy link
Contributor Author

gspschmid commented Sep 5, 2024

Does that mean that XLA cannot take an async op in a conditional branch/loop and move the async-done outside the branch/loop body? Or is that it can, but possibly incorrectly so (given the chance of copies)?

Async ops do not get special treatment yet but they do require that the buffer passed from async start to async done does not have any other uses. That way, you can rely on there not being a copy.

That makes me wonder whether there is actually any JAX-surface-level (or at least StableHLO) pattern that would guarantee a buffer being kept alive. In the presence of earlier program transformations and, say, host offloading, even relying on non-copying for straightline code seems like tenuous assumption.

@frgossen
Copy link
Member

frgossen commented Sep 5, 2024

Does that mean that XLA cannot take the an async op in a conditional branch/loop and move the async-done outside the branch/loop body? Or is that it does, but possibly incorrectly so (given the possibility of copies)?

It does not support that today but I'm working on it for loops.

that would guarantee a buffer being kept alive

I don't think there is atm. other than no loops/conds + single use chains. Is that what you mean with "straightline code"?

@ezhulenev
Copy link
Member

I think we need a special treatment for async ops, we should be able to prove with dataflow analysis that async start is actually consumed by a corresponding done operation, and nothing in between. I suspect same issue will show up in pipeline partitioning and reordering send/recv start and done.

@frgossen
Copy link
Member

frgossen commented Sep 5, 2024

I think we need a special treatment for async ops, we should be able to prove with dataflow analysis that async start is actually consumed by a corresponding done operation, and nothing in between. I suspect same issue will show up in pipeline partitioning and reordering send/recv start and done.

Yes, this shows up in pipeline parallelism. This is what I want to fix it for in loops

@nouiz
Copy link
Contributor

nouiz commented Sep 5, 2024

Why focus on async ops?
Why not focus on the aliasing behavior?
What is the optimized HLO to see if and where the copy happens.

@ezhulenev
Copy link
Member

I think the problem is that today you can’t distinguish between mutable and immutable aliasing, and by default XLA assumes in place update when it sees aliasing buffers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants