coverage & fixes

This commit is contained in:
Christian Köllner 2020-04-10 00:01:12 +02:00
parent 747d350b20
commit 5a04f2f3da
12 changed files with 472 additions and 109 deletions

View File

@ -158,5 +158,23 @@ namespace Capnp.Net.Runtime.Tests
{ {
NewDtbdctTestbed().RunTest(Testsuite.ImportReceiverCanceled); NewDtbdctTestbed().RunTest(Testsuite.ImportReceiverCanceled);
} }
[TestMethod]
public void ButNoTailCall()
{
NewDtbdctTestbed().RunTest(Testsuite.ButNoTailCall);
}
[TestMethod]
public void SecondIsTailCall()
{
NewDtbdctTestbed().RunTest(Testsuite.SecondIsTailCall);
}
[TestMethod]
public void ReexportSenderPromise()
{
NewDtbdctTestbed().RunTest(Testsuite.ReexportSenderPromise);
}
} }
} }

View File

@ -681,6 +681,72 @@ namespace Capnp.Net.Runtime.Tests.GenImpls
} }
} }
} }
class TestTailCallerImpl3 : ITestTailCaller
{
public TestTailCallerImpl3()
{
}
public void Dispose()
{
}
public Task<TestTailCallee.TailResult> Foo(int i, ITestTailCallee callee, CancellationToken cancellationToken_)
{
using (callee)
{
var task1 = callee.Foo(i, "from TestTailCaller 1", cancellationToken_);
async void FinishTask()
{
var r = await task1;
r.C.Dispose();
}
FinishTask();
var task2 = callee.Foo(i, "from TestTailCaller 2", cancellationToken_);
async void AssertIsTailCall()
{
try
{
await task2;
Assert.Fail("Not a tail call");
}
catch (TailCallNoDataException)
{
}
}
AssertIsTailCall();
return task2;
}
}
}
class TestTailCallerImpl4 : ITestTailCaller
{
public TestTailCallerImpl4()
{
}
public void Dispose()
{
}
public async Task<TestTailCallee.TailResult> Foo(int i, ITestTailCallee callee, CancellationToken cancellationToken_)
{
await Task.Yield();
using (callee)
{
return await callee.Foo(i, "from TestTailCaller", cancellationToken_);
}
}
}
#endregion TestTailCaller #endregion TestTailCaller
#region TestTailCallee #region TestTailCallee
@ -993,7 +1059,7 @@ namespace Capnp.Net.Runtime.Tests.GenImpls
{ {
if (_echoCounter++ < 20) if (_echoCounter++ < 20)
{ {
return Task.FromResult<ITestCallOrder>(((Proxy)cap).Cast<ITestMoreStuff>(true).Echo(cap).Eager()); return Task.FromResult(((Proxy)cap).Cast<ITestMoreStuff>(false).Echo(cap).Eager());
} }
else else
{ {
@ -1056,6 +1122,192 @@ namespace Capnp.Net.Runtime.Tests.GenImpls
} }
} }
class TestMoreStuffImpl4 : ITestMoreStuff, ITestCallOrder
{
readonly TaskCompletionSource<ITestInterface> _heldCap = new TaskCompletionSource<ITestInterface>();
public Task<string> CallFoo(ITestInterface cap, CancellationToken cancellationToken_ = default)
{
using (cap)
{
return cap.Foo(123, true);
}
}
public Task<string> CallFooWhenResolved(ITestInterface Cap, CancellationToken cancellationToken_ = default)
{
throw new NotImplementedException();
}
public Task<string> CallHeld(CancellationToken cancellationToken_ = default)
{
throw new NotImplementedException();
}
public async void Dispose()
{
using (var cap = await _heldCap.Task)
{
}
}
public Task<ITestCallOrder> Echo(ITestCallOrder cap, CancellationToken cancellationToken_ = default)
{
using (var target = ((Proxy)cap).Cast<ITestMoreStuff>(false))
{
return Task.FromResult(target.Echo(cap).Eager());
}
}
public Task ExpectCancel(ITestInterface Cap, CancellationToken cancellationToken_ = default)
{
throw new NotImplementedException();
}
uint _counter;
public Task<uint> GetCallSequence(uint expected, CancellationToken cancellationToken_ = default)
{
Assert.AreEqual(_counter, expected);
return Task.FromResult(_counter++);
}
public Task<string> GetEnormousString(CancellationToken cancellationToken_ = default)
{
throw new NotImplementedException();
}
public Task<ITestHandle> GetHandle(CancellationToken cancellationToken_ = default)
{
throw new NotImplementedException();
}
public Task<ITestInterface> GetHeld(CancellationToken cancellationToken_ = default)
{
return Task.FromResult(_heldCap.Task.Eager(true));
}
public Task<ITestMoreStuff> GetNull(CancellationToken cancellationToken_ = default)
{
throw new NotImplementedException();
}
public Task Hold(ITestInterface Cap, CancellationToken cancellationToken_ = default)
{
_heldCap.SetResult(Cap);
return Task.CompletedTask;
}
public Task<(string, string)> MethodWithDefaults(string A, uint B, string C, CancellationToken cancellationToken_ = default)
{
throw new NotImplementedException();
}
public Task MethodWithNullDefault(string A, ITestInterface B, CancellationToken cancellationToken_ = default)
{
throw new NotImplementedException();
}
public Task<ITestInterface> NeverReturn(ITestInterface Cap, CancellationToken cancellationToken_ = default)
{
throw new NotImplementedException();
}
}
class TestMoreStuffImpl5 : ITestMoreStuff, ITestCallOrder
{
readonly TaskCompletionSource<ITestInterface> _heldCap = new TaskCompletionSource<ITestInterface>();
public Task<string> CallFoo(ITestInterface cap, CancellationToken cancellationToken_ = default)
{
using (cap)
{
return cap.Foo(123, true);
}
}
public Task<string> CallFooWhenResolved(ITestInterface Cap, CancellationToken cancellationToken_ = default)
{
throw new NotImplementedException();
}
public Task<string> CallHeld(CancellationToken cancellationToken_ = default)
{
throw new NotImplementedException();
}
public async void Dispose()
{
using (var cap = await _heldCap.Task)
{
}
}
TaskCompletionSource<int> _echoEnabled = new TaskCompletionSource<int>();
public void EnableEcho() => _echoEnabled.SetResult(0);
public async Task<ITestCallOrder> Echo(ITestCallOrder cap, CancellationToken cancellationToken_ = default)
{
await _echoEnabled.Task;
return cap;
}
public Task ExpectCancel(ITestInterface Cap, CancellationToken cancellationToken_ = default)
{
throw new NotImplementedException();
}
uint _counter;
public Task<uint> GetCallSequence(uint expected, CancellationToken cancellationToken_ = default)
{
Assert.AreEqual(_counter, expected);
return Task.FromResult(_counter++);
}
public Task<string> GetEnormousString(CancellationToken cancellationToken_ = default)
{
throw new NotImplementedException();
}
public Task<ITestHandle> GetHandle(CancellationToken cancellationToken_ = default)
{
throw new NotImplementedException();
}
public Task<ITestInterface> GetHeld(CancellationToken cancellationToken_ = default)
{
return Task.FromResult(_heldCap.Task.Eager(true));
}
public Task<ITestMoreStuff> GetNull(CancellationToken cancellationToken_ = default)
{
throw new NotImplementedException();
}
public Task Hold(ITestInterface Cap, CancellationToken cancellationToken_ = default)
{
_heldCap.SetResult(Cap);
return Task.CompletedTask;
}
public Task<(string, string)> MethodWithDefaults(string A, uint B, string C, CancellationToken cancellationToken_ = default)
{
throw new NotImplementedException();
}
public Task MethodWithNullDefault(string A, ITestInterface B, CancellationToken cancellationToken_ = default)
{
throw new NotImplementedException();
}
public Task<ITestInterface> NeverReturn(ITestInterface Cap, CancellationToken cancellationToken_ = default)
{
throw new NotImplementedException();
}
}
#endregion TestMoreStuff #endregion TestMoreStuff
#region TestHandle #region TestHandle

View File

@ -261,20 +261,29 @@ namespace Capnp.Net.Runtime.Tests
{ {
var echoTask = main.Echo(Proxy.Share<ITestCallOrder>(main)); var echoTask = main.Echo(Proxy.Share<ITestCallOrder>(main));
Assert.IsTrue(echoTask.Wait(MediumNonDbgTimeout)); Assert.IsTrue(echoTask.Wait(MediumNonDbgTimeout));
var echo = echoTask.Result; using (var echo = echoTask.Result)
var list = new Task<uint>[1000];
for (uint i = 0; i < list.Length; i++)
{ {
list[i] = echo.GetCallSequence(i); var list = new Task<uint>[1000];
} for (uint i = 0; i < list.Length; i++)
Assert.IsTrue(Task.WaitAll(list, MediumNonDbgTimeout)); {
for (uint i = 0; i < list.Length; i++) list[i] = echo.GetCallSequence(i);
{ }
Assert.AreEqual(i, list[i].Result); Assert.IsTrue(Task.WaitAll(list, MediumNonDbgTimeout));
for (uint i = 0; i < list.Length; i++)
{
Assert.AreEqual(i, list[i].Result);
}
} }
} }
} }
} }
} }
[TestMethod]
public void NoTailCallMt()
{
NewLocalhostTcpTestbed().RunTest(Testsuite.NoTailCallMt);
}
} }
} }

View File

@ -415,20 +415,22 @@ namespace Capnp.Net.Runtime.Tests
var callee = new TestTailCalleeImpl(calleeCallCount); var callee = new TestTailCalleeImpl(calleeCallCount);
var promise = main.Foo(456, callee, default); var promise = main.Foo(456, callee, default);
var dependentCall0 = promise.C().GetCallSequence(0, default); using (var c = promise.C())
{
var dependentCall0 = c.GetCallSequence(0, default);
Assert.IsTrue(promise.Wait(MediumNonDbgTimeout)); Assert.IsTrue(promise.Wait(MediumNonDbgTimeout));
Assert.AreEqual(456u, promise.Result.I); Assert.AreEqual(456u, promise.Result.I);
Assert.AreEqual("from TestTailCaller", promise.Result.T); Assert.AreEqual("from TestTailCaller", promise.Result.T);
var dependentCall1 = promise.C().GetCallSequence(0, default); var dependentCall1 = c.GetCallSequence(0, default);
var dependentCall2 = promise.C().GetCallSequence(0, default); var dependentCall2 = c.GetCallSequence(0, default);
AssertOutput(stdout, "foo");
Assert.IsTrue(dependentCall0.Wait(MediumNonDbgTimeout));
Assert.IsTrue(dependentCall1.Wait(MediumNonDbgTimeout));
Assert.IsTrue(dependentCall2.Wait(MediumNonDbgTimeout));
AssertOutput(stdout, "foo");
Assert.IsTrue(dependentCall0.Wait(MediumNonDbgTimeout));
Assert.IsTrue(dependentCall1.Wait(MediumNonDbgTimeout));
Assert.IsTrue(dependentCall2.Wait(MediumNonDbgTimeout));
}
Assert.AreEqual(1, calleeCallCount.CallCount); Assert.AreEqual(1, calleeCallCount.CallCount);
} }
} }
@ -523,28 +525,30 @@ namespace Capnp.Net.Runtime.Tests
using (var main = client.GetMain<ITestMoreStuff>()) using (var main = client.GetMain<ITestMoreStuff>())
{ {
var tcs = new TaskCompletionSource<ITestInterface>(); var tcs = new TaskCompletionSource<ITestInterface>();
var eager = tcs.Task.Eager(true); using (var eager = tcs.Task.Eager(true))
{
var request = main.CallFoo(Proxy.Share(eager), default);
AssertOutput(stdout, "callFoo");
var request2 = main.CallFooWhenResolved(Proxy.Share(eager), default);
AssertOutput(stdout, "callFooWhenResolved");
var request = main.CallFoo(eager, default); var gcs = main.GetCallSequence(0, default);
AssertOutput(stdout, "callFoo"); AssertOutput(stdout, "getCallSequence");
var request2 = main.CallFooWhenResolved(eager, default); Assert.IsTrue(gcs.Wait(MediumNonDbgTimeout));
AssertOutput(stdout, "callFooWhenResolved"); Assert.AreEqual(2u, gcs.Result);
var gcs = main.GetCallSequence(0, default); var chainedCallCount = new Counters();
AssertOutput(stdout, "getCallSequence"); var tiimpl = new TestInterfaceImpl(chainedCallCount);
Assert.IsTrue(gcs.Wait(MediumNonDbgTimeout)); tcs.SetResult(tiimpl);
Assert.AreEqual(2u, gcs.Result);
var chainedCallCount = new Counters(); Assert.IsTrue(request.Wait(MediumNonDbgTimeout));
var tiimpl = new TestInterfaceImpl(chainedCallCount); Assert.IsTrue(request2.Wait(MediumNonDbgTimeout));
tcs.SetResult(tiimpl);
Assert.IsTrue(request.Wait(MediumNonDbgTimeout)); Assert.AreEqual("bar", request.Result);
Assert.IsTrue(request2.Wait(MediumNonDbgTimeout)); Assert.AreEqual("bar", request2.Result);
Assert.AreEqual(2, chainedCallCount.CallCount);
Assert.AreEqual("bar", request.Result); }
Assert.AreEqual("bar", request2.Result);
Assert.AreEqual(2, chainedCallCount.CallCount);
AssertOutput(stdout, "fin"); AssertOutput(stdout, "fin");
AssertOutput(stdout, "fin"); AssertOutput(stdout, "fin");

View File

@ -749,5 +749,74 @@ namespace Capnp.Net.Runtime.Tests
Assert.IsTrue(foo.IsCanceled); Assert.IsTrue(foo.IsCanceled);
} }
} }
public static void ButNoTailCall(ITestbed testbed)
{
var impl = new TestMoreStuffImpl4();
using (var main = testbed.ConnectMain<ITestMoreStuff>(impl))
{
var peer = new TestMoreStuffImpl5();
var heldTask = main.Echo(peer);
testbed.MustComplete(heldTask);
var r = heldTask.Result as IResolvingCapability;
peer.EnableEcho();
testbed.MustComplete(r.WhenResolved);
heldTask.Result.Dispose();
}
}
public static void SecondIsTailCall(ITestbed testbed)
{
var impl = new TestTailCallerImpl3();
using (var main = testbed.ConnectMain<ITestTailCaller>(impl))
{
var callee = new TestTailCalleeImpl(new Counters());
var task = main.Foo(123, callee);
testbed.MustComplete(task);
Assert.AreEqual("from TestTailCaller 2", task.Result.T);
}
}
public static void NoTailCallMt(ITestbed testbed)
{
var impl = new TestTailCallerImpl4();
using (var main = testbed.ConnectMain<ITestTailCaller>(impl))
using (var callee = Proxy.Share<ITestTailCallee>(new TestTailCalleeImpl(new Counters())))
{
var tasks = ParallelEnumerable
.Range(0, 1000)
.Select(async i =>
{
var r = await main.Foo(i, Proxy.Share(callee));
Assert.AreEqual((uint)i, r.I);
})
.ToArray();
testbed.MustComplete(tasks);
Assert.IsFalse(tasks.Any(t => t.IsCanceled || t.IsFaulted));
}
}
public static void ReexportSenderPromise(ITestbed testbed)
{
var impl = new TestTailCallerImpl(new Counters());
using (var main = testbed.ConnectMain<ITestTailCaller>(impl))
{
var tcs = new TaskCompletionSource<ITestTailCallee>();
using (var promise = Proxy.Share(tcs.Task.Eager(true)))
{
var task1 = main.Foo(1, Proxy.Share(promise));
var task2 = main.Foo(2, Proxy.Share(promise));
var callee = new TestTailCalleeImpl(new Counters());
tcs.SetResult(callee);
testbed.MustComplete(task1, task2);
}
}
}
} }
} }

View File

@ -242,11 +242,16 @@ namespace Capnp.Net.Runtime.Tests
{ {
(_server, _client) = SetupClientServerPair(); (_server, _client) = SetupClientServerPair();
_client.WhenConnected.Wait(MediumNonDbgTimeout); _client.WhenConnected.Wait(MediumNonDbgTimeout);
Assert.IsTrue(SpinWait.SpinUntil(() => _server.ConnectionCount > 0, MediumNonDbgTimeout));
var conn = _server.Connections[0];
using (_server) using (_server)
using (_client) using (_client)
{ {
action(this); action(this);
Assert.IsTrue(SpinWait.SpinUntil(() => _client.SendCount == conn.RecvCount, MediumNonDbgTimeout));
Assert.IsTrue(SpinWait.SpinUntil(() => conn.SendCount == _client.RecvCount, MediumNonDbgTimeout));
} }
} }

View File

@ -29,7 +29,7 @@
</PropertyGroup> </PropertyGroup>
<PropertyGroup> <PropertyGroup>
<DefineConstants></DefineConstants> <DefineConstants>DebugFinalizers</DefineConstants>
</PropertyGroup> </PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(TargetFramework)|$(Platform)'=='Release|netstandard2.0|AnyCPU'"> <PropertyGroup Condition="'$(Configuration)|$(TargetFramework)|$(Platform)'=='Release|netstandard2.0|AnyCPU'">

View File

@ -1,4 +1,5 @@
using System; using System;
using System.Collections.Generic;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -11,7 +12,7 @@ namespace Capnp.Rpc
public static class Impatient public static class Impatient
{ {
static readonly ConditionalWeakTable<Task, IPromisedAnswer> _taskTable = new ConditionalWeakTable<Task, IPromisedAnswer>(); static readonly ConditionalWeakTable<Task, IPromisedAnswer> _taskTable = new ConditionalWeakTable<Task, IPromisedAnswer>();
static readonly ThreadLocal<IRpcEndpoint?> _askingEndpoint = new ThreadLocal<IRpcEndpoint?>(); static readonly ThreadLocal<Stack<IRpcEndpoint>> _askingEndpoint = new ThreadLocal<Stack<IRpcEndpoint>>(() => new Stack<IRpcEndpoint>());
/// <summary> /// <summary>
/// Attaches a continuation to the given promise and registers the resulting task for pipelining. /// Attaches a continuation to the given promise and registers the resulting task for pipelining.
@ -171,8 +172,17 @@ namespace Capnp.Rpc
internal static IRpcEndpoint? AskingEndpoint internal static IRpcEndpoint? AskingEndpoint
{ {
get => _askingEndpoint.Value; get => _askingEndpoint.Value!.Count > 0 ? _askingEndpoint.Value.Peek() : null;
set { _askingEndpoint.Value = value; } }
internal static void PushAskingEndpoint(IRpcEndpoint endpoint)
{
_askingEndpoint.Value!.Push(endpoint);
}
internal static void PopAskingEndpoint()
{
_askingEndpoint.Value!.Pop();
} }
/// <summary> /// <summary>

View File

@ -66,14 +66,6 @@ namespace Capnp.Rpc
_inParams = inParams; _inParams = inParams;
StateFlags = inParams == null ? State.Sent : State.None; StateFlags = inParams == null ? State.Sent : State.None;
if (inParams != null)
{
foreach (var cap in inParams.Caps!)
{
cap.AddRef();
}
}
if (target != null) if (target != null)
{ {
target.AddRef(); target.AddRef();
@ -282,22 +274,6 @@ namespace Capnp.Rpc
return new RemoteAnswerCapability(this, access, proxyTask); return new RemoteAnswerCapability(this, access, proxyTask);
} }
static void ReleaseCaps(ConsumedCapability? target, SerializerState? inParams)
{
if (inParams != null)
{
foreach (var cap in inParams.Caps!)
{
cap.Release();
}
}
if (target != null)
{
target.Release();
}
}
static void ReleaseOutCaps(DeserializerState outParams) static void ReleaseOutCaps(DeserializerState outParams)
{ {
foreach (var cap in outParams.Caps!) foreach (var cap in outParams.Caps!)
@ -327,8 +303,8 @@ namespace Capnp.Rpc
Debug.Assert(msg.Call!.Target.which != MessageTarget.WHICH.undefined); Debug.Assert(msg.Call!.Target.which != MessageTarget.WHICH.undefined);
var call = msg.Call; var call = msg.Call;
call.QuestionId = QuestionId; call.QuestionId = QuestionId;
call.SendResultsTo.which = IsTailCall ? call.SendResultsTo.which = IsTailCall ?
Call.sendResultsTo.WHICH.Yourself : Call.sendResultsTo.WHICH.Yourself :
Call.sendResultsTo.WHICH.Caller; Call.sendResultsTo.WHICH.Caller;
try try
@ -341,7 +317,7 @@ namespace Capnp.Rpc
OnException(exception); OnException(exception);
} }
ReleaseCaps(target, inParams); target?.Release();
} }
/// <summary> /// <summary>

View File

@ -1,5 +1,6 @@
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using System; using System;
using System.Diagnostics;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -25,10 +26,6 @@ namespace Capnp.Rpc
return BareProxy.FromImpl(obj).Cast<T>(true); return BareProxy.FromImpl(obj).Cast<T>(true);
} }
#if DebugFinalizers
ILogger Logger { get; } = Logging.CreateLogger<Proxy>();
#endif
bool _disposedValue = false; bool _disposedValue = false;
/// <summary> /// <summary>
@ -56,15 +53,23 @@ namespace Capnp.Rpc
return CapabilityReflection.CreateProxy<T>(ConsumedCap) as T; return CapabilityReflection.CreateProxy<T>(ConsumedCap) as T;
} }
ConsumedCapability? _consumedCap;
/// <summary> /// <summary>
/// Underlying low-level capability /// Underlying low-level capability
/// </summary> /// </summary>
protected internal ConsumedCapability? ConsumedCap { get; private set; } protected internal ConsumedCapability? ConsumedCap => _disposedValue ?
throw new ObjectDisposedException(nameof(Proxy)) : _consumedCap;
/// <summary> /// <summary>
/// Whether is this a broken capability. /// Whether is this a broken capability.
/// </summary> /// </summary>
public bool IsNull => ConsumedCap == null; public bool IsNull => _consumedCap == null;
/// <summary>
/// Whether <see cref="Dispose()"/> was called on this Proxy.
/// </summary>
public bool IsDisposed => _disposedValue;
static async void DisposeCtrWhenReturned(CancellationTokenRegistration ctr, IPromisedAnswer answer) static async void DisposeCtrWhenReturned(CancellationTokenRegistration ctr, IPromisedAnswer answer)
{ {
@ -134,12 +139,12 @@ namespace Capnp.Rpc
if (cap == null) if (cap == null)
return; return;
ConsumedCap = cap; _consumedCap = cap;
cap.AddRef(); cap.AddRef();
#if DebugFinalizers #if DebugFinalizers
if (ConsumedCap != null) if (_consumedCap != null)
ConsumedCap.OwningProxy = this; _consumedCap.OwningProxy = this;
#endif #endif
} }
@ -166,14 +171,14 @@ namespace Capnp.Rpc
{ {
if (disposing) if (disposing)
{ {
ConsumedCap?.Release(); _consumedCap?.Release();
} }
else else
{ {
// When called from the Finalizer, we must not throw. // When called from the Finalizer, we must not throw.
// But when reference counting goes wrong, ConsumedCapability.Release() will throw an InvalidOperationException. // But when reference counting goes wrong, ConsumedCapability.Release() will throw an InvalidOperationException.
// The only option here is to suppress that exception. // The only option here is to suppress that exception.
try { ConsumedCap?.Release(); } try { _consumedCap?.Release(); }
catch { } catch { }
} }
@ -187,7 +192,7 @@ namespace Capnp.Rpc
~Proxy() ~Proxy()
{ {
#if DebugFinalizers #if DebugFinalizers
Logger?.LogWarning($"Caught orphaned Proxy, created from here: {CreatorStackTrace}."); Debugger.Log(0, "DebugFinalizers", $"Caught orphaned Proxy, created from here: {CreatorStackTrace}.");
#endif #endif
Dispose(false); Dispose(false);

View File

@ -66,7 +66,7 @@ namespace Capnp.Rpc
Dismissed Dismissed
} }
static readonly ThreadLocal<PendingQuestion?> _tailCall = new ThreadLocal<PendingQuestion?>(); static readonly ThreadLocal<PendingQuestion?> _deferredCall = new ThreadLocal<PendingQuestion?>();
static readonly ThreadLocal<bool> _canDeferCalls = new ThreadLocal<bool>(); static readonly ThreadLocal<bool> _canDeferCalls = new ThreadLocal<bool>();
ILogger Logger { get; } = Logging.CreateLogger<RpcEndpoint>(); ILogger Logger { get; } = Logging.CreateLogger<RpcEndpoint>();
@ -81,6 +81,7 @@ namespace Capnp.Rpc
readonly Dictionary<uint, PendingAnswer> _answerTable = new Dictionary<uint, PendingAnswer>(); readonly Dictionary<uint, PendingAnswer> _answerTable = new Dictionary<uint, PendingAnswer>();
readonly Dictionary<uint, TaskCompletionSource<int>> _pendingDisembargos = new Dictionary<uint, TaskCompletionSource<int>>(); readonly Dictionary<uint, TaskCompletionSource<int>> _pendingDisembargos = new Dictionary<uint, TaskCompletionSource<int>>();
readonly object _reentrancyBlocker = new object(); readonly object _reentrancyBlocker = new object();
readonly object _callReturnBlocker = new object();
long _recvCount; long _recvCount;
long _sendCount; long _sendCount;
@ -284,16 +285,8 @@ namespace Capnp.Rpc
if (_revExportTable.TryGetValue(providedCapability, out uint id)) if (_revExportTable.TryGetValue(providedCapability, out uint id))
{ {
_exportTable[id].AddRef();
first = false; first = false;
if (_exportTable.TryGetValue(id, out var rc))
{
rc.AddRef();
}
else
{
Logger.LogError("Inconsistent export table: Capability with id {0} exists in reverse table only", id);
}
} }
else else
{ {
@ -305,7 +298,6 @@ namespace Capnp.Rpc
_revExportTable.Add(providedCapability, id); _revExportTable.Add(providedCapability, id);
_exportTable.Add(id, new RefCounted<Skeleton>(providedCapability)); _exportTable.Add(id, new RefCounted<Skeleton>(providedCapability));
first = true; first = true;
} }
@ -407,8 +399,22 @@ namespace Capnp.Rpc
} }
} }
void DispatchDeferredCalls()
{
var call = _deferredCall.Value;
_deferredCall.Value = null;
call?.Send();
}
void ProcessCall(Call.READER req) void ProcessCall(Call.READER req)
{
lock (_callReturnBlocker)
{
ProcessCallLocked(req);
}
}
void ProcessCallLocked(Call.READER req)
{ {
Return.WRITER SetupReturn(MessageBuilder mb) Return.WRITER SetupReturn(MessageBuilder mb)
{ {
@ -420,8 +426,10 @@ namespace Capnp.Rpc
return ret; return ret;
} }
void ReturnCall(Action<Return.WRITER> why) void ReturnCallNoCapTable(Action<Return.WRITER> why)
{ {
DispatchDeferredCalls();
var mb = MessageBuilder.Create(); var mb = MessageBuilder.Create();
mb.InitCapTable(); mb.InitCapTable();
var ret = SetupReturn(mb); var ret = SetupReturn(mb);
@ -430,7 +438,10 @@ namespace Capnp.Rpc
try try
{ {
Tx(mb.Frame); lock (_callReturnBlocker)
{
Tx(mb.Frame);
}
} }
catch (RpcException exception) catch (RpcException exception)
{ {
@ -473,7 +484,7 @@ namespace Capnp.Rpc
{ {
Debug.Fail("Either answer or counter question must be present"); Debug.Fail("Either answer or counter question must be present");
} }
else if (aorcq.Answer != null || aorcq.Counterquestion != _tailCall.Value) else if (aorcq.Answer != null || aorcq.Counterquestion != _deferredCall.Value)
{ {
var results = aorcq.Answer ?? (DynamicSerializerState)(await aorcq.Counterquestion!.WhenReturned); var results = aorcq.Answer ?? (DynamicSerializerState)(await aorcq.Counterquestion!.WhenReturned);
var ret = SetupReturn(results.MsgBuilder!); var ret = SetupReturn(results.MsgBuilder!);
@ -484,12 +495,13 @@ namespace Capnp.Rpc
ret.which = Return.WHICH.Results; ret.which = Return.WHICH.Results;
ret.Results!.Content = results.Rewrap<DynamicSerializerState>(); ret.Results!.Content = results.Rewrap<DynamicSerializerState>();
ret.ReleaseParamCaps = releaseParamCaps; ret.ReleaseParamCaps = releaseParamCaps;
DispatchDeferredCalls();
ExportCapTableAndSend(results, ret.Results); ExportCapTableAndSend(results, ret.Results);
pendingAnswer.CapTable = ret.Results.CapTable; pendingAnswer.CapTable = ret.Results.CapTable;
break; break;
case Call.sendResultsTo.WHICH.Yourself: case Call.sendResultsTo.WHICH.Yourself:
ReturnCall(ret2 => ReturnCallNoCapTable(ret2 =>
{ {
ret2.which = Return.WHICH.ResultsSentElsewhere; ret2.which = Return.WHICH.ResultsSentElsewhere;
ret2.ReleaseParamCaps = releaseParamCaps; ret2.ReleaseParamCaps = releaseParamCaps;
@ -499,11 +511,11 @@ namespace Capnp.Rpc
} }
else if (aorcq.Counterquestion != null) else if (aorcq.Counterquestion != null)
{ {
_tailCall.Value = null; _deferredCall.Value = null;
aorcq.Counterquestion.IsTailCall = true; aorcq.Counterquestion.IsTailCall = true;
aorcq.Counterquestion.Send(); aorcq.Counterquestion.Send();
ReturnCall(ret2 => ReturnCallNoCapTable(ret2 =>
{ {
ret2.which = Return.WHICH.TakeFromOtherQuestion; ret2.which = Return.WHICH.TakeFromOtherQuestion;
ret2.TakeFromOtherQuestion = aorcq.Counterquestion.QuestionId; ret2.TakeFromOtherQuestion = aorcq.Counterquestion.QuestionId;
@ -513,7 +525,7 @@ namespace Capnp.Rpc
} }
catch (TaskCanceledException) catch (TaskCanceledException)
{ {
ReturnCall(ret => ReturnCallNoCapTable(ret =>
{ {
ret.which = Return.WHICH.Canceled; ret.which = Return.WHICH.Canceled;
ret.ReleaseParamCaps = releaseParamCaps; ret.ReleaseParamCaps = releaseParamCaps;
@ -521,7 +533,7 @@ namespace Capnp.Rpc
} }
catch (System.Exception exception) catch (System.Exception exception)
{ {
ReturnCall(ret => ReturnCallNoCapTable(ret =>
{ {
ret.which = Return.WHICH.Exception; ret.which = Return.WHICH.Exception;
ret.Exception!.Reason = exception.Message; ret.Exception!.Reason = exception.Message;
@ -543,7 +555,7 @@ namespace Capnp.Rpc
} }
finally finally
{ {
ReturnCall(ret => ReturnCallNoCapTable(ret =>
{ {
ret.which = Return.WHICH.ResultsSentElsewhere; ret.which = Return.WHICH.ResultsSentElsewhere;
ret.ReleaseParamCaps = releaseParamCaps; ret.ReleaseParamCaps = releaseParamCaps;
@ -600,7 +612,7 @@ namespace Capnp.Rpc
} }
_canDeferCalls.Value = true; _canDeferCalls.Value = true;
Impatient.AskingEndpoint = this; Impatient.PushAskingEndpoint(this);
try try
{ {
@ -683,10 +695,8 @@ namespace Capnp.Rpc
finally finally
{ {
_canDeferCalls.Value = false; _canDeferCalls.Value = false;
Impatient.AskingEndpoint = null; Impatient.PopAskingEndpoint();
var call = _tailCall.Value; DispatchDeferredCalls();
_tailCall.Value = null;
call?.Send();
} }
} }
@ -1419,8 +1429,8 @@ namespace Capnp.Rpc
if (_canDeferCalls.Value) if (_canDeferCalls.Value)
{ {
_tailCall.Value?.Send(); DispatchDeferredCalls();
_tailCall.Value = question; _deferredCall.Value = question;
} }
else else
{ {

View File

@ -76,7 +76,12 @@ namespace Capnp.Rpc
protected override void Dispose(bool disposing) protected override void Dispose(bool disposing)
{ {
Proxy.Dispose(); if (disposing)
Proxy.Dispose();
else
try { Proxy.Dispose(); }
catch { }
base.Dispose(disposing); base.Dispose(disposing);
} }
} }