diff --git a/Capnp.Net.Runtime/Rpc/PendingAnswer.cs b/Capnp.Net.Runtime/Rpc/PendingAnswer.cs index 8692f6f..e4fb99f 100644 --- a/Capnp.Net.Runtime/Rpc/PendingAnswer.cs +++ b/Capnp.Net.Runtime/Rpc/PendingAnswer.cs @@ -1,4 +1,5 @@ -using System; +using Capnp.Util; +using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; @@ -9,7 +10,7 @@ namespace Capnp.Rpc { readonly CancellationTokenSource? _cts; readonly TaskCompletionSource _cancelCompleter; - readonly Task _answerTask; + readonly StrictlyOrderedAwaitTask _answerTask; public PendingAnswer(Task callTask, CancellationTokenSource? cts) { @@ -23,19 +24,16 @@ namespace Capnp.Rpc _cts = cts; _cancelCompleter = new TaskCompletionSource(); - _answerTask = CancelableAwaitWhenReady(); + _answerTask = CancelableAwaitWhenReady().EnforceAwaitOrder(); - Chain(async t => + TakeCapTableOwnership(); + } + + async void TakeCapTableOwnership() + { + try { - var aorcq = default(AnswerOrCounterquestion); - - try - { - aorcq = await t; - } - catch - { - } + var aorcq = await _answerTask; if (aorcq.Answer != null) { @@ -43,11 +41,35 @@ namespace Capnp.Rpc { foreach (var cap in aorcq.Answer.Caps) { - cap?.AddRef(); + cap.AddRef(); } } } - }); + } + catch + { + } + } + + async void ReleaseCapTableOwnership() + { + try + { + var aorcq = await _answerTask; + if (aorcq.Answer != null) + { + if (aorcq.Answer.Caps != null) + { + foreach (var cap in aorcq.Answer.Caps) + { + cap?.Release(); + } + } + } + } + catch + { + } } public CancellationToken CancellationToken => _cts?.Token ?? CancellationToken.None; @@ -60,105 +82,78 @@ namespace Capnp.Rpc _cancelCompleter.SetCanceled(); } - public void Chain(Action> func) + public void Chain(Action> func) { func(_answerTask); } public void Chain(PromisedAnswer.READER rd, Action> func) { - Chain(t => + async Task EvaluateProxy() { - async Task EvaluateProxy() + var aorcq = await _answerTask; + + if (aorcq.Answer != null) { - var aorcq = await t; + DeserializerState cur = aorcq.Answer; - if (aorcq.Answer != null) + foreach (var op in rd.Transform) { - DeserializerState cur = aorcq.Answer; - - foreach (var op in rd.Transform) + switch (op.which) { - switch (op.which) - { - case PromisedAnswer.Op.WHICH.GetPointerField: - try - { - cur = cur.StructReadPointer(op.GetPointerField); - } - catch (System.Exception) - { - throw new RpcException("Illegal pointer field in transformation operation"); - } - break; - - case PromisedAnswer.Op.WHICH.Noop: - break; - - default: - throw new ArgumentOutOfRangeException("Unknown transformation operation"); - } - } - - switch (cur.Kind) - { - case ObjectKind.Capability: + case PromisedAnswer.Op.WHICH.GetPointerField: try { - return new Proxy(aorcq.Answer.Caps![(int)cur.CapabilityIndex]); + cur = cur.StructReadPointer(op.GetPointerField); } - catch (ArgumentOutOfRangeException) + catch (System.Exception) { - throw new RpcException("Capability index out of range"); + throw new RpcException("Illegal pointer field in transformation operation"); } + break; - case ObjectKind.Nil: - return new Proxy(NullCapability.Instance); + case PromisedAnswer.Op.WHICH.Noop: + break; default: - throw new ArgumentOutOfRangeException("Transformation did not result in a capability"); + throw new ArgumentOutOfRangeException("Unknown transformation operation"); } } - else + + switch (cur.Kind) { - var path = MemberAccessPath.Deserialize(rd); - var cap = new RemoteAnswerCapability(aorcq.Counterquestion!, path); - return new Proxy(cap); + case ObjectKind.Capability: + try + { + return new Proxy(aorcq.Answer.Caps![(int)cur.CapabilityIndex]); + } + catch (ArgumentOutOfRangeException) + { + throw new RpcException("Capability index out of range"); + } + + case ObjectKind.Nil: + return new Proxy(NullCapability.Instance); + + default: + throw new ArgumentOutOfRangeException("Transformation did not result in a capability"); } } + else + { + var path = MemberAccessPath.Deserialize(rd); + var cap = new RemoteAnswerCapability(aorcq.Counterquestion!, path); + return new Proxy(cap); + } + } - func(EvaluateProxy()); - }); + func(EvaluateProxy()); } public void Dispose() { _cts?.Dispose(); - - Chain(async t => - { - AnswerOrCounterquestion aorcq; - - try - { - aorcq = await t; - } - catch - { - return; - } - - if (aorcq.Answer != null) - { - if (aorcq.Answer.Caps != null) - { - foreach (var cap in aorcq.Answer.Caps) - { - cap?.Release(); - } - } - } - }); + ReleaseCapTableOwnership(); } } } \ No newline at end of file