Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDY] preprocess sharding groups which have shardings prior to propagation to validate there are no inter-group conflicts. #107

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions shardy/dialect/sdy/transforms/import/import_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ void addImportPipeline(OpPassManager& pm, StringRef dumpDirectory) {
pm.addNestedPass<func::FuncOp>(createConstantSplitterPass());
pm.addNestedPass<func::FuncOp>(createAddDataFlowEdgesPass());
pm.addNestedPass<func::FuncOp>(createApplyShardingConstraintsPass());
// The sharding group import pass must run after applying sharding
// constraints. This ensures we can detect sharding conflicts between group
// members which have pre-propagation shardings due to sharding constraints.
pm.addPass(createShardingGroupImportPass());
pm.addPass(createImportMaximalShardingPass());

Expand Down
81 changes: 77 additions & 4 deletions shardy/dialect/sdy/transforms/import/sharding_group_import.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "mlir/Pass/Pass.h" // IWYU pragma: keep
#include "mlir/Support/LLVM.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"

namespace mlir {
namespace sdy {
Expand All @@ -41,8 +42,11 @@ using llvm::SmallVector;

using ValueToShardingGroup =
llvm::DenseMap<Value, llvm::SmallVector<ShardingGroupOp>>;
using GroupIdToValues = llvm::DenseMap<int64_t, SmallVector<Value>>;
using TensorShape = ArrayRef<int64_t>;

void unifyShardingGroups(ValueToShardingGroup& tensorToGroups) {
void unifyShardingGroups(ValueToShardingGroup& tensorToGroups,
GroupIdToValues& groupIdToReindexedTensors) {
if (tensorToGroups.empty()) {
return;
}
Expand Down Expand Up @@ -75,6 +79,7 @@ void unifyShardingGroups(ValueToShardingGroup& tensorToGroups) {
for (ShardingGroupOp op : groupsForTensor) {
op.setGroupId(reindexMap[shardingGroupEquivalences.getLeaderValue(
op.getGroupId())]);
groupIdToReindexedTensors[op.getGroupId()].push_back(op.getInput());
}
}
}
Expand All @@ -83,14 +88,15 @@ LogicalResult buildShardingGroupMappingAndValidateGroups(
ModuleOp module, ValueToShardingGroup& tensorToGroups) {
// Map to hold validation info for shard groups within manual computations.
DenseMap<int64_t, ManualComputationOp> groupToManualComp;
DenseMap<int64_t, TensorShape> groupToTensorShape;

// While walking the graph we simultaneously build up the tensorToGroups
// mapping (which will be used for unification) while also validating the
// structure of shard groups meets expectations
WalkResult result = module.walk([&](ShardingGroupOp op) {
tensorToGroups[op.getInput()].push_back(op);

// Validate sharding groups. All values in a group should have either:
// All values in a sharding group should have either:
// 1) No manual computation op parent
// 2) The same manual computation op parent.
// If a group has no manual computation op parent, 'groupToManualComp'
Expand All @@ -108,11 +114,68 @@ LogicalResult buildShardingGroupMappingAndValidateGroups(
return WalkResult::interrupt();
}

// All values in asharding group should have the same shape. It is possible
// to relax this constraint to just requiring ranks are the same (if we
// are not in conservative mode). However GSPMD required tensor shapes to be
// equivalent, so we will maintain this stricter requirement for parity.
TensorShape ts = getTensorShape(op.getInput());
auto [ts_it, ts_inserted] = groupToTensorShape.try_emplace(groupId, ts);
if (!ts_inserted && ts_it->getSecond() != ts) {
op.emitError(
"ShardingGroupOps values must have the same shape for groupId: ")
<< groupId;
return WalkResult::interrupt();
}

return WalkResult::advance();
});
return failure(result.wasInterrupted());
}

LogicalResult validateCompatibilityAndApplyInitialShardingConstraints(
ModuleOp module, GroupIdToValues& groupIdToValues) {
DenseMap<int64_t, TensorShardingAttr> groupIdToSharding;
// Sharding Constraints will only conflict with Sharding Groups if their value
// is a member of some sharding group. Because of this it is sufficient to
// only validate consistency of shardings of values in ShardingGroupOps.
WalkResult result = module.walk([&](ShardingGroupOp shardingGroupOp) {
TensorShardingAttr sharding = getSharding(shardingGroupOp.getInput());
// Conflicts only occur when there are two or more Values in a group
// which have a sharding and those shardings are different. If there is no
// sharding, then there will be no conflict.
if (!sharding) {
return WalkResult::advance();
}

int64_t groupId = shardingGroupOp.getGroupId();
auto [it, inserted] = groupIdToSharding.try_emplace(groupId, sharding);
if (!inserted && it->second != sharding) {
shardingGroupOp.emitError(
"Inconsistent shardings prior to propagation for ShardingGroupOps "
"with canonicalized groupId: ")
<< groupId;
return WalkResult::interrupt();
}
return WalkResult::advance();
});

if (result.wasInterrupted()) {
return failure();
}

// Apply initial shardings to all values in the group.
for (auto& [groupId, sharding] : groupIdToSharding) {
if (!sharding) {
continue;
}
for (Value value : groupIdToValues[groupId]) {
setSharding(value, sharding);
}
}

return success();
}

struct ShardingGroupImportPass
: public impl::ShardingGroupImportPassBase<ShardingGroupImportPass> {
using ShardingGroupImportPassBase::ShardingGroupImportPassBase;
Expand All @@ -121,12 +184,22 @@ struct ShardingGroupImportPass
// Extract the sharding group ids and tensor -> {group_id} mapping from the
// high level module and validate any sharding group constrainst are met.
ValueToShardingGroup tensorToGroups;
if (failed(buildShardingGroupMappingAndValidateGroups(getOperation(),
ModuleOp module = getOperation();
if (failed(buildShardingGroupMappingAndValidateGroups(module,
tensorToGroups))) {
signalPassFailure();
}

unifyShardingGroups(tensorToGroups);
GroupIdToValues groupIdToReindexedTensors;
unifyShardingGroups(tensorToGroups, groupIdToReindexedTensors);

// This pass assumes sharding constraints are already applied to values.
// Compatibility constraints are applied after group unification to detect
// conflicts within the unified groups.
if (failed(validateCompatibilityAndApplyInitialShardingConstraints(
module, groupIdToReindexedTensors))) {
signalPassFailure();
}
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,67 @@ func.func @main(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
sdy.sharding_group %0 group_id = 7331 : tensor<8x8xf32>
func.return %0: tensor<8x8xf32>
}

// -----

sdy.mesh @mesh = <["a"=2, "b"=2]>

// Disallow creation of sharding groups which have values with different shapes.
func.func @main(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
%0 = stablehlo.add %arg0, %arg0 : tensor<8x8xf32>
%1 = stablehlo.constant dense<0.0> : tensor<8x8x1xf32>
sdy.sharding_group %arg0 group_id = 23 : tensor<8x8xf32>
sdy.sharding_group %0 group_id = 23 : tensor<8x8xf32>
// expected-error@below {{ShardingGroupOps values must have the same shape for groupId: 23}}
sdy.sharding_group %1 group_id = 23 : tensor<8x8x1xf32>
func.return %0: tensor<8x8xf32>
}

// -----

sdy.mesh @mesh = <["a"=2, "b"=2]>

// Throw error for sharding groups which have incompatible shardings inferred
// from initial constraints.
func.func @main(
%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>},
%arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}, {}]>}) {
// %0 = stablehlo.constant dense<0.0> : tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}
// %1 = stablehlo.constant dense<0.0> : tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}, {}]>}
// Sharding Group and Sharding Constraint compatibility checks happend after
// unification + canonicalization of group ids, which is why the group id
// below (555) corresponds to group id: 0 in the check-error.
sdy.sharding_group %arg0 group_id = 555 : tensor<8x8xf32>
// expected-error@below {{Inconsistent shardings prior to propagation for ShardingGroupOps with canonicalized groupId: 0}}
sdy.sharding_group %arg1 group_id = 555 : tensor<8x8xf32>
func.return
}

// -----

sdy.mesh @mesh = <["a"=2, "b"=2]>

// Throw error for sharding groups which have incompatible shardings inferred
// from initial constraints.
func.func @main(
%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>},
%arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}) {

%0 = stablehlo.constant dense<0.0> : tensor<8x8xf32>
%1 = stablehlo.constant dense<0.0> : tensor<8x8xf32>

sdy.sharding_group %arg0 group_id = 10 : tensor<8x8xf32>
sdy.sharding_group %0 group_id = 10 : tensor<8x8xf32>
sdy.sharding_group %0 group_id = 20 : tensor<8x8xf32>
sdy.sharding_group %1 group_id = 20 : tensor<8x8xf32>

// The shard group below will cause the above sharding groups to be merged
// by transitivity this implies that all of {%arg0, %arg1, 0, 1} should have
// the same sharding. Note that %0 and %1 are compatible by them selves but
// %arg0 and %arg1 are not due to their initial shardings.
sdy.sharding_group %1 group_id = 30 : tensor<8x8xf32>
// expected-error@below {{Inconsistent shardings prior to propagation for ShardingGroupOps with canonicalized groupId: 0}}
sdy.sharding_group %arg1 group_id = 30 : tensor<8x8xf32>
func.return
}

Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,69 @@ func.func @sharding_groups_reindex_ordering_matches_min_element_ordering(%arg0:
sdy.sharding_group %arg2 group_id = 123456 : tensor<4xf32>
func.return
}

// -----

sdy.mesh @mesh = <["a"=2, "b"=2]>

// CHECK-LABEL: set_existing_shardings_for_sharding_group_members
func.func @set_existing_shardings_for_sharding_group_members(
%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>},
%arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>}) {
// CHECK: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {"b"}]>]>} dense<0.000000e+00> : tensor<8x8xf32>
%0 = stablehlo.constant dense<0.0> : tensor<8x8xf32>

sdy.sharding_group %arg0 group_id = 43210 : tensor<8x8xf32>
sdy.sharding_group %arg1 group_id = 43210 : tensor<8x8xf32>
sdy.sharding_group %0 group_id = 43210 : tensor<8x8xf32>
func.return
}

// -----

sdy.mesh @mesh = <["a"=2, "b"=2]>

// CHECK-LABEL: transitively_update_shardings_for_sharding_group_members
func.func @transitively_update_shardings_for_sharding_group_members(
%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>},
%arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}) {
// CHECK: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} dense<0.000000e+00> : tensor<8x8xf32>
// CHECK: %cst_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} dense<0.000000e+00> : tensor<8x8xf32>
%0 = stablehlo.constant dense<0.0> : tensor<8x8xf32>
%1 = stablehlo.constant dense<0.0> : tensor<8x8xf32>

sdy.sharding_group %arg0 group_id = 10 : tensor<8x8xf32>
sdy.sharding_group %0 group_id = 10 : tensor<8x8xf32>
sdy.sharding_group %0 group_id = 20 : tensor<8x8xf32>
sdy.sharding_group %1 group_id = 20 : tensor<8x8xf32>
sdy.sharding_group %1 group_id = 30 : tensor<8x8xf32>
sdy.sharding_group %arg1 group_id = 30 : tensor<8x8xf32>
func.return
}

// -----

sdy.mesh @mesh = <["a"=2, "b"=2]>

// CHECK-LABEL: set_existing_shards_for_disjoint_groups
func.func @set_existing_shards_for_disjoint_groups(
%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>},
%arg1: tensor<8x8xf32>,
%arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"b"}]>},
%arg3: tensor<8x8xf32>) {
// CHECK: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} dense<0.000000e+00> : tensor<8x8xf32>
%0 = stablehlo.constant dense<0.0> : tensor<8x8xf32>
// CHECK: %cst_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"b"}]>]>} dense<0.000000e+00> : tensor<8x8xf32>
%1 = stablehlo.constant dense<0.0> : tensor<8x8xf32>
// CHECK: %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<8x8xf32>
%2 = stablehlo.constant dense<0.0> : tensor<8x8xf32>

sdy.sharding_group %arg0 group_id = 11111 : tensor<8x8xf32>
sdy.sharding_group %arg1 group_id = 11111 : tensor<8x8xf32>
sdy.sharding_group %0 group_id = 11111 : tensor<8x8xf32>

sdy.sharding_group %arg2 group_id = 22222 : tensor<8x8xf32>
sdy.sharding_group %arg3 group_id = 22222 : tensor<8x8xf32>
sdy.sharding_group %1 group_id = 22222 : tensor<8x8xf32>
func.return
}
Loading