Re: git: 85df11a1dec6 - main - ktls: deep copy tls_enable struct for in-kernel tcp consumers

From: Mark Johnston <markj_at_freebsd.org>
Date: Fri, 15 Mar 2024 11:02:15 UTC
On Wed, Mar 13, 2024 at 09:08:29PM +0000, Richard Scheffenegger wrote:
> The branch main has been updated by rscheff:
> 
> URL: https://cgit.FreeBSD.org/src/commit/?id=85df11a1dec6eab9efbce9fd20712402a8e7ac7c
> 
> commit 85df11a1dec6eab9efbce9fd20712402a8e7ac7c
> Author:     Richard Scheffenegger <rscheff@FreeBSD.org>
> AuthorDate: 2024-03-13 11:35:51 +0000
> Commit:     Richard Scheffenegger <rscheff@FreeBSD.org>
> CommitDate: 2024-03-13 12:23:13 +0000
> 
>     ktls: deep copy tls_enable struct for in-kernel tcp consumers
>     
>     Doing a deep copy of the keys early allows users of the
>     tls_enable structure to assume kernel memory.
>     This enables the socket options to be set by kernel threads.
>     
>     Reviewed By:    #transport, tuexen, jhb, rrs
>     Sponsored by:   NetApp, Inc.
>     X-NetApp-PR:    #79
>     Differential Revision:  https://reviews.freebsd.org/D44250
> ---
>  sys/kern/uipc_ktls.c     | 96 ++++++++++++++++++++++++++++++++++++++++--------
>  sys/netinet/tcp_usrreq.c | 44 ++++------------------
>  sys/sys/ktls.h           | 17 +++++----
>  3 files changed, 97 insertions(+), 60 deletions(-)
> 
> diff --git a/sys/kern/uipc_ktls.c b/sys/kern/uipc_ktls.c
> index deba6940bbee..df296090ec97 100644
> --- a/sys/kern/uipc_ktls.c
> +++ b/sys/kern/uipc_ktls.c
> @@ -297,10 +297,86 @@ SYSCTL_COUNTER_U64(_kern_ipc_tls_toe, OID_AUTO, chacha20, CTLFLAG_RD,
>  
>  static MALLOC_DEFINE(M_KTLS, "ktls", "Kernel TLS");
>  
> +static void ktls_reclaim_thread(void *ctx);
>  static void ktls_reset_receive_tag(void *context, int pending);
>  static void ktls_reset_send_tag(void *context, int pending);
>  static void ktls_work_thread(void *ctx);
> -static void ktls_reclaim_thread(void *ctx);
> +
> +int
> +ktls_copyin_tls_enable(struct sockopt *sopt, struct tls_enable *tls)
> +{
> +	struct tls_enable_v0 tls_v0;
> +	int error;
> +	uint8_t *cipher_key = NULL, *iv = NULL, *auth_key = NULL;
> +
> +	if (sopt->sopt_valsize == sizeof(tls_v0)) {
> +		error = sooptcopyin(sopt, &tls_v0, sizeof(tls_v0), sizeof(tls_v0));
> +		if (error != 0)
> +			goto done;
> +		memset(tls, 0, sizeof(*tls));
> +		tls->cipher_key = tls_v0.cipher_key;
> +		tls->iv = tls_v0.iv;
> +		tls->auth_key = tls_v0.auth_key;
> +		tls->cipher_algorithm = tls_v0.cipher_algorithm;
> +		tls->cipher_key_len = tls_v0.cipher_key_len;
> +		tls->iv_len = tls_v0.iv_len;
> +		tls->auth_algorithm = tls_v0.auth_algorithm;
> +		tls->auth_key_len = tls_v0.auth_key_len;
> +		tls->flags = tls_v0.flags;
> +		tls->tls_vmajor = tls_v0.tls_vmajor;
> +		tls->tls_vminor = tls_v0.tls_vminor;
> +	} else
> +		error = sooptcopyin(sopt, tls, sizeof(*tls), sizeof(*tls));
> +
> +	if (error != 0)
> +		goto done;
> +
> +	/*
> +	 * Now do a deep copy of the variable-length arrays in the struct, so that
> +	 * subsequent consumers of it can reliably assume kernel memory. This
> +	 * requires doing our own allocations, which we will free in the
> +	 * error paths so that our caller need only worry about outstanding
> +	 * allocations existing on successful return.
> +	 */
> +	cipher_key = malloc(tls->cipher_key_len, M_KTLS, M_WAITOK);
> +	iv = malloc(tls->iv_len, M_KTLS, M_WAITOK);
> +	auth_key = malloc(tls->auth_key_len, M_KTLS, M_WAITOK);

Hi Richard,

These lengths need to be validated against some maximum and minimum
values, as they are provided by userspace and thus aren't to be trusted.
See https://syzkaller.appspot.com/bug?extid=72022fa9163fa958b66c

> +	if (sopt->sopt_td != NULL) {
> +		error = copyin(tls->cipher_key, cipher_key, tls->cipher_key_len);
> +		if (error != 0)
> +			goto done;
> +		error = copyin(tls->iv, iv, tls->iv_len);
> +		if (error != 0)
> +			goto done;
> +		error = copyin(tls->auth_key, auth_key, tls->auth_key_len);
> +		if (error != 0)
> +			goto done;
> +	} else {
> +		bcopy(tls->cipher_key, cipher_key, tls->cipher_key_len);
> +		bcopy(tls->iv, iv, tls->iv_len);
> +		bcopy(tls->auth_key, auth_key, tls->auth_key_len);
> +	}
> +	tls->cipher_key = cipher_key;
> +	tls->iv = iv;
> +	tls->auth_key = auth_key;
> +
> +done:
> +	if (error != 0) {
> +		zfree(cipher_key, M_KTLS);
> +		zfree(iv, M_KTLS);
> +		zfree(auth_key, M_KTLS);
> +	}
> +
> +	return (error);
> +}
> +
> +void
> +ktls_cleanup_tls_enable(struct tls_enable *tls)
> +{
> +	zfree(__DECONST(void *, tls->cipher_key), M_KTLS);
> +	zfree(__DECONST(void *, tls->iv), M_KTLS);
> +	zfree(__DECONST(void *, tls->auth_key), M_KTLS);
> +}
>  
>  static u_int
>  ktls_get_cpu(struct socket *so)
> @@ -702,18 +778,12 @@ ktls_create_session(struct socket *so, struct tls_enable *en,
>  		tls->params.auth_key_len = en->auth_key_len;
>  		tls->params.auth_key = malloc(en->auth_key_len, M_KTLS,
>  		    M_WAITOK);
> -		error = copyin(en->auth_key, tls->params.auth_key,
> -		    en->auth_key_len);
> -		if (error)
> -			goto out;
> +		bcopy(en->auth_key, tls->params.auth_key, en->auth_key_len);
>  	}
>  
>  	tls->params.cipher_key_len = en->cipher_key_len;
>  	tls->params.cipher_key = malloc(en->cipher_key_len, M_KTLS, M_WAITOK);
> -	error = copyin(en->cipher_key, tls->params.cipher_key,
> -	    en->cipher_key_len);
> -	if (error)
> -		goto out;
> +	bcopy(en->cipher_key, tls->params.cipher_key, en->cipher_key_len);
>  
>  	/*
>  	 * This holds the implicit portion of the nonce for AEAD
> @@ -722,9 +792,7 @@ ktls_create_session(struct socket *so, struct tls_enable *en,
>  	 */
>  	if (en->iv_len != 0) {
>  		tls->params.iv_len = en->iv_len;
> -		error = copyin(en->iv, tls->params.iv, en->iv_len);
> -		if (error)
> -			goto out;
> +		bcopy(en->iv, tls->params.iv, en->iv_len);
>  
>  		/*
>  		 * For TLS 1.2 with GCM, generate an 8-byte nonce as a
> @@ -740,10 +808,6 @@ ktls_create_session(struct socket *so, struct tls_enable *en,
>  
>  	*tlsp = tls;
>  	return (0);
> -
> -out:
> -	ktls_free(tls);
> -	return (error);
>  }
>  
>  static struct ktls_session *
> diff --git a/sys/netinet/tcp_usrreq.c b/sys/netinet/tcp_usrreq.c
> index a73d2a15c1d5..916fe33e8704 100644
> --- a/sys/netinet/tcp_usrreq.c
> +++ b/sys/netinet/tcp_usrreq.c
> @@ -1914,37 +1914,6 @@ CTASSERT(TCP_CA_NAME_MAX <= TCP_LOG_ID_LEN);
>  CTASSERT(TCP_LOG_REASON_LEN <= TCP_LOG_ID_LEN);
>  #endif
>  
> -#ifdef KERN_TLS
> -static int
> -copyin_tls_enable(struct sockopt *sopt, struct tls_enable *tls)
> -{
> -	struct tls_enable_v0 tls_v0;
> -	int error;
> -
> -	if (sopt->sopt_valsize == sizeof(tls_v0)) {
> -		error = sooptcopyin(sopt, &tls_v0, sizeof(tls_v0),
> -		    sizeof(tls_v0));
> -		if (error)
> -			return (error);
> -		memset(tls, 0, sizeof(*tls));
> -		tls->cipher_key = tls_v0.cipher_key;
> -		tls->iv = tls_v0.iv;
> -		tls->auth_key = tls_v0.auth_key;
> -		tls->cipher_algorithm = tls_v0.cipher_algorithm;
> -		tls->cipher_key_len = tls_v0.cipher_key_len;
> -		tls->iv_len = tls_v0.iv_len;
> -		tls->auth_algorithm = tls_v0.auth_algorithm;
> -		tls->auth_key_len = tls_v0.auth_key_len;
> -		tls->flags = tls_v0.flags;
> -		tls->tls_vmajor = tls_v0.tls_vmajor;
> -		tls->tls_vminor = tls_v0.tls_vminor;
> -		return (0);
> -	}
> -
> -	return (sooptcopyin(sopt, tls, sizeof(*tls), sizeof(*tls)));
> -}
> -#endif
> -
>  extern struct cc_algo newreno_cc_algo;
>  
>  static int
> @@ -2292,15 +2261,16 @@ unlock_and_done:
>  #ifdef KERN_TLS
>  		case TCP_TXTLS_ENABLE:
>  			INP_WUNLOCK(inp);
> -			error = copyin_tls_enable(sopt, &tls);
> -			if (error)
> +			error = ktls_copyin_tls_enable(sopt, &tls);
> +			if (error != 0)
>  				break;
>  			error = ktls_enable_tx(so, &tls);
> +			ktls_cleanup_tls_enable(&tls);
>  			break;
>  		case TCP_TXTLS_MODE:
>  			INP_WUNLOCK(inp);
>  			error = sooptcopyin(sopt, &ui, sizeof(ui), sizeof(ui));
> -			if (error)
> +			if (error != 0)
>  				return (error);
>  
>  			INP_WLOCK_RECHECK(inp);
> @@ -2309,11 +2279,11 @@ unlock_and_done:
>  			break;
>  		case TCP_RXTLS_ENABLE:
>  			INP_WUNLOCK(inp);
> -			error = sooptcopyin(sopt, &tls, sizeof(tls),
> -			    sizeof(tls));
> -			if (error)
> +			error = ktls_copyin_tls_enable(sopt, &tls);
> +			if (error != 0)
>  				break;
>  			error = ktls_enable_rx(so, &tls);
> +			ktls_cleanup_tls_enable(&tls);
>  			break;
>  #endif
>  		case TCP_MAXUNACKTIME:
> diff --git a/sys/sys/ktls.h b/sys/sys/ktls.h
> index 693864394ffe..9b3433f4b1fd 100644
> --- a/sys/sys/ktls.h
> +++ b/sys/sys/ktls.h
> @@ -174,6 +174,7 @@ struct m_snd_tag;
>  struct mbuf;
>  struct sockbuf;
>  struct socket;
> +struct sockopt;
>  
>  struct ktls_session {
>  	struct ktls_ocf_session *ocf_session;
> @@ -213,27 +214,29 @@ typedef enum {
>  } ktls_mbuf_crypto_st_t;
>  
>  void ktls_check_rx(struct sockbuf *sb);
> -ktls_mbuf_crypto_st_t ktls_mbuf_crypto_state(struct mbuf *mb, int offset, int len);
> +void ktls_cleanup_tls_enable(struct tls_enable *tls);
> +int ktls_copyin_tls_enable(struct sockopt *sopt, struct tls_enable *tls);
>  void ktls_disable_ifnet(void *arg);
>  int ktls_enable_rx(struct socket *so, struct tls_enable *en);
>  int ktls_enable_tx(struct socket *so, struct tls_enable *en);
> +void ktls_enqueue(struct mbuf *m, struct socket *so, int page_count);
> +void ktls_enqueue_to_free(struct mbuf *m);
>  void ktls_destroy(struct ktls_session *tls);
>  void ktls_frame(struct mbuf *m, struct ktls_session *tls, int *enqueue_cnt,
>      uint8_t record_type);
> -bool ktls_permit_empty_frames(struct ktls_session *tls);
> -void ktls_seq(struct sockbuf *sb, struct mbuf *m);
> -void ktls_enqueue(struct mbuf *m, struct socket *so, int page_count);
> -void ktls_enqueue_to_free(struct mbuf *m);
>  int ktls_get_rx_mode(struct socket *so, int *modep);
> -int ktls_set_tx_mode(struct socket *so, int mode);
>  int ktls_get_tx_mode(struct socket *so, int *modep);
>  int ktls_get_rx_sequence(struct inpcb *inp, uint32_t *tcpseq, uint64_t *tlsseq);
>  void ktls_input_ifp_mismatch(struct sockbuf *sb, struct ifnet *ifp);
> -int ktls_output_eagain(struct inpcb *inp, struct ktls_session *tls);
> +ktls_mbuf_crypto_st_t ktls_mbuf_crypto_state(struct mbuf *mb, int offset, int len);
>  #ifdef RATELIMIT
>  int ktls_modify_txrtlmt(struct ktls_session *tls, uint64_t max_pacing_rate);
>  #endif
> +int ktls_output_eagain(struct inpcb *inp, struct ktls_session *tls);
>  bool ktls_pending_rx_info(struct sockbuf *sb, uint64_t *seqnop, size_t *residp);
> +bool ktls_permit_empty_frames(struct ktls_session *tls);
> +void ktls_seq(struct sockbuf *sb, struct mbuf *m);
> +int ktls_set_tx_mode(struct socket *so, int mode);
>  
>  static inline struct ktls_session *
>  ktls_hold(struct ktls_session *tls)
>