|  | // SPDX-License-Identifier: GPL-2.0 | 
|  | #include <net/tcp.h> | 
|  | #include <net/strparser.h> | 
|  | #include <net/xfrm.h> | 
|  | #include <net/esp.h> | 
|  | #include <net/espintcp.h> | 
|  | #include <linux/skmsg.h> | 
|  | #include <net/inet_common.h> | 
|  | #if IS_ENABLED(CONFIG_IPV6) | 
|  | #include <net/ipv6_stubs.h> | 
|  | #endif | 
|  |  | 
|  | static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb, | 
|  | struct sock *sk) | 
|  | { | 
|  | if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf || | 
|  | !sk_rmem_schedule(sk, skb, skb->truesize)) { | 
|  | kfree_skb(skb); | 
|  | return; | 
|  | } | 
|  |  | 
|  | skb_set_owner_r(skb, sk); | 
|  |  | 
|  | memset(skb->cb, 0, sizeof(skb->cb)); | 
|  | skb_queue_tail(&ctx->ike_queue, skb); | 
|  | ctx->saved_data_ready(sk); | 
|  | } | 
|  |  | 
|  | static void handle_esp(struct sk_buff *skb, struct sock *sk) | 
|  | { | 
|  | skb_reset_transport_header(skb); | 
|  | memset(skb->cb, 0, sizeof(skb->cb)); | 
|  |  | 
|  | rcu_read_lock(); | 
|  | skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif); | 
|  | local_bh_disable(); | 
|  | #if IS_ENABLED(CONFIG_IPV6) | 
|  | if (sk->sk_family == AF_INET6) | 
|  | ipv6_stub->xfrm6_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP); | 
|  | else | 
|  | #endif | 
|  | xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP); | 
|  | local_bh_enable(); | 
|  | rcu_read_unlock(); | 
|  | } | 
|  |  | 
|  | static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb) | 
|  | { | 
|  | struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx, | 
|  | strp); | 
|  | struct strp_msg *rxm = strp_msg(skb); | 
|  | u32 nonesp_marker; | 
|  | int err; | 
|  |  | 
|  | err = skb_copy_bits(skb, rxm->offset + 2, &nonesp_marker, | 
|  | sizeof(nonesp_marker)); | 
|  | if (err < 0) { | 
|  | kfree_skb(skb); | 
|  | return; | 
|  | } | 
|  |  | 
|  | /* remove header, leave non-ESP marker/SPI */ | 
|  | if (!__pskb_pull(skb, rxm->offset + 2)) { | 
|  | kfree_skb(skb); | 
|  | return; | 
|  | } | 
|  |  | 
|  | if (pskb_trim(skb, rxm->full_len - 2) != 0) { | 
|  | kfree_skb(skb); | 
|  | return; | 
|  | } | 
|  |  | 
|  | if (nonesp_marker == 0) | 
|  | handle_nonesp(ctx, skb, strp->sk); | 
|  | else | 
|  | handle_esp(skb, strp->sk); | 
|  | } | 
|  |  | 
|  | static int espintcp_parse(struct strparser *strp, struct sk_buff *skb) | 
|  | { | 
|  | struct strp_msg *rxm = strp_msg(skb); | 
|  | __be16 blen; | 
|  | u16 len; | 
|  | int err; | 
|  |  | 
|  | if (skb->len < rxm->offset + 2) | 
|  | return 0; | 
|  |  | 
|  | err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen)); | 
|  | if (err < 0) | 
|  | return err; | 
|  |  | 
|  | len = be16_to_cpu(blen); | 
|  | if (len < 6) | 
|  | return -EINVAL; | 
|  |  | 
|  | return len; | 
|  | } | 
|  |  | 
|  | static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, | 
|  | int nonblock, int flags, int *addr_len) | 
|  | { | 
|  | struct espintcp_ctx *ctx = espintcp_getctx(sk); | 
|  | struct sk_buff *skb; | 
|  | int err = 0; | 
|  | int copied; | 
|  | int off = 0; | 
|  |  | 
|  | flags |= nonblock ? MSG_DONTWAIT : 0; | 
|  |  | 
|  | skb = __skb_recv_datagram(sk, &ctx->ike_queue, flags, &off, &err); | 
|  | if (!skb) | 
|  | return err; | 
|  |  | 
|  | copied = len; | 
|  | if (copied > skb->len) | 
|  | copied = skb->len; | 
|  | else if (copied < skb->len) | 
|  | msg->msg_flags |= MSG_TRUNC; | 
|  |  | 
|  | err = skb_copy_datagram_msg(skb, 0, msg, copied); | 
|  | if (unlikely(err)) { | 
|  | kfree_skb(skb); | 
|  | return err; | 
|  | } | 
|  |  | 
|  | if (flags & MSG_TRUNC) | 
|  | copied = skb->len; | 
|  | kfree_skb(skb); | 
|  | return copied; | 
|  | } | 
|  |  | 
|  | int espintcp_queue_out(struct sock *sk, struct sk_buff *skb) | 
|  | { | 
|  | struct espintcp_ctx *ctx = espintcp_getctx(sk); | 
|  |  | 
|  | if (skb_queue_len(&ctx->out_queue) >= netdev_max_backlog) | 
|  | return -ENOBUFS; | 
|  |  | 
|  | __skb_queue_tail(&ctx->out_queue, skb); | 
|  |  | 
|  | return 0; | 
|  | } | 
|  | EXPORT_SYMBOL_GPL(espintcp_queue_out); | 
|  |  | 
|  | /* espintcp length field is 2B and length includes the length field's size */ | 
|  | #define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2) | 
|  |  | 
|  | static int espintcp_sendskb_locked(struct sock *sk, struct espintcp_msg *emsg, | 
|  | int flags) | 
|  | { | 
|  | do { | 
|  | int ret; | 
|  |  | 
|  | ret = skb_send_sock_locked(sk, emsg->skb, | 
|  | emsg->offset, emsg->len); | 
|  | if (ret < 0) | 
|  | return ret; | 
|  |  | 
|  | emsg->len -= ret; | 
|  | emsg->offset += ret; | 
|  | } while (emsg->len > 0); | 
|  |  | 
|  | kfree_skb(emsg->skb); | 
|  | memset(emsg, 0, sizeof(*emsg)); | 
|  |  | 
|  | return 0; | 
|  | } | 
|  |  | 
|  | static int espintcp_sendskmsg_locked(struct sock *sk, | 
|  | struct espintcp_msg *emsg, int flags) | 
|  | { | 
|  | struct sk_msg *skmsg = &emsg->skmsg; | 
|  | struct scatterlist *sg; | 
|  | int done = 0; | 
|  | int ret; | 
|  |  | 
|  | flags |= MSG_SENDPAGE_NOTLAST; | 
|  | sg = &skmsg->sg.data[skmsg->sg.start]; | 
|  | do { | 
|  | size_t size = sg->length - emsg->offset; | 
|  | int offset = sg->offset + emsg->offset; | 
|  | struct page *p; | 
|  |  | 
|  | emsg->offset = 0; | 
|  |  | 
|  | if (sg_is_last(sg)) | 
|  | flags &= ~MSG_SENDPAGE_NOTLAST; | 
|  |  | 
|  | p = sg_page(sg); | 
|  | retry: | 
|  | ret = do_tcp_sendpages(sk, p, offset, size, flags); | 
|  | if (ret < 0) { | 
|  | emsg->offset = offset - sg->offset; | 
|  | skmsg->sg.start += done; | 
|  | return ret; | 
|  | } | 
|  |  | 
|  | if (ret != size) { | 
|  | offset += ret; | 
|  | size -= ret; | 
|  | goto retry; | 
|  | } | 
|  |  | 
|  | done++; | 
|  | put_page(p); | 
|  | sk_mem_uncharge(sk, sg->length); | 
|  | sg = sg_next(sg); | 
|  | } while (sg); | 
|  |  | 
|  | memset(emsg, 0, sizeof(*emsg)); | 
|  |  | 
|  | return 0; | 
|  | } | 
|  |  | 
|  | static int espintcp_push_msgs(struct sock *sk) | 
|  | { | 
|  | struct espintcp_ctx *ctx = espintcp_getctx(sk); | 
|  | struct espintcp_msg *emsg = &ctx->partial; | 
|  | int err; | 
|  |  | 
|  | if (!emsg->len) | 
|  | return 0; | 
|  |  | 
|  | if (ctx->tx_running) | 
|  | return -EAGAIN; | 
|  | ctx->tx_running = 1; | 
|  |  | 
|  | if (emsg->skb) | 
|  | err = espintcp_sendskb_locked(sk, emsg, 0); | 
|  | else | 
|  | err = espintcp_sendskmsg_locked(sk, emsg, 0); | 
|  | if (err == -EAGAIN) { | 
|  | ctx->tx_running = 0; | 
|  | return 0; | 
|  | } | 
|  | if (!err) | 
|  | memset(emsg, 0, sizeof(*emsg)); | 
|  |  | 
|  | ctx->tx_running = 0; | 
|  |  | 
|  | return err; | 
|  | } | 
|  |  | 
|  | int espintcp_push_skb(struct sock *sk, struct sk_buff *skb) | 
|  | { | 
|  | struct espintcp_ctx *ctx = espintcp_getctx(sk); | 
|  | struct espintcp_msg *emsg = &ctx->partial; | 
|  | unsigned int len; | 
|  | int offset; | 
|  |  | 
|  | if (sk->sk_state != TCP_ESTABLISHED) { | 
|  | kfree_skb(skb); | 
|  | return -ECONNRESET; | 
|  | } | 
|  |  | 
|  | offset = skb_transport_offset(skb); | 
|  | len = skb->len - offset; | 
|  |  | 
|  | espintcp_push_msgs(sk); | 
|  |  | 
|  | if (emsg->len) { | 
|  | kfree_skb(skb); | 
|  | return -ENOBUFS; | 
|  | } | 
|  |  | 
|  | skb_set_owner_w(skb, sk); | 
|  |  | 
|  | emsg->offset = offset; | 
|  | emsg->len = len; | 
|  | emsg->skb = skb; | 
|  |  | 
|  | espintcp_push_msgs(sk); | 
|  |  | 
|  | return 0; | 
|  | } | 
|  | EXPORT_SYMBOL_GPL(espintcp_push_skb); | 
|  |  | 
|  | static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) | 
|  | { | 
|  | long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); | 
|  | struct espintcp_ctx *ctx = espintcp_getctx(sk); | 
|  | struct espintcp_msg *emsg = &ctx->partial; | 
|  | struct iov_iter pfx_iter; | 
|  | struct kvec pfx_iov = {}; | 
|  | size_t msglen = size + 2; | 
|  | char buf[2] = {0}; | 
|  | int err, end; | 
|  |  | 
|  | if (msg->msg_flags) | 
|  | return -EOPNOTSUPP; | 
|  |  | 
|  | if (size > MAX_ESPINTCP_MSG) | 
|  | return -EMSGSIZE; | 
|  |  | 
|  | if (msg->msg_controllen) | 
|  | return -EOPNOTSUPP; | 
|  |  | 
|  | lock_sock(sk); | 
|  |  | 
|  | err = espintcp_push_msgs(sk); | 
|  | if (err < 0) { | 
|  | err = -ENOBUFS; | 
|  | goto unlock; | 
|  | } | 
|  |  | 
|  | sk_msg_init(&emsg->skmsg); | 
|  | while (1) { | 
|  | /* only -ENOMEM is possible since we don't coalesce */ | 
|  | err = sk_msg_alloc(sk, &emsg->skmsg, msglen, 0); | 
|  | if (!err) | 
|  | break; | 
|  |  | 
|  | err = sk_stream_wait_memory(sk, &timeo); | 
|  | if (err) | 
|  | goto fail; | 
|  | } | 
|  |  | 
|  | *((__be16 *)buf) = cpu_to_be16(msglen); | 
|  | pfx_iov.iov_base = buf; | 
|  | pfx_iov.iov_len = sizeof(buf); | 
|  | iov_iter_kvec(&pfx_iter, WRITE, &pfx_iov, 1, pfx_iov.iov_len); | 
|  |  | 
|  | err = sk_msg_memcopy_from_iter(sk, &pfx_iter, &emsg->skmsg, | 
|  | pfx_iov.iov_len); | 
|  | if (err < 0) | 
|  | goto fail; | 
|  |  | 
|  | err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, &emsg->skmsg, size); | 
|  | if (err < 0) | 
|  | goto fail; | 
|  |  | 
|  | end = emsg->skmsg.sg.end; | 
|  | emsg->len = size; | 
|  | sk_msg_iter_var_prev(end); | 
|  | sg_mark_end(sk_msg_elem(&emsg->skmsg, end)); | 
|  |  | 
|  | tcp_rate_check_app_limited(sk); | 
|  |  | 
|  | err = espintcp_push_msgs(sk); | 
|  | /* this message could be partially sent, keep it */ | 
|  | if (err < 0) | 
|  | goto unlock; | 
|  | release_sock(sk); | 
|  |  | 
|  | return size; | 
|  |  | 
|  | fail: | 
|  | sk_msg_free(sk, &emsg->skmsg); | 
|  | memset(emsg, 0, sizeof(*emsg)); | 
|  | unlock: | 
|  | release_sock(sk); | 
|  | return err; | 
|  | } | 
|  |  | 
|  | static struct proto espintcp_prot __ro_after_init; | 
|  | static struct proto_ops espintcp_ops __ro_after_init; | 
|  | static struct proto espintcp6_prot; | 
|  | static struct proto_ops espintcp6_ops; | 
|  | static DEFINE_MUTEX(tcpv6_prot_mutex); | 
|  |  | 
|  | static void espintcp_data_ready(struct sock *sk) | 
|  | { | 
|  | struct espintcp_ctx *ctx = espintcp_getctx(sk); | 
|  |  | 
|  | strp_data_ready(&ctx->strp); | 
|  | } | 
|  |  | 
|  | static void espintcp_tx_work(struct work_struct *work) | 
|  | { | 
|  | struct espintcp_ctx *ctx = container_of(work, | 
|  | struct espintcp_ctx, work); | 
|  | struct sock *sk = ctx->strp.sk; | 
|  |  | 
|  | lock_sock(sk); | 
|  | if (!ctx->tx_running) | 
|  | espintcp_push_msgs(sk); | 
|  | release_sock(sk); | 
|  | } | 
|  |  | 
|  | static void espintcp_write_space(struct sock *sk) | 
|  | { | 
|  | struct espintcp_ctx *ctx = espintcp_getctx(sk); | 
|  |  | 
|  | schedule_work(&ctx->work); | 
|  | ctx->saved_write_space(sk); | 
|  | } | 
|  |  | 
|  | static void espintcp_destruct(struct sock *sk) | 
|  | { | 
|  | struct espintcp_ctx *ctx = espintcp_getctx(sk); | 
|  |  | 
|  | ctx->saved_destruct(sk); | 
|  | kfree(ctx); | 
|  | } | 
|  |  | 
|  | bool tcp_is_ulp_esp(struct sock *sk) | 
|  | { | 
|  | return sk->sk_prot == &espintcp_prot || sk->sk_prot == &espintcp6_prot; | 
|  | } | 
|  | EXPORT_SYMBOL_GPL(tcp_is_ulp_esp); | 
|  |  | 
|  | static void build_protos(struct proto *espintcp_prot, | 
|  | struct proto_ops *espintcp_ops, | 
|  | const struct proto *orig_prot, | 
|  | const struct proto_ops *orig_ops); | 
|  | static int espintcp_init_sk(struct sock *sk) | 
|  | { | 
|  | struct inet_connection_sock *icsk = inet_csk(sk); | 
|  | struct strp_callbacks cb = { | 
|  | .rcv_msg = espintcp_rcv, | 
|  | .parse_msg = espintcp_parse, | 
|  | }; | 
|  | struct espintcp_ctx *ctx; | 
|  | int err; | 
|  |  | 
|  | /* sockmap is not compatible with espintcp */ | 
|  | if (sk->sk_user_data) | 
|  | return -EBUSY; | 
|  |  | 
|  | ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); | 
|  | if (!ctx) | 
|  | return -ENOMEM; | 
|  |  | 
|  | err = strp_init(&ctx->strp, sk, &cb); | 
|  | if (err) | 
|  | goto free; | 
|  |  | 
|  | __sk_dst_reset(sk); | 
|  |  | 
|  | strp_check_rcv(&ctx->strp); | 
|  | skb_queue_head_init(&ctx->ike_queue); | 
|  | skb_queue_head_init(&ctx->out_queue); | 
|  |  | 
|  | if (sk->sk_family == AF_INET) { | 
|  | sk->sk_prot = &espintcp_prot; | 
|  | sk->sk_socket->ops = &espintcp_ops; | 
|  | } else { | 
|  | mutex_lock(&tcpv6_prot_mutex); | 
|  | if (!espintcp6_prot.recvmsg) | 
|  | build_protos(&espintcp6_prot, &espintcp6_ops, sk->sk_prot, sk->sk_socket->ops); | 
|  | mutex_unlock(&tcpv6_prot_mutex); | 
|  |  | 
|  | sk->sk_prot = &espintcp6_prot; | 
|  | sk->sk_socket->ops = &espintcp6_ops; | 
|  | } | 
|  | ctx->saved_data_ready = sk->sk_data_ready; | 
|  | ctx->saved_write_space = sk->sk_write_space; | 
|  | ctx->saved_destruct = sk->sk_destruct; | 
|  | sk->sk_data_ready = espintcp_data_ready; | 
|  | sk->sk_write_space = espintcp_write_space; | 
|  | sk->sk_destruct = espintcp_destruct; | 
|  | rcu_assign_pointer(icsk->icsk_ulp_data, ctx); | 
|  | INIT_WORK(&ctx->work, espintcp_tx_work); | 
|  |  | 
|  | /* avoid using task_frag */ | 
|  | sk->sk_allocation = GFP_ATOMIC; | 
|  |  | 
|  | return 0; | 
|  |  | 
|  | free: | 
|  | kfree(ctx); | 
|  | return err; | 
|  | } | 
|  |  | 
|  | static void espintcp_release(struct sock *sk) | 
|  | { | 
|  | struct espintcp_ctx *ctx = espintcp_getctx(sk); | 
|  | struct sk_buff_head queue; | 
|  | struct sk_buff *skb; | 
|  |  | 
|  | __skb_queue_head_init(&queue); | 
|  | skb_queue_splice_init(&ctx->out_queue, &queue); | 
|  |  | 
|  | while ((skb = __skb_dequeue(&queue))) | 
|  | espintcp_push_skb(sk, skb); | 
|  |  | 
|  | tcp_release_cb(sk); | 
|  | } | 
|  |  | 
|  | static void espintcp_close(struct sock *sk, long timeout) | 
|  | { | 
|  | struct espintcp_ctx *ctx = espintcp_getctx(sk); | 
|  | struct espintcp_msg *emsg = &ctx->partial; | 
|  |  | 
|  | strp_stop(&ctx->strp); | 
|  |  | 
|  | sk->sk_prot = &tcp_prot; | 
|  | barrier(); | 
|  |  | 
|  | cancel_work_sync(&ctx->work); | 
|  | strp_done(&ctx->strp); | 
|  |  | 
|  | skb_queue_purge(&ctx->out_queue); | 
|  | skb_queue_purge(&ctx->ike_queue); | 
|  |  | 
|  | if (emsg->len) { | 
|  | if (emsg->skb) | 
|  | kfree_skb(emsg->skb); | 
|  | else | 
|  | sk_msg_free(sk, &emsg->skmsg); | 
|  | } | 
|  |  | 
|  | tcp_close(sk, timeout); | 
|  | } | 
|  |  | 
|  | static __poll_t espintcp_poll(struct file *file, struct socket *sock, | 
|  | poll_table *wait) | 
|  | { | 
|  | __poll_t mask = datagram_poll(file, sock, wait); | 
|  | struct sock *sk = sock->sk; | 
|  | struct espintcp_ctx *ctx = espintcp_getctx(sk); | 
|  |  | 
|  | if (!skb_queue_empty(&ctx->ike_queue)) | 
|  | mask |= EPOLLIN | EPOLLRDNORM; | 
|  |  | 
|  | return mask; | 
|  | } | 
|  |  | 
|  | static void build_protos(struct proto *espintcp_prot, | 
|  | struct proto_ops *espintcp_ops, | 
|  | const struct proto *orig_prot, | 
|  | const struct proto_ops *orig_ops) | 
|  | { | 
|  | memcpy(espintcp_prot, orig_prot, sizeof(struct proto)); | 
|  | memcpy(espintcp_ops, orig_ops, sizeof(struct proto_ops)); | 
|  | espintcp_prot->sendmsg = espintcp_sendmsg; | 
|  | espintcp_prot->recvmsg = espintcp_recvmsg; | 
|  | espintcp_prot->close = espintcp_close; | 
|  | espintcp_prot->release_cb = espintcp_release; | 
|  | espintcp_ops->poll = espintcp_poll; | 
|  | } | 
|  |  | 
|  | static struct tcp_ulp_ops espintcp_ulp __read_mostly = { | 
|  | .name = "espintcp", | 
|  | .owner = THIS_MODULE, | 
|  | .init = espintcp_init_sk, | 
|  | }; | 
|  |  | 
|  | void __init espintcp_init(void) | 
|  | { | 
|  | build_protos(&espintcp_prot, &espintcp_ops, &tcp_prot, &inet_stream_ops); | 
|  |  | 
|  | tcp_register_ulp(&espintcp_ulp); | 
|  | } |