PendingAnswer should also use StrictlyOrderedAwaitTask<T>

This commit is contained in:
Christian Köllner 2020-04-11 21:21:17 +02:00
parent 19b36a1643
commit 197817a7d7

View File

@ -1,4 +1,5 @@
using System; using Capnp.Util;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -9,7 +10,7 @@ namespace Capnp.Rpc
{ {
readonly CancellationTokenSource? _cts; readonly CancellationTokenSource? _cts;
readonly TaskCompletionSource<AnswerOrCounterquestion> _cancelCompleter; readonly TaskCompletionSource<AnswerOrCounterquestion> _cancelCompleter;
readonly Task<AnswerOrCounterquestion> _answerTask; readonly StrictlyOrderedAwaitTask<AnswerOrCounterquestion> _answerTask;
public PendingAnswer(Task<AnswerOrCounterquestion> callTask, CancellationTokenSource? cts) public PendingAnswer(Task<AnswerOrCounterquestion> callTask, CancellationTokenSource? cts)
{ {
@ -23,19 +24,16 @@ namespace Capnp.Rpc
_cts = cts; _cts = cts;
_cancelCompleter = new TaskCompletionSource<AnswerOrCounterquestion>(); _cancelCompleter = new TaskCompletionSource<AnswerOrCounterquestion>();
_answerTask = CancelableAwaitWhenReady(); _answerTask = CancelableAwaitWhenReady().EnforceAwaitOrder();
Chain(async t => TakeCapTableOwnership();
}
async void TakeCapTableOwnership()
{
try
{ {
var aorcq = default(AnswerOrCounterquestion); var aorcq = await _answerTask;
try
{
aorcq = await t;
}
catch
{
}
if (aorcq.Answer != null) if (aorcq.Answer != null)
{ {
@ -43,11 +41,35 @@ namespace Capnp.Rpc
{ {
foreach (var cap in aorcq.Answer.Caps) foreach (var cap in aorcq.Answer.Caps)
{ {
cap?.AddRef(); cap.AddRef();
} }
} }
} }
}); }
catch
{
}
}
async void ReleaseCapTableOwnership()
{
try
{
var aorcq = await _answerTask;
if (aorcq.Answer != null)
{
if (aorcq.Answer.Caps != null)
{
foreach (var cap in aorcq.Answer.Caps)
{
cap?.Release();
}
}
}
}
catch
{
}
} }
public CancellationToken CancellationToken => _cts?.Token ?? CancellationToken.None; public CancellationToken CancellationToken => _cts?.Token ?? CancellationToken.None;
@ -60,105 +82,78 @@ namespace Capnp.Rpc
_cancelCompleter.SetCanceled(); _cancelCompleter.SetCanceled();
} }
public void Chain(Action<Task<AnswerOrCounterquestion>> func) public void Chain(Action<StrictlyOrderedAwaitTask<AnswerOrCounterquestion>> func)
{ {
func(_answerTask); func(_answerTask);
} }
public void Chain(PromisedAnswer.READER rd, Action<Task<Proxy>> func) public void Chain(PromisedAnswer.READER rd, Action<Task<Proxy>> func)
{ {
Chain(t => async Task<Proxy> EvaluateProxy()
{ {
async Task<Proxy> EvaluateProxy() var aorcq = await _answerTask;
if (aorcq.Answer != null)
{ {
var aorcq = await t; DeserializerState cur = aorcq.Answer;
if (aorcq.Answer != null) foreach (var op in rd.Transform)
{ {
DeserializerState cur = aorcq.Answer; switch (op.which)
foreach (var op in rd.Transform)
{ {
switch (op.which) case PromisedAnswer.Op.WHICH.GetPointerField:
{
case PromisedAnswer.Op.WHICH.GetPointerField:
try
{
cur = cur.StructReadPointer(op.GetPointerField);
}
catch (System.Exception)
{
throw new RpcException("Illegal pointer field in transformation operation");
}
break;
case PromisedAnswer.Op.WHICH.Noop:
break;
default:
throw new ArgumentOutOfRangeException("Unknown transformation operation");
}
}
switch (cur.Kind)
{
case ObjectKind.Capability:
try try
{ {
return new Proxy(aorcq.Answer.Caps![(int)cur.CapabilityIndex]); cur = cur.StructReadPointer(op.GetPointerField);
} }
catch (ArgumentOutOfRangeException) catch (System.Exception)
{ {
throw new RpcException("Capability index out of range"); throw new RpcException("Illegal pointer field in transformation operation");
} }
break;
case ObjectKind.Nil: case PromisedAnswer.Op.WHICH.Noop:
return new Proxy(NullCapability.Instance); break;
default: default:
throw new ArgumentOutOfRangeException("Transformation did not result in a capability"); throw new ArgumentOutOfRangeException("Unknown transformation operation");
} }
} }
else
switch (cur.Kind)
{ {
var path = MemberAccessPath.Deserialize(rd); case ObjectKind.Capability:
var cap = new RemoteAnswerCapability(aorcq.Counterquestion!, path); try
return new Proxy(cap); {
return new Proxy(aorcq.Answer.Caps![(int)cur.CapabilityIndex]);
}
catch (ArgumentOutOfRangeException)
{
throw new RpcException("Capability index out of range");
}
case ObjectKind.Nil:
return new Proxy(NullCapability.Instance);
default:
throw new ArgumentOutOfRangeException("Transformation did not result in a capability");
} }
} }
else
{
var path = MemberAccessPath.Deserialize(rd);
var cap = new RemoteAnswerCapability(aorcq.Counterquestion!, path);
return new Proxy(cap);
}
}
func(EvaluateProxy()); func(EvaluateProxy());
});
} }
public void Dispose() public void Dispose()
{ {
_cts?.Dispose(); _cts?.Dispose();
ReleaseCapTableOwnership();
Chain(async t =>
{
AnswerOrCounterquestion aorcq;
try
{
aorcq = await t;
}
catch
{
return;
}
if (aorcq.Answer != null)
{
if (aorcq.Answer.Caps != null)
{
foreach (var cap in aorcq.Answer.Caps)
{
cap?.Release();
}
}
}
});
} }
} }
} }