dscpPolicy.c - cache result-less lookups as well

It is most definitely worthwhile to cache negative lookups as well!

Test: TreeHugger, atest DscpPolicyTest
Signed-off-by: Maciej Żenczykowski <[email protected]>
Change-Id: Iab1a57a2611a891642fef0c5897918c16e0ca540
diff --git a/bpf_progs/dscpPolicy.c b/bpf_progs/dscpPolicy.c
index e55f473..f308931 100644
--- a/bpf_progs/dscpPolicy.c
+++ b/bpf_progs/dscpPolicy.c
@@ -34,8 +34,8 @@
 #include "dscpPolicy.h"
 
 #define ECN_MASK 3
-#define IP4_OFFSET(field, header) (header + offsetof(struct iphdr, field))
-#define UPDATE_TOS(dscp, tos) (dscp << 2) | (tos & ECN_MASK)
+#define IP4_OFFSET(field, header) ((header) + offsetof(struct iphdr, field))
+#define UPDATE_TOS(dscp, tos) ((dscp) << 2) | ((tos) & ECN_MASK)
 
 DEFINE_BPF_MAP_GRW(socket_policy_cache_map, HASH, uint64_t, RuleEntry, CACHE_MAP_SIZE, AID_SYSTEM)
 
@@ -125,6 +125,7 @@
             v6_equal(dst_ip, existing_rule->dst_ip) && skb->ifindex == existing_rule->ifindex &&
         ntohs(sport) == htons(existing_rule->src_port) &&
         ntohs(dport) == htons(existing_rule->dst_port) && protocol == existing_rule->proto) {
+        if (existing_rule->dscp_val < 0) return;
         if (ipv4) {
             uint8_t newTos = UPDATE_TOS(existing_rule->dscp_val, tos);
             bpf_l3_csum_replace(skb, IP4_OFFSET(check, l2_header_size), htons(tos), htons(newTos),
@@ -140,8 +141,8 @@
     }
 
     // Linear scan ipv4_dscp_policies_map since no stored params match skb.
-    int best_score = -1;
-    uint32_t best_match = 0;
+    int best_score = 0;
+    int8_t new_dscp = -1;
 
     for (register uint64_t i = 0; i < MAX_POLICIES; i++) {
         int score = 0;
@@ -189,25 +190,10 @@
         }
 
         if (score > best_score && temp_mask == policy->present_fields) {
-            best_match = i;
             best_score = score;
-        }
-    }
-
-    uint8_t new_dscp = 0;
-    if (best_score > 0) {
-        DscpPolicy* policy;
-        if (ipv4) {
-            policy = bpf_ipv4_dscp_policies_map_lookup_elem(&best_match);
-        } else {
-            policy = bpf_ipv6_dscp_policies_map_lookup_elem(&best_match);
-        }
-
-        if (policy) {
             new_dscp = policy->dscp_val;
         }
-    } else
-        return;
+    }
 
     RuleEntry value = {
         .src_ip = src_ip,
@@ -219,9 +205,11 @@
         .dscp_val = new_dscp,
     };
 
-    // Update map with new policy.
+    // Update cache with found policy.
     bpf_socket_policy_cache_map_update_elem(&cookie, &value, BPF_ANY);
 
+    if (new_dscp < 0) return;
+
     // Need to store bytes after updating map or program will not load.
     if (ipv4) {
         uint8_t new_tos = UPDATE_TOS(new_dscp, tos);
diff --git a/bpf_progs/dscpPolicy.h b/bpf_progs/dscpPolicy.h
index 8968668..c1db6ab 100644
--- a/bpf_progs/dscpPolicy.h
+++ b/bpf_progs/dscpPolicy.h
@@ -50,7 +50,7 @@
     __be16 dst_port_start;
     __be16 dst_port_end;
     uint8_t proto;
-    uint8_t dscp_val;
+    int8_t dscp_val;  // -1 none, or 0..63 DSCP value
     uint8_t present_fields;
     uint8_t pad[3];
 } DscpPolicy;
@@ -63,7 +63,7 @@
     __be16 src_port;
     __be16 dst_port;
     __u8 proto;
-    __u8 dscp_val;
+    __s8 dscp_val;  // -1 none, or 0..63 DSCP value
     __u8 pad[2];
 } RuleEntry;
 STRUCT_SIZE(RuleEntry, 2 * 16 + 1 * 4 + 2 * 2 + 2 * 1 + 2);  // 44
diff --git a/service/src/com/android/server/connectivity/DscpPolicyTracker.java b/service/src/com/android/server/connectivity/DscpPolicyTracker.java
index 8cb3213..2bfad10 100644
--- a/service/src/com/android/server/connectivity/DscpPolicyTracker.java
+++ b/service/src/com/android/server/connectivity/DscpPolicyTracker.java
@@ -185,7 +185,7 @@
                         new DscpPolicyValue(policy.getSourceAddress(),
                             policy.getDestinationAddress(), ifIndex,
                             policy.getSourcePort(), policy.getDestinationPortRange(),
-                            (short) policy.getProtocol(), (short) policy.getDscpValue()));
+                            (short) policy.getProtocol(), (byte) policy.getDscpValue()));
             }
 
             // Add v6 policy to mBpfDscpIpv6Policies if source and destination address
@@ -196,7 +196,7 @@
                         new DscpPolicyValue(policy.getSourceAddress(),
                                 policy.getDestinationAddress(), ifIndex,
                                 policy.getSourcePort(), policy.getDestinationPortRange(),
-                                (short) policy.getProtocol(), (short) policy.getDscpValue()));
+                                (short) policy.getProtocol(), (byte) policy.getDscpValue()));
             }
 
             ifacePolicies.put(policy.getPolicyId(), addIndex);
diff --git a/service/src/com/android/server/connectivity/DscpPolicyValue.java b/service/src/com/android/server/connectivity/DscpPolicyValue.java
index 6e4e7eb..4bb41da 100644
--- a/service/src/com/android/server/connectivity/DscpPolicyValue.java
+++ b/service/src/com/android/server/connectivity/DscpPolicyValue.java
@@ -52,8 +52,8 @@
     @Field(order = 6, type = Type.U8)
     public final short proto;
 
-    @Field(order = 7, type = Type.U8)
-    public final short dscp;
+    @Field(order = 7, type = Type.S8)
+    public final byte dscp;
 
     @Field(order = 8, type = Type.U8, padding = 3)
     public final short mask;
@@ -100,7 +100,7 @@
             InetAddress.parseNumericAddress("::").getAddress();
 
     private short makeMask(final byte[] src46, final byte[] dst46, final int srcPort,
-            final int dstPortStart, final short proto, final short dscp) {
+            final int dstPortStart, final short proto, final byte dscp) {
         short mask = 0;
         if (src46 != EMPTY_ADDRESS_FIELD) {
             mask |= SRC_IP_MASK;
@@ -122,7 +122,7 @@
 
     private DscpPolicyValue(final InetAddress src46, final InetAddress dst46, final long ifIndex,
             final int srcPort, final int dstPortStart, final int dstPortEnd, final short proto,
-            final short dscp) {
+            final byte dscp) {
         this.src46 = toAddressField(src46);
         this.dst46 = toAddressField(dst46);
         this.ifIndex = ifIndex;
@@ -142,7 +142,7 @@
 
     public DscpPolicyValue(final InetAddress src46, final InetAddress dst46, final long ifIndex,
             final int srcPort, final Range<Integer> dstPort, final short proto,
-            final short dscp) {
+            final byte dscp) {
         this(src46, dst46, ifIndex, srcPort, dstPort != null ? dstPort.getLower() : -1,
                 dstPort != null ? dstPort.getUpper() : -1, proto, dscp);
     }
@@ -150,7 +150,7 @@
     public static final DscpPolicyValue NONE = new DscpPolicyValue(
             null /* src46 */, null /* dst46 */, 0 /* ifIndex */, -1 /* srcPort */,
             -1 /* dstPortStart */, -1 /* dstPortEnd */, (short) -1 /* proto */,
-            (short) 0 /* dscp */);
+            (byte) -1 /* dscp */);
 
     @Override
     public String toString() {