Move miette towards edges of BFFH for more structured error reporting

This commit is contained in:
Nadja Reitzenstein 2022-12-25 11:54:36 +01:00
parent 0d2cd6f376
commit 81aa60a10f
15 changed files with 226 additions and 76 deletions

View File

@ -12,9 +12,10 @@ use std::future::Future;
use std::pin::Pin; use std::pin::Pin;
use miette::IntoDiagnostic; use miette::{Diagnostic, IntoDiagnostic};
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::time::Duration; use std::time::Duration;
use thiserror::Error;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use rumqttc::ConnectReturnCode::Success; use rumqttc::ConnectReturnCode::Success;
@ -111,11 +112,33 @@ static ROOT_CERTS: Lazy<RootCertStore> = Lazy::new(|| {
store store
}); });
pub fn load(executor: Executor, config: &Config, resources: ResourcesHandle) -> miette::Result<()> { #[derive(Debug, Error, Diagnostic)]
pub enum ActorError {
#[error("failed to parse MQTT url")]
UrlParseError(
#[from]
#[source]
url::ParseError,
),
#[error("MQTT config is invalid")]
InvalidConfig,
#[error("MQTT connection failed")]
ConnectionError(
#[from]
#[source]
rumqttc::ConnectionError,
),
}
pub fn load(
executor: Executor,
config: &Config,
resources: ResourcesHandle,
) -> Result<(), ActorError> {
let span = tracing::info_span!("loading actors"); let span = tracing::info_span!("loading actors");
let _guard = span; let _guard = span;
let mqtt_url = Url::parse(config.mqtt_url.as_str()).into_diagnostic()?; let mqtt_url = Url::parse(config.mqtt_url.as_str())?;
let (transport, default_port) = match mqtt_url.scheme() { let (transport, default_port) = match mqtt_url.scheme() {
"mqtts" | "ssl" => ( "mqtts" | "ssl" => (
rumqttc::Transport::tls_with_config( rumqttc::Transport::tls_with_config(
@ -132,12 +155,12 @@ pub fn load(executor: Executor, config: &Config, resources: ResourcesHandle) ->
scheme => { scheme => {
tracing::error!(%scheme, "MQTT url uses invalid scheme"); tracing::error!(%scheme, "MQTT url uses invalid scheme");
miette::bail!("invalid config"); return Err(ActorError::InvalidConfig);
} }
}; };
let host = mqtt_url.host_str().ok_or_else(|| { let host = mqtt_url.host_str().ok_or_else(|| {
tracing::error!("MQTT url must contain a hostname"); tracing::error!("MQTT url must contain a hostname");
miette::miette!("invalid config") ActorError::InvalidConfig
})?; })?;
let port = mqtt_url.port().unwrap_or(default_port); let port = mqtt_url.port().unwrap_or(default_port);
@ -168,7 +191,7 @@ pub fn load(executor: Executor, config: &Config, resources: ResourcesHandle) ->
} }
Err(error) => { Err(error) => {
tracing::error!(?error, "MQTT connection failed"); tracing::error!(?error, "MQTT connection failed");
miette::bail!("mqtt connection failed") return Err(ActorError::ConnectionError(error));
} }
} }

View File

@ -1,8 +1,10 @@
use miette::Diagnostic;
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use std::fs::{File, OpenOptions}; use std::fs::{File, OpenOptions};
use std::io; use std::io;
use std::io::{LineWriter, Write}; use std::io::{LineWriter, Write};
use std::sync::Mutex; use std::sync::Mutex;
use thiserror::Error;
use crate::Config; use crate::Config;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -23,8 +25,13 @@ pub struct AuditLogLine<'a> {
state: &'a str, state: &'a str,
} }
#[derive(Debug, Error, Diagnostic)]
#[error(transparent)]
#[repr(transparent)]
pub struct Error(#[from] pub io::Error);
impl AuditLog { impl AuditLog {
pub fn new(config: &Config) -> io::Result<&'static Self> { pub fn new(config: &Config) -> Result<&'static Self, Error> {
AUDIT.get_or_try_init(|| { AUDIT.get_or_try_init(|| {
tracing::debug!(path = %config.auditlog_path.display(), "Initializing audit log"); tracing::debug!(path = %config.auditlog_path.display(), "Initializing audit log");
let fd = OpenOptions::new() let fd = OpenOptions::new()

View File

@ -2,7 +2,7 @@ mod server;
pub use server::FabFire; pub use server::FabFire;
use rsasl::mechname::Mechname; use rsasl::mechname::Mechname;
use rsasl::registry::{Mechanism, MECHANISMS, Side}; use rsasl::registry::{Mechanism, Side, MECHANISMS};
const MECHNAME: &'static Mechname = &Mechname::const_new_unchecked(b"X-FABFIRE"); const MECHNAME: &'static Mechname = &Mechname::const_new_unchecked(b"X-FABFIRE");
@ -10,8 +10,8 @@ const MECHNAME: &'static Mechname = &Mechname::const_new_unchecked(b"X-FABFIRE")
pub static FABFIRE: Mechanism = pub static FABFIRE: Mechanism =
Mechanism::build(MECHNAME, 300, None, Some(FabFire::new_server), Side::Client); Mechanism::build(MECHNAME, 300, None, Some(FabFire::new_server), Side::Client);
use std::marker::PhantomData;
use rsasl::property::SizedProperty; use rsasl::property::SizedProperty;
use std::marker::PhantomData;
// All Property types must implement Debug. // All Property types must implement Debug.
#[derive(Debug)] #[derive(Debug)]

View File

@ -3,7 +3,9 @@ use desfire::desfire::Desfire;
use desfire::error::Error as DesfireError; use desfire::error::Error as DesfireError;
use desfire::iso7816_4::apduresponse::APDUResponse; use desfire::iso7816_4::apduresponse::APDUResponse;
use rsasl::callback::SessionData; use rsasl::callback::SessionData;
use rsasl::mechanism::{Authentication, MechanismData, MechanismError, MechanismErrorKind, State, ThisProvider}; use rsasl::mechanism::{
Authentication, MechanismData, MechanismError, MechanismErrorKind, State, ThisProvider,
};
use rsasl::prelude::{MessageSent, SASLConfig, SASLError, SessionError}; use rsasl::prelude::{MessageSent, SASLConfig, SASLError, SessionError};
use rsasl::property::AuthId; use rsasl::property::AuthId;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -62,9 +64,7 @@ impl Display for FabFireError {
} }
} }
impl std::error::Error for FabFireError { impl std::error::Error for FabFireError {}
}
impl MechanismError for FabFireError { impl MechanismError for FabFireError {
fn kind(&self) -> MechanismErrorKind { fn kind(&self) -> MechanismErrorKind {
@ -496,7 +496,10 @@ impl Authentication for FabFire {
let token = String::from_utf8(data).unwrap(); let token = String::from_utf8(data).unwrap();
let prov = let prov =
ThisProvider::<AuthId>::with(token.trim_matches(char::from(0))); ThisProvider::<AuthId>::with(token.trim_matches(char::from(0)));
let key = session.need_with::<FabFireCardKey, _, _>(&prov, |key| Ok(Box::from(key.as_slice())))?; let key = session
.need_with::<FabFireCardKey, _, _>(&prov, |key| {
Ok(Box::from(key.as_slice()))
})?;
self.key_info = Some(KeyInfo { key_id: 0x01, key }); self.key_info = Some(KeyInfo { key_id: 0x01, key });
} }
None => { None => {

View File

@ -1,11 +1,11 @@
use crate::users::Users; use crate::users::Users;
use miette::{IntoDiagnostic, WrapErr}; use miette::{IntoDiagnostic, WrapErr};
use std::sync::Arc; use rsasl::callback::{CallbackError, Context, Request, SessionCallback, SessionData};
use rsasl::callback::{CallbackError, Request, SessionCallback, SessionData, Context};
use rsasl::mechanism::SessionError; use rsasl::mechanism::SessionError;
use rsasl::prelude::{Mechname, SASLConfig, SASLServer, Session, Validation}; use rsasl::prelude::{Mechname, SASLConfig, SASLServer, Session, Validation};
use rsasl::property::{AuthId, AuthzId, Password}; use rsasl::property::{AuthId, AuthzId, Password};
use rsasl::validate::{Validate, ValidationError}; use rsasl::validate::{Validate, ValidationError};
use std::sync::Arc;
use crate::authentication::fabfire::FabFireCardKey; use crate::authentication::fabfire::FabFireCardKey;
use crate::users::db::User; use crate::users::db::User;
@ -23,41 +23,55 @@ impl Callback {
} }
} }
impl SessionCallback for Callback { impl SessionCallback for Callback {
fn callback(&self, session_data: &SessionData, context: &Context, request: &mut Request) -> Result<(), SessionError> { fn callback(
&self,
session_data: &SessionData,
context: &Context,
request: &mut Request,
) -> Result<(), SessionError> {
if let Some(authid) = context.get_ref::<AuthId>() { if let Some(authid) = context.get_ref::<AuthId>() {
request.satisfy_with::<FabFireCardKey, _>(|| { request.satisfy_with::<FabFireCardKey, _>(|| {
let user = self.users.get_user(authid).ok_or(CallbackError::NoValue)?; let user = self.users.get_user(authid).ok_or(CallbackError::NoValue)?;
let kv = user.userdata.kv.get("cardkey").ok_or(CallbackError::NoValue)?; let kv = user
let card_key = <[u8; 16]>::try_from( .userdata
hex::decode(kv).map_err(|_| CallbackError::NoValue)?, .kv
).map_err(|_| CallbackError::NoValue)?; .get("cardkey")
.ok_or(CallbackError::NoValue)?;
let card_key =
<[u8; 16]>::try_from(hex::decode(kv).map_err(|_| CallbackError::NoValue)?)
.map_err(|_| CallbackError::NoValue)?;
Ok(card_key) Ok(card_key)
})?; })?;
} }
Ok(()) Ok(())
} }
fn validate(&self, session_data: &SessionData, context: &Context, validate: &mut Validate<'_>) -> Result<(), ValidationError> { fn validate(
&self,
session_data: &SessionData,
context: &Context,
validate: &mut Validate<'_>,
) -> Result<(), ValidationError> {
let span = tracing::info_span!(parent: &self.span, "validate"); let span = tracing::info_span!(parent: &self.span, "validate");
let _guard = span.enter(); let _guard = span.enter();
if validate.is::<V>() { if validate.is::<V>() {
match session_data.mechanism().mechanism.as_str() { match session_data.mechanism().mechanism.as_str() {
"PLAIN" => { "PLAIN" => {
let authcid = context.get_ref::<AuthId>() let authcid = context
.get_ref::<AuthId>()
.ok_or(ValidationError::MissingRequiredProperty)?; .ok_or(ValidationError::MissingRequiredProperty)?;
let authzid = context.get_ref::<AuthzId>(); let authzid = context.get_ref::<AuthzId>();
let password = context.get_ref::<Password>() let password = context
.get_ref::<Password>()
.ok_or(ValidationError::MissingRequiredProperty)?; .ok_or(ValidationError::MissingRequiredProperty)?;
if authzid.is_some() { if authzid.is_some() {
return Ok(()) return Ok(());
} }
if let Some(user) = self.users.get_user(authcid) { if let Some(user) = self.users.get_user(authcid) {
match user.check_password(password) { match user.check_password(password) {
Ok(true) => { Ok(true) => validate.finalize::<V>(user),
validate.finalize::<V>(user)
}
Ok(false) => { Ok(false) => {
tracing::warn!(authid=%authcid, "AUTH FAILED: bad password"); tracing::warn!(authid=%authcid, "AUTH FAILED: bad password");
} }

View File

@ -2,21 +2,21 @@ use capnp::capability::Promise;
use capnp::Error; use capnp::Error;
use capnp_rpc::pry; use capnp_rpc::pry;
use rsasl::mechname::Mechname; use rsasl::mechname::Mechname;
use rsasl::prelude::State as SaslState;
use rsasl::prelude::{MessageSent, Session};
use rsasl::property::AuthId; use rsasl::property::AuthId;
use std::fmt; use std::fmt;
use std::fmt::{Formatter, Write}; use std::fmt::{Formatter, Write};
use std::io::Cursor; use std::io::Cursor;
use rsasl::prelude::{MessageSent, Session};
use rsasl::prelude::State as SaslState;
use tracing::Span; use tracing::Span;
use crate::authentication::V;
use crate::capnp::session::APISession; use crate::capnp::session::APISession;
use crate::session::SessionManager; use crate::session::SessionManager;
use api::authenticationsystem_capnp::authentication::{ use api::authenticationsystem_capnp::authentication::{
AbortParams, AbortResults, Server as AuthenticationSystem, StepParams, StepResults, AbortParams, AbortResults, Server as AuthenticationSystem, StepParams, StepResults,
}; };
use api::authenticationsystem_capnp::{response, response::Error as ErrorCode}; use api::authenticationsystem_capnp::{response, response::Error as ErrorCode};
use crate::authentication::V;
const TARGET: &str = "bffh::api::authenticationsystem"; const TARGET: &str = "bffh::api::authenticationsystem";

View File

@ -1,5 +1,7 @@
use async_net::TcpListener; use miette::Diagnostic;
use thiserror::Error;
use async_net::TcpListener;
use capnp_rpc::rpc_twoparty_capnp::Side; use capnp_rpc::rpc_twoparty_capnp::Side;
use capnp_rpc::twoparty::VatNetwork; use capnp_rpc::twoparty::VatNetwork;
use capnp_rpc::RpcSystem; use capnp_rpc::RpcSystem;
@ -37,6 +39,10 @@ pub struct APIServer {
authentication: AuthenticationHandle, authentication: AuthenticationHandle,
} }
#[derive(Debug, Error, Diagnostic)]
#[error("Reached Void error, this should not be possible")]
pub enum Error {}
impl APIServer { impl APIServer {
pub fn new( pub fn new(
executor: Executor<'static>, executor: Executor<'static>,
@ -60,7 +66,7 @@ impl APIServer {
acceptor: TlsAcceptor, acceptor: TlsAcceptor,
sessionmanager: SessionManager, sessionmanager: SessionManager,
authentication: AuthenticationHandle, authentication: AuthenticationHandle,
) -> miette::Result<Self> { ) -> Result<Self, Error> {
let span = tracing::info_span!("binding API listen sockets"); let span = tracing::info_span!("binding API listen sockets");
let _guard = span.enter(); let _guard = span.enter();

View File

@ -13,9 +13,9 @@ pub type ErrorO = lmdb::Error;
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
#[repr(transparent)] #[derive(Clone, Debug, PartialEq, Eq, Error)]
#[derive(Debug, Error)]
#[error(transparent)] #[error(transparent)]
#[repr(transparent)]
pub struct Error(#[from] lmdb::Error); pub struct Error(#[from] lmdb::Error);
impl Diagnostic for Error { impl Diagnostic for Error {

View File

@ -1,5 +1,6 @@
use super::Initiator; use super::Initiator;
use super::InitiatorCallbacks; use super::InitiatorCallbacks;
use crate::resources::modules::fabaccess::Status;
use crate::resources::state::State; use crate::resources::state::State;
use crate::utils::linebuffer::LineBuffer; use crate::utils::linebuffer::LineBuffer;
use async_process::{Child, ChildStderr, ChildStdout, Command, Stdio}; use async_process::{Child, ChildStderr, ChildStdout, Command, Stdio};
@ -11,7 +12,6 @@ use std::future::Future;
use std::io; use std::io;
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use crate::resources::modules::fabaccess::Status;
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub enum InputMessage { pub enum InputMessage {
@ -63,7 +63,12 @@ struct ProcessState {
impl ProcessState { impl ProcessState {
pub fn new(stdout: ChildStdout, stderr: ChildStderr, child: Child) -> Self { pub fn new(stdout: ChildStdout, stderr: ChildStderr, child: Child) -> Self {
Self { stdout, stderr, stderr_closed: false, child } Self {
stdout,
stderr,
stderr_closed: false,
child,
}
} }
fn try_process(&mut self, buffer: &[u8], callbacks: &mut InitiatorCallbacks) -> usize { fn try_process(&mut self, buffer: &[u8], callbacks: &mut InitiatorCallbacks) -> usize {
@ -100,7 +105,9 @@ impl ProcessState {
let InputMessage::SetState(status) = state; let InputMessage::SetState(status) = state;
callbacks.set_status(status); callbacks.set_status(status);
} }
Err(error) => tracing::warn!(%error, "process initiator did not send a valid line"), Err(error) => {
tracing::warn!(%error, "process initiator did not send a valid line")
}
} }
} }
} }
@ -202,8 +209,8 @@ impl Future for Process {
impl Initiator for Process { impl Initiator for Process {
fn new(params: &HashMap<String, String>, callbacks: InitiatorCallbacks) -> miette::Result<Self> fn new(params: &HashMap<String, String>, callbacks: InitiatorCallbacks) -> miette::Result<Self>
where where
Self: Sized, Self: Sized,
{ {
let cmd = params let cmd = params
.get("cmd") .get("cmd")

View File

@ -8,6 +8,10 @@
//! This is the capnp component of the FabAccess project. //! This is the capnp component of the FabAccess project.
//! The entry point of bffhd can be found in [bin/bffhd/main.rs](../bin/bffhd/main.rs) //! The entry point of bffhd can be found in [bin/bffhd/main.rs](../bin/bffhd/main.rs)
use miette::Diagnostic;
use std::io;
use thiserror::Error;
pub mod config; pub mod config;
/// Internal Databases build on top of LMDB, a mmap()'ed B-tree DB optimized for reads /// Internal Databases build on top of LMDB, a mmap()'ed B-tree DB optimized for reads
@ -82,10 +86,58 @@ impl error::Description for SignalHandlerErr {
const CODE: &'static str = "signals::new"; const CODE: &'static str = "signals::new";
} }
#[derive(Debug, Error, Diagnostic)]
pub enum BFFHError {
#[error("DB operation failed")]
DBError(
#[from]
#[source]
db::Error,
),
#[error("failed to initialize global user store")]
UsersError(
#[from]
#[source]
users::Error,
),
#[error("failed to initialize state database")]
StateDBError(
#[from]
#[source]
resources::state::db::StateDBError,
),
#[error("audit log failed")]
AuditLogError(
#[from]
#[source]
audit::Error,
),
#[error("Failed to initialize signal handler")]
SignalsError(#[source] std::io::Error),
#[error("error in actor subsystem")]
ActorError(
#[from]
#[source]
actors::ActorError,
),
#[error("failed to initialize TLS config")]
TlsSetup(
#[from]
#[source]
tls::Error,
),
#[error("API handler failed")]
ApiError(
#[from]
#[source]
capnp::Error,
),
}
impl Diflouroborane { impl Diflouroborane {
pub fn setup() {} pub fn setup() {}
pub fn new(config: Config) -> miette::Result<Self> { pub fn new(config: Config) -> Result<Self, BFFHError> {
let mut server = logging::init(&config.logging); let mut server = logging::init(&config.logging);
let span = tracing::info_span!( let span = tracing::info_span!(
target: "bffh", target: "bffh",
@ -121,9 +173,7 @@ impl Diflouroborane {
let users = Users::new(env.clone())?; let users = Users::new(env.clone())?;
let roles = Roles::new(config.roles.clone()); let roles = Roles::new(config.roles.clone());
let _audit_log = AuditLog::new(&config) let _audit_log = AuditLog::new(&config)?;
.into_diagnostic()
.wrap_err("Failed to initialize audit log")?;
let resources = ResourcesHandle::new(config.machines.iter().map(|(id, desc)| { let resources = ResourcesHandle::new(config.machines.iter().map(|(id, desc)| {
Resource::new(Arc::new(resources::Inner::new( Resource::new(Arc::new(resources::Inner::new(
@ -145,10 +195,10 @@ impl Diflouroborane {
}) })
} }
pub fn run(&mut self) -> miette::Result<()> { pub fn run(&mut self) -> Result<(), BFFHError> {
let _guard = self.span.enter(); let _guard = self.span.enter();
let mut signals = signal_hook_async_std::Signals::new(&[SIGINT, SIGQUIT, SIGTERM]) let mut signals = signal_hook_async_std::Signals::new(&[SIGINT, SIGQUIT, SIGTERM])
.map_err(|ioerr| error::wrap::<SignalHandlerErr>(ioerr.into()))?; .map_err(BFFHError::SignalsError)?;
let sessionmanager = SessionManager::new(self.users.clone(), self.roles.clone()); let sessionmanager = SessionManager::new(self.users.clone(), self.roles.clone());
let authentication = AuthenticationHandle::new(self.users.clone()); let authentication = AuthenticationHandle::new(self.users.clone());
@ -162,8 +212,7 @@ impl Diflouroborane {
); );
actors::load(self.executor.clone(), &self.config, self.resources.clone())?; actors::load(self.executor.clone(), &self.config, self.resources.clone())?;
let tlsconfig = TlsConfig::new(self.config.tlskeylog.as_ref(), !self.config.is_quiet()) let tlsconfig = TlsConfig::new(self.config.tlskeylog.as_ref(), !self.config.is_quiet())?;
.into_diagnostic()?;
let acceptor = tlsconfig.make_tls_acceptor(&self.config.tlsconfig)?; let acceptor = tlsconfig.make_tls_acceptor(&self.config.tlsconfig)?;
let apiserver = self.executor.run(APIServer::bind( let apiserver = self.executor.run(APIServer::bind(

View File

@ -17,7 +17,7 @@ pub struct StateDB {
db: DB<AlignedAdapter<State>>, db: DB<AlignedAdapter<State>>,
} }
#[derive(Debug, Error, Diagnostic)] #[derive(Clone, Debug, PartialEq, Eq, Error, Diagnostic)]
pub enum StateDBError { pub enum StateDBError {
#[error("opening the state db environment failed")] #[error("opening the state db environment failed")]
#[diagnostic( #[diagnostic(

View File

@ -1,10 +1,10 @@
use crate::authorization::permissions::Permission; use crate::authorization::permissions::Permission;
use crate::authorization::roles::Roles; use crate::authorization::roles::Roles;
use crate::resources::Resource; use crate::resources::Resource;
use crate::users::db::User;
use crate::users::{db, UserRef}; use crate::users::{db, UserRef};
use crate::Users; use crate::Users;
use tracing::Span; use tracing::Span;
use crate::users::db::User;
#[derive(Clone)] #[derive(Clone)]
pub struct SessionManager { pub struct SessionManager {
@ -18,7 +18,9 @@ impl SessionManager {
} }
pub fn try_open(&self, parent: &Span, uid: impl AsRef<str>) -> Option<SessionHandle> { pub fn try_open(&self, parent: &Span, uid: impl AsRef<str>) -> Option<SessionHandle> {
self.users.get_user(uid.as_ref()).map(|user| self.open(parent, user)) self.users
.get_user(uid.as_ref())
.map(|user| self.open(parent, user))
} }
// TODO: make infallible // TODO: make infallible

View File

@ -1,17 +1,19 @@
use std::fs::File; use std::fs::File;
use std::io; use std::io;
use std::io::BufReader; use std::io::BufReader;
use std::path::Path; use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use crate::capnp::TlsListen; use crate::capnp::TlsListen;
use futures_rustls::TlsAcceptor; use futures_rustls::TlsAcceptor;
use miette::IntoDiagnostic; use miette::Diagnostic;
use rustls::version::{TLS12, TLS13}; use rustls::version::{TLS12, TLS13};
use rustls::{Certificate, PrivateKey, ServerConfig, SupportedCipherSuite}; use rustls::{Certificate, PrivateKey, ServerConfig, SupportedCipherSuite};
use thiserror::Error;
use tracing::Level; use tracing::Level;
use crate::keylog::KeyLogFile; use crate::keylog::KeyLogFile;
use crate::tls::Error::KeyLogOpen;
fn lookup_cipher_suite(name: &str) -> Option<SupportedCipherSuite> { fn lookup_cipher_suite(name: &str) -> Option<SupportedCipherSuite> {
match name { match name {
@ -47,8 +49,32 @@ pub struct TlsConfig {
keylog: Option<Arc<KeyLogFile>>, keylog: Option<Arc<KeyLogFile>>,
} }
#[derive(Debug, Error, Diagnostic)]
pub enum Error {
#[error("failed to open certificate file at path {0}")]
OpenCertFile(PathBuf, #[source] io::Error),
#[error("failed to open private key file at path {0}")]
OpenKeyFile(PathBuf, #[source] io::Error),
#[error("failed to read system certs")]
SystemCertsFile(#[source] io::Error),
#[error("failed to read from key file")]
ReadKeyFile(#[source] io::Error),
#[error("private key file must contain a single PEM-encoded private key")]
KeyFileFormat,
#[error("invalid TLS version {0}")]
TlsVersion(String),
#[error("Initializing TLS context failed")]
Builder(
#[from]
#[source]
rustls::Error,
),
#[error("failed to initialize key log")]
KeyLogOpen(#[source] io::Error),
}
impl TlsConfig { impl TlsConfig {
pub fn new(keylogfile: Option<impl AsRef<Path>>, warn: bool) -> io::Result<Self> { pub fn new(keylogfile: Option<impl AsRef<Path>>, warn: bool) -> Result<Self, Error> {
let span = tracing::span!(Level::INFO, "tls"); let span = tracing::span!(Level::INFO, "tls");
let _guard = span.enter(); let _guard = span.enter();
@ -57,7 +83,11 @@ impl TlsConfig {
} }
if let Some(path) = keylogfile { if let Some(path) = keylogfile {
let keylog = Some(KeyLogFile::new(path).map(|ok| Arc::new(ok))?); let keylog = Some(
KeyLogFile::new(path)
.map(|ok| Arc::new(ok))
.map_err(KeyLogOpen)?,
);
Ok(Self { keylog }) Ok(Self { keylog })
} else { } else {
Ok(Self { keylog: None }) Ok(Self { keylog: None })
@ -75,27 +105,31 @@ impl TlsConfig {
} }
} }
pub fn make_tls_acceptor(&self, config: &TlsListen) -> miette::Result<TlsAcceptor> { pub fn make_tls_acceptor(&self, config: &TlsListen) -> Result<TlsAcceptor, Error> {
let span = tracing::debug_span!("tls"); let span = tracing::debug_span!("tls");
let _guard = span.enter(); let _guard = span.enter();
tracing::debug!(path = %config.certfile.as_path().display(), "reading certificates"); let path = config.certfile.as_path();
let mut certfp = BufReader::new(File::open(config.certfile.as_path()).into_diagnostic()?); tracing::debug!(path = %path.display(), "reading certificates");
let mut certfp =
BufReader::new(File::open(path).map_err(|e| Error::OpenCertFile(path.into(), e))?);
let certs = rustls_pemfile::certs(&mut certfp) let certs = rustls_pemfile::certs(&mut certfp)
.into_diagnostic()? .map_err(Error::SystemCertsFile)?
.into_iter() .into_iter()
.map(Certificate) .map(Certificate)
.collect(); .collect();
tracing::debug!(path = %config.keyfile.as_path().display(), "reading private key"); let path = config.keyfile.as_path();
let mut keyfp = BufReader::new(File::open(config.keyfile.as_path()).into_diagnostic()?); tracing::debug!(path = %path.display(), "reading private key");
let key = match rustls_pemfile::read_one(&mut keyfp).into_diagnostic()? { let mut keyfp =
BufReader::new(File::open(path).map_err(|err| Error::OpenKeyFile(path.into(), err))?);
let key = match rustls_pemfile::read_one(&mut keyfp).map_err(Error::ReadKeyFile)? {
Some(rustls_pemfile::Item::PKCS8Key(key) | rustls_pemfile::Item::RSAKey(key)) => { Some(rustls_pemfile::Item::PKCS8Key(key) | rustls_pemfile::Item::RSAKey(key)) => {
PrivateKey(key) PrivateKey(key)
} }
_ => { _ => {
tracing::error!("private key file invalid"); tracing::error!("private key file invalid");
miette::bail!("private key file must contain a PEM-encoded private key") return Err(Error::KeyFileFormat);
} }
}; };
@ -104,20 +138,19 @@ impl TlsConfig {
.with_safe_default_kx_groups(); .with_safe_default_kx_groups();
let tls_builder = if let Some(ref min) = config.tls_min_version { let tls_builder = if let Some(ref min) = config.tls_min_version {
match min.as_str() { let v = min.to_lowercase();
match v.as_str() {
"tls12" => tls_builder.with_protocol_versions(&[&TLS12]), "tls12" => tls_builder.with_protocol_versions(&[&TLS12]),
"tls13" => tls_builder.with_protocol_versions(&[&TLS13]), "tls13" => tls_builder.with_protocol_versions(&[&TLS13]),
x => miette::bail!("TLS version {} is invalid", x), _ => return Err(Error::TlsVersion(v)),
} }
} else { } else {
tls_builder.with_safe_default_protocol_versions() tls_builder.with_safe_default_protocol_versions()
} }?;
.into_diagnostic()?;
let mut tls_config = tls_builder let mut tls_config = tls_builder
.with_no_client_auth() .with_no_client_auth()
.with_single_cert(certs, key) .with_single_cert(certs, key)?;
.into_diagnostic()?;
if let Some(keylog) = &self.keylog { if let Some(keylog) = &self.keylog {
tls_config.key_log = keylog.clone(); tls_config.key_log = keylog.clone();

View File

@ -11,6 +11,8 @@ use rkyv::ser::serializers::AllocSerializer;
use rkyv::ser::Serializer; use rkyv::ser::Serializer;
use rkyv::Deserialize; use rkyv::Deserialize;
pub use crate::db::Error;
#[derive( #[derive(
Clone, Clone,
PartialEq, PartialEq,

View File

@ -11,6 +11,7 @@ use clap::ArgMatches;
use miette::{Context, Diagnostic, IntoDiagnostic, SourceOffset, SourceSpan}; use miette::{Context, Diagnostic, IntoDiagnostic, SourceOffset, SourceSpan};
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use thiserror::Error; use thiserror::Error;
pub mod db; pub mod db;
@ -69,17 +70,20 @@ pub struct Users {
userdb: &'static UserDB, userdb: &'static UserDB,
} }
#[derive(Clone, Debug, PartialEq, Eq, Error, Diagnostic)]
#[error(transparent)]
#[repr(transparent)]
pub struct Error(#[from] pub db::Error);
impl Users { impl Users {
pub fn new(env: Arc<Environment>) -> miette::Result<Self> { pub fn new(env: Arc<Environment>) -> Result<Self, Error> {
let span = tracing::debug_span!("users", ?env, "Creating Users handle"); let span = tracing::debug_span!("users", ?env, "Creating Users handle");
let _guard = span.enter(); let _guard = span.enter();
let userdb = USERDB let userdb = USERDB.get_or_try_init(|| {
.get_or_try_init(|| { tracing::debug!("Global resource not yet initialized, initializing…");
tracing::debug!("Global resource not yet initialized, initializing…"); unsafe { UserDB::create(env) }
unsafe { UserDB::create(env) } })?;
})
.wrap_err("Failed to open userdb")?;
Ok(Self { userdb }) Ok(Self { userdb })
} }