git: 7e65cfc9bbe5 - main - pf: Make pf_get_translation() more expressive

From: Mark Johnston <markj_at_FreeBSD.org>
Date: Mon, 19 Aug 2024 14:37:40 UTC
The branch main has been updated by markj:

URL: https://cgit.FreeBSD.org/src/commit/?id=7e65cfc9bbe5a9d735ef38f7ed49965b234b8a20

commit 7e65cfc9bbe5a9d735ef38f7ed49965b234b8a20
Author:     Mark Johnston <markj@FreeBSD.org>
AuthorDate: 2024-08-19 14:14:30 +0000
Commit:     Mark Johnston <markj@FreeBSD.org>
CommitDate: 2024-08-19 14:37:27 +0000

    pf: Make pf_get_translation() more expressive
    
    Currently pf_get_translation() returns a pointer to a matching
    nat/rdr/binat rule, or NULL if no rule was matched or an error occurred
    while applying the translation.  That is, we don't distinguish between
    errors and the lack of a matching rule.  This, if an error (e.g., a
    memory allocation failure or a state conflict) occurs, we simply handle
    the packet as if no translation rule was present.  This is not
    desireable.
    
    Make pf_get_translation() return the matching rule as an out-param and
    instead return a reason code which indicates whether there was no
    translation rule, or there was a translation rule and we failed to apply
    it, or there was a translation rule and we applied it successfully.
    
    Reviewed by:    kp, allanjude
    MFC after:      3 months
    Sponsored by:   Klara, Inc.
    Sponsored by:   Modirum
    Differential Revision:  https://reviews.freebsd.org/D45672
---
 sys/net/pfvar.h        |  5 +++--
 sys/netpfil/pf/pf.c    | 18 ++++++++++++----
 sys/netpfil/pf/pf_lb.c | 57 +++++++++++++++++++++++++++++++++-----------------
 3 files changed, 55 insertions(+), 25 deletions(-)

diff --git a/sys/net/pfvar.h b/sys/net/pfvar.h
index 863883c2d61e..d66e6f799761 100644
--- a/sys/net/pfvar.h
+++ b/sys/net/pfvar.h
@@ -2569,11 +2569,12 @@ u_short			 pf_map_addr(u_int8_t, struct pf_krule *,
 			    struct pf_addr *, struct pf_addr *,
 			    struct pfi_kkif **nkif, struct pf_addr *,
 			    struct pf_ksrc_node **);
-struct pf_krule		*pf_get_translation(struct pf_pdesc *, struct mbuf *,
+u_short			 pf_get_translation(struct pf_pdesc *, struct mbuf *,
 			    int, struct pfi_kkif *, struct pf_ksrc_node **,
 			    struct pf_state_key **, struct pf_state_key **,
 			    struct pf_addr *, struct pf_addr *,
-			    uint16_t, uint16_t, struct pf_kanchor_stackframe *);
+			    uint16_t, uint16_t, struct pf_kanchor_stackframe *,
+			    struct pf_krule **);
 
 struct pf_state_key	*pf_state_key_setup(struct pf_pdesc *, struct pf_addr *,
 			    struct pf_addr *, u_int16_t, u_int16_t);
diff --git a/sys/netpfil/pf/pf.c b/sys/netpfil/pf/pf.c
index 0547e29e04c2..2bbd231b3ee9 100644
--- a/sys/netpfil/pf/pf.c
+++ b/sys/netpfil/pf/pf.c
@@ -4605,7 +4605,7 @@ pf_test_rule(struct pf_krule **rm, struct pf_kstate **sm, struct pfi_kkif *kif,
 	struct pf_ksrc_node	*nsn = NULL;
 	struct tcphdr		*th = &pd->hdr.tcp;
 	struct pf_state_key	*sk = NULL, *nk = NULL;
-	u_short			 reason;
+	u_short			 reason, transerror;
 	int			 rewrite = 0, hdrlen = 0;
 	int			 tag = -1;
 	int			 asd = 0;
@@ -4618,6 +4618,8 @@ pf_test_rule(struct pf_krule **rm, struct pf_kstate **sm, struct pfi_kkif *kif,
 
 	PF_RULES_RASSERT();
 
+	SLIST_INIT(&match_rules);
+
 	if (inp != NULL) {
 		INP_LOCK_ASSERT(inp);
 		pd->lookup.uid = inp->inp_cred->cr_uid;
@@ -4686,8 +4688,17 @@ pf_test_rule(struct pf_krule **rm, struct pf_kstate **sm, struct pfi_kkif *kif,
 	r = TAILQ_FIRST(pf_main_ruleset.rules[PF_RULESET_FILTER].active.ptr);
 
 	/* check packet for BINAT/NAT/RDR */
-	if ((nr = pf_get_translation(pd, m, off, kif, &nsn, &sk,
-	    &nk, saddr, daddr, sport, dport, anchor_stack)) != NULL) {
+	transerror = pf_get_translation(pd, m, off, kif, &nsn, &sk,
+	    &nk, saddr, daddr, sport, dport, anchor_stack, &nr);
+	switch (transerror) {
+	default:
+		/* A translation error occurred. */
+		REASON_SET(&reason, transerror);
+		goto cleanup;
+	case PFRES_MAX:
+		/* No match. */
+		break;
+	case PFRES_MATCH:
 		KASSERT(sk != NULL, ("%s: null sk", __func__));
 		KASSERT(nk != NULL, ("%s: null nk", __func__));
 
@@ -4836,7 +4847,6 @@ pf_test_rule(struct pf_krule **rm, struct pf_kstate **sm, struct pfi_kkif *kif,
 		pd->nat_rule = nr;
 	}
 
-	SLIST_INIT(&match_rules);
 	while (r != NULL) {
 		pf_counter_u64_add(&r->evaluations, 1);
 		if (pfi_kkif_match(r->kif, kif) == r->ifnot)
diff --git a/sys/netpfil/pf/pf_lb.c b/sys/netpfil/pf/pf_lb.c
index 4b703d3d02da..68fc76233dab 100644
--- a/sys/netpfil/pf/pf_lb.c
+++ b/sys/netpfil/pf/pf_lb.c
@@ -591,22 +591,26 @@ done:
 	return (reason);
 }
 
-struct pf_krule *
+u_short
 pf_get_translation(struct pf_pdesc *pd, struct mbuf *m, int off,
     struct pfi_kkif *kif, struct pf_ksrc_node **sn,
     struct pf_state_key **skp, struct pf_state_key **nkp,
     struct pf_addr *saddr, struct pf_addr *daddr,
-    uint16_t sport, uint16_t dport, struct pf_kanchor_stackframe *anchor_stack)
+    uint16_t sport, uint16_t dport, struct pf_kanchor_stackframe *anchor_stack,
+    struct pf_krule **rp)
 {
 	struct pf_krule	*r = NULL;
 	struct pf_addr	*naddr;
 	uint16_t	*nportp;
 	uint16_t	 low, high;
+	u_short		 reason;
 
 	PF_RULES_RASSERT();
 	KASSERT(*skp == NULL, ("*skp not NULL"));
 	KASSERT(*nkp == NULL, ("*nkp not NULL"));
 
+	*rp = NULL;
+
 	if (pd->dir == PF_OUT) {
 		r = pf_match_translation(pd, m, off, kif, saddr,
 		    sport, daddr, dport, PF_RULESET_BINAT, anchor_stack);
@@ -624,23 +628,23 @@ pf_get_translation(struct pf_pdesc *pd, struct mbuf *m, int off,
 	}
 
 	if (r == NULL)
-		return (NULL);
+		return (PFRES_MAX);
 
 	switch (r->action) {
 	case PF_NONAT:
 	case PF_NOBINAT:
 	case PF_NORDR:
-		return (NULL);
+		return (PFRES_MAX);
 	}
 
 	*skp = pf_state_key_setup(pd, saddr, daddr, sport, dport);
 	if (*skp == NULL)
-		return (NULL);
+		return (PFRES_MEMORY);
 	*nkp = pf_state_key_clone(*skp);
 	if (*nkp == NULL) {
 		uma_zfree(V_pf_state_key_z, *skp);
 		*skp = NULL;
-		return (NULL);
+		return (PFRES_MEMORY);
 	}
 
 	naddr = &(*nkp)->addr[1];
@@ -664,6 +668,7 @@ pf_get_translation(struct pf_pdesc *pd, struct mbuf *m, int off,
 				    r->rpool.mape.offset,
 				    r->rpool.mape.psidlen,
 				    r->rpool.mape.psid));
+				reason = PFRES_MAPFAILED;
 				goto notrans;
 			}
 		} else if (pf_get_sport(pd->af, pd->proto, r, saddr, sport,
@@ -671,6 +676,7 @@ pf_get_translation(struct pf_pdesc *pd, struct mbuf *m, int off,
 			DPFPRINTF(PF_DEBUG_MISC,
 			    ("pf: NAT proxy port allocation (%u-%u) failed\n",
 			    r->rpool.proxy_port[0], r->rpool.proxy_port[1]));
+			reason = PFRES_MAPFAILED;
 			goto notrans;
 		}
 		break;
@@ -682,8 +688,10 @@ pf_get_translation(struct pf_pdesc *pd, struct mbuf *m, int off,
 #ifdef INET
 				case AF_INET:
 					if (r->rpool.cur->addr.p.dyn->
-					    pfid_acnt4 < 1)
+					    pfid_acnt4 < 1) {
+						reason = PFRES_MAPFAILED;
 						goto notrans;
+					}
 					PF_POOLMASK(naddr,
 					    &r->rpool.cur->addr.p.dyn->
 					    pfid_addr4,
@@ -694,8 +702,10 @@ pf_get_translation(struct pf_pdesc *pd, struct mbuf *m, int off,
 #ifdef INET6
 				case AF_INET6:
 					if (r->rpool.cur->addr.p.dyn->
-					    pfid_acnt6 < 1)
+					    pfid_acnt6 < 1) {
+						reason = PFRES_MAPFAILED;
 						goto notrans;
+					}
 					PF_POOLMASK(naddr,
 					    &r->rpool.cur->addr.p.dyn->
 					    pfid_addr6,
@@ -715,8 +725,10 @@ pf_get_translation(struct pf_pdesc *pd, struct mbuf *m, int off,
 				switch (pd->af) {
 #ifdef INET
 				case AF_INET:
-					if (r->src.addr.p.dyn-> pfid_acnt4 < 1)
+					if (r->src.addr.p.dyn->pfid_acnt4 < 1) {
+						reason = PFRES_MAPFAILED;
 						goto notrans;
+					}
 					PF_POOLMASK(naddr,
 					    &r->src.addr.p.dyn->pfid_addr4,
 					    &r->src.addr.p.dyn->pfid_mask4,
@@ -725,8 +737,10 @@ pf_get_translation(struct pf_pdesc *pd, struct mbuf *m, int off,
 #endif /* INET */
 #ifdef INET6
 				case AF_INET6:
-					if (r->src.addr.p.dyn->pfid_acnt6 < 1)
+					if (r->src.addr.p.dyn->pfid_acnt6 < 1) {
+						reason = PFRES_MAPFAILED;
 						goto notrans;
+					}
 					PF_POOLMASK(naddr,
 					    &r->src.addr.p.dyn->pfid_addr6,
 					    &r->src.addr.p.dyn->pfid_mask6,
@@ -744,7 +758,8 @@ pf_get_translation(struct pf_pdesc *pd, struct mbuf *m, int off,
 		struct pf_state_key_cmp key;
 		uint16_t cut, low, high, nport;
 
-		if (pf_map_addr(pd->af, r, saddr, naddr, NULL, NULL, sn))
+		reason = pf_map_addr(pd->af, r, saddr, naddr, NULL, NULL, sn);
+		if (reason != 0)
 			goto notrans;
 		if ((r->rpool.opts & PF_POOL_TYPEMASK) == PF_POOL_BITMASK)
 			PF_POOLMASK(naddr, naddr, &r->rpool.cur->addr.v.a.mask,
@@ -815,12 +830,13 @@ pf_get_translation(struct pf_pdesc *pd, struct mbuf *m, int off,
 
 		DPFPRINTF(PF_DEBUG_MISC,
 		    ("pf: RDR source port allocation failed\n"));
-		if (0) {
+		reason = PFRES_MAPFAILED;
+		goto notrans;
+
 out:
-			DPFPRINTF(PF_DEBUG_MISC,
-			    ("pf: RDR source port allocation %u->%u\n",
-			    ntohs(sport), ntohs((*nkp)->port[0])));
-		}
+		DPFPRINTF(PF_DEBUG_MISC,
+		    ("pf: RDR source port allocation %u->%u\n",
+		    ntohs(sport), ntohs((*nkp)->port[0])));
 		break;
 	}
 	default:
@@ -828,14 +844,17 @@ out:
 	}
 
 	/* Return success only if translation really happened. */
-	if (bcmp(*skp, *nkp, sizeof(struct pf_state_key_cmp)))
-		return (r);
+	if (bcmp(*skp, *nkp, sizeof(struct pf_state_key_cmp))) {
+		*rp = r;
+		return (PFRES_MATCH);
+	}
 
+	reason = PFRES_MAX;
 notrans:
 	uma_zfree(V_pf_state_key_z, *nkp);
 	uma_zfree(V_pf_state_key_z, *skp);
 	*skp = *nkp = NULL;
 	*sn = NULL;
 
-	return (NULL);
+	return (reason);
 }