Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Legalize quantized stablehlo operation using uniform_quantize/uniform_dequantize #2394

Merged
merged 10 commits into from
Jun 18, 2024
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,7 @@ cc_library(
"stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp",
"stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp",
"stablehlo/transforms/StablehloLegalizeQuantToInt.cpp",
"stablehlo/transforms/StablehloLegalizeQuantizedOpToQDQ.cpp",
"stablehlo/transforms/StablehloLegalizeToVhlo.cpp",
"stablehlo/transforms/StablehloRefineArguments.cpp",
"stablehlo/transforms/StablehloRefineShapes.cpp",
Expand Down
25 changes: 25 additions & 0 deletions docs/generated/stablehlo_passes.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,31 @@ _Convert from StableHLO quantized ops to StableHLO primitive ops._

Convert StableHLO programs using UniformQuantized types to semantically
equivalent integer math.
### `-stablehlo-legalize-quantized-op-to-qdq`

_Decompose StableHLO quantized ops using uniform quantize/dequantize ops._

Decompose StableHLO quantized programs using uniform quantize/dequantize
operations. For example, the following program

```mlir
func.func @add(%arg0: tensor<!quant.uniform<i8:f32,1.0:0>>, %arg1: tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>> {
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32,1.0:0>>, tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>>
func.return %0 : tensor<!quant.uniform<i8:f32,3.0:2>>
}
```

Will become:

```mlir
func.func @add(%arg0: tensor<!quant.uniform<i8:f32, 1.000000e+00>>, %arg1: tensor<!quant.uniform<i8:f32, 2.000000e+00:1>>) -> tensor<!quant.uniform<i8:f32, 3.000000e+00:2>> {
%0 = stablehlo.uniform_dequantize %arg0 : (tensor<!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<f32>
%1 = stablehlo.uniform_dequantize %arg1 : (tensor<!quant.uniform<i8:f32, 2.000000e+00:1>>) -> tensor<f32>
%2 = stablehlo.add %0, %1 : tensor<f32>
%3 = stablehlo.uniform_quantize %2 : (tensor<f32>) -> tensor<!quant.uniform<i8:f32, 3.000000e+00:2>>
return %3 : tensor<!quant.uniform<i8:f32, 3.000000e+00:2>>
}
```
### `-stablehlo-legalize-to-vhlo`

_Legalize StableHLO to VHLO._
Expand Down
294 changes: 285 additions & 9 deletions stablehlo/tests/ops_stablehlo_quantized.mlir

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions stablehlo/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ add_mlir_dialect_library(StablehloPasses
StablehloLegalizeCompositeToCall.cpp
StablehloLegalizeDeprecatedOps.cpp
StablehloLegalizeQuantToInt.cpp
StablehloLegalizeQuantizedOpToQDQ.cpp
StablehloLegalizeToVhlo.cpp
StablehloRefineArguments.cpp
StablehloRefineShapes.cpp
Expand Down
5 changes: 5 additions & 0 deletions stablehlo/transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ void populateStablehloAggressiveFolderPatterns(RewritePatternSet *patterns,
MLIRContext *context,
bool foldFloat);

/// Collection of rewrite patterns for lowering quantized StableHLO operations
/// using uniform dequantize/quantize operations.
void populateStablehloLegalizeQuantizedOpToQDQPatterns(
RewritePatternSet *patterns, MLIRContext *context);

/// A subset of folding patterns for StableHLO that is necessary for shape
/// refinement.
void populateStablehloShapeFolderPatterns(RewritePatternSet *patterns,
Expand Down
31 changes: 31 additions & 0 deletions stablehlo/transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,34 @@ def StablehloLegalizeQuantToIntPass : Pass<"stablehlo-legalize-quant-to-int", "m
"mlir::stablehlo::StablehloDialect",
];
}

def StablehloLegalizeQuantizedOpToQDQPass : Pass<"stablehlo-legalize-quantized-op-to-qdq", "mlir::func::FuncOp"> {
let summary = "Decompose StableHLO quantized ops using uniform quantize/dequantize ops.";

let description = [{
Decompose StableHLO quantized programs using uniform quantize/dequantize
operations. For example, the following program

```mlir
func.func @add(%arg0: tensor<!quant.uniform<i8:f32,1.0:0>>, %arg1: tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>> {
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<!quant.uniform<i8:f32,1.0:0>>, tensor<!quant.uniform<i8:f32,2.0:1>>) -> tensor<!quant.uniform<i8:f32,3.0:2>>
func.return %0 : tensor<!quant.uniform<i8:f32,3.0:2>>
}
```

Will become:

```mlir
func.func @add(%arg0: tensor<!quant.uniform<i8:f32, 1.000000e+00>>, %arg1: tensor<!quant.uniform<i8:f32, 2.000000e+00:1>>) -> tensor<!quant.uniform<i8:f32, 3.000000e+00:2>> {
%0 = stablehlo.uniform_dequantize %arg0 : (tensor<!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<f32>
%1 = stablehlo.uniform_dequantize %arg1 : (tensor<!quant.uniform<i8:f32, 2.000000e+00:1>>) -> tensor<f32>
%2 = stablehlo.add %0, %1 : tensor<f32>
%3 = stablehlo.uniform_quantize %2 : (tensor<f32>) -> tensor<!quant.uniform<i8:f32, 3.000000e+00:2>>
return %3 : tensor<!quant.uniform<i8:f32, 3.000000e+00:2>>
}
```
}];
let dependentDialects = [
"mlir::stablehlo::StablehloDialect",
];
}
135 changes: 135 additions & 0 deletions stablehlo/transforms/StablehloLegalizeQuantizedOpToQDQ.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/* Copyright 2024 The StableHLO Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h" // Include for TypeConverter
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/transforms/Passes.h"

namespace mlir {
namespace stablehlo {

#define GEN_PASS_DEF_STABLEHLOLEGALIZEQUANTIZEDOPTOQDQPASS
#include "stablehlo/transforms/Passes.h.inc"

namespace {

bool isAnyQuantizedTypes(TypeRange types) {
return llvm::any_of(types, [](Type type) {
return isa<quant::QuantizedType>(getElementTypeOrSelf(type));
});
}

template <typename StablehloOpType>
struct QuantizedStablehloOpConversion
: public OpRewritePattern<StablehloOpType> {
using OpRewritePattern<StablehloOpType>::OpRewritePattern;
LogicalResult matchAndRewrite(StablehloOpType op,
PatternRewriter& rewriter) const override {
if (!isAnyQuantizedTypes(op->getOperandTypes()) &&
!isAnyQuantizedTypes(op->getResultTypes())) {
return failure();
}

SmallVector<Value> dequantizedOperands;
for (auto operand : op->getOperands()) {
if (isa<quant::QuantizedType>(getElementTypeOrSelf(operand.getType()))) {
dequantizedOperands.push_back(
rewriter.create<UniformDequantizeOp>(op->getLoc(), operand));
} else {
dequantizedOperands.push_back(operand);
}
}

auto origOp = op.getOperation();
auto origAttrs = origOp->getAttrs();
auto newOp = rewriter
.create<StablehloOpType>(op.getLoc(), dequantizedOperands,
origAttrs)
.getOperation();

SmallVector<Value> quantizedResults;
for (auto [oldResult, newResult] :
llvm::zip(origOp->getResults(), newOp->getResults())) {
if (isa<quant::QuantizedType>(
getElementTypeOrSelf(oldResult.getType()))) {
quantizedResults.push_back(
rewriter.create<stablehlo::UniformQuantizeOp>(
op->getLoc(), oldResult.getType(), newResult));
} else {
quantizedResults.push_back(newResult);
}
}
rewriter.replaceOp(op, quantizedResults);
return success();
}
};

class StablehloLegalizeQuantizedOpToQDQPass
: public impl::StablehloLegalizeQuantizedOpToQDQPassBase<
StablehloLegalizeQuantizedOpToQDQPass> {
public:
LogicalResult initialize(MLIRContext* context) override {
RewritePatternSet patterns_(context);
populateStablehloLegalizeQuantizedOpToQDQPatterns(&patterns_, context);
patterns = std::move(patterns_);
return success();
}

void runOnOperation() override {
auto func = getOperation();
if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) {
func.emitError("Failed to converge StablehloCanonicalizeDynamism in ")
<< config.maxIterations << " iterations";
sdasgup3 marked this conversation as resolved.
Show resolved Hide resolved
}
}

private:
FrozenRewritePatternSet patterns;
GreedyRewriteConfig config;
};

template <typename... StablehloOpTypes>
void populateStablehloLegalizeQuantizedOpToQDQPatterns(
RewritePatternSet* patterns, MLIRContext* context) {
patterns->add<QuantizedStablehloOpConversion<StablehloOpTypes>...>(context);
}

} // namespace

void populateStablehloLegalizeQuantizedOpToQDQPatterns(
RewritePatternSet* patterns, MLIRContext* context) {
// The following list covers most of the operations which, according to the
// stablehlo spoecification document, interprets the quantized
// operation using dequant-op-quant strategy. The ones excluded are
// AddOP, ConvolutionOp, DotGeneralOp, and DynamicConvOp, which are current
// using `stablehlo-legalize-quant-to-int` pass for decomposituion to
// primitive math operations.
sdasgup3 marked this conversation as resolved.
Show resolved Hide resolved
populateStablehloLegalizeQuantizedOpToQDQPatterns<
AbsOp, Atan2Op, BatchNormGradOp, BatchNormInferenceOp,
BatchNormTrainingOp, CbrtOp, CeilOp, CholeskyOp, ClampOp, CompareOp,
CosineOp, DivOp, Expm1Op, ExpOp, FloorOp, Log1pOp, LogisticOp, LogOp,
MaxOp, MinOp, MulOp, NegOp, PowOp, ReducePrecisionOp, RemOp, RoundOp,
RoundNearestEvenOp, RsqrtOp, SelectOp, SignOp, SineOp, SqrtOp, SubtractOp,
TanhOp, TriangularSolveOp>(patterns, context);
}

} // namespace stablehlo
} // namespace mlir
Loading