rust: devres: fix race condition due to nesting

sysrq: Show Blocked State
task:rmmod           state:D stack:0     pid:1331  tgid:1331  ppid:1330   task_flags:0x400100 flags:0x00000010
Call trace:
 __switch_to+0x190/0x294 (T)
 __schedule+0x878/0xf10
 schedule+0x4c/0xcc
 schedule_timeout+0x44/0x118
 wait_for_common+0xc0/0x18c
 wait_for_completion+0x18/0x24
 _RINvNtCs4gKlGRWyJ5S_4core3ptr13drop_in_placeINtNtNtCsgzhNYVB7wSz_6kernel4sync3arc3ArcINtNtBN_6devres6DevresmEEECsRdyc7Hyps3_15rust_driver_pci+0x68/0xe8 [rust_driver_pci]
 _RINvNvNtCsgzhNYVB7wSz_6kernel6devres16register_foreign8callbackINtNtCs4gKlGRWyJ5S_4core3pin3PinINtNtNtB6_5alloc4kbox3BoxINtNtNtB6_4sync3arc3ArcINtB4_6DevresmEENtNtB1A_9allocator7KmallocEEECsRdyc7Hyps3_15rust_driver_pci+0x34/0xc8 [rust_driver_pci]
 devm_action_release+0x14/0x20
 devres_release_all+0xb8/0x118
 device_release_driver_internal+0x1c4/0x28c
 driver_detach+0x94/0xd4
 bus_remove_driver+0xdc/0x11c
 driver_unregister+0x34/0x58
 pci_unregister_driver+0x20/0x80
 __arm64_sys_delete_module+0x1d8/0x254
 invoke_syscall+0x40/0xcc
 el0_svc_common+0x8c/0xd8
 do_el0_svc+0x1c/0x28
 el0_svc+0x54/0x1d4
 el0t_64_sync_handler+0x84/0x12c
 el0t_64_sync+0x198/0x19c

Signed-off-by: Danilo Krummrich <dakr@kernel.org>
diff --git a/rust/kernel/devres.rs b/rust/kernel/devres.rs
index cdc4967..392be82 100644
--- a/rust/kernel/devres.rs
+++ b/rust/kernel/devres.rs
@@ -21,30 +21,11 @@
     sync::{
         aref::ARef,
         rcu,
-        Completion, //
+        Arc, //
     },
-    types::{
-        ForeignOwnable,
-        Opaque,
-        ScopeGuard, //
-    },
+    types::ForeignOwnable,
 };
 
-use pin_init::Wrapper;
-
-/// [`Devres`] inner data accessed from [`Devres::callback`].
-#[pin_data]
-struct Inner<T: Send> {
-    #[pin]
-    data: Revocable<T>,
-    /// Tracks whether [`Devres::callback`] has been completed.
-    #[pin]
-    devm: Completion,
-    /// Tracks whether revoking [`Self::data`] has been completed.
-    #[pin]
-    revoke: Completion,
-}
-
 /// This abstraction is meant to be used by subsystems to containerize [`Device`] bound resources to
 /// manage their lifetime.
 ///
@@ -128,10 +109,6 @@ struct Inner<T: Send> {
 /// # Ok(())
 /// # }
 /// ```
-///
-/// # Invariants
-///
-/// `Self::inner` is guaranteed to be initialized and is always accessed read-only.
 #[pin_data(PinnedDrop)]
 pub struct Devres<T: Send> {
     dev: ARef<Device>,
@@ -140,13 +117,8 @@ pub struct Devres<T: Send> {
     /// Has to be stored, since Rust does not guarantee to always return the same address for a
     /// function. However, the C API uses the address as a key.
     callback: unsafe extern "C" fn(*mut c_void),
-    /// Contains all the fields shared with [`Self::callback`].
-    // TODO: Replace with `UnsafePinned`, once available.
-    //
-    // Subsequently, the `drop_in_place()` in `Devres::drop` and `Devres::new` as well as the
-    // explicit `Send` and `Sync' impls can be removed.
     #[pin]
-    inner: Opaque<Inner<T>>,
+    data: Arc<Revocable<T>>,
     _add_action: (),
 }
 
@@ -163,66 +135,45 @@ pub fn new<'a, E>(
         T: 'a,
         Error: From<E>,
     {
-        try_pin_init!(&this in Self {
+        try_pin_init!(Self {
             dev: dev.into(),
             callback: Self::devres_callback,
-            // INVARIANT: `inner` is properly initialized.
-            inner <- Opaque::pin_init(try_pin_init!(Inner {
-                    devm <- Completion::new(),
-                    revoke <- Completion::new(),
-                    data <- Revocable::new(data),
-            })),
+            data: Arc::pin_init(Revocable::new(data), GFP_KERNEL)?,
             // TODO: Replace with "initializer code blocks" [1] once available.
             //
             // [1] https://github.com/Rust-for-Linux/pin-init/pull/69
             _add_action: {
-                // SAFETY: `this` is a valid pointer to uninitialized memory.
-                let inner = unsafe { &raw mut (*this.as_ptr()).inner };
-
                 // SAFETY:
                 // - `dev.as_raw()` is a pointer to a valid bound device.
-                // - `inner` is guaranteed to be a valid for the duration of the lifetime of `Self`.
+                // - `data` is guaranteed to be a valid for the duration of the lifetime of `Self`.
                 // - `devm_add_action()` is guaranteed not to call `callback` until `this` has been
                 //    properly initialized, because we require `dev` (i.e. the *bound* device) to
                 //    live at least as long as the returned `impl PinInit<Self, Error>`.
                 to_result(unsafe {
-                    bindings::devm_add_action(dev.as_raw(), Some(*callback), inner.cast())
-                }).inspect_err(|_| {
-                    let inner = Opaque::cast_into(inner);
-
-                    // SAFETY: `inner` is a valid pointer to an `Inner<T>` and valid for both reads
-                    // and writes.
-                    unsafe { core::ptr::drop_in_place(inner) };
+                    bindings::devm_add_action(
+                        dev.as_raw(),
+                        Some(*callback),
+                        Arc::as_ptr(&data).cast_mut().cast(),
+                    )
                 })?;
+
+                // Take additional reference count for `devm_add_action()`.
+                core::mem::forget(data.clone());
             },
         })
     }
 
-    fn inner(&self) -> &Inner<T> {
-        // SAFETY: By the type invairants of `Self`, `inner` is properly initialized and always
-        // accessed read-only.
-        unsafe { &*self.inner.get() }
-    }
-
     fn data(&self) -> &Revocable<T> {
-        &self.inner().data
+        &self.data
     }
 
     #[allow(clippy::missing_safety_doc)]
     unsafe extern "C" fn devres_callback(ptr: *mut kernel::ffi::c_void) {
-        // SAFETY: In `Self::new` we've passed a valid pointer to `Inner` to `devm_add_action()`,
-        // hence `ptr` must be a valid pointer to `Inner`.
-        let inner = unsafe { &*ptr.cast::<Inner<T>>() };
+        // SAFETY: In `Self::new` we've passed a valid pointer of `Revocable<T>` to
+        // `devm_add_action()`, hence `ptr` must be a valid pointer to `Revocable<T>`.
+        let data = unsafe { Arc::from_raw(ptr.cast::<Revocable<T>>()) };
 
-        // Ensure that `inner` can't be used anymore after we signal completion of this callback.
-        let inner = ScopeGuard::new_with_data(inner, |inner| inner.devm.complete_all());
-
-        if !inner.data.revoke() {
-            // If `revoke()` returns false, it means that `Devres::drop` already started revoking
-            // `data` for us. Hence we have to wait until `Devres::drop` signals that it
-            // completed revoking `data`.
-            inner.revoke.wait_for_completion();
-        }
+        data.revoke();
     }
 
     fn remove_action(&self) -> bool {
@@ -234,7 +185,7 @@ fn remove_action(&self) -> bool {
             bindings::devm_remove_action_nowarn(
                 self.dev.as_raw(),
                 Some(self.callback),
-                core::ptr::from_ref(self.inner()).cast_mut().cast(),
+                core::ptr::from_ref(self.data()).cast_mut().cast(),
             )
         } == 0)
     }
@@ -320,24 +271,13 @@ fn drop(self: Pin<&mut Self>) {
         // anymore, hence it is safe not to wait for the grace period to finish.
         if unsafe { self.data().revoke_nosync() } {
             // We revoked `self.data` before the devres action did, hence try to remove it.
-            if !self.remove_action() {
-                // We could not remove the devres action, which means that it now runs concurrently,
-                // hence signal that `self.data` has been revoked by us successfully.
-                self.inner().revoke.complete_all();
-
-                // Wait for `Self::devres_callback` to be done using this object.
-                self.inner().devm.wait_for_completion();
+            if self.remove_action() {
+                // SAFETY: In `Self::new` we have taken an additional reference count for
+                // `devm_add_action`. Since `remove_action()` was successful, we have to drop this
+                // additional reference count.
+                drop(unsafe { Arc::from_raw(Arc::as_ptr(&self.data)) });
             }
-        } else {
-            // `Self::devres_callback` revokes `self.data` for us, hence wait for it to be done
-            // using this object.
-            self.inner().devm.wait_for_completion();
         }
-
-        // INVARIANT: At this point it is guaranteed that `inner` can't be accessed any more.
-        //
-        // SAFETY: `inner` is valid for dropping.
-        unsafe { core::ptr::drop_in_place(self.inner.get()) };
     }
 }