Skip to content

Commit

Permalink
Refactor hlo_sharding_util::ReshapeSharding by reducing the if-else…
Browse files Browse the repository at this point in the history
… branches.

We also highlight a TODO in this cl, which will be revisited later.

No behavior change.

PiperOrigin-RevId: 7267313
  • Loading branch information
ZixuanJiang authored and Google-ML-Automation committed Feb 14, 2025
1 parent 9959142 commit da05dc5
Showing 1 changed file with 27 additions and 65 deletions.
92 changes: 27 additions & 65 deletions xla/hlo/utils/hlo_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,7 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
};

bool inplace_add_sharding_dim = false;
auto append_sharding_dim = [&](int64_t size) {
auto append_target_sharding_dim = [&](int64_t size) {
if (inplace_add_sharding_dim) {
target_tile_assignment_dimensions.back() *= size;
} else {
Expand All @@ -924,12 +924,8 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
inplace_add_sharding_dim = false;
};

while (!source_dims_stack.empty() || !target_dims_stack.empty()) {
if (Product(sharding_tile_dims_stack) == 1) {
// No more partitions left.
break;
}

while (!source_dims_stack.empty() && !target_dims_stack.empty() &&
Product(sharding_tile_dims_stack) != 1) {
int64_t source_dims_product = 1;
while (!sharding_tile_dims_stack.empty() &&
sharding_tile_dims_stack.back() == 1) {
Expand All @@ -940,90 +936,56 @@ std::optional<HloSharding> ReshapeSharding(const Shape& source_shape,
source_dims_product % target_dims_stack.back() == 0) {
source_dims_product /= target_dims_stack.back();
target_dims_stack.pop_back();
append_sharding_dim(1);
append_target_sharding_dim(1);
}
if (source_dims_product != 1) {
source_dims_push(source_dims_product, 1);
}

if (target_dims_stack.empty()) {
if (source_dims_stack.empty() || target_dims_stack.empty()) {
break;
}
int64_t s_size = source_dims_stack.back();
int64_t s_partitions = sharding_tile_dims_stack.back();
source_dims_pop();

int64_t t_size = target_dims_stack.back();
target_dims_stack.pop_back();

int64_t s_size = 1;
int64_t s_partitions = 1;
if (!source_dims_stack.empty()) {
s_size = source_dims_stack.back();
s_partitions = sharding_tile_dims_stack.back();
source_dims_pop();
}

if (s_size == t_size) {
// Same dimension.
append_sharding_dim(s_partitions);
} else if (s_partitions > 1 && s_size % s_partitions == 0 &&
t_size % s_partitions == 0) {
// If s_partitions evenly divides both s_size and t_size, we can add this
// sharding dim and work on shard sized shapes in the next iteration.
source_dims_push(s_size / s_partitions, 1);
target_dims_stack.push_back(t_size / s_partitions);
append_sharding_dim(s_partitions);
inplace_add_sharding_dim = true;
// Same dimension size.
append_target_sharding_dim(s_partitions);
} else if (t_size == 1) {
// Trivial dimension added.
append_sharding_dim(1);
append_target_sharding_dim(1);
source_dims_push(s_size, s_partitions);
} else if (s_size == 1) {
// Trivial dimension removed.
target_dims_stack.push_back(t_size);
if (s_partitions > 1) {
dims_to_replicate.push_back(source_dims_index);
}
} else if (s_size > t_size) {
// Dimension split.
if (s_size % s_partitions != 0) {
return std::nullopt;
}
if (s_size % t_size != 0) {
// Transpose is needed between source and target shapes.
append_sharding_dim(std::gcd(t_size, s_partitions));
break;
}
if (t_size % s_partitions == 0) {
append_sharding_dim(s_partitions);
// We have part of the s_size unprocessed, so put it back to stack.
source_dims_push(s_size / t_size, 1);
} else if (s_partitions % t_size == 0) {
append_sharding_dim(t_size);
// We have part of the s_size unprocessed, so put it back to stack.
source_dims_push(s_size / t_size, s_partitions / t_size);
} else if (s_partitions == 1) {
if (!source_dims_stack.empty() && sharding_tile_dims_stack.back() == 1) {
source_dims_stack.back() *= s_size;
} else {
append_sharding_dim(std::gcd(t_size, s_partitions));
break;
}
} else if (s_size % s_partitions != 0) {
// TODO(zixuanjiang): Although we can propagate thd gcd(s_size,
// s_partitions), we return std::nullopt since the current partitioner
// reply on that to create halo exchange. Revisit it later.
return std::nullopt;
} else {
// Dimension merge. Also merge the source dimension with the next, and
// process it next time.
if (s_size % s_partitions != 0) {
return std::nullopt;
}
CHECK(!source_dims_stack.empty());
if (t_size % s_size != 0) {
// Transpose is needed between source and target shapes.
append_sharding_dim(std::gcd(t_size, s_partitions));
int64_t gcd = std::gcd(s_partitions, t_size);
if (gcd == 1) {
break;
}
if (sharding_tile_dims_stack.back() != 1 && s_size != s_partitions) {
// If the next dimension to combine is sharded, we require that the
// current dimension's shard size to be 1. Otherwise, the new shard
// would be non-contiguous.
break;
}
source_dims_stack.back() *= s_size;
sharding_tile_dims_stack.back() *= s_partitions;
target_dims_stack.push_back(t_size);

source_dims_push(s_size / gcd, s_partitions / gcd);
target_dims_stack.push_back(t_size / gcd);
append_target_sharding_dim(gcd);
inplace_add_sharding_dim = true;
}
}

Expand Down

0 comments on commit da05dc5

Please sign in to comment.