Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Asynchronous prefetching #5583

Open
1 task done
quanvuong opened this issue Jul 28, 2024 · 8 comments
Open
1 task done

Asynchronous prefetching #5583

quanvuong opened this issue Jul 28, 2024 · 8 comments
Assignees
Labels
enhancement New feature or request

Comments

@quanvuong
Copy link

Is this a new feature, an improvement, or a change to existing functionality?

New Feature

How would you describe the priority of this feature request

Nice to have (e.g. Adoption is possible, the feature will enhance the use-case even more).

Please provide a clear description of problem this feature solves

Instantiating the dali dataloader for jax takes a long time, because of prefetching.

Feature Description

It would be nice to have an asynchronous prefetching feature, so I can interleave jitting the model and the prefetching operations.

Describe your ideal solution

Have two function

def start_async_prefetch(self):

def block_till_ready_async_prefetch():

With these two functions, I can control when prefetching happens

Describe any alternatives you have considered

No response

Additional context

No response

Check for duplicates

  • I have searched the open bugs/issues and have found no duplicates for this bug report
@quanvuong quanvuong added the enhancement New feature or request label Jul 28, 2024
@JanuszL
Copy link
Contributor

JanuszL commented Aug 5, 2024

Hi @quanvuong,

Thank you for reaching out.
Currently, the exact functionality you are asking for is not fully exposed.
What you can try out is:

  • instantiate the iterator with prepare_first_batch=False so there is no prefetching when it is created
  • and (this should work, although it was not fully tested in this kind of use case) call iterator._schedule_runs(release_outputs=False) to start prefetching. asynchroniously and then call just next() to wait for the data.
    Please let us know if that works for your use case.

@JanuszL JanuszL assigned JanuszL and unassigned mzient Aug 5, 2024
@quanvuong
Copy link
Author

That does seem to speed things up by quite a bit.

Instantiating the data iterator is still quite slow because I'm using the integration with jax, which requires starting the worker pool with "spawn" (on a 8 H100 node, starting the all the worker pools can takes 20 minutes).

Do you have advises on how to improve the speed here?

@JanuszL
Copy link
Contributor

JanuszL commented Aug 12, 2024

Hi @quanvuong,

(on a 8 H100 node, starting the all the worker pools can takes 20 minutes).

This is surprising and not expected. Can you share a self contained repro code we can run on our end for debugging?

@quanvuong
Copy link
Author

I'm working on the self contained repo, in the mean time, I have narrowed down to these lines that are the slow operations.

Specifically, going from s0 to s1 takes 80 seconds (in nvidia/dali/_multiproc/pool.py)

    def _start_processes(self, mp, start_method, write_sockets):
        try:
            import time 
            s0 = time.time()
            for process in self._processes:
                process.start()
            s1 = time.time()
            task_queues = [
                worker_context.dedicated_task_queue
                for worker_context in self._workers_contexts
                if worker_context.dedicated_task_queue is not None
            ]
            if self._general_task_queue is not None:
                task_queues.append(self._general_task_queue)
            self._observer = Observer(mp, self._processes, task_queues, self._result_queue)
            s3 = time.time()
            if start_method != "fork":
                # NOTE when making any changes here, make sure to reflect them in the worker
                # process, so that it sets received handles to objects in the same order
                self._send_queue_handles(write_sockets)
                self._send_shm_handles(write_sockets)
            s4 = time.time()
            self._sync_initialized_workers()
            s5 = time.time()
            print(f"_start_processes: s1-s0 {s1-s0}")
            print(f"_start_processes: s3-s1 {s3-s1}")
            print(f"_start_processes: s4-s3 {s4-s3}")
            print(f"_start_processes: s5-s4 {s5-s4}")
        except:  # noqa: E722
            if self._observer is not None:
                self._observer.close()
                self._observer = None
            else:
                for proc in self._processes:
                    if proc.is_alive():
                        proc.terminate()
                for proc in self._processes:
                    if proc.pid is not None:
                        proc.join()
            raise

@stiepan
Copy link
Member

stiepan commented Aug 13, 2024

Hi @quanvuong,

Because you are referring to nvidia/dali/_multiproc/pool.py, I assume you are using parallel external source in your iterator. It's a bit of a guess without more details on the code you wrote, but the place you pointed to will take more time when the callback/source you are passing to external source is heavy. At that point, Python multiprocess package passes serialized callbacks to the workers, so the bigger the serialized object is, the longer time it will take to start the processes.

If that's the case, you can check https://docs.nvidia.com/deeplearning/dali/user-guide/docs/examples/general/data_loading/parallel_external_source.html#Serialization-and-heavy-setup to try to get around, by making sure that the seriallized object is lighter.

@quanvuong
Copy link
Author

Yes I am using parallel external source in my iterator (with jax integration).

I have moved heavy set up to get_state as recommended, and that reduces the time taken to instantiate the data iterator by 10-15%, but it is still quite slow. Any advise? Is there a profiler that I can use?

We are running on 8 H100 nodes, and instantiating the data iterator takes more than 10 minutes (about 2 minutes per gpu).

@JanuszL
Copy link
Contributor

JanuszL commented Aug 19, 2024

Hi @quanvuong,

Can you bisect which part of the external source callback is slow (start with the callback that just calls np.ones for example and then gradually extend it adding more logic/imports)?

@stiepan
Copy link
Member

stiepan commented Aug 20, 2024

Hi @quanvuong,

To make sure this the serialization is no longer the main factor contributing to the start-up time, you could use custom pickler that wraps the pickle and provides you with some more information, like the size of the callback once it is serialized.

import pickle
from nvidia.dali import pipeline_def, fn

from source import Cb

class PeekPickler:

    @classmethod
    def loads(cls, payload):
        return pickle.loads(payload)

    @classmethod
    def dumps(cls, obj):
        payload = pickle.dumps(obj)
        print("The payload size for the callback: ", obj, len(payload))
        return payload


@pipeline_def(
    batch_size=4,
    device_id=0,
    num_threads=4,
    py_num_workers=4,
    py_start_method="spawn",
    py_callback_pickler=PeekPickler
)
def pipeline():
    return fn.external_source(Cb(1024), parallel=True, batch=False)

Another thing that may contribute to the total start-up time of the workers (although I would expect it to show as s5-s4 in the snippet you provided) is the time Python takes to setup all the imports and globals in the worker processes. Note, the main entrypoint file will be loaded and setup in the worker process, including recursively processing imports. It may help to define the callback/source in a separate file and make sure any heavy setup in the entrypoint file (along with any imports not needed for the callback) are protected with if __name__ == "__main__": statement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants