git: 96871af01382 - main - inpcb: use family specific sockaddr argument for bind functions

From: Gleb Smirnoff <glebius_at_FreeBSD.org>
Date: Wed, 15 Feb 2023 18:30:46 UTC
The branch main has been updated by glebius:

URL: https://cgit.FreeBSD.org/src/commit/?id=96871af01382ecaec59ccbf6999ba8ad76a5f9e9

commit 96871af01382ecaec59ccbf6999ba8ad76a5f9e9
Author:     Gleb Smirnoff <glebius@FreeBSD.org>
AuthorDate: 2023-02-15 18:30:16 +0000
Commit:     Gleb Smirnoff <glebius@FreeBSD.org>
CommitDate: 2023-02-15 18:30:16 +0000

    inpcb: use family specific sockaddr argument for bind functions
    
    Do the cast from sockaddr to either IPv4 or IPv6 sockaddr in the
    protocol's pr_bind method and from there on go down the call
    stack with family specific argument.
    
    Reviewed by:            zlei, melifaro, markj
    Differential Revision:  https://reviews.freebsd.org/D38601
---
 sys/netinet/in_pcb.c       | 22 ++++++++++------------
 sys/netinet/in_pcb.h       |  4 ++--
 sys/netinet/tcp_usrreq.c   |  7 +++----
 sys/netinet/udp_usrreq.c   |  6 +++---
 sys/netinet6/in6_pcb.c     |  7 ++-----
 sys/netinet6/in6_pcb.h     |  2 +-
 sys/netinet6/udp6_usrreq.c | 12 +++++-------
 7 files changed, 26 insertions(+), 34 deletions(-)

diff --git a/sys/netinet/in_pcb.c b/sys/netinet/in_pcb.c
index ad889f29de55..a23c89fe8033 100644
--- a/sys/netinet/in_pcb.c
+++ b/sys/netinet/in_pcb.c
@@ -655,21 +655,21 @@ out:
 
 #ifdef INET
 int
-in_pcbbind(struct inpcb *inp, struct sockaddr *nam, struct ucred *cred)
+in_pcbbind(struct inpcb *inp, struct sockaddr_in *sin, struct ucred *cred)
 {
 	int anonport, error;
 
-	KASSERT(nam == NULL || nam->sa_family == AF_INET,
-	    ("%s: invalid address family for %p", __func__, nam));
-	KASSERT(nam == NULL || nam->sa_len == sizeof(struct sockaddr_in),
-	    ("%s: invalid address length for %p", __func__, nam));
+	KASSERT(sin == NULL || sin->sin_family == AF_INET,
+	    ("%s: invalid address family for %p", __func__, sin));
+	KASSERT(sin == NULL || sin->sin_len == sizeof(struct sockaddr_in),
+	    ("%s: invalid address length for %p", __func__, sin));
 	INP_WLOCK_ASSERT(inp);
 	INP_HASH_WLOCK_ASSERT(inp->inp_pcbinfo);
 
 	if (inp->inp_lport != 0 || inp->inp_laddr.s_addr != INADDR_ANY)
 		return (EINVAL);
-	anonport = nam == NULL || ((struct sockaddr_in *)nam)->sin_port == 0;
-	error = in_pcbbind_setup(inp, nam, &inp->inp_laddr.s_addr,
+	anonport = sin == NULL || sin->sin_port == 0;
+	error = in_pcbbind_setup(inp, sin, &inp->inp_laddr.s_addr,
 	    &inp->inp_lport, cred);
 	if (error)
 		return (error);
@@ -901,11 +901,10 @@ in_pcbbind_check_bindmulti(const struct inpcb *ni, const struct inpcb *oi)
  * On error, the values of *laddrp and *lportp are not changed.
  */
 int
-in_pcbbind_setup(struct inpcb *inp, struct sockaddr *nam, in_addr_t *laddrp,
+in_pcbbind_setup(struct inpcb *inp, struct sockaddr_in *sin, in_addr_t *laddrp,
     u_short *lportp, struct ucred *cred)
 {
 	struct socket *so = inp->inp_socket;
-	struct sockaddr_in *sin;
 	struct inpcbinfo *pcbinfo = inp->inp_pcbinfo;
 	struct in_addr laddr;
 	u_short lport = 0;
@@ -925,15 +924,14 @@ in_pcbbind_setup(struct inpcb *inp, struct sockaddr *nam, in_addr_t *laddrp,
 	INP_HASH_LOCK_ASSERT(pcbinfo);
 
 	laddr.s_addr = *laddrp;
-	if (nam != NULL && laddr.s_addr != INADDR_ANY)
+	if (sin != NULL && laddr.s_addr != INADDR_ANY)
 		return (EINVAL);
 	if ((so->so_options & (SO_REUSEADDR|SO_REUSEPORT|SO_REUSEPORT_LB)) == 0)
 		lookupflags = INPLOOKUP_WILDCARD;
-	if (nam == NULL) {
+	if (sin == NULL) {
 		if ((error = prison_local_ip4(cred, &laddr)) != 0)
 			return (error);
 	} else {
-		sin = (struct sockaddr_in *)nam;
 		KASSERT(sin->sin_family == AF_INET,
 		    ("%s: invalid family for address %p", __func__, sin));
 		KASSERT(sin->sin_len == sizeof(*sin),
diff --git a/sys/netinet/in_pcb.h b/sys/netinet/in_pcb.h
index c450685affcb..f15fd0db4dfb 100644
--- a/sys/netinet/in_pcb.h
+++ b/sys/netinet/in_pcb.h
@@ -739,8 +739,8 @@ int	in_pcbbind_check_bindmulti(const struct inpcb *ni,
 
 void	in_pcbpurgeif0(struct inpcbinfo *, struct ifnet *);
 int	in_pcballoc(struct socket *, struct inpcbinfo *);
-int	in_pcbbind(struct inpcb *, struct sockaddr *, struct ucred *);
-int	in_pcbbind_setup(struct inpcb *, struct sockaddr *, in_addr_t *,
+int	in_pcbbind(struct inpcb *, struct sockaddr_in *, struct ucred *);
+int	in_pcbbind_setup(struct inpcb *, struct sockaddr_in *, in_addr_t *,
 	    u_short *, struct ucred *);
 int	in_pcbconnect(struct inpcb *, struct sockaddr_in *, struct ucred *,
 	    bool);
diff --git a/sys/netinet/tcp_usrreq.c b/sys/netinet/tcp_usrreq.c
index 5c98e969c5ce..1dbf4659ad00 100644
--- a/sys/netinet/tcp_usrreq.c
+++ b/sys/netinet/tcp_usrreq.c
@@ -245,7 +245,7 @@ tcp_usr_bind(struct socket *so, struct sockaddr *nam, struct thread *td)
 	tp = intotcpcb(inp);
 #endif
 	INP_HASH_WLOCK(&V_tcbinfo);
-	error = in_pcbbind(inp, nam, td->td_ucred);
+	error = in_pcbbind(inp, sinp, td->td_ucred);
 	INP_HASH_WUNLOCK(&V_tcbinfo);
 out:
 	TCP_PROBE2(debug__user, tp, PRU_BIND);
@@ -309,14 +309,13 @@ tcp6_usr_bind(struct socket *so, struct sockaddr *nam, struct thread *td)
 			}
 			inp->inp_vflag |= INP_IPV4;
 			inp->inp_vflag &= ~INP_IPV6;
-			error = in_pcbbind(inp, (struct sockaddr *)&sin,
-			    td->td_ucred);
+			error = in_pcbbind(inp, &sin, td->td_ucred);
 			INP_HASH_WUNLOCK(&V_tcbinfo);
 			goto out;
 		}
 	}
 #endif
-	error = in6_pcbbind(inp, nam, td->td_ucred);
+	error = in6_pcbbind(inp, sin6, td->td_ucred);
 	INP_HASH_WUNLOCK(&V_tcbinfo);
 out:
 	if (error != 0)
diff --git a/sys/netinet/udp_usrreq.c b/sys/netinet/udp_usrreq.c
index f911a79519b6..f039675b1e55 100644
--- a/sys/netinet/udp_usrreq.c
+++ b/sys/netinet/udp_usrreq.c
@@ -1206,8 +1206,8 @@ udp_send(struct socket *so, int flags, struct mbuf *m, struct sockaddr *addr,
 			goto release;
 		}
 		INP_HASH_WLOCK(pcbinfo);
-		error = in_pcbbind_setup(inp, (struct sockaddr *)&src,
-		    &laddr.s_addr, &lport, td->td_ucred);
+		error = in_pcbbind_setup(inp, &src, &laddr.s_addr, &lport,
+		    td->td_ucred);
 		INP_HASH_WUNLOCK(pcbinfo);
 		if (error)
 			goto release;
@@ -1546,7 +1546,7 @@ udp_bind(struct socket *so, struct sockaddr *nam, struct thread *td)
 
 	INP_WLOCK(inp);
 	INP_HASH_WLOCK(pcbinfo);
-	error = in_pcbbind(inp, nam, td->td_ucred);
+	error = in_pcbbind(inp, sinp, td->td_ucred);
 	INP_HASH_WUNLOCK(pcbinfo);
 	INP_WUNLOCK(inp);
 	return (error);
diff --git a/sys/netinet6/in6_pcb.c b/sys/netinet6/in6_pcb.c
index 2b8e48090f28..92a1ea840af2 100644
--- a/sys/netinet6/in6_pcb.c
+++ b/sys/netinet6/in6_pcb.c
@@ -153,11 +153,9 @@ in6_pcbsetport(struct in6_addr *laddr, struct inpcb *inp, struct ucred *cred)
 }
 
 int
-in6_pcbbind(struct inpcb *inp, struct sockaddr *nam,
-    struct ucred *cred)
+in6_pcbbind(struct inpcb *inp, struct sockaddr_in6 *sin6, struct ucred *cred)
 {
 	struct socket *so = inp->inp_socket;
-	struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)NULL;
 	struct inpcbinfo *pcbinfo = inp->inp_pcbinfo;
 	u_short	lport = 0;
 	int error, lookupflags = 0;
@@ -176,12 +174,11 @@ in6_pcbbind(struct inpcb *inp, struct sockaddr *nam,
 		return (EINVAL);
 	if ((so->so_options & (SO_REUSEADDR|SO_REUSEPORT|SO_REUSEPORT_LB)) == 0)
 		lookupflags = INPLOOKUP_WILDCARD;
-	if (nam == NULL) {
+	if (sin6 == NULL) {
 		if ((error = prison_local_ip6(cred, &inp->in6p_laddr,
 		    ((inp->inp_flags & IN6P_IPV6_V6ONLY) != 0))) != 0)
 			return (error);
 	} else {
-		sin6 = (struct sockaddr_in6 *)nam;
 		KASSERT(sin6->sin6_family == AF_INET6,
 		    ("%s: invalid address family for %p", __func__, sin6));
 		KASSERT(sin6->sin6_len == sizeof(*sin6),
diff --git a/sys/netinet6/in6_pcb.h b/sys/netinet6/in6_pcb.h
index 800d26e8b3d6..91131d1968bc 100644
--- a/sys/netinet6/in6_pcb.h
+++ b/sys/netinet6/in6_pcb.h
@@ -73,7 +73,7 @@
 
 void	in6_pcbpurgeif0(struct inpcbinfo *, struct ifnet *);
 void	in6_losing(struct inpcb *);
-int	in6_pcbbind(struct inpcb *, struct sockaddr *, struct ucred *);
+int	in6_pcbbind(struct inpcb *, struct sockaddr_in6 *, struct ucred *);
 int	in6_pcbconnect(struct inpcb *, struct sockaddr_in6 *, struct ucred *,
 	    bool);
 void	in6_pcbdisconnect(struct inpcb *);
diff --git a/sys/netinet6/udp6_usrreq.c b/sys/netinet6/udp6_usrreq.c
index 3e6c57a8c6ff..8a95e1623f9c 100644
--- a/sys/netinet6/udp6_usrreq.c
+++ b/sys/netinet6/udp6_usrreq.c
@@ -1020,6 +1020,7 @@ udp6_attach(struct socket *so, int proto, struct thread *td)
 static int
 udp6_bind(struct socket *so, struct sockaddr *nam, struct thread *td)
 {
+	struct sockaddr_in6 *sin6_p;
 	struct inpcb *inp;
 	struct inpcbinfo *pcbinfo;
 	int error;
@@ -1034,16 +1035,14 @@ udp6_bind(struct socket *so, struct sockaddr *nam, struct thread *td)
 	if (nam->sa_len != sizeof(struct sockaddr_in6))
 		return (EINVAL);
 
+	sin6_p = (struct sockaddr_in6 *)nam;
+
 	INP_WLOCK(inp);
 	INP_HASH_WLOCK(pcbinfo);
 	vflagsav = inp->inp_vflag;
 	inp->inp_vflag &= ~INP_IPV4;
 	inp->inp_vflag |= INP_IPV6;
 	if ((inp->inp_flags & IN6P_IPV6_V6ONLY) == 0) {
-		struct sockaddr_in6 *sin6_p;
-
-		sin6_p = (struct sockaddr_in6 *)nam;
-
 		if (IN6_IS_ADDR_UNSPECIFIED(&sin6_p->sin6_addr))
 			inp->inp_vflag |= INP_IPV4;
 #ifdef INET
@@ -1053,14 +1052,13 @@ udp6_bind(struct socket *so, struct sockaddr *nam, struct thread *td)
 			in6_sin6_2_sin(&sin, sin6_p);
 			inp->inp_vflag |= INP_IPV4;
 			inp->inp_vflag &= ~INP_IPV6;
-			error = in_pcbbind(inp, (struct sockaddr *)&sin,
-			    td->td_ucred);
+			error = in_pcbbind(inp, &sin, td->td_ucred);
 			goto out;
 		}
 #endif
 	}
 
-	error = in6_pcbbind(inp, nam, td->td_ucred);
+	error = in6_pcbbind(inp, sin6_p, td->td_ucred);
 #ifdef INET
 out:
 #endif