git: bd7762c86986 - main - pf: add a rule rb tree

From: Mateusz Guzik <mjg_at_FreeBSD.org>
Date: Mon, 28 Mar 2022 11:47:32 UTC
The branch main has been updated by mjg:

URL: https://cgit.FreeBSD.org/src/commit/?id=bd7762c86986537a5b393537b85de31b1556735b

commit bd7762c86986537a5b393537b85de31b1556735b
Author:     Mateusz Guzik <mjg@FreeBSD.org>
AuthorDate: 2022-02-28 20:17:32 +0000
Commit:     Mateusz Guzik <mjg@FreeBSD.org>
CommitDate: 2022-03-28 11:45:03 +0000

    pf: add a rule rb tree
    
    with md5 sum used as key.
    
    This gets rid of the quadratic rule traversal when "keep_counters" is
    set.
    
    Reviewed by:    kp
    Sponsored by:   Rubicon Communications, LLC ("Netgate")
---
 sys/net/pfvar.h           |  5 ++++
 sys/netpfil/pf/pf_ioctl.c | 70 +++++++++++++++++++++++++++++++++--------------
 2 files changed, 54 insertions(+), 21 deletions(-)

diff --git a/sys/net/pfvar.h b/sys/net/pfvar.h
index b83a6d90f8d6..ccc81ea137b9 100644
--- a/sys/net/pfvar.h
+++ b/sys/net/pfvar.h
@@ -673,6 +673,9 @@ union pf_krule_ptr {
 	u_int32_t		 nr;
 };
 
+RB_HEAD(pf_krule_global, pf_krule);
+RB_PROTOTYPE(pf_krule_global, pf_krule, entry_global, pf_krule_compare);
+
 struct pf_krule {
 	struct pf_rule_addr	 src;
 	struct pf_rule_addr	 dst;
@@ -770,6 +773,7 @@ struct pf_krule {
 		u_int16_t		port;
 	}			divert;
 	u_int8_t		 md5sum[PF_MD5_DIGEST_LENGTH];
+	RB_ENTRY(pf_krule)	 entry_global;
 
 #ifdef PF_WANT_32_TO_64_COUNTER
 	LIST_ENTRY(pf_krule)	 allrulelist;
@@ -1140,6 +1144,7 @@ struct pf_kruleset {
 			u_int32_t		 rcount;
 			u_int32_t		 ticket;
 			int			 open;
+			struct pf_krule_global 	 *tree;
 		}			 active, inactive;
 	}			 rules[PF_RULESET_MAX];
 	struct pf_kanchor	*anchor;
diff --git a/sys/netpfil/pf/pf_ioctl.c b/sys/netpfil/pf/pf_ioctl.c
index 724ca9b700db..ae07fe80fbf8 100644
--- a/sys/netpfil/pf/pf_ioctl.c
+++ b/sys/netpfil/pf/pf_ioctl.c
@@ -141,6 +141,11 @@ static int		 pf_import_kaltq(struct pfioc_altq_v1 *,
 
 VNET_DEFINE(struct pf_krule,	pf_default_rule);
 
+static __inline int             pf_krule_compare(struct pf_krule *,
+				    struct pf_krule *);
+
+RB_GENERATE(pf_krule_global, pf_krule, entry_global, pf_krule_compare);
+
 #ifdef ALTQ
 VNET_DEFINE_STATIC(int,		pf_altq_running);
 #define	V_pf_altq_running	VNET(pf_altq_running)
@@ -530,6 +535,7 @@ pf_free_rule(struct pf_krule *rule)
 {
 
 	PF_RULES_WASSERT();
+	PF_CONFIG_ASSERT();
 
 	if (rule->tag)
 		tag_unref(&V_pf_tags, rule->tag);
@@ -1141,6 +1147,7 @@ out:
 static int
 pf_begin_rules(u_int32_t *ticket, int rs_num, const char *anchor)
 {
+	struct pf_krule_global *tree;
 	struct pf_kruleset	*rs;
 	struct pf_krule		*rule;
 
@@ -1148,9 +1155,18 @@ pf_begin_rules(u_int32_t *ticket, int rs_num, const char *anchor)
 
 	if (rs_num < 0 || rs_num >= PF_RULESET_MAX)
 		return (EINVAL);
+	tree = malloc(sizeof(struct pf_krule_global), M_TEMP, M_NOWAIT);
+	if (tree == NULL)
+		return (ENOMEM);
+	RB_INIT(tree);
 	rs = pf_find_or_create_kruleset(anchor);
-	if (rs == NULL)
+	if (rs == NULL) {
+		free(tree, M_TEMP);
 		return (EINVAL);
+	}
+	free(rs->rules[rs_num].inactive.tree, M_TEMP);
+	rs->rules[rs_num].inactive.tree = tree;
+
 	while ((rule = TAILQ_FIRST(rs->rules[rs_num].inactive.ptr)) != NULL) {
 		pf_unlink_rule(rs->rules[rs_num].inactive.ptr, rule);
 		rs->rules[rs_num].inactive.rcount--;
@@ -1275,19 +1291,20 @@ pf_hash_rule(struct pf_krule *rule)
 	MD5Final(rule->md5sum, &ctx);
 }
 
-static bool
+static int
 pf_krule_compare(struct pf_krule *a, struct pf_krule *b)
 {
 
-	return (memcmp(a->md5sum, b->md5sum, PF_MD5_DIGEST_LENGTH) == 0);
+	return (memcmp(a->md5sum, b->md5sum, PF_MD5_DIGEST_LENGTH));
 }
 
 static int
 pf_commit_rules(u_int32_t ticket, int rs_num, char *anchor)
 {
 	struct pf_kruleset	*rs;
-	struct pf_krule		*rule, **old_array, *tail;
+	struct pf_krule		*rule, **old_array, *old_rule;
 	struct pf_krulequeue	*old_rules;
+	struct pf_krule_global  *old_tree;
 	int			 error;
 	u_int32_t		 old_rcount;
 
@@ -1311,40 +1328,43 @@ pf_commit_rules(u_int32_t ticket, int rs_num, char *anchor)
 	old_rules = rs->rules[rs_num].active.ptr;
 	old_rcount = rs->rules[rs_num].active.rcount;
 	old_array = rs->rules[rs_num].active.ptr_array;
+	old_tree = rs->rules[rs_num].active.tree;
 
 	rs->rules[rs_num].active.ptr =
 	    rs->rules[rs_num].inactive.ptr;
 	rs->rules[rs_num].active.ptr_array =
 	    rs->rules[rs_num].inactive.ptr_array;
+	rs->rules[rs_num].active.tree =
+	    rs->rules[rs_num].inactive.tree;
 	rs->rules[rs_num].active.rcount =
 	    rs->rules[rs_num].inactive.rcount;
 
 	/* Attempt to preserve counter information. */
-	if (V_pf_status.keep_counters) {
+	if (V_pf_status.keep_counters && old_tree != NULL) {
 		TAILQ_FOREACH(rule, rs->rules[rs_num].active.ptr,
 		    entries) {
-			tail = TAILQ_FIRST(old_rules);
-			while ((tail != NULL) && ! pf_krule_compare(tail, rule))
-				tail = TAILQ_NEXT(tail, entries);
-			if (tail != NULL) {
-				pf_counter_u64_critical_enter();
-				pf_counter_u64_add_protected(&rule->evaluations,
-				    pf_counter_u64_fetch(&tail->evaluations));
-				pf_counter_u64_add_protected(&rule->packets[0],
-				    pf_counter_u64_fetch(&tail->packets[0]));
-				pf_counter_u64_add_protected(&rule->packets[1],
-				    pf_counter_u64_fetch(&tail->packets[1]));
-				pf_counter_u64_add_protected(&rule->bytes[0],
-				    pf_counter_u64_fetch(&tail->bytes[0]));
-				pf_counter_u64_add_protected(&rule->bytes[1],
-				    pf_counter_u64_fetch(&tail->bytes[1]));
-				pf_counter_u64_critical_exit();
+			old_rule = RB_FIND(pf_krule_global, old_tree, rule);
+			if (old_rule == NULL) {
+				continue;
 			}
+			pf_counter_u64_critical_enter();
+			pf_counter_u64_add_protected(&rule->evaluations,
+			    pf_counter_u64_fetch(&old_rule->evaluations));
+			pf_counter_u64_add_protected(&rule->packets[0],
+			    pf_counter_u64_fetch(&old_rule->packets[0]));
+			pf_counter_u64_add_protected(&rule->packets[1],
+			    pf_counter_u64_fetch(&old_rule->packets[1]));
+			pf_counter_u64_add_protected(&rule->bytes[0],
+			    pf_counter_u64_fetch(&old_rule->bytes[0]));
+			pf_counter_u64_add_protected(&rule->bytes[1],
+			    pf_counter_u64_fetch(&old_rule->bytes[1]));
+			pf_counter_u64_critical_exit();
 		}
 	}
 
 	rs->rules[rs_num].inactive.ptr = old_rules;
 	rs->rules[rs_num].inactive.ptr_array = old_array;
+	rs->rules[rs_num].inactive.tree = NULL;
 	rs->rules[rs_num].inactive.rcount = old_rcount;
 
 	rs->rules[rs_num].active.ticket =
@@ -1362,6 +1382,7 @@ pf_commit_rules(u_int32_t ticket, int rs_num, char *anchor)
 	rs->rules[rs_num].inactive.rcount = 0;
 	rs->rules[rs_num].inactive.open = 0;
 	pf_remove_if_empty_kruleset(rs);
+	free(old_tree, M_TEMP);
 
 	return (0);
 }
@@ -2207,6 +2228,13 @@ pf_ioctl_addrule(struct pf_krule *rule, uint32_t ticket,
 
 	PF_RULES_WUNLOCK();
 	pf_hash_rule(rule);
+	if (RB_INSERT(pf_krule_global, ruleset->rules[rs_num].inactive.tree, rule) != NULL) {
+		PF_RULES_WLOCK();
+		pf_free_rule(rule);
+		rule = NULL;
+		error = EINVAL;
+		ERROUT(error);
+	}
 	PF_CONFIG_UNLOCK();
 
 	return (0);