using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; namespace Capnp { /// /// Implements the heart of deserialization. This stateful helper struct exposes all functionality to traverse serialized data. /// Although it is public, you should not use it directly. Instead, use the reader, writer, and domain class adapters which are produced /// by the code generator. /// public struct DeserializerState: IStructDeserializer, IDisposable { /// /// A wire message is essentially a collection of memory blocks. /// public IReadOnlyList> Segments { get; } /// /// Index of the segment (into the Segments property) which this state currently refers to. /// public uint CurrentSegmentIndex { get; private set; } /// /// Word offset within the current segment which this state currently refers to. /// public int Offset { get; set; } /// /// Context-dependent meaning: Usually the number of bytes traversed until this state was reached, to prevent amplification attacks. /// However, if this state is of Kind == ObjectKind.Value (an artificial category which will never occur on the wire but is used to /// internally represent lists of primitives as lists of structs), it contains the primitive's value. /// public uint BytesTraversedOrData { get; set; } /// /// If this state currently represents a list, the number of list elements. /// public int ListElementCount { get; private set; } /// /// If this state currently represents a struct, the struct's data section word count. /// public ushort StructDataCount { get; set; } /// /// If this state currently represents a struct, the struct's pointer section word count. /// public ushort StructPtrCount { get; set; } /// /// The kind of object this state currently represents. /// public ObjectKind Kind { get; set; } bool _disposed; /// /// The capabilities imported from the capability table. Only valid in RPC context. /// public IList? Caps { get; set; } /// /// Current segment (essentially Segments[CurrentSegmentIndex]) /// public ReadOnlySpan CurrentSegment => Segments != null ? Segments[(int)CurrentSegmentIndex].Span : default; DeserializerState(IReadOnlyList> segments) { Segments = segments; CurrentSegmentIndex = 0; Offset = 0; BytesTraversedOrData = 0; ListElementCount = 0; StructDataCount = 0; StructPtrCount = 1; Kind = ObjectKind.Struct; Caps = null; _disposed = false; } /// /// Constructs a state representing a message root object. /// /// the message /// public static DeserializerState CreateRoot(WireFrame frame) { var state = new DeserializerState(frame.Segments); state.DecodePointer(0); return state; } /// /// Implicitly converts a serializer state into a deserializer state. /// The conversion is cheap, since it does not involve copying any payload. /// /// The serializer state to be converted /// is null /// is not bound to a MessageBuilder public static implicit operator DeserializerState(SerializerState state) { if (state == null) throw new ArgumentNullException(nameof(state)); if (state.MsgBuilder == null) throw new InvalidOperationException("state is not bound to a MessageBuilder"); switch (state.Kind) { case ObjectKind.ListOfBits: case ObjectKind.ListOfBytes: case ObjectKind.ListOfEmpty: case ObjectKind.ListOfInts: case ObjectKind.ListOfLongs: case ObjectKind.ListOfPointers: case ObjectKind.ListOfShorts: case ObjectKind.ListOfStructs: case ObjectKind.Nil: case ObjectKind.Struct: return new DeserializerState(state.Allocator!.Segments) { CurrentSegmentIndex = state.SegmentIndex, Offset = state.Offset, ListElementCount = state.ListElementCount, StructDataCount = state.StructDataCount, StructPtrCount = state.StructPtrCount, Kind = state.Kind, Caps = state.Caps }; case ObjectKind.Capability: return new DeserializerState(state.Allocator!.Segments) { Kind = ObjectKind.Capability, Caps = state.Caps, BytesTraversedOrData = state.CapabilityIndex }; default: throw new ArgumentException("Unexpected type of object, cannot convert that into DeserializerState", nameof(state)); } } /// /// Constructs a state representing the given value. This kind of state is artificial and beyond the Cap'n Proto specification. /// We need it to internally represent list of primitive values as lists of structs. /// public static DeserializerState MakeValueState(uint value) { return new DeserializerState() { BytesTraversedOrData = value, Kind = ObjectKind.Value }; } /// /// Increments the number of bytes traversed and checks the results against the traversal limit. /// /// Amount to increase the traversed bytes public void IncrementBytesTraversed(uint additionalBytesTraversed) { BytesTraversedOrData = checked(BytesTraversedOrData + additionalBytesTraversed); if (BytesTraversedOrData > SecurityOptions.TraversalLimit) throw new DeserializationException("Traversal limit was reached"); } /// /// Memory span which represents this struct's data section (given this state actually represents a struct) /// public ReadOnlySpan StructDataSection => CurrentSegment.Slice(Offset, StructDataCount); ReadOnlySpan GetRawBits() => CurrentSegment.Slice(Offset, (ListElementCount + 63) / 64); ReadOnlySpan GetRawBytes() => CurrentSegment.Slice(Offset, (ListElementCount + 7) / 8); ReadOnlySpan GetRawShorts() => CurrentSegment.Slice(Offset, (ListElementCount + 3) / 4); ReadOnlySpan GetRawInts() => CurrentSegment.Slice(Offset, (ListElementCount + 1) / 2); ReadOnlySpan GetRawLongs() => CurrentSegment.Slice(Offset, ListElementCount); /// /// If this state represents a list of primitive values, returns the raw list data. /// public ReadOnlySpan RawData { get { return Kind switch { ObjectKind.ListOfBits => GetRawBits(), ObjectKind.ListOfBytes => GetRawBytes(), ObjectKind.ListOfShorts => GetRawShorts(), ObjectKind.ListOfInts => GetRawInts(), ObjectKind.ListOfLongs => GetRawLongs(), _ => default, }; } } void Validate() { try { switch (Kind) { case ObjectKind.Struct: CurrentSegment.Slice(Offset, StructDataCount + StructPtrCount); break; case ObjectKind.ListOfBits: GetRawBits(); break; case ObjectKind.ListOfBytes: GetRawBytes(); break; case ObjectKind.ListOfShorts: GetRawShorts(); break; case ObjectKind.ListOfInts: GetRawInts(); break; case ObjectKind.ListOfLongs: case ObjectKind.ListOfPointers: GetRawLongs(); break; case ObjectKind.ListOfStructs: CurrentSegment.Slice(Offset, checked(ListElementCount * (StructDataCount + StructPtrCount))); break; } } catch (Exception problem) { throw new DeserializationException("Invalid wire pointer", problem); } } /// /// Interprets a pointer within the current segment and mutates this state to represent the pointer's target. /// /// word offset relative to this.Offset within current segment /// offset negative or out of range /// invalid pointer data or traversal limit exceeded internal void DecodePointer(int offset) { if (offset < 0) throw new IndexOutOfRangeException(nameof(offset)); WirePointer pointer = CurrentSegment[Offset + offset]; int derefCount = 0; do { if (pointer.IsNull) { this = default; return; } switch (pointer.Kind) { case PointerKind.Struct: Offset = checked(pointer.Offset + Offset + offset + 1); IncrementBytesTraversed(checked(8u * pointer.StructSize)); StructDataCount = pointer.StructDataCount; StructPtrCount = pointer.StructPtrCount; Kind = ObjectKind.Struct; Validate(); return; case PointerKind.List: Offset = checked(pointer.Offset + Offset + offset + 1); ListElementCount = pointer.ListElementCount; StructDataCount = 0; StructPtrCount = 0; switch (pointer.ListKind) { case ListKind.ListOfEmpty: // e.g. List(void) // the “traversal limit” should count a list of zero-sized elements as if each element were one word instead. IncrementBytesTraversed(checked(8u * (uint)ListElementCount)); Kind = ObjectKind.ListOfEmpty; break; case ListKind.ListOfBits: IncrementBytesTraversed(checked((uint)ListElementCount + 7) / 8); Kind = ObjectKind.ListOfBits; break; case ListKind.ListOfBytes: IncrementBytesTraversed((uint)ListElementCount); Kind = ObjectKind.ListOfBytes; break; case ListKind.ListOfShorts: IncrementBytesTraversed(checked(2u * (uint)ListElementCount)); Kind = ObjectKind.ListOfShorts; break; case ListKind.ListOfInts: IncrementBytesTraversed(checked(4u * (uint)ListElementCount)); Kind = ObjectKind.ListOfInts; break; case ListKind.ListOfLongs: IncrementBytesTraversed(checked(8u * (uint)ListElementCount)); Kind = ObjectKind.ListOfLongs; break; case ListKind.ListOfPointers: IncrementBytesTraversed(checked(8u * (uint)ListElementCount)); Kind = ObjectKind.ListOfPointers; break; case ListKind.ListOfStructs: { if (Offset >= CurrentSegment.Length) throw new DeserializationException("List of composites pointer exceeds segment bounds"); WirePointer tag = CurrentSegment[Offset]; if (tag.Kind != PointerKind.Struct) throw new DeserializationException("Unexpected: List of composites with non-struct type tag"); IncrementBytesTraversed(checked(8u * (uint)pointer.ListElementCount + 8u)); ListElementCount = tag.ListOfStructsElementCount; StructDataCount = tag.StructDataCount; StructPtrCount = tag.StructPtrCount; Kind = ObjectKind.ListOfStructs; } break; default: throw new InvalidProgramException(); } Validate(); return; case PointerKind.Far: if (pointer.TargetSegmentIndex >= Segments.Count) throw new DeserializationException("Error decoding pointer: Invalid target segment index"); CurrentSegmentIndex = pointer.TargetSegmentIndex; if (pointer.IsDoubleFar) { if (pointer.LandingPadOffset >= CurrentSegment.Length - 1) throw new DeserializationException("Error decoding double-far pointer: exceeds segment bounds"); Offset = 0; WirePointer pointer1 = CurrentSegment[pointer.LandingPadOffset]; if (pointer1.Kind != PointerKind.Far || pointer1.IsDoubleFar) throw new DeserializationException("Error decoding double-far pointer: convention broken"); WirePointer pointer2 = CurrentSegment[pointer.LandingPadOffset + 1]; if (pointer2.Kind == PointerKind.Far) throw new DeserializationException("Error decoding double-far pointer: not followed by intra-segment pointer"); CurrentSegmentIndex = pointer1.TargetSegmentIndex; Offset = pointer1.LandingPadOffset; pointer = pointer2; offset = -1; } else { Offset = 0; offset = pointer.LandingPadOffset; if (pointer.LandingPadOffset >= CurrentSegment.Length) throw new DeserializationException("Error decoding pointer: exceeds segment bounds"); pointer = CurrentSegment[pointer.LandingPadOffset]; } continue; case PointerKind.Other: var tmp = Caps; this = default; Caps = tmp; Kind = ObjectKind.Capability; BytesTraversedOrData = pointer.CapabilityIndex; return; default: throw new InvalidProgramException(); } } while (++derefCount < SecurityOptions.RecursionLimit); throw new DeserializationException("Recursion limit reached while decoding a pointer"); } /// /// Interprets a pointer within the current segment as capability pointer and returns the according low-level capability object from /// the capability table. Does not mutate this state. /// /// Offset relative to this.Offset within current segment /// the low-level capability object, or null if it is a null pointer /// offset negative or out of range /// capability table not set /// not a capability pointer or invalid capability index internal Rpc.ConsumedCapability DecodeCapPointer(int offset) { if (offset < 0) { throw new ArgumentOutOfRangeException(nameof(offset)); } if (Caps == null) { throw new InvalidOperationException("Capbility table not set"); } WirePointer pointer = CurrentSegment[Offset + offset]; if (pointer.IsNull) { // Despite this behavior is not officially specified, // the official C++ implementation seems to send null pointers for null caps. return Rpc.NullCapability.Instance; } if (pointer.Kind != PointerKind.Other) { throw new Rpc.RpcException("Expected a capability pointer, but got something different"); } if (pointer.CapabilityIndex >= Caps.Count) { throw new Rpc.RpcException("Capability index out of range"); } return Caps[(int)pointer.CapabilityIndex]; } /// /// Reads a slice of up to 64 bits from this struct's data section, starting from the specified bit offset. /// The slice must be aligned within a 64 bit word boundary. /// /// Start bit offset relative to the data section, little endian /// numbers of bits to read /// the data /// non-aligned access /// bitOffset exceeds the data section /// this state does not represent a struct public ulong StructReadData(ulong bitOffset, int bitCount) { switch (Kind) { case ObjectKind.Nil: return 0; case ObjectKind.Struct: int index = checked((int)(bitOffset / 64)); int relBitOffset = (int)(bitOffset % 64); var data = StructDataSection; if (index >= data.Length) return 0; // Assume backwards-compatible change if (relBitOffset + bitCount > 64) throw new ArgumentOutOfRangeException(nameof(bitCount)); ulong word = data[index]; if (bitCount == 64) { return word; } else { ulong mask = (1ul << bitCount) - 1; return (word >> relBitOffset) & mask; } case ObjectKind.Value: if (bitOffset >= 32) return 0; if (bitCount >= 32) return BytesTraversedOrData >> (int)bitOffset; return (BytesTraversedOrData >> (int)bitOffset) & ((1u << bitCount) - 1); default: throw new DeserializationException("This is not a struct"); } } /// /// Decodes a pointer from this struct's pointer section and returns the state representing the pointer target. /// It is valid to specify an index beyond the pointer section, in which case a default state (representing the "null object") /// will be returned. This is to preserve upward compatibility with schema evolution. /// /// Index within the pointer section /// the target state /// this state does not represent a struct, /// invalid pointer, or traversal limit exceeded public DeserializerState StructReadPointer(int index) { if (Kind != ObjectKind.Struct && Kind != ObjectKind.Nil) throw new DeserializationException("This is not a struct"); if (index >= StructPtrCount) return default; DeserializerState state = this; state.DecodePointer(index + StructDataCount); return state; } internal Rpc.ConsumedCapability StructReadRawCap(int index) { if (Kind != ObjectKind.Struct && Kind != ObjectKind.Nil) throw new InvalidOperationException("Allowed on structs only"); if (index >= StructPtrCount) return Rpc.NullCapability.Instance; return DecodeCapPointer(index + StructDataCount); } /// /// Given this state represents a list (of anything), returns a ListDeserializer to further decode the list content. /// /// state does not represent a list public ListDeserializer RequireList() { switch (Kind) { case ObjectKind.ListOfBits: return new ListOfBitsDeserializer(this, false); case ObjectKind.ListOfBytes: return new ListOfPrimitivesDeserializer(this, ListKind.ListOfBytes); case ObjectKind.ListOfEmpty: return new ListOfEmptyDeserializer(this); case ObjectKind.ListOfInts: return new ListOfPrimitivesDeserializer(this, ListKind.ListOfInts); case ObjectKind.ListOfLongs: return new ListOfPrimitivesDeserializer(this, ListKind.ListOfLongs); case ObjectKind.ListOfPointers: return new ListOfPointersDeserializer(this); case ObjectKind.ListOfShorts: return new ListOfPrimitivesDeserializer(this, ListKind.ListOfShorts); case ObjectKind.ListOfStructs: return new ListOfStructsDeserializer(this); case ObjectKind.Nil: return new EmptyListDeserializer(); default: throw new DeserializationException("Cannot deserialize this object as list"); } } /// /// Given this state represents a list of pointers, returns a ListOfCapsDeserializer for decoding it as list of capabilities. /// /// Capability interface /// state does not represent a list of pointers public ListOfCapsDeserializer RequireCapList() where T: class { switch (Kind) { case ObjectKind.ListOfPointers: return new ListOfCapsDeserializer(this); default: throw new DeserializationException("Cannot deserialize this object as capability list"); } } /// /// Convenience method. Given this state represents a struct, decodes text field from its pointer table. /// /// index within this struct's pointer table /// default text to return of pointer is null /// the decoded text, or defaultText (which might be null) /// negative index /// state does not represent a struct, invalid pointer, /// non-list-of-bytes pointer, traversal limit exceeded [return: NotNullIfNotNull("defaultText")] public string? ReadText(int index, string? defaultText = null) { return StructReadPointer(index).RequireList().CastText() ?? defaultText; } /// /// Convenience method. Given this state represents a struct, decodes a list deserializer field from its pointer table. /// /// index within this struct's pointer table /// the list deserializer instance /// negative index /// state does not represent a struct, invalid pointer, /// non-list pointer, traversal limit exceeded public ListDeserializer ReadList(int index) { return StructReadPointer(index).RequireList(); } /// /// Convenience method. Given this state represents a struct, decodes a capability list field from its pointer table. /// /// Capability interface /// index within this struct's pointer table /// the capability list deserializer instance /// negative index /// state does not represent a struct, invalid pointer, /// non-list-of-pointers pointer, traversal limit exceeded public ListOfCapsDeserializer ReadCapList(int index) where T : class { return StructReadPointer(index).RequireCapList(); } /// /// Convenience method. Given this state represents a struct, decodes a list of structs field from its pointer table. /// /// Struct target representation type /// index within this struct's pointer table /// constructs a target representation type instance from the underlying deserializer state /// the decoded list of structs /// negative index /// state does not represent a struct, invalid pointer, /// non-list-of-{structs,pointers} pointer, traversal limit exceeded public IReadOnlyList ReadListOfStructs(int index, Func cons) { return ReadList(index).Cast(cons); } /// /// Convenience method. Given this state represents a struct, decodes a struct field from its pointer table. /// /// Struct target representation type /// index within this struct's pointer table /// constructs a target representation type instance from the underlying deserializer state /// the decoded struct /// negative index /// state does not represent a struct, invalid pointer, /// non-struct pointer, traversal limit exceeded public T ReadStruct(int index, Func cons) { return cons(StructReadPointer(index)); } /// /// Convenience method. Given this state represents a struct, determines if a field is non-null. /// /// index within this struct's pointer table /// true if the field is non-null, false otherwise /// negative or too large index /// state does not represent a struct, invalid pointer, /// non-struct pointer public bool IsStructFieldNonNull(int index) { if (Kind != ObjectKind.Struct && Kind != ObjectKind.Nil) { throw new DeserializationException("This is not a struct"); } if (index < 0 || index >= StructPtrCount) { throw new IndexOutOfRangeException($"Invalid index {index}. Must be [0, {StructPtrCount})."); } var pointerOffset = index + StructDataCount; WirePointer pointer = CurrentSegment[Offset + pointerOffset]; return !pointer.IsNull; } /// /// Given this state represents a capability, returns its index into the capability table. /// public uint CapabilityIndex => Kind == ObjectKind.Capability ? BytesTraversedOrData : ~0u; /// /// Given this state represents a struct, decodes a capability field from its pointer table. /// /// Capability interface /// index within this struct's pointer table /// capability instance or null if pointer was null /// negative index /// state does not represent a struct, invalid pointer, /// non-capability pointer, traversal limit exceeded public T? ReadCap(int index) where T: class { var cap = StructReadRawCap(index); return Rpc.CapabilityReflection.CreateProxy(cap) as T; } /// /// Given this state represents a struct, decodes a capability field from its pointer table and /// returns it as bare (generic) proxy. /// /// index within this struct's pointer table /// capability instance or null if pointer was null /// negative index /// state does not represent a struct, invalid pointer, /// non-capability pointer, traversal limit exceeded public Rpc.BareProxy ReadCap(int index) { var cap = StructReadRawCap(index); return new Rpc.BareProxy(cap); } /// /// Given this state represents a capability, wraps it into a proxy instance for the desired interface. /// /// Capability interface /// capability instance or null if pointer was null /// negative index /// state does not represent a capability public T? RequireCap() where T: class { if (Kind == ObjectKind.Nil) return null; if (Kind != ObjectKind.Capability) throw new DeserializationException("Expected a capability"); if (Caps == null) throw new InvalidOperationException("Capability table not set"); return (Rpc.CapabilityReflection.CreateProxy(Caps[(int)CapabilityIndex]) as T)!; } /// /// Releases the capability table /// public void Dispose() { if (Caps != null && !_disposed) { foreach (var cap in Caps) { cap.Release(); } Caps = null; _disposed = true; } } } }