blob: c2f46b17731ec04ed88a8f45acfac87ca8bf49b0 [file]
// SPDX-License-Identifier: GPL-2.0-or-later
/*
* Copyright (C) 2017, Microsoft Corporation.
* Copyright (C) 2018, LG Electronics.
* Copyright (c) 2025, Stefan Metzmacher
*/
#include "internal.h"
static int smbdirect_connection_wait_for_rw_credits(struct smbdirect_socket *sc,
int credits)
{
return smbdirect_socket_wait_for_credits(sc,
SMBDIRECT_SOCKET_CONNECTED,
-ENOTCONN,
&sc->rw_io.credits.wait_queue,
&sc->rw_io.credits.count,
credits);
}
static int smbdirect_connection_calc_rw_credits(struct smbdirect_socket *sc,
const void *buf,
size_t len)
{
return DIV_ROUND_UP(smbdirect_get_buf_page_count(buf, len),
sc->rw_io.credits.num_pages);
}
static int smbdirect_connection_rdma_get_sg_list(void *buf,
size_t size,
struct scatterlist *sg_list,
size_t nentries)
{
bool high = is_vmalloc_addr(buf);
struct page *page;
size_t offset, len;
int i = 0;
if (size == 0 || nentries < smbdirect_get_buf_page_count(buf, size))
return -EINVAL;
offset = offset_in_page(buf);
buf -= offset;
while (size > 0) {
len = min_t(size_t, PAGE_SIZE - offset, size);
if (high)
page = vmalloc_to_page(buf);
else
page = kmap_to_page(buf);
if (!sg_list)
return -EINVAL;
sg_set_page(sg_list, page, len, offset);
sg_list = sg_next(sg_list);
buf += PAGE_SIZE;
size -= len;
offset = 0;
i++;
}
return i;
}
static void smbdirect_connection_rw_io_free(struct smbdirect_rw_io *msg,
enum dma_data_direction dir)
{
struct smbdirect_socket *sc = msg->socket;
rdma_rw_ctx_destroy(&msg->rdma_ctx,
sc->ib.qp,
sc->ib.qp->port,
msg->sgt.sgl,
msg->sgt.nents,
dir);
sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
kfree(msg);
}
static void smbdirect_connection_rdma_rw_done(struct ib_cq *cq, struct ib_wc *wc,
enum dma_data_direction dir)
{
struct smbdirect_rw_io *msg =
container_of(wc->wr_cqe, struct smbdirect_rw_io, cqe);
struct smbdirect_socket *sc = msg->socket;
if (wc->status != IB_WC_SUCCESS) {
msg->error = -EIO;
pr_err("read/write error. opcode = %d, status = %s(%d)\n",
wc->opcode, ib_wc_status_msg(wc->status), wc->status);
if (wc->status != IB_WC_WR_FLUSH_ERR)
smbdirect_socket_schedule_cleanup(sc, msg->error);
}
complete(msg->completion);
}
static void smbdirect_connection_rdma_read_done(struct ib_cq *cq, struct ib_wc *wc)
{
smbdirect_connection_rdma_rw_done(cq, wc, DMA_FROM_DEVICE);
}
static void smbdirect_connection_rdma_write_done(struct ib_cq *cq, struct ib_wc *wc)
{
smbdirect_connection_rdma_rw_done(cq, wc, DMA_TO_DEVICE);
}
int smbdirect_connection_rdma_xmit(struct smbdirect_socket *sc,
void *buf, size_t buf_len,
struct smbdirect_buffer_descriptor_v1 *desc,
size_t desc_len,
bool is_read)
{
const struct smbdirect_socket_parameters *sp = &sc->parameters;
enum dma_data_direction direction = is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE;
struct smbdirect_rw_io *msg, *next_msg;
size_t i;
int ret;
DECLARE_COMPLETION_ONSTACK(completion);
struct ib_send_wr *first_wr;
LIST_HEAD(msg_list);
u8 *desc_buf;
int credits_needed;
size_t desc_buf_len, desc_num = 0;
if (sc->status != SMBDIRECT_SOCKET_CONNECTED)
return -ENOTCONN;
if (buf_len > sp->max_read_write_size)
return -EINVAL;
/* calculate needed credits */
credits_needed = 0;
desc_buf = buf;
for (i = 0; i < desc_len / sizeof(*desc); i++) {
if (!buf_len)
break;
desc_buf_len = le32_to_cpu(desc[i].length);
if (!desc_buf_len)
return -EINVAL;
if (desc_buf_len > buf_len) {
desc_buf_len = buf_len;
desc[i].length = cpu_to_le32(desc_buf_len);
buf_len = 0;
}
credits_needed += smbdirect_connection_calc_rw_credits(sc,
desc_buf,
desc_buf_len);
desc_buf += desc_buf_len;
buf_len -= desc_buf_len;
desc_num++;
}
smbdirect_log_rdma_rw(sc, SMBDIRECT_LOG_INFO,
"RDMA %s, len %zu, needed credits %d\n",
str_read_write(is_read), buf_len, credits_needed);
ret = smbdirect_connection_wait_for_rw_credits(sc, credits_needed);
if (ret < 0)
return ret;
/* build rdma_rw_ctx for each descriptor */
desc_buf = buf;
for (i = 0; i < desc_num; i++) {
size_t page_count;
msg = kzalloc_flex(*msg, sg_list, SG_CHUNK_SIZE,
sc->rw_io.mem.gfp_mask);
if (!msg) {
ret = -ENOMEM;
goto out;
}
desc_buf_len = le32_to_cpu(desc[i].length);
page_count = smbdirect_get_buf_page_count(desc_buf, desc_buf_len);
msg->socket = sc;
msg->cqe.done = is_read ?
smbdirect_connection_rdma_read_done :
smbdirect_connection_rdma_write_done;
msg->completion = &completion;
msg->sgt.sgl = &msg->sg_list[0];
ret = sg_alloc_table_chained(&msg->sgt,
page_count,
msg->sg_list,
SG_CHUNK_SIZE);
if (ret) {
ret = -ENOMEM;
goto free_msg;
}
ret = smbdirect_connection_rdma_get_sg_list(desc_buf,
desc_buf_len,
msg->sgt.sgl,
msg->sgt.orig_nents);
if (ret < 0)
goto free_table;
ret = rdma_rw_ctx_init(&msg->rdma_ctx,
sc->ib.qp,
sc->ib.qp->port,
msg->sgt.sgl,
page_count,
0,
le64_to_cpu(desc[i].offset),
le32_to_cpu(desc[i].token),
direction);
if (ret < 0) {
pr_err("failed to init rdma_rw_ctx: %d\n", ret);
goto free_table;
}
list_add_tail(&msg->list, &msg_list);
desc_buf += desc_buf_len;
}
/* concatenate work requests of rdma_rw_ctxs */
first_wr = NULL;
list_for_each_entry_reverse(msg, &msg_list, list) {
first_wr = rdma_rw_ctx_wrs(&msg->rdma_ctx,
sc->ib.qp,
sc->ib.qp->port,
&msg->cqe,
first_wr);
}
ret = ib_post_send(sc->ib.qp, first_wr, NULL);
if (ret) {
pr_err("failed to post send wr for RDMA R/W: %d\n", ret);
goto out;
}
msg = list_last_entry(&msg_list, struct smbdirect_rw_io, list);
wait_for_completion(&completion);
ret = msg->error;
out:
list_for_each_entry_safe(msg, next_msg, &msg_list, list) {
list_del(&msg->list);
smbdirect_connection_rw_io_free(msg, direction);
}
atomic_add(credits_needed, &sc->rw_io.credits.count);
wake_up(&sc->rw_io.credits.wait_queue);
return ret;
free_table:
sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
free_msg:
kfree(msg);
goto out;
}
__SMBDIRECT_EXPORT_SYMBOL__(smbdirect_connection_rdma_xmit);