Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend CustomCallOp backend_config to take a DictionaryAttr #2415

Merged
merged 19 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion stablehlo/dialect/StablehloEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,15 @@ def STABLEHLO_CUSTOM_CALL_API_VERSION_STATUS_RETURNING :
I32EnumAttrCase<"API_VERSION_STATUS_RETURNING", 2>;
def STABLEHLO_CUSTOM_CALL_API_VERSION_STATUS_RETURNING_UNIFIED :
I32EnumAttrCase<"API_VERSION_STATUS_RETURNING_UNIFIED", 3>;
def STABLEHLO_CUSTOM_CALL_API_VERSION_TYPED_FFI :
I32EnumAttrCase<"API_VERSION_TYPED_FFI", 4>;
def StableHLO_CustomCallApiVersionAttr :
I32EnumAttr<"CustomCallApiVersion", "Custom call API version", [
STABLEHLO_CUSTOM_CALL_API_VERSION_UNSPECIFIED,
STABLEHLO_CUSTOM_CALL_API_VERSION_ORIGINAL,
STABLEHLO_CUSTOM_CALL_API_VERSION_STATUS_RETURNING,
STABLEHLO_CUSTOM_CALL_API_VERSION_STATUS_RETURNING_UNIFIED
STABLEHLO_CUSTOM_CALL_API_VERSION_STATUS_RETURNING_UNIFIED,
STABLEHLO_CUSTOM_CALL_API_VERSION_TYPED_FFI
]> {
let cppNamespace = "::mlir::stablehlo";
}
Expand Down
24 changes: 24 additions & 0 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,20 @@ LogicalResult CustomCallOp::verify() {
<< "operand part has type " << operandPart
<< " and output part has type " << outputPart;
}
if (auto backendConfig = getBackendConfig()) {
if (getApiVersion() == CustomCallApiVersion::API_VERSION_TYPED_FFI) {
if (!isa<mlir::DictionaryAttr>(*backendConfig))
return emitOpError() << "backend_config for api_version "
<< stringifyCustomCallApiVersion(getApiVersion())
<< " must be a dictionary attribute.";
} else {
if (!isa<mlir::StringAttr>(*backendConfig))
return emitOpError() << "backend_config for api_version "
<< stringifyCustomCallApiVersion(getApiVersion())
<< " must be a string attribute.";
}
}

return success();
}

Expand All @@ -493,6 +507,16 @@ void CustomCallOp::getEffects(
effects.emplace_back(MemoryEffects::Read::get());
}

mlir::Attribute CustomCallOp::getBackendConfigOrDefault() {
if (getBackendConfig().has_value()) return getBackendConfig().value();
abhigunj marked this conversation as resolved.
Show resolved Hide resolved

if (getApiVersion() ==
mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI)
return DictionaryAttr::get(getContext());

return StringAttr::get(getContext(), "");
}

//===----------------------------------------------------------------------===//
// CholeskyOp
//===----------------------------------------------------------------------===//
Expand Down
14 changes: 12 additions & 2 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2340,13 +2340,19 @@ def StableHLO_CustomCallOp: StableHLO_Op<"custom_call",
Encapsulates an implementation-defined operation `call_target_name` that
takes `inputs` and `called_computations` and produces `results`.

Depending on the API version there are two ways to pass extra bits of static
information to the external function:
abhigunj marked this conversation as resolved.
Show resolved Hide resolved
1. Use `API_VERSION_TYPED_FFI` which allows passing a dictionary attribute.
2. Use a previous API version with a StringAttr to encode backend config.

See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#custom_call

Example:
```mlir
%results = stablehlo.custom_call @foo(%input0) {
backend_config = "bar",
backend_config = {bar = 42 : i32},
api_version = 4 : i32,
abhigunj marked this conversation as resolved.
Show resolved Hide resolved
called_computations = [@foo]
} : (tensor<f64>) -> tensor<f64>
```
Expand All @@ -2356,7 +2362,7 @@ def StableHLO_CustomCallOp: StableHLO_Op<"custom_call",
Variadic<HLO_CustomCallValue>:$inputs,
StrAttr:$call_target_name,
DefaultValuedOptionalAttr<BoolAttr, "false">:$has_side_effect,
DefaultValuedStrAttr<StrAttr, "">:$backend_config,
OptionalAttr<AnyAttrOf<[StrAttr, DictionaryAttr]>>:$backend_config,
// TODO(b/189822916): Remove this field when all clients are migrated to
// the status-returning API.
DefaultValuedOptionalAttr<
Expand All @@ -2380,6 +2386,10 @@ def StableHLO_CustomCallOp: StableHLO_Op<"custom_call",
custom<CustomCallTarget>($call_target_name) `(` $inputs `)`
attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = commonClassDeclaration # [{
mlir::Attribute getBackendConfigOrDefault();
}];
}

def StableHLO_DotOp: StableHLO_Op<"dot", [Pure]> {
Expand Down
16 changes: 16 additions & 0 deletions stablehlo/dialect/VhloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,17 @@ LogicalResult verifyConstraint_0_17_0(mlir::Operation* op,
return failure();
return success();
}

LogicalResult verifyConstraint_1_1_0(mlir::Operation* op,
Version targetVersion) {
auto customCallOp = dyn_cast<mlir::vhlo::CustomCallOpV1>(op);
GleasonK marked this conversation as resolved.
Show resolved Hide resolved
if (customCallOp &&
isa<mlir::DictionaryAttr>(customCallOp.getBackendConfig()) &&
targetVersion < Version(1, 1, 0))
return failure();
abhigunj marked this conversation as resolved.
Show resolved Hide resolved
return success();
}

} // namespace

LogicalResult AllReduceOpV1::validateConstraint(mlir::Operation* op,
Expand Down Expand Up @@ -365,5 +376,10 @@ LogicalResult SelectAndScatterOpV1::validateConstraint(mlir::Operation* op,
return verifyConstraint_0_17_0(op, targetVersion);
}

LogicalResult CustomCallOpV1::validateConstraint(mlir::Operation* op,
Version targetVersion) {
return verifyConstraint_1_1_0(op, targetVersion);
}

} // namespace vhlo
} // namespace mlir
3 changes: 2 additions & 1 deletion stablehlo/dialect/VhloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,8 @@ def VHLO_CrossReplicaSumOpV1 : VHLO_Op<"cross-replica-sum_v1", "0.9.0", "current
// TODO(#740): output_operand_aliases is not part of the spec.
// CustomCallOp has proven to be one of the trickiest ops to fully spec.
// We're aiming to address all these todos by the release of StableHLO v1.0.
def VHLO_CustomCallOpV1 : VHLO_Op<"custom_call_v1", "0.9.0", "current"> {
def VHLO_CustomCallOpV1 : VHLO_Op<"custom_call_v1", "0.9.0", "current",
[DeclareOpInterfaceMethods<VHLO_VersionedOpConstraintInterface>]> {
let arguments = (ins
Variadic<VHLO_AnyType>:$inputs,
VHLO_AnyAttr:$call_target_name,
Expand Down
24 changes: 24 additions & 0 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4926,6 +4926,30 @@ func.func @custom_call_unranked_types(%arg0: tensor<*xf32>) -> tensor<*xf32> {

// -----

func.func @custom_call_with_dictionary_backend_config() {
// CHECK: stablehlo.custom_call @foo() {api_version = 4 : i32, backend_config = {foo = 42 : i32}}
"stablehlo.custom_call"() {api_version = 4 : i32, backend_config={foo = 42 : i32}, call_target_name = "foo"} : () -> ()
func.return
}

// -----

func.func @custom_call_with_incompatible_backend_config() {
// expected-error@+1 {{backend_config for api_version API_VERSION_TYPED_FFI must be a dictionary attribute}}
"stablehlo.custom_call"() {api_version = 4 : i32, backend_config="bar=42", call_target_name = "foo"} : () -> ()
func.return
}

// -----

func.func @custom_call_with_incompatible_backend_config() {
// expected-error@+1 {{backend_config for api_version API_VERSION_STATUS_RETURNING_UNIFIED must be a string attribute}}
"stablehlo.custom_call"() {api_version = 3 : i32, backend_config={bar = 42 : i32}, call_target_name = "foo"} : () -> ()
func.return
}

// -----

// Test custom attribute printing/parsing.
// We really just need one op as holder, use module: this is the simplest top-level.

Expand Down
1 change: 0 additions & 1 deletion stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,6 @@ func.func public @op_custom_call_empty_result_layout(%arg0: tensor<i64>) -> tens
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.i64_v1>) -> !vhlo.tuple_v1<>
%0 = "stablehlo.custom_call"(%arg0) <{
api_version = 2 : i32,
backend_config = "",
abhigunj marked this conversation as resolved.
Show resolved Hide resolved
call_target_name = "empty_output",
has_side_effect = true,
operand_layouts = [dense<> : tensor<0xindex>],
Expand Down
11 changes: 11 additions & 0 deletions stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_1_0.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo --vhlo-to-version='target=1.1.0' --verify-diagnostics --split-input-file %s

func.func @custom_call_dictionary_attr(%arg0: tensor<f32>) -> tensor<f32> {
GleasonK marked this conversation as resolved.
Show resolved Hide resolved
// expected-error @+1 {{failed to legalize operation 'stablehlo.custom_call' that was explicitly marked illegal}}
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
api_version = 4 : i32,
backend_config={foo = 42 : i32}
} : (tensor<f32>) -> tensor<f32>
func.return %0 : tensor<f32>
}
Loading