git: 1ebe8e0fad40 - stable/13 - ipsec: enter epoch before calling into ipsec_run_hhooks

From: Mateusz Guzik <mjg_at_FreeBSD.org>
Date: Mon, 11 Oct 2021 09:15:49 UTC
The branch stable/13 has been updated by mjg:

URL: https://cgit.FreeBSD.org/src/commit/?id=1ebe8e0fad409ec16b34c392e823c25ecd42876f

commit 1ebe8e0fad409ec16b34c392e823c25ecd42876f
Author:     Mateusz Guzik <mjg@FreeBSD.org>
AuthorDate: 2021-09-17 12:00:20 +0000
Commit:     Mateusz Guzik <mjg@FreeBSD.org>
CommitDate: 2021-10-11 09:10:31 +0000

    ipsec: enter epoch before calling into ipsec_run_hhooks
    
    pfil_run_hooks which eventually can get called asserts on it.
    
    Reviewed by:    ae
    Sponsored by:   Rubicon Communications, LLC ("Netgate")
    Differential Revision: https://reviews.freebsd.org/D32007
    
    (cherry picked from commit 590d0715b348d0d8da0c0355cebd9dff18e39831)
---
 sys/netipsec/ipsec_input.c | 27 ++++++++++++++++-----------
 1 file changed, 16 insertions(+), 11 deletions(-)

diff --git a/sys/netipsec/ipsec_input.c b/sys/netipsec/ipsec_input.c
index 48acba68a1fe..2e2efe34842b 100644
--- a/sys/netipsec/ipsec_input.c
+++ b/sys/netipsec/ipsec_input.c
@@ -305,7 +305,7 @@ ipsec4_common_input_cb(struct mbuf *m, struct secasvar *sav, int skip,
 			    buf, sizeof(buf)), (u_long) ntohl(sav->spi)));
 			IPSEC_ISTAT(sproto, hdrops);
 			error = ENOBUFS;
-			goto bad;
+			goto bad_noepoch;
 		}
 
 		ip = mtod(m, struct ip *);
@@ -325,6 +325,11 @@ ipsec4_common_input_cb(struct mbuf *m, struct secasvar *sav, int skip,
 	    (prot == IPPROTO_UDP || prot == IPPROTO_TCP))
 		udp_ipsec_adjust_cksum(m, sav, prot, skip);
 
+	/*
+	 * Needed for ipsec_run_hooks and netisr_queue_src
+	 */
+	NET_EPOCH_ENTER(et);
+
 	IPSEC_INIT_CTX(&ctx, &m, NULL, sav, AF_INET, IPSEC_ENC_BEFORE);
 	if ((error = ipsec_run_hhooks(&ctx, HHOOK_TYPE_IPSEC_IN)) != 0)
 		goto bad;
@@ -424,18 +429,19 @@ ipsec4_common_input_cb(struct mbuf *m, struct secasvar *sav, int skip,
 	if (saidx->mode == IPSEC_MODE_TUNNEL)
 		error = ipsec_if_input(m, sav, af);
 	if (error == 0) {
-		NET_EPOCH_ENTER(et);
 		error = netisr_queue_src(isr_prot, (uintptr_t)sav->spi, m);
-		NET_EPOCH_EXIT(et);
 		if (error) {
 			IPSEC_ISTAT(sproto, qfull);
 			DPRINTF(("%s: queue full; proto %u packet dropped\n",
 			    __func__, sproto));
 		}
 	}
+	NET_EPOCH_EXIT(et);
 	key_freesav(&sav);
 	return (error);
 bad:
+	NET_EPOCH_EXIT(et);
+bad_noepoch:
 	key_freesav(&sav);
 	if (m != NULL)
 		m_freem(m);
@@ -512,6 +518,8 @@ ipsec6_common_input_cb(struct mbuf *m, struct secasvar *sav, int skip,
 		sproto == IPPROTO_IPCOMP,
 		("unexpected security protocol %u", sproto));
 
+	NET_EPOCH_ENTER(et);
+
 	/* Fix IPv6 header */
 	if (m->m_len < sizeof(struct ip6_hdr) &&
 	    (m = m_pullup(m, sizeof(struct ip6_hdr))) == NULL) {
@@ -623,16 +631,15 @@ ipsec6_common_input_cb(struct mbuf *m, struct secasvar *sav, int skip,
 		if (saidx->mode == IPSEC_MODE_TUNNEL)
 			error = ipsec_if_input(m, sav, af);
 		if (error == 0) {
-			NET_EPOCH_ENTER(et);
 			error = netisr_queue_src(isr_prot,
 			    (uintptr_t)sav->spi, m);
-			NET_EPOCH_EXIT(et);
 			if (error) {
 				IPSEC_ISTAT(sproto, qfull);
 				DPRINTF(("%s: queue full; proto %u packet"
 				    " dropped\n", __func__, sproto));
 			}
 		}
+		NET_EPOCH_EXIT(et);
 		key_freesav(&sav);
 		return (error);
 	}
@@ -642,12 +649,11 @@ ipsec6_common_input_cb(struct mbuf *m, struct secasvar *sav, int skip,
 	 */
 	nest = 0;
 	nxt = nxt8;
-	NET_EPOCH_ENTER(et);
 	while (nxt != IPPROTO_DONE) {
 		if (V_ip6_hdrnestlimit && (++nest > V_ip6_hdrnestlimit)) {
 			IP6STAT_INC(ip6s_toomanyhdr);
 			error = EINVAL;
-			goto bad_epoch;
+			goto bad;
 		}
 
 		/*
@@ -658,7 +664,7 @@ ipsec6_common_input_cb(struct mbuf *m, struct secasvar *sav, int skip,
 			IP6STAT_INC(ip6s_tooshort);
 			in6_ifstat_inc(m->m_pkthdr.rcvif, ifs6_in_truncated);
 			error = EINVAL;
-			goto bad_epoch;
+			goto bad;
 		}
 		/*
 		 * Enforce IPsec policy checking if we are seeing last header.
@@ -668,16 +674,15 @@ ipsec6_common_input_cb(struct mbuf *m, struct secasvar *sav, int skip,
 		if ((inet6sw[ip6_protox[nxt]].pr_flags & PR_LASTHDR) != 0 &&
 		    ipsec6_in_reject(m, NULL)) {
 			error = EINVAL;
-			goto bad_epoch;
+			goto bad;
 		}
 		nxt = (*inet6sw[ip6_protox[nxt]].pr_input)(&m, &skip, nxt);
 	}
 	NET_EPOCH_EXIT(et);
 	key_freesav(&sav);
 	return (0);
-bad_epoch:
-	NET_EPOCH_EXIT(et);
 bad:
+	NET_EPOCH_EXIT(et);
 	key_freesav(&sav);
 	if (m)
 		m_freem(m);