Run rustfmt

This commit is contained in:
Nadja Reitzenstein 2022-05-05 15:50:44 +02:00
parent 475cb9b9b4
commit 2d9f30b55b
74 changed files with 1243 additions and 1053 deletions

View File

@ -1,7 +1,8 @@
use walkdir::{WalkDir, DirEntry};
use walkdir::{DirEntry, WalkDir};
fn is_hidden(entry: &DirEntry) -> bool {
entry.file_name()
entry
.file_name()
.to_str()
.map(|s| s.starts_with('.'))
.unwrap_or(false)
@ -22,15 +23,15 @@ fn main() {
.filter_map(Result::ok) // Filter all entries that access failed on
.filter(|e| !e.file_type().is_dir()) // Filter directories
// Filter non-schema files
.filter(|e| e.file_name()
.to_str()
.map(|s| s.ends_with(".capnp"))
.unwrap_or(false)
)
.filter(|e| {
e.file_name()
.to_str()
.map(|s| s.ends_with(".capnp"))
.unwrap_or(false)
})
{
println!("Collecting schema file {}", entry.path().display());
compile_command
.file(entry.path());
compile_command.file(entry.path());
}
println!("Compiling schemas...");
@ -53,16 +54,18 @@ fn main() {
.filter_map(Result::ok) // Filter all entries that access failed on
.filter(|e| !e.file_type().is_dir()) // Filter directories
// Filter non-schema files
.filter(|e| e.file_name()
.to_str()
.map(|s| s.ends_with(".capnp"))
.unwrap_or(false)
)
.filter(|e| {
e.file_name()
.to_str()
.map(|s| s.ends_with(".capnp"))
.unwrap_or(false)
})
{
println!("Collecting schema file {}", entry.path().display());
compile_command
.file(entry.path());
compile_command.file(entry.path());
}
compile_command.run().expect("Failed to generate extra API code");
compile_command
.run()
.expect("Failed to generate extra API code");
}

View File

@ -1,4 +1,3 @@
//! FabAccess generated API bindings
//!
//! This crate contains slightly nicer and better documented bindings for the FabAccess API.

View File

@ -1,6 +1,5 @@
pub use capnpc::schema_capnp;
#[cfg(feature = "generated")]
pub mod authenticationsystem_capnp {
include!(concat!(env!("OUT_DIR"), "/authenticationsystem_capnp.rs"));

View File

@ -1,6 +1,6 @@
use std::collections::HashMap;
use futures_util::future;
use futures_util::future::BoxFuture;
use std::collections::HashMap;
use crate::actors::Actor;
use crate::db::ArchivedValue;

View File

@ -3,7 +3,7 @@ use crate::resources::state::State;
use crate::{Config, ResourcesHandle};
use async_compat::CompatExt;
use executor::pool::Executor;
use futures_signals::signal::{Signal};
use futures_signals::signal::Signal;
use futures_util::future::BoxFuture;
use rumqttc::{AsyncClient, ConnectionError, Event, Incoming, MqttOptions};
@ -18,15 +18,15 @@ use std::time::Duration;
use once_cell::sync::Lazy;
use rumqttc::ConnectReturnCode::Success;
use rustls::{RootCertStore};
use url::Url;
use crate::actors::dummy::Dummy;
use crate::actors::process::Process;
use crate::db::ArchivedValue;
use rustls::RootCertStore;
use url::Url;
mod shelly;
mod process;
mod dummy;
mod process;
mod shelly;
pub trait Actor {
fn apply(&mut self, state: ArchivedValue<State>) -> BoxFuture<'static, ()>;
@ -102,7 +102,7 @@ static ROOT_CERTS: Lazy<RootCertStore> = Lazy::new(|| {
} else {
tracing::info!(loaded, "certificates loaded");
}
},
}
Err(error) => {
tracing::error!(%error, "failed to load system certificates");
}
@ -219,8 +219,10 @@ pub fn load(executor: Executor, config: &Config, resources: ResourcesHandle) ->
.compat(),
);
let mut actor_map: HashMap<String, _> = config.actor_connections.iter()
.filter_map(|(k,v)| {
let mut actor_map: HashMap<String, _> = config
.actor_connections
.iter()
.filter_map(|(k, v)| {
if let Some(resource) = resources.get_by_id(v) {
Some((k.clone(), resource.get_signal()))
} else {
@ -258,8 +260,6 @@ fn load_single(
"Dummy" => Some(Box::new(Dummy::new(name.clone(), params.clone()))),
"Process" => Process::new(name.clone(), params).map(|a| a.into_boxed_actuator()),
"Shelly" => Some(Box::new(Shelly::new(name.clone(), client, params))),
_ => {
None
}
_ => None,
}
}

View File

@ -1,6 +1,6 @@
use futures_util::future::BoxFuture;
use std::collections::HashMap;
use std::process::{Command, Stdio};
use futures_util::future::BoxFuture;
use crate::actors::Actor;
use crate::db::ArchivedValue;
@ -16,11 +16,10 @@ pub struct Process {
impl Process {
pub fn new(name: String, params: &HashMap<String, String>) -> Option<Self> {
let cmd = params.get("cmd").map(|s| s.to_string())?;
let args = params.get("args").map(|argv|
argv.split_whitespace()
.map(|s| s.to_string())
.collect())
.unwrap_or_else(Vec::new);
let args = params
.get("args")
.map(|argv| argv.split_whitespace().map(|s| s.to_string()).collect())
.unwrap_or_else(Vec::new);
Some(Self { name, cmd, args })
}
@ -48,41 +47,42 @@ impl Actor for Process {
command.arg("inuse").arg(by.id.as_str());
}
ArchivedStatus::ToCheck(by) => {
command.arg("tocheck")
.arg(by.id.as_str());
command.arg("tocheck").arg(by.id.as_str());
}
ArchivedStatus::Blocked(by) => {
command.arg("blocked")
.arg(by.id.as_str());
command.arg("blocked").arg(by.id.as_str());
}
ArchivedStatus::Disabled => {
command.arg("disabled");
}
ArchivedStatus::Disabled => { command.arg("disabled"); },
ArchivedStatus::Reserved(by) => {
command.arg("reserved")
.arg(by.id.as_str());
command.arg("reserved").arg(by.id.as_str());
}
}
let name = self.name.clone();
Box::pin(async move { match command.output() {
Ok(retv) if retv.status.success() => {
tracing::trace!("Actor was successful");
let outstr = String::from_utf8_lossy(&retv.stdout);
for line in outstr.lines() {
tracing::debug!(%name, %line, "actor stdout");
}
}
Ok(retv) => {
tracing::warn!(%name, ?state, code=?retv.status,
"Actor returned nonzero exitcode"
);
if !retv.stderr.is_empty() {
let errstr = String::from_utf8_lossy(&retv.stderr);
for line in errstr.lines() {
tracing::warn!(%name, %line, "actor stderr");
Box::pin(async move {
match command.output() {
Ok(retv) if retv.status.success() => {
tracing::trace!("Actor was successful");
let outstr = String::from_utf8_lossy(&retv.stdout);
for line in outstr.lines() {
tracing::debug!(%name, %line, "actor stdout");
}
}
Ok(retv) => {
tracing::warn!(%name, ?state, code=?retv.status,
"Actor returned nonzero exitcode"
);
if !retv.stderr.is_empty() {
let errstr = String::from_utf8_lossy(&retv.stderr);
for line in errstr.lines() {
tracing::warn!(%name, %line, "actor stderr");
}
}
}
Err(error) => tracing::warn!(%name, ?error, "process actor failed to run cmd"),
}
Err(error) => tracing::warn!(%name, ?error, "process actor failed to run cmd"),
}})
})
}
}

View File

@ -1,11 +1,11 @@
use std::collections::HashMap;
use futures_util::future::BoxFuture;
use std::collections::HashMap;
use rumqttc::{AsyncClient, QoS};
use crate::actors::Actor;
use crate::db::ArchivedValue;
use crate::resources::modules::fabaccess::ArchivedStatus;
use crate::resources::state::State;
use rumqttc::{AsyncClient, QoS};
/// An actuator for a Shellie connected listening on one MQTT broker
///
@ -28,7 +28,11 @@ impl Shelly {
tracing::debug!(%name,%topic,"Starting shelly module");
Shelly { name, client, topic, }
Shelly {
name,
client,
topic,
}
}
/// Set the name to a new one. This changes the shelly that will be activated
@ -38,7 +42,6 @@ impl Shelly {
}
}
impl Actor for Shelly {
fn apply(&mut self, state: ArchivedValue<State>) -> BoxFuture<'static, ()> {
tracing::debug!(?state, name=%self.name,

View File

@ -1,11 +1,11 @@
use once_cell::sync::OnceCell;
use std::fs::{File, OpenOptions};
use std::io;
use std::io::{LineWriter, Write};
use std::sync::Mutex;
use once_cell::sync::OnceCell;
use crate::Config;
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
use serde_json::Serializer;
pub static AUDIT: OnceCell<AuditLog> = OnceCell::new();
@ -26,7 +26,10 @@ impl AuditLog {
pub fn new(config: &Config) -> io::Result<&'static Self> {
AUDIT.get_or_try_init(|| {
tracing::debug!(path = %config.auditlog_path.display(), "Initializing audit log");
let fd = OpenOptions::new().create(true).append(true).open(&config.auditlog_path)?;
let fd = OpenOptions::new()
.create(true)
.append(true)
.open(&config.auditlog_path)?;
let writer = Mutex::new(LineWriter::new(fd));
Ok(Self { writer })
})
@ -34,7 +37,11 @@ impl AuditLog {
pub fn log(&self, machine: &str, state: &str) -> io::Result<()> {
let timestamp = chrono::Utc::now().timestamp();
let line = AuditLogLine { timestamp, machine, state };
let line = AuditLogLine {
timestamp,
machine,
state,
};
tracing::debug!(?line, "writing audit log line");
@ -42,7 +49,8 @@ impl AuditLog {
let mut writer: &mut LineWriter<File> = &mut *guard;
let mut ser = Serializer::new(&mut writer);
line.serialize(&mut ser).expect("failed to serialize audit log line");
line.serialize(&mut ser)
.expect("failed to serialize audit log line");
writer.write("\n".as_bytes())?;
Ok(())
}

View File

@ -19,8 +19,8 @@ pub static FABFIRE: Mechanism = Mechanism {
first: Side::Client,
};
use rsasl::property::{Property, PropertyDefinition, PropertyQ};
use std::marker::PhantomData;
use rsasl::property::{Property, PropertyQ, PropertyDefinition};
// All Property types must implement Debug.
#[derive(Debug)]
// The `PhantomData` in the constructor is only used so external crates can't construct this type.

View File

@ -1,17 +1,17 @@
use std::fmt::{Debug, Display, Formatter};
use std::io::Write;
use desfire::desfire::desfire::MAX_BYTES_PER_TRANSACTION;
use desfire::desfire::Desfire;
use desfire::error::Error as DesfireError;
use desfire::iso7816_4::apduresponse::APDUResponse;
use rsasl::error::{MechanismError, MechanismErrorKind, SASLError, SessionError};
use rsasl::mechanism::Authentication;
use rsasl::SASL;
use rsasl::session::{SessionData, StepResult};
use serde::{Deserialize, Serialize};
use desfire::desfire::Desfire;
use desfire::iso7816_4::apduresponse::APDUResponse;
use desfire::error::{Error as DesfireError};
use std::convert::TryFrom;
use std::sync::Arc;
use desfire::desfire::desfire::MAX_BYTES_PER_TRANSACTION;
use rsasl::property::AuthId;
use rsasl::session::{SessionData, StepResult};
use rsasl::SASL;
use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use std::fmt::{Debug, Display, Formatter};
use std::io::Write;
use std::sync::Arc;
use crate::authentication::fabfire::FabFireCardKey;
@ -37,7 +37,9 @@ impl Debug for FabFireError {
FabFireError::InvalidMagic(magic) => write!(f, "InvalidMagic: {}", magic),
FabFireError::InvalidToken(token) => write!(f, "InvalidToken: {}", token),
FabFireError::InvalidURN(urn) => write!(f, "InvalidURN: {}", urn),
FabFireError::InvalidCredentials(credentials) => write!(f, "InvalidCredentials: {}", credentials),
FabFireError::InvalidCredentials(credentials) => {
write!(f, "InvalidCredentials: {}", credentials)
}
FabFireError::Session(err) => write!(f, "Session: {}", err),
}
}
@ -53,7 +55,9 @@ impl Display for FabFireError {
FabFireError::InvalidMagic(magic) => write!(f, "InvalidMagic: {}", magic),
FabFireError::InvalidToken(token) => write!(f, "InvalidToken: {}", token),
FabFireError::InvalidURN(urn) => write!(f, "InvalidURN: {}", urn),
FabFireError::InvalidCredentials(credentials) => write!(f, "InvalidCredentials: {}", credentials),
FabFireError::InvalidCredentials(credentials) => {
write!(f, "InvalidCredentials: {}", credentials)
}
FabFireError::Session(err) => write!(f, "Session: {}", err),
}
}
@ -107,16 +111,22 @@ enum CardCommand {
addn_txt: Option<String>,
},
sendPICC {
#[serde(deserialize_with = "hex::deserialize", serialize_with = "hex::serialize_upper")]
data: Vec<u8>
#[serde(
deserialize_with = "hex::deserialize",
serialize_with = "hex::serialize_upper"
)]
data: Vec<u8>,
},
readPICC {
#[serde(deserialize_with = "hex::deserialize", serialize_with = "hex::serialize_upper")]
data: Vec<u8>
#[serde(
deserialize_with = "hex::deserialize",
serialize_with = "hex::serialize_upper"
)]
data: Vec<u8>,
},
haltPICC,
Key {
data: String
data: String,
},
ConfirmUser,
}
@ -145,18 +155,35 @@ const MAGIC: &'static str = "FABACCESS\0DESFIRE\01.0\0";
impl FabFire {
pub fn new_server(_sasl: &SASL) -> Result<Box<dyn Authentication>, SASLError> {
Ok(Box::new(Self { step: Step::New, card_info: None, key_info: None, auth_info: None, app_id: 1, local_urn: "urn:fabaccess:lab:innovisionlab".to_string(), desfire: Desfire { card: None, session_key: None, cbc_iv: None } }))
Ok(Box::new(Self {
step: Step::New,
card_info: None,
key_info: None,
auth_info: None,
app_id: 1,
local_urn: "urn:fabaccess:lab:innovisionlab".to_string(),
desfire: Desfire {
card: None,
session_key: None,
cbc_iv: None,
},
}))
}
}
impl Authentication for FabFire {
fn step(&mut self, session: &mut SessionData, input: Option<&[u8]>, writer: &mut dyn Write) -> StepResult {
fn step(
&mut self,
session: &mut SessionData,
input: Option<&[u8]>,
writer: &mut dyn Write,
) -> StepResult {
match self.step {
Step::New => {
tracing::trace!("Step: New");
//receive card info (especially card UID) from reader
return match input {
None => { Err(SessionError::InputDataRequired) }
None => Err(SessionError::InputDataRequired),
Some(cardinfo) => {
self.card_info = match serde_json::from_slice(cardinfo) {
Ok(card_info) => Some(card_info),
@ -170,7 +197,10 @@ impl Authentication for FabFire {
Ok(buf) => match Vec::<u8>::try_from(buf) {
Ok(data) => data,
Err(e) => {
tracing::error!("Failed to convert APDUCommand to Vec<u8>: {:?}", e);
tracing::error!(
"Failed to convert APDUCommand to Vec<u8>: {:?}",
e
);
return Err(FabFireError::SerializationError.into());
}
},
@ -183,7 +213,9 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) {
Ok(send_buf) => {
self.step = Step::SelectApp;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?;
writer
.write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len())))
}
Err(e) => {
@ -198,30 +230,39 @@ impl Authentication for FabFire {
tracing::trace!("Step: SelectApp");
// check that we successfully selected the application
let response: CardCommand = match input {
None => { return Err(SessionError::InputDataRequired); }
Some(buf) => match serde_json::from_slice(buf).map_err(|e| FabFireError::DeserializationError(e)) {
None => {
return Err(SessionError::InputDataRequired);
}
Some(buf) => match serde_json::from_slice(buf)
.map_err(|e| FabFireError::DeserializationError(e))
{
Ok(response) => response,
Err(e) => {
tracing::error!("Deserializing data from card failed: {:?}", e);
return Err(e.into());
}
}
},
};
let apdu_response = match response {
CardCommand::readPICC { data } => { APDUResponse::new(&*data) }
CardCommand::readPICC { data } => APDUResponse::new(&*data),
_ => {
tracing::error!("Unexpected response: {:?}", response);
return Err(FabFireError::ParseError.into());
}
};
apdu_response.check().map_err(|e| FabFireError::CardError(e))?;
apdu_response
.check()
.map_err(|e| FabFireError::CardError(e))?;
// request the contents of the file containing the magic string
const MAGIC_FILE_ID: u8 = 0x01;
let buf = match self.desfire.read_data_chunk_cmd(MAGIC_FILE_ID, 0, MAGIC.len()) {
let buf = match self
.desfire
.read_data_chunk_cmd(MAGIC_FILE_ID, 0, MAGIC.len())
{
Ok(buf) => match Vec::<u8>::try_from(buf) {
Ok(data) => data,
Err(e) => {
@ -238,7 +279,9 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) {
Ok(send_buf) => {
self.step = Step::VerifyMagic;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?;
writer
.write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len())))
}
Err(e) => {
@ -251,25 +294,28 @@ impl Authentication for FabFire {
tracing::trace!("Step: VerifyMagic");
// verify the magic string to determine that we have a valid fabfire card
let response: CardCommand = match input {
None => { return Err(SessionError::InputDataRequired); }
Some(buf) => match serde_json::from_slice(buf).map_err(|e| FabFireError::DeserializationError(e)) {
None => {
return Err(SessionError::InputDataRequired);
}
Some(buf) => match serde_json::from_slice(buf)
.map_err(|e| FabFireError::DeserializationError(e))
{
Ok(response) => response,
Err(e) => {
tracing::error!("Deserializing data from card failed: {:?}", e);
return Err(e.into());
}
}
},
};
let apdu_response = match response {
CardCommand::readPICC { data } => { APDUResponse::new(&*data) }
CardCommand::readPICC { data } => APDUResponse::new(&*data),
_ => {
tracing::error!("Unexpected response: {:?}", response);
return Err(FabFireError::ParseError.into());
}
};
match apdu_response.check() {
Ok(_) => {
match apdu_response.body {
@ -291,11 +337,15 @@ impl Authentication for FabFire {
}
}
// request the contents of the file containing the URN
const URN_FILE_ID: u8 = 0x02;
let buf = match self.desfire.read_data_chunk_cmd(URN_FILE_ID, 0, self.local_urn.as_bytes().len()) { // TODO: support urn longer than 47 Bytes
let buf = match self.desfire.read_data_chunk_cmd(
URN_FILE_ID,
0,
self.local_urn.as_bytes().len(),
) {
// TODO: support urn longer than 47 Bytes
Ok(buf) => match Vec::<u8>::try_from(buf) {
Ok(data) => data,
Err(e) => {
@ -312,7 +362,9 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) {
Ok(send_buf) => {
self.step = Step::GetURN;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?;
writer
.write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len())))
}
Err(e) => {
@ -325,32 +377,39 @@ impl Authentication for FabFire {
tracing::trace!("Step: GetURN");
// parse the urn and match it to our local urn
let response: CardCommand = match input {
None => { return Err(SessionError::InputDataRequired); }
Some(buf) => match serde_json::from_slice(buf).map_err(|e| FabFireError::DeserializationError(e)) {
None => {
return Err(SessionError::InputDataRequired);
}
Some(buf) => match serde_json::from_slice(buf)
.map_err(|e| FabFireError::DeserializationError(e))
{
Ok(response) => response,
Err(e) => {
tracing::error!("Deserializing data from card failed: {:?}", e);
return Err(e.into());
}
}
},
};
let apdu_response = match response {
CardCommand::readPICC { data } => { APDUResponse::new(&*data) }
CardCommand::readPICC { data } => APDUResponse::new(&*data),
_ => {
tracing::error!("Unexpected response: {:?}", response);
tracing::error!("Unexpected response: {:?}", response);
return Err(FabFireError::ParseError.into());
}
};
match apdu_response.check() {
Ok(_) => {
match apdu_response.body {
Some(data) => {
let received_urn = String::from_utf8(data).unwrap();
if received_urn != self.local_urn {
tracing::error!("URN mismatch: {:?} != {:?}", received_urn, self.local_urn);
tracing::error!(
"URN mismatch: {:?} != {:?}",
received_urn,
self.local_urn
);
return Err(FabFireError::ParseError.into());
}
}
@ -361,14 +420,19 @@ impl Authentication for FabFire {
};
}
Err(e) => {
tracing::error!("Got invalid APDUResponse: {:?}", e);
tracing::error!("Got invalid APDUResponse: {:?}", e);
return Err(FabFireError::ParseError.into());
}
}
// request the contents of the file containing the URN
const TOKEN_FILE_ID: u8 = 0x03;
let buf = match self.desfire.read_data_chunk_cmd(TOKEN_FILE_ID, 0, MAX_BYTES_PER_TRANSACTION) { // TODO: support data longer than 47 Bytes
let buf = match self.desfire.read_data_chunk_cmd(
TOKEN_FILE_ID,
0,
MAX_BYTES_PER_TRANSACTION,
) {
// TODO: support data longer than 47 Bytes
Ok(buf) => match Vec::<u8>::try_from(buf) {
Ok(data) => data,
Err(e) => {
@ -385,7 +449,9 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) {
Ok(send_buf) => {
self.step = Step::GetToken;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?;
writer
.write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len())))
}
Err(e) => {
@ -398,43 +464,52 @@ impl Authentication for FabFire {
// println!("Step: GetToken");
// parse the token and select the appropriate user
let response: CardCommand = match input {
None => { return Err(SessionError::InputDataRequired); }
Some(buf) => match serde_json::from_slice(buf).map_err(|e| FabFireError::DeserializationError(e)) {
None => {
return Err(SessionError::InputDataRequired);
}
Some(buf) => match serde_json::from_slice(buf)
.map_err(|e| FabFireError::DeserializationError(e))
{
Ok(response) => response,
Err(e) => {
tracing::error!("Deserializing data from card failed: {:?}", e);
return Err(e.into());
}
}
},
};
let apdu_response = match response {
CardCommand::readPICC { data } => { APDUResponse::new(&*data) }
CardCommand::readPICC { data } => APDUResponse::new(&*data),
_ => {
tracing::error!("Unexpected response: {:?}", response);
return Err(FabFireError::ParseError.into());
}
};
match apdu_response.check() {
Ok(_) => {
match apdu_response.body {
Some(data) => {
let token = String::from_utf8(data).unwrap();
session.set_property::<AuthId>(Arc::new(token.trim_matches(char::from(0)).to_string()));
let key = match session.get_property_or_callback::<FabFireCardKey>() {
session.set_property::<AuthId>(Arc::new(
token.trim_matches(char::from(0)).to_string(),
));
let key = match session.get_property_or_callback::<FabFireCardKey>()
{
Ok(Some(key)) => Box::from(key.as_slice()),
Ok(None) => {
tracing::error!("No keys on file for token");
return Err(FabFireError::InvalidCredentials("No keys on file for token".to_string()).into());
return Err(FabFireError::InvalidCredentials(
"No keys on file for token".to_string(),
)
.into());
}
Err(e) => {
tracing::error!("Failed to get key: {:?}", e);
return Err(FabFireError::Session(e).into());
}
};
self.key_info = Some(KeyInfo{ key_id: 0x01, key });
self.key_info = Some(KeyInfo { key_id: 0x01, key });
}
None => {
tracing::error!("No data in response");
@ -448,7 +523,10 @@ impl Authentication for FabFire {
}
}
let buf = match self.desfire.authenticate_iso_aes_challenge_cmd(self.key_info.as_ref().unwrap().key_id) {
let buf = match self
.desfire
.authenticate_iso_aes_challenge_cmd(self.key_info.as_ref().unwrap().key_id)
{
Ok(buf) => match Vec::<u8>::try_from(buf) {
Ok(data) => data,
Err(e) => {
@ -465,7 +543,9 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) {
Ok(send_buf) => {
self.step = Step::Authenticate1;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?;
writer
.write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len())))
}
Err(e) => {
@ -477,25 +557,28 @@ impl Authentication for FabFire {
Step::Authenticate1 => {
tracing::trace!("Step: Authenticate1");
let response: CardCommand = match input {
None => { return Err(SessionError::InputDataRequired); }
Some(buf) => match serde_json::from_slice(buf).map_err(|e| FabFireError::DeserializationError(e)) {
None => {
return Err(SessionError::InputDataRequired);
}
Some(buf) => match serde_json::from_slice(buf)
.map_err(|e| FabFireError::DeserializationError(e))
{
Ok(response) => response,
Err(e) => {
tracing::error!("Failed to deserialize response: {:?}", e);
return Err(e.into());
}
}
},
};
let apdu_response = match response {
CardCommand::readPICC { data } => { APDUResponse::new(&*data) }
CardCommand::readPICC { data } => APDUResponse::new(&*data),
_ => {
tracing::error!("Unexpected response: {:?}", response);
return Err(FabFireError::ParseError.into());
}
};
match apdu_response.check() {
Ok(_) => {
match apdu_response.body {
@ -506,13 +589,19 @@ impl Authentication for FabFire {
//TODO: Check if we need a CSPRNG here
let rnd_a: [u8; 16] = rand::random();
let (cmd_challenge_response,
let (cmd_challenge_response, rnd_b, iv) = self
.desfire
.authenticate_iso_aes_response_cmd(
rnd_b_enc,
&*(self.key_info.as_ref().unwrap().key),
&rnd_a,
)
.unwrap();
self.auth_info = Some(AuthInfo {
rnd_a: Vec::<u8>::from(rnd_a),
rnd_b,
iv) = self.desfire.authenticate_iso_aes_response_cmd(
rnd_b_enc,
&*(self.key_info.as_ref().unwrap().key),
&rnd_a).unwrap();
self.auth_info = Some(AuthInfo { rnd_a: Vec::<u8>::from(rnd_a), rnd_b, iv });
iv,
});
let buf = match Vec::<u8>::try_from(cmd_challenge_response) {
Ok(data) => data,
Err(e) => {
@ -524,7 +613,9 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) {
Ok(send_buf) => {
self.step = Step::Authenticate2;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?;
writer
.write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len())))
}
Err(e) => {
@ -548,58 +639,73 @@ impl Authentication for FabFire {
Step::Authenticate2 => {
// println!("Step: Authenticate2");
let response: CardCommand = match input {
None => { return Err(SessionError::InputDataRequired); }
Some(buf) => match serde_json::from_slice(buf).map_err(|e| FabFireError::DeserializationError(e)) {
None => {
return Err(SessionError::InputDataRequired);
}
Some(buf) => match serde_json::from_slice(buf)
.map_err(|e| FabFireError::DeserializationError(e))
{
Ok(response) => response,
Err(e) => {
tracing::error!("Failed to deserialize response: {:?}", e);
return Err(e.into());
}
}
},
};
let apdu_response = match response {
CardCommand::readPICC { data } => { APDUResponse::new(&*data) }
CardCommand::readPICC { data } => APDUResponse::new(&*data),
_ => {
tracing::error!("Got invalid response: {:?}", response);
return Err(FabFireError::ParseError.into());
}
};
match apdu_response.check() {
Ok(_) => {
match apdu_response.body {
Some(data) => {
match self.auth_info.as_ref() {
None => { return Err(FabFireError::ParseError.into()); }
Some(auth_info) => {
if self.desfire.authenticate_iso_aes_verify(
Some(data) => match self.auth_info.as_ref() {
None => {
return Err(FabFireError::ParseError.into());
}
Some(auth_info) => {
if self
.desfire
.authenticate_iso_aes_verify(
data.as_slice(),
auth_info.rnd_a.as_slice(),
auth_info.rnd_b.as_slice(), &*(self.key_info.as_ref().unwrap().key),
auth_info.iv.as_slice()).is_ok() {
let cmd = CardCommand::message{
msg_id: Some(4),
clr_txt: None,
addn_txt: Some("".to_string()),
};
return match serde_json::to_vec(&cmd) {
Ok(send_buf) => {
self.step = Step::Authenticate1;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?;
return Ok(rsasl::session::Step::Done(Some(send_buf.len())))
}
Err(e) => {
tracing::error!("Failed to serialize command: {:?}", e);
Err(FabFireError::SerializationError.into())
}
};
}
auth_info.rnd_b.as_slice(),
&*(self.key_info.as_ref().unwrap().key),
auth_info.iv.as_slice(),
)
.is_ok()
{
let cmd = CardCommand::message {
msg_id: Some(4),
clr_txt: None,
addn_txt: Some("".to_string()),
};
return match serde_json::to_vec(&cmd) {
Ok(send_buf) => {
self.step = Step::Authenticate1;
writer
.write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
return Ok(rsasl::session::Step::Done(Some(
send_buf.len(),
)));
}
Err(e) => {
tracing::error!(
"Failed to serialize command: {:?}",
e
);
Err(FabFireError::SerializationError.into())
}
};
}
}
}
},
None => {
tracing::error!("got empty response");
return Err(FabFireError::ParseError.into());
@ -608,7 +714,9 @@ impl Authentication for FabFire {
}
Err(_e) => {
tracing::error!("Got invalid response: {:?}", apdu_response);
return Err(FabFireError::InvalidCredentials(format!("{}", apdu_response)).into());
return Err(
FabFireError::InvalidCredentials(format!("{}", apdu_response)).into(),
);
}
}
}
@ -616,4 +724,4 @@ impl Authentication for FabFire {
return Ok(rsasl::session::Step::Done(None));
}
}
}

View File

@ -1,6 +1,5 @@
use crate::users::Users;
use rsasl::error::{SessionError};
use rsasl::error::SessionError;
use rsasl::mechname::Mechname;
use rsasl::property::{AuthId, Password};
use rsasl::session::{Session, SessionData};
@ -31,13 +30,19 @@ impl rsasl::callback::Callback for Callback {
match property {
fabfire::FABFIRECARDKEY => {
let authcid = session.get_property_or_callback::<AuthId>()?;
let user = self.users.get_user(authcid.unwrap().as_ref())
let user = self
.users
.get_user(authcid.unwrap().as_ref())
.ok_or(SessionError::AuthenticationFailure)?;
let kv = user.userdata.kv.get("cardkey")
let kv = user
.userdata
.kv
.get("cardkey")
.ok_or(SessionError::AuthenticationFailure)?;
let card_key = <[u8; 16]>::try_from(hex::decode(kv)
.map_err(|_| SessionError::AuthenticationFailure)?)
.map_err(|_| SessionError::AuthenticationFailure)?;
let card_key = <[u8; 16]>::try_from(
hex::decode(kv).map_err(|_| SessionError::AuthenticationFailure)?,
)
.map_err(|_| SessionError::AuthenticationFailure)?;
session.set_property::<FabFireCardKey>(Arc::new(card_key));
Ok(())
}
@ -60,9 +65,7 @@ impl rsasl::callback::Callback for Callback {
.ok_or(SessionError::no_property::<AuthId>())?;
tracing::debug!(authid=%authnid, "SIMPLE validation requested");
if let Some(user) = self
.users
.get_user(authnid.as_str()) {
if let Some(user) = self.users.get_user(authnid.as_str()) {
let passwd = session
.get_property::<Password>()
.ok_or(SessionError::no_property::<Password>())?;
@ -84,7 +87,7 @@ impl rsasl::callback::Callback for Callback {
_ => {
tracing::error!(?validation, "Unimplemented validation requested");
Err(SessionError::no_validate(validation))
},
}
}
}
}
@ -111,10 +114,12 @@ impl AuthenticationHandle {
let mut rsasl = SASL::new();
rsasl.install_callback(Arc::new(Callback::new(userdb)));
let mechs: Vec<&'static str> = rsasl.server_mech_list().into_iter()
let mechs: Vec<&'static str> = rsasl
.server_mech_list()
.into_iter()
.map(|m| m.mechanism.as_str())
.collect();
tracing::info!(available_mechs=mechs.len(), "initialized sasl backend");
tracing::info!(available_mechs = mechs.len(), "initialized sasl backend");
tracing::debug!(?mechs, "available mechs");
Self {

View File

@ -1,9 +1,6 @@
use crate::authorization::roles::{Roles};
use crate::authorization::roles::Roles;
use crate::Users;
pub mod permissions;
pub mod roles;
@ -22,4 +19,4 @@ impl AuthorizationHandle {
let user = self.users.get_user(uid.as_ref())?;
Some(user.userdata.roles.clone())
}
}
}

View File

@ -1,10 +1,9 @@
//! Access control logic
//!
use std::fmt;
use std::cmp::Ordering;
use std::convert::{TryFrom, Into};
use std::convert::{Into, TryFrom};
use std::fmt;
fn is_sep_char(c: char) -> bool {
c == '.'
@ -20,7 +19,7 @@ pub struct PrivilegesBuf {
/// Which permission is required to write parts of this thing
pub write: PermissionBuf,
/// Which permission is required to manage all parts of this thing
pub manage: PermissionBuf
pub manage: PermissionBuf,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
@ -39,13 +38,17 @@ impl PermissionBuf {
#[inline(always)]
/// Allocate an empty `PermissionBuf`
pub fn new() -> Self {
PermissionBuf { inner: String::new() }
PermissionBuf {
inner: String::new(),
}
}
#[inline(always)]
/// Allocate a `PermissionBuf` with the given capacity given to the internal [`String`]
pub fn with_capacity(cap: usize) -> Self {
PermissionBuf { inner: String::with_capacity(cap) }
PermissionBuf {
inner: String::with_capacity(cap),
}
}
#[inline(always)]
@ -59,7 +62,13 @@ impl PermissionBuf {
pub fn _push(&mut self, perm: &Permission) {
// in general we always need a separator unless the last byte is one or the string is empty
let need_sep = self.inner.chars().rev().next().map(|c| !is_sep_char(c)).unwrap_or(false);
let need_sep = self
.inner
.chars()
.rev()
.next()
.map(|c| !is_sep_char(c))
.unwrap_or(false);
if need_sep {
self.inner.push('.')
}
@ -73,7 +82,9 @@ impl PermissionBuf {
#[inline]
pub fn from_perm(perm: &Permission) -> Self {
Self { inner: perm.as_str().to_string() }
Self {
inner: perm.as_str().to_string(),
}
}
#[inline(always)]
@ -119,7 +130,7 @@ impl fmt::Display for PermissionBuf {
#[derive(PartialEq, Eq, Hash, Debug)]
#[repr(transparent)]
/// A borrowed permission string
///
///
/// Permissions have total equality and partial ordering.
/// Specifically permissions on the same path in a tree can be compared for specificity.
/// This means that ```(bffh.perm) > (bffh.perm.sub) == true```
@ -141,7 +152,7 @@ impl Permission {
}
#[inline(always)]
pub fn iter(&self) -> std::str::Split<char> {
pub fn iter(&self) -> std::str::Split<char> {
self.0.split('.')
}
}
@ -162,12 +173,14 @@ impl PartialOrd for Permission {
}
}
match (l,r) {
match (l, r) {
(None, None) => Some(Ordering::Equal),
(Some(_), None) => Some(Ordering::Less),
(None, Some(_)) => Some(Ordering::Greater),
(Some(_), Some(_)) => unreachable!("Broken contract in Permission::partial_cmp: sides \
should never be both Some!"),
(Some(_), Some(_)) => unreachable!(
"Broken contract in Permission::partial_cmp: sides \
should never be both Some!"
),
}
}
}
@ -183,7 +196,7 @@ impl AsRef<Permission> for Permission {
#[serde(try_from = "String")]
#[serde(into = "String")]
pub enum PermRule {
/// The permission is precise,
/// The permission is precise,
///
/// i.e. `Base("bffh.perm")` grants bffh.perm but does not grant permission for bffh.perm.sub
Base(PermissionBuf),
@ -208,7 +221,7 @@ impl PermRule {
pub fn match_perm<P: AsRef<Permission> + ?Sized>(&self, perm: &P) -> bool {
match self {
PermRule::Base(ref base) => base.as_permission() == perm.as_ref(),
PermRule::Children(ref parent) => parent.as_permission() > perm.as_ref() ,
PermRule::Children(ref parent) => parent.as_permission() > perm.as_ref(),
PermRule::Subtree(ref parent) => parent.as_permission() >= perm.as_ref(),
}
}
@ -217,12 +230,9 @@ impl PermRule {
impl fmt::Display for PermRule {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PermRule::Base(perm)
=> write!(f, "{}", perm),
PermRule::Children(parent)
=> write!(f,"{}.+", parent),
PermRule::Subtree(parent)
=> write!(f,"{}.*", parent),
PermRule::Base(perm) => write!(f, "{}", perm),
PermRule::Children(parent) => write!(f, "{}.+", parent),
PermRule::Subtree(parent) => write!(f, "{}.*", parent),
}
}
}
@ -234,7 +244,7 @@ impl Into<String> for PermRule {
PermRule::Children(mut perm) => {
perm.push(Permission::new("+"));
perm.into_string()
},
}
PermRule::Subtree(mut perm) => {
perm.push(Permission::new("+"));
perm.into_string()
@ -252,15 +262,19 @@ impl TryFrom<String> for PermRule {
if len <= 2 {
Err("Input string for PermRule is too short")
} else {
match &input[len-2..len] {
match &input[len - 2..len] {
".+" => {
input.truncate(len-2);
Ok(PermRule::Children(PermissionBuf::from_string_unchecked(input)))
},
input.truncate(len - 2);
Ok(PermRule::Children(PermissionBuf::from_string_unchecked(
input,
)))
}
".*" => {
input.truncate(len-2);
Ok(PermRule::Subtree(PermissionBuf::from_string_unchecked(input)))
},
input.truncate(len - 2);
Ok(PermRule::Subtree(PermissionBuf::from_string_unchecked(
input,
)))
}
_ => Ok(PermRule::Base(PermissionBuf::from_string_unchecked(input))),
}
}
@ -273,8 +287,10 @@ mod tests {
#[test]
fn permission_ord_test() {
assert!(PermissionBuf::from_string_unchecked("bffh.perm".to_string())
> PermissionBuf::from_string_unchecked("bffh.perm.sub".to_string()));
assert!(
PermissionBuf::from_string_unchecked("bffh.perm".to_string())
> PermissionBuf::from_string_unchecked("bffh.perm.sub".to_string())
);
}
#[test]
@ -316,11 +332,9 @@ mod tests {
fn format_and_read_compatible() {
use std::convert::TryInto;
let testdata = vec![
("testrole", "testsource"),
("", "norole"),
("nosource", "")
].into_iter().map(|(n,s)| (n.to_string(), s.to_string()));
let testdata = vec![("testrole", "testsource"), ("", "norole"), ("nosource", "")]
.into_iter()
.map(|(n, s)| (n.to_string(), s.to_string()));
for (name, source) in testdata {
let role = RoleIdentifier { name, source };
@ -337,19 +351,24 @@ mod tests {
}
}
#[test]
fn rules_from_string_test() {
assert_eq!(
PermRule::Base(PermissionBuf::from_string_unchecked("bffh.perm".to_string())),
PermRule::Base(PermissionBuf::from_string_unchecked(
"bffh.perm".to_string()
)),
PermRule::try_from("bffh.perm".to_string()).unwrap()
);
assert_eq!(
PermRule::Children(PermissionBuf::from_string_unchecked("bffh.perm".to_string())),
PermRule::Children(PermissionBuf::from_string_unchecked(
"bffh.perm".to_string()
)),
PermRule::try_from("bffh.perm.+".to_string()).unwrap()
);
assert_eq!(
PermRule::Subtree(PermissionBuf::from_string_unchecked("bffh.perm".to_string())),
PermRule::Subtree(PermissionBuf::from_string_unchecked(
"bffh.perm".to_string()
)),
PermRule::try_from("bffh.perm.*".to_string()).unwrap()
);
}

View File

@ -1,8 +1,8 @@
use crate::authorization::permissions::{PermRule, Permission};
use crate::users::db::UserData;
use once_cell::sync::OnceCell;
use std::collections::{HashMap, HashSet};
use std::fmt;
use once_cell::sync::OnceCell;
use crate::authorization::permissions::{Permission, PermRule};
use crate::users::db::UserData;
static ROLES: OnceCell<HashMap<String, Role>> = OnceCell::new();
@ -27,7 +27,6 @@ impl Roles {
self.roles.get(roleid)
}
/// Tally a role dependency tree into a set
///
/// A Default implementation exists which adapter may overwrite with more efficient
@ -62,10 +61,11 @@ impl Roles {
output
}
fn permitted_tally(&self,
roles: &mut HashSet<String>,
role_id: &String,
perm: &Permission
fn permitted_tally(
&self,
roles: &mut HashSet<String>,
role_id: &String,
perm: &Permission,
) -> bool {
if let Some(role) = self.get(role_id) {
// Only check and tally parents of a role at the role itself if it's the first time we
@ -130,7 +130,10 @@ pub struct Role {
impl Role {
pub fn new(parents: Vec<String>, permissions: Vec<PermRule>) -> Self {
Self { parents, permissions }
Self {
parents,
permissions,
}
}
}
@ -157,4 +160,4 @@ impl fmt::Display for Role {
Ok(())
}
}
}

View File

@ -5,7 +5,6 @@ use rsasl::property::AuthId;
use rsasl::session::{Session, Step};
use std::io::Cursor;
use crate::capnp::session::APISession;
use crate::session::SessionManager;
use api::authenticationsystem_capnp::authentication::{

View File

@ -2,7 +2,7 @@ use std::fmt::Formatter;
use std::net::ToSocketAddrs;
use std::path::PathBuf;
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
use crate::config::deser_option;
@ -14,7 +14,11 @@ use crate::config::deser_option;
pub struct Listen {
pub address: String,
#[serde(default, skip_serializing_if = "Option::is_none", deserialize_with = "deser_option")]
#[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deser_option"
)]
pub port: Option<u16>,
}
@ -56,4 +60,4 @@ pub struct TlsListen {
}
// The default port in the non-assignable i.e. free-use area
pub const DEFAULT_PORT: u16 = 59661;
pub const DEFAULT_PORT: u16 = 59661;

View File

@ -1,13 +1,13 @@
use std::net::SocketAddr;
pub use api::connection_capnp::bootstrap::Client;
use api::connection_capnp::bootstrap;
pub use api::connection_capnp::bootstrap::Client;
use std::net::SocketAddr;
use capnp::capability::Promise;
use capnp_rpc::pry;
use rsasl::mechname::Mechname;
use crate::authentication::AuthenticationHandle;
use crate::capnp::authenticationsystem::Authentication;
use crate::session::SessionManager;
use capnp::capability::Promise;
use capnp_rpc::pry;
use rsasl::mechname::Mechname;
/// Cap'n Proto API Handler
pub struct BootCap {
@ -17,7 +17,11 @@ pub struct BootCap {
}
impl BootCap {
pub fn new(peer_addr: SocketAddr, authentication: AuthenticationHandle, sessionmanager: SessionManager) -> Self {
pub fn new(
peer_addr: SocketAddr,
authentication: AuthenticationHandle,
sessionmanager: SessionManager,
) -> Self {
tracing::trace!(%peer_addr, "bootstrapping RPC");
Self {
peer_addr,
@ -62,12 +66,14 @@ impl bootstrap::Server for BootCap {
tracing::trace!("mechanisms");
let builder = result.get();
let mechs: Vec<_> = self.authentication.list_available_mechs()
let mechs: Vec<_> = self
.authentication
.list_available_mechs()
.into_iter()
.map(|m| m.as_str())
.collect();
let mut mechbuilder = builder.init_mechs(mechs.len() as u32);
for (i,m) in mechs.iter().enumerate() {
for (i, m) in mechs.iter().enumerate() {
mechbuilder.set(i as u32, m);
}

View File

@ -1,14 +1,11 @@
use crate::capnp::machine::Machine;
use crate::resources::search::ResourcesHandle;
use crate::resources::Resource;
use crate::session::SessionHandle;
use api::machinesystem_capnp::machine_system::{
info,
};
use crate::RESOURCES;
use api::machinesystem_capnp::machine_system::info;
use capnp::capability::Promise;
use capnp_rpc::pry;
use crate::capnp::machine::Machine;
use crate::RESOURCES;
use crate::resources::Resource;
use crate::resources::search::ResourcesHandle;
#[derive(Clone)]
pub struct Machines {
@ -19,7 +16,10 @@ pub struct Machines {
impl Machines {
pub fn new(session: SessionHandle) -> Self {
// FIXME no unwrap bad
Self { session, resources: RESOURCES.get().unwrap().clone() }
Self {
session,
resources: RESOURCES.get().unwrap().clone(),
}
}
}
@ -29,7 +29,9 @@ impl info::Server for Machines {
_: info::GetMachineListParams,
mut result: info::GetMachineListResults,
) -> Promise<(), ::capnp::Error> {
let machine_list: Vec<(usize, &Resource)> = self.resources.list_all()
let machine_list: Vec<(usize, &Resource)> = self
.resources
.list_all()
.into_iter()
.filter(|resource| resource.visible(&self.session))
.enumerate()

View File

@ -1,6 +1,5 @@
use async_net::TcpListener;
use capnp_rpc::rpc_twoparty_capnp::Side;
use capnp_rpc::twoparty::VatNetwork;
use capnp_rpc::RpcSystem;
@ -69,9 +68,7 @@ impl APIServer {
listens
.into_iter()
.map(|a| async move {
(async_net::resolve(a.to_tuple()).await, a)
})
.map(|a| async move { (async_net::resolve(a.to_tuple()).await, a) })
.collect::<FuturesUnordered<_>>()
.filter_map(|(res, addr)| async move {
match res {
@ -111,7 +108,13 @@ impl APIServer {
tracing::warn!("No usable listen addresses configured for the API server!");
}
Ok(Self::new(executor, sockets, acceptor, sessionmanager, authentication))
Ok(Self::new(
executor,
sockets,
acceptor,
sessionmanager,
authentication,
))
}
pub async fn handle_until(self, stop: impl Future) {
@ -129,10 +132,11 @@ impl APIServer {
} else {
tracing::error!(?stream, "failing a TCP connection with no peer addr");
}
},
}
Err(e) => tracing::warn!("Failed to accept stream: {}", e),
}
}).await;
})
.await;
tracing::info!("closing down API handler");
}
@ -153,7 +157,11 @@ impl APIServer {
let (rx, tx) = futures_lite::io::split(stream);
let vat = VatNetwork::new(rx, tx, Side::Server, Default::default());
let bootstrap: connection::Client = capnp_rpc::new_client(connection::BootCap::new(peer_addr, self.authentication.clone(), self.sessionmanager.clone()));
let bootstrap: connection::Client = capnp_rpc::new_client(connection::BootCap::new(
peer_addr,
self.authentication.clone(),
self.sessionmanager.clone(),
));
if let Err(e) = RpcSystem::new(Box::new(vat), Some(bootstrap.client)).await {
tracing::error!("Error during RPC handling: {}", e);

View File

@ -10,6 +10,4 @@ impl Permissions {
}
}
impl PermissionSystem for Permissions {
}
impl PermissionSystem for Permissions {}

View File

@ -1,12 +1,10 @@
use api::authenticationsystem_capnp::response::successful::Builder;
use crate::authorization::permissions::Permission;
use api::authenticationsystem_capnp::response::successful::Builder;
use crate::capnp::machinesystem::Machines;
use crate::capnp::permissionsystem::Permissions;
use crate::capnp::user_system::Users;
use crate::session::{SessionHandle};
use crate::session::SessionHandle;
#[derive(Debug, Clone)]
pub struct APISession;
@ -39,4 +37,4 @@ impl APISession {
b.set_info(capnp_rpc::new_client(Permissions::new(session)));
}
}
}
}

View File

@ -1,10 +1,10 @@
use crate::authorization::permissions::Permission;
use crate::session::SessionHandle;
use crate::users::{db, UserRef};
use api::general_capnp::optional;
use api::user_capnp::user::{self, admin, info, manage};
use capnp::capability::Promise;
use capnp_rpc::pry;
use crate::session::SessionHandle;
use api::user_capnp::user::{admin, info, manage, self};
use api::general_capnp::optional;
use crate::authorization::permissions::Permission;
use crate::users::{db, UserRef};
#[derive(Clone)]
pub struct User {
@ -22,7 +22,11 @@ impl User {
Self::new(session, user)
}
pub fn build_optional(session: &SessionHandle, user: Option<UserRef>, builder: optional::Builder<user::Owned>) {
pub fn build_optional(
session: &SessionHandle,
user: Option<UserRef>,
builder: optional::Builder<user::Owned>,
) {
if let Some(user) = user.and_then(|u| session.users.get_user(u.get_username())) {
let builder = builder.init_just();
Self::fill(&session, user, builder);
@ -102,12 +106,18 @@ impl admin::Server for User {
let rolename = pry!(pry!(pry!(param.get()).get_role()).get_name());
if let Some(_role) = self.session.roles.get(rolename) {
let mut target = self.session.users.get_user(self.user.get_username()).unwrap();
let mut target = self
.session
.users
.get_user(self.user.get_username())
.unwrap();
// Only update if needed
if !target.userdata.roles.iter().any(|r| r.as_str() == rolename) {
target.userdata.roles.push(rolename.to_string());
self.session.users.put_user(self.user.get_username(), &target);
self.session
.users
.put_user(self.user.get_username(), &target);
}
}
@ -121,22 +131,24 @@ impl admin::Server for User {
let rolename = pry!(pry!(pry!(param.get()).get_role()).get_name());
if let Some(_role) = self.session.roles.get(rolename) {
let mut target = self.session.users.get_user(self.user.get_username()).unwrap();
let mut target = self
.session
.users
.get_user(self.user.get_username())
.unwrap();
// Only update if needed
if target.userdata.roles.iter().any(|r| r.as_str() == rolename) {
target.userdata.roles.retain(|r| r.as_str() != rolename);
self.session.users.put_user(self.user.get_username(), &target);
self.session
.users
.put_user(self.user.get_username(), &target);
}
}
Promise::ok(())
}
fn pwd(
&mut self,
_: admin::PwdParams,
_: admin::PwdResults,
) -> Promise<(), ::capnp::Error> {
fn pwd(&mut self, _: admin::PwdParams, _: admin::PwdResults) -> Promise<(), ::capnp::Error> {
Promise::err(::capnp::Error::unimplemented(
"method not implemented".to_string(),
))

View File

@ -1,15 +1,12 @@
use api::usersystem_capnp::user_system::{info, manage, search};
use capnp::capability::Promise;
use capnp_rpc::pry;
use api::usersystem_capnp::user_system::{
info, manage, search
};
use crate::capnp::user::User;
use crate::session::SessionHandle;
use crate::users::{db, UserRef};
#[derive(Clone)]
pub struct Users {
session: SessionHandle,
@ -40,7 +37,8 @@ impl manage::Server for Users {
mut result: manage::GetUserListResults,
) -> Promise<(), ::capnp::Error> {
let userdb = self.session.users.into_inner();
let users = pry!(userdb.get_all()
let users = pry!(userdb
.get_all()
.map_err(|e| capnp::Error::failed(format!("UserDB error: {:?}", e))));
let mut builder = result.get().init_user_list(users.len() as u32);
for (i, (_, user)) in users.into_iter().enumerate() {
@ -113,4 +111,4 @@ impl search::Server for Users {
User::build_optional(&self.session, Some(userref), result.get());
Promise::ok(())
}
}
}

View File

@ -1,8 +1,6 @@
use std::path::Path;
use crate::Config;
use std::path::Path;
pub fn read_config_file(path: impl AsRef<Path>) -> Result<Config, serde_dhall::Error> {
serde_dhall::from_file(path)
.parse()
.map_err(Into::into)
}
serde_dhall::from_file(path).parse().map_err(Into::into)
}

View File

@ -1,13 +1,13 @@
use std::default::Default;
use std::path::{PathBuf};
use std::collections::HashMap;
use std::default::Default;
use std::path::PathBuf;
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
mod dhall;
pub use dhall::read_config_file as read;
use crate::authorization::permissions::{PrivilegesBuf};
use crate::authorization::permissions::PrivilegesBuf;
use crate::authorization::roles::Role;
use crate::capnp::{Listen, TlsListen};
use crate::logging::LogConfig;
@ -23,13 +23,25 @@ pub struct MachineDescription {
pub name: String,
/// An optional description of the Machine.
#[serde(default, skip_serializing_if = "Option::is_none", deserialize_with = "deser_option")]
#[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deser_option"
)]
pub description: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none", deserialize_with = "deser_option")]
#[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deser_option"
)]
pub wiki: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none", deserialize_with = "deser_option")]
#[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deser_option"
)]
pub category: Option<String>,
/// The permission required
@ -83,48 +95,49 @@ impl Config {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModuleConfig {
pub module: String,
pub params: HashMap<String, String>
pub params: HashMap<String, String>,
}
pub(crate) fn deser_option<'de, D, T>(d: D) -> std::result::Result<Option<T>, D::Error>
where D: serde::Deserializer<'de>, T: serde::Deserialize<'de>,
where
D: serde::Deserializer<'de>,
T: serde::Deserialize<'de>,
{
Ok(T::deserialize(d).ok())
}
impl Default for Config {
fn default() -> Self {
let mut actors: HashMap::<String, ModuleConfig> = HashMap::new();
let mut initiators: HashMap::<String, ModuleConfig> = HashMap::new();
let mut actors: HashMap<String, ModuleConfig> = HashMap::new();
let mut initiators: HashMap<String, ModuleConfig> = HashMap::new();
let machines = HashMap::new();
actors.insert("Actor".to_string(), ModuleConfig {
module: "Shelly".to_string(),
params: HashMap::new(),
});
initiators.insert("Initiator".to_string(), ModuleConfig {
module: "TCP-Listen".to_string(),
params: HashMap::new(),
});
actors.insert(
"Actor".to_string(),
ModuleConfig {
module: "Shelly".to_string(),
params: HashMap::new(),
},
);
initiators.insert(
"Initiator".to_string(),
ModuleConfig {
module: "TCP-Listen".to_string(),
params: HashMap::new(),
},
);
Config {
listens: vec![
Listen {
address: "127.0.0.1".to_string(),
port: None,
}
],
listens: vec![Listen {
address: "127.0.0.1".to_string(),
port: None,
}],
actors,
initiators,
machines,
mqtt_url: "tcp://localhost:1883".to_string(),
actor_connections: vec![
("Testmachine".to_string(), "Actor".to_string()),
],
init_connections: vec![
("Initiator".to_string(), "Testmachine".to_string()),
],
actor_connections: vec![("Testmachine".to_string(), "Actor".to_string())],
init_connections: vec![("Initiator".to_string(), "Testmachine".to_string())],
db_path: PathBuf::from("/run/bffh/database"),
auditlog_path: PathBuf::from("/var/log/bffh/audit.log"),
@ -133,7 +146,7 @@ impl Default for Config {
tlsconfig: TlsListen {
certfile: PathBuf::from("./bffh.crt"),
keyfile: PathBuf::from("./bffh.key"),
.. Default::default()
..Default::default()
},
tlskeylog: None,

View File

@ -2,6 +2,6 @@ mod raw;
pub use raw::RawDB;
mod typed;
pub use typed::{DB, ArchivedValue, Adapter, AlignedAdapter};
pub use typed::{Adapter, AlignedAdapter, ArchivedValue, DB};
pub type Error = lmdb::Error;
pub type Error = lmdb::Error;

View File

@ -1,10 +1,4 @@
use lmdb::{
Transaction,
RwTransaction,
Environment,
DatabaseFlags,
WriteFlags,
};
use lmdb::{DatabaseFlags, Environment, RwTransaction, Transaction, WriteFlags};
#[derive(Debug, Clone)]
pub struct RawDB {
@ -15,13 +9,22 @@ impl RawDB {
pub fn open(env: &Environment, name: Option<&str>) -> lmdb::Result<Self> {
env.open_db(name).map(|db| Self { db })
}
pub fn create(env: &Environment, name: Option<&str>, flags: DatabaseFlags) -> lmdb::Result<Self> {
pub fn create(
env: &Environment,
name: Option<&str>,
flags: DatabaseFlags,
) -> lmdb::Result<Self> {
env.create_db(name, flags).map(|db| Self { db })
}
pub fn get<'txn, T: Transaction, K>(&self, txn: &'txn T, key: &K) -> lmdb::Result<Option<&'txn [u8]>>
where K: AsRef<[u8]>
pub fn get<'txn, T: Transaction, K>(
&self,
txn: &'txn T,
key: &K,
) -> lmdb::Result<Option<&'txn [u8]>>
where
K: AsRef<[u8]>,
{
match txn.get(self.db, key) {
Ok(buf) => Ok(Some(buf)),
@ -30,24 +33,37 @@ impl RawDB {
}
}
pub fn put<K, V>(&self, txn: &mut RwTransaction, key: &K, value: &V, flags: WriteFlags)
-> lmdb::Result<()>
where K: AsRef<[u8]>,
V: AsRef<[u8]>,
pub fn put<K, V>(
&self,
txn: &mut RwTransaction,
key: &K,
value: &V,
flags: WriteFlags,
) -> lmdb::Result<()>
where
K: AsRef<[u8]>,
V: AsRef<[u8]>,
{
txn.put(self.db, key, value, flags)
}
pub fn reserve<'txn, K>(&self, txn: &'txn mut RwTransaction, key: &K, size: usize, flags: WriteFlags)
-> lmdb::Result<&'txn mut [u8]>
where K: AsRef<[u8]>
pub fn reserve<'txn, K>(
&self,
txn: &'txn mut RwTransaction,
key: &K,
size: usize,
flags: WriteFlags,
) -> lmdb::Result<&'txn mut [u8]>
where
K: AsRef<[u8]>,
{
txn.reserve(self.db, key, size, flags)
}
pub fn del<K, V>(&self, txn: &mut RwTransaction, key: &K, value: Option<&V>) -> lmdb::Result<()>
where K: AsRef<[u8]>,
V: AsRef<[u8]>,
where
K: AsRef<[u8]>,
V: AsRef<[u8]>,
{
txn.del(self.db, key, value.map(AsRef::as_ref))
}
@ -60,7 +76,10 @@ impl RawDB {
cursor.iter_start()
}
pub fn open_ro_cursor<'txn, T: Transaction>(&self, txn: &'txn T) -> lmdb::Result<lmdb::RoCursor<'txn>> {
pub fn open_ro_cursor<'txn, T: Transaction>(
&self,
txn: &'txn T,
) -> lmdb::Result<lmdb::RoCursor<'txn>> {
txn.open_ro_cursor(self.db)
}
}

View File

@ -119,7 +119,11 @@ impl<A> DB<A> {
}
impl<A: Adapter> DB<A> {
pub fn get<T: Transaction>(&self, txn: &T, key: &impl AsRef<[u8]>) -> Result<Option<A::Item>, db::Error> {
pub fn get<T: Transaction>(
&self,
txn: &T,
key: &impl AsRef<[u8]>,
) -> Result<Option<A::Item>, db::Error> {
Ok(self.db.get(txn, key)?.map(A::decode))
}
@ -129,8 +133,7 @@ impl<A: Adapter> DB<A> {
key: &impl AsRef<[u8]>,
value: &A::Item,
flags: WriteFlags,
) -> Result<(), db::Error>
{
) -> Result<(), db::Error> {
let len = A::encoded_len(value);
let buf = self.db.reserve(txn, key, len, flags)?;
assert_eq!(buf.len(), len, "Reserved buffer is not of requested size!");
@ -146,11 +149,12 @@ impl<A: Adapter> DB<A> {
self.db.clear(txn)
}
pub fn get_all<'txn, T: Transaction>(&self, txn: &'txn T) -> Result<impl IntoIterator<Item=(&'txn [u8], A::Item)>, db::Error> {
pub fn get_all<'txn, T: Transaction>(
&self,
txn: &'txn T,
) -> Result<impl IntoIterator<Item = (&'txn [u8], A::Item)>, db::Error> {
let mut cursor = self.db.open_ro_cursor(txn)?;
let it = cursor.iter_start();
Ok(it.filter_map(|buf| buf.ok().map(|(kbuf,vbuf)| {
(kbuf, A::decode(vbuf))
})))
Ok(it.filter_map(|buf| buf.ok().map(|(kbuf, vbuf)| (kbuf, A::decode(vbuf)))))
}
}
}

View File

@ -1,7 +1,7 @@
use std::io;
use std::fmt;
use rsasl::error::SessionError;
use crate::db;
use rsasl::error::SessionError;
use std::fmt;
use std::io;
type DBError = db::Error;
@ -21,19 +21,19 @@ impl fmt::Display for Error {
match self {
Error::SASL(e) => {
write!(f, "SASL Error: {}", e)
},
}
Error::IO(e) => {
write!(f, "IO Error: {}", e)
},
}
Error::Boxed(e) => {
write!(f, "{}", e)
},
}
Error::Capnp(e) => {
write!(f, "Cap'n Proto Error: {}", e)
},
}
Error::DB(e) => {
write!(f, "DB Error: {:?}", e)
},
}
Error::Denied => {
write!(f, "You do not have the permission required to do that.")
}
@ -71,4 +71,4 @@ impl From<DBError> for Error {
}
}
pub type Result<T> = std::result::Result<T, Error>;
pub type Result<T> = std::result::Result<T, Error>;

View File

@ -1,9 +1,9 @@
use std::fs::{File, OpenOptions};
use std::{fmt, io};
use std::fmt::Formatter;
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::path::Path;
use std::sync::Mutex;
use std::{fmt, io};
// Internal mutable state for KeyLogFile
struct KeyLogFileInner {
@ -18,10 +18,7 @@ impl fmt::Debug for KeyLogFileInner {
impl KeyLogFileInner {
fn new(path: impl AsRef<Path>) -> io::Result<Self> {
let file = OpenOptions::new()
.append(true)
.create(true)
.open(path)?;
let file = OpenOptions::new().append(true).create(true).open(path)?;
Ok(Self {
file,
@ -70,4 +67,4 @@ impl rustls::KeyLog for KeyLogFile {
}
}
}
}
}

View File

@ -16,9 +16,9 @@ pub mod db;
/// Shared error type
pub mod error;
pub mod users;
pub mod authentication;
pub mod authorization;
pub mod users;
/// Resources
pub mod resources;
@ -31,38 +31,34 @@ pub mod capnp;
pub mod utils;
mod tls;
mod audit;
mod keylog;
mod logging;
mod audit;
mod session;
mod tls;
use std::sync::{Arc};
use std::sync::Arc;
use anyhow::Context;
use futures_util::StreamExt;
use once_cell::sync::OnceCell;
use signal_hook::consts::signal::*;
use executor::pool::Executor;
use crate::audit::AuditLog;
use crate::authentication::AuthenticationHandle;
use crate::authorization::roles::Roles;
use crate::capnp::APIServer;
use crate::config::{Config};
use crate::config::Config;
use crate::resources::modules::fabaccess::MachineState;
use crate::resources::Resource;
use crate::resources::search::ResourcesHandle;
use crate::resources::state::db::StateDB;
use crate::resources::Resource;
use crate::session::SessionManager;
use crate::tls::TlsConfig;
use crate::users::db::UserDB;
use crate::users::Users;
use executor::pool::Executor;
use signal_hook::consts::signal::*;
pub const VERSION_STRING: &'static str = env!("BFFHD_VERSION_STRING");
pub const RELEASE_STRING: &'static str = env!("BFFHD_RELEASE_STRING");
@ -81,7 +77,7 @@ pub static RESOURCES: OnceCell<ResourcesHandle> = OnceCell::new();
impl Diflouroborane {
pub fn new(config: Config) -> anyhow::Result<Self> {
logging::init(&config.logging);
tracing::info!(version=VERSION_STRING, "Starting BFFH");
tracing::info!(version = VERSION_STRING, "Starting BFFH");
let span = tracing::info_span!("setup");
let _guard = span.enter();
@ -89,8 +85,8 @@ impl Diflouroborane {
let executor = Executor::new();
let env = StateDB::open_env(&config.db_path)?;
let statedb = StateDB::create_with_env(env.clone())
.context("Failed to open state DB file")?;
let statedb =
StateDB::create_with_env(env.clone()).context("Failed to open state DB file")?;
let users = Users::new(env.clone()).context("Failed to open users DB file")?;
let roles = Roles::new(config.roles.clone());
@ -98,31 +94,43 @@ impl Diflouroborane {
let _audit_log = AuditLog::new(&config).context("Failed to initialize audit log")?;
let resources = ResourcesHandle::new(config.machines.iter().map(|(id, desc)| {
Resource::new(Arc::new(resources::Inner::new(id.to_string(), statedb.clone(), desc.clone())))
Resource::new(Arc::new(resources::Inner::new(
id.to_string(),
statedb.clone(),
desc.clone(),
)))
}));
RESOURCES.set(resources.clone());
Ok(Self { config, executor, statedb, users, roles, resources })
Ok(Self {
config,
executor,
statedb,
users,
roles,
resources,
})
}
pub fn run(&mut self) -> anyhow::Result<()> {
let mut signals = signal_hook_async_std::Signals::new(&[
SIGINT,
SIGQUIT,
SIGTERM,
]).context("Failed to construct signal handler")?;
let mut signals = signal_hook_async_std::Signals::new(&[SIGINT, SIGQUIT, SIGTERM])
.context("Failed to construct signal handler")?;
actors::load(self.executor.clone(), &self.config, self.resources.clone())?;
let tlsconfig = TlsConfig::new(self.config.tlskeylog.as_ref(), !self.config.is_quiet())?;
let acceptor = tlsconfig.make_tls_acceptor(&self.config.tlsconfig)?;
let sessionmanager = SessionManager::new(self.users.clone(), self.roles.clone());
let authentication = AuthenticationHandle::new(self.users.clone());
let apiserver = self.executor.run(APIServer::bind(self.executor.clone(), &self.config.listens, acceptor, sessionmanager, authentication))?;
let apiserver = self.executor.run(APIServer::bind(
self.executor.clone(),
&self.config.listens,
acceptor,
sessionmanager,
authentication,
))?;
let (mut tx, rx) = async_oneshot::oneshot();
@ -142,4 +150,3 @@ impl Diflouroborane {
Ok(())
}
}

View File

@ -1,7 +1,6 @@
use tracing_subscriber::{EnvFilter};
use tracing_subscriber::EnvFilter;
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogConfig {
@ -32,8 +31,7 @@ pub fn init(config: &LogConfig) {
EnvFilter::from_env("BFFH_LOG")
};
let builder = tracing_subscriber::fmt()
.with_env_filter(filter);
let builder = tracing_subscriber::fmt().with_env_filter(filter);
let format = config.format.to_lowercase();
match format.as_str() {
@ -43,4 +41,4 @@ pub fn init(config: &LogConfig) {
_ => builder.init(),
}
tracing::info!(format = format.as_str(), "Logging initialized")
}
}

View File

@ -1,19 +1,19 @@
use rkyv::{Archive, Serialize, Deserialize};
use rkyv::{Archive, Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Archive, Serialize, Deserialize)]
#[derive(serde::Serialize, serde::Deserialize)]
#[derive(
Clone,
Debug,
PartialEq,
Eq,
Archive,
Serialize,
Deserialize,
serde::Serialize,
serde::Deserialize,
)]
pub struct Resource {
uuid: u128,
id: String,
name_idx: u64,
description_idx: u64,
}
}

View File

@ -1,21 +1,21 @@
use futures_signals::signal::{Mutable, Signal};
use rkyv::Infallible;
use std::ops::Deref;
use std::sync::Arc;
use futures_signals::signal::{Mutable, Signal};
use rkyv::{Archived, Deserialize};
use rkyv::option::ArchivedOption;
use rkyv::ser::Serializer;
use rkyv::ser::serializers::AllocSerializer;
use crate::audit::AUDIT;
use crate::authorization::permissions::PrivilegesBuf;
use crate::config::MachineDescription;
use crate::db::ArchivedValue;
use crate::resources::modules::fabaccess::{MachineState, Status, ArchivedStatus};
use crate::resources::modules::fabaccess::{ArchivedStatus, MachineState, Status};
use crate::resources::state::db::StateDB;
use crate::resources::state::State;
use crate::session::SessionHandle;
use crate::users::UserRef;
use rkyv::option::ArchivedOption;
use rkyv::ser::serializers::AllocSerializer;
use rkyv::ser::Serializer;
use rkyv::{Archived, Deserialize};
pub mod db;
pub mod search;
@ -43,27 +43,35 @@ impl Inner {
let update = state.to_state();
let mut serializer = AllocSerializer::<1024>::default();
serializer.serialize_value(&update).expect("failed to serialize new default state");
serializer
.serialize_value(&update)
.expect("failed to serialize new default state");
let val = ArchivedValue::new(serializer.into_serializer().into_inner());
db.put(&id.as_bytes(), &val).unwrap();
val
};
let signal = Mutable::new(state);
Self { id, db, signal, desc }
Self {
id,
db,
signal,
desc,
}
}
pub fn signal(&self) -> impl Signal<Item=ArchivedValue<State>> {
pub fn signal(&self) -> impl Signal<Item = ArchivedValue<State>> {
Box::pin(self.signal.signal_cloned())
}
fn get_state(&self) -> ArchivedValue<State> {
self.db.get(self.id.as_bytes())
self.db
.get(self.id.as_bytes())
.expect("lmdb error")
.expect("state should never be None")
}
fn get_state_ref(&self) -> impl Deref<Target=ArchivedValue<State>> + '_ {
fn get_state_ref(&self) -> impl Deref<Target = ArchivedValue<State>> + '_ {
self.signal.lock_ref()
}
@ -76,7 +84,10 @@ impl Inner {
self.db.put(&self.id.as_bytes(), &state).unwrap();
tracing::trace!("Updated DB, sending update signal");
AUDIT.get().unwrap().log(self.id.as_str(), &format!("{}", state));
AUDIT
.get()
.unwrap()
.log(self.id.as_str(), &format!("{}", state));
self.signal.set(state);
tracing::trace!("Sent update signal");
@ -85,7 +96,7 @@ impl Inner {
#[derive(Clone)]
pub struct Resource {
inner: Arc<Inner>
inner: Arc<Inner>,
}
impl Resource {
@ -97,7 +108,7 @@ impl Resource {
self.inner.get_state()
}
pub fn get_state_ref(&self) -> impl Deref<Target=ArchivedValue<State>> + '_ {
pub fn get_state_ref(&self) -> impl Deref<Target = ArchivedValue<State>> + '_ {
self.inner.get_state_ref()
}
@ -109,7 +120,7 @@ impl Resource {
self.inner.desc.name.as_str()
}
pub fn get_signal(&self) -> impl Signal<Item=ArchivedValue<State>> {
pub fn get_signal(&self) -> impl Signal<Item = ArchivedValue<State>> {
self.inner.signal()
}
@ -125,13 +136,13 @@ impl Resource {
let state = self.get_state_ref();
let state: &Archived<State> = state.as_ref();
match &state.inner.state {
ArchivedStatus::Blocked(user) |
ArchivedStatus::InUse(user) |
ArchivedStatus::Reserved(user) |
ArchivedStatus::ToCheck(user) => {
ArchivedStatus::Blocked(user)
| ArchivedStatus::InUse(user)
| ArchivedStatus::Reserved(user)
| ArchivedStatus::ToCheck(user) => {
let user = Deserialize::<UserRef, _>::deserialize(user, &mut Infallible).unwrap();
Some(user)
},
}
_ => None,
}
}
@ -158,8 +169,9 @@ impl Resource {
let old = self.inner.get_state();
let oldref: &Archived<State> = old.as_ref();
let previous: &Archived<Option<UserRef>> = &oldref.inner.previous;
let previous = Deserialize::<Option<UserRef>, _>::deserialize(previous, &mut rkyv::Infallible)
.expect("Infallible deserializer failed");
let previous =
Deserialize::<Option<UserRef>, _>::deserialize(previous, &mut rkyv::Infallible)
.expect("Infallible deserializer failed");
let new = MachineState { state, previous };
self.set_state(new);
}

View File

@ -1,14 +1,13 @@
use std::fmt;
use std::fmt::{Write};
use crate::utils::oid::ObjectIdentifier;
use once_cell::sync::Lazy;
use rkyv::{Archive, Archived, Deserialize, Infallible};
use std::fmt;
use std::fmt::Write;
use std::str::FromStr;
//use crate::oidvalue;
use crate::resources::state::{State};
use crate::resources::state::State;
use crate::users::UserRef;
@ -80,13 +79,12 @@ impl MachineState {
}
pub fn from(dbstate: &Archived<State>) -> Self {
let state: &Archived<MachineState> = &dbstate.inner;
Deserialize::deserialize(state, &mut Infallible).unwrap()
}
pub fn to_state(&self) -> State {
State {
inner: self.clone()
inner: self.clone(),
}
}

View File

@ -1,6 +1,3 @@
pub mod fabaccess;
pub trait MachineModel {
}
pub trait MachineModel {}

View File

@ -1,13 +1,13 @@
use crate::resources::Resource;
use std::collections::HashMap;
use std::sync::Arc;
use crate::resources::Resource;
struct Inner {
id: HashMap<String, Resource>,
}
impl Inner {
pub fn new(resources: impl IntoIterator<Item=Resource>) -> Self {
pub fn new(resources: impl IntoIterator<Item = Resource>) -> Self {
let mut id = HashMap::new();
for resource in resources {
@ -25,13 +25,13 @@ pub struct ResourcesHandle {
}
impl ResourcesHandle {
pub fn new(resources: impl IntoIterator<Item=Resource>) -> Self {
pub fn new(resources: impl IntoIterator<Item = Resource>) -> Self {
Self {
inner: Arc::new(Inner::new(resources)),
}
}
pub fn list_all(&self) -> impl IntoIterator<Item=&Resource> {
pub fn list_all(&self) -> impl IntoIterator<Item = &Resource> {
self.inner.id.values()
}

View File

@ -1,9 +1,6 @@
use crate::db;
use crate::db::{ArchivedValue, RawDB, DB, AlignedAdapter};
use lmdb::{
DatabaseFlags, Environment, EnvironmentFlags, Transaction,
WriteFlags,
};
use crate::db::{AlignedAdapter, ArchivedValue, RawDB, DB};
use lmdb::{DatabaseFlags, Environment, EnvironmentFlags, Transaction, WriteFlags};
use std::{path::Path, sync::Arc};
use crate::resources::state::State;
@ -67,7 +64,7 @@ impl StateDB {
pub fn get_all<'txn, T: Transaction>(
&self,
txn: &'txn T,
) -> Result<impl IntoIterator<Item = (&'txn [u8], ArchivedValue<State>)>, db::Error, > {
) -> Result<impl IntoIterator<Item = (&'txn [u8], ArchivedValue<State>)>, db::Error> {
self.db.get_all(txn)
}

View File

@ -1,37 +1,27 @@
use std::{
fmt,
hash::{
Hasher
},
};
use std::fmt::{Debug, Display, Formatter};
use std::{fmt, hash::Hasher};
use std::ops::Deref;
use rkyv::{Archive, Deserialize, out_field, Serialize};
use rkyv::{out_field, Archive, Deserialize, Serialize};
use serde::de::{Error, MapAccess, Unexpected};
use serde::Deserializer;
use serde::ser::SerializeMap;
use serde::Deserializer;
use crate::MachineState;
use crate::resources::modules::fabaccess::OID_VALUE;
use crate::MachineState;
use crate::utils::oid::ObjectIdentifier;
pub mod value;
pub mod db;
pub mod value;
#[derive(Archive, Serialize, Deserialize)]
#[derive(Clone, PartialEq, Eq)]
#[derive(Archive, Serialize, Deserialize, Clone, PartialEq, Eq)]
#[archive_attr(derive(Debug))]
pub struct State {
pub inner: MachineState,
}
impl fmt::Debug for State {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut sf = f.debug_struct("State");
@ -51,7 +41,8 @@ impl fmt::Display for ArchivedState {
impl serde::Serialize for State {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: serde::Serializer
where
S: serde::Serializer,
{
let mut ser = serializer.serialize_map(Some(1))?;
ser.serialize_entry(OID_VALUE.deref(), &self.inner)?;
@ -60,7 +51,8 @@ impl serde::Serialize for State {
}
impl<'de> serde::Deserialize<'de> for State {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de>
where
D: Deserializer<'de>,
{
deserializer.deserialize_map(StateVisitor)
}
@ -74,12 +66,13 @@ impl<'de> serde::de::Visitor<'de> for StateVisitor {
write!(formatter, "a map from OIDs to value objects")
}
fn visit_map<A: MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error>
{
let oid: ObjectIdentifier = map.next_key()?
.ok_or(A::Error::missing_field("oid"))?;
fn visit_map<A: MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
let oid: ObjectIdentifier = map.next_key()?.ok_or(A::Error::missing_field("oid"))?;
if oid != *OID_VALUE.deref() {
return Err(A::Error::invalid_value(Unexpected::Other("Unknown OID"), &"OID of fabaccess state"))
return Err(A::Error::invalid_value(
Unexpected::Other("Unknown OID"),
&"OID of fabaccess state",
));
}
let val: MachineState = map.next_value()?;
Ok(State { inner: val })
@ -88,8 +81,8 @@ impl<'de> serde::de::Visitor<'de> for StateVisitor {
#[cfg(test)]
pub mod tests {
use super::*;
use super::value::*;
use super::*;
pub(crate) fn gen_random() -> State {
let amt: u8 = rand::random::<u8>() % 20;
@ -97,7 +90,7 @@ pub mod tests {
let mut sb = State::build();
for _ in 0..amt {
let oid = crate::utils::oid::tests::gen_random();
sb = match rand::random::<u32>()%12 {
sb = match rand::random::<u32>() % 12 {
0 => sb.add(oid, Box::new(rand::random::<bool>())),
1 => sb.add(oid, Box::new(rand::random::<u8>())),
2 => sb.add(oid, Box::new(rand::random::<u16>())),
@ -156,4 +149,4 @@ pub mod tests {
assert_eq!(stateC, stateB);
}
}
}

View File

@ -1,10 +1,12 @@
use std::{hash::Hash};
use std::hash::Hash;
use ptr_meta::{DynMetadata, Pointee};
use rkyv::{out_field, Archive, ArchivePointee, ArchiveUnsized, Archived, ArchivedMetadata, Serialize, SerializeUnsized, RelPtr};
use rkyv::{
out_field, Archive, ArchivePointee, ArchiveUnsized, Archived, ArchivedMetadata, RelPtr,
Serialize, SerializeUnsized,
};
use rkyv_dyn::{DynError, DynSerializer};
use crate::utils::oid::ObjectIdentifier;
// Not using linkme because dynamically loaded modules
@ -16,7 +18,6 @@ use serde::ser::SerializeMap;
use std::collections::HashMap;
use std::ops::Deref;
use std::sync::atomic::{AtomicUsize, Ordering};
@ -61,7 +62,7 @@ struct NewStateBuilder {
// turns into
struct NewState {
inner: ArchivedVec<ArchivedMetaBox<dyn ArchivedStateValue>>
inner: ArchivedVec<ArchivedMetaBox<dyn ArchivedStateValue>>,
}
impl NewState {
pub fn get_value<T: TypeOid>(&self) -> Option<&T> {
@ -148,9 +149,9 @@ pub trait TypeOid {
}
impl<T> SerializeDynOid for T
where
T: for<'a> Serialize<dyn DynSerializer + 'a>,
T::Archived: TypeOid,
where
T: for<'a> Serialize<dyn DynSerializer + 'a>,
T::Archived: TypeOid,
{
fn archived_type_oid(&self) -> &'static ObjectIdentifier {
Archived::<T>::type_oid()
@ -371,7 +372,7 @@ pub mod macros {
stringify!($y)
}
}
}
};
}
#[macro_export]
@ -380,16 +381,15 @@ pub mod macros {
unsafe impl $crate::resources::state::value::RegisteredImpl for $z {
fn vtable() -> usize {
unsafe {
::core::mem::transmute(ptr_meta::metadata(
::core::ptr::null::<$z>() as *const dyn $crate::resources::state::value::ArchivedStateValue
))
::core::mem::transmute(ptr_meta::metadata(::core::ptr::null::<$z>()
as *const dyn $crate::resources::state::value::ArchivedStateValue))
}
}
fn debug_info() -> $crate::resources::state::value::ImplDebugInfo {
$crate::debug_info!()
}
}
}
};
}
#[macro_export]
@ -801,4 +801,4 @@ mod tests {
}
}
*/
*/

View File

@ -0,0 +1 @@

View File

@ -1,20 +1,13 @@
use crate::authorization::permissions::Permission;
use crate::authorization::roles::{Roles};
use crate::authorization::roles::Roles;
use crate::resources::Resource;
use crate::Users;
use crate::users::{db, UserRef};
use crate::Users;
#[derive(Clone)]
pub struct SessionManager {
users: Users,
roles: Roles,
// cache: SessionCache // todo
}
impl SessionManager {
@ -52,33 +45,39 @@ impl SessionHandle {
}
pub fn get_user(&self) -> db::User {
self.users.get_user(self.user.get_username()).expect("Failed to get user self")
self.users
.get_user(self.user.get_username())
.expect("Failed to get user self")
}
pub fn has_disclose(&self, resource: &Resource) -> bool {
if let Some(user) = self.users.get_user(self.user.get_username()) {
self.roles.is_permitted(&user.userdata, &resource.get_required_privs().disclose)
self.roles
.is_permitted(&user.userdata, &resource.get_required_privs().disclose)
} else {
false
}
}
pub fn has_read(&self, resource: &Resource) -> bool {
if let Some(user) = self.users.get_user(self.user.get_username()) {
self.roles.is_permitted(&user.userdata, &resource.get_required_privs().read)
self.roles
.is_permitted(&user.userdata, &resource.get_required_privs().read)
} else {
false
}
}
pub fn has_write(&self, resource: &Resource) -> bool {
if let Some(user) = self.users.get_user(self.user.get_username()) {
self.roles.is_permitted(&user.userdata, &resource.get_required_privs().write)
self.roles
.is_permitted(&user.userdata, &resource.get_required_privs().write)
} else {
false
}
}
pub fn has_manage(&self, resource: &Resource) -> bool {
if let Some(user) = self.users.get_user(self.user.get_username()) {
self.roles.is_permitted(&user.userdata, &resource.get_required_privs().manage)
self.roles
.is_permitted(&user.userdata, &resource.get_required_privs().manage)
} else {
false
}
@ -90,4 +89,4 @@ impl SessionHandle {
false
}
}
}
}

View File

@ -4,26 +4,39 @@ use std::io::BufReader;
use std::path::Path;
use std::sync::Arc;
use futures_rustls::TlsAcceptor;
use rustls::{Certificate, PrivateKey, ServerConfig, SupportedCipherSuite};
use rustls::version::{TLS12, TLS13};
use tracing::{Level};
use crate::capnp::TlsListen;
use futures_rustls::TlsAcceptor;
use rustls::version::{TLS12, TLS13};
use rustls::{Certificate, PrivateKey, ServerConfig, SupportedCipherSuite};
use tracing::Level;
use crate::keylog::KeyLogFile;
fn lookup_cipher_suite(name: &str) -> Option<SupportedCipherSuite> {
match name {
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => Some(rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256),
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => Some(rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384),
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => Some(rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256),
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" => Some(rustls::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256),
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" => Some(rustls::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384),
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => Some(rustls::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256),
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => {
Some(rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256)
}
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => {
Some(rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384)
}
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => {
Some(rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256)
}
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" => {
Some(rustls::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256)
}
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" => {
Some(rustls::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384)
}
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => {
Some(rustls::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256)
}
"TLS13_AES_128_GCM_SHA256" => Some(rustls::cipher_suite::TLS13_AES_128_GCM_SHA256),
"TLS13_AES_256_GCM_SHA384" => Some(rustls::cipher_suite::TLS13_AES_256_GCM_SHA384),
"TLS13_CHACHA20_POLY1305_SHA256" => Some(rustls::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256),
"TLS13_CHACHA20_POLY1305_SHA256" => {
Some(rustls::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256)
}
_ => None,
}
}
@ -43,7 +56,6 @@ impl TlsConfig {
}
if let Some(path) = keylogfile {
let keylog = Some(KeyLogFile::new(path).map(|ok| Arc::new(ok))?);
Ok(Self { keylog })
} else {

View File

@ -1,15 +1,15 @@
use lmdb::{DatabaseFlags, Environment, RwTransaction, Transaction, WriteFlags};
use std::collections::{HashMap};
use rkyv::Infallible;
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Context;
use std::sync::Arc;
use rkyv::{Deserialize};
use rkyv::ser::Serializer;
use rkyv::ser::serializers::AllocSerializer;
use crate::db;
use crate::db::{AlignedAdapter, ArchivedValue, DB, RawDB};
use crate::db::{AlignedAdapter, ArchivedValue, RawDB, DB};
use rkyv::ser::serializers::AllocSerializer;
use rkyv::ser::Serializer;
use rkyv::Deserialize;
#[derive(
Clone,
@ -30,8 +30,7 @@ pub struct User {
impl User {
pub fn check_password(&self, pwd: &[u8]) -> anyhow::Result<bool> {
if let Some(ref encoded) = self.userdata.passwd {
argon2::verify_encoded(encoded, pwd)
.context("Stored password is an invalid string")
argon2::verify_encoded(encoded, pwd).context("Stored password is an invalid string")
} else {
Ok(false)
}
@ -48,23 +47,23 @@ impl User {
id: username.to_string(),
userdata: UserData {
passwd: Some(hash),
.. Default::default()
}
..Default::default()
},
}
}
}
#[derive(
Clone,
PartialEq,
Eq,
Debug,
Default,
rkyv::Archive,
rkyv::Serialize,
rkyv::Deserialize,
serde::Serialize,
serde::Deserialize,
Clone,
PartialEq,
Eq,
Debug,
Default,
rkyv::Archive,
rkyv::Serialize,
rkyv::Deserialize,
serde::Serialize,
serde::Deserialize,
)]
/// Data on an user to base decisions on
///
@ -85,12 +84,19 @@ pub struct UserData {
impl UserData {
pub fn new(roles: Vec<String>) -> Self {
Self { roles, kv: HashMap::new(), passwd: None }
Self {
roles,
kv: HashMap::new(),
passwd: None,
}
}
pub fn new_with_kv(roles: Vec<String>, kv: HashMap<String, String>) -> Self {
Self { roles, kv, passwd: None }
Self {
roles,
kv,
passwd: None,
}
}
}
#[derive(Clone, Debug)]
@ -140,7 +146,12 @@ impl UserDB {
Ok(())
}
pub fn put_txn(&self, txn: &mut RwTransaction, uid: &str, user: &User) -> Result<(), db::Error> {
pub fn put_txn(
&self,
txn: &mut RwTransaction,
uid: &str,
user: &User,
) -> Result<(), db::Error> {
let mut serializer = AllocSerializer::<1024>::default();
serializer.serialize_value(user).expect("rkyv error");
let v = serializer.into_serializer().into_inner();
@ -169,10 +180,11 @@ impl UserDB {
let mut out = Vec::new();
for (uid, user) in iter {
let uid = unsafe { std::str::from_utf8_unchecked(uid).to_string() };
let user: User = Deserialize::<User, _>::deserialize(user.as_ref(), &mut Infallible).unwrap();
let user: User =
Deserialize::<User, _>::deserialize(user.as_ref(), &mut Infallible).unwrap();
out.push((uid, user));
}
Ok(out)
}
}
}

View File

@ -10,8 +10,6 @@ use std::sync::Arc;
pub mod db;
use crate::users::db::UserData;
use crate::UserDB;
@ -87,10 +85,9 @@ impl Users {
pub fn get_user(&self, uid: &str) -> Option<db::User> {
tracing::trace!(uid, "Looking up user");
self.userdb
.get(uid)
.unwrap()
.map(|user| Deserialize::<db::User, _>::deserialize(user.as_ref(), &mut Infallible).unwrap())
self.userdb.get(uid).unwrap().map(|user| {
Deserialize::<db::User, _>::deserialize(user.as_ref(), &mut Infallible).unwrap()
})
}
pub fn put_user(&self, uid: &str, user: &db::User) -> Result<(), lmdb::Error> {

View File

@ -1,20 +1,16 @@
use std::collections::HashMap;
use once_cell::sync::Lazy;
struct Locales {
map: HashMap<&'static str, HashMap<&'static str, &'static str>>
map: HashMap<&'static str, HashMap<&'static str, &'static str>>,
}
impl Locales {
pub fn get(&self, lang: &str, msg: &str)
-> Option<(&'static str, &'static str)>
{
self.map.get(msg).and_then(|map| {
map.get_key_value(lang).map(|(k,v)| (*k, *v))
})
pub fn get(&self, lang: &str, msg: &str) -> Option<(&'static str, &'static str)> {
self.map
.get(msg)
.and_then(|map| map.get_key_value(lang).map(|(k, v)| (*k, *v)))
}
pub fn available(&self, _msg: &str) -> &[&'static str] {
@ -22,8 +18,8 @@ impl Locales {
}
}
static LANG: Lazy<Locales> = Lazy::new(|| {
Locales { map: HashMap::new() }
static LANG: Lazy<Locales> = Lazy::new(|| Locales {
map: HashMap::new(),
});
struct L10NString {
@ -59,4 +55,4 @@ impl l10n::Server for L10NString {
Promise::ok(())
}
}
*/
*/

View File

@ -7,4 +7,4 @@ pub mod varint;
/// Localization strings
pub mod l10nstring;
pub mod uuid;
pub mod uuid;

View File

@ -52,16 +52,16 @@
//! [Object Identifiers]: https://en.wikipedia.org/wiki/Object_identifier
//! [ITU]: https://en.wikipedia.org/wiki/International_Telecommunications_Union
use rkyv::{Archive, Serialize};
use crate::utils::varint::VarU128;
use rkyv::ser::Serializer;
use rkyv::vec::{ArchivedVec, VecResolver};
use rkyv::{Archive, Serialize};
use std::convert::TryFrom;
use std::ops::Deref;
use std::convert::TryInto;
use std::fmt;
use std::fmt::Formatter;
use rkyv::ser::Serializer;
use std::ops::Deref;
use std::str::FromStr;
use crate::utils::varint::VarU128;
use std::convert::TryInto;
type Node = u128;
type VarNode = VarU128;
@ -69,8 +69,8 @@ type VarNode = VarU128;
/// Convenience module for quickly importing the public interface (e.g., `use oid::prelude::*`)
pub mod prelude {
pub use super::ObjectIdentifier;
pub use super::ObjectIdentifierRoot::*;
pub use super::ObjectIdentifierError;
pub use super::ObjectIdentifierRoot::*;
pub use core::convert::{TryFrom, TryInto};
}
@ -132,7 +132,8 @@ impl ObjectIdentifier {
let mut parsing_big_int = false;
let mut big_int: Node = 0;
for i in 1..nodes.len() {
if !parsing_big_int && nodes[i] < 128 {} else {
if !parsing_big_int && nodes[i] < 128 {
} else {
if big_int > 0 {
if big_int >= Node::MAX >> 7 {
return Err(ObjectIdentifierError::IllegalChildNodeValue);
@ -149,9 +150,11 @@ impl ObjectIdentifier {
Ok(Self { nodes })
}
pub fn build<B: AsRef<[Node]>>(root: ObjectIdentifierRoot, first: u8, children: B)
-> Result<Self, ObjectIdentifierError>
{
pub fn build<B: AsRef<[Node]>>(
root: ObjectIdentifierRoot,
first: u8,
children: B,
) -> Result<Self, ObjectIdentifierError> {
if first > 40 {
return Err(ObjectIdentifierError::IllegalFirstChildNode);
}
@ -163,7 +166,9 @@ impl ObjectIdentifier {
let var: VarNode = child.into();
vec.extend_from_slice(var.as_bytes())
}
Ok(Self { nodes: vec.into_boxed_slice() })
Ok(Self {
nodes: vec.into_boxed_slice(),
})
}
#[inline(always)]
@ -196,12 +201,14 @@ impl FromStr for ObjectIdentifier {
fn from_str(value: &str) -> Result<Self, Self::Err> {
let mut nodes = value.split(".");
let root = nodes.next()
let root = nodes
.next()
.and_then(|n| n.parse::<u8>().ok())
.and_then(|n| n.try_into().ok())
.ok_or(ObjectIdentifierError::IllegalRootNode)?;
let first = nodes.next()
let first = nodes
.next()
.and_then(|n| parse_string_first_node(n).ok())
.ok_or(ObjectIdentifierError::IllegalFirstChildNode)?;
@ -238,7 +245,7 @@ impl fmt::Debug for ObjectIdentifier {
#[repr(transparent)]
pub struct ArchivedObjectIdentifier {
archived: ArchivedVec<u8>
archived: ArchivedVec<u8>,
}
impl Deref for ArchivedObjectIdentifier {
@ -250,8 +257,12 @@ impl Deref for ArchivedObjectIdentifier {
impl fmt::Debug for ArchivedObjectIdentifier {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}", &convert_to_string(self.archived.as_slice())
.unwrap_or_else(|e| format!("Invalid OID: {:?}", e)))
write!(
f,
"{}",
&convert_to_string(self.archived.as_slice())
.unwrap_or_else(|e| format!("Invalid OID: {:?}", e))
)
}
}
@ -275,7 +286,8 @@ impl Archive for &'static ObjectIdentifier {
}
impl<S: Serializer + ?Sized> Serialize<S> for ObjectIdentifier
where [u8]: rkyv::SerializeUnsized<S>
where
[u8]: rkyv::SerializeUnsized<S>,
{
fn serialize(&self, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
ArchivedVec::serialize_from_slice(self.nodes.as_ref(), serializer)
@ -340,8 +352,7 @@ fn convert_to_string(nodes: &[u8]) -> Result<String, ObjectIdentifierError> {
impl Into<String> for &ObjectIdentifier {
fn into(self) -> String {
convert_to_string(&self.nodes)
.expect("Valid OID object couldn't be serialized.")
convert_to_string(&self.nodes).expect("Valid OID object couldn't be serialized.")
}
}
@ -468,16 +479,13 @@ mod serde_support {
}
}
impl ser::Serialize for ArchivedObjectIdentifier {
fn serialize<S>(
&self,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: ser::Serializer,
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: ser::Serializer,
{
if serializer.is_human_readable() {
let encoded: String = convert_to_string(self.deref())
.expect("Failed to convert valid OID to String");
let encoded: String =
convert_to_string(self.deref()).expect("Failed to convert valid OID to String");
serializer.serialize_str(&encoded)
} else {
serializer.serialize_bytes(self.deref())
@ -498,8 +506,7 @@ pub(crate) mod tests {
children.push(rand::random());
}
ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, 25, children)
.unwrap()
ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, 25, children).unwrap()
}
#[test]
@ -508,7 +515,8 @@ pub(crate) mod tests {
ObjectIdentifierRoot::ItuT,
0x01,
vec![1, 2, 3, 5, 8, 13, 21],
).unwrap();
)
.unwrap();
let buffer: Vec<u8> = bincode::serialize(&expected).unwrap();
let actual = bincode::deserialize(&buffer).unwrap();
assert_eq!(expected, actual);
@ -517,11 +525,7 @@ pub(crate) mod tests {
#[test]
fn encode_binary_root_node_0() {
let expected: Vec<u8> = vec![0];
let oid = ObjectIdentifier::build(
ObjectIdentifierRoot::ItuT,
0x00,
vec![],
).unwrap();
let oid = ObjectIdentifier::build(ObjectIdentifierRoot::ItuT, 0x00, vec![]).unwrap();
let actual: Vec<u8> = oid.into();
assert_eq!(expected, actual);
}
@ -529,11 +533,7 @@ pub(crate) mod tests {
#[test]
fn encode_binary_root_node_1() {
let expected: Vec<u8> = vec![40];
let oid = ObjectIdentifier::build(
ObjectIdentifierRoot::Iso,
0x00,
vec![],
).unwrap();
let oid = ObjectIdentifier::build(ObjectIdentifierRoot::Iso, 0x00, vec![]).unwrap();
let actual: Vec<u8> = oid.into();
assert_eq!(expected, actual);
}
@ -541,11 +541,8 @@ pub(crate) mod tests {
#[test]
fn encode_binary_root_node_2() {
let expected: Vec<u8> = vec![80];
let oid = ObjectIdentifier::build(
ObjectIdentifierRoot::JointIsoItuT,
0x00,
vec![],
).unwrap();
let oid =
ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, 0x00, vec![]).unwrap();
let actual: Vec<u8> = oid.into();
assert_eq!(expected, actual);
}
@ -557,7 +554,8 @@ pub(crate) mod tests {
ObjectIdentifierRoot::ItuT,
0x01,
vec![1, 2, 3, 5, 8, 13, 21],
).unwrap();
)
.unwrap();
let actual: Vec<u8> = oid.into();
assert_eq!(expected, actual);
}
@ -572,7 +570,8 @@ pub(crate) mod tests {
ObjectIdentifierRoot::JointIsoItuT,
39,
vec![42, 2501, 65535, 2147483647, 1235, 2352],
).unwrap();
)
.unwrap();
let actual: Vec<u8> = (oid).into();
assert_eq!(expected, actual);
}
@ -580,11 +579,7 @@ pub(crate) mod tests {
#[test]
fn encode_string_root_node_0() {
let expected = "0.0";
let oid = ObjectIdentifier::build(
ObjectIdentifierRoot::ItuT,
0x00,
vec![],
).unwrap();
let oid = ObjectIdentifier::build(ObjectIdentifierRoot::ItuT, 0x00, vec![]).unwrap();
let actual: String = (oid).into();
assert_eq!(expected, actual);
}
@ -592,11 +587,7 @@ pub(crate) mod tests {
#[test]
fn encode_string_root_node_1() {
let expected = "1.0";
let oid = ObjectIdentifier::build(
ObjectIdentifierRoot::Iso,
0x00,
vec![],
).unwrap();
let oid = ObjectIdentifier::build(ObjectIdentifierRoot::Iso, 0x00, vec![]).unwrap();
let actual: String = (&oid).into();
assert_eq!(expected, actual);
}
@ -604,11 +595,8 @@ pub(crate) mod tests {
#[test]
fn encode_string_root_node_2() {
let expected = "2.0";
let oid = ObjectIdentifier::build(
ObjectIdentifierRoot::JointIsoItuT,
0x00,
vec![],
).unwrap();
let oid =
ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, 0x00, vec![]).unwrap();
let actual: String = (&oid).into();
assert_eq!(expected, actual);
}
@ -620,7 +608,8 @@ pub(crate) mod tests {
ObjectIdentifierRoot::ItuT,
0x01,
vec![1, 2, 3, 5, 8, 13, 21],
).unwrap();
)
.unwrap();
let actual: String = (&oid).into();
assert_eq!(expected, actual);
}
@ -632,40 +621,29 @@ pub(crate) mod tests {
ObjectIdentifierRoot::JointIsoItuT,
39,
vec![42, 2501, 65535, 2147483647, 1235, 2352],
).unwrap();
)
.unwrap();
let actual: String = (&oid).into();
assert_eq!(expected, actual);
}
#[test]
fn parse_binary_root_node_0() {
let expected = ObjectIdentifier::build(
ObjectIdentifierRoot::ItuT,
0x00,
vec![],
);
let expected = ObjectIdentifier::build(ObjectIdentifierRoot::ItuT, 0x00, vec![]);
let actual = vec![0x00].try_into();
assert_eq!(expected, actual);
}
#[test]
fn parse_binary_root_node_1() {
let expected = ObjectIdentifier::build(
ObjectIdentifierRoot::Iso,
0x00,
vec![],
);
let expected = ObjectIdentifier::build(ObjectIdentifierRoot::Iso, 0x00, vec![]);
let actual = vec![40].try_into();
assert_eq!(expected, actual);
}
#[test]
fn parse_binary_root_node_2() {
let expected = ObjectIdentifier::build(
ObjectIdentifierRoot::JointIsoItuT,
0x00,
vec![],
);
let expected = ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, 0x00, vec![]);
let actual = vec![80].try_into();
assert_eq!(expected, actual);
}
@ -698,33 +676,21 @@ pub(crate) mod tests {
#[test]
fn parse_string_root_node_0() {
let expected = ObjectIdentifier::build(
ObjectIdentifierRoot::ItuT,
0x00,
vec![],
);
let expected = ObjectIdentifier::build(ObjectIdentifierRoot::ItuT, 0x00, vec![]);
let actual = "0.0".try_into();
assert_eq!(expected, actual);
}
#[test]
fn parse_string_root_node_1() {
let expected = ObjectIdentifier::build(
ObjectIdentifierRoot::Iso,
0x00,
vec![],
);
let expected = ObjectIdentifier::build(ObjectIdentifierRoot::Iso, 0x00, vec![]);
let actual = "1.0".try_into();
assert_eq!(expected, actual);
}
#[test]
fn parse_string_root_node_2() {
let expected = ObjectIdentifier::build(
ObjectIdentifierRoot::JointIsoItuT,
0x00,
vec![],
);
let expected = ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, 0x00, vec![]);
let actual = "2.0".try_into();
assert_eq!(expected, actual);
}
@ -852,18 +818,21 @@ pub(crate) mod tests {
#[test]
fn parse_string_large_children_ok() {
let expected =
ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT,
25,
vec![190754093376743485973207716749546715206,
255822649272987943607843257596365752308,
15843412533224453995377625663329542022,
6457999595881951503805148772927347934,
19545192863105095042881850060069531734,
195548685662657784196186957311035194990,
233020488258340943072303499291936117654,
193307160423854019916786016773068715190,
]).unwrap();
let expected = ObjectIdentifier::build(
ObjectIdentifierRoot::JointIsoItuT,
25,
vec![
190754093376743485973207716749546715206,
255822649272987943607843257596365752308,
15843412533224453995377625663329542022,
6457999595881951503805148772927347934,
19545192863105095042881850060069531734,
195548685662657784196186957311035194990,
233020488258340943072303499291936117654,
193307160423854019916786016773068715190,
],
)
.unwrap();
let actual = "2.25.190754093376743485973207716749546715206.\
255822649272987943607843257596365752308.\
15843412533224453995377625663329542022.\
@ -871,31 +840,27 @@ pub(crate) mod tests {
19545192863105095042881850060069531734.\
195548685662657784196186957311035194990.\
233020488258340943072303499291936117654.\
193307160423854019916786016773068715190".try_into().unwrap();
193307160423854019916786016773068715190"
.try_into()
.unwrap();
assert_eq!(expected, actual);
}
#[test]
fn encode_to_string() {
let expected = String::from("1.2.3.4");
let actual: String = ObjectIdentifier::build(
ObjectIdentifierRoot::Iso,
2,
vec![3, 4],
).unwrap()
.into();
let actual: String = ObjectIdentifier::build(ObjectIdentifierRoot::Iso, 2, vec![3, 4])
.unwrap()
.into();
assert_eq!(expected, actual);
}
#[test]
fn encode_to_bytes() {
let expected = vec![0x2A, 0x03, 0x04];
let actual: Vec<u8> = ObjectIdentifier::build(
ObjectIdentifierRoot::Iso,
2,
vec![3, 4],
).unwrap()
.into();
let actual: Vec<u8> = ObjectIdentifier::build(ObjectIdentifierRoot::Iso, 2, vec![3, 4])
.unwrap()
.into();
assert_eq!(expected, actual);
}
}
}

View File

@ -1,11 +1,10 @@
use uuid::Uuid;
use api::general_capnp::u_u_i_d::{Builder, Reader};
use uuid::Uuid;
pub fn uuid_to_api(uuid: Uuid, mut builder: Builder) {
let [a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p]
= uuid.as_u128().to_ne_bytes();
let lower = u64::from_ne_bytes([a,b,c,d,e,f,g,h]);
let upper = u64::from_ne_bytes([i,j,k,l,m,n,o,p]);
let [a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p] = uuid.as_u128().to_ne_bytes();
let lower = u64::from_ne_bytes([a, b, c, d, e, f, g, h]);
let upper = u64::from_ne_bytes([i, j, k, l, m, n, o, p]);
builder.set_uuid0(lower);
builder.set_uuid1(upper);
}
@ -13,8 +12,8 @@ pub fn uuid_to_api(uuid: Uuid, mut builder: Builder) {
pub fn api_to_uuid(reader: Reader) -> Uuid {
let lower: u64 = reader.reborrow().get_uuid0();
let upper: u64 = reader.get_uuid1();
let [a,b,c,d,e,f,g,h] = lower.to_ne_bytes();
let [i,j,k,l,m,n,o,p] = upper.to_ne_bytes();
let num = u128::from_ne_bytes([a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p]);
let [a, b, c, d, e, f, g, h] = lower.to_ne_bytes();
let [i, j, k, l, m, n, o, p] = upper.to_ne_bytes();
let num = u128::from_ne_bytes([a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p]);
Uuid::from_u128(num)
}
}

View File

@ -27,7 +27,6 @@ impl<const N: usize> VarUInt<N> {
pub const fn into_bytes(self) -> [u8; N] {
self.bytes
}
}
impl<const N: usize> Default for VarUInt<N> {
@ -52,7 +51,7 @@ macro_rules! convert_from {
let bytes = this.as_mut_bytes();
let mut more = 0u8;
let mut idx: usize = bytes.len()-1;
let mut idx: usize = bytes.len() - 1;
while num > 0x7f {
bytes[idx] = ((num & 0x7f) as u8 | more);
@ -65,7 +64,7 @@ macro_rules! convert_from {
this.offset = idx;
this
}
}
};
}
macro_rules! convert_into {
@ -84,7 +83,7 @@ macro_rules! convert_into {
let mut shift = 0;
for neg in 1..=len {
let idx = len-neg;
let idx = len - neg;
let val = (bytes[idx] & 0x7f) as $x;
let shifted = val << shift;
out |= shifted;
@ -93,7 +92,7 @@ macro_rules! convert_into {
out
}
}
};
}
macro_rules! impl_convert_from_to {
@ -105,7 +104,7 @@ macro_rules! impl_convert_from_to {
impl Into<$num> for VarUInt<$req> {
convert_into! { $num }
}
}
};
}
impl_convert_from_to!(u8, 2, VarU8);
@ -123,8 +122,9 @@ type VarUsize = VarU32;
type VarUsize = VarU16;
impl<T, const N: usize> From<&T> for VarUInt<N>
where T: Copy,
VarUInt<N>: From<T>
where
T: Copy,
VarUInt<N>: From<T>,
{
fn from(t: &T) -> Self {
(*t).into()
@ -162,4 +162,4 @@ mod tests {
let expected: &[u8] = &[129, 0];
assert_eq!(vi.as_bytes(), expected)
}
}
}

View File

@ -1,12 +1,9 @@
use clap::{Arg, Command};
use diflouroborane::{config, Diflouroborane};
use std::str::FromStr;
use std::{env, io, io::Write, path::PathBuf};
use nix::NixPath;
fn main() -> anyhow::Result<()> {
@ -125,12 +122,19 @@ fn main() -> anyhow::Result<()> {
unimplemented!()
} else if matches.is_present("load") {
let bffh = Diflouroborane::new(config)?;
if bffh.users.load_file(matches.value_of("load").unwrap()).is_ok() {
if bffh
.users
.load_file(matches.value_of("load").unwrap())
.is_ok()
{
tracing::info!("loaded users from {}", matches.value_of("load").unwrap());
} else {
tracing::error!("failed to load users from {}", matches.value_of("load").unwrap());
tracing::error!(
"failed to load users from {}",
matches.value_of("load").unwrap()
);
}
return Ok(())
return Ok(());
} else {
let keylog = matches.value_of("keylog");
// When passed an empty string (i.e no value) take the value from the env

View File

@ -4,48 +4,57 @@ fn main() {
println!(">>> Building version number...");
let rustc = std::env::var("RUSTC").unwrap();
let out = Command::new(rustc).arg("--version")
.output()
.expect("failed to run `rustc --version`");
let rustc_version = String::from_utf8(out.stdout)
.expect("rustc --version returned invalid UTF-8");
let out = Command::new(rustc)
.arg("--version")
.output()
.expect("failed to run `rustc --version`");
let rustc_version =
String::from_utf8(out.stdout).expect("rustc --version returned invalid UTF-8");
let rustc_version = rustc_version.trim();
println!("cargo:rustc-env=CARGO_RUSTC_VERSION={}", rustc_version);
println!("cargo:rerun-if-env-changed=BFFHD_BUILD_TAGGED_RELEASE");
let tagged_release = option_env!("BFFHD_BUILD_TAGGED_RELEASE") == Some("1");
let version_string = if tagged_release {
format!("{version} [{rustc}]",
version = env!("CARGO_PKG_VERSION"),
rustc = rustc_version)
format!(
"{version} [{rustc}]",
version = env!("CARGO_PKG_VERSION"),
rustc = rustc_version
)
} else {
// Build version number using the current git commit id
let out = Command::new("git").arg("rev-list")
.args(["HEAD", "-1"])
.output()
.expect("failed to run `git rev-list HEAD -1`");
let owned_gitrev = String::from_utf8(out.stdout)
.expect("git rev-list output was not valid UTF8");
let out = Command::new("git")
.arg("rev-list")
.args(["HEAD", "-1"])
.output()
.expect("failed to run `git rev-list HEAD -1`");
let owned_gitrev =
String::from_utf8(out.stdout).expect("git rev-list output was not valid UTF8");
let gitrev = owned_gitrev.trim();
let abbrev = match gitrev.len(){
let abbrev = match gitrev.len() {
0 => "unknown",
_ => &gitrev[0..9],
};
let out = Command::new("git").arg("log")
.args(["-1", "--format=%as"])
.output()
.expect("failed to run `git log -1 --format=\"format:%as\"`");
let commit_date = String::from_utf8(out.stdout)
.expect("git log output was not valid UTF8");
let out = Command::new("git")
.arg("log")
.args(["-1", "--format=%as"])
.output()
.expect("failed to run `git log -1 --format=\"format:%as\"`");
let commit_date = String::from_utf8(out.stdout).expect("git log output was not valid UTF8");
let commit_date = commit_date.trim();
format!("{version} ({gitrev} {date}) [{rustc}]",
version=env!("CARGO_PKG_VERSION"),
gitrev=abbrev,
date=commit_date,
rustc=rustc_version)
format!(
"{version} ({gitrev} {date}) [{rustc}]",
version = env!("CARGO_PKG_VERSION"),
gitrev = abbrev,
date = commit_date,
rustc = rustc_version
)
};
println!("cargo:rustc-env=BFFHD_VERSION_STRING={}", version_string);
println!("cargo:rustc-env=BFFHD_RELEASE_STRING=\"BFFH {}\"", version_string);
println!(
"cargo:rustc-env=BFFHD_RELEASE_STRING=\"BFFH {}\"",
version_string
);
}

View File

@ -1,5 +1,5 @@
use sdk::BoxFuture;
use sdk::initiators::{Initiator, InitiatorError, ResourceID, UpdateSink};
use sdk::BoxFuture;
#[sdk::module]
struct Dummy {
@ -10,11 +10,17 @@ struct Dummy {
}
impl Initiator for Dummy {
fn start_for(&mut self, machine: ResourceID) -> BoxFuture<'static, Result<(), Box<dyn InitiatorError>>> {
fn start_for(
&mut self,
machine: ResourceID,
) -> BoxFuture<'static, Result<(), Box<dyn InitiatorError>>> {
todo!()
}
fn run(&mut self, request: &mut UpdateSink) -> BoxFuture<'static, Result<(), Box<dyn InitiatorError>>> {
fn run(
&mut self,
request: &mut UpdateSink,
) -> BoxFuture<'static, Result<(), Box<dyn InitiatorError>>> {
todo!()
}
}

View File

@ -1,10 +1,10 @@
use proc_macro::TokenStream;
use std::sync::Mutex;
use quote::{format_ident, quote};
use syn::{braced, parse_macro_input, Field, Ident, Token, Visibility, Type};
use std::sync::Mutex;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::token::Brace;
use syn::{braced, parse_macro_input, Field, Ident, Token, Type, Visibility};
mod keywords {
syn::custom_keyword!(initiator);
@ -32,8 +32,10 @@ impl Parse for ModuleAttrs {
} else if lookahead.peek(keywords::sensor) {
Ok(ModuleAttrs::Sensor)
} else {
Err(input.error("Module type must be empty or one of \"initiator\", \"actor\", or \
\"sensor\""))
Err(input.error(
"Module type must be empty or one of \"initiator\", \"actor\", or \
\"sensor\"",
))
}
}
}
@ -84,4 +86,4 @@ pub fn module(attr: TokenStream, tokens: TokenStream) -> TokenStream {
}
};
output.into()
}
}

View File

@ -1,10 +1,4 @@
pub use diflouroborane::{
initiators::{
UpdateSink,
UpdateError,
Initiator,
InitiatorError,
},
initiators::{Initiator, InitiatorError, UpdateError, UpdateSink},
resource::claim::ResourceID,
};

View File

@ -1,5 +1,4 @@
#[forbid(private_in_public)]
pub use sdk_proc::module;
pub use futures_util::future::BoxFuture;

View File

@ -1,22 +1,22 @@
use executor::prelude::*;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use executor::prelude::*;
fn increment(b: &mut Criterion) {
let mut sum = 0;
let executor = Executor::new();
b.bench_function("Executor::run", |b| b.iter(|| {
executor.run(
async {
b.bench_function("Executor::run", |b| {
b.iter(|| {
executor.run(async {
(0..10_000_000).for_each(|_| {
sum += 1;
});
},
);
}));
});
})
});
black_box(sum);
}
criterion_group!(perf, increment);
criterion_main!(perf);
criterion_main!(perf);

View File

@ -1,8 +1,8 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use executor::load_balancer;
use executor::prelude::*;
use futures_timer::Delay;
use std::time::Duration;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
#[cfg(feature = "tokio-runtime")]
mod benches {
@ -27,7 +27,6 @@ mod benches {
pub fn spawn_single(b: &mut Criterion) {
_spawn_single(b);
}
}
criterion_group!(spawn, benches::spawn_lot, benches::spawn_single);
@ -36,29 +35,29 @@ criterion_main!(spawn);
// Benchmark for a 10K burst task spawn
fn _spawn_lot(b: &mut Criterion) {
let executor = Executor::new();
b.bench_function("spawn_lot", |b| b.iter(|| {
let _ = (0..10_000)
.map(|_| {
executor.spawn(
async {
b.bench_function("spawn_lot", |b| {
b.iter(|| {
let _ = (0..10_000)
.map(|_| {
executor.spawn(async {
let duration = Duration::from_millis(1);
Delay::new(duration).await;
},
)
})
.collect::<Vec<_>>();
}));
})
})
.collect::<Vec<_>>();
})
});
}
// Benchmark for a single task spawn
fn _spawn_single(b: &mut Criterion) {
let executor = Executor::new();
b.bench_function("spawn single", |b| b.iter(|| {
executor.spawn(
async {
b.bench_function("spawn single", |b| {
b.iter(|| {
executor.spawn(async {
let duration = Duration::from_millis(1);
Delay::new(duration).await;
},
);
}));
});
})
});
}

View File

@ -1,7 +1,7 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use executor::load_balancer::{core_count, get_cores, stats, SmpStats};
use executor::placement;
use std::thread;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
fn stress_stats<S: SmpStats + Sync + Send>(stats: &'static S) {
let mut handles = Vec::with_capacity(*core_count());
@ -27,9 +27,11 @@ fn stress_stats<S: SmpStats + Sync + Send>(stats: &'static S) {
// 158,278 ns/iter (+/- 117,103)
fn lockless_stats_bench(b: &mut Criterion) {
b.bench_function("stress_stats", |b| b.iter(|| {
stress_stats(stats());
}));
b.bench_function("stress_stats", |b| {
b.iter(|| {
stress_stats(stats());
})
});
}
fn lockless_stats_bad_load(b: &mut Criterion) {
@ -45,9 +47,11 @@ fn lockless_stats_bad_load(b: &mut Criterion) {
}
}
b.bench_function("get_sorted_load", |b| b.iter(|| {
let _sorted_load = stats.get_sorted_load();
}));
b.bench_function("get_sorted_load", |b| {
b.iter(|| {
let _sorted_load = stats.get_sorted_load();
})
});
}
fn lockless_stats_good_load(b: &mut Criterion) {
@ -59,11 +63,17 @@ fn lockless_stats_good_load(b: &mut Criterion) {
stats.store_load(i, i);
}
b.bench_function("get_sorted_load", |b| b.iter(|| {
let _sorted_load = stats.get_sorted_load();
}));
b.bench_function("get_sorted_load", |b| {
b.iter(|| {
let _sorted_load = stats.get_sorted_load();
})
});
}
criterion_group!(stats_bench, lockless_stats_bench, lockless_stats_bad_load,
lockless_stats_good_load);
criterion_group!(
stats_bench,
lockless_stats_bench,
lockless_stats_bad_load,
lockless_stats_good_load
);
criterion_main!(stats_bench);

View File

@ -1,12 +1,12 @@
use executor::pool;
use executor::prelude::*;
use futures_util::{stream::FuturesUnordered, Stream};
use futures_util::{FutureExt, StreamExt};
use lightproc::prelude::RecoverableHandle;
use std::io::Write;
use std::panic::resume_unwind;
use std::rc::Rc;
use std::time::Duration;
use futures_util::{stream::FuturesUnordered, Stream};
use futures_util::{FutureExt, StreamExt};
use executor::pool;
use executor::prelude::*;
use lightproc::prelude::RecoverableHandle;
fn main() {
tracing_subscriber::fmt()
@ -24,9 +24,9 @@ fn main() {
let executor = Executor::new();
let mut handles: FuturesUnordered<RecoverableHandle<usize>> = (0..2000).map(|n| {
executor.spawn(
async move {
let mut handles: FuturesUnordered<RecoverableHandle<usize>> = (0..2000)
.map(|n| {
executor.spawn(async move {
let m: u64 = rand::random::<u64>() % 200;
tracing::debug!("Will sleep {} * 1 ms", m);
// simulate some really heavy load.
@ -34,9 +34,9 @@ fn main() {
async_std::task::sleep(Duration::from_millis(1)).await;
}
return n;
},
)
}).collect();
})
})
.collect();
//let handle = handles.fuse().all(|opt| async move { opt.is_some() });
/* Futures passed to `spawn` need to be `Send` so this won't work:
@ -58,12 +58,12 @@ fn main() {
// However, you can't pass it a future outright but have to hand it a generator creating the
// future on the correct thread.
let fut = async {
let local_futs: FuturesUnordered<_> = (0..200).map(|ref n| {
let n = *n;
let exe = executor.clone();
async move {
exe.spawn(
async {
let local_futs: FuturesUnordered<_> = (0..200)
.map(|ref n| {
let n = *n;
let exe = executor.clone();
async move {
exe.spawn(async {
let tid = std::thread::current().id();
tracing::info!("spawn_local({}) is on thread {:?}", n, tid);
exe.spawn_local(async move {
@ -86,10 +86,11 @@ fn main() {
*rc
})
}
).await
}
}).collect();
})
.await
}
})
.collect();
local_futs
};
@ -108,12 +109,10 @@ fn main() {
async_std::task::sleep(Duration::from_secs(20)).await;
tracing::info!("This is taking too long.");
};
executor.run(
async {
let res = futures_util::select! {
_ = a.fuse() => {},
_ = b.fuse() => {},
};
},
);
executor.run(async {
let res = futures_util::select! {
_ = a.fuse() => {},
_ = b.fuse() => {},
};
});
}

View File

@ -28,10 +28,10 @@
#![forbid(unused_import_braces)]
pub mod load_balancer;
pub mod manage;
pub mod placement;
pub mod pool;
pub mod run;
pub mod manage;
mod thread_manager;
mod worker;

View File

@ -1,6 +1,2 @@
/// View and Manage the current processes of this executor
pub struct Manager {
}
pub struct Manager {}

View File

@ -7,19 +7,19 @@
//! [`spawn`]: crate::pool::spawn
//! [`Worker`]: crate::run_queue::Worker
use std::cell::Cell;
use crate::thread_manager::{ThreadManager, DynamicRunner};
use crate::run::block;
use crate::thread_manager::{DynamicRunner, ThreadManager};
use crate::worker::{Sleeper, WorkerThread};
use crossbeam_deque::{Injector, Stealer};
use lightproc::lightproc::LightProc;
use lightproc::recoverable_handle::RecoverableHandle;
use std::cell::Cell;
use std::future::Future;
use std::iter::Iterator;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::sync::Arc;
use std::time::Duration;
use crossbeam_deque::{Injector, Stealer};
use crate::run::block;
use crate::worker::{Sleeper, WorkerThread};
#[derive(Debug)]
struct Spooler<'a> {
@ -31,10 +31,13 @@ struct Spooler<'a> {
impl Spooler<'_> {
pub fn new() -> Self {
let spool = Arc::new(Injector::new());
let threads = Box::leak(Box::new(
ThreadManager::new(2, AsyncRunner, spool.clone())));
let threads = Box::leak(Box::new(ThreadManager::new(2, AsyncRunner, spool.clone())));
threads.initialize();
Self { spool, threads, _marker: PhantomData }
Self {
spool,
threads,
_marker: PhantomData,
}
}
}
@ -53,9 +56,7 @@ impl<'a, 'executor: 'a> Executor<'executor> {
fn schedule(&self) -> impl Fn(LightProc) + 'a {
let task_queue = self.spooler.spool.clone();
move |lightproc: LightProc| {
task_queue.push(lightproc)
}
move |lightproc: LightProc| task_queue.push(lightproc)
}
///
@ -94,23 +95,21 @@ impl<'a, 'executor: 'a> Executor<'executor> {
/// # }
/// ```
pub fn spawn<F, R>(&self, future: F) -> RecoverableHandle<R>
where
F: Future<Output = R> + Send + 'a,
R: Send + 'a,
where
F: Future<Output = R> + Send + 'a,
R: Send + 'a,
{
let (task, handle) =
LightProc::recoverable(future, self.schedule());
let (task, handle) = LightProc::recoverable(future, self.schedule());
task.schedule();
handle
}
pub fn spawn_local<F, R>(&self, future: F) -> RecoverableHandle<R>
where
F: Future<Output = R> + 'a,
R: Send + 'a,
where
F: Future<Output = R> + 'a,
R: Send + 'a,
{
let (task, handle) =
LightProc::recoverable(future, schedule_local());
let (task, handle) = LightProc::recoverable(future, schedule_local());
task.schedule();
handle
}
@ -135,8 +134,8 @@ impl<'a, 'executor: 'a> Executor<'executor> {
/// );
/// ```
pub fn run<F, R>(&self, future: F) -> R
where
F: Future<Output = R>,
where
F: Future<Output = R>,
{
unsafe {
// An explicitly uninitialized `R`. Until `assume_init` is called this will not call any
@ -174,17 +173,20 @@ impl DynamicRunner for AsyncRunner {
sleeper
}
fn run_static<'b>(fences: impl Iterator<Item=&'b Stealer<LightProc>>, park_timeout: Duration) -> ! {
fn run_static<'b>(
fences: impl Iterator<Item = &'b Stealer<LightProc>>,
park_timeout: Duration,
) -> ! {
let worker = get_worker();
worker.run_timeout(fences, park_timeout)
}
fn run_dynamic<'b>(fences: impl Iterator<Item=&'b Stealer<LightProc>>) -> ! {
fn run_dynamic<'b>(fences: impl Iterator<Item = &'b Stealer<LightProc>>) -> ! {
let worker = get_worker();
worker.run(fences)
}
fn run_standalone<'b>(fences: impl Iterator<Item=&'b Stealer<LightProc>>) {
fn run_standalone<'b>(fences: impl Iterator<Item = &'b Stealer<LightProc>>) {
let worker = get_worker();
worker.run_once(fences)
}
@ -196,10 +198,9 @@ thread_local! {
fn get_worker() -> &'static WorkerThread<'static, LightProc> {
WORKER.with(|cell| {
let worker = unsafe {
&*cell.as_ptr() as &'static Option<WorkerThread<_>>
};
worker.as_ref()
let worker = unsafe { &*cell.as_ptr() as &'static Option<WorkerThread<_>> };
worker
.as_ref()
.expect("AsyncRunner running outside Executor context")
})
}

View File

@ -45,13 +45,18 @@
//! Throughput hogs determined by a combination of job in / job out frequency and current scheduler task assignment frequency.
//! Threshold of EMA difference is eluded by machine epsilon for floating point arithmetic errors.
use crate::worker::Sleeper;
use crate::{load_balancer, placement};
use core::fmt;
use crossbeam_channel::bounded;
use crossbeam_deque::{Injector, Stealer};
use crossbeam_queue::ArrayQueue;
use fmt::{Debug, Formatter};
use lazy_static::lazy_static;
use lightproc::lightproc::LightProc;
use placement::CoreId;
use std::collections::VecDeque;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use std::{
sync::{
@ -60,12 +65,7 @@ use std::{
},
thread,
};
use std::sync::{Arc, RwLock};
use crossbeam_channel::bounded;
use crossbeam_deque::{Injector, Stealer};
use tracing::{debug, trace};
use lightproc::lightproc::LightProc;
use crate::worker::Sleeper;
/// The default thread park timeout before checking for new tasks.
const THREAD_PARK_TIMEOUT: Duration = Duration::from_millis(1);
@ -113,10 +113,12 @@ lazy_static! {
pub trait DynamicRunner {
fn setup(task_queue: Arc<Injector<LightProc>>) -> Sleeper<LightProc>;
fn run_static<'b>(fences: impl Iterator<Item=&'b Stealer<LightProc>>,
park_timeout: Duration) -> !;
fn run_dynamic<'b>(fences: impl Iterator<Item=&'b Stealer<LightProc>>) -> !;
fn run_standalone<'b>(fences: impl Iterator<Item=&'b Stealer<LightProc>>);
fn run_static<'b>(
fences: impl Iterator<Item = &'b Stealer<LightProc>>,
park_timeout: Duration,
) -> !;
fn run_dynamic<'b>(fences: impl Iterator<Item = &'b Stealer<LightProc>>) -> !;
fn run_standalone<'b>(fences: impl Iterator<Item = &'b Stealer<LightProc>>);
}
/// The `ThreadManager` is creates and destroys worker threads depending on demand according to
@ -183,11 +185,14 @@ impl<Runner: Debug> Debug for ThreadManager<Runner> {
}
fmt.debug_struct("DynamicPoolManager")
.field("thread pool", &ThreadCount(
&self.static_threads,
&self.dynamic_threads,
&self.parked_threads.len(),
))
.field(
"thread pool",
&ThreadCount(
&self.static_threads,
&self.dynamic_threads,
&self.parked_threads.len(),
),
)
.field("runner", &self.runner)
.field("last_frequency", &self.last_frequency)
.finish()
@ -195,7 +200,11 @@ impl<Runner: Debug> Debug for ThreadManager<Runner> {
}
impl<Runner: DynamicRunner + Sync + Send> ThreadManager<Runner> {
pub fn new(static_threads: usize, runner: Runner, task_queue: Arc<Injector<LightProc>>) -> Self {
pub fn new(
static_threads: usize,
runner: Runner,
task_queue: Arc<Injector<LightProc>>,
) -> Self {
let dynamic_threads = 1.max(num_cpus::get().checked_sub(static_threads).unwrap_or(0));
let parked_threads = ArrayQueue::new(1.max(static_threads + dynamic_threads));
let fences = Arc::new(RwLock::new(Vec::new()));
@ -252,7 +261,10 @@ impl<Runner: DynamicRunner + Sync + Send> ThreadManager<Runner> {
});
// Dynamic thread manager that will allow us to unpark threads when needed
debug!("spooling up {} dynamic worker threads", self.dynamic_threads);
debug!(
"spooling up {} dynamic worker threads",
self.dynamic_threads
);
(0..self.dynamic_threads).for_each(|_| {
let tx = tx.clone();
let fencelock = fencelock.clone();
@ -302,10 +314,11 @@ impl<Runner: DynamicRunner + Sync + Send> ThreadManager<Runner> {
/// Provision threads takes a number of threads that need to be made available.
/// It will try to unpark threads from the dynamic pool, and spawn more threads if needs be.
pub fn provision_threads(&'static self,
n: usize,
fencelock: &Arc<RwLock<Vec<Stealer<LightProc>>>>)
{
pub fn provision_threads(
&'static self,
n: usize,
fencelock: &Arc<RwLock<Vec<Stealer<LightProc>>>>,
) {
let rem = self.unpark_thread(n);
if rem != 0 {
debug!("no more threads to unpark, spawning {} new threads", rem);
@ -391,7 +404,5 @@ impl<Runner: DynamicRunner + Sync + Send> ThreadManager<Runner> {
/// on the request rate.
///
/// It uses frequency based calculation to define work. Utilizing average processing rate.
fn scale_pool(&'static self) {
}
fn scale_pool(&'static self) {}
}

View File

@ -1,10 +1,10 @@
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::Duration;
use crossbeam_deque::{Injector, Steal, Stealer, Worker};
use crossbeam_queue::SegQueue;
use crossbeam_utils::sync::{Parker, Unparker};
use lightproc::prelude::LightProc;
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::Duration;
pub trait Runnable {
fn run(self);
@ -61,8 +61,14 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
let unparker = parker.unparker().clone();
(
Self { task_queue, tasks, local_tasks, parker, _marker },
Sleeper { stealer, unparker }
Self {
task_queue,
tasks,
local_tasks,
parker,
_marker,
},
Sleeper { stealer, unparker },
)
}
@ -71,10 +77,8 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
}
/// Run this worker thread "forever" (i.e. until the thread panics or is otherwise killed)
pub fn run(&self, fences: impl Iterator<Item=&'a Stealer<T>>) -> ! {
let fences: Vec<Stealer<T>> = fences
.map(|stealer| stealer.clone())
.collect();
pub fn run(&self, fences: impl Iterator<Item = &'a Stealer<T>>) -> ! {
let fences: Vec<Stealer<T>> = fences.map(|stealer| stealer.clone()).collect();
loop {
self.run_inner(&fences);
@ -82,10 +86,12 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
}
}
pub fn run_timeout(&self, fences: impl Iterator<Item=&'a Stealer<T>>, timeout: Duration) -> ! {
let fences: Vec<Stealer<T>> = fences
.map(|stealer| stealer.clone())
.collect();
pub fn run_timeout(
&self,
fences: impl Iterator<Item = &'a Stealer<T>>,
timeout: Duration,
) -> ! {
let fences: Vec<Stealer<T>> = fences.map(|stealer| stealer.clone()).collect();
loop {
self.run_inner(&fences);
@ -93,10 +99,8 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
}
}
pub fn run_once(&self, fences: impl Iterator<Item=&'a Stealer<T>>) {
let fences: Vec<Stealer<T>> = fences
.map(|stealer| stealer.clone())
.collect();
pub fn run_once(&self, fences: impl Iterator<Item = &'a Stealer<T>>) {
let fences: Vec<Stealer<T>> = fences.map(|stealer| stealer.clone()).collect();
self.run_inner(fences);
}
@ -123,17 +127,19 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
Steal::Success(task) => {
task.run();
continue 'work;
},
}
// If there is no more work to steal from the global queue, try other
// workers next
Steal::Empty => break,
// If a race condition occurred try again with backoff
Steal::Retry => for _ in 0..(1 << i) {
core::hint::spin_loop();
i += 1;
},
Steal::Retry => {
for _ in 0..(1 << i) {
core::hint::spin_loop();
i += 1;
}
}
}
}
@ -145,7 +151,7 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
Steal::Success(task) => {
task.run();
continue 'work;
},
}
// If no other worker has work to do we're done once again.
Steal::Empty => break,
@ -169,6 +175,6 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
}
#[inline(always)]
fn select_fence<'a, T>(fences: impl Iterator<Item=&'a Stealer<T>>) -> Option<&'a Stealer<T>> {
fn select_fence<'a, T>(fences: impl Iterator<Item = &'a Stealer<T>>) -> Option<&'a Stealer<T>> {
fences.max_by_key(|fence| fence.len())
}
}

View File

@ -1,8 +1,8 @@
use std::io::Write;
use executor::prelude::{spawn, ProcStack};
use executor::run::run;
use std::io::Write;
use std::thread;
use std::time::Duration;
use executor::prelude::{ProcStack, spawn};
#[cfg(feature = "tokio-runtime")]
mod tokio_tests {
@ -21,13 +21,11 @@ mod no_tokio_tests {
}
fn run_test() {
let handle = spawn(
async {
let duration = Duration::from_millis(1);
thread::sleep(duration);
//42
},
);
let handle = spawn(async {
let duration = Duration::from_millis(1);
thread::sleep(duration);
//42
});
let output = run(handle, ProcStack {});

View File

@ -1,11 +1,11 @@
use executor::blocking;
use executor::prelude::ProcStack;
use executor::run::run;
use futures_util::future::join_all;
use lightproc::recoverable_handle::RecoverableHandle;
use std::thread;
use std::time::Duration;
use std::time::Instant;
use executor::prelude::ProcStack;
// Test for slow joins without task bursts during joins.
#[test]
@ -17,12 +17,10 @@ fn slow_join() {
// Send an initial batch of million bursts.
let handles = (0..1_000_000)
.map(|_| {
blocking::spawn_blocking(
async {
let duration = Duration::from_millis(1);
thread::sleep(duration);
},
)
blocking::spawn_blocking(async {
let duration = Duration::from_millis(1);
thread::sleep(duration);
})
})
.collect::<Vec<RecoverableHandle<()>>>();
@ -35,12 +33,10 @@ fn slow_join() {
// Spawn yet another batch of work on top of it
let handles = (0..10_000)
.map(|_| {
blocking::spawn_blocking(
async {
let duration = Duration::from_millis(100);
thread::sleep(duration);
},
)
blocking::spawn_blocking(async {
let duration = Duration::from_millis(100);
thread::sleep(duration);
})
})
.collect::<Vec<RecoverableHandle<()>>>();
@ -63,12 +59,10 @@ fn slow_join_interrupted() {
// Send an initial batch of million bursts.
let handles = (0..1_000_000)
.map(|_| {
blocking::spawn_blocking(
async {
let duration = Duration::from_millis(1);
thread::sleep(duration);
},
)
blocking::spawn_blocking(async {
let duration = Duration::from_millis(1);
thread::sleep(duration);
})
})
.collect::<Vec<RecoverableHandle<()>>>();
@ -82,12 +76,10 @@ fn slow_join_interrupted() {
// Spawn yet another batch of work on top of it
let handles = (0..10_000)
.map(|_| {
blocking::spawn_blocking(
async {
let duration = Duration::from_millis(100);
thread::sleep(duration);
},
)
blocking::spawn_blocking(async {
let duration = Duration::from_millis(100);
thread::sleep(duration);
})
})
.collect::<Vec<RecoverableHandle<()>>>();
@ -111,12 +103,10 @@ fn longhauling_task_join() {
// First batch of overhauling tasks
let _ = (0..100_000)
.map(|_| {
blocking::spawn_blocking(
async {
let duration = Duration::from_millis(1000);
thread::sleep(duration);
},
)
blocking::spawn_blocking(async {
let duration = Duration::from_millis(1000);
thread::sleep(duration);
})
})
.collect::<Vec<RecoverableHandle<()>>>();
@ -127,12 +117,10 @@ fn longhauling_task_join() {
// Send yet another medium sized batch to see how it scales.
let handles = (0..10_000)
.map(|_| {
blocking::spawn_blocking(
async {
let duration = Duration::from_millis(100);
thread::sleep(duration);
},
)
blocking::spawn_blocking(async {
let duration = Duration::from_millis(100);
thread::sleep(duration);
})
})
.collect::<Vec<RecoverableHandle<()>>>();

View File

@ -1,11 +1,11 @@
use std::any::Any;
use std::fmt::Debug;
use std::ops::Deref;
use crossbeam::channel::{unbounded, Sender};
use futures_executor as executor;
use lazy_static::lazy_static;
use lightproc::prelude::*;
use std::any::Any;
use std::fmt::Debug;
use std::future::Future;
use std::ops::Deref;
use std::thread;
fn spawn_on_thread<F, R>(future: F) -> RecoverableHandle<R>
@ -30,20 +30,17 @@ where
}
let schedule = |t| (QUEUE.deref()).send(t).unwrap();
let (proc, handle) = LightProc::recoverable(
future,
schedule
);
let (proc, handle) = LightProc::recoverable(future, schedule);
let handle = handle
.on_panic(|err: Box<dyn Any + Send>| {
match err.downcast::<&'static str>() {
Ok(reason) => println!("Future panicked: {}", &reason),
Err(err) =>
println!("Future panicked with a non-text reason of typeid {:?}",
err.type_id()),
}
});
let handle = handle.on_panic(
|err: Box<dyn Any + Send>| match err.downcast::<&'static str>() {
Ok(reason) => println!("Future panicked: {}", &reason),
Err(err) => println!(
"Future panicked with a non-text reason of typeid {:?}",
err.type_id()
),
},
);
proc.schedule();

View File

@ -14,15 +14,10 @@ where
{
let (sender, receiver) = channel::unbounded();
let future = async move {
fut.await
};
let future = async move { fut.await };
let schedule = move |t| sender.send(t).unwrap();
let (proc, handle) = LightProc::build(
future,
schedule,
);
let (proc, handle) = LightProc::build(future, schedule);
proc.schedule();

View File

@ -77,9 +77,10 @@ impl LightProc {
/// });
/// ```
pub fn recoverable<'a, F, R, S>(future: F, schedule: S) -> (Self, RecoverableHandle<R>)
where F: Future<Output=R> + 'a,
R: 'a,
S: Fn(LightProc) + 'a,
where
F: Future<Output = R> + 'a,
R: 'a,
S: Fn(LightProc) + 'a,
{
let recovery_future = AssertUnwindSafe(future).catch_unwind();
let (proc, handle) = Self::build(recovery_future, schedule);
@ -115,9 +116,10 @@ impl LightProc {
/// );
/// ```
pub fn build<'a, F, R, S>(future: F, schedule: S) -> (Self, ProcHandle<R>)
where F: Future<Output=R> + 'a,
R: 'a,
S: Fn(LightProc) + 'a,
where
F: Future<Output = R> + 'a,
R: 'a,
S: Fn(LightProc) + 'a,
{
let raw_proc = RawProc::allocate(future, schedule);
let proc = LightProc { raw_proc };

View File

@ -44,12 +44,10 @@ impl ProcData {
let (flags, references) = state.parts();
let new = State::new(flags | CLOSED, references);
// Mark the proc as closed.
match self.state.compare_exchange_weak(
state,
new,
Ordering::AcqRel,
Ordering::Acquire,
) {
match self
.state
.compare_exchange_weak(state, new, Ordering::AcqRel, Ordering::Acquire)
{
Ok(_) => {
// Notify the awaiter that the proc has been closed.
if state.is_awaiter() {
@ -117,7 +115,8 @@ impl ProcData {
// Release the lock. If we've cleared the awaiter, then also unset the awaiter flag.
if new_is_none {
self.state.fetch_and((!LOCKED & !AWAITER).into(), Ordering::Release);
self.state
.fetch_and((!LOCKED & !AWAITER).into(), Ordering::Release);
} else {
self.state.fetch_and((!LOCKED).into(), Ordering::Release);
}
@ -142,9 +141,7 @@ impl Debug for ProcData {
.field("ref_count", &state.get_refcount())
.finish()
} else {
fmt.debug_struct("ProcData")
.field("state", &state)
.finish()
fmt.debug_struct("ProcData").field("state", &state).finish()
}
}
}

View File

@ -273,9 +273,7 @@ where
let raw = Self::from_ptr(ptr);
// Decrement the reference count.
let new = (*raw.pdata)
.state
.fetch_sub(1, Ordering::AcqRel);
let new = (*raw.pdata).state.fetch_sub(1, Ordering::AcqRel);
let new = new.set_refcount(new.get_refcount().saturating_sub(1));
// If this was the last reference to the proc and the `ProcHandle` has been dropped as
@ -444,13 +442,12 @@ where
// was woken and then clean up its resources.
let (flags, references) = state.parts();
let flags = if state.is_closed() {
flags & !( RUNNING | SCHEDULED )
flags & !(RUNNING | SCHEDULED)
} else {
flags & !RUNNING
};
let new = State::new(flags, references);
// Mark the proc as not running.
match (*raw.pdata).state.compare_exchange_weak(
state,
@ -502,10 +499,10 @@ impl<'a, F, R, S> Copy for RawProc<'a, F, R, S> {}
/// A guard that closes the proc if polling its future panics.
struct Guard<'a, F, R, S>(RawProc<'a, F, R, S>)
where
F: Future<Output = R> + 'a,
R: 'a,
S: Fn(LightProc) + 'a;
where
F: Future<Output = R> + 'a,
R: 'a,
S: Fn(LightProc) + 'a;
impl<'a, F, R, S> Drop for Guard<'a, F, R, S>
where

View File

@ -1,9 +1,9 @@
//!
//! Handle for recoverable process
use std::any::Any;
use crate::proc_data::ProcData;
use crate::proc_handle::ProcHandle;
use crate::state::State;
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::future::Future;
use std::pin::Pin;
@ -80,12 +80,12 @@ impl<R> RecoverableHandle<R> {
/// });
/// ```
pub fn on_panic<F>(mut self, callback: F) -> Self
where F: FnOnce(Box<dyn Any + Send>) + Send + Sync + 'static,
where
F: FnOnce(Box<dyn Any + Send>) + Send + Sync + 'static,
{
self.panicked = Some(Box::new(callback));
self
}
}
impl<R> Future for RecoverableHandle<R> {
@ -102,7 +102,7 @@ impl<R> Future for RecoverableHandle<R> {
}
Poll::Ready(None)
},
}
}
}
}

View File

@ -73,25 +73,22 @@ bitflags::bitflags! {
#[repr(packed)]
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub struct State {
bytes: [u8; 8]
bytes: [u8; 8],
}
impl State {
#[inline(always)]
pub const fn new(flags: StateFlags, references: u32) -> Self {
let [a,b,c,d] = references.to_ne_bytes();
let [e,f,g,h] = flags.bits.to_ne_bytes();
Self::from_bytes([a,b,c,d,e,f,g,h])
let [a, b, c, d] = references.to_ne_bytes();
let [e, f, g, h] = flags.bits.to_ne_bytes();
Self::from_bytes([a, b, c, d, e, f, g, h])
}
#[inline(always)]
pub const fn parts(self: Self) -> (StateFlags, u32) {
let [a,b,c,d,e,f,g,h] = self.bytes;
let refcount = u32::from_ne_bytes([a,b,c,d]);
let state = unsafe {
StateFlags::from_bits_unchecked(u32::from_ne_bytes([e,f,g,h]))
};
let [a, b, c, d, e, f, g, h] = self.bytes;
let refcount = u32::from_ne_bytes([a, b, c, d]);
let state = unsafe { StateFlags::from_bits_unchecked(u32::from_ne_bytes([e, f, g, h])) };
(state, refcount)
}
@ -101,8 +98,8 @@ impl State {
/// Note that the reference counter only tracks the `LightProc` and `Waker`s. The `ProcHandle` is
/// tracked separately by the `HANDLE` flag.
pub const fn get_refcount(self) -> u32 {
let [a,b,c,d,_,_,_,_] = self.bytes;
u32::from_ne_bytes([a,b,c,d])
let [a, b, c, d, _, _, _, _] = self.bytes;
u32::from_ne_bytes([a, b, c, d])
}
#[inline(always)]
@ -116,7 +113,7 @@ impl State {
#[inline(always)]
pub const fn get_flags(self) -> StateFlags {
let [_, _, _, _, e, f, g, h] = self.bytes;
unsafe { StateFlags::from_bits_unchecked(u32::from_ne_bytes([e,f,g,h])) }
unsafe { StateFlags::from_bits_unchecked(u32::from_ne_bytes([e, f, g, h])) }
}
#[inline(always)]
@ -207,10 +204,10 @@ impl AtomicState {
current: State,
new: State,
success: Ordering,
failure: Ordering
) -> Result<State, State>
{
self.inner.compare_exchange(current.into_u64(), new.into_u64(), success, failure)
failure: Ordering,
) -> Result<State, State> {
self.inner
.compare_exchange(current.into_u64(), new.into_u64(), success, failure)
.map(|u| State::from_u64(u))
.map_err(|u| State::from_u64(u))
}
@ -220,37 +217,37 @@ impl AtomicState {
current: State,
new: State,
success: Ordering,
failure: Ordering
) -> Result<State, State>
{
self.inner.compare_exchange_weak(current.into_u64(), new.into_u64(), success, failure)
failure: Ordering,
) -> Result<State, State> {
self.inner
.compare_exchange_weak(current.into_u64(), new.into_u64(), success, failure)
.map(|u| State::from_u64(u))
.map_err(|u| State::from_u64(u))
}
pub fn fetch_or(&self, val: StateFlags, order: Ordering) -> State {
let [a,b,c,d] = val.bits.to_ne_bytes();
let store = u64::from_ne_bytes([0,0,0,0,a,b,c,d]);
let [a, b, c, d] = val.bits.to_ne_bytes();
let store = u64::from_ne_bytes([0, 0, 0, 0, a, b, c, d]);
State::from_u64(self.inner.fetch_or(store, order))
}
pub fn fetch_and(&self, val: StateFlags, order: Ordering) -> State {
let [a,b,c,d] = val.bits.to_ne_bytes();
let store = u64::from_ne_bytes([!0,!0,!0,!0,a,b,c,d]);
let [a, b, c, d] = val.bits.to_ne_bytes();
let store = u64::from_ne_bytes([!0, !0, !0, !0, a, b, c, d]);
State::from_u64(self.inner.fetch_and(store, order))
}
// FIXME: Do this properly
pub fn fetch_add(&self, val: u32, order: Ordering) -> State {
let [a,b,c,d] = val.to_ne_bytes();
let store = u64::from_ne_bytes([a,b,c,d,0,0,0,0]);
let [a, b, c, d] = val.to_ne_bytes();
let store = u64::from_ne_bytes([a, b, c, d, 0, 0, 0, 0]);
State::from_u64(self.inner.fetch_add(store, order))
}
// FIXME: Do this properly
pub fn fetch_sub(&self, val: u32, order: Ordering) -> State {
let [a,b,c,d] = val.to_ne_bytes();
let store = u64::from_ne_bytes([a,b,c,d,0,0,0,0]);
let [a, b, c, d] = val.to_ne_bytes();
let store = u64::from_ne_bytes([a, b, c, d, 0, 0, 0, 0]);
State::from_u64(self.inner.fetch_sub(store, order))
}
}