diff --git a/maitake-sync/src/mutex.rs b/maitake-sync/src/mutex.rs index af71c6e0..62b98176 100644 --- a/maitake-sync/src/mutex.rs +++ b/maitake-sync/src/mutex.rs @@ -348,8 +348,12 @@ where // === impl Lock === -impl<'a, T> Future for Lock<'a, T> { - type Output = MutexGuard<'a, T>; +impl<'a, T, L> Future for Lock<'a, T, L> +where + T: ?Sized, + L: ScopedRawMutex, +{ + type Output = MutexGuard<'a, T, L>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); diff --git a/maitake-sync/src/mutex/tests.rs b/maitake-sync/src/mutex/tests.rs index 813b7f65..43b7b147 100644 --- a/maitake-sync/src/mutex/tests.rs +++ b/maitake-sync/src/mutex/tests.rs @@ -1,3 +1,5 @@ +use mutex_traits::ScopedRawMutex; + use crate::loom::{self, future}; use crate::Mutex; @@ -54,3 +56,50 @@ fn basic_multi_threaded() { }) }); } + +struct NopRawMutex; + +unsafe impl ScopedRawMutex for NopRawMutex { + fn try_with_lock(&self, _: impl FnOnce() -> R) -> Option { + None + } + + fn with_lock(&self, _: impl FnOnce() -> R) -> R { + unimplemented!("this doesn't actually do anything") + } + + fn is_locked(&self) -> bool { + true + } +} + +fn assert_future(_: F) {} + +#[test] +fn lock_future_impls_future() { + loom::model(|| { + // Mutex with `DefaultMutex` as the `ScopedRawMutex` implementation + let mutex = Mutex::new(()); + assert_future(mutex.lock()); + + // Mutex with a custom `ScopedRawMutex` implementation + let mutex = Mutex::new_with_raw_mutex((), NopRawMutex); + assert_future(mutex.lock()); + }) +} + +#[test] +#[cfg(feature = "alloc")] +fn lock_owned_future_impls_future() { + loom::model(|| { + use alloc::sync::Arc; + + // Mutex with `DefaultMutex` as the `ScopedRawMutex` implementation + let mutex = Arc::new(Mutex::new(())); + assert_future(mutex.lock_owned()); + + // Mutex with a custom `ScopedRawMutex` implementation + let mutex = Arc::new(Mutex::new_with_raw_mutex((), NopRawMutex)); + assert_future(mutex.lock_owned()); + }) +}