test_user_copy: Improve failure output

Improve the test() macro so that values can be printed in the failure
message and so that the result can be compared against specific values.

This helps for debugging, and allows test_user_copy to be more easily
extended to test in loops without the failure output becoming confusing.

We make use of the value comparison to ensure that failing copies return
a specific value (PAGE_SIZE) for the number of bytes that couldn't be
copied.

Signed-off-by: James Hogan <james.hogan@imgtec.com>
Cc: linux-kernel@vger.kernel.org
Cc: Kees Cook <keescook@chromium.org>
diff --git a/lib/test_user_copy.c b/lib/test_user_copy.c
index 1a8d71a..964622b 100644
--- a/lib/test_user_copy.c
+++ b/lib/test_user_copy.c
@@ -43,12 +43,12 @@
 # define TEST_U64
 #endif
 
-#define test(condition, msg)		\
-({					\
-	int cond = (condition);		\
-	if (cond)			\
-		pr_warn("%s\n", msg);	\
-	cond;				\
+#define test(exp, val, msg, ...)			\
+({							\
+	int cond = (exp) != (val);			\
+	if (cond)					\
+		pr_warn(msg "\n", ##__VA_ARGS__);	\
+	cond;						\
 })
 
 static int __init test_user_copy_init(void)
@@ -57,7 +57,7 @@
 	char *kmem;
 	char __user *usermem;
 	char *bad_usermem;
-	unsigned long user_addr;
+	unsigned long user_addr, n;
 	u8 val_u8;
 	u16 val_u16;
 	u32 val_u32;
@@ -85,23 +85,23 @@
 	 * Legitimate usage: none of these copies should fail.
 	 */
 	memset(kmem, 0x3a, PAGE_SIZE * 2);
-	ret |= test(copy_to_user(usermem, kmem, PAGE_SIZE),
-		    "legitimate copy_to_user failed");
+	ret |= test(0, n = copy_to_user(usermem, kmem, PAGE_SIZE),
+		    "legitimate copy_to_user failed to write %lu bytes", n);
 	memset(kmem, 0x0, PAGE_SIZE);
-	ret |= test(copy_from_user(kmem, usermem, PAGE_SIZE),
-		    "legitimate copy_from_user failed");
-	ret |= test(memcmp(kmem, kmem + PAGE_SIZE, PAGE_SIZE),
+	ret |= test(0, n = copy_from_user(kmem, usermem, PAGE_SIZE),
+		    "legitimate copy_from_user failed to read %lu bytes", n);
+	ret |= test(0, memcmp(kmem, kmem + PAGE_SIZE, PAGE_SIZE),
 		    "legitimate usercopy failed to copy data");
 
 #define test_legit(size, check)						  \
 	do {								  \
 		val_##size = check;					  \
-		ret |= test(put_user(val_##size, (size __user *)usermem), \
+		ret |= test(0, put_user(val_##size, (size __user *)usermem), \
 		    "legitimate put_user (" #size ") failed");		  \
 		val_##size = 0;						  \
-		ret |= test(get_user(val_##size, (size __user *)usermem), \
+		ret |= test(0, get_user(val_##size, (size __user *)usermem), \
 		    "legitimate get_user (" #size ") failed");		  \
-		ret |= test(val_##size != check,			  \
+		ret |= test(0, val_##size != check,			  \
 		    "legitimate get_user (" #size ") failed to do copy"); \
 		if (val_##size != check) {				  \
 			pr_info("0x%llx != 0x%llx\n",			  \
@@ -127,12 +127,14 @@
 	memset(kmem + PAGE_SIZE, 0, PAGE_SIZE);
 
 	/* Reject kernel-to-kernel copies through copy_from_user(). */
-	ret |= test(!copy_from_user(kmem, (char __user *)(kmem + PAGE_SIZE),
-				    PAGE_SIZE),
-		    "illegal all-kernel copy_from_user passed");
+	ret |= test(PAGE_SIZE,
+		    n = copy_from_user(kmem, (char __user *)(kmem + PAGE_SIZE),
+				       PAGE_SIZE),
+		    "illegal all-kernel copy_from_user failed to read %lu bytes instead of %lu",
+		    n, PAGE_SIZE);
 
 	/* Destination half of buffer should have been zeroed. */
-	ret |= test(memcmp(kmem + PAGE_SIZE, kmem, PAGE_SIZE),
+	ret |= test(0, memcmp(kmem + PAGE_SIZE, kmem, PAGE_SIZE),
 		    "zeroing failure for illegal all-kernel copy_from_user");
 
 #if 0
@@ -142,29 +144,35 @@
 	 * to be tested in LKDTM instead, since this test module does not
 	 * expect to explode.
 	 */
-	ret |= test(!copy_from_user(bad_usermem, (char __user *)kmem,
-				    PAGE_SIZE),
-		    "illegal reversed copy_from_user passed");
+	ret |= test(PAGE_SIZE,
+		    n = copy_from_user(bad_usermem, (char __user *)kmem,
+				       PAGE_SIZE),
+		    "illegal reversed copy_from_user failed to read %lu bytes instead of %lu",
+		    n, PAGE_SIZE);
 #endif
-	ret |= test(!copy_to_user((char __user *)kmem, kmem + PAGE_SIZE,
-				  PAGE_SIZE),
-		    "illegal all-kernel copy_to_user passed");
-	ret |= test(!copy_to_user((char __user *)kmem, bad_usermem,
-				  PAGE_SIZE),
-		    "illegal reversed copy_to_user passed");
+	ret |= test(PAGE_SIZE,
+		    n = copy_to_user((char __user *)kmem, kmem + PAGE_SIZE,
+				     PAGE_SIZE),
+		    "illegal all-kernel copy_to_user failed to read %lu bytes instead of %lu",
+		    n, PAGE_SIZE);
+	ret |= test(PAGE_SIZE,
+		    n = copy_to_user((char __user *)kmem, bad_usermem,
+				     PAGE_SIZE),
+		    "illegal reversed copy_to_user failed to read %lu bytes instead of %lu",
+		    n, PAGE_SIZE);
 
 #define test_illegal(size, check)					    \
 	do {								    \
 		val_##size = (check);					    \
-		ret |= test(!get_user(val_##size, (size __user *)kmem),	    \
+		ret |= test(0, !get_user(val_##size, (size __user *)kmem),  \
 		    "illegal get_user (" #size ") passed");		    \
-		ret |= test(val_##size != (size)0,			    \
+		ret |= test((size)0, val_##size,			    \
 		    "zeroing failure for illegal get_user (" #size ")");    \
 		if (val_##size != (size)0) {				    \
 			pr_info("0x%llx != 0\n",			    \
 				(unsigned long long)val_##size);	    \
 		}							    \
-		ret |= test(!put_user(val_##size, (size __user *)kmem),	    \
+		ret |= test(0, !put_user(val_##size, (size __user *)kmem),  \
 		    "illegal put_user (" #size ") passed");		    \
 	} while (0)