Skip to content

Commit

Permalink
Disable streams for the DML EP (#19481)
Browse files Browse the repository at this point in the history
There's currently a bug in the allocation planner when reusing buffers
and more than one streams are used that make it possible (although
rarely) to reach a reference count of 0 for a buffer that is still being
used. Since DML doesn't benefit from multiple streams, disabling it is
the safest option for now.

This is a high priority issue that we need to fix for 1.17.1 since it
breaks stable diffusion. Identifying the perfect fix and fixing the
underlying issue would be too risky for a patch release, especially
given the limited time that we have.

#19480
  • Loading branch information
PatriceVignola authored Feb 10, 2024
1 parent 0e984ef commit 1182b55
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 11 deletions.
9 changes: 7 additions & 2 deletions cmake/adjust_global_compile_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,13 @@ if (onnxruntime_MINIMAL_BUILD)
endif()
endif()

# enable stream for all the non-minimal build
if (NOT onnxruntime_MINIMAL_BUILD)
# Enable stream for all the non-minimal build, except for DML. There's currently a bug
# in the allocation planner when reusing buffers and more than one streams are used that
# make it possible (although rarely) to reach a reference count of 0 for a buffer that is
# still being used. Since DML doesn't benefit from multiple streams, disabling it is the
# safest option for now.
# https://github.com/microsoft/onnxruntime/issues/19480
if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_DML)
add_compile_definitions(ORT_ENABLE_STREAM)
endif()

Expand Down
21 changes: 17 additions & 4 deletions onnxruntime/test/framework/allocation_planner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,23 @@ class PlannerTest : public ::testing::Test {

if (invoke_createPlan_explicityly) {
onnxruntime::GraphViewer graph_viewer{graph_};
status = SequentialPlanner::CreatePlan(nullptr, graph_viewer, outer_scope_node_args, execution_providers_,
kernel_create_info_map, {}, {}, state_->GetOrtValueNameIdxMap(), test_context,
MockStreamHandleRegsitry(), /* {{kCpuExecutionProvider, 1}}, {},*/
ORT_TSTR(""), DefaultLoggingManager().DefaultLogger(), plan_);
status = SequentialPlanner::CreatePlan(
nullptr,
graph_viewer,
outer_scope_node_args,
execution_providers_,
kernel_create_info_map,
{},
{},
state_->GetOrtValueNameIdxMap(),
test_context,
#ifdef ORT_ENABLE_STREAM
MockStreamHandleRegsitry(),
#endif
/* {{kCpuExecutionProvider, 1}}, {},*/
ORT_TSTR(""),
DefaultLoggingManager().DefaultLogger(),
plan_);

EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
// AllocationPlanTestUtility::BasicIntegrityCheck(*plan_, name_to_arg_.size());
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/test/framework/bfc_arena_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ struct StreamMock : public Stream {
Status CleanUpOnRunEnd() override { return Status::OK(); }
};

#ifdef ORT_ENABLE_STREAM
TEST(StreamAwareArenaTest, TwoStreamAllocation) {
StreamAwareArena a(std::unique_ptr<IAllocator>(new CPUAllocator()), 1 << 30, false);
CheckStats(&a, 0, 0, 0, 0);
Expand Down Expand Up @@ -413,6 +414,7 @@ TEST(StreamAwareArenaTest, TestSecureTheChunk) {
EXPECT_TRUE(waitFunctionInvoked) << "wait function should be invoked";
a.Free(p2);
}
#endif

TEST(BFCArenaTest, TestExtendStrategy) {
int64_t extend_delta_bytes = 0;
Expand Down
55 changes: 50 additions & 5 deletions onnxruntime/test/framework/execution_frame_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,16 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) {
ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager));

vector<OrtValue> outputs;
ExecutionFrame frame({}, {}, {}, outputs, {}, {}, state);
ExecutionFrame frame(
{},
{},
{},
outputs,
{},
#ifdef ORT_ENABLE_STREAM
{},
#endif
state);

int start_index = frame.GetNodeOffset(node->Index());
ASSERT_EQ(start_index, 0);
Expand Down Expand Up @@ -150,7 +159,16 @@ TEST_F(ExecutionFrameTest, OutputShapeValidationTest) {
ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager));

vector<OrtValue> outputs;
ExecutionFrame frame({}, {}, {}, outputs, {}, {}, state);
ExecutionFrame frame(
{},
{},
{},
outputs,
{},
#ifdef ORT_ENABLE_STREAM
{},
#endif
state);

int start_index = frame.GetNodeOffset(node->Index());
ASSERT_EQ(start_index, 0);
Expand Down Expand Up @@ -216,7 +234,16 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) {
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("Y", y_idx).IsOK());

vector<OrtValue> outputs;
ExecutionFrame frame(AsSpan({x_idx}), AsSpan({value}), AsSpan({y_idx}), outputs, {}, {}, state);
ExecutionFrame frame(
AsSpan({x_idx}),
AsSpan({value}),
AsSpan({y_idx}),
outputs,
{},
#ifdef ORT_ENABLE_STREAM
{},
#endif
state);

OrtValue* p_ml_value = frame.GetMutableNodeInputOrOutputMLValue(0);
Tensor* p_tensor_arg_0 = p_ml_value ? p_ml_value->GetMutable<Tensor>() : nullptr;
Expand Down Expand Up @@ -299,7 +326,16 @@ TEST_F(ExecutionFrameTest, MemPatternTest) {
std::vector<float>(6, 1.0f), &v3);

std::vector<OrtValue> outputs;
ExecutionFrame frame(AsSpan({x1_idx, x2_idx, x3_idx}), AsSpan({v1, v2, v3}), AsSpan({t3_idx}), outputs, {}, {}, state);
ExecutionFrame frame(
AsSpan({x1_idx, x2_idx, x3_idx}),
AsSpan({v1, v2, v3}),
AsSpan({t3_idx}),
outputs,
{},
#ifdef ORT_ENABLE_STREAM
{},
#endif
state);

OrtValue& mlvalue3 = *frame.GetMutableNodeInputOrOutputMLValue(3);
OrtValue& mlvalue4 = *frame.GetMutableNodeInputOrOutputMLValue(4);
Expand Down Expand Up @@ -388,7 +424,16 @@ TEST_F(ExecutionFrameTest, MemPatternWithExternalOutputsTest) {
CreateMLValue<float>(cpu_allocator, std::vector<int64_t>{2, 2}, std::vector<float>(4, 1.0f), &t_value);

vector<OrtValue> outputs;
ExecutionFrame frame(AsSpan({x_idx}), AsSpan({x_value}), AsSpan({y_idx}), outputs, {}, {}, state);
ExecutionFrame frame(
AsSpan({x_idx}),
AsSpan({x_value}),
AsSpan({y_idx}),
outputs,
{},
#ifdef ORT_ENABLE_STREAM
{},
#endif
state);

ASSERT_FALSE(frame.GetMutableNodeInputOrOutputMLValue(t_idx)->IsTensor());
ASSERT_STATUS_OK(frame.SetOutputMLValue(t_idx, t_value));
Expand Down

0 comments on commit 1182b55

Please sign in to comment.