import asyncio
import socket
import ssl
import capnp
import logging

connection_capnp = capnp.load('schema/connection.capnp')
authenticationsystem_capnp = capnp.load('schema/authenticationsystem.capnp')


async def myreader(client, reader):
    while True:
        data = await reader.read(4096)
        client.write(data)


async def mywriter(client, writer):
    while True:
        data = await client.read(4096)
        writer.write(data.tobytes())
        await writer.drain()


async def connect(host, port, user, pw):
    # Setup SSL context
    ctx = ssl.create_default_context(
        ssl.Purpose.SERVER_AUTH
    )
    ctx.check_hostname = False
    ctx.verify_mode = ssl.CERT_NONE

    # Handle both IPv4 and IPv6 cases
    try:
        reader, writer = await asyncio.open_connection(
            host, port, ssl=ctx, family=socket.AF_INET6
        )
    except Exception:
        reader, writer = await asyncio.open_connection(
            host, port, ssl=ctx, family=socket.AF_INET
        )

    # Start TwoPartyClient using TwoWayPipe (takes no arguments in this mode)
    client = capnp.TwoPartyClient()

    # Assemble reader and writer tasks, run in the background
    coroutines = [myreader(client, reader), mywriter(client, writer)]
    asyncio.gather(*coroutines, return_exceptions=True)

    boot = client.bootstrap().cast_as(connection_capnp.Bootstrap)
    auth = await boot.createSession("PLAIN").a_wait()
    p = "\0" + user + "\0" + pw
    response = await auth.authentication.step(p).a_wait()
    if response.which() == 'successful':
        return response.successful.session
    else:
        print("Authentication failed!")
        return None



async def connect_with_fabfire_initial(host, port, uid):
    # Setup SSL context
    ctx = ssl.create_default_context(
        ssl.Purpose.SERVER_AUTH
    )
    ctx.check_hostname = False
    ctx.verify_mode = ssl.CERT_NONE

    # Handle both IPv4 and IPv6 cases
    try:
        reader, writer = await asyncio.open_connection(
            host, port, ssl=ctx, family=socket.AF_INET6
        )
    except Exception:
        reader, writer = await asyncio.open_connection(
            host, port, ssl=ctx, family=socket.AF_INET
        )

    # Start TwoPartyClient using TwoWayPipe (takes no arguments in this mode)
    client = capnp.TwoPartyClient()


    # Assemble reader and writer tasks, run in the background
    coroutines = [myreader(client, reader), mywriter(client, writer)]
    asyncio.gather(*coroutines, return_exceptions=True)

    boot = client.bootstrap().cast_as(connection_capnp.Bootstrap)
    auth = await boot.createSession("X-FABFIRE").a_wait()
    response = await auth.authentication.step(uid).a_wait()
    logging.debug(f"got response type: {response.which()}")
    if response.which() == "challenge":
        logging.debug(f"challenge: {response.challenge}")
        return auth, response.challenge
    else:
        logging.error(f"Auth failed: {response.failed.code}, additional info: {response.failed.additionalData}")
        return None


async def connect_with_fabfire_step(auth, msg):
    response = await auth.authentication.step(msg).a_wait()
    if response.which() == "challenge":
        logging.debug(f"challenge: {response.challenge}")
        return response.challenge, None  # auth cap, challenge, not done
    elif response.which() == "successful":
        logging.info(f"Auth completed successfully! Got additional Data: {response.successful.additionalData}")
        return response.successful.additionalData, response.successful.session  # dont care, message, we are done
    else:
        logging.error(f"Auth failed: {response.failed.code}, additional info: {response.failed.additionalData}")
        return None