Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:CLIENT] Why force using U32 indices type when converting torch gather op #15654

Open
Nullkooland opened this issue Aug 2, 2024 · 2 comments

Comments

@Nullkooland
Copy link

Nullkooland commented Aug 2, 2024

This commit: a1c04b8 intrudoces the code that forces xla::TorchGather to convert indices tensor with element type of i64 to u32:

if (ShapeUtil::ElementHasBitWidth(index_shape, 64) &&
input_shape.dimensions(dim) < std::numeric_limits<uint32>::max()) {
index = ConvertElementType(index, U32);
index_shape.set_element_type(U32);
}

This causes the StableHLO IRs exported with torch_xla to have such pattern around stablehlo.gather:

gather_ui32_indices

However, when lowering stablehlo.gather to MLIR tensor dialect's tensor.gather, the tensor.gather op requires that its indices operand tensor has signless integer type, so ui32 would cause error. Also, it is inconsistent that any other index type in the IR is i64 while only index type of this gather is ui32.

@blakehechtman Could you have a look? This commit looks like a HACK to me.
Is it possible to revert this?

@cheshire
Copy link
Member

cheshire commented Aug 5, 2024

The commit is 5 years old, and thus unlikely to be reverted.

However, when lowering stablehlo.gather to MLIR tensor dialect's tensor.gather

This seems outside of the scope of OpenXLA, more like a stablehlo/tensor dialect interop issue?

@Nullkooland
Copy link
Author

Nullkooland commented Aug 16, 2024

The commit is 5 years old, and thus unlikely to be reverted.

However, when lowering stablehlo.gather to MLIR tensor dialect's tensor.gather

This seems outside of the scope of OpenXLA, more like a stablehlo/tensor dialect interop issue?

I don't think this is a stablehlo-to-tensor dialect conversion issue, since stablehlo.gather is able to take i64 indices. After removing this force u32 indices conversion code in XLA, the same torch gather op will be exported as:
stablehlo_gather_i64_indices

I don't see why do we need this u32 HACK.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants