Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC to test libtorch support for onnxnruntime-extensions #770

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@
cmake_minimum_required(VERSION 3.25)
project(onnxruntime_extensions LANGUAGES C CXX)

MESSAGE(OCOS_LIBTORCH_PATH=$ENV{OCOS_LIBTORCH_PATH})
if (DEFINED ENV{OCOS_LIBTORCH_PATH})
# https://download.pytorch.org/libtorch/nightly/cu121/libtorch-shared-with-deps-latest.zip <-- NOT TESTED
# https://download.pytorch.org/libtorch/nightly/cu121/libtorch-cxx11-abi-shared-with-deps-latest.zip <- NOT TESTED
# https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip <-- WORKS
# TODO: Maybe use export _GLIBCXX_USE_CXX11_ABI=1 if building pytorch from source
set(CMAKE_PREFIX_PATH $ENV{OCOS_LIBTORCH_PATH})
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
endif()

# set(CMAKE_VERBOSE_MAKEFILE ON)
if(NOT CMAKE_BUILD_TYPE)
message(STATUS "Build type not set - using RelWithDebInfo")
Expand Down Expand Up @@ -175,6 +186,18 @@ if (MSVC)
# See https://developercommunity.visualstudio.com/t/Access-violation-with-std::mutex::lock-a/10664660#T-N10668856
# Remove this definition once the conda msvcp140.dll dll is updated.
add_compile_definitions(_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR)

if (DEFINED ENV{OCOS_LIBTORCH_PATH})
# The following code block is suggested to be used on Windows.
# According to https://github.com/pytorch/pytorch/issues/25457,
# the DLLs need to be copied to avoid memory errors.
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
add_custom_command(TARGET ocos_operators
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different
${TORCH_DLLS}
$<TARGET_FILE_DIR:ocos_operators>)
endif()
endif()

if(NOT OCOS_BUILD_PYTHON AND OCOS_ENABLE_PYTHON)
Expand Down Expand Up @@ -596,6 +619,10 @@ target_include_directories(ocos_operators PUBLIC
${PROJECT_SOURCE_DIR}/base
${PROJECT_SOURCE_DIR}/operators)

if (DEFINED ENV{OCOS_LIBTORCH_PATH})
target_include_directories(ocos_operators PUBLIC ${TORCH_INCLUDE_DIRS})
endif()

if (OCOS_USE_CUDA)
target_include_directories(ocos_operators PUBLIC ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples)
endif()
Expand Down Expand Up @@ -724,6 +751,10 @@ list(APPEND ocos_libraries noexcep_operators)
target_compile_definitions(ocos_operators PRIVATE ${OCOS_COMPILE_DEFINITIONS})
target_link_libraries(ocos_operators PRIVATE ${ocos_libraries})

if (DEFINED ENV{OCOS_LIBTORCH_PATH})
target_link_libraries(ocos_operators PRIVATE "${TORCH_LIBRARIES}")
endif()

set (file_patterns "shared/lib/*.cc")
if (OCOS_ENABLE_C_API)
list(APPEND file_patterns "shared/api/*.h*" "shared/api/*.c" "shared/api/*.cc")
Expand Down
1 change: 1 addition & 0 deletions cmake/pytorch
Submodule pytorch added at e880cb
37 changes: 37 additions & 0 deletions operators/math/com_amd_myrelu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <torch/torch.h>

#include "ocos.h"

// TODO: Good example for CPU/CUDA op
// https://github.com/microsoft/onnxruntime-extensions/pull/739/files

// TODO: Add DLPack support to ONNXRuntime-extensions for perf improvement
// https://github.com/microsoft/onnxruntime/pull/6968

// TODO: Make templates for Tensor<T>? Testing for Tensor<float>
// template <typename T>
OrtStatusPtr com_amd_myrelu(const ortc::Tensor<float>& input_ort,
ortc::Tensor<float>& out_ort) {

int64_t input_size = input_ort.NumberOfElement();
if (0 == input_size) {
return nullptr;
}

// Massaging the input to Pytorch format
torch::Tensor X = torch::empty(input_ort.Shape()).contiguous();
memcpy(X.data_ptr<float>(), input_ort.Data(), input_size * sizeof(float)); // TODO: replace with todlpack + torch::Tensor

// Do computation
float* out_ort_ptr = out_ort.Allocate(input_ort.Shape());

// Massaging the output to ORT format
auto out_torch = torch::relu(X);
memcpy(out_ort_ptr, out_torch.data_ptr<float>(), input_size * sizeof(float)); // TODO: replace with todlpack + ortc::Tensor conversion

return nullptr;
}
17 changes: 17 additions & 0 deletions operators/math/cuda/com_amd_myrelu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// #include <torch/extension.h>
#include <torch/torch.h>
#include <cuda_runtime.h>
#include "com_amd_myrelu.cuh"

__global__ void com_amd_myrelu_kernel(const float* input, float* out, int input_size) {
// TODO: Properly implement CUDA version

// Massaging the output to ORT format
auto out_torch = torch::relu(input);
memcpy(out, out_torch.data_ptr<float>(), input_size); // TODO: replace with todlpack + ortc::Tensor conversion
}

void com_amd_myrelu_impl(cudaStream_t stream,
const float* input, float* out, int size) {
com_amd_myrelu_kernel<<<1, 1, 0, stream>>>(input, out, size);
}
9 changes: 9 additions & 0 deletions operators/math/cuda/com_amd_myrelu.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <cuda.h>
#include <cuda_runtime.h>

void com_amd_myrelu_impl(cudaStream_t stream,
const float* input, float* out, int size);
25 changes: 25 additions & 0 deletions operators/math/cuda/com_amd_myrelu_def.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "com_amd_myrelu.cuh"
#include "narrow.h"
#include "com_amd_myrelu_def.h"
#include <cuda.h>
#include <cuda_runtime.h>

OrtStatusPtr com_amd_myrelu_cuda(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<float>& input,
ortc::Tensor<float>& out0_tensor) {
// TODO: Properly implement CUDA version
int64_t input_size = input_ort.NumberOfElement() * sizeof(float);
if (0 == input_size) {
return nullptr;
}

// Massaging the input to Pytorch format
torch::Tensor X = torch::empty(input_ort.Shape()).contiguous();
memcpy(X.data_ptr<float>(), input_ort.Data(), input_size); // TODO: replace with todlpack + torch::Tensor

// Do computation
float* out_ort_ptr = out_ort.Allocate(input_ort.Shape());

com_amd_myrelu_impl(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()), X, out_ort_ptr, input_size);
return nullptr;
}
9 changes: 9 additions & 0 deletions operators/math/cuda/com_amd_myrelu_def.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "ocos.h"

OrtStatusPtr com_amd_myrelu_cuda(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<float>& input,
ortc::Tensor<float>& out0_tensor);
7 changes: 6 additions & 1 deletion operators/math/math.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "ocos.h"
#include "negpos.hpp"
#include "com_amd_myrelu.hpp"
#ifdef ENABLE_DLIB
#include "dlib/inverse.hpp"
#include "dlib/stft_norm.hpp"
Expand All @@ -9,19 +10,23 @@

#ifdef USE_CUDA
#include "cuda/negpos_def.h"
#include "cuda/com_amd_myrelu_def.h"
#endif // USE_CUDA

FxLoadCustomOpFactory LoadCustomOpClasses_Math = []() -> CustomOpArray& {
static OrtOpLoader op_loader(CustomCpuFuncV2("NegPos", neg_pos),
#ifdef USE_CUDA
CustomCudaFuncV2("NegPos", neg_pos_cuda),
CustomCudaFuncV2("MyReLu", com_amd_myrelu_cuda),
#endif
CustomCpuFuncV2("MyReLu", com_amd_myrelu),
#ifdef ENABLE_DLIB
CustomCpuFuncV2("Inverse", inverse),
CustomCpuStructV2("StftNorm", StftNormal),
#endif
CustomCpuFuncV2("SegmentExtraction", segment_extraction),
CustomCpuFuncV2("SegmentSum", segment_sum));
CustomCpuFuncV2("SegmentSum", segment_sum)
);

#if defined(USE_CUDA)
// CustomCudaFunc("NegPos", neg_pos_cuda),
Expand Down
124 changes: 124 additions & 0 deletions test/test_myrelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import io
import onnx
import unittest
import torch
import numpy as np
import onnxruntime as _ort
from onnxruntime_extensions import (
onnx_op, PyCustomOpDef,
get_library_path as _get_library_path)


import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms


def _create_test_model(device:str="cpu", seed=42):
# Basic setup
use_cuda = "cuda" in device.lower() and torch.cuda.is_available()
torch.manual_seed(seed)

device = torch.device(device)

# Data loader stuff
export_kwargs = {'batch_size': 1}
if use_cuda:
export_kwargs = {'num_workers': 1,
'pin_memory': True,
'shuffle': True}
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
export_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
export_loader = torch.utils.data.DataLoader(export_dataset,**export_kwargs)

# Register custom op for relu in onnx and use in the model
# Domain must be "ai.onnx.contrib" to be compatible with onnxruntime-extensions
from torch.onnx import register_custom_op_symbolic

def com_amd_relu_1(g, input):
return g.op("ai.onnx.contrib::MyReLu", input).setType(input.type())

register_custom_op_symbolic("::relu", com_amd_relu_1, 9)

# Model
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, 5)
self.conv2 = nn.Conv2d(10, 20, 5)
self.conv2_drop = nn.Dropout2d()
self.dropout = nn.Dropout(0.5)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
self.relu = nn.ReLU()

def forward(self, x):
x = self.conv1(x)
x = torch.max_pool2d(x, 2)
x = self.relu(x)
x = self.conv2(x)
x = self.conv2_drop(x)
x = torch.max_pool2d(x, 2)
x = self.relu(x)
x = x.view(-1, 320)
x = self.fc1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output

# Exporting to ONNX with custom op
model = Net().to(device)
input_sample = next(iter(export_loader))
input_sample[0] = input_sample[0].to(device) # torch.randn(1,1,28,28, dtype=torch.float)
input_sample[1] = input_sample[1].to(device) # torch.randint(1,10,(1,))
f = io.BytesIO()
with torch.no_grad():
torch.onnx.export(model, (input_sample[0],), f)
model_proto = onnx.ModelProto.FromString(f.getvalue())
return model_proto, input_sample

class TestPythonOp(unittest.TestCase):

# Used to test custom op on PyThon.
# The ONNX graph has the custom which is executed by the function below
# @classmethod
# def setUpClass(cls):
# @onnx_op(op_type="MyReLu",
# inputs=[PyCustomOpDef.dt_float],
# outputs=[PyCustomOpDef.dt_float])
# def myrelu(x):
# return torch.relu(torch.from_numpy(x.copy()))

def test_python_myrelu(self):
# EPs = ['CPUExecutionProvider', 'CUDAExecutionProvider']
EPs = ['CPUExecutionProvider'] # TODO: Test with CUDA
DEVICEs = ["cpu", "cuda"]
for dev_idx, ep in enumerate(EPs):
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model, inputs = _create_test_model(device=DEVICEs[dev_idx], seed=42)
self.assertIn('op_type: "MyReLu"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=[ep])
out = sess.run(None, {'input.1': inputs[0].numpy(force=True)})

def test_cc_myrelu(self):
# EPs = ['CPUExecutionProvider', 'CUDAExecutionProvider']
EPs = ['CPUExecutionProvider'] # TODO: Test with CUDA
DEVICEs = ["cpu", "cuda"]
for dev_idx, ep in enumerate(EPs):
so = _ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
onnx_model, inputs = _create_test_model(device=DEVICEs[dev_idx], seed=42)
self.assertIn('op_type: "MyReLu"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=[ep])
out = sess.run(None, {'input.1': inputs[0].numpy(force=True)})


if __name__ == "__main__":
unittest.main()