| // SPDX-License-Identifier: GPL-2.0-only | 
 | /* | 
 |  * Cryptographic API. | 
 |  * | 
 |  * Copyright (c) 2017-present, Facebook, Inc. | 
 |  */ | 
 | #include <linux/crypto.h> | 
 | #include <linux/init.h> | 
 | #include <linux/interrupt.h> | 
 | #include <linux/mm.h> | 
 | #include <linux/module.h> | 
 | #include <linux/net.h> | 
 | #include <linux/vmalloc.h> | 
 | #include <linux/zstd.h> | 
 | #include <crypto/internal/acompress.h> | 
 | #include <crypto/scatterwalk.h> | 
 |  | 
 |  | 
 | #define ZSTD_DEF_LEVEL		3 | 
 | #define ZSTD_MAX_WINDOWLOG	18 | 
 | #define ZSTD_MAX_SIZE		BIT(ZSTD_MAX_WINDOWLOG) | 
 |  | 
 | struct zstd_ctx { | 
 | 	zstd_cctx *cctx; | 
 | 	zstd_dctx *dctx; | 
 | 	size_t wksp_size; | 
 | 	zstd_parameters params; | 
 | 	u8 wksp[] __aligned(8); | 
 | }; | 
 |  | 
 | static DEFINE_MUTEX(zstd_stream_lock); | 
 |  | 
 | static void *zstd_alloc_stream(void) | 
 | { | 
 | 	zstd_parameters params; | 
 | 	struct zstd_ctx *ctx; | 
 | 	size_t wksp_size; | 
 |  | 
 | 	params = zstd_get_params(ZSTD_DEF_LEVEL, ZSTD_MAX_SIZE); | 
 |  | 
 | 	wksp_size = max_t(size_t, | 
 | 			  zstd_cstream_workspace_bound(¶ms.cParams), | 
 | 			  zstd_dstream_workspace_bound(ZSTD_MAX_SIZE)); | 
 | 	if (!wksp_size) | 
 | 		return ERR_PTR(-EINVAL); | 
 |  | 
 | 	ctx = kvmalloc(sizeof(*ctx) + wksp_size, GFP_KERNEL); | 
 | 	if (!ctx) | 
 | 		return ERR_PTR(-ENOMEM); | 
 |  | 
 | 	ctx->params = params; | 
 | 	ctx->wksp_size = wksp_size; | 
 |  | 
 | 	return ctx; | 
 | } | 
 |  | 
 | static void zstd_free_stream(void *ctx) | 
 | { | 
 | 	kvfree(ctx); | 
 | } | 
 |  | 
 | static struct crypto_acomp_streams zstd_streams = { | 
 | 	.alloc_ctx = zstd_alloc_stream, | 
 | 	.free_ctx = zstd_free_stream, | 
 | }; | 
 |  | 
 | static int zstd_init(struct crypto_acomp *acomp_tfm) | 
 | { | 
 | 	int ret = 0; | 
 |  | 
 | 	mutex_lock(&zstd_stream_lock); | 
 | 	ret = crypto_acomp_alloc_streams(&zstd_streams); | 
 | 	mutex_unlock(&zstd_stream_lock); | 
 |  | 
 | 	return ret; | 
 | } | 
 |  | 
 | static void zstd_exit(struct crypto_acomp *acomp_tfm) | 
 | { | 
 | 	crypto_acomp_free_streams(&zstd_streams); | 
 | } | 
 |  | 
 | static int zstd_compress_one(struct acomp_req *req, struct zstd_ctx *ctx, | 
 | 			     const void *src, void *dst, unsigned int *dlen) | 
 | { | 
 | 	unsigned int out_len; | 
 |  | 
 | 	ctx->cctx = zstd_init_cctx(ctx->wksp, ctx->wksp_size); | 
 | 	if (!ctx->cctx) | 
 | 		return -EINVAL; | 
 |  | 
 | 	out_len = zstd_compress_cctx(ctx->cctx, dst, req->dlen, src, req->slen, | 
 | 				     &ctx->params); | 
 | 	if (zstd_is_error(out_len)) | 
 | 		return -EINVAL; | 
 |  | 
 | 	*dlen = out_len; | 
 |  | 
 | 	return 0; | 
 | } | 
 |  | 
 | static int zstd_compress(struct acomp_req *req) | 
 | { | 
 | 	struct crypto_acomp_stream *s; | 
 | 	unsigned int pos, scur, dcur; | 
 | 	unsigned int total_out = 0; | 
 | 	bool data_available = true; | 
 | 	zstd_out_buffer outbuf; | 
 | 	struct acomp_walk walk; | 
 | 	zstd_in_buffer inbuf; | 
 | 	struct zstd_ctx *ctx; | 
 | 	size_t pending_bytes; | 
 | 	size_t num_bytes; | 
 | 	int ret; | 
 |  | 
 | 	s = crypto_acomp_lock_stream_bh(&zstd_streams); | 
 | 	ctx = s->ctx; | 
 |  | 
 | 	ret = acomp_walk_virt(&walk, req, true); | 
 | 	if (ret) | 
 | 		goto out; | 
 |  | 
 | 	ctx->cctx = zstd_init_cstream(&ctx->params, 0, ctx->wksp, ctx->wksp_size); | 
 | 	if (!ctx->cctx) { | 
 | 		ret = -EINVAL; | 
 | 		goto out; | 
 | 	} | 
 |  | 
 | 	do { | 
 | 		dcur = acomp_walk_next_dst(&walk); | 
 | 		if (!dcur) { | 
 | 			ret = -ENOSPC; | 
 | 			goto out; | 
 | 		} | 
 |  | 
 | 		outbuf.pos = 0; | 
 | 		outbuf.dst = (u8 *)walk.dst.virt.addr; | 
 | 		outbuf.size = dcur; | 
 |  | 
 | 		do { | 
 | 			scur = acomp_walk_next_src(&walk); | 
 | 			if (dcur == req->dlen && scur == req->slen) { | 
 | 				ret = zstd_compress_one(req, ctx, walk.src.virt.addr, | 
 | 							walk.dst.virt.addr, &total_out); | 
 | 				acomp_walk_done_src(&walk, scur); | 
 | 				acomp_walk_done_dst(&walk, dcur); | 
 | 				goto out; | 
 | 			} | 
 |  | 
 | 			if (scur) { | 
 | 				inbuf.pos = 0; | 
 | 				inbuf.src = walk.src.virt.addr; | 
 | 				inbuf.size = scur; | 
 | 			} else { | 
 | 				data_available = false; | 
 | 				break; | 
 | 			} | 
 |  | 
 | 			num_bytes = zstd_compress_stream(ctx->cctx, &outbuf, &inbuf); | 
 | 			if (ZSTD_isError(num_bytes)) { | 
 | 				ret = -EIO; | 
 | 				goto out; | 
 | 			} | 
 |  | 
 | 			pending_bytes = zstd_flush_stream(ctx->cctx, &outbuf); | 
 | 			if (ZSTD_isError(pending_bytes)) { | 
 | 				ret = -EIO; | 
 | 				goto out; | 
 | 			} | 
 | 			acomp_walk_done_src(&walk, inbuf.pos); | 
 | 		} while (dcur != outbuf.pos); | 
 |  | 
 | 		total_out += outbuf.pos; | 
 | 		acomp_walk_done_dst(&walk, dcur); | 
 | 	} while (data_available); | 
 |  | 
 | 	pos = outbuf.pos; | 
 | 	num_bytes = zstd_end_stream(ctx->cctx, &outbuf); | 
 | 	if (ZSTD_isError(num_bytes)) | 
 | 		ret = -EIO; | 
 | 	else | 
 | 		total_out += (outbuf.pos - pos); | 
 |  | 
 | out: | 
 | 	if (ret) | 
 | 		req->dlen = 0; | 
 | 	else | 
 | 		req->dlen = total_out; | 
 |  | 
 | 	crypto_acomp_unlock_stream_bh(s); | 
 |  | 
 | 	return ret; | 
 | } | 
 |  | 
 | static int zstd_decompress_one(struct acomp_req *req, struct zstd_ctx *ctx, | 
 | 			       const void *src, void *dst, unsigned int *dlen) | 
 | { | 
 | 	size_t out_len; | 
 |  | 
 | 	ctx->dctx = zstd_init_dctx(ctx->wksp, ctx->wksp_size); | 
 | 	if (!ctx->dctx) | 
 | 		return -EINVAL; | 
 |  | 
 | 	out_len = zstd_decompress_dctx(ctx->dctx, dst, req->dlen, src, req->slen); | 
 | 	if (zstd_is_error(out_len)) | 
 | 		return -EINVAL; | 
 |  | 
 | 	*dlen = out_len; | 
 |  | 
 | 	return 0; | 
 | } | 
 |  | 
 | static int zstd_decompress(struct acomp_req *req) | 
 | { | 
 | 	struct crypto_acomp_stream *s; | 
 | 	unsigned int total_out = 0; | 
 | 	unsigned int scur, dcur; | 
 | 	zstd_out_buffer outbuf; | 
 | 	struct acomp_walk walk; | 
 | 	zstd_in_buffer inbuf; | 
 | 	struct zstd_ctx *ctx; | 
 | 	size_t pending_bytes; | 
 | 	int ret; | 
 |  | 
 | 	s = crypto_acomp_lock_stream_bh(&zstd_streams); | 
 | 	ctx = s->ctx; | 
 |  | 
 | 	ret = acomp_walk_virt(&walk, req, true); | 
 | 	if (ret) | 
 | 		goto out; | 
 |  | 
 | 	ctx->dctx = zstd_init_dstream(ZSTD_MAX_SIZE, ctx->wksp, ctx->wksp_size); | 
 | 	if (!ctx->dctx) { | 
 | 		ret = -EINVAL; | 
 | 		goto out; | 
 | 	} | 
 |  | 
 | 	do { | 
 | 		scur = acomp_walk_next_src(&walk); | 
 | 		if (scur) { | 
 | 			inbuf.pos = 0; | 
 | 			inbuf.size = scur; | 
 | 			inbuf.src = walk.src.virt.addr; | 
 | 		} else { | 
 | 			break; | 
 | 		} | 
 |  | 
 | 		do { | 
 | 			dcur = acomp_walk_next_dst(&walk); | 
 | 			if (dcur == req->dlen && scur == req->slen) { | 
 | 				ret = zstd_decompress_one(req, ctx, walk.src.virt.addr, | 
 | 							  walk.dst.virt.addr, &total_out); | 
 | 				acomp_walk_done_dst(&walk, dcur); | 
 | 				acomp_walk_done_src(&walk, scur); | 
 | 				goto out; | 
 | 			} | 
 |  | 
 | 			if (!dcur) { | 
 | 				ret = -ENOSPC; | 
 | 				goto out; | 
 | 			} | 
 |  | 
 | 			outbuf.pos = 0; | 
 | 			outbuf.dst = (u8 *)walk.dst.virt.addr; | 
 | 			outbuf.size = dcur; | 
 |  | 
 | 			pending_bytes = zstd_decompress_stream(ctx->dctx, &outbuf, &inbuf); | 
 | 			if (ZSTD_isError(pending_bytes)) { | 
 | 				ret = -EIO; | 
 | 				goto out; | 
 | 			} | 
 |  | 
 | 			total_out += outbuf.pos; | 
 |  | 
 | 			acomp_walk_done_dst(&walk, outbuf.pos); | 
 | 		} while (inbuf.pos != scur); | 
 |  | 
 | 		acomp_walk_done_src(&walk, scur); | 
 | 	} while (ret == 0); | 
 |  | 
 | out: | 
 | 	if (ret) | 
 | 		req->dlen = 0; | 
 | 	else | 
 | 		req->dlen = total_out; | 
 |  | 
 | 	crypto_acomp_unlock_stream_bh(s); | 
 |  | 
 | 	return ret; | 
 | } | 
 |  | 
 | static struct acomp_alg zstd_acomp = { | 
 | 	.base = { | 
 | 		.cra_name = "zstd", | 
 | 		.cra_driver_name = "zstd-generic", | 
 | 		.cra_flags = CRYPTO_ALG_REQ_VIRT, | 
 | 		.cra_module = THIS_MODULE, | 
 | 	}, | 
 | 	.init = zstd_init, | 
 | 	.exit = zstd_exit, | 
 | 	.compress = zstd_compress, | 
 | 	.decompress = zstd_decompress, | 
 | }; | 
 |  | 
 | static int __init zstd_mod_init(void) | 
 | { | 
 | 	return crypto_register_acomp(&zstd_acomp); | 
 | } | 
 |  | 
 | static void __exit zstd_mod_fini(void) | 
 | { | 
 | 	crypto_unregister_acomp(&zstd_acomp); | 
 | } | 
 |  | 
 | module_init(zstd_mod_init); | 
 | module_exit(zstd_mod_fini); | 
 |  | 
 | MODULE_LICENSE("GPL"); | 
 | MODULE_DESCRIPTION("Zstd Compression Algorithm"); | 
 | MODULE_ALIAS_CRYPTO("zstd"); |