using Capnp.Rpc; using FabAccessAPI.Exceptions; using FabAccessAPI.Schema; using NLog; using S22.Sasl; using System; using System.Collections.Generic; using System.Linq; using System.Net.Security; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; namespace FabAccessAPI { public class API : IAPI { #region Logger private static readonly Logger Log = LogManager.GetCurrentClassLogger(); #endregion #region Private Members private TcpRpcClient _TcpRpcClient; private IBootstrap _Bootstrap; private static SemaphoreSlim _ConnectSemaphore = new SemaphoreSlim(1, 1); private static SemaphoreSlim _ReconnectSemaphore = new SemaphoreSlim(1, 1); #endregion #region Constructors public API() { } #endregion #region Events public event EventHandler ConnectionStatusChanged; public void OnTcpRpcConnectionChanged(object sender, ConnectionStateChange args) { if (args.LastState == ConnectionState.Active && args.NewState == ConnectionState.Down) { Log.Trace("TcpRpcClient Event ConnectionLoss"); ConnectionStatusChanged?.Invoke(this, ConnectionStatusChange.ConnectionLoss); _TcpRpcClient = null; } } public void UnbindAllEvents() { if(ConnectionStatusChanged != null) { foreach (Delegate d in ConnectionStatusChanged.GetInvocationList()) { ConnectionStatusChanged -= (EventHandler)d; } } } #endregion #region Members public ConnectionData ConnectionData { get; private set; } public ConnectionInfo ConnectionInfo { get; private set; } public bool IsConnected { get { return _TcpRpcClient != null && _TcpRpcClient.State == ConnectionState.Active; } } public Session Session { get; private set; } #endregion #region Methods /// /// Connect to server with ConnectionData /// /// /// public async Task Connect(ConnectionData connectionData, TcpRpcClient tcpRpcClient = null) { await _ConnectSemaphore.WaitAsync(); try { if (IsConnected) { await Disconnect(); } if (tcpRpcClient == null) { tcpRpcClient = new TcpRpcClient(); } try { await _ConnectAsync(tcpRpcClient, connectionData).ConfigureAwait(false); _Bootstrap = tcpRpcClient.GetMain(); ConnectionInfo = await _GetConnectionInfo(_Bootstrap); Session = await _Authenticate(connectionData).ConfigureAwait(false); ConnectionData = connectionData; _TcpRpcClient = tcpRpcClient; tcpRpcClient.ConnectionStateChanged += OnTcpRpcConnectionChanged; ConnectionStatusChanged?.Invoke(this, ConnectionStatusChange.Connected); Log.Info("API connected"); } catch (System.Exception ex) { await Disconnect().ConfigureAwait(false); Log.Warn(ex, "API connecting failed"); throw ex; } } finally { _ConnectSemaphore.Release(); } } public Task Disconnect() { if (IsConnected) { _TcpRpcClient.Dispose(); } _Bootstrap = null; Session = null; _TcpRpcClient = null; ConnectionData = null; ConnectionInfo = null; ConnectionStatusChanged?.Invoke(this, ConnectionStatusChange.Disconnected); Log.Info("API disconnected"); return Task.CompletedTask; } public async Task Reconnect() { await _ReconnectSemaphore.WaitAsync(); try { if (ConnectionData != null && IsConnected == false) { await Connect(ConnectionData); } ConnectionStatusChanged?.Invoke(this, ConnectionStatusChange.Reconnected); Log.Info("API reconnected"); } finally { _ReconnectSemaphore.Release(); } } public async Task TestConnection(ConnectionData connectionData, TcpRpcClient tcpRpcClient = null) { try { if (tcpRpcClient == null) { tcpRpcClient = new TcpRpcClient(); } await _ConnectAsync(tcpRpcClient, connectionData).ConfigureAwait(false); IBootstrap testBootstrap = tcpRpcClient.GetMain(); ConnectionInfo connectionInfo = await _GetConnectionInfo(testBootstrap).ConfigureAwait(false); tcpRpcClient.Dispose(); return connectionInfo; } catch { throw new ConnectingFailedException(); } } #endregion #region Private Methods /// /// Validate Certificate /// TODO: Do some validation /// private static bool _RemoteCertificateValidationCallback(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) { // TODO Cert Check return true; } /// /// Connect to server async with ConnectionData /// /// TLS Error /// Based on RPC Exception /// private async Task _ConnectAsync(TcpRpcClient rpcClient, ConnectionData connectionData) { rpcClient.InjectMidlayer((tcpstream) => { var sslStream = new SslStream(tcpstream, false, new RemoteCertificateValidationCallback(_RemoteCertificateValidationCallback)); try { sslStream.AuthenticateAsClient("bffhd"); return sslStream; } catch (AuthenticationException) { sslStream.Close(); throw; } }); try { Task timeoutTask = Task.Delay(3000); rpcClient.Connect(connectionData.Host.Host, connectionData.Host.Port); await await Task.WhenAny(rpcClient.WhenConnected, timeoutTask); if(timeoutTask.IsCompleted) { throw new ConnectingFailedException("Connection timeout"); } } catch (RpcException exception) when (string.Equals(exception.Message, "TcpRpcClient is unable to connect", StringComparison.Ordinal)) { throw new ConnectingFailedException("RPC Connecting failed", exception); } } /// /// Create ConnectionInfo from Bootstrap /// private async Task _GetConnectionInfo(IBootstrap bootstrap) { ConnectionInfo connectionInfo = new ConnectionInfo() { APIVersion = await bootstrap.GetAPIVersion().ConfigureAwait(false), Mechanisms = new List(await bootstrap.Mechanisms().ConfigureAwait(false)), ServerName = (await bootstrap.GetServerRelease().ConfigureAwait(false)).Item1, ServerRelease = (await bootstrap.GetServerRelease().ConfigureAwait(false)).Item2, }; return connectionInfo; } /// /// Authenticate connection with ConnectionData /// /// /// /// private async Task _Authenticate(ConnectionData connectionData) { IAuthentication? authentication = await _Bootstrap.CreateSession(MechanismString.ToString(connectionData.Mechanism)).ConfigureAwait(false); return await _SASLAuthenticate(authentication, MechanismString.ToString(connectionData.Mechanism), connectionData.Properties).ConfigureAwait(false); } /// /// Authenticate Connection to get Session /// /// /// /// private async Task _SASLAuthenticate(IAuthentication authentication, string mech, Dictionary properties) { SaslMechanism? saslMechanism = SaslFactory.Create(mech); foreach (KeyValuePair entry in properties) { saslMechanism.Properties.Add(entry.Key, entry.Value); } byte[] data = new byte[0]; if (saslMechanism.HasInitial) { data = saslMechanism.GetResponse(new byte[0]); } Response? response = await authentication.Step(data); while (!saslMechanism.IsCompleted) { if(response.Failed != null) { break; } if(response.Challenge != null) { byte[]? additional = saslMechanism.GetResponse(response.Challenge.ToArray()); response = await authentication.Step(additional); } else { throw new AuthenticationFailedException(); } } if (response.Successful != null) { return response.Successful.Session; } else if (response.Failed != null) { switch (response.Failed.Code) { case Response.Error.badMechanism: throw new BadMechanismException(); case Response.Error.invalidCredentials: throw new InvalidCredentialsException(); case Response.Error.aborted: case Response.Error.failed: default: throw new AuthenticationFailedException(response.Failed.AdditionalData.ToArray()); } } else { throw new AuthenticationFailedException(); } } #endregion } }