Run rustfmt

This commit is contained in:
Nadja Reitzenstein 2022-05-05 15:50:44 +02:00
parent 475cb9b9b4
commit 2d9f30b55b
74 changed files with 1243 additions and 1053 deletions

View File

@ -1,7 +1,8 @@
use walkdir::{WalkDir, DirEntry}; use walkdir::{DirEntry, WalkDir};
fn is_hidden(entry: &DirEntry) -> bool { fn is_hidden(entry: &DirEntry) -> bool {
entry.file_name() entry
.file_name()
.to_str() .to_str()
.map(|s| s.starts_with('.')) .map(|s| s.starts_with('.'))
.unwrap_or(false) .unwrap_or(false)
@ -22,15 +23,15 @@ fn main() {
.filter_map(Result::ok) // Filter all entries that access failed on .filter_map(Result::ok) // Filter all entries that access failed on
.filter(|e| !e.file_type().is_dir()) // Filter directories .filter(|e| !e.file_type().is_dir()) // Filter directories
// Filter non-schema files // Filter non-schema files
.filter(|e| e.file_name() .filter(|e| {
e.file_name()
.to_str() .to_str()
.map(|s| s.ends_with(".capnp")) .map(|s| s.ends_with(".capnp"))
.unwrap_or(false) .unwrap_or(false)
) })
{ {
println!("Collecting schema file {}", entry.path().display()); println!("Collecting schema file {}", entry.path().display());
compile_command compile_command.file(entry.path());
.file(entry.path());
} }
println!("Compiling schemas..."); println!("Compiling schemas...");
@ -53,16 +54,18 @@ fn main() {
.filter_map(Result::ok) // Filter all entries that access failed on .filter_map(Result::ok) // Filter all entries that access failed on
.filter(|e| !e.file_type().is_dir()) // Filter directories .filter(|e| !e.file_type().is_dir()) // Filter directories
// Filter non-schema files // Filter non-schema files
.filter(|e| e.file_name() .filter(|e| {
e.file_name()
.to_str() .to_str()
.map(|s| s.ends_with(".capnp")) .map(|s| s.ends_with(".capnp"))
.unwrap_or(false) .unwrap_or(false)
) })
{ {
println!("Collecting schema file {}", entry.path().display()); println!("Collecting schema file {}", entry.path().display());
compile_command compile_command.file(entry.path());
.file(entry.path());
} }
compile_command.run().expect("Failed to generate extra API code"); compile_command
.run()
.expect("Failed to generate extra API code");
} }

View File

@ -1,4 +1,3 @@
//! FabAccess generated API bindings //! FabAccess generated API bindings
//! //!
//! This crate contains slightly nicer and better documented bindings for the FabAccess API. //! This crate contains slightly nicer and better documented bindings for the FabAccess API.

View File

@ -1,6 +1,5 @@
pub use capnpc::schema_capnp; pub use capnpc::schema_capnp;
#[cfg(feature = "generated")] #[cfg(feature = "generated")]
pub mod authenticationsystem_capnp { pub mod authenticationsystem_capnp {
include!(concat!(env!("OUT_DIR"), "/authenticationsystem_capnp.rs")); include!(concat!(env!("OUT_DIR"), "/authenticationsystem_capnp.rs"));

View File

@ -1,6 +1,6 @@
use std::collections::HashMap;
use futures_util::future; use futures_util::future;
use futures_util::future::BoxFuture; use futures_util::future::BoxFuture;
use std::collections::HashMap;
use crate::actors::Actor; use crate::actors::Actor;
use crate::db::ArchivedValue; use crate::db::ArchivedValue;

View File

@ -3,7 +3,7 @@ use crate::resources::state::State;
use crate::{Config, ResourcesHandle}; use crate::{Config, ResourcesHandle};
use async_compat::CompatExt; use async_compat::CompatExt;
use executor::pool::Executor; use executor::pool::Executor;
use futures_signals::signal::{Signal}; use futures_signals::signal::Signal;
use futures_util::future::BoxFuture; use futures_util::future::BoxFuture;
use rumqttc::{AsyncClient, ConnectionError, Event, Incoming, MqttOptions}; use rumqttc::{AsyncClient, ConnectionError, Event, Incoming, MqttOptions};
@ -18,15 +18,15 @@ use std::time::Duration;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use rumqttc::ConnectReturnCode::Success; use rumqttc::ConnectReturnCode::Success;
use rustls::{RootCertStore};
use url::Url;
use crate::actors::dummy::Dummy; use crate::actors::dummy::Dummy;
use crate::actors::process::Process; use crate::actors::process::Process;
use crate::db::ArchivedValue; use crate::db::ArchivedValue;
use rustls::RootCertStore;
use url::Url;
mod shelly;
mod process;
mod dummy; mod dummy;
mod process;
mod shelly;
pub trait Actor { pub trait Actor {
fn apply(&mut self, state: ArchivedValue<State>) -> BoxFuture<'static, ()>; fn apply(&mut self, state: ArchivedValue<State>) -> BoxFuture<'static, ()>;
@ -102,7 +102,7 @@ static ROOT_CERTS: Lazy<RootCertStore> = Lazy::new(|| {
} else { } else {
tracing::info!(loaded, "certificates loaded"); tracing::info!(loaded, "certificates loaded");
} }
}, }
Err(error) => { Err(error) => {
tracing::error!(%error, "failed to load system certificates"); tracing::error!(%error, "failed to load system certificates");
} }
@ -219,8 +219,10 @@ pub fn load(executor: Executor, config: &Config, resources: ResourcesHandle) ->
.compat(), .compat(),
); );
let mut actor_map: HashMap<String, _> = config.actor_connections.iter() let mut actor_map: HashMap<String, _> = config
.filter_map(|(k,v)| { .actor_connections
.iter()
.filter_map(|(k, v)| {
if let Some(resource) = resources.get_by_id(v) { if let Some(resource) = resources.get_by_id(v) {
Some((k.clone(), resource.get_signal())) Some((k.clone(), resource.get_signal()))
} else { } else {
@ -258,8 +260,6 @@ fn load_single(
"Dummy" => Some(Box::new(Dummy::new(name.clone(), params.clone()))), "Dummy" => Some(Box::new(Dummy::new(name.clone(), params.clone()))),
"Process" => Process::new(name.clone(), params).map(|a| a.into_boxed_actuator()), "Process" => Process::new(name.clone(), params).map(|a| a.into_boxed_actuator()),
"Shelly" => Some(Box::new(Shelly::new(name.clone(), client, params))), "Shelly" => Some(Box::new(Shelly::new(name.clone(), client, params))),
_ => { _ => None,
None
}
} }
} }

View File

@ -1,6 +1,6 @@
use futures_util::future::BoxFuture;
use std::collections::HashMap; use std::collections::HashMap;
use std::process::{Command, Stdio}; use std::process::{Command, Stdio};
use futures_util::future::BoxFuture;
use crate::actors::Actor; use crate::actors::Actor;
use crate::db::ArchivedValue; use crate::db::ArchivedValue;
@ -16,10 +16,9 @@ pub struct Process {
impl Process { impl Process {
pub fn new(name: String, params: &HashMap<String, String>) -> Option<Self> { pub fn new(name: String, params: &HashMap<String, String>) -> Option<Self> {
let cmd = params.get("cmd").map(|s| s.to_string())?; let cmd = params.get("cmd").map(|s| s.to_string())?;
let args = params.get("args").map(|argv| let args = params
argv.split_whitespace() .get("args")
.map(|s| s.to_string()) .map(|argv| argv.split_whitespace().map(|s| s.to_string()).collect())
.collect())
.unwrap_or_else(Vec::new); .unwrap_or_else(Vec::new);
Some(Self { name, cmd, args }) Some(Self { name, cmd, args })
@ -48,22 +47,22 @@ impl Actor for Process {
command.arg("inuse").arg(by.id.as_str()); command.arg("inuse").arg(by.id.as_str());
} }
ArchivedStatus::ToCheck(by) => { ArchivedStatus::ToCheck(by) => {
command.arg("tocheck") command.arg("tocheck").arg(by.id.as_str());
.arg(by.id.as_str());
} }
ArchivedStatus::Blocked(by) => { ArchivedStatus::Blocked(by) => {
command.arg("blocked") command.arg("blocked").arg(by.id.as_str());
.arg(by.id.as_str()); }
ArchivedStatus::Disabled => {
command.arg("disabled");
} }
ArchivedStatus::Disabled => { command.arg("disabled"); },
ArchivedStatus::Reserved(by) => { ArchivedStatus::Reserved(by) => {
command.arg("reserved") command.arg("reserved").arg(by.id.as_str());
.arg(by.id.as_str());
} }
} }
let name = self.name.clone(); let name = self.name.clone();
Box::pin(async move { match command.output() { Box::pin(async move {
match command.output() {
Ok(retv) if retv.status.success() => { Ok(retv) if retv.status.success() => {
tracing::trace!("Actor was successful"); tracing::trace!("Actor was successful");
let outstr = String::from_utf8_lossy(&retv.stdout); let outstr = String::from_utf8_lossy(&retv.stdout);
@ -83,6 +82,7 @@ impl Actor for Process {
} }
} }
Err(error) => tracing::warn!(%name, ?error, "process actor failed to run cmd"), Err(error) => tracing::warn!(%name, ?error, "process actor failed to run cmd"),
}}) }
})
} }
} }

View File

@ -1,11 +1,11 @@
use std::collections::HashMap;
use futures_util::future::BoxFuture; use futures_util::future::BoxFuture;
use std::collections::HashMap;
use rumqttc::{AsyncClient, QoS};
use crate::actors::Actor; use crate::actors::Actor;
use crate::db::ArchivedValue; use crate::db::ArchivedValue;
use crate::resources::modules::fabaccess::ArchivedStatus; use crate::resources::modules::fabaccess::ArchivedStatus;
use crate::resources::state::State; use crate::resources::state::State;
use rumqttc::{AsyncClient, QoS};
/// An actuator for a Shellie connected listening on one MQTT broker /// An actuator for a Shellie connected listening on one MQTT broker
/// ///
@ -28,7 +28,11 @@ impl Shelly {
tracing::debug!(%name,%topic,"Starting shelly module"); tracing::debug!(%name,%topic,"Starting shelly module");
Shelly { name, client, topic, } Shelly {
name,
client,
topic,
}
} }
/// Set the name to a new one. This changes the shelly that will be activated /// Set the name to a new one. This changes the shelly that will be activated
@ -38,7 +42,6 @@ impl Shelly {
} }
} }
impl Actor for Shelly { impl Actor for Shelly {
fn apply(&mut self, state: ArchivedValue<State>) -> BoxFuture<'static, ()> { fn apply(&mut self, state: ArchivedValue<State>) -> BoxFuture<'static, ()> {
tracing::debug!(?state, name=%self.name, tracing::debug!(?state, name=%self.name,

View File

@ -1,11 +1,11 @@
use once_cell::sync::OnceCell;
use std::fs::{File, OpenOptions}; use std::fs::{File, OpenOptions};
use std::io; use std::io;
use std::io::{LineWriter, Write}; use std::io::{LineWriter, Write};
use std::sync::Mutex; use std::sync::Mutex;
use once_cell::sync::OnceCell;
use crate::Config; use crate::Config;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use serde_json::Serializer; use serde_json::Serializer;
pub static AUDIT: OnceCell<AuditLog> = OnceCell::new(); pub static AUDIT: OnceCell<AuditLog> = OnceCell::new();
@ -26,7 +26,10 @@ impl AuditLog {
pub fn new(config: &Config) -> io::Result<&'static Self> { pub fn new(config: &Config) -> io::Result<&'static Self> {
AUDIT.get_or_try_init(|| { AUDIT.get_or_try_init(|| {
tracing::debug!(path = %config.auditlog_path.display(), "Initializing audit log"); tracing::debug!(path = %config.auditlog_path.display(), "Initializing audit log");
let fd = OpenOptions::new().create(true).append(true).open(&config.auditlog_path)?; let fd = OpenOptions::new()
.create(true)
.append(true)
.open(&config.auditlog_path)?;
let writer = Mutex::new(LineWriter::new(fd)); let writer = Mutex::new(LineWriter::new(fd));
Ok(Self { writer }) Ok(Self { writer })
}) })
@ -34,7 +37,11 @@ impl AuditLog {
pub fn log(&self, machine: &str, state: &str) -> io::Result<()> { pub fn log(&self, machine: &str, state: &str) -> io::Result<()> {
let timestamp = chrono::Utc::now().timestamp(); let timestamp = chrono::Utc::now().timestamp();
let line = AuditLogLine { timestamp, machine, state }; let line = AuditLogLine {
timestamp,
machine,
state,
};
tracing::debug!(?line, "writing audit log line"); tracing::debug!(?line, "writing audit log line");
@ -42,7 +49,8 @@ impl AuditLog {
let mut writer: &mut LineWriter<File> = &mut *guard; let mut writer: &mut LineWriter<File> = &mut *guard;
let mut ser = Serializer::new(&mut writer); let mut ser = Serializer::new(&mut writer);
line.serialize(&mut ser).expect("failed to serialize audit log line"); line.serialize(&mut ser)
.expect("failed to serialize audit log line");
writer.write("\n".as_bytes())?; writer.write("\n".as_bytes())?;
Ok(()) Ok(())
} }

View File

@ -19,8 +19,8 @@ pub static FABFIRE: Mechanism = Mechanism {
first: Side::Client, first: Side::Client,
}; };
use rsasl::property::{Property, PropertyDefinition, PropertyQ};
use std::marker::PhantomData; use std::marker::PhantomData;
use rsasl::property::{Property, PropertyQ, PropertyDefinition};
// All Property types must implement Debug. // All Property types must implement Debug.
#[derive(Debug)] #[derive(Debug)]
// The `PhantomData` in the constructor is only used so external crates can't construct this type. // The `PhantomData` in the constructor is only used so external crates can't construct this type.

View File

@ -1,17 +1,17 @@
use std::fmt::{Debug, Display, Formatter}; use desfire::desfire::desfire::MAX_BYTES_PER_TRANSACTION;
use std::io::Write; use desfire::desfire::Desfire;
use desfire::error::Error as DesfireError;
use desfire::iso7816_4::apduresponse::APDUResponse;
use rsasl::error::{MechanismError, MechanismErrorKind, SASLError, SessionError}; use rsasl::error::{MechanismError, MechanismErrorKind, SASLError, SessionError};
use rsasl::mechanism::Authentication; use rsasl::mechanism::Authentication;
use rsasl::SASL;
use rsasl::session::{SessionData, StepResult};
use serde::{Deserialize, Serialize};
use desfire::desfire::Desfire;
use desfire::iso7816_4::apduresponse::APDUResponse;
use desfire::error::{Error as DesfireError};
use std::convert::TryFrom;
use std::sync::Arc;
use desfire::desfire::desfire::MAX_BYTES_PER_TRANSACTION;
use rsasl::property::AuthId; use rsasl::property::AuthId;
use rsasl::session::{SessionData, StepResult};
use rsasl::SASL;
use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use std::fmt::{Debug, Display, Formatter};
use std::io::Write;
use std::sync::Arc;
use crate::authentication::fabfire::FabFireCardKey; use crate::authentication::fabfire::FabFireCardKey;
@ -37,7 +37,9 @@ impl Debug for FabFireError {
FabFireError::InvalidMagic(magic) => write!(f, "InvalidMagic: {}", magic), FabFireError::InvalidMagic(magic) => write!(f, "InvalidMagic: {}", magic),
FabFireError::InvalidToken(token) => write!(f, "InvalidToken: {}", token), FabFireError::InvalidToken(token) => write!(f, "InvalidToken: {}", token),
FabFireError::InvalidURN(urn) => write!(f, "InvalidURN: {}", urn), FabFireError::InvalidURN(urn) => write!(f, "InvalidURN: {}", urn),
FabFireError::InvalidCredentials(credentials) => write!(f, "InvalidCredentials: {}", credentials), FabFireError::InvalidCredentials(credentials) => {
write!(f, "InvalidCredentials: {}", credentials)
}
FabFireError::Session(err) => write!(f, "Session: {}", err), FabFireError::Session(err) => write!(f, "Session: {}", err),
} }
} }
@ -53,7 +55,9 @@ impl Display for FabFireError {
FabFireError::InvalidMagic(magic) => write!(f, "InvalidMagic: {}", magic), FabFireError::InvalidMagic(magic) => write!(f, "InvalidMagic: {}", magic),
FabFireError::InvalidToken(token) => write!(f, "InvalidToken: {}", token), FabFireError::InvalidToken(token) => write!(f, "InvalidToken: {}", token),
FabFireError::InvalidURN(urn) => write!(f, "InvalidURN: {}", urn), FabFireError::InvalidURN(urn) => write!(f, "InvalidURN: {}", urn),
FabFireError::InvalidCredentials(credentials) => write!(f, "InvalidCredentials: {}", credentials), FabFireError::InvalidCredentials(credentials) => {
write!(f, "InvalidCredentials: {}", credentials)
}
FabFireError::Session(err) => write!(f, "Session: {}", err), FabFireError::Session(err) => write!(f, "Session: {}", err),
} }
} }
@ -107,16 +111,22 @@ enum CardCommand {
addn_txt: Option<String>, addn_txt: Option<String>,
}, },
sendPICC { sendPICC {
#[serde(deserialize_with = "hex::deserialize", serialize_with = "hex::serialize_upper")] #[serde(
data: Vec<u8> deserialize_with = "hex::deserialize",
serialize_with = "hex::serialize_upper"
)]
data: Vec<u8>,
}, },
readPICC { readPICC {
#[serde(deserialize_with = "hex::deserialize", serialize_with = "hex::serialize_upper")] #[serde(
data: Vec<u8> deserialize_with = "hex::deserialize",
serialize_with = "hex::serialize_upper"
)]
data: Vec<u8>,
}, },
haltPICC, haltPICC,
Key { Key {
data: String data: String,
}, },
ConfirmUser, ConfirmUser,
} }
@ -145,18 +155,35 @@ const MAGIC: &'static str = "FABACCESS\0DESFIRE\01.0\0";
impl FabFire { impl FabFire {
pub fn new_server(_sasl: &SASL) -> Result<Box<dyn Authentication>, SASLError> { pub fn new_server(_sasl: &SASL) -> Result<Box<dyn Authentication>, SASLError> {
Ok(Box::new(Self { step: Step::New, card_info: None, key_info: None, auth_info: None, app_id: 1, local_urn: "urn:fabaccess:lab:innovisionlab".to_string(), desfire: Desfire { card: None, session_key: None, cbc_iv: None } })) Ok(Box::new(Self {
step: Step::New,
card_info: None,
key_info: None,
auth_info: None,
app_id: 1,
local_urn: "urn:fabaccess:lab:innovisionlab".to_string(),
desfire: Desfire {
card: None,
session_key: None,
cbc_iv: None,
},
}))
} }
} }
impl Authentication for FabFire { impl Authentication for FabFire {
fn step(&mut self, session: &mut SessionData, input: Option<&[u8]>, writer: &mut dyn Write) -> StepResult { fn step(
&mut self,
session: &mut SessionData,
input: Option<&[u8]>,
writer: &mut dyn Write,
) -> StepResult {
match self.step { match self.step {
Step::New => { Step::New => {
tracing::trace!("Step: New"); tracing::trace!("Step: New");
//receive card info (especially card UID) from reader //receive card info (especially card UID) from reader
return match input { return match input {
None => { Err(SessionError::InputDataRequired) } None => Err(SessionError::InputDataRequired),
Some(cardinfo) => { Some(cardinfo) => {
self.card_info = match serde_json::from_slice(cardinfo) { self.card_info = match serde_json::from_slice(cardinfo) {
Ok(card_info) => Some(card_info), Ok(card_info) => Some(card_info),
@ -170,7 +197,10 @@ impl Authentication for FabFire {
Ok(buf) => match Vec::<u8>::try_from(buf) { Ok(buf) => match Vec::<u8>::try_from(buf) {
Ok(data) => data, Ok(data) => data,
Err(e) => { Err(e) => {
tracing::error!("Failed to convert APDUCommand to Vec<u8>: {:?}", e); tracing::error!(
"Failed to convert APDUCommand to Vec<u8>: {:?}",
e
);
return Err(FabFireError::SerializationError.into()); return Err(FabFireError::SerializationError.into());
} }
}, },
@ -183,7 +213,9 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) { return match serde_json::to_vec(&cmd) {
Ok(send_buf) => { Ok(send_buf) => {
self.step = Step::SelectApp; self.step = Step::SelectApp;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?; writer
.write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len()))) Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len())))
} }
Err(e) => { Err(e) => {
@ -198,30 +230,39 @@ impl Authentication for FabFire {
tracing::trace!("Step: SelectApp"); tracing::trace!("Step: SelectApp");
// check that we successfully selected the application // check that we successfully selected the application
let response: CardCommand = match input { let response: CardCommand = match input {
None => { return Err(SessionError::InputDataRequired); } None => {
Some(buf) => match serde_json::from_slice(buf).map_err(|e| FabFireError::DeserializationError(e)) { return Err(SessionError::InputDataRequired);
}
Some(buf) => match serde_json::from_slice(buf)
.map_err(|e| FabFireError::DeserializationError(e))
{
Ok(response) => response, Ok(response) => response,
Err(e) => { Err(e) => {
tracing::error!("Deserializing data from card failed: {:?}", e); tracing::error!("Deserializing data from card failed: {:?}", e);
return Err(e.into()); return Err(e.into());
} }
} },
}; };
let apdu_response = match response { let apdu_response = match response {
CardCommand::readPICC { data } => { APDUResponse::new(&*data) } CardCommand::readPICC { data } => APDUResponse::new(&*data),
_ => { _ => {
tracing::error!("Unexpected response: {:?}", response); tracing::error!("Unexpected response: {:?}", response);
return Err(FabFireError::ParseError.into()); return Err(FabFireError::ParseError.into());
} }
}; };
apdu_response.check().map_err(|e| FabFireError::CardError(e))?; apdu_response
.check()
.map_err(|e| FabFireError::CardError(e))?;
// request the contents of the file containing the magic string // request the contents of the file containing the magic string
const MAGIC_FILE_ID: u8 = 0x01; const MAGIC_FILE_ID: u8 = 0x01;
let buf = match self.desfire.read_data_chunk_cmd(MAGIC_FILE_ID, 0, MAGIC.len()) { let buf = match self
.desfire
.read_data_chunk_cmd(MAGIC_FILE_ID, 0, MAGIC.len())
{
Ok(buf) => match Vec::<u8>::try_from(buf) { Ok(buf) => match Vec::<u8>::try_from(buf) {
Ok(data) => data, Ok(data) => data,
Err(e) => { Err(e) => {
@ -238,7 +279,9 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) { return match serde_json::to_vec(&cmd) {
Ok(send_buf) => { Ok(send_buf) => {
self.step = Step::VerifyMagic; self.step = Step::VerifyMagic;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?; writer
.write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len()))) Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len())))
} }
Err(e) => { Err(e) => {
@ -251,25 +294,28 @@ impl Authentication for FabFire {
tracing::trace!("Step: VerifyMagic"); tracing::trace!("Step: VerifyMagic");
// verify the magic string to determine that we have a valid fabfire card // verify the magic string to determine that we have a valid fabfire card
let response: CardCommand = match input { let response: CardCommand = match input {
None => { return Err(SessionError::InputDataRequired); } None => {
Some(buf) => match serde_json::from_slice(buf).map_err(|e| FabFireError::DeserializationError(e)) { return Err(SessionError::InputDataRequired);
}
Some(buf) => match serde_json::from_slice(buf)
.map_err(|e| FabFireError::DeserializationError(e))
{
Ok(response) => response, Ok(response) => response,
Err(e) => { Err(e) => {
tracing::error!("Deserializing data from card failed: {:?}", e); tracing::error!("Deserializing data from card failed: {:?}", e);
return Err(e.into()); return Err(e.into());
} }
} },
}; };
let apdu_response = match response { let apdu_response = match response {
CardCommand::readPICC { data } => { APDUResponse::new(&*data) } CardCommand::readPICC { data } => APDUResponse::new(&*data),
_ => { _ => {
tracing::error!("Unexpected response: {:?}", response); tracing::error!("Unexpected response: {:?}", response);
return Err(FabFireError::ParseError.into()); return Err(FabFireError::ParseError.into());
} }
}; };
match apdu_response.check() { match apdu_response.check() {
Ok(_) => { Ok(_) => {
match apdu_response.body { match apdu_response.body {
@ -291,11 +337,15 @@ impl Authentication for FabFire {
} }
} }
// request the contents of the file containing the URN // request the contents of the file containing the URN
const URN_FILE_ID: u8 = 0x02; const URN_FILE_ID: u8 = 0x02;
let buf = match self.desfire.read_data_chunk_cmd(URN_FILE_ID, 0, self.local_urn.as_bytes().len()) { // TODO: support urn longer than 47 Bytes let buf = match self.desfire.read_data_chunk_cmd(
URN_FILE_ID,
0,
self.local_urn.as_bytes().len(),
) {
// TODO: support urn longer than 47 Bytes
Ok(buf) => match Vec::<u8>::try_from(buf) { Ok(buf) => match Vec::<u8>::try_from(buf) {
Ok(data) => data, Ok(data) => data,
Err(e) => { Err(e) => {
@ -312,7 +362,9 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) { return match serde_json::to_vec(&cmd) {
Ok(send_buf) => { Ok(send_buf) => {
self.step = Step::GetURN; self.step = Step::GetURN;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?; writer
.write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len()))) Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len())))
} }
Err(e) => { Err(e) => {
@ -325,32 +377,39 @@ impl Authentication for FabFire {
tracing::trace!("Step: GetURN"); tracing::trace!("Step: GetURN");
// parse the urn and match it to our local urn // parse the urn and match it to our local urn
let response: CardCommand = match input { let response: CardCommand = match input {
None => { return Err(SessionError::InputDataRequired); } None => {
Some(buf) => match serde_json::from_slice(buf).map_err(|e| FabFireError::DeserializationError(e)) { return Err(SessionError::InputDataRequired);
}
Some(buf) => match serde_json::from_slice(buf)
.map_err(|e| FabFireError::DeserializationError(e))
{
Ok(response) => response, Ok(response) => response,
Err(e) => { Err(e) => {
tracing::error!("Deserializing data from card failed: {:?}", e); tracing::error!("Deserializing data from card failed: {:?}", e);
return Err(e.into()); return Err(e.into());
} }
} },
}; };
let apdu_response = match response { let apdu_response = match response {
CardCommand::readPICC { data } => { APDUResponse::new(&*data) } CardCommand::readPICC { data } => APDUResponse::new(&*data),
_ => { _ => {
tracing::error!("Unexpected response: {:?}", response); tracing::error!("Unexpected response: {:?}", response);
return Err(FabFireError::ParseError.into()); return Err(FabFireError::ParseError.into());
} }
}; };
match apdu_response.check() { match apdu_response.check() {
Ok(_) => { Ok(_) => {
match apdu_response.body { match apdu_response.body {
Some(data) => { Some(data) => {
let received_urn = String::from_utf8(data).unwrap(); let received_urn = String::from_utf8(data).unwrap();
if received_urn != self.local_urn { if received_urn != self.local_urn {
tracing::error!("URN mismatch: {:?} != {:?}", received_urn, self.local_urn); tracing::error!(
"URN mismatch: {:?} != {:?}",
received_urn,
self.local_urn
);
return Err(FabFireError::ParseError.into()); return Err(FabFireError::ParseError.into());
} }
} }
@ -368,7 +427,12 @@ impl Authentication for FabFire {
// request the contents of the file containing the URN // request the contents of the file containing the URN
const TOKEN_FILE_ID: u8 = 0x03; const TOKEN_FILE_ID: u8 = 0x03;
let buf = match self.desfire.read_data_chunk_cmd(TOKEN_FILE_ID, 0, MAX_BYTES_PER_TRANSACTION) { // TODO: support data longer than 47 Bytes let buf = match self.desfire.read_data_chunk_cmd(
TOKEN_FILE_ID,
0,
MAX_BYTES_PER_TRANSACTION,
) {
// TODO: support data longer than 47 Bytes
Ok(buf) => match Vec::<u8>::try_from(buf) { Ok(buf) => match Vec::<u8>::try_from(buf) {
Ok(data) => data, Ok(data) => data,
Err(e) => { Err(e) => {
@ -385,7 +449,9 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) { return match serde_json::to_vec(&cmd) {
Ok(send_buf) => { Ok(send_buf) => {
self.step = Step::GetToken; self.step = Step::GetToken;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?; writer
.write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len()))) Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len())))
} }
Err(e) => { Err(e) => {
@ -398,43 +464,52 @@ impl Authentication for FabFire {
// println!("Step: GetToken"); // println!("Step: GetToken");
// parse the token and select the appropriate user // parse the token and select the appropriate user
let response: CardCommand = match input { let response: CardCommand = match input {
None => { return Err(SessionError::InputDataRequired); } None => {
Some(buf) => match serde_json::from_slice(buf).map_err(|e| FabFireError::DeserializationError(e)) { return Err(SessionError::InputDataRequired);
}
Some(buf) => match serde_json::from_slice(buf)
.map_err(|e| FabFireError::DeserializationError(e))
{
Ok(response) => response, Ok(response) => response,
Err(e) => { Err(e) => {
tracing::error!("Deserializing data from card failed: {:?}", e); tracing::error!("Deserializing data from card failed: {:?}", e);
return Err(e.into()); return Err(e.into());
} }
} },
}; };
let apdu_response = match response { let apdu_response = match response {
CardCommand::readPICC { data } => { APDUResponse::new(&*data) } CardCommand::readPICC { data } => APDUResponse::new(&*data),
_ => { _ => {
tracing::error!("Unexpected response: {:?}", response); tracing::error!("Unexpected response: {:?}", response);
return Err(FabFireError::ParseError.into()); return Err(FabFireError::ParseError.into());
} }
}; };
match apdu_response.check() { match apdu_response.check() {
Ok(_) => { Ok(_) => {
match apdu_response.body { match apdu_response.body {
Some(data) => { Some(data) => {
let token = String::from_utf8(data).unwrap(); let token = String::from_utf8(data).unwrap();
session.set_property::<AuthId>(Arc::new(token.trim_matches(char::from(0)).to_string())); session.set_property::<AuthId>(Arc::new(
let key = match session.get_property_or_callback::<FabFireCardKey>() { token.trim_matches(char::from(0)).to_string(),
));
let key = match session.get_property_or_callback::<FabFireCardKey>()
{
Ok(Some(key)) => Box::from(key.as_slice()), Ok(Some(key)) => Box::from(key.as_slice()),
Ok(None) => { Ok(None) => {
tracing::error!("No keys on file for token"); tracing::error!("No keys on file for token");
return Err(FabFireError::InvalidCredentials("No keys on file for token".to_string()).into()); return Err(FabFireError::InvalidCredentials(
"No keys on file for token".to_string(),
)
.into());
} }
Err(e) => { Err(e) => {
tracing::error!("Failed to get key: {:?}", e); tracing::error!("Failed to get key: {:?}", e);
return Err(FabFireError::Session(e).into()); return Err(FabFireError::Session(e).into());
} }
}; };
self.key_info = Some(KeyInfo{ key_id: 0x01, key }); self.key_info = Some(KeyInfo { key_id: 0x01, key });
} }
None => { None => {
tracing::error!("No data in response"); tracing::error!("No data in response");
@ -448,7 +523,10 @@ impl Authentication for FabFire {
} }
} }
let buf = match self.desfire.authenticate_iso_aes_challenge_cmd(self.key_info.as_ref().unwrap().key_id) { let buf = match self
.desfire
.authenticate_iso_aes_challenge_cmd(self.key_info.as_ref().unwrap().key_id)
{
Ok(buf) => match Vec::<u8>::try_from(buf) { Ok(buf) => match Vec::<u8>::try_from(buf) {
Ok(data) => data, Ok(data) => data,
Err(e) => { Err(e) => {
@ -465,7 +543,9 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) { return match serde_json::to_vec(&cmd) {
Ok(send_buf) => { Ok(send_buf) => {
self.step = Step::Authenticate1; self.step = Step::Authenticate1;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?; writer
.write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len()))) Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len())))
} }
Err(e) => { Err(e) => {
@ -477,25 +557,28 @@ impl Authentication for FabFire {
Step::Authenticate1 => { Step::Authenticate1 => {
tracing::trace!("Step: Authenticate1"); tracing::trace!("Step: Authenticate1");
let response: CardCommand = match input { let response: CardCommand = match input {
None => { return Err(SessionError::InputDataRequired); } None => {
Some(buf) => match serde_json::from_slice(buf).map_err(|e| FabFireError::DeserializationError(e)) { return Err(SessionError::InputDataRequired);
}
Some(buf) => match serde_json::from_slice(buf)
.map_err(|e| FabFireError::DeserializationError(e))
{
Ok(response) => response, Ok(response) => response,
Err(e) => { Err(e) => {
tracing::error!("Failed to deserialize response: {:?}", e); tracing::error!("Failed to deserialize response: {:?}", e);
return Err(e.into()); return Err(e.into());
} }
} },
}; };
let apdu_response = match response { let apdu_response = match response {
CardCommand::readPICC { data } => { APDUResponse::new(&*data) } CardCommand::readPICC { data } => APDUResponse::new(&*data),
_ => { _ => {
tracing::error!("Unexpected response: {:?}", response); tracing::error!("Unexpected response: {:?}", response);
return Err(FabFireError::ParseError.into()); return Err(FabFireError::ParseError.into());
} }
}; };
match apdu_response.check() { match apdu_response.check() {
Ok(_) => { Ok(_) => {
match apdu_response.body { match apdu_response.body {
@ -506,13 +589,19 @@ impl Authentication for FabFire {
//TODO: Check if we need a CSPRNG here //TODO: Check if we need a CSPRNG here
let rnd_a: [u8; 16] = rand::random(); let rnd_a: [u8; 16] = rand::random();
let (cmd_challenge_response, let (cmd_challenge_response, rnd_b, iv) = self
rnd_b, .desfire
iv) = self.desfire.authenticate_iso_aes_response_cmd( .authenticate_iso_aes_response_cmd(
rnd_b_enc, rnd_b_enc,
&*(self.key_info.as_ref().unwrap().key), &*(self.key_info.as_ref().unwrap().key),
&rnd_a).unwrap(); &rnd_a,
self.auth_info = Some(AuthInfo { rnd_a: Vec::<u8>::from(rnd_a), rnd_b, iv }); )
.unwrap();
self.auth_info = Some(AuthInfo {
rnd_a: Vec::<u8>::from(rnd_a),
rnd_b,
iv,
});
let buf = match Vec::<u8>::try_from(cmd_challenge_response) { let buf = match Vec::<u8>::try_from(cmd_challenge_response) {
Ok(data) => data, Ok(data) => data,
Err(e) => { Err(e) => {
@ -524,7 +613,9 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) { return match serde_json::to_vec(&cmd) {
Ok(send_buf) => { Ok(send_buf) => {
self.step = Step::Authenticate2; self.step = Step::Authenticate2;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?; writer
.write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len()))) Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len())))
} }
Err(e) => { Err(e) => {
@ -548,39 +639,48 @@ impl Authentication for FabFire {
Step::Authenticate2 => { Step::Authenticate2 => {
// println!("Step: Authenticate2"); // println!("Step: Authenticate2");
let response: CardCommand = match input { let response: CardCommand = match input {
None => { return Err(SessionError::InputDataRequired); } None => {
Some(buf) => match serde_json::from_slice(buf).map_err(|e| FabFireError::DeserializationError(e)) { return Err(SessionError::InputDataRequired);
}
Some(buf) => match serde_json::from_slice(buf)
.map_err(|e| FabFireError::DeserializationError(e))
{
Ok(response) => response, Ok(response) => response,
Err(e) => { Err(e) => {
tracing::error!("Failed to deserialize response: {:?}", e); tracing::error!("Failed to deserialize response: {:?}", e);
return Err(e.into()); return Err(e.into());
} }
} },
}; };
let apdu_response = match response { let apdu_response = match response {
CardCommand::readPICC { data } => { APDUResponse::new(&*data) } CardCommand::readPICC { data } => APDUResponse::new(&*data),
_ => { _ => {
tracing::error!("Got invalid response: {:?}", response); tracing::error!("Got invalid response: {:?}", response);
return Err(FabFireError::ParseError.into()); return Err(FabFireError::ParseError.into());
} }
}; };
match apdu_response.check() { match apdu_response.check() {
Ok(_) => { Ok(_) => {
match apdu_response.body { match apdu_response.body {
Some(data) => { Some(data) => match self.auth_info.as_ref() {
match self.auth_info.as_ref() { None => {
None => { return Err(FabFireError::ParseError.into()); } return Err(FabFireError::ParseError.into());
}
Some(auth_info) => { Some(auth_info) => {
if self.desfire.authenticate_iso_aes_verify( if self
.desfire
.authenticate_iso_aes_verify(
data.as_slice(), data.as_slice(),
auth_info.rnd_a.as_slice(), auth_info.rnd_a.as_slice(),
auth_info.rnd_b.as_slice(), &*(self.key_info.as_ref().unwrap().key), auth_info.rnd_b.as_slice(),
auth_info.iv.as_slice()).is_ok() { &*(self.key_info.as_ref().unwrap().key),
auth_info.iv.as_slice(),
let cmd = CardCommand::message{ )
.is_ok()
{
let cmd = CardCommand::message {
msg_id: Some(4), msg_id: Some(4),
clr_txt: None, clr_txt: None,
addn_txt: Some("".to_string()), addn_txt: Some("".to_string()),
@ -588,18 +688,24 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) { return match serde_json::to_vec(&cmd) {
Ok(send_buf) => { Ok(send_buf) => {
self.step = Step::Authenticate1; self.step = Step::Authenticate1;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?; writer
return Ok(rsasl::session::Step::Done(Some(send_buf.len()))) .write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
return Ok(rsasl::session::Step::Done(Some(
send_buf.len(),
)));
} }
Err(e) => { Err(e) => {
tracing::error!("Failed to serialize command: {:?}", e); tracing::error!(
"Failed to serialize command: {:?}",
e
);
Err(FabFireError::SerializationError.into()) Err(FabFireError::SerializationError.into())
} }
}; };
} }
} }
} },
}
None => { None => {
tracing::error!("got empty response"); tracing::error!("got empty response");
return Err(FabFireError::ParseError.into()); return Err(FabFireError::ParseError.into());
@ -608,7 +714,9 @@ impl Authentication for FabFire {
} }
Err(_e) => { Err(_e) => {
tracing::error!("Got invalid response: {:?}", apdu_response); tracing::error!("Got invalid response: {:?}", apdu_response);
return Err(FabFireError::InvalidCredentials(format!("{}", apdu_response)).into()); return Err(
FabFireError::InvalidCredentials(format!("{}", apdu_response)).into(),
);
} }
} }
} }

View File

@ -1,6 +1,5 @@
use crate::users::Users; use crate::users::Users;
use rsasl::error::{SessionError}; use rsasl::error::SessionError;
use rsasl::mechname::Mechname; use rsasl::mechname::Mechname;
use rsasl::property::{AuthId, Password}; use rsasl::property::{AuthId, Password};
use rsasl::session::{Session, SessionData}; use rsasl::session::{Session, SessionData};
@ -31,12 +30,18 @@ impl rsasl::callback::Callback for Callback {
match property { match property {
fabfire::FABFIRECARDKEY => { fabfire::FABFIRECARDKEY => {
let authcid = session.get_property_or_callback::<AuthId>()?; let authcid = session.get_property_or_callback::<AuthId>()?;
let user = self.users.get_user(authcid.unwrap().as_ref()) let user = self
.users
.get_user(authcid.unwrap().as_ref())
.ok_or(SessionError::AuthenticationFailure)?; .ok_or(SessionError::AuthenticationFailure)?;
let kv = user.userdata.kv.get("cardkey") let kv = user
.userdata
.kv
.get("cardkey")
.ok_or(SessionError::AuthenticationFailure)?; .ok_or(SessionError::AuthenticationFailure)?;
let card_key = <[u8; 16]>::try_from(hex::decode(kv) let card_key = <[u8; 16]>::try_from(
.map_err(|_| SessionError::AuthenticationFailure)?) hex::decode(kv).map_err(|_| SessionError::AuthenticationFailure)?,
)
.map_err(|_| SessionError::AuthenticationFailure)?; .map_err(|_| SessionError::AuthenticationFailure)?;
session.set_property::<FabFireCardKey>(Arc::new(card_key)); session.set_property::<FabFireCardKey>(Arc::new(card_key));
Ok(()) Ok(())
@ -60,9 +65,7 @@ impl rsasl::callback::Callback for Callback {
.ok_or(SessionError::no_property::<AuthId>())?; .ok_or(SessionError::no_property::<AuthId>())?;
tracing::debug!(authid=%authnid, "SIMPLE validation requested"); tracing::debug!(authid=%authnid, "SIMPLE validation requested");
if let Some(user) = self if let Some(user) = self.users.get_user(authnid.as_str()) {
.users
.get_user(authnid.as_str()) {
let passwd = session let passwd = session
.get_property::<Password>() .get_property::<Password>()
.ok_or(SessionError::no_property::<Password>())?; .ok_or(SessionError::no_property::<Password>())?;
@ -84,7 +87,7 @@ impl rsasl::callback::Callback for Callback {
_ => { _ => {
tracing::error!(?validation, "Unimplemented validation requested"); tracing::error!(?validation, "Unimplemented validation requested");
Err(SessionError::no_validate(validation)) Err(SessionError::no_validate(validation))
}, }
} }
} }
} }
@ -111,10 +114,12 @@ impl AuthenticationHandle {
let mut rsasl = SASL::new(); let mut rsasl = SASL::new();
rsasl.install_callback(Arc::new(Callback::new(userdb))); rsasl.install_callback(Arc::new(Callback::new(userdb)));
let mechs: Vec<&'static str> = rsasl.server_mech_list().into_iter() let mechs: Vec<&'static str> = rsasl
.server_mech_list()
.into_iter()
.map(|m| m.mechanism.as_str()) .map(|m| m.mechanism.as_str())
.collect(); .collect();
tracing::info!(available_mechs=mechs.len(), "initialized sasl backend"); tracing::info!(available_mechs = mechs.len(), "initialized sasl backend");
tracing::debug!(?mechs, "available mechs"); tracing::debug!(?mechs, "available mechs");
Self { Self {

View File

@ -1,9 +1,6 @@
use crate::authorization::roles::Roles;
use crate::authorization::roles::{Roles};
use crate::Users; use crate::Users;
pub mod permissions; pub mod permissions;
pub mod roles; pub mod roles;

View File

@ -1,10 +1,9 @@
//! Access control logic //! Access control logic
//! //!
use std::fmt;
use std::cmp::Ordering; use std::cmp::Ordering;
use std::convert::{TryFrom, Into}; use std::convert::{Into, TryFrom};
use std::fmt;
fn is_sep_char(c: char) -> bool { fn is_sep_char(c: char) -> bool {
c == '.' c == '.'
@ -20,7 +19,7 @@ pub struct PrivilegesBuf {
/// Which permission is required to write parts of this thing /// Which permission is required to write parts of this thing
pub write: PermissionBuf, pub write: PermissionBuf,
/// Which permission is required to manage all parts of this thing /// Which permission is required to manage all parts of this thing
pub manage: PermissionBuf pub manage: PermissionBuf,
} }
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
@ -39,13 +38,17 @@ impl PermissionBuf {
#[inline(always)] #[inline(always)]
/// Allocate an empty `PermissionBuf` /// Allocate an empty `PermissionBuf`
pub fn new() -> Self { pub fn new() -> Self {
PermissionBuf { inner: String::new() } PermissionBuf {
inner: String::new(),
}
} }
#[inline(always)] #[inline(always)]
/// Allocate a `PermissionBuf` with the given capacity given to the internal [`String`] /// Allocate a `PermissionBuf` with the given capacity given to the internal [`String`]
pub fn with_capacity(cap: usize) -> Self { pub fn with_capacity(cap: usize) -> Self {
PermissionBuf { inner: String::with_capacity(cap) } PermissionBuf {
inner: String::with_capacity(cap),
}
} }
#[inline(always)] #[inline(always)]
@ -59,7 +62,13 @@ impl PermissionBuf {
pub fn _push(&mut self, perm: &Permission) { pub fn _push(&mut self, perm: &Permission) {
// in general we always need a separator unless the last byte is one or the string is empty // in general we always need a separator unless the last byte is one or the string is empty
let need_sep = self.inner.chars().rev().next().map(|c| !is_sep_char(c)).unwrap_or(false); let need_sep = self
.inner
.chars()
.rev()
.next()
.map(|c| !is_sep_char(c))
.unwrap_or(false);
if need_sep { if need_sep {
self.inner.push('.') self.inner.push('.')
} }
@ -73,7 +82,9 @@ impl PermissionBuf {
#[inline] #[inline]
pub fn from_perm(perm: &Permission) -> Self { pub fn from_perm(perm: &Permission) -> Self {
Self { inner: perm.as_str().to_string() } Self {
inner: perm.as_str().to_string(),
}
} }
#[inline(always)] #[inline(always)]
@ -162,12 +173,14 @@ impl PartialOrd for Permission {
} }
} }
match (l,r) { match (l, r) {
(None, None) => Some(Ordering::Equal), (None, None) => Some(Ordering::Equal),
(Some(_), None) => Some(Ordering::Less), (Some(_), None) => Some(Ordering::Less),
(None, Some(_)) => Some(Ordering::Greater), (None, Some(_)) => Some(Ordering::Greater),
(Some(_), Some(_)) => unreachable!("Broken contract in Permission::partial_cmp: sides \ (Some(_), Some(_)) => unreachable!(
should never be both Some!"), "Broken contract in Permission::partial_cmp: sides \
should never be both Some!"
),
} }
} }
} }
@ -208,7 +221,7 @@ impl PermRule {
pub fn match_perm<P: AsRef<Permission> + ?Sized>(&self, perm: &P) -> bool { pub fn match_perm<P: AsRef<Permission> + ?Sized>(&self, perm: &P) -> bool {
match self { match self {
PermRule::Base(ref base) => base.as_permission() == perm.as_ref(), PermRule::Base(ref base) => base.as_permission() == perm.as_ref(),
PermRule::Children(ref parent) => parent.as_permission() > perm.as_ref() , PermRule::Children(ref parent) => parent.as_permission() > perm.as_ref(),
PermRule::Subtree(ref parent) => parent.as_permission() >= perm.as_ref(), PermRule::Subtree(ref parent) => parent.as_permission() >= perm.as_ref(),
} }
} }
@ -217,12 +230,9 @@ impl PermRule {
impl fmt::Display for PermRule { impl fmt::Display for PermRule {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
PermRule::Base(perm) PermRule::Base(perm) => write!(f, "{}", perm),
=> write!(f, "{}", perm), PermRule::Children(parent) => write!(f, "{}.+", parent),
PermRule::Children(parent) PermRule::Subtree(parent) => write!(f, "{}.*", parent),
=> write!(f,"{}.+", parent),
PermRule::Subtree(parent)
=> write!(f,"{}.*", parent),
} }
} }
} }
@ -234,7 +244,7 @@ impl Into<String> for PermRule {
PermRule::Children(mut perm) => { PermRule::Children(mut perm) => {
perm.push(Permission::new("+")); perm.push(Permission::new("+"));
perm.into_string() perm.into_string()
}, }
PermRule::Subtree(mut perm) => { PermRule::Subtree(mut perm) => {
perm.push(Permission::new("+")); perm.push(Permission::new("+"));
perm.into_string() perm.into_string()
@ -252,15 +262,19 @@ impl TryFrom<String> for PermRule {
if len <= 2 { if len <= 2 {
Err("Input string for PermRule is too short") Err("Input string for PermRule is too short")
} else { } else {
match &input[len-2..len] { match &input[len - 2..len] {
".+" => { ".+" => {
input.truncate(len-2); input.truncate(len - 2);
Ok(PermRule::Children(PermissionBuf::from_string_unchecked(input))) Ok(PermRule::Children(PermissionBuf::from_string_unchecked(
}, input,
)))
}
".*" => { ".*" => {
input.truncate(len-2); input.truncate(len - 2);
Ok(PermRule::Subtree(PermissionBuf::from_string_unchecked(input))) Ok(PermRule::Subtree(PermissionBuf::from_string_unchecked(
}, input,
)))
}
_ => Ok(PermRule::Base(PermissionBuf::from_string_unchecked(input))), _ => Ok(PermRule::Base(PermissionBuf::from_string_unchecked(input))),
} }
} }
@ -273,8 +287,10 @@ mod tests {
#[test] #[test]
fn permission_ord_test() { fn permission_ord_test() {
assert!(PermissionBuf::from_string_unchecked("bffh.perm".to_string()) assert!(
> PermissionBuf::from_string_unchecked("bffh.perm.sub".to_string())); PermissionBuf::from_string_unchecked("bffh.perm".to_string())
> PermissionBuf::from_string_unchecked("bffh.perm.sub".to_string())
);
} }
#[test] #[test]
@ -316,11 +332,9 @@ mod tests {
fn format_and_read_compatible() { fn format_and_read_compatible() {
use std::convert::TryInto; use std::convert::TryInto;
let testdata = vec![ let testdata = vec![("testrole", "testsource"), ("", "norole"), ("nosource", "")]
("testrole", "testsource"), .into_iter()
("", "norole"), .map(|(n, s)| (n.to_string(), s.to_string()));
("nosource", "")
].into_iter().map(|(n,s)| (n.to_string(), s.to_string()));
for (name, source) in testdata { for (name, source) in testdata {
let role = RoleIdentifier { name, source }; let role = RoleIdentifier { name, source };
@ -337,19 +351,24 @@ mod tests {
} }
} }
#[test] #[test]
fn rules_from_string_test() { fn rules_from_string_test() {
assert_eq!( assert_eq!(
PermRule::Base(PermissionBuf::from_string_unchecked("bffh.perm".to_string())), PermRule::Base(PermissionBuf::from_string_unchecked(
"bffh.perm".to_string()
)),
PermRule::try_from("bffh.perm".to_string()).unwrap() PermRule::try_from("bffh.perm".to_string()).unwrap()
); );
assert_eq!( assert_eq!(
PermRule::Children(PermissionBuf::from_string_unchecked("bffh.perm".to_string())), PermRule::Children(PermissionBuf::from_string_unchecked(
"bffh.perm".to_string()
)),
PermRule::try_from("bffh.perm.+".to_string()).unwrap() PermRule::try_from("bffh.perm.+".to_string()).unwrap()
); );
assert_eq!( assert_eq!(
PermRule::Subtree(PermissionBuf::from_string_unchecked("bffh.perm".to_string())), PermRule::Subtree(PermissionBuf::from_string_unchecked(
"bffh.perm".to_string()
)),
PermRule::try_from("bffh.perm.*".to_string()).unwrap() PermRule::try_from("bffh.perm.*".to_string()).unwrap()
); );
} }

View File

@ -1,8 +1,8 @@
use crate::authorization::permissions::{PermRule, Permission};
use crate::users::db::UserData;
use once_cell::sync::OnceCell;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::fmt; use std::fmt;
use once_cell::sync::OnceCell;
use crate::authorization::permissions::{Permission, PermRule};
use crate::users::db::UserData;
static ROLES: OnceCell<HashMap<String, Role>> = OnceCell::new(); static ROLES: OnceCell<HashMap<String, Role>> = OnceCell::new();
@ -27,7 +27,6 @@ impl Roles {
self.roles.get(roleid) self.roles.get(roleid)
} }
/// Tally a role dependency tree into a set /// Tally a role dependency tree into a set
/// ///
/// A Default implementation exists which adapter may overwrite with more efficient /// A Default implementation exists which adapter may overwrite with more efficient
@ -62,10 +61,11 @@ impl Roles {
output output
} }
fn permitted_tally(&self, fn permitted_tally(
&self,
roles: &mut HashSet<String>, roles: &mut HashSet<String>,
role_id: &String, role_id: &String,
perm: &Permission perm: &Permission,
) -> bool { ) -> bool {
if let Some(role) = self.get(role_id) { if let Some(role) = self.get(role_id) {
// Only check and tally parents of a role at the role itself if it's the first time we // Only check and tally parents of a role at the role itself if it's the first time we
@ -130,7 +130,10 @@ pub struct Role {
impl Role { impl Role {
pub fn new(parents: Vec<String>, permissions: Vec<PermRule>) -> Self { pub fn new(parents: Vec<String>, permissions: Vec<PermRule>) -> Self {
Self { parents, permissions } Self {
parents,
permissions,
}
} }
} }

View File

@ -5,7 +5,6 @@ use rsasl::property::AuthId;
use rsasl::session::{Session, Step}; use rsasl::session::{Session, Step};
use std::io::Cursor; use std::io::Cursor;
use crate::capnp::session::APISession; use crate::capnp::session::APISession;
use crate::session::SessionManager; use crate::session::SessionManager;
use api::authenticationsystem_capnp::authentication::{ use api::authenticationsystem_capnp::authentication::{

View File

@ -2,7 +2,7 @@ use std::fmt::Formatter;
use std::net::ToSocketAddrs; use std::net::ToSocketAddrs;
use std::path::PathBuf; use std::path::PathBuf;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::config::deser_option; use crate::config::deser_option;
@ -14,7 +14,11 @@ use crate::config::deser_option;
pub struct Listen { pub struct Listen {
pub address: String, pub address: String,
#[serde(default, skip_serializing_if = "Option::is_none", deserialize_with = "deser_option")] #[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deser_option"
)]
pub port: Option<u16>, pub port: Option<u16>,
} }

View File

@ -1,13 +1,13 @@
use std::net::SocketAddr;
pub use api::connection_capnp::bootstrap::Client;
use api::connection_capnp::bootstrap; use api::connection_capnp::bootstrap;
pub use api::connection_capnp::bootstrap::Client;
use std::net::SocketAddr;
use capnp::capability::Promise;
use capnp_rpc::pry;
use rsasl::mechname::Mechname;
use crate::authentication::AuthenticationHandle; use crate::authentication::AuthenticationHandle;
use crate::capnp::authenticationsystem::Authentication; use crate::capnp::authenticationsystem::Authentication;
use crate::session::SessionManager; use crate::session::SessionManager;
use capnp::capability::Promise;
use capnp_rpc::pry;
use rsasl::mechname::Mechname;
/// Cap'n Proto API Handler /// Cap'n Proto API Handler
pub struct BootCap { pub struct BootCap {
@ -17,7 +17,11 @@ pub struct BootCap {
} }
impl BootCap { impl BootCap {
pub fn new(peer_addr: SocketAddr, authentication: AuthenticationHandle, sessionmanager: SessionManager) -> Self { pub fn new(
peer_addr: SocketAddr,
authentication: AuthenticationHandle,
sessionmanager: SessionManager,
) -> Self {
tracing::trace!(%peer_addr, "bootstrapping RPC"); tracing::trace!(%peer_addr, "bootstrapping RPC");
Self { Self {
peer_addr, peer_addr,
@ -62,12 +66,14 @@ impl bootstrap::Server for BootCap {
tracing::trace!("mechanisms"); tracing::trace!("mechanisms");
let builder = result.get(); let builder = result.get();
let mechs: Vec<_> = self.authentication.list_available_mechs() let mechs: Vec<_> = self
.authentication
.list_available_mechs()
.into_iter() .into_iter()
.map(|m| m.as_str()) .map(|m| m.as_str())
.collect(); .collect();
let mut mechbuilder = builder.init_mechs(mechs.len() as u32); let mut mechbuilder = builder.init_mechs(mechs.len() as u32);
for (i,m) in mechs.iter().enumerate() { for (i, m) in mechs.iter().enumerate() {
mechbuilder.set(i as u32, m); mechbuilder.set(i as u32, m);
} }

View File

@ -1,14 +1,11 @@
use crate::capnp::machine::Machine;
use crate::resources::search::ResourcesHandle;
use crate::resources::Resource;
use crate::session::SessionHandle; use crate::session::SessionHandle;
use api::machinesystem_capnp::machine_system::{ use crate::RESOURCES;
info, use api::machinesystem_capnp::machine_system::info;
};
use capnp::capability::Promise; use capnp::capability::Promise;
use capnp_rpc::pry; use capnp_rpc::pry;
use crate::capnp::machine::Machine;
use crate::RESOURCES;
use crate::resources::Resource;
use crate::resources::search::ResourcesHandle;
#[derive(Clone)] #[derive(Clone)]
pub struct Machines { pub struct Machines {
@ -19,7 +16,10 @@ pub struct Machines {
impl Machines { impl Machines {
pub fn new(session: SessionHandle) -> Self { pub fn new(session: SessionHandle) -> Self {
// FIXME no unwrap bad // FIXME no unwrap bad
Self { session, resources: RESOURCES.get().unwrap().clone() } Self {
session,
resources: RESOURCES.get().unwrap().clone(),
}
} }
} }
@ -29,7 +29,9 @@ impl info::Server for Machines {
_: info::GetMachineListParams, _: info::GetMachineListParams,
mut result: info::GetMachineListResults, mut result: info::GetMachineListResults,
) -> Promise<(), ::capnp::Error> { ) -> Promise<(), ::capnp::Error> {
let machine_list: Vec<(usize, &Resource)> = self.resources.list_all() let machine_list: Vec<(usize, &Resource)> = self
.resources
.list_all()
.into_iter() .into_iter()
.filter(|resource| resource.visible(&self.session)) .filter(|resource| resource.visible(&self.session))
.enumerate() .enumerate()

View File

@ -1,6 +1,5 @@
use async_net::TcpListener; use async_net::TcpListener;
use capnp_rpc::rpc_twoparty_capnp::Side; use capnp_rpc::rpc_twoparty_capnp::Side;
use capnp_rpc::twoparty::VatNetwork; use capnp_rpc::twoparty::VatNetwork;
use capnp_rpc::RpcSystem; use capnp_rpc::RpcSystem;
@ -69,9 +68,7 @@ impl APIServer {
listens listens
.into_iter() .into_iter()
.map(|a| async move { .map(|a| async move { (async_net::resolve(a.to_tuple()).await, a) })
(async_net::resolve(a.to_tuple()).await, a)
})
.collect::<FuturesUnordered<_>>() .collect::<FuturesUnordered<_>>()
.filter_map(|(res, addr)| async move { .filter_map(|(res, addr)| async move {
match res { match res {
@ -111,7 +108,13 @@ impl APIServer {
tracing::warn!("No usable listen addresses configured for the API server!"); tracing::warn!("No usable listen addresses configured for the API server!");
} }
Ok(Self::new(executor, sockets, acceptor, sessionmanager, authentication)) Ok(Self::new(
executor,
sockets,
acceptor,
sessionmanager,
authentication,
))
} }
pub async fn handle_until(self, stop: impl Future) { pub async fn handle_until(self, stop: impl Future) {
@ -129,10 +132,11 @@ impl APIServer {
} else { } else {
tracing::error!(?stream, "failing a TCP connection with no peer addr"); tracing::error!(?stream, "failing a TCP connection with no peer addr");
} }
}, }
Err(e) => tracing::warn!("Failed to accept stream: {}", e), Err(e) => tracing::warn!("Failed to accept stream: {}", e),
} }
}).await; })
.await;
tracing::info!("closing down API handler"); tracing::info!("closing down API handler");
} }
@ -153,7 +157,11 @@ impl APIServer {
let (rx, tx) = futures_lite::io::split(stream); let (rx, tx) = futures_lite::io::split(stream);
let vat = VatNetwork::new(rx, tx, Side::Server, Default::default()); let vat = VatNetwork::new(rx, tx, Side::Server, Default::default());
let bootstrap: connection::Client = capnp_rpc::new_client(connection::BootCap::new(peer_addr, self.authentication.clone(), self.sessionmanager.clone())); let bootstrap: connection::Client = capnp_rpc::new_client(connection::BootCap::new(
peer_addr,
self.authentication.clone(),
self.sessionmanager.clone(),
));
if let Err(e) = RpcSystem::new(Box::new(vat), Some(bootstrap.client)).await { if let Err(e) = RpcSystem::new(Box::new(vat), Some(bootstrap.client)).await {
tracing::error!("Error during RPC handling: {}", e); tracing::error!("Error during RPC handling: {}", e);

View File

@ -10,6 +10,4 @@ impl Permissions {
} }
} }
impl PermissionSystem for Permissions { impl PermissionSystem for Permissions {}
}

View File

@ -1,12 +1,10 @@
use api::authenticationsystem_capnp::response::successful::Builder;
use crate::authorization::permissions::Permission; use crate::authorization::permissions::Permission;
use api::authenticationsystem_capnp::response::successful::Builder;
use crate::capnp::machinesystem::Machines; use crate::capnp::machinesystem::Machines;
use crate::capnp::permissionsystem::Permissions; use crate::capnp::permissionsystem::Permissions;
use crate::capnp::user_system::Users; use crate::capnp::user_system::Users;
use crate::session::{SessionHandle}; use crate::session::SessionHandle;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct APISession; pub struct APISession;

View File

@ -1,10 +1,10 @@
use crate::authorization::permissions::Permission;
use crate::session::SessionHandle;
use crate::users::{db, UserRef};
use api::general_capnp::optional;
use api::user_capnp::user::{self, admin, info, manage};
use capnp::capability::Promise; use capnp::capability::Promise;
use capnp_rpc::pry; use capnp_rpc::pry;
use crate::session::SessionHandle;
use api::user_capnp::user::{admin, info, manage, self};
use api::general_capnp::optional;
use crate::authorization::permissions::Permission;
use crate::users::{db, UserRef};
#[derive(Clone)] #[derive(Clone)]
pub struct User { pub struct User {
@ -22,7 +22,11 @@ impl User {
Self::new(session, user) Self::new(session, user)
} }
pub fn build_optional(session: &SessionHandle, user: Option<UserRef>, builder: optional::Builder<user::Owned>) { pub fn build_optional(
session: &SessionHandle,
user: Option<UserRef>,
builder: optional::Builder<user::Owned>,
) {
if let Some(user) = user.and_then(|u| session.users.get_user(u.get_username())) { if let Some(user) = user.and_then(|u| session.users.get_user(u.get_username())) {
let builder = builder.init_just(); let builder = builder.init_just();
Self::fill(&session, user, builder); Self::fill(&session, user, builder);
@ -102,12 +106,18 @@ impl admin::Server for User {
let rolename = pry!(pry!(pry!(param.get()).get_role()).get_name()); let rolename = pry!(pry!(pry!(param.get()).get_role()).get_name());
if let Some(_role) = self.session.roles.get(rolename) { if let Some(_role) = self.session.roles.get(rolename) {
let mut target = self.session.users.get_user(self.user.get_username()).unwrap(); let mut target = self
.session
.users
.get_user(self.user.get_username())
.unwrap();
// Only update if needed // Only update if needed
if !target.userdata.roles.iter().any(|r| r.as_str() == rolename) { if !target.userdata.roles.iter().any(|r| r.as_str() == rolename) {
target.userdata.roles.push(rolename.to_string()); target.userdata.roles.push(rolename.to_string());
self.session.users.put_user(self.user.get_username(), &target); self.session
.users
.put_user(self.user.get_username(), &target);
} }
} }
@ -121,22 +131,24 @@ impl admin::Server for User {
let rolename = pry!(pry!(pry!(param.get()).get_role()).get_name()); let rolename = pry!(pry!(pry!(param.get()).get_role()).get_name());
if let Some(_role) = self.session.roles.get(rolename) { if let Some(_role) = self.session.roles.get(rolename) {
let mut target = self.session.users.get_user(self.user.get_username()).unwrap(); let mut target = self
.session
.users
.get_user(self.user.get_username())
.unwrap();
// Only update if needed // Only update if needed
if target.userdata.roles.iter().any(|r| r.as_str() == rolename) { if target.userdata.roles.iter().any(|r| r.as_str() == rolename) {
target.userdata.roles.retain(|r| r.as_str() != rolename); target.userdata.roles.retain(|r| r.as_str() != rolename);
self.session.users.put_user(self.user.get_username(), &target); self.session
.users
.put_user(self.user.get_username(), &target);
} }
} }
Promise::ok(()) Promise::ok(())
} }
fn pwd( fn pwd(&mut self, _: admin::PwdParams, _: admin::PwdResults) -> Promise<(), ::capnp::Error> {
&mut self,
_: admin::PwdParams,
_: admin::PwdResults,
) -> Promise<(), ::capnp::Error> {
Promise::err(::capnp::Error::unimplemented( Promise::err(::capnp::Error::unimplemented(
"method not implemented".to_string(), "method not implemented".to_string(),
)) ))

View File

@ -1,15 +1,12 @@
use api::usersystem_capnp::user_system::{info, manage, search};
use capnp::capability::Promise; use capnp::capability::Promise;
use capnp_rpc::pry; use capnp_rpc::pry;
use api::usersystem_capnp::user_system::{
info, manage, search
};
use crate::capnp::user::User; use crate::capnp::user::User;
use crate::session::SessionHandle; use crate::session::SessionHandle;
use crate::users::{db, UserRef}; use crate::users::{db, UserRef};
#[derive(Clone)] #[derive(Clone)]
pub struct Users { pub struct Users {
session: SessionHandle, session: SessionHandle,
@ -40,7 +37,8 @@ impl manage::Server for Users {
mut result: manage::GetUserListResults, mut result: manage::GetUserListResults,
) -> Promise<(), ::capnp::Error> { ) -> Promise<(), ::capnp::Error> {
let userdb = self.session.users.into_inner(); let userdb = self.session.users.into_inner();
let users = pry!(userdb.get_all() let users = pry!(userdb
.get_all()
.map_err(|e| capnp::Error::failed(format!("UserDB error: {:?}", e)))); .map_err(|e| capnp::Error::failed(format!("UserDB error: {:?}", e))));
let mut builder = result.get().init_user_list(users.len() as u32); let mut builder = result.get().init_user_list(users.len() as u32);
for (i, (_, user)) in users.into_iter().enumerate() { for (i, (_, user)) in users.into_iter().enumerate() {

View File

@ -1,8 +1,6 @@
use std::path::Path;
use crate::Config; use crate::Config;
use std::path::Path;
pub fn read_config_file(path: impl AsRef<Path>) -> Result<Config, serde_dhall::Error> { pub fn read_config_file(path: impl AsRef<Path>) -> Result<Config, serde_dhall::Error> {
serde_dhall::from_file(path) serde_dhall::from_file(path).parse().map_err(Into::into)
.parse()
.map_err(Into::into)
} }

View File

@ -1,13 +1,13 @@
use std::default::Default;
use std::path::{PathBuf};
use std::collections::HashMap; use std::collections::HashMap;
use std::default::Default;
use std::path::PathBuf;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
mod dhall; mod dhall;
pub use dhall::read_config_file as read; pub use dhall::read_config_file as read;
use crate::authorization::permissions::{PrivilegesBuf}; use crate::authorization::permissions::PrivilegesBuf;
use crate::authorization::roles::Role; use crate::authorization::roles::Role;
use crate::capnp::{Listen, TlsListen}; use crate::capnp::{Listen, TlsListen};
use crate::logging::LogConfig; use crate::logging::LogConfig;
@ -23,13 +23,25 @@ pub struct MachineDescription {
pub name: String, pub name: String,
/// An optional description of the Machine. /// An optional description of the Machine.
#[serde(default, skip_serializing_if = "Option::is_none", deserialize_with = "deser_option")] #[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deser_option"
)]
pub description: Option<String>, pub description: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none", deserialize_with = "deser_option")] #[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deser_option"
)]
pub wiki: Option<String>, pub wiki: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none", deserialize_with = "deser_option")] #[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deser_option"
)]
pub category: Option<String>, pub category: Option<String>,
/// The permission required /// The permission required
@ -83,48 +95,49 @@ impl Config {
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModuleConfig { pub struct ModuleConfig {
pub module: String, pub module: String,
pub params: HashMap<String, String> pub params: HashMap<String, String>,
} }
pub(crate) fn deser_option<'de, D, T>(d: D) -> std::result::Result<Option<T>, D::Error> pub(crate) fn deser_option<'de, D, T>(d: D) -> std::result::Result<Option<T>, D::Error>
where D: serde::Deserializer<'de>, T: serde::Deserialize<'de>, where
D: serde::Deserializer<'de>,
T: serde::Deserialize<'de>,
{ {
Ok(T::deserialize(d).ok()) Ok(T::deserialize(d).ok())
} }
impl Default for Config { impl Default for Config {
fn default() -> Self { fn default() -> Self {
let mut actors: HashMap::<String, ModuleConfig> = HashMap::new(); let mut actors: HashMap<String, ModuleConfig> = HashMap::new();
let mut initiators: HashMap::<String, ModuleConfig> = HashMap::new(); let mut initiators: HashMap<String, ModuleConfig> = HashMap::new();
let machines = HashMap::new(); let machines = HashMap::new();
actors.insert("Actor".to_string(), ModuleConfig { actors.insert(
"Actor".to_string(),
ModuleConfig {
module: "Shelly".to_string(), module: "Shelly".to_string(),
params: HashMap::new(), params: HashMap::new(),
}); },
initiators.insert("Initiator".to_string(), ModuleConfig { );
initiators.insert(
"Initiator".to_string(),
ModuleConfig {
module: "TCP-Listen".to_string(), module: "TCP-Listen".to_string(),
params: HashMap::new(), params: HashMap::new(),
}); },
);
Config { Config {
listens: vec![ listens: vec![Listen {
Listen {
address: "127.0.0.1".to_string(), address: "127.0.0.1".to_string(),
port: None, port: None,
} }],
],
actors, actors,
initiators, initiators,
machines, machines,
mqtt_url: "tcp://localhost:1883".to_string(), mqtt_url: "tcp://localhost:1883".to_string(),
actor_connections: vec![ actor_connections: vec![("Testmachine".to_string(), "Actor".to_string())],
("Testmachine".to_string(), "Actor".to_string()), init_connections: vec![("Initiator".to_string(), "Testmachine".to_string())],
],
init_connections: vec![
("Initiator".to_string(), "Testmachine".to_string()),
],
db_path: PathBuf::from("/run/bffh/database"), db_path: PathBuf::from("/run/bffh/database"),
auditlog_path: PathBuf::from("/var/log/bffh/audit.log"), auditlog_path: PathBuf::from("/var/log/bffh/audit.log"),
@ -133,7 +146,7 @@ impl Default for Config {
tlsconfig: TlsListen { tlsconfig: TlsListen {
certfile: PathBuf::from("./bffh.crt"), certfile: PathBuf::from("./bffh.crt"),
keyfile: PathBuf::from("./bffh.key"), keyfile: PathBuf::from("./bffh.key"),
.. Default::default() ..Default::default()
}, },
tlskeylog: None, tlskeylog: None,

View File

@ -2,6 +2,6 @@ mod raw;
pub use raw::RawDB; pub use raw::RawDB;
mod typed; mod typed;
pub use typed::{DB, ArchivedValue, Adapter, AlignedAdapter}; pub use typed::{Adapter, AlignedAdapter, ArchivedValue, DB};
pub type Error = lmdb::Error; pub type Error = lmdb::Error;

View File

@ -1,10 +1,4 @@
use lmdb::{ use lmdb::{DatabaseFlags, Environment, RwTransaction, Transaction, WriteFlags};
Transaction,
RwTransaction,
Environment,
DatabaseFlags,
WriteFlags,
};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct RawDB { pub struct RawDB {
@ -16,12 +10,21 @@ impl RawDB {
env.open_db(name).map(|db| Self { db }) env.open_db(name).map(|db| Self { db })
} }
pub fn create(env: &Environment, name: Option<&str>, flags: DatabaseFlags) -> lmdb::Result<Self> { pub fn create(
env: &Environment,
name: Option<&str>,
flags: DatabaseFlags,
) -> lmdb::Result<Self> {
env.create_db(name, flags).map(|db| Self { db }) env.create_db(name, flags).map(|db| Self { db })
} }
pub fn get<'txn, T: Transaction, K>(&self, txn: &'txn T, key: &K) -> lmdb::Result<Option<&'txn [u8]>> pub fn get<'txn, T: Transaction, K>(
where K: AsRef<[u8]> &self,
txn: &'txn T,
key: &K,
) -> lmdb::Result<Option<&'txn [u8]>>
where
K: AsRef<[u8]>,
{ {
match txn.get(self.db, key) { match txn.get(self.db, key) {
Ok(buf) => Ok(Some(buf)), Ok(buf) => Ok(Some(buf)),
@ -30,23 +33,36 @@ impl RawDB {
} }
} }
pub fn put<K, V>(&self, txn: &mut RwTransaction, key: &K, value: &V, flags: WriteFlags) pub fn put<K, V>(
-> lmdb::Result<()> &self,
where K: AsRef<[u8]>, txn: &mut RwTransaction,
key: &K,
value: &V,
flags: WriteFlags,
) -> lmdb::Result<()>
where
K: AsRef<[u8]>,
V: AsRef<[u8]>, V: AsRef<[u8]>,
{ {
txn.put(self.db, key, value, flags) txn.put(self.db, key, value, flags)
} }
pub fn reserve<'txn, K>(&self, txn: &'txn mut RwTransaction, key: &K, size: usize, flags: WriteFlags) pub fn reserve<'txn, K>(
-> lmdb::Result<&'txn mut [u8]> &self,
where K: AsRef<[u8]> txn: &'txn mut RwTransaction,
key: &K,
size: usize,
flags: WriteFlags,
) -> lmdb::Result<&'txn mut [u8]>
where
K: AsRef<[u8]>,
{ {
txn.reserve(self.db, key, size, flags) txn.reserve(self.db, key, size, flags)
} }
pub fn del<K, V>(&self, txn: &mut RwTransaction, key: &K, value: Option<&V>) -> lmdb::Result<()> pub fn del<K, V>(&self, txn: &mut RwTransaction, key: &K, value: Option<&V>) -> lmdb::Result<()>
where K: AsRef<[u8]>, where
K: AsRef<[u8]>,
V: AsRef<[u8]>, V: AsRef<[u8]>,
{ {
txn.del(self.db, key, value.map(AsRef::as_ref)) txn.del(self.db, key, value.map(AsRef::as_ref))
@ -60,7 +76,10 @@ impl RawDB {
cursor.iter_start() cursor.iter_start()
} }
pub fn open_ro_cursor<'txn, T: Transaction>(&self, txn: &'txn T) -> lmdb::Result<lmdb::RoCursor<'txn>> { pub fn open_ro_cursor<'txn, T: Transaction>(
&self,
txn: &'txn T,
) -> lmdb::Result<lmdb::RoCursor<'txn>> {
txn.open_ro_cursor(self.db) txn.open_ro_cursor(self.db)
} }
} }

View File

@ -119,7 +119,11 @@ impl<A> DB<A> {
} }
impl<A: Adapter> DB<A> { impl<A: Adapter> DB<A> {
pub fn get<T: Transaction>(&self, txn: &T, key: &impl AsRef<[u8]>) -> Result<Option<A::Item>, db::Error> { pub fn get<T: Transaction>(
&self,
txn: &T,
key: &impl AsRef<[u8]>,
) -> Result<Option<A::Item>, db::Error> {
Ok(self.db.get(txn, key)?.map(A::decode)) Ok(self.db.get(txn, key)?.map(A::decode))
} }
@ -129,8 +133,7 @@ impl<A: Adapter> DB<A> {
key: &impl AsRef<[u8]>, key: &impl AsRef<[u8]>,
value: &A::Item, value: &A::Item,
flags: WriteFlags, flags: WriteFlags,
) -> Result<(), db::Error> ) -> Result<(), db::Error> {
{
let len = A::encoded_len(value); let len = A::encoded_len(value);
let buf = self.db.reserve(txn, key, len, flags)?; let buf = self.db.reserve(txn, key, len, flags)?;
assert_eq!(buf.len(), len, "Reserved buffer is not of requested size!"); assert_eq!(buf.len(), len, "Reserved buffer is not of requested size!");
@ -146,11 +149,12 @@ impl<A: Adapter> DB<A> {
self.db.clear(txn) self.db.clear(txn)
} }
pub fn get_all<'txn, T: Transaction>(&self, txn: &'txn T) -> Result<impl IntoIterator<Item=(&'txn [u8], A::Item)>, db::Error> { pub fn get_all<'txn, T: Transaction>(
&self,
txn: &'txn T,
) -> Result<impl IntoIterator<Item = (&'txn [u8], A::Item)>, db::Error> {
let mut cursor = self.db.open_ro_cursor(txn)?; let mut cursor = self.db.open_ro_cursor(txn)?;
let it = cursor.iter_start(); let it = cursor.iter_start();
Ok(it.filter_map(|buf| buf.ok().map(|(kbuf,vbuf)| { Ok(it.filter_map(|buf| buf.ok().map(|(kbuf, vbuf)| (kbuf, A::decode(vbuf)))))
(kbuf, A::decode(vbuf))
})))
} }
} }

View File

@ -1,7 +1,7 @@
use std::io;
use std::fmt;
use rsasl::error::SessionError;
use crate::db; use crate::db;
use rsasl::error::SessionError;
use std::fmt;
use std::io;
type DBError = db::Error; type DBError = db::Error;
@ -21,19 +21,19 @@ impl fmt::Display for Error {
match self { match self {
Error::SASL(e) => { Error::SASL(e) => {
write!(f, "SASL Error: {}", e) write!(f, "SASL Error: {}", e)
}, }
Error::IO(e) => { Error::IO(e) => {
write!(f, "IO Error: {}", e) write!(f, "IO Error: {}", e)
}, }
Error::Boxed(e) => { Error::Boxed(e) => {
write!(f, "{}", e) write!(f, "{}", e)
}, }
Error::Capnp(e) => { Error::Capnp(e) => {
write!(f, "Cap'n Proto Error: {}", e) write!(f, "Cap'n Proto Error: {}", e)
}, }
Error::DB(e) => { Error::DB(e) => {
write!(f, "DB Error: {:?}", e) write!(f, "DB Error: {:?}", e)
}, }
Error::Denied => { Error::Denied => {
write!(f, "You do not have the permission required to do that.") write!(f, "You do not have the permission required to do that.")
} }

View File

@ -1,9 +1,9 @@
use std::fs::{File, OpenOptions};
use std::{fmt, io};
use std::fmt::Formatter; use std::fmt::Formatter;
use std::fs::{File, OpenOptions};
use std::io::Write; use std::io::Write;
use std::path::Path; use std::path::Path;
use std::sync::Mutex; use std::sync::Mutex;
use std::{fmt, io};
// Internal mutable state for KeyLogFile // Internal mutable state for KeyLogFile
struct KeyLogFileInner { struct KeyLogFileInner {
@ -18,10 +18,7 @@ impl fmt::Debug for KeyLogFileInner {
impl KeyLogFileInner { impl KeyLogFileInner {
fn new(path: impl AsRef<Path>) -> io::Result<Self> { fn new(path: impl AsRef<Path>) -> io::Result<Self> {
let file = OpenOptions::new() let file = OpenOptions::new().append(true).create(true).open(path)?;
.append(true)
.create(true)
.open(path)?;
Ok(Self { Ok(Self {
file, file,

View File

@ -16,9 +16,9 @@ pub mod db;
/// Shared error type /// Shared error type
pub mod error; pub mod error;
pub mod users;
pub mod authentication; pub mod authentication;
pub mod authorization; pub mod authorization;
pub mod users;
/// Resources /// Resources
pub mod resources; pub mod resources;
@ -31,38 +31,34 @@ pub mod capnp;
pub mod utils; pub mod utils;
mod tls; mod audit;
mod keylog; mod keylog;
mod logging; mod logging;
mod audit;
mod session; mod session;
mod tls;
use std::sync::Arc;
use std::sync::{Arc};
use anyhow::Context; use anyhow::Context;
use futures_util::StreamExt; use futures_util::StreamExt;
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use signal_hook::consts::signal::*;
use executor::pool::Executor;
use crate::audit::AuditLog; use crate::audit::AuditLog;
use crate::authentication::AuthenticationHandle; use crate::authentication::AuthenticationHandle;
use crate::authorization::roles::Roles; use crate::authorization::roles::Roles;
use crate::capnp::APIServer; use crate::capnp::APIServer;
use crate::config::{Config}; use crate::config::Config;
use crate::resources::modules::fabaccess::MachineState; use crate::resources::modules::fabaccess::MachineState;
use crate::resources::Resource;
use crate::resources::search::ResourcesHandle; use crate::resources::search::ResourcesHandle;
use crate::resources::state::db::StateDB; use crate::resources::state::db::StateDB;
use crate::resources::Resource;
use crate::session::SessionManager; use crate::session::SessionManager;
use crate::tls::TlsConfig; use crate::tls::TlsConfig;
use crate::users::db::UserDB; use crate::users::db::UserDB;
use crate::users::Users; use crate::users::Users;
use executor::pool::Executor;
use signal_hook::consts::signal::*;
pub const VERSION_STRING: &'static str = env!("BFFHD_VERSION_STRING"); pub const VERSION_STRING: &'static str = env!("BFFHD_VERSION_STRING");
pub const RELEASE_STRING: &'static str = env!("BFFHD_RELEASE_STRING"); pub const RELEASE_STRING: &'static str = env!("BFFHD_RELEASE_STRING");
@ -81,7 +77,7 @@ pub static RESOURCES: OnceCell<ResourcesHandle> = OnceCell::new();
impl Diflouroborane { impl Diflouroborane {
pub fn new(config: Config) -> anyhow::Result<Self> { pub fn new(config: Config) -> anyhow::Result<Self> {
logging::init(&config.logging); logging::init(&config.logging);
tracing::info!(version=VERSION_STRING, "Starting BFFH"); tracing::info!(version = VERSION_STRING, "Starting BFFH");
let span = tracing::info_span!("setup"); let span = tracing::info_span!("setup");
let _guard = span.enter(); let _guard = span.enter();
@ -89,8 +85,8 @@ impl Diflouroborane {
let executor = Executor::new(); let executor = Executor::new();
let env = StateDB::open_env(&config.db_path)?; let env = StateDB::open_env(&config.db_path)?;
let statedb = StateDB::create_with_env(env.clone()) let statedb =
.context("Failed to open state DB file")?; StateDB::create_with_env(env.clone()).context("Failed to open state DB file")?;
let users = Users::new(env.clone()).context("Failed to open users DB file")?; let users = Users::new(env.clone()).context("Failed to open users DB file")?;
let roles = Roles::new(config.roles.clone()); let roles = Roles::new(config.roles.clone());
@ -98,31 +94,43 @@ impl Diflouroborane {
let _audit_log = AuditLog::new(&config).context("Failed to initialize audit log")?; let _audit_log = AuditLog::new(&config).context("Failed to initialize audit log")?;
let resources = ResourcesHandle::new(config.machines.iter().map(|(id, desc)| { let resources = ResourcesHandle::new(config.machines.iter().map(|(id, desc)| {
Resource::new(Arc::new(resources::Inner::new(id.to_string(), statedb.clone(), desc.clone()))) Resource::new(Arc::new(resources::Inner::new(
id.to_string(),
statedb.clone(),
desc.clone(),
)))
})); }));
RESOURCES.set(resources.clone()); RESOURCES.set(resources.clone());
Ok(Self {
Ok(Self { config, executor, statedb, users, roles, resources }) config,
executor,
statedb,
users,
roles,
resources,
})
} }
pub fn run(&mut self) -> anyhow::Result<()> { pub fn run(&mut self) -> anyhow::Result<()> {
let mut signals = signal_hook_async_std::Signals::new(&[ let mut signals = signal_hook_async_std::Signals::new(&[SIGINT, SIGQUIT, SIGTERM])
SIGINT, .context("Failed to construct signal handler")?;
SIGQUIT,
SIGTERM,
]).context("Failed to construct signal handler")?;
actors::load(self.executor.clone(), &self.config, self.resources.clone())?; actors::load(self.executor.clone(), &self.config, self.resources.clone())?;
let tlsconfig = TlsConfig::new(self.config.tlskeylog.as_ref(), !self.config.is_quiet())?; let tlsconfig = TlsConfig::new(self.config.tlskeylog.as_ref(), !self.config.is_quiet())?;
let acceptor = tlsconfig.make_tls_acceptor(&self.config.tlsconfig)?; let acceptor = tlsconfig.make_tls_acceptor(&self.config.tlsconfig)?;
let sessionmanager = SessionManager::new(self.users.clone(), self.roles.clone()); let sessionmanager = SessionManager::new(self.users.clone(), self.roles.clone());
let authentication = AuthenticationHandle::new(self.users.clone()); let authentication = AuthenticationHandle::new(self.users.clone());
let apiserver = self.executor.run(APIServer::bind(self.executor.clone(), &self.config.listens, acceptor, sessionmanager, authentication))?; let apiserver = self.executor.run(APIServer::bind(
self.executor.clone(),
&self.config.listens,
acceptor,
sessionmanager,
authentication,
))?;
let (mut tx, rx) = async_oneshot::oneshot(); let (mut tx, rx) = async_oneshot::oneshot();
@ -142,4 +150,3 @@ impl Diflouroborane {
Ok(()) Ok(())
} }
} }

View File

@ -1,7 +1,6 @@
use tracing_subscriber::{EnvFilter}; use tracing_subscriber::EnvFilter;
use serde::{Deserialize, Serialize};
use serde::{Serialize, Deserialize};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogConfig { pub struct LogConfig {
@ -32,8 +31,7 @@ pub fn init(config: &LogConfig) {
EnvFilter::from_env("BFFH_LOG") EnvFilter::from_env("BFFH_LOG")
}; };
let builder = tracing_subscriber::fmt() let builder = tracing_subscriber::fmt().with_env_filter(filter);
.with_env_filter(filter);
let format = config.format.to_lowercase(); let format = config.format.to_lowercase();
match format.as_str() { match format.as_str() {

View File

@ -1,16 +1,16 @@
use rkyv::{Archive, Serialize, Deserialize}; use rkyv::{Archive, Deserialize, Serialize};
#[derive(
Clone,
Debug,
PartialEq,
Eq,
Archive,
Serialize,
Deserialize,
#[derive(Clone, Debug, PartialEq, Eq)] serde::Serialize,
#[derive(Archive, Serialize, Deserialize)] serde::Deserialize,
#[derive(serde::Serialize, serde::Deserialize)] )]
pub struct Resource { pub struct Resource {
uuid: u128, uuid: u128,
id: String, id: String,

View File

@ -1,21 +1,21 @@
use futures_signals::signal::{Mutable, Signal};
use rkyv::Infallible; use rkyv::Infallible;
use std::ops::Deref; use std::ops::Deref;
use std::sync::Arc; use std::sync::Arc;
use futures_signals::signal::{Mutable, Signal};
use rkyv::{Archived, Deserialize};
use rkyv::option::ArchivedOption;
use rkyv::ser::Serializer;
use rkyv::ser::serializers::AllocSerializer;
use crate::audit::AUDIT; use crate::audit::AUDIT;
use crate::authorization::permissions::PrivilegesBuf; use crate::authorization::permissions::PrivilegesBuf;
use crate::config::MachineDescription; use crate::config::MachineDescription;
use crate::db::ArchivedValue; use crate::db::ArchivedValue;
use crate::resources::modules::fabaccess::{MachineState, Status, ArchivedStatus}; use crate::resources::modules::fabaccess::{ArchivedStatus, MachineState, Status};
use crate::resources::state::db::StateDB; use crate::resources::state::db::StateDB;
use crate::resources::state::State; use crate::resources::state::State;
use crate::session::SessionHandle; use crate::session::SessionHandle;
use crate::users::UserRef; use crate::users::UserRef;
use rkyv::option::ArchivedOption;
use rkyv::ser::serializers::AllocSerializer;
use rkyv::ser::Serializer;
use rkyv::{Archived, Deserialize};
pub mod db; pub mod db;
pub mod search; pub mod search;
@ -43,27 +43,35 @@ impl Inner {
let update = state.to_state(); let update = state.to_state();
let mut serializer = AllocSerializer::<1024>::default(); let mut serializer = AllocSerializer::<1024>::default();
serializer.serialize_value(&update).expect("failed to serialize new default state"); serializer
.serialize_value(&update)
.expect("failed to serialize new default state");
let val = ArchivedValue::new(serializer.into_serializer().into_inner()); let val = ArchivedValue::new(serializer.into_serializer().into_inner());
db.put(&id.as_bytes(), &val).unwrap(); db.put(&id.as_bytes(), &val).unwrap();
val val
}; };
let signal = Mutable::new(state); let signal = Mutable::new(state);
Self { id, db, signal, desc } Self {
id,
db,
signal,
desc,
}
} }
pub fn signal(&self) -> impl Signal<Item=ArchivedValue<State>> { pub fn signal(&self) -> impl Signal<Item = ArchivedValue<State>> {
Box::pin(self.signal.signal_cloned()) Box::pin(self.signal.signal_cloned())
} }
fn get_state(&self) -> ArchivedValue<State> { fn get_state(&self) -> ArchivedValue<State> {
self.db.get(self.id.as_bytes()) self.db
.get(self.id.as_bytes())
.expect("lmdb error") .expect("lmdb error")
.expect("state should never be None") .expect("state should never be None")
} }
fn get_state_ref(&self) -> impl Deref<Target=ArchivedValue<State>> + '_ { fn get_state_ref(&self) -> impl Deref<Target = ArchivedValue<State>> + '_ {
self.signal.lock_ref() self.signal.lock_ref()
} }
@ -76,7 +84,10 @@ impl Inner {
self.db.put(&self.id.as_bytes(), &state).unwrap(); self.db.put(&self.id.as_bytes(), &state).unwrap();
tracing::trace!("Updated DB, sending update signal"); tracing::trace!("Updated DB, sending update signal");
AUDIT.get().unwrap().log(self.id.as_str(), &format!("{}", state)); AUDIT
.get()
.unwrap()
.log(self.id.as_str(), &format!("{}", state));
self.signal.set(state); self.signal.set(state);
tracing::trace!("Sent update signal"); tracing::trace!("Sent update signal");
@ -85,7 +96,7 @@ impl Inner {
#[derive(Clone)] #[derive(Clone)]
pub struct Resource { pub struct Resource {
inner: Arc<Inner> inner: Arc<Inner>,
} }
impl Resource { impl Resource {
@ -97,7 +108,7 @@ impl Resource {
self.inner.get_state() self.inner.get_state()
} }
pub fn get_state_ref(&self) -> impl Deref<Target=ArchivedValue<State>> + '_ { pub fn get_state_ref(&self) -> impl Deref<Target = ArchivedValue<State>> + '_ {
self.inner.get_state_ref() self.inner.get_state_ref()
} }
@ -109,7 +120,7 @@ impl Resource {
self.inner.desc.name.as_str() self.inner.desc.name.as_str()
} }
pub fn get_signal(&self) -> impl Signal<Item=ArchivedValue<State>> { pub fn get_signal(&self) -> impl Signal<Item = ArchivedValue<State>> {
self.inner.signal() self.inner.signal()
} }
@ -125,13 +136,13 @@ impl Resource {
let state = self.get_state_ref(); let state = self.get_state_ref();
let state: &Archived<State> = state.as_ref(); let state: &Archived<State> = state.as_ref();
match &state.inner.state { match &state.inner.state {
ArchivedStatus::Blocked(user) | ArchivedStatus::Blocked(user)
ArchivedStatus::InUse(user) | | ArchivedStatus::InUse(user)
ArchivedStatus::Reserved(user) | | ArchivedStatus::Reserved(user)
ArchivedStatus::ToCheck(user) => { | ArchivedStatus::ToCheck(user) => {
let user = Deserialize::<UserRef, _>::deserialize(user, &mut Infallible).unwrap(); let user = Deserialize::<UserRef, _>::deserialize(user, &mut Infallible).unwrap();
Some(user) Some(user)
}, }
_ => None, _ => None,
} }
} }
@ -158,7 +169,8 @@ impl Resource {
let old = self.inner.get_state(); let old = self.inner.get_state();
let oldref: &Archived<State> = old.as_ref(); let oldref: &Archived<State> = old.as_ref();
let previous: &Archived<Option<UserRef>> = &oldref.inner.previous; let previous: &Archived<Option<UserRef>> = &oldref.inner.previous;
let previous = Deserialize::<Option<UserRef>, _>::deserialize(previous, &mut rkyv::Infallible) let previous =
Deserialize::<Option<UserRef>, _>::deserialize(previous, &mut rkyv::Infallible)
.expect("Infallible deserializer failed"); .expect("Infallible deserializer failed");
let new = MachineState { state, previous }; let new = MachineState { state, previous };
self.set_state(new); self.set_state(new);

View File

@ -1,14 +1,13 @@
use std::fmt;
use std::fmt::{Write};
use crate::utils::oid::ObjectIdentifier; use crate::utils::oid::ObjectIdentifier;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use rkyv::{Archive, Archived, Deserialize, Infallible}; use rkyv::{Archive, Archived, Deserialize, Infallible};
use std::fmt;
use std::fmt::Write;
use std::str::FromStr; use std::str::FromStr;
//use crate::oidvalue; //use crate::oidvalue;
use crate::resources::state::{State}; use crate::resources::state::State;
use crate::users::UserRef; use crate::users::UserRef;
@ -80,13 +79,12 @@ impl MachineState {
} }
pub fn from(dbstate: &Archived<State>) -> Self { pub fn from(dbstate: &Archived<State>) -> Self {
let state: &Archived<MachineState> = &dbstate.inner; let state: &Archived<MachineState> = &dbstate.inner;
Deserialize::deserialize(state, &mut Infallible).unwrap() Deserialize::deserialize(state, &mut Infallible).unwrap()
} }
pub fn to_state(&self) -> State { pub fn to_state(&self) -> State {
State { State {
inner: self.clone() inner: self.clone(),
} }
} }

View File

@ -1,6 +1,3 @@
pub mod fabaccess; pub mod fabaccess;
pub trait MachineModel { pub trait MachineModel {}
}

View File

@ -1,13 +1,13 @@
use crate::resources::Resource;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use crate::resources::Resource;
struct Inner { struct Inner {
id: HashMap<String, Resource>, id: HashMap<String, Resource>,
} }
impl Inner { impl Inner {
pub fn new(resources: impl IntoIterator<Item=Resource>) -> Self { pub fn new(resources: impl IntoIterator<Item = Resource>) -> Self {
let mut id = HashMap::new(); let mut id = HashMap::new();
for resource in resources { for resource in resources {
@ -25,13 +25,13 @@ pub struct ResourcesHandle {
} }
impl ResourcesHandle { impl ResourcesHandle {
pub fn new(resources: impl IntoIterator<Item=Resource>) -> Self { pub fn new(resources: impl IntoIterator<Item = Resource>) -> Self {
Self { Self {
inner: Arc::new(Inner::new(resources)), inner: Arc::new(Inner::new(resources)),
} }
} }
pub fn list_all(&self) -> impl IntoIterator<Item=&Resource> { pub fn list_all(&self) -> impl IntoIterator<Item = &Resource> {
self.inner.id.values() self.inner.id.values()
} }

View File

@ -1,9 +1,6 @@
use crate::db; use crate::db;
use crate::db::{ArchivedValue, RawDB, DB, AlignedAdapter}; use crate::db::{AlignedAdapter, ArchivedValue, RawDB, DB};
use lmdb::{ use lmdb::{DatabaseFlags, Environment, EnvironmentFlags, Transaction, WriteFlags};
DatabaseFlags, Environment, EnvironmentFlags, Transaction,
WriteFlags,
};
use std::{path::Path, sync::Arc}; use std::{path::Path, sync::Arc};
use crate::resources::state::State; use crate::resources::state::State;
@ -67,7 +64,7 @@ impl StateDB {
pub fn get_all<'txn, T: Transaction>( pub fn get_all<'txn, T: Transaction>(
&self, &self,
txn: &'txn T, txn: &'txn T,
) -> Result<impl IntoIterator<Item = (&'txn [u8], ArchivedValue<State>)>, db::Error, > { ) -> Result<impl IntoIterator<Item = (&'txn [u8], ArchivedValue<State>)>, db::Error> {
self.db.get_all(txn) self.db.get_all(txn)
} }

View File

@ -1,37 +1,27 @@
use std::{
fmt,
hash::{
Hasher
},
};
use std::fmt::{Debug, Display, Formatter}; use std::fmt::{Debug, Display, Formatter};
use std::{fmt, hash::Hasher};
use std::ops::Deref; use std::ops::Deref;
use rkyv::{out_field, Archive, Deserialize, Serialize};
use rkyv::{Archive, Deserialize, out_field, Serialize};
use serde::de::{Error, MapAccess, Unexpected}; use serde::de::{Error, MapAccess, Unexpected};
use serde::Deserializer;
use serde::ser::SerializeMap; use serde::ser::SerializeMap;
use serde::Deserializer;
use crate::MachineState;
use crate::resources::modules::fabaccess::OID_VALUE; use crate::resources::modules::fabaccess::OID_VALUE;
use crate::MachineState;
use crate::utils::oid::ObjectIdentifier; use crate::utils::oid::ObjectIdentifier;
pub mod value;
pub mod db; pub mod db;
pub mod value;
#[derive(Archive, Serialize, Deserialize)] #[derive(Archive, Serialize, Deserialize, Clone, PartialEq, Eq)]
#[derive(Clone, PartialEq, Eq)]
#[archive_attr(derive(Debug))] #[archive_attr(derive(Debug))]
pub struct State { pub struct State {
pub inner: MachineState, pub inner: MachineState,
} }
impl fmt::Debug for State { impl fmt::Debug for State {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut sf = f.debug_struct("State"); let mut sf = f.debug_struct("State");
@ -51,7 +41,8 @@ impl fmt::Display for ArchivedState {
impl serde::Serialize for State { impl serde::Serialize for State {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: serde::Serializer where
S: serde::Serializer,
{ {
let mut ser = serializer.serialize_map(Some(1))?; let mut ser = serializer.serialize_map(Some(1))?;
ser.serialize_entry(OID_VALUE.deref(), &self.inner)?; ser.serialize_entry(OID_VALUE.deref(), &self.inner)?;
@ -60,7 +51,8 @@ impl serde::Serialize for State {
} }
impl<'de> serde::Deserialize<'de> for State { impl<'de> serde::Deserialize<'de> for State {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de> where
D: Deserializer<'de>,
{ {
deserializer.deserialize_map(StateVisitor) deserializer.deserialize_map(StateVisitor)
} }
@ -74,12 +66,13 @@ impl<'de> serde::de::Visitor<'de> for StateVisitor {
write!(formatter, "a map from OIDs to value objects") write!(formatter, "a map from OIDs to value objects")
} }
fn visit_map<A: MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> fn visit_map<A: MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
{ let oid: ObjectIdentifier = map.next_key()?.ok_or(A::Error::missing_field("oid"))?;
let oid: ObjectIdentifier = map.next_key()?
.ok_or(A::Error::missing_field("oid"))?;
if oid != *OID_VALUE.deref() { if oid != *OID_VALUE.deref() {
return Err(A::Error::invalid_value(Unexpected::Other("Unknown OID"), &"OID of fabaccess state")) return Err(A::Error::invalid_value(
Unexpected::Other("Unknown OID"),
&"OID of fabaccess state",
));
} }
let val: MachineState = map.next_value()?; let val: MachineState = map.next_value()?;
Ok(State { inner: val }) Ok(State { inner: val })
@ -88,8 +81,8 @@ impl<'de> serde::de::Visitor<'de> for StateVisitor {
#[cfg(test)] #[cfg(test)]
pub mod tests { pub mod tests {
use super::*;
use super::value::*; use super::value::*;
use super::*;
pub(crate) fn gen_random() -> State { pub(crate) fn gen_random() -> State {
let amt: u8 = rand::random::<u8>() % 20; let amt: u8 = rand::random::<u8>() % 20;
@ -97,7 +90,7 @@ pub mod tests {
let mut sb = State::build(); let mut sb = State::build();
for _ in 0..amt { for _ in 0..amt {
let oid = crate::utils::oid::tests::gen_random(); let oid = crate::utils::oid::tests::gen_random();
sb = match rand::random::<u32>()%12 { sb = match rand::random::<u32>() % 12 {
0 => sb.add(oid, Box::new(rand::random::<bool>())), 0 => sb.add(oid, Box::new(rand::random::<bool>())),
1 => sb.add(oid, Box::new(rand::random::<u8>())), 1 => sb.add(oid, Box::new(rand::random::<u8>())),
2 => sb.add(oid, Box::new(rand::random::<u16>())), 2 => sb.add(oid, Box::new(rand::random::<u16>())),

View File

@ -1,10 +1,12 @@
use std::{hash::Hash}; use std::hash::Hash;
use ptr_meta::{DynMetadata, Pointee}; use ptr_meta::{DynMetadata, Pointee};
use rkyv::{out_field, Archive, ArchivePointee, ArchiveUnsized, Archived, ArchivedMetadata, Serialize, SerializeUnsized, RelPtr}; use rkyv::{
out_field, Archive, ArchivePointee, ArchiveUnsized, Archived, ArchivedMetadata, RelPtr,
Serialize, SerializeUnsized,
};
use rkyv_dyn::{DynError, DynSerializer}; use rkyv_dyn::{DynError, DynSerializer};
use crate::utils::oid::ObjectIdentifier; use crate::utils::oid::ObjectIdentifier;
// Not using linkme because dynamically loaded modules // Not using linkme because dynamically loaded modules
@ -16,7 +18,6 @@ use serde::ser::SerializeMap;
use std::collections::HashMap; use std::collections::HashMap;
use std::ops::Deref; use std::ops::Deref;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
@ -61,7 +62,7 @@ struct NewStateBuilder {
// turns into // turns into
struct NewState { struct NewState {
inner: ArchivedVec<ArchivedMetaBox<dyn ArchivedStateValue>> inner: ArchivedVec<ArchivedMetaBox<dyn ArchivedStateValue>>,
} }
impl NewState { impl NewState {
pub fn get_value<T: TypeOid>(&self) -> Option<&T> { pub fn get_value<T: TypeOid>(&self) -> Option<&T> {
@ -148,7 +149,7 @@ pub trait TypeOid {
} }
impl<T> SerializeDynOid for T impl<T> SerializeDynOid for T
where where
T: for<'a> Serialize<dyn DynSerializer + 'a>, T: for<'a> Serialize<dyn DynSerializer + 'a>,
T::Archived: TypeOid, T::Archived: TypeOid,
{ {
@ -371,7 +372,7 @@ pub mod macros {
stringify!($y) stringify!($y)
} }
} }
} };
} }
#[macro_export] #[macro_export]
@ -380,16 +381,15 @@ pub mod macros {
unsafe impl $crate::resources::state::value::RegisteredImpl for $z { unsafe impl $crate::resources::state::value::RegisteredImpl for $z {
fn vtable() -> usize { fn vtable() -> usize {
unsafe { unsafe {
::core::mem::transmute(ptr_meta::metadata( ::core::mem::transmute(ptr_meta::metadata(::core::ptr::null::<$z>()
::core::ptr::null::<$z>() as *const dyn $crate::resources::state::value::ArchivedStateValue as *const dyn $crate::resources::state::value::ArchivedStateValue))
))
} }
} }
fn debug_info() -> $crate::resources::state::value::ImplDebugInfo { fn debug_info() -> $crate::resources::state::value::ImplDebugInfo {
$crate::debug_info!() $crate::debug_info!()
} }
} }
} };
} }
#[macro_export] #[macro_export]

View File

@ -0,0 +1 @@

View File

@ -1,20 +1,13 @@
use crate::authorization::permissions::Permission; use crate::authorization::permissions::Permission;
use crate::authorization::roles::{Roles}; use crate::authorization::roles::Roles;
use crate::resources::Resource; use crate::resources::Resource;
use crate::Users;
use crate::users::{db, UserRef}; use crate::users::{db, UserRef};
use crate::Users;
#[derive(Clone)] #[derive(Clone)]
pub struct SessionManager { pub struct SessionManager {
users: Users, users: Users,
roles: Roles, roles: Roles,
// cache: SessionCache // todo // cache: SessionCache // todo
} }
impl SessionManager { impl SessionManager {
@ -52,33 +45,39 @@ impl SessionHandle {
} }
pub fn get_user(&self) -> db::User { pub fn get_user(&self) -> db::User {
self.users.get_user(self.user.get_username()).expect("Failed to get user self") self.users
.get_user(self.user.get_username())
.expect("Failed to get user self")
} }
pub fn has_disclose(&self, resource: &Resource) -> bool { pub fn has_disclose(&self, resource: &Resource) -> bool {
if let Some(user) = self.users.get_user(self.user.get_username()) { if let Some(user) = self.users.get_user(self.user.get_username()) {
self.roles.is_permitted(&user.userdata, &resource.get_required_privs().disclose) self.roles
.is_permitted(&user.userdata, &resource.get_required_privs().disclose)
} else { } else {
false false
} }
} }
pub fn has_read(&self, resource: &Resource) -> bool { pub fn has_read(&self, resource: &Resource) -> bool {
if let Some(user) = self.users.get_user(self.user.get_username()) { if let Some(user) = self.users.get_user(self.user.get_username()) {
self.roles.is_permitted(&user.userdata, &resource.get_required_privs().read) self.roles
.is_permitted(&user.userdata, &resource.get_required_privs().read)
} else { } else {
false false
} }
} }
pub fn has_write(&self, resource: &Resource) -> bool { pub fn has_write(&self, resource: &Resource) -> bool {
if let Some(user) = self.users.get_user(self.user.get_username()) { if let Some(user) = self.users.get_user(self.user.get_username()) {
self.roles.is_permitted(&user.userdata, &resource.get_required_privs().write) self.roles
.is_permitted(&user.userdata, &resource.get_required_privs().write)
} else { } else {
false false
} }
} }
pub fn has_manage(&self, resource: &Resource) -> bool { pub fn has_manage(&self, resource: &Resource) -> bool {
if let Some(user) = self.users.get_user(self.user.get_username()) { if let Some(user) = self.users.get_user(self.user.get_username()) {
self.roles.is_permitted(&user.userdata, &resource.get_required_privs().manage) self.roles
.is_permitted(&user.userdata, &resource.get_required_privs().manage)
} else { } else {
false false
} }

View File

@ -4,26 +4,39 @@ use std::io::BufReader;
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use futures_rustls::TlsAcceptor;
use rustls::{Certificate, PrivateKey, ServerConfig, SupportedCipherSuite};
use rustls::version::{TLS12, TLS13};
use tracing::{Level};
use crate::capnp::TlsListen; use crate::capnp::TlsListen;
use futures_rustls::TlsAcceptor;
use rustls::version::{TLS12, TLS13};
use rustls::{Certificate, PrivateKey, ServerConfig, SupportedCipherSuite};
use tracing::Level;
use crate::keylog::KeyLogFile; use crate::keylog::KeyLogFile;
fn lookup_cipher_suite(name: &str) -> Option<SupportedCipherSuite> { fn lookup_cipher_suite(name: &str) -> Option<SupportedCipherSuite> {
match name { match name {
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => Some(rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => {
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => Some(rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384), Some(rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256)
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => Some(rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256), }
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" => Some(rustls::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256), "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => {
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" => Some(rustls::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384), Some(rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384)
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => Some(rustls::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256), }
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => {
Some(rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256)
}
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" => {
Some(rustls::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256)
}
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" => {
Some(rustls::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384)
}
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => {
Some(rustls::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256)
}
"TLS13_AES_128_GCM_SHA256" => Some(rustls::cipher_suite::TLS13_AES_128_GCM_SHA256), "TLS13_AES_128_GCM_SHA256" => Some(rustls::cipher_suite::TLS13_AES_128_GCM_SHA256),
"TLS13_AES_256_GCM_SHA384" => Some(rustls::cipher_suite::TLS13_AES_256_GCM_SHA384), "TLS13_AES_256_GCM_SHA384" => Some(rustls::cipher_suite::TLS13_AES_256_GCM_SHA384),
"TLS13_CHACHA20_POLY1305_SHA256" => Some(rustls::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256), "TLS13_CHACHA20_POLY1305_SHA256" => {
Some(rustls::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256)
}
_ => None, _ => None,
} }
} }
@ -43,7 +56,6 @@ impl TlsConfig {
} }
if let Some(path) = keylogfile { if let Some(path) = keylogfile {
let keylog = Some(KeyLogFile::new(path).map(|ok| Arc::new(ok))?); let keylog = Some(KeyLogFile::new(path).map(|ok| Arc::new(ok))?);
Ok(Self { keylog }) Ok(Self { keylog })
} else { } else {

View File

@ -1,15 +1,15 @@
use lmdb::{DatabaseFlags, Environment, RwTransaction, Transaction, WriteFlags}; use lmdb::{DatabaseFlags, Environment, RwTransaction, Transaction, WriteFlags};
use std::collections::{HashMap};
use rkyv::Infallible; use rkyv::Infallible;
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Context; use anyhow::Context;
use std::sync::Arc;
use rkyv::{Deserialize};
use rkyv::ser::Serializer;
use rkyv::ser::serializers::AllocSerializer;
use crate::db; use crate::db;
use crate::db::{AlignedAdapter, ArchivedValue, DB, RawDB}; use crate::db::{AlignedAdapter, ArchivedValue, RawDB, DB};
use rkyv::ser::serializers::AllocSerializer;
use rkyv::ser::Serializer;
use rkyv::Deserialize;
#[derive( #[derive(
Clone, Clone,
@ -30,8 +30,7 @@ pub struct User {
impl User { impl User {
pub fn check_password(&self, pwd: &[u8]) -> anyhow::Result<bool> { pub fn check_password(&self, pwd: &[u8]) -> anyhow::Result<bool> {
if let Some(ref encoded) = self.userdata.passwd { if let Some(ref encoded) = self.userdata.passwd {
argon2::verify_encoded(encoded, pwd) argon2::verify_encoded(encoded, pwd).context("Stored password is an invalid string")
.context("Stored password is an invalid string")
} else { } else {
Ok(false) Ok(false)
} }
@ -48,23 +47,23 @@ impl User {
id: username.to_string(), id: username.to_string(),
userdata: UserData { userdata: UserData {
passwd: Some(hash), passwd: Some(hash),
.. Default::default() ..Default::default()
} },
} }
} }
} }
#[derive( #[derive(
Clone, Clone,
PartialEq, PartialEq,
Eq, Eq,
Debug, Debug,
Default, Default,
rkyv::Archive, rkyv::Archive,
rkyv::Serialize, rkyv::Serialize,
rkyv::Deserialize, rkyv::Deserialize,
serde::Serialize, serde::Serialize,
serde::Deserialize, serde::Deserialize,
)] )]
/// Data on an user to base decisions on /// Data on an user to base decisions on
/// ///
@ -85,12 +84,19 @@ pub struct UserData {
impl UserData { impl UserData {
pub fn new(roles: Vec<String>) -> Self { pub fn new(roles: Vec<String>) -> Self {
Self { roles, kv: HashMap::new(), passwd: None } Self {
roles,
kv: HashMap::new(),
passwd: None,
}
} }
pub fn new_with_kv(roles: Vec<String>, kv: HashMap<String, String>) -> Self { pub fn new_with_kv(roles: Vec<String>, kv: HashMap<String, String>) -> Self {
Self { roles, kv, passwd: None } Self {
roles,
kv,
passwd: None,
}
} }
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -140,7 +146,12 @@ impl UserDB {
Ok(()) Ok(())
} }
pub fn put_txn(&self, txn: &mut RwTransaction, uid: &str, user: &User) -> Result<(), db::Error> { pub fn put_txn(
&self,
txn: &mut RwTransaction,
uid: &str,
user: &User,
) -> Result<(), db::Error> {
let mut serializer = AllocSerializer::<1024>::default(); let mut serializer = AllocSerializer::<1024>::default();
serializer.serialize_value(user).expect("rkyv error"); serializer.serialize_value(user).expect("rkyv error");
let v = serializer.into_serializer().into_inner(); let v = serializer.into_serializer().into_inner();
@ -169,7 +180,8 @@ impl UserDB {
let mut out = Vec::new(); let mut out = Vec::new();
for (uid, user) in iter { for (uid, user) in iter {
let uid = unsafe { std::str::from_utf8_unchecked(uid).to_string() }; let uid = unsafe { std::str::from_utf8_unchecked(uid).to_string() };
let user: User = Deserialize::<User, _>::deserialize(user.as_ref(), &mut Infallible).unwrap(); let user: User =
Deserialize::<User, _>::deserialize(user.as_ref(), &mut Infallible).unwrap();
out.push((uid, user)); out.push((uid, user));
} }

View File

@ -10,8 +10,6 @@ use std::sync::Arc;
pub mod db; pub mod db;
use crate::users::db::UserData; use crate::users::db::UserData;
use crate::UserDB; use crate::UserDB;
@ -87,10 +85,9 @@ impl Users {
pub fn get_user(&self, uid: &str) -> Option<db::User> { pub fn get_user(&self, uid: &str) -> Option<db::User> {
tracing::trace!(uid, "Looking up user"); tracing::trace!(uid, "Looking up user");
self.userdb self.userdb.get(uid).unwrap().map(|user| {
.get(uid) Deserialize::<db::User, _>::deserialize(user.as_ref(), &mut Infallible).unwrap()
.unwrap() })
.map(|user| Deserialize::<db::User, _>::deserialize(user.as_ref(), &mut Infallible).unwrap())
} }
pub fn put_user(&self, uid: &str, user: &db::User) -> Result<(), lmdb::Error> { pub fn put_user(&self, uid: &str, user: &db::User) -> Result<(), lmdb::Error> {

View File

@ -1,20 +1,16 @@
use std::collections::HashMap; use std::collections::HashMap;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
struct Locales { struct Locales {
map: HashMap<&'static str, HashMap<&'static str, &'static str>> map: HashMap<&'static str, HashMap<&'static str, &'static str>>,
} }
impl Locales { impl Locales {
pub fn get(&self, lang: &str, msg: &str) pub fn get(&self, lang: &str, msg: &str) -> Option<(&'static str, &'static str)> {
-> Option<(&'static str, &'static str)> self.map
{ .get(msg)
self.map.get(msg).and_then(|map| { .and_then(|map| map.get_key_value(lang).map(|(k, v)| (*k, *v)))
map.get_key_value(lang).map(|(k,v)| (*k, *v))
})
} }
pub fn available(&self, _msg: &str) -> &[&'static str] { pub fn available(&self, _msg: &str) -> &[&'static str] {
@ -22,8 +18,8 @@ impl Locales {
} }
} }
static LANG: Lazy<Locales> = Lazy::new(|| { static LANG: Lazy<Locales> = Lazy::new(|| Locales {
Locales { map: HashMap::new() } map: HashMap::new(),
}); });
struct L10NString { struct L10NString {

View File

@ -52,16 +52,16 @@
//! [Object Identifiers]: https://en.wikipedia.org/wiki/Object_identifier //! [Object Identifiers]: https://en.wikipedia.org/wiki/Object_identifier
//! [ITU]: https://en.wikipedia.org/wiki/International_Telecommunications_Union //! [ITU]: https://en.wikipedia.org/wiki/International_Telecommunications_Union
use rkyv::{Archive, Serialize}; use crate::utils::varint::VarU128;
use rkyv::ser::Serializer;
use rkyv::vec::{ArchivedVec, VecResolver}; use rkyv::vec::{ArchivedVec, VecResolver};
use rkyv::{Archive, Serialize};
use std::convert::TryFrom; use std::convert::TryFrom;
use std::ops::Deref; use std::convert::TryInto;
use std::fmt; use std::fmt;
use std::fmt::Formatter; use std::fmt::Formatter;
use rkyv::ser::Serializer; use std::ops::Deref;
use std::str::FromStr; use std::str::FromStr;
use crate::utils::varint::VarU128;
use std::convert::TryInto;
type Node = u128; type Node = u128;
type VarNode = VarU128; type VarNode = VarU128;
@ -69,8 +69,8 @@ type VarNode = VarU128;
/// Convenience module for quickly importing the public interface (e.g., `use oid::prelude::*`) /// Convenience module for quickly importing the public interface (e.g., `use oid::prelude::*`)
pub mod prelude { pub mod prelude {
pub use super::ObjectIdentifier; pub use super::ObjectIdentifier;
pub use super::ObjectIdentifierRoot::*;
pub use super::ObjectIdentifierError; pub use super::ObjectIdentifierError;
pub use super::ObjectIdentifierRoot::*;
pub use core::convert::{TryFrom, TryInto}; pub use core::convert::{TryFrom, TryInto};
} }
@ -132,7 +132,8 @@ impl ObjectIdentifier {
let mut parsing_big_int = false; let mut parsing_big_int = false;
let mut big_int: Node = 0; let mut big_int: Node = 0;
for i in 1..nodes.len() { for i in 1..nodes.len() {
if !parsing_big_int && nodes[i] < 128 {} else { if !parsing_big_int && nodes[i] < 128 {
} else {
if big_int > 0 { if big_int > 0 {
if big_int >= Node::MAX >> 7 { if big_int >= Node::MAX >> 7 {
return Err(ObjectIdentifierError::IllegalChildNodeValue); return Err(ObjectIdentifierError::IllegalChildNodeValue);
@ -149,9 +150,11 @@ impl ObjectIdentifier {
Ok(Self { nodes }) Ok(Self { nodes })
} }
pub fn build<B: AsRef<[Node]>>(root: ObjectIdentifierRoot, first: u8, children: B) pub fn build<B: AsRef<[Node]>>(
-> Result<Self, ObjectIdentifierError> root: ObjectIdentifierRoot,
{ first: u8,
children: B,
) -> Result<Self, ObjectIdentifierError> {
if first > 40 { if first > 40 {
return Err(ObjectIdentifierError::IllegalFirstChildNode); return Err(ObjectIdentifierError::IllegalFirstChildNode);
} }
@ -163,7 +166,9 @@ impl ObjectIdentifier {
let var: VarNode = child.into(); let var: VarNode = child.into();
vec.extend_from_slice(var.as_bytes()) vec.extend_from_slice(var.as_bytes())
} }
Ok(Self { nodes: vec.into_boxed_slice() }) Ok(Self {
nodes: vec.into_boxed_slice(),
})
} }
#[inline(always)] #[inline(always)]
@ -196,12 +201,14 @@ impl FromStr for ObjectIdentifier {
fn from_str(value: &str) -> Result<Self, Self::Err> { fn from_str(value: &str) -> Result<Self, Self::Err> {
let mut nodes = value.split("."); let mut nodes = value.split(".");
let root = nodes.next() let root = nodes
.next()
.and_then(|n| n.parse::<u8>().ok()) .and_then(|n| n.parse::<u8>().ok())
.and_then(|n| n.try_into().ok()) .and_then(|n| n.try_into().ok())
.ok_or(ObjectIdentifierError::IllegalRootNode)?; .ok_or(ObjectIdentifierError::IllegalRootNode)?;
let first = nodes.next() let first = nodes
.next()
.and_then(|n| parse_string_first_node(n).ok()) .and_then(|n| parse_string_first_node(n).ok())
.ok_or(ObjectIdentifierError::IllegalFirstChildNode)?; .ok_or(ObjectIdentifierError::IllegalFirstChildNode)?;
@ -238,7 +245,7 @@ impl fmt::Debug for ObjectIdentifier {
#[repr(transparent)] #[repr(transparent)]
pub struct ArchivedObjectIdentifier { pub struct ArchivedObjectIdentifier {
archived: ArchivedVec<u8> archived: ArchivedVec<u8>,
} }
impl Deref for ArchivedObjectIdentifier { impl Deref for ArchivedObjectIdentifier {
@ -250,8 +257,12 @@ impl Deref for ArchivedObjectIdentifier {
impl fmt::Debug for ArchivedObjectIdentifier { impl fmt::Debug for ArchivedObjectIdentifier {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}", &convert_to_string(self.archived.as_slice()) write!(
.unwrap_or_else(|e| format!("Invalid OID: {:?}", e))) f,
"{}",
&convert_to_string(self.archived.as_slice())
.unwrap_or_else(|e| format!("Invalid OID: {:?}", e))
)
} }
} }
@ -275,7 +286,8 @@ impl Archive for &'static ObjectIdentifier {
} }
impl<S: Serializer + ?Sized> Serialize<S> for ObjectIdentifier impl<S: Serializer + ?Sized> Serialize<S> for ObjectIdentifier
where [u8]: rkyv::SerializeUnsized<S> where
[u8]: rkyv::SerializeUnsized<S>,
{ {
fn serialize(&self, serializer: &mut S) -> Result<Self::Resolver, S::Error> { fn serialize(&self, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
ArchivedVec::serialize_from_slice(self.nodes.as_ref(), serializer) ArchivedVec::serialize_from_slice(self.nodes.as_ref(), serializer)
@ -340,8 +352,7 @@ fn convert_to_string(nodes: &[u8]) -> Result<String, ObjectIdentifierError> {
impl Into<String> for &ObjectIdentifier { impl Into<String> for &ObjectIdentifier {
fn into(self) -> String { fn into(self) -> String {
convert_to_string(&self.nodes) convert_to_string(&self.nodes).expect("Valid OID object couldn't be serialized.")
.expect("Valid OID object couldn't be serialized.")
} }
} }
@ -468,16 +479,13 @@ mod serde_support {
} }
} }
impl ser::Serialize for ArchivedObjectIdentifier { impl ser::Serialize for ArchivedObjectIdentifier {
fn serialize<S>( fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
&self,
serializer: S,
) -> Result<S::Ok, S::Error>
where where
S: ser::Serializer, S: ser::Serializer,
{ {
if serializer.is_human_readable() { if serializer.is_human_readable() {
let encoded: String = convert_to_string(self.deref()) let encoded: String =
.expect("Failed to convert valid OID to String"); convert_to_string(self.deref()).expect("Failed to convert valid OID to String");
serializer.serialize_str(&encoded) serializer.serialize_str(&encoded)
} else { } else {
serializer.serialize_bytes(self.deref()) serializer.serialize_bytes(self.deref())
@ -498,8 +506,7 @@ pub(crate) mod tests {
children.push(rand::random()); children.push(rand::random());
} }
ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, 25, children) ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, 25, children).unwrap()
.unwrap()
} }
#[test] #[test]
@ -508,7 +515,8 @@ pub(crate) mod tests {
ObjectIdentifierRoot::ItuT, ObjectIdentifierRoot::ItuT,
0x01, 0x01,
vec![1, 2, 3, 5, 8, 13, 21], vec![1, 2, 3, 5, 8, 13, 21],
).unwrap(); )
.unwrap();
let buffer: Vec<u8> = bincode::serialize(&expected).unwrap(); let buffer: Vec<u8> = bincode::serialize(&expected).unwrap();
let actual = bincode::deserialize(&buffer).unwrap(); let actual = bincode::deserialize(&buffer).unwrap();
assert_eq!(expected, actual); assert_eq!(expected, actual);
@ -517,11 +525,7 @@ pub(crate) mod tests {
#[test] #[test]
fn encode_binary_root_node_0() { fn encode_binary_root_node_0() {
let expected: Vec<u8> = vec![0]; let expected: Vec<u8> = vec![0];
let oid = ObjectIdentifier::build( let oid = ObjectIdentifier::build(ObjectIdentifierRoot::ItuT, 0x00, vec![]).unwrap();
ObjectIdentifierRoot::ItuT,
0x00,
vec![],
).unwrap();
let actual: Vec<u8> = oid.into(); let actual: Vec<u8> = oid.into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -529,11 +533,7 @@ pub(crate) mod tests {
#[test] #[test]
fn encode_binary_root_node_1() { fn encode_binary_root_node_1() {
let expected: Vec<u8> = vec![40]; let expected: Vec<u8> = vec![40];
let oid = ObjectIdentifier::build( let oid = ObjectIdentifier::build(ObjectIdentifierRoot::Iso, 0x00, vec![]).unwrap();
ObjectIdentifierRoot::Iso,
0x00,
vec![],
).unwrap();
let actual: Vec<u8> = oid.into(); let actual: Vec<u8> = oid.into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -541,11 +541,8 @@ pub(crate) mod tests {
#[test] #[test]
fn encode_binary_root_node_2() { fn encode_binary_root_node_2() {
let expected: Vec<u8> = vec![80]; let expected: Vec<u8> = vec![80];
let oid = ObjectIdentifier::build( let oid =
ObjectIdentifierRoot::JointIsoItuT, ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, 0x00, vec![]).unwrap();
0x00,
vec![],
).unwrap();
let actual: Vec<u8> = oid.into(); let actual: Vec<u8> = oid.into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -557,7 +554,8 @@ pub(crate) mod tests {
ObjectIdentifierRoot::ItuT, ObjectIdentifierRoot::ItuT,
0x01, 0x01,
vec![1, 2, 3, 5, 8, 13, 21], vec![1, 2, 3, 5, 8, 13, 21],
).unwrap(); )
.unwrap();
let actual: Vec<u8> = oid.into(); let actual: Vec<u8> = oid.into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -572,7 +570,8 @@ pub(crate) mod tests {
ObjectIdentifierRoot::JointIsoItuT, ObjectIdentifierRoot::JointIsoItuT,
39, 39,
vec![42, 2501, 65535, 2147483647, 1235, 2352], vec![42, 2501, 65535, 2147483647, 1235, 2352],
).unwrap(); )
.unwrap();
let actual: Vec<u8> = (oid).into(); let actual: Vec<u8> = (oid).into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -580,11 +579,7 @@ pub(crate) mod tests {
#[test] #[test]
fn encode_string_root_node_0() { fn encode_string_root_node_0() {
let expected = "0.0"; let expected = "0.0";
let oid = ObjectIdentifier::build( let oid = ObjectIdentifier::build(ObjectIdentifierRoot::ItuT, 0x00, vec![]).unwrap();
ObjectIdentifierRoot::ItuT,
0x00,
vec![],
).unwrap();
let actual: String = (oid).into(); let actual: String = (oid).into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -592,11 +587,7 @@ pub(crate) mod tests {
#[test] #[test]
fn encode_string_root_node_1() { fn encode_string_root_node_1() {
let expected = "1.0"; let expected = "1.0";
let oid = ObjectIdentifier::build( let oid = ObjectIdentifier::build(ObjectIdentifierRoot::Iso, 0x00, vec![]).unwrap();
ObjectIdentifierRoot::Iso,
0x00,
vec![],
).unwrap();
let actual: String = (&oid).into(); let actual: String = (&oid).into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -604,11 +595,8 @@ pub(crate) mod tests {
#[test] #[test]
fn encode_string_root_node_2() { fn encode_string_root_node_2() {
let expected = "2.0"; let expected = "2.0";
let oid = ObjectIdentifier::build( let oid =
ObjectIdentifierRoot::JointIsoItuT, ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, 0x00, vec![]).unwrap();
0x00,
vec![],
).unwrap();
let actual: String = (&oid).into(); let actual: String = (&oid).into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -620,7 +608,8 @@ pub(crate) mod tests {
ObjectIdentifierRoot::ItuT, ObjectIdentifierRoot::ItuT,
0x01, 0x01,
vec![1, 2, 3, 5, 8, 13, 21], vec![1, 2, 3, 5, 8, 13, 21],
).unwrap(); )
.unwrap();
let actual: String = (&oid).into(); let actual: String = (&oid).into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -632,40 +621,29 @@ pub(crate) mod tests {
ObjectIdentifierRoot::JointIsoItuT, ObjectIdentifierRoot::JointIsoItuT,
39, 39,
vec![42, 2501, 65535, 2147483647, 1235, 2352], vec![42, 2501, 65535, 2147483647, 1235, 2352],
).unwrap(); )
.unwrap();
let actual: String = (&oid).into(); let actual: String = (&oid).into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
#[test] #[test]
fn parse_binary_root_node_0() { fn parse_binary_root_node_0() {
let expected = ObjectIdentifier::build( let expected = ObjectIdentifier::build(ObjectIdentifierRoot::ItuT, 0x00, vec![]);
ObjectIdentifierRoot::ItuT,
0x00,
vec![],
);
let actual = vec![0x00].try_into(); let actual = vec![0x00].try_into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
#[test] #[test]
fn parse_binary_root_node_1() { fn parse_binary_root_node_1() {
let expected = ObjectIdentifier::build( let expected = ObjectIdentifier::build(ObjectIdentifierRoot::Iso, 0x00, vec![]);
ObjectIdentifierRoot::Iso,
0x00,
vec![],
);
let actual = vec![40].try_into(); let actual = vec![40].try_into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
#[test] #[test]
fn parse_binary_root_node_2() { fn parse_binary_root_node_2() {
let expected = ObjectIdentifier::build( let expected = ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, 0x00, vec![]);
ObjectIdentifierRoot::JointIsoItuT,
0x00,
vec![],
);
let actual = vec![80].try_into(); let actual = vec![80].try_into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -698,33 +676,21 @@ pub(crate) mod tests {
#[test] #[test]
fn parse_string_root_node_0() { fn parse_string_root_node_0() {
let expected = ObjectIdentifier::build( let expected = ObjectIdentifier::build(ObjectIdentifierRoot::ItuT, 0x00, vec![]);
ObjectIdentifierRoot::ItuT,
0x00,
vec![],
);
let actual = "0.0".try_into(); let actual = "0.0".try_into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
#[test] #[test]
fn parse_string_root_node_1() { fn parse_string_root_node_1() {
let expected = ObjectIdentifier::build( let expected = ObjectIdentifier::build(ObjectIdentifierRoot::Iso, 0x00, vec![]);
ObjectIdentifierRoot::Iso,
0x00,
vec![],
);
let actual = "1.0".try_into(); let actual = "1.0".try_into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
#[test] #[test]
fn parse_string_root_node_2() { fn parse_string_root_node_2() {
let expected = ObjectIdentifier::build( let expected = ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, 0x00, vec![]);
ObjectIdentifierRoot::JointIsoItuT,
0x00,
vec![],
);
let actual = "2.0".try_into(); let actual = "2.0".try_into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -852,10 +818,11 @@ pub(crate) mod tests {
#[test] #[test]
fn parse_string_large_children_ok() { fn parse_string_large_children_ok() {
let expected = let expected = ObjectIdentifier::build(
ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, ObjectIdentifierRoot::JointIsoItuT,
25, 25,
vec![190754093376743485973207716749546715206, vec![
190754093376743485973207716749546715206,
255822649272987943607843257596365752308, 255822649272987943607843257596365752308,
15843412533224453995377625663329542022, 15843412533224453995377625663329542022,
6457999595881951503805148772927347934, 6457999595881951503805148772927347934,
@ -863,7 +830,9 @@ pub(crate) mod tests {
195548685662657784196186957311035194990, 195548685662657784196186957311035194990,
233020488258340943072303499291936117654, 233020488258340943072303499291936117654,
193307160423854019916786016773068715190, 193307160423854019916786016773068715190,
]).unwrap(); ],
)
.unwrap();
let actual = "2.25.190754093376743485973207716749546715206.\ let actual = "2.25.190754093376743485973207716749546715206.\
255822649272987943607843257596365752308.\ 255822649272987943607843257596365752308.\
15843412533224453995377625663329542022.\ 15843412533224453995377625663329542022.\
@ -871,18 +840,17 @@ pub(crate) mod tests {
19545192863105095042881850060069531734.\ 19545192863105095042881850060069531734.\
195548685662657784196186957311035194990.\ 195548685662657784196186957311035194990.\
233020488258340943072303499291936117654.\ 233020488258340943072303499291936117654.\
193307160423854019916786016773068715190".try_into().unwrap(); 193307160423854019916786016773068715190"
.try_into()
.unwrap();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
#[test] #[test]
fn encode_to_string() { fn encode_to_string() {
let expected = String::from("1.2.3.4"); let expected = String::from("1.2.3.4");
let actual: String = ObjectIdentifier::build( let actual: String = ObjectIdentifier::build(ObjectIdentifierRoot::Iso, 2, vec![3, 4])
ObjectIdentifierRoot::Iso, .unwrap()
2,
vec![3, 4],
).unwrap()
.into(); .into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -890,11 +858,8 @@ pub(crate) mod tests {
#[test] #[test]
fn encode_to_bytes() { fn encode_to_bytes() {
let expected = vec![0x2A, 0x03, 0x04]; let expected = vec![0x2A, 0x03, 0x04];
let actual: Vec<u8> = ObjectIdentifier::build( let actual: Vec<u8> = ObjectIdentifier::build(ObjectIdentifierRoot::Iso, 2, vec![3, 4])
ObjectIdentifierRoot::Iso, .unwrap()
2,
vec![3, 4],
).unwrap()
.into(); .into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }

View File

@ -1,11 +1,10 @@
use uuid::Uuid;
use api::general_capnp::u_u_i_d::{Builder, Reader}; use api::general_capnp::u_u_i_d::{Builder, Reader};
use uuid::Uuid;
pub fn uuid_to_api(uuid: Uuid, mut builder: Builder) { pub fn uuid_to_api(uuid: Uuid, mut builder: Builder) {
let [a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p] let [a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p] = uuid.as_u128().to_ne_bytes();
= uuid.as_u128().to_ne_bytes(); let lower = u64::from_ne_bytes([a, b, c, d, e, f, g, h]);
let lower = u64::from_ne_bytes([a,b,c,d,e,f,g,h]); let upper = u64::from_ne_bytes([i, j, k, l, m, n, o, p]);
let upper = u64::from_ne_bytes([i,j,k,l,m,n,o,p]);
builder.set_uuid0(lower); builder.set_uuid0(lower);
builder.set_uuid1(upper); builder.set_uuid1(upper);
} }
@ -13,8 +12,8 @@ pub fn uuid_to_api(uuid: Uuid, mut builder: Builder) {
pub fn api_to_uuid(reader: Reader) -> Uuid { pub fn api_to_uuid(reader: Reader) -> Uuid {
let lower: u64 = reader.reborrow().get_uuid0(); let lower: u64 = reader.reborrow().get_uuid0();
let upper: u64 = reader.get_uuid1(); let upper: u64 = reader.get_uuid1();
let [a,b,c,d,e,f,g,h] = lower.to_ne_bytes(); let [a, b, c, d, e, f, g, h] = lower.to_ne_bytes();
let [i,j,k,l,m,n,o,p] = upper.to_ne_bytes(); let [i, j, k, l, m, n, o, p] = upper.to_ne_bytes();
let num = u128::from_ne_bytes([a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p]); let num = u128::from_ne_bytes([a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p]);
Uuid::from_u128(num) Uuid::from_u128(num)
} }

View File

@ -27,7 +27,6 @@ impl<const N: usize> VarUInt<N> {
pub const fn into_bytes(self) -> [u8; N] { pub const fn into_bytes(self) -> [u8; N] {
self.bytes self.bytes
} }
} }
impl<const N: usize> Default for VarUInt<N> { impl<const N: usize> Default for VarUInt<N> {
@ -52,7 +51,7 @@ macro_rules! convert_from {
let bytes = this.as_mut_bytes(); let bytes = this.as_mut_bytes();
let mut more = 0u8; let mut more = 0u8;
let mut idx: usize = bytes.len()-1; let mut idx: usize = bytes.len() - 1;
while num > 0x7f { while num > 0x7f {
bytes[idx] = ((num & 0x7f) as u8 | more); bytes[idx] = ((num & 0x7f) as u8 | more);
@ -65,7 +64,7 @@ macro_rules! convert_from {
this.offset = idx; this.offset = idx;
this this
} }
} };
} }
macro_rules! convert_into { macro_rules! convert_into {
@ -84,7 +83,7 @@ macro_rules! convert_into {
let mut shift = 0; let mut shift = 0;
for neg in 1..=len { for neg in 1..=len {
let idx = len-neg; let idx = len - neg;
let val = (bytes[idx] & 0x7f) as $x; let val = (bytes[idx] & 0x7f) as $x;
let shifted = val << shift; let shifted = val << shift;
out |= shifted; out |= shifted;
@ -93,7 +92,7 @@ macro_rules! convert_into {
out out
} }
} };
} }
macro_rules! impl_convert_from_to { macro_rules! impl_convert_from_to {
@ -105,7 +104,7 @@ macro_rules! impl_convert_from_to {
impl Into<$num> for VarUInt<$req> { impl Into<$num> for VarUInt<$req> {
convert_into! { $num } convert_into! { $num }
} }
} };
} }
impl_convert_from_to!(u8, 2, VarU8); impl_convert_from_to!(u8, 2, VarU8);
@ -123,8 +122,9 @@ type VarUsize = VarU32;
type VarUsize = VarU16; type VarUsize = VarU16;
impl<T, const N: usize> From<&T> for VarUInt<N> impl<T, const N: usize> From<&T> for VarUInt<N>
where T: Copy, where
VarUInt<N>: From<T> T: Copy,
VarUInt<N>: From<T>,
{ {
fn from(t: &T) -> Self { fn from(t: &T) -> Self {
(*t).into() (*t).into()

View File

@ -1,12 +1,9 @@
use clap::{Arg, Command}; use clap::{Arg, Command};
use diflouroborane::{config, Diflouroborane}; use diflouroborane::{config, Diflouroborane};
use std::str::FromStr; use std::str::FromStr;
use std::{env, io, io::Write, path::PathBuf}; use std::{env, io, io::Write, path::PathBuf};
use nix::NixPath; use nix::NixPath;
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
@ -125,12 +122,19 @@ fn main() -> anyhow::Result<()> {
unimplemented!() unimplemented!()
} else if matches.is_present("load") { } else if matches.is_present("load") {
let bffh = Diflouroborane::new(config)?; let bffh = Diflouroborane::new(config)?;
if bffh.users.load_file(matches.value_of("load").unwrap()).is_ok() { if bffh
.users
.load_file(matches.value_of("load").unwrap())
.is_ok()
{
tracing::info!("loaded users from {}", matches.value_of("load").unwrap()); tracing::info!("loaded users from {}", matches.value_of("load").unwrap());
} else { } else {
tracing::error!("failed to load users from {}", matches.value_of("load").unwrap()); tracing::error!(
"failed to load users from {}",
matches.value_of("load").unwrap()
);
} }
return Ok(()) return Ok(());
} else { } else {
let keylog = matches.value_of("keylog"); let keylog = matches.value_of("keylog");
// When passed an empty string (i.e no value) take the value from the env // When passed an empty string (i.e no value) take the value from the env

View File

@ -4,48 +4,57 @@ fn main() {
println!(">>> Building version number..."); println!(">>> Building version number...");
let rustc = std::env::var("RUSTC").unwrap(); let rustc = std::env::var("RUSTC").unwrap();
let out = Command::new(rustc).arg("--version") let out = Command::new(rustc)
.arg("--version")
.output() .output()
.expect("failed to run `rustc --version`"); .expect("failed to run `rustc --version`");
let rustc_version = String::from_utf8(out.stdout) let rustc_version =
.expect("rustc --version returned invalid UTF-8"); String::from_utf8(out.stdout).expect("rustc --version returned invalid UTF-8");
let rustc_version = rustc_version.trim(); let rustc_version = rustc_version.trim();
println!("cargo:rustc-env=CARGO_RUSTC_VERSION={}", rustc_version); println!("cargo:rustc-env=CARGO_RUSTC_VERSION={}", rustc_version);
println!("cargo:rerun-if-env-changed=BFFHD_BUILD_TAGGED_RELEASE"); println!("cargo:rerun-if-env-changed=BFFHD_BUILD_TAGGED_RELEASE");
let tagged_release = option_env!("BFFHD_BUILD_TAGGED_RELEASE") == Some("1"); let tagged_release = option_env!("BFFHD_BUILD_TAGGED_RELEASE") == Some("1");
let version_string = if tagged_release { let version_string = if tagged_release {
format!("{version} [{rustc}]", format!(
"{version} [{rustc}]",
version = env!("CARGO_PKG_VERSION"), version = env!("CARGO_PKG_VERSION"),
rustc = rustc_version) rustc = rustc_version
)
} else { } else {
// Build version number using the current git commit id // Build version number using the current git commit id
let out = Command::new("git").arg("rev-list") let out = Command::new("git")
.arg("rev-list")
.args(["HEAD", "-1"]) .args(["HEAD", "-1"])
.output() .output()
.expect("failed to run `git rev-list HEAD -1`"); .expect("failed to run `git rev-list HEAD -1`");
let owned_gitrev = String::from_utf8(out.stdout) let owned_gitrev =
.expect("git rev-list output was not valid UTF8"); String::from_utf8(out.stdout).expect("git rev-list output was not valid UTF8");
let gitrev = owned_gitrev.trim(); let gitrev = owned_gitrev.trim();
let abbrev = match gitrev.len(){ let abbrev = match gitrev.len() {
0 => "unknown", 0 => "unknown",
_ => &gitrev[0..9], _ => &gitrev[0..9],
}; };
let out = Command::new("git").arg("log") let out = Command::new("git")
.arg("log")
.args(["-1", "--format=%as"]) .args(["-1", "--format=%as"])
.output() .output()
.expect("failed to run `git log -1 --format=\"format:%as\"`"); .expect("failed to run `git log -1 --format=\"format:%as\"`");
let commit_date = String::from_utf8(out.stdout) let commit_date = String::from_utf8(out.stdout).expect("git log output was not valid UTF8");
.expect("git log output was not valid UTF8");
let commit_date = commit_date.trim(); let commit_date = commit_date.trim();
format!("{version} ({gitrev} {date}) [{rustc}]", format!(
version=env!("CARGO_PKG_VERSION"), "{version} ({gitrev} {date}) [{rustc}]",
gitrev=abbrev, version = env!("CARGO_PKG_VERSION"),
date=commit_date, gitrev = abbrev,
rustc=rustc_version) date = commit_date,
rustc = rustc_version
)
}; };
println!("cargo:rustc-env=BFFHD_VERSION_STRING={}", version_string); println!("cargo:rustc-env=BFFHD_VERSION_STRING={}", version_string);
println!("cargo:rustc-env=BFFHD_RELEASE_STRING=\"BFFH {}\"", version_string); println!(
"cargo:rustc-env=BFFHD_RELEASE_STRING=\"BFFH {}\"",
version_string
);
} }

View File

@ -1,5 +1,5 @@
use sdk::BoxFuture;
use sdk::initiators::{Initiator, InitiatorError, ResourceID, UpdateSink}; use sdk::initiators::{Initiator, InitiatorError, ResourceID, UpdateSink};
use sdk::BoxFuture;
#[sdk::module] #[sdk::module]
struct Dummy { struct Dummy {
@ -10,11 +10,17 @@ struct Dummy {
} }
impl Initiator for Dummy { impl Initiator for Dummy {
fn start_for(&mut self, machine: ResourceID) -> BoxFuture<'static, Result<(), Box<dyn InitiatorError>>> { fn start_for(
&mut self,
machine: ResourceID,
) -> BoxFuture<'static, Result<(), Box<dyn InitiatorError>>> {
todo!() todo!()
} }
fn run(&mut self, request: &mut UpdateSink) -> BoxFuture<'static, Result<(), Box<dyn InitiatorError>>> { fn run(
&mut self,
request: &mut UpdateSink,
) -> BoxFuture<'static, Result<(), Box<dyn InitiatorError>>> {
todo!() todo!()
} }
} }

View File

@ -1,10 +1,10 @@
use proc_macro::TokenStream; use proc_macro::TokenStream;
use std::sync::Mutex;
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use syn::{braced, parse_macro_input, Field, Ident, Token, Visibility, Type}; use std::sync::Mutex;
use syn::parse::{Parse, ParseStream}; use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated; use syn::punctuated::Punctuated;
use syn::token::Brace; use syn::token::Brace;
use syn::{braced, parse_macro_input, Field, Ident, Token, Type, Visibility};
mod keywords { mod keywords {
syn::custom_keyword!(initiator); syn::custom_keyword!(initiator);
@ -32,8 +32,10 @@ impl Parse for ModuleAttrs {
} else if lookahead.peek(keywords::sensor) { } else if lookahead.peek(keywords::sensor) {
Ok(ModuleAttrs::Sensor) Ok(ModuleAttrs::Sensor)
} else { } else {
Err(input.error("Module type must be empty or one of \"initiator\", \"actor\", or \ Err(input.error(
\"sensor\"")) "Module type must be empty or one of \"initiator\", \"actor\", or \
\"sensor\"",
))
} }
} }
} }

View File

@ -1,10 +1,4 @@
pub use diflouroborane::{ pub use diflouroborane::{
initiators::{ initiators::{Initiator, InitiatorError, UpdateError, UpdateSink},
UpdateSink,
UpdateError,
Initiator,
InitiatorError,
},
resource::claim::ResourceID, resource::claim::ResourceID,
}; };

View File

@ -1,5 +1,4 @@
#[forbid(private_in_public)] #[forbid(private_in_public)]
pub use sdk_proc::module; pub use sdk_proc::module;
pub use futures_util::future::BoxFuture; pub use futures_util::future::BoxFuture;

View File

@ -1,19 +1,19 @@
use executor::prelude::*;
use criterion::{black_box, criterion_group, criterion_main, Criterion}; use criterion::{black_box, criterion_group, criterion_main, Criterion};
use executor::prelude::*;
fn increment(b: &mut Criterion) { fn increment(b: &mut Criterion) {
let mut sum = 0; let mut sum = 0;
let executor = Executor::new(); let executor = Executor::new();
b.bench_function("Executor::run", |b| b.iter(|| { b.bench_function("Executor::run", |b| {
executor.run( b.iter(|| {
async { executor.run(async {
(0..10_000_000).for_each(|_| { (0..10_000_000).for_each(|_| {
sum += 1; sum += 1;
}); });
}, });
); })
})); });
black_box(sum); black_box(sum);
} }

View File

@ -1,8 +1,8 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use executor::load_balancer; use executor::load_balancer;
use executor::prelude::*; use executor::prelude::*;
use futures_timer::Delay; use futures_timer::Delay;
use std::time::Duration; use std::time::Duration;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
#[cfg(feature = "tokio-runtime")] #[cfg(feature = "tokio-runtime")]
mod benches { mod benches {
@ -27,7 +27,6 @@ mod benches {
pub fn spawn_single(b: &mut Criterion) { pub fn spawn_single(b: &mut Criterion) {
_spawn_single(b); _spawn_single(b);
} }
} }
criterion_group!(spawn, benches::spawn_lot, benches::spawn_single); criterion_group!(spawn, benches::spawn_lot, benches::spawn_single);
@ -36,29 +35,29 @@ criterion_main!(spawn);
// Benchmark for a 10K burst task spawn // Benchmark for a 10K burst task spawn
fn _spawn_lot(b: &mut Criterion) { fn _spawn_lot(b: &mut Criterion) {
let executor = Executor::new(); let executor = Executor::new();
b.bench_function("spawn_lot", |b| b.iter(|| { b.bench_function("spawn_lot", |b| {
b.iter(|| {
let _ = (0..10_000) let _ = (0..10_000)
.map(|_| { .map(|_| {
executor.spawn( executor.spawn(async {
async {
let duration = Duration::from_millis(1); let duration = Duration::from_millis(1);
Delay::new(duration).await; Delay::new(duration).await;
}, })
)
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
})); })
});
} }
// Benchmark for a single task spawn // Benchmark for a single task spawn
fn _spawn_single(b: &mut Criterion) { fn _spawn_single(b: &mut Criterion) {
let executor = Executor::new(); let executor = Executor::new();
b.bench_function("spawn single", |b| b.iter(|| { b.bench_function("spawn single", |b| {
executor.spawn( b.iter(|| {
async { executor.spawn(async {
let duration = Duration::from_millis(1); let duration = Duration::from_millis(1);
Delay::new(duration).await; Delay::new(duration).await;
}, });
); })
})); });
} }

View File

@ -1,7 +1,7 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use executor::load_balancer::{core_count, get_cores, stats, SmpStats}; use executor::load_balancer::{core_count, get_cores, stats, SmpStats};
use executor::placement; use executor::placement;
use std::thread; use std::thread;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
fn stress_stats<S: SmpStats + Sync + Send>(stats: &'static S) { fn stress_stats<S: SmpStats + Sync + Send>(stats: &'static S) {
let mut handles = Vec::with_capacity(*core_count()); let mut handles = Vec::with_capacity(*core_count());
@ -27,9 +27,11 @@ fn stress_stats<S: SmpStats + Sync + Send>(stats: &'static S) {
// 158,278 ns/iter (+/- 117,103) // 158,278 ns/iter (+/- 117,103)
fn lockless_stats_bench(b: &mut Criterion) { fn lockless_stats_bench(b: &mut Criterion) {
b.bench_function("stress_stats", |b| b.iter(|| { b.bench_function("stress_stats", |b| {
b.iter(|| {
stress_stats(stats()); stress_stats(stats());
})); })
});
} }
fn lockless_stats_bad_load(b: &mut Criterion) { fn lockless_stats_bad_load(b: &mut Criterion) {
@ -45,9 +47,11 @@ fn lockless_stats_bad_load(b: &mut Criterion) {
} }
} }
b.bench_function("get_sorted_load", |b| b.iter(|| { b.bench_function("get_sorted_load", |b| {
b.iter(|| {
let _sorted_load = stats.get_sorted_load(); let _sorted_load = stats.get_sorted_load();
})); })
});
} }
fn lockless_stats_good_load(b: &mut Criterion) { fn lockless_stats_good_load(b: &mut Criterion) {
@ -59,11 +63,17 @@ fn lockless_stats_good_load(b: &mut Criterion) {
stats.store_load(i, i); stats.store_load(i, i);
} }
b.bench_function("get_sorted_load", |b| b.iter(|| { b.bench_function("get_sorted_load", |b| {
b.iter(|| {
let _sorted_load = stats.get_sorted_load(); let _sorted_load = stats.get_sorted_load();
})); })
});
} }
criterion_group!(stats_bench, lockless_stats_bench, lockless_stats_bad_load, criterion_group!(
lockless_stats_good_load); stats_bench,
lockless_stats_bench,
lockless_stats_bad_load,
lockless_stats_good_load
);
criterion_main!(stats_bench); criterion_main!(stats_bench);

View File

@ -1,12 +1,12 @@
use executor::pool;
use executor::prelude::*;
use futures_util::{stream::FuturesUnordered, Stream};
use futures_util::{FutureExt, StreamExt};
use lightproc::prelude::RecoverableHandle;
use std::io::Write; use std::io::Write;
use std::panic::resume_unwind; use std::panic::resume_unwind;
use std::rc::Rc; use std::rc::Rc;
use std::time::Duration; use std::time::Duration;
use futures_util::{stream::FuturesUnordered, Stream};
use futures_util::{FutureExt, StreamExt};
use executor::pool;
use executor::prelude::*;
use lightproc::prelude::RecoverableHandle;
fn main() { fn main() {
tracing_subscriber::fmt() tracing_subscriber::fmt()
@ -24,9 +24,9 @@ fn main() {
let executor = Executor::new(); let executor = Executor::new();
let mut handles: FuturesUnordered<RecoverableHandle<usize>> = (0..2000).map(|n| { let mut handles: FuturesUnordered<RecoverableHandle<usize>> = (0..2000)
executor.spawn( .map(|n| {
async move { executor.spawn(async move {
let m: u64 = rand::random::<u64>() % 200; let m: u64 = rand::random::<u64>() % 200;
tracing::debug!("Will sleep {} * 1 ms", m); tracing::debug!("Will sleep {} * 1 ms", m);
// simulate some really heavy load. // simulate some really heavy load.
@ -34,9 +34,9 @@ fn main() {
async_std::task::sleep(Duration::from_millis(1)).await; async_std::task::sleep(Duration::from_millis(1)).await;
} }
return n; return n;
}, })
) })
}).collect(); .collect();
//let handle = handles.fuse().all(|opt| async move { opt.is_some() }); //let handle = handles.fuse().all(|opt| async move { opt.is_some() });
/* Futures passed to `spawn` need to be `Send` so this won't work: /* Futures passed to `spawn` need to be `Send` so this won't work:
@ -58,12 +58,12 @@ fn main() {
// However, you can't pass it a future outright but have to hand it a generator creating the // However, you can't pass it a future outright but have to hand it a generator creating the
// future on the correct thread. // future on the correct thread.
let fut = async { let fut = async {
let local_futs: FuturesUnordered<_> = (0..200).map(|ref n| { let local_futs: FuturesUnordered<_> = (0..200)
.map(|ref n| {
let n = *n; let n = *n;
let exe = executor.clone(); let exe = executor.clone();
async move { async move {
exe.spawn( exe.spawn(async {
async {
let tid = std::thread::current().id(); let tid = std::thread::current().id();
tracing::info!("spawn_local({}) is on thread {:?}", n, tid); tracing::info!("spawn_local({}) is on thread {:?}", n, tid);
exe.spawn_local(async move { exe.spawn_local(async move {
@ -86,10 +86,11 @@ fn main() {
*rc *rc
}) })
})
.await
} }
).await })
} .collect();
}).collect();
local_futs local_futs
}; };
@ -108,12 +109,10 @@ fn main() {
async_std::task::sleep(Duration::from_secs(20)).await; async_std::task::sleep(Duration::from_secs(20)).await;
tracing::info!("This is taking too long."); tracing::info!("This is taking too long.");
}; };
executor.run( executor.run(async {
async {
let res = futures_util::select! { let res = futures_util::select! {
_ = a.fuse() => {}, _ = a.fuse() => {},
_ = b.fuse() => {}, _ = b.fuse() => {},
}; };
}, });
);
} }

View File

@ -28,10 +28,10 @@
#![forbid(unused_import_braces)] #![forbid(unused_import_braces)]
pub mod load_balancer; pub mod load_balancer;
pub mod manage;
pub mod placement; pub mod placement;
pub mod pool; pub mod pool;
pub mod run; pub mod run;
pub mod manage;
mod thread_manager; mod thread_manager;
mod worker; mod worker;

View File

@ -1,6 +1,2 @@
/// View and Manage the current processes of this executor /// View and Manage the current processes of this executor
pub struct Manager { pub struct Manager {}
}

View File

@ -7,19 +7,19 @@
//! [`spawn`]: crate::pool::spawn //! [`spawn`]: crate::pool::spawn
//! [`Worker`]: crate::run_queue::Worker //! [`Worker`]: crate::run_queue::Worker
use std::cell::Cell; use crate::run::block;
use crate::thread_manager::{ThreadManager, DynamicRunner}; use crate::thread_manager::{DynamicRunner, ThreadManager};
use crate::worker::{Sleeper, WorkerThread};
use crossbeam_deque::{Injector, Stealer};
use lightproc::lightproc::LightProc; use lightproc::lightproc::LightProc;
use lightproc::recoverable_handle::RecoverableHandle; use lightproc::recoverable_handle::RecoverableHandle;
use std::cell::Cell;
use std::future::Future; use std::future::Future;
use std::iter::Iterator; use std::iter::Iterator;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use crossbeam_deque::{Injector, Stealer};
use crate::run::block;
use crate::worker::{Sleeper, WorkerThread};
#[derive(Debug)] #[derive(Debug)]
struct Spooler<'a> { struct Spooler<'a> {
@ -31,10 +31,13 @@ struct Spooler<'a> {
impl Spooler<'_> { impl Spooler<'_> {
pub fn new() -> Self { pub fn new() -> Self {
let spool = Arc::new(Injector::new()); let spool = Arc::new(Injector::new());
let threads = Box::leak(Box::new( let threads = Box::leak(Box::new(ThreadManager::new(2, AsyncRunner, spool.clone())));
ThreadManager::new(2, AsyncRunner, spool.clone())));
threads.initialize(); threads.initialize();
Self { spool, threads, _marker: PhantomData } Self {
spool,
threads,
_marker: PhantomData,
}
} }
} }
@ -53,9 +56,7 @@ impl<'a, 'executor: 'a> Executor<'executor> {
fn schedule(&self) -> impl Fn(LightProc) + 'a { fn schedule(&self) -> impl Fn(LightProc) + 'a {
let task_queue = self.spooler.spool.clone(); let task_queue = self.spooler.spool.clone();
move |lightproc: LightProc| { move |lightproc: LightProc| task_queue.push(lightproc)
task_queue.push(lightproc)
}
} }
/// ///
@ -98,8 +99,7 @@ impl<'a, 'executor: 'a> Executor<'executor> {
F: Future<Output = R> + Send + 'a, F: Future<Output = R> + Send + 'a,
R: Send + 'a, R: Send + 'a,
{ {
let (task, handle) = let (task, handle) = LightProc::recoverable(future, self.schedule());
LightProc::recoverable(future, self.schedule());
task.schedule(); task.schedule();
handle handle
} }
@ -109,8 +109,7 @@ impl<'a, 'executor: 'a> Executor<'executor> {
F: Future<Output = R> + 'a, F: Future<Output = R> + 'a,
R: Send + 'a, R: Send + 'a,
{ {
let (task, handle) = let (task, handle) = LightProc::recoverable(future, schedule_local());
LightProc::recoverable(future, schedule_local());
task.schedule(); task.schedule();
handle handle
} }
@ -174,17 +173,20 @@ impl DynamicRunner for AsyncRunner {
sleeper sleeper
} }
fn run_static<'b>(fences: impl Iterator<Item=&'b Stealer<LightProc>>, park_timeout: Duration) -> ! { fn run_static<'b>(
fences: impl Iterator<Item = &'b Stealer<LightProc>>,
park_timeout: Duration,
) -> ! {
let worker = get_worker(); let worker = get_worker();
worker.run_timeout(fences, park_timeout) worker.run_timeout(fences, park_timeout)
} }
fn run_dynamic<'b>(fences: impl Iterator<Item=&'b Stealer<LightProc>>) -> ! { fn run_dynamic<'b>(fences: impl Iterator<Item = &'b Stealer<LightProc>>) -> ! {
let worker = get_worker(); let worker = get_worker();
worker.run(fences) worker.run(fences)
} }
fn run_standalone<'b>(fences: impl Iterator<Item=&'b Stealer<LightProc>>) { fn run_standalone<'b>(fences: impl Iterator<Item = &'b Stealer<LightProc>>) {
let worker = get_worker(); let worker = get_worker();
worker.run_once(fences) worker.run_once(fences)
} }
@ -196,10 +198,9 @@ thread_local! {
fn get_worker() -> &'static WorkerThread<'static, LightProc> { fn get_worker() -> &'static WorkerThread<'static, LightProc> {
WORKER.with(|cell| { WORKER.with(|cell| {
let worker = unsafe { let worker = unsafe { &*cell.as_ptr() as &'static Option<WorkerThread<_>> };
&*cell.as_ptr() as &'static Option<WorkerThread<_>> worker
}; .as_ref()
worker.as_ref()
.expect("AsyncRunner running outside Executor context") .expect("AsyncRunner running outside Executor context")
}) })
} }

View File

@ -45,13 +45,18 @@
//! Throughput hogs determined by a combination of job in / job out frequency and current scheduler task assignment frequency. //! Throughput hogs determined by a combination of job in / job out frequency and current scheduler task assignment frequency.
//! Threshold of EMA difference is eluded by machine epsilon for floating point arithmetic errors. //! Threshold of EMA difference is eluded by machine epsilon for floating point arithmetic errors.
use crate::worker::Sleeper;
use crate::{load_balancer, placement}; use crate::{load_balancer, placement};
use core::fmt; use core::fmt;
use crossbeam_channel::bounded;
use crossbeam_deque::{Injector, Stealer};
use crossbeam_queue::ArrayQueue; use crossbeam_queue::ArrayQueue;
use fmt::{Debug, Formatter}; use fmt::{Debug, Formatter};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use lightproc::lightproc::LightProc;
use placement::CoreId; use placement::CoreId;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::sync::{Arc, RwLock};
use std::time::Duration; use std::time::Duration;
use std::{ use std::{
sync::{ sync::{
@ -60,12 +65,7 @@ use std::{
}, },
thread, thread,
}; };
use std::sync::{Arc, RwLock};
use crossbeam_channel::bounded;
use crossbeam_deque::{Injector, Stealer};
use tracing::{debug, trace}; use tracing::{debug, trace};
use lightproc::lightproc::LightProc;
use crate::worker::Sleeper;
/// The default thread park timeout before checking for new tasks. /// The default thread park timeout before checking for new tasks.
const THREAD_PARK_TIMEOUT: Duration = Duration::from_millis(1); const THREAD_PARK_TIMEOUT: Duration = Duration::from_millis(1);
@ -113,10 +113,12 @@ lazy_static! {
pub trait DynamicRunner { pub trait DynamicRunner {
fn setup(task_queue: Arc<Injector<LightProc>>) -> Sleeper<LightProc>; fn setup(task_queue: Arc<Injector<LightProc>>) -> Sleeper<LightProc>;
fn run_static<'b>(fences: impl Iterator<Item=&'b Stealer<LightProc>>, fn run_static<'b>(
park_timeout: Duration) -> !; fences: impl Iterator<Item = &'b Stealer<LightProc>>,
fn run_dynamic<'b>(fences: impl Iterator<Item=&'b Stealer<LightProc>>) -> !; park_timeout: Duration,
fn run_standalone<'b>(fences: impl Iterator<Item=&'b Stealer<LightProc>>); ) -> !;
fn run_dynamic<'b>(fences: impl Iterator<Item = &'b Stealer<LightProc>>) -> !;
fn run_standalone<'b>(fences: impl Iterator<Item = &'b Stealer<LightProc>>);
} }
/// The `ThreadManager` is creates and destroys worker threads depending on demand according to /// The `ThreadManager` is creates and destroys worker threads depending on demand according to
@ -183,11 +185,14 @@ impl<Runner: Debug> Debug for ThreadManager<Runner> {
} }
fmt.debug_struct("DynamicPoolManager") fmt.debug_struct("DynamicPoolManager")
.field("thread pool", &ThreadCount( .field(
"thread pool",
&ThreadCount(
&self.static_threads, &self.static_threads,
&self.dynamic_threads, &self.dynamic_threads,
&self.parked_threads.len(), &self.parked_threads.len(),
)) ),
)
.field("runner", &self.runner) .field("runner", &self.runner)
.field("last_frequency", &self.last_frequency) .field("last_frequency", &self.last_frequency)
.finish() .finish()
@ -195,7 +200,11 @@ impl<Runner: Debug> Debug for ThreadManager<Runner> {
} }
impl<Runner: DynamicRunner + Sync + Send> ThreadManager<Runner> { impl<Runner: DynamicRunner + Sync + Send> ThreadManager<Runner> {
pub fn new(static_threads: usize, runner: Runner, task_queue: Arc<Injector<LightProc>>) -> Self { pub fn new(
static_threads: usize,
runner: Runner,
task_queue: Arc<Injector<LightProc>>,
) -> Self {
let dynamic_threads = 1.max(num_cpus::get().checked_sub(static_threads).unwrap_or(0)); let dynamic_threads = 1.max(num_cpus::get().checked_sub(static_threads).unwrap_or(0));
let parked_threads = ArrayQueue::new(1.max(static_threads + dynamic_threads)); let parked_threads = ArrayQueue::new(1.max(static_threads + dynamic_threads));
let fences = Arc::new(RwLock::new(Vec::new())); let fences = Arc::new(RwLock::new(Vec::new()));
@ -252,7 +261,10 @@ impl<Runner: DynamicRunner + Sync + Send> ThreadManager<Runner> {
}); });
// Dynamic thread manager that will allow us to unpark threads when needed // Dynamic thread manager that will allow us to unpark threads when needed
debug!("spooling up {} dynamic worker threads", self.dynamic_threads); debug!(
"spooling up {} dynamic worker threads",
self.dynamic_threads
);
(0..self.dynamic_threads).for_each(|_| { (0..self.dynamic_threads).for_each(|_| {
let tx = tx.clone(); let tx = tx.clone();
let fencelock = fencelock.clone(); let fencelock = fencelock.clone();
@ -302,10 +314,11 @@ impl<Runner: DynamicRunner + Sync + Send> ThreadManager<Runner> {
/// Provision threads takes a number of threads that need to be made available. /// Provision threads takes a number of threads that need to be made available.
/// It will try to unpark threads from the dynamic pool, and spawn more threads if needs be. /// It will try to unpark threads from the dynamic pool, and spawn more threads if needs be.
pub fn provision_threads(&'static self, pub fn provision_threads(
&'static self,
n: usize, n: usize,
fencelock: &Arc<RwLock<Vec<Stealer<LightProc>>>>) fencelock: &Arc<RwLock<Vec<Stealer<LightProc>>>>,
{ ) {
let rem = self.unpark_thread(n); let rem = self.unpark_thread(n);
if rem != 0 { if rem != 0 {
debug!("no more threads to unpark, spawning {} new threads", rem); debug!("no more threads to unpark, spawning {} new threads", rem);
@ -391,7 +404,5 @@ impl<Runner: DynamicRunner + Sync + Send> ThreadManager<Runner> {
/// on the request rate. /// on the request rate.
/// ///
/// It uses frequency based calculation to define work. Utilizing average processing rate. /// It uses frequency based calculation to define work. Utilizing average processing rate.
fn scale_pool(&'static self) { fn scale_pool(&'static self) {}
}
} }

View File

@ -1,10 +1,10 @@
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::Duration;
use crossbeam_deque::{Injector, Steal, Stealer, Worker}; use crossbeam_deque::{Injector, Steal, Stealer, Worker};
use crossbeam_queue::SegQueue; use crossbeam_queue::SegQueue;
use crossbeam_utils::sync::{Parker, Unparker}; use crossbeam_utils::sync::{Parker, Unparker};
use lightproc::prelude::LightProc; use lightproc::prelude::LightProc;
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::Duration;
pub trait Runnable { pub trait Runnable {
fn run(self); fn run(self);
@ -61,8 +61,14 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
let unparker = parker.unparker().clone(); let unparker = parker.unparker().clone();
( (
Self { task_queue, tasks, local_tasks, parker, _marker }, Self {
Sleeper { stealer, unparker } task_queue,
tasks,
local_tasks,
parker,
_marker,
},
Sleeper { stealer, unparker },
) )
} }
@ -71,10 +77,8 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
} }
/// Run this worker thread "forever" (i.e. until the thread panics or is otherwise killed) /// Run this worker thread "forever" (i.e. until the thread panics or is otherwise killed)
pub fn run(&self, fences: impl Iterator<Item=&'a Stealer<T>>) -> ! { pub fn run(&self, fences: impl Iterator<Item = &'a Stealer<T>>) -> ! {
let fences: Vec<Stealer<T>> = fences let fences: Vec<Stealer<T>> = fences.map(|stealer| stealer.clone()).collect();
.map(|stealer| stealer.clone())
.collect();
loop { loop {
self.run_inner(&fences); self.run_inner(&fences);
@ -82,10 +86,12 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
} }
} }
pub fn run_timeout(&self, fences: impl Iterator<Item=&'a Stealer<T>>, timeout: Duration) -> ! { pub fn run_timeout(
let fences: Vec<Stealer<T>> = fences &self,
.map(|stealer| stealer.clone()) fences: impl Iterator<Item = &'a Stealer<T>>,
.collect(); timeout: Duration,
) -> ! {
let fences: Vec<Stealer<T>> = fences.map(|stealer| stealer.clone()).collect();
loop { loop {
self.run_inner(&fences); self.run_inner(&fences);
@ -93,10 +99,8 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
} }
} }
pub fn run_once(&self, fences: impl Iterator<Item=&'a Stealer<T>>) { pub fn run_once(&self, fences: impl Iterator<Item = &'a Stealer<T>>) {
let fences: Vec<Stealer<T>> = fences let fences: Vec<Stealer<T>> = fences.map(|stealer| stealer.clone()).collect();
.map(|stealer| stealer.clone())
.collect();
self.run_inner(fences); self.run_inner(fences);
} }
@ -123,17 +127,19 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
Steal::Success(task) => { Steal::Success(task) => {
task.run(); task.run();
continue 'work; continue 'work;
}, }
// If there is no more work to steal from the global queue, try other // If there is no more work to steal from the global queue, try other
// workers next // workers next
Steal::Empty => break, Steal::Empty => break,
// If a race condition occurred try again with backoff // If a race condition occurred try again with backoff
Steal::Retry => for _ in 0..(1 << i) { Steal::Retry => {
for _ in 0..(1 << i) {
core::hint::spin_loop(); core::hint::spin_loop();
i += 1; i += 1;
}, }
}
} }
} }
@ -145,7 +151,7 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
Steal::Success(task) => { Steal::Success(task) => {
task.run(); task.run();
continue 'work; continue 'work;
}, }
// If no other worker has work to do we're done once again. // If no other worker has work to do we're done once again.
Steal::Empty => break, Steal::Empty => break,
@ -169,6 +175,6 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
} }
#[inline(always)] #[inline(always)]
fn select_fence<'a, T>(fences: impl Iterator<Item=&'a Stealer<T>>) -> Option<&'a Stealer<T>> { fn select_fence<'a, T>(fences: impl Iterator<Item = &'a Stealer<T>>) -> Option<&'a Stealer<T>> {
fences.max_by_key(|fence| fence.len()) fences.max_by_key(|fence| fence.len())
} }

View File

@ -1,8 +1,8 @@
use std::io::Write; use executor::prelude::{spawn, ProcStack};
use executor::run::run; use executor::run::run;
use std::io::Write;
use std::thread; use std::thread;
use std::time::Duration; use std::time::Duration;
use executor::prelude::{ProcStack, spawn};
#[cfg(feature = "tokio-runtime")] #[cfg(feature = "tokio-runtime")]
mod tokio_tests { mod tokio_tests {
@ -21,13 +21,11 @@ mod no_tokio_tests {
} }
fn run_test() { fn run_test() {
let handle = spawn( let handle = spawn(async {
async {
let duration = Duration::from_millis(1); let duration = Duration::from_millis(1);
thread::sleep(duration); thread::sleep(duration);
//42 //42
}, });
);
let output = run(handle, ProcStack {}); let output = run(handle, ProcStack {});

View File

@ -1,11 +1,11 @@
use executor::blocking; use executor::blocking;
use executor::prelude::ProcStack;
use executor::run::run; use executor::run::run;
use futures_util::future::join_all; use futures_util::future::join_all;
use lightproc::recoverable_handle::RecoverableHandle; use lightproc::recoverable_handle::RecoverableHandle;
use std::thread; use std::thread;
use std::time::Duration; use std::time::Duration;
use std::time::Instant; use std::time::Instant;
use executor::prelude::ProcStack;
// Test for slow joins without task bursts during joins. // Test for slow joins without task bursts during joins.
#[test] #[test]
@ -17,12 +17,10 @@ fn slow_join() {
// Send an initial batch of million bursts. // Send an initial batch of million bursts.
let handles = (0..1_000_000) let handles = (0..1_000_000)
.map(|_| { .map(|_| {
blocking::spawn_blocking( blocking::spawn_blocking(async {
async {
let duration = Duration::from_millis(1); let duration = Duration::from_millis(1);
thread::sleep(duration); thread::sleep(duration);
}, })
)
}) })
.collect::<Vec<RecoverableHandle<()>>>(); .collect::<Vec<RecoverableHandle<()>>>();
@ -35,12 +33,10 @@ fn slow_join() {
// Spawn yet another batch of work on top of it // Spawn yet another batch of work on top of it
let handles = (0..10_000) let handles = (0..10_000)
.map(|_| { .map(|_| {
blocking::spawn_blocking( blocking::spawn_blocking(async {
async {
let duration = Duration::from_millis(100); let duration = Duration::from_millis(100);
thread::sleep(duration); thread::sleep(duration);
}, })
)
}) })
.collect::<Vec<RecoverableHandle<()>>>(); .collect::<Vec<RecoverableHandle<()>>>();
@ -63,12 +59,10 @@ fn slow_join_interrupted() {
// Send an initial batch of million bursts. // Send an initial batch of million bursts.
let handles = (0..1_000_000) let handles = (0..1_000_000)
.map(|_| { .map(|_| {
blocking::spawn_blocking( blocking::spawn_blocking(async {
async {
let duration = Duration::from_millis(1); let duration = Duration::from_millis(1);
thread::sleep(duration); thread::sleep(duration);
}, })
)
}) })
.collect::<Vec<RecoverableHandle<()>>>(); .collect::<Vec<RecoverableHandle<()>>>();
@ -82,12 +76,10 @@ fn slow_join_interrupted() {
// Spawn yet another batch of work on top of it // Spawn yet another batch of work on top of it
let handles = (0..10_000) let handles = (0..10_000)
.map(|_| { .map(|_| {
blocking::spawn_blocking( blocking::spawn_blocking(async {
async {
let duration = Duration::from_millis(100); let duration = Duration::from_millis(100);
thread::sleep(duration); thread::sleep(duration);
}, })
)
}) })
.collect::<Vec<RecoverableHandle<()>>>(); .collect::<Vec<RecoverableHandle<()>>>();
@ -111,12 +103,10 @@ fn longhauling_task_join() {
// First batch of overhauling tasks // First batch of overhauling tasks
let _ = (0..100_000) let _ = (0..100_000)
.map(|_| { .map(|_| {
blocking::spawn_blocking( blocking::spawn_blocking(async {
async {
let duration = Duration::from_millis(1000); let duration = Duration::from_millis(1000);
thread::sleep(duration); thread::sleep(duration);
}, })
)
}) })
.collect::<Vec<RecoverableHandle<()>>>(); .collect::<Vec<RecoverableHandle<()>>>();
@ -127,12 +117,10 @@ fn longhauling_task_join() {
// Send yet another medium sized batch to see how it scales. // Send yet another medium sized batch to see how it scales.
let handles = (0..10_000) let handles = (0..10_000)
.map(|_| { .map(|_| {
blocking::spawn_blocking( blocking::spawn_blocking(async {
async {
let duration = Duration::from_millis(100); let duration = Duration::from_millis(100);
thread::sleep(duration); thread::sleep(duration);
}, })
)
}) })
.collect::<Vec<RecoverableHandle<()>>>(); .collect::<Vec<RecoverableHandle<()>>>();

View File

@ -1,11 +1,11 @@
use std::any::Any;
use std::fmt::Debug;
use std::ops::Deref;
use crossbeam::channel::{unbounded, Sender}; use crossbeam::channel::{unbounded, Sender};
use futures_executor as executor; use futures_executor as executor;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use lightproc::prelude::*; use lightproc::prelude::*;
use std::any::Any;
use std::fmt::Debug;
use std::future::Future; use std::future::Future;
use std::ops::Deref;
use std::thread; use std::thread;
fn spawn_on_thread<F, R>(future: F) -> RecoverableHandle<R> fn spawn_on_thread<F, R>(future: F) -> RecoverableHandle<R>
@ -30,20 +30,17 @@ where
} }
let schedule = |t| (QUEUE.deref()).send(t).unwrap(); let schedule = |t| (QUEUE.deref()).send(t).unwrap();
let (proc, handle) = LightProc::recoverable( let (proc, handle) = LightProc::recoverable(future, schedule);
future,
schedule
);
let handle = handle let handle = handle.on_panic(
.on_panic(|err: Box<dyn Any + Send>| { |err: Box<dyn Any + Send>| match err.downcast::<&'static str>() {
match err.downcast::<&'static str>() {
Ok(reason) => println!("Future panicked: {}", &reason), Ok(reason) => println!("Future panicked: {}", &reason),
Err(err) => Err(err) => println!(
println!("Future panicked with a non-text reason of typeid {:?}", "Future panicked with a non-text reason of typeid {:?}",
err.type_id()), err.type_id()
} ),
}); },
);
proc.schedule(); proc.schedule();

View File

@ -14,15 +14,10 @@ where
{ {
let (sender, receiver) = channel::unbounded(); let (sender, receiver) = channel::unbounded();
let future = async move { let future = async move { fut.await };
fut.await
};
let schedule = move |t| sender.send(t).unwrap(); let schedule = move |t| sender.send(t).unwrap();
let (proc, handle) = LightProc::build( let (proc, handle) = LightProc::build(future, schedule);
future,
schedule,
);
proc.schedule(); proc.schedule();

View File

@ -77,7 +77,8 @@ impl LightProc {
/// }); /// });
/// ``` /// ```
pub fn recoverable<'a, F, R, S>(future: F, schedule: S) -> (Self, RecoverableHandle<R>) pub fn recoverable<'a, F, R, S>(future: F, schedule: S) -> (Self, RecoverableHandle<R>)
where F: Future<Output=R> + 'a, where
F: Future<Output = R> + 'a,
R: 'a, R: 'a,
S: Fn(LightProc) + 'a, S: Fn(LightProc) + 'a,
{ {
@ -115,7 +116,8 @@ impl LightProc {
/// ); /// );
/// ``` /// ```
pub fn build<'a, F, R, S>(future: F, schedule: S) -> (Self, ProcHandle<R>) pub fn build<'a, F, R, S>(future: F, schedule: S) -> (Self, ProcHandle<R>)
where F: Future<Output=R> + 'a, where
F: Future<Output = R> + 'a,
R: 'a, R: 'a,
S: Fn(LightProc) + 'a, S: Fn(LightProc) + 'a,
{ {

View File

@ -44,12 +44,10 @@ impl ProcData {
let (flags, references) = state.parts(); let (flags, references) = state.parts();
let new = State::new(flags | CLOSED, references); let new = State::new(flags | CLOSED, references);
// Mark the proc as closed. // Mark the proc as closed.
match self.state.compare_exchange_weak( match self
state, .state
new, .compare_exchange_weak(state, new, Ordering::AcqRel, Ordering::Acquire)
Ordering::AcqRel, {
Ordering::Acquire,
) {
Ok(_) => { Ok(_) => {
// Notify the awaiter that the proc has been closed. // Notify the awaiter that the proc has been closed.
if state.is_awaiter() { if state.is_awaiter() {
@ -117,7 +115,8 @@ impl ProcData {
// Release the lock. If we've cleared the awaiter, then also unset the awaiter flag. // Release the lock. If we've cleared the awaiter, then also unset the awaiter flag.
if new_is_none { if new_is_none {
self.state.fetch_and((!LOCKED & !AWAITER).into(), Ordering::Release); self.state
.fetch_and((!LOCKED & !AWAITER).into(), Ordering::Release);
} else { } else {
self.state.fetch_and((!LOCKED).into(), Ordering::Release); self.state.fetch_and((!LOCKED).into(), Ordering::Release);
} }
@ -142,9 +141,7 @@ impl Debug for ProcData {
.field("ref_count", &state.get_refcount()) .field("ref_count", &state.get_refcount())
.finish() .finish()
} else { } else {
fmt.debug_struct("ProcData") fmt.debug_struct("ProcData").field("state", &state).finish()
.field("state", &state)
.finish()
} }
} }
} }

View File

@ -273,9 +273,7 @@ where
let raw = Self::from_ptr(ptr); let raw = Self::from_ptr(ptr);
// Decrement the reference count. // Decrement the reference count.
let new = (*raw.pdata) let new = (*raw.pdata).state.fetch_sub(1, Ordering::AcqRel);
.state
.fetch_sub(1, Ordering::AcqRel);
let new = new.set_refcount(new.get_refcount().saturating_sub(1)); let new = new.set_refcount(new.get_refcount().saturating_sub(1));
// If this was the last reference to the proc and the `ProcHandle` has been dropped as // If this was the last reference to the proc and the `ProcHandle` has been dropped as
@ -444,13 +442,12 @@ where
// was woken and then clean up its resources. // was woken and then clean up its resources.
let (flags, references) = state.parts(); let (flags, references) = state.parts();
let flags = if state.is_closed() { let flags = if state.is_closed() {
flags & !( RUNNING | SCHEDULED ) flags & !(RUNNING | SCHEDULED)
} else { } else {
flags & !RUNNING flags & !RUNNING
}; };
let new = State::new(flags, references); let new = State::new(flags, references);
// Mark the proc as not running. // Mark the proc as not running.
match (*raw.pdata).state.compare_exchange_weak( match (*raw.pdata).state.compare_exchange_weak(
state, state,
@ -502,7 +499,7 @@ impl<'a, F, R, S> Copy for RawProc<'a, F, R, S> {}
/// A guard that closes the proc if polling its future panics. /// A guard that closes the proc if polling its future panics.
struct Guard<'a, F, R, S>(RawProc<'a, F, R, S>) struct Guard<'a, F, R, S>(RawProc<'a, F, R, S>)
where where
F: Future<Output = R> + 'a, F: Future<Output = R> + 'a,
R: 'a, R: 'a,
S: Fn(LightProc) + 'a; S: Fn(LightProc) + 'a;

View File

@ -1,9 +1,9 @@
//! //!
//! Handle for recoverable process //! Handle for recoverable process
use std::any::Any;
use crate::proc_data::ProcData; use crate::proc_data::ProcData;
use crate::proc_handle::ProcHandle; use crate::proc_handle::ProcHandle;
use crate::state::State; use crate::state::State;
use std::any::Any;
use std::fmt::{self, Debug, Formatter}; use std::fmt::{self, Debug, Formatter};
use std::future::Future; use std::future::Future;
use std::pin::Pin; use std::pin::Pin;
@ -80,12 +80,12 @@ impl<R> RecoverableHandle<R> {
/// }); /// });
/// ``` /// ```
pub fn on_panic<F>(mut self, callback: F) -> Self pub fn on_panic<F>(mut self, callback: F) -> Self
where F: FnOnce(Box<dyn Any + Send>) + Send + Sync + 'static, where
F: FnOnce(Box<dyn Any + Send>) + Send + Sync + 'static,
{ {
self.panicked = Some(Box::new(callback)); self.panicked = Some(Box::new(callback));
self self
} }
} }
impl<R> Future for RecoverableHandle<R> { impl<R> Future for RecoverableHandle<R> {
@ -102,7 +102,7 @@ impl<R> Future for RecoverableHandle<R> {
} }
Poll::Ready(None) Poll::Ready(None)
}, }
} }
} }
} }

View File

@ -73,25 +73,22 @@ bitflags::bitflags! {
#[repr(packed)] #[repr(packed)]
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)] #[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub struct State { pub struct State {
bytes: [u8; 8] bytes: [u8; 8],
} }
impl State { impl State {
#[inline(always)] #[inline(always)]
pub const fn new(flags: StateFlags, references: u32) -> Self { pub const fn new(flags: StateFlags, references: u32) -> Self {
let [a,b,c,d] = references.to_ne_bytes(); let [a, b, c, d] = references.to_ne_bytes();
let [e,f,g,h] = flags.bits.to_ne_bytes(); let [e, f, g, h] = flags.bits.to_ne_bytes();
Self::from_bytes([a,b,c,d,e,f,g,h]) Self::from_bytes([a, b, c, d, e, f, g, h])
} }
#[inline(always)] #[inline(always)]
pub const fn parts(self: Self) -> (StateFlags, u32) { pub const fn parts(self: Self) -> (StateFlags, u32) {
let [a,b,c,d,e,f,g,h] = self.bytes; let [a, b, c, d, e, f, g, h] = self.bytes;
let refcount = u32::from_ne_bytes([a,b,c,d]); let refcount = u32::from_ne_bytes([a, b, c, d]);
let state = unsafe { let state = unsafe { StateFlags::from_bits_unchecked(u32::from_ne_bytes([e, f, g, h])) };
StateFlags::from_bits_unchecked(u32::from_ne_bytes([e,f,g,h]))
};
(state, refcount) (state, refcount)
} }
@ -101,8 +98,8 @@ impl State {
/// Note that the reference counter only tracks the `LightProc` and `Waker`s. The `ProcHandle` is /// Note that the reference counter only tracks the `LightProc` and `Waker`s. The `ProcHandle` is
/// tracked separately by the `HANDLE` flag. /// tracked separately by the `HANDLE` flag.
pub const fn get_refcount(self) -> u32 { pub const fn get_refcount(self) -> u32 {
let [a,b,c,d,_,_,_,_] = self.bytes; let [a, b, c, d, _, _, _, _] = self.bytes;
u32::from_ne_bytes([a,b,c,d]) u32::from_ne_bytes([a, b, c, d])
} }
#[inline(always)] #[inline(always)]
@ -116,7 +113,7 @@ impl State {
#[inline(always)] #[inline(always)]
pub const fn get_flags(self) -> StateFlags { pub const fn get_flags(self) -> StateFlags {
let [_, _, _, _, e, f, g, h] = self.bytes; let [_, _, _, _, e, f, g, h] = self.bytes;
unsafe { StateFlags::from_bits_unchecked(u32::from_ne_bytes([e,f,g,h])) } unsafe { StateFlags::from_bits_unchecked(u32::from_ne_bytes([e, f, g, h])) }
} }
#[inline(always)] #[inline(always)]
@ -207,10 +204,10 @@ impl AtomicState {
current: State, current: State,
new: State, new: State,
success: Ordering, success: Ordering,
failure: Ordering failure: Ordering,
) -> Result<State, State> ) -> Result<State, State> {
{ self.inner
self.inner.compare_exchange(current.into_u64(), new.into_u64(), success, failure) .compare_exchange(current.into_u64(), new.into_u64(), success, failure)
.map(|u| State::from_u64(u)) .map(|u| State::from_u64(u))
.map_err(|u| State::from_u64(u)) .map_err(|u| State::from_u64(u))
} }
@ -220,37 +217,37 @@ impl AtomicState {
current: State, current: State,
new: State, new: State,
success: Ordering, success: Ordering,
failure: Ordering failure: Ordering,
) -> Result<State, State> ) -> Result<State, State> {
{ self.inner
self.inner.compare_exchange_weak(current.into_u64(), new.into_u64(), success, failure) .compare_exchange_weak(current.into_u64(), new.into_u64(), success, failure)
.map(|u| State::from_u64(u)) .map(|u| State::from_u64(u))
.map_err(|u| State::from_u64(u)) .map_err(|u| State::from_u64(u))
} }
pub fn fetch_or(&self, val: StateFlags, order: Ordering) -> State { pub fn fetch_or(&self, val: StateFlags, order: Ordering) -> State {
let [a,b,c,d] = val.bits.to_ne_bytes(); let [a, b, c, d] = val.bits.to_ne_bytes();
let store = u64::from_ne_bytes([0,0,0,0,a,b,c,d]); let store = u64::from_ne_bytes([0, 0, 0, 0, a, b, c, d]);
State::from_u64(self.inner.fetch_or(store, order)) State::from_u64(self.inner.fetch_or(store, order))
} }
pub fn fetch_and(&self, val: StateFlags, order: Ordering) -> State { pub fn fetch_and(&self, val: StateFlags, order: Ordering) -> State {
let [a,b,c,d] = val.bits.to_ne_bytes(); let [a, b, c, d] = val.bits.to_ne_bytes();
let store = u64::from_ne_bytes([!0,!0,!0,!0,a,b,c,d]); let store = u64::from_ne_bytes([!0, !0, !0, !0, a, b, c, d]);
State::from_u64(self.inner.fetch_and(store, order)) State::from_u64(self.inner.fetch_and(store, order))
} }
// FIXME: Do this properly // FIXME: Do this properly
pub fn fetch_add(&self, val: u32, order: Ordering) -> State { pub fn fetch_add(&self, val: u32, order: Ordering) -> State {
let [a,b,c,d] = val.to_ne_bytes(); let [a, b, c, d] = val.to_ne_bytes();
let store = u64::from_ne_bytes([a,b,c,d,0,0,0,0]); let store = u64::from_ne_bytes([a, b, c, d, 0, 0, 0, 0]);
State::from_u64(self.inner.fetch_add(store, order)) State::from_u64(self.inner.fetch_add(store, order))
} }
// FIXME: Do this properly // FIXME: Do this properly
pub fn fetch_sub(&self, val: u32, order: Ordering) -> State { pub fn fetch_sub(&self, val: u32, order: Ordering) -> State {
let [a,b,c,d] = val.to_ne_bytes(); let [a, b, c, d] = val.to_ne_bytes();
let store = u64::from_ne_bytes([a,b,c,d,0,0,0,0]); let store = u64::from_ne_bytes([a, b, c, d, 0, 0, 0, 0]);
State::from_u64(self.inner.fetch_sub(store, order)) State::from_u64(self.inner.fetch_sub(store, order))
} }
} }