Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Build for GPU fails due to nccl error #16711

Open
juuso-oskari opened this issue Sep 2, 2024 · 8 comments
Open

Build for GPU fails due to nccl error #16711

juuso-oskari opened this issue Sep 2, 2024 · 8 comments
Assignees

Comments

@juuso-oskari
Copy link

I'm trying to build the XLA for GPU according to this guide: https://openxla.org/xla/developer_guide. Configuration goes just fine:

$ docker exec xla ./configure.py --backend=CUDA
INFO:root:Trying to find path to clang...
INFO:root:Found path to clang at /usr/lib/llvm-18/bin/clang
INFO:root:Running echo __clang_major__ | /usr/lib/llvm-18/bin/clang -E -P -
INFO:root:/usr/lib/llvm-18/bin/clang reports major version 18.
INFO:root:Trying to find path to nvidia-smi...
INFO:root:Found path to nvidia-smi at /usr/bin/nvidia-smi
INFO:root:Found CUDA compute capabilities: ['8.9']
INFO:root:Writing bazelrc to /xla/xla_configure.bazelrc...

But then when I try:

$ docker exec xla bazel build --test_output=all --spawn_strategy=sandboxed //xla/...
(truncated output)
Repository rule _tf_http_archive defined at:
  /root/.cache/bazel/_bazel_root/e4ab50d61a21943a819d1e092972a817/external/tsl/third_party/repo.bzl:89:35: in <toplevel>
ERROR: Analysis of target '//xla/tsl/cuda:nccl' failed; build aborted: Analysis failed
INFO: Elapsed time: 169.944s
INFO: 0 processes.
FAILED: Build did NOT complete successfully (288 packages loaded, 15201 targets configured)

So it seems that building the nccl library fails. If I try:

$ docker exec xla bazel build --test_output=all --spawn_strategy=sandboxed //xla/tsl/cuda:nccl
INFO: Reading 'startup' options from /xla/.bazelrc: --windows_enable_symlinks
INFO: Options provided by the client:
  Inherited 'common' options: --isatty=0 --terminal_columns=80
INFO: Reading rc options for 'build' from /xla/.bazelrc:
  Inherited 'common' options: --experimental_repo_remote_exec
INFO: Reading rc options for 'build' from /etc/bazel.bazelrc:
  'build' options: --action_env=DOCKER_CACHEBUSTER=1724544836993635635 --host_action_env=DOCKER_HOST_CACHEBUSTER=1724544837092374632
INFO: Reading rc options for 'build' from /xla/.bazelrc:
  'build' options: --define framework_shared_object=true --define tsl_protobuf_header_only=true --define=use_fast_cpp_protos=true --define=allow_oversize_protos=true --spawn_strategy=standalone -c opt --announce_rc --define=grpc_no_ares=true --noincompatible_remove_legacy_whole_archive --features=-force_no_whole_archive --enable_platform_specific_config --define=with_xla_support=true --config=short_logs --config=v2 --experimental_cc_shared_library --experimental_link_static_libraries_once=false --incompatible_enforce_config_setting_visibility
INFO: Reading rc options for 'build' from /xla/xla_configure.bazelrc:
  'build' options: --action_env CLANG_COMPILER_PATH=/usr/lib/llvm-18/bin/clang --repo_env CC=/usr/lib/llvm-18/bin/clang --repo_env BAZEL_COMPILER=/usr/lib/llvm-18/bin/clang --config nvcc_clang --action_env CLANG_CUDA_COMPILER_PATH=/usr/lib/llvm-18/bin/clang --config nonccl --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 --action_env PYTHON_BIN_PATH=/usr/bin/python3 --python_path /usr/bin/python3 --copt -Wno-sign-compare --copt -Wno-error=unused-command-line-argument --copt -Wno-gnu-offsetof-extensions --build_tag_filters -no_oss --test_tag_filters -no_oss
INFO: Found applicable config definition build:short_logs in file /xla/.bazelrc: --output_filter=DONT_MATCH_ANYTHING
INFO: Found applicable config definition build:v2 in file /xla/.bazelrc: --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1
INFO: Found applicable config definition build:nvcc_clang in file /xla/.bazelrc: --config=cuda --action_env=TF_NVCC_CLANG=1 --@local_config_cuda//:cuda_compiler=nvcc
INFO: Found applicable config definition build:cuda in file /xla/.bazelrc: --repo_env TF_NEED_CUDA=1 --crosstool_top=@local_config_cuda//crosstool:toolchain --@local_config_cuda//:enable_cuda --repo_env=HERMETIC_CUDA_VERSION=12.3.2 --repo_env=HERMETIC_CUDNN_VERSION=9.3.0 --@local_config_cuda//cuda:include_cuda_libs=true
INFO: Found applicable config definition build:cuda in file /xla/xla_configure.bazelrc: --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES=8.9
INFO: Found applicable config definition build:nonccl in file /xla/.bazelrc: --define=no_nccl_support=true
INFO: Found applicable config definition build:linux in file /xla/.bazelrc: --host_copt=-w --copt=-Wno-all --copt=-Wno-extra --copt=-Wno-deprecated --copt=-Wno-deprecated-declarations --copt=-Wno-ignored-attributes --copt=-Wno-array-bounds --copt=-Wunused-result --copt=-Werror=unused-result --copt=-Wswitch --copt=-Werror=switch --copt=-Wno-error=unused-but-set-variable --define=PREFIX=/usr --define=LIBDIR=$(PREFIX)/lib --define=INCLUDEDIR=$(PREFIX)/include --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 --config=dynamic_kernels --experimental_guard_against_concurrent_changes
INFO: Found applicable config definition build:dynamic_kernels in file /xla/.bazelrc: --define=dynamic_loaded_kernels=true --copt=-DAUTOLOAD_DYNAMIC_KERNELS
Loading: 
Loading: 
Loading: 0 packages loaded
INFO: Build option --define has changed, discarding analysis cache.
Analyzing: target //xla/tsl/cuda:nccl (0 packages loaded, 0 targets configured)
ERROR: /xla/xla/tsl/cuda/BUILD.bazel:336:11: no such target '@local_config_nccl//:nccl_headers': target 'nccl_headers' not declared in package '' defined by /root/.cache/bazel/_bazel_root/e4ab50d61a21943a819d1e092972a817/external/local_config_nccl/BUILD (Tip: use `query "@local_config_nccl//:*"` to see all the targets in that package) and referenced by '//xla/tsl/cuda:nccl'
ERROR: Analysis of target '//xla/tsl/cuda:nccl' failed; build aborted: Analysis failed
INFO: Elapsed time: 0.274s
INFO: 0 processes.
FAILED: Build did NOT complete successfully (0 packages loaded, 267 targets configured)

I get a bit more verbose output. I've seen couple related unsolved issues #11604 and #10616, but nothing in these has worked for me yet.

@juuso-oskari
Copy link
Author

Update: Managed to continue the compilation by running all the docker commands with sudo. I suppose rootless docker runs could be possible with correct configuring of the nvidia container toolkit (but for me this failed) https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html.

But still I get an error at the later stages of compilation:

....
[46,478 / 52,250] Compiling xla/python/ifrt_proxy/client/array.cc; 16s processwrapper-sandbox ... (32 actions, 31 running)
ERROR: /xla/xla/tests/BUILD:2734:12: Linking xla/tests/local_client_aot_test failed: (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing command (from target //xla/tests:local_client_aot_test) external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc @bazel-out/k8-opt/bin/xla/tests/local_client_aot_test-2.params

@akuegel
Copy link
Member

akuegel commented Sep 3, 2024

Compiling/running that test should not be a blocker. I would suggest to skip it in your compile command with --build_tag_filters

So as far as I can tell, this test is supposed to work on these architectures:

https://github.com/openxla/xla/blob/main/xla/tests/local_client_aot_test_helper.cc#L66

@frgossen
Copy link
Member

frgossen commented Sep 6, 2024

@Tixxx, do you run into this, too?

@Tixxx
Copy link
Contributor

Tixxx commented Sep 6, 2024

can you try running configure.py with the nccl option?
python configure.py --backend cuda --nccl

@juuso-oskari
Copy link
Author

@Tixxx still fails with the same error eventhough I run with the --nccl

@Xinyu302
Copy link

Xinyu302 commented Sep 9, 2024

I have the same error...

@Xinyu302
Copy link

Xinyu302 commented Sep 9, 2024

@Tixxx still fails with the same error eventhough I run with the --nccl

Maybe you can try this:

export TF_NCCL_USE_STUB=1

@ybaturina
Copy link

This command should work too:
bazel build --test_output=all //xla/tsl/cuda:nccl --repo_env=TF_NCCL_USE_STUB=1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants