git: 3212ad15abde - main - Add getsock

From: Mateusz Guzik <mjg_at_FreeBSD.org>
Date: Sat, 10 Sep 2022 19:51:37 UTC
The branch main has been updated by mjg:

URL: https://cgit.FreeBSD.org/src/commit/?id=3212ad15abde2bd40030c7818672fd488da548d1

commit 3212ad15abde2bd40030c7818672fd488da548d1
Author:     Mateusz Guzik <mjg@FreeBSD.org>
AuthorDate: 2022-09-07 15:41:55 +0000
Commit:     Mateusz Guzik <mjg@FreeBSD.org>
CommitDate: 2022-09-10 19:47:47 +0000

    Add getsock
    
    All but one consumers of getsock_cap only pass 4 arguments.
    Take advantage of it.
---
 sys/compat/linux/linux_socket.c | 29 ++++++++++----------------
 sys/kern/kern_sendfile.c        |  3 +--
 sys/kern/uipc_syscalls.c        | 46 ++++++++++++++++++++++++-----------------
 sys/net/if_ovpn.c               |  4 +---
 sys/netinet/sctp_syscalls.c     | 13 ++++++------
 sys/sys/socketvar.h             |  2 ++
 6 files changed, 49 insertions(+), 48 deletions(-)

diff --git a/sys/compat/linux/linux_socket.c b/sys/compat/linux/linux_socket.c
index 9434e40709bb..fa9c39673fee 100644
--- a/sys/compat/linux/linux_socket.c
+++ b/sys/compat/linux/linux_socket.c
@@ -970,7 +970,6 @@ linux_connect(struct thread *td, struct linux_connect_args *args)
 	struct socket *so;
 	struct sockaddr *sa;
 	struct file *fp;
-	u_int fflag;
 	int error;
 
 	error = linux_to_bsd_sockaddr(PTRIN(args->name), &sa,
@@ -988,14 +987,13 @@ linux_connect(struct thread *td, struct linux_connect_args *args)
 	 * when on a non-blocking socket. Instead it returns the
 	 * error getsockopt(SOL_SOCKET, SO_ERROR) would return on BSD.
 	 */
-	error = getsock_cap(td, args->s, &cap_connect_rights,
-	    &fp, &fflag, NULL);
+	error = getsock(td, args->s, &cap_connect_rights, &fp);
 	if (error != 0)
 		return (error);
 
 	error = EISCONN;
 	so = fp->f_data;
-	if (fflag & FNONBLOCK) {
+	if (atomic_load_int(&fp->f_flag) & FNONBLOCK) {
 		SOCK_LOCK(so);
 		if (so->so_emuldata == 0)
 			error = so->so_error;
@@ -1058,7 +1056,7 @@ linux_accept_common(struct thread *td, int s, l_uintptr_t addr,
 				error = EINVAL;
 			break;
 		case EINVAL:
-			error1 = getsock_cap(td, s, &cap_accept_rights, &fp1, NULL, NULL);
+			error1 = getsock(td, s, &cap_accept_rights, &fp1);
 			if (error1 != 0) {
 				error = error1;
 				break;
@@ -1207,7 +1205,7 @@ linux_send(struct thread *td, struct linux_send_args *args)
 		int tolen;
 	} */ bsd_args;
 	struct file *fp;
-	int error, fflag;
+	int error;
 
 	bsd_args.s = args->s;
 	bsd_args.buf = (caddr_t)PTRIN(args->msg);
@@ -1221,10 +1219,9 @@ linux_send(struct thread *td, struct linux_send_args *args)
 		 * Linux doesn't return ENOTCONN for non-blocking sockets.
 		 * Instead it returns the EAGAIN.
 		 */
-		error = getsock_cap(td, args->s, &cap_send_rights, &fp,
-		    &fflag, NULL);
+		error = getsock(td, args->s, &cap_send_rights, &fp);
 		if (error == 0) {
-			if (fflag & FNONBLOCK)
+			if (atomic_load_int(&fp->f_flag) & FNONBLOCK)
 				error = EAGAIN;
 			fdrop(fp, td);
 		}
@@ -1275,8 +1272,7 @@ linux_sendto(struct thread *td, struct linux_sendto_args *args)
 		return (linux_sendto_hdrincl(td, args));
 
 	bzero(&msg, sizeof(msg));
-	error = getsock_cap(td, args->s, &cap_send_connect_rights,
-	    &fp, NULL, NULL);
+	error = getsock(td, args->s, &cap_send_connect_rights, &fp);
 	if (error != 0)
 		return (error);
 	so = fp->f_data;
@@ -1366,7 +1362,7 @@ linux_sendmsg_common(struct thread *td, l_int s, struct l_msghdr *msghdr,
 	void *data;
 	l_size_t len;
 	l_size_t clen;
-	int error, fflag;
+	int error;
 
 	error = copyin(msghdr, &linux_msghdr, sizeof(linux_msghdr));
 	if (error != 0)
@@ -1409,8 +1405,7 @@ linux_sendmsg_common(struct thread *td, l_int s, struct l_msghdr *msghdr,
 		if (sa_family == AF_UNIX)
 			goto bad;
 
-		error = getsock_cap(td, s, &cap_send_rights, &fp,
-		    &fflag, NULL);
+		error = getsock(td, s, &cap_send_rights, &fp);
 		if (error != 0)
 			goto bad;
 		so = fp->f_data;
@@ -1908,8 +1903,7 @@ linux_recvmsg(struct thread *td, struct linux_recvmsg_args *args)
 	struct file *fp;
 	int error;
 
-	error = getsock_cap(td, args->s, &cap_recv_rights,
-	    &fp, NULL, NULL);
+	error = getsock(td, args->s, &cap_recv_rights, &fp);
 	if (error != 0)
 		return (error);
 	fdrop(fp, td);
@@ -1927,8 +1921,7 @@ linux_recvmmsg_common(struct thread *td, l_int s, struct l_mmsghdr *msg,
 	l_uint retval;
 	int error, datagrams;
 
-	error = getsock_cap(td, s, &cap_recv_rights,
-	    &fp, NULL, NULL);
+	error = getsock(td, s, &cap_recv_rights, &fp);
 	if (error != 0)
 		return (error);
 	datagrams = 0;
diff --git a/sys/kern/kern_sendfile.c b/sys/kern/kern_sendfile.c
index f444e38e153d..96f95e4c841f 100644
--- a/sys/kern/kern_sendfile.c
+++ b/sys/kern/kern_sendfile.c
@@ -653,8 +653,7 @@ sendfile_getsock(struct thread *td, int s, struct file **sock_fp,
 	/*
 	 * The socket must be a stream socket and connected.
 	 */
-	error = getsock_cap(td, s, &cap_send_rights,
-	    sock_fp, NULL, NULL);
+	error = getsock(td, s, &cap_send_rights, sock_fp);
 	if (error != 0)
 		return (error);
 	*so = (*sock_fp)->f_data;
diff --git a/sys/kern/uipc_syscalls.c b/sys/kern/uipc_syscalls.c
index e77475992d0b..c269bd09f139 100644
--- a/sys/kern/uipc_syscalls.c
+++ b/sys/kern/uipc_syscalls.c
@@ -110,6 +110,23 @@ getsock_cap(struct thread *td, int fd, cap_rights_t *rightsp,
 	return (0);
 }
 
+int
+getsock(struct thread *td, int fd, cap_rights_t *rightsp, struct file **fpp)
+{
+	struct file *fp;
+	int error;
+
+	error = fget_unlocked(td, fd, rightsp, &fp);
+	if (__predict_false(error != 0))
+		return (error);
+	if (__predict_false(fp->f_type != DTYPE_SOCKET)) {
+		fdrop(fp, td);
+		return (ENOTSOCK);
+	}
+	*fpp = fp;
+	return (0);
+}
+
 /*
  * System call interface to the socket abstraction.
  */
@@ -194,8 +211,7 @@ kern_bindat(struct thread *td, int dirfd, int fd, struct sockaddr *sa)
 
 	AUDIT_ARG_FD(fd);
 	AUDIT_ARG_SOCKADDR(td, dirfd, sa);
-	error = getsock_cap(td, fd, &cap_bind_rights,
-	    &fp, NULL, NULL);
+	error = getsock(td, fd, &cap_bind_rights, &fp);
 	if (error != 0)
 		return (error);
 	so = fp->f_data;
@@ -247,8 +263,7 @@ kern_listen(struct thread *td, int s, int backlog)
 	int error;
 
 	AUDIT_ARG_FD(s);
-	error = getsock_cap(td, s, &cap_listen_rights,
-	    &fp, NULL, NULL);
+	error = getsock(td, s, &cap_listen_rights, &fp);
 	if (error == 0) {
 		so = fp->f_data;
 #ifdef MAC
@@ -491,8 +506,7 @@ kern_connectat(struct thread *td, int dirfd, int fd, struct sockaddr *sa)
 
 	AUDIT_ARG_FD(fd);
 	AUDIT_ARG_SOCKADDR(td, dirfd, sa);
-	error = getsock_cap(td, fd, &cap_connect_rights,
-	    &fp, NULL, NULL);
+	error = getsock(td, fd, &cap_connect_rights, &fp);
 	if (error != 0)
 		return (error);
 	so = fp->f_data;
@@ -738,7 +752,7 @@ kern_sendit(struct thread *td, int s, struct msghdr *mp, int flags,
 		AUDIT_ARG_SOCKADDR(td, AT_FDCWD, mp->msg_name);
 		rights = &cap_send_connect_rights;
 	}
-	error = getsock_cap(td, s, rights, &fp, NULL, NULL);
+	error = getsock(td, s, rights, &fp);
 	if (error != 0) {
 		m_freem(control);
 		return (error);
@@ -916,8 +930,7 @@ kern_recvit(struct thread *td, int s, struct msghdr *mp, enum uio_seg fromseg,
 		*controlp = NULL;
 
 	AUDIT_ARG_FD(s);
-	error = getsock_cap(td, s, &cap_recv_rights,
-	    &fp, NULL, NULL);
+	error = getsock(td, s, &cap_recv_rights, &fp);
 	if (error != 0)
 		return (error);
 	so = fp->f_data;
@@ -1205,8 +1218,7 @@ kern_shutdown(struct thread *td, int s, int how)
 	int error;
 
 	AUDIT_ARG_FD(s);
-	error = getsock_cap(td, s, &cap_shutdown_rights,
-	    &fp, NULL, NULL);
+	error = getsock(td, s, &cap_shutdown_rights, &fp);
 	if (error == 0) {
 		so = fp->f_data;
 		error = soshutdown(so, how);
@@ -1263,8 +1275,7 @@ kern_setsockopt(struct thread *td, int s, int level, int name, const void *val,
 	}
 
 	AUDIT_ARG_FD(s);
-	error = getsock_cap(td, s, &cap_setsockopt_rights,
-	    &fp, NULL, NULL);
+	error = getsock(td, s, &cap_setsockopt_rights, &fp);
 	if (error == 0) {
 		so = fp->f_data;
 		error = sosetopt(so, &sopt);
@@ -1328,8 +1339,7 @@ kern_getsockopt(struct thread *td, int s, int level, int name, void *val,
 	}
 
 	AUDIT_ARG_FD(s);
-	error = getsock_cap(td, s, &cap_getsockopt_rights,
-	    &fp, NULL, NULL);
+	error = getsock(td, s, &cap_getsockopt_rights, &fp);
 	if (error == 0) {
 		so = fp->f_data;
 		error = sogetopt(so, &sopt);
@@ -1378,8 +1388,7 @@ kern_getsockname(struct thread *td, int fd, struct sockaddr **sa,
 	int error;
 
 	AUDIT_ARG_FD(fd);
-	error = getsock_cap(td, fd, &cap_getsockname_rights,
-	    &fp, NULL, NULL);
+	error = getsock(td, fd, &cap_getsockname_rights, &fp);
 	if (error != 0)
 		return (error);
 	so = fp->f_data;
@@ -1460,8 +1469,7 @@ kern_getpeername(struct thread *td, int fd, struct sockaddr **sa,
 	int error;
 
 	AUDIT_ARG_FD(fd);
-	error = getsock_cap(td, fd, &cap_getpeername_rights,
-	    &fp, NULL, NULL);
+	error = getsock(td, fd, &cap_getpeername_rights, &fp);
 	if (error != 0)
 		return (error);
 	so = fp->f_data;
diff --git a/sys/net/if_ovpn.c b/sys/net/if_ovpn.c
index a90c11c1dcbf..9e0829d996ce 100644
--- a/sys/net/if_ovpn.c
+++ b/sys/net/if_ovpn.c
@@ -449,7 +449,6 @@ ovpn_new_peer(struct ifnet *ifp, const nvlist_t *nvl)
 	struct ovpn_softc *sc = ifp->if_softc;
 	struct thread *td = curthread;
 	struct socket *so = NULL;
-	u_int fflag;
 	int fd;
 	uint32_t peerid;
 	int ret = 0, i;
@@ -476,8 +475,7 @@ ovpn_new_peer(struct ifnet *ifp, const nvlist_t *nvl)
 	fd = nvlist_get_number(nvl, "fd");
 
 	/* Look up the userspace process and use the fd to find the socket. */
-	ret = getsock_cap(td, fd, &cap_connect_rights, &fp,
-	    &fflag, NULL);
+	ret = getsock(td, fd, &cap_connect_rights, &fp);
 	if (ret != 0)
 		return (ret);
 
diff --git a/sys/netinet/sctp_syscalls.c b/sys/netinet/sctp_syscalls.c
index 2697d139300c..a58ba9c231f3 100644
--- a/sys/netinet/sctp_syscalls.c
+++ b/sys/netinet/sctp_syscalls.c
@@ -153,10 +153,11 @@ sys_sctp_peeloff(td, uap)
 	int error, fd;
 
 	AUDIT_ARG_FD(uap->sd);
-	error = getsock_cap(td, uap->sd, cap_rights_init_one(&rights, CAP_PEELOFF),
-	    &headfp, &fflag, NULL);
+	error = getsock(td, uap->sd, cap_rights_init_one(&rights, CAP_PEELOFF),
+	    &headfp);
 	if (error != 0)
 		goto done2;
+	fflag = atomic_load_int(&fp->f_flag);
 	head = headfp->f_data;
 	if (head->so_proto->pr_protocol != IPPROTO_SCTP) {
 		error = EOPNOTSUPP;
@@ -252,7 +253,7 @@ sys_sctp_generic_sendmsg (td, uap)
 	}
 
 	AUDIT_ARG_FD(uap->sd);
-	error = getsock_cap(td, uap->sd, &rights, &fp, NULL, NULL);
+	error = getsock(td, uap->sd, &rights, &fp);
 	if (error != 0)
 		goto sctp_bad;
 #ifdef KTRACE
@@ -361,7 +362,7 @@ sys_sctp_generic_sendmsg_iov(td, uap)
 	}
 
 	AUDIT_ARG_FD(uap->sd);
-	error = getsock_cap(td, uap->sd, &rights, &fp, NULL, NULL);
+	error = getsock(td, uap->sd, &rights, &fp);
 	if (error != 0)
 		goto sctp_bad1;
 
@@ -472,8 +473,8 @@ sys_sctp_generic_recvmsg(td, uap)
 	int error, fromlen, i, msg_flags;
 
 	AUDIT_ARG_FD(uap->sd);
-	error = getsock_cap(td, uap->sd, cap_rights_init_one(&rights, CAP_RECV),
-	    &fp, NULL, NULL);
+	error = getsock(td, uap->sd, cap_rights_init_one(&rights, CAP_RECV),
+	    &fp);
 	if (error != 0)
 		return (error);
 #ifdef COMPAT_FREEBSD32
diff --git a/sys/sys/socketvar.h b/sys/sys/socketvar.h
index 0c60b9e13cf2..101c6f3f4513 100644
--- a/sys/sys/socketvar.h
+++ b/sys/sys/socketvar.h
@@ -450,6 +450,8 @@ int	getsockaddr(struct sockaddr **namp, const struct sockaddr *uaddr,
 	    size_t len);
 int	getsock_cap(struct thread *td, int fd, cap_rights_t *rightsp,
 	    struct file **fpp, u_int *fflagp, struct filecaps *havecaps);
+int	getsock(struct thread *td, int fd, cap_rights_t *rightsp,
+	    struct file **fpp);
 void	soabort(struct socket *so);
 int	soaccept(struct socket *so, struct sockaddr **nam);
 void	soaio_enqueue(struct task *task);