Merge "Fix decoding logic for fields with default values" into androidx-main
diff --git a/savedstate/savedstate/src/commonMain/kotlin/androidx/savedstate/serialization/SavedStateDecoder.kt b/savedstate/savedstate/src/commonMain/kotlin/androidx/savedstate/serialization/SavedStateDecoder.kt
index 65ccd1f..3f06dff 100644
--- a/savedstate/savedstate/src/commonMain/kotlin/androidx/savedstate/serialization/SavedStateDecoder.kt
+++ b/savedstate/savedstate/src/commonMain/kotlin/androidx/savedstate/serialization/SavedStateDecoder.kt
@@ -22,6 +22,7 @@
 import kotlinx.serialization.ExperimentalSerializationApi
 import kotlinx.serialization.SerializationException
 import kotlinx.serialization.descriptors.SerialDescriptor
+import kotlinx.serialization.descriptors.StructureKind
 import kotlinx.serialization.encoding.AbstractDecoder
 import kotlinx.serialization.encoding.CompositeDecoder
 import kotlinx.serialization.modules.EmptySerializersModule
@@ -73,9 +74,31 @@
     private var index = 0
 
     override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
-        if (index == savedState.read { size() }) return CompositeDecoder.DECODE_DONE
-        key = descriptor.getElementName(index)
-        return index++
+        val size =
+            if (descriptor.kind == StructureKind.LIST || descriptor.kind == StructureKind.MAP) {
+                // Use the number of elements encoded for collections.
+                savedState.read { size() }
+            } else {
+                // We may skip elements when encoding so if we used `size()`
+                // here we may miss some fields.
+                descriptor.elementsCount
+            }
+        fun hasDefaultValueDefined(index: Int) = descriptor.isElementOptional(index)
+        fun presentInEncoding(index: Int) =
+            savedState.read {
+                val key = descriptor.getElementName(index)
+                contains(key)
+            }
+        // Skip elements omitted from encoding (those assigned with its default values).
+        while (index < size && hasDefaultValueDefined(index) && !presentInEncoding(index)) {
+            index++
+        }
+        if (index < size) {
+            key = descriptor.getElementName(index)
+            return index++
+        } else {
+            return CompositeDecoder.DECODE_DONE
+        }
     }
 
     override fun decodeBoolean(): Boolean = savedState.read { getBoolean(key) }
diff --git a/savedstate/savedstate/src/commonTest/kotlin/androidx/savedstate/SavedStateCodecTest.kt b/savedstate/savedstate/src/commonTest/kotlin/androidx/savedstate/SavedStateCodecTest.kt
index 89d6548..88c33e4 100644
--- a/savedstate/savedstate/src/commonTest/kotlin/androidx/savedstate/SavedStateCodecTest.kt
+++ b/savedstate/savedstate/src/commonTest/kotlin/androidx/savedstate/SavedStateCodecTest.kt
@@ -401,15 +401,19 @@
 
     @Test
     fun sealedClasses() {
-        Node.Add(Node.Operand(3), Node.Operand(5)).encodeDecode {
+        // Should use base type for encoding/decoding.
+        Node.Add(Node.Operand(3), Node.Operand(5)).encodeDecode<Node> {
             assertThat(size()).isEqualTo(2)
-            getSavedState("lhs").read {
-                assertThat(size()).isEqualTo(1)
-                assertThat(getInt("value")).isEqualTo(3)
-            }
-            getSavedState("rhs").read {
-                assertThat(size()).isEqualTo(1)
-                assertThat(getInt("value")).isEqualTo(5)
+            assertThat(getString("type")).isEqualTo("androidx.savedstate.Node.Add")
+            getSavedState("value").read {
+                getSavedState("lhs").read {
+                    assertThat(size()).isEqualTo(1)
+                    assertThat(getInt("value")).isEqualTo(3)
+                }
+                getSavedState("rhs").read {
+                    assertThat(size()).isEqualTo(1)
+                    assertThat(getInt("value")).isEqualTo(5)
+                }
             }
         }
     }
@@ -425,12 +429,24 @@
         }
 
         // Nullable with default value.
-        @Serializable data class B(val s: String? = "foo")
-        B().encodeDecode()
-        B(s = "bar").encodeDecode {
+        @Serializable data class B(val s: String? = "foo", val i: Int)
+        B(i = 3).encodeDecode {
             assertThat(size()).isEqualTo(1)
+            assertThat(getInt("i")).isEqualTo(3)
+        }
+        B(s = null, i = 3).encodeDecode {
+            assertThat(size()).isEqualTo(2)
+            assertThat(isNull("s")).isTrue()
+        }
+        B(s = "bar", i = 3).encodeDecode {
+            assertThat(size()).isEqualTo(2)
             assertThat(getString("s")).isEqualTo("bar")
         }
+        // The value of `s` is the same as its default value so it's omitted from encoding.
+        B(s = "foo", i = 3).encodeDecode {
+            assertThat(size()).isEqualTo(1)
+            assertThat(getInt("i")).isEqualTo(3)
+        }
 
         // Nullable without default value
         @Serializable data class C(val s: String?)
@@ -450,6 +466,22 @@
             assertThat(getInt("i")).isEqualTo(5)
             assertThat(getString("s")).isEqualTo("foo")
         }
+
+        // Nullable with null as default value.
+        @Serializable data class E(val s: String? = null)
+        // Even though we encode `null`s in general as we don't encode default values
+        // nothing is encoded.
+        E().encodeDecode()
+
+        // Nullable in parent
+        G(i = 3).encodeDecode<F> {
+            assertThat(size()).isEqualTo(2)
+            assertThat(getString("type")).isEqualTo("androidx.savedstate.G")
+            getSavedState("value").read {
+                assertThat(size()).isEqualTo(1)
+                assertThat(getInt("i")).isEqualTo(3)
+            }
+        }
     }
 
     @Test
@@ -568,6 +600,7 @@
 
 private typealias MyNestedTypeAlias = MyTypeAliasToInt
 
+@Serializable
 private sealed class Node {
     @Serializable data class Add(val lhs: Operand, val rhs: Operand) : Node()
 
@@ -610,3 +643,10 @@
         return MyColor(array[0], array[1], array[2])
     }
 }
+
+@Serializable
+private sealed class F {
+    val s: String? = null
+}
+
+@Serializable private data class G(val i: Int) : F()