2022-06-02 13:15:03 +02:00

337 lines
11 KiB
C#

using Capnp.Rpc;
using FabAccessAPI.Exceptions;
using FabAccessAPI.Schema;
using NLog;
using S22.Sasl;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Security;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
namespace FabAccessAPI
{
public class API : IAPI
{
#region Logger
private static readonly Logger Log = LogManager.GetCurrentClassLogger();
#endregion
#region Private Members
private TcpRpcClient _TcpRpcClient;
private IBootstrap _Bootstrap;
private static SemaphoreSlim _ConnectSemaphore = new SemaphoreSlim(1, 1);
private static SemaphoreSlim _ReconnectSemaphore = new SemaphoreSlim(1, 1);
#endregion
#region Constructors
public API()
{
}
#endregion
#region Events
public event EventHandler<ConnectionStatusChange> ConnectionStatusChanged;
public void OnTcpRpcConnectionChanged(object sender, ConnectionStateChange args)
{
if (args.LastState == ConnectionState.Active && args.NewState == ConnectionState.Down)
{
Log.Trace("TcpRpcClient Event ConnectionLoss");
ConnectionStatusChanged?.Invoke(this, ConnectionStatusChange.ConnectionLoss);
_TcpRpcClient = null;
}
}
public void UnbindAllEvents()
{
if(ConnectionStatusChanged != null)
{
foreach (Delegate d in ConnectionStatusChanged.GetInvocationList())
{
ConnectionStatusChanged -= (EventHandler<ConnectionStatusChange>)d;
}
}
}
#endregion
#region Members
public ConnectionData ConnectionData { get; private set; }
public ConnectionInfo ConnectionInfo { get; private set; }
public bool IsConnected
{
get
{
return _TcpRpcClient != null && _TcpRpcClient.State == ConnectionState.Active;
}
}
public Session Session { get; private set; }
#endregion
#region Methods
/// <summary>
/// Connect to server with ConnectionData
/// </summary>
/// <exception cref="AuthenticationException"></exception>
/// <exception cref="ConnectingFailedException"></exception>
public async Task Connect(ConnectionData connectionData, TcpRpcClient tcpRpcClient = null)
{
await _ConnectSemaphore.WaitAsync();
try
{
if (IsConnected)
{
await Disconnect();
}
if (tcpRpcClient == null)
{
tcpRpcClient = new TcpRpcClient();
}
try
{
await _ConnectAsync(tcpRpcClient, connectionData).ConfigureAwait(false);
_Bootstrap = tcpRpcClient.GetMain<IBootstrap>();
ConnectionInfo = await _GetConnectionInfo(_Bootstrap);
Session = await _Authenticate(connectionData).ConfigureAwait(false);
ConnectionData = connectionData;
_TcpRpcClient = tcpRpcClient;
tcpRpcClient.ConnectionStateChanged += OnTcpRpcConnectionChanged;
ConnectionStatusChanged?.Invoke(this, ConnectionStatusChange.Connected);
Log.Info("API connected");
}
catch (System.Exception ex)
{
await Disconnect().ConfigureAwait(false);
Log.Warn(ex, "API connecting failed");
throw ex;
}
}
finally
{
_ConnectSemaphore.Release();
}
}
public Task Disconnect()
{
if (IsConnected)
{
_TcpRpcClient.Dispose();
}
_Bootstrap = null;
Session = null;
_TcpRpcClient = null;
ConnectionData = null;
ConnectionInfo = null;
ConnectionStatusChanged?.Invoke(this, ConnectionStatusChange.Disconnected);
Log.Info("API disconnected");
return Task.CompletedTask;
}
public async Task Reconnect()
{
await _ReconnectSemaphore.WaitAsync();
try
{
if (ConnectionData != null && IsConnected == false)
{
await Connect(ConnectionData);
}
ConnectionStatusChanged?.Invoke(this, ConnectionStatusChange.Reconnected);
Log.Info("API reconnected");
}
finally
{
_ReconnectSemaphore.Release();
}
}
public async Task<ConnectionInfo> TestConnection(ConnectionData connectionData, TcpRpcClient tcpRpcClient = null)
{
try
{
if (tcpRpcClient == null)
{
tcpRpcClient = new TcpRpcClient();
}
await _ConnectAsync(tcpRpcClient, connectionData).ConfigureAwait(false);
IBootstrap testBootstrap = tcpRpcClient.GetMain<IBootstrap>();
ConnectionInfo connectionInfo = await _GetConnectionInfo(testBootstrap).ConfigureAwait(false);
tcpRpcClient.Dispose();
return connectionInfo;
}
catch
{
throw new ConnectingFailedException();
}
}
#endregion
#region Private Methods
/// <summary>
/// Validate Certificate
/// TODO: Do some validation
/// </summary>
private static bool _RemoteCertificateValidationCallback(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors)
{
// TODO Cert Check
return true;
}
/// <summary>
/// Connect to server async with ConnectionData
/// </summary>
/// <exception cref="AuthenticationException">TLS Error</exception>
/// <exception cref="ConnectingFailedException">Based on RPC Exception</exception>
///
private async Task _ConnectAsync(TcpRpcClient rpcClient, ConnectionData connectionData)
{
rpcClient.InjectMidlayer((tcpstream) =>
{
var sslStream = new SslStream(tcpstream, false, new RemoteCertificateValidationCallback(_RemoteCertificateValidationCallback));
try
{
sslStream.AuthenticateAsClient("bffhd");
return sslStream;
}
catch (AuthenticationException)
{
sslStream.Close();
throw;
}
});
try
{
Task timeoutTask = Task.Delay(3000);
rpcClient.Connect(connectionData.Host.Host, connectionData.Host.Port);
await await Task.WhenAny(rpcClient.WhenConnected, timeoutTask);
if(timeoutTask.IsCompleted)
{
throw new ConnectingFailedException("Connection timeout");
}
}
catch (RpcException exception) when (string.Equals(exception.Message, "TcpRpcClient is unable to connect", StringComparison.Ordinal))
{
throw new ConnectingFailedException("RPC Connecting failed", exception);
}
}
/// <summary>
/// Create ConnectionInfo from Bootstrap
/// </summary>
private async Task<ConnectionInfo> _GetConnectionInfo(IBootstrap bootstrap)
{
ConnectionInfo connectionInfo = new ConnectionInfo()
{
APIVersion = await bootstrap.GetAPIVersion().ConfigureAwait(false),
Mechanisms = new List<string>(await bootstrap.Mechanisms().ConfigureAwait(false)),
ServerName = (await bootstrap.GetServerRelease().ConfigureAwait(false)).Item1,
ServerRelease = (await bootstrap.GetServerRelease().ConfigureAwait(false)).Item2,
};
return connectionInfo;
}
/// <summary>
/// Authenticate connection with ConnectionData
/// </summary>
/// <exception cref="UnsupportedMechanismException"></exception>
/// <exception cref="InvalidCredentialsException"></exception>
/// <exception cref="AuthenticationFailedException"></exception>
private async Task<Session> _Authenticate(ConnectionData connectionData)
{
IAuthentication? authentication = await _Bootstrap.CreateSession(MechanismString.ToString(connectionData.Mechanism)).ConfigureAwait(false);
return await _SASLAuthenticate(authentication, MechanismString.ToString(connectionData.Mechanism), connectionData.Properties).ConfigureAwait(false);
}
/// <summary>
/// Authenticate Connection to get Session
/// </summary>
/// <exception cref="BadMechanismException"></exception>
/// <exception cref="InvalidCredentialsException"></exception>
/// <exception cref="AuthenticationFailedException"></exception>
private async Task<Session> _SASLAuthenticate(IAuthentication authentication, string mech, Dictionary<string, object> properties)
{
SaslMechanism? saslMechanism = SaslFactory.Create(mech);
foreach (KeyValuePair<string, object> entry in properties)
{
saslMechanism.Properties.Add(entry.Key, entry.Value);
}
byte[] data = new byte[0];
if (saslMechanism.HasInitial)
{
data = saslMechanism.GetResponse(new byte[0]);
}
Response? response = await authentication.Step(data);
while (!saslMechanism.IsCompleted)
{
if(response.Failed != null)
{
break;
}
if(response.Challenge != null)
{
byte[]? additional = saslMechanism.GetResponse(response.Challenge.ToArray());
response = await authentication.Step(additional);
}
else
{
throw new AuthenticationFailedException();
}
}
if (response.Successful != null)
{
return response.Successful.Session;
}
else if (response.Failed != null)
{
switch (response.Failed.Code)
{
case Response.Error.badMechanism:
throw new BadMechanismException();
case Response.Error.invalidCredentials:
throw new InvalidCredentialsException();
case Response.Error.aborted:
case Response.Error.failed:
default:
throw new AuthenticationFailedException(response.Failed.AdditionalData.ToArray());
}
}
else
{
throw new AuthenticationFailedException();
}
}
#endregion
}
}