smb: server: make use of smbdirect_socket.send_io.lcredits.*

This introduces logic to prevent on overflow of
the send submission queue with ib_post_send() easier.

As we first get a local credit and then a remote credit
before we mark us as pending.

From reading the git history of the linux smbdirect
implementations in client and server) it was seen
that a peer granted more credits than we requested.
I guess that only happened because of bugs in our
implementation which was active as client and server.
I guess Windows won't do that.

So the local credits make sure we only use the amount
of credits we asked for.

Fixes: 0626e6641f6b ("cifsd: add server handler for central processing and tranport layers")
Cc: Namjae Jeon <linkinjeon@kernel.org>
Cc: Steve French <smfrench@gmail.com>
Cc: Tom Talpey <tom@talpey.com>
Cc: linux-cifs@vger.kernel.org
Cc: samba-technical@lists.samba.org
Signed-off-by: Stefan Metzmacher <metze@samba.org>
Acked-by: Namjae Jeon <linkinjeon@kernel.org>
Signed-off-by: Steve French <stfrench@microsoft.com>
diff --git a/fs/smb/server/transport_rdma.c b/fs/smb/server/transport_rdma.c
index 2aa8e4d..8aaa950 100644
--- a/fs/smb/server/transport_rdma.c
+++ b/fs/smb/server/transport_rdma.c
@@ -219,6 +219,7 @@ static void smb_direct_disconnect_wake_up_all(struct smbdirect_socket *sc)
 	 * in order to notice the broken connection.
 	 */
 	wake_up_all(&sc->status_wait);
+	wake_up_all(&sc->send_io.lcredits.wait_queue);
 	wake_up_all(&sc->send_io.credits.wait_queue);
 	wake_up_all(&sc->send_io.pending.zero_wait_queue);
 	wake_up_all(&sc->recv_io.reassembly.wait_queue);
@@ -916,6 +917,7 @@ static void send_done(struct ib_cq *cq, struct ib_wc *wc)
 {
 	struct smbdirect_send_io *sendmsg, *sibling, *next;
 	struct smbdirect_socket *sc;
+	int lcredits = 0;
 
 	sendmsg = container_of(wc->wr_cqe, struct smbdirect_send_io, cqe);
 	sc = sendmsg->socket;
@@ -930,9 +932,11 @@ static void send_done(struct ib_cq *cq, struct ib_wc *wc)
 	list_for_each_entry_safe(sibling, next, &sendmsg->sibling_list, sibling_list) {
 		list_del_init(&sibling->sibling_list);
 		smb_direct_free_sendmsg(sc, sibling);
+		lcredits += 1;
 	}
 	/* Note this frees wc->wr_cqe, but not wc */
 	smb_direct_free_sendmsg(sc, sendmsg);
+	lcredits += 1;
 
 	if (wc->status != IB_WC_SUCCESS || wc->opcode != IB_WC_SEND) {
 		pr_err("Send error. status='%s (%d)', opcode=%d\n",
@@ -942,6 +946,9 @@ static void send_done(struct ib_cq *cq, struct ib_wc *wc)
 		return;
 	}
 
+	atomic_add(lcredits, &sc->send_io.lcredits.count);
+	wake_up(&sc->send_io.lcredits.wait_queue);
+
 	if (atomic_dec_and_test(&sc->send_io.pending.count))
 		wake_up(&sc->send_io.pending.zero_wait_queue);
 }
@@ -1081,6 +1088,23 @@ static int wait_for_credits(struct smbdirect_socket *sc,
 	} while (true);
 }
 
+static int wait_for_send_lcredit(struct smbdirect_socket *sc,
+				 struct smbdirect_send_batch *send_ctx)
+{
+	if (send_ctx && (atomic_read(&sc->send_io.lcredits.count) <= 1)) {
+		int ret;
+
+		ret = smb_direct_flush_send_list(sc, send_ctx, false);
+		if (ret)
+			return ret;
+	}
+
+	return wait_for_credits(sc,
+				&sc->send_io.lcredits.wait_queue,
+				&sc->send_io.lcredits.count,
+				1);
+}
+
 static int wait_for_send_credits(struct smbdirect_socket *sc,
 				 struct smbdirect_send_batch *send_ctx)
 {
@@ -1268,9 +1292,13 @@ static int smb_direct_post_send_data(struct smbdirect_socket *sc,
 	int data_length;
 	struct scatterlist sg[SMBDIRECT_SEND_IO_MAX_SGE - 1];
 
+	ret = wait_for_send_lcredit(sc, send_ctx);
+	if (ret)
+		goto lcredit_failed;
+
 	ret = wait_for_send_credits(sc, send_ctx);
 	if (ret)
-		return ret;
+		goto credit_failed;
 
 	data_length = 0;
 	for (i = 0; i < niov; i++)
@@ -1278,10 +1306,8 @@ static int smb_direct_post_send_data(struct smbdirect_socket *sc,
 
 	ret = smb_direct_create_header(sc, data_length, remaining_data_length,
 				       &msg);
-	if (ret) {
-		atomic_inc(&sc->send_io.credits.count);
-		return ret;
-	}
+	if (ret)
+		goto header_failed;
 
 	for (i = 0; i < niov; i++) {
 		struct ib_sge *sge;
@@ -1319,7 +1345,11 @@ static int smb_direct_post_send_data(struct smbdirect_socket *sc,
 	return 0;
 err:
 	smb_direct_free_sendmsg(sc, msg);
+header_failed:
 	atomic_inc(&sc->send_io.credits.count);
+credit_failed:
+	atomic_inc(&sc->send_io.lcredits.count);
+lcredit_failed:
 	return ret;
 }
 
@@ -1897,6 +1927,8 @@ static int smb_direct_init_params(struct smbdirect_socket *sc)
 		return -EINVAL;
 	}
 
+	atomic_set(&sc->send_io.lcredits.count, sp->send_credit_target);
+
 	maxpages = DIV_ROUND_UP(sp->max_read_write_size, PAGE_SIZE);
 	sc->rw_io.credits.max = rdma_rw_mr_factor(sc->ib.dev,
 						  sc->rdma.cm_id->port_num,