|  | // SPDX-License-Identifier: GPL-2.0-only | 
|  | /* | 
|  | * Copyright (c) 2021, 2022 Oracle.  All rights reserved. | 
|  | * | 
|  | * The AUTH_TLS credential is used only to probe a remote peer | 
|  | * for RPC-over-TLS support. | 
|  | */ | 
|  |  | 
|  | #include <linux/types.h> | 
|  | #include <linux/module.h> | 
|  | #include <linux/sunrpc/clnt.h> | 
|  |  | 
|  | static const char *starttls_token = "STARTTLS"; | 
|  | static const size_t starttls_len = 8; | 
|  |  | 
|  | static struct rpc_auth tls_auth; | 
|  | static struct rpc_cred tls_cred; | 
|  |  | 
|  | static void tls_encode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr, | 
|  | const void *obj) | 
|  | { | 
|  | } | 
|  |  | 
|  | static int tls_decode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr, | 
|  | void *obj) | 
|  | { | 
|  | return 0; | 
|  | } | 
|  |  | 
|  | static const struct rpc_procinfo rpcproc_tls_probe = { | 
|  | .p_encode	= tls_encode_probe, | 
|  | .p_decode	= tls_decode_probe, | 
|  | }; | 
|  |  | 
|  | static void rpc_tls_probe_call_prepare(struct rpc_task *task, void *data) | 
|  | { | 
|  | task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT; | 
|  | rpc_call_start(task); | 
|  | } | 
|  |  | 
|  | static void rpc_tls_probe_call_done(struct rpc_task *task, void *data) | 
|  | { | 
|  | } | 
|  |  | 
|  | static const struct rpc_call_ops rpc_tls_probe_ops = { | 
|  | .rpc_call_prepare	= rpc_tls_probe_call_prepare, | 
|  | .rpc_call_done		= rpc_tls_probe_call_done, | 
|  | }; | 
|  |  | 
|  | static int tls_probe(struct rpc_clnt *clnt) | 
|  | { | 
|  | struct rpc_message msg = { | 
|  | .rpc_proc	= &rpcproc_tls_probe, | 
|  | }; | 
|  | struct rpc_task_setup task_setup_data = { | 
|  | .rpc_client	= clnt, | 
|  | .rpc_message	= &msg, | 
|  | .rpc_op_cred	= &tls_cred, | 
|  | .callback_ops	= &rpc_tls_probe_ops, | 
|  | .flags		= RPC_TASK_SOFT | RPC_TASK_SOFTCONN, | 
|  | }; | 
|  | struct rpc_task	*task; | 
|  | int status; | 
|  |  | 
|  | task = rpc_run_task(&task_setup_data); | 
|  | if (IS_ERR(task)) | 
|  | return PTR_ERR(task); | 
|  | status = task->tk_status; | 
|  | rpc_put_task(task); | 
|  | return status; | 
|  | } | 
|  |  | 
|  | static struct rpc_auth *tls_create(const struct rpc_auth_create_args *args, | 
|  | struct rpc_clnt *clnt) | 
|  | { | 
|  | refcount_inc(&tls_auth.au_count); | 
|  | return &tls_auth; | 
|  | } | 
|  |  | 
|  | static void tls_destroy(struct rpc_auth *auth) | 
|  | { | 
|  | } | 
|  |  | 
|  | static struct rpc_cred *tls_lookup_cred(struct rpc_auth *auth, | 
|  | struct auth_cred *acred, int flags) | 
|  | { | 
|  | return get_rpccred(&tls_cred); | 
|  | } | 
|  |  | 
|  | static void tls_destroy_cred(struct rpc_cred *cred) | 
|  | { | 
|  | } | 
|  |  | 
|  | static int tls_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags) | 
|  | { | 
|  | return 1; | 
|  | } | 
|  |  | 
|  | static int tls_marshal(struct rpc_task *task, struct xdr_stream *xdr) | 
|  | { | 
|  | __be32 *p; | 
|  |  | 
|  | p = xdr_reserve_space(xdr, 4 * XDR_UNIT); | 
|  | if (!p) | 
|  | return -EMSGSIZE; | 
|  | /* Credential */ | 
|  | *p++ = rpc_auth_tls; | 
|  | *p++ = xdr_zero; | 
|  | /* Verifier */ | 
|  | *p++ = rpc_auth_null; | 
|  | *p   = xdr_zero; | 
|  | return 0; | 
|  | } | 
|  |  | 
|  | static int tls_refresh(struct rpc_task *task) | 
|  | { | 
|  | set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags); | 
|  | return 0; | 
|  | } | 
|  |  | 
|  | static int tls_validate(struct rpc_task *task, struct xdr_stream *xdr) | 
|  | { | 
|  | __be32 *p; | 
|  | void *str; | 
|  |  | 
|  | p = xdr_inline_decode(xdr, XDR_UNIT); | 
|  | if (!p) | 
|  | return -EIO; | 
|  | if (*p != rpc_auth_null) | 
|  | return -EIO; | 
|  | if (xdr_stream_decode_opaque_inline(xdr, &str, starttls_len) != starttls_len) | 
|  | return -EPROTONOSUPPORT; | 
|  | if (memcmp(str, starttls_token, starttls_len)) | 
|  | return -EPROTONOSUPPORT; | 
|  | return 0; | 
|  | } | 
|  |  | 
|  | const struct rpc_authops authtls_ops = { | 
|  | .owner		= THIS_MODULE, | 
|  | .au_flavor	= RPC_AUTH_TLS, | 
|  | .au_name	= "NULL", | 
|  | .create		= tls_create, | 
|  | .destroy	= tls_destroy, | 
|  | .lookup_cred	= tls_lookup_cred, | 
|  | .ping		= tls_probe, | 
|  | }; | 
|  |  | 
|  | static struct rpc_auth tls_auth = { | 
|  | .au_cslack	= NUL_CALLSLACK, | 
|  | .au_rslack	= NUL_REPLYSLACK, | 
|  | .au_verfsize	= NUL_REPLYSLACK, | 
|  | .au_ralign	= NUL_REPLYSLACK, | 
|  | .au_ops		= &authtls_ops, | 
|  | .au_flavor	= RPC_AUTH_TLS, | 
|  | .au_count	= REFCOUNT_INIT(1), | 
|  | }; | 
|  |  | 
|  | static const struct rpc_credops tls_credops = { | 
|  | .cr_name	= "AUTH_TLS", | 
|  | .crdestroy	= tls_destroy_cred, | 
|  | .crmatch	= tls_match, | 
|  | .crmarshal	= tls_marshal, | 
|  | .crwrap_req	= rpcauth_wrap_req_encode, | 
|  | .crrefresh	= tls_refresh, | 
|  | .crvalidate	= tls_validate, | 
|  | .crunwrap_resp	= rpcauth_unwrap_resp_decode, | 
|  | }; | 
|  |  | 
|  | static struct rpc_cred tls_cred = { | 
|  | .cr_lru		= LIST_HEAD_INIT(tls_cred.cr_lru), | 
|  | .cr_auth	= &tls_auth, | 
|  | .cr_ops		= &tls_credops, | 
|  | .cr_count	= REFCOUNT_INIT(2), | 
|  | .cr_flags	= 1UL << RPCAUTH_CRED_UPTODATE, | 
|  | }; |