Skip to content

Commit

Permalink
test structured bindings
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 725127187
  • Loading branch information
tomnatan30 authored and copybara-github committed Feb 10, 2025
1 parent 07ea88d commit 02de9f2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
10 changes: 3 additions & 7 deletions shardy/dialect/sdy/transforms/propagation/basic_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -614,11 +614,7 @@ LogicalResult BasicPropagationPassImpl::propagate(
conservativePropagation, shardingGroupMap);
// We only need a single iteration (and another to confirm convergence), since
// we make sure ops whose sharding changes are added back to the worklist.
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
config.fold = false;
config.cseConstants = false;
GreedyRewriteConfig config{.useTopDownTraversal = true};
if (failed(applyPatternsGreedily(moduleOp, std::move(patterns), config))) {
// We should always converge in 2 iterations, if we don't, something is
// wrong.
Expand All @@ -628,8 +624,8 @@ LogicalResult BasicPropagationPassImpl::propagate(
return failure();
}

// Pushes any shardings from the values returned in the terminator of the body
// of `funcOp` to the corresponding `funcOp` result type attrs.
// Pushes any shardings from tha values returned in the terminator of the body
// of `funcOp` to the coresponding `funcOp` result type attrs.
if (failed(propagateFuncResults(moduleOp, symbolTable, factorPropagation,
shardingGroupMap))) {
return failure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ void saveShardingOriginsOnModule(

func::FuncOp funcOp = getEnclosingOfType<func::FuncOp>(owningOp);

// TODO(bartchr): Swap the map to store `ValueOrFuncResult` to avoid having
// TODO(bartchr): Swop the map to sture `ValueOrFuncResult` to avoid having
// to do this terminator finding logic just to set the func result attr.
OpOperand* terminatorOperand = getTerminatorOperand(value, funcOp);

Expand All @@ -317,10 +317,12 @@ void saveShardingOriginsOnModule(
builder.getDictionaryAttr(entries));
}
TypeSwitch<Operation*, void>(owningOp)
.Case<func::FuncOp>([&, value = value](func::FuncOp funcOp) {
funcOp.setArgAttr(cast<BlockArgument>(value).getArgNumber(),
kShardingOriginsAttr,
builder.getDictionaryAttr(entries));
.Case<func::FuncOp>([&](func::FuncOp funcOp) {
if (value) {
funcOp.setArgAttr(cast<BlockArgument>(value).getArgNumber(),
kShardingOriginsAttr,
builder.getDictionaryAttr(entries));
}
})
.Case<ShardingConstraintOp, DataFlowEdgeOp>([&](Operation* op) {
op->setAttr(kShardingOriginsAttr, builder.getDictionaryAttr(entries));
Expand Down

0 comments on commit 02de9f2

Please sign in to comment.