fuse: allow dax code to be configured out

A separate spinlock for protecting the dax map has been split out from
fc->lock.  

Signed-off-by: Miklos Szeredi <mszeredi@redhat.com>
diff --git a/fs/fuse/Kconfig b/fs/fuse/Kconfig
index ad0e421..82c5340 100644
--- a/fs/fuse/Kconfig
+++ b/fs/fuse/Kconfig
@@ -39,3 +39,16 @@
 
 	  If you want to share files between guests or with the host, answer Y
 	  or M.
+
+config FUSE_DAX
+	bool "Virtio Filesystem Direct Host Memory Access support"
+	default y
+	depends on VIRTIO_FS
+	depends on FS_DAX
+	depends on DAX_DRIVER
+	help
+	  This allows bypassing guest page cache and allows mapping host page
+	  cache directly in guest address space.
+
+	  If you want to allow mounting a Virtio Filesystem with the "dax"
+	  option, answer Y.
diff --git a/fs/fuse/Makefile b/fs/fuse/Makefile
index a0ec0ee..8c7021f 100644
--- a/fs/fuse/Makefile
+++ b/fs/fuse/Makefile
@@ -7,5 +7,7 @@
 obj-$(CONFIG_CUSE) += cuse.o
 obj-$(CONFIG_VIRTIO_FS) += virtiofs.o
 
-fuse-objs := dev.o dir.o file.o inode.o control.o xattr.o acl.o readdir.o dax.o
-virtiofs-y += virtio_fs.o
+fuse-y := dev.o dir.o file.o inode.o control.o xattr.o acl.o readdir.o
+fuse-$(CONFIG_FUSE_DAX) += dax.o
+
+virtiofs-y := virtio_fs.o
diff --git a/fs/fuse/dax.c b/fs/fuse/dax.c
index 77d2108..c7099ee 100644
--- a/fs/fuse/dax.c
+++ b/fs/fuse/dax.c
@@ -10,6 +10,15 @@
 #include <linux/iomap.h>
 #include <linux/uio.h>
 #include <linux/pfn_t.h>
+#include <linux/interval_tree.h>
+
+/*
+ * Default memory range size.  A power of 2 so it agrees with common FUSE_INIT
+ * map_alignment values 4KB and 64KB.
+ */
+#define FUSE_DAX_SHIFT	21
+#define FUSE_DAX_SZ	(1 << FUSE_DAX_SHIFT)
+#define FUSE_DAX_PAGES	(FUSE_DAX_SZ/PAGE_SIZE)
 
 /* Number of ranges reclaimer will try to free in one invocation */
 #define FUSE_DAX_RECLAIM_CHUNK		(10)
@@ -48,6 +57,41 @@
 	refcount_t refcnt;
 };
 
+/* Per-inode dax map */
+struct fuse_inode_dax {
+	/* Semaphore to protect modifications to the dmap tree */
+	struct rw_semaphore sem;
+
+	/* Sorted rb tree of struct fuse_dax_mapping elements */
+	struct rb_root_cached tree;
+	unsigned long nr;
+};
+
+
+struct fuse_conn_dax {
+	/* DAX device, non-NULL if DAX is supported */
+	struct dax_device *dev;
+
+	/* Lock protecting accessess to  members of this structure */
+	spinlock_t lock;
+
+	/* List of memory ranges which are busy */
+	unsigned long nr_busy_ranges;
+	struct list_head busy_ranges;
+
+	/* Worker to free up memory ranges */
+	struct delayed_work free_work;
+
+	/* Wait queue for a dax range to become free */
+	wait_queue_head_t range_waitq;
+
+	/* DAX Window Free Ranges */
+	long nr_free_ranges;
+	struct list_head free_ranges;
+
+	unsigned long nr_ranges;
+};
+
 static inline struct fuse_dax_mapping *
 node_to_dmap(struct interval_tree_node *node)
 {
@@ -57,86 +101,88 @@
 	return container_of(node, struct fuse_dax_mapping, itn);
 }
 
-static struct fuse_dax_mapping *alloc_dax_mapping_reclaim(struct fuse_conn *fc,
-							struct inode *inode);
+static struct fuse_dax_mapping *
+alloc_dax_mapping_reclaim(struct fuse_conn_dax *fcd, struct inode *inode);
+
 static void
-__kick_dmap_free_worker(struct fuse_conn *fc, unsigned long delay_ms)
+__kick_dmap_free_worker(struct fuse_conn_dax *fcd, unsigned long delay_ms)
 {
 	unsigned long free_threshold;
 
 	/* If number of free ranges are below threshold, start reclaim */
-	free_threshold = max_t(unsigned long, fc->nr_ranges * FUSE_DAX_RECLAIM_THRESHOLD / 100,
+	free_threshold = max_t(unsigned long, fcd->nr_ranges * FUSE_DAX_RECLAIM_THRESHOLD / 100,
 			     1);
-	if (fc->nr_free_ranges < free_threshold)
-		queue_delayed_work(system_long_wq, &fc->dax_free_work,
+	if (fcd->nr_free_ranges < free_threshold)
+		queue_delayed_work(system_long_wq, &fcd->free_work,
 				   msecs_to_jiffies(delay_ms));
 }
 
-static void kick_dmap_free_worker(struct fuse_conn *fc, unsigned long delay_ms)
+static void kick_dmap_free_worker(struct fuse_conn_dax *fcd,
+				  unsigned long delay_ms)
 {
-	spin_lock(&fc->lock);
-	__kick_dmap_free_worker(fc, delay_ms);
-	spin_unlock(&fc->lock);
+	spin_lock(&fcd->lock);
+	__kick_dmap_free_worker(fcd, delay_ms);
+	spin_unlock(&fcd->lock);
 }
 
-static struct fuse_dax_mapping *alloc_dax_mapping(struct fuse_conn *fc)
+static struct fuse_dax_mapping *alloc_dax_mapping(struct fuse_conn_dax *fcd)
 {
 	struct fuse_dax_mapping *dmap = NULL;
 
-	spin_lock(&fc->lock);
+	spin_lock(&fcd->lock);
 
-	if (fc->nr_free_ranges <= 0) {
-		spin_unlock(&fc->lock);
+	if (fcd->nr_free_ranges <= 0) {
+		spin_unlock(&fcd->lock);
 		goto out_kick;
 	}
 
-	WARN_ON(list_empty(&fc->free_ranges));
+	WARN_ON(list_empty(&fcd->free_ranges));
 
 	/* Take a free range */
-	dmap = list_first_entry(&fc->free_ranges, struct fuse_dax_mapping,
+	dmap = list_first_entry(&fcd->free_ranges, struct fuse_dax_mapping,
 					list);
 	list_del_init(&dmap->list);
-	fc->nr_free_ranges--;
-	spin_unlock(&fc->lock);
+	fcd->nr_free_ranges--;
+	spin_unlock(&fcd->lock);
 
 out_kick:
-	kick_dmap_free_worker(fc, 0);
+	kick_dmap_free_worker(fcd, 0);
 	return dmap;
 }
 
-/* This assumes fc->lock is held */
-static void __dmap_remove_busy_list(struct fuse_conn *fc,
+/* This assumes fcd->lock is held */
+static void __dmap_remove_busy_list(struct fuse_conn_dax *fcd,
 				    struct fuse_dax_mapping *dmap)
 {
 	list_del_init(&dmap->busy_list);
-	WARN_ON(fc->nr_busy_ranges == 0);
-	fc->nr_busy_ranges--;
+	WARN_ON(fcd->nr_busy_ranges == 0);
+	fcd->nr_busy_ranges--;
 }
 
-static void dmap_remove_busy_list(struct fuse_conn *fc,
+static void dmap_remove_busy_list(struct fuse_conn_dax *fcd,
 				  struct fuse_dax_mapping *dmap)
 {
-	spin_lock(&fc->lock);
-	__dmap_remove_busy_list(fc, dmap);
-	spin_unlock(&fc->lock);
+	spin_lock(&fcd->lock);
+	__dmap_remove_busy_list(fcd, dmap);
+	spin_unlock(&fcd->lock);
 }
 
-/* This assumes fc->lock is held */
-static void __dmap_add_to_free_pool(struct fuse_conn *fc,
+/* This assumes fcd->lock is held */
+static void __dmap_add_to_free_pool(struct fuse_conn_dax *fcd,
 				struct fuse_dax_mapping *dmap)
 {
-	list_add_tail(&dmap->list, &fc->free_ranges);
-	fc->nr_free_ranges++;
-	wake_up(&fc->dax_range_waitq);
+	list_add_tail(&dmap->list, &fcd->free_ranges);
+	fcd->nr_free_ranges++;
+	wake_up(&fcd->range_waitq);
 }
 
-static void dmap_add_to_free_pool(struct fuse_conn *fc,
+static void dmap_add_to_free_pool(struct fuse_conn_dax *fcd,
 				struct fuse_dax_mapping *dmap)
 {
 	/* Return fuse_dax_mapping to free list */
-	spin_lock(&fc->lock);
-	__dmap_add_to_free_pool(fc, dmap);
-	spin_unlock(&fc->lock);
+	spin_lock(&fcd->lock);
+	__dmap_add_to_free_pool(fcd, dmap);
+	spin_unlock(&fcd->lock);
 }
 
 static int fuse_setup_one_mapping(struct inode *inode, unsigned long start_idx,
@@ -144,13 +190,14 @@
 				  bool upgrade)
 {
 	struct fuse_conn *fc = get_fuse_conn(inode);
+	struct fuse_conn_dax *fcd = fc->dax;
 	struct fuse_inode *fi = get_fuse_inode(inode);
 	struct fuse_setupmapping_in inarg;
 	loff_t offset = start_idx << FUSE_DAX_SHIFT;
 	FUSE_ARGS(args);
 	ssize_t err;
 
-	WARN_ON(fc->nr_free_ranges < 0);
+	WARN_ON(fcd->nr_free_ranges < 0);
 
 	/* Ask fuse daemon to setup mapping */
 	memset(&inarg, 0, sizeof(inarg));
@@ -178,21 +225,20 @@
 		 */
 		dmap->inode = inode;
 		dmap->itn.start = dmap->itn.last = start_idx;
-		/* Protected by fi->i_dmap_sem */
-		interval_tree_insert(&dmap->itn, &fi->dmap_tree);
-		fi->nr_dmaps++;
-		spin_lock(&fc->lock);
-		list_add_tail(&dmap->busy_list, &fc->busy_ranges);
-		fc->nr_busy_ranges++;
-		spin_unlock(&fc->lock);
+		/* Protected by fi->dax->sem */
+		interval_tree_insert(&dmap->itn, &fi->dax->tree);
+		fi->dax->nr++;
+		spin_lock(&fcd->lock);
+		list_add_tail(&dmap->busy_list, &fcd->busy_ranges);
+		fcd->nr_busy_ranges++;
+		spin_unlock(&fcd->lock);
 	}
 	return 0;
 }
 
-static int
-fuse_send_removemapping(struct inode *inode,
-			struct fuse_removemapping_in *inargp,
-			struct fuse_removemapping_one *remove_one)
+static int fuse_send_removemapping(struct inode *inode,
+				   struct fuse_removemapping_in *inargp,
+				   struct fuse_removemapping_one *remove_one)
 {
 	struct fuse_inode *fi = get_fuse_inode(inode);
 	struct fuse_conn *fc = get_fuse_conn(inode);
@@ -246,18 +292,18 @@
 
 /*
  * Cleanup dmap entry and add back to free list. This should be called with
- * fc->lock held.
+ * fcd->lock held.
  */
-static void dmap_reinit_add_to_free_pool(struct fuse_conn *fc,
+static void dmap_reinit_add_to_free_pool(struct fuse_conn_dax *fcd,
 					    struct fuse_dax_mapping *dmap)
 {
 	pr_debug("fuse: freeing memory range start_idx=0x%lx end_idx=0x%lx window_offset=0x%llx length=0x%llx\n",
 		 dmap->itn.start, dmap->itn.last, dmap->window_offset,
 		 dmap->length);
-	__dmap_remove_busy_list(fc, dmap);
+	__dmap_remove_busy_list(fcd, dmap);
 	dmap->inode = NULL;
 	dmap->itn.start = dmap->itn.last = 0;
-	__dmap_add_to_free_pool(fc, dmap);
+	__dmap_add_to_free_pool(fcd, dmap);
 }
 
 /*
@@ -266,7 +312,8 @@
  * called from evict_inode() path where we know all dmap entries can be
  * reclaimed.
  */
-static void inode_reclaim_dmap_range(struct fuse_conn *fc, struct inode *inode,
+static void inode_reclaim_dmap_range(struct fuse_conn_dax *fcd,
+				     struct inode *inode,
 				      loff_t start, loff_t end)
 {
 	struct fuse_inode *fi = get_fuse_inode(inode);
@@ -278,14 +325,14 @@
 	struct interval_tree_node *node;
 
 	while (1) {
-		node = interval_tree_iter_first(&fi->dmap_tree, start_idx,
+		node = interval_tree_iter_first(&fi->dax->tree, start_idx,
 						end_idx);
 		if (!node)
 			break;
 		dmap = node_to_dmap(node);
 		/* inode is going away. There should not be any users of dmap */
 		WARN_ON(refcount_read(&dmap->refcnt) > 1);
-		interval_tree_remove(&dmap->itn, &fi->dmap_tree);
+		interval_tree_remove(&dmap->itn, &fi->dax->tree);
 		num++;
 		list_add(&dmap->list, &to_remove);
 	}
@@ -294,29 +341,19 @@
 	if (list_empty(&to_remove))
 		return;
 
-	WARN_ON(fi->nr_dmaps < num);
-	fi->nr_dmaps -= num;
-	/*
-	 * During umount/shutdown, fuse connection is dropped first
-	 * and evict_inode() is called later. That means any
-	 * removemapping messages are going to fail. Send messages
-	 * only if connection is up. Otherwise fuse daemon is
-	 * responsible for cleaning up any leftover references and
-	 * mappings.
-	 */
-	if (fc->connected) {
-		err = dmap_removemapping_list(inode, num, &to_remove);
-		if (err) {
-			pr_warn("Failed to removemappings. start=0x%llx end=0x%llx\n",
-				start, end);
-		}
+	WARN_ON(fi->dax->nr < num);
+	fi->dax->nr -= num;
+	err = dmap_removemapping_list(inode, num, &to_remove);
+	if (err && err != -ENOTCONN) {
+		pr_warn("Failed to removemappings. start=0x%llx end=0x%llx\n",
+			start, end);
 	}
-	spin_lock(&fc->lock);
+	spin_lock(&fcd->lock);
 	list_for_each_entry_safe(dmap, n, &to_remove, list) {
 		list_del_init(&dmap->list);
-		dmap_reinit_add_to_free_pool(fc, dmap);
+		dmap_reinit_add_to_free_pool(fcd, dmap);
 	}
-	spin_unlock(&fc->lock);
+	spin_unlock(&fcd->lock);
 }
 
 static int dmap_removemapping_one(struct inode *inode,
@@ -335,20 +372,23 @@
 }
 
 /*
- * It is called from evict_inode() and by that time inode is going away. So
- * this function does not take any locks like fi->i_dmap_sem for traversing
+ * It is called from evict_inode() and by that time inode is going away. So this
+ * function does not take any locks like fi->dax->sem for traversing
  * that fuse inode interval tree. If that lock is taken then lock validator
  * complains of deadlock situation w.r.t fs_reclaim lock.
  */
-void fuse_cleanup_inode_mappings(struct inode *inode)
+void fuse_dax_inode_cleanup(struct inode *inode)
 {
 	struct fuse_conn *fc = get_fuse_conn(inode);
+	struct fuse_inode *fi = get_fuse_inode(inode);
+
 	/*
 	 * fuse_evict_inode() has already called truncate_inode_pages_final()
 	 * before we arrive here. So we should not have to worry about any
 	 * pages/exception entries still associated with inode.
 	 */
-	inode_reclaim_dmap_range(fc, inode, 0, -1);
+	inode_reclaim_dmap_range(fc->dax, inode, 0, -1);
+	WARN_ON(fi->dax->nr);
 }
 
 static void fuse_fill_iomap_hole(struct iomap *iomap, loff_t length)
@@ -359,8 +399,8 @@
 }
 
 static void fuse_fill_iomap(struct inode *inode, loff_t pos, loff_t length,
-			struct iomap *iomap, struct fuse_dax_mapping *dmap,
-			unsigned int flags)
+			    struct iomap *iomap, struct fuse_dax_mapping *dmap,
+			    unsigned int flags)
 {
 	loff_t offset, len;
 	loff_t i_size = i_size_read(inode);
@@ -380,7 +420,7 @@
 		iomap->type = IOMAP_MAPPED;
 		/*
 		 * increace refcnt so that reclaim code knows this dmap is in
-		 * use. This assumes i_dmap_sem mutex is held either
+		 * use. This assumes fi->dax->sem mutex is held either
 		 * shared/exclusive.
 		 */
 		refcount_inc(&dmap->refcnt);
@@ -399,7 +439,7 @@
 				      struct iomap *iomap)
 {
 	struct fuse_inode *fi = get_fuse_inode(inode);
-	struct fuse_conn *fc = get_fuse_conn(inode);
+	struct fuse_conn_dax *fcd = get_fuse_conn(inode)->dax;
 	struct fuse_dax_mapping *dmap, *alloc_dmap = NULL;
 	int ret;
 	bool writable = flags & IOMAP_WRITE;
@@ -417,11 +457,11 @@
 	 * range to become free and retry.
 	 */
 	if (flags & IOMAP_FAULT) {
-		alloc_dmap = alloc_dax_mapping(fc);
+		alloc_dmap = alloc_dax_mapping(fcd);
 		if (!alloc_dmap)
 			return -EAGAIN;
 	} else {
-		alloc_dmap = alloc_dax_mapping_reclaim(fc, inode);
+		alloc_dmap = alloc_dax_mapping_reclaim(fcd, inode);
 		if (IS_ERR(alloc_dmap))
 			return PTR_ERR(alloc_dmap);
 	}
@@ -434,17 +474,17 @@
 	 * Take write lock so that only one caller can try to setup mapping
 	 * and other waits.
 	 */
-	down_write(&fi->i_dmap_sem);
+	down_write(&fi->dax->sem);
 	/*
 	 * We dropped lock. Check again if somebody else setup
 	 * mapping already.
 	 */
-	node = interval_tree_iter_first(&fi->dmap_tree, start_idx, start_idx);
+	node = interval_tree_iter_first(&fi->dax->tree, start_idx, start_idx);
 	if (node) {
 		dmap = node_to_dmap(node);
 		fuse_fill_iomap(inode, pos, length, iomap, dmap, flags);
-		dmap_add_to_free_pool(fc, alloc_dmap);
-		up_write(&fi->i_dmap_sem);
+		dmap_add_to_free_pool(fcd, alloc_dmap);
+		up_write(&fi->dax->sem);
 		return 0;
 	}
 
@@ -452,12 +492,12 @@
 	ret = fuse_setup_one_mapping(inode, pos >> FUSE_DAX_SHIFT, alloc_dmap,
 				     writable, false);
 	if (ret < 0) {
-		dmap_add_to_free_pool(fc, alloc_dmap);
-		up_write(&fi->i_dmap_sem);
+		dmap_add_to_free_pool(fcd, alloc_dmap);
+		up_write(&fi->dax->sem);
 		return ret;
 	}
 	fuse_fill_iomap(inode, pos, length, iomap, alloc_dmap, flags);
-	up_write(&fi->i_dmap_sem);
+	up_write(&fi->dax->sem);
 	return 0;
 }
 
@@ -475,14 +515,14 @@
 	 * Take exclusive lock so that only one caller can try to setup
 	 * mapping and others wait.
 	 */
-	down_write(&fi->i_dmap_sem);
-	node = interval_tree_iter_first(&fi->dmap_tree, idx, idx);
+	down_write(&fi->dax->sem);
+	node = interval_tree_iter_first(&fi->dax->tree, idx, idx);
 
 	/* We are holding either inode lock or i_mmap_sem, and that should
 	 * ensure that dmap can't be truncated. We are holding a reference
 	 * on dmap and that should make sure it can't be reclaimed. So dmap
 	 * should still be there in tree despite the fact we dropped and
-	 * re-acquired the i_dmap_sem lock.
+	 * re-acquired the dmap->sem lock.
 	 */
 	ret = -EIO;
 	if (WARN_ON(!node))
@@ -490,7 +530,7 @@
 	dmap = node_to_dmap(node);
 
 	/* We took an extra reference on dmap to make sure its not reclaimd.
-	 * Now we hold i_dmap_sem lock and that reference is not needed
+	 * Now we hold dmap->sem lock and that reference is not needed
 	 * anymore. Drop it.
 	 */
 	if (refcount_dec_and_test(&dmap->refcnt)) {
@@ -515,7 +555,7 @@
 out_fill_iomap:
 	fuse_fill_iomap(inode, pos, length, iomap, dmap, flags);
 out_err:
-	up_write(&fi->i_dmap_sem);
+	up_write(&fi->dax->sem);
 	return ret;
 }
 
@@ -540,7 +580,7 @@
 	iomap->offset = pos;
 	iomap->flags = 0;
 	iomap->bdev = NULL;
-	iomap->dax_dev = fc->dax_dev;
+	iomap->dax_dev = fc->dax->dev;
 
 	/*
 	 * Both read/write and mmap path can race here. So we need something
@@ -549,33 +589,33 @@
 	 * For now, use a semaphore for this. It probably needs to be
 	 * optimized later.
 	 */
-	down_read(&fi->i_dmap_sem);
-	node = interval_tree_iter_first(&fi->dmap_tree, start_idx, start_idx);
+	down_read(&fi->dax->sem);
+	node = interval_tree_iter_first(&fi->dax->tree, start_idx, start_idx);
 	if (node) {
 		dmap = node_to_dmap(node);
 		if (writable && !dmap->writable) {
 			/* Upgrade read-only mapping to read-write. This will
-			 * require exclusive i_dmap_sem lock as we don't want
+			 * require exclusive fi->dax->sem lock as we don't want
 			 * two threads to be trying to this simultaneously
 			 * for same dmap. So drop shared lock and acquire
 			 * exclusive lock.
 			 *
-			 * Before dropping i_dmap_sem lock, take reference
+			 * Before dropping fi->dax->sem lock, take reference
 			 * on dmap so that its not freed by range reclaim.
 			 */
 			refcount_inc(&dmap->refcnt);
-			up_read(&fi->i_dmap_sem);
+			up_read(&fi->dax->sem);
 			pr_debug("%s: Upgrading mapping at offset 0x%llx length 0x%llx\n",
 				 __func__, pos, length);
 			return fuse_upgrade_dax_mapping(inode, pos, length,
 							flags, iomap);
 		} else {
 			fuse_fill_iomap(inode, pos, length, iomap, dmap, flags);
-			up_read(&fi->i_dmap_sem);
+			up_read(&fi->dax->sem);
 			return 0;
 		}
 	} else {
-		up_read(&fi->i_dmap_sem);
+		up_read(&fi->dax->sem);
 		pr_debug("%s: no mapping at offset 0x%llx length 0x%llx\n",
 				__func__, pos, length);
 		if (pos >= i_size_read(inode))
@@ -632,7 +672,7 @@
 }
 
 /* Should be called with fi->i_mmap_sem lock held exclusively */
-static int __fuse_break_dax_layouts(struct inode *inode, bool *retry,
+static int __fuse_dax_break_layouts(struct inode *inode, bool *retry,
 				    loff_t start, loff_t end)
 {
 	struct page *page;
@@ -648,7 +688,7 @@
 }
 
 /* dmap_end == 0 leads to unmapping of whole file */
-int fuse_break_dax_layouts(struct inode *inode, u64 dmap_start,
+int fuse_dax_break_layouts(struct inode *inode, u64 dmap_start,
 				  u64 dmap_end)
 {
 	bool	retry;
@@ -656,7 +696,7 @@
 
 	do {
 		retry = false;
-		ret = __fuse_break_dax_layouts(inode, &retry, dmap_start,
+		ret = __fuse_dax_break_layouts(inode, &retry, dmap_start,
 					       dmap_end);
 	} while (ret == 0 && retry);
 
@@ -743,14 +783,14 @@
 	return ret;
 }
 
-int fuse_dax_writepages(struct address_space *mapping,
-			struct writeback_control *wbc)
+static int fuse_dax_writepages(struct address_space *mapping,
+			       struct writeback_control *wbc)
 {
 
 	struct inode *inode = mapping->host;
 	struct fuse_conn *fc = get_fuse_conn(inode);
 
-	return dax_writeback_mapping_range(mapping, fc->dax_dev, wbc);
+	return dax_writeback_mapping_range(mapping, fc->dax->dev, wbc);
 }
 
 static vm_fault_t __fuse_dax_fault(struct vm_fault *vmf,
@@ -761,14 +801,14 @@
 	struct super_block *sb = inode->i_sb;
 	pfn_t pfn;
 	int error = 0;
-	struct fuse_conn *fc = get_fuse_conn(inode);
+	struct fuse_conn_dax *fcd = get_fuse_conn(inode)->dax;
 	bool retry = false;
 
 	if (write)
 		sb_start_pagefault(sb);
 retry:
-	if (retry && !(fc->nr_free_ranges > 0))
-		wait_event(fc->dax_range_waitq, (fc->nr_free_ranges > 0));
+	if (retry && !(fcd->nr_free_ranges > 0))
+		wait_event(fcd->range_waitq, (fcd->nr_free_ranges > 0));
 
 	/*
 	 * We need to serialize against not only truncate but also against
@@ -856,7 +896,7 @@
 	return ret;
 }
 
-static int reclaim_one_dmap_locked(struct fuse_conn *fc, struct inode *inode,
+static int reclaim_one_dmap_locked(struct inode *inode,
 				   struct fuse_dax_mapping *dmap)
 {
 	int ret;
@@ -871,35 +911,31 @@
 		return ret;
 
 	/* Remove dax mapping from inode interval tree now */
-	interval_tree_remove(&dmap->itn, &fi->dmap_tree);
-	fi->nr_dmaps--;
+	interval_tree_remove(&dmap->itn, &fi->dax->tree);
+	fi->dax->nr--;
 
-	/* It is possible that umount/shutodwn has killed the fuse connection
-	 * and worker thread is trying to reclaim memory in parallel. So check
-	 * if connection is still up or not otherwise don't send removemapping
-	 * message.
+	/* It is possible that umount/shutdown has killed the fuse connection
+	 * and worker thread is trying to reclaim memory in parallel.  Don't
+	 * warn in that case.
 	 */
-	if (fc->connected) {
-		ret = dmap_removemapping_one(inode, dmap);
-		if (ret) {
-			pr_warn("Failed to remove mapping. offset=0x%llx len=0x%llx ret=%d\n",
-				dmap->window_offset, dmap->length, ret);
-		}
+	ret = dmap_removemapping_one(inode, dmap);
+	if (ret && ret != -ENOTCONN) {
+		pr_warn("Failed to remove mapping. offset=0x%llx len=0x%llx ret=%d\n",
+			dmap->window_offset, dmap->length, ret);
 	}
 	return 0;
 }
 
 /* Find first mapped dmap for an inode and return file offset. Caller needs
- * to hold inode->i_dmap_sem lock either shared or exclusive.
+ * to hold inode->dmap->sem lock either shared or exclusive.
  */
-static struct fuse_dax_mapping *inode_lookup_first_dmap(struct fuse_conn *fc,
-							struct inode *inode)
+static struct fuse_dax_mapping *inode_lookup_first_dmap(struct inode *inode)
 {
 	struct fuse_inode *fi = get_fuse_inode(inode);
 	struct fuse_dax_mapping *dmap;
 	struct interval_tree_node *node;
 
-	for (node = interval_tree_iter_first(&fi->dmap_tree, 0, -1); node;
+	for (node = interval_tree_iter_first(&fi->dax->tree, 0, -1); node;
 	     node = interval_tree_iter_next(node, 0, -1)) {
 		dmap = node_to_dmap(node);
 		/* still in use. */
@@ -917,7 +953,7 @@
  * it back to free pool.
  */
 static struct fuse_dax_mapping *
-inode_inline_reclaim_one_dmap(struct fuse_conn *fc, struct inode *inode,
+inode_inline_reclaim_one_dmap(struct fuse_conn_dax *fcd, struct inode *inode,
 			      bool *retry)
 {
 	struct fuse_inode *fi = get_fuse_inode(inode);
@@ -930,14 +966,14 @@
 	down_write(&fi->i_mmap_sem);
 
 	/* Lookup a dmap and corresponding file offset to reclaim. */
-	down_read(&fi->i_dmap_sem);
-	dmap = inode_lookup_first_dmap(fc, inode);
+	down_read(&fi->dax->sem);
+	dmap = inode_lookup_first_dmap(inode);
 	if (dmap) {
 		start_idx = dmap->itn.start;
 		dmap_start = start_idx << FUSE_DAX_SHIFT;
 		dmap_end = dmap_start + FUSE_DAX_SZ - 1;
 	}
-	up_read(&fi->i_dmap_sem);
+	up_read(&fi->dax->sem);
 
 	if (!dmap)
 		goto out_mmap_sem;
@@ -945,16 +981,16 @@
 	 * Make sure there are no references to inode pages using
 	 * get_user_pages()
 	 */
-	ret = fuse_break_dax_layouts(inode, dmap_start, dmap_end);
+	ret = fuse_dax_break_layouts(inode, dmap_start, dmap_end);
 	if (ret) {
-		pr_debug("fuse: fuse_break_dax_layouts() failed. err=%d\n",
+		pr_debug("fuse: fuse_dax_break_layouts() failed. err=%d\n",
 			 ret);
 		dmap = ERR_PTR(ret);
 		goto out_mmap_sem;
 	}
 
-	down_write(&fi->i_dmap_sem);
-	node = interval_tree_iter_first(&fi->dmap_tree, start_idx, start_idx);
+	down_write(&fi->dax->sem);
+	node = interval_tree_iter_first(&fi->dax->tree, start_idx, start_idx);
 	/* Range already got reclaimed by somebody else */
 	if (!node) {
 		if (retry)
@@ -971,14 +1007,14 @@
 		goto out_write_dmap_sem;
 	}
 
-	ret = reclaim_one_dmap_locked(fc, inode, dmap);
+	ret = reclaim_one_dmap_locked(inode, dmap);
 	if (ret < 0) {
 		dmap = ERR_PTR(ret);
 		goto out_write_dmap_sem;
 	}
 
 	/* Clean up dmap. Do not add back to free list */
-	dmap_remove_busy_list(fc, dmap);
+	dmap_remove_busy_list(fcd, dmap);
 	dmap->inode = NULL;
 	dmap->itn.start = dmap->itn.last = 0;
 
@@ -986,14 +1022,14 @@
 		 __func__, inode, dmap->window_offset, dmap->length);
 
 out_write_dmap_sem:
-	up_write(&fi->i_dmap_sem);
+	up_write(&fi->dax->sem);
 out_mmap_sem:
 	up_write(&fi->i_mmap_sem);
 	return dmap;
 }
 
-static struct fuse_dax_mapping *alloc_dax_mapping_reclaim(struct fuse_conn *fc,
-					struct inode *inode)
+static struct fuse_dax_mapping *
+alloc_dax_mapping_reclaim(struct fuse_conn_dax *fcd, struct inode *inode)
 {
 	struct fuse_dax_mapping *dmap;
 	struct fuse_inode *fi = get_fuse_inode(inode);
@@ -1001,11 +1037,11 @@
 	while (1) {
 		bool retry = false;
 
-		dmap = alloc_dax_mapping(fc);
+		dmap = alloc_dax_mapping(fcd);
 		if (dmap)
 			return dmap;
 
-		dmap = inode_inline_reclaim_one_dmap(fc, inode, &retry);
+		dmap = inode_inline_reclaim_one_dmap(fcd, inode, &retry);
 		/*
 		 * Either we got a mapping or it is an error, return in both
 		 * the cases.
@@ -1020,27 +1056,27 @@
 		 * if a deadlock is possible if we sleep with fi->i_mmap_sem
 		 * held and worker to free memory can't make progress due
 		 * to unavailability of fi->i_mmap_sem lock. So sleep
-		 * only if fi->nr_dmaps=0
+		 * only if fi->dax->nr=0
 		 */
 		if (retry)
 			continue;
 		/*
 		 * There are no mappings which can be reclaimed. Wait for one.
-		 * We are not holding fi->i_dmap_sem. So it is possible
+		 * We are not holding fi->dax->sem. So it is possible
 		 * that range gets added now. But as we are not holding
-		 * fi->i_mmap_sem, worker should still be able to free up
-		 * a range and wake us up.
+		 * fi->i_mmap_sem, worker should still be able to free up a
+		 * range and wake us up.
 		 */
-		if (!fi->nr_dmaps && !(fc->nr_free_ranges > 0)) {
-			if (wait_event_killable_exclusive(fc->dax_range_waitq,
-					(fc->nr_free_ranges > 0))) {
+		if (!fi->dax->nr && !(fcd->nr_free_ranges > 0)) {
+			if (wait_event_killable_exclusive(fcd->range_waitq,
+					(fcd->nr_free_ranges > 0))) {
 				return ERR_PTR(-EINTR);
 			}
 		}
 	}
 }
 
-static int lookup_and_reclaim_dmap_locked(struct fuse_conn *fc,
+static int lookup_and_reclaim_dmap_locked(struct fuse_conn_dax *fcd,
 					  struct inode *inode,
 					  unsigned long start_idx)
 {
@@ -1050,7 +1086,7 @@
 	struct interval_tree_node *node;
 
 	/* Find fuse dax mapping at file offset inode. */
-	node = interval_tree_iter_first(&fi->dmap_tree, start_idx, start_idx);
+	node = interval_tree_iter_first(&fi->dax->tree, start_idx, start_idx);
 
 	/* Range already got cleaned up by somebody else */
 	if (!node)
@@ -1061,14 +1097,14 @@
 	if (refcount_read(&dmap->refcnt) > 1)
 		return 0;
 
-	ret = reclaim_one_dmap_locked(fc, inode, dmap);
+	ret = reclaim_one_dmap_locked(inode, dmap);
 	if (ret < 0)
 		return ret;
 
 	/* Cleanup dmap entry and add back to free list */
-	spin_lock(&fc->lock);
-	dmap_reinit_add_to_free_pool(fc, dmap);
-	spin_unlock(&fc->lock);
+	spin_lock(&fcd->lock);
+	dmap_reinit_add_to_free_pool(fcd, dmap);
+	spin_unlock(&fcd->lock);
 	return ret;
 }
 
@@ -1076,10 +1112,11 @@
  * Free a range of memory.
  * Locking.
  * 1. Take fuse_inode->i_mmap_sem to block dax faults.
- * 2. Take fuse_inode->i_dmap_sem to protect interval tree and also to make
- *    sure read/write can not reuse a dmap which we might be freeing.
+ * 2. Take fuse_inode->dmap->sem to protect interval tree and also to
+ *    make sure read/write can not reuse a dmap which we might be freeing.
  */
-static int lookup_and_reclaim_dmap(struct fuse_conn *fc, struct inode *inode,
+static int lookup_and_reclaim_dmap(struct fuse_conn_dax *fcd,
+				   struct inode *inode,
 				   unsigned long start_idx,
 				   unsigned long end_idx)
 {
@@ -1089,22 +1126,22 @@
 	loff_t dmap_end = (dmap_start + FUSE_DAX_SZ) - 1;
 
 	down_write(&fi->i_mmap_sem);
-	ret = fuse_break_dax_layouts(inode, dmap_start, dmap_end);
+	ret = fuse_dax_break_layouts(inode, dmap_start, dmap_end);
 	if (ret) {
-		pr_debug("virtio_fs: fuse_break_dax_layouts() failed. err=%d\n",
+		pr_debug("virtio_fs: fuse_dax_break_layouts() failed. err=%d\n",
 			 ret);
 		goto out_mmap_sem;
 	}
 
-	down_write(&fi->i_dmap_sem);
-	ret = lookup_and_reclaim_dmap_locked(fc, inode, start_idx);
-	up_write(&fi->i_dmap_sem);
+	down_write(&fi->dax->sem);
+	ret = lookup_and_reclaim_dmap_locked(fcd, inode, start_idx);
+	up_write(&fi->dax->sem);
 out_mmap_sem:
 	up_write(&fi->i_mmap_sem);
 	return ret;
 }
 
-static int try_to_free_dmap_chunks(struct fuse_conn *fc,
+static int try_to_free_dmap_chunks(struct fuse_conn_dax *fcd,
 				   unsigned long nr_to_free)
 {
 	struct fuse_dax_mapping *dmap, *pos, *temp;
@@ -1119,14 +1156,14 @@
 			break;
 
 		dmap = NULL;
-		spin_lock(&fc->lock);
+		spin_lock(&fcd->lock);
 
-		if (!fc->nr_busy_ranges) {
-			spin_unlock(&fc->lock);
+		if (!fcd->nr_busy_ranges) {
+			spin_unlock(&fcd->lock);
 			return 0;
 		}
 
-		list_for_each_entry_safe(pos, temp, &fc->busy_ranges,
+		list_for_each_entry_safe(pos, temp, &fcd->busy_ranges,
 						busy_list) {
 			/* skip this range if it's in use. */
 			if (refcount_read(&pos->refcnt) > 1)
@@ -1146,16 +1183,16 @@
 			 * selecting new element in next iteration of loop.
 			 */
 			dmap = pos;
-			list_move_tail(&dmap->busy_list, &fc->busy_ranges);
+			list_move_tail(&dmap->busy_list, &fcd->busy_ranges);
 			start_idx = end_idx = dmap->itn.start;
 			window_offset = dmap->window_offset;
 			break;
 		}
-		spin_unlock(&fc->lock);
+		spin_unlock(&fcd->lock);
 		if (!dmap)
 			return 0;
 
-		ret = lookup_and_reclaim_dmap(fc, inode, start_idx, end_idx);
+		ret = lookup_and_reclaim_dmap(fcd, inode, start_idx, end_idx);
 		iput(inode);
 		if (ret)
 			return ret;
@@ -1164,46 +1201,68 @@
 	return 0;
 }
 
-void fuse_dax_free_mem_worker(struct work_struct *work)
+static void fuse_dax_free_mem_worker(struct work_struct *work)
 {
 	int ret;
-	struct fuse_conn *fc = container_of(work, struct fuse_conn,
-						dax_free_work.work);
-	ret = try_to_free_dmap_chunks(fc, FUSE_DAX_RECLAIM_CHUNK);
+	struct fuse_conn_dax *fcd;
+
+	fcd = container_of(work, typeof(*fcd), free_work.work);
+	ret = try_to_free_dmap_chunks(fcd, FUSE_DAX_RECLAIM_CHUNK);
 	if (ret) {
 		pr_debug("fuse: try_to_free_dmap_chunks() failed with err=%d\n",
 			 ret);
 	}
 
 	/* If number of free ranges are still below threhold, requeue */
-	kick_dmap_free_worker(fc, 1);
+	kick_dmap_free_worker(fcd, 1);
 }
 
-void fuse_free_dax_mem_ranges(struct fuse_conn *fc)
+static void fuse_free_dax_mem_ranges(struct fuse_conn_dax *fcd)
 {
 	struct fuse_dax_mapping *range, *temp;
 
 	/* Free All allocated elements */
-	list_for_each_entry_safe(range, temp, &fc->free_ranges, list) {
+	list_for_each_entry_safe(range, temp, &fcd->free_ranges, list) {
 		list_del(&range->list);
 		if (!list_empty(&range->busy_list))
 			list_del(&range->busy_list);
 		kfree(range);
 	}
+	kfree(fcd);
 }
 
-#ifdef CONFIG_FS_DAX
-int fuse_dax_mem_range_init(struct fuse_conn *fc, struct dax_device *dax_dev)
+void fuse_dax_conn_free(struct fuse_conn *fc)
+{
+	if (fc->dax)
+		fuse_free_dax_mem_ranges(fc->dax);
+}
+
+int fuse_dax_conn_alloc(struct fuse_conn *fc, struct dax_device *dax_dev)
 {
 	long nr_pages, nr_ranges;
 	void *kaddr;
 	pfn_t pfn;
 	struct fuse_dax_mapping *range;
+	struct fuse_conn_dax *fcd;
 	phys_addr_t phys_addr;
 	int ret, id;
 	size_t dax_size = -1;
 	unsigned long i;
 
+	if (!dax_dev)
+		return 0;
+
+
+	fcd = kzalloc(sizeof(*fcd), GFP_KERNEL);
+	if (!fcd)
+		return -ENOMEM;
+
+	spin_lock_init(&fcd->lock);
+	init_waitqueue_head(&fcd->range_waitq);
+	INIT_LIST_HEAD(&fcd->free_ranges);
+	INIT_LIST_HEAD(&fcd->busy_ranges);
+	INIT_DELAYED_WORK(&fcd->free_work, fuse_dax_free_mem_worker);
+
 	id = dax_read_lock();
 	nr_pages = dax_direct_access(dax_dev, 0, PHYS_PFN(dax_size), &kaddr,
 					&pfn);
@@ -1232,20 +1291,71 @@
 		range->length = FUSE_DAX_SZ;
 		INIT_LIST_HEAD(&range->busy_list);
 		refcount_set(&range->refcnt, 1);
-		list_add_tail(&range->list, &fc->free_ranges);
+		list_add_tail(&range->list, &fcd->free_ranges);
 	}
 
-	fc->nr_free_ranges = nr_ranges;
-	fc->nr_ranges = nr_ranges;
+	fcd->nr_free_ranges = nr_ranges;
+	fcd->nr_ranges = nr_ranges;
+
+	fc->dax = fcd;
 	return 0;
 out_err:
 	/* Free All allocated elements */
-	fuse_free_dax_mem_ranges(fc);
+	fuse_free_dax_mem_ranges(fcd);
 	return ret;
 }
-#else /* !CONFIG_FS_DAX */
-int fuse_dax_mem_range_init(struct fuse_conn *fc, struct dax_device *dax_dev)
+
+bool fuse_dax_inode_alloc(struct super_block *sb, struct fuse_inode *fi)
 {
-	return 0;
+	struct fuse_conn *fc = get_fuse_conn_super(sb);
+
+	fi->dax = NULL;
+	if (fc->dax) {
+		fi->dax = kzalloc(sizeof(*fi->dax), GFP_KERNEL_ACCOUNT);
+		if (!fi->dax)
+			return false;
+
+		init_rwsem(&fi->dax->sem);
+		fi->dax->tree = RB_ROOT_CACHED;
+	}
+
+	return true;
 }
-#endif /* CONFIG_FS_DAX */
+
+static const struct address_space_operations fuse_dax_file_aops  = {
+	.writepages	= fuse_dax_writepages,
+	.direct_IO	= noop_direct_IO,
+	.set_page_dirty	= noop_set_page_dirty,
+	.invalidatepage	= noop_invalidatepage,
+};
+
+void fuse_dax_inode_init(struct inode *inode)
+{
+	struct fuse_conn *fc = get_fuse_conn(inode);
+
+	if (!fc->dax)
+		return;
+
+	inode->i_flags |= S_DAX;
+	inode->i_data.a_ops = &fuse_dax_file_aops;
+}
+
+bool fuse_dax_check_alignment(struct fuse_conn *fc, unsigned int map_alignment)
+{
+	if (fc->dax && (map_alignment > FUSE_DAX_SHIFT)) {
+		pr_warn("FUSE: map_alignment %u incompatible with dax mem range size %u\n",
+			map_alignment, FUSE_DAX_SZ);
+		return false;
+	}
+	return true;
+}
+
+void fuse_dax_cancel_work(struct fuse_conn *fc)
+{
+	struct fuse_conn_dax *fcd = fc->dax;
+
+	if (fcd)
+		cancel_delayed_work_sync(&fcd->free_work);
+
+}
+EXPORT_SYMBOL_GPL(fuse_dax_cancel_work);
diff --git a/fs/fuse/dir.c b/fs/fuse/dir.c
index 4c7e29b..c4a0129 100644
--- a/fs/fuse/dir.c
+++ b/fs/fuse/dir.c
@@ -1516,10 +1516,10 @@
 		is_truncate = true;
 	}
 
-	if (IS_DAX(inode) && is_truncate) {
+	if (FUSE_IS_DAX(inode) && is_truncate) {
 		down_write(&fi->i_mmap_sem);
 		fault_blocked = true;
-		err = fuse_break_dax_layouts(inode, 0, 0);
+		err = fuse_dax_break_layouts(inode, 0, 0);
 		if (err) {
 			up_write(&fi->i_mmap_sem);
 			return err;
diff --git a/fs/fuse/file.c b/fs/fuse/file.c
index 458b6fa..7b53ae4 100644
--- a/fs/fuse/file.c
+++ b/fs/fuse/file.c
@@ -224,7 +224,7 @@
 			  fc->atomic_o_trunc &&
 			  fc->writeback_cache;
 	bool dax_truncate = (file->f_flags & O_TRUNC) &&
-			  fc->atomic_o_trunc && IS_DAX(inode);
+			  fc->atomic_o_trunc && FUSE_IS_DAX(inode);
 
 	err = generic_file_open(inode, file);
 	if (err)
@@ -237,7 +237,7 @@
 
 	if (dax_truncate) {
 		down_write(&get_fuse_inode(inode)->i_mmap_sem);
-		err = fuse_break_dax_layouts(inode, 0, 0);
+		err = fuse_dax_break_layouts(inode, 0, 0);
 		if (err)
 			goto out;
 	}
@@ -1558,7 +1558,7 @@
 	if (is_bad_inode(file_inode(file)))
 		return -EIO;
 
-	if (IS_DAX(inode))
+	if (FUSE_IS_DAX(inode))
 		return fuse_dax_read_iter(iocb, to);
 
 	if (ff->open_flags & FOPEN_DIRECT_IO)
@@ -1576,7 +1576,7 @@
 	if (is_bad_inode(file_inode(file)))
 		return -EIO;
 
-	if (IS_DAX(inode))
+	if (FUSE_IS_DAX(inode))
 		return fuse_dax_write_iter(iocb, from);
 
 	if (ff->open_flags & FOPEN_DIRECT_IO)
@@ -2340,7 +2340,7 @@
 	struct fuse_file *ff = file->private_data;
 
 	/* DAX mmap is superior to direct_io mmap */
-	if (IS_DAX(file_inode(file)))
+	if (FUSE_IS_DAX(file_inode(file)))
 		return fuse_dax_mmap(file, vma);
 
 	if (ff->open_flags & FOPEN_DIRECT_IO) {
@@ -3235,7 +3235,7 @@
 	bool lock_inode = !(mode & FALLOC_FL_KEEP_SIZE) ||
 			   (mode & FALLOC_FL_PUNCH_HOLE);
 
-	bool block_faults = IS_DAX(inode) && lock_inode;
+	bool block_faults = FUSE_IS_DAX(inode) && lock_inode;
 
 	if (mode & ~(FALLOC_FL_KEEP_SIZE | FALLOC_FL_PUNCH_HOLE))
 		return -EOPNOTSUPP;
@@ -3247,7 +3247,7 @@
 		inode_lock(inode);
 		if (block_faults) {
 			down_write(&fi->i_mmap_sem);
-			err = fuse_break_dax_layouts(inode, 0, 0);
+			err = fuse_dax_break_layouts(inode, 0, 0);
 			if (err)
 				goto out;
 		}
@@ -3467,17 +3467,9 @@
 	.write_end	= fuse_write_end,
 };
 
-static const struct address_space_operations fuse_dax_file_aops  = {
-	.writepages	= fuse_dax_writepages,
-	.direct_IO	= noop_direct_IO,
-	.set_page_dirty	= noop_set_page_dirty,
-	.invalidatepage	= noop_invalidatepage,
-};
-
 void fuse_init_file_inode(struct inode *inode)
 {
 	struct fuse_inode *fi = get_fuse_inode(inode);
-	struct fuse_conn *fc = get_fuse_conn(inode);
 
 	inode->i_fop = &fuse_file_operations;
 	inode->i_data.a_ops = &fuse_file_aops;
@@ -3487,10 +3479,7 @@
 	fi->writectr = 0;
 	init_waitqueue_head(&fi->page_waitq);
 	fi->writepages = RB_ROOT;
-	fi->dmap_tree = RB_ROOT_CACHED;
 
-	if (fc->dax_dev) {
-		inode->i_flags |= S_DAX;
-		inode->i_data.a_ops = &fuse_dax_file_aops;
-	}
+	if (IS_ENABLED(CONFIG_VIRTIO_FS_DAX))
+		fuse_dax_inode_init(inode);
 }
diff --git a/fs/fuse/fuse_i.h b/fs/fuse/fuse_i.h
index ad63243..d3ad468 100644
--- a/fs/fuse/fuse_i.h
+++ b/fs/fuse/fuse_i.h
@@ -31,7 +31,6 @@
 #include <linux/pid_namespace.h>
 #include <linux/refcount.h>
 #include <linux/user_namespace.h>
-#include <linux/interval_tree.h>
 
 /** Default max number of pages that can be used in a single read request */
 #define FUSE_DEFAULT_MAX_PAGES_PER_REQ 32
@@ -48,14 +47,6 @@
 /** Number of dentries for each connection in the control filesystem */
 #define FUSE_CTL_NUM_DENTRIES 5
 
-/*
- * Default memory range size.  A power of 2 so it agrees with common FUSE_INIT
- * map_alignment values 4KB and 64KB.
- */
-#define FUSE_DAX_SZ	(2*1024*1024)
-#define FUSE_DAX_SHIFT	(21)
-#define FUSE_DAX_PAGES	(FUSE_DAX_SZ/PAGE_SIZE)
-
 /** List of active connections */
 extern struct list_head fuse_conn_list;
 
@@ -165,14 +156,12 @@
 	 */
 	struct rw_semaphore i_mmap_sem;
 
+#ifdef CONFIG_FUSE_DAX
 	/*
-	 * Semaphore to protect modifications to dmap_tree
+	 * Dax specific inode data
 	 */
-	struct rw_semaphore i_dmap_sem;
-
-	/** Sorted rb tree of struct fuse_dax_mapping elements */
-	struct rb_root_cached dmap_tree;
-	unsigned long nr_dmaps;
+	struct fuse_inode_dax *dax;
+#endif
 };
 
 /** FUSE inode state bits */
@@ -785,26 +774,10 @@
 	/** List of device instances belonging to this connection */
 	struct list_head devices;
 
-	/** DAX device, non-NULL if DAX is supported */
-	struct dax_device *dax_dev;
-
-	/* List of memory ranges which are busy */
-	unsigned long nr_busy_ranges;
-	struct list_head busy_ranges;
-
-	/* Worker to free up memory ranges */
-	struct delayed_work dax_free_work;
-
-	/* Wait queue for a dax range to become free */
-	wait_queue_head_t dax_range_waitq;
-
-	/*
-	 * DAX Window Free Ranges
-	 */
-	long nr_free_ranges;
-	struct list_head free_ranges;
-
-	unsigned long nr_ranges;
+#ifdef CONFIG_FUSE_DAX
+	/* Dax specific conn data */
+	struct fuse_conn_dax *dax;
+#endif
 };
 
 static inline struct fuse_conn *get_fuse_conn_super(struct super_block *sb)
@@ -1142,18 +1115,21 @@
  */
 u64 fuse_get_unique(struct fuse_iqueue *fiq);
 void fuse_free_conn(struct fuse_conn *fc);
-void fuse_dax_free_mem_worker(struct work_struct *work);
-void fuse_cleanup_inode_mappings(struct inode *inode);
 
 /* dax.c */
 
+#define FUSE_IS_DAX(inode) (IS_ENABLED(CONFIG_FUSE_DAX) && IS_DAX(inode))
+
 ssize_t fuse_dax_read_iter(struct kiocb *iocb, struct iov_iter *to);
 ssize_t fuse_dax_write_iter(struct kiocb *iocb, struct iov_iter *from);
-int fuse_dax_writepages(struct address_space *mapping,
-			struct writeback_control *wbc);
 int fuse_dax_mmap(struct file *file, struct vm_area_struct *vma);
-int fuse_break_dax_layouts(struct inode *inode, u64 dmap_start, u64 dmap_end);
-int fuse_dax_mem_range_init(struct fuse_conn *fc, struct dax_device *dax_dev);
-void fuse_free_dax_mem_ranges(struct fuse_conn *fc);
+int fuse_dax_break_layouts(struct inode *inode, u64 dmap_start, u64 dmap_end);
+int fuse_dax_conn_alloc(struct fuse_conn *fc, struct dax_device *dax_dev);
+void fuse_dax_conn_free(struct fuse_conn *fc);
+bool fuse_dax_inode_alloc(struct super_block *sb, struct fuse_inode *fi);
+void fuse_dax_inode_init(struct inode *inode);
+void fuse_dax_inode_cleanup(struct inode *inode);
+bool fuse_dax_check_alignment(struct fuse_conn *fc, unsigned int map_alignment);
+void fuse_dax_cancel_work(struct fuse_conn *fc);
 
 #endif /* _FS_FUSE_I_H */
diff --git a/fs/fuse/inode.c b/fs/fuse/inode.c
index 55b780b..b6d0ed3 100644
--- a/fs/fuse/inode.c
+++ b/fs/fuse/inode.c
@@ -84,18 +84,23 @@
 	fi->attr_version = 0;
 	fi->orig_ino = 0;
 	fi->state = 0;
-	fi->nr_dmaps = 0;
 	mutex_init(&fi->mutex);
 	init_rwsem(&fi->i_mmap_sem);
-	init_rwsem(&fi->i_dmap_sem);
 	spin_lock_init(&fi->lock);
 	fi->forget = fuse_alloc_forget();
-	if (!fi->forget) {
-		kmem_cache_free(fuse_inode_cachep, fi);
-		return NULL;
-	}
+	if (!fi->forget)
+		goto out_free;
+
+	if (IS_ENABLED(CONFIG_FUSE_DAX) && !fuse_dax_inode_alloc(sb, fi))
+		goto out_free_forget;
 
 	return &fi->inode;
+
+out_free_forget:
+	kfree(fi->forget);
+out_free:
+	kmem_cache_free(fuse_inode_cachep, fi);
+	return NULL;
 }
 
 static void fuse_free_inode(struct inode *inode)
@@ -115,10 +120,9 @@
 	clear_inode(inode);
 	if (inode->i_sb->s_flags & SB_ACTIVE) {
 		struct fuse_conn *fc = get_fuse_conn(inode);
-		if (IS_DAX(inode)) {
-			fuse_cleanup_inode_mappings(inode);
-			WARN_ON(fi->nr_dmaps);
-		}
+
+		if (FUSE_IS_DAX(inode))
+			fuse_dax_inode_cleanup(inode);
 		fuse_queue_forget(fc, fi->forget, fi->nodeid, fi->nlookup);
 		fi->forget = NULL;
 	}
@@ -594,8 +598,10 @@
 		if (sb->s_bdev && sb->s_blocksize != FUSE_DEFAULT_BLKSIZE)
 			seq_printf(m, ",blksize=%lu", sb->s_blocksize);
 	}
-	if (fc->dax_dev)
+#ifdef CONFIG_FUSE_DAX
+	if (fc->dax)
 		seq_puts(m, ",dax");
+#endif
 
 	return 0;
 }
@@ -636,7 +642,6 @@
 	refcount_set(&fc->count, 1);
 	atomic_set(&fc->dev_count, 1);
 	init_waitqueue_head(&fc->blocked_waitq);
-	init_waitqueue_head(&fc->dax_range_waitq);
 	fuse_iqueue_init(&fc->iq, fiq_ops, fiq_priv);
 	INIT_LIST_HEAD(&fc->bg_queue);
 	INIT_LIST_HEAD(&fc->entry);
@@ -654,9 +659,6 @@
 	fc->pid_ns = get_pid_ns(task_active_pid_ns(current));
 	fc->user_ns = get_user_ns(user_ns);
 	fc->max_pages = FUSE_DEFAULT_MAX_PAGES_PER_REQ;
-	INIT_LIST_HEAD(&fc->free_ranges);
-	INIT_LIST_HEAD(&fc->busy_ranges);
-	INIT_DELAYED_WORK(&fc->dax_free_work, fuse_dax_free_mem_worker);
 }
 EXPORT_SYMBOL_GPL(fuse_conn_init);
 
@@ -665,8 +667,8 @@
 	if (refcount_dec_and_test(&fc->count)) {
 		struct fuse_iqueue *fiq = &fc->iq;
 
-		if (fc->dax_dev)
-			fuse_free_dax_mem_ranges(fc);
+		if (IS_ENABLED(CONFIG_FUSE_DAX))
+			fuse_dax_conn_free(fc);
 		if (fiq->ops->release)
 			fiq->ops->release(fiq);
 		put_pid_ns(fc->pid_ns);
@@ -983,10 +985,9 @@
 					min_t(unsigned int, FUSE_MAX_MAX_PAGES,
 					max_t(unsigned int, arg->max_pages, 1));
 			}
-			if ((arg->flags & FUSE_MAP_ALIGNMENT) &&
-			    (FUSE_DAX_SZ % (1ul << arg->map_alignment))) {
-				pr_err("FUSE: map_alignment %u incompatible with dax mem range size %u\n",
-				       arg->map_alignment, FUSE_DAX_SZ);
+			if (IS_ENABLED(CONFIG_FUSE_DAX) &&
+			    arg->flags & FUSE_MAP_ALIGNMENT &&
+			    !fuse_dax_check_alignment(fc, arg->map_alignment)) {
 				ok = false;
 			}
 		} else {
@@ -1032,7 +1033,7 @@
 		FUSE_PARALLEL_DIROPS | FUSE_HANDLE_KILLPRIV | FUSE_POSIX_ACL |
 		FUSE_ABORT_ERROR | FUSE_MAX_PAGES | FUSE_CACHE_SYMLINKS |
 		FUSE_NO_OPENDIR_SUPPORT | FUSE_EXPLICIT_INVAL_DATA |
-		FUSE_MAP_ALIGNMENT;
+		(IS_ENABLED(CONFIG_FUSE_DAX) * FUSE_MAP_ALIGNMENT);
 	ia->args.opcode = FUSE_INIT;
 	ia->args.in_numargs = 1;
 	ia->args.in_args[0].size = sizeof(ia->in);
@@ -1204,19 +1205,17 @@
 	if (sb->s_user_ns != &init_user_ns)
 		sb->s_xattr = fuse_no_acl_xattr_handlers;
 
-	if (ctx->dax_dev) {
-		err = fuse_dax_mem_range_init(fc, ctx->dax_dev);
-		if (err) {
-			pr_debug("fuse_dax_mem_range_init() returned %d\n", err);
-			goto err_free_ranges;
-		}
+	if (IS_ENABLED(CONFIG_FUSE_DAX)) {
+		err = fuse_dax_conn_alloc(fc, ctx->dax_dev);
+		if (err)
+			goto err;
 	}
 
 	if (ctx->fudptr) {
 		err = -ENOMEM;
 		fud = fuse_dev_alloc_install(fc);
 		if (!fud)
-			goto err_free_ranges;
+			goto err_free_dax;
 	}
 
 	fc->dev = sb->s_dev;
@@ -1239,7 +1238,6 @@
 	fc->destroy = ctx->destroy;
 	fc->no_control = ctx->no_control;
 	fc->no_force_umount = ctx->no_force_umount;
-	fc->dax_dev = ctx->dax_dev;
 
 	err = -ENOMEM;
 	root = fuse_get_root_inode(sb, ctx->rootmode);
@@ -1272,9 +1270,9 @@
  err_dev_free:
 	if (fud)
 		fuse_dev_free(fud);
- err_free_ranges:
-	if (ctx->dax_dev)
-		fuse_free_dax_mem_ranges(fc);
+ err_free_dax:
+	if (IS_ENABLED(CONFIG_FUSE_DAX))
+		fuse_dax_conn_free(fc);
  err:
 	return err;
 }
diff --git a/fs/fuse/virtio_fs.c b/fs/fuse/virtio_fs.c
index 9fbb981..c27feb2 100644
--- a/fs/fuse/virtio_fs.c
+++ b/fs/fuse/virtio_fs.c
@@ -806,7 +806,7 @@
 	struct dev_pagemap *pgmap;
 	bool have_cache;
 
-	if (!IS_ENABLED(CONFIG_DAX_DRIVER))
+	if (!IS_ENABLED(CONFIG_FUSE_DAX))
 		return 0;
 
 	/* Get cache region */
@@ -1347,7 +1347,8 @@
 	/* Stop dax worker. Soon evict_inodes() will be called which will
 	 * free all memory ranges belonging to all inodes.
 	 */
-	cancel_delayed_work_sync(&fc->dax_free_work);
+	if (IS_ENABLED(CONFIG_FUSE_DAX))
+		fuse_dax_cancel_work(fc);
 
 	/* Stop forget queue. Soon destroy will be sent */
 	spin_lock(&fsvq->lock);