diff --git a/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc b/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc index 910297e26644c..8e849bd5b9cbb 100644 --- a/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc +++ b/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc @@ -21,7 +21,9 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" @@ -121,9 +123,16 @@ void IfrtCompileAtomProgramPass::runOnOperation() { meshes_round_trip_attr = front_end_attr.get(xla::sdy::kMeshesRoundTripAttr); } + // Stash the errors in a MapVector, which maintains the order in which they + // are encountered. We do not emit an error within the walk because atom + // programs share a context and their compilations are dispatched in parallel. + // Any error emitted here could leak into a scoped diagnostic handler used + // while dispatching a compilation. + llvm::MapVector call_op_to_error; + // Walk and dispatch the compilations in parallel. - auto compile_result = - module_op.walk([&](CallOp call_op) -> mlir::WalkResult { + module_op.walk( + [&](CallOp call_op) -> mlir::WalkResult { // Do not dispatch the atom program for compilation it has already been // dispatched. if (!call_to_compile_futures.contains(call_op)) { @@ -132,21 +141,28 @@ void IfrtCompileAtomProgramPass::runOnOperation() { llvm::dyn_cast(callee->getParentOp()); if (callee.getSymName() != kCalleeMainFuncName || callee_module == nullptr) { - return call_op.emitOpError() - << "requires callee outlined as `" << kCalleeMainFuncName - << "` function in a ModuleOp. Actual callee name: " - << callee.getSymName() << ". Actual callee parent: " - << callee->getParentOp()->getName(); + // No need to clone the call op because it won't be modified if + // any error is encountered. + call_op_to_error.try_emplace( + call_op, + absl::StrCat( + "requires callee outlined as `", kCalleeMainFuncName.str(), + "` function in a ModuleOp. Actual callee name: ", + callee.getSymName().str(), ". Actual callee parent: ", + callee->getParentOp()->getName().getStringRef().str())); + return mlir::WalkResult::advance(); } if (call_op->hasAttr(kIsSdyPartitioned)) { // Add the meshes roundtrip attribute to the callee module if the // atom program was partitioned with sdy. if (!meshes_round_trip_attr) { - return call_op.emitOpError() - << "requires meshes roundtrip attribute to be set on the " - "program module if the atom program was partitioned " - "with sdy."; + call_op_to_error.try_emplace( + call_op, + "requires meshes roundtrip attribute to be set on the " + "program module if the atom program was partitioned with " + "sdy."); + return mlir::WalkResult::advance(); } xla::sdy::setFrontendAttribute( callee_module, xla::sdy::kMeshesRoundTripAttr, @@ -156,9 +172,12 @@ void IfrtCompileAtomProgramPass::runOnOperation() { absl::StatusOr compile_future = atom_program_compiler_.CompileModule(call_op, callee_module); if (!compile_future.ok()) { - return call_op.emitOpError() - << "failed to dispatch compilation for atom executable: " - << compile_future.status().ToString(); + call_op_to_error.try_emplace( + call_op, + absl::StrCat( + "failed to dispatch compilation of atom executable: ", + compile_future.status().ToString())); + return mlir::WalkResult::advance(); } // Clone the CallOp because it will be modified later, but we want // to keep the original to be able to access the future. @@ -167,17 +186,14 @@ void IfrtCompileAtomProgramPass::runOnOperation() { return mlir::WalkResult::advance(); }); - bool pass_failed = false; - if (compile_result.wasInterrupted()) { - pass_failed = true; - } else { + if (call_op_to_error.empty()) { // Map from the hash of the CallOp to the symbol ref of the // LoadedExecutableOp. llvm::DenseMap call_op_to_loaded_exec_op_ref; // Walk, wait on compilations, and generate LoadedExecutableOps. - auto result = - module_op.walk([&](CallOp call_op) -> mlir::WalkResult { + module_op.walk( + [&](CallOp call_op) -> mlir::WalkResult { mlir::SymbolRefAttr loaded_exec_op_ref; if (auto loaded_exec_op_ref_it = call_op_to_loaded_exec_op_ref.find(call_op); @@ -188,9 +204,12 @@ void IfrtCompileAtomProgramPass::runOnOperation() { } else { auto compile_result = call_to_compile_futures[call_op].Await(); if (!compile_result.ok()) { - return call_op.emitOpError() - << "failed to compile to atom executable: " - << compile_result.status().ToString(); + call_op_to_error.try_emplace( + call_op, + absl::StrCat( + "failed to dispatch compilation of atom executable: ", + compile_result.status().ToString())); + return mlir::WalkResult::advance(); } auto callee_module = llvm::dyn_cast( call_op.getCalleeOp(symbol_table)->getParentOp()); @@ -198,9 +217,11 @@ void IfrtCompileAtomProgramPass::runOnOperation() { GenerateLoadedExecutableOp(callee_module, compile_result->name, call_op, builder); if (!symbol_ref.ok()) { - return call_op.emitOpError() - << "failed to generate loaded executable op: " - << symbol_ref.status().ToString(); + call_op_to_error.try_emplace( + call_op, + absl::StrCat("failed to generate loaded executable op: ", + symbol_ref.status().ToString())); + return mlir::WalkResult::advance(); } loaded_exec_op_ref = *symbol_ref; // Clone the CallOp because it will be modified next, but we want to @@ -210,9 +231,7 @@ void IfrtCompileAtomProgramPass::runOnOperation() { CHECK(atom_executable_map_ ->try_emplace(compile_result->name, std::move(compile_result->executable)) - .second) - << "Failed to insert atom executable to map. Executable `" - << compile_result->name << "` already exists"; + .second); } // Generate CallLoadedExecutableOp. @@ -228,23 +247,24 @@ void IfrtCompileAtomProgramPass::runOnOperation() { call_op.erase(); return mlir::WalkResult::advance(); }); - if (result.wasInterrupted()) { - pass_failed = true; - } // Erase the CallOp clones that we're used as keys of the map. for (auto& [call_op, loaded_exec_op_ref] : call_op_to_loaded_exec_op_ref) { call_op.erase(); } } - if (pass_failed) { - // Wait on all compile futures to ensure that they do not access - // this->compiler_ after the pass has been destructed. We don't care if - // the compilations succeed at this point because the pass has failed - // anyways. + if (!call_op_to_error.empty()) { + // Wait on all compile futures to ensure that 1) the errors emitted here + // do not leak into any scoped diagnostic handlers that might be created + // during compilation dispatch, and 2) this->compiler_ is not accessed after + // the pass has been destructed. We don't care if the compilations succeed + // at this point because the pass has failed anyways. for (auto& [call_op, future] : call_to_compile_futures) { (void)future.Await(); } + for (auto& [call_op, error] : call_op_to_error) { + call_op.emitError(error); + } signalPassFailure(); } // Erase the CallOp clones that we're used as keys of the map.