From 1182b5509ba2604856d02cf22795d6874252892e Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Sat, 10 Feb 2024 00:34:34 -0800 Subject: [PATCH] Disable streams for the DML EP (#19481) There's currently a bug in the allocation planner when reusing buffers and more than one streams are used that make it possible (although rarely) to reach a reference count of 0 for a buffer that is still being used. Since DML doesn't benefit from multiple streams, disabling it is the safest option for now. This is a high priority issue that we need to fix for 1.17.1 since it breaks stable diffusion. Identifying the perfect fix and fixing the underlying issue would be too risky for a patch release, especially given the limited time that we have. https://github.com/microsoft/onnxruntime/issues/19480 --- cmake/adjust_global_compile_flags.cmake | 9 ++- .../test/framework/allocation_planner_test.cc | 21 +++++-- onnxruntime/test/framework/bfc_arena_test.cc | 2 + .../test/framework/execution_frame_test.cc | 55 +++++++++++++++++-- 4 files changed, 76 insertions(+), 11 deletions(-) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 2c7bf9f1c2f5c..a56864ebf4644 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -92,8 +92,13 @@ if (onnxruntime_MINIMAL_BUILD) endif() endif() -# enable stream for all the non-minimal build -if (NOT onnxruntime_MINIMAL_BUILD) +# Enable stream for all the non-minimal build, except for DML. There's currently a bug +# in the allocation planner when reusing buffers and more than one streams are used that +# make it possible (although rarely) to reach a reference count of 0 for a buffer that is +# still being used. Since DML doesn't benefit from multiple streams, disabling it is the +# safest option for now. +# https://github.com/microsoft/onnxruntime/issues/19480 +if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_DML) add_compile_definitions(ORT_ENABLE_STREAM) endif() diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index b174ee4138be3..d7b1de5c930c5 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -327,10 +327,23 @@ class PlannerTest : public ::testing::Test { if (invoke_createPlan_explicityly) { onnxruntime::GraphViewer graph_viewer{graph_}; - status = SequentialPlanner::CreatePlan(nullptr, graph_viewer, outer_scope_node_args, execution_providers_, - kernel_create_info_map, {}, {}, state_->GetOrtValueNameIdxMap(), test_context, - MockStreamHandleRegsitry(), /* {{kCpuExecutionProvider, 1}}, {},*/ - ORT_TSTR(""), DefaultLoggingManager().DefaultLogger(), plan_); + status = SequentialPlanner::CreatePlan( + nullptr, + graph_viewer, + outer_scope_node_args, + execution_providers_, + kernel_create_info_map, + {}, + {}, + state_->GetOrtValueNameIdxMap(), + test_context, +#ifdef ORT_ENABLE_STREAM + MockStreamHandleRegsitry(), +#endif + /* {{kCpuExecutionProvider, 1}}, {},*/ + ORT_TSTR(""), + DefaultLoggingManager().DefaultLogger(), + plan_); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); // AllocationPlanTestUtility::BasicIntegrityCheck(*plan_, name_to_arg_.size()); diff --git a/onnxruntime/test/framework/bfc_arena_test.cc b/onnxruntime/test/framework/bfc_arena_test.cc index 0d3e4449da939..e9f734057da1c 100644 --- a/onnxruntime/test/framework/bfc_arena_test.cc +++ b/onnxruntime/test/framework/bfc_arena_test.cc @@ -337,6 +337,7 @@ struct StreamMock : public Stream { Status CleanUpOnRunEnd() override { return Status::OK(); } }; +#ifdef ORT_ENABLE_STREAM TEST(StreamAwareArenaTest, TwoStreamAllocation) { StreamAwareArena a(std::unique_ptr(new CPUAllocator()), 1 << 30, false); CheckStats(&a, 0, 0, 0, 0); @@ -413,6 +414,7 @@ TEST(StreamAwareArenaTest, TestSecureTheChunk) { EXPECT_TRUE(waitFunctionInvoked) << "wait function should be invoked"; a.Free(p2); } +#endif TEST(BFCArenaTest, TestExtendStrategy) { int64_t extend_delta_bytes = 0; diff --git a/onnxruntime/test/framework/execution_frame_test.cc b/onnxruntime/test/framework/execution_frame_test.cc index ec572ce9deed8..60752d7456d97 100644 --- a/onnxruntime/test/framework/execution_frame_test.cc +++ b/onnxruntime/test/framework/execution_frame_test.cc @@ -75,7 +75,16 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) { ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); vector outputs; - ExecutionFrame frame({}, {}, {}, outputs, {}, {}, state); + ExecutionFrame frame( + {}, + {}, + {}, + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); int start_index = frame.GetNodeOffset(node->Index()); ASSERT_EQ(start_index, 0); @@ -150,7 +159,16 @@ TEST_F(ExecutionFrameTest, OutputShapeValidationTest) { ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); vector outputs; - ExecutionFrame frame({}, {}, {}, outputs, {}, {}, state); + ExecutionFrame frame( + {}, + {}, + {}, + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); int start_index = frame.GetNodeOffset(node->Index()); ASSERT_EQ(start_index, 0); @@ -216,7 +234,16 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) { ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("Y", y_idx).IsOK()); vector outputs; - ExecutionFrame frame(AsSpan({x_idx}), AsSpan({value}), AsSpan({y_idx}), outputs, {}, {}, state); + ExecutionFrame frame( + AsSpan({x_idx}), + AsSpan({value}), + AsSpan({y_idx}), + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); OrtValue* p_ml_value = frame.GetMutableNodeInputOrOutputMLValue(0); Tensor* p_tensor_arg_0 = p_ml_value ? p_ml_value->GetMutable() : nullptr; @@ -299,7 +326,16 @@ TEST_F(ExecutionFrameTest, MemPatternTest) { std::vector(6, 1.0f), &v3); std::vector outputs; - ExecutionFrame frame(AsSpan({x1_idx, x2_idx, x3_idx}), AsSpan({v1, v2, v3}), AsSpan({t3_idx}), outputs, {}, {}, state); + ExecutionFrame frame( + AsSpan({x1_idx, x2_idx, x3_idx}), + AsSpan({v1, v2, v3}), + AsSpan({t3_idx}), + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); OrtValue& mlvalue3 = *frame.GetMutableNodeInputOrOutputMLValue(3); OrtValue& mlvalue4 = *frame.GetMutableNodeInputOrOutputMLValue(4); @@ -388,7 +424,16 @@ TEST_F(ExecutionFrameTest, MemPatternWithExternalOutputsTest) { CreateMLValue(cpu_allocator, std::vector{2, 2}, std::vector(4, 1.0f), &t_value); vector outputs; - ExecutionFrame frame(AsSpan({x_idx}), AsSpan({x_value}), AsSpan({y_idx}), outputs, {}, {}, state); + ExecutionFrame frame( + AsSpan({x_idx}), + AsSpan({x_value}), + AsSpan({y_idx}), + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); ASSERT_FALSE(frame.GetMutableNodeInputOrOutputMLValue(t_idx)->IsTensor()); ASSERT_STATUS_OK(frame.SetOutputMLValue(t_idx, t_value));