Pass HPKE mode in HPKE createContext.

PiperOrigin-RevId: 553067158
diff --git a/java_src/src/main/java/com/google/crypto/tink/hybrid/internal/HpkeContext.java b/java_src/src/main/java/com/google/crypto/tink/hybrid/internal/HpkeContext.java
index 4d5b35a..c1077ac 100644
--- a/java_src/src/main/java/com/google/crypto/tink/hybrid/internal/HpkeContext.java
+++ b/java_src/src/main/java/com/google/crypto/tink/hybrid/internal/HpkeContext.java
@@ -58,6 +58,7 @@
 
   /** Helper function factored out to facilitate unit testing. */
   static HpkeContext createContext(
+      byte[] mode,
       byte[] encapsulatedKey,
       byte[] sharedSecret,
       HpkeKem kem,
@@ -68,7 +69,7 @@
     byte[] suiteId = HpkeUtil.hpkeSuiteId(kem.getKemId(), kdf.getKdfId(), aead.getAeadId());
     byte[] pskIdHash = kdf.labeledExtract(HpkeUtil.EMPTY_SALT, EMPTY_IKM, "psk_id_hash", suiteId);
     byte[] infoHash = kdf.labeledExtract(HpkeUtil.EMPTY_SALT, info, "info_hash", suiteId);
-    byte[] keyScheduleContext = Bytes.concat(HpkeUtil.BASE_MODE, pskIdHash, infoHash);
+    byte[] keyScheduleContext = Bytes.concat(mode, pskIdHash, infoHash);
     byte[] secret = kdf.labeledExtract(sharedSecret, EMPTY_IKM, "secret", suiteId);
 
     byte[] key = kdf.labeledExpand(secret, keyScheduleContext, "key", suiteId, aead.getKeyLength());
@@ -96,7 +97,7 @@
         kem.encapsulate(recipientPublicKey.getPublicKey().toByteArray());
     byte[] encapsulatedKey = encapOutput.getEncapsulatedKey();
     byte[] sharedSecret = encapOutput.getSharedSecret();
-    return createContext(encapsulatedKey, sharedSecret, kem, kdf, aead, info);
+    return createContext(HpkeUtil.BASE_MODE, encapsulatedKey, sharedSecret, kem, kdf, aead, info);
   }
 
   /**
@@ -119,7 +120,7 @@
       byte[] info)
       throws GeneralSecurityException {
     byte[] sharedSecret = kem.decapsulate(encapsulatedKey, recipientPrivateKey);
-    return createContext(encapsulatedKey, sharedSecret, kem, kdf, aead, info);
+    return createContext(HpkeUtil.BASE_MODE, encapsulatedKey, sharedSecret, kem, kdf, aead, info);
   }
 
   private static BigInteger maxSequenceNumber(int nonceLength) {
diff --git a/java_src/src/main/java/com/google/crypto/tink/hybrid/internal/HpkeUtil.java b/java_src/src/main/java/com/google/crypto/tink/hybrid/internal/HpkeUtil.java
index a719924..559e445 100644
--- a/java_src/src/main/java/com/google/crypto/tink/hybrid/internal/HpkeUtil.java
+++ b/java_src/src/main/java/com/google/crypto/tink/hybrid/internal/HpkeUtil.java
@@ -30,6 +30,7 @@
 public final class HpkeUtil {
   // HPKE mode identifiers.
   public static final byte[] BASE_MODE = intToByteArray(1, 0x0);
+  public static final byte[] AUTH_MODE = intToByteArray(1, 0x2);
 
   // HPKE KEM algorithm identifiers.
   public static final byte[] X25519_HKDF_SHA256_KEM_ID = intToByteArray(2, 0x20);
diff --git a/java_src/src/test/java/com/google/crypto/tink/hybrid/internal/HpkeContextTest.java b/java_src/src/test/java/com/google/crypto/tink/hybrid/internal/HpkeContextTest.java
index 8e2d09e..38fe78c 100644
--- a/java_src/src/test/java/com/google/crypto/tink/hybrid/internal/HpkeContextTest.java
+++ b/java_src/src/test/java/com/google/crypto/tink/hybrid/internal/HpkeContextTest.java
@@ -104,13 +104,25 @@
 
     HpkeContext encryptionContext =
         HpkeContext.createContext(
-            testSetup.encapsulatedKey, testSetup.sharedSecret, kem, kdf, aead, testSetup.info);
+            mode,
+            testSetup.encapsulatedKey,
+            testSetup.sharedSecret,
+            kem,
+            kdf,
+            aead,
+            testSetup.info);
     verifyContext(encryptionContext, testVector);
     verifyEncrypt(encryptionContext, testVector);
 
     HpkeContext decryptionContext =
         HpkeContext.createContext(
-            testSetup.encapsulatedKey, testSetup.sharedSecret, kem, kdf, aead, testSetup.info);
+            mode,
+            testSetup.encapsulatedKey,
+            testSetup.sharedSecret,
+            kem,
+            kdf,
+            aead,
+            testSetup.info);
     verifyContext(decryptionContext, testVector);
     verifyDecrypt(decryptionContext, testVector);
   }