dma-api: Teach the "DMA-from-stack" check about vmapped stacks

If we're using CONFIG_VMAP_STACK and we manage to point an sg entry
at the stack, then either the sg page will be in highmem or sg_virt
will return the direct-map alias.  In neither case will the existing
check_for_stack() implementation realize that it's a stack page.

Fix it by explicitly checking for stack pages.

Signed-off-by: Andy Lutomirski <luto@kernel.org>
diff --git a/lib/dma-debug.c b/lib/dma-debug.c
index 51a76af..5b2e63c 100644
--- a/lib/dma-debug.c
+++ b/lib/dma-debug.c
@@ -22,6 +22,7 @@
 #include <linux/stacktrace.h>
 #include <linux/dma-debug.h>
 #include <linux/spinlock.h>
+#include <linux/vmalloc.h>
 #include <linux/debugfs.h>
 #include <linux/uaccess.h>
 #include <linux/export.h>
@@ -1162,11 +1163,35 @@
 	put_hash_bucket(bucket, &flags);
 }
 
-static void check_for_stack(struct device *dev, void *addr)
+static void check_for_stack(struct device *dev,
+			    struct page *page, size_t offset)
 {
-	if (object_is_on_stack(addr))
-		err_printk(dev, NULL, "DMA-API: device driver maps memory from "
-				"stack [addr=%p]\n", addr);
+	void *addr;
+	struct vm_struct *stack_vm_area = task_stack_vm_area(current);
+
+	if (!stack_vm_area) {
+		/* Stack is direct-mapped. */
+		if (PageHighMem(page))
+			return;
+		addr = page_address(page) + offset;
+		if (object_is_on_stack(addr))
+			err_printk(dev, NULL, "DMA-API: device driver maps memory from stack [addr=%p]\n",
+				   addr);
+	} else {
+		/* Stack is vmalloced. */
+		int i;
+
+		for (i = 0; i < stack_vm_area->nr_pages; i++) {
+			if (page != stack_vm_area->pages[i])
+				continue;
+
+			addr = (u8 *)current->stack + i * PAGE_SIZE +
+				offset;
+			err_printk(dev, NULL, "DMA-API: device driver maps memory from stack [probable addr=%p]\n",
+				   addr);
+			break;
+		}
+	}
 }
 
 static inline bool overlap(void *addr, unsigned long len, void *start, void *end)
@@ -1289,10 +1314,11 @@
 	if (map_single)
 		entry->type = dma_debug_single;
 
+	check_for_stack(dev, page, offset);
+
 	if (!PageHighMem(page)) {
 		void *addr = page_address(page) + offset;
 
-		check_for_stack(dev, addr);
 		check_for_illegal_area(dev, addr, size);
 	}
 
@@ -1384,8 +1410,9 @@
 		entry->sg_call_ents   = nents;
 		entry->sg_mapped_ents = mapped_ents;
 
+		check_for_stack(dev, sg_page(s), s->offset);
+
 		if (!PageHighMem(sg_page(s))) {
-			check_for_stack(dev, sg_virt(s));
 			check_for_illegal_area(dev, sg_virt(s), sg_dma_len(s));
 		}