Merge tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost

Pull virtio updates from Michael Tsirkin:
 "vhost,virtio and vdpa features, fixes, and cleanups:

   - mac vlan filter and stats support in mlx5 vdpa

   - irq hardening in virtio

   - performance improvements in virtio crypto

   - polling i/o support in virtio blk

   - ASID support in vhost

   - fixes, cleanups all over the place"

* tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost: (64 commits)
  vdpa: ifcvf: set pci driver data in probe
  vdpa/mlx5: Add RX MAC VLAN filter support
  vdpa/mlx5: Remove flow counter from steering
  vhost: rename vhost_work_dev_flush
  vhost-test: drop flush after vhost_dev_cleanup
  vhost-scsi: drop flush after vhost_dev_cleanup
  vhost_vsock: simplify vhost_vsock_flush()
  vhost_test: remove vhost_test_flush_vq()
  vhost_net: get rid of vhost_net_flush_vq() and extra flush calls
  vhost: flush dev once during vhost_dev_stop
  vhost: get rid of vhost_poll_flush() wrapper
  vhost-vdpa: return -EFAULT on copy_to_user() failure
  vdpasim: Off by one in vdpasim_set_group_asid()
  virtio: Directly use ida_alloc()/free()
  virtio: use WARN_ON() to warning illegal status value
  virtio: harden vring IRQ
  virtio: allow to unbreak virtqueue
  virtio-ccw: implement synchronize_cbs()
  virtio-mmio: implement synchronize_cbs()
  virtio-pci: implement synchronize_cbs()
  ...
diff --git a/drivers/block/virtio_blk.c b/drivers/block/virtio_blk.c
index d624cc8..6fc7850 100644
--- a/drivers/block/virtio_blk.c
+++ b/drivers/block/virtio_blk.c
@@ -37,6 +37,10 @@
 		 "0 for no limit. "
 		 "Values > nr_cpu_ids truncated to nr_cpu_ids.");
 
+static unsigned int poll_queues;
+module_param(poll_queues, uint, 0644);
+MODULE_PARM_DESC(poll_queues, "The number of dedicated virtqueues for polling I/O");
+
 static int major;
 static DEFINE_IDA(vd_index_ida);
 
@@ -74,6 +78,7 @@
 
 	/* num of vqs */
 	int num_vqs;
+	int io_queues[HCTX_MAX_TYPES];
 	struct virtio_blk_vq *vqs;
 };
 
@@ -96,8 +101,7 @@
 	}
 }
 
-static int virtblk_add_req(struct virtqueue *vq, struct virtblk_req *vbr,
-		struct scatterlist *data_sg, bool have_data)
+static int virtblk_add_req(struct virtqueue *vq, struct virtblk_req *vbr)
 {
 	struct scatterlist hdr, status, *sgs[3];
 	unsigned int num_out = 0, num_in = 0;
@@ -105,11 +109,11 @@
 	sg_init_one(&hdr, &vbr->out_hdr, sizeof(vbr->out_hdr));
 	sgs[num_out++] = &hdr;
 
-	if (have_data) {
+	if (vbr->sg_table.nents) {
 		if (vbr->out_hdr.type & cpu_to_virtio32(vq->vdev, VIRTIO_BLK_T_OUT))
-			sgs[num_out++] = data_sg;
+			sgs[num_out++] = vbr->sg_table.sgl;
 		else
-			sgs[num_out + num_in++] = data_sg;
+			sgs[num_out + num_in++] = vbr->sg_table.sgl;
 	}
 
 	sg_init_one(&status, &vbr->status, sizeof(vbr->status));
@@ -299,18 +303,12 @@
 		virtqueue_notify(vq->vq);
 }
 
-static blk_status_t virtio_queue_rq(struct blk_mq_hw_ctx *hctx,
-			   const struct blk_mq_queue_data *bd)
+static blk_status_t virtblk_prep_rq(struct blk_mq_hw_ctx *hctx,
+					struct virtio_blk *vblk,
+					struct request *req,
+					struct virtblk_req *vbr)
 {
-	struct virtio_blk *vblk = hctx->queue->queuedata;
-	struct request *req = bd->rq;
-	struct virtblk_req *vbr = blk_mq_rq_to_pdu(req);
-	unsigned long flags;
-	int num;
-	int qid = hctx->queue_num;
-	bool notify = false;
 	blk_status_t status;
-	int err;
 
 	status = virtblk_setup_cmd(vblk->vdev, req, vbr);
 	if (unlikely(status))
@@ -318,14 +316,33 @@
 
 	blk_mq_start_request(req);
 
-	num = virtblk_map_data(hctx, req, vbr);
-	if (unlikely(num < 0)) {
+	vbr->sg_table.nents = virtblk_map_data(hctx, req, vbr);
+	if (unlikely(vbr->sg_table.nents < 0)) {
 		virtblk_cleanup_cmd(req);
 		return BLK_STS_RESOURCE;
 	}
 
+	return BLK_STS_OK;
+}
+
+static blk_status_t virtio_queue_rq(struct blk_mq_hw_ctx *hctx,
+			   const struct blk_mq_queue_data *bd)
+{
+	struct virtio_blk *vblk = hctx->queue->queuedata;
+	struct request *req = bd->rq;
+	struct virtblk_req *vbr = blk_mq_rq_to_pdu(req);
+	unsigned long flags;
+	int qid = hctx->queue_num;
+	bool notify = false;
+	blk_status_t status;
+	int err;
+
+	status = virtblk_prep_rq(hctx, vblk, req, vbr);
+	if (unlikely(status))
+		return status;
+
 	spin_lock_irqsave(&vblk->vqs[qid].lock, flags);
-	err = virtblk_add_req(vblk->vqs[qid].vq, vbr, vbr->sg_table.sgl, num);
+	err = virtblk_add_req(vblk->vqs[qid].vq, vbr);
 	if (err) {
 		virtqueue_kick(vblk->vqs[qid].vq);
 		/* Don't stop the queue if -ENOMEM: we may have failed to
@@ -355,6 +372,75 @@
 	return BLK_STS_OK;
 }
 
+static bool virtblk_prep_rq_batch(struct request *req)
+{
+	struct virtio_blk *vblk = req->mq_hctx->queue->queuedata;
+	struct virtblk_req *vbr = blk_mq_rq_to_pdu(req);
+
+	req->mq_hctx->tags->rqs[req->tag] = req;
+
+	return virtblk_prep_rq(req->mq_hctx, vblk, req, vbr) == BLK_STS_OK;
+}
+
+static bool virtblk_add_req_batch(struct virtio_blk_vq *vq,
+					struct request **rqlist,
+					struct request **requeue_list)
+{
+	unsigned long flags;
+	int err;
+	bool kick;
+
+	spin_lock_irqsave(&vq->lock, flags);
+
+	while (!rq_list_empty(*rqlist)) {
+		struct request *req = rq_list_pop(rqlist);
+		struct virtblk_req *vbr = blk_mq_rq_to_pdu(req);
+
+		err = virtblk_add_req(vq->vq, vbr);
+		if (err) {
+			virtblk_unmap_data(req, vbr);
+			virtblk_cleanup_cmd(req);
+			rq_list_add(requeue_list, req);
+		}
+	}
+
+	kick = virtqueue_kick_prepare(vq->vq);
+	spin_unlock_irqrestore(&vq->lock, flags);
+
+	return kick;
+}
+
+static void virtio_queue_rqs(struct request **rqlist)
+{
+	struct request *req, *next, *prev = NULL;
+	struct request *requeue_list = NULL;
+
+	rq_list_for_each_safe(rqlist, req, next) {
+		struct virtio_blk_vq *vq = req->mq_hctx->driver_data;
+		bool kick;
+
+		if (!virtblk_prep_rq_batch(req)) {
+			rq_list_move(rqlist, &requeue_list, req, prev);
+			req = prev;
+			if (!req)
+				continue;
+		}
+
+		if (!next || req->mq_hctx != next->mq_hctx) {
+			req->rq_next = NULL;
+			kick = virtblk_add_req_batch(vq, rqlist, &requeue_list);
+			if (kick)
+				virtqueue_notify(vq->vq);
+
+			*rqlist = next;
+			prev = NULL;
+		} else
+			prev = req;
+	}
+
+	*rqlist = requeue_list;
+}
+
 /* return id (s/n) string for *disk to *id_str
  */
 static int virtblk_get_id(struct gendisk *disk, char *id_str)
@@ -512,6 +598,7 @@
 	const char **names;
 	struct virtqueue **vqs;
 	unsigned short num_vqs;
+	unsigned int num_poll_vqs;
 	struct virtio_device *vdev = vblk->vdev;
 	struct irq_affinity desc = { 0, };
 
@@ -520,6 +607,7 @@
 				   &num_vqs);
 	if (err)
 		num_vqs = 1;
+
 	if (!err && !num_vqs) {
 		dev_err(&vdev->dev, "MQ advertised but zero queues reported\n");
 		return -EINVAL;
@@ -529,6 +617,17 @@
 			min_not_zero(num_request_queues, nr_cpu_ids),
 			num_vqs);
 
+	num_poll_vqs = min_t(unsigned int, poll_queues, num_vqs - 1);
+
+	vblk->io_queues[HCTX_TYPE_DEFAULT] = num_vqs - num_poll_vqs;
+	vblk->io_queues[HCTX_TYPE_READ] = 0;
+	vblk->io_queues[HCTX_TYPE_POLL] = num_poll_vqs;
+
+	dev_info(&vdev->dev, "%d/%d/%d default/read/poll queues\n",
+				vblk->io_queues[HCTX_TYPE_DEFAULT],
+				vblk->io_queues[HCTX_TYPE_READ],
+				vblk->io_queues[HCTX_TYPE_POLL]);
+
 	vblk->vqs = kmalloc_array(num_vqs, sizeof(*vblk->vqs), GFP_KERNEL);
 	if (!vblk->vqs)
 		return -ENOMEM;
@@ -541,12 +640,18 @@
 		goto out;
 	}
 
-	for (i = 0; i < num_vqs; i++) {
+	for (i = 0; i < num_vqs - num_poll_vqs; i++) {
 		callbacks[i] = virtblk_done;
 		snprintf(vblk->vqs[i].name, VQ_NAME_LEN, "req.%d", i);
 		names[i] = vblk->vqs[i].name;
 	}
 
+	for (; i < num_vqs; i++) {
+		callbacks[i] = NULL;
+		snprintf(vblk->vqs[i].name, VQ_NAME_LEN, "req_poll.%d", i);
+		names[i] = vblk->vqs[i].name;
+	}
+
 	/* Discover virtqueues and write information to configuration.  */
 	err = virtio_find_vqs(vdev, num_vqs, vqs, callbacks, names, &desc);
 	if (err)
@@ -692,16 +797,90 @@
 static int virtblk_map_queues(struct blk_mq_tag_set *set)
 {
 	struct virtio_blk *vblk = set->driver_data;
+	int i, qoff;
 
-	return blk_mq_virtio_map_queues(&set->map[HCTX_TYPE_DEFAULT],
-					vblk->vdev, 0);
+	for (i = 0, qoff = 0; i < set->nr_maps; i++) {
+		struct blk_mq_queue_map *map = &set->map[i];
+
+		map->nr_queues = vblk->io_queues[i];
+		map->queue_offset = qoff;
+		qoff += map->nr_queues;
+
+		if (map->nr_queues == 0)
+			continue;
+
+		/*
+		 * Regular queues have interrupts and hence CPU affinity is
+		 * defined by the core virtio code, but polling queues have
+		 * no interrupts so we let the block layer assign CPU affinity.
+		 */
+		if (i == HCTX_TYPE_POLL)
+			blk_mq_map_queues(&set->map[i]);
+		else
+			blk_mq_virtio_map_queues(&set->map[i], vblk->vdev, 0);
+	}
+
+	return 0;
+}
+
+static void virtblk_complete_batch(struct io_comp_batch *iob)
+{
+	struct request *req;
+
+	rq_list_for_each(&iob->req_list, req) {
+		virtblk_unmap_data(req, blk_mq_rq_to_pdu(req));
+		virtblk_cleanup_cmd(req);
+	}
+	blk_mq_end_request_batch(iob);
+}
+
+static int virtblk_poll(struct blk_mq_hw_ctx *hctx, struct io_comp_batch *iob)
+{
+	struct virtio_blk *vblk = hctx->queue->queuedata;
+	struct virtio_blk_vq *vq = hctx->driver_data;
+	struct virtblk_req *vbr;
+	unsigned long flags;
+	unsigned int len;
+	int found = 0;
+
+	spin_lock_irqsave(&vq->lock, flags);
+
+	while ((vbr = virtqueue_get_buf(vq->vq, &len)) != NULL) {
+		struct request *req = blk_mq_rq_from_pdu(vbr);
+
+		found++;
+		if (!blk_mq_add_to_batch(req, iob, vbr->status,
+						virtblk_complete_batch))
+			blk_mq_complete_request(req);
+	}
+
+	if (found)
+		blk_mq_start_stopped_hw_queues(vblk->disk->queue, true);
+
+	spin_unlock_irqrestore(&vq->lock, flags);
+
+	return found;
+}
+
+static int virtblk_init_hctx(struct blk_mq_hw_ctx *hctx, void *data,
+			  unsigned int hctx_idx)
+{
+	struct virtio_blk *vblk = data;
+	struct virtio_blk_vq *vq = &vblk->vqs[hctx_idx];
+
+	WARN_ON(vblk->tag_set.tags[hctx_idx] != hctx->tags);
+	hctx->driver_data = vq;
+	return 0;
 }
 
 static const struct blk_mq_ops virtio_mq_ops = {
 	.queue_rq	= virtio_queue_rq,
+	.queue_rqs	= virtio_queue_rqs,
 	.commit_rqs	= virtio_commit_rqs,
+	.init_hctx	= virtblk_init_hctx,
 	.complete	= virtblk_request_done,
 	.map_queues	= virtblk_map_queues,
+	.poll		= virtblk_poll,
 };
 
 static unsigned int virtblk_queue_depth;
@@ -778,6 +957,9 @@
 		sizeof(struct scatterlist) * VIRTIO_BLK_INLINE_SG_CNT;
 	vblk->tag_set.driver_data = vblk;
 	vblk->tag_set.nr_hw_queues = vblk->num_vqs;
+	vblk->tag_set.nr_maps = 1;
+	if (vblk->io_queues[HCTX_TYPE_POLL])
+		vblk->tag_set.nr_maps = 3;
 
 	err = blk_mq_alloc_tag_set(&vblk->tag_set);
 	if (err)
diff --git a/drivers/crypto/virtio/virtio_crypto_akcipher_algs.c b/drivers/crypto/virtio/virtio_crypto_akcipher_algs.c
index f3ec942..2a60d05 100644
--- a/drivers/crypto/virtio/virtio_crypto_akcipher_algs.c
+++ b/drivers/crypto/virtio/virtio_crypto_akcipher_algs.c
@@ -90,9 +90,12 @@
 	}
 
 	akcipher_req = vc_akcipher_req->akcipher_req;
-	if (vc_akcipher_req->opcode != VIRTIO_CRYPTO_AKCIPHER_VERIFY)
+	if (vc_akcipher_req->opcode != VIRTIO_CRYPTO_AKCIPHER_VERIFY) {
+		/* actuall length maybe less than dst buffer */
+		akcipher_req->dst_len = len - sizeof(vc_req->status);
 		sg_copy_from_buffer(akcipher_req->dst, sg_nents(akcipher_req->dst),
 				    vc_akcipher_req->dst_buf, akcipher_req->dst_len);
+	}
 	virtio_crypto_akcipher_finalize_req(vc_akcipher_req, akcipher_req, error);
 }
 
@@ -103,54 +106,56 @@
 	struct scatterlist outhdr_sg, key_sg, inhdr_sg, *sgs[3];
 	struct virtio_crypto *vcrypto = ctx->vcrypto;
 	uint8_t *pkey;
-	unsigned int inlen;
 	int err;
 	unsigned int num_out = 0, num_in = 0;
+	struct virtio_crypto_op_ctrl_req *ctrl;
+	struct virtio_crypto_session_input *input;
+	struct virtio_crypto_ctrl_request *vc_ctrl_req;
 
 	pkey = kmemdup(key, keylen, GFP_ATOMIC);
 	if (!pkey)
 		return -ENOMEM;
 
-	spin_lock(&vcrypto->ctrl_lock);
-	memcpy(&vcrypto->ctrl.header, header, sizeof(vcrypto->ctrl.header));
-	memcpy(&vcrypto->ctrl.u, para, sizeof(vcrypto->ctrl.u));
-	vcrypto->input.status = cpu_to_le32(VIRTIO_CRYPTO_ERR);
+	vc_ctrl_req = kzalloc(sizeof(*vc_ctrl_req), GFP_KERNEL);
+	if (!vc_ctrl_req) {
+		err = -ENOMEM;
+		goto out;
+	}
 
-	sg_init_one(&outhdr_sg, &vcrypto->ctrl, sizeof(vcrypto->ctrl));
+	ctrl = &vc_ctrl_req->ctrl;
+	memcpy(&ctrl->header, header, sizeof(ctrl->header));
+	memcpy(&ctrl->u, para, sizeof(ctrl->u));
+	input = &vc_ctrl_req->input;
+	input->status = cpu_to_le32(VIRTIO_CRYPTO_ERR);
+
+	sg_init_one(&outhdr_sg, ctrl, sizeof(*ctrl));
 	sgs[num_out++] = &outhdr_sg;
 
 	sg_init_one(&key_sg, pkey, keylen);
 	sgs[num_out++] = &key_sg;
 
-	sg_init_one(&inhdr_sg, &vcrypto->input, sizeof(vcrypto->input));
+	sg_init_one(&inhdr_sg, input, sizeof(*input));
 	sgs[num_out + num_in++] = &inhdr_sg;
 
-	err = virtqueue_add_sgs(vcrypto->ctrl_vq, sgs, num_out, num_in, vcrypto, GFP_ATOMIC);
+	err = virtio_crypto_ctrl_vq_request(vcrypto, sgs, num_out, num_in, vc_ctrl_req);
 	if (err < 0)
 		goto out;
 
-	virtqueue_kick(vcrypto->ctrl_vq);
-	while (!virtqueue_get_buf(vcrypto->ctrl_vq, &inlen) &&
-	       !virtqueue_is_broken(vcrypto->ctrl_vq))
-		cpu_relax();
-
-	if (le32_to_cpu(vcrypto->input.status) != VIRTIO_CRYPTO_OK) {
+	if (le32_to_cpu(input->status) != VIRTIO_CRYPTO_OK) {
+		pr_err("virtio_crypto: Create session failed status: %u\n",
+			le32_to_cpu(input->status));
 		err = -EINVAL;
 		goto out;
 	}
 
-	ctx->session_id = le64_to_cpu(vcrypto->input.session_id);
+	ctx->session_id = le64_to_cpu(input->session_id);
 	ctx->session_valid = true;
 	err = 0;
 
 out:
-	spin_unlock(&vcrypto->ctrl_lock);
+	kfree(vc_ctrl_req);
 	kfree_sensitive(pkey);
 
-	if (err < 0)
-		pr_err("virtio_crypto: Create session failed status: %u\n",
-			le32_to_cpu(vcrypto->input.status));
-
 	return err;
 }
 
@@ -159,37 +164,41 @@
 	struct scatterlist outhdr_sg, inhdr_sg, *sgs[2];
 	struct virtio_crypto_destroy_session_req *destroy_session;
 	struct virtio_crypto *vcrypto = ctx->vcrypto;
-	unsigned int num_out = 0, num_in = 0, inlen;
+	unsigned int num_out = 0, num_in = 0;
 	int err;
+	struct virtio_crypto_op_ctrl_req *ctrl;
+	struct virtio_crypto_inhdr *ctrl_status;
+	struct virtio_crypto_ctrl_request *vc_ctrl_req;
 
-	spin_lock(&vcrypto->ctrl_lock);
-	if (!ctx->session_valid) {
-		err = 0;
-		goto out;
-	}
-	vcrypto->ctrl_status.status = VIRTIO_CRYPTO_ERR;
-	vcrypto->ctrl.header.opcode = cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_DESTROY_SESSION);
-	vcrypto->ctrl.header.queue_id = 0;
+	if (!ctx->session_valid)
+		return 0;
 
-	destroy_session = &vcrypto->ctrl.u.destroy_session;
+	vc_ctrl_req = kzalloc(sizeof(*vc_ctrl_req), GFP_KERNEL);
+	if (!vc_ctrl_req)
+		return -ENOMEM;
+
+	ctrl_status = &vc_ctrl_req->ctrl_status;
+	ctrl_status->status = VIRTIO_CRYPTO_ERR;
+	ctrl = &vc_ctrl_req->ctrl;
+	ctrl->header.opcode = cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_DESTROY_SESSION);
+	ctrl->header.queue_id = 0;
+
+	destroy_session = &ctrl->u.destroy_session;
 	destroy_session->session_id = cpu_to_le64(ctx->session_id);
 
-	sg_init_one(&outhdr_sg, &vcrypto->ctrl, sizeof(vcrypto->ctrl));
+	sg_init_one(&outhdr_sg, ctrl, sizeof(*ctrl));
 	sgs[num_out++] = &outhdr_sg;
 
-	sg_init_one(&inhdr_sg, &vcrypto->ctrl_status.status, sizeof(vcrypto->ctrl_status.status));
+	sg_init_one(&inhdr_sg, &ctrl_status->status, sizeof(ctrl_status->status));
 	sgs[num_out + num_in++] = &inhdr_sg;
 
-	err = virtqueue_add_sgs(vcrypto->ctrl_vq, sgs, num_out, num_in, vcrypto, GFP_ATOMIC);
+	err = virtio_crypto_ctrl_vq_request(vcrypto, sgs, num_out, num_in, vc_ctrl_req);
 	if (err < 0)
 		goto out;
 
-	virtqueue_kick(vcrypto->ctrl_vq);
-	while (!virtqueue_get_buf(vcrypto->ctrl_vq, &inlen) &&
-	       !virtqueue_is_broken(vcrypto->ctrl_vq))
-		cpu_relax();
-
-	if (vcrypto->ctrl_status.status != VIRTIO_CRYPTO_OK) {
+	if (ctrl_status->status != VIRTIO_CRYPTO_OK) {
+		pr_err("virtio_crypto: Close session failed status: %u, session_id: 0x%llx\n",
+			ctrl_status->status, destroy_session->session_id);
 		err = -EINVAL;
 		goto out;
 	}
@@ -198,11 +207,7 @@
 	ctx->session_valid = false;
 
 out:
-	spin_unlock(&vcrypto->ctrl_lock);
-	if (err < 0) {
-		pr_err("virtio_crypto: Close session failed status: %u, session_id: 0x%llx\n",
-			vcrypto->ctrl_status.status, destroy_session->session_id);
-	}
+	kfree(vc_ctrl_req);
 
 	return err;
 }
diff --git a/drivers/crypto/virtio/virtio_crypto_common.h b/drivers/crypto/virtio/virtio_crypto_common.h
index e693d4e..59a4c02 100644
--- a/drivers/crypto/virtio/virtio_crypto_common.h
+++ b/drivers/crypto/virtio/virtio_crypto_common.h
@@ -13,6 +13,7 @@
 #include <crypto/aead.h>
 #include <crypto/aes.h>
 #include <crypto/engine.h>
+#include <uapi/linux/virtio_crypto.h>
 
 
 /* Internal representation of a data virtqueue */
@@ -65,11 +66,6 @@
 	/* Maximum size of per request */
 	u64 max_size;
 
-	/* Control VQ buffers: protected by the ctrl_lock */
-	struct virtio_crypto_op_ctrl_req ctrl;
-	struct virtio_crypto_session_input input;
-	struct virtio_crypto_inhdr ctrl_status;
-
 	unsigned long status;
 	atomic_t ref_count;
 	struct list_head list;
@@ -85,6 +81,18 @@
 	__u64 session_id;
 };
 
+/*
+ * Note: there are padding fields in request, clear them to zero before
+ *       sending to host to avoid to divulge any information.
+ * Ex, virtio_crypto_ctrl_request::ctrl::u::destroy_session::padding[48]
+ */
+struct virtio_crypto_ctrl_request {
+	struct virtio_crypto_op_ctrl_req ctrl;
+	struct virtio_crypto_session_input input;
+	struct virtio_crypto_inhdr ctrl_status;
+	struct completion compl;
+};
+
 struct virtio_crypto_request;
 typedef void (*virtio_crypto_data_callback)
 		(struct virtio_crypto_request *vc_req, int len);
@@ -134,5 +142,8 @@
 void virtio_crypto_skcipher_algs_unregister(struct virtio_crypto *vcrypto);
 int virtio_crypto_akcipher_algs_register(struct virtio_crypto *vcrypto);
 void virtio_crypto_akcipher_algs_unregister(struct virtio_crypto *vcrypto);
+int virtio_crypto_ctrl_vq_request(struct virtio_crypto *vcrypto, struct scatterlist *sgs[],
+				  unsigned int out_sgs, unsigned int in_sgs,
+				  struct virtio_crypto_ctrl_request *vc_ctrl_req);
 
 #endif /* _VIRTIO_CRYPTO_COMMON_H */
diff --git a/drivers/crypto/virtio/virtio_crypto_core.c b/drivers/crypto/virtio/virtio_crypto_core.c
index c6f482d..1198bd3 100644
--- a/drivers/crypto/virtio/virtio_crypto_core.c
+++ b/drivers/crypto/virtio/virtio_crypto_core.c
@@ -22,6 +22,56 @@
 	}
 }
 
+static void virtio_crypto_ctrlq_callback(struct virtio_crypto_ctrl_request *vc_ctrl_req)
+{
+	complete(&vc_ctrl_req->compl);
+}
+
+static void virtcrypto_ctrlq_callback(struct virtqueue *vq)
+{
+	struct virtio_crypto *vcrypto = vq->vdev->priv;
+	struct virtio_crypto_ctrl_request *vc_ctrl_req;
+	unsigned long flags;
+	unsigned int len;
+
+	spin_lock_irqsave(&vcrypto->ctrl_lock, flags);
+	do {
+		virtqueue_disable_cb(vq);
+		while ((vc_ctrl_req = virtqueue_get_buf(vq, &len)) != NULL) {
+			spin_unlock_irqrestore(&vcrypto->ctrl_lock, flags);
+			virtio_crypto_ctrlq_callback(vc_ctrl_req);
+			spin_lock_irqsave(&vcrypto->ctrl_lock, flags);
+		}
+		if (unlikely(virtqueue_is_broken(vq)))
+			break;
+	} while (!virtqueue_enable_cb(vq));
+	spin_unlock_irqrestore(&vcrypto->ctrl_lock, flags);
+}
+
+int virtio_crypto_ctrl_vq_request(struct virtio_crypto *vcrypto, struct scatterlist *sgs[],
+		unsigned int out_sgs, unsigned int in_sgs,
+		struct virtio_crypto_ctrl_request *vc_ctrl_req)
+{
+	int err;
+	unsigned long flags;
+
+	init_completion(&vc_ctrl_req->compl);
+
+	spin_lock_irqsave(&vcrypto->ctrl_lock, flags);
+	err = virtqueue_add_sgs(vcrypto->ctrl_vq, sgs, out_sgs, in_sgs, vc_ctrl_req, GFP_ATOMIC);
+	if (err < 0) {
+		spin_unlock_irqrestore(&vcrypto->ctrl_lock, flags);
+		return err;
+	}
+
+	virtqueue_kick(vcrypto->ctrl_vq);
+	spin_unlock_irqrestore(&vcrypto->ctrl_lock, flags);
+
+	wait_for_completion(&vc_ctrl_req->compl);
+
+	return 0;
+}
+
 static void virtcrypto_dataq_callback(struct virtqueue *vq)
 {
 	struct virtio_crypto *vcrypto = vq->vdev->priv;
@@ -73,7 +123,7 @@
 		goto err_names;
 
 	/* Parameters for control virtqueue */
-	callbacks[total_vqs - 1] = NULL;
+	callbacks[total_vqs - 1] = virtcrypto_ctrlq_callback;
 	names[total_vqs - 1] = "controlq";
 
 	/* Allocate/initialize parameters for data virtqueues */
@@ -94,7 +144,8 @@
 		spin_lock_init(&vi->data_vq[i].lock);
 		vi->data_vq[i].vq = vqs[i];
 		/* Initialize crypto engine */
-		vi->data_vq[i].engine = crypto_engine_alloc_init(dev, 1);
+		vi->data_vq[i].engine = crypto_engine_alloc_init_and_set(dev, true, NULL, true,
+						virtqueue_get_vring_size(vqs[i]));
 		if (!vi->data_vq[i].engine) {
 			ret = -ENOMEM;
 			goto err_engine;
diff --git a/drivers/crypto/virtio/virtio_crypto_skcipher_algs.c b/drivers/crypto/virtio/virtio_crypto_skcipher_algs.c
index a618c46..e553cca 100644
--- a/drivers/crypto/virtio/virtio_crypto_skcipher_algs.c
+++ b/drivers/crypto/virtio/virtio_crypto_skcipher_algs.c
@@ -118,11 +118,14 @@
 		int encrypt)
 {
 	struct scatterlist outhdr, key_sg, inhdr, *sgs[3];
-	unsigned int tmp;
 	struct virtio_crypto *vcrypto = ctx->vcrypto;
 	int op = encrypt ? VIRTIO_CRYPTO_OP_ENCRYPT : VIRTIO_CRYPTO_OP_DECRYPT;
 	int err;
 	unsigned int num_out = 0, num_in = 0;
+	struct virtio_crypto_op_ctrl_req *ctrl;
+	struct virtio_crypto_session_input *input;
+	struct virtio_crypto_sym_create_session_req *sym_create_session;
+	struct virtio_crypto_ctrl_request *vc_ctrl_req;
 
 	/*
 	 * Avoid to do DMA from the stack, switch to using
@@ -133,26 +136,29 @@
 	if (!cipher_key)
 		return -ENOMEM;
 
-	spin_lock(&vcrypto->ctrl_lock);
+	vc_ctrl_req = kzalloc(sizeof(*vc_ctrl_req), GFP_KERNEL);
+	if (!vc_ctrl_req) {
+		err = -ENOMEM;
+		goto out;
+	}
+
 	/* Pad ctrl header */
-	vcrypto->ctrl.header.opcode =
-		cpu_to_le32(VIRTIO_CRYPTO_CIPHER_CREATE_SESSION);
-	vcrypto->ctrl.header.algo = cpu_to_le32(alg);
+	ctrl = &vc_ctrl_req->ctrl;
+	ctrl->header.opcode = cpu_to_le32(VIRTIO_CRYPTO_CIPHER_CREATE_SESSION);
+	ctrl->header.algo = cpu_to_le32(alg);
 	/* Set the default dataqueue id to 0 */
-	vcrypto->ctrl.header.queue_id = 0;
+	ctrl->header.queue_id = 0;
 
-	vcrypto->input.status = cpu_to_le32(VIRTIO_CRYPTO_ERR);
+	input = &vc_ctrl_req->input;
+	input->status = cpu_to_le32(VIRTIO_CRYPTO_ERR);
 	/* Pad cipher's parameters */
-	vcrypto->ctrl.u.sym_create_session.op_type =
-		cpu_to_le32(VIRTIO_CRYPTO_SYM_OP_CIPHER);
-	vcrypto->ctrl.u.sym_create_session.u.cipher.para.algo =
-		vcrypto->ctrl.header.algo;
-	vcrypto->ctrl.u.sym_create_session.u.cipher.para.keylen =
-		cpu_to_le32(keylen);
-	vcrypto->ctrl.u.sym_create_session.u.cipher.para.op =
-		cpu_to_le32(op);
+	sym_create_session = &ctrl->u.sym_create_session;
+	sym_create_session->op_type = cpu_to_le32(VIRTIO_CRYPTO_SYM_OP_CIPHER);
+	sym_create_session->u.cipher.para.algo = ctrl->header.algo;
+	sym_create_session->u.cipher.para.keylen = cpu_to_le32(keylen);
+	sym_create_session->u.cipher.para.op = cpu_to_le32(op);
 
-	sg_init_one(&outhdr, &vcrypto->ctrl, sizeof(vcrypto->ctrl));
+	sg_init_one(&outhdr, ctrl, sizeof(*ctrl));
 	sgs[num_out++] = &outhdr;
 
 	/* Set key */
@@ -160,45 +166,30 @@
 	sgs[num_out++] = &key_sg;
 
 	/* Return status and session id back */
-	sg_init_one(&inhdr, &vcrypto->input, sizeof(vcrypto->input));
+	sg_init_one(&inhdr, input, sizeof(*input));
 	sgs[num_out + num_in++] = &inhdr;
 
-	err = virtqueue_add_sgs(vcrypto->ctrl_vq, sgs, num_out,
-				num_in, vcrypto, GFP_ATOMIC);
-	if (err < 0) {
-		spin_unlock(&vcrypto->ctrl_lock);
-		kfree_sensitive(cipher_key);
-		return err;
-	}
-	virtqueue_kick(vcrypto->ctrl_vq);
+	err = virtio_crypto_ctrl_vq_request(vcrypto, sgs, num_out, num_in, vc_ctrl_req);
+	if (err < 0)
+		goto out;
 
-	/*
-	 * Trapping into the hypervisor, so the request should be
-	 * handled immediately.
-	 */
-	while (!virtqueue_get_buf(vcrypto->ctrl_vq, &tmp) &&
-	       !virtqueue_is_broken(vcrypto->ctrl_vq))
-		cpu_relax();
-
-	if (le32_to_cpu(vcrypto->input.status) != VIRTIO_CRYPTO_OK) {
-		spin_unlock(&vcrypto->ctrl_lock);
+	if (le32_to_cpu(input->status) != VIRTIO_CRYPTO_OK) {
 		pr_err("virtio_crypto: Create session failed status: %u\n",
-			le32_to_cpu(vcrypto->input.status));
-		kfree_sensitive(cipher_key);
-		return -EINVAL;
+			le32_to_cpu(input->status));
+		err = -EINVAL;
+		goto out;
 	}
 
 	if (encrypt)
-		ctx->enc_sess_info.session_id =
-			le64_to_cpu(vcrypto->input.session_id);
+		ctx->enc_sess_info.session_id = le64_to_cpu(input->session_id);
 	else
-		ctx->dec_sess_info.session_id =
-			le64_to_cpu(vcrypto->input.session_id);
+		ctx->dec_sess_info.session_id = le64_to_cpu(input->session_id);
 
-	spin_unlock(&vcrypto->ctrl_lock);
-
+	err = 0;
+out:
+	kfree(vc_ctrl_req);
 	kfree_sensitive(cipher_key);
-	return 0;
+	return err;
 }
 
 static int virtio_crypto_alg_skcipher_close_session(
@@ -206,60 +197,55 @@
 		int encrypt)
 {
 	struct scatterlist outhdr, status_sg, *sgs[2];
-	unsigned int tmp;
 	struct virtio_crypto_destroy_session_req *destroy_session;
 	struct virtio_crypto *vcrypto = ctx->vcrypto;
 	int err;
 	unsigned int num_out = 0, num_in = 0;
+	struct virtio_crypto_op_ctrl_req *ctrl;
+	struct virtio_crypto_inhdr *ctrl_status;
+	struct virtio_crypto_ctrl_request *vc_ctrl_req;
 
-	spin_lock(&vcrypto->ctrl_lock);
-	vcrypto->ctrl_status.status = VIRTIO_CRYPTO_ERR;
+	vc_ctrl_req = kzalloc(sizeof(*vc_ctrl_req), GFP_KERNEL);
+	if (!vc_ctrl_req)
+		return -ENOMEM;
+
+	ctrl_status = &vc_ctrl_req->ctrl_status;
+	ctrl_status->status = VIRTIO_CRYPTO_ERR;
 	/* Pad ctrl header */
-	vcrypto->ctrl.header.opcode =
-		cpu_to_le32(VIRTIO_CRYPTO_CIPHER_DESTROY_SESSION);
+	ctrl = &vc_ctrl_req->ctrl;
+	ctrl->header.opcode = cpu_to_le32(VIRTIO_CRYPTO_CIPHER_DESTROY_SESSION);
 	/* Set the default virtqueue id to 0 */
-	vcrypto->ctrl.header.queue_id = 0;
+	ctrl->header.queue_id = 0;
 
-	destroy_session = &vcrypto->ctrl.u.destroy_session;
+	destroy_session = &ctrl->u.destroy_session;
 
 	if (encrypt)
-		destroy_session->session_id =
-			cpu_to_le64(ctx->enc_sess_info.session_id);
+		destroy_session->session_id = cpu_to_le64(ctx->enc_sess_info.session_id);
 	else
-		destroy_session->session_id =
-			cpu_to_le64(ctx->dec_sess_info.session_id);
+		destroy_session->session_id = cpu_to_le64(ctx->dec_sess_info.session_id);
 
-	sg_init_one(&outhdr, &vcrypto->ctrl, sizeof(vcrypto->ctrl));
+	sg_init_one(&outhdr, ctrl, sizeof(*ctrl));
 	sgs[num_out++] = &outhdr;
 
 	/* Return status and session id back */
-	sg_init_one(&status_sg, &vcrypto->ctrl_status.status,
-		sizeof(vcrypto->ctrl_status.status));
+	sg_init_one(&status_sg, &ctrl_status->status, sizeof(ctrl_status->status));
 	sgs[num_out + num_in++] = &status_sg;
 
-	err = virtqueue_add_sgs(vcrypto->ctrl_vq, sgs, num_out,
-			num_in, vcrypto, GFP_ATOMIC);
-	if (err < 0) {
-		spin_unlock(&vcrypto->ctrl_lock);
-		return err;
-	}
-	virtqueue_kick(vcrypto->ctrl_vq);
+	err = virtio_crypto_ctrl_vq_request(vcrypto, sgs, num_out, num_in, vc_ctrl_req);
+	if (err < 0)
+		goto out;
 
-	while (!virtqueue_get_buf(vcrypto->ctrl_vq, &tmp) &&
-	       !virtqueue_is_broken(vcrypto->ctrl_vq))
-		cpu_relax();
-
-	if (vcrypto->ctrl_status.status != VIRTIO_CRYPTO_OK) {
-		spin_unlock(&vcrypto->ctrl_lock);
+	if (ctrl_status->status != VIRTIO_CRYPTO_OK) {
 		pr_err("virtio_crypto: Close session failed status: %u, session_id: 0x%llx\n",
-			vcrypto->ctrl_status.status,
-			destroy_session->session_id);
+			ctrl_status->status, destroy_session->session_id);
 
 		return -EINVAL;
 	}
-	spin_unlock(&vcrypto->ctrl_lock);
 
-	return 0;
+	err = 0;
+out:
+	kfree(vc_ctrl_req);
+	return err;
 }
 
 static int virtio_crypto_alg_skcipher_init_sessions(
diff --git a/drivers/s390/virtio/virtio_ccw.c b/drivers/s390/virtio/virtio_ccw.c
index d35e7a3..97e51c3 100644
--- a/drivers/s390/virtio/virtio_ccw.c
+++ b/drivers/s390/virtio/virtio_ccw.c
@@ -62,6 +62,7 @@
 	unsigned int revision; /* Transport revision */
 	wait_queue_head_t wait_q;
 	spinlock_t lock;
+	rwlock_t irq_lock;
 	struct mutex io_lock; /* Serializes I/O requests */
 	struct list_head virtqueues;
 	bool is_thinint;
@@ -970,6 +971,10 @@
 	ccw->flags = 0;
 	ccw->count = sizeof(status);
 	ccw->cda = (__u32)(unsigned long)&vcdev->dma_area->status;
+	/* We use ssch for setting the status which is a serializing
+	 * instruction that guarantees the memory writes have
+	 * completed before ssch.
+	 */
 	ret = ccw_io_helper(vcdev, ccw, VIRTIO_CCW_DOING_WRITE_STATUS);
 	/* Write failed? We assume status is unchanged. */
 	if (ret)
@@ -984,6 +989,30 @@
 	return dev_name(&vcdev->cdev->dev);
 }
 
+static void virtio_ccw_synchronize_cbs(struct virtio_device *vdev)
+{
+	struct virtio_ccw_device *vcdev = to_vc_device(vdev);
+	struct airq_info *info = vcdev->airq_info;
+
+	if (info) {
+		/*
+		 * This device uses adapter interrupts: synchronize with
+		 * vring_interrupt() called by virtio_airq_handler()
+		 * via the indicator area lock.
+		 */
+		write_lock_irq(&info->lock);
+		write_unlock_irq(&info->lock);
+	} else {
+		/* This device uses classic interrupts: synchronize
+		 * with vring_interrupt() called by
+		 * virtio_ccw_int_handler() via the per-device
+		 * irq_lock
+		 */
+		write_lock_irq(&vcdev->irq_lock);
+		write_unlock_irq(&vcdev->irq_lock);
+	}
+}
+
 static const struct virtio_config_ops virtio_ccw_config_ops = {
 	.get_features = virtio_ccw_get_features,
 	.finalize_features = virtio_ccw_finalize_features,
@@ -995,6 +1024,7 @@
 	.find_vqs = virtio_ccw_find_vqs,
 	.del_vqs = virtio_ccw_del_vqs,
 	.bus_name = virtio_ccw_bus_name,
+	.synchronize_cbs = virtio_ccw_synchronize_cbs,
 };
 
 
@@ -1106,6 +1136,8 @@
 			vcdev->err = -EIO;
 	}
 	virtio_ccw_check_activity(vcdev, activity);
+	/* Interrupts are disabled here */
+	read_lock(&vcdev->irq_lock);
 	for_each_set_bit(i, indicators(vcdev),
 			 sizeof(*indicators(vcdev)) * BITS_PER_BYTE) {
 		/* The bit clear must happen before the vring kick. */
@@ -1114,6 +1146,7 @@
 		vq = virtio_ccw_vq_by_ind(vcdev, i);
 		vring_interrupt(0, vq);
 	}
+	read_unlock(&vcdev->irq_lock);
 	if (test_bit(0, indicators2(vcdev))) {
 		virtio_config_changed(&vcdev->vdev);
 		clear_bit(0, indicators2(vcdev));
@@ -1284,6 +1317,7 @@
 	init_waitqueue_head(&vcdev->wait_q);
 	INIT_LIST_HEAD(&vcdev->virtqueues);
 	spin_lock_init(&vcdev->lock);
+	rwlock_init(&vcdev->irq_lock);
 	mutex_init(&vcdev->io_lock);
 
 	spin_lock_irqsave(get_ccwdev_lock(cdev), flags);
diff --git a/drivers/vdpa/alibaba/eni_vdpa.c b/drivers/vdpa/alibaba/eni_vdpa.c
index f480d54..5a09a09 100644
--- a/drivers/vdpa/alibaba/eni_vdpa.c
+++ b/drivers/vdpa/alibaba/eni_vdpa.c
@@ -470,7 +470,7 @@
 		return ret;
 
 	eni_vdpa = vdpa_alloc_device(struct eni_vdpa, vdpa,
-				     dev, &eni_vdpa_ops, NULL, false);
+				     dev, &eni_vdpa_ops, 1, 1, NULL, false);
 	if (IS_ERR(eni_vdpa)) {
 		ENI_ERR(pdev, "failed to allocate vDPA structure\n");
 		return PTR_ERR(eni_vdpa);
diff --git a/drivers/vdpa/ifcvf/ifcvf_main.c b/drivers/vdpa/ifcvf/ifcvf_main.c
index 4366320..0a56707 100644
--- a/drivers/vdpa/ifcvf/ifcvf_main.c
+++ b/drivers/vdpa/ifcvf/ifcvf_main.c
@@ -290,16 +290,16 @@
 	struct ifcvf_hw *vf = &adapter->vf;
 	int config_vector, ret;
 
-	if (vf->msix_vector_status == MSIX_VECTOR_DEV_SHARED)
-		return 0;
-
 	if (vf->msix_vector_status == MSIX_VECTOR_PER_VQ_AND_CONFIG)
-		/* vector 0 ~ vf->nr_vring for vqs, num vf->nr_vring vector for config interrupt */
 		config_vector = vf->nr_vring;
-
-	if (vf->msix_vector_status ==  MSIX_VECTOR_SHARED_VQ_AND_CONFIG)
+	else if (vf->msix_vector_status ==  MSIX_VECTOR_SHARED_VQ_AND_CONFIG)
 		/* vector 0 for vqs and 1 for config interrupt */
 		config_vector = 1;
+	else if (vf->msix_vector_status == MSIX_VECTOR_DEV_SHARED)
+		/* re-use the vqs vector */
+		return 0;
+	else
+		return -EINVAL;
 
 	snprintf(vf->config_msix_name, 256, "ifcvf[%s]-config\n",
 		 pci_name(pdev));
@@ -626,6 +626,11 @@
 	return  vf->config_size;
 }
 
+static u32 ifcvf_vdpa_get_vq_group(struct vdpa_device *vdpa, u16 idx)
+{
+	return 0;
+}
+
 static void ifcvf_vdpa_get_config(struct vdpa_device *vdpa_dev,
 				  unsigned int offset,
 				  void *buf, unsigned int len)
@@ -704,6 +709,7 @@
 	.get_device_id	= ifcvf_vdpa_get_device_id,
 	.get_vendor_id	= ifcvf_vdpa_get_vendor_id,
 	.get_vq_align	= ifcvf_vdpa_get_vq_align,
+	.get_vq_group	= ifcvf_vdpa_get_vq_group,
 	.get_config_size	= ifcvf_vdpa_get_config_size,
 	.get_config	= ifcvf_vdpa_get_config,
 	.set_config	= ifcvf_vdpa_set_config,
@@ -758,14 +764,13 @@
 	pdev = ifcvf_mgmt_dev->pdev;
 	dev = &pdev->dev;
 	adapter = vdpa_alloc_device(struct ifcvf_adapter, vdpa,
-				    dev, &ifc_vdpa_ops, name, false);
+				    dev, &ifc_vdpa_ops, 1, 1, name, false);
 	if (IS_ERR(adapter)) {
 		IFCVF_ERR(pdev, "Failed to allocate vDPA structure");
 		return PTR_ERR(adapter);
 	}
 
 	ifcvf_mgmt_dev->adapter = adapter;
-	pci_set_drvdata(pdev, ifcvf_mgmt_dev);
 
 	vf = &adapter->vf;
 	vf->dev_type = get_dev_type(pdev);
@@ -880,6 +885,8 @@
 		goto err;
 	}
 
+	pci_set_drvdata(pdev, ifcvf_mgmt_dev);
+
 	return 0;
 
 err:
diff --git a/drivers/vdpa/mlx5/core/mlx5_vdpa.h b/drivers/vdpa/mlx5/core/mlx5_vdpa.h
index daaf7b5..4410409 100644
--- a/drivers/vdpa/mlx5/core/mlx5_vdpa.h
+++ b/drivers/vdpa/mlx5/core/mlx5_vdpa.h
@@ -61,6 +61,8 @@
 	struct vringh_kiov riov;
 	struct vringh_kiov wiov;
 	unsigned short head;
+	unsigned int received_desc;
+	unsigned int completed_desc;
 };
 
 struct mlx5_vdpa_wq_ent {
diff --git a/drivers/vdpa/mlx5/net/mlx5_vnet.c b/drivers/vdpa/mlx5/net/mlx5_vnet.c
index e0de440..b7a9554 100644
--- a/drivers/vdpa/mlx5/net/mlx5_vnet.c
+++ b/drivers/vdpa/mlx5/net/mlx5_vnet.c
@@ -48,6 +48,8 @@
 
 #define MLX5_FEATURE(_mvdev, _feature) (!!((_mvdev)->actual_features & BIT_ULL(_feature)))
 
+#define MLX5V_UNTAGGED 0x1000
+
 struct mlx5_vdpa_net_resources {
 	u32 tisn;
 	u32 tdn;
@@ -119,6 +121,7 @@
 	struct mlx5_vdpa_umem umem2;
 	struct mlx5_vdpa_umem umem3;
 
+	u32 counter_set_id;
 	bool initialized;
 	int index;
 	u32 virtq_id;
@@ -143,6 +146,8 @@
 	return idx <= mvdev->max_idx;
 }
 
+#define MLX5V_MACVLAN_SIZE 256
+
 struct mlx5_vdpa_net {
 	struct mlx5_vdpa_dev mvdev;
 	struct mlx5_vdpa_net_resources res;
@@ -154,17 +159,22 @@
 	 * since memory map might change and we need to destroy and create
 	 * resources while driver in operational.
 	 */
-	struct mutex reslock;
+	struct rw_semaphore reslock;
 	struct mlx5_flow_table *rxft;
-	struct mlx5_fc *rx_counter;
-	struct mlx5_flow_handle *rx_rule_ucast;
-	struct mlx5_flow_handle *rx_rule_mcast;
 	bool setup;
 	u32 cur_num_vqs;
 	u32 rqt_size;
 	struct notifier_block nb;
 	struct vdpa_callback config_cb;
 	struct mlx5_vdpa_wq_ent cvq_ent;
+	struct hlist_head macvlan_hash[MLX5V_MACVLAN_SIZE];
+};
+
+struct macvlan_node {
+	struct hlist_node hlist;
+	struct mlx5_flow_handle *ucast_rule;
+	struct mlx5_flow_handle *mcast_rule;
+	u64 macvlan;
 };
 
 static void free_resources(struct mlx5_vdpa_net *ndev);
@@ -818,6 +828,12 @@
 	       (!!(features & BIT_ULL(VIRTIO_NET_F_GUEST_CSUM)) << 6);
 }
 
+static bool counters_supported(const struct mlx5_vdpa_dev *mvdev)
+{
+	return MLX5_CAP_GEN_64(mvdev->mdev, general_obj_types) &
+	       BIT_ULL(MLX5_OBJ_TYPE_VIRTIO_Q_COUNTERS);
+}
+
 static int create_virtqueue(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq)
 {
 	int inlen = MLX5_ST_SZ_BYTES(create_virtio_net_q_in);
@@ -872,6 +888,8 @@
 	MLX5_SET(virtio_q, vq_ctx, umem_3_id, mvq->umem3.id);
 	MLX5_SET(virtio_q, vq_ctx, umem_3_size, mvq->umem3.size);
 	MLX5_SET(virtio_q, vq_ctx, pd, ndev->mvdev.res.pdn);
+	if (counters_supported(&ndev->mvdev))
+		MLX5_SET(virtio_q, vq_ctx, counter_set_id, mvq->counter_set_id);
 
 	err = mlx5_cmd_exec(ndev->mvdev.mdev, in, inlen, out, sizeof(out));
 	if (err)
@@ -1135,6 +1153,47 @@
 	return err;
 }
 
+static int counter_set_alloc(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq)
+{
+	u32 in[MLX5_ST_SZ_DW(create_virtio_q_counters_in)] = {};
+	u32 out[MLX5_ST_SZ_DW(create_virtio_q_counters_out)] = {};
+	void *cmd_hdr;
+	int err;
+
+	if (!counters_supported(&ndev->mvdev))
+		return 0;
+
+	cmd_hdr = MLX5_ADDR_OF(create_virtio_q_counters_in, in, hdr);
+
+	MLX5_SET(general_obj_in_cmd_hdr, cmd_hdr, opcode, MLX5_CMD_OP_CREATE_GENERAL_OBJECT);
+	MLX5_SET(general_obj_in_cmd_hdr, cmd_hdr, obj_type, MLX5_OBJ_TYPE_VIRTIO_Q_COUNTERS);
+	MLX5_SET(general_obj_in_cmd_hdr, cmd_hdr, uid, ndev->mvdev.res.uid);
+
+	err = mlx5_cmd_exec(ndev->mvdev.mdev, in, sizeof(in), out, sizeof(out));
+	if (err)
+		return err;
+
+	mvq->counter_set_id = MLX5_GET(general_obj_out_cmd_hdr, out, obj_id);
+
+	return 0;
+}
+
+static void counter_set_dealloc(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq)
+{
+	u32 in[MLX5_ST_SZ_DW(destroy_virtio_q_counters_in)] = {};
+	u32 out[MLX5_ST_SZ_DW(destroy_virtio_q_counters_out)] = {};
+
+	if (!counters_supported(&ndev->mvdev))
+		return;
+
+	MLX5_SET(destroy_virtio_q_counters_in, in, hdr.opcode, MLX5_CMD_OP_DESTROY_GENERAL_OBJECT);
+	MLX5_SET(destroy_virtio_q_counters_in, in, hdr.obj_id, mvq->counter_set_id);
+	MLX5_SET(destroy_virtio_q_counters_in, in, hdr.uid, ndev->mvdev.res.uid);
+	MLX5_SET(destroy_virtio_q_counters_in, in, hdr.obj_type, MLX5_OBJ_TYPE_VIRTIO_Q_COUNTERS);
+	if (mlx5_cmd_exec(ndev->mvdev.mdev, in, sizeof(in), out, sizeof(out)))
+		mlx5_vdpa_warn(&ndev->mvdev, "dealloc counter set 0x%x\n", mvq->counter_set_id);
+}
+
 static int setup_vq(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq)
 {
 	u16 idx = mvq->index;
@@ -1162,6 +1221,10 @@
 	if (err)
 		goto err_connect;
 
+	err = counter_set_alloc(ndev, mvq);
+	if (err)
+		goto err_counter;
+
 	err = create_virtqueue(ndev, mvq);
 	if (err)
 		goto err_connect;
@@ -1179,6 +1242,8 @@
 	return 0;
 
 err_connect:
+	counter_set_dealloc(ndev, mvq);
+err_counter:
 	qp_destroy(ndev, &mvq->vqqp);
 err_vqqp:
 	qp_destroy(ndev, &mvq->fwqp);
@@ -1223,6 +1288,7 @@
 
 	suspend_vq(ndev, mvq);
 	destroy_virtqueue(ndev, mvq);
+	counter_set_dealloc(ndev, mvq);
 	qp_destroy(ndev, &mvq->vqqp);
 	qp_destroy(ndev, &mvq->fwqp);
 	cq_destroy(ndev, mvq->index);
@@ -1347,12 +1413,17 @@
 	mlx5_vdpa_destroy_tir(&ndev->mvdev, ndev->res.tirn);
 }
 
-static int add_fwd_to_tir(struct mlx5_vdpa_net *ndev)
+#define MAX_STEERING_ENT 0x8000
+#define MAX_STEERING_GROUPS 2
+
+static int mlx5_vdpa_add_mac_vlan_rules(struct mlx5_vdpa_net *ndev, u8 *mac,
+					u16 vid, bool tagged,
+					struct mlx5_flow_handle **ucast,
+					struct mlx5_flow_handle **mcast)
 {
-	struct mlx5_flow_destination dest[2] = {};
-	struct mlx5_flow_table_attr ft_attr = {};
+	struct mlx5_flow_destination dest = {};
 	struct mlx5_flow_act flow_act = {};
-	struct mlx5_flow_namespace *ns;
+	struct mlx5_flow_handle *rule;
 	struct mlx5_flow_spec *spec;
 	void *headers_c;
 	void *headers_v;
@@ -1365,85 +1436,178 @@
 		return -ENOMEM;
 
 	spec->match_criteria_enable = MLX5_MATCH_OUTER_HEADERS;
-	ft_attr.max_fte = 2;
-	ft_attr.autogroup.max_num_groups = 2;
-
-	ns = mlx5_get_flow_namespace(ndev->mvdev.mdev, MLX5_FLOW_NAMESPACE_BYPASS);
-	if (!ns) {
-		mlx5_vdpa_warn(&ndev->mvdev, "failed to get flow namespace\n");
-		err = -EOPNOTSUPP;
-		goto err_ns;
-	}
-
-	ndev->rxft = mlx5_create_auto_grouped_flow_table(ns, &ft_attr);
-	if (IS_ERR(ndev->rxft)) {
-		err = PTR_ERR(ndev->rxft);
-		goto err_ns;
-	}
-
-	ndev->rx_counter = mlx5_fc_create(ndev->mvdev.mdev, false);
-	if (IS_ERR(ndev->rx_counter)) {
-		err = PTR_ERR(ndev->rx_counter);
-		goto err_fc;
-	}
-
 	headers_c = MLX5_ADDR_OF(fte_match_param, spec->match_criteria, outer_headers);
-	dmac_c = MLX5_ADDR_OF(fte_match_param, headers_c, outer_headers.dmac_47_16);
-	memset(dmac_c, 0xff, ETH_ALEN);
 	headers_v = MLX5_ADDR_OF(fte_match_param, spec->match_value, outer_headers);
+	dmac_c = MLX5_ADDR_OF(fte_match_param, headers_c, outer_headers.dmac_47_16);
 	dmac_v = MLX5_ADDR_OF(fte_match_param, headers_v, outer_headers.dmac_47_16);
-	ether_addr_copy(dmac_v, ndev->config.mac);
-
-	flow_act.action = MLX5_FLOW_CONTEXT_ACTION_FWD_DEST | MLX5_FLOW_CONTEXT_ACTION_COUNT;
-	dest[0].type = MLX5_FLOW_DESTINATION_TYPE_TIR;
-	dest[0].tir_num = ndev->res.tirn;
-	dest[1].type = MLX5_FLOW_DESTINATION_TYPE_COUNTER;
-	dest[1].counter_id = mlx5_fc_id(ndev->rx_counter);
-	ndev->rx_rule_ucast = mlx5_add_flow_rules(ndev->rxft, spec, &flow_act, dest, 2);
-
-	if (IS_ERR(ndev->rx_rule_ucast)) {
-		err = PTR_ERR(ndev->rx_rule_ucast);
-		ndev->rx_rule_ucast = NULL;
-		goto err_rule_ucast;
+	memset(dmac_c, 0xff, ETH_ALEN);
+	ether_addr_copy(dmac_v, mac);
+	MLX5_SET(fte_match_set_lyr_2_4, headers_c, cvlan_tag, 1);
+	if (tagged) {
+		MLX5_SET(fte_match_set_lyr_2_4, headers_v, cvlan_tag, 1);
+		MLX5_SET_TO_ONES(fte_match_set_lyr_2_4, headers_c, first_vid);
+		MLX5_SET(fte_match_set_lyr_2_4, headers_c, first_vid, vid);
 	}
+	flow_act.action = MLX5_FLOW_CONTEXT_ACTION_FWD_DEST;
+	dest.type = MLX5_FLOW_DESTINATION_TYPE_TIR;
+	dest.tir_num = ndev->res.tirn;
+	rule = mlx5_add_flow_rules(ndev->rxft, spec, &flow_act, &dest, 1);
+	if (IS_ERR(rule))
+		return PTR_ERR(rule);
+
+	*ucast = rule;
 
 	memset(dmac_c, 0, ETH_ALEN);
 	memset(dmac_v, 0, ETH_ALEN);
 	dmac_c[0] = 1;
 	dmac_v[0] = 1;
-	flow_act.action = MLX5_FLOW_CONTEXT_ACTION_FWD_DEST;
-	ndev->rx_rule_mcast = mlx5_add_flow_rules(ndev->rxft, spec, &flow_act, dest, 1);
-	if (IS_ERR(ndev->rx_rule_mcast)) {
-		err = PTR_ERR(ndev->rx_rule_mcast);
-		ndev->rx_rule_mcast = NULL;
-		goto err_rule_mcast;
+	rule = mlx5_add_flow_rules(ndev->rxft, spec, &flow_act, &dest, 1);
+	kvfree(spec);
+	if (IS_ERR(rule)) {
+		err = PTR_ERR(rule);
+		goto err_mcast;
 	}
 
-	kvfree(spec);
+	*mcast = rule;
 	return 0;
 
-err_rule_mcast:
-	mlx5_del_flow_rules(ndev->rx_rule_ucast);
-	ndev->rx_rule_ucast = NULL;
-err_rule_ucast:
-	mlx5_fc_destroy(ndev->mvdev.mdev, ndev->rx_counter);
-err_fc:
-	mlx5_destroy_flow_table(ndev->rxft);
-err_ns:
-	kvfree(spec);
+err_mcast:
+	mlx5_del_flow_rules(*ucast);
 	return err;
 }
 
-static void remove_fwd_to_tir(struct mlx5_vdpa_net *ndev)
+static void mlx5_vdpa_del_mac_vlan_rules(struct mlx5_vdpa_net *ndev,
+					 struct mlx5_flow_handle *ucast,
+					 struct mlx5_flow_handle *mcast)
 {
-	if (!ndev->rx_rule_ucast)
+	mlx5_del_flow_rules(ucast);
+	mlx5_del_flow_rules(mcast);
+}
+
+static u64 search_val(u8 *mac, u16 vlan, bool tagged)
+{
+	u64 val;
+
+	if (!tagged)
+		vlan = MLX5V_UNTAGGED;
+
+	val = (u64)vlan << 48 |
+	      (u64)mac[0] << 40 |
+	      (u64)mac[1] << 32 |
+	      (u64)mac[2] << 24 |
+	      (u64)mac[3] << 16 |
+	      (u64)mac[4] << 8 |
+	      (u64)mac[5];
+
+	return val;
+}
+
+static struct macvlan_node *mac_vlan_lookup(struct mlx5_vdpa_net *ndev, u64 value)
+{
+	struct macvlan_node *pos;
+	u32 idx;
+
+	idx = hash_64(value, 8); // tbd 8
+	hlist_for_each_entry(pos, &ndev->macvlan_hash[idx], hlist) {
+		if (pos->macvlan == value)
+			return pos;
+	}
+	return NULL;
+}
+
+static int mac_vlan_add(struct mlx5_vdpa_net *ndev, u8 *mac, u16 vlan, bool tagged) // vlan -> vid
+{
+	struct macvlan_node *ptr;
+	u64 val;
+	u32 idx;
+	int err;
+
+	val = search_val(mac, vlan, tagged);
+	if (mac_vlan_lookup(ndev, val))
+		return -EEXIST;
+
+	ptr = kzalloc(sizeof(*ptr), GFP_KERNEL);
+	if (!ptr)
+		return -ENOMEM;
+
+	err = mlx5_vdpa_add_mac_vlan_rules(ndev, ndev->config.mac, vlan, tagged,
+					   &ptr->ucast_rule, &ptr->mcast_rule);
+	if (err)
+		goto err_add;
+
+	ptr->macvlan = val;
+	idx = hash_64(val, 8);
+	hlist_add_head(&ptr->hlist, &ndev->macvlan_hash[idx]);
+	return 0;
+
+err_add:
+	kfree(ptr);
+	return err;
+}
+
+static void mac_vlan_del(struct mlx5_vdpa_net *ndev, u8 *mac, u16 vlan, bool tagged)
+{
+	struct macvlan_node *ptr;
+
+	ptr = mac_vlan_lookup(ndev, search_val(mac, vlan, tagged));
+	if (!ptr)
 		return;
 
-	mlx5_del_flow_rules(ndev->rx_rule_mcast);
-	ndev->rx_rule_mcast = NULL;
-	mlx5_del_flow_rules(ndev->rx_rule_ucast);
-	ndev->rx_rule_ucast = NULL;
-	mlx5_fc_destroy(ndev->mvdev.mdev, ndev->rx_counter);
+	hlist_del(&ptr->hlist);
+	mlx5_vdpa_del_mac_vlan_rules(ndev, ptr->ucast_rule, ptr->mcast_rule);
+	kfree(ptr);
+}
+
+static void clear_mac_vlan_table(struct mlx5_vdpa_net *ndev)
+{
+	struct macvlan_node *pos;
+	struct hlist_node *n;
+	int i;
+
+	for (i = 0; i < MLX5V_MACVLAN_SIZE; i++) {
+		hlist_for_each_entry_safe(pos, n, &ndev->macvlan_hash[i], hlist) {
+			hlist_del(&pos->hlist);
+			mlx5_vdpa_del_mac_vlan_rules(ndev, pos->ucast_rule, pos->mcast_rule);
+			kfree(pos);
+		}
+	}
+}
+
+static int setup_steering(struct mlx5_vdpa_net *ndev)
+{
+	struct mlx5_flow_table_attr ft_attr = {};
+	struct mlx5_flow_namespace *ns;
+	int err;
+
+	ft_attr.max_fte = MAX_STEERING_ENT;
+	ft_attr.autogroup.max_num_groups = MAX_STEERING_GROUPS;
+
+	ns = mlx5_get_flow_namespace(ndev->mvdev.mdev, MLX5_FLOW_NAMESPACE_BYPASS);
+	if (!ns) {
+		mlx5_vdpa_warn(&ndev->mvdev, "failed to get flow namespace\n");
+		return -EOPNOTSUPP;
+	}
+
+	ndev->rxft = mlx5_create_auto_grouped_flow_table(ns, &ft_attr);
+	if (IS_ERR(ndev->rxft)) {
+		mlx5_vdpa_warn(&ndev->mvdev, "failed to create flow table\n");
+		return PTR_ERR(ndev->rxft);
+	}
+
+	err = mac_vlan_add(ndev, ndev->config.mac, 0, false);
+	if (err)
+		goto err_add;
+
+	return 0;
+
+err_add:
+	mlx5_destroy_flow_table(ndev->rxft);
+	return err;
+}
+
+static void teardown_steering(struct mlx5_vdpa_net *ndev)
+{
+	clear_mac_vlan_table(ndev);
 	mlx5_destroy_flow_table(ndev->rxft);
 }
 
@@ -1494,9 +1658,9 @@
 
 		/* Need recreate the flow table entry, so that the packet could forward back
 		 */
-		remove_fwd_to_tir(ndev);
+		mac_vlan_del(ndev, ndev->config.mac, 0, false);
 
-		if (add_fwd_to_tir(ndev)) {
+		if (mac_vlan_add(ndev, ndev->config.mac, 0, false)) {
 			mlx5_vdpa_warn(mvdev, "failed to insert forward rules, try to restore\n");
 
 			/* Although it hardly run here, we still need double check */
@@ -1520,7 +1684,7 @@
 
 			memcpy(ndev->config.mac, mac_back, ETH_ALEN);
 
-			if (add_fwd_to_tir(ndev))
+			if (mac_vlan_add(ndev, ndev->config.mac, 0, false))
 				mlx5_vdpa_warn(mvdev, "restore forward rules failed: insert forward rules failed\n");
 
 			break;
@@ -1622,6 +1786,42 @@
 	return status;
 }
 
+static virtio_net_ctrl_ack handle_ctrl_vlan(struct mlx5_vdpa_dev *mvdev, u8 cmd)
+{
+	struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
+	virtio_net_ctrl_ack status = VIRTIO_NET_ERR;
+	struct mlx5_control_vq *cvq = &mvdev->cvq;
+	__virtio16 vlan;
+	size_t read;
+	u16 id;
+
+	switch (cmd) {
+	case VIRTIO_NET_CTRL_VLAN_ADD:
+		read = vringh_iov_pull_iotlb(&cvq->vring, &cvq->riov, &vlan, sizeof(vlan));
+		if (read != sizeof(vlan))
+			break;
+
+		id = mlx5vdpa16_to_cpu(mvdev, vlan);
+		if (mac_vlan_add(ndev, ndev->config.mac, id, true))
+			break;
+
+		status = VIRTIO_NET_OK;
+		break;
+	case VIRTIO_NET_CTRL_VLAN_DEL:
+		read = vringh_iov_pull_iotlb(&cvq->vring, &cvq->riov, &vlan, sizeof(vlan));
+		if (read != sizeof(vlan))
+			break;
+
+		id = mlx5vdpa16_to_cpu(mvdev, vlan);
+		mac_vlan_del(ndev, ndev->config.mac, id, true);
+		break;
+	default:
+	break;
+}
+
+return status;
+}
+
 static void mlx5_cvq_kick_handler(struct work_struct *work)
 {
 	virtio_net_ctrl_ack status = VIRTIO_NET_ERR;
@@ -1638,7 +1838,7 @@
 	ndev = to_mlx5_vdpa_ndev(mvdev);
 	cvq = &mvdev->cvq;
 
-	mutex_lock(&ndev->reslock);
+	down_write(&ndev->reslock);
 
 	if (!(mvdev->status & VIRTIO_CONFIG_S_DRIVER_OK))
 		goto out;
@@ -1659,6 +1859,7 @@
 		if (read != sizeof(ctrl))
 			break;
 
+		cvq->received_desc++;
 		switch (ctrl.class) {
 		case VIRTIO_NET_CTRL_MAC:
 			status = handle_ctrl_mac(mvdev, ctrl.cmd);
@@ -1666,7 +1867,9 @@
 		case VIRTIO_NET_CTRL_MQ:
 			status = handle_ctrl_mq(mvdev, ctrl.cmd);
 			break;
-
+		case VIRTIO_NET_CTRL_VLAN:
+			status = handle_ctrl_vlan(mvdev, ctrl.cmd);
+			break;
 		default:
 			break;
 		}
@@ -1682,12 +1885,13 @@
 		if (vringh_need_notify_iotlb(&cvq->vring))
 			vringh_notify(&cvq->vring);
 
+		cvq->completed_desc++;
 		queue_work(mvdev->wq, &wqent->work);
 		break;
 	}
 
 out:
-	mutex_unlock(&ndev->reslock);
+	up_write(&ndev->reslock);
 }
 
 static void mlx5_vdpa_kick_vq(struct vdpa_device *vdev, u16 idx)
@@ -1888,6 +2092,11 @@
 	return PAGE_SIZE;
 }
 
+static u32 mlx5_vdpa_get_vq_group(struct vdpa_device *vdpa, u16 idx)
+{
+	return 0;
+}
+
 enum { MLX5_VIRTIO_NET_F_GUEST_CSUM = 1 << 9,
 	MLX5_VIRTIO_NET_F_CSUM = 1 << 10,
 	MLX5_VIRTIO_NET_F_HOST_TSO6 = 1 << 11,
@@ -1925,6 +2134,7 @@
 	mlx_vdpa_features |= BIT_ULL(VIRTIO_NET_F_MQ);
 	mlx_vdpa_features |= BIT_ULL(VIRTIO_NET_F_STATUS);
 	mlx_vdpa_features |= BIT_ULL(VIRTIO_NET_F_MTU);
+	mlx_vdpa_features |= BIT_ULL(VIRTIO_NET_F_CTRL_VLAN);
 
 	return mlx_vdpa_features;
 }
@@ -2185,7 +2395,7 @@
 	struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
 	int err;
 
-	WARN_ON(!mutex_is_locked(&ndev->reslock));
+	WARN_ON(!rwsem_is_locked(&ndev->reslock));
 
 	if (ndev->setup) {
 		mlx5_vdpa_warn(mvdev, "setup driver called for already setup driver\n");
@@ -2210,9 +2420,9 @@
 		goto err_tir;
 	}
 
-	err = add_fwd_to_tir(ndev);
+	err = setup_steering(ndev);
 	if (err) {
-		mlx5_vdpa_warn(mvdev, "add_fwd_to_tir\n");
+		mlx5_vdpa_warn(mvdev, "setup_steering\n");
 		goto err_fwd;
 	}
 	ndev->setup = true;
@@ -2233,12 +2443,12 @@
 static void teardown_driver(struct mlx5_vdpa_net *ndev)
 {
 
-	WARN_ON(!mutex_is_locked(&ndev->reslock));
+	WARN_ON(!rwsem_is_locked(&ndev->reslock));
 
 	if (!ndev->setup)
 		return;
 
-	remove_fwd_to_tir(ndev);
+	teardown_steering(ndev);
 	destroy_tir(ndev);
 	destroy_rqt(ndev);
 	teardown_virtqueues(ndev);
@@ -2263,7 +2473,7 @@
 
 	print_status(mvdev, status, true);
 
-	mutex_lock(&ndev->reslock);
+	down_write(&ndev->reslock);
 
 	if ((status ^ ndev->mvdev.status) & VIRTIO_CONFIG_S_DRIVER_OK) {
 		if (status & VIRTIO_CONFIG_S_DRIVER_OK) {
@@ -2279,14 +2489,14 @@
 	}
 
 	ndev->mvdev.status = status;
-	mutex_unlock(&ndev->reslock);
+	up_write(&ndev->reslock);
 	return;
 
 err_setup:
 	mlx5_vdpa_destroy_mr(&ndev->mvdev);
 	ndev->mvdev.status |= VIRTIO_CONFIG_S_FAILED;
 err_clear:
-	mutex_unlock(&ndev->reslock);
+	up_write(&ndev->reslock);
 }
 
 static int mlx5_vdpa_reset(struct vdpa_device *vdev)
@@ -2297,12 +2507,14 @@
 	print_status(mvdev, 0, true);
 	mlx5_vdpa_info(mvdev, "performing device reset\n");
 
-	mutex_lock(&ndev->reslock);
+	down_write(&ndev->reslock);
 	teardown_driver(ndev);
 	clear_vqs_ready(ndev);
 	mlx5_vdpa_destroy_mr(&ndev->mvdev);
 	ndev->mvdev.status = 0;
 	ndev->cur_num_vqs = 0;
+	ndev->mvdev.cvq.received_desc = 0;
+	ndev->mvdev.cvq.completed_desc = 0;
 	memset(ndev->event_cbs, 0, sizeof(*ndev->event_cbs) * (mvdev->max_vqs + 1));
 	ndev->mvdev.actual_features = 0;
 	++mvdev->generation;
@@ -2310,7 +2522,7 @@
 		if (mlx5_vdpa_create_mr(mvdev, NULL))
 			mlx5_vdpa_warn(mvdev, "create MR failed\n");
 	}
-	mutex_unlock(&ndev->reslock);
+	up_write(&ndev->reslock);
 
 	return 0;
 }
@@ -2343,14 +2555,15 @@
 	return mvdev->generation;
 }
 
-static int mlx5_vdpa_set_map(struct vdpa_device *vdev, struct vhost_iotlb *iotlb)
+static int mlx5_vdpa_set_map(struct vdpa_device *vdev, unsigned int asid,
+			     struct vhost_iotlb *iotlb)
 {
 	struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
 	struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
 	bool change_map;
 	int err;
 
-	mutex_lock(&ndev->reslock);
+	down_write(&ndev->reslock);
 
 	err = mlx5_vdpa_handle_set_map(mvdev, iotlb, &change_map);
 	if (err) {
@@ -2362,7 +2575,7 @@
 		err = mlx5_vdpa_change_map(mvdev, iotlb);
 
 err:
-	mutex_unlock(&ndev->reslock);
+	up_write(&ndev->reslock);
 	return err;
 }
 
@@ -2381,7 +2594,6 @@
 		mlx5_mpfs_del_mac(pfmdev, ndev->config.mac);
 	}
 	mlx5_vdpa_free_resources(&ndev->mvdev);
-	mutex_destroy(&ndev->reslock);
 	kfree(ndev->event_cbs);
 	kfree(ndev->vqs);
 }
@@ -2422,6 +2634,93 @@
 	return mvdev->actual_features;
 }
 
+static int counter_set_query(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *mvq,
+			     u64 *received_desc, u64 *completed_desc)
+{
+	u32 in[MLX5_ST_SZ_DW(query_virtio_q_counters_in)] = {};
+	u32 out[MLX5_ST_SZ_DW(query_virtio_q_counters_out)] = {};
+	void *cmd_hdr;
+	void *ctx;
+	int err;
+
+	if (!counters_supported(&ndev->mvdev))
+		return -EOPNOTSUPP;
+
+	if (mvq->fw_state != MLX5_VIRTIO_NET_Q_OBJECT_STATE_RDY)
+		return -EAGAIN;
+
+	cmd_hdr = MLX5_ADDR_OF(query_virtio_q_counters_in, in, hdr);
+
+	MLX5_SET(general_obj_in_cmd_hdr, cmd_hdr, opcode, MLX5_CMD_OP_QUERY_GENERAL_OBJECT);
+	MLX5_SET(general_obj_in_cmd_hdr, cmd_hdr, obj_type, MLX5_OBJ_TYPE_VIRTIO_Q_COUNTERS);
+	MLX5_SET(general_obj_in_cmd_hdr, cmd_hdr, uid, ndev->mvdev.res.uid);
+	MLX5_SET(general_obj_in_cmd_hdr, cmd_hdr, obj_id, mvq->counter_set_id);
+
+	err = mlx5_cmd_exec(ndev->mvdev.mdev, in, sizeof(in), out, sizeof(out));
+	if (err)
+		return err;
+
+	ctx = MLX5_ADDR_OF(query_virtio_q_counters_out, out, counters);
+	*received_desc = MLX5_GET64(virtio_q_counters, ctx, received_desc);
+	*completed_desc = MLX5_GET64(virtio_q_counters, ctx, completed_desc);
+	return 0;
+}
+
+static int mlx5_vdpa_get_vendor_vq_stats(struct vdpa_device *vdev, u16 idx,
+					 struct sk_buff *msg,
+					 struct netlink_ext_ack *extack)
+{
+	struct mlx5_vdpa_dev *mvdev = to_mvdev(vdev);
+	struct mlx5_vdpa_net *ndev = to_mlx5_vdpa_ndev(mvdev);
+	struct mlx5_vdpa_virtqueue *mvq;
+	struct mlx5_control_vq *cvq;
+	u64 received_desc;
+	u64 completed_desc;
+	int err = 0;
+
+	down_read(&ndev->reslock);
+	if (!is_index_valid(mvdev, idx)) {
+		NL_SET_ERR_MSG_MOD(extack, "virtqueue index is not valid");
+		err = -EINVAL;
+		goto out_err;
+	}
+
+	if (idx == ctrl_vq_idx(mvdev)) {
+		cvq = &mvdev->cvq;
+		received_desc = cvq->received_desc;
+		completed_desc = cvq->completed_desc;
+		goto out;
+	}
+
+	mvq = &ndev->vqs[idx];
+	err = counter_set_query(ndev, mvq, &received_desc, &completed_desc);
+	if (err) {
+		NL_SET_ERR_MSG_MOD(extack, "failed to query hardware");
+		goto out_err;
+	}
+
+out:
+	err = -EMSGSIZE;
+	if (nla_put_string(msg, VDPA_ATTR_DEV_VENDOR_ATTR_NAME, "received_desc"))
+		goto out_err;
+
+	if (nla_put_u64_64bit(msg, VDPA_ATTR_DEV_VENDOR_ATTR_VALUE, received_desc,
+			      VDPA_ATTR_PAD))
+		goto out_err;
+
+	if (nla_put_string(msg, VDPA_ATTR_DEV_VENDOR_ATTR_NAME, "completed_desc"))
+		goto out_err;
+
+	if (nla_put_u64_64bit(msg, VDPA_ATTR_DEV_VENDOR_ATTR_VALUE, completed_desc,
+			      VDPA_ATTR_PAD))
+		goto out_err;
+
+	err = 0;
+out_err:
+	up_read(&ndev->reslock);
+	return err;
+}
+
 static const struct vdpa_config_ops mlx5_vdpa_ops = {
 	.set_vq_address = mlx5_vdpa_set_vq_address,
 	.set_vq_num = mlx5_vdpa_set_vq_num,
@@ -2431,9 +2730,11 @@
 	.get_vq_ready = mlx5_vdpa_get_vq_ready,
 	.set_vq_state = mlx5_vdpa_set_vq_state,
 	.get_vq_state = mlx5_vdpa_get_vq_state,
+	.get_vendor_vq_stats = mlx5_vdpa_get_vendor_vq_stats,
 	.get_vq_notification = mlx5_get_vq_notification,
 	.get_vq_irq = mlx5_get_vq_irq,
 	.get_vq_align = mlx5_vdpa_get_vq_align,
+	.get_vq_group = mlx5_vdpa_get_vq_group,
 	.get_device_features = mlx5_vdpa_get_device_features,
 	.set_driver_features = mlx5_vdpa_set_driver_features,
 	.get_driver_features = mlx5_vdpa_get_driver_features,
@@ -2669,7 +2970,7 @@
 	}
 
 	ndev = vdpa_alloc_device(struct mlx5_vdpa_net, mvdev.vdev, mdev->device, &mlx5_vdpa_ops,
-				 name, false);
+				 1, 1, name, false);
 	if (IS_ERR(ndev))
 		return PTR_ERR(ndev);
 
@@ -2686,18 +2987,18 @@
 	}
 
 	init_mvqs(ndev);
-	mutex_init(&ndev->reslock);
+	init_rwsem(&ndev->reslock);
 	config = &ndev->config;
 
 	if (add_config->mask & BIT_ULL(VDPA_ATTR_DEV_NET_CFG_MTU)) {
 		err = config_func_mtu(mdev, add_config->net.mtu);
 		if (err)
-			goto err_mtu;
+			goto err_alloc;
 	}
 
 	err = query_mtu(mdev, &mtu);
 	if (err)
-		goto err_mtu;
+		goto err_alloc;
 
 	ndev->config.mtu = cpu_to_mlx5vdpa16(mvdev, mtu);
 
@@ -2711,14 +3012,14 @@
 	} else {
 		err = mlx5_query_nic_vport_mac_address(mdev, 0, 0, config->mac);
 		if (err)
-			goto err_mtu;
+			goto err_alloc;
 	}
 
 	if (!is_zero_ether_addr(config->mac)) {
 		pfmdev = pci_get_drvdata(pci_physfn(mdev->pdev));
 		err = mlx5_mpfs_add_mac(pfmdev, config->mac);
 		if (err)
-			goto err_mtu;
+			goto err_alloc;
 
 		ndev->mvdev.mlx_features |= BIT_ULL(VIRTIO_NET_F_MAC);
 	}
@@ -2768,8 +3069,6 @@
 err_mpfs:
 	if (!is_zero_ether_addr(config->mac))
 		mlx5_mpfs_del_mac(pfmdev, config->mac);
-err_mtu:
-	mutex_destroy(&ndev->reslock);
 err_alloc:
 	put_device(&mvdev->vdev.dev);
 	return err;
diff --git a/drivers/vdpa/vdpa.c b/drivers/vdpa/vdpa.c
index 2b75c00..f15fb11 100644
--- a/drivers/vdpa/vdpa.c
+++ b/drivers/vdpa/vdpa.c
@@ -18,14 +18,14 @@
 
 static LIST_HEAD(mdev_head);
 /* A global mutex that protects vdpa management device and device level operations. */
-static DEFINE_MUTEX(vdpa_dev_mutex);
+static DECLARE_RWSEM(vdpa_dev_lock);
 static DEFINE_IDA(vdpa_index_ida);
 
 void vdpa_set_status(struct vdpa_device *vdev, u8 status)
 {
-	mutex_lock(&vdev->cf_mutex);
+	down_write(&vdev->cf_lock);
 	vdev->config->set_status(vdev, status);
-	mutex_unlock(&vdev->cf_mutex);
+	up_write(&vdev->cf_lock);
 }
 EXPORT_SYMBOL(vdpa_set_status);
 
@@ -148,7 +148,6 @@
 		ops->free(vdev);
 
 	ida_simple_remove(&vdpa_index_ida, vdev->index);
-	mutex_destroy(&vdev->cf_mutex);
 	kfree(vdev->driver_override);
 	kfree(vdev);
 }
@@ -159,6 +158,8 @@
  * initialized but before registered.
  * @parent: the parent device
  * @config: the bus operations that is supported by this device
+ * @ngroups: number of groups supported by this device
+ * @nas: number of address spaces supported by this device
  * @size: size of the parent structure that contains private data
  * @name: name of the vdpa device; optional.
  * @use_va: indicate whether virtual address must be used by this device
@@ -171,6 +172,7 @@
  */
 struct vdpa_device *__vdpa_alloc_device(struct device *parent,
 					const struct vdpa_config_ops *config,
+					unsigned int ngroups, unsigned int nas,
 					size_t size, const char *name,
 					bool use_va)
 {
@@ -203,6 +205,8 @@
 	vdev->config = config;
 	vdev->features_valid = false;
 	vdev->use_va = use_va;
+	vdev->ngroups = ngroups;
+	vdev->nas = nas;
 
 	if (name)
 		err = dev_set_name(&vdev->dev, "%s", name);
@@ -211,7 +215,7 @@
 	if (err)
 		goto err_name;
 
-	mutex_init(&vdev->cf_mutex);
+	init_rwsem(&vdev->cf_lock);
 	device_initialize(&vdev->dev);
 
 	return vdev;
@@ -238,7 +242,7 @@
 
 	vdev->nvqs = nvqs;
 
-	lockdep_assert_held(&vdpa_dev_mutex);
+	lockdep_assert_held(&vdpa_dev_lock);
 	dev = bus_find_device(&vdpa_bus, NULL, dev_name(&vdev->dev), vdpa_name_match);
 	if (dev) {
 		put_device(dev);
@@ -278,9 +282,9 @@
 {
 	int err;
 
-	mutex_lock(&vdpa_dev_mutex);
+	down_write(&vdpa_dev_lock);
 	err = __vdpa_register_device(vdev, nvqs);
-	mutex_unlock(&vdpa_dev_mutex);
+	up_write(&vdpa_dev_lock);
 	return err;
 }
 EXPORT_SYMBOL_GPL(vdpa_register_device);
@@ -293,7 +297,7 @@
  */
 void _vdpa_unregister_device(struct vdpa_device *vdev)
 {
-	lockdep_assert_held(&vdpa_dev_mutex);
+	lockdep_assert_held(&vdpa_dev_lock);
 	WARN_ON(!vdev->mdev);
 	device_unregister(&vdev->dev);
 }
@@ -305,9 +309,9 @@
  */
 void vdpa_unregister_device(struct vdpa_device *vdev)
 {
-	mutex_lock(&vdpa_dev_mutex);
+	down_write(&vdpa_dev_lock);
 	device_unregister(&vdev->dev);
-	mutex_unlock(&vdpa_dev_mutex);
+	up_write(&vdpa_dev_lock);
 }
 EXPORT_SYMBOL_GPL(vdpa_unregister_device);
 
@@ -352,9 +356,9 @@
 		return -EINVAL;
 
 	INIT_LIST_HEAD(&mdev->list);
-	mutex_lock(&vdpa_dev_mutex);
+	down_write(&vdpa_dev_lock);
 	list_add_tail(&mdev->list, &mdev_head);
-	mutex_unlock(&vdpa_dev_mutex);
+	up_write(&vdpa_dev_lock);
 	return 0;
 }
 EXPORT_SYMBOL_GPL(vdpa_mgmtdev_register);
@@ -371,14 +375,14 @@
 
 void vdpa_mgmtdev_unregister(struct vdpa_mgmt_dev *mdev)
 {
-	mutex_lock(&vdpa_dev_mutex);
+	down_write(&vdpa_dev_lock);
 
 	list_del(&mdev->list);
 
 	/* Filter out all the entries belong to this management device and delete it. */
 	bus_for_each_dev(&vdpa_bus, NULL, mdev, vdpa_match_remove);
 
-	mutex_unlock(&vdpa_dev_mutex);
+	up_write(&vdpa_dev_lock);
 }
 EXPORT_SYMBOL_GPL(vdpa_mgmtdev_unregister);
 
@@ -407,9 +411,9 @@
 void vdpa_get_config(struct vdpa_device *vdev, unsigned int offset,
 		     void *buf, unsigned int len)
 {
-	mutex_lock(&vdev->cf_mutex);
+	down_read(&vdev->cf_lock);
 	vdpa_get_config_unlocked(vdev, offset, buf, len);
-	mutex_unlock(&vdev->cf_mutex);
+	up_read(&vdev->cf_lock);
 }
 EXPORT_SYMBOL_GPL(vdpa_get_config);
 
@@ -423,9 +427,9 @@
 void vdpa_set_config(struct vdpa_device *vdev, unsigned int offset,
 		     const void *buf, unsigned int length)
 {
-	mutex_lock(&vdev->cf_mutex);
+	down_write(&vdev->cf_lock);
 	vdev->config->set_config(vdev, offset, buf, length);
-	mutex_unlock(&vdev->cf_mutex);
+	up_write(&vdev->cf_lock);
 }
 EXPORT_SYMBOL_GPL(vdpa_set_config);
 
@@ -532,17 +536,17 @@
 	if (!msg)
 		return -ENOMEM;
 
-	mutex_lock(&vdpa_dev_mutex);
+	down_read(&vdpa_dev_lock);
 	mdev = vdpa_mgmtdev_get_from_attr(info->attrs);
 	if (IS_ERR(mdev)) {
-		mutex_unlock(&vdpa_dev_mutex);
+		up_read(&vdpa_dev_lock);
 		NL_SET_ERR_MSG_MOD(info->extack, "Fail to find the specified mgmt device");
 		err = PTR_ERR(mdev);
 		goto out;
 	}
 
 	err = vdpa_mgmtdev_fill(mdev, msg, info->snd_portid, info->snd_seq, 0);
-	mutex_unlock(&vdpa_dev_mutex);
+	up_read(&vdpa_dev_lock);
 	if (err)
 		goto out;
 	err = genlmsg_reply(msg, info);
@@ -561,7 +565,7 @@
 	int idx = 0;
 	int err;
 
-	mutex_lock(&vdpa_dev_mutex);
+	down_read(&vdpa_dev_lock);
 	list_for_each_entry(mdev, &mdev_head, list) {
 		if (idx < start) {
 			idx++;
@@ -574,7 +578,7 @@
 		idx++;
 	}
 out:
-	mutex_unlock(&vdpa_dev_mutex);
+	up_read(&vdpa_dev_lock);
 	cb->args[0] = idx;
 	return msg->len;
 }
@@ -627,7 +631,7 @@
 	    !netlink_capable(skb, CAP_NET_ADMIN))
 		return -EPERM;
 
-	mutex_lock(&vdpa_dev_mutex);
+	down_write(&vdpa_dev_lock);
 	mdev = vdpa_mgmtdev_get_from_attr(info->attrs);
 	if (IS_ERR(mdev)) {
 		NL_SET_ERR_MSG_MOD(info->extack, "Fail to find the specified management device");
@@ -643,7 +647,7 @@
 
 	err = mdev->ops->dev_add(mdev, name, &config);
 err:
-	mutex_unlock(&vdpa_dev_mutex);
+	up_write(&vdpa_dev_lock);
 	return err;
 }
 
@@ -659,7 +663,7 @@
 		return -EINVAL;
 	name = nla_data(info->attrs[VDPA_ATTR_DEV_NAME]);
 
-	mutex_lock(&vdpa_dev_mutex);
+	down_write(&vdpa_dev_lock);
 	dev = bus_find_device(&vdpa_bus, NULL, name, vdpa_name_match);
 	if (!dev) {
 		NL_SET_ERR_MSG_MOD(info->extack, "device not found");
@@ -677,7 +681,7 @@
 mdev_err:
 	put_device(dev);
 dev_err:
-	mutex_unlock(&vdpa_dev_mutex);
+	up_write(&vdpa_dev_lock);
 	return err;
 }
 
@@ -743,7 +747,7 @@
 	if (!msg)
 		return -ENOMEM;
 
-	mutex_lock(&vdpa_dev_mutex);
+	down_read(&vdpa_dev_lock);
 	dev = bus_find_device(&vdpa_bus, NULL, devname, vdpa_name_match);
 	if (!dev) {
 		NL_SET_ERR_MSG_MOD(info->extack, "device not found");
@@ -756,14 +760,19 @@
 		goto mdev_err;
 	}
 	err = vdpa_dev_fill(vdev, msg, info->snd_portid, info->snd_seq, 0, info->extack);
-	if (!err)
-		err = genlmsg_reply(msg, info);
+	if (err)
+		goto mdev_err;
+
+	err = genlmsg_reply(msg, info);
+	put_device(dev);
+	up_read(&vdpa_dev_lock);
+	return err;
+
 mdev_err:
 	put_device(dev);
 err:
-	mutex_unlock(&vdpa_dev_mutex);
-	if (err)
-		nlmsg_free(msg);
+	up_read(&vdpa_dev_lock);
+	nlmsg_free(msg);
 	return err;
 }
 
@@ -804,9 +813,9 @@
 	info.start_idx = cb->args[0];
 	info.idx = 0;
 
-	mutex_lock(&vdpa_dev_mutex);
+	down_read(&vdpa_dev_lock);
 	bus_for_each_dev(&vdpa_bus, NULL, &info, vdpa_dev_dump);
-	mutex_unlock(&vdpa_dev_mutex);
+	up_read(&vdpa_dev_lock);
 	cb->args[0] = info.idx;
 	return msg->len;
 }
@@ -861,7 +870,7 @@
 	u8 status;
 	int err;
 
-	mutex_lock(&vdev->cf_mutex);
+	down_read(&vdev->cf_lock);
 	status = vdev->config->get_status(vdev);
 	if (!(status & VIRTIO_CONFIG_S_FEATURES_OK)) {
 		NL_SET_ERR_MSG_MOD(extack, "Features negotiation not completed");
@@ -898,14 +907,116 @@
 	if (err)
 		goto msg_err;
 
-	mutex_unlock(&vdev->cf_mutex);
+	up_read(&vdev->cf_lock);
 	genlmsg_end(msg, hdr);
 	return 0;
 
 msg_err:
 	genlmsg_cancel(msg, hdr);
 out:
-	mutex_unlock(&vdev->cf_mutex);
+	up_read(&vdev->cf_lock);
+	return err;
+}
+
+static int vdpa_fill_stats_rec(struct vdpa_device *vdev, struct sk_buff *msg,
+			       struct genl_info *info, u32 index)
+{
+	struct virtio_net_config config = {};
+	u64 features;
+	u16 max_vqp;
+	u8 status;
+	int err;
+
+	status = vdev->config->get_status(vdev);
+	if (!(status & VIRTIO_CONFIG_S_FEATURES_OK)) {
+		NL_SET_ERR_MSG_MOD(info->extack, "feature negotiation not complete");
+		return -EAGAIN;
+	}
+	vdpa_get_config_unlocked(vdev, 0, &config, sizeof(config));
+
+	max_vqp = le16_to_cpu(config.max_virtqueue_pairs);
+	if (nla_put_u16(msg, VDPA_ATTR_DEV_NET_CFG_MAX_VQP, max_vqp))
+		return -EMSGSIZE;
+
+	features = vdev->config->get_driver_features(vdev);
+	if (nla_put_u64_64bit(msg, VDPA_ATTR_DEV_NEGOTIATED_FEATURES,
+			      features, VDPA_ATTR_PAD))
+		return -EMSGSIZE;
+
+	if (nla_put_u32(msg, VDPA_ATTR_DEV_QUEUE_INDEX, index))
+		return -EMSGSIZE;
+
+	err = vdev->config->get_vendor_vq_stats(vdev, index, msg, info->extack);
+	if (err)
+		return err;
+
+	return 0;
+}
+
+static int vendor_stats_fill(struct vdpa_device *vdev, struct sk_buff *msg,
+			     struct genl_info *info, u32 index)
+{
+	int err;
+
+	down_read(&vdev->cf_lock);
+	if (!vdev->config->get_vendor_vq_stats) {
+		err = -EOPNOTSUPP;
+		goto out;
+	}
+
+	err = vdpa_fill_stats_rec(vdev, msg, info, index);
+out:
+	up_read(&vdev->cf_lock);
+	return err;
+}
+
+static int vdpa_dev_vendor_stats_fill(struct vdpa_device *vdev,
+				      struct sk_buff *msg,
+				      struct genl_info *info, u32 index)
+{
+	u32 device_id;
+	void *hdr;
+	int err;
+	u32 portid = info->snd_portid;
+	u32 seq = info->snd_seq;
+	u32 flags = 0;
+
+	hdr = genlmsg_put(msg, portid, seq, &vdpa_nl_family, flags,
+			  VDPA_CMD_DEV_VSTATS_GET);
+	if (!hdr)
+		return -EMSGSIZE;
+
+	if (nla_put_string(msg, VDPA_ATTR_DEV_NAME, dev_name(&vdev->dev))) {
+		err = -EMSGSIZE;
+		goto undo_msg;
+	}
+
+	device_id = vdev->config->get_device_id(vdev);
+	if (nla_put_u32(msg, VDPA_ATTR_DEV_ID, device_id)) {
+		err = -EMSGSIZE;
+		goto undo_msg;
+	}
+
+	switch (device_id) {
+	case VIRTIO_ID_NET:
+		if (index > VIRTIO_NET_CTRL_MQ_VQ_PAIRS_MAX) {
+			NL_SET_ERR_MSG_MOD(info->extack, "queue index excceeds max value");
+			err = -ERANGE;
+			break;
+		}
+
+		err = vendor_stats_fill(vdev, msg, info, index);
+		break;
+	default:
+		err = -EOPNOTSUPP;
+		break;
+	}
+	genlmsg_end(msg, hdr);
+
+	return err;
+
+undo_msg:
+	genlmsg_cancel(msg, hdr);
 	return err;
 }
 
@@ -924,7 +1035,7 @@
 	if (!msg)
 		return -ENOMEM;
 
-	mutex_lock(&vdpa_dev_mutex);
+	down_read(&vdpa_dev_lock);
 	dev = bus_find_device(&vdpa_bus, NULL, devname, vdpa_name_match);
 	if (!dev) {
 		NL_SET_ERR_MSG_MOD(info->extack, "device not found");
@@ -945,7 +1056,7 @@
 mdev_err:
 	put_device(dev);
 dev_err:
-	mutex_unlock(&vdpa_dev_mutex);
+	up_read(&vdpa_dev_lock);
 	if (err)
 		nlmsg_free(msg);
 	return err;
@@ -983,13 +1094,67 @@
 	info.start_idx = cb->args[0];
 	info.idx = 0;
 
-	mutex_lock(&vdpa_dev_mutex);
+	down_read(&vdpa_dev_lock);
 	bus_for_each_dev(&vdpa_bus, NULL, &info, vdpa_dev_config_dump);
-	mutex_unlock(&vdpa_dev_mutex);
+	up_read(&vdpa_dev_lock);
 	cb->args[0] = info.idx;
 	return msg->len;
 }
 
+static int vdpa_nl_cmd_dev_stats_get_doit(struct sk_buff *skb,
+					  struct genl_info *info)
+{
+	struct vdpa_device *vdev;
+	struct sk_buff *msg;
+	const char *devname;
+	struct device *dev;
+	u32 index;
+	int err;
+
+	if (!info->attrs[VDPA_ATTR_DEV_NAME])
+		return -EINVAL;
+
+	if (!info->attrs[VDPA_ATTR_DEV_QUEUE_INDEX])
+		return -EINVAL;
+
+	devname = nla_data(info->attrs[VDPA_ATTR_DEV_NAME]);
+	msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
+	if (!msg)
+		return -ENOMEM;
+
+	index = nla_get_u32(info->attrs[VDPA_ATTR_DEV_QUEUE_INDEX]);
+	down_read(&vdpa_dev_lock);
+	dev = bus_find_device(&vdpa_bus, NULL, devname, vdpa_name_match);
+	if (!dev) {
+		NL_SET_ERR_MSG_MOD(info->extack, "device not found");
+		err = -ENODEV;
+		goto dev_err;
+	}
+	vdev = container_of(dev, struct vdpa_device, dev);
+	if (!vdev->mdev) {
+		NL_SET_ERR_MSG_MOD(info->extack, "unmanaged vdpa device");
+		err = -EINVAL;
+		goto mdev_err;
+	}
+	err = vdpa_dev_vendor_stats_fill(vdev, msg, info, index);
+	if (err)
+		goto mdev_err;
+
+	err = genlmsg_reply(msg, info);
+
+	put_device(dev);
+	up_read(&vdpa_dev_lock);
+
+	return err;
+
+mdev_err:
+	put_device(dev);
+dev_err:
+	nlmsg_free(msg);
+	up_read(&vdpa_dev_lock);
+	return err;
+}
+
 static const struct nla_policy vdpa_nl_policy[VDPA_ATTR_MAX + 1] = {
 	[VDPA_ATTR_MGMTDEV_BUS_NAME] = { .type = NLA_NUL_STRING },
 	[VDPA_ATTR_MGMTDEV_DEV_NAME] = { .type = NLA_STRING },
@@ -1030,6 +1195,12 @@
 		.doit = vdpa_nl_cmd_dev_config_get_doit,
 		.dumpit = vdpa_nl_cmd_dev_config_get_dumpit,
 	},
+	{
+		.cmd = VDPA_CMD_DEV_VSTATS_GET,
+		.validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP,
+		.doit = vdpa_nl_cmd_dev_stats_get_doit,
+		.flags = GENL_ADMIN_PERM,
+	},
 };
 
 static struct genl_family vdpa_nl_family __ro_after_init = {
diff --git a/drivers/vdpa/vdpa_sim/vdpa_sim.c b/drivers/vdpa/vdpa_sim/vdpa_sim.c
index ddbe142..0f28658 100644
--- a/drivers/vdpa/vdpa_sim/vdpa_sim.c
+++ b/drivers/vdpa/vdpa_sim/vdpa_sim.c
@@ -96,11 +96,17 @@
 {
 	int i;
 
-	for (i = 0; i < vdpasim->dev_attr.nvqs; i++)
-		vdpasim_vq_reset(vdpasim, &vdpasim->vqs[i]);
-
 	spin_lock(&vdpasim->iommu_lock);
-	vhost_iotlb_reset(vdpasim->iommu);
+
+	for (i = 0; i < vdpasim->dev_attr.nvqs; i++) {
+		vdpasim_vq_reset(vdpasim, &vdpasim->vqs[i]);
+		vringh_set_iotlb(&vdpasim->vqs[i].vring, &vdpasim->iommu[0],
+				 &vdpasim->iommu_lock);
+	}
+
+	for (i = 0; i < vdpasim->dev_attr.nas; i++)
+		vhost_iotlb_reset(&vdpasim->iommu[i]);
+
 	spin_unlock(&vdpasim->iommu_lock);
 
 	vdpasim->features = 0;
@@ -145,7 +151,7 @@
 	dma_addr = iova_dma_addr(&vdpasim->iova, iova);
 
 	spin_lock(&vdpasim->iommu_lock);
-	ret = vhost_iotlb_add_range(vdpasim->iommu, (u64)dma_addr,
+	ret = vhost_iotlb_add_range(&vdpasim->iommu[0], (u64)dma_addr,
 				    (u64)dma_addr + size - 1, (u64)paddr, perm);
 	spin_unlock(&vdpasim->iommu_lock);
 
@@ -161,7 +167,7 @@
 				size_t size)
 {
 	spin_lock(&vdpasim->iommu_lock);
-	vhost_iotlb_del_range(vdpasim->iommu, (u64)dma_addr,
+	vhost_iotlb_del_range(&vdpasim->iommu[0], (u64)dma_addr,
 			      (u64)dma_addr + size - 1);
 	spin_unlock(&vdpasim->iommu_lock);
 
@@ -251,6 +257,7 @@
 		ops = &vdpasim_config_ops;
 
 	vdpasim = vdpa_alloc_device(struct vdpasim, vdpa, NULL, ops,
+				    dev_attr->ngroups, dev_attr->nas,
 				    dev_attr->name, false);
 	if (IS_ERR(vdpasim)) {
 		ret = PTR_ERR(vdpasim);
@@ -278,16 +285,20 @@
 	if (!vdpasim->vqs)
 		goto err_iommu;
 
-	vdpasim->iommu = vhost_iotlb_alloc(max_iotlb_entries, 0);
+	vdpasim->iommu = kmalloc_array(vdpasim->dev_attr.nas,
+				       sizeof(*vdpasim->iommu), GFP_KERNEL);
 	if (!vdpasim->iommu)
 		goto err_iommu;
 
+	for (i = 0; i < vdpasim->dev_attr.nas; i++)
+		vhost_iotlb_init(&vdpasim->iommu[i], 0, 0);
+
 	vdpasim->buffer = kvmalloc(dev_attr->buffer_size, GFP_KERNEL);
 	if (!vdpasim->buffer)
 		goto err_iommu;
 
 	for (i = 0; i < dev_attr->nvqs; i++)
-		vringh_set_iotlb(&vdpasim->vqs[i].vring, vdpasim->iommu,
+		vringh_set_iotlb(&vdpasim->vqs[i].vring, &vdpasim->iommu[0],
 				 &vdpasim->iommu_lock);
 
 	ret = iova_cache_get();
@@ -353,11 +364,14 @@
 {
 	struct vdpasim *vdpasim = vdpa_to_sim(vdpa);
 	struct vdpasim_virtqueue *vq = &vdpasim->vqs[idx];
+	bool old_ready;
 
 	spin_lock(&vdpasim->lock);
+	old_ready = vq->ready;
 	vq->ready = ready;
-	if (vq->ready)
+	if (vq->ready && !old_ready) {
 		vdpasim_queue_ready(vdpasim, idx);
+	}
 	spin_unlock(&vdpasim->lock);
 }
 
@@ -399,6 +413,15 @@
 	return VDPASIM_QUEUE_ALIGN;
 }
 
+static u32 vdpasim_get_vq_group(struct vdpa_device *vdpa, u16 idx)
+{
+	/* RX and TX belongs to group 0, CVQ belongs to group 1 */
+	if (idx == 2)
+		return 1;
+	else
+		return 0;
+}
+
 static u64 vdpasim_get_device_features(struct vdpa_device *vdpa)
 {
 	struct vdpasim *vdpasim = vdpa_to_sim(vdpa);
@@ -534,20 +557,53 @@
 	return range;
 }
 
-static int vdpasim_set_map(struct vdpa_device *vdpa,
+static int vdpasim_set_group_asid(struct vdpa_device *vdpa, unsigned int group,
+				  unsigned int asid)
+{
+	struct vdpasim *vdpasim = vdpa_to_sim(vdpa);
+	struct vhost_iotlb *iommu;
+	int i;
+
+	if (group > vdpasim->dev_attr.ngroups)
+		return -EINVAL;
+
+	if (asid >= vdpasim->dev_attr.nas)
+		return -EINVAL;
+
+	iommu = &vdpasim->iommu[asid];
+
+	spin_lock(&vdpasim->lock);
+
+	for (i = 0; i < vdpasim->dev_attr.nvqs; i++)
+		if (vdpasim_get_vq_group(vdpa, i) == group)
+			vringh_set_iotlb(&vdpasim->vqs[i].vring, iommu,
+					 &vdpasim->iommu_lock);
+
+	spin_unlock(&vdpasim->lock);
+
+	return 0;
+}
+
+static int vdpasim_set_map(struct vdpa_device *vdpa, unsigned int asid,
 			   struct vhost_iotlb *iotlb)
 {
 	struct vdpasim *vdpasim = vdpa_to_sim(vdpa);
 	struct vhost_iotlb_map *map;
+	struct vhost_iotlb *iommu;
 	u64 start = 0ULL, last = 0ULL - 1;
 	int ret;
 
+	if (asid >= vdpasim->dev_attr.nas)
+		return -EINVAL;
+
 	spin_lock(&vdpasim->iommu_lock);
-	vhost_iotlb_reset(vdpasim->iommu);
+
+	iommu = &vdpasim->iommu[asid];
+	vhost_iotlb_reset(iommu);
 
 	for (map = vhost_iotlb_itree_first(iotlb, start, last); map;
 	     map = vhost_iotlb_itree_next(map, start, last)) {
-		ret = vhost_iotlb_add_range(vdpasim->iommu, map->start,
+		ret = vhost_iotlb_add_range(iommu, map->start,
 					    map->last, map->addr, map->perm);
 		if (ret)
 			goto err;
@@ -556,31 +612,39 @@
 	return 0;
 
 err:
-	vhost_iotlb_reset(vdpasim->iommu);
+	vhost_iotlb_reset(iommu);
 	spin_unlock(&vdpasim->iommu_lock);
 	return ret;
 }
 
-static int vdpasim_dma_map(struct vdpa_device *vdpa, u64 iova, u64 size,
+static int vdpasim_dma_map(struct vdpa_device *vdpa, unsigned int asid,
+			   u64 iova, u64 size,
 			   u64 pa, u32 perm, void *opaque)
 {
 	struct vdpasim *vdpasim = vdpa_to_sim(vdpa);
 	int ret;
 
+	if (asid >= vdpasim->dev_attr.nas)
+		return -EINVAL;
+
 	spin_lock(&vdpasim->iommu_lock);
-	ret = vhost_iotlb_add_range_ctx(vdpasim->iommu, iova, iova + size - 1,
-					pa, perm, opaque);
+	ret = vhost_iotlb_add_range_ctx(&vdpasim->iommu[asid], iova,
+					iova + size - 1, pa, perm, opaque);
 	spin_unlock(&vdpasim->iommu_lock);
 
 	return ret;
 }
 
-static int vdpasim_dma_unmap(struct vdpa_device *vdpa, u64 iova, u64 size)
+static int vdpasim_dma_unmap(struct vdpa_device *vdpa, unsigned int asid,
+			     u64 iova, u64 size)
 {
 	struct vdpasim *vdpasim = vdpa_to_sim(vdpa);
 
+	if (asid >= vdpasim->dev_attr.nas)
+		return -EINVAL;
+
 	spin_lock(&vdpasim->iommu_lock);
-	vhost_iotlb_del_range(vdpasim->iommu, iova, iova + size - 1);
+	vhost_iotlb_del_range(&vdpasim->iommu[asid], iova, iova + size - 1);
 	spin_unlock(&vdpasim->iommu_lock);
 
 	return 0;
@@ -604,8 +668,7 @@
 	}
 
 	kvfree(vdpasim->buffer);
-	if (vdpasim->iommu)
-		vhost_iotlb_free(vdpasim->iommu);
+	vhost_iotlb_free(vdpasim->iommu);
 	kfree(vdpasim->vqs);
 	kfree(vdpasim->config);
 }
@@ -620,6 +683,7 @@
 	.set_vq_state           = vdpasim_set_vq_state,
 	.get_vq_state           = vdpasim_get_vq_state,
 	.get_vq_align           = vdpasim_get_vq_align,
+	.get_vq_group           = vdpasim_get_vq_group,
 	.get_device_features    = vdpasim_get_device_features,
 	.set_driver_features    = vdpasim_set_driver_features,
 	.get_driver_features    = vdpasim_get_driver_features,
@@ -635,6 +699,7 @@
 	.set_config             = vdpasim_set_config,
 	.get_generation         = vdpasim_get_generation,
 	.get_iova_range         = vdpasim_get_iova_range,
+	.set_group_asid         = vdpasim_set_group_asid,
 	.dma_map                = vdpasim_dma_map,
 	.dma_unmap              = vdpasim_dma_unmap,
 	.free                   = vdpasim_free,
@@ -650,6 +715,7 @@
 	.set_vq_state           = vdpasim_set_vq_state,
 	.get_vq_state           = vdpasim_get_vq_state,
 	.get_vq_align           = vdpasim_get_vq_align,
+	.get_vq_group           = vdpasim_get_vq_group,
 	.get_device_features    = vdpasim_get_device_features,
 	.set_driver_features    = vdpasim_set_driver_features,
 	.get_driver_features    = vdpasim_get_driver_features,
@@ -665,6 +731,7 @@
 	.set_config             = vdpasim_set_config,
 	.get_generation         = vdpasim_get_generation,
 	.get_iova_range         = vdpasim_get_iova_range,
+	.set_group_asid         = vdpasim_set_group_asid,
 	.set_map                = vdpasim_set_map,
 	.free                   = vdpasim_free,
 };
diff --git a/drivers/vdpa/vdpa_sim/vdpa_sim.h b/drivers/vdpa/vdpa_sim/vdpa_sim.h
index cd58e88..622782e 100644
--- a/drivers/vdpa/vdpa_sim/vdpa_sim.h
+++ b/drivers/vdpa/vdpa_sim/vdpa_sim.h
@@ -41,6 +41,8 @@
 	size_t buffer_size;
 	int nvqs;
 	u32 id;
+	u32 ngroups;
+	u32 nas;
 
 	work_func_t work_fn;
 	void (*get_config)(struct vdpasim *vdpasim, void *config);
@@ -63,6 +65,7 @@
 	u32 status;
 	u32 generation;
 	u64 features;
+	u32 groups;
 	/* spinlock to synchronize iommu table */
 	spinlock_t iommu_lock;
 };
diff --git a/drivers/vdpa/vdpa_sim/vdpa_sim_net.c b/drivers/vdpa/vdpa_sim/vdpa_sim_net.c
index d5324f6..5125976 100644
--- a/drivers/vdpa/vdpa_sim/vdpa_sim_net.c
+++ b/drivers/vdpa/vdpa_sim/vdpa_sim_net.c
@@ -26,9 +26,122 @@
 #define DRV_LICENSE  "GPL v2"
 
 #define VDPASIM_NET_FEATURES	(VDPASIM_FEATURES | \
-				 (1ULL << VIRTIO_NET_F_MAC))
+				 (1ULL << VIRTIO_NET_F_MAC) | \
+				 (1ULL << VIRTIO_NET_F_MTU) | \
+				 (1ULL << VIRTIO_NET_F_CTRL_VQ) | \
+				 (1ULL << VIRTIO_NET_F_CTRL_MAC_ADDR))
 
-#define VDPASIM_NET_VQ_NUM	2
+/* 3 virtqueues, 2 address spaces, 2 virtqueue groups */
+#define VDPASIM_NET_VQ_NUM	3
+#define VDPASIM_NET_AS_NUM	2
+#define VDPASIM_NET_GROUP_NUM	2
+
+static void vdpasim_net_complete(struct vdpasim_virtqueue *vq, size_t len)
+{
+	/* Make sure data is wrote before advancing index */
+	smp_wmb();
+
+	vringh_complete_iotlb(&vq->vring, vq->head, len);
+
+	/* Make sure used is visible before rasing the interrupt. */
+	smp_wmb();
+
+	local_bh_disable();
+	if (vringh_need_notify_iotlb(&vq->vring) > 0)
+		vringh_notify(&vq->vring);
+	local_bh_enable();
+}
+
+static bool receive_filter(struct vdpasim *vdpasim, size_t len)
+{
+	bool modern = vdpasim->features & (1ULL << VIRTIO_F_VERSION_1);
+	size_t hdr_len = modern ? sizeof(struct virtio_net_hdr_v1) :
+				  sizeof(struct virtio_net_hdr);
+	struct virtio_net_config *vio_config = vdpasim->config;
+
+	if (len < ETH_ALEN + hdr_len)
+		return false;
+
+	if (!strncmp(vdpasim->buffer + hdr_len, vio_config->mac, ETH_ALEN))
+		return true;
+
+	return false;
+}
+
+static virtio_net_ctrl_ack vdpasim_handle_ctrl_mac(struct vdpasim *vdpasim,
+						   u8 cmd)
+{
+	struct virtio_net_config *vio_config = vdpasim->config;
+	struct vdpasim_virtqueue *cvq = &vdpasim->vqs[2];
+	virtio_net_ctrl_ack status = VIRTIO_NET_ERR;
+	size_t read;
+
+	switch (cmd) {
+	case VIRTIO_NET_CTRL_MAC_ADDR_SET:
+		read = vringh_iov_pull_iotlb(&cvq->vring, &cvq->in_iov,
+					     vio_config->mac, ETH_ALEN);
+		if (read == ETH_ALEN)
+			status = VIRTIO_NET_OK;
+		break;
+	default:
+		break;
+	}
+
+	return status;
+}
+
+static void vdpasim_handle_cvq(struct vdpasim *vdpasim)
+{
+	struct vdpasim_virtqueue *cvq = &vdpasim->vqs[2];
+	virtio_net_ctrl_ack status = VIRTIO_NET_ERR;
+	struct virtio_net_ctrl_hdr ctrl;
+	size_t read, write;
+	int err;
+
+	if (!(vdpasim->features & (1ULL << VIRTIO_NET_F_CTRL_VQ)))
+		return;
+
+	if (!cvq->ready)
+		return;
+
+	while (true) {
+		err = vringh_getdesc_iotlb(&cvq->vring, &cvq->in_iov,
+					   &cvq->out_iov,
+					   &cvq->head, GFP_ATOMIC);
+		if (err <= 0)
+			break;
+
+		read = vringh_iov_pull_iotlb(&cvq->vring, &cvq->in_iov, &ctrl,
+					     sizeof(ctrl));
+		if (read != sizeof(ctrl))
+			break;
+
+		switch (ctrl.class) {
+		case VIRTIO_NET_CTRL_MAC:
+			status = vdpasim_handle_ctrl_mac(vdpasim, ctrl.cmd);
+			break;
+		default:
+			break;
+		}
+
+		/* Make sure data is wrote before advancing index */
+		smp_wmb();
+
+		write = vringh_iov_push_iotlb(&cvq->vring, &cvq->out_iov,
+					      &status, sizeof(status));
+		vringh_complete_iotlb(&cvq->vring, cvq->head, write);
+		vringh_kiov_cleanup(&cvq->in_iov);
+		vringh_kiov_cleanup(&cvq->out_iov);
+
+		/* Make sure used is visible before rasing the interrupt. */
+		smp_wmb();
+
+		local_bh_disable();
+		if (cvq->cb)
+			cvq->cb(cvq->private);
+		local_bh_enable();
+	}
+}
 
 static void vdpasim_net_work(struct work_struct *work)
 {
@@ -36,7 +149,6 @@
 	struct vdpasim_virtqueue *txq = &vdpasim->vqs[1];
 	struct vdpasim_virtqueue *rxq = &vdpasim->vqs[0];
 	ssize_t read, write;
-	size_t total_write;
 	int pkts = 0;
 	int err;
 
@@ -45,53 +157,40 @@
 	if (!(vdpasim->status & VIRTIO_CONFIG_S_DRIVER_OK))
 		goto out;
 
+	vdpasim_handle_cvq(vdpasim);
+
 	if (!txq->ready || !rxq->ready)
 		goto out;
 
 	while (true) {
-		total_write = 0;
 		err = vringh_getdesc_iotlb(&txq->vring, &txq->out_iov, NULL,
 					   &txq->head, GFP_ATOMIC);
 		if (err <= 0)
 			break;
 
+		read = vringh_iov_pull_iotlb(&txq->vring, &txq->out_iov,
+					     vdpasim->buffer,
+					     PAGE_SIZE);
+
+		if (!receive_filter(vdpasim, read)) {
+			vdpasim_net_complete(txq, 0);
+			continue;
+		}
+
 		err = vringh_getdesc_iotlb(&rxq->vring, NULL, &rxq->in_iov,
 					   &rxq->head, GFP_ATOMIC);
 		if (err <= 0) {
-			vringh_complete_iotlb(&txq->vring, txq->head, 0);
+			vdpasim_net_complete(txq, 0);
 			break;
 		}
 
-		while (true) {
-			read = vringh_iov_pull_iotlb(&txq->vring, &txq->out_iov,
-						     vdpasim->buffer,
-						     PAGE_SIZE);
-			if (read <= 0)
-				break;
+		write = vringh_iov_push_iotlb(&rxq->vring, &rxq->in_iov,
+					      vdpasim->buffer, read);
+		if (write <= 0)
+			break;
 
-			write = vringh_iov_push_iotlb(&rxq->vring, &rxq->in_iov,
-						      vdpasim->buffer, read);
-			if (write <= 0)
-				break;
-
-			total_write += write;
-		}
-
-		/* Make sure data is wrote before advancing index */
-		smp_wmb();
-
-		vringh_complete_iotlb(&txq->vring, txq->head, 0);
-		vringh_complete_iotlb(&rxq->vring, rxq->head, total_write);
-
-		/* Make sure used is visible before rasing the interrupt. */
-		smp_wmb();
-
-		local_bh_disable();
-		if (vringh_need_notify_iotlb(&txq->vring) > 0)
-			vringh_notify(&txq->vring);
-		if (vringh_need_notify_iotlb(&rxq->vring) > 0)
-			vringh_notify(&rxq->vring);
-		local_bh_enable();
+		vdpasim_net_complete(txq, 0);
+		vdpasim_net_complete(rxq, write);
 
 		if (++pkts > 4) {
 			schedule_work(&vdpasim->work);
@@ -145,6 +244,8 @@
 	dev_attr.id = VIRTIO_ID_NET;
 	dev_attr.supported_features = VDPASIM_NET_FEATURES;
 	dev_attr.nvqs = VDPASIM_NET_VQ_NUM;
+	dev_attr.ngroups = VDPASIM_NET_GROUP_NUM;
+	dev_attr.nas = VDPASIM_NET_AS_NUM;
 	dev_attr.config_size = sizeof(struct virtio_net_config);
 	dev_attr.get_config = vdpasim_net_get_config;
 	dev_attr.work_fn = vdpasim_net_work;
diff --git a/drivers/vdpa/vdpa_user/vduse_dev.c b/drivers/vdpa/vdpa_user/vduse_dev.c
index f85d1a0..d503848 100644
--- a/drivers/vdpa/vdpa_user/vduse_dev.c
+++ b/drivers/vdpa/vdpa_user/vduse_dev.c
@@ -693,6 +693,7 @@
 }
 
 static int vduse_vdpa_set_map(struct vdpa_device *vdpa,
+				unsigned int asid,
 				struct vhost_iotlb *iotlb)
 {
 	struct vduse_dev *dev = vdpa_to_vduse(vdpa);
@@ -1495,7 +1496,7 @@
 		return -EEXIST;
 
 	vdev = vdpa_alloc_device(struct vduse_vdpa, vdpa, dev->dev,
-				 &vduse_vdpa_config_ops, name, true);
+				 &vduse_vdpa_config_ops, 1, 1, name, true);
 	if (IS_ERR(vdev))
 		return PTR_ERR(vdev);
 
diff --git a/drivers/vdpa/virtio_pci/vp_vdpa.c b/drivers/vdpa/virtio_pci/vp_vdpa.c
index cce101e..0452207 100644
--- a/drivers/vdpa/virtio_pci/vp_vdpa.c
+++ b/drivers/vdpa/virtio_pci/vp_vdpa.c
@@ -32,7 +32,7 @@
 
 struct vp_vdpa {
 	struct vdpa_device vdpa;
-	struct virtio_pci_modern_device mdev;
+	struct virtio_pci_modern_device *mdev;
 	struct vp_vring *vring;
 	struct vdpa_callback config_cb;
 	char msix_name[VP_VDPA_NAME_SIZE];
@@ -41,6 +41,12 @@
 	int vectors;
 };
 
+struct vp_vdpa_mgmtdev {
+	struct vdpa_mgmt_dev mgtdev;
+	struct virtio_pci_modern_device *mdev;
+	struct vp_vdpa *vp_vdpa;
+};
+
 static struct vp_vdpa *vdpa_to_vp(struct vdpa_device *vdpa)
 {
 	return container_of(vdpa, struct vp_vdpa, vdpa);
@@ -50,7 +56,12 @@
 {
 	struct vp_vdpa *vp_vdpa = vdpa_to_vp(vdpa);
 
-	return &vp_vdpa->mdev;
+	return vp_vdpa->mdev;
+}
+
+static struct virtio_pci_modern_device *vp_vdpa_to_mdev(struct vp_vdpa *vp_vdpa)
+{
+	return vp_vdpa->mdev;
 }
 
 static u64 vp_vdpa_get_device_features(struct vdpa_device *vdpa)
@@ -96,7 +107,7 @@
 
 static void vp_vdpa_free_irq(struct vp_vdpa *vp_vdpa)
 {
-	struct virtio_pci_modern_device *mdev = &vp_vdpa->mdev;
+	struct virtio_pci_modern_device *mdev = vp_vdpa_to_mdev(vp_vdpa);
 	struct pci_dev *pdev = mdev->pci_dev;
 	int i;
 
@@ -143,7 +154,7 @@
 
 static int vp_vdpa_request_irq(struct vp_vdpa *vp_vdpa)
 {
-	struct virtio_pci_modern_device *mdev = &vp_vdpa->mdev;
+	struct virtio_pci_modern_device *mdev = vp_vdpa_to_mdev(vp_vdpa);
 	struct pci_dev *pdev = mdev->pci_dev;
 	int i, ret, irq;
 	int queues = vp_vdpa->queues;
@@ -198,7 +209,7 @@
 static void vp_vdpa_set_status(struct vdpa_device *vdpa, u8 status)
 {
 	struct vp_vdpa *vp_vdpa = vdpa_to_vp(vdpa);
-	struct virtio_pci_modern_device *mdev = &vp_vdpa->mdev;
+	struct virtio_pci_modern_device *mdev = vp_vdpa_to_mdev(vp_vdpa);
 	u8 s = vp_vdpa_get_status(vdpa);
 
 	if (status & VIRTIO_CONFIG_S_DRIVER_OK &&
@@ -212,7 +223,7 @@
 static int vp_vdpa_reset(struct vdpa_device *vdpa)
 {
 	struct vp_vdpa *vp_vdpa = vdpa_to_vp(vdpa);
-	struct virtio_pci_modern_device *mdev = &vp_vdpa->mdev;
+	struct virtio_pci_modern_device *mdev = vp_vdpa_to_mdev(vp_vdpa);
 	u8 s = vp_vdpa_get_status(vdpa);
 
 	vp_modern_set_status(mdev, 0);
@@ -372,7 +383,7 @@
 			       void *buf, unsigned int len)
 {
 	struct vp_vdpa *vp_vdpa = vdpa_to_vp(vdpa);
-	struct virtio_pci_modern_device *mdev = &vp_vdpa->mdev;
+	struct virtio_pci_modern_device *mdev = vp_vdpa_to_mdev(vp_vdpa);
 	u8 old, new;
 	u8 *p;
 	int i;
@@ -392,7 +403,7 @@
 			       unsigned int len)
 {
 	struct vp_vdpa *vp_vdpa = vdpa_to_vp(vdpa);
-	struct virtio_pci_modern_device *mdev = &vp_vdpa->mdev;
+	struct virtio_pci_modern_device *mdev = vp_vdpa_to_mdev(vp_vdpa);
 	const u8 *p = buf;
 	int i;
 
@@ -412,7 +423,7 @@
 vp_vdpa_get_vq_notification(struct vdpa_device *vdpa, u16 qid)
 {
 	struct vp_vdpa *vp_vdpa = vdpa_to_vp(vdpa);
-	struct virtio_pci_modern_device *mdev = &vp_vdpa->mdev;
+	struct virtio_pci_modern_device *mdev = vp_vdpa_to_mdev(vp_vdpa);
 	struct vdpa_notification_area notify;
 
 	notify.addr = vp_vdpa->vring[qid].notify_pa;
@@ -454,38 +465,31 @@
 	pci_free_irq_vectors(data);
 }
 
-static int vp_vdpa_probe(struct pci_dev *pdev, const struct pci_device_id *id)
+static int vp_vdpa_dev_add(struct vdpa_mgmt_dev *v_mdev, const char *name,
+			   const struct vdpa_dev_set_config *add_config)
 {
-	struct virtio_pci_modern_device *mdev;
+	struct vp_vdpa_mgmtdev *vp_vdpa_mgtdev =
+		container_of(v_mdev, struct vp_vdpa_mgmtdev, mgtdev);
+
+	struct virtio_pci_modern_device *mdev = vp_vdpa_mgtdev->mdev;
+	struct pci_dev *pdev = mdev->pci_dev;
 	struct device *dev = &pdev->dev;
-	struct vp_vdpa *vp_vdpa;
+	struct vp_vdpa *vp_vdpa = NULL;
 	int ret, i;
 
-	ret = pcim_enable_device(pdev);
-	if (ret)
-		return ret;
-
 	vp_vdpa = vdpa_alloc_device(struct vp_vdpa, vdpa,
-				    dev, &vp_vdpa_ops, NULL, false);
+				    dev, &vp_vdpa_ops, 1, 1, name, false);
+
 	if (IS_ERR(vp_vdpa)) {
 		dev_err(dev, "vp_vdpa: Failed to allocate vDPA structure\n");
 		return PTR_ERR(vp_vdpa);
 	}
 
-	mdev = &vp_vdpa->mdev;
-	mdev->pci_dev = pdev;
-
-	ret = vp_modern_probe(mdev);
-	if (ret) {
-		dev_err(&pdev->dev, "Failed to probe modern PCI device\n");
-		goto err;
-	}
-
-	pci_set_master(pdev);
-	pci_set_drvdata(pdev, vp_vdpa);
+	vp_vdpa_mgtdev->vp_vdpa = vp_vdpa;
 
 	vp_vdpa->vdpa.dma_dev = &pdev->dev;
 	vp_vdpa->queues = vp_modern_get_num_queues(mdev);
+	vp_vdpa->mdev = mdev;
 
 	ret = devm_add_action_or_reset(dev, vp_vdpa_free_irq_vectors, pdev);
 	if (ret) {
@@ -516,7 +520,8 @@
 	}
 	vp_vdpa->config_irq = VIRTIO_MSI_NO_VECTOR;
 
-	ret = vdpa_register_device(&vp_vdpa->vdpa, vp_vdpa->queues);
+	vp_vdpa->vdpa.mdev = &vp_vdpa_mgtdev->mgtdev;
+	ret = _vdpa_register_device(&vp_vdpa->vdpa, vp_vdpa->queues);
 	if (ret) {
 		dev_err(&pdev->dev, "Failed to register to vdpa bus\n");
 		goto err;
@@ -529,12 +534,104 @@
 	return ret;
 }
 
+static void vp_vdpa_dev_del(struct vdpa_mgmt_dev *v_mdev,
+			    struct vdpa_device *dev)
+{
+	struct vp_vdpa_mgmtdev *vp_vdpa_mgtdev =
+		container_of(v_mdev, struct vp_vdpa_mgmtdev, mgtdev);
+
+	struct vp_vdpa *vp_vdpa = vp_vdpa_mgtdev->vp_vdpa;
+
+	_vdpa_unregister_device(&vp_vdpa->vdpa);
+	vp_vdpa_mgtdev->vp_vdpa = NULL;
+}
+
+static const struct vdpa_mgmtdev_ops vp_vdpa_mdev_ops = {
+	.dev_add = vp_vdpa_dev_add,
+	.dev_del = vp_vdpa_dev_del,
+};
+
+static int vp_vdpa_probe(struct pci_dev *pdev, const struct pci_device_id *id)
+{
+	struct vp_vdpa_mgmtdev *vp_vdpa_mgtdev = NULL;
+	struct vdpa_mgmt_dev *mgtdev;
+	struct device *dev = &pdev->dev;
+	struct virtio_pci_modern_device *mdev = NULL;
+	struct virtio_device_id *mdev_id = NULL;
+	int err;
+
+	vp_vdpa_mgtdev = kzalloc(sizeof(*vp_vdpa_mgtdev), GFP_KERNEL);
+	if (!vp_vdpa_mgtdev)
+		return -ENOMEM;
+
+	mgtdev = &vp_vdpa_mgtdev->mgtdev;
+	mgtdev->ops = &vp_vdpa_mdev_ops;
+	mgtdev->device = dev;
+
+	mdev = kzalloc(sizeof(struct virtio_pci_modern_device), GFP_KERNEL);
+	if (!mdev) {
+		err = -ENOMEM;
+		goto mdev_err;
+	}
+
+	mdev_id = kzalloc(sizeof(struct virtio_device_id), GFP_KERNEL);
+	if (!mdev_id) {
+		err = -ENOMEM;
+		goto mdev_id_err;
+	}
+
+	vp_vdpa_mgtdev->mdev = mdev;
+	mdev->pci_dev = pdev;
+
+	err = pcim_enable_device(pdev);
+	if (err) {
+		goto probe_err;
+	}
+
+	err = vp_modern_probe(mdev);
+	if (err) {
+		dev_err(&pdev->dev, "Failed to probe modern PCI device\n");
+		goto probe_err;
+	}
+
+	mdev_id->device = mdev->id.device;
+	mdev_id->vendor = mdev->id.vendor;
+	mgtdev->id_table = mdev_id;
+	mgtdev->max_supported_vqs = vp_modern_get_num_queues(mdev);
+	mgtdev->supported_features = vp_modern_get_features(mdev);
+	pci_set_master(pdev);
+	pci_set_drvdata(pdev, vp_vdpa_mgtdev);
+
+	err = vdpa_mgmtdev_register(mgtdev);
+	if (err) {
+		dev_err(&pdev->dev, "Failed to register vdpa mgmtdev device\n");
+		goto register_err;
+	}
+
+	return 0;
+
+register_err:
+	vp_modern_remove(vp_vdpa_mgtdev->mdev);
+probe_err:
+	kfree(mdev_id);
+mdev_id_err:
+	kfree(mdev);
+mdev_err:
+	kfree(vp_vdpa_mgtdev);
+	return err;
+}
+
 static void vp_vdpa_remove(struct pci_dev *pdev)
 {
-	struct vp_vdpa *vp_vdpa = pci_get_drvdata(pdev);
+	struct vp_vdpa_mgmtdev *vp_vdpa_mgtdev = pci_get_drvdata(pdev);
+	struct virtio_pci_modern_device *mdev = NULL;
 
-	vp_modern_remove(&vp_vdpa->mdev);
-	vdpa_unregister_device(&vp_vdpa->vdpa);
+	mdev = vp_vdpa_mgtdev->mdev;
+	vp_modern_remove(mdev);
+	vdpa_mgmtdev_unregister(&vp_vdpa_mgtdev->mgtdev);
+	kfree(&vp_vdpa_mgtdev->mgtdev.id_table);
+	kfree(mdev);
+	kfree(vp_vdpa_mgtdev);
 }
 
 static struct pci_driver vp_vdpa_driver = {
diff --git a/drivers/vhost/iotlb.c b/drivers/vhost/iotlb.c
index 5829cf2..ea61330 100644
--- a/drivers/vhost/iotlb.c
+++ b/drivers/vhost/iotlb.c
@@ -126,6 +126,23 @@
 EXPORT_SYMBOL_GPL(vhost_iotlb_del_range);
 
 /**
+ * vhost_iotlb_init - initialize a vhost IOTLB
+ * @iotlb: the IOTLB that needs to be initialized
+ * @limit: maximum number of IOTLB entries
+ * @flags: VHOST_IOTLB_FLAG_XXX
+ */
+void vhost_iotlb_init(struct vhost_iotlb *iotlb, unsigned int limit,
+		      unsigned int flags)
+{
+	iotlb->root = RB_ROOT_CACHED;
+	iotlb->limit = limit;
+	iotlb->nmaps = 0;
+	iotlb->flags = flags;
+	INIT_LIST_HEAD(&iotlb->list);
+}
+EXPORT_SYMBOL_GPL(vhost_iotlb_init);
+
+/**
  * vhost_iotlb_alloc - add a new vhost IOTLB
  * @limit: maximum number of IOTLB entries
  * @flags: VHOST_IOTLB_FLAG_XXX
@@ -139,11 +156,7 @@
 	if (!iotlb)
 		return NULL;
 
-	iotlb->root = RB_ROOT_CACHED;
-	iotlb->limit = limit;
-	iotlb->nmaps = 0;
-	iotlb->flags = flags;
-	INIT_LIST_HEAD(&iotlb->list);
+	vhost_iotlb_init(iotlb, limit, flags);
 
 	return iotlb;
 }
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index 297b5db..68e4ecd 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -1374,16 +1374,9 @@
 	*rx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_RX].vq);
 }
 
-static void vhost_net_flush_vq(struct vhost_net *n, int index)
-{
-	vhost_poll_flush(n->poll + index);
-	vhost_poll_flush(&n->vqs[index].vq.poll);
-}
-
 static void vhost_net_flush(struct vhost_net *n)
 {
-	vhost_net_flush_vq(n, VHOST_NET_VQ_TX);
-	vhost_net_flush_vq(n, VHOST_NET_VQ_RX);
+	vhost_dev_flush(&n->dev);
 	if (n->vqs[VHOST_NET_VQ_TX].ubufs) {
 		mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
 		n->tx_flush = true;
@@ -1572,7 +1565,7 @@
 	}
 
 	if (oldsock) {
-		vhost_net_flush_vq(n, index);
+		vhost_dev_flush(&n->dev);
 		sockfd_put(oldsock);
 	}
 
diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c
index 532e204..ffd9e6c 100644
--- a/drivers/vhost/scsi.c
+++ b/drivers/vhost/scsi.c
@@ -1436,7 +1436,7 @@
 		kref_put(&old_inflight[i]->kref, vhost_scsi_done_inflight);
 
 	/* Flush both the vhost poll and vhost work */
-	vhost_work_dev_flush(&vs->dev);
+	vhost_dev_flush(&vs->dev);
 
 	/* Wait for all reqs issued before the flush to be finished */
 	for (i = 0; i < VHOST_SCSI_MAX_VQ; i++)
@@ -1827,8 +1827,6 @@
 	vhost_scsi_clear_endpoint(vs, &t);
 	vhost_dev_stop(&vs->dev);
 	vhost_dev_cleanup(&vs->dev);
-	/* Jobs can re-queue themselves in evt kick handler. Do extra flush. */
-	vhost_scsi_flush(vs);
 	kfree(vs->dev.vqs);
 	kvfree(vs);
 	return 0;
diff --git a/drivers/vhost/test.c b/drivers/vhost/test.c
index 05740cb..bc8e7fb 100644
--- a/drivers/vhost/test.c
+++ b/drivers/vhost/test.c
@@ -144,14 +144,9 @@
 	*privatep = vhost_test_stop_vq(n, n->vqs + VHOST_TEST_VQ);
 }
 
-static void vhost_test_flush_vq(struct vhost_test *n, int index)
-{
-	vhost_poll_flush(&n->vqs[index].poll);
-}
-
 static void vhost_test_flush(struct vhost_test *n)
 {
-	vhost_test_flush_vq(n, VHOST_TEST_VQ);
+	vhost_dev_flush(&n->dev);
 }
 
 static int vhost_test_release(struct inode *inode, struct file *f)
@@ -163,9 +158,6 @@
 	vhost_test_flush(n);
 	vhost_dev_stop(&n->dev);
 	vhost_dev_cleanup(&n->dev);
-	/* We do an extra flush before freeing memory,
-	 * since jobs can re-queue themselves. */
-	vhost_test_flush(n);
 	kfree(n->dev.vqs);
 	kfree(n);
 	return 0;
@@ -210,7 +202,7 @@
 			goto err;
 
 		if (oldpriv) {
-			vhost_test_flush_vq(n, index);
+			vhost_test_flush(n);
 		}
 	}
 
@@ -303,7 +295,7 @@
 	mutex_unlock(&vq->mutex);
 
 	if (enable) {
-		vhost_test_flush_vq(n, index);
+		vhost_test_flush(n);
 	}
 
 	mutex_unlock(&n->dev.mutex);
diff --git a/drivers/vhost/vdpa.c b/drivers/vhost/vdpa.c
index 4c2f0bd..935a1d0 100644
--- a/drivers/vhost/vdpa.c
+++ b/drivers/vhost/vdpa.c
@@ -28,17 +28,27 @@
 enum {
 	VHOST_VDPA_BACKEND_FEATURES =
 	(1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2) |
-	(1ULL << VHOST_BACKEND_F_IOTLB_BATCH),
+	(1ULL << VHOST_BACKEND_F_IOTLB_BATCH) |
+	(1ULL << VHOST_BACKEND_F_IOTLB_ASID),
 };
 
 #define VHOST_VDPA_DEV_MAX (1U << MINORBITS)
 
+#define VHOST_VDPA_IOTLB_BUCKETS 16
+
+struct vhost_vdpa_as {
+	struct hlist_node hash_link;
+	struct vhost_iotlb iotlb;
+	u32 id;
+};
+
 struct vhost_vdpa {
 	struct vhost_dev vdev;
 	struct iommu_domain *domain;
 	struct vhost_virtqueue *vqs;
 	struct completion completion;
 	struct vdpa_device *vdpa;
+	struct hlist_head as[VHOST_VDPA_IOTLB_BUCKETS];
 	struct device dev;
 	struct cdev cdev;
 	atomic_t opened;
@@ -48,12 +58,89 @@
 	struct eventfd_ctx *config_ctx;
 	int in_batch;
 	struct vdpa_iova_range range;
+	u32 batch_asid;
 };
 
 static DEFINE_IDA(vhost_vdpa_ida);
 
 static dev_t vhost_vdpa_major;
 
+static inline u32 iotlb_to_asid(struct vhost_iotlb *iotlb)
+{
+	struct vhost_vdpa_as *as = container_of(iotlb, struct
+						vhost_vdpa_as, iotlb);
+	return as->id;
+}
+
+static struct vhost_vdpa_as *asid_to_as(struct vhost_vdpa *v, u32 asid)
+{
+	struct hlist_head *head = &v->as[asid % VHOST_VDPA_IOTLB_BUCKETS];
+	struct vhost_vdpa_as *as;
+
+	hlist_for_each_entry(as, head, hash_link)
+		if (as->id == asid)
+			return as;
+
+	return NULL;
+}
+
+static struct vhost_iotlb *asid_to_iotlb(struct vhost_vdpa *v, u32 asid)
+{
+	struct vhost_vdpa_as *as = asid_to_as(v, asid);
+
+	if (!as)
+		return NULL;
+
+	return &as->iotlb;
+}
+
+static struct vhost_vdpa_as *vhost_vdpa_alloc_as(struct vhost_vdpa *v, u32 asid)
+{
+	struct hlist_head *head = &v->as[asid % VHOST_VDPA_IOTLB_BUCKETS];
+	struct vhost_vdpa_as *as;
+
+	if (asid_to_as(v, asid))
+		return NULL;
+
+	if (asid >= v->vdpa->nas)
+		return NULL;
+
+	as = kmalloc(sizeof(*as), GFP_KERNEL);
+	if (!as)
+		return NULL;
+
+	vhost_iotlb_init(&as->iotlb, 0, 0);
+	as->id = asid;
+	hlist_add_head(&as->hash_link, head);
+
+	return as;
+}
+
+static struct vhost_vdpa_as *vhost_vdpa_find_alloc_as(struct vhost_vdpa *v,
+						      u32 asid)
+{
+	struct vhost_vdpa_as *as = asid_to_as(v, asid);
+
+	if (as)
+		return as;
+
+	return vhost_vdpa_alloc_as(v, asid);
+}
+
+static int vhost_vdpa_remove_as(struct vhost_vdpa *v, u32 asid)
+{
+	struct vhost_vdpa_as *as = asid_to_as(v, asid);
+
+	if (!as)
+		return -EINVAL;
+
+	hlist_del(&as->hash_link);
+	vhost_iotlb_reset(&as->iotlb);
+	kfree(as);
+
+	return 0;
+}
+
 static void handle_vq_kick(struct vhost_work *work)
 {
 	struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
@@ -411,6 +498,22 @@
 			return -EFAULT;
 		ops->set_vq_ready(vdpa, idx, s.num);
 		return 0;
+	case VHOST_VDPA_GET_VRING_GROUP:
+		s.index = idx;
+		s.num = ops->get_vq_group(vdpa, idx);
+		if (s.num >= vdpa->ngroups)
+			return -EIO;
+		else if (copy_to_user(argp, &s, sizeof(s)))
+			return -EFAULT;
+		return 0;
+	case VHOST_VDPA_SET_GROUP_ASID:
+		if (copy_from_user(&s, argp, sizeof(s)))
+			return -EFAULT;
+		if (s.num >= vdpa->nas)
+			return -EINVAL;
+		if (!ops->set_group_asid)
+			return -EOPNOTSUPP;
+		return ops->set_group_asid(vdpa, idx, s.num);
 	case VHOST_GET_VRING_BASE:
 		r = ops->get_vq_state(v->vdpa, idx, &vq_state);
 		if (r)
@@ -505,6 +608,15 @@
 	case VHOST_VDPA_GET_VRING_NUM:
 		r = vhost_vdpa_get_vring_num(v, argp);
 		break;
+	case VHOST_VDPA_GET_GROUP_NUM:
+		if (copy_to_user(argp, &v->vdpa->ngroups,
+				 sizeof(v->vdpa->ngroups)))
+			r = -EFAULT;
+		break;
+	case VHOST_VDPA_GET_AS_NUM:
+		if (copy_to_user(argp, &v->vdpa->nas, sizeof(v->vdpa->nas)))
+			r = -EFAULT;
+		break;
 	case VHOST_SET_LOG_BASE:
 	case VHOST_SET_LOG_FD:
 		r = -ENOIOCTLCMD;
@@ -537,10 +649,11 @@
 	return r;
 }
 
-static void vhost_vdpa_pa_unmap(struct vhost_vdpa *v, u64 start, u64 last)
+static void vhost_vdpa_pa_unmap(struct vhost_vdpa *v,
+				struct vhost_iotlb *iotlb,
+				u64 start, u64 last)
 {
 	struct vhost_dev *dev = &v->vdev;
-	struct vhost_iotlb *iotlb = dev->iotlb;
 	struct vhost_iotlb_map *map;
 	struct page *page;
 	unsigned long pfn, pinned;
@@ -559,10 +672,10 @@
 	}
 }
 
-static void vhost_vdpa_va_unmap(struct vhost_vdpa *v, u64 start, u64 last)
+static void vhost_vdpa_va_unmap(struct vhost_vdpa *v,
+				struct vhost_iotlb *iotlb,
+				u64 start, u64 last)
 {
-	struct vhost_dev *dev = &v->vdev;
-	struct vhost_iotlb *iotlb = dev->iotlb;
 	struct vhost_iotlb_map *map;
 	struct vdpa_map_file *map_file;
 
@@ -574,23 +687,16 @@
 	}
 }
 
-static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v, u64 start, u64 last)
+static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v,
+				   struct vhost_iotlb *iotlb,
+				   u64 start, u64 last)
 {
 	struct vdpa_device *vdpa = v->vdpa;
 
 	if (vdpa->use_va)
-		return vhost_vdpa_va_unmap(v, start, last);
+		return vhost_vdpa_va_unmap(v, iotlb, start, last);
 
-	return vhost_vdpa_pa_unmap(v, start, last);
-}
-
-static void vhost_vdpa_iotlb_free(struct vhost_vdpa *v)
-{
-	struct vhost_dev *dev = &v->vdev;
-
-	vhost_vdpa_iotlb_unmap(v, 0ULL, 0ULL - 1);
-	kfree(dev->iotlb);
-	dev->iotlb = NULL;
+	return vhost_vdpa_pa_unmap(v, iotlb, start, last);
 }
 
 static int perm_to_iommu_flags(u32 perm)
@@ -615,30 +721,31 @@
 	return flags | IOMMU_CACHE;
 }
 
-static int vhost_vdpa_map(struct vhost_vdpa *v, u64 iova,
-			  u64 size, u64 pa, u32 perm, void *opaque)
+static int vhost_vdpa_map(struct vhost_vdpa *v, struct vhost_iotlb *iotlb,
+			  u64 iova, u64 size, u64 pa, u32 perm, void *opaque)
 {
 	struct vhost_dev *dev = &v->vdev;
 	struct vdpa_device *vdpa = v->vdpa;
 	const struct vdpa_config_ops *ops = vdpa->config;
+	u32 asid = iotlb_to_asid(iotlb);
 	int r = 0;
 
-	r = vhost_iotlb_add_range_ctx(dev->iotlb, iova, iova + size - 1,
+	r = vhost_iotlb_add_range_ctx(iotlb, iova, iova + size - 1,
 				      pa, perm, opaque);
 	if (r)
 		return r;
 
 	if (ops->dma_map) {
-		r = ops->dma_map(vdpa, iova, size, pa, perm, opaque);
+		r = ops->dma_map(vdpa, asid, iova, size, pa, perm, opaque);
 	} else if (ops->set_map) {
 		if (!v->in_batch)
-			r = ops->set_map(vdpa, dev->iotlb);
+			r = ops->set_map(vdpa, asid, iotlb);
 	} else {
 		r = iommu_map(v->domain, iova, pa, size,
 			      perm_to_iommu_flags(perm));
 	}
 	if (r) {
-		vhost_iotlb_del_range(dev->iotlb, iova, iova + size - 1);
+		vhost_iotlb_del_range(iotlb, iova, iova + size - 1);
 		return r;
 	}
 
@@ -648,25 +755,34 @@
 	return 0;
 }
 
-static void vhost_vdpa_unmap(struct vhost_vdpa *v, u64 iova, u64 size)
+static void vhost_vdpa_unmap(struct vhost_vdpa *v,
+			     struct vhost_iotlb *iotlb,
+			     u64 iova, u64 size)
 {
-	struct vhost_dev *dev = &v->vdev;
 	struct vdpa_device *vdpa = v->vdpa;
 	const struct vdpa_config_ops *ops = vdpa->config;
+	u32 asid = iotlb_to_asid(iotlb);
 
-	vhost_vdpa_iotlb_unmap(v, iova, iova + size - 1);
+	vhost_vdpa_iotlb_unmap(v, iotlb, iova, iova + size - 1);
 
 	if (ops->dma_map) {
-		ops->dma_unmap(vdpa, iova, size);
+		ops->dma_unmap(vdpa, asid, iova, size);
 	} else if (ops->set_map) {
 		if (!v->in_batch)
-			ops->set_map(vdpa, dev->iotlb);
+			ops->set_map(vdpa, asid, iotlb);
 	} else {
 		iommu_unmap(v->domain, iova, size);
 	}
+
+	/* If we are in the middle of batch processing, delay the free
+	 * of AS until BATCH_END.
+	 */
+	if (!v->in_batch && !iotlb->nmaps)
+		vhost_vdpa_remove_as(v, asid);
 }
 
 static int vhost_vdpa_va_map(struct vhost_vdpa *v,
+			     struct vhost_iotlb *iotlb,
 			     u64 iova, u64 size, u64 uaddr, u32 perm)
 {
 	struct vhost_dev *dev = &v->vdev;
@@ -696,7 +812,7 @@
 		offset = (vma->vm_pgoff << PAGE_SHIFT) + uaddr - vma->vm_start;
 		map_file->offset = offset;
 		map_file->file = get_file(vma->vm_file);
-		ret = vhost_vdpa_map(v, map_iova, map_size, uaddr,
+		ret = vhost_vdpa_map(v, iotlb, map_iova, map_size, uaddr,
 				     perm, map_file);
 		if (ret) {
 			fput(map_file->file);
@@ -709,7 +825,7 @@
 		map_iova += map_size;
 	}
 	if (ret)
-		vhost_vdpa_unmap(v, iova, map_iova - iova);
+		vhost_vdpa_unmap(v, iotlb, iova, map_iova - iova);
 
 	mmap_read_unlock(dev->mm);
 
@@ -717,6 +833,7 @@
 }
 
 static int vhost_vdpa_pa_map(struct vhost_vdpa *v,
+			     struct vhost_iotlb *iotlb,
 			     u64 iova, u64 size, u64 uaddr, u32 perm)
 {
 	struct vhost_dev *dev = &v->vdev;
@@ -780,7 +897,7 @@
 			if (last_pfn && (this_pfn != last_pfn + 1)) {
 				/* Pin a contiguous chunk of memory */
 				csize = PFN_PHYS(last_pfn - map_pfn + 1);
-				ret = vhost_vdpa_map(v, iova, csize,
+				ret = vhost_vdpa_map(v, iotlb, iova, csize,
 						     PFN_PHYS(map_pfn),
 						     perm, NULL);
 				if (ret) {
@@ -810,7 +927,7 @@
 	}
 
 	/* Pin the rest chunk */
-	ret = vhost_vdpa_map(v, iova, PFN_PHYS(last_pfn - map_pfn + 1),
+	ret = vhost_vdpa_map(v, iotlb, iova, PFN_PHYS(last_pfn - map_pfn + 1),
 			     PFN_PHYS(map_pfn), perm, NULL);
 out:
 	if (ret) {
@@ -830,7 +947,7 @@
 			for (pfn = map_pfn; pfn <= last_pfn; pfn++)
 				unpin_user_page(pfn_to_page(pfn));
 		}
-		vhost_vdpa_unmap(v, start, size);
+		vhost_vdpa_unmap(v, iotlb, start, size);
 	}
 unlock:
 	mmap_read_unlock(dev->mm);
@@ -841,11 +958,10 @@
 }
 
 static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
+					   struct vhost_iotlb *iotlb,
 					   struct vhost_iotlb_msg *msg)
 {
-	struct vhost_dev *dev = &v->vdev;
 	struct vdpa_device *vdpa = v->vdpa;
-	struct vhost_iotlb *iotlb = dev->iotlb;
 
 	if (msg->iova < v->range.first || !msg->size ||
 	    msg->iova > U64_MAX - msg->size + 1 ||
@@ -857,19 +973,21 @@
 		return -EEXIST;
 
 	if (vdpa->use_va)
-		return vhost_vdpa_va_map(v, msg->iova, msg->size,
+		return vhost_vdpa_va_map(v, iotlb, msg->iova, msg->size,
 					 msg->uaddr, msg->perm);
 
-	return vhost_vdpa_pa_map(v, msg->iova, msg->size, msg->uaddr,
+	return vhost_vdpa_pa_map(v, iotlb, msg->iova, msg->size, msg->uaddr,
 				 msg->perm);
 }
 
-static int vhost_vdpa_process_iotlb_msg(struct vhost_dev *dev,
+static int vhost_vdpa_process_iotlb_msg(struct vhost_dev *dev, u32 asid,
 					struct vhost_iotlb_msg *msg)
 {
 	struct vhost_vdpa *v = container_of(dev, struct vhost_vdpa, vdev);
 	struct vdpa_device *vdpa = v->vdpa;
 	const struct vdpa_config_ops *ops = vdpa->config;
+	struct vhost_iotlb *iotlb = NULL;
+	struct vhost_vdpa_as *as = NULL;
 	int r = 0;
 
 	mutex_lock(&dev->mutex);
@@ -878,20 +996,47 @@
 	if (r)
 		goto unlock;
 
+	if (msg->type == VHOST_IOTLB_UPDATE ||
+	    msg->type == VHOST_IOTLB_BATCH_BEGIN) {
+		as = vhost_vdpa_find_alloc_as(v, asid);
+		if (!as) {
+			dev_err(&v->dev, "can't find and alloc asid %d\n",
+				asid);
+			r = -EINVAL;
+			goto unlock;
+		}
+		iotlb = &as->iotlb;
+	} else
+		iotlb = asid_to_iotlb(v, asid);
+
+	if ((v->in_batch && v->batch_asid != asid) || !iotlb) {
+		if (v->in_batch && v->batch_asid != asid) {
+			dev_info(&v->dev, "batch id %d asid %d\n",
+				 v->batch_asid, asid);
+		}
+		if (!iotlb)
+			dev_err(&v->dev, "no iotlb for asid %d\n", asid);
+		r = -EINVAL;
+		goto unlock;
+	}
+
 	switch (msg->type) {
 	case VHOST_IOTLB_UPDATE:
-		r = vhost_vdpa_process_iotlb_update(v, msg);
+		r = vhost_vdpa_process_iotlb_update(v, iotlb, msg);
 		break;
 	case VHOST_IOTLB_INVALIDATE:
-		vhost_vdpa_unmap(v, msg->iova, msg->size);
+		vhost_vdpa_unmap(v, iotlb, msg->iova, msg->size);
 		break;
 	case VHOST_IOTLB_BATCH_BEGIN:
+		v->batch_asid = asid;
 		v->in_batch = true;
 		break;
 	case VHOST_IOTLB_BATCH_END:
 		if (v->in_batch && ops->set_map)
-			ops->set_map(vdpa, dev->iotlb);
+			ops->set_map(vdpa, asid, iotlb);
 		v->in_batch = false;
+		if (!iotlb->nmaps)
+			vhost_vdpa_remove_as(v, asid);
 		break;
 	default:
 		r = -EINVAL;
@@ -977,6 +1122,21 @@
 	}
 }
 
+static void vhost_vdpa_cleanup(struct vhost_vdpa *v)
+{
+	struct vhost_vdpa_as *as;
+	u32 asid;
+
+	vhost_dev_cleanup(&v->vdev);
+	kfree(v->vdev.vqs);
+
+	for (asid = 0; asid < v->vdpa->nas; asid++) {
+		as = asid_to_as(v, asid);
+		if (as)
+			vhost_vdpa_remove_as(v, asid);
+	}
+}
+
 static int vhost_vdpa_open(struct inode *inode, struct file *filep)
 {
 	struct vhost_vdpa *v;
@@ -1010,15 +1170,9 @@
 	vhost_dev_init(dev, vqs, nvqs, 0, 0, 0, false,
 		       vhost_vdpa_process_iotlb_msg);
 
-	dev->iotlb = vhost_iotlb_alloc(0, 0);
-	if (!dev->iotlb) {
-		r = -ENOMEM;
-		goto err_init_iotlb;
-	}
-
 	r = vhost_vdpa_alloc_domain(v);
 	if (r)
-		goto err_init_iotlb;
+		goto err_alloc_domain;
 
 	vhost_vdpa_set_iova_range(v);
 
@@ -1026,9 +1180,8 @@
 
 	return 0;
 
-err_init_iotlb:
-	vhost_dev_cleanup(&v->vdev);
-	kfree(vqs);
+err_alloc_domain:
+	vhost_vdpa_cleanup(v);
 err:
 	atomic_dec(&v->opened);
 	return r;
@@ -1052,11 +1205,9 @@
 	vhost_vdpa_clean_irq(v);
 	vhost_vdpa_reset(v);
 	vhost_dev_stop(&v->vdev);
-	vhost_vdpa_iotlb_free(v);
 	vhost_vdpa_free_domain(v);
 	vhost_vdpa_config_put(v);
 	vhost_dev_cleanup(&v->vdev);
-	kfree(v->vdev.vqs);
 	mutex_unlock(&d->mutex);
 
 	atomic_dec(&v->opened);
@@ -1152,7 +1303,14 @@
 	const struct vdpa_config_ops *ops = vdpa->config;
 	struct vhost_vdpa *v;
 	int minor;
-	int r;
+	int i, r;
+
+	/* We can't support platform IOMMU device with more than 1
+	 * group or as
+	 */
+	if (!ops->set_map && !ops->dma_map &&
+	    (vdpa->ngroups > 1 || vdpa->nas > 1))
+		return -EOPNOTSUPP;
 
 	v = kzalloc(sizeof(*v), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
 	if (!v)
@@ -1196,6 +1354,9 @@
 	init_completion(&v->completion);
 	vdpa_set_drvdata(vdpa, v);
 
+	for (i = 0; i < VHOST_VDPA_IOTLB_BUCKETS; i++)
+		INIT_HLIST_HEAD(&v->as[i]);
+
 	return 0;
 
 err:
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index d02173f..4009782 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -231,7 +231,7 @@
 }
 EXPORT_SYMBOL_GPL(vhost_poll_stop);
 
-void vhost_work_dev_flush(struct vhost_dev *dev)
+void vhost_dev_flush(struct vhost_dev *dev)
 {
 	struct vhost_flush_struct flush;
 
@@ -243,15 +243,7 @@
 		wait_for_completion(&flush.wait_event);
 	}
 }
-EXPORT_SYMBOL_GPL(vhost_work_dev_flush);
-
-/* Flush any work that has been scheduled. When calling this, don't hold any
- * locks that are also used by the callback. */
-void vhost_poll_flush(struct vhost_poll *poll)
-{
-	vhost_work_dev_flush(poll->dev);
-}
-EXPORT_SYMBOL_GPL(vhost_poll_flush);
+EXPORT_SYMBOL_GPL(vhost_dev_flush);
 
 void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work)
 {
@@ -468,7 +460,7 @@
 		    struct vhost_virtqueue **vqs, int nvqs,
 		    int iov_limit, int weight, int byte_weight,
 		    bool use_worker,
-		    int (*msg_handler)(struct vhost_dev *dev,
+		    int (*msg_handler)(struct vhost_dev *dev, u32 asid,
 				       struct vhost_iotlb_msg *msg))
 {
 	struct vhost_virtqueue *vq;
@@ -538,7 +530,7 @@
 	attach.owner = current;
 	vhost_work_init(&attach.work, vhost_attach_cgroups_work);
 	vhost_work_queue(dev, &attach.work);
-	vhost_work_dev_flush(dev);
+	vhost_dev_flush(dev);
 	return attach.ret;
 }
 
@@ -661,11 +653,11 @@
 	int i;
 
 	for (i = 0; i < dev->nvqs; ++i) {
-		if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick) {
+		if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick)
 			vhost_poll_stop(&dev->vqs[i]->poll);
-			vhost_poll_flush(&dev->vqs[i]->poll);
-		}
 	}
+
+	vhost_dev_flush(dev);
 }
 EXPORT_SYMBOL_GPL(vhost_dev_stop);
 
@@ -1090,11 +1082,14 @@
 	return true;
 }
 
-static int vhost_process_iotlb_msg(struct vhost_dev *dev,
+static int vhost_process_iotlb_msg(struct vhost_dev *dev, u32 asid,
 				   struct vhost_iotlb_msg *msg)
 {
 	int ret = 0;
 
+	if (asid != 0)
+		return -EINVAL;
+
 	mutex_lock(&dev->mutex);
 	vhost_dev_lock_vqs(dev);
 	switch (msg->type) {
@@ -1141,6 +1136,7 @@
 	struct vhost_iotlb_msg msg;
 	size_t offset;
 	int type, ret;
+	u32 asid = 0;
 
 	ret = copy_from_iter(&type, sizeof(type), from);
 	if (ret != sizeof(type)) {
@@ -1156,7 +1152,16 @@
 		offset = offsetof(struct vhost_msg, iotlb) - sizeof(int);
 		break;
 	case VHOST_IOTLB_MSG_V2:
-		offset = sizeof(__u32);
+		if (vhost_backend_has_feature(dev->vqs[0],
+					      VHOST_BACKEND_F_IOTLB_ASID)) {
+			ret = copy_from_iter(&asid, sizeof(asid), from);
+			if (ret != sizeof(asid)) {
+				ret = -EINVAL;
+				goto done;
+			}
+			offset = 0;
+		} else
+			offset = sizeof(__u32);
 		break;
 	default:
 		ret = -EINVAL;
@@ -1178,9 +1183,9 @@
 	}
 
 	if (dev->msg_handler)
-		ret = dev->msg_handler(dev, &msg);
+		ret = dev->msg_handler(dev, asid, &msg);
 	else
-		ret = vhost_process_iotlb_msg(dev, &msg);
+		ret = vhost_process_iotlb_msg(dev, asid, &msg);
 	if (ret) {
 		ret = -EFAULT;
 		goto done;
@@ -1719,7 +1724,7 @@
 	mutex_unlock(&vq->mutex);
 
 	if (pollstop && vq->handle_kick)
-		vhost_poll_flush(&vq->poll);
+		vhost_dev_flush(vq->poll.dev);
 	return r;
 }
 EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 638bb64..d910910 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -44,9 +44,8 @@
 		     __poll_t mask, struct vhost_dev *dev);
 int vhost_poll_start(struct vhost_poll *poll, struct file *file);
 void vhost_poll_stop(struct vhost_poll *poll);
-void vhost_poll_flush(struct vhost_poll *poll);
 void vhost_poll_queue(struct vhost_poll *poll);
-void vhost_work_dev_flush(struct vhost_dev *dev);
+void vhost_dev_flush(struct vhost_dev *dev);
 
 struct vhost_log {
 	u64 addr;
@@ -161,7 +160,7 @@
 	int byte_weight;
 	u64 kcov_handle;
 	bool use_worker;
-	int (*msg_handler)(struct vhost_dev *dev,
+	int (*msg_handler)(struct vhost_dev *dev, u32 asid,
 			   struct vhost_iotlb_msg *msg);
 };
 
@@ -169,7 +168,7 @@
 void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs,
 		    int nvqs, int iov_limit, int weight, int byte_weight,
 		    bool use_worker,
-		    int (*msg_handler)(struct vhost_dev *dev,
+		    int (*msg_handler)(struct vhost_dev *dev, u32 asid,
 				       struct vhost_iotlb_msg *msg));
 long vhost_dev_set_owner(struct vhost_dev *dev);
 bool vhost_dev_has_owner(struct vhost_dev *dev);
diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index e6c9d41..3683304 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -705,12 +705,7 @@
 
 static void vhost_vsock_flush(struct vhost_vsock *vsock)
 {
-	int i;
-
-	for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++)
-		if (vsock->vqs[i].handle_kick)
-			vhost_poll_flush(&vsock->vqs[i].poll);
-	vhost_work_dev_flush(&vsock->dev);
+	vhost_dev_flush(&vsock->dev);
 }
 
 static void vhost_vsock_reset_orphans(struct sock *sk)
diff --git a/drivers/virtio/virtio.c b/drivers/virtio/virtio.c
index 22f15f4..ef04a96 100644
--- a/drivers/virtio/virtio.c
+++ b/drivers/virtio/virtio.c
@@ -169,7 +169,7 @@
 /* Do some validation, then set FEATURES_OK */
 static int virtio_features_ok(struct virtio_device *dev)
 {
-	unsigned status;
+	unsigned int status;
 	int ret;
 
 	might_sleep();
@@ -220,6 +220,15 @@
  * */
 void virtio_reset_device(struct virtio_device *dev)
 {
+	/*
+	 * The below virtio_synchronize_cbs() guarantees that any
+	 * interrupt for this line arriving after
+	 * virtio_synchronize_vqs() has completed is guaranteed to see
+	 * vq->broken as true.
+	 */
+	virtio_break_device(dev);
+	virtio_synchronize_cbs(dev);
+
 	dev->config->reset(dev);
 }
 EXPORT_SYMBOL_GPL(virtio_reset_device);
@@ -413,7 +422,7 @@
 	device_initialize(&dev->dev);
 
 	/* Assign a unique device index and hence name. */
-	err = ida_simple_get(&virtio_index_ida, 0, 0, GFP_KERNEL);
+	err = ida_alloc(&virtio_index_ida, GFP_KERNEL);
 	if (err < 0)
 		goto out;
 
@@ -428,16 +437,16 @@
 	dev->config_enabled = false;
 	dev->config_change_pending = false;
 
+	INIT_LIST_HEAD(&dev->vqs);
+	spin_lock_init(&dev->vqs_list_lock);
+
 	/* We always start by resetting the device, in case a previous
 	 * driver messed it up.  This also tests that code path a little. */
-	dev->config->reset(dev);
+	virtio_reset_device(dev);
 
 	/* Acknowledge that we've seen the device. */
 	virtio_add_status(dev, VIRTIO_CONFIG_S_ACKNOWLEDGE);
 
-	INIT_LIST_HEAD(&dev->vqs);
-	spin_lock_init(&dev->vqs_list_lock);
-
 	/*
 	 * device_add() causes the bus infrastructure to look for a matching
 	 * driver.
@@ -451,7 +460,7 @@
 out_of_node_put:
 	of_node_put(dev->dev.of_node);
 out_ida_remove:
-	ida_simple_remove(&virtio_index_ida, dev->index);
+	ida_free(&virtio_index_ida, dev->index);
 out:
 	virtio_add_status(dev, VIRTIO_CONFIG_S_FAILED);
 	return err;
@@ -469,7 +478,7 @@
 	int index = dev->index; /* save for after device release */
 
 	device_unregister(&dev->dev);
-	ida_simple_remove(&virtio_index_ida, index);
+	ida_free(&virtio_index_ida, index);
 }
 EXPORT_SYMBOL_GPL(unregister_virtio_device);
 
@@ -496,7 +505,7 @@
 
 	/* We always start by resetting the device, in case a previous
 	 * driver messed it up. */
-	dev->config->reset(dev);
+	virtio_reset_device(dev);
 
 	/* Acknowledge that we've seen the device. */
 	virtio_add_status(dev, VIRTIO_CONFIG_S_ACKNOWLEDGE);
@@ -526,8 +535,9 @@
 			goto err;
 	}
 
-	/* Finally, tell the device we're all set */
-	virtio_add_status(dev, VIRTIO_CONFIG_S_DRIVER_OK);
+	/* If restore didn't do it, mark device DRIVER_OK ourselves. */
+	if (!(dev->config->get_status(dev) & VIRTIO_CONFIG_S_DRIVER_OK))
+		virtio_device_ready(dev);
 
 	virtio_config_enable(dev);
 
diff --git a/drivers/virtio/virtio_balloon.c b/drivers/virtio/virtio_balloon.c
index f4c34a2..b9737da 100644
--- a/drivers/virtio/virtio_balloon.c
+++ b/drivers/virtio/virtio_balloon.c
@@ -27,7 +27,7 @@
  * multiple balloon pages.  All memory counters in this driver are in balloon
  * page units.
  */
-#define VIRTIO_BALLOON_PAGES_PER_PAGE (unsigned)(PAGE_SIZE >> VIRTIO_BALLOON_PFN_SHIFT)
+#define VIRTIO_BALLOON_PAGES_PER_PAGE (unsigned int)(PAGE_SIZE >> VIRTIO_BALLOON_PFN_SHIFT)
 #define VIRTIO_BALLOON_ARRAY_PFNS_MAX 256
 /* Maximum number of (4k) pages to deflate on OOM notifications. */
 #define VIRTIO_BALLOON_OOM_NR_PAGES 256
@@ -208,10 +208,10 @@
 					  page_to_balloon_pfn(page) + i);
 }
 
-static unsigned fill_balloon(struct virtio_balloon *vb, size_t num)
+static unsigned int fill_balloon(struct virtio_balloon *vb, size_t num)
 {
-	unsigned num_allocated_pages;
-	unsigned num_pfns;
+	unsigned int num_allocated_pages;
+	unsigned int num_pfns;
 	struct page *page;
 	LIST_HEAD(pages);
 
@@ -272,9 +272,9 @@
 	}
 }
 
-static unsigned leak_balloon(struct virtio_balloon *vb, size_t num)
+static unsigned int leak_balloon(struct virtio_balloon *vb, size_t num)
 {
-	unsigned num_freed_pages;
+	unsigned int num_freed_pages;
 	struct page *page;
 	struct balloon_dev_info *vb_dev_info = &vb->vb_dev_info;
 	LIST_HEAD(pages);
diff --git a/drivers/virtio/virtio_mmio.c b/drivers/virtio/virtio_mmio.c
index 56128b9..f9a36bc 100644
--- a/drivers/virtio/virtio_mmio.c
+++ b/drivers/virtio/virtio_mmio.c
@@ -144,8 +144,8 @@
 	return 0;
 }
 
-static void vm_get(struct virtio_device *vdev, unsigned offset,
-		   void *buf, unsigned len)
+static void vm_get(struct virtio_device *vdev, unsigned int offset,
+		   void *buf, unsigned int len)
 {
 	struct virtio_mmio_device *vm_dev = to_virtio_mmio_device(vdev);
 	void __iomem *base = vm_dev->base + VIRTIO_MMIO_CONFIG;
@@ -186,8 +186,8 @@
 	}
 }
 
-static void vm_set(struct virtio_device *vdev, unsigned offset,
-		   const void *buf, unsigned len)
+static void vm_set(struct virtio_device *vdev, unsigned int offset,
+		   const void *buf, unsigned int len)
 {
 	struct virtio_mmio_device *vm_dev = to_virtio_mmio_device(vdev);
 	void __iomem *base = vm_dev->base + VIRTIO_MMIO_CONFIG;
@@ -253,6 +253,11 @@
 	/* We should never be setting status to 0. */
 	BUG_ON(status == 0);
 
+	/*
+	 * Per memory-barriers.txt, wmb() is not needed to guarantee
+	 * that the the cache coherent memory writes have completed
+	 * before writing to the MMIO region.
+	 */
 	writel(status, vm_dev->base + VIRTIO_MMIO_STATUS);
 }
 
@@ -345,7 +350,14 @@
 	free_irq(platform_get_irq(vm_dev->pdev, 0), vm_dev);
 }
 
-static struct virtqueue *vm_setup_vq(struct virtio_device *vdev, unsigned index,
+static void vm_synchronize_cbs(struct virtio_device *vdev)
+{
+	struct virtio_mmio_device *vm_dev = to_virtio_mmio_device(vdev);
+
+	synchronize_irq(platform_get_irq(vm_dev->pdev, 0));
+}
+
+static struct virtqueue *vm_setup_vq(struct virtio_device *vdev, unsigned int index,
 				  void (*callback)(struct virtqueue *vq),
 				  const char *name, bool ctx)
 {
@@ -455,7 +467,7 @@
 	return ERR_PTR(err);
 }
 
-static int vm_find_vqs(struct virtio_device *vdev, unsigned nvqs,
+static int vm_find_vqs(struct virtio_device *vdev, unsigned int nvqs,
 		       struct virtqueue *vqs[],
 		       vq_callback_t *callbacks[],
 		       const char * const names[],
@@ -541,6 +553,7 @@
 	.finalize_features = vm_finalize_features,
 	.bus_name	= vm_bus_name,
 	.get_shm_region = vm_get_shm_region,
+	.synchronize_cbs = vm_synchronize_cbs,
 };
 
 
@@ -657,7 +670,7 @@
 	int err;
 	struct resource resources[2] = {};
 	char *str;
-	long long int base, size;
+	long long base, size;
 	unsigned int irq;
 	int processed, consumed = 0;
 	struct platform_device *pdev;
diff --git a/drivers/virtio/virtio_pci_common.c b/drivers/virtio/virtio_pci_common.c
index d724f67..ca51fcc 100644
--- a/drivers/virtio/virtio_pci_common.c
+++ b/drivers/virtio/virtio_pci_common.c
@@ -104,8 +104,8 @@
 {
 	struct virtio_pci_device *vp_dev = to_vp_device(vdev);
 	const char *name = dev_name(&vp_dev->vdev.dev);
-	unsigned flags = PCI_IRQ_MSIX;
-	unsigned i, v;
+	unsigned int flags = PCI_IRQ_MSIX;
+	unsigned int i, v;
 	int err = -ENOMEM;
 
 	vp_dev->msix_vectors = nvectors;
@@ -171,7 +171,7 @@
 	return err;
 }
 
-static struct virtqueue *vp_setup_vq(struct virtio_device *vdev, unsigned index,
+static struct virtqueue *vp_setup_vq(struct virtio_device *vdev, unsigned int index,
 				     void (*callback)(struct virtqueue *vq),
 				     const char *name,
 				     bool ctx,
@@ -254,8 +254,7 @@
 
 	if (vp_dev->msix_affinity_masks) {
 		for (i = 0; i < vp_dev->msix_vectors; i++)
-			if (vp_dev->msix_affinity_masks[i])
-				free_cpumask_var(vp_dev->msix_affinity_masks[i]);
+			free_cpumask_var(vp_dev->msix_affinity_masks[i]);
 	}
 
 	if (vp_dev->msix_enabled) {
@@ -276,7 +275,7 @@
 	vp_dev->vqs = NULL;
 }
 
-static int vp_find_vqs_msix(struct virtio_device *vdev, unsigned nvqs,
+static int vp_find_vqs_msix(struct virtio_device *vdev, unsigned int nvqs,
 		struct virtqueue *vqs[], vq_callback_t *callbacks[],
 		const char * const names[], bool per_vq_vectors,
 		const bool *ctx,
@@ -350,7 +349,7 @@
 	return err;
 }
 
-static int vp_find_vqs_intx(struct virtio_device *vdev, unsigned nvqs,
+static int vp_find_vqs_intx(struct virtio_device *vdev, unsigned int nvqs,
 		struct virtqueue *vqs[], vq_callback_t *callbacks[],
 		const char * const names[], const bool *ctx)
 {
@@ -389,7 +388,7 @@
 }
 
 /* the config->find_vqs() implementation */
-int vp_find_vqs(struct virtio_device *vdev, unsigned nvqs,
+int vp_find_vqs(struct virtio_device *vdev, unsigned int nvqs,
 		struct virtqueue *vqs[], vq_callback_t *callbacks[],
 		const char * const names[], const bool *ctx,
 		struct irq_affinity *desc)
diff --git a/drivers/virtio/virtio_pci_common.h b/drivers/virtio/virtio_pci_common.h
index eb17a29..23112d84 100644
--- a/drivers/virtio/virtio_pci_common.h
+++ b/drivers/virtio/virtio_pci_common.h
@@ -38,7 +38,7 @@
 	struct list_head node;
 
 	/* MSI-X vector (or none) */
-	unsigned msix_vector;
+	unsigned int msix_vector;
 };
 
 /* Our device structure */
@@ -68,16 +68,16 @@
 	 * and I'm too lazy to allocate each name separately. */
 	char (*msix_names)[256];
 	/* Number of available vectors */
-	unsigned msix_vectors;
+	unsigned int msix_vectors;
 	/* Vectors allocated, excluding per-vq vectors if any */
-	unsigned msix_used_vectors;
+	unsigned int msix_used_vectors;
 
 	/* Whether we have vector per vq */
 	bool per_vq_vectors;
 
 	struct virtqueue *(*setup_vq)(struct virtio_pci_device *vp_dev,
 				      struct virtio_pci_vq_info *info,
-				      unsigned idx,
+				      unsigned int idx,
 				      void (*callback)(struct virtqueue *vq),
 				      const char *name,
 				      bool ctx,
@@ -108,7 +108,7 @@
 /* the config->del_vqs() implementation */
 void vp_del_vqs(struct virtio_device *vdev);
 /* the config->find_vqs() implementation */
-int vp_find_vqs(struct virtio_device *vdev, unsigned nvqs,
+int vp_find_vqs(struct virtio_device *vdev, unsigned int nvqs,
 		struct virtqueue *vqs[], vq_callback_t *callbacks[],
 		const char * const names[], const bool *ctx,
 		struct irq_affinity *desc);
diff --git a/drivers/virtio/virtio_pci_legacy.c b/drivers/virtio/virtio_pci_legacy.c
index 6f4e34c..a5e5721 100644
--- a/drivers/virtio/virtio_pci_legacy.c
+++ b/drivers/virtio/virtio_pci_legacy.c
@@ -45,8 +45,8 @@
 }
 
 /* virtio config->get() implementation */
-static void vp_get(struct virtio_device *vdev, unsigned offset,
-		   void *buf, unsigned len)
+static void vp_get(struct virtio_device *vdev, unsigned int offset,
+		   void *buf, unsigned int len)
 {
 	struct virtio_pci_device *vp_dev = to_vp_device(vdev);
 	void __iomem *ioaddr = vp_dev->ldev.ioaddr +
@@ -61,8 +61,8 @@
 
 /* the config->set() implementation.  it's symmetric to the config->get()
  * implementation */
-static void vp_set(struct virtio_device *vdev, unsigned offset,
-		   const void *buf, unsigned len)
+static void vp_set(struct virtio_device *vdev, unsigned int offset,
+		   const void *buf, unsigned int len)
 {
 	struct virtio_pci_device *vp_dev = to_vp_device(vdev);
 	void __iomem *ioaddr = vp_dev->ldev.ioaddr +
@@ -109,7 +109,7 @@
 
 static struct virtqueue *setup_vq(struct virtio_pci_device *vp_dev,
 				  struct virtio_pci_vq_info *info,
-				  unsigned index,
+				  unsigned int index,
 				  void (*callback)(struct virtqueue *vq),
 				  const char *name,
 				  bool ctx,
@@ -192,6 +192,7 @@
 	.reset		= vp_reset,
 	.find_vqs	= vp_find_vqs,
 	.del_vqs	= vp_del_vqs,
+	.synchronize_cbs = vp_synchronize_vectors,
 	.get_features	= vp_get_features,
 	.finalize_features = vp_finalize_features,
 	.bus_name	= vp_bus_name,
diff --git a/drivers/virtio/virtio_pci_modern.c b/drivers/virtio/virtio_pci_modern.c
index a2671a2..623906b 100644
--- a/drivers/virtio/virtio_pci_modern.c
+++ b/drivers/virtio/virtio_pci_modern.c
@@ -60,8 +60,8 @@
 }
 
 /* virtio config->get() implementation */
-static void vp_get(struct virtio_device *vdev, unsigned offset,
-		   void *buf, unsigned len)
+static void vp_get(struct virtio_device *vdev, unsigned int offset,
+		   void *buf, unsigned int len)
 {
 	struct virtio_pci_device *vp_dev = to_vp_device(vdev);
 	struct virtio_pci_modern_device *mdev = &vp_dev->mdev;
@@ -98,8 +98,8 @@
 
 /* the config->set() implementation.  it's symmetric to the config->get()
  * implementation */
-static void vp_set(struct virtio_device *vdev, unsigned offset,
-		   const void *buf, unsigned len)
+static void vp_set(struct virtio_device *vdev, unsigned int offset,
+		   const void *buf, unsigned int len)
 {
 	struct virtio_pci_device *vp_dev = to_vp_device(vdev);
 	struct virtio_pci_modern_device *mdev = &vp_dev->mdev;
@@ -183,7 +183,7 @@
 
 static struct virtqueue *setup_vq(struct virtio_pci_device *vp_dev,
 				  struct virtio_pci_vq_info *info,
-				  unsigned index,
+				  unsigned int index,
 				  void (*callback)(struct virtqueue *vq),
 				  const char *name,
 				  bool ctx,
@@ -248,7 +248,7 @@
 	return ERR_PTR(err);
 }
 
-static int vp_modern_find_vqs(struct virtio_device *vdev, unsigned nvqs,
+static int vp_modern_find_vqs(struct virtio_device *vdev, unsigned int nvqs,
 			      struct virtqueue *vqs[],
 			      vq_callback_t *callbacks[],
 			      const char * const names[], const bool *ctx,
@@ -394,6 +394,7 @@
 	.reset		= vp_reset,
 	.find_vqs	= vp_modern_find_vqs,
 	.del_vqs	= vp_del_vqs,
+	.synchronize_cbs = vp_synchronize_vectors,
 	.get_features	= vp_get_features,
 	.finalize_features = vp_finalize_features,
 	.bus_name	= vp_bus_name,
@@ -411,6 +412,7 @@
 	.reset		= vp_reset,
 	.find_vqs	= vp_modern_find_vqs,
 	.del_vqs	= vp_del_vqs,
+	.synchronize_cbs = vp_synchronize_vectors,
 	.get_features	= vp_get_features,
 	.finalize_features = vp_finalize_features,
 	.bus_name	= vp_bus_name,
diff --git a/drivers/virtio/virtio_pci_modern_dev.c b/drivers/virtio/virtio_pci_modern_dev.c
index 591738a..a0fa14f 100644
--- a/drivers/virtio/virtio_pci_modern_dev.c
+++ b/drivers/virtio/virtio_pci_modern_dev.c
@@ -347,6 +347,7 @@
 err_map_isr:
 	pci_iounmap(pci_dev, mdev->common);
 err_map_common:
+	pci_release_selected_regions(pci_dev, mdev->modern_bars);
 	return err;
 }
 EXPORT_SYMBOL_GPL(vp_modern_probe);
@@ -466,6 +467,11 @@
 {
 	struct virtio_pci_common_cfg __iomem *cfg = mdev->common;
 
+	/*
+	 * Per memory-barriers.txt, wmb() is not needed to guarantee
+	 * that the the cache coherent memory writes have completed
+	 * before writing to the MMIO region.
+	 */
 	vp_iowrite8(status, &cfg->device_status);
 }
 EXPORT_SYMBOL_GPL(vp_modern_set_status);
diff --git a/drivers/virtio/virtio_ring.c b/drivers/virtio/virtio_ring.c
index cfb028c..13a7348 100644
--- a/drivers/virtio/virtio_ring.c
+++ b/drivers/virtio/virtio_ring.c
@@ -205,11 +205,9 @@
 
 #define to_vvq(_vq) container_of(_vq, struct vring_virtqueue, vq)
 
-static inline bool virtqueue_use_indirect(struct virtqueue *_vq,
+static inline bool virtqueue_use_indirect(struct vring_virtqueue *vq,
 					  unsigned int total_sg)
 {
-	struct vring_virtqueue *vq = to_vvq(_vq);
-
 	/*
 	 * If the host supports indirect descriptor tables, and we have multiple
 	 * buffers, then go indirect. FIXME: tune this threshold
@@ -499,7 +497,7 @@
 
 	head = vq->free_head;
 
-	if (virtqueue_use_indirect(_vq, total_sg))
+	if (virtqueue_use_indirect(vq, total_sg))
 		desc = alloc_indirect_split(_vq, total_sg, gfp);
 	else {
 		desc = NULL;
@@ -519,7 +517,7 @@
 		descs_used = total_sg;
 	}
 
-	if (vq->vq.num_free < descs_used) {
+	if (unlikely(vq->vq.num_free < descs_used)) {
 		pr_debug("Can't add buf len %i - avail = %i\n",
 			 descs_used, vq->vq.num_free);
 		/* FIXME: for historical reasons, we force a notify here if
@@ -811,7 +809,7 @@
 	}
 }
 
-static unsigned virtqueue_enable_cb_prepare_split(struct virtqueue *_vq)
+static unsigned int virtqueue_enable_cb_prepare_split(struct virtqueue *_vq)
 {
 	struct vring_virtqueue *vq = to_vvq(_vq);
 	u16 last_used_idx;
@@ -836,7 +834,7 @@
 	return last_used_idx;
 }
 
-static bool virtqueue_poll_split(struct virtqueue *_vq, unsigned last_used_idx)
+static bool virtqueue_poll_split(struct virtqueue *_vq, unsigned int last_used_idx)
 {
 	struct vring_virtqueue *vq = to_vvq(_vq);
 
@@ -1178,7 +1176,7 @@
 
 	BUG_ON(total_sg == 0);
 
-	if (virtqueue_use_indirect(_vq, total_sg)) {
+	if (virtqueue_use_indirect(vq, total_sg)) {
 		err = virtqueue_add_indirect_packed(vq, sgs, total_sg, out_sgs,
 						    in_sgs, data, gfp);
 		if (err != -ENOMEM) {
@@ -1488,7 +1486,7 @@
 	}
 }
 
-static unsigned virtqueue_enable_cb_prepare_packed(struct virtqueue *_vq)
+static unsigned int virtqueue_enable_cb_prepare_packed(struct virtqueue *_vq)
 {
 	struct vring_virtqueue *vq = to_vvq(_vq);
 
@@ -1690,7 +1688,7 @@
 	vq->we_own_ring = true;
 	vq->notify = notify;
 	vq->weak_barriers = weak_barriers;
-	vq->broken = false;
+	vq->broken = true;
 	vq->last_used_idx = 0;
 	vq->event_triggered = false;
 	vq->num_added = 0;
@@ -2027,7 +2025,7 @@
  * Caller must ensure we don't call this with other virtqueue
  * operations at the same time (except where noted).
  */
-unsigned virtqueue_enable_cb_prepare(struct virtqueue *_vq)
+unsigned int virtqueue_enable_cb_prepare(struct virtqueue *_vq)
 {
 	struct vring_virtqueue *vq = to_vvq(_vq);
 
@@ -2048,7 +2046,7 @@
  *
  * This does not need to be serialized.
  */
-bool virtqueue_poll(struct virtqueue *_vq, unsigned last_used_idx)
+bool virtqueue_poll(struct virtqueue *_vq, unsigned int last_used_idx)
 {
 	struct vring_virtqueue *vq = to_vvq(_vq);
 
@@ -2074,7 +2072,7 @@
  */
 bool virtqueue_enable_cb(struct virtqueue *_vq)
 {
-	unsigned last_used_idx = virtqueue_enable_cb_prepare(_vq);
+	unsigned int last_used_idx = virtqueue_enable_cb_prepare(_vq);
 
 	return !virtqueue_poll(_vq, last_used_idx);
 }
@@ -2136,8 +2134,11 @@
 		return IRQ_NONE;
 	}
 
-	if (unlikely(vq->broken))
-		return IRQ_HANDLED;
+	if (unlikely(vq->broken)) {
+		dev_warn_once(&vq->vq.vdev->dev,
+			      "virtio vring IRQ raised before DRIVER_OK");
+		return IRQ_NONE;
+	}
 
 	/* Just a hint for performance: so it's ok that this can be racy! */
 	if (vq->event)
@@ -2179,7 +2180,7 @@
 	vq->we_own_ring = false;
 	vq->notify = notify;
 	vq->weak_barriers = weak_barriers;
-	vq->broken = false;
+	vq->broken = true;
 	vq->last_used_idx = 0;
 	vq->event_triggered = false;
 	vq->num_added = 0;
@@ -2397,6 +2398,28 @@
 }
 EXPORT_SYMBOL_GPL(virtio_break_device);
 
+/*
+ * This should allow the device to be used by the driver. You may
+ * need to grab appropriate locks to flush the write to
+ * vq->broken. This should only be used in some specific case e.g
+ * (probing and restoring). This function should only be called by the
+ * core, not directly by the driver.
+ */
+void __virtio_unbreak_device(struct virtio_device *dev)
+{
+	struct virtqueue *_vq;
+
+	spin_lock(&dev->vqs_list_lock);
+	list_for_each_entry(_vq, &dev->vqs, list) {
+		struct vring_virtqueue *vq = to_vvq(_vq);
+
+		/* Pairs with READ_ONCE() in virtqueue_is_broken(). */
+		WRITE_ONCE(vq->broken, false);
+	}
+	spin_unlock(&dev->vqs_list_lock);
+}
+EXPORT_SYMBOL_GPL(__virtio_unbreak_device);
+
 dma_addr_t virtqueue_get_desc_addr(struct virtqueue *_vq)
 {
 	struct vring_virtqueue *vq = to_vvq(_vq);
diff --git a/drivers/virtio/virtio_vdpa.c b/drivers/virtio/virtio_vdpa.c
index 7650455..c40f7de 100644
--- a/drivers/virtio/virtio_vdpa.c
+++ b/drivers/virtio/virtio_vdpa.c
@@ -53,16 +53,16 @@
 	return to_virtio_vdpa_device(vdev)->vdpa;
 }
 
-static void virtio_vdpa_get(struct virtio_device *vdev, unsigned offset,
-			    void *buf, unsigned len)
+static void virtio_vdpa_get(struct virtio_device *vdev, unsigned int offset,
+			    void *buf, unsigned int len)
 {
 	struct vdpa_device *vdpa = vd_get_vdpa(vdev);
 
 	vdpa_get_config(vdpa, offset, buf, len);
 }
 
-static void virtio_vdpa_set(struct virtio_device *vdev, unsigned offset,
-			    const void *buf, unsigned len)
+static void virtio_vdpa_set(struct virtio_device *vdev, unsigned int offset,
+			    const void *buf, unsigned int len)
 {
 	struct vdpa_device *vdpa = vd_get_vdpa(vdev);
 
@@ -184,7 +184,7 @@
 	}
 
 	/* Setup virtqueue callback */
-	cb.callback = virtio_vdpa_virtqueue_cb;
+	cb.callback = callback ? virtio_vdpa_virtqueue_cb : NULL;
 	cb.private = info;
 	ops->set_vq_cb(vdpa, index, &cb);
 	ops->set_vq_num(vdpa, index, virtqueue_get_vring_size(vq));
@@ -263,7 +263,7 @@
 		virtio_vdpa_del_vq(vq);
 }
 
-static int virtio_vdpa_find_vqs(struct virtio_device *vdev, unsigned nvqs,
+static int virtio_vdpa_find_vqs(struct virtio_device *vdev, unsigned int nvqs,
 				struct virtqueue *vqs[],
 				vq_callback_t *callbacks[],
 				const char * const names[],
diff --git a/include/linux/mlx5/mlx5_ifc.h b/include/linux/mlx5/mlx5_ifc.h
index 2cd7d61..fd7d083 100644
--- a/include/linux/mlx5/mlx5_ifc.h
+++ b/include/linux/mlx5/mlx5_ifc.h
@@ -87,6 +87,7 @@
 enum {
 	MLX5_OBJ_TYPE_GENEVE_TLV_OPT = 0x000b,
 	MLX5_OBJ_TYPE_VIRTIO_NET_Q = 0x000d,
+	MLX5_OBJ_TYPE_VIRTIO_Q_COUNTERS = 0x001c,
 	MLX5_OBJ_TYPE_MATCH_DEFINER = 0x0018,
 	MLX5_OBJ_TYPE_MKEY = 0xff01,
 	MLX5_OBJ_TYPE_QP = 0xff02,
diff --git a/include/linux/mlx5/mlx5_ifc_vdpa.h b/include/linux/mlx5/mlx5_ifc_vdpa.h
index 1a9c9d9..4414ed5 100644
--- a/include/linux/mlx5/mlx5_ifc_vdpa.h
+++ b/include/linux/mlx5/mlx5_ifc_vdpa.h
@@ -165,4 +165,43 @@
 	struct mlx5_ifc_general_obj_out_cmd_hdr_bits general_obj_out_cmd_hdr;
 };
 
+struct mlx5_ifc_virtio_q_counters_bits {
+	u8    modify_field_select[0x40];
+	u8    reserved_at_40[0x40];
+	u8    received_desc[0x40];
+	u8    completed_desc[0x40];
+	u8    error_cqes[0x20];
+	u8    bad_desc_errors[0x20];
+	u8    exceed_max_chain[0x20];
+	u8    invalid_buffer[0x20];
+	u8    reserved_at_180[0x280];
+};
+
+struct mlx5_ifc_create_virtio_q_counters_in_bits {
+	struct mlx5_ifc_general_obj_in_cmd_hdr_bits hdr;
+	struct mlx5_ifc_virtio_q_counters_bits virtio_q_counters;
+};
+
+struct mlx5_ifc_create_virtio_q_counters_out_bits {
+	struct mlx5_ifc_general_obj_in_cmd_hdr_bits hdr;
+	struct mlx5_ifc_virtio_q_counters_bits virtio_q_counters;
+};
+
+struct mlx5_ifc_destroy_virtio_q_counters_in_bits {
+	struct mlx5_ifc_general_obj_in_cmd_hdr_bits hdr;
+};
+
+struct mlx5_ifc_destroy_virtio_q_counters_out_bits {
+	struct mlx5_ifc_general_obj_out_cmd_hdr_bits hdr;
+};
+
+struct mlx5_ifc_query_virtio_q_counters_in_bits {
+	struct mlx5_ifc_general_obj_in_cmd_hdr_bits hdr;
+};
+
+struct mlx5_ifc_query_virtio_q_counters_out_bits {
+	struct mlx5_ifc_general_obj_in_cmd_hdr_bits hdr;
+	struct mlx5_ifc_virtio_q_counters_bits counters;
+};
+
 #endif /* __MLX5_IFC_VDPA_H_ */
diff --git a/include/linux/vdpa.h b/include/linux/vdpa.h
index 8943a20..15af802 100644
--- a/include/linux/vdpa.h
+++ b/include/linux/vdpa.h
@@ -66,9 +66,11 @@
  * @dma_dev: the actual device that is performing DMA
  * @driver_override: driver name to force a match
  * @config: the configuration ops for this device.
- * @cf_mutex: Protects get and set access to configuration layout.
+ * @cf_lock: Protects get and set access to configuration layout.
  * @index: device index
  * @features_valid: were features initialized? for legacy guests
+ * @ngroups: the number of virtqueue groups
+ * @nas: the number of address spaces
  * @use_va: indicate whether virtual address must be used by this device
  * @nvqs: maximum number of supported virtqueues
  * @mdev: management device pointer; caller must setup when registering device as part
@@ -79,12 +81,14 @@
 	struct device *dma_dev;
 	const char *driver_override;
 	const struct vdpa_config_ops *config;
-	struct mutex cf_mutex; /* Protects get/set config */
+	struct rw_semaphore cf_lock; /* Protects get/set config */
 	unsigned int index;
 	bool features_valid;
 	bool use_va;
 	u32 nvqs;
 	struct vdpa_mgmt_dev *mdev;
+	unsigned int ngroups;
+	unsigned int nas;
 };
 
 /**
@@ -172,6 +176,10 @@
  *				for the device
  *				@vdev: vdpa device
  *				Returns virtqueue algin requirement
+ * @get_vq_group:		Get the group id for a specific virtqueue
+ *				@vdev: vdpa device
+ *				@idx: virtqueue index
+ *				Returns u32: group id for this virtqueue
  * @get_device_features:	Get virtio features supported by the device
  *				@vdev: vdpa device
  *				Returns the virtio features support by the
@@ -232,10 +240,17 @@
  *				@vdev: vdpa device
  *				Returns the iova range supported by
  *				the device.
+ * @set_group_asid:		Set address space identifier for a
+ *				virtqueue group
+ *				@vdev: vdpa device
+ *				@group: virtqueue group
+ *				@asid: address space id for this group
+ *				Returns integer: success (0) or error (< 0)
  * @set_map:			Set device memory mapping (optional)
  *				Needed for device that using device
  *				specific DMA translation (on-chip IOMMU)
  *				@vdev: vdpa device
+ *				@asid: address space identifier
  *				@iotlb: vhost memory mapping to be
  *				used by the vDPA
  *				Returns integer: success (0) or error (< 0)
@@ -244,6 +259,7 @@
  *				specific DMA translation (on-chip IOMMU)
  *				and preferring incremental map.
  *				@vdev: vdpa device
+ *				@asid: address space identifier
  *				@iova: iova to be mapped
  *				@size: size of the area
  *				@pa: physical address for the map
@@ -255,6 +271,7 @@
  *				specific DMA translation (on-chip IOMMU)
  *				and preferring incremental unmap.
  *				@vdev: vdpa device
+ *				@asid: address space identifier
  *				@iova: iova to be unmapped
  *				@size: size of the area
  *				Returns integer: success (0) or error (< 0)
@@ -276,6 +293,9 @@
 			    const struct vdpa_vq_state *state);
 	int (*get_vq_state)(struct vdpa_device *vdev, u16 idx,
 			    struct vdpa_vq_state *state);
+	int (*get_vendor_vq_stats)(struct vdpa_device *vdev, u16 idx,
+				   struct sk_buff *msg,
+				   struct netlink_ext_ack *extack);
 	struct vdpa_notification_area
 	(*get_vq_notification)(struct vdpa_device *vdev, u16 idx);
 	/* vq irq is not expected to be changed once DRIVER_OK is set */
@@ -283,6 +303,7 @@
 
 	/* Device ops */
 	u32 (*get_vq_align)(struct vdpa_device *vdev);
+	u32 (*get_vq_group)(struct vdpa_device *vdev, u16 idx);
 	u64 (*get_device_features)(struct vdpa_device *vdev);
 	int (*set_driver_features)(struct vdpa_device *vdev, u64 features);
 	u64 (*get_driver_features)(struct vdpa_device *vdev);
@@ -304,10 +325,14 @@
 	struct vdpa_iova_range (*get_iova_range)(struct vdpa_device *vdev);
 
 	/* DMA ops */
-	int (*set_map)(struct vdpa_device *vdev, struct vhost_iotlb *iotlb);
-	int (*dma_map)(struct vdpa_device *vdev, u64 iova, u64 size,
-		       u64 pa, u32 perm, void *opaque);
-	int (*dma_unmap)(struct vdpa_device *vdev, u64 iova, u64 size);
+	int (*set_map)(struct vdpa_device *vdev, unsigned int asid,
+		       struct vhost_iotlb *iotlb);
+	int (*dma_map)(struct vdpa_device *vdev, unsigned int asid,
+		       u64 iova, u64 size, u64 pa, u32 perm, void *opaque);
+	int (*dma_unmap)(struct vdpa_device *vdev, unsigned int asid,
+			 u64 iova, u64 size);
+	int (*set_group_asid)(struct vdpa_device *vdev, unsigned int group,
+			      unsigned int asid);
 
 	/* Free device resources */
 	void (*free)(struct vdpa_device *vdev);
@@ -315,6 +340,7 @@
 
 struct vdpa_device *__vdpa_alloc_device(struct device *parent,
 					const struct vdpa_config_ops *config,
+					unsigned int ngroups, unsigned int nas,
 					size_t size, const char *name,
 					bool use_va);
 
@@ -325,17 +351,20 @@
  * @member: the name of struct vdpa_device within the @dev_struct
  * @parent: the parent device
  * @config: the bus operations that is supported by this device
+ * @ngroups: the number of virtqueue groups supported by this device
+ * @nas: the number of address spaces
  * @name: name of the vdpa device
  * @use_va: indicate whether virtual address must be used by this device
  *
  * Return allocated data structure or ERR_PTR upon error
  */
-#define vdpa_alloc_device(dev_struct, member, parent, config, name, use_va)   \
-			  container_of(__vdpa_alloc_device( \
-				       parent, config, \
-				       sizeof(dev_struct) + \
+#define vdpa_alloc_device(dev_struct, member, parent, config, ngroups, nas, \
+			  name, use_va) \
+			  container_of((__vdpa_alloc_device( \
+				       parent, config, ngroups, nas, \
+				       (sizeof(dev_struct) + \
 				       BUILD_BUG_ON_ZERO(offsetof( \
-				       dev_struct, member)), name, use_va), \
+				       dev_struct, member))), name, use_va)), \
 				       dev_struct, member)
 
 int vdpa_register_device(struct vdpa_device *vdev, u32 nvqs);
@@ -395,10 +424,10 @@
 	const struct vdpa_config_ops *ops = vdev->config;
 	int ret;
 
-	mutex_lock(&vdev->cf_mutex);
+	down_write(&vdev->cf_lock);
 	vdev->features_valid = false;
 	ret = ops->reset(vdev);
-	mutex_unlock(&vdev->cf_mutex);
+	up_write(&vdev->cf_lock);
 	return ret;
 }
 
@@ -417,9 +446,9 @@
 {
 	int ret;
 
-	mutex_lock(&vdev->cf_mutex);
+	down_write(&vdev->cf_lock);
 	ret = vdpa_set_features_unlocked(vdev, features);
-	mutex_unlock(&vdev->cf_mutex);
+	up_write(&vdev->cf_lock);
 
 	return ret;
 }
@@ -463,7 +492,7 @@
 struct vdpa_mgmt_dev {
 	struct device *device;
 	const struct vdpa_mgmtdev_ops *ops;
-	const struct virtio_device_id *id_table;
+	struct virtio_device_id *id_table;
 	u64 config_attr_mask;
 	struct list_head list;
 	u64 supported_features;
diff --git a/include/linux/vhost_iotlb.h b/include/linux/vhost_iotlb.h
index 2d0e2f5..e79a408 100644
--- a/include/linux/vhost_iotlb.h
+++ b/include/linux/vhost_iotlb.h
@@ -36,6 +36,8 @@
 			  u64 addr, unsigned int perm);
 void vhost_iotlb_del_range(struct vhost_iotlb *iotlb, u64 start, u64 last);
 
+void vhost_iotlb_init(struct vhost_iotlb *iotlb, unsigned int limit,
+		      unsigned int flags);
 struct vhost_iotlb *vhost_iotlb_alloc(unsigned int limit, unsigned int flags);
 void vhost_iotlb_free(struct vhost_iotlb *iotlb);
 void vhost_iotlb_reset(struct vhost_iotlb *iotlb);
diff --git a/include/linux/virtio.h b/include/linux/virtio.h
index 5464f398..d8fdf17 100644
--- a/include/linux/virtio.h
+++ b/include/linux/virtio.h
@@ -131,6 +131,7 @@
 bool is_virtio_device(struct device *dev);
 
 void virtio_break_device(struct virtio_device *dev);
+void __virtio_unbreak_device(struct virtio_device *dev);
 
 void virtio_config_changed(struct virtio_device *dev);
 #ifdef CONFIG_PM_SLEEP
diff --git a/include/linux/virtio_config.h b/include/linux/virtio_config.h
index b341dd6..9a36051 100644
--- a/include/linux/virtio_config.h
+++ b/include/linux/virtio_config.h
@@ -57,6 +57,11 @@
  *		include a NULL entry for vqs unused by driver
  *	Returns 0 on success or error status
  * @del_vqs: free virtqueues found by find_vqs().
+ * @synchronize_cbs: synchronize with the virtqueue callbacks (optional)
+ *      The function guarantees that all memory operations on the
+ *      queue before it are visible to the vring_interrupt() that is
+ *      called after it.
+ *      vdev: the virtio_device
  * @get_features: get the array of feature bits for this device.
  *	vdev: the virtio_device
  *	Returns the first 64 feature bits (all we currently need).
@@ -89,6 +94,7 @@
 			const char * const names[], const bool *ctx,
 			struct irq_affinity *desc);
 	void (*del_vqs)(struct virtio_device *);
+	void (*synchronize_cbs)(struct virtio_device *);
 	u64 (*get_features)(struct virtio_device *vdev);
 	int (*finalize_features)(struct virtio_device *vdev);
 	const char *(*bus_name)(struct virtio_device *vdev);
@@ -218,6 +224,25 @@
 }
 
 /**
+ * virtio_synchronize_cbs - synchronize with virtqueue callbacks
+ * @vdev: the device
+ */
+static inline
+void virtio_synchronize_cbs(struct virtio_device *dev)
+{
+	if (dev->config->synchronize_cbs) {
+		dev->config->synchronize_cbs(dev);
+	} else {
+		/*
+		 * A best effort fallback to synchronize with
+		 * interrupts, preemption and softirq disabled
+		 * regions. See comment above synchronize_rcu().
+		 */
+		synchronize_rcu();
+	}
+}
+
+/**
  * virtio_device_ready - enable vq use in probe function
  * @vdev: the device
  *
@@ -230,7 +255,27 @@
 {
 	unsigned status = dev->config->get_status(dev);
 
-	BUG_ON(status & VIRTIO_CONFIG_S_DRIVER_OK);
+	WARN_ON(status & VIRTIO_CONFIG_S_DRIVER_OK);
+
+	/*
+	 * The virtio_synchronize_cbs() makes sure vring_interrupt()
+	 * will see the driver specific setup if it sees vq->broken
+	 * as false (even if the notifications come before DRIVER_OK).
+	 */
+	virtio_synchronize_cbs(dev);
+	__virtio_unbreak_device(dev);
+	/*
+	 * The transport should ensure the visibility of vq->broken
+	 * before setting DRIVER_OK. See the comments for the transport
+	 * specific set_status() method.
+	 *
+	 * A well behaved device will only notify a virtqueue after
+	 * DRIVER_OK, this means the device should "see" the coherenct
+	 * memory write that set vq->broken as false which is done by
+	 * the driver when it sees DRIVER_OK, then the following
+	 * driver's vring_interrupt() will see vq->broken as false so
+	 * we won't lose any notification.
+	 */
 	dev->config->set_status(dev, status | VIRTIO_CONFIG_S_DRIVER_OK);
 }
 
diff --git a/include/uapi/linux/vdpa.h b/include/uapi/linux/vdpa.h
index 1061d8d..25c55ca 100644
--- a/include/uapi/linux/vdpa.h
+++ b/include/uapi/linux/vdpa.h
@@ -18,6 +18,7 @@
 	VDPA_CMD_DEV_DEL,
 	VDPA_CMD_DEV_GET,		/* can dump */
 	VDPA_CMD_DEV_CONFIG_GET,	/* can dump */
+	VDPA_CMD_DEV_VSTATS_GET,
 };
 
 enum vdpa_attr {
@@ -46,6 +47,11 @@
 	VDPA_ATTR_DEV_NEGOTIATED_FEATURES,	/* u64 */
 	VDPA_ATTR_DEV_MGMTDEV_MAX_VQS,		/* u32 */
 	VDPA_ATTR_DEV_SUPPORTED_FEATURES,	/* u64 */
+
+	VDPA_ATTR_DEV_QUEUE_INDEX,              /* u32 */
+	VDPA_ATTR_DEV_VENDOR_ATTR_NAME,		/* string */
+	VDPA_ATTR_DEV_VENDOR_ATTR_VALUE,        /* u64 */
+
 	/* new attributes must be added above here */
 	VDPA_ATTR_MAX,
 };
diff --git a/include/uapi/linux/vhost.h b/include/uapi/linux/vhost.h
index 5d99e7c..cab645d 100644
--- a/include/uapi/linux/vhost.h
+++ b/include/uapi/linux/vhost.h
@@ -89,11 +89,6 @@
 
 /* Set or get vhost backend capability */
 
-/* Use message type V2 */
-#define VHOST_BACKEND_F_IOTLB_MSG_V2 0x1
-/* IOTLB can accept batching hints */
-#define VHOST_BACKEND_F_IOTLB_BATCH  0x2
-
 #define VHOST_SET_BACKEND_FEATURES _IOW(VHOST_VIRTIO, 0x25, __u64)
 #define VHOST_GET_BACKEND_FEATURES _IOR(VHOST_VIRTIO, 0x26, __u64)
 
@@ -150,11 +145,30 @@
 /* Get the valid iova range */
 #define VHOST_VDPA_GET_IOVA_RANGE	_IOR(VHOST_VIRTIO, 0x78, \
 					     struct vhost_vdpa_iova_range)
-
 /* Get the config size */
 #define VHOST_VDPA_GET_CONFIG_SIZE	_IOR(VHOST_VIRTIO, 0x79, __u32)
 
 /* Get the count of all virtqueues */
 #define VHOST_VDPA_GET_VQS_COUNT	_IOR(VHOST_VIRTIO, 0x80, __u32)
 
+/* Get the number of virtqueue groups. */
+#define VHOST_VDPA_GET_GROUP_NUM	_IOR(VHOST_VIRTIO, 0x81, __u32)
+
+/* Get the number of address spaces. */
+#define VHOST_VDPA_GET_AS_NUM		_IOR(VHOST_VIRTIO, 0x7A, unsigned int)
+
+/* Get the group for a virtqueue: read index, write group in num,
+ * The virtqueue index is stored in the index field of
+ * vhost_vring_state. The group for this specific virtqueue is
+ * returned via num field of vhost_vring_state.
+ */
+#define VHOST_VDPA_GET_VRING_GROUP	_IOWR(VHOST_VIRTIO, 0x7B,	\
+					      struct vhost_vring_state)
+/* Set the ASID for a virtqueue group. The group index is stored in
+ * the index field of vhost_vring_state, the ASID associated with this
+ * group is stored at num field of vhost_vring_state.
+ */
+#define VHOST_VDPA_SET_GROUP_ASID	_IOW(VHOST_VIRTIO, 0x7C, \
+					     struct vhost_vring_state)
+
 #endif
diff --git a/include/uapi/linux/vhost_types.h b/include/uapi/linux/vhost_types.h
index f7f6a3a..634cee4 100644
--- a/include/uapi/linux/vhost_types.h
+++ b/include/uapi/linux/vhost_types.h
@@ -87,7 +87,7 @@
 
 struct vhost_msg_v2 {
 	__u32 type;
-	__u32 reserved;
+	__u32 asid;
 	union {
 		struct vhost_iotlb_msg iotlb;
 		__u8 padding[64];
@@ -153,4 +153,13 @@
 /* vhost-net should add virtio_net_hdr for RX, and strip for TX packets. */
 #define VHOST_NET_F_VIRTIO_NET_HDR 27
 
+/* Use message type V2 */
+#define VHOST_BACKEND_F_IOTLB_MSG_V2 0x1
+/* IOTLB can accept batching hints */
+#define VHOST_BACKEND_F_IOTLB_BATCH  0x2
+/* IOTLB can accept address space identifier through V2 type of IOTLB
+ * message
+ */
+#define VHOST_BACKEND_F_IOTLB_ASID  0x3
+
 #endif