Skip to content

Commit

Permalink
Make simpler to reason about expansion logic from Iota to legacy repl…
Browse files Browse the repository at this point in the history
…ica groups.

PiperOrigin-RevId: 726730891
  • Loading branch information
toli-y authored and Google-ML-Automation committed Feb 14, 2025
1 parent 7cd8616 commit 9959142
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 47 deletions.
43 changes: 22 additions & 21 deletions xla/hlo/ir/collective_device_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,35 +81,36 @@ IotaReplicaGroupList IotaReplicaGroupList::FromProto(
proto.iota_transpose_perm().end()));
}

void CollectiveDeviceList::MaybeMaterializeFullReplicaGroupList() const {
if (replica_groups_ != nullptr && !replica_groups_->empty()) {
VLOG(10) << "Replica group list already materialized.";
return;
}
if (!iota_replica_group_list_.has_value()) {
VLOG(1) << "Replica group list not materialized because iota replica group "
"list is not present.";
return;
}
VLOG(10) << "Materializing full replica group list";

replica_groups_ = std::make_shared<std::vector<ReplicaGroup>>();
const int64_t num_replica_groups =
iota_replica_group_list_->num_replica_groups();
replica_groups_->reserve(num_replica_groups);

Array<int64_t> array = iota_replica_group_list_->ToArray();
namespace {
std::shared_ptr<std::vector<ReplicaGroup>> ExpandIota(
const IotaReplicaGroupList& iota) {
VLOG(3) << "Expanding iota replica group list: " << iota.ToString();
auto result = std::make_shared<std::vector<ReplicaGroup>>();
const int64_t num_replica_groups = iota.num_replica_groups();
result->reserve(num_replica_groups);

Array<int64_t> array = iota.ToArray();
// Iota replica group list array must only have 2 dimensions.
DCHECK_EQ(array.num_dimensions(), 2);
const int64_t num_devices_per_group =
iota_replica_group_list_->num_devices_per_group();
const int64_t num_devices_per_group = iota.num_devices_per_group();
DCHECK_EQ(array.end() - array.begin(),
num_devices_per_group * num_replica_groups);
for (auto it = array.begin(); it != array.end();
it += num_devices_per_group) {
auto& group = replica_groups_->emplace_back();
auto& group = result->emplace_back();
*group.mutable_replica_ids() = {it, it + num_devices_per_group};
}
return result;
}
} // namespace

const std::vector<ReplicaGroup>& CollectiveDeviceList::replica_groups() const {
if (replica_groups_ == nullptr) {
CHECK(iota_replica_group_list_.has_value());
replica_groups_ = ExpandIota(iota_replica_group_list_.value());
CHECK(replica_groups_ != nullptr);
}
return *replica_groups_;
}

std::string CollectiveDeviceList::ToString(
Expand Down
16 changes: 5 additions & 11 deletions xla/hlo/ir/collective_device_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ class IotaReplicaGroupList {
// replica groups, it may be used to represent these lists in compact forms.
class CollectiveDeviceList {
public:
explicit CollectiveDeviceList() = default;
explicit CollectiveDeviceList()
: replica_groups_(std::make_shared<std::vector<ReplicaGroup>>()) {};

explicit CollectiveDeviceList(absl::Span<const ReplicaGroup> replica_groups)
: replica_groups_(std::make_shared<std::vector<ReplicaGroup>>(
Expand All @@ -99,21 +100,15 @@ class CollectiveDeviceList {
const IotaReplicaGroupList& iota_replica_group_list)
: iota_replica_group_list_(iota_replica_group_list) {}

const std::vector<ReplicaGroup>& replica_groups() const {
MaybeMaterializeFullReplicaGroupList();
return *replica_groups_;
}

// Lazyly explands iota if applicable.
const std::vector<ReplicaGroup>& replica_groups() const;
const std::optional<IotaReplicaGroupList>& iota_replica_group_list() const {
return iota_replica_group_list_;
}

std::string ToString(bool print_full_replica_group_list = false) const;

CollectiveDeviceListProto ToProto() const;

static CollectiveDeviceList FromProto(const CollectiveDeviceListProto& proto);

static CollectiveDeviceList FromProto(const HloInstructionProto& proto);

private:
Expand Down Expand Up @@ -142,8 +137,7 @@ class CollectiveDeviceList {

std::optional<IotaReplicaGroupList> iota_replica_group_list_;
// shared_ptr for fast copy.
mutable std::shared_ptr<std::vector<ReplicaGroup>> replica_groups_ =
std::make_shared<std::vector<ReplicaGroup>>();
mutable std::shared_ptr<std::vector<ReplicaGroup>> replica_groups_ = nullptr;
};

} // namespace xla
Expand Down
62 changes: 47 additions & 15 deletions xla/hlo/ir/collective_device_list_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,51 @@ CollectiveDeviceListProto CreateDeviceListProto(
}

TEST(CollectiveDeviceListTest, DefaultListToString) {
EXPECT_EQ(CollectiveDeviceList().ToString(), "{}");
EXPECT_EQ(CollectiveDeviceList().ToString(true), "{}");
EXPECT_EQ(CollectiveDeviceList().ToString(false), "{}");

ReplicaGroup empty_group;
std::vector<ReplicaGroup> empty_groups;
empty_groups.push_back(empty_group);
empty_groups.push_back(empty_group);
EXPECT_EQ(CollectiveDeviceList(empty_groups).ToString(), "{{},{}}");

std::vector<std::vector<int64_t>> empty_groups2;
EXPECT_EQ(CollectiveDeviceList(empty_groups2).ToString(), "{}");

EXPECT_EQ(CollectiveDeviceList({{1}}).ToString(), "{{1}}");
EXPECT_EQ(CollectiveDeviceList({{1, 2}, {3, 4}}).ToString(), "{{1,2},{3,4}}");
EXPECT_EQ(CollectiveDeviceList({{1, 2, 3, 4, 5, 6, 7}}).ToString(),
"{{1,2,3,4,5,6,7}}");
}

TEST(CollectiveDeviceListTest, DeepCopy) {
CollectiveDeviceList orig({{1, 2, 3, 4, 5, 6, 7}});
CollectiveDeviceList orig({{1, 2, 3, 4}});
CollectiveDeviceList copy = orig;
EXPECT_EQ(&orig.replica_groups(), &copy.replica_groups());
EXPECT_EQ(orig.ToString(), copy.ToString());
}

TEST(CollectiveDeviceListTest, DeepCopyIotaBeforeExpansion) {
CollectiveDeviceList orig(IotaReplicaGroupList(2, 4));
CollectiveDeviceList copy = orig;

EXPECT_NE(&orig.iota_replica_group_list().value(),
&copy.iota_replica_group_list().value());
EXPECT_NE(&orig.replica_groups(), &copy.replica_groups());
EXPECT_EQ(orig.ToString(), copy.ToString());
}

TEST(CollectiveDeviceListTest, DeepCopyIotaAfterExpansion) {
CollectiveDeviceList orig(IotaReplicaGroupList(2, 4));
const std::vector<ReplicaGroup>& local_ref = orig.replica_groups();
CollectiveDeviceList copy = orig;

EXPECT_NE(&orig.iota_replica_group_list().value(),
&copy.iota_replica_group_list().value());
EXPECT_EQ(&orig.replica_groups(), &copy.replica_groups());
EXPECT_EQ(&local_ref, &copy.replica_groups());
EXPECT_EQ(orig.ToString(), copy.ToString());
}

TEST(CollectiveDeviceListTest, DefaultListToProto) {
Expand Down Expand Up @@ -95,27 +130,24 @@ TEST(CollectiveDeviceListTest, DefaultListFromProto2) {
EXPECT_FALSE(list.iota_replica_group_list().has_value());
}

TEST(CollectiveDeviceListTest, IotaListToString) {
CollectiveDeviceList list(IotaReplicaGroupList(2, 10));
EXPECT_EQ(list.ToString(), "[2,10]<=[20]");
TEST(CollectiveDeviceListTest, IotaToString) {
EXPECT_EQ(CollectiveDeviceList(IotaReplicaGroupList(0, 0)).ToString(),
"[0,0]<=[0]");
EXPECT_EQ(CollectiveDeviceList(IotaReplicaGroupList(2, 10)).ToString(),
"[2,10]<=[20]");
}

TEST(CollectiveDeviceListTest,
IotaListToStringWithPrintingFullReplicaGroupList) {
TEST(CollectiveDeviceListTest, IotaToReplicaGroupString) {
CollectiveDeviceList list(IotaReplicaGroupList(2, 10));
EXPECT_EQ(list.ToString(/*print_full_replica_group_list=*/true),
EXPECT_EQ(list.ToString(false), "[2,10]<=[20]");
EXPECT_EQ(list.ToString(true),
"{{0,1,2,3,4,5,6,7,8,9},{10,11,12,13,14,15,16,17,18,19}}");
}

TEST(CollectiveDeviceListTest, IotaListToString2) {
CollectiveDeviceList list(IotaReplicaGroupList(2, 10, {4, 5}, {1, 0}));
EXPECT_EQ(list.ToString(), "[2,10]<=[4,5]T(1,0)");
}

TEST(CollectiveDeviceListTest,
IotaListToStringWithPrintingFullReplicaGroupList2) {
CollectiveDeviceList list(IotaReplicaGroupList(2, 10, {4, 5}, {1, 0}));
EXPECT_EQ(list.ToString(/*print_full_replica_group_list=*/true),
EXPECT_EQ(list.ToString(false), "[2,10]<=[4,5]T(1,0)");
EXPECT_EQ(list.ToString(true),
"{{0,5,10,15,1,6,11,16,2,7},{12,17,3,8,13,18,4,9,14,19}}");
}

Expand Down

0 comments on commit 9959142

Please sign in to comment.