linux/rust/pin-init/examples/mutex.rs
Benno Lossin 58cebd6888 rust: pin-init: examples, tests: add conditional compilation in order to compile under any feature combination
In the CI, all examples & tests should be run under all feature
combinations. Currently several examples & tests use `std` without
conditionally enabling it. Thus make them all compile under any feature
combination by conditionally disabling the code that uses e.g. `std`.

Link: fdfb70efdd
Link: https://lore.kernel.org/all/20250523125424.192843-2-lossin@kernel.org
Signed-off-by: Benno Lossin <lossin@kernel.org>
2025-06-11 21:13:56 +02:00

225 lines
5.8 KiB
Rust

// SPDX-License-Identifier: Apache-2.0 OR MIT
#![allow(clippy::undocumented_unsafe_blocks)]
#![cfg_attr(feature = "alloc", feature(allocator_api))]
#![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
#![allow(clippy::missing_safety_doc)]
use core::{
cell::{Cell, UnsafeCell},
marker::PhantomPinned,
ops::{Deref, DerefMut},
pin::Pin,
sync::atomic::{AtomicBool, Ordering},
};
#[cfg(feature = "std")]
use std::{
sync::Arc,
thread::{self, sleep, Builder, Thread},
time::Duration,
};
use pin_init::*;
#[allow(unused_attributes)]
#[path = "./linked_list.rs"]
pub mod linked_list;
use linked_list::*;
pub struct SpinLock {
inner: AtomicBool,
}
impl SpinLock {
#[inline]
pub fn acquire(&self) -> SpinLockGuard<'_> {
while self
.inner
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_err()
{
#[cfg(feature = "std")]
while self.inner.load(Ordering::Relaxed) {
thread::yield_now();
}
}
SpinLockGuard(self)
}
#[inline]
#[allow(clippy::new_without_default)]
pub const fn new() -> Self {
Self {
inner: AtomicBool::new(false),
}
}
}
pub struct SpinLockGuard<'a>(&'a SpinLock);
impl Drop for SpinLockGuard<'_> {
#[inline]
fn drop(&mut self) {
self.0.inner.store(false, Ordering::Release);
}
}
#[pin_data]
pub struct CMutex<T> {
#[pin]
wait_list: ListHead,
spin_lock: SpinLock,
locked: Cell<bool>,
#[pin]
data: UnsafeCell<T>,
}
impl<T> CMutex<T> {
#[inline]
pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
pin_init!(CMutex {
wait_list <- ListHead::new(),
spin_lock: SpinLock::new(),
locked: Cell::new(false),
data <- unsafe {
pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
val.__pinned_init(slot.cast::<T>())
})
},
})
}
#[inline]
pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
let mut sguard = self.spin_lock.acquire();
if self.locked.get() {
stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
// println!("wait list length: {}", self.wait_list.size());
while self.locked.get() {
drop(sguard);
#[cfg(feature = "std")]
thread::park();
sguard = self.spin_lock.acquire();
}
// This does have an effect, as the ListHead inside wait_entry implements Drop!
#[expect(clippy::drop_non_drop)]
drop(wait_entry);
}
self.locked.set(true);
unsafe {
Pin::new_unchecked(CMutexGuard {
mtx: self,
_pin: PhantomPinned,
})
}
}
#[allow(dead_code)]
pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
// SAFETY: we have an exclusive reference and thus nobody has access to data.
unsafe { &mut *self.data.get() }
}
}
unsafe impl<T: Send> Send for CMutex<T> {}
unsafe impl<T: Send> Sync for CMutex<T> {}
pub struct CMutexGuard<'a, T> {
mtx: &'a CMutex<T>,
_pin: PhantomPinned,
}
impl<T> Drop for CMutexGuard<'_, T> {
#[inline]
fn drop(&mut self) {
let sguard = self.mtx.spin_lock.acquire();
self.mtx.locked.set(false);
if let Some(list_field) = self.mtx.wait_list.next() {
let _wait_entry = list_field.as_ptr().cast::<WaitEntry>();
#[cfg(feature = "std")]
unsafe {
(*_wait_entry).thread.unpark()
};
}
drop(sguard);
}
}
impl<T> Deref for CMutexGuard<'_, T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe { &*self.mtx.data.get() }
}
}
impl<T> DerefMut for CMutexGuard<'_, T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.mtx.data.get() }
}
}
#[pin_data]
#[repr(C)]
struct WaitEntry {
#[pin]
wait_list: ListHead,
#[cfg(feature = "std")]
thread: Thread,
}
impl WaitEntry {
#[inline]
fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
#[cfg(feature = "std")]
{
pin_init!(Self {
thread: thread::current(),
wait_list <- ListHead::insert_prev(list),
})
}
#[cfg(not(feature = "std"))]
{
pin_init!(Self {
wait_list <- ListHead::insert_prev(list),
})
}
}
}
#[cfg_attr(test, test)]
#[allow(dead_code)]
fn main() {
#[cfg(feature = "std")]
{
let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
let mut handles = vec![];
let thread_count = 20;
let workload = if cfg!(miri) { 100 } else { 1_000 };
for i in 0..thread_count {
let mtx = mtx.clone();
handles.push(
Builder::new()
.name(format!("worker #{i}"))
.spawn(move || {
for _ in 0..workload {
*mtx.lock() += 1;
}
println!("{i} halfway");
sleep(Duration::from_millis((i as u64) * 10));
for _ in 0..workload {
*mtx.lock() += 1;
}
println!("{i} finished");
})
.expect("should not fail"),
);
}
for h in handles {
h.join().expect("thread panicked");
}
println!("{:?}", &*mtx.lock());
assert_eq!(*mtx.lock(), workload * thread_count * 2);
}
}