Skip to content

Commit

Permalink
As proposed in RFC: int2 in StableHLO (openxla#2403), this PR adds su…
Browse files Browse the repository at this point in the history
…pport for these types to StableHLO.
  • Loading branch information
superbobry committed Jul 2, 2024
1 parent f372def commit 05e6a35
Show file tree
Hide file tree
Showing 12 changed files with 93 additions and 19 deletions.
6 changes: 3 additions & 3 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ which may allow us to remove tuple types from StableHLO
TensorElementType ::= BooleanType | IntegerType | FloatType | ComplexType
BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
| 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
ComplexType ::= 'complex' '<' ComplexElementType '>'
Expand All @@ -259,7 +259,7 @@ values of type `tensor<T>`).

* **Boolean type** represents boolean values `true` and `false`.
* **Integer types** can be either signed (`si`) or unsigned (`ui`) and have
one of the supported bit widths (`4`, `8`, `16`, `32` or `64`).
one of the supported bit widths (`2`, `4`, `8`, `16`, `32` or `64`).
Signed `siN` types represent integer values from `-2^(N-1)` to `2^(N-1)-1`
inclusive, and unsigned `uiN` types represent integer values from `0` to
`2^N-1` inclusive.
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;

// TODO(hinsu): Use signed integers instead of signless integer which is being
// used for legacy reasons.
def HLO_SInt : SignlessIntOfWidths<[4, 8, 16, 32, 64]>;
def HLO_UInt : UnsignedIntOfWidths<[4, 8, 16, 32, 64]>;
def HLO_SInt : SignlessIntOfWidths<[2, 4, 8, 16, 32, 64]>;
def HLO_UInt : UnsignedIntOfWidths<[2, 4, 8, 16, 32, 64]>;
def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>;

def HLO_Float : AnyTypeOf<[F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, F8E5M2,
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/Version.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Version {
static FailureOr<Version> fromString(llvm::StringRef versionRef);

/// Return a Version representing the current VHLO dialect version.
static Version getCurrentVersion() { return Version(1, 1, 7); }
static Version getCurrentVersion() { return Version(1, 2, 0); }

/// Return a Version representing the minimum supported VHLO dialect version.
static Version getMinimumVersion() { return Version(0, 9, 0); }
Expand Down
20 changes: 20 additions & 0 deletions stablehlo/dialect/VhloBytecode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ enum TypeCode {
/// }
kIndexV1Type = 9,

/// IntegerSI2V1Type {
/// }
kIntegerSI2V1Type = 31,

/// IntegerSI4V1Type {
/// }
kIntegerSI4V1Type = 10,
Expand All @@ -250,6 +254,10 @@ enum TypeCode {
/// }
kIntegerSI64V1Type = 14,

/// IntegerUI2V1Type {
/// }
kIntegerUI2V1Type = 32,

/// IntegerUI4V1Type {
/// }
kIntegerUI4V1Type = 15,
Expand Down Expand Up @@ -959,6 +967,8 @@ Type VhloBytecodeInterface::readType(DialectBytecodeReader &reader) const {
return readFunctionV1Type(reader);
case vhlo_encoding::kIndexV1Type:
return IndexV1Type::get(getContext());
case vhlo_encoding::kIntegerSI2V1Type:
return IntegerSI2V1Type::get(getContext());
case vhlo_encoding::kIntegerSI4V1Type:
return IntegerSI4V1Type::get(getContext());
case vhlo_encoding::kIntegerSI8V1Type:
Expand All @@ -969,6 +979,8 @@ Type VhloBytecodeInterface::readType(DialectBytecodeReader &reader) const {
return IntegerSI32V1Type::get(getContext());
case vhlo_encoding::kIntegerSI64V1Type:
return IntegerSI64V1Type::get(getContext());
case vhlo_encoding::kIntegerUI2V1Type:
return IntegerUI2V1Type::get(getContext());
case vhlo_encoding::kIntegerUI4V1Type:
return IntegerUI4V1Type::get(getContext());
case vhlo_encoding::kIntegerUI8V1Type:
Expand Down Expand Up @@ -1059,6 +1071,10 @@ LogicalResult VhloBytecodeInterface::writeType(
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kIndexV1Type), success();
})
.Case([&](IntegerSI2V1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kIntegerSI2V1Type), success();
})
.Case([&](IntegerSI4V1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kIntegerSI4V1Type), success();
Expand All @@ -1079,6 +1095,10 @@ LogicalResult VhloBytecodeInterface::writeType(
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kIntegerSI64V1Type), success();
})
.Case([&](IntegerUI2V1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kIntegerUI2V1Type), success();
})
.Case([&](IntegerUI4V1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kIntegerUI4V1Type), success();
Expand Down
1 change: 1 addition & 0 deletions stablehlo/dialect/VhloDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def VHLO_Dialect : Dialect {
0.19.0: Introduce `composite` operation.
0.20.0: Remove `padding` attribute from `dynamic_conv`.
1.0.0: Increase compatibility guarantees to 5 years backward, 2 years forward (no functional changes relative to 0.20.0).
1.2.0: Introduce `si2` and `ui2` types.
}];

let useDefaultAttributePrinterParser = 0;
Expand Down
9 changes: 9 additions & 0 deletions stablehlo/dialect/VhloTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ Type convertBuiltinIntegerType(IntegerType type) {
bool isSignless = type.isSignless();
auto ctx = type.getContext();
switch (type.getWidth()) {
case 2:
return isSignless ? cast<Type>(IntegerSI2V1Type::get(ctx))
: cast<Type>(IntegerUI2V1Type::get(ctx));
case 4:
return isSignless ? cast<Type>(IntegerSI4V1Type::get(ctx))
: cast<Type>(IntegerUI4V1Type::get(ctx));
Expand Down Expand Up @@ -183,6 +186,9 @@ void VhloTypeConverter::addVhloToBuiltinConversions() {
});
addConversion(
[&](IndexV1Type type) { return IndexType::get(type.getContext()); });
addConversion([&](IntegerSI2V1Type type) {
return IntegerType::get(type.getContext(), 2);
});
addConversion([&](IntegerSI4V1Type type) {
return IntegerType::get(type.getContext(), 4);
});
Expand All @@ -198,6 +204,9 @@ void VhloTypeConverter::addVhloToBuiltinConversions() {
addConversion([&](IntegerSI64V1Type type) {
return IntegerType::get(type.getContext(), 64);
});
addConversion([&](IntegerUI2V1Type type) {
return IntegerType::get(type.getContext(), 2, IntegerType::Unsigned);
});
addConversion([&](IntegerUI4V1Type type) {
return IntegerType::get(type.getContext(), 4, IntegerType::Unsigned);
});
Expand Down
6 changes: 6 additions & 0 deletions stablehlo/dialect/VhloTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def VHLO_FunctionV1 : VHLO_TypeDef<"FunctionV1", "func_v1", "0.9.0", "current">
// dynamism RFC.
def VHLO_IndexV1 : VHLO_TypeDef<"IndexV1", "index_v1", "0.9.0", "current">;

// Corresponds to the 'si2' FloatType from the StableHLO spec.
def VHLO_IntegerSI2V1 : VHLO_TypeDef<"IntegerSI2V1", "i2_v1", "1.2.0", "current">;

// Corresponds to the 'si4' FloatType from the StableHLO spec.
def VHLO_IntegerSI4V1 : VHLO_TypeDef<"IntegerSI4V1", "i4_v1", "0.9.0", "current">;

Expand All @@ -134,6 +137,9 @@ def VHLO_IntegerSI32V1 : VHLO_TypeDef<"IntegerSI32V1", "i32_v1", "0.9.0", "curre
// Corresponds to the 'si64' FloatType from the StableHLO spec.
def VHLO_IntegerSI64V1 : VHLO_TypeDef<"IntegerSI64V1", "i64_v1", "0.9.0", "current">;

// Corresponds to the 'ui2' FloatType from the StableHLO spec.
def VHLO_IntegerUI2V1 : VHLO_TypeDef<"IntegerUI2V1", "ui2_v1", "1.2.0", "current">;

// Corresponds to the 'ui4' FloatType from the StableHLO spec.
def VHLO_IntegerUI4V1 : VHLO_TypeDef<"IntegerUI4V1", "ui4_v1", "0.9.0", "current">;

Expand Down
18 changes: 12 additions & 6 deletions stablehlo/reference/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ Element Tensor::get(const Index &index) const {
if (isSupportedIntegerType(elementType)) {
IntegerType intTy = cast<IntegerType>(elementType);

if (elementType.isSignlessInteger(4) || elementType.isSignlessInteger(8)) {
if (elementType.isSignlessInteger(2) || elementType.isSignlessInteger(4) ||
elementType.isSignlessInteger(8)) {
auto elementData = reinterpret_cast<const int8_t *>(elementPtr);
return Element(elementType, APInt(intTy.getWidth(), *elementData,
intTy.isSignedInteger()));
Expand All @@ -174,7 +175,8 @@ Element Tensor::get(const Index &index) const {
auto elementData = reinterpret_cast<const int64_t *>(elementPtr);
return Element(elementType, APInt(intTy.getWidth(), *elementData,
intTy.isSignedInteger()));
} else if (elementType.isUnsignedInteger(4) ||
} else if (elementType.isUnsignedInteger(2) ||
elementType.isUnsignedInteger(4) ||
elementType.isUnsignedInteger(8)) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APInt(intTy.getWidth(), *elementData,
Expand Down Expand Up @@ -270,7 +272,8 @@ void Tensor::set(const Index &index, const Element &element) {
// integers which was added in MHLO for legacy reasons. Going forward,
// StableHLO will adopt signfull integer semantics with signed and unsigned
// integer variants.
if (elementType.isSignlessInteger(4) || elementType.isSignlessInteger(8)) {
if (elementType.isSignlessInteger(2) || elementType.isSignlessInteger(4) ||
elementType.isSignlessInteger(8)) {
auto elementData = reinterpret_cast<int8_t *>(elementPtr);
auto value = element.getIntegerValue();
*elementData = (int8_t)value.getSExtValue();
Expand Down Expand Up @@ -299,7 +302,8 @@ void Tensor::set(const Index &index, const Element &element) {
}

// Handle unsigned integer types.
if (elementType.isUnsignedInteger(4) || elementType.isUnsignedInteger(8)) {
if (elementType.isUnsignedInteger(2) || elementType.isUnsignedInteger(4) ||
elementType.isUnsignedInteger(8)) {
auto elementData = reinterpret_cast<uint8_t *>(elementPtr);
auto value = element.getIntegerValue();
*elementData = (uint8_t)value.getZExtValue();
Expand Down Expand Up @@ -428,7 +432,8 @@ Tensor makeTensor(DenseElementsAttr attr) {
}

// Handle signed integer types.
if (elementType.isSignlessInteger(4) || elementType.isSignlessInteger(8)) {
if (elementType.isSignlessInteger(2) || elementType.isSignlessInteger(4) ||
elementType.isSignlessInteger(8)) {
auto intValues = llvm::map_to_vector(
attr.getValues<APInt>(),
[&](APInt value) -> int8_t { return value.getSExtValue(); });
Expand Down Expand Up @@ -461,7 +466,8 @@ Tensor makeTensor(DenseElementsAttr attr) {
}

// Handle unsigned integer types.
if (elementType.isUnsignedInteger(4) || elementType.isUnsignedInteger(8)) {
if (elementType.isUnsignedInteger(2) || elementType.isUnsignedInteger(4) ||
elementType.isUnsignedInteger(8)) {
auto intValues = llvm::map_to_vector(
attr.getValues<APInt>(),
[&](APInt value) -> uint8_t { return value.getZExtValue(); });
Expand Down
12 changes: 6 additions & 6 deletions stablehlo/reference/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@ namespace mlir {
namespace stablehlo {

bool isSupportedUnsignedIntegerType(Type type) {
return type.isUnsignedInteger(4) || type.isUnsignedInteger(8) ||
type.isUnsignedInteger(16) || type.isUnsignedInteger(32) ||
type.isUnsignedInteger(64);
return type.isUnsignedInteger(2) || type.isUnsignedInteger(4) ||
type.isUnsignedInteger(8) || type.isUnsignedInteger(16) ||
type.isUnsignedInteger(32) || type.isUnsignedInteger(64);
}

bool isSupportedSignedIntegerType(Type type) {
// TODO(#22): StableHLO, as bootstrapped from MHLO, inherits signless
// integers which was added in MHLO for legacy reasons. Going forward,
// StableHLO will adopt signfull integer semantics with signed and unsigned
// integer variants.
return type.isSignlessInteger(4) || type.isSignlessInteger(8) ||
type.isSignlessInteger(16) || type.isSignlessInteger(32) ||
type.isSignlessInteger(64);
return type.isSignlessInteger(2) || type.isSignlessInteger(4) ||
type.isSignlessInteger(8) || type.isSignlessInteger(16) ||
type.isSignlessInteger(32) || type.isSignlessInteger(64);
}

bool isSupportedBooleanType(Type type) { return type.isSignlessInteger(1); }
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/tests/CheckOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ llvm::Error evalExpectCloseOp(const Tensor &actual, const Tensor &expected,
lhsIt != actual.index_end(); ++lhsIt, ++rhsIt) {
auto e1 = actual.get(*lhsIt);
auto e2 = expected.get(*rhsIt);
uint64_t ulp_diff = ULPDifference(e1, e2);
size_t ulp_diff = ULPDifference(e1, e2);
if (ulp_diff > max_ulp_difference || ulp_diff < min_ulp_difference) {
output << "\n index=" << (*lhsIt) << ", actual=" << e1
<< ", expected=" << e2 << ", ULP difference=" << ulp_diff;
Expand Down
16 changes: 16 additions & 0 deletions stablehlo/tests/interpret/constant.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

func.func @constant_op_test_si2() {
%0 = stablehlo.constant dense<[-2, -1, 0, 1]> : tensor<4xi2>
check.expect_eq_const %0, dense<[-2, -1, 0, 1]> : tensor<4xi2>
func.return
}

// -----

func.func @constant_op_test_ui2() {
%0 = stablehlo.constant dense<[0, 1, 2, 3]> : tensor<4xui2>
check.expect_eq_const %0, dense<[0, 1, 2, 3]> : tensor<4xui2>
func.return
}

// -----

func.func @constant_op_test_si4() {
%0 = stablehlo.constant dense<[-8, -1, 0, 1, 7]> : tensor<5xi4>
check.expect_eq_const %0, dense<[-8, -1, 0, 1, 7]> : tensor<5xi4>
Expand Down
16 changes: 16 additions & 0 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2525,6 +2525,14 @@ func.func @type_i1(%arg0: tensor<i1>, %arg1: tensor<i1>) -> tensor<i1> {
func.return %0 : tensor<i1>
}

// CHECK-LABEL: "type_i2"
// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}})
func.func @type_i2(%arg0: tensor<i2>, %arg1: tensor<i2>) -> tensor<i2> {
// CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1<!vhlo.i2_v1>, !vhlo.tensor_v1<!vhlo.i2_v1>) -> !vhlo.tensor_v1<!vhlo.i2_v1>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i2>, tensor<i2>) -> tensor<i2>
func.return %0 : tensor<i2>
}

// CHECK-LABEL: "type_i4"
// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}})
func.func @type_i4(%arg0: tensor<i4>, %arg1: tensor<i4>) -> tensor<i4> {
Expand Down Expand Up @@ -2565,6 +2573,14 @@ func.func @type_i64(%arg0: tensor<i64>, %arg1: tensor<i64>) -> tensor<i64> {
func.return %0 : tensor<i64>
}

// CHECK-LABEL: "type_ui2"
// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}})
func.func @type_ui2(%arg0: tensor<ui2>, %arg1: tensor<ui2>) -> tensor<ui2> {
// CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1<!vhlo.ui2_v1>, !vhlo.tensor_v1<!vhlo.ui2_v1>) -> !vhlo.tensor_v1<!vhlo.ui2_v1>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<ui2>, tensor<ui2>) -> tensor<ui2>
func.return %0 : tensor<ui2>
}

// CHECK-LABEL: "type_ui4"
// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}})
func.func @type_ui4(%arg0: tensor<ui4>, %arg1: tensor<ui4>) -> tensor<ui4> {
Expand Down

0 comments on commit 05e6a35

Please sign in to comment.