From cd8dc73f65cd564ba114088b7d8fd98307272f13 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Mon, 10 Feb 2025 08:06:55 -0800 Subject: [PATCH] Extract conflicting collective analysis PiperOrigin-RevId: 725217631 --- xla/service/BUILD | 15 ++ xla/service/collective_conflict_analysis.cc | 166 +++++++++++++++++++ xla/service/collective_conflict_analysis.h | 59 +++++++ xla/service/collective_permute_decomposer.cc | 151 +---------------- 4 files changed, 241 insertions(+), 150 deletions(-) create mode 100644 xla/service/collective_conflict_analysis.cc create mode 100644 xla/service/collective_conflict_analysis.h diff --git a/xla/service/BUILD b/xla/service/BUILD index 42ef16c97ec00..56cae089d1dc1 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -324,6 +324,7 @@ cc_library( hdrs = ["collective_permute_decomposer.h"], deps = [ ":call_graph", + ":collective_conflict_analysis", ":collective_ops_utils", ":collective_permute_cycle", ":pattern_matcher", @@ -364,6 +365,20 @@ xla_cc_test( ], ) +cc_library( + name = "collective_conflict_analysis", + srcs = ["collective_conflict_analysis.cc"], + hdrs = ["collective_conflict_analysis.h"], + deps = [ + ":collective_ops_utils", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + ], +) + cc_library( name = "constant_value", srcs = ["constant_value.cc"], diff --git a/xla/service/collective_conflict_analysis.cc b/xla/service/collective_conflict_analysis.cc new file mode 100644 index 0000000000000..73ce75e9897ed --- /dev/null +++ b/xla/service/collective_conflict_analysis.cc @@ -0,0 +1,166 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/collective_conflict_analysis.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/service/collective_ops_utils.h" + +namespace xla { + +void AbstractReplicaGroups::merge_groups(int64_t replica_id, + int64_t other_replica_id) { + if (get_index(replica_id) == -1 && get_index(other_replica_id) == -1) { + set_index(replica_id, groups.size()); + set_index(other_replica_id, groups.size()); + groups.push_back({replica_id, other_replica_id}); + return; + } + if (get_index(replica_id) == get_index(other_replica_id)) return; + if (get_index(replica_id) == -1) { + std::swap(replica_id, other_replica_id); + } + CHECK_NE(get_index(replica_id), -1); + if (get_index(other_replica_id) == -1) { + set_index(other_replica_id, get_index(replica_id)); + groups[get_index(replica_id)].insert(other_replica_id); + return; + } + CHECK(get_index(replica_id) != -1 && get_index(other_replica_id) != -1 && + get_index(replica_id) != get_index(other_replica_id)); + auto& other_set = groups[get_index(other_replica_id)]; + for (int64_t replica_id_in_other_set : other_set) { + groups[get_index(replica_id)].insert(replica_id_in_other_set); + set_index(replica_id_in_other_set, get_index(replica_id)); + } + other_set.clear(); +} + +bool IsConflictingAbstractReplicaGroups(AbstractReplicaGroups& lhs, + AbstractReplicaGroups& rhs) { + std::vector frequency(lhs.groups.size(), 0); + for (auto& rhs_group : rhs.groups) { + std::fill(frequency.begin(), frequency.end(), 0); + for (int64_t rhs_replica_id : rhs_group) { + int64_t i = lhs.get_index(rhs_replica_id); + if (i == -1) continue; + if (++frequency[i] >= 2) return true; + } + } + return false; +} + +void GetAbstractReplicaGroups(HloInstruction* instr, + AbstractReplicaGroups& groups) { + // Abstract from source-target pairs of collective-permute to abstract replica + // groups. + if (instr->opcode() == HloOpcode::kCollectivePermute) { + auto* cp = Cast(instr); + for (auto& [i, j] : cp->source_target_pairs()) { + groups.merge_groups(i, j); + } + return; + } + + // Abstract from source-target pairs of send/recv to abstract replica groups. + auto add_replica_group = [&groups](const ReplicaGroup& replica_group) { + auto& ids = replica_group.replica_ids(); + if (ids.empty()) return; + int64_t leader_id = ids[0]; + for (int64_t i = 1; i < ids.size(); ++i) { + groups.merge_groups(leader_id, ids[i]); + } + }; + if (instr->opcode() == HloOpcode::kSend || + instr->opcode() == HloOpcode::kRecv) { + auto* sr = Cast(instr); + CHECK(!sr->is_host_transfer()); + std::optional source_target_pairs_str = + sr->frontend_attributes().map().at(kSendRecvSourceTargetPairsAttr); + CHECK(source_target_pairs_str.has_value()); + absl::StatusOr> source_target_pairs = + ParseReplicaGroupsOnly(*source_target_pairs_str); + CHECK(source_target_pairs.ok() && "Expect valid source_target_pairs"); + for (auto& replica_group : *source_target_pairs) { + add_replica_group(replica_group); + } + return; + } + + // Convert normal replica groups to abstract replica groups. + for (auto& replica_group : GetCollectiveReplicaGroups(instr)) { + add_replica_group(replica_group); + } +} + +std::vector FindAllConflictingCollectives( + const HloComputation* computation, + const std::vector& seed_collectives) { + absl::flat_hash_set seen; + + // Get the supremum of all abstract replica groups of the seed collectives + // we're starting with. + AbstractReplicaGroups abstract_replica_groups_supremum; + for (HloInstruction* instr : seed_collectives) { + GetAbstractReplicaGroups(instr, abstract_replica_groups_supremum); + seen.insert(instr); + } + + // Try finding more and more conflicting collectives until we reach a + // fixpoint. This is needed because we may get a coarser supremum with each + // new conflicting collective. + std::vector conflicing_collectives; + bool fixpoint_reached; + do { + fixpoint_reached = true; + + // Look at each collective in the computation. + for (HloInstruction* instr : computation->MakeInstructionPostOrder()) { + // Skip if not a collective or already considered for the supremum. + if (!IsNonFusionCollective(instr) || seen.contains(instr)) continue; + + // Check if this collective is already conflicting with the coarsest + // abstract replica groups. If it does, add to the conflicting collectives + // and update the supremum. + AbstractReplicaGroups groups; + GetAbstractReplicaGroups(instr, groups); + if (IsConflictingAbstractReplicaGroups( + groups, abstract_replica_groups_supremum)) { + conflicing_collectives.push_back(instr); + GetAbstractReplicaGroups(instr, abstract_replica_groups_supremum); + seen.insert(instr); + fixpoint_reached = false; + } + } + } while (!fixpoint_reached); + + return conflicing_collectives; +} + +} // namespace xla diff --git a/xla/service/collective_conflict_analysis.h b/xla/service/collective_conflict_analysis.h new file mode 100644 index 0000000000000..4c0218d2e7ba4 --- /dev/null +++ b/xla/service/collective_conflict_analysis.h @@ -0,0 +1,59 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_COLLECTIVE_CONFLICT_ANALYSIS_H_ +#define XLA_SERVICE_COLLECTIVE_CONFLICT_ANALYSIS_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "xla/hlo/ir/hlo_module.h" + +namespace xla { + +struct AbstractReplicaGroups { + // Holds groups of abstract replica ids. + std::vector> groups; + + // Maps abstract replica id to index in groups. + std::vector index_map; + + int64_t get_index(int64_t replica_id) { + while (index_map.size() <= replica_id) index_map.push_back(-1); + return index_map[replica_id]; + } + + void set_index(int64_t replica_id, int64_t index) { + while (index_map.size() <= replica_id) index_map.push_back(-1); + index_map[replica_id] = index; + } + + void merge_groups(int64_t replica_id, int64_t other_replica_id); +}; + +bool IsConflictingAbstractReplicaGroups(AbstractReplicaGroups& lhs, + AbstractReplicaGroups& rhs); + +void GetAbstractReplicaGroups(HloInstruction* instr, + AbstractReplicaGroups& groups); + +std::vector FindAllConflictingCollectives( + const HloComputation* computation, + const std::vector& seed_collectives); + +} // namespace xla + +#endif // XLA_SERVICE_COLLECTIVE_CONFLICT_ANALYSIS_H_ diff --git a/xla/service/collective_permute_decomposer.cc b/xla/service/collective_permute_decomposer.cc index 7e1fc08837bfc..c6a97b4441f27 100644 --- a/xla/service/collective_permute_decomposer.cc +++ b/xla/service/collective_permute_decomposer.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/service/call_graph.h" +#include "xla/service/collective_conflict_analysis.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/collective_permute_cycle.h" #include "xla/service/gpu/backend_configs.pb.h" @@ -220,156 +221,6 @@ CheckCyclePatterns(HloCollectivePermuteInstruction* cp0, return std::nullopt; } -namespace { - -struct AbstractReplicaGroups { - // Holds groups of abstract replica ids. - std::vector> groups; - - // Maps abstract replica id to index in groups. - std::vector index_map; - - int64_t get_index(int64_t replica_id) { - while (index_map.size() <= replica_id) index_map.push_back(-1); - return index_map[replica_id]; - } - - void set_index(int64_t replica_id, int64_t index) { - while (index_map.size() <= replica_id) index_map.push_back(-1); - index_map[replica_id] = index; - } - - void merge_groups(int64_t replica_id, int64_t other_replica_id) { - if (get_index(replica_id) == -1 && get_index(other_replica_id) == -1) { - set_index(replica_id, groups.size()); - set_index(other_replica_id, groups.size()); - groups.push_back({replica_id, other_replica_id}); - return; - } - if (get_index(replica_id) == get_index(other_replica_id)) return; - if (get_index(replica_id) == -1) { - std::swap(replica_id, other_replica_id); - } - CHECK_NE(get_index(replica_id), -1); - if (get_index(other_replica_id) == -1) { - set_index(other_replica_id, get_index(replica_id)); - groups[get_index(replica_id)].insert(other_replica_id); - return; - } - CHECK(get_index(replica_id) != -1 && get_index(other_replica_id) != -1 && - get_index(replica_id) != get_index(other_replica_id)); - auto& other_set = groups[get_index(other_replica_id)]; - for (int64_t replica_id_in_other_set : other_set) { - groups[get_index(replica_id)].insert(replica_id_in_other_set); - set_index(replica_id_in_other_set, get_index(replica_id)); - } - other_set.clear(); - } -}; - -} // namespace - -static bool IsConflictingAbstractReplicaGroups(AbstractReplicaGroups& lhs, - AbstractReplicaGroups& rhs) { - std::vector frequency(lhs.groups.size(), 0); - for (auto& rhs_group : rhs.groups) { - std::fill(frequency.begin(), frequency.end(), 0); - for (int64_t rhs_replica_id : rhs_group) { - int64_t i = lhs.get_index(rhs_replica_id); - if (i == -1) continue; - if (++frequency[i] >= 2) return true; - } - } - return false; -} - -static void GetAbstractReplicaGroups(HloInstruction* instr, - AbstractReplicaGroups& groups) { - // Abstract from source-target pairs of collective-permute to abstract replica - // groups. - if (instr->opcode() == HloOpcode::kCollectivePermute) { - auto* cp = Cast(instr); - for (auto& [i, j] : cp->source_target_pairs()) { - groups.merge_groups(i, j); - } - return; - } - - // Abstract from source-target pairs of send/recv to abstract replica groups. - auto add_replica_group = [&groups](const ReplicaGroup& replica_group) { - auto& ids = replica_group.replica_ids(); - if (ids.empty()) return; - int64_t leader_id = ids[0]; - for (int64_t i = 1; i < ids.size(); ++i) { - groups.merge_groups(leader_id, ids[i]); - } - }; - if (instr->opcode() == HloOpcode::kSend || - instr->opcode() == HloOpcode::kRecv) { - auto* sr = Cast(instr); - CHECK(!sr->is_host_transfer()); - std::optional source_target_pairs_str = - sr->frontend_attributes().map().at(kSendRecvSourceTargetPairsAttr); - CHECK(source_target_pairs_str.has_value()); - absl::StatusOr> source_target_pairs = - ParseReplicaGroupsOnly(*source_target_pairs_str); - CHECK(source_target_pairs.ok() && "Expect valid source_target_pairs"); - for (auto& replica_group : *source_target_pairs) { - add_replica_group(replica_group); - } - return; - } - - // Convert normal replica groups to abstract replica groups. - for (auto& replica_group : GetCollectiveReplicaGroups(instr)) { - add_replica_group(replica_group); - } -} - -static std::vector FindAllConflictingCollectives( - const HloComputation* computation, - std::vector& seed_collectives) { - absl::flat_hash_set seen; - - // Get the supremum of all abstract replica groups of the seed collectives - // we're starting with. - AbstractReplicaGroups abstract_replica_groups_supremum; - for (HloInstruction* instr : seed_collectives) { - GetAbstractReplicaGroups(instr, abstract_replica_groups_supremum); - seen.insert(instr); - } - - // Try finding more and more conflicting collectives until we reach a - // fixpoint. This is needed because we may get a coarser supremum with each - // new conflicting collective. - std::vector conflicing_collectives; - bool fixpoint_reached; - do { - fixpoint_reached = true; - - // Look at each collective in the computation. - for (HloInstruction* instr : computation->MakeInstructionPostOrder()) { - // Skip if not a collective or already considered for the supremum. - if (!IsNonFusionCollective(instr) || seen.contains(instr)) continue; - - // Check if this collective is already conflicting with the coarsest - // abstract replica groups. If it does, add to the conflicting collectives - // and update the supremum. - AbstractReplicaGroups groups; - GetAbstractReplicaGroups(instr, groups); - if (IsConflictingAbstractReplicaGroups( - groups, abstract_replica_groups_supremum)) { - conflicing_collectives.push_back(instr); - GetAbstractReplicaGroups(instr, abstract_replica_groups_supremum); - seen.insert(instr); - fixpoint_reached = false; - } - } - } while (!fixpoint_reached); - - return conflicing_collectives; -} - static std::vector FindAllConflictingCollectives( HloComputation* computation, const std::vector& cps) {