git: 3b3c08c13586 - main - tcp: cleanup functions related to socket option handling

From: Michael Tuexen <tuexen_at_FreeBSD.org>
Date: Thu, 03 Feb 2022 13:46:46 UTC
The branch main has been updated by tuexen:

URL: https://cgit.FreeBSD.org/src/commit/?id=3b3c08c13586e23e1625425a60eaee79a3aed590

commit 3b3c08c13586e23e1625425a60eaee79a3aed590
Author:     Michael Tuexen <tuexen@FreeBSD.org>
AuthorDate: 2022-02-02 08:20:43 +0000
Commit:     Michael Tuexen <tuexen@FreeBSD.org>
CommitDate: 2022-02-02 08:27:59 +0000

    tcp: cleanup functions related to socket option handling
    
    Consistently only pass the inp and the sopt around. Don't pass the
    so around, since in a upcoming commit tcp_ctloutput_set() will be
    called from a context different from setsockopt(). Also expect
    the inp to be locked when calling tcp_ctloutput_[gs]et(), this is
    also required for the upcoming use by tcpsso, a command line tool
    to set socket options.
    Reviewed by:            glebius, rscheff
    Sponsored by:           Netflix, Inc.
    Differential Revision:  https://reviews.freebsd.org/D34151
---
 sys/netinet/tcp_stacks/bbr.c  | 46 +++++++++++++++---------------
 sys/netinet/tcp_stacks/rack.c | 65 ++++++++++++++++++++++++-------------------
 sys/netinet/tcp_usrreq.c      | 52 +++++++++++++++++++---------------
 sys/netinet/tcp_var.h         |  5 ++--
 4 files changed, 91 insertions(+), 77 deletions(-)

diff --git a/sys/netinet/tcp_stacks/bbr.c b/sys/netinet/tcp_stacks/bbr.c
index 1ddcd18be8c6..c5cf8a46880f 100644
--- a/sys/netinet/tcp_stacks/bbr.c
+++ b/sys/netinet/tcp_stacks/bbr.c
@@ -519,8 +519,7 @@ bbr_log_pacing_delay_calc(struct tcp_bbr *bbr, uint16_t gain, uint32_t len,
     uint32_t cts, uint32_t usecs, uint64_t bw, uint32_t override, int mod);
 
 static int
-bbr_ctloutput(struct socket *so, struct sockopt *sopt, struct inpcb *inp,
-    struct tcpcb *tp);
+bbr_ctloutput(struct inpcb *inp, struct sockopt *sopt);
 
 static inline uint8_t
 bbr_state_val(struct tcp_bbr *bbr)
@@ -14235,16 +14234,17 @@ struct tcp_function_block __tcp_bbr = {
  * option.
  */
 static int
-bbr_set_sockopt(struct socket *so, struct sockopt *sopt,
-		struct inpcb *inp, struct tcpcb *tp, struct tcp_bbr *bbr)
+bbr_set_sockopt(struct inpcb *inp, struct sockopt *sopt)
 {
 	struct epoch_tracker et;
+	struct tcpcb *tp;
+	struct tcp_bbr *bbr;
 	int32_t error = 0, optval;
 
 	switch (sopt->sopt_level) {
 	case IPPROTO_IPV6:
 	case IPPROTO_IP:
-		return (tcp_default_ctloutput(so, sopt, inp, tp));
+		return (tcp_default_ctloutput(inp, sopt));
 	}
 
 	switch (sopt->sopt_name) {
@@ -14293,7 +14293,7 @@ bbr_set_sockopt(struct socket *so, struct sockopt *sopt,
 	case TCP_BBR_RETRAN_WTSO:
 		break;
 	default:
-		return (tcp_default_ctloutput(so, sopt, inp, tp));
+		return (tcp_default_ctloutput(inp, sopt));
 		break;
 	}
 	INP_WUNLOCK(inp);
@@ -14629,7 +14629,7 @@ bbr_set_sockopt(struct socket *so, struct sockopt *sopt,
 		}
 		break;
 	default:
-		return (tcp_default_ctloutput(so, sopt, inp, tp));
+		return (tcp_default_ctloutput(inp, sopt));
 		break;
 	}
 #ifdef NETFLIX_STATS
@@ -14643,11 +14643,18 @@ bbr_set_sockopt(struct socket *so, struct sockopt *sopt,
  * return 0 on success, error-num on failure
  */
 static int
-bbr_get_sockopt(struct socket *so, struct sockopt *sopt,
-    struct inpcb *inp, struct tcpcb *tp, struct tcp_bbr *bbr)
+bbr_get_sockopt(struct inpcb *inp, struct sockopt *sopt)
 {
+	struct tcpcb *tp;
+	struct tcp_bbr *bbr;
 	int32_t error, optval;
 
+	tp = intotcpcb(inp);
+	bbr = (struct tcp_bbr *)tp->t_fb_ptr;
+	if (bbr == NULL) {
+		INP_WUNLOCK(inp);
+		return (EINVAL);
+	}
 	/*
 	 * Because all our options are either boolean or an int, we can just
 	 * pull everything into optval and then unlock and copy. If we ever
@@ -14781,7 +14788,7 @@ bbr_get_sockopt(struct socket *so, struct sockopt *sopt,
 			optval |= BBR_INCL_ENET_OH;
 		break;
 	default:
-		return (tcp_default_ctloutput(so, sopt, inp, tp));
+		return (tcp_default_ctloutput(inp, sopt));
 		break;
 	}
 	INP_WUNLOCK(inp);
@@ -14793,24 +14800,15 @@ bbr_get_sockopt(struct socket *so, struct sockopt *sopt,
  * return 0 on success, error-num on failure
  */
 static int
-bbr_ctloutput(struct socket *so, struct sockopt *sopt, struct inpcb *inp, struct tcpcb *tp)
+bbr_ctloutput(struct inpcb *inp, struct sockopt *sopt)
 {
-	int32_t error = EINVAL;
-	struct tcp_bbr *bbr;
-
-	bbr = (struct tcp_bbr *)tp->t_fb_ptr;
-	if (bbr == NULL) {
-		/* Huh? */
-		goto out;
-	}
 	if (sopt->sopt_dir == SOPT_SET) {
-		return (bbr_set_sockopt(so, sopt, inp, tp, bbr));
+		return (bbr_set_sockopt(inp, sopt));
 	} else if (sopt->sopt_dir == SOPT_GET) {
-		return (bbr_get_sockopt(so, sopt, inp, tp, bbr));
+		return (bbr_get_sockopt(inp, sopt));
+	} else {
+		panic("%s: sopt_dir $%d", __func__, sopt->sopt_dir);
 	}
-out:
-	INP_WUNLOCK(inp);
-	return (error);
 }
 
 static const char *bbr_stack_names[] = {
diff --git a/sys/netinet/tcp_stacks/rack.c b/sys/netinet/tcp_stacks/rack.c
index 63e9d5220b19..04174b8a76e8 100644
--- a/sys/netinet/tcp_stacks/rack.c
+++ b/sys/netinet/tcp_stacks/rack.c
@@ -449,8 +449,7 @@ rack_cong_signal(struct tcpcb *tp,
 		 uint32_t type, uint32_t ack);
 static void rack_counter_destroy(void);
 static int
-rack_ctloutput(struct socket *so, struct sockopt *sopt,
-    struct inpcb *inp, struct tcpcb *tp);
+rack_ctloutput(struct inpcb *inp, struct sockopt *sopt);
 static int32_t rack_ctor(void *mem, int32_t size, void *arg, int32_t how);
 static void
 rack_set_pace_segments(struct tcpcb *tp, struct tcp_rack *rack, uint32_t line, uint64_t *fill_override);
@@ -477,8 +476,7 @@ static struct rack_sendmap *rack_find_lowest_rsm(struct tcp_rack *rack);
 static void rack_free(struct tcp_rack *rack, struct rack_sendmap *rsm);
 static void rack_fini(struct tcpcb *tp, int32_t tcb_is_purged);
 static int
-rack_get_sockopt(struct socket *so, struct sockopt *sopt,
-    struct inpcb *inp, struct tcpcb *tp, struct tcp_rack *rack);
+rack_get_sockopt(struct sockopt *sopt, struct inpcb *inp);
 static void
 rack_do_goodput_measurement(struct tcpcb *tp, struct tcp_rack *rack,
 			    tcp_seq th_ack, int line, uint8_t quality);
@@ -508,8 +506,7 @@ rack_proc_sack_blk(struct tcpcb *tp, struct tcp_rack *rack,
 static void rack_post_recovery(struct tcpcb *tp, uint32_t th_seq);
 static void rack_remxt_tmr(struct tcpcb *tp);
 static int
-rack_set_sockopt(struct socket *so, struct sockopt *sopt,
-    struct inpcb *inp, struct tcpcb *tp, struct tcp_rack *rack);
+rack_set_sockopt(struct inpcb *inp, struct sockopt *sopt);
 static void rack_set_state(struct tcpcb *tp, struct tcp_rack *rack);
 static int32_t rack_stopall(struct tcpcb *tp);
 static void
@@ -20437,18 +20434,32 @@ static struct tcp_function_block __tcp_rack = {
  * option.
  */
 static int
-rack_set_sockopt(struct socket *so, struct sockopt *sopt,
-    struct inpcb *inp, struct tcpcb *tp, struct tcp_rack *rack)
+rack_set_sockopt(struct inpcb *inp, struct sockopt *sopt)
 {
 #ifdef INET6
-	struct ip6_hdr *ip6 = (struct ip6_hdr *)rack->r_ctl.fsb.tcp_ip_hdr;
+	struct ip6_hdr *ip6;
 #endif
 #ifdef INET
-	struct ip *ip = (struct ip *)rack->r_ctl.fsb.tcp_ip_hdr;
+	struct ip *ip;
 #endif
+	struct tcpcb *tp;
+	struct tcp_rack *rack;
 	uint64_t loptval;
 	int32_t error = 0, optval;
 
+	tp = intotcpcb(inp);
+	rack = (struct tcp_rack *)tp->t_fb_ptr;
+	if (rack == NULL) {
+		INP_WUNLOCK(inp);
+		return (EINVAL);
+	}
+#ifdef INET6
+	ip6 = (struct ip6_hdr *)rack->r_ctl.fsb.tcp_ip_hdr;
+#endif
+#ifdef INET
+	ip = (struct ip *)rack->r_ctl.fsb.tcp_ip_hdr;
+#endif
+
 	switch (sopt->sopt_level) {
 #ifdef INET6
 	case IPPROTO_IPV6:
@@ -20545,7 +20556,7 @@ rack_set_sockopt(struct socket *so, struct sockopt *sopt,
 		break;
 	default:
 		/* Filter off all unknown options to the base stack */
-		return (tcp_default_ctloutput(so, sopt, inp, tp));
+		return (tcp_default_ctloutput(inp, sopt));
 		break;
 	}
 	INP_WUNLOCK(inp);
@@ -20648,9 +20659,10 @@ rack_fill_info(struct tcpcb *tp, struct tcp_info *ti)
 }
 
 static int
-rack_get_sockopt(struct socket *so, struct sockopt *sopt,
-    struct inpcb *inp, struct tcpcb *tp, struct tcp_rack *rack)
+rack_get_sockopt(struct inpcb *inp, struct sockopt *sopt)
 {
+	struct tcpcb *tp;
+	struct tcp_rack *rack;
 	int32_t error, optval;
 	uint64_t val, loptval;
 	struct	tcp_info ti;
@@ -20661,6 +20673,12 @@ rack_get_sockopt(struct socket *so, struct sockopt *sopt,
 	 * impact to this routine.
 	 */
 	error = 0;
+	tp = intotcpcb(inp);
+	rack = (struct tcp_rack *)tp->t_fb_ptr;
+	if (rack == NULL) {
+		INP_WUNLOCK(inp);
+		return (EINVAL);
+	}
 	switch (sopt->sopt_name) {
 	case TCP_INFO:
 		/* First get the info filled */
@@ -20901,7 +20919,7 @@ rack_get_sockopt(struct socket *so, struct sockopt *sopt,
 		optval = rack->r_ctl.timer_slop;
 		break;
 	default:
-		return (tcp_default_ctloutput(so, sopt, inp, tp));
+		return (tcp_default_ctloutput(inp, sopt));
 		break;
 	}
 	INP_WUNLOCK(inp);
@@ -20915,24 +20933,15 @@ rack_get_sockopt(struct socket *so, struct sockopt *sopt,
 }
 
 static int
-rack_ctloutput(struct socket *so, struct sockopt *sopt, struct inpcb *inp, struct tcpcb *tp)
+rack_ctloutput(struct inpcb *inp, struct sockopt *sopt)
 {
-	int32_t error = EINVAL;
-	struct tcp_rack *rack;
-
-	rack = (struct tcp_rack *)tp->t_fb_ptr;
-	if (rack == NULL) {
-		/* Huh? */
-		goto out;
-	}
 	if (sopt->sopt_dir == SOPT_SET) {
-		return (rack_set_sockopt(so, sopt, inp, tp, rack));
+		return (rack_set_sockopt(inp, sopt));
 	} else if (sopt->sopt_dir == SOPT_GET) {
-		return (rack_get_sockopt(so, sopt, inp, tp, rack));
+		return (rack_get_sockopt(inp, sopt));
+	} else {
+		panic("%s: sopt_dir $%d", __func__, sopt->sopt_dir);
 	}
-out:
-	INP_WUNLOCK(inp);
-	return (error);
 }
 
 static const char *rack_stack_names[] = {
diff --git a/sys/netinet/tcp_usrreq.c b/sys/netinet/tcp_usrreq.c
index db3f85b43acc..f2652811b86a 100644
--- a/sys/netinet/tcp_usrreq.c
+++ b/sys/netinet/tcp_usrreq.c
@@ -1719,8 +1719,10 @@ tcp_ctloutput_set(struct inpcb *inp, struct sockopt *sopt)
 	int error = 0;
 
 	MPASS(sopt->sopt_dir == SOPT_SET);
+	INP_WLOCK_ASSERT(inp);
 
 	if (sopt->sopt_level != IPPROTO_TCP) {
+		INP_WUNLOCK(inp);
 #ifdef INET6
 		if (inp->inp_vflag & INP_IPV6PROTO)
 			error = ip6_ctloutput(inp->inp_socket, sopt);
@@ -1768,6 +1770,11 @@ tcp_ctloutput_set(struct inpcb *inp, struct sockopt *sopt)
 		default:
 			return (error);
 		}
+		INP_WLOCK(inp);
+		if (inp->inp_flags & (INP_TIMEWAIT | INP_DROPPED)) {
+			INP_WUNLOCK(inp);
+			return (ECONNRESET);
+		}
 	} else if (sopt->sopt_name == TCP_FUNCTION_BLK) {
 		/*
 		 * Protect the TCP option TCP_FUNCTION_BLK so
@@ -1776,6 +1783,7 @@ tcp_ctloutput_set(struct inpcb *inp, struct sockopt *sopt)
 		struct tcp_function_set fsn;
 		struct tcp_function_block *blk;
 
+		INP_WUNLOCK(inp);
 		error = sooptcopyin(sopt, &fsn, sizeof fsn, sizeof fsn);
 		if (error)
 			return (error);
@@ -1871,15 +1879,10 @@ err_out:
 		return (error);
 	}
 
-	INP_WLOCK(inp);
-	if (inp->inp_flags & (INP_TIMEWAIT | INP_DROPPED)) {
-		INP_WUNLOCK(inp);
-		return (ECONNRESET);
-	}
 	tp = intotcpcb(inp);
 
-	/* Pass in the INP locked, caller must unlock it. */
-	return (tp->t_fb->tfb_tcp_ctloutput(inp->inp_socket, sopt, inp, tp));
+	/* Pass in the INP locked, callee must unlock it. */
+	return (tp->t_fb->tfb_tcp_ctloutput(inp, sopt));
 }
 
 static int
@@ -1889,8 +1892,10 @@ tcp_ctloutput_get(struct inpcb *inp, struct sockopt *sopt)
 	struct	tcpcb *tp;
 
 	MPASS(sopt->sopt_dir == SOPT_GET);
+	INP_WLOCK_ASSERT(inp);
 
 	if (sopt->sopt_level != IPPROTO_TCP) {
+		INP_WUNLOCK(inp);
 #ifdef INET6
 		if (inp->inp_vflag & INP_IPV6PROTO)
 			error = ip6_ctloutput(inp->inp_socket, sopt);
@@ -1903,11 +1908,6 @@ tcp_ctloutput_get(struct inpcb *inp, struct sockopt *sopt)
 #endif
 		return (error);
 	}
-	INP_WLOCK(inp);
-	if (inp->inp_flags & (INP_TIMEWAIT | INP_DROPPED)) {
-		INP_WUNLOCK(inp);
-		return (ECONNRESET);
-	}
 	tp = intotcpcb(inp);
 	if (((sopt->sopt_name == TCP_FUNCTION_BLK) ||
 	     (sopt->sopt_name == TCP_FUNCTION_ALIAS))) {
@@ -1928,8 +1928,8 @@ tcp_ctloutput_get(struct inpcb *inp, struct sockopt *sopt)
 		return (error);
 	}
 
-	/* Pass in the INP locked, caller must unlock it. */
-	return (tp->t_fb->tfb_tcp_ctloutput(inp->inp_socket, sopt, inp, tp));
+	/* Pass in the INP locked, callee must unlock it. */
+	return (tp->t_fb->tfb_tcp_ctloutput(inp, sopt));
 }
 
 int
@@ -1940,6 +1940,11 @@ tcp_ctloutput(struct socket *so, struct sockopt *sopt)
 	inp = sotoinpcb(so);
 	KASSERT(inp != NULL, ("tcp_ctloutput: inp == NULL"));
 
+	INP_WLOCK(inp);
+	if (inp->inp_flags & (INP_TIMEWAIT | INP_DROPPED)) {
+		INP_WUNLOCK(inp);
+		return (ECONNRESET);
+	}
 	if (sopt->sopt_dir == SOPT_SET)
 		return (tcp_ctloutput_set(inp, sopt));
 	else if (sopt->sopt_dir == SOPT_GET)
@@ -1991,10 +1996,11 @@ copyin_tls_enable(struct sockopt *sopt, struct tls_enable *tls)
 extern struct cc_algo newreno_cc_algo;
 
 static int
-tcp_congestion(struct socket *so, struct sockopt *sopt, struct inpcb *inp, struct tcpcb *tp)
+tcp_congestion(struct inpcb *inp, struct sockopt *sopt)
 {
 	struct cc_algo *algo;
 	void *ptr = NULL;
+	struct tcpcb *tp;
 	struct cc_var cc_mem;
 	char	buf[TCP_CA_NAME_MAX];
 	size_t mem_sz;
@@ -2103,8 +2109,9 @@ no_mem_needed:
 }
 
 int
-tcp_default_ctloutput(struct socket *so, struct sockopt *sopt, struct inpcb *inp, struct tcpcb *tp)
+tcp_default_ctloutput(struct inpcb *inp, struct sockopt *sopt)
 {
+	struct tcpcb *tp;
 	int	error, opt, optval;
 	u_int	ui;
 	struct	tcp_info ti;
@@ -2119,6 +2126,7 @@ tcp_default_ctloutput(struct socket *so, struct sockopt *sopt, struct inpcb *inp
 
 	INP_WLOCK_ASSERT(inp);
 
+	tp = intotcpcb(inp);
 	switch (sopt->sopt_level) {
 #ifdef INET6
 	case IPPROTO_IPV6:
@@ -2317,7 +2325,7 @@ unlock_and_done:
 			break;
 
 		case TCP_CONGESTION:
-			error = tcp_congestion(so, sopt, inp, tp);
+			error = tcp_congestion(inp, sopt);
 			break;
 
 		case TCP_REUSPORT_LB_NUMA:
@@ -2336,7 +2344,7 @@ unlock_and_done:
 			error = copyin_tls_enable(sopt, &tls);
 			if (error)
 				break;
-			error = ktls_enable_tx(so, &tls);
+			error = ktls_enable_tx(inp->inp_socket, &tls);
 			break;
 		case TCP_TXTLS_MODE:
 			INP_WUNLOCK(inp);
@@ -2345,7 +2353,7 @@ unlock_and_done:
 				return (error);
 
 			INP_WLOCK_RECHECK(inp);
-			error = ktls_set_tx_mode(so, ui);
+			error = ktls_set_tx_mode(inp->inp_socket, ui);
 			INP_WUNLOCK(inp);
 			break;
 		case TCP_RXTLS_ENABLE:
@@ -2354,7 +2362,7 @@ unlock_and_done:
 			    sizeof(tls));
 			if (error)
 				break;
-			error = ktls_enable_rx(so, &tls);
+			error = ktls_enable_rx(inp->inp_socket, &tls);
 			break;
 #endif
 
@@ -2699,14 +2707,14 @@ unhold:
 #endif
 #ifdef KERN_TLS
 		case TCP_TXTLS_MODE:
-			error = ktls_get_tx_mode(so, &optval);
+			error = ktls_get_tx_mode(inp->inp_socket, &optval);
 			INP_WUNLOCK(inp);
 			if (error == 0)
 				error = sooptcopyout(sopt, &optval,
 				    sizeof(optval));
 			break;
 		case TCP_RXTLS_MODE:
-			error = ktls_get_rx_mode(so, &optval);
+			error = ktls_get_rx_mode(inp->inp_socket, &optval);
 			INP_WUNLOCK(inp);
 			if (error == 0)
 				error = sooptcopyout(sopt, &optval,
diff --git a/sys/netinet/tcp_var.h b/sys/netinet/tcp_var.h
index 07788ada3985..ccfd9a8f11e2 100644
--- a/sys/netinet/tcp_var.h
+++ b/sys/netinet/tcp_var.h
@@ -358,8 +358,7 @@ struct tcp_function_block {
 			    struct socket *, struct tcpcb *,
 			    int, int, uint8_t,
 			    int, struct timeval *);
-	int     (*tfb_tcp_ctloutput)(struct socket *so, struct sockopt *sopt,
-			    struct inpcb *inp, struct tcpcb *tp);
+	int     (*tfb_tcp_ctloutput)(struct inpcb *inp, struct sockopt *sopt);
 	/* Optional memory allocation/free routine */
 	int	(*tfb_tcp_fb_init)(struct tcpcb *);
 	void	(*tfb_tcp_fb_fini)(struct tcpcb *, int);
@@ -1128,7 +1127,7 @@ int find_tcp_function_alias(struct tcp_function_block *blk, struct tcp_function_
 void tcp_switch_back_to_default(struct tcpcb *tp);
 struct tcp_function_block *
 find_and_ref_tcp_fb(struct tcp_function_block *fs);
-int tcp_default_ctloutput(struct socket *so, struct sockopt *sopt, struct inpcb *inp, struct tcpcb *tp);
+int tcp_default_ctloutput(struct inpcb *inp, struct sockopt *sopt);
 
 extern counter_u64_t tcp_inp_lro_direct_queue;
 extern counter_u64_t tcp_inp_lro_wokeup_queue;