From b762b16d1680e5ab5cb039f927828a22dd542762 Mon Sep 17 00:00:00 2001 From: Seher Ellis Date: Fri, 14 Feb 2025 13:06:04 -0800 Subject: [PATCH] [XLA:SchedulingAnnotations] Uniquify annotation ids for unrolled scheduling groups. PiperOrigin-RevId: 727036617 --- xla/service/BUILD | 21 +++++--- xla/service/scheduling_annotations_util.cc | 63 ++++++++++++++++++++++ xla/service/scheduling_annotations_util.h | 42 +++++++++++++++ xla/service/while_loop_unroller.cc | 30 +++++++++-- xla/service/while_loop_unroller_test.cc | 38 +++++++++---- 5 files changed, 172 insertions(+), 22 deletions(-) create mode 100644 xla/service/scheduling_annotations_util.cc create mode 100644 xla/service/scheduling_annotations_util.h diff --git a/xla/service/BUILD b/xla/service/BUILD index c21d5f0e8ec88..9f08aad2e7161 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -2930,21 +2930,20 @@ cc_library( ":call_inliner", ":collective_ops_utils", ":constant_value", - ":hlo_buffer", ":hlo_creation_utils", ":hlo_cse", - ":hlo_value", ":pattern_matcher", + ":scheduling_annotations_util", ":value_range", ":while_loop_constant_sinking", "//xla:comparison_util", "//xla:literal", "//xla:literal_util", "//xla:shape_util", + "//xla:side_effect_util", "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/analysis:while_loop_analysis", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", @@ -2961,8 +2960,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:statusor", ], ) @@ -2970,6 +2967,7 @@ xla_cc_test( name = "while_loop_unroller_test", srcs = ["while_loop_unroller_test.cc"], deps = [ + ":scheduling_annotations_util", ":while_loop_unroller", "//xla:literal", "//xla/hlo/ir:hlo", @@ -2978,11 +2976,11 @@ xla_cc_test( "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", - "@tsl//tsl/platform:statusor", ], ) @@ -6565,4 +6563,15 @@ xla_cc_test( ], ) +cc_library( + name = "scheduling_annotations_util", + srcs = ["scheduling_annotations_util.cc"], + hdrs = ["scheduling_annotations_util.h"], + deps = [ + "//xla:side_effect_util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/strings", + ], +) + exports_files(["xla_aot_compile_test_gpu_target_config.prototxt"]) diff --git a/xla/service/scheduling_annotations_util.cc b/xla/service/scheduling_annotations_util.cc new file mode 100644 index 0000000000000..f7c8479aa466b --- /dev/null +++ b/xla/service/scheduling_annotations_util.cc @@ -0,0 +1,63 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/scheduling_annotations_util.h" + +#include +#include +#include +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/side_effect_util.h" + +namespace xla { + +std::optional GetSchedulingAnnotation( + const HloInstruction* instruction) { + const auto& attrs = instruction->frontend_attributes().map(); + if (!attrs.contains(kXlaSchedulingGroupIdAttr)) { + return std::nullopt; + } + int64_t annotation_id; + if (!absl::SimpleAtoi(attrs.at(kXlaSchedulingGroupIdAttr), &annotation_id)) { + return std::nullopt; + } + return annotation_id; +} + +void SetSchedulingAnnotation(HloInstruction* instruction, int64_t id) { + FrontendAttributes fas = instruction->frontend_attributes(); + fas.mutable_map()->find(kXlaSchedulingGroupIdAttr)->second = absl::StrCat(id); + instruction->set_frontend_attributes(fas); +} + +int64_t NextSchedulingId(const HloModule& module) { + int64_t next_scheduling_id = 1; + for (const HloComputation* comp : module.computations()) { + for (const HloInstruction* hlo : comp->instructions()) { + std::optional scheduling_id = GetSchedulingAnnotation(hlo); + if (scheduling_id.has_value()) { + next_scheduling_id = + std::max(next_scheduling_id, scheduling_id.value() + 1); + } + } + } + return next_scheduling_id; +} + +} // namespace xla diff --git a/xla/service/scheduling_annotations_util.h b/xla/service/scheduling_annotations_util.h new file mode 100644 index 0000000000000..76a36e2828fa9 --- /dev/null +++ b/xla/service/scheduling_annotations_util.h @@ -0,0 +1,42 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_SCHEDULING_ANNOTATIONS_UTIL_H_ +#define XLA_SERVICE_SCHEDULING_ANNOTATIONS_UTIL_H_ + +#include +#include + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" + +namespace xla { + +// Returns the scheduling annotation id for the given instruction. If the +// instruction does not have a scheduling annotation, or the annotation is not +// an integer returns std::nullopt. +std::optional GetSchedulingAnnotation( + const HloInstruction* instruction); + +// Sets the scheduling annotation id for the given instruction. +void SetSchedulingAnnotation(HloInstruction* instruction, int64_t id); + +// Returns the next available scheduling id for the given module. The next +// available id is the maximum scheduling id in the module plus one. +int64_t NextSchedulingId(const HloModule& module); + +} // namespace xla + +#endif // XLA_SERVICE_SCHEDULING_ANNOTATIONS_UTIL_H_ diff --git a/xla/service/while_loop_unroller.cc b/xla/service/while_loop_unroller.cc index 66fb50483b918..6716a968c904f 100644 --- a/xla/service/while_loop_unroller.cc +++ b/xla/service/while_loop_unroller.cc @@ -52,6 +52,7 @@ limitations under the License. #include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_cse.h" #include "xla/service/pattern_matcher.h" +#include "xla/service/scheduling_annotations_util.h" #include "xla/service/value_range.h" #include "xla/service/while_loop_constant_sinking.h" #include "xla/shape.h" @@ -59,8 +60,6 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -186,7 +185,8 @@ absl::Status ReplaceInductionVarUses(HloComputation* body, absl::StatusOr> UnrollSingleIterationOfTrivialLoop(HloInstruction* while_op, WhileLoopConfig config, - const int64_t induction_value) { + const int64_t induction_value, + int64_t& next_scheduling_id) { // We clone the body since we are changing the computation. std::unique_ptr while_body_clone = while_op->while_body()->Clone( @@ -207,6 +207,7 @@ UnrollSingleIterationOfTrivialLoop(HloInstruction* while_op, induction_value_constant, config.induction_var_idx)); + absl::flat_hash_set seen_scheduling_ids; for (HloInstruction* body_inst : while_body_clone->instructions()) { // We need to assign a unique channel_id for the collective ops that are // unrolled within the while loop body or fusions containing collectives. @@ -217,6 +218,18 @@ UnrollSingleIterationOfTrivialLoop(HloInstruction* while_op, // channel_id across the module. collective->set_channel_id(unique_channel_id++); } + + // We need to assign a unique id to each scheduling group (of instructions) + // that are unrolled within the while loop body. + std::optional scheduling_id = GetSchedulingAnnotation(body_inst); + if (scheduling_id.has_value()) { + if (!seen_scheduling_ids.contains(scheduling_id.value())) { + seen_scheduling_ids.insert(scheduling_id.value()); + next_scheduling_id++; + } + SetSchedulingAnnotation(body_inst, next_scheduling_id); + } + // Handle DynamicGte and DynamicTuple custom-calls created during unstacking // pass. All custom-calls must be replaced for the loop to be unrolled // successfully. @@ -279,11 +292,15 @@ absl::StatusOr UnrollInternal(HloInstruction* while_op, HloComputation* computation = while_op->parent(); HloInstruction* unrolled_body_call_op; std::vector call_operands = {while_op->operands().at(0)}; + + int64_t next_scheduling_id = NextSchedulingId(*while_op->GetModule()); for (int64_t i = config.init; i < config.trip_count + config.init; ++i) { CHECK(OverflowSafeAdd(i, (int64_t)1).has_value()); HloComputation* unrolled_body = module->AddEmbeddedComputation( - UnrollSingleIterationOfTrivialLoop(while_op, config, i).value()); + UnrollSingleIterationOfTrivialLoop(while_op, config, i, + next_scheduling_id) + .value()); unrolled_body_call_op = computation->AddInstruction(HloInstruction::CreateCall( while_op->shape(), call_operands, unrolled_body)); @@ -318,11 +335,14 @@ absl::StatusOr UnrollInternalWrappedAndReturnReplacement( // We assume while has only one tuple parameter call_operands.emplace_back(std::move(p.value())); + int64_t next_scheduling_id = NextSchedulingId(*while_op->GetModule()); for (int64_t i = config.init; i < config.trip_count + config.init; ++i) { CHECK(OverflowSafeAdd(i, (int64_t)1).has_value()); HloComputation* unrolled_body = module->AddEmbeddedComputation( - UnrollSingleIterationOfTrivialLoop(while_op, config, i).value()); + UnrollSingleIterationOfTrivialLoop(while_op, config, i, + next_scheduling_id) + .value()); unrolled_body_call_op = body_builder.AddInstruction( HloInstruction::CreateCall(while_op->shape(), call_operands, diff --git a/xla/service/while_loop_unroller_test.cc b/xla/service/while_loop_unroller_test.cc index 8f9398239f1dc..6692220b59a58 100644 --- a/xla/service/while_loop_unroller_test.cc +++ b/xla/service/while_loop_unroller_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" @@ -33,9 +34,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/literal.h" +#include "xla/service/scheduling_annotations_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" -#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -919,8 +920,8 @@ TEST_F(WhileLoopUnrollerTest, LoopWithCollective2) { get-tuple-element.29000 = s8[1,32,2048]{2,1,0:T(8,128)(4,1)S(1)} get-tuple-element(wide_param.41), index=0 get-tuple-element.29001 = s8[1,2048,4096]{2,1,0:T(8,128)(4,1)S(1)} get-tuple-element(wide_param.41), index=1 get-tuple-element.28990 = s32[1,32,256]{2,1,0:T(8,128)S(1)} get-tuple-element(wide_param.41), index=3 - collective-permute-start = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:S(2)}, u32[]{:S(2)}) collective-permute-start(get-tuple-element.28990), channel_id=18, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11},{11,12},{12,13},{13,14},{14,15},{15,0},{16,17},{17,18},{18,19},{19,20},{20,21},{21,22},{22,23},{23,24},{24,25},{25,26},{26,27},{27,28},{28,29},{29,30},{30,31},{31,16},{32,33},{33,34},{34,35},{35,36},{36,37},{37,38},{38,39},{39,40},{40,41},{41,42},{42,43},{43,44},{44,45},{45,46},{46,47},{47,32},{48,49},{49,50},{50,51},{51,52},{52,53},{53,54},{54,55},{55,56},{56,57},{57,58},{58,59},{59,60},{60,61},{61,62},{62,63},{63,48},{64,65},{65,66},{66,67},{67,68},{68,69},{69,70},{70,71},{71,72},{72,73},{73,74},{74,75},{75,76},{76,77},{77,78},{78,79},{79,64},{80,81},{81,82},{82,83},{83,84},{84,85},{85,86},{86,87},{87,88},{88,89},{89,90},{90,91},{91,92},{92,93},{93,94},{94,95},{95,80},{96,97},{97,98},{98,99},{99,100},{100,101},{101,102},{102,103},{103,104},{104,105},{105,106},{106,107},{107,108},{108,109},{109,110},{110,111},{111,96},{112,113},{113,114},{114,115},{115,116},{116,117},{117,118},{118,119},{119,120},{120,121},{121,122},{122,123},{123,124},{124,125},{125,126},{126,127},{127,112},{128,129},{129,130},{130,131},{131,132},{132,133},{133,134},{134,135},{135,136},{136,137},{137,138},{138,139},{139,140},{140,141},{141,142},{142,143},{143,128},{144,145},{145,146},{146,147},{147,148},{148,149},{149,150},{150,151},{151,152},{152,153},{153,154},{154,155},{155,156},{156,157},{157,158},{158,159},{159,144},{160,161},{161,162},{162,163},{163,164},{164,165},{165,166},{166,167},{167,168},{168,169},{169,170},{170,171},{171,172},{172,173},{173,174},{174,175},{175,160},{176,177},{177,178},{178,179},{179,180},{180,181},{181,182},{182,183},{183,184},{184,185},{185,186},{186,187},{187,188},{188,189},{189,190},{190,191},{191,176},{192,193},{193,194},{194,195},{195,196},{196,197},{197,198},{198,199},{199,200},{200,201},{201,202},{202,203},{203,204},{204,205},{205,206},{206,207},{207,192},{208,209},{209,210},{210,211},{211,212},{212,213},{213,214},{214,215},{215,216},{216,217},{217,218},{218,219},{219,220},{220,221},{221,222},{222,223},{223,208},{224,225},{225,226},{226,227},{227,228},{228,229},{229,230},{230,231},{231,232},{232,233},{233,234},{234,235},{235,236},{236,237},{237,238},{238,239},{239,224},{240,241},{241,242},{242,243},{243,244},{244,245},{245,246},{246,247},{247,248},{248,249},{249,250},{250,251},{251,252},{252,253},{253,254},{254,255},{255,240}} - collective-permute-done = s32[1,32,256]{2,1,0:T(8,128)S(1)} collective-permute-done(collective-permute-start) + collective-permute-start = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:S(2)}, u32[]{:S(2)}) collective-permute-start(get-tuple-element.28990), channel_id=18, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11},{11,12},{12,13},{13,14},{14,15},{15,0},{16,17},{17,18},{18,19},{19,20},{20,21},{21,22},{22,23},{23,24},{24,25},{25,26},{26,27},{27,28},{28,29},{29,30},{30,31},{31,16},{32,33},{33,34},{34,35},{35,36},{36,37},{37,38},{38,39},{39,40},{40,41},{41,42},{42,43},{43,44},{44,45},{45,46},{46,47},{47,32},{48,49},{49,50},{50,51},{51,52},{52,53},{53,54},{54,55},{55,56},{56,57},{57,58},{58,59},{59,60},{60,61},{61,62},{62,63},{63,48},{64,65},{65,66},{66,67},{67,68},{68,69},{69,70},{70,71},{71,72},{72,73},{73,74},{74,75},{75,76},{76,77},{77,78},{78,79},{79,64},{80,81},{81,82},{82,83},{83,84},{84,85},{85,86},{86,87},{87,88},{88,89},{89,90},{90,91},{91,92},{92,93},{93,94},{94,95},{95,80},{96,97},{97,98},{98,99},{99,100},{100,101},{101,102},{102,103},{103,104},{104,105},{105,106},{106,107},{107,108},{108,109},{109,110},{110,111},{111,96},{112,113},{113,114},{114,115},{115,116},{116,117},{117,118},{118,119},{119,120},{120,121},{121,122},{122,123},{123,124},{124,125},{125,126},{126,127},{127,112},{128,129},{129,130},{130,131},{131,132},{132,133},{133,134},{134,135},{135,136},{136,137},{137,138},{138,139},{139,140},{140,141},{141,142},{142,143},{143,128},{144,145},{145,146},{146,147},{147,148},{148,149},{149,150},{150,151},{151,152},{152,153},{153,154},{154,155},{155,156},{156,157},{157,158},{158,159},{159,144},{160,161},{161,162},{162,163},{163,164},{164,165},{165,166},{166,167},{167,168},{168,169},{169,170},{170,171},{171,172},{172,173},{173,174},{174,175},{175,160},{176,177},{177,178},{178,179},{179,180},{180,181},{181,182},{182,183},{183,184},{184,185},{185,186},{186,187},{187,188},{188,189},{189,190},{190,191},{191,176},{192,193},{193,194},{194,195},{195,196},{196,197},{197,198},{198,199},{199,200},{200,201},{201,202},{202,203},{203,204},{204,205},{205,206},{206,207},{207,192},{208,209},{209,210},{210,211},{211,212},{212,213},{213,214},{214,215},{215,216},{216,217},{217,218},{218,219},{219,220},{220,221},{221,222},{222,223},{223,208},{224,225},{225,226},{226,227},{227,228},{228,229},{229,230},{230,231},{231,232},{232,233},{233,234},{234,235},{235,236},{236,237},{237,238},{238,239},{239,224},{240,241},{241,242},{242,243},{243,244},{244,245},{245,246},{246,247},{247,248},{248,249},{249,250},{250,251},{251,252},{252,253},{253,254},{254,255},{255,240}}, frontend_attributes={_scheduling_group_id="0"} + collective-permute-done = s32[1,32,256]{2,1,0:T(8,128)S(1)} collective-permute-done(collective-permute-start), frontend_attributes={_scheduling_group_id="0"} get-tuple-element.29005 = u32[]{:T(128)} get-tuple-element(wide_param.41), index=5 get-tuple-element.29006 = u32[256]{0:T(256)} get-tuple-element(wide_param.41), index=6 partition-id.101 = u32[] partition-id() @@ -944,14 +945,14 @@ TEST_F(WhileLoopUnrollerTest, LoopWithCollective2) { clamp.1713 = u32[]{:T(128)S(6)} clamp(get-tuple-element.29005, and.401, get-tuple-element.29008) convert.8616 = s32[]{:T(128)S(6)} convert(clamp.1713) multiply.14831 = s32[]{:T(128)S(6)} multiply(convert.8616, get-tuple-element.29009) - fusion.4289 = s32[1,32,2,256]{3,1,2,0:T(8,128)S(1)} fusion(multiply.14830, bitcast.8823, multiply.14831, get-tuple-element.29000), kind=kOutput, calls=fused_computation.71.clone + fusion.4289 = s32[1,32,2,256]{3,1,2,0:T(8,128)S(1)} fusion(multiply.14830, bitcast.8823, multiply.14831, get-tuple-element.29000), kind=kOutput, calls=fused_computation.71.clone, frontend_attributes={_scheduling_group_id="0"} get-tuple-element.28989 = s32[1,32,256]{2,1,0:T(8,128)S(1)} get-tuple-element(wide_param.41), index=2 - collective-permute-start.1 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:S(2)}, u32[]{:S(2)}) collective-permute-start(get-tuple-element.28989), channel_id=17, source_target_pairs={{0,15},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10},{12,11},{13,12},{14,13},{15,14},{16,31},{17,16},{18,17},{19,18},{20,19},{21,20},{22,21},{23,22},{24,23},{25,24},{26,25},{27,26},{28,27},{29,28},{30,29},{31,30},{32,47},{33,32},{34,33},{35,34},{36,35},{37,36},{38,37},{39,38},{40,39},{41,40},{42,41},{43,42},{44,43},{45,44},{46,45},{47,46},{48,63},{49,48},{50,49},{51,50},{52,51},{53,52},{54,53},{55,54},{56,55},{57,56},{58,57},{59,58},{60,59},{61,60},{62,61},{63,62},{64,79},{65,64},{66,65},{67,66},{68,67},{69,68},{70,69},{71,70},{72,71},{73,72},{74,73},{75,74},{76,75},{77,76},{78,77},{79,78},{80,95},{81,80},{82,81},{83,82},{84,83},{85,84},{86,85},{87,86},{88,87},{89,88},{90,89},{91,90},{92,91},{93,92},{94,93},{95,94},{96,111},{97,96},{98,97},{99,98},{100,99},{101,100},{102,101},{103,102},{104,103},{105,104},{106,105},{107,106},{108,107},{109,108},{110,109},{111,110},{112,127},{113,112},{114,113},{115,114},{116,115},{117,116},{118,117},{119,118},{120,119},{121,120},{122,121},{123,122},{124,123},{125,124},{126,125},{127,126},{128,143},{129,128},{130,129},{131,130},{132,131},{133,132},{134,133},{135,134},{136,135},{137,136},{138,137},{139,138},{140,139},{141,140},{142,141},{143,142},{144,159},{145,144},{146,145},{147,146},{148,147},{149,148},{150,149},{151,150},{152,151},{153,152},{154,153},{155,154},{156,155},{157,156},{158,157},{159,158},{160,175},{161,160},{162,161},{163,162},{164,163},{165,164},{166,165},{167,166},{168,167},{169,168},{170,169},{171,170},{172,171},{173,172},{174,173},{175,174},{176,191},{177,176},{178,177},{179,178},{180,179},{181,180},{182,181},{183,182},{184,183},{185,184},{186,185},{187,186},{188,187},{189,188},{190,189},{191,190},{192,207},{193,192},{194,193},{195,194},{196,195},{197,196},{198,197},{199,198},{200,199},{201,200},{202,201},{203,202},{204,203},{205,204},{206,205},{207,206},{208,223},{209,208},{210,209},{211,210},{212,211},{213,212},{214,213},{215,214},{216,215},{217,216},{218,217},{219,218},{220,219},{221,220},{222,221},{223,222},{224,239},{225,224},{226,225},{227,226},{228,227},{229,228},{230,229},{231,230},{232,231},{233,232},{234,233},{235,234},{236,235},{237,236},{238,237},{239,238},{240,255},{241,240},{242,241},{243,242},{244,243},{245,244},{246,245},{247,246},{248,247},{249,248},{250,249},{251,250},{252,251},{253,252},{254,253},{255,254}} - collective-permute-done.1 = s32[1,32,256]{2,1,0:T(8,128)S(1)} collective-permute-done(collective-permute-start.1) + collective-permute-start.1 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:S(2)}, u32[]{:S(2)}) collective-permute-start(get-tuple-element.28989), channel_id=17, source_target_pairs={{0,15},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10},{12,11},{13,12},{14,13},{15,14},{16,31},{17,16},{18,17},{19,18},{20,19},{21,20},{22,21},{23,22},{24,23},{25,24},{26,25},{27,26},{28,27},{29,28},{30,29},{31,30},{32,47},{33,32},{34,33},{35,34},{36,35},{37,36},{38,37},{39,38},{40,39},{41,40},{42,41},{43,42},{44,43},{45,44},{46,45},{47,46},{48,63},{49,48},{50,49},{51,50},{52,51},{53,52},{54,53},{55,54},{56,55},{57,56},{58,57},{59,58},{60,59},{61,60},{62,61},{63,62},{64,79},{65,64},{66,65},{67,66},{68,67},{69,68},{70,69},{71,70},{72,71},{73,72},{74,73},{75,74},{76,75},{77,76},{78,77},{79,78},{80,95},{81,80},{82,81},{83,82},{84,83},{85,84},{86,85},{87,86},{88,87},{89,88},{90,89},{91,90},{92,91},{93,92},{94,93},{95,94},{96,111},{97,96},{98,97},{99,98},{100,99},{101,100},{102,101},{103,102},{104,103},{105,104},{106,105},{107,106},{108,107},{109,108},{110,109},{111,110},{112,127},{113,112},{114,113},{115,114},{116,115},{117,116},{118,117},{119,118},{120,119},{121,120},{122,121},{123,122},{124,123},{125,124},{126,125},{127,126},{128,143},{129,128},{130,129},{131,130},{132,131},{133,132},{134,133},{135,134},{136,135},{137,136},{138,137},{139,138},{140,139},{141,140},{142,141},{143,142},{144,159},{145,144},{146,145},{147,146},{148,147},{149,148},{150,149},{151,150},{152,151},{153,152},{154,153},{155,154},{156,155},{157,156},{158,157},{159,158},{160,175},{161,160},{162,161},{163,162},{164,163},{165,164},{166,165},{167,166},{168,167},{169,168},{170,169},{171,170},{172,171},{173,172},{174,173},{175,174},{176,191},{177,176},{178,177},{179,178},{180,179},{181,180},{182,181},{183,182},{184,183},{185,184},{186,185},{187,186},{188,187},{189,188},{190,189},{191,190},{192,207},{193,192},{194,193},{195,194},{196,195},{197,196},{198,197},{199,198},{200,199},{201,200},{202,201},{203,202},{204,203},{205,204},{206,205},{207,206},{208,223},{209,208},{210,209},{211,210},{212,211},{213,212},{214,213},{215,214},{216,215},{217,216},{218,217},{219,218},{220,219},{221,220},{222,221},{223,222},{224,239},{225,224},{226,225},{227,226},{228,227},{229,228},{230,229},{231,230},{232,231},{233,232},{234,233},{235,234},{236,235},{237,236},{238,237},{239,238},{240,255},{241,240},{242,241},{243,242},{244,243},{245,244},{246,245},{247,246},{248,247},{249,248},{250,249},{251,250},{252,251},{253,252},{254,253},{255,254}}, frontend_attributes={_scheduling_group_id="0"} + collective-permute-done.1 = s32[1,32,256]{2,1,0:T(8,128)S(1)} collective-permute-done(collective-permute-start.1), frontend_attributes={_scheduling_group_id="0"} fusion.4290 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}) fusion(collective-permute-done, fusion.4289, collective-permute-done.1), kind=kLoop, calls=fused_computation.76.clone get-tuple-element.22079 = s32[1,32,256]{2,1,0:T(8,128)S(1)} get-tuple-element(fusion.4290), index=0 - collective-permute-start.2 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:S(2)}, u32[]{:S(2)}) collective-permute-start(get-tuple-element.22079), channel_id=20, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11},{11,12},{12,13},{13,14},{14,15},{15,0},{16,17},{17,18},{18,19},{19,20},{20,21},{21,22},{22,23},{23,24},{24,25},{25,26},{26,27},{27,28},{28,29},{29,30},{30,31},{31,16},{32,33},{33,34},{34,35},{35,36},{36,37},{37,38},{38,39},{39,40},{40,41},{41,42},{42,43},{43,44},{44,45},{45,46},{46,47},{47,32},{48,49},{49,50},{50,51},{51,52},{52,53},{53,54},{54,55},{55,56},{56,57},{57,58},{58,59},{59,60},{60,61},{61,62},{62,63},{63,48},{64,65},{65,66},{66,67},{67,68},{68,69},{69,70},{70,71},{71,72},{72,73},{73,74},{74,75},{75,76},{76,77},{77,78},{78,79},{79,64},{80,81},{81,82},{82,83},{83,84},{84,85},{85,86},{86,87},{87,88},{88,89},{89,90},{90,91},{91,92},{92,93},{93,94},{94,95},{95,80},{96,97},{97,98},{98,99},{99,100},{100,101},{101,102},{102,103},{103,104},{104,105},{105,106},{106,107},{107,108},{108,109},{109,110},{110,111},{111,96},{112,113},{113,114},{114,115},{115,116},{116,117},{117,118},{118,119},{119,120},{120,121},{121,122},{122,123},{123,124},{124,125},{125,126},{126,127},{127,112},{128,129},{129,130},{130,131},{131,132},{132,133},{133,134},{134,135},{135,136},{136,137},{137,138},{138,139},{139,140},{140,141},{141,142},{142,143},{143,128},{144,145},{145,146},{146,147},{147,148},{148,149},{149,150},{150,151},{151,152},{152,153},{153,154},{154,155},{155,156},{156,157},{157,158},{158,159},{159,144},{160,161},{161,162},{162,163},{163,164},{164,165},{165,166},{166,167},{167,168},{168,169},{169,170},{170,171},{171,172},{172,173},{173,174},{174,175},{175,160},{176,177},{177,178},{178,179},{179,180},{180,181},{181,182},{182,183},{183,184},{184,185},{185,186},{186,187},{187,188},{188,189},{189,190},{190,191},{191,176},{192,193},{193,194},{194,195},{195,196},{196,197},{197,198},{198,199},{199,200},{200,201},{201,202},{202,203},{203,204},{204,205},{205,206},{206,207},{207,192},{208,209},{209,210},{210,211},{211,212},{212,213},{213,214},{214,215},{215,216},{216,217},{217,218},{218,219},{219,220},{220,221},{221,222},{222,223},{223,208},{224,225},{225,226},{226,227},{227,228},{228,229},{229,230},{230,231},{231,232},{232,233},{233,234},{234,235},{235,236},{236,237},{237,238},{238,239},{239,224},{240,241},{241,242},{242,243},{243,244},{244,245},{245,246},{246,247},{247,248},{248,249},{249,250},{250,251},{251,252},{252,253},{253,254},{254,255},{255,240}} - collective-permute-done.2 = s32[1,32,256]{2,1,0:T(8,128)S(1)} collective-permute-done(collective-permute-start.2) + collective-permute-start.2 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:S(2)}, u32[]{:S(2)}) collective-permute-start(get-tuple-element.22079), channel_id=20, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11},{11,12},{12,13},{13,14},{14,15},{15,0},{16,17},{17,18},{18,19},{19,20},{20,21},{21,22},{22,23},{23,24},{24,25},{25,26},{26,27},{27,28},{28,29},{29,30},{30,31},{31,16},{32,33},{33,34},{34,35},{35,36},{36,37},{37,38},{38,39},{39,40},{40,41},{41,42},{42,43},{43,44},{44,45},{45,46},{46,47},{47,32},{48,49},{49,50},{50,51},{51,52},{52,53},{53,54},{54,55},{55,56},{56,57},{57,58},{58,59},{59,60},{60,61},{61,62},{62,63},{63,48},{64,65},{65,66},{66,67},{67,68},{68,69},{69,70},{70,71},{71,72},{72,73},{73,74},{74,75},{75,76},{76,77},{77,78},{78,79},{79,64},{80,81},{81,82},{82,83},{83,84},{84,85},{85,86},{86,87},{87,88},{88,89},{89,90},{90,91},{91,92},{92,93},{93,94},{94,95},{95,80},{96,97},{97,98},{98,99},{99,100},{100,101},{101,102},{102,103},{103,104},{104,105},{105,106},{106,107},{107,108},{108,109},{109,110},{110,111},{111,96},{112,113},{113,114},{114,115},{115,116},{116,117},{117,118},{118,119},{119,120},{120,121},{121,122},{122,123},{123,124},{124,125},{125,126},{126,127},{127,112},{128,129},{129,130},{130,131},{131,132},{132,133},{133,134},{134,135},{135,136},{136,137},{137,138},{138,139},{139,140},{140,141},{141,142},{142,143},{143,128},{144,145},{145,146},{146,147},{147,148},{148,149},{149,150},{150,151},{151,152},{152,153},{153,154},{154,155},{155,156},{156,157},{157,158},{158,159},{159,144},{160,161},{161,162},{162,163},{163,164},{164,165},{165,166},{166,167},{167,168},{168,169},{169,170},{170,171},{171,172},{172,173},{173,174},{174,175},{175,160},{176,177},{177,178},{178,179},{179,180},{180,181},{181,182},{182,183},{183,184},{184,185},{185,186},{186,187},{187,188},{188,189},{189,190},{190,191},{191,176},{192,193},{193,194},{194,195},{195,196},{196,197},{197,198},{198,199},{199,200},{200,201},{201,202},{202,203},{203,204},{204,205},{205,206},{206,207},{207,192},{208,209},{209,210},{210,211},{211,212},{212,213},{213,214},{214,215},{215,216},{216,217},{217,218},{218,219},{219,220},{220,221},{221,222},{222,223},{223,208},{224,225},{225,226},{226,227},{227,228},{228,229},{229,230},{230,231},{231,232},{232,233},{233,234},{234,235},{235,236},{236,237},{237,238},{238,239},{239,224},{240,241},{241,242},{242,243},{243,244},{244,245},{245,246},{246,247},{247,248},{248,249},{249,250},{250,251},{251,252},{252,253},{253,254},{254,255},{255,240}}, frontend_attributes={_scheduling_group_id="1"} + collective-permute-done.2 = s32[1,32,256]{2,1,0:T(8,128)S(1)} collective-permute-done(collective-permute-start.2), frontend_attributes={_scheduling_group_id="1"} get-tuple-element.29011 = u32[]{:T(128)} get-tuple-element(wide_param.41), index=11 add.10209 = u32[]{:T(128)S(6)} add(get-tuple-element.28991, get-tuple-element.29011) subtract.2864 = u32[]{:T(128)S(6)} subtract(add.10204, add.10209) @@ -966,10 +967,10 @@ TEST_F(WhileLoopUnrollerTest, LoopWithCollective2) { clamp.1715 = u32[]{:T(128)S(6)} clamp(get-tuple-element.29005, and.403, get-tuple-element.29008) convert.8618 = s32[]{:T(128)S(6)} convert(clamp.1715) multiply.14833 = s32[]{:T(128)S(6)} multiply(convert.8618, get-tuple-element.29009) - fusion.4293 = s32[1,32,2,256]{3,1,2,0:T(8,128)S(1)} fusion(multiply.14832, bitcast.8824, multiply.14833, get-tuple-element.29000), kind=kOutput, calls=fused_computation.72.clone + fusion.4293 = s32[1,32,2,256]{3,1,2,0:T(8,128)S(1)} fusion(multiply.14832, bitcast.8824, multiply.14833, get-tuple-element.29000), kind=kOutput, calls=fused_computation.72.clone, frontend_attributes={_scheduling_group_id="1"} get-tuple-element.22080 = s32[1,32,256]{2,1,0:T(8,128)S(1)} get-tuple-element(fusion.4290), index=1 - collective-permute-start.3 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:S(2)}, u32[]{:S(2)}) collective-permute-start(get-tuple-element.22080), channel_id=19, source_target_pairs={{0,15},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10},{12,11},{13,12},{14,13},{15,14},{16,31},{17,16},{18,17},{19,18},{20,19},{21,20},{22,21},{23,22},{24,23},{25,24},{26,25},{27,26},{28,27},{29,28},{30,29},{31,30},{32,47},{33,32},{34,33},{35,34},{36,35},{37,36},{38,37},{39,38},{40,39},{41,40},{42,41},{43,42},{44,43},{45,44},{46,45},{47,46},{48,63},{49,48},{50,49},{51,50},{52,51},{53,52},{54,53},{55,54},{56,55},{57,56},{58,57},{59,58},{60,59},{61,60},{62,61},{63,62},{64,79},{65,64},{66,65},{67,66},{68,67},{69,68},{70,69},{71,70},{72,71},{73,72},{74,73},{75,74},{76,75},{77,76},{78,77},{79,78},{80,95},{81,80},{82,81},{83,82},{84,83},{85,84},{86,85},{87,86},{88,87},{89,88},{90,89},{91,90},{92,91},{93,92},{94,93},{95,94},{96,111},{97,96},{98,97},{99,98},{100,99},{101,100},{102,101},{103,102},{104,103},{105,104},{106,105},{107,106},{108,107},{109,108},{110,109},{111,110},{112,127},{113,112},{114,113},{115,114},{116,115},{117,116},{118,117},{119,118},{120,119},{121,120},{122,121},{123,122},{124,123},{125,124},{126,125},{127,126},{128,143},{129,128},{130,129},{131,130},{132,131},{133,132},{134,133},{135,134},{136,135},{137,136},{138,137},{139,138},{140,139},{141,140},{142,141},{143,142},{144,159},{145,144},{146,145},{147,146},{148,147},{149,148},{150,149},{151,150},{152,151},{153,152},{154,153},{155,154},{156,155},{157,156},{158,157},{159,158},{160,175},{161,160},{162,161},{163,162},{164,163},{165,164},{166,165},{167,166},{168,167},{169,168},{170,169},{171,170},{172,171},{173,172},{174,173},{175,174},{176,191},{177,176},{178,177},{179,178},{180,179},{181,180},{182,181},{183,182},{184,183},{185,184},{186,185},{187,186},{188,187},{189,188},{190,189},{191,190},{192,207},{193,192},{194,193},{195,194},{196,195},{197,196},{198,197},{199,198},{200,199},{201,200},{202,201},{203,202},{204,203},{205,204},{206,205},{207,206},{208,223},{209,208},{210,209},{211,210},{212,211},{213,212},{214,213},{215,214},{216,215},{217,216},{218,217},{219,218},{220,219},{221,220},{222,221},{223,222},{224,239},{225,224},{226,225},{227,226},{228,227},{229,228},{230,229},{231,230},{232,231},{233,232},{234,233},{235,234},{236,235},{237,236},{238,237},{239,238},{240,255},{241,240},{242,241},{243,242},{244,243},{245,244},{246,245},{247,246},{248,247},{249,248},{250,249},{251,250},{252,251},{253,252},{254,253},{255,254}} - collective-permute-done.3 = s32[1,32,256]{2,1,0:T(8,128)S(1)} collective-permute-done(collective-permute-start.3) + collective-permute-start.3 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}, u32[]{:S(2)}, u32[]{:S(2)}) collective-permute-start(get-tuple-element.22080), channel_id=19, source_target_pairs={{0,15},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10},{12,11},{13,12},{14,13},{15,14},{16,31},{17,16},{18,17},{19,18},{20,19},{21,20},{22,21},{23,22},{24,23},{25,24},{26,25},{27,26},{28,27},{29,28},{30,29},{31,30},{32,47},{33,32},{34,33},{35,34},{36,35},{37,36},{38,37},{39,38},{40,39},{41,40},{42,41},{43,42},{44,43},{45,44},{46,45},{47,46},{48,63},{49,48},{50,49},{51,50},{52,51},{53,52},{54,53},{55,54},{56,55},{57,56},{58,57},{59,58},{60,59},{61,60},{62,61},{63,62},{64,79},{65,64},{66,65},{67,66},{68,67},{69,68},{70,69},{71,70},{72,71},{73,72},{74,73},{75,74},{76,75},{77,76},{78,77},{79,78},{80,95},{81,80},{82,81},{83,82},{84,83},{85,84},{86,85},{87,86},{88,87},{89,88},{90,89},{91,90},{92,91},{93,92},{94,93},{95,94},{96,111},{97,96},{98,97},{99,98},{100,99},{101,100},{102,101},{103,102},{104,103},{105,104},{106,105},{107,106},{108,107},{109,108},{110,109},{111,110},{112,127},{113,112},{114,113},{115,114},{116,115},{117,116},{118,117},{119,118},{120,119},{121,120},{122,121},{123,122},{124,123},{125,124},{126,125},{127,126},{128,143},{129,128},{130,129},{131,130},{132,131},{133,132},{134,133},{135,134},{136,135},{137,136},{138,137},{139,138},{140,139},{141,140},{142,141},{143,142},{144,159},{145,144},{146,145},{147,146},{148,147},{149,148},{150,149},{151,150},{152,151},{153,152},{154,153},{155,154},{156,155},{157,156},{158,157},{159,158},{160,175},{161,160},{162,161},{163,162},{164,163},{165,164},{166,165},{167,166},{168,167},{169,168},{170,169},{171,170},{172,171},{173,172},{174,173},{175,174},{176,191},{177,176},{178,177},{179,178},{180,179},{181,180},{182,181},{183,182},{184,183},{185,184},{186,185},{187,186},{188,187},{189,188},{190,189},{191,190},{192,207},{193,192},{194,193},{195,194},{196,195},{197,196},{198,197},{199,198},{200,199},{201,200},{202,201},{203,202},{204,203},{205,204},{206,205},{207,206},{208,223},{209,208},{210,209},{211,210},{212,211},{213,212},{214,213},{215,214},{216,215},{217,216},{218,217},{219,218},{220,219},{221,220},{222,221},{223,222},{224,239},{225,224},{226,225},{227,226},{228,227},{229,228},{230,229},{231,230},{232,231},{233,232},{234,233},{235,234},{236,235},{237,236},{238,237},{239,238},{240,255},{241,240},{242,241},{243,242},{244,243},{245,244},{246,245},{247,246},{248,247},{249,248},{250,249},{251,250},{252,251},{253,252},{254,253},{255,254}}, frontend_attributes={_scheduling_group_id="1"} + collective-permute-done.3 = s32[1,32,256]{2,1,0:T(8,128)S(1)} collective-permute-done(collective-permute-start.3), frontend_attributes={_scheduling_group_id="1"} fusion.4294 = (s32[1,32,256]{2,1,0:T(8,128)S(1)}, s32[1,32,256]{2,1,0:T(8,128)S(1)}) fusion(collective-permute-done.2, fusion.4293, collective-permute-done.3), kind=kLoop, calls=fused_computation.74.clone get-tuple-element.29002 = s32[1,32,256]{2,1,0:T(8,128)S(1)} get-tuple-element(fusion.4294), index=1 get-tuple-element.29003 = s32[1,32,256]{2,1,0:T(8,128)S(1)} get-tuple-element(fusion.4294), index=0 @@ -1025,6 +1026,21 @@ TEST_F(WhileLoopUnrollerTest, LoopWithCollective2) { // The total number of fusions in the unrolled version in the entry must be // equal to loop_trip_count * fusion_instr_count EXPECT_EQ(fusion_instr_count * 4, fusion_instr_count_after_unroll); + + // Each scheduling group should have exactly 5 instructions and there should + // be 8 groups in total. + absl::flat_hash_map num_instrs_per_group; + for (const HloInstruction* instr : + module->entry_computation()->instructions()) { + if (std::optional id = GetSchedulingAnnotation(instr)) { + num_instrs_per_group[id.value()]++; + } + } + for (const auto& [group_id, num_instrs] : num_instrs_per_group) { + EXPECT_EQ(num_instrs, 5); + VLOG(1) << "Group id: " << group_id << " num instrs: " << num_instrs; + } + EXPECT_EQ(num_instrs_per_group.size(), 8); } TEST_F(WhileLoopUnrollerTest, MatchShapeCoveringDS) {