dscpPolicy - further improvements

Bug: 237485762
Test: atest DscpPolicyTest
Signed-off-by: Maciej Żenczykowski <[email protected]>
Change-Id: Ia88b14609cad4604523e3fc41860c980ee11abe0
diff --git a/bpf_progs/dscpPolicy.c b/bpf_progs/dscpPolicy.c
index f308931..6e85710 100644
--- a/bpf_progs/dscpPolicy.c
+++ b/bpf_progs/dscpPolicy.c
@@ -121,10 +121,13 @@
 
     RuleEntry* existing_rule = bpf_socket_policy_cache_map_lookup_elem(&cookie);
 
-    if (existing_rule && v6_equal(src_ip, existing_rule->src_ip) &&
-            v6_equal(dst_ip, existing_rule->dst_ip) && skb->ifindex == existing_rule->ifindex &&
-        ntohs(sport) == htons(existing_rule->src_port) &&
-        ntohs(dport) == htons(existing_rule->dst_port) && protocol == existing_rule->proto) {
+    if (existing_rule &&
+        v6_equal(src_ip, existing_rule->src_ip) &&
+        v6_equal(dst_ip, existing_rule->dst_ip) &&
+        skb->ifindex == existing_rule->ifindex &&
+        sport == existing_rule->src_port &&
+        dport == existing_rule->dst_port &&
+        protocol == existing_rule->proto) {
         if (existing_rule->dscp_val < 0) return;
         if (ipv4) {
             uint8_t newTos = UPDATE_TOS(existing_rule->dscp_val, tos);
@@ -145,8 +148,6 @@
     int8_t new_dscp = -1;
 
     for (register uint64_t i = 0; i < MAX_POLICIES; i++) {
-        int score = 0;
-        uint8_t temp_mask = 0;
         // Using a uint64 in for loop prevents infinite loop during BPF load,
         // but the key is uint32, so convert back.
         uint32_t key = i;
@@ -158,38 +159,35 @@
             policy = bpf_ipv6_dscp_policies_map_lookup_elem(&key);
         }
 
-        // If the policy lookup failed, present_fields is 0, or iface index does not match
-        // index on skb buff, then we can continue to next policy.
-        if (!policy || policy->present_fields == 0 || policy->ifindex != skb->ifindex) continue;
+        // If the policy lookup failed, just continue (this should not ever happen)
+        if (!policy) continue;
 
-        if ((policy->present_fields & SRC_IP_MASK_FLAG) == SRC_IP_MASK_FLAG &&
-            v6_equal(src_ip, policy->src_ip)) {
-            score++;
-            temp_mask |= SRC_IP_MASK_FLAG;
-        }
-        if ((policy->present_fields & DST_IP_MASK_FLAG) == DST_IP_MASK_FLAG &&
-            v6_equal(dst_ip, policy->dst_ip)) {
-            score++;
-            temp_mask |= DST_IP_MASK_FLAG;
-        }
-        if ((policy->present_fields & SRC_PORT_MASK_FLAG) == SRC_PORT_MASK_FLAG &&
-            ntohs(sport) == htons(policy->src_port)) {
-            score++;
-            temp_mask |= SRC_PORT_MASK_FLAG;
-        }
-        if ((policy->present_fields & DST_PORT_MASK_FLAG) == DST_PORT_MASK_FLAG &&
-            ntohs(dport) >= htons(policy->dst_port_start) &&
-            ntohs(dport) <= htons(policy->dst_port_end)) {
-            score++;
-            temp_mask |= DST_PORT_MASK_FLAG;
-        }
-        if ((policy->present_fields & PROTO_MASK_FLAG) == PROTO_MASK_FLAG &&
-            protocol == policy->proto) {
-            score++;
-            temp_mask |= PROTO_MASK_FLAG;
-        }
+        // If policy iface index does not match skb, then skip to next policy.
+        if (policy->ifindex != skb->ifindex) continue;
 
-        if (score > best_score && temp_mask == policy->present_fields) {
+        int score = 0;
+
+        if (policy->present_fields & PROTO_MASK_FLAG) {
+            if (protocol != policy->proto) continue;
+            score += 0xFFFF;
+        }
+        if (policy->present_fields & SRC_IP_MASK_FLAG) {
+            if (v6_not_equal(src_ip, policy->src_ip)) continue;
+            score += 0xFFFF;
+        }
+        if (policy->present_fields & DST_IP_MASK_FLAG) {
+            if (v6_not_equal(dst_ip, policy->dst_ip)) continue;
+            score += 0xFFFF;
+        }
+        if (policy->present_fields & SRC_PORT_MASK_FLAG) {
+            if (sport != policy->src_port) continue;
+            score += 0xFFFF;
+        }
+        if (ntohs(dport) < ntohs(policy->dst_port_start)) continue;
+        if (ntohs(dport) > ntohs(policy->dst_port_end)) continue;
+        score += 0xFFFF + ntohs(policy->dst_port_start) - ntohs(policy->dst_port_end);
+
+        if (score > best_score) {
             best_score = score;
             new_dscp = policy->dscp_val;
         }
diff --git a/bpf_progs/dscpPolicy.h b/bpf_progs/dscpPolicy.h
index c1db6ab..1a6e14d 100644
--- a/bpf_progs/dscpPolicy.h
+++ b/bpf_progs/dscpPolicy.h
@@ -20,16 +20,22 @@
 #define SRC_IP_MASK_FLAG     1
 #define DST_IP_MASK_FLAG     2
 #define SRC_PORT_MASK_FLAG   4
-#define DST_PORT_MASK_FLAG   8
-#define PROTO_MASK_FLAG      16
+#define PROTO_MASK_FLAG      8
 
 #define STRUCT_SIZE(name, size) _Static_assert(sizeof(name) == (size), "Incorrect struct size.")
 
-#define v6_equal(a, b) \
-    (((a.s6_addr32[0] ^ b.s6_addr32[0]) | \
-      (a.s6_addr32[1] ^ b.s6_addr32[1]) | \
-      (a.s6_addr32[2] ^ b.s6_addr32[2]) | \
-      (a.s6_addr32[3] ^ b.s6_addr32[3])) == 0)
+// Retrieve the first (ie. high) 64 bits of an IPv6 address (in network order)
+#define v6_hi_be64(v) (*(uint64_t*)&((v).s6_addr32[0]))
+
+// Retrieve the last (ie. low) 64 bits of an IPv6 address (in network order)
+#define v6_lo_be64(v) (*(uint64_t*)&((v).s6_addr32[2]))
+
+// This returns a non-zero u64 iff a != b
+#define v6_not_equal(a, b) ((v6_hi_be64(a) ^ v6_hi_be64(b)) \
+                          | (v6_lo_be64(a) ^ v6_lo_be64(b)))
+
+// Returns 'a == b' as boolean
+#define v6_equal(a, b) (!v6_not_equal((a), (b)))
 
 // TODO: these are already defined in packages/modules/Connectivity/bpf_progs/bpf_net_helpers.h.
 // smove to common location in future.