fabaccess-bffh/bffhd/capnp/authenticationsystem.rs

202 lines
6.1 KiB
Rust
Raw Normal View History

2022-03-08 16:41:38 +01:00
use capnp::capability::Promise;
use capnp::Error;
use capnp_rpc::pry;
use rsasl::mechname::Mechname;
use rsasl::prelude::State as SaslState;
use rsasl::prelude::{MessageSent, Session};
2022-03-12 17:31:53 +01:00
use rsasl::property::AuthId;
2022-06-24 13:57:47 +02:00
use std::fmt;
use std::fmt::{Formatter, Write};
2022-03-12 17:31:53 +01:00
use std::io::Cursor;
2022-06-24 13:57:47 +02:00
use tracing::Span;
2022-03-08 16:41:38 +01:00
use crate::authentication::V;
2022-03-12 17:31:53 +01:00
use crate::capnp::session::APISession;
use crate::session::SessionManager;
2022-03-12 02:00:55 +01:00
use api::authenticationsystem_capnp::authentication::{
2022-03-12 17:31:53 +01:00
AbortParams, AbortResults, Server as AuthenticationSystem, StepParams, StepResults,
2022-03-08 16:41:38 +01:00
};
2022-03-12 17:31:53 +01:00
use api::authenticationsystem_capnp::{response, response::Error as ErrorCode};
2022-03-08 16:41:38 +01:00
2022-06-24 13:57:47 +02:00
const TARGET: &str = "bffh::api::authenticationsystem";
2022-03-08 16:41:38 +01:00
pub struct Authentication {
2022-06-24 13:57:47 +02:00
span: Span,
2022-03-08 16:41:38 +01:00
state: State,
}
2022-03-12 17:31:53 +01:00
impl Authentication {
pub fn new(
parent: &Span,
mechanism: &Mechname, /* TODO: this is stored in session as well, get it out of there. */
2022-11-01 10:47:51 +01:00
session: Session<V>,
sessionmanager: SessionManager,
) -> Self {
let span = tracing::info_span!(
target: TARGET,
parent: parent,
"Authentication",
mechanism = mechanism.as_str()
);
2022-06-24 13:57:47 +02:00
tracing::trace!(
target: TARGET,
parent: &span,
"constructing valid authentication system"
);
2022-03-12 17:31:53 +01:00
Self {
2022-06-24 13:57:47 +02:00
span,
2022-03-12 17:31:53 +01:00
state: State::Running(session, sessionmanager),
}
}
pub fn invalid_mechanism() -> Self {
2022-06-24 13:57:47 +02:00
let span = tracing::info_span!(target: TARGET, "Authentication",);
tracing::trace!(
target: TARGET,
parent: &span,
"constructing invalid mechanism authentication system"
);
2022-03-12 17:31:53 +01:00
Self {
2022-06-24 13:57:47 +02:00
span,
2022-03-12 17:31:53 +01:00
state: State::InvalidMechanism,
}
}
fn build_error(&self, response: response::Builder) {
if let State::Running(_, _) = self.state {
return;
}
let mut builder = response.init_failed();
match self.state {
State::InvalidMechanism => builder.set_code(ErrorCode::BadMechanism),
State::Finished => builder.set_code(ErrorCode::Aborted),
State::Aborted => builder.set_code(ErrorCode::Aborted),
_ => unreachable!(),
}
}
}
2022-06-24 13:57:47 +02:00
impl fmt::Display for Authentication {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str("Authentication(")?;
match &self.state {
State::InvalidMechanism => f.write_str("invalid mechanism")?,
State::Finished => f.write_str("finished")?,
State::Aborted => f.write_str("aborted")?,
State::Running(_, _) => f.write_str("running")?,
}
f.write_char(')')
}
}
2022-03-08 16:41:38 +01:00
enum State {
InvalidMechanism,
Finished,
Aborted,
2022-11-01 10:47:51 +01:00
Running(Session<V>, SessionManager),
2022-03-08 16:41:38 +01:00
}
impl AuthenticationSystem for Authentication {
2022-03-08 16:41:38 +01:00
fn step(&mut self, params: StepParams, mut results: StepResults) -> Promise<(), Error> {
2022-06-24 13:57:47 +02:00
let _guard = self.span.enter();
let _span = tracing::trace_span!(target: TARGET, "step",).entered();
tracing::trace!(params.data = "<authentication data>", "method call");
#[repr(transparent)]
struct Response {
union_field: &'static str,
}
impl fmt::Display for Response {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str("Response(")?;
f.write_str(self.union_field)?;
f.write_char(')')
}
}
let mut response;
2022-03-12 17:31:53 +01:00
let mut builder = results.get();
if let State::Running(mut session, manager) =
std::mem::replace(&mut self.state, State::Aborted)
{
let data: &[u8] = pry!(pry!(params.get()).get_data());
2022-06-24 13:57:47 +02:00
2022-11-01 10:47:51 +01:00
let mut out = Vec::new();
2022-03-12 17:31:53 +01:00
match session.step(Some(data), &mut out) {
2022-10-05 17:28:47 +02:00
Ok(SaslState::Finished(sent)) => {
2022-03-12 17:31:53 +01:00
self.state = State::Finished;
2022-11-01 10:47:51 +01:00
if let Some(user) = session.validation() {
let session = manager.open(&self.span, user);
response = Response {
union_field: "successful",
};
let mut builder = builder.init_successful();
if sent == MessageSent::Yes {
builder.set_additional_data(out.as_slice());
}
APISession::build(session, builder)
} else {
let mut builder = builder.init_failed();
builder.set_code(ErrorCode::InvalidCredentials);
response = Response {
union_field: "error",
};
2022-03-12 17:31:53 +01:00
}
}
2022-10-05 17:28:47 +02:00
Ok(SaslState::Running) => {
2022-03-12 17:31:53 +01:00
self.state = State::Running(session, manager);
2022-11-01 10:47:51 +01:00
builder.set_challenge(out.as_slice());
2022-06-24 13:57:47 +02:00
response = Response {
union_field: "challenge",
};
2022-03-12 17:31:53 +01:00
}
2023-02-13 18:44:08 +01:00
Err(e) => {
tracing::error!(error = %e., "authentication failed");
2022-03-12 17:31:53 +01:00
self.state = State::Aborted;
self.build_error(builder);
2022-06-24 13:57:47 +02:00
response = Response {
union_field: "error",
};
2022-03-12 17:31:53 +01:00
}
}
} else {
self.build_error(builder);
2022-06-24 13:57:47 +02:00
response = Response {
union_field: "error",
};
2022-03-12 17:31:53 +01:00
}
2022-06-24 13:57:47 +02:00
tracing::trace!(
results = %response,
"method return"
);
2022-03-12 17:31:53 +01:00
Promise::ok(())
2022-03-08 16:41:38 +01:00
}
fn abort(&mut self, _: AbortParams, _: AbortResults) -> Promise<(), Error> {
2022-06-24 13:57:47 +02:00
let _guard = self.span.enter();
let _span = tracing::trace_span!(
target: TARGET,
parent: &self.span,
"abort",
)
.entered();
tracing::trace!("method call");
2022-03-12 01:27:58 +01:00
self.state = State::Aborted;
2022-06-24 13:57:47 +02:00
tracing::trace!("method return");
2022-03-08 16:41:38 +01:00
Promise::ok(())
}
2022-03-12 17:31:53 +01:00
}