|  | // SPDX-License-Identifier: GPL-2.0 | 
|  |  | 
|  | //! Procedural macro to run KUnit tests using a user-space like syntax. | 
|  | //! | 
|  | //! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com> | 
|  |  | 
|  | use proc_macro::{Delimiter, Group, TokenStream, TokenTree}; | 
|  | use std::collections::HashMap; | 
|  | use std::fmt::Write; | 
|  |  | 
|  | pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { | 
|  | let attr = attr.to_string(); | 
|  |  | 
|  | if attr.is_empty() { | 
|  | panic!("Missing test name in `#[kunit_tests(test_name)]` macro") | 
|  | } | 
|  |  | 
|  | if attr.len() > 255 { | 
|  | panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes") | 
|  | } | 
|  |  | 
|  | let mut tokens: Vec<_> = ts.into_iter().collect(); | 
|  |  | 
|  | // Scan for the `mod` keyword. | 
|  | tokens | 
|  | .iter() | 
|  | .find_map(|token| match token { | 
|  | TokenTree::Ident(ident) => match ident.to_string().as_str() { | 
|  | "mod" => Some(true), | 
|  | _ => None, | 
|  | }, | 
|  | _ => None, | 
|  | }) | 
|  | .expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules"); | 
|  |  | 
|  | // Retrieve the main body. The main body should be the last token tree. | 
|  | let body = match tokens.pop() { | 
|  | Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, | 
|  | _ => panic!("Cannot locate main body of module"), | 
|  | }; | 
|  |  | 
|  | // Get the functions set as tests. Search for `[test]` -> `fn`. | 
|  | let mut body_it = body.stream().into_iter(); | 
|  | let mut tests = Vec::new(); | 
|  | let mut attributes: HashMap<String, TokenStream> = HashMap::new(); | 
|  | while let Some(token) = body_it.next() { | 
|  | match token { | 
|  | TokenTree::Punct(ref p) if p.as_char() == '#' => match body_it.next() { | 
|  | Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => { | 
|  | if let Some(TokenTree::Ident(name)) = g.stream().into_iter().next() { | 
|  | // Collect attributes because we need to find which are tests. We also | 
|  | // need to copy `cfg` attributes so tests can be conditionally enabled. | 
|  | attributes | 
|  | .entry(name.to_string()) | 
|  | .or_default() | 
|  | .extend([token, TokenTree::Group(g)]); | 
|  | } | 
|  | continue; | 
|  | } | 
|  | _ => (), | 
|  | }, | 
|  | TokenTree::Ident(i) if i.to_string() == "fn" && attributes.contains_key("test") => { | 
|  | if let Some(TokenTree::Ident(test_name)) = body_it.next() { | 
|  | tests.push((test_name, attributes.remove("cfg").unwrap_or_default())) | 
|  | } | 
|  | } | 
|  |  | 
|  | _ => (), | 
|  | } | 
|  | attributes.clear(); | 
|  | } | 
|  |  | 
|  | // Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration. | 
|  | let config_kunit = "#[cfg(CONFIG_KUNIT=\"y\")]".to_owned().parse().unwrap(); | 
|  | tokens.insert( | 
|  | 0, | 
|  | TokenTree::Group(Group::new(Delimiter::None, config_kunit)), | 
|  | ); | 
|  |  | 
|  | // Generate the test KUnit test suite and a test case for each `#[test]`. | 
|  | // The code generated for the following test module: | 
|  | // | 
|  | // ``` | 
|  | // #[kunit_tests(kunit_test_suit_name)] | 
|  | // mod tests { | 
|  | //     #[test] | 
|  | //     fn foo() { | 
|  | //         assert_eq!(1, 1); | 
|  | //     } | 
|  | // | 
|  | //     #[test] | 
|  | //     fn bar() { | 
|  | //         assert_eq!(2, 2); | 
|  | //     } | 
|  | // } | 
|  | // ``` | 
|  | // | 
|  | // Looks like: | 
|  | // | 
|  | // ``` | 
|  | // unsafe extern "C" fn kunit_rust_wrapper_foo(_test: *mut ::kernel::bindings::kunit) { foo(); } | 
|  | // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut ::kernel::bindings::kunit) { bar(); } | 
|  | // | 
|  | // static mut TEST_CASES: [::kernel::bindings::kunit_case; 3] = [ | 
|  | //     ::kernel::kunit::kunit_case(::kernel::c_str!("foo"), kunit_rust_wrapper_foo), | 
|  | //     ::kernel::kunit::kunit_case(::kernel::c_str!("bar"), kunit_rust_wrapper_bar), | 
|  | //     ::kernel::kunit::kunit_case_null(), | 
|  | // ]; | 
|  | // | 
|  | // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES); | 
|  | // ``` | 
|  | let mut kunit_macros = "".to_owned(); | 
|  | let mut test_cases = "".to_owned(); | 
|  | let mut assert_macros = "".to_owned(); | 
|  | let path = crate::helpers::file(); | 
|  | let num_tests = tests.len(); | 
|  | for (test, cfg_attr) in tests { | 
|  | let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}"); | 
|  | // Append any `cfg` attributes the user might have written on their tests so we don't | 
|  | // attempt to call them when they are `cfg`'d out. An extra `use` is used here to reduce | 
|  | // the length of the assert message. | 
|  | let kunit_wrapper = format!( | 
|  | r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit) | 
|  | {{ | 
|  | (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED; | 
|  | {cfg_attr} {{ | 
|  | (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS; | 
|  | use ::kernel::kunit::is_test_result_ok; | 
|  | assert!(is_test_result_ok({test}())); | 
|  | }} | 
|  | }}"#, | 
|  | ); | 
|  | writeln!(kunit_macros, "{kunit_wrapper}").unwrap(); | 
|  | writeln!( | 
|  | test_cases, | 
|  | "    ::kernel::kunit::kunit_case(::kernel::c_str!(\"{test}\"), {kunit_wrapper_fn_name})," | 
|  | ) | 
|  | .unwrap(); | 
|  | writeln!( | 
|  | assert_macros, | 
|  | r#" | 
|  | /// Overrides the usual [`assert!`] macro with one that calls KUnit instead. | 
|  | #[allow(unused)] | 
|  | macro_rules! assert {{ | 
|  | ($cond:expr $(,)?) => {{{{ | 
|  | kernel::kunit_assert!("{test}", "{path}", 0, $cond); | 
|  | }}}} | 
|  | }} | 
|  |  | 
|  | /// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead. | 
|  | #[allow(unused)] | 
|  | macro_rules! assert_eq {{ | 
|  | ($left:expr, $right:expr $(,)?) => {{{{ | 
|  | kernel::kunit_assert_eq!("{test}", "{path}", 0, $left, $right); | 
|  | }}}} | 
|  | }} | 
|  | "# | 
|  | ) | 
|  | .unwrap(); | 
|  | } | 
|  |  | 
|  | writeln!(kunit_macros).unwrap(); | 
|  | writeln!( | 
|  | kunit_macros, | 
|  | "static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases}    ::kernel::kunit::kunit_case_null(),\n];", | 
|  | num_tests + 1 | 
|  | ) | 
|  | .unwrap(); | 
|  |  | 
|  | writeln!( | 
|  | kunit_macros, | 
|  | "::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);" | 
|  | ) | 
|  | .unwrap(); | 
|  |  | 
|  | // Remove the `#[test]` macros. | 
|  | // We do this at a token level, in order to preserve span information. | 
|  | let mut new_body = vec![]; | 
|  | let mut body_it = body.stream().into_iter(); | 
|  |  | 
|  | while let Some(token) = body_it.next() { | 
|  | match token { | 
|  | TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() { | 
|  | Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (), | 
|  | Some(next) => { | 
|  | new_body.extend([token, next]); | 
|  | } | 
|  | _ => { | 
|  | new_body.push(token); | 
|  | } | 
|  | }, | 
|  | _ => { | 
|  | new_body.push(token); | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | let mut final_body = TokenStream::new(); | 
|  | final_body.extend::<TokenStream>(assert_macros.parse().unwrap()); | 
|  | final_body.extend(new_body); | 
|  | final_body.extend::<TokenStream>(kunit_macros.parse().unwrap()); | 
|  |  | 
|  | tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body))); | 
|  |  | 
|  | tokens.into_iter().collect() | 
|  | } |