From 82bec193ccae54891e61a035ee41c000066da2b8 Mon Sep 17 00:00:00 2001 From: Bill Varcho Date: Tue, 17 Sep 2024 17:37:39 -0700 Subject: [PATCH] [SDY] preprocess sharding groups which have shardings prior to propagation to validate there are no inter-group conflicts. PiperOrigin-RevId: 675771176 --- .../sdy/transforms/import/import_pipeline.cc | 3 + .../import/sharding_group_import.cc | 81 ++++++++++++++++++- ...r.mlir => sharding_group_constraints.mlir} | 64 +++++++++++++++ .../import/test/sharding_group_import.mlir | 66 +++++++++++++++ 4 files changed, 210 insertions(+), 4 deletions(-) rename shardy/dialect/sdy/transforms/import/test/{sharding_group_manual_computation_barrier.mlir => sharding_group_constraints.mlir} (73%) diff --git a/shardy/dialect/sdy/transforms/import/import_pipeline.cc b/shardy/dialect/sdy/transforms/import/import_pipeline.cc index 5ff4027..7c89a85 100644 --- a/shardy/dialect/sdy/transforms/import/import_pipeline.cc +++ b/shardy/dialect/sdy/transforms/import/import_pipeline.cc @@ -36,6 +36,9 @@ void addImportPipeline(OpPassManager& pm, StringRef dumpDirectory) { pm.addNestedPass(createConstantSplitterPass()); pm.addNestedPass(createAddDataFlowEdgesPass()); pm.addNestedPass(createApplyShardingConstraintsPass()); + // The sharding group import pass must run after applying sharding + // constraints. This ensures we can detect sharding conflicts between group + // members which have pre-propagation shardings due to sharding constraints. pm.addPass(createShardingGroupImportPass()); pm.addPass(createImportMaximalShardingPass()); diff --git a/shardy/dialect/sdy/transforms/import/sharding_group_import.cc b/shardy/dialect/sdy/transforms/import/sharding_group_import.cc index 8b4f2d1..2679ced 100644 --- a/shardy/dialect/sdy/transforms/import/sharding_group_import.cc +++ b/shardy/dialect/sdy/transforms/import/sharding_group_import.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // IWYU pragma: keep #include "mlir/Support/LLVM.h" #include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/utils.h" namespace mlir { namespace sdy { @@ -41,8 +42,11 @@ using llvm::SmallVector; using ValueToShardingGroup = llvm::DenseMap>; +using GroupIdToValues = llvm::DenseMap>; +using TensorShape = ArrayRef; -void unifyShardingGroups(ValueToShardingGroup& tensorToGroups) { +void unifyShardingGroups(ValueToShardingGroup& tensorToGroups, + GroupIdToValues& groupIdToReindexedTensors) { if (tensorToGroups.empty()) { return; } @@ -75,6 +79,7 @@ void unifyShardingGroups(ValueToShardingGroup& tensorToGroups) { for (ShardingGroupOp op : groupsForTensor) { op.setGroupId(reindexMap[shardingGroupEquivalences.getLeaderValue( op.getGroupId())]); + groupIdToReindexedTensors[op.getGroupId()].push_back(op.getInput()); } } } @@ -83,6 +88,7 @@ LogicalResult buildShardingGroupMappingAndValidateGroups( ModuleOp module, ValueToShardingGroup& tensorToGroups) { // Map to hold validation info for shard groups within manual computations. DenseMap groupToManualComp; + DenseMap groupToTensorShape; // While walking the graph we simultaneously build up the tensorToGroups // mapping (which will be used for unification) while also validating the @@ -90,7 +96,7 @@ LogicalResult buildShardingGroupMappingAndValidateGroups( WalkResult result = module.walk([&](ShardingGroupOp op) { tensorToGroups[op.getInput()].push_back(op); - // Validate sharding groups. All values in a group should have either: + // All values in a sharding group should have either: // 1) No manual computation op parent // 2) The same manual computation op parent. // If a group has no manual computation op parent, 'groupToManualComp' @@ -108,11 +114,68 @@ LogicalResult buildShardingGroupMappingAndValidateGroups( return WalkResult::interrupt(); } + // All values in asharding group should have the same shape. It is possible + // to relax this constraint to just requiring ranks are the same (if we + // are not in conservative mode). However GSPMD required tensor shapes to be + // equivalent, so we will maintain this stricter requirement for parity. + TensorShape ts = getTensorShape(op.getInput()); + auto [ts_it, ts_inserted] = groupToTensorShape.try_emplace(groupId, ts); + if (!ts_inserted && ts_it->getSecond() != ts) { + op.emitError( + "ShardingGroupOps values must have the same shape for groupId: ") + << groupId; + return WalkResult::interrupt(); + } + return WalkResult::advance(); }); return failure(result.wasInterrupted()); } +LogicalResult validateCompatibilityAndApplyInitialShardingConstraints( + ModuleOp module, GroupIdToValues& groupIdToValues) { + DenseMap groupIdToSharding; + // Sharding Constraints will only conflict with Sharding Groups if their value + // is a member of some sharding group. Because of this it is sufficient to + // only validate consistency of shardings of values in ShardingGroupOps. + WalkResult result = module.walk([&](ShardingGroupOp shardingGroupOp) { + TensorShardingAttr sharding = getSharding(shardingGroupOp.getInput()); + // Conflicts only occur when there are two or more Values in a group + // which have a sharding and those shardings are different. If there is no + // sharding, then there will be no conflict. + if (!sharding) { + return WalkResult::advance(); + } + + int64_t groupId = shardingGroupOp.getGroupId(); + auto [it, inserted] = groupIdToSharding.try_emplace(groupId, sharding); + if (!inserted && it->second != sharding) { + shardingGroupOp.emitError( + "Inconsistent shardings prior to propagation for ShardingGroupOps " + "with canonicalized groupId: ") + << groupId; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) { + return failure(); + } + + // Apply initial shardings to all values in the group. + for (auto& [groupId, sharding] : groupIdToSharding) { + if (!sharding) { + continue; + } + for (Value value : groupIdToValues[groupId]) { + setSharding(value, sharding); + } + } + + return success(); +} + struct ShardingGroupImportPass : public impl::ShardingGroupImportPassBase { using ShardingGroupImportPassBase::ShardingGroupImportPassBase; @@ -121,12 +184,22 @@ struct ShardingGroupImportPass // Extract the sharding group ids and tensor -> {group_id} mapping from the // high level module and validate any sharding group constrainst are met. ValueToShardingGroup tensorToGroups; - if (failed(buildShardingGroupMappingAndValidateGroups(getOperation(), + ModuleOp module = getOperation(); + if (failed(buildShardingGroupMappingAndValidateGroups(module, tensorToGroups))) { signalPassFailure(); } - unifyShardingGroups(tensorToGroups); + GroupIdToValues groupIdToReindexedTensors; + unifyShardingGroups(tensorToGroups, groupIdToReindexedTensors); + + // This pass assumes sharding constraints are already applied to values. + // Compatibility constraints are applied after group unification to detect + // conflicts within the unified groups. + if (failed(validateCompatibilityAndApplyInitialShardingConstraints( + module, groupIdToReindexedTensors))) { + signalPassFailure(); + } } }; diff --git a/shardy/dialect/sdy/transforms/import/test/sharding_group_manual_computation_barrier.mlir b/shardy/dialect/sdy/transforms/import/test/sharding_group_constraints.mlir similarity index 73% rename from shardy/dialect/sdy/transforms/import/test/sharding_group_manual_computation_barrier.mlir rename to shardy/dialect/sdy/transforms/import/test/sharding_group_constraints.mlir index 4c386fc..1cf9f18 100644 --- a/shardy/dialect/sdy/transforms/import/test/sharding_group_manual_computation_barrier.mlir +++ b/shardy/dialect/sdy/transforms/import/test/sharding_group_constraints.mlir @@ -168,3 +168,67 @@ func.func @main(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { sdy.sharding_group %0 group_id = 7331 : tensor<8x8xf32> func.return %0: tensor<8x8xf32> } + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// Disallow creation of sharding groups which have values with different shapes. +func.func @main(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + %0 = stablehlo.add %arg0, %arg0 : tensor<8x8xf32> + %1 = stablehlo.constant dense<0.0> : tensor<8x8x1xf32> + sdy.sharding_group %arg0 group_id = 23 : tensor<8x8xf32> + sdy.sharding_group %0 group_id = 23 : tensor<8x8xf32> + // expected-error@below {{ShardingGroupOps values must have the same shape for groupId: 23}} + sdy.sharding_group %1 group_id = 23 : tensor<8x8x1xf32> + func.return %0: tensor<8x8xf32> +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// Throw error for sharding groups which have incompatible shardings inferred +// from initial constraints. +func.func @main( + %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, + %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}, {}]>}) { + // %0 = stablehlo.constant dense<0.0> : tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>} + // %1 = stablehlo.constant dense<0.0> : tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}, {}]>} + // Sharding Group and Sharding Constraint compatibility checks happend after + // unification + canonicalization of group ids, which is why the group id + // below (555) corresponds to group id: 0 in the check-error. + sdy.sharding_group %arg0 group_id = 555 : tensor<8x8xf32> + // expected-error@below {{Inconsistent shardings prior to propagation for ShardingGroupOps with canonicalized groupId: 0}} + sdy.sharding_group %arg1 group_id = 555 : tensor<8x8xf32> + func.return +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// Throw error for sharding groups which have incompatible shardings inferred +// from initial constraints. +func.func @main( + %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>}, + %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}) { + + %0 = stablehlo.constant dense<0.0> : tensor<8x8xf32> + %1 = stablehlo.constant dense<0.0> : tensor<8x8xf32> + + sdy.sharding_group %arg0 group_id = 10 : tensor<8x8xf32> + sdy.sharding_group %0 group_id = 10 : tensor<8x8xf32> + sdy.sharding_group %0 group_id = 20 : tensor<8x8xf32> + sdy.sharding_group %1 group_id = 20 : tensor<8x8xf32> + + // The shard group below will cause the above sharding groups to be merged + // by transitivity this implies that all of {%arg0, %arg1, 0, 1} should have + // the same sharding. Note that %0 and %1 are compatible by them selves but + // %arg0 and %arg1 are not due to their initial shardings. + sdy.sharding_group %1 group_id = 30 : tensor<8x8xf32> + // expected-error@below {{Inconsistent shardings prior to propagation for ShardingGroupOps with canonicalized groupId: 0}} + sdy.sharding_group %arg1 group_id = 30 : tensor<8x8xf32> + func.return +} + diff --git a/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir b/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir index 7cd8589..1d05bd1 100644 --- a/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir +++ b/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir @@ -79,3 +79,69 @@ func.func @sharding_groups_reindex_ordering_matches_min_element_ordering(%arg0: sdy.sharding_group %arg2 group_id = 123456 : tensor<4xf32> func.return } + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// CHECK-LABEL: set_existing_shardings_for_sharding_group_members +func.func @set_existing_shardings_for_sharding_group_members( + %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>}, + %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>}) { + // CHECK: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {"b"}]>]>} dense<0.000000e+00> : tensor<8x8xf32> + %0 = stablehlo.constant dense<0.0> : tensor<8x8xf32> + + sdy.sharding_group %arg0 group_id = 43210 : tensor<8x8xf32> + sdy.sharding_group %arg1 group_id = 43210 : tensor<8x8xf32> + sdy.sharding_group %0 group_id = 43210 : tensor<8x8xf32> + func.return +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// CHECK-LABEL: transitively_update_shardings_for_sharding_group_members +func.func @transitively_update_shardings_for_sharding_group_members( + %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, + %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}) { + // CHECK: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} dense<0.000000e+00> : tensor<8x8xf32> + // CHECK: %cst_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} dense<0.000000e+00> : tensor<8x8xf32> + %0 = stablehlo.constant dense<0.0> : tensor<8x8xf32> + %1 = stablehlo.constant dense<0.0> : tensor<8x8xf32> + + sdy.sharding_group %arg0 group_id = 10 : tensor<8x8xf32> + sdy.sharding_group %0 group_id = 10 : tensor<8x8xf32> + sdy.sharding_group %0 group_id = 20 : tensor<8x8xf32> + sdy.sharding_group %1 group_id = 20 : tensor<8x8xf32> + sdy.sharding_group %1 group_id = 30 : tensor<8x8xf32> + sdy.sharding_group %arg1 group_id = 30 : tensor<8x8xf32> + func.return +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// CHECK-LABEL: set_existing_shards_for_disjoint_groups +func.func @set_existing_shards_for_disjoint_groups( + %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, + %arg1: tensor<8x8xf32>, + %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"b"}]>}, + %arg3: tensor<8x8xf32>) { + // CHECK: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} dense<0.000000e+00> : tensor<8x8xf32> + %0 = stablehlo.constant dense<0.0> : tensor<8x8xf32> + // CHECK: %cst_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"b"}]>]>} dense<0.000000e+00> : tensor<8x8xf32> + %1 = stablehlo.constant dense<0.0> : tensor<8x8xf32> + // CHECK: %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<8x8xf32> + %2 = stablehlo.constant dense<0.0> : tensor<8x8xf32> + + sdy.sharding_group %arg0 group_id = 11111 : tensor<8x8xf32> + sdy.sharding_group %arg1 group_id = 11111 : tensor<8x8xf32> + sdy.sharding_group %0 group_id = 11111 : tensor<8x8xf32> + + sdy.sharding_group %arg2 group_id = 22222 : tensor<8x8xf32> + sdy.sharding_group %arg3 group_id = 22222 : tensor<8x8xf32> + sdy.sharding_group %1 group_id = 22222 : tensor<8x8xf32> + func.return +}