From 68de1a72c76b715dce7d6748b2f5d7939fc82a37 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Wed, 22 Jan 2025 09:16:28 -0800 Subject: [PATCH] #sdy Add Shardy translate mesh pass. For a module with a top level mesh symbol, this allows you to translate it and any uses of it in shardings to the new axis names. Currently this is only to support translating mesh axis names, not sizes. PiperOrigin-RevId: 718412595 --- shardy/dialect/sdy/transforms/import/BUILD | 1 + .../dialect/sdy/transforms/import/passes.td | 19 +++ .../import/test/translate_mesh.mlir | 15 ++ .../sdy/transforms/import/translate_mesh.cc | 133 ++++++++++++++++++ 4 files changed, 168 insertions(+) create mode 100644 shardy/dialect/sdy/transforms/import/test/translate_mesh.mlir create mode 100644 shardy/dialect/sdy/transforms/import/translate_mesh.cc diff --git a/shardy/dialect/sdy/transforms/import/BUILD b/shardy/dialect/sdy/transforms/import/BUILD index f5028c29..8b1a7459 100644 --- a/shardy/dialect/sdy/transforms/import/BUILD +++ b/shardy/dialect/sdy/transforms/import/BUILD @@ -48,6 +48,7 @@ cc_library( "lift_inlined_meshes.cc", "manual_axes_cleanup.cc", "sharding_group_import.cc", + "translate_mesh.cc", ], hdrs = [ "passes.h", diff --git a/shardy/dialect/sdy/transforms/import/passes.td b/shardy/dialect/sdy/transforms/import/passes.td index c2ed7f5e..ddfedc12 100644 --- a/shardy/dialect/sdy/transforms/import/passes.td +++ b/shardy/dialect/sdy/transforms/import/passes.td @@ -153,3 +153,22 @@ def ManualAxesCleanupPass : Pass<"sdy-manual-axes-cleanup", "ModuleOp"> { }]; let dependentDialects = ["mlir::sdy::SdyDialect"]; } + +def TranslateMeshPass : Pass<"sdy-translate-mesh", "ModuleOp"> { + let summary = "Replaces "; + let description = [{ + 1. For any in/out sharding that hasn't specified a manual axis, add that + manual axis to its replicated_axes. This is to ensure manual axes are + always fully specified. + 2. Sorts the manual axes in mesh axis declaration order. + }]; + let dependentDialects = ["mlir::sdy::SdyDialect"]; + + let options = [ + ListOption<"axisNames", "axis-names", "std::string", + "Names of the new axes.">, + Option<"oldMeshName", "old-mesh-name", "std::string", + /*default=*/"", + "The name of the old mesh to be replaced."> + ]; +} diff --git a/shardy/dialect/sdy/transforms/import/test/translate_mesh.mlir b/shardy/dialect/sdy/transforms/import/test/translate_mesh.mlir new file mode 100644 index 00000000..4839da09 --- /dev/null +++ b/shardy/dialect/sdy/transforms/import/test/translate_mesh.mlir @@ -0,0 +1,15 @@ +// RUN: sdy_opt %s -sdy-translate-mesh="old-mesh-name=my_mesh axis-names='data,model'" 2>&1 | FileCheck %s + +// CHECK-LABEL: @my_mesh +// CHECK-SAME{LITERAL}: <["data"=2, "model"=4]> +sdy.mesh @my_mesh = <["a"=2, "b"=4]> + +// CHECK-NOT: <["a"=2, "b"=4]> + +// CHECK-LABEL: @foo +func.func @foo(%arg0 : tensor<8x8xf32>) -> tensor<8x8xf32> { + // CHECK-NEXT: stablehlo.add + // CHECK-SAME{LITERAL}: #sdy.sharding_per_value<[<@my_mesh, [{"data", ?}p1, {}], replicated={"model"}>]> + %0 = stablehlo.add %arg0, %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@my_mesh, [{"a", ?}p1, {}], replicated={"b"}>]>} : tensor<8x8xf32> + return %0 : tensor<8x8xf32> +} diff --git a/shardy/dialect/sdy/transforms/import/translate_mesh.cc b/shardy/dialect/sdy/transforms/import/translate_mesh.cc new file mode 100644 index 00000000..25de9d41 --- /dev/null +++ b/shardy/dialect/sdy/transforms/import/translate_mesh.cc @@ -0,0 +1,133 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include // IWYU pragma: keep +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/transforms/common/sharding_walker.h" +#include "shardy/dialect/sdy/transforms/import/passes.h" // IWYU pragma: keep + +namespace mlir { +namespace sdy { + +#define GEN_PASS_DEF_TRANSLATEMESHPASS +#include "shardy/dialect/sdy/transforms/import/passes.h.inc" + +namespace { + +LogicalResult translateMesh(ModuleOp moduleOp, + StringRef oldMeshName, + ArrayRef newMeshAxisNames) { + MLIRContext* context = moduleOp.getContext(); + auto oldMeshOp = SymbolTable::lookupNearestSymbolFrom( + moduleOp, SymbolRefAttr::get(context, oldMeshName)); + if (!oldMeshOp) { + return moduleOp.emitError() + << "Mesh " << oldMeshName << " not found in module."; + } + sdy::MeshAttr oldMesh = oldMeshOp.getMesh(); + if (oldMesh.getAxes().size() != newMeshAxisNames.size()) { + return moduleOp.emitError() + << "Both meshes must have the same number of axes."; + } + llvm::StringMap oldToNewAxis; + bool sameMesh = true; + for (auto [oldAxis, newAxisName] : + llvm::zip_equal(oldMesh.getAxes(), newMeshAxisNames)) { + oldToNewAxis[oldAxis.getName()] = newAxisName; + if (oldAxis.getName() != newAxisName) { + sameMesh = false; + } + } + // Exit early if the meshes are the exact same. + if (sameMesh) { + return success(); + } + StringAttr meshName = StringAttr::get(context, oldMeshName); + sdy::transformShardings( + moduleOp, + [&](sdy::TensorShardingAttr oldSharding) -> sdy::TensorShardingAttr { + SmallVector newDimShardings; + for (auto oldDimSharding : oldSharding.getDimShardings()) { + SmallVector newAxisRefs; + llvm::transform(oldDimSharding.getAxes(), + std::back_inserter(newAxisRefs), + [&](sdy::AxisRefAttr oldAxisRef) { + return sdy::AxisRefAttr::get( + context, oldToNewAxis[oldAxisRef.getName()], + oldAxisRef.getSubAxisInfo()); + }); + newDimShardings.push_back(sdy::DimensionShardingAttr::get( + context, newAxisRefs, oldDimSharding.getIsClosed(), + oldDimSharding.getPriority())); + } + SmallVector newReplicatedAxes; + llvm::transform(oldSharding.getReplicatedAxes(), + std::back_inserter(newReplicatedAxes), + [&](sdy::AxisRefAttr oldAxisRef) { + return sdy::AxisRefAttr::get( + context, oldToNewAxis[oldAxisRef.getName()], + oldAxisRef.getSubAxisInfo()); + }); + return sdy::TensorShardingAttr::get(context, meshName, newDimShardings, + newReplicatedAxes); + }); + SmallVector newAxes; + newAxes.reserve(newMeshAxisNames.size()); + for (const auto& [axisName, oldAxis] : + llvm::zip_equal(newMeshAxisNames, oldMesh.getAxes())) { + newAxes.push_back(MeshAxisAttr::get(context, axisName, oldAxis.getSize())); + } + IRRewriter rewriter(moduleOp); + rewriter.setInsertionPoint(oldMeshOp); + SymbolTable symbolTable(moduleOp); + auto newMeshOp = rewriter.create( + moduleOp.getLoc(), oldMeshName, + MeshAttr::get(context, newAxes, oldMesh.getDeviceIds())); + symbolTable.erase(oldMeshOp); + symbolTable.insert(newMeshOp); + return success(); +} + +struct TranslateMeshPass + : public impl::TranslateMeshPassBase { + using TranslateMeshPassBase::TranslateMeshPassBase; + + void runOnOperation() final { + if (translateMesh(getOperation(), oldMeshName, llvm::to_vector(axisNames)) + .failed()) { + signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace sdy +} // namespace mlir