Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple sharding policy in plugin.jax.data_iterator #5535

Open
1 task done
sali1997s opened this issue Jun 19, 2024 · 3 comments
Open
1 task done

Multiple sharding policy in plugin.jax.data_iterator #5535

sali1997s opened this issue Jun 19, 2024 · 3 comments
Assignees
Labels
enhancement New feature or request JAX Issues related to DALI and JAX integration question Further information is requested

Comments

@sali1997s
Copy link

Describe the question.

Is there a way to set jax sharding for each output separately plugin.jax.data_iterator?
For example, I have pipeline, that has 2 outputs. I want first output to be PartitionSpec(‘batch’, ‘model’) and the second to be PartitionSpec(‘batch, None) or PartitionSpec(‘batch’)?

Check for duplicates

  • I have searched the open bugs/issues and have found no duplicates for this bug report
@sali1997s sali1997s added the question Further information is requested label Jun 19, 2024
@awolant awolant self-assigned this Jun 19, 2024
@awolant awolant added JAX Issues related to DALI and JAX integration enhancement New feature or request labels Jun 19, 2024
@awolant
Copy link
Contributor

awolant commented Jun 19, 2024

Hello @sali1997s

Thanks for the question. Currently, something like this not supported unfortunately. This enhancement is in our TODO list for JAX integration.

Could you tell more about your use case and how would you need this to work? Especially, how do you map this to map on devices? Do you need both CPU and the GPU?
I am asking because with DALI pipelines working on particular GPU there are some design and performance considerations for this feature and we would like the input from the users to influence these decisions. Thanks!

@sali1997s
Copy link
Author

sali1997s commented Jun 21, 2024

Thank you, for answering, @awolant!
Sorry, I was thinking about my task deeper, and came to conclusion that partitioning data over batch fully covers my needs.
I thought, i need more control over partitioning, but i don't need it currently.

But i've found that dataloader workes only in Data Parallel training, it currently doesn't support model-parallism inside.
Here is a minimal reproducable example. By changing device_mesh to mesh_utils.create_device_mesh((4, 2)) it fails.

Also i've got question about @data_iterator (size param) and external source interaction. In case number of samples is divisible by shard size it works as supposed. But in other case it fails with
WARNING:root:DALI iterator does not support resetting while epoch is not finished. Ignoring.... And doesn't go for second epoch iteration. Is there i can do something with it?

from nvidia.dali.plugin.jax import data_iterator
from jax.experimental import mesh_utils
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from jax.sharding import Mesh, PartitionSpec, NamedSharding
import numpy as np

GLOBAL_BATCH_SIZE = 64

class DataSorceCallable:
    def __init__(self, batch_size, seed, shard_id, num_shards):
        self.rng = np.random.default_rng(seed=seed)
        self.batch_size = batch_size
        
        self.files = np.random.rand(GLOBAL_BATCH_SIZE * 10, 4).astype(np.float32)

        self.shard_id = shard_id
        self.num_shards = num_shards
     
        self.shard_size = len(self.files) // num_shards
        self.shard_offset = self.shard_size * shard_id

        # If the shard size is not divisible by the batch size, the last incomplete batch
        # will be omitted.
        self.full_iterations = self.shard_size // batch_size 
        # print(self.full_iterations, self.shard_size, batch_size, len(self.files))
        self.perm = None
        self.last_seen_epoch = (
            None  # so that we don't have to recompute the `self.perm` for every sample
        )
    def __call__(self, sample_info):
        if sample_info.iteration >= self.full_iterations:
            raise StopIteration()
        if self.last_seen_epoch != sample_info.epoch_idx:
            self.last_seen_epoch = sample_info.epoch_idx
            self.perm = np.random.default_rng(seed=42 + sample_info.epoch_idx).permutation( 
                len(self.files)
            )

        sample_idx = self.perm[sample_info.idx_in_epoch + self.shard_offset]

        return self.files[sample_idx, :]

if __name__ == "__main__":
    device_mesh = mesh_utils.create_device_mesh((8, 1))
    mesh = Mesh(device_mesh, axis_names=("batch",'model'))
    sharding = NamedSharding(mesh, PartitionSpec("batch",))

    @data_iterator(output_map=['out'], sharding=sharding, size = GLOBAL_BATCH_SIZE * 10, prepare_first_batch = False)
    def callable_pipeline(num_shards, shard_id):
        out, = fn.external_source(
            source=DataSorceCallable(GLOBAL_BATCH_SIZE//num_shards, num_shards=num_shards, shard_id=shard_id, seed=42),
            num_outputs=1,
            batch=False,
            # parallel=True,
            dtype=[types.FLOAT],
        )
        return out.gpu()
    
    dataloader = callable_pipeline(batch_size = GLOBAL_BATCH_SIZE)

    for el in dataloader:
        print(el['out'].sharding)

@awolant
Copy link
Contributor

awolant commented Jul 4, 2024

Thanks for the reproduction. This is definitely a feature that could be added to DALI JAX support to enhance in functionality. For the first version of this integrating layer we focused only on the most common and simple cases.

When it comes to your question about external source, unfortunately, right now there is no way to do something like this. As I said, for this first version we wanted it to work in the most common and stable case.

In your use case, how would you expect this to work? I am asking just to get feedback about possible improvements to the JAX integration? Would you like for the missing samples to be filled/duplicated somehow?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request JAX Issues related to DALI and JAX integration question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants