Allow extra fields in wrapper messages, more tests.
diff --git a/csharp/src/Google.Protobuf.Test/WellKnownTypes/WrappersTest.cs b/csharp/src/Google.Protobuf.Test/WellKnownTypes/WrappersTest.cs
index 8a9c3d0..4a425f7 100644
--- a/csharp/src/Google.Protobuf.Test/WellKnownTypes/WrappersTest.cs
+++ b/csharp/src/Google.Protobuf.Test/WellKnownTypes/WrappersTest.cs
@@ -386,7 +386,7 @@
}
[Test]
- public void UnknownFieldInWrapper()
+ public void UnknownFieldInWrapperInt32FastPath()
{
var stream = new MemoryStream();
var output = new CodedOutputStream(stream);
@@ -395,13 +395,40 @@
var valueTag = WireFormat.MakeTag(Int32Value.ValueFieldNumber, WireFormat.WireType.Varint);
output.WriteTag(wrapperTag);
- output.WriteLength(4); // unknownTag + value 5 + valueType + value 6, each 1 byte
+ // Wrapper message is just long enough - 6 bytes - to use the wrapper fast-path.
+ output.WriteLength(6); // unknownTag + value 5 + valueType, each 1 byte, + value 65536, 3 bytes
output.WriteTag(unknownTag);
output.WriteInt32((int) valueTag); // Sneakily "pretend" it's a tag when it's really a value
output.WriteTag(valueTag);
+ output.WriteInt32(65536);
+
+ output.Flush();
+ Assert.AreEqual(8, stream.Length); // tag (1 byte) + length (1 byte) + message (6 bytes)
+ stream.Position = 0;
+
+ var message = TestWellKnownTypes.Parser.ParseFrom(stream);
+ Assert.AreEqual(65536, message.Int32Field);
+ }
+
+ [Test]
+ public void UnknownFieldInWrapperInt32SlowPath()
+ {
+ var stream = new MemoryStream();
+ var output = new CodedOutputStream(stream);
+ var wrapperTag = WireFormat.MakeTag(TestWellKnownTypes.Int32FieldFieldNumber, WireFormat.WireType.LengthDelimited);
+ var unknownTag = WireFormat.MakeTag(15, WireFormat.WireType.Varint);
+ var valueTag = WireFormat.MakeTag(Int32Value.ValueFieldNumber, WireFormat.WireType.Varint);
+
+ output.WriteTag(wrapperTag);
+ // Wrapper message is too short to be used on the wrapper fast-path.
+ output.WriteLength(4); // unknownTag + value 5 + valueType + value 6, each 1 byte
+ output.WriteTag(unknownTag);
+ output.WriteInt32((int)valueTag); // Sneakily "pretend" it's a tag when it's really a value
+ output.WriteTag(valueTag);
output.WriteInt32(6);
output.Flush();
+ Assert.Less(stream.Length, 8); // tag (1 byte) + length (1 byte) + message
stream.Position = 0;
var message = TestWellKnownTypes.Parser.ParseFrom(stream);
@@ -409,6 +436,56 @@
}
[Test]
+ public void UnknownFieldInWrapperInt64FastPath()
+ {
+ var stream = new MemoryStream();
+ var output = new CodedOutputStream(stream);
+ var wrapperTag = WireFormat.MakeTag(TestWellKnownTypes.Int64FieldFieldNumber, WireFormat.WireType.LengthDelimited);
+ var unknownTag = WireFormat.MakeTag(15, WireFormat.WireType.Varint);
+ var valueTag = WireFormat.MakeTag(Int64Value.ValueFieldNumber, WireFormat.WireType.Varint);
+
+ output.WriteTag(wrapperTag);
+ // Wrapper message is just long enough - 10 bytes - to use the wrapper fast-path.
+ output.WriteLength(11); // unknownTag + value 5 + valueType, each 1 byte, + value 0xfffffffffffff, 8 bytes
+ output.WriteTag(unknownTag);
+ output.WriteInt64((int)valueTag); // Sneakily "pretend" it's a tag when it's really a value
+ output.WriteTag(valueTag);
+ output.WriteInt64(0xfffffffffffffL);
+
+ output.Flush();
+ Assert.AreEqual(13, stream.Length); // tag (1 byte) + length (1 byte) + message (11 bytes)
+ stream.Position = 0;
+
+ var message = TestWellKnownTypes.Parser.ParseFrom(stream);
+ Assert.AreEqual(0xfffffffffffffL, message.Int64Field);
+ }
+
+ [Test]
+ public void UnknownFieldInWrapperInt64SlowPath()
+ {
+ var stream = new MemoryStream();
+ var output = new CodedOutputStream(stream);
+ var wrapperTag = WireFormat.MakeTag(TestWellKnownTypes.Int64FieldFieldNumber, WireFormat.WireType.LengthDelimited);
+ var unknownTag = WireFormat.MakeTag(15, WireFormat.WireType.Varint);
+ var valueTag = WireFormat.MakeTag(Int64Value.ValueFieldNumber, WireFormat.WireType.Varint);
+
+ output.WriteTag(wrapperTag);
+ // Wrapper message is too short to be used on the wrapper fast-path.
+ output.WriteLength(4); // unknownTag + value 5 + valueType + value 6, each 1 byte
+ output.WriteTag(unknownTag);
+ output.WriteInt64((int)valueTag); // Sneakily "pretend" it's a tag when it's really a value
+ output.WriteTag(valueTag);
+ output.WriteInt64(6);
+
+ output.Flush();
+ Assert.Less(stream.Length, 12); // tag (1 byte) + length (1 byte) + message
+ stream.Position = 0;
+
+ var message = TestWellKnownTypes.Parser.ParseFrom(stream);
+ Assert.AreEqual(6L, message.Int64Field);
+ }
+
+ [Test]
public void ClearWithReflection()
{
// String and Bytes are the tricky ones here, as the CLR type of the property
diff --git a/csharp/src/Google.Protobuf/CodedInputStream.cs b/csharp/src/Google.Protobuf/CodedInputStream.cs
index cd091b2..44934f3 100644
--- a/csharp/src/Google.Protobuf/CodedInputStream.cs
+++ b/csharp/src/Google.Protobuf/CodedInputStream.cs
@@ -737,29 +737,76 @@
return false;
}
+ internal static float? ReadFloatWrapperLittleEndian(CodedInputStream input)
+ {
+ // length:1 + tag:1 + value:4 = 6 bytes
+ if (input.bufferPos + 6 <= input.bufferSize)
+ {
+ // The entire wrapper message is already contained in `buffer`.
+ int length = input.buffer[input.bufferPos];
+ if (length == 0)
+ {
+ input.bufferPos++;
+ return 0F;
+ }
+ // tag:1 + value:4 = length of 5 bytes
+ // field=1, type=32-bit = tag of 13
+ if (length != 5 || input.buffer[input.bufferPos + 1] != 13)
+ {
+ return ReadFloatWrapperSlow(input);
+ }
+ var result = BitConverter.ToSingle(input.buffer, input.bufferPos + 2);
+ input.bufferPos += 6;
+ return result;
+ }
+ else
+ {
+ return ReadFloatWrapperSlow(input);
+ }
+ }
+
+ internal static float? ReadFloatWrapperSlow(CodedInputStream input)
+ {
+ int length = input.ReadLength();
+ if (length == 0)
+ {
+ return 0F;
+ }
+ int finalBufferPos = input.totalBytesRetired + input.bufferPos + length;
+ float result = 0F;
+ do
+ {
+ // field=1, type=32-bit = tag of 13
+ if (input.ReadTag() == 13)
+ {
+ result = input.ReadFloat();
+ }
+ else
+ {
+ input.SkipLastField();
+ }
+ }
+ while (input.totalBytesRetired + input.bufferPos < finalBufferPos);
+ return result;
+ }
+
internal static double? ReadDoubleWrapperLittleEndian(CodedInputStream input)
{
- // tag:1 + value:8 = 9 bytes
- const int expectedLength = 9;
- // field=1, type=64-bit = tag of 9
- const int expectedTag = 9;
// length:1 + tag:1 + value:8 = 10 bytes
if (input.bufferPos + 10 <= input.bufferSize)
{
+ // The entire wrapper message is already contained in `buffer`.
int length = input.buffer[input.bufferPos];
if (length == 0)
{
input.bufferPos++;
return 0D;
}
- if (length != expectedLength)
- {
- throw InvalidProtocolBufferException.InvalidWrapperMessageLength();
- }
+ // tag:1 + value:8 = length of 9 bytes
// field=1, type=64-bit = tag of 9
- if (input.buffer[input.bufferPos + 1] != expectedTag)
+ if (length != 9 || input.buffer[input.bufferPos + 1] != 9)
{
- throw InvalidProtocolBufferException.InvalidWrapperMessageTag();
+ return ReadDoubleWrapperSlow(input);
}
var result = BitConverter.ToDouble(input.buffer, input.bufferPos + 2);
input.bufferPos += 10;
@@ -767,50 +814,119 @@
}
else
{
- int length = input.ReadLength();
- if (length == 0)
- {
- return 0D;
- }
- if (length != expectedLength)
- {
- throw InvalidProtocolBufferException.InvalidWrapperMessageLength();
- }
- if (input.ReadTag() != expectedTag)
- {
- throw InvalidProtocolBufferException.InvalidWrapperMessageTag();
- }
- return input.ReadDouble();
+ return ReadDoubleWrapperSlow(input);
}
}
- internal static double? ReadDoubleWrapperBigEndian(CodedInputStream input)
+ internal static double? ReadDoubleWrapperSlow(CodedInputStream input)
{
int length = input.ReadLength();
if (length == 0)
{
return 0D;
}
- // tag:1 + value:8 = 9 bytes
- if (length != 9)
+ int finalBufferPos = input.totalBytesRetired + input.bufferPos + length;
+ double result = 0D;
+ do
{
- throw InvalidProtocolBufferException.InvalidWrapperMessageLength();
+ // field=1, type=64-bit = tag of 9
+ if (input.ReadTag() == 9)
+ {
+ result = input.ReadDouble();
+ }
+ else
+ {
+ input.SkipLastField();
+ }
}
- // field=1, type=64-bit = tag of 9
- if (input.ReadTag() != 9)
- {
- throw InvalidProtocolBufferException.InvalidWrapperMessageTag();
- }
- return input.ReadDouble();
+ while (input.totalBytesRetired + input.bufferPos < finalBufferPos);
+ return result;
}
- internal static long? ReadInt64Wrapper(CodedInputStream input)
+ internal static bool? ReadBoolWrapper(CodedInputStream input)
+ {
+ return ReadUInt32Wrapper(input) != 0;
+ }
+
+ internal static uint? ReadUInt32Wrapper(CodedInputStream input)
+ {
+ // length:1 + tag:1 + value:5(varint32-max) = 7 bytes
+ if (input.bufferPos + 7 <= input.bufferSize)
+ {
+ // The entire wrapper message is already contained in `buffer`.
+ int pos0 = input.bufferPos;
+ int length = input.buffer[input.bufferPos++];
+ if (length == 0)
+ {
+ return 0;
+ }
+ // Length will always fit in a single byte.
+ if (length >= 128)
+ {
+ input.bufferPos = pos0;
+ return ReadUInt32WrapperSlow(input);
+ }
+ int finalBufferPos = input.bufferPos + length;
+ // field=1, type=varint = tag of 8
+ if (input.buffer[input.bufferPos++] != 8)
+ {
+ input.bufferPos = pos0;
+ return ReadUInt32WrapperSlow(input);
+ }
+ var result = input.ReadUInt32();
+ // Verify this message only contained a single field.
+ if (input.bufferPos != finalBufferPos)
+ {
+ input.bufferPos = pos0;
+ return ReadUInt32WrapperSlow(input);
+ }
+ return result;
+ }
+ else
+ {
+ return ReadUInt32WrapperSlow(input);
+ }
+ }
+
+ private static uint? ReadUInt32WrapperSlow(CodedInputStream input)
+ {
+ int length = input.ReadLength();
+ if (length == 0)
+ {
+ return 0;
+ }
+ int finalBufferPos = input.totalBytesRetired + input.bufferPos + length;
+ uint result = 0;
+ do
+ {
+ // field=1, type=varint = tag of 8
+ if (input.ReadTag() == 8)
+ {
+ result = input.ReadUInt32();
+ }
+ else
+ {
+ input.SkipLastField();
+ }
+ }
+ while (input.totalBytesRetired + input.bufferPos < finalBufferPos);
+ return result;
+ }
+
+ internal static int? ReadInt32Wrapper(CodedInputStream input)
+ {
+ return (int?)ReadUInt32Wrapper(input);
+ }
+
+ internal static ulong? ReadUInt64Wrapper(CodedInputStream input)
{
// field=1, type=varint = tag of 8
const int expectedTag = 8;
// length:1 + tag:1 + value:10(varint64-max) = 12 bytes
if (input.bufferPos + 12 <= input.bufferSize)
{
+ // The entire wrapper message is already contained in `buffer`.
+ int pos0 = input.bufferPos;
int length = input.buffer[input.bufferPos++];
if (length == 0)
{
@@ -819,43 +935,61 @@
// Length will always fit in a single byte.
if (length >= 128)
{
- throw InvalidProtocolBufferException.InvalidWrapperMessageLength();
+ input.bufferPos = pos0;
+ return ReadUInt64WrapperSlow(input);
}
int finalBufferPos = input.bufferPos + length;
if (input.buffer[input.bufferPos++] != expectedTag)
{
- throw InvalidProtocolBufferException.InvalidWrapperMessageTag();
+ input.bufferPos = pos0;
+ return ReadUInt64WrapperSlow(input);
}
- var result = input.ReadInt64();
+ var result = input.ReadUInt64();
// Verify this message only contained a single field.
if (input.bufferPos != finalBufferPos)
{
- throw InvalidProtocolBufferException.InvalidWrapperMessageExtraFields();
+ input.bufferPos = pos0;
+ return ReadUInt64WrapperSlow(input);
}
return result;
}
else
{
- int length = input.ReadLength();
- if (length == 0)
- {
- return 0L;
- }
- int finalBufferPos = input.totalBytesRetired + input.bufferPos + length;
- if (input.ReadTag() != expectedTag)
- {
- throw InvalidProtocolBufferException.InvalidWrapperMessageTag();
- }
- var result = input.ReadInt64();
- // Verify this message only contained a single field.
- if (input.totalBytesRetired + input.bufferPos != finalBufferPos)
- {
- throw InvalidProtocolBufferException.InvalidWrapperMessageExtraFields();
- }
- return result;
+ return ReadUInt64WrapperSlow(input);
}
}
-
+
+ internal static ulong? ReadUInt64WrapperSlow(CodedInputStream input)
+ {
+ // field=1, type=varint = tag of 8
+ const int expectedTag = 8;
+ int length = input.ReadLength();
+ if (length == 0)
+ {
+ return 0L;
+ }
+ int finalBufferPos = input.totalBytesRetired + input.bufferPos + length;
+ ulong result = 0L;
+ do
+ {
+ if (input.ReadTag() == expectedTag)
+ {
+ result = input.ReadUInt64();
+ }
+ else
+ {
+ input.SkipLastField();
+ }
+ }
+ while (input.totalBytesRetired + input.bufferPos < finalBufferPos);
+ return result;
+ }
+
+ internal static long? ReadInt64Wrapper(CodedInputStream input)
+ {
+ return (long?)ReadUInt64Wrapper(input);
+ }
+
#endregion
#region Underlying reading primitives
diff --git a/csharp/src/Google.Protobuf/FieldCodec.cs b/csharp/src/Google.Protobuf/FieldCodec.cs
index 7689964..1971261 100644
--- a/csharp/src/Google.Protobuf/FieldCodec.cs
+++ b/csharp/src/Google.Protobuf/FieldCodec.cs
@@ -539,18 +539,21 @@
{ typeof(ByteString), ForBytes(WireFormat.MakeTag(WrappersReflection.WrapperValueFieldNumber, WireFormat.WireType.LengthDelimited)) }
};
- private static readonly Dictionary<System.Type, Func<object>> Readers = new Dictionary<System.Type, Func<object>>
+ private static readonly Dictionary<System.Type, object> Readers = new Dictionary<System.Type, object>
{
// TODO: Provide more optimized readers.
- { typeof(bool), null },
- { typeof(int), null },
- { typeof(long), () => (Func<CodedInputStream, long?>)CodedInputStream.ReadInt64Wrapper },
- { typeof(uint), null },
- { typeof(ulong), null },
- { typeof(float), null },
- { typeof(double), () => BitConverter.IsLittleEndian ?
+ { typeof(bool), (Func<CodedInputStream, bool?>)CodedInputStream.ReadBoolWrapper },
+ { typeof(int), (Func<CodedInputStream, int?>)CodedInputStream.ReadInt32Wrapper },
+ { typeof(long), (Func<CodedInputStream, long?>)CodedInputStream.ReadInt64Wrapper },
+ { typeof(uint), (Func<CodedInputStream, uint?>)CodedInputStream.ReadUInt32Wrapper },
+ { typeof(ulong), (Func<CodedInputStream, ulong?>)CodedInputStream.ReadUInt64Wrapper },
+ { typeof(float), BitConverter.IsLittleEndian ?
+ (Func<CodedInputStream, float?>)CodedInputStream.ReadFloatWrapperLittleEndian :
+ (Func<CodedInputStream, float?>)CodedInputStream.ReadFloatWrapperSlow },
+ { typeof(double), BitConverter.IsLittleEndian ?
(Func<CodedInputStream, double?>)CodedInputStream.ReadDoubleWrapperLittleEndian :
- (Func<CodedInputStream, double?>)CodedInputStream.ReadDoubleWrapperBigEndian },
+ (Func<CodedInputStream, double?>)CodedInputStream.ReadDoubleWrapperSlow },
+ // `string` and `ByteString` less performance-sensitive. Do not implement for now.
{ typeof(string), null },
{ typeof(ByteString), null },
};
@@ -571,7 +574,7 @@
internal static Func<CodedInputStream, T?> GetReader<T>() where T : struct
{
- Func<object> value;
+ object value;
if (!Readers.TryGetValue(typeof(T), out value))
{
throw new InvalidOperationException("Invalid type argument requested for wrapper reader: " + typeof(T));
@@ -583,7 +586,7 @@
return input => Read<T>(input, nestedCoded);
}
// Return optimized read for the wrapper type.
- return (Func<CodedInputStream, T?>)value();
+ return (Func<CodedInputStream, T?>)value;
}
internal static T Read<T>(CodedInputStream input, FieldCodec<T> codec)