v5: update transmit to use new checksum code

Test: TreeHugger
  m libapf_v5 && a_test apf_{assemble,checksum,dns,run}_test NetworkStackTests:android.net.apf.Apf{,V5}Test
Signed-off-by: Maciej Żenczykowski <[email protected]>
Change-Id: Iac2de6b0f06aeda0c674e9857406502b3c6502c7
diff --git a/disassembler.c b/disassembler.c
index a89f213..3a70f82 100644
--- a/disassembler.c
+++ b/disassembler.c
@@ -270,11 +270,17 @@
                         bprintf("%d", alloc_len);
                     }
                     break;
-                case TRANSMITDISCARD_EXT_OPCODE:
-                    if (reg_num == 0) {
-                        print_opcode("discard");
-                    } else  {
-                        print_opcode("transmit");
+                case TRANSMIT_EXT_OPCODE:
+                    print_opcode(reg_num ? "transmitudp" : "transmit");
+                    u8 ip_ofs = DECODE_IMM(1);
+                    u8 csum_ofs = DECODE_IMM(1);
+                    if (csum_ofs < 255) {
+                        u8 csum_start = DECODE_IMM(1);
+                        u16 partial_csum = DECODE_IMM(2);
+                        bprintf("ip_ofs=%d, csum_ofs=%d, csum_start=%d, partial_csum=0x%04x",
+                                ip_ofs, csum_ofs, csum_start, partial_csum);
+                    } else {
+                        bprintf("ip_ofs=%d", ip_ofs);
                     }
                     break;
                 case EWRITE1_EXT_OPCODE: print_opcode("ewrite1"); bprintf("r%d", reg_num); break;
diff --git a/v5/apf.h b/v5/apf.h
index 8f5e37d..9c6eb5f 100644
--- a/v5/apf.h
+++ b/v5/apf.h
@@ -205,11 +205,15 @@
  */
 #define ALLOCATE_EXT_OPCODE 36
 /* Transmit and deallocate the buffer (transmission can be delayed until the program
- * terminates). R=0 means discard the buffer, R=1 means transmit the buffer.
- * "e.g. trans"
- * "e.g. discard"
+ * terminates).  Length of buffer is the output buffer pointer (0 means discard).
+ * R=1 iff udp style L4 checksum
+ * u8 imm2 - ip header offset from start of buffer (255 for non-ip packets)
+ * u8 imm3 - offset from start of buffer to store L4 checksum (255 for no L4 checksum)
+ * u8 imm4 - offset from start of buffer to begin L4 checksum calculation (present iff imm3 != 255)
+ * u16 imm5 - partial checksum value to include in L4 checksum (present iff imm3 != 255)
+ * "e.g. transmit"
  */
-#define TRANSMITDISCARD_EXT_OPCODE 37
+#define TRANSMIT_EXT_OPCODE 37
 /* Write 1, 2 or 4 byte value from register to the output buffer and auto-increment the
  * output buffer pointer.
  * e.g. "ewrite1 r0"
diff --git a/v5/apf_interpreter.c b/v5/apf_interpreter.c
index 7d15ea0..1d7a30d 100644
--- a/v5/apf_interpreter.c
+++ b/v5/apf_interpreter.c
@@ -276,11 +276,15 @@
  */
 #define ALLOCATE_EXT_OPCODE 36
 /* Transmit and deallocate the buffer (transmission can be delayed until the program
- * terminates). R=0 means discard the buffer, R=1 means transmit the buffer.
- * "e.g. trans"
- * "e.g. discard"
+ * terminates).  Length of buffer is the output buffer pointer (0 means discard).
+ * R=1 iff udp style L4 checksum
+ * u8 imm2 - ip header offset from start of buffer (255 for non-ip packets)
+ * u8 imm3 - offset from start of buffer to store L4 checksum (255 for no L4 checksum)
+ * u8 imm4 - offset from start of buffer to begin L4 checksum calculation (present iff imm3 != 255)
+ * u16 imm5 - partial checksum value to include in L4 checksum (present iff imm3 != 255)
+ * "e.g. transmit"
  */
-#define TRANSMITDISCARD_EXT_OPCODE 37
+#define TRANSMIT_EXT_OPCODE 37
 /* Write 1, 2 or 4 byte value from register to the output buffer and auto-increment the
  * output buffer pointer.
  * e.g. "ewrite1 r0"
@@ -889,14 +893,14 @@
                     } else {
                         DECODE_IMM(tx_buf_len, 2); /* 2nd imm, at worst 6 bytes past prog_len */
                     }
-                    /* checksumming functions requires minimum 74 byte buffer for correctness */
-                    if (tx_buf_len < 74) tx_buf_len = 74;
+                    /* checksumming functions requires minimum 266 byte buffer for correctness */
+                    if (tx_buf_len < 266) tx_buf_len = 266;
                     tx_buf = apf_allocate_buffer(ctx, tx_buf_len);
                     if (!tx_buf) { counter[-3]++; return PASS_PACKET; } /* allocate failure */
                     memset(tx_buf, 0, tx_buf_len);
                     mem.named.tx_buf_offset = 0;
                     break;
-                  case TRANSMITDISCARD_EXT_OPCODE:
+                  case TRANSMIT_EXT_OPCODE:
                     ASSERT_RETURN(tx_buf != NULL);
                     u32 pkt_len = mem.named.tx_buf_offset;
                     /* If pkt_len > allocate_buffer_len, it means sth. wrong */
@@ -910,7 +914,18 @@
                     /* tx_buf_len cannot be large because we'd run out of RAM, */
                     /* so the above unsigned comparison effectively guarantees casting pkt_len */
                     /* to a signed value does not result in it going negative. */
-                    int dscp = calculate_checksum_and_return_dscp(tx_buf, (s32)pkt_len);
+                    u8 ip_ofs, csum_ofs;
+                    u8 csum_start = 0;
+                    u16 partial_csum = 0;
+                    DECODE_IMM(ip_ofs, 1);            /* 2nd imm, at worst 5 bytes past prog_len */
+                    DECODE_IMM(csum_ofs, 1);          /* 3rd imm, at worst 6 bytes past prog_len */
+                    if (csum_ofs < 255) {
+                        DECODE_IMM(csum_start, 1);    /* 4th imm, at worst 7 bytes past prog_len */
+                        DECODE_IMM(partial_csum, 2);  /* 5th imm, at worst 9 bytes past prog_len */
+                    }
+                    int dscp = csum_and_return_dscp(tx_buf, (s32)pkt_len, ip_ofs,
+                                                    partial_csum, csum_start, csum_ofs,
+                                                    (bool)reg_num);
                     int ret = apf_transmit_buffer(ctx, tx_buf, pkt_len, dscp);
                     tx_buf = NULL;
                     tx_buf_len = 0;
diff --git a/v5/apf_interpreter_source.c b/v5/apf_interpreter_source.c
index 782bf00..3834b6b 100644
--- a/v5/apf_interpreter_source.c
+++ b/v5/apf_interpreter_source.c
@@ -308,14 +308,14 @@
                     } else {
                         DECODE_IMM(tx_buf_len, 2); // 2nd imm, at worst 6 bytes past prog_len
                     }
-                    // checksumming functions requires minimum 74 byte buffer for correctness
-                    if (tx_buf_len < 74) tx_buf_len = 74;
+                    // checksumming functions requires minimum 266 byte buffer for correctness
+                    if (tx_buf_len < 266) tx_buf_len = 266;
                     tx_buf = apf_allocate_buffer(ctx, tx_buf_len);
                     if (!tx_buf) { counter[-3]++; return PASS_PACKET; } // allocate failure
                     memset(tx_buf, 0, tx_buf_len);
                     mem.named.tx_buf_offset = 0;
                     break;
-                  case TRANSMITDISCARD_EXT_OPCODE:
+                  case TRANSMIT_EXT_OPCODE:
                     ASSERT_RETURN(tx_buf != NULL);
                     u32 pkt_len = mem.named.tx_buf_offset;
                     // If pkt_len > allocate_buffer_len, it means sth. wrong
@@ -329,7 +329,18 @@
                     // tx_buf_len cannot be large because we'd run out of RAM,
                     // so the above unsigned comparison effectively guarantees casting pkt_len
                     // to a signed value does not result in it going negative.
-                    int dscp = calculate_checksum_and_return_dscp(tx_buf, (s32)pkt_len);
+                    u8 ip_ofs, csum_ofs;
+                    u8 csum_start = 0;
+                    u16 partial_csum = 0;
+                    DECODE_IMM(ip_ofs, 1);            // 2nd imm, at worst 5 bytes past prog_len
+                    DECODE_IMM(csum_ofs, 1);          // 3rd imm, at worst 6 bytes past prog_len
+                    if (csum_ofs < 255) {
+                        DECODE_IMM(csum_start, 1);    // 4th imm, at worst 7 bytes past prog_len
+                        DECODE_IMM(partial_csum, 2);  // 5th imm, at worst 9 bytes past prog_len
+                    }
+                    int dscp = csum_and_return_dscp(tx_buf, (s32)pkt_len, ip_ofs,
+                                                    partial_csum, csum_start, csum_ofs,
+                                                    (bool)reg_num);
                     int ret = apf_transmit_buffer(ctx, tx_buf, pkt_len, dscp);
                     tx_buf = NULL;
                     tx_buf_len = 0;