Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend CustomCallOp backend_config to take a DictionaryAttr #2415

Merged
merged 19 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,8 @@ void CustomCallOp::getEffects(
}

mlir::Attribute CustomCallOp::getBackendConfigOrDefault() {
if (getBackendConfig().has_value()) return getBackendConfig().value();
auto backendConfig = getBackendConfig();
if (backendConfig.has_value()) return backendConfig.value();

if (getApiVersion() ==
mlir::stablehlo::CustomCallApiVersion::API_VERSION_TYPED_FFI)
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/Version.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Version {
static FailureOr<Version> fromString(llvm::StringRef versionRef);

/// Return a Version representing the current VHLO dialect version.
static Version getCurrentVersion() { return Version(1, 2, 1); }
static Version getCurrentVersion() { return Version(1, 3, 0); }

/// Return a Version representing the minimum supported VHLO dialect version.
static Version getMinimumVersion() { return Version(0, 9, 0); }
Expand Down
1 change: 1 addition & 0 deletions stablehlo/dialect/VhloDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def VHLO_Dialect : Dialect {
0.20.0: Remove `padding` attribute from `dynamic_conv`.
1.0.0: Increase compatibility guarantees to 5 years backward, 2 years forward (no functional changes relative to 0.20.0).
1.2.0: Introduce `si2` and `ui2` types.
1.3.0: Extend `custom_call` op `backend_config` to support `DictionaryAttr`
}];

let useDefaultAttributePrinterParser = 0;
Expand Down
5 changes: 3 additions & 2 deletions stablehlo/dialect/VhloEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,18 @@ def VHLO_ComparisonTypeAttrV1
// CustomCallApiVersion
//===----------------------------------------------------------------------===//

// TODO(#1187): CustomCallApiVersion is not part of the StableHLO spec.
def VHLO_CUSTOM_CALL_API_VERSION_V1_UNSPECIFIED : I32EnumAttrCase<"API_VERSION_UNSPECIFIED", 0>;
def VHLO_CUSTOM_CALL_API_VERSION_V1_ORIGINAL : I32EnumAttrCase<"API_VERSION_ORIGINAL", 1>;
def VHLO_CUSTOM_CALL_API_VERSION_V1_STATUS_RETURNING : I32EnumAttrCase<"API_VERSION_STATUS_RETURNING", 2>;
def VHLO_CUSTOM_CALL_API_VERSION_V1_STATUS_RETURNING_UNIFIED : I32EnumAttrCase<"API_VERSION_STATUS_RETURNING_UNIFIED", 3>;
def VHLO_CUSTOM_CALL_API_VERSION_V1_TYPED_FFI : I32EnumAttrCase<"API_VERSION_TYPED_FFI", 4>;

def VHLO_CustomCallApiVersionV1 : VHLO_I32EnumAttr<"CustomCallApiVersionV1", [
VHLO_CUSTOM_CALL_API_VERSION_V1_UNSPECIFIED,
VHLO_CUSTOM_CALL_API_VERSION_V1_ORIGINAL,
VHLO_CUSTOM_CALL_API_VERSION_V1_STATUS_RETURNING,
VHLO_CUSTOM_CALL_API_VERSION_V1_STATUS_RETURNING_UNIFIED
VHLO_CUSTOM_CALL_API_VERSION_V1_STATUS_RETURNING_UNIFIED,
VHLO_CUSTOM_CALL_API_VERSION_V1_TYPED_FFI
]> {}

def VHLO_CustomCallApiVersionAttrV1
Expand Down
14 changes: 8 additions & 6 deletions stablehlo/dialect/VhloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,13 +334,15 @@ LogicalResult verifyConstraint_0_17_0(mlir::Operation* op,
return success();
}

LogicalResult verifyConstraint_1_1_0(mlir::Operation* op,
LogicalResult verifyConstraint_1_3_0(mlir::Operation* op,
Version targetVersion) {
auto customCallOp = dyn_cast<mlir::vhlo::CustomCallOpV1>(op);
if (customCallOp &&
isa<mlir::DictionaryAttr>(customCallOp.getBackendConfig()) &&
targetVersion < Version(1, 1, 0))
auto customCallOp = cast<mlir::vhlo::CustomCallOpV1>(op);
if (targetVersion < Version(1, 3, 0) &&
(isa<vhlo::DictionaryV1Attr>(customCallOp.getBackendConfig()) ||
mlir::cast<CustomCallApiVersionV1Attr>(customCallOp.getApiVersion())
.getValue() == CustomCallApiVersionV1::API_VERSION_TYPED_FFI)) {
return failure();
abhigunj marked this conversation as resolved.
Show resolved Hide resolved
}
return success();
}

Expand Down Expand Up @@ -378,7 +380,7 @@ LogicalResult SelectAndScatterOpV1::validateConstraint(mlir::Operation* op,

LogicalResult CustomCallOpV1::validateConstraint(mlir::Operation* op,
Version targetVersion) {
return verifyConstraint_1_1_0(op, targetVersion);
return verifyConstraint_1_3_0(op, targetVersion);
}

} // namespace vhlo
Expand Down
Loading
Loading