-
Notifications
You must be signed in to change notification settings - Fork 615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multiple sharding policy in plugin.jax.data_iterator #5535
Comments
Hello @sali1997s Thanks for the question. Currently, something like this not supported unfortunately. This enhancement is in our TODO list for JAX integration. Could you tell more about your use case and how would you need this to work? Especially, how do you map this to map on devices? Do you need both CPU and the GPU? |
Thank you, for answering, @awolant! But i've found that dataloader workes only in Data Parallel training, it currently doesn't support model-parallism inside. Also i've got question about
|
Thanks for the reproduction. This is definitely a feature that could be added to DALI JAX support to enhance in functionality. For the first version of this integrating layer we focused only on the most common and simple cases. When it comes to your question about external source, unfortunately, right now there is no way to do something like this. As I said, for this first version we wanted it to work in the most common and stable case. In your use case, how would you expect this to work? I am asking just to get feedback about possible improvements to the JAX integration? Would you like for the missing samples to be filled/duplicated somehow? |
Describe the question.
Is there a way to set jax sharding for each output separately plugin.jax.data_iterator?
For example, I have pipeline, that has 2 outputs. I want first output to be PartitionSpec(‘batch’, ‘model’) and the second to be PartitionSpec(‘batch, None) or PartitionSpec(‘batch’)?
Check for duplicates
The text was updated successfully, but these errors were encountered: