diff --git a/shardy/dialect/sdy/ir/utils.cc b/shardy/dialect/sdy/ir/utils.cc index 14b54471..5197b3f3 100644 --- a/shardy/dialect/sdy/ir/utils.cc +++ b/shardy/dialect/sdy/ir/utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include "llvm/ADT/STLExtras.h" @@ -336,5 +337,35 @@ void removeShardingRules(Operation* rootOp) { }); } +std::optional getCommonMeshName( + ArrayRef operandShardings, + ArrayRef resultsShardings) { + StringRef meshName; + for (TensorShardingAttr sharding : llvm::concat( + operandShardings, resultsShardings)) { + if (sharding) { + if (meshName.empty()) { + meshName = sharding.getMeshName(); + } else if (meshName != sharding.getMeshName()) { + // Found more than one mesh name. + return std::nullopt; + } + } + } + return meshName.empty() ? std::nullopt : std::make_optional(meshName); +} + +ManualAxisToOwner getParentManualComputationOps(Operation* op) { + ManualAxisToOwner alreadyManualAxes; + auto parent = op->getParentOfType(); + while (parent) { + for (StringRef axisName : parent.getManualAxes()) { + alreadyManualAxes[axisName] = parent; + } + parent = parent->getParentOfType(); + } + return alreadyManualAxes; +} + } // namespace sdy } // namespace mlir diff --git a/shardy/dialect/sdy/ir/utils.h b/shardy/dialect/sdy/ir/utils.h index b4337b6d..4fbb89a6 100644 --- a/shardy/dialect/sdy/ir/utils.h +++ b/shardy/dialect/sdy/ir/utils.h @@ -17,9 +17,11 @@ limitations under the License. #define SHARDY_DIALECT_SDY_IR_UTILS_H_ #include +#include #include #include +#include "llvm/ADT/DenseMap.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Threading.h" @@ -260,6 +262,22 @@ void cloneRegionAndConvertTerminatorOp(Region& src, Region& dst) { cloneRegionAndConvertTerminatorOp(src, dst, rewriter); } +// Returns the common mesh name used by all the `TensorShardingAttr` or +// std::nullopt if there is none. +std::optional getCommonMeshName( + ArrayRef operandShardings, + ArrayRef resultsShardings); + +// Mapping between an axis name to the `ManualComputationOp` whose body is +// manual on. +using ManualAxisToOwner = llvm::SmallDenseMap; + +// Creates a mapping from axis name to the corresponding `ManualComputationOp`. +// +// ManualComputations op are allowed to be nested within each other, and this +// gives what manual axis was introduced by what `ManualComputationOp`. +ManualAxisToOwner getParentManualComputationOps(Operation* op); + } // namespace sdy } // namespace mlir diff --git a/shardy/dialect/sdy/ir/verifiers.cc b/shardy/dialect/sdy/ir/verifiers.cc index b48c03a5..4f90b64a 100644 --- a/shardy/dialect/sdy/ir/verifiers.cc +++ b/shardy/dialect/sdy/ir/verifiers.cc @@ -55,10 +55,6 @@ using func::FuncOp; using ::llvm::SmallDenseMap; using ::llvm::SmallDenseSet; -// Mapping between an axis name to the ManualComputationOp whose body is manual -// on. -using ManualAxisToOwner = SmallDenseMap; - using EmitErrorFn = std::function; EmitErrorFn getEmitErrorFn(Operation* op) { @@ -174,22 +170,6 @@ LogicalResult verifySubAxes(ArrayRef subAxes, StringRef axisName, return success(); } -// ManualComputations op are allowed to be nested within each other. However, -// they cannot operate on the same manual axes. This function creates a mapping -// from a manual mesh axis name to the corresponding ManualComputationOp that -// operates on it to help with verifying this is the case. -ManualAxisToOwner getParentManualComputationOps(Operation* op) { - ManualAxisToOwner alreadyManualAxes; - auto parent = op->getParentOfType(); - while (parent) { - for (StringRef axisName : parent.getManualAxes()) { - alreadyManualAxes[axisName] = parent; - } - parent = parent->getParentOfType(); - } - return alreadyManualAxes; -} - LogicalResult emitBoundAxisInManualComputationError(EmitErrorFn emitError, StringRef boundAxis, Location parentLoc) { diff --git a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc index 6f0e8196..11d28675 100644 --- a/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/basic_propagation.cc @@ -209,27 +209,6 @@ void updateTensorShardings( mesh, notifyOpModified); } -// Returns the common mesh name used by all the `TensorShardingAttr` or -// std::nullopt if there is none. -std::optional getCommonMeshName( - ArrayRef operandShardings, - ArrayRef resultsShardings) { - StringRef meshName; - for (TensorShardingAttr sharding : llvm::concat( - operandShardings, resultsShardings)) { - if (sharding) { - if (meshName.empty()) { - meshName = sharding.getMeshName(); - } else if (meshName != sharding.getMeshName()) { - // Found more than one mesh name. - return std::nullopt; - } - } - } - - return meshName.empty() ? std::nullopt : std::make_optional(meshName); -} - // Propagates tensor shardings of the given `operands` and `results` according // to `shardingRule`. //