Skip to content

Commit

Permalink
[pjrt] Link libdevice *before* running the optimization pipeline
Browse files Browse the repository at this point in the history
I also added a check whether libdevice is necessary to avoid linking it in
when no libdevice functions are used by the kernel.

PiperOrigin-RevId: 725803347
  • Loading branch information
superbobry authored and Google-ML-Automation committed Feb 11, 2025
1 parent e9c0445 commit a98b259
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions xla/pjrt/triton_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "absl/strings/str_replace.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
Expand Down Expand Up @@ -256,17 +257,21 @@ absl::StatusOr<std::string> LLVMToPTX(mlir::ModuleOp module,

llvmModule->setDataLayout(machine->createDataLayout());

auto needsLibdevice =
llvm::any_of(llvmModule->functions(), [](const auto& f) {
return !f.isIntrinsic() && f.isDeclaration() &&
f.getName().starts_with("__nv_");
});
if (needsLibdevice) {
TF_RETURN_IF_ERROR(LinkLibdevice(llvmModule.get()));
}

auto transformer = mlir::makeOptimizingTransformer(
/*optLevel=*/3, /*sizeLevel=*/0, /*targetMachine=*/machine.get());
if (auto error = transformer(llvmModule.get()); error) {
return absl::InternalError("Failed to optimize LLVM IR");
}

if (auto status = LinkLibdevice(llvmModule.get()); !status.ok()) {
// TODO(slebedev): Make this an error if the module requires libdevice.
LOG(ERROR) << "Failed to link libdevice: " << status;
}

std::string result;
{
llvm::raw_string_ostream stream(result);
Expand Down

0 comments on commit a98b259

Please sign in to comment.