diff --git a/Capnp.Net.Runtime/Rpc/PendingAnswer.cs b/Capnp.Net.Runtime/Rpc/PendingAnswer.cs index 6660725..5959b7e 100644 --- a/Capnp.Net.Runtime/Rpc/PendingAnswer.cs +++ b/Capnp.Net.Runtime/Rpc/PendingAnswer.cs @@ -6,98 +6,41 @@ namespace Capnp.Rpc { class PendingAnswer: IDisposable { - readonly object _reentrancyBlocker = new object(); readonly CancellationTokenSource? _cts; - readonly TaskCompletionSource _whenCanceled; - Task _callTask; - Task? _initialTask; - Task? _chainedTask; - bool _disposed; + readonly TaskCompletionSource _cancelCompleter; + readonly Task _answerTask; public PendingAnswer(Task callTask, CancellationTokenSource? cts) { + async Task CancelableAwaitWhenReady() + { + return await await Task.WhenAny(callTask, _cancelCompleter.Task); + } + + if (callTask == null) + throw new ArgumentNullException(nameof(callTask)); + _cts = cts; - _callTask = callTask ?? throw new ArgumentNullException(nameof(callTask)); - _whenCanceled = new TaskCompletionSource(); + _cancelCompleter = new TaskCompletionSource(); + _answerTask = CancelableAwaitWhenReady(); } + public CancellationToken CancellationToken => _cts?.Token ?? CancellationToken.None; + public void Cancel() { _cts?.Cancel(); - _whenCanceled.SetResult(0); + _cancelCompleter.SetCanceled(); } - async Task InitialAwaitWhenReady() + public void Chain(Action> func) { - var which = await Task.WhenAny(_callTask, _whenCanceled.Task); - - if (which != _callTask) - { - throw new TaskCanceledException(); - } + func(_answerTask); } - async Task AwaitChainedTask(Task chainedTask, Func, Task> func) + public void Chain(PromisedAnswer.READER rd, Action> func) { - try - { - await chainedTask; - } - catch (System.Exception exception) - { - await func(Task.FromException(exception)); - throw; - } - - await func(_callTask); - } - - static async Task AwaitSeq(Task task1, Task task2) - { - await task1; - await task2; - } - - public void Chain(bool strictSync, Func, Task> func) - { - - lock (_reentrancyBlocker) - { - if (_disposed) - { - throw new ObjectDisposedException(nameof(PendingAnswer)); - } - - if (_initialTask == null) - { - _initialTask = InitialAwaitWhenReady(); - } - - Task followUpTask; - - if (strictSync) - { - followUpTask = AwaitChainedTask(_chainedTask ?? _initialTask, func); - } - else - { - followUpTask = AwaitChainedTask(_initialTask, func); - } - - if (_chainedTask != null) - { - _chainedTask = AwaitSeq(_chainedTask, followUpTask); - } - else - { - _chainedTask = followUpTask; - } - } - } - - public void Chain(bool strictSync, PromisedAnswer.READER rd, Func, Task> func) - { - Chain(strictSync, async t => + Chain(t => { async Task EvaluateProxy() { @@ -158,43 +101,13 @@ namespace Capnp.Rpc } } - await func(EvaluateProxy()); + func(EvaluateProxy()); }); } - public CancellationToken CancellationToken => _cts?.Token ?? CancellationToken.None; - - public async void Dispose() + public void Dispose() { - if (_cts != null) - { - Task? chainedTask; - - lock (_reentrancyBlocker) - { - if (_disposed) - { - return; - } - chainedTask = _chainedTask; - _disposed = true; - } - - if (chainedTask != null) - { - try - { - await chainedTask; - } - catch - { - } - finally - { - _cts.Dispose(); - } - } - } + _cts?.Dispose(); } } } \ No newline at end of file diff --git a/Capnp.Net.Runtime/Rpc/RpcEngine.cs b/Capnp.Net.Runtime/Rpc/RpcEngine.cs index 2378959..87460f8 100644 --- a/Capnp.Net.Runtime/Rpc/RpcEngine.cs +++ b/Capnp.Net.Runtime/Rpc/RpcEngine.cs @@ -398,7 +398,7 @@ namespace Capnp.Rpc switch (req.SendResultsTo.which) { case Call.sendResultsTo.WHICH.Caller: - pendingAnswer.Chain(false, async t => + pendingAnswer.Chain(async t => { try { @@ -466,7 +466,7 @@ namespace Capnp.Rpc break; case Call.sendResultsTo.WHICH.Yourself: - pendingAnswer.Chain(false, async t => + pendingAnswer.Chain(async t => { try { @@ -575,7 +575,6 @@ namespace Capnp.Rpc if (exists) { previousAnswer!.Chain( - false, req.Target.PromisedAnswer, async t => { @@ -679,7 +678,7 @@ namespace Capnp.Rpc if (exists) { - pendingAnswer!.Chain(false, async t => + pendingAnswer!.Chain(async t => { try { @@ -807,7 +806,7 @@ namespace Capnp.Rpc if (_answerTable.TryGetValue(promisedAnswer.QuestionId, out var previousAnswer)) { - previousAnswer.Chain(true, + previousAnswer.Chain( disembargo.Target.PromisedAnswer, async t => { @@ -924,7 +923,7 @@ namespace Capnp.Rpc { try { - answer.Chain(false, async t => + answer.Chain(async t => { var aorcq = await t; var results = aorcq.Answer; @@ -1246,7 +1245,7 @@ namespace Capnp.Rpc { var tcs = new TaskCompletionSource(); - pendingAnswer.Chain(false, + pendingAnswer.Chain( capDesc.ReceiverAnswer, async t => {