using Capnp.Rpc;
using FabAccessAPI.Exceptions;
using FabAccessAPI.Schema;
using NLog;
using S22.Sasl;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Security;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;

namespace FabAccessAPI
{
    public class API : IAPI
    {
        #region Logger
        private static readonly Logger Log = LogManager.GetCurrentClassLogger();
        #endregion

        #region Private Members
        private TcpRpcClient _TcpRpcClient;
        private IBootstrap _Bootstrap;
        private static SemaphoreSlim _ConnectSemaphore = new SemaphoreSlim(1, 1);
        private static SemaphoreSlim _ReconnectSemaphore = new SemaphoreSlim(1, 1);
        #endregion

        #region Constructors
        public API()
        {

        }
        #endregion

        #region Events
        public event EventHandler<ConnectionStatusChange> ConnectionStatusChanged;

        public void OnTcpRpcConnectionChanged(object sender, ConnectionStateChange args)
        {
            if (args.LastState == ConnectionState.Active && args.NewState == ConnectionState.Down)
            {
                Log.Trace("TcpRpcClient Event ConnectionLoss");
                ConnectionStatusChanged?.Invoke(this, ConnectionStatusChange.ConnectionLoss);
                _TcpRpcClient = null;
            }
        }

        public void UnbindAllEvents()
        {
            if(ConnectionStatusChanged != null)
            {
                foreach (Delegate d in ConnectionStatusChanged.GetInvocationList())
                {
                    ConnectionStatusChanged -= (EventHandler<ConnectionStatusChange>)d;
                }
            }
        }
        #endregion

        #region Members
        public ConnectionData ConnectionData { get; private set; }

        public ConnectionInfo ConnectionInfo { get; private set; }

        public bool IsConnected
        {
            get
            {
                return _TcpRpcClient != null && _TcpRpcClient.State == ConnectionState.Active;
            }
        }

        public Session Session { get; private set; }
        #endregion

        #region Methods
        /// <summary>
        /// Connect to server with ConnectionData
        /// </summary>
        /// <exception cref="AuthenticationException"></exception>
        /// <exception cref="ConnectingFailedException"></exception>
        public async Task Connect(ConnectionData connectionData, TcpRpcClient tcpRpcClient = null)
        {
            await _ConnectSemaphore.WaitAsync();
            try
            {
                if (IsConnected)
                {
                    await Disconnect();
                }

                if (tcpRpcClient == null)
                {
                    tcpRpcClient = new TcpRpcClient();
                }

                try
                {
                    await _ConnectAsync(tcpRpcClient, connectionData).ConfigureAwait(false);

                    _Bootstrap = tcpRpcClient.GetMain<IBootstrap>();
                    ConnectionInfo = await _GetConnectionInfo(_Bootstrap);

                    Session = await _Authenticate(connectionData).ConfigureAwait(false);
                    ConnectionData = connectionData;

                    _TcpRpcClient = tcpRpcClient;
                    tcpRpcClient.ConnectionStateChanged += OnTcpRpcConnectionChanged;
                    ConnectionStatusChanged?.Invoke(this, ConnectionStatusChange.Connected);
                    Log.Info("API connected");
                }
                catch (System.Exception ex)
                {
                    await Disconnect().ConfigureAwait(false);
                    Log.Warn(ex, "API connecting failed");
                    throw ex;
                }
            }
            finally
            {
                _ConnectSemaphore.Release();
            }
        }

        public Task Disconnect()
        {
            if (IsConnected)
            {
                _TcpRpcClient.Dispose();
            }
            
            _Bootstrap = null;
            Session = null;
            _TcpRpcClient = null;
            ConnectionData = null;
            ConnectionInfo = null;

            ConnectionStatusChanged?.Invoke(this, ConnectionStatusChange.Disconnected);

            Log.Info("API disconnected");

            return Task.CompletedTask;
        }

        public async Task Reconnect()
        {
            await _ReconnectSemaphore.WaitAsync();
            try
            {
                if (ConnectionData != null && IsConnected == false)
                {
                    await Connect(ConnectionData);
                }

                ConnectionStatusChanged?.Invoke(this, ConnectionStatusChange.Reconnected);
                Log.Info("API reconnected");
            }
            finally
            {
                _ReconnectSemaphore.Release();
            }
            
        }

        public async Task<ConnectionInfo> TestConnection(ConnectionData connectionData, TcpRpcClient tcpRpcClient = null)
        {
            try
            {
                if (tcpRpcClient == null)
                {
                    tcpRpcClient = new TcpRpcClient();
                }

                await _ConnectAsync(tcpRpcClient, connectionData).ConfigureAwait(false);
                IBootstrap testBootstrap = tcpRpcClient.GetMain<IBootstrap>();

                ConnectionInfo connectionInfo = await _GetConnectionInfo(testBootstrap).ConfigureAwait(false);

                tcpRpcClient.Dispose();

                return connectionInfo;
            }
            catch
            {
                throw new ConnectingFailedException();
            }
        }
        #endregion

        #region Private Methods
        /// <summary>
        /// Validate Certificate
        /// TODO: Do some validation
        /// </summary>
        private static bool _RemoteCertificateValidationCallback(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors)
        {
            // TODO Cert Check
            return true;
        }

        /// <summary>
        /// Connect to server async with ConnectionData
        /// </summary>
        /// <exception cref="AuthenticationException">TLS Error</exception>
        /// <exception cref="ConnectingFailedException">Based on RPC Exception</exception>
        /// 
        private async Task _ConnectAsync(TcpRpcClient rpcClient, ConnectionData connectionData)
        {
            rpcClient.InjectMidlayer((tcpstream) =>
            {
                var sslStream = new SslStream(tcpstream, false, new RemoteCertificateValidationCallback(_RemoteCertificateValidationCallback));
                try
                {
                    sslStream.AuthenticateAsClient("bffhd");
                    return sslStream;
                }
                catch (AuthenticationException)
                {
                    sslStream.Close();
                    throw;
                }
            });
            
            try
            {
                Task timeoutTask = Task.Delay(3000);
                rpcClient.Connect(connectionData.Host.Host, connectionData.Host.Port);
                await await Task.WhenAny(rpcClient.WhenConnected, timeoutTask);
                
                if(timeoutTask.IsCompleted)
                {
                    throw new ConnectingFailedException("Connection timeout");
                }
            }
            catch (RpcException exception) when (string.Equals(exception.Message, "TcpRpcClient is unable to connect", StringComparison.Ordinal))
            {
                throw new ConnectingFailedException("RPC Connecting failed", exception);
            }
        }

        /// <summary>
        /// Create ConnectionInfo from Bootstrap
        /// </summary>
        private async Task<ConnectionInfo> _GetConnectionInfo(IBootstrap bootstrap)
        {
            ConnectionInfo connectionInfo = new ConnectionInfo()
            {
                APIVersion = await bootstrap.GetAPIVersion().ConfigureAwait(false),
                Mechanisms = new List<string>(await bootstrap.Mechanisms().ConfigureAwait(false)),
                ServerName = (await bootstrap.GetServerRelease().ConfigureAwait(false)).Item1,
                ServerRelease = (await bootstrap.GetServerRelease().ConfigureAwait(false)).Item2,
            };

            return connectionInfo;
        }

        /// <summary>
        /// Authenticate connection with ConnectionData
        /// </summary>
        /// <exception cref="UnsupportedMechanismException"></exception>
        /// <exception cref="InvalidCredentialsException"></exception>
        /// <exception cref="AuthenticationFailedException"></exception>
        private async Task<Session> _Authenticate(ConnectionData connectionData)
        {
            IAuthentication? authentication = await _Bootstrap.CreateSession(MechanismString.ToString(connectionData.Mechanism)).ConfigureAwait(false);

            return await _SASLAuthenticate(authentication, MechanismString.ToString(connectionData.Mechanism), connectionData.Properties).ConfigureAwait(false);
        }

        /// <summary>
        /// Authenticate Connection to get Session
        /// </summary>
        /// <exception cref="BadMechanismException"></exception>
        /// <exception cref="InvalidCredentialsException"></exception>
        /// <exception cref="AuthenticationFailedException"></exception>
        private async Task<Session> _SASLAuthenticate(IAuthentication authentication, string mech, Dictionary<string, object> properties)
        {
            SaslMechanism? saslMechanism = SaslFactory.Create(mech);
            foreach (KeyValuePair<string, object> entry in properties)
            {
                saslMechanism.Properties.Add(entry.Key, entry.Value);
            }

            byte[] data = new byte[0];

            if (saslMechanism.HasInitial)
            {
                data = saslMechanism.GetResponse(new byte[0]);
            }

            Response? response = await authentication.Step(data);
            while (!saslMechanism.IsCompleted)
            {
                if(response.Failed != null)
                {
                    break;
                }
                if(response.Challenge != null)
                {
                    byte[]? additional = saslMechanism.GetResponse(response.Challenge.ToArray());
                    response = await authentication.Step(additional);
                }
                else
                {
                    throw new AuthenticationFailedException();
                }
            }

            if (response.Successful != null)
            {
                return response.Successful.Session;
            }
            else if (response.Failed != null)
            {
                switch (response.Failed.Code)
                {
                    case Response.Error.badMechanism:
                        throw new BadMechanismException();
                    case Response.Error.invalidCredentials:
                        throw new InvalidCredentialsException();
                    case Response.Error.aborted:
                    case Response.Error.failed:
                    default:
                        throw new AuthenticationFailedException(response.Failed.AdditionalData.ToArray());
                }
            }
            else
            {
                throw new AuthenticationFailedException();
            }
        }
        #endregion
    }
}