| // SPDX-License-Identifier: GPL-2.0 | 
 | /* | 
 |  * Amazon Nitro Secure Module driver. | 
 |  * | 
 |  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | 
 |  * | 
 |  * The Nitro Secure Module implements commands via CBOR over virtio. | 
 |  * This driver exposes a raw message ioctls on /dev/nsm that user | 
 |  * space can use to issue these commands. | 
 |  */ | 
 |  | 
 | #include <linux/file.h> | 
 | #include <linux/fs.h> | 
 | #include <linux/interrupt.h> | 
 | #include <linux/hw_random.h> | 
 | #include <linux/miscdevice.h> | 
 | #include <linux/module.h> | 
 | #include <linux/mutex.h> | 
 | #include <linux/slab.h> | 
 | #include <linux/string.h> | 
 | #include <linux/uaccess.h> | 
 | #include <linux/uio.h> | 
 | #include <linux/virtio_config.h> | 
 | #include <linux/virtio_ids.h> | 
 | #include <linux/virtio.h> | 
 | #include <linux/wait.h> | 
 | #include <uapi/linux/nsm.h> | 
 |  | 
 | /* Timeout for NSM virtqueue respose in milliseconds. */ | 
 | #define NSM_DEFAULT_TIMEOUT_MSECS (120000) /* 2 minutes */ | 
 |  | 
 | /* Maximum length input data */ | 
 | struct nsm_data_req { | 
 | 	u32 len; | 
 | 	u8  data[NSM_REQUEST_MAX_SIZE]; | 
 | }; | 
 |  | 
 | /* Maximum length output data */ | 
 | struct nsm_data_resp { | 
 | 	u32 len; | 
 | 	u8  data[NSM_RESPONSE_MAX_SIZE]; | 
 | }; | 
 |  | 
 | /* Full NSM request/response message */ | 
 | struct nsm_msg { | 
 | 	struct nsm_data_req req; | 
 | 	struct nsm_data_resp resp; | 
 | }; | 
 |  | 
 | struct nsm { | 
 | 	struct virtio_device *vdev; | 
 | 	struct virtqueue     *vq; | 
 | 	struct mutex          lock; | 
 | 	struct completion     cmd_done; | 
 | 	struct miscdevice     misc; | 
 | 	struct hwrng          hwrng; | 
 | 	struct work_struct    misc_init; | 
 | 	struct nsm_msg        msg; | 
 | }; | 
 |  | 
 | /* NSM device ID */ | 
 | static const struct virtio_device_id id_table[] = { | 
 | 	{ VIRTIO_ID_NITRO_SEC_MOD, VIRTIO_DEV_ANY_ID }, | 
 | 	{ 0 }, | 
 | }; | 
 |  | 
 | static struct nsm *file_to_nsm(struct file *file) | 
 | { | 
 | 	return container_of(file->private_data, struct nsm, misc); | 
 | } | 
 |  | 
 | static struct nsm *hwrng_to_nsm(struct hwrng *rng) | 
 | { | 
 | 	return container_of(rng, struct nsm, hwrng); | 
 | } | 
 |  | 
 | #define CBOR_TYPE_MASK  0xE0 | 
 | #define CBOR_TYPE_MAP 0xA0 | 
 | #define CBOR_TYPE_TEXT 0x60 | 
 | #define CBOR_TYPE_ARRAY 0x40 | 
 | #define CBOR_HEADER_SIZE_SHORT 1 | 
 |  | 
 | #define CBOR_SHORT_SIZE_MAX_VALUE 23 | 
 | #define CBOR_LONG_SIZE_U8  24 | 
 | #define CBOR_LONG_SIZE_U16 25 | 
 | #define CBOR_LONG_SIZE_U32 26 | 
 | #define CBOR_LONG_SIZE_U64 27 | 
 |  | 
 | static bool cbor_object_is_array(const u8 *cbor_object, size_t cbor_object_size) | 
 | { | 
 | 	if (cbor_object_size == 0 || cbor_object == NULL) | 
 | 		return false; | 
 |  | 
 | 	return (cbor_object[0] & CBOR_TYPE_MASK) == CBOR_TYPE_ARRAY; | 
 | } | 
 |  | 
 | static int cbor_object_get_array(u8 *cbor_object, size_t cbor_object_size, u8 **cbor_array) | 
 | { | 
 | 	u8 cbor_short_size; | 
 | 	void *array_len_p; | 
 | 	u64 array_len; | 
 | 	u64 array_offset; | 
 |  | 
 | 	if (!cbor_object_is_array(cbor_object, cbor_object_size)) | 
 | 		return -EFAULT; | 
 |  | 
 | 	cbor_short_size = (cbor_object[0] & 0x1F); | 
 |  | 
 | 	/* Decoding byte array length */ | 
 | 	array_offset = CBOR_HEADER_SIZE_SHORT; | 
 | 	if (cbor_short_size >= CBOR_LONG_SIZE_U8) | 
 | 		array_offset += BIT(cbor_short_size - CBOR_LONG_SIZE_U8); | 
 |  | 
 | 	if (cbor_object_size < array_offset) | 
 | 		return -EFAULT; | 
 |  | 
 | 	array_len_p = &cbor_object[1]; | 
 |  | 
 | 	switch (cbor_short_size) { | 
 | 	case CBOR_SHORT_SIZE_MAX_VALUE: /* short encoding */ | 
 | 		array_len = cbor_short_size; | 
 | 		break; | 
 | 	case CBOR_LONG_SIZE_U8: | 
 | 		array_len = *(u8 *)array_len_p; | 
 | 		break; | 
 | 	case CBOR_LONG_SIZE_U16: | 
 | 		array_len = be16_to_cpup((__be16 *)array_len_p); | 
 | 		break; | 
 | 	case CBOR_LONG_SIZE_U32: | 
 | 		array_len = be32_to_cpup((__be32 *)array_len_p); | 
 | 		break; | 
 | 	case CBOR_LONG_SIZE_U64: | 
 | 		array_len = be64_to_cpup((__be64 *)array_len_p); | 
 | 		break; | 
 | 	} | 
 |  | 
 | 	if (cbor_object_size < array_offset) | 
 | 		return -EFAULT; | 
 |  | 
 | 	if (cbor_object_size - array_offset < array_len) | 
 | 		return -EFAULT; | 
 |  | 
 | 	if (array_len > INT_MAX) | 
 | 		return -EFAULT; | 
 |  | 
 | 	*cbor_array = cbor_object + array_offset; | 
 | 	return array_len; | 
 | } | 
 |  | 
 | /* Copy the request of a raw message to kernel space */ | 
 | static int fill_req_raw(struct nsm *nsm, struct nsm_data_req *req, | 
 | 			struct nsm_raw *raw) | 
 | { | 
 | 	/* Verify the user input size. */ | 
 | 	if (raw->request.len > sizeof(req->data)) | 
 | 		return -EMSGSIZE; | 
 |  | 
 | 	/* Copy the request payload */ | 
 | 	if (copy_from_user(req->data, u64_to_user_ptr(raw->request.addr), | 
 | 			   raw->request.len)) | 
 | 		return -EFAULT; | 
 |  | 
 | 	req->len = raw->request.len; | 
 |  | 
 | 	return 0; | 
 | } | 
 |  | 
 | /* Copy the response of a raw message back to user-space */ | 
 | static int parse_resp_raw(struct nsm *nsm, struct nsm_data_resp *resp, | 
 | 			  struct nsm_raw *raw) | 
 | { | 
 | 	/* Truncate any message that does not fit. */ | 
 | 	raw->response.len = min_t(u64, raw->response.len, resp->len); | 
 |  | 
 | 	/* Copy the response content to user space */ | 
 | 	if (copy_to_user(u64_to_user_ptr(raw->response.addr), | 
 | 			 resp->data, raw->response.len)) | 
 | 		return -EFAULT; | 
 |  | 
 | 	return 0; | 
 | } | 
 |  | 
 | /* Virtqueue interrupt handler */ | 
 | static void nsm_vq_callback(struct virtqueue *vq) | 
 | { | 
 | 	struct nsm *nsm = vq->vdev->priv; | 
 |  | 
 | 	complete(&nsm->cmd_done); | 
 | } | 
 |  | 
 | /* Forward a message to the NSM device and wait for the response from it */ | 
 | static int nsm_sendrecv_msg_locked(struct nsm *nsm) | 
 | { | 
 | 	struct device *dev = &nsm->vdev->dev; | 
 | 	struct scatterlist sg_in, sg_out; | 
 | 	struct nsm_msg *msg = &nsm->msg; | 
 | 	struct virtqueue *vq = nsm->vq; | 
 | 	unsigned int len; | 
 | 	void *queue_buf; | 
 | 	bool kicked; | 
 | 	int rc; | 
 |  | 
 | 	/* Initialize scatter-gather lists with request and response buffers. */ | 
 | 	sg_init_one(&sg_out, msg->req.data, msg->req.len); | 
 | 	sg_init_one(&sg_in, msg->resp.data, sizeof(msg->resp.data)); | 
 |  | 
 | 	init_completion(&nsm->cmd_done); | 
 | 	/* Add the request buffer (read by the device). */ | 
 | 	rc = virtqueue_add_outbuf(vq, &sg_out, 1, msg->req.data, GFP_KERNEL); | 
 | 	if (rc) | 
 | 		return rc; | 
 |  | 
 | 	/* Add the response buffer (written by the device). */ | 
 | 	rc = virtqueue_add_inbuf(vq, &sg_in, 1, msg->resp.data, GFP_KERNEL); | 
 | 	if (rc) | 
 | 		goto cleanup; | 
 |  | 
 | 	kicked = virtqueue_kick(vq); | 
 | 	if (!kicked) { | 
 | 		/* Cannot kick the virtqueue. */ | 
 | 		rc = -EIO; | 
 | 		goto cleanup; | 
 | 	} | 
 |  | 
 | 	/* If the kick succeeded, wait for the device's response. */ | 
 | 	if (!wait_for_completion_io_timeout(&nsm->cmd_done, | 
 | 		msecs_to_jiffies(NSM_DEFAULT_TIMEOUT_MSECS))) { | 
 | 		rc = -ETIMEDOUT; | 
 | 		goto cleanup; | 
 | 	} | 
 |  | 
 | 	queue_buf = virtqueue_get_buf(vq, &len); | 
 | 	if (!queue_buf || (queue_buf != msg->req.data)) { | 
 | 		dev_err(dev, "wrong request buffer."); | 
 | 		rc = -ENODATA; | 
 | 		goto cleanup; | 
 | 	} | 
 |  | 
 | 	queue_buf = virtqueue_get_buf(vq, &len); | 
 | 	if (!queue_buf || (queue_buf != msg->resp.data)) { | 
 | 		dev_err(dev, "wrong response buffer."); | 
 | 		rc = -ENODATA; | 
 | 		goto cleanup; | 
 | 	} | 
 |  | 
 | 	msg->resp.len = len; | 
 |  | 
 | 	rc = 0; | 
 |  | 
 | cleanup: | 
 | 	if (rc) { | 
 | 		/* Clean the virtqueue. */ | 
 | 		while (virtqueue_get_buf(vq, &len) != NULL) | 
 | 			; | 
 | 	} | 
 |  | 
 | 	return rc; | 
 | } | 
 |  | 
 | static int fill_req_get_random(struct nsm *nsm, struct nsm_data_req *req) | 
 | { | 
 | 	/* | 
 | 	 * 69                          # text(9) | 
 | 	 *     47657452616E646F6D      # "GetRandom" | 
 | 	 */ | 
 | 	const u8 request[] = { CBOR_TYPE_TEXT + strlen("GetRandom"), | 
 | 			       'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm' }; | 
 |  | 
 | 	memcpy(req->data, request, sizeof(request)); | 
 | 	req->len = sizeof(request); | 
 |  | 
 | 	return 0; | 
 | } | 
 |  | 
 | static int parse_resp_get_random(struct nsm *nsm, struct nsm_data_resp *resp, | 
 | 				 void *out, size_t max) | 
 | { | 
 | 	/* | 
 | 	 * A1                          # map(1) | 
 | 	 *     69                      # text(9) - Name of field | 
 | 	 *         47657452616E646F6D  # "GetRandom" | 
 | 	 * A1                          # map(1) - The field itself | 
 | 	 *     66                      # text(6) | 
 | 	 *         72616E646F6D        # "random" | 
 | 	 *	# The rest of the response is random data | 
 | 	 */ | 
 | 	const u8 response[] = { CBOR_TYPE_MAP + 1, | 
 | 				CBOR_TYPE_TEXT + strlen("GetRandom"), | 
 | 				'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm', | 
 | 				CBOR_TYPE_MAP + 1, | 
 | 				CBOR_TYPE_TEXT + strlen("random"), | 
 | 				'r', 'a', 'n', 'd', 'o', 'm' }; | 
 | 	struct device *dev = &nsm->vdev->dev; | 
 | 	u8 *rand_data = NULL; | 
 | 	u8 *resp_ptr = resp->data; | 
 | 	u64 resp_len = resp->len; | 
 | 	int rc; | 
 |  | 
 | 	if ((resp->len < sizeof(response) + 1) || | 
 | 	    (memcmp(resp_ptr, response, sizeof(response)) != 0)) { | 
 | 		dev_err(dev, "Invalid response for GetRandom"); | 
 | 		return -EFAULT; | 
 | 	} | 
 |  | 
 | 	resp_ptr += sizeof(response); | 
 | 	resp_len -= sizeof(response); | 
 |  | 
 | 	rc = cbor_object_get_array(resp_ptr, resp_len, &rand_data); | 
 | 	if (rc < 0) { | 
 | 		dev_err(dev, "GetRandom: Invalid CBOR encoding\n"); | 
 | 		return rc; | 
 | 	} | 
 |  | 
 | 	rc = min_t(size_t, rc, max); | 
 | 	memcpy(out, rand_data, rc); | 
 |  | 
 | 	return rc; | 
 | } | 
 |  | 
 | /* | 
 |  * HwRNG implementation | 
 |  */ | 
 | static int nsm_rng_read(struct hwrng *rng, void *data, size_t max, bool wait) | 
 | { | 
 | 	struct nsm *nsm = hwrng_to_nsm(rng); | 
 | 	struct device *dev = &nsm->vdev->dev; | 
 | 	int rc = 0; | 
 |  | 
 | 	/* NSM always needs to wait for a response */ | 
 | 	if (!wait) | 
 | 		return 0; | 
 |  | 
 | 	mutex_lock(&nsm->lock); | 
 |  | 
 | 	rc = fill_req_get_random(nsm, &nsm->msg.req); | 
 | 	if (rc != 0) | 
 | 		goto out; | 
 |  | 
 | 	rc = nsm_sendrecv_msg_locked(nsm); | 
 | 	if (rc != 0) | 
 | 		goto out; | 
 |  | 
 | 	rc = parse_resp_get_random(nsm, &nsm->msg.resp, data, max); | 
 | 	if (rc < 0) | 
 | 		goto out; | 
 |  | 
 | 	dev_dbg(dev, "RNG: returning rand bytes = %d", rc); | 
 | out: | 
 | 	mutex_unlock(&nsm->lock); | 
 | 	return rc; | 
 | } | 
 |  | 
 | static long nsm_dev_ioctl(struct file *file, unsigned int cmd, | 
 | 	unsigned long arg) | 
 | { | 
 | 	void __user *argp = u64_to_user_ptr((u64)arg); | 
 | 	struct nsm *nsm = file_to_nsm(file); | 
 | 	struct nsm_raw raw; | 
 | 	int r = 0; | 
 |  | 
 | 	if (cmd != NSM_IOCTL_RAW) | 
 | 		return -EINVAL; | 
 |  | 
 | 	if (_IOC_SIZE(cmd) != sizeof(raw)) | 
 | 		return -EINVAL; | 
 |  | 
 | 	/* Copy user argument struct to kernel argument struct */ | 
 | 	r = -EFAULT; | 
 | 	if (copy_from_user(&raw, argp, _IOC_SIZE(cmd))) | 
 | 		goto out; | 
 |  | 
 | 	mutex_lock(&nsm->lock); | 
 |  | 
 | 	/* Convert kernel argument struct to device request */ | 
 | 	r = fill_req_raw(nsm, &nsm->msg.req, &raw); | 
 | 	if (r) | 
 | 		goto out; | 
 |  | 
 | 	/* Send message to NSM and read reply */ | 
 | 	r = nsm_sendrecv_msg_locked(nsm); | 
 | 	if (r) | 
 | 		goto out; | 
 |  | 
 | 	/* Parse device response into kernel argument struct */ | 
 | 	r = parse_resp_raw(nsm, &nsm->msg.resp, &raw); | 
 | 	if (r) | 
 | 		goto out; | 
 |  | 
 | 	/* Copy kernel argument struct back to user argument struct */ | 
 | 	r = -EFAULT; | 
 | 	if (copy_to_user(argp, &raw, sizeof(raw))) | 
 | 		goto out; | 
 |  | 
 | 	r = 0; | 
 |  | 
 | out: | 
 | 	mutex_unlock(&nsm->lock); | 
 | 	return r; | 
 | } | 
 |  | 
 | static int nsm_device_init_vq(struct virtio_device *vdev) | 
 | { | 
 | 	struct virtqueue *vq = virtio_find_single_vq(vdev, | 
 | 		nsm_vq_callback, "nsm.vq.0"); | 
 | 	struct nsm *nsm = vdev->priv; | 
 |  | 
 | 	if (IS_ERR(vq)) | 
 | 		return PTR_ERR(vq); | 
 |  | 
 | 	nsm->vq = vq; | 
 |  | 
 | 	return 0; | 
 | } | 
 |  | 
 | static const struct file_operations nsm_dev_fops = { | 
 | 	.unlocked_ioctl = nsm_dev_ioctl, | 
 | 	.compat_ioctl = compat_ptr_ioctl, | 
 | }; | 
 |  | 
 | /* Handler for probing the NSM device */ | 
 | static int nsm_device_probe(struct virtio_device *vdev) | 
 | { | 
 | 	struct device *dev = &vdev->dev; | 
 | 	struct nsm *nsm; | 
 | 	int rc; | 
 |  | 
 | 	nsm = devm_kzalloc(&vdev->dev, sizeof(*nsm), GFP_KERNEL); | 
 | 	if (!nsm) | 
 | 		return -ENOMEM; | 
 |  | 
 | 	vdev->priv = nsm; | 
 | 	nsm->vdev = vdev; | 
 |  | 
 | 	rc = nsm_device_init_vq(vdev); | 
 | 	if (rc) { | 
 | 		dev_err(dev, "queue failed to initialize: %d.\n", rc); | 
 | 		goto err_init_vq; | 
 | 	} | 
 |  | 
 | 	mutex_init(&nsm->lock); | 
 |  | 
 | 	/* Register as hwrng provider */ | 
 | 	nsm->hwrng = (struct hwrng) { | 
 | 		.read = nsm_rng_read, | 
 | 		.name = "nsm-hwrng", | 
 | 		.quality = 1000, | 
 | 	}; | 
 |  | 
 | 	rc = hwrng_register(&nsm->hwrng); | 
 | 	if (rc) { | 
 | 		dev_err(dev, "RNG initialization error: %d.\n", rc); | 
 | 		goto err_hwrng; | 
 | 	} | 
 |  | 
 | 	/* Register /dev/nsm device node */ | 
 | 	nsm->misc = (struct miscdevice) { | 
 | 		.minor	= MISC_DYNAMIC_MINOR, | 
 | 		.name	= "nsm", | 
 | 		.fops	= &nsm_dev_fops, | 
 | 		.mode	= 0666, | 
 | 	}; | 
 |  | 
 | 	rc = misc_register(&nsm->misc); | 
 | 	if (rc) { | 
 | 		dev_err(dev, "misc device registration error: %d.\n", rc); | 
 | 		goto err_misc; | 
 | 	} | 
 |  | 
 | 	return 0; | 
 |  | 
 | err_misc: | 
 | 	hwrng_unregister(&nsm->hwrng); | 
 | err_hwrng: | 
 | 	vdev->config->del_vqs(vdev); | 
 | err_init_vq: | 
 | 	return rc; | 
 | } | 
 |  | 
 | /* Handler for removing the NSM device */ | 
 | static void nsm_device_remove(struct virtio_device *vdev) | 
 | { | 
 | 	struct nsm *nsm = vdev->priv; | 
 |  | 
 | 	hwrng_unregister(&nsm->hwrng); | 
 |  | 
 | 	vdev->config->del_vqs(vdev); | 
 | 	misc_deregister(&nsm->misc); | 
 | } | 
 |  | 
 | /* NSM device configuration structure */ | 
 | static struct virtio_driver virtio_nsm_driver = { | 
 | 	.feature_table             = 0, | 
 | 	.feature_table_size        = 0, | 
 | 	.feature_table_legacy      = 0, | 
 | 	.feature_table_size_legacy = 0, | 
 | 	.driver.name               = KBUILD_MODNAME, | 
 | 	.id_table                  = id_table, | 
 | 	.probe                     = nsm_device_probe, | 
 | 	.remove                    = nsm_device_remove, | 
 | }; | 
 |  | 
 | module_virtio_driver(virtio_nsm_driver); | 
 | MODULE_DEVICE_TABLE(virtio, id_table); | 
 | MODULE_DESCRIPTION("Virtio NSM driver"); | 
 | MODULE_LICENSE("GPL"); |