git: 85df11a1dec6 - main - ktls: deep copy tls_enable struct for in-kernel tcp consumers

From: Richard Scheffenegger <rscheff_at_FreeBSD.org>
Date: Wed, 13 Mar 2024 21:08:29 UTC
The branch main has been updated by rscheff:

URL: https://cgit.FreeBSD.org/src/commit/?id=85df11a1dec6eab9efbce9fd20712402a8e7ac7c

commit 85df11a1dec6eab9efbce9fd20712402a8e7ac7c
Author:     Richard Scheffenegger <rscheff@FreeBSD.org>
AuthorDate: 2024-03-13 11:35:51 +0000
Commit:     Richard Scheffenegger <rscheff@FreeBSD.org>
CommitDate: 2024-03-13 12:23:13 +0000

    ktls: deep copy tls_enable struct for in-kernel tcp consumers
    
    Doing a deep copy of the keys early allows users of the
    tls_enable structure to assume kernel memory.
    This enables the socket options to be set by kernel threads.
    
    Reviewed By:    #transport, tuexen, jhb, rrs
    Sponsored by:   NetApp, Inc.
    X-NetApp-PR:    #79
    Differential Revision:  https://reviews.freebsd.org/D44250
---
 sys/kern/uipc_ktls.c     | 96 ++++++++++++++++++++++++++++++++++++++++--------
 sys/netinet/tcp_usrreq.c | 44 ++++------------------
 sys/sys/ktls.h           | 17 +++++----
 3 files changed, 97 insertions(+), 60 deletions(-)

diff --git a/sys/kern/uipc_ktls.c b/sys/kern/uipc_ktls.c
index deba6940bbee..df296090ec97 100644
--- a/sys/kern/uipc_ktls.c
+++ b/sys/kern/uipc_ktls.c
@@ -297,10 +297,86 @@ SYSCTL_COUNTER_U64(_kern_ipc_tls_toe, OID_AUTO, chacha20, CTLFLAG_RD,
 
 static MALLOC_DEFINE(M_KTLS, "ktls", "Kernel TLS");
 
+static void ktls_reclaim_thread(void *ctx);
 static void ktls_reset_receive_tag(void *context, int pending);
 static void ktls_reset_send_tag(void *context, int pending);
 static void ktls_work_thread(void *ctx);
-static void ktls_reclaim_thread(void *ctx);
+
+int
+ktls_copyin_tls_enable(struct sockopt *sopt, struct tls_enable *tls)
+{
+	struct tls_enable_v0 tls_v0;
+	int error;
+	uint8_t *cipher_key = NULL, *iv = NULL, *auth_key = NULL;
+
+	if (sopt->sopt_valsize == sizeof(tls_v0)) {
+		error = sooptcopyin(sopt, &tls_v0, sizeof(tls_v0), sizeof(tls_v0));
+		if (error != 0)
+			goto done;
+		memset(tls, 0, sizeof(*tls));
+		tls->cipher_key = tls_v0.cipher_key;
+		tls->iv = tls_v0.iv;
+		tls->auth_key = tls_v0.auth_key;
+		tls->cipher_algorithm = tls_v0.cipher_algorithm;
+		tls->cipher_key_len = tls_v0.cipher_key_len;
+		tls->iv_len = tls_v0.iv_len;
+		tls->auth_algorithm = tls_v0.auth_algorithm;
+		tls->auth_key_len = tls_v0.auth_key_len;
+		tls->flags = tls_v0.flags;
+		tls->tls_vmajor = tls_v0.tls_vmajor;
+		tls->tls_vminor = tls_v0.tls_vminor;
+	} else
+		error = sooptcopyin(sopt, tls, sizeof(*tls), sizeof(*tls));
+
+	if (error != 0)
+		goto done;
+
+	/*
+	 * Now do a deep copy of the variable-length arrays in the struct, so that
+	 * subsequent consumers of it can reliably assume kernel memory. This
+	 * requires doing our own allocations, which we will free in the
+	 * error paths so that our caller need only worry about outstanding
+	 * allocations existing on successful return.
+	 */
+	cipher_key = malloc(tls->cipher_key_len, M_KTLS, M_WAITOK);
+	iv = malloc(tls->iv_len, M_KTLS, M_WAITOK);
+	auth_key = malloc(tls->auth_key_len, M_KTLS, M_WAITOK);
+	if (sopt->sopt_td != NULL) {
+		error = copyin(tls->cipher_key, cipher_key, tls->cipher_key_len);
+		if (error != 0)
+			goto done;
+		error = copyin(tls->iv, iv, tls->iv_len);
+		if (error != 0)
+			goto done;
+		error = copyin(tls->auth_key, auth_key, tls->auth_key_len);
+		if (error != 0)
+			goto done;
+	} else {
+		bcopy(tls->cipher_key, cipher_key, tls->cipher_key_len);
+		bcopy(tls->iv, iv, tls->iv_len);
+		bcopy(tls->auth_key, auth_key, tls->auth_key_len);
+	}
+	tls->cipher_key = cipher_key;
+	tls->iv = iv;
+	tls->auth_key = auth_key;
+
+done:
+	if (error != 0) {
+		zfree(cipher_key, M_KTLS);
+		zfree(iv, M_KTLS);
+		zfree(auth_key, M_KTLS);
+	}
+
+	return (error);
+}
+
+void
+ktls_cleanup_tls_enable(struct tls_enable *tls)
+{
+	zfree(__DECONST(void *, tls->cipher_key), M_KTLS);
+	zfree(__DECONST(void *, tls->iv), M_KTLS);
+	zfree(__DECONST(void *, tls->auth_key), M_KTLS);
+}
 
 static u_int
 ktls_get_cpu(struct socket *so)
@@ -702,18 +778,12 @@ ktls_create_session(struct socket *so, struct tls_enable *en,
 		tls->params.auth_key_len = en->auth_key_len;
 		tls->params.auth_key = malloc(en->auth_key_len, M_KTLS,
 		    M_WAITOK);
-		error = copyin(en->auth_key, tls->params.auth_key,
-		    en->auth_key_len);
-		if (error)
-			goto out;
+		bcopy(en->auth_key, tls->params.auth_key, en->auth_key_len);
 	}
 
 	tls->params.cipher_key_len = en->cipher_key_len;
 	tls->params.cipher_key = malloc(en->cipher_key_len, M_KTLS, M_WAITOK);
-	error = copyin(en->cipher_key, tls->params.cipher_key,
-	    en->cipher_key_len);
-	if (error)
-		goto out;
+	bcopy(en->cipher_key, tls->params.cipher_key, en->cipher_key_len);
 
 	/*
 	 * This holds the implicit portion of the nonce for AEAD
@@ -722,9 +792,7 @@ ktls_create_session(struct socket *so, struct tls_enable *en,
 	 */
 	if (en->iv_len != 0) {
 		tls->params.iv_len = en->iv_len;
-		error = copyin(en->iv, tls->params.iv, en->iv_len);
-		if (error)
-			goto out;
+		bcopy(en->iv, tls->params.iv, en->iv_len);
 
 		/*
 		 * For TLS 1.2 with GCM, generate an 8-byte nonce as a
@@ -740,10 +808,6 @@ ktls_create_session(struct socket *so, struct tls_enable *en,
 
 	*tlsp = tls;
 	return (0);
-
-out:
-	ktls_free(tls);
-	return (error);
 }
 
 static struct ktls_session *
diff --git a/sys/netinet/tcp_usrreq.c b/sys/netinet/tcp_usrreq.c
index a73d2a15c1d5..916fe33e8704 100644
--- a/sys/netinet/tcp_usrreq.c
+++ b/sys/netinet/tcp_usrreq.c
@@ -1914,37 +1914,6 @@ CTASSERT(TCP_CA_NAME_MAX <= TCP_LOG_ID_LEN);
 CTASSERT(TCP_LOG_REASON_LEN <= TCP_LOG_ID_LEN);
 #endif
 
-#ifdef KERN_TLS
-static int
-copyin_tls_enable(struct sockopt *sopt, struct tls_enable *tls)
-{
-	struct tls_enable_v0 tls_v0;
-	int error;
-
-	if (sopt->sopt_valsize == sizeof(tls_v0)) {
-		error = sooptcopyin(sopt, &tls_v0, sizeof(tls_v0),
-		    sizeof(tls_v0));
-		if (error)
-			return (error);
-		memset(tls, 0, sizeof(*tls));
-		tls->cipher_key = tls_v0.cipher_key;
-		tls->iv = tls_v0.iv;
-		tls->auth_key = tls_v0.auth_key;
-		tls->cipher_algorithm = tls_v0.cipher_algorithm;
-		tls->cipher_key_len = tls_v0.cipher_key_len;
-		tls->iv_len = tls_v0.iv_len;
-		tls->auth_algorithm = tls_v0.auth_algorithm;
-		tls->auth_key_len = tls_v0.auth_key_len;
-		tls->flags = tls_v0.flags;
-		tls->tls_vmajor = tls_v0.tls_vmajor;
-		tls->tls_vminor = tls_v0.tls_vminor;
-		return (0);
-	}
-
-	return (sooptcopyin(sopt, tls, sizeof(*tls), sizeof(*tls)));
-}
-#endif
-
 extern struct cc_algo newreno_cc_algo;
 
 static int
@@ -2292,15 +2261,16 @@ unlock_and_done:
 #ifdef KERN_TLS
 		case TCP_TXTLS_ENABLE:
 			INP_WUNLOCK(inp);
-			error = copyin_tls_enable(sopt, &tls);
-			if (error)
+			error = ktls_copyin_tls_enable(sopt, &tls);
+			if (error != 0)
 				break;
 			error = ktls_enable_tx(so, &tls);
+			ktls_cleanup_tls_enable(&tls);
 			break;
 		case TCP_TXTLS_MODE:
 			INP_WUNLOCK(inp);
 			error = sooptcopyin(sopt, &ui, sizeof(ui), sizeof(ui));
-			if (error)
+			if (error != 0)
 				return (error);
 
 			INP_WLOCK_RECHECK(inp);
@@ -2309,11 +2279,11 @@ unlock_and_done:
 			break;
 		case TCP_RXTLS_ENABLE:
 			INP_WUNLOCK(inp);
-			error = sooptcopyin(sopt, &tls, sizeof(tls),
-			    sizeof(tls));
-			if (error)
+			error = ktls_copyin_tls_enable(sopt, &tls);
+			if (error != 0)
 				break;
 			error = ktls_enable_rx(so, &tls);
+			ktls_cleanup_tls_enable(&tls);
 			break;
 #endif
 		case TCP_MAXUNACKTIME:
diff --git a/sys/sys/ktls.h b/sys/sys/ktls.h
index 693864394ffe..9b3433f4b1fd 100644
--- a/sys/sys/ktls.h
+++ b/sys/sys/ktls.h
@@ -174,6 +174,7 @@ struct m_snd_tag;
 struct mbuf;
 struct sockbuf;
 struct socket;
+struct sockopt;
 
 struct ktls_session {
 	struct ktls_ocf_session *ocf_session;
@@ -213,27 +214,29 @@ typedef enum {
 } ktls_mbuf_crypto_st_t;
 
 void ktls_check_rx(struct sockbuf *sb);
-ktls_mbuf_crypto_st_t ktls_mbuf_crypto_state(struct mbuf *mb, int offset, int len);
+void ktls_cleanup_tls_enable(struct tls_enable *tls);
+int ktls_copyin_tls_enable(struct sockopt *sopt, struct tls_enable *tls);
 void ktls_disable_ifnet(void *arg);
 int ktls_enable_rx(struct socket *so, struct tls_enable *en);
 int ktls_enable_tx(struct socket *so, struct tls_enable *en);
+void ktls_enqueue(struct mbuf *m, struct socket *so, int page_count);
+void ktls_enqueue_to_free(struct mbuf *m);
 void ktls_destroy(struct ktls_session *tls);
 void ktls_frame(struct mbuf *m, struct ktls_session *tls, int *enqueue_cnt,
     uint8_t record_type);
-bool ktls_permit_empty_frames(struct ktls_session *tls);
-void ktls_seq(struct sockbuf *sb, struct mbuf *m);
-void ktls_enqueue(struct mbuf *m, struct socket *so, int page_count);
-void ktls_enqueue_to_free(struct mbuf *m);
 int ktls_get_rx_mode(struct socket *so, int *modep);
-int ktls_set_tx_mode(struct socket *so, int mode);
 int ktls_get_tx_mode(struct socket *so, int *modep);
 int ktls_get_rx_sequence(struct inpcb *inp, uint32_t *tcpseq, uint64_t *tlsseq);
 void ktls_input_ifp_mismatch(struct sockbuf *sb, struct ifnet *ifp);
-int ktls_output_eagain(struct inpcb *inp, struct ktls_session *tls);
+ktls_mbuf_crypto_st_t ktls_mbuf_crypto_state(struct mbuf *mb, int offset, int len);
 #ifdef RATELIMIT
 int ktls_modify_txrtlmt(struct ktls_session *tls, uint64_t max_pacing_rate);
 #endif
+int ktls_output_eagain(struct inpcb *inp, struct ktls_session *tls);
 bool ktls_pending_rx_info(struct sockbuf *sb, uint64_t *seqnop, size_t *residp);
+bool ktls_permit_empty_frames(struct ktls_session *tls);
+void ktls_seq(struct sockbuf *sb, struct mbuf *m);
+int ktls_set_tx_mode(struct socket *so, int mode);
 
 static inline struct ktls_session *
 ktls_hold(struct ktls_session *tls)