Merge branch 'release/v0.4.2'

* release/v0.4.2: (31 commits)
  Bump version to 0.4.2
  Archive Cargo.lock
  Absolute path to cargo2junit
  Install cargo2junit in test build
  whoops
  Ah yes, why bother with correct documentation anyway?
  Move rustup/cargo install to only the jobs that need them
  Allow rustfmt failure until we fix capnp gen being fmt'ed
  ...
  okay I guess?
  rustup
  actually we don't need clippy for non-MR things
  okay gitlab, be that way
  and use stable goddamit
  okay make builds work better for merges
  Actually, only build if linting checks out. And make Gitlab CI work.
  Try to get the Gitlab CI to cooperate.
  Build test harnish as part of the `build` step
  Make docker containers only be built when necessary
  Correct gitlab-ci.yml
  ...
This commit is contained in:
Nadja Reitzenstein 2022-05-13 18:32:28 +02:00
commit bfde6c03dc
82 changed files with 1344 additions and 1446 deletions

4
.gitignore vendored
View File

@ -1,5 +1 @@
/target /target
**/*.rs.bk
tags
/.idea/

View File

@ -1,86 +1,152 @@
# Official language image. Look for the different tagged releases at: # Define slightly different stages.
# https://hub.docker.com/r/library/rust/tags/ # Additionally, lint the code before anything else to fail more quickly
image: "rust:latest" stages:
- lint
- build
- test
- release
- dockerify
# Optional: Pick zero or more services to be used on all builds. default:
# Only needed when using a docker container to run your tests in. image: "rust:latest"
# Check out: http://docs.gitlab.com/ce/ci/docker/using_docker_images.html#what-is-a-service tags:
# services: - linux
# - mysql:latest - docker
# - redis:latest
# - postgres:latest
variables: variables:
GIT_SUBMODULE_STRATEGY: recursive GIT_SUBMODULE_STRATEGY: recursive
# CARGO_HOME: $CI_PROJECT_DIR/cargo CARGO_HOME: $CI_PROJECT_DIR/cargo
APT_CACHE_DIR: $CI_PROJECT_DIR/apt APT_CACHE_DIR: $CI_PROJECT_DIR/apt
# cache dependencies and build environment to speed up setup
cache:
key: "$CI_COMMIT_REF_SLUG"
paths:
- apt/
- cargo/
- target/
# install build dependencies # install build dependencies
before_script: before_script:
- apt-get update -yqq - apt-get update -yqq
- apt-get install -o dir::cache::archives="$APT_CACHE_DIR" -yqq --no-install-recommends capnproto build-essential cmake clang libclang-dev libgsasl7-dev - apt-get install -o dir::cache::archives="$APT_CACHE_DIR" -yqq --no-install-recommends capnproto build-essential cmake clang libclang-dev jq
- rustup update
- rustup component add rustfmt
- rustup component add clippy
# Use clippy to lint the project .lints:
stage: lint
allow_failure: true
only:
- merge_requests
# Use clippy lints
lint:clippy: lint:clippy:
allow_failure: true extends: .lints
script: script:
- rustc --version && cargo --version # Print version info for debugging - rustup component add clippy
- cargo clippy -- -D warnings - cargo clippy -V
only: - echo -e "\e[0Ksection_start:`date +%s`:clippy_output\r\e[0Kcargo clippy output"
- master - cargo clippy -- --no-deps
- development - echo -e "\e[0Ksection_end:`date +%s`:clippy_output\r\e[0K"
- merge_requests
tags:
- linux
- docker
# Use rustfmt to check formating of the project # Use rustfmt to check formating
lint:fmt: lint:fmt:
allow_failure: true extends: .lints
script: script:
- rustc --version && cargo --version # Print version info for debugging - rustup component add rustfmt
- cargo fmt -- --check # TODO: Do we want to enforce formating? - cargo fmt --version
- echo -e "\e[0Ksection_start:`date +%s`:rustfmt_output\r\e[0KChanges suggested by rustfmt"
- cargo fmt --check -- -v
- echo -e "\e[0Ksection_end:`date +%s`:rustfmt_output\r\e[0K"
# Check if the code builds on rust stable
stable:build:
stage: build
only: only:
- master - main
- development - development
- merge_requests - merge_requests
tags:
- linux
- docker
# Use cargo to test the project
test:cargo:
script: script:
- rustc --version && cargo --version # Print version info for debugging - rustc +stable --version && cargo --version
- cargo test --workspace --verbose - echo -e "\e[0Ksection_start:`date +%s`:build_output\r\e[0KOutput of cargo check"
- cargo check --verbose
- echo -e "\e[0Ksection_end:`date +%s`:build_output\r\e[0K"
stable:test:
stage: build
needs: ["stable:build"]
only: only:
- master - main
- development - development
- merge_requests - merge_requests
tags:
- linux
- docker
build:docker-master:
image:
name: gcr.io/kaniko-project/executor:v1.6.0-debug
entrypoint: [""]
before_script:
- ''
script: script:
- mkdir -p /kaniko/.docker - echo -e "\e[0Ksection_start:`date +%s`:build_output\r\e[0KOutput of cargo test --no-run"
- echo "{\"auths\":{\"$CI_REGISTRY\":{\"username\":\"$CI_REGISTRY_USER\",\"password\":\"$CI_REGISTRY_PASSWORD\"}}}" > /kaniko/.docker/config.json - cargo test --verbose --no-run --workspace
- /kaniko/executor --force --context $CI_PROJECT_DIR --dockerfile $CI_PROJECT_DIR/Dockerfile --destination $CI_REGISTRY_IMAGE:latest - echo -e "\e[0Ksection_end:`date +%s`:build_output\r\e[0K"
- cargo install --root $CARGO_HOME cargo2junit
.tests:
stage: test
needs: ["stable:test"]
script:
- cargo test --workspace $TEST_TARGET -- -Z unstable-options --format json --report-time | $CARGO_HOME/bin/cargo2junit > report.xml
artifacts:
when: always
reports:
junit:
- report.xml
only: only:
- master - main
tags: - development
- linux - merge_requests
- docker
# Run unit tests
unit test 1:3:
variables:
TEST_TARGET: "--lib"
extends: .tests
unit test 2:3:
variables:
TEST_TARGET: "--bins"
extends: .tests
unit test 3:3:
variables:
TEST_TARGET: "--examples"
extends: .tests
release_prepare:
stage: release
rules:
- if: $CI_COMMIT_TAG =~ "release/.*"
when: never
- if: $CI_COMMIT_BRANCH == "main"
script:
- VERSION="cargo metadata --format-version 1 | jq -C '.packages | .[] | select(.name == "diflouroborane") | .version' -r"
- echo $VERSION > release.env
artifacts:
reports:
dotenv: release.env
release_job:
stage: release
needs:
- job: release_prepare
artifacts: true
image: registry.gitlab.com/gitlab-org/release-cli:latest
rules:
- if: $CI_COMMIT_TAG =~ "release/.*"
when: never
- if: $CI_COMMIT_BRANCH == "main"
script:
- echo "Creating GitLab release…"
release:
name: "BFFH $VERSION"
description: "GitLab CI auto-created release"
tag_name: "release/$VERSION"
build:docker-releases: build:docker-releases:
stage: dockerify
image: image:
name: gcr.io/kaniko-project/executor:v1.6.0-debug name: gcr.io/kaniko-project/executor:v1.6.0-debug
entrypoint: [""] entrypoint: [""]
@ -90,13 +156,12 @@ build:docker-releases:
- mkdir -p /kaniko/.docker - mkdir -p /kaniko/.docker
- echo "{\"auths\":{\"$CI_REGISTRY\":{\"username\":\"$CI_REGISTRY_USER\",\"password\":\"$CI_REGISTRY_PASSWORD\"}}}" > /kaniko/.docker/config.json - echo "{\"auths\":{\"$CI_REGISTRY\":{\"username\":\"$CI_REGISTRY_USER\",\"password\":\"$CI_REGISTRY_PASSWORD\"}}}" > /kaniko/.docker/config.json
- /kaniko/executor --force --context $CI_PROJECT_DIR --dockerfile $CI_PROJECT_DIR/Dockerfile --destination $CI_REGISTRY_IMAGE:$CI_COMMIT_TAG - /kaniko/executor --force --context $CI_PROJECT_DIR --dockerfile $CI_PROJECT_DIR/Dockerfile --destination $CI_REGISTRY_IMAGE:$CI_COMMIT_TAG
only: rules:
- tags - if: $CI_COMMIT_TAG =~ "release/.*"
tags: when: never
- linux
- docker
build:docker-development: build:docker-development:
stage: dockerify
image: image:
name: gcr.io/kaniko-project/executor:v1.6.0-debug name: gcr.io/kaniko-project/executor:v1.6.0-debug
entrypoint: [""] entrypoint: [""]
@ -108,14 +173,3 @@ build:docker-development:
- /kaniko/executor --force --context $CI_PROJECT_DIR --dockerfile $CI_PROJECT_DIR/Dockerfile --destination $CI_REGISTRY_IMAGE:dev-latest - /kaniko/executor --force --context $CI_PROJECT_DIR --dockerfile $CI_PROJECT_DIR/Dockerfile --destination $CI_REGISTRY_IMAGE:dev-latest
only: only:
- development - development
tags:
- linux
- docker
# cache dependencies and build environment to speed up setup
cache:
key: "$CI_COMMIT_REF_SLUG"
paths:
- apt/
# - cargo/
- target/

2
.gitmodules vendored
View File

@ -1,4 +1,4 @@
[submodule "schema"] [submodule "schema"]
path = api/schema path = api/schema
url = https://gitlab.com/fabinfra/fabaccess/fabaccess-api url = ../fabaccess-api
branch = main branch = main

9
Cargo.lock generated
View File

@ -775,7 +775,7 @@ dependencies = [
[[package]] [[package]]
name = "diflouroborane" name = "diflouroborane"
version = "0.4.1" version = "0.4.2"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"api", "api",
@ -881,13 +881,6 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10"
[[package]]
name = "dummy"
version = "0.1.0"
dependencies = [
"sdk",
]
[[package]] [[package]]
name = "either" name = "either"
version = "1.6.1" version = "1.6.1"

View File

@ -1,6 +1,6 @@
[package] [package]
name = "diflouroborane" name = "diflouroborane"
version = "0.4.1" version = "0.4.2"
authors = [ "dequbed <me@dequbed.space>" authors = [ "dequbed <me@dequbed.space>"
, "Kai Jan Kriegel <kai@kjkriegel.de>" , "Kai Jan Kriegel <kai@kjkriegel.de>"
, "Joseph Langosch <thejoklla@gmail.com>" , "Joseph Langosch <thejoklla@gmail.com>"
@ -19,10 +19,14 @@ lto = "thin"
[lib] [lib]
path = "bffhd/lib.rs" path = "bffhd/lib.rs"
# Don't run unit tests on `cargo test --tests`, only run integration tests.
test = false
[[bin]] [[bin]]
name = "bffhd" name = "bffhd"
path = "bin/bffhd/main.rs" path = "bin/bffhd/main.rs"
# Don't run unit tests on `cargo test --tests`, only run integration tests.
test = false
[dependencies] [dependencies]
libc = "0.2.101" libc = "0.2.101"

View File

@ -1,7 +1,8 @@
use walkdir::{WalkDir, DirEntry}; use walkdir::{DirEntry, WalkDir};
fn is_hidden(entry: &DirEntry) -> bool { fn is_hidden(entry: &DirEntry) -> bool {
entry.file_name() entry
.file_name()
.to_str() .to_str()
.map(|s| s.starts_with('.')) .map(|s| s.starts_with('.'))
.unwrap_or(false) .unwrap_or(false)
@ -22,15 +23,15 @@ fn main() {
.filter_map(Result::ok) // Filter all entries that access failed on .filter_map(Result::ok) // Filter all entries that access failed on
.filter(|e| !e.file_type().is_dir()) // Filter directories .filter(|e| !e.file_type().is_dir()) // Filter directories
// Filter non-schema files // Filter non-schema files
.filter(|e| e.file_name() .filter(|e| {
e.file_name()
.to_str() .to_str()
.map(|s| s.ends_with(".capnp")) .map(|s| s.ends_with(".capnp"))
.unwrap_or(false) .unwrap_or(false)
) })
{ {
println!("Collecting schema file {}", entry.path().display()); println!("Collecting schema file {}", entry.path().display());
compile_command compile_command.file(entry.path());
.file(entry.path());
} }
println!("Compiling schemas..."); println!("Compiling schemas...");
@ -53,16 +54,18 @@ fn main() {
.filter_map(Result::ok) // Filter all entries that access failed on .filter_map(Result::ok) // Filter all entries that access failed on
.filter(|e| !e.file_type().is_dir()) // Filter directories .filter(|e| !e.file_type().is_dir()) // Filter directories
// Filter non-schema files // Filter non-schema files
.filter(|e| e.file_name() .filter(|e| {
e.file_name()
.to_str() .to_str()
.map(|s| s.ends_with(".capnp")) .map(|s| s.ends_with(".capnp"))
.unwrap_or(false) .unwrap_or(false)
) })
{ {
println!("Collecting schema file {}", entry.path().display()); println!("Collecting schema file {}", entry.path().display());
compile_command compile_command.file(entry.path());
.file(entry.path());
} }
compile_command.run().expect("Failed to generate extra API code"); compile_command
.run()
.expect("Failed to generate extra API code");
} }

View File

@ -1,4 +1,3 @@
//! FabAccess generated API bindings //! FabAccess generated API bindings
//! //!
//! This crate contains slightly nicer and better documented bindings for the FabAccess API. //! This crate contains slightly nicer and better documented bindings for the FabAccess API.

View File

@ -1,6 +1,5 @@
pub use capnpc::schema_capnp; pub use capnpc::schema_capnp;
#[cfg(feature = "generated")] #[cfg(feature = "generated")]
pub mod authenticationsystem_capnp { pub mod authenticationsystem_capnp {
include!(concat!(env!("OUT_DIR"), "/authenticationsystem_capnp.rs")); include!(concat!(env!("OUT_DIR"), "/authenticationsystem_capnp.rs"));

View File

@ -1,6 +1,6 @@
use std::collections::HashMap;
use futures_util::future; use futures_util::future;
use futures_util::future::BoxFuture; use futures_util::future::BoxFuture;
use std::collections::HashMap;
use crate::actors::Actor; use crate::actors::Actor;
use crate::db::ArchivedValue; use crate::db::ArchivedValue;

View File

@ -3,7 +3,7 @@ use crate::resources::state::State;
use crate::{Config, ResourcesHandle}; use crate::{Config, ResourcesHandle};
use async_compat::CompatExt; use async_compat::CompatExt;
use executor::pool::Executor; use executor::pool::Executor;
use futures_signals::signal::{Signal}; use futures_signals::signal::Signal;
use futures_util::future::BoxFuture; use futures_util::future::BoxFuture;
use rumqttc::{AsyncClient, ConnectionError, Event, Incoming, MqttOptions}; use rumqttc::{AsyncClient, ConnectionError, Event, Incoming, MqttOptions};
@ -18,15 +18,15 @@ use std::time::Duration;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use rumqttc::ConnectReturnCode::Success; use rumqttc::ConnectReturnCode::Success;
use rustls::{RootCertStore};
use url::Url;
use crate::actors::dummy::Dummy; use crate::actors::dummy::Dummy;
use crate::actors::process::Process; use crate::actors::process::Process;
use crate::db::ArchivedValue; use crate::db::ArchivedValue;
use rustls::RootCertStore;
use url::Url;
mod shelly;
mod process;
mod dummy; mod dummy;
mod process;
mod shelly;
pub trait Actor { pub trait Actor {
fn apply(&mut self, state: ArchivedValue<State>) -> BoxFuture<'static, ()>; fn apply(&mut self, state: ArchivedValue<State>) -> BoxFuture<'static, ()>;
@ -102,7 +102,7 @@ static ROOT_CERTS: Lazy<RootCertStore> = Lazy::new(|| {
} else { } else {
tracing::info!(loaded, "certificates loaded"); tracing::info!(loaded, "certificates loaded");
} }
}, }
Err(error) => { Err(error) => {
tracing::error!(%error, "failed to load system certificates"); tracing::error!(%error, "failed to load system certificates");
} }
@ -219,8 +219,10 @@ pub fn load(executor: Executor, config: &Config, resources: ResourcesHandle) ->
.compat(), .compat(),
); );
let mut actor_map: HashMap<String, _> = config.actor_connections.iter() let mut actor_map: HashMap<String, _> = config
.filter_map(|(k,v)| { .actor_connections
.iter()
.filter_map(|(k, v)| {
if let Some(resource) = resources.get_by_id(v) { if let Some(resource) = resources.get_by_id(v) {
Some((k.clone(), resource.get_signal())) Some((k.clone(), resource.get_signal()))
} else { } else {
@ -258,8 +260,6 @@ fn load_single(
"Dummy" => Some(Box::new(Dummy::new(name.clone(), params.clone()))), "Dummy" => Some(Box::new(Dummy::new(name.clone(), params.clone()))),
"Process" => Process::new(name.clone(), params).map(|a| a.into_boxed_actuator()), "Process" => Process::new(name.clone(), params).map(|a| a.into_boxed_actuator()),
"Shelly" => Some(Box::new(Shelly::new(name.clone(), client, params))), "Shelly" => Some(Box::new(Shelly::new(name.clone(), client, params))),
_ => { _ => None,
None
}
} }
} }

View File

@ -1,6 +1,6 @@
use futures_util::future::BoxFuture;
use std::collections::HashMap; use std::collections::HashMap;
use std::process::{Command, Stdio}; use std::process::{Command, Stdio};
use futures_util::future::BoxFuture;
use crate::actors::Actor; use crate::actors::Actor;
use crate::db::ArchivedValue; use crate::db::ArchivedValue;
@ -16,10 +16,9 @@ pub struct Process {
impl Process { impl Process {
pub fn new(name: String, params: &HashMap<String, String>) -> Option<Self> { pub fn new(name: String, params: &HashMap<String, String>) -> Option<Self> {
let cmd = params.get("cmd").map(|s| s.to_string())?; let cmd = params.get("cmd").map(|s| s.to_string())?;
let args = params.get("args").map(|argv| let args = params
argv.split_whitespace() .get("args")
.map(|s| s.to_string()) .map(|argv| argv.split_whitespace().map(|s| s.to_string()).collect())
.collect())
.unwrap_or_else(Vec::new); .unwrap_or_else(Vec::new);
Some(Self { name, cmd, args }) Some(Self { name, cmd, args })
@ -48,22 +47,22 @@ impl Actor for Process {
command.arg("inuse").arg(by.id.as_str()); command.arg("inuse").arg(by.id.as_str());
} }
ArchivedStatus::ToCheck(by) => { ArchivedStatus::ToCheck(by) => {
command.arg("tocheck") command.arg("tocheck").arg(by.id.as_str());
.arg(by.id.as_str());
} }
ArchivedStatus::Blocked(by) => { ArchivedStatus::Blocked(by) => {
command.arg("blocked") command.arg("blocked").arg(by.id.as_str());
.arg(by.id.as_str()); }
ArchivedStatus::Disabled => {
command.arg("disabled");
} }
ArchivedStatus::Disabled => { command.arg("disabled"); },
ArchivedStatus::Reserved(by) => { ArchivedStatus::Reserved(by) => {
command.arg("reserved") command.arg("reserved").arg(by.id.as_str());
.arg(by.id.as_str());
} }
} }
let name = self.name.clone(); let name = self.name.clone();
Box::pin(async move { match command.output() { Box::pin(async move {
match command.output() {
Ok(retv) if retv.status.success() => { Ok(retv) if retv.status.success() => {
tracing::trace!("Actor was successful"); tracing::trace!("Actor was successful");
let outstr = String::from_utf8_lossy(&retv.stdout); let outstr = String::from_utf8_lossy(&retv.stdout);
@ -83,6 +82,7 @@ impl Actor for Process {
} }
} }
Err(error) => tracing::warn!(%name, ?error, "process actor failed to run cmd"), Err(error) => tracing::warn!(%name, ?error, "process actor failed to run cmd"),
}}) }
})
} }
} }

View File

@ -1,11 +1,11 @@
use std::collections::HashMap;
use futures_util::future::BoxFuture; use futures_util::future::BoxFuture;
use std::collections::HashMap;
use rumqttc::{AsyncClient, QoS};
use crate::actors::Actor; use crate::actors::Actor;
use crate::db::ArchivedValue; use crate::db::ArchivedValue;
use crate::resources::modules::fabaccess::ArchivedStatus; use crate::resources::modules::fabaccess::ArchivedStatus;
use crate::resources::state::State; use crate::resources::state::State;
use rumqttc::{AsyncClient, QoS};
/// An actuator for a Shellie connected listening on one MQTT broker /// An actuator for a Shellie connected listening on one MQTT broker
/// ///
@ -28,7 +28,11 @@ impl Shelly {
tracing::debug!(%name,%topic,"Starting shelly module"); tracing::debug!(%name,%topic,"Starting shelly module");
Shelly { name, client, topic, } Shelly {
name,
client,
topic,
}
} }
/// Set the name to a new one. This changes the shelly that will be activated /// Set the name to a new one. This changes the shelly that will be activated
@ -38,7 +42,6 @@ impl Shelly {
} }
} }
impl Actor for Shelly { impl Actor for Shelly {
fn apply(&mut self, state: ArchivedValue<State>) -> BoxFuture<'static, ()> { fn apply(&mut self, state: ArchivedValue<State>) -> BoxFuture<'static, ()> {
tracing::debug!(?state, name=%self.name, tracing::debug!(?state, name=%self.name,

View File

@ -1,11 +1,11 @@
use once_cell::sync::OnceCell;
use std::fs::{File, OpenOptions}; use std::fs::{File, OpenOptions};
use std::io; use std::io;
use std::io::{LineWriter, Write}; use std::io::{LineWriter, Write};
use std::sync::Mutex; use std::sync::Mutex;
use once_cell::sync::OnceCell;
use crate::Config; use crate::Config;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use serde_json::Serializer; use serde_json::Serializer;
pub static AUDIT: OnceCell<AuditLog> = OnceCell::new(); pub static AUDIT: OnceCell<AuditLog> = OnceCell::new();
@ -26,7 +26,10 @@ impl AuditLog {
pub fn new(config: &Config) -> io::Result<&'static Self> { pub fn new(config: &Config) -> io::Result<&'static Self> {
AUDIT.get_or_try_init(|| { AUDIT.get_or_try_init(|| {
tracing::debug!(path = %config.auditlog_path.display(), "Initializing audit log"); tracing::debug!(path = %config.auditlog_path.display(), "Initializing audit log");
let fd = OpenOptions::new().create(true).append(true).open(&config.auditlog_path)?; let fd = OpenOptions::new()
.create(true)
.append(true)
.open(&config.auditlog_path)?;
let writer = Mutex::new(LineWriter::new(fd)); let writer = Mutex::new(LineWriter::new(fd));
Ok(Self { writer }) Ok(Self { writer })
}) })
@ -34,7 +37,11 @@ impl AuditLog {
pub fn log(&self, machine: &str, state: &str) -> io::Result<()> { pub fn log(&self, machine: &str, state: &str) -> io::Result<()> {
let timestamp = chrono::Utc::now().timestamp(); let timestamp = chrono::Utc::now().timestamp();
let line = AuditLogLine { timestamp, machine, state }; let line = AuditLogLine {
timestamp,
machine,
state,
};
tracing::debug!(?line, "writing audit log line"); tracing::debug!(?line, "writing audit log line");
@ -42,7 +49,8 @@ impl AuditLog {
let mut writer: &mut LineWriter<File> = &mut *guard; let mut writer: &mut LineWriter<File> = &mut *guard;
let mut ser = Serializer::new(&mut writer); let mut ser = Serializer::new(&mut writer);
line.serialize(&mut ser).expect("failed to serialize audit log line"); line.serialize(&mut ser)
.expect("failed to serialize audit log line");
writer.write("\n".as_bytes())?; writer.write("\n".as_bytes())?;
Ok(()) Ok(())
} }

View File

@ -19,8 +19,8 @@ pub static FABFIRE: Mechanism = Mechanism {
first: Side::Client, first: Side::Client,
}; };
use rsasl::property::{Property, PropertyDefinition, PropertyQ};
use std::marker::PhantomData; use std::marker::PhantomData;
use rsasl::property::{Property, PropertyQ, PropertyDefinition};
// All Property types must implement Debug. // All Property types must implement Debug.
#[derive(Debug)] #[derive(Debug)]
// The `PhantomData` in the constructor is only used so external crates can't construct this type. // The `PhantomData` in the constructor is only used so external crates can't construct this type.

View File

@ -1,17 +1,17 @@
use std::fmt::{Debug, Display, Formatter}; use desfire::desfire::desfire::MAX_BYTES_PER_TRANSACTION;
use std::io::Write; use desfire::desfire::Desfire;
use desfire::error::Error as DesfireError;
use desfire::iso7816_4::apduresponse::APDUResponse;
use rsasl::error::{MechanismError, MechanismErrorKind, SASLError, SessionError}; use rsasl::error::{MechanismError, MechanismErrorKind, SASLError, SessionError};
use rsasl::mechanism::Authentication; use rsasl::mechanism::Authentication;
use rsasl::SASL;
use rsasl::session::{SessionData, StepResult};
use serde::{Deserialize, Serialize};
use desfire::desfire::Desfire;
use desfire::iso7816_4::apduresponse::APDUResponse;
use desfire::error::{Error as DesfireError};
use std::convert::TryFrom;
use std::sync::Arc;
use desfire::desfire::desfire::MAX_BYTES_PER_TRANSACTION;
use rsasl::property::AuthId; use rsasl::property::AuthId;
use rsasl::session::{SessionData, StepResult};
use rsasl::SASL;
use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use std::fmt::{Debug, Display, Formatter};
use std::io::Write;
use std::sync::Arc;
use crate::authentication::fabfire::FabFireCardKey; use crate::authentication::fabfire::FabFireCardKey;
@ -37,7 +37,9 @@ impl Debug for FabFireError {
FabFireError::InvalidMagic(magic) => write!(f, "InvalidMagic: {}", magic), FabFireError::InvalidMagic(magic) => write!(f, "InvalidMagic: {}", magic),
FabFireError::InvalidToken(token) => write!(f, "InvalidToken: {}", token), FabFireError::InvalidToken(token) => write!(f, "InvalidToken: {}", token),
FabFireError::InvalidURN(urn) => write!(f, "InvalidURN: {}", urn), FabFireError::InvalidURN(urn) => write!(f, "InvalidURN: {}", urn),
FabFireError::InvalidCredentials(credentials) => write!(f, "InvalidCredentials: {}", credentials), FabFireError::InvalidCredentials(credentials) => {
write!(f, "InvalidCredentials: {}", credentials)
}
FabFireError::Session(err) => write!(f, "Session: {}", err), FabFireError::Session(err) => write!(f, "Session: {}", err),
} }
} }
@ -53,7 +55,9 @@ impl Display for FabFireError {
FabFireError::InvalidMagic(magic) => write!(f, "InvalidMagic: {}", magic), FabFireError::InvalidMagic(magic) => write!(f, "InvalidMagic: {}", magic),
FabFireError::InvalidToken(token) => write!(f, "InvalidToken: {}", token), FabFireError::InvalidToken(token) => write!(f, "InvalidToken: {}", token),
FabFireError::InvalidURN(urn) => write!(f, "InvalidURN: {}", urn), FabFireError::InvalidURN(urn) => write!(f, "InvalidURN: {}", urn),
FabFireError::InvalidCredentials(credentials) => write!(f, "InvalidCredentials: {}", credentials), FabFireError::InvalidCredentials(credentials) => {
write!(f, "InvalidCredentials: {}", credentials)
}
FabFireError::Session(err) => write!(f, "Session: {}", err), FabFireError::Session(err) => write!(f, "Session: {}", err),
} }
} }
@ -107,16 +111,22 @@ enum CardCommand {
addn_txt: Option<String>, addn_txt: Option<String>,
}, },
sendPICC { sendPICC {
#[serde(deserialize_with = "hex::deserialize", serialize_with = "hex::serialize_upper")] #[serde(
data: Vec<u8> deserialize_with = "hex::deserialize",
serialize_with = "hex::serialize_upper"
)]
data: Vec<u8>,
}, },
readPICC { readPICC {
#[serde(deserialize_with = "hex::deserialize", serialize_with = "hex::serialize_upper")] #[serde(
data: Vec<u8> deserialize_with = "hex::deserialize",
serialize_with = "hex::serialize_upper"
)]
data: Vec<u8>,
}, },
haltPICC, haltPICC,
Key { Key {
data: String data: String,
}, },
ConfirmUser, ConfirmUser,
} }
@ -145,18 +155,35 @@ const MAGIC: &'static str = "FABACCESS\0DESFIRE\01.0\0";
impl FabFire { impl FabFire {
pub fn new_server(_sasl: &SASL) -> Result<Box<dyn Authentication>, SASLError> { pub fn new_server(_sasl: &SASL) -> Result<Box<dyn Authentication>, SASLError> {
Ok(Box::new(Self { step: Step::New, card_info: None, key_info: None, auth_info: None, app_id: 1, local_urn: "urn:fabaccess:lab:innovisionlab".to_string(), desfire: Desfire { card: None, session_key: None, cbc_iv: None } })) Ok(Box::new(Self {
step: Step::New,
card_info: None,
key_info: None,
auth_info: None,
app_id: 1,
local_urn: "urn:fabaccess:lab:innovisionlab".to_string(),
desfire: Desfire {
card: None,
session_key: None,
cbc_iv: None,
},
}))
} }
} }
impl Authentication for FabFire { impl Authentication for FabFire {
fn step(&mut self, session: &mut SessionData, input: Option<&[u8]>, writer: &mut dyn Write) -> StepResult { fn step(
&mut self,
session: &mut SessionData,
input: Option<&[u8]>,
writer: &mut dyn Write,
) -> StepResult {
match self.step { match self.step {
Step::New => { Step::New => {
tracing::trace!("Step: New"); tracing::trace!("Step: New");
//receive card info (especially card UID) from reader //receive card info (especially card UID) from reader
return match input { return match input {
None => { Err(SessionError::InputDataRequired) } None => Err(SessionError::InputDataRequired),
Some(cardinfo) => { Some(cardinfo) => {
self.card_info = match serde_json::from_slice(cardinfo) { self.card_info = match serde_json::from_slice(cardinfo) {
Ok(card_info) => Some(card_info), Ok(card_info) => Some(card_info),
@ -170,7 +197,10 @@ impl Authentication for FabFire {
Ok(buf) => match Vec::<u8>::try_from(buf) { Ok(buf) => match Vec::<u8>::try_from(buf) {
Ok(data) => data, Ok(data) => data,
Err(e) => { Err(e) => {
tracing::error!("Failed to convert APDUCommand to Vec<u8>: {:?}", e); tracing::error!(
"Failed to convert APDUCommand to Vec<u8>: {:?}",
e
);
return Err(FabFireError::SerializationError.into()); return Err(FabFireError::SerializationError.into());
} }
}, },
@ -183,7 +213,9 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) { return match serde_json::to_vec(&cmd) {
Ok(send_buf) => { Ok(send_buf) => {
self.step = Step::SelectApp; self.step = Step::SelectApp;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?; writer
.write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len()))) Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len())))
} }
Err(e) => { Err(e) => {
@ -198,30 +230,39 @@ impl Authentication for FabFire {
tracing::trace!("Step: SelectApp"); tracing::trace!("Step: SelectApp");
// check that we successfully selected the application // check that we successfully selected the application
let response: CardCommand = match input { let response: CardCommand = match input {
None => { return Err(SessionError::InputDataRequired); } None => {
Some(buf) => match serde_json::from_slice(buf).map_err(|e| FabFireError::DeserializationError(e)) { return Err(SessionError::InputDataRequired);
}
Some(buf) => match serde_json::from_slice(buf)
.map_err(|e| FabFireError::DeserializationError(e))
{
Ok(response) => response, Ok(response) => response,
Err(e) => { Err(e) => {
tracing::error!("Deserializing data from card failed: {:?}", e); tracing::error!("Deserializing data from card failed: {:?}", e);
return Err(e.into()); return Err(e.into());
} }
} },
}; };
let apdu_response = match response { let apdu_response = match response {
CardCommand::readPICC { data } => { APDUResponse::new(&*data) } CardCommand::readPICC { data } => APDUResponse::new(&*data),
_ => { _ => {
tracing::error!("Unexpected response: {:?}", response); tracing::error!("Unexpected response: {:?}", response);
return Err(FabFireError::ParseError.into()); return Err(FabFireError::ParseError.into());
} }
}; };
apdu_response.check().map_err(|e| FabFireError::CardError(e))?; apdu_response
.check()
.map_err(|e| FabFireError::CardError(e))?;
// request the contents of the file containing the magic string // request the contents of the file containing the magic string
const MAGIC_FILE_ID: u8 = 0x01; const MAGIC_FILE_ID: u8 = 0x01;
let buf = match self.desfire.read_data_chunk_cmd(MAGIC_FILE_ID, 0, MAGIC.len()) { let buf = match self
.desfire
.read_data_chunk_cmd(MAGIC_FILE_ID, 0, MAGIC.len())
{
Ok(buf) => match Vec::<u8>::try_from(buf) { Ok(buf) => match Vec::<u8>::try_from(buf) {
Ok(data) => data, Ok(data) => data,
Err(e) => { Err(e) => {
@ -238,7 +279,9 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) { return match serde_json::to_vec(&cmd) {
Ok(send_buf) => { Ok(send_buf) => {
self.step = Step::VerifyMagic; self.step = Step::VerifyMagic;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?; writer
.write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len()))) Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len())))
} }
Err(e) => { Err(e) => {
@ -251,25 +294,28 @@ impl Authentication for FabFire {
tracing::trace!("Step: VerifyMagic"); tracing::trace!("Step: VerifyMagic");
// verify the magic string to determine that we have a valid fabfire card // verify the magic string to determine that we have a valid fabfire card
let response: CardCommand = match input { let response: CardCommand = match input {
None => { return Err(SessionError::InputDataRequired); } None => {
Some(buf) => match serde_json::from_slice(buf).map_err(|e| FabFireError::DeserializationError(e)) { return Err(SessionError::InputDataRequired);
}
Some(buf) => match serde_json::from_slice(buf)
.map_err(|e| FabFireError::DeserializationError(e))
{
Ok(response) => response, Ok(response) => response,
Err(e) => { Err(e) => {
tracing::error!("Deserializing data from card failed: {:?}", e); tracing::error!("Deserializing data from card failed: {:?}", e);
return Err(e.into()); return Err(e.into());
} }
} },
}; };
let apdu_response = match response { let apdu_response = match response {
CardCommand::readPICC { data } => { APDUResponse::new(&*data) } CardCommand::readPICC { data } => APDUResponse::new(&*data),
_ => { _ => {
tracing::error!("Unexpected response: {:?}", response); tracing::error!("Unexpected response: {:?}", response);
return Err(FabFireError::ParseError.into()); return Err(FabFireError::ParseError.into());
} }
}; };
match apdu_response.check() { match apdu_response.check() {
Ok(_) => { Ok(_) => {
match apdu_response.body { match apdu_response.body {
@ -291,11 +337,15 @@ impl Authentication for FabFire {
} }
} }
// request the contents of the file containing the URN // request the contents of the file containing the URN
const URN_FILE_ID: u8 = 0x02; const URN_FILE_ID: u8 = 0x02;
let buf = match self.desfire.read_data_chunk_cmd(URN_FILE_ID, 0, self.local_urn.as_bytes().len()) { // TODO: support urn longer than 47 Bytes let buf = match self.desfire.read_data_chunk_cmd(
URN_FILE_ID,
0,
self.local_urn.as_bytes().len(),
) {
// TODO: support urn longer than 47 Bytes
Ok(buf) => match Vec::<u8>::try_from(buf) { Ok(buf) => match Vec::<u8>::try_from(buf) {
Ok(data) => data, Ok(data) => data,
Err(e) => { Err(e) => {
@ -312,7 +362,9 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) { return match serde_json::to_vec(&cmd) {
Ok(send_buf) => { Ok(send_buf) => {
self.step = Step::GetURN; self.step = Step::GetURN;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?; writer
.write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len()))) Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len())))
} }
Err(e) => { Err(e) => {
@ -325,32 +377,39 @@ impl Authentication for FabFire {
tracing::trace!("Step: GetURN"); tracing::trace!("Step: GetURN");
// parse the urn and match it to our local urn // parse the urn and match it to our local urn
let response: CardCommand = match input { let response: CardCommand = match input {
None => { return Err(SessionError::InputDataRequired); } None => {
Some(buf) => match serde_json::from_slice(buf).map_err(|e| FabFireError::DeserializationError(e)) { return Err(SessionError::InputDataRequired);
}
Some(buf) => match serde_json::from_slice(buf)
.map_err(|e| FabFireError::DeserializationError(e))
{
Ok(response) => response, Ok(response) => response,
Err(e) => { Err(e) => {
tracing::error!("Deserializing data from card failed: {:?}", e); tracing::error!("Deserializing data from card failed: {:?}", e);
return Err(e.into()); return Err(e.into());
} }
} },
}; };
let apdu_response = match response { let apdu_response = match response {
CardCommand::readPICC { data } => { APDUResponse::new(&*data) } CardCommand::readPICC { data } => APDUResponse::new(&*data),
_ => { _ => {
tracing::error!("Unexpected response: {:?}", response); tracing::error!("Unexpected response: {:?}", response);
return Err(FabFireError::ParseError.into()); return Err(FabFireError::ParseError.into());
} }
}; };
match apdu_response.check() { match apdu_response.check() {
Ok(_) => { Ok(_) => {
match apdu_response.body { match apdu_response.body {
Some(data) => { Some(data) => {
let received_urn = String::from_utf8(data).unwrap(); let received_urn = String::from_utf8(data).unwrap();
if received_urn != self.local_urn { if received_urn != self.local_urn {
tracing::error!("URN mismatch: {:?} != {:?}", received_urn, self.local_urn); tracing::error!(
"URN mismatch: {:?} != {:?}",
received_urn,
self.local_urn
);
return Err(FabFireError::ParseError.into()); return Err(FabFireError::ParseError.into());
} }
} }
@ -368,7 +427,12 @@ impl Authentication for FabFire {
// request the contents of the file containing the URN // request the contents of the file containing the URN
const TOKEN_FILE_ID: u8 = 0x03; const TOKEN_FILE_ID: u8 = 0x03;
let buf = match self.desfire.read_data_chunk_cmd(TOKEN_FILE_ID, 0, MAX_BYTES_PER_TRANSACTION) { // TODO: support data longer than 47 Bytes let buf = match self.desfire.read_data_chunk_cmd(
TOKEN_FILE_ID,
0,
MAX_BYTES_PER_TRANSACTION,
) {
// TODO: support data longer than 47 Bytes
Ok(buf) => match Vec::<u8>::try_from(buf) { Ok(buf) => match Vec::<u8>::try_from(buf) {
Ok(data) => data, Ok(data) => data,
Err(e) => { Err(e) => {
@ -385,7 +449,9 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) { return match serde_json::to_vec(&cmd) {
Ok(send_buf) => { Ok(send_buf) => {
self.step = Step::GetToken; self.step = Step::GetToken;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?; writer
.write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len()))) Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len())))
} }
Err(e) => { Err(e) => {
@ -398,43 +464,52 @@ impl Authentication for FabFire {
// println!("Step: GetToken"); // println!("Step: GetToken");
// parse the token and select the appropriate user // parse the token and select the appropriate user
let response: CardCommand = match input { let response: CardCommand = match input {
None => { return Err(SessionError::InputDataRequired); } None => {
Some(buf) => match serde_json::from_slice(buf).map_err(|e| FabFireError::DeserializationError(e)) { return Err(SessionError::InputDataRequired);
}
Some(buf) => match serde_json::from_slice(buf)
.map_err(|e| FabFireError::DeserializationError(e))
{
Ok(response) => response, Ok(response) => response,
Err(e) => { Err(e) => {
tracing::error!("Deserializing data from card failed: {:?}", e); tracing::error!("Deserializing data from card failed: {:?}", e);
return Err(e.into()); return Err(e.into());
} }
} },
}; };
let apdu_response = match response { let apdu_response = match response {
CardCommand::readPICC { data } => { APDUResponse::new(&*data) } CardCommand::readPICC { data } => APDUResponse::new(&*data),
_ => { _ => {
tracing::error!("Unexpected response: {:?}", response); tracing::error!("Unexpected response: {:?}", response);
return Err(FabFireError::ParseError.into()); return Err(FabFireError::ParseError.into());
} }
}; };
match apdu_response.check() { match apdu_response.check() {
Ok(_) => { Ok(_) => {
match apdu_response.body { match apdu_response.body {
Some(data) => { Some(data) => {
let token = String::from_utf8(data).unwrap(); let token = String::from_utf8(data).unwrap();
session.set_property::<AuthId>(Arc::new(token.trim_matches(char::from(0)).to_string())); session.set_property::<AuthId>(Arc::new(
let key = match session.get_property_or_callback::<FabFireCardKey>() { token.trim_matches(char::from(0)).to_string(),
));
let key = match session.get_property_or_callback::<FabFireCardKey>()
{
Ok(Some(key)) => Box::from(key.as_slice()), Ok(Some(key)) => Box::from(key.as_slice()),
Ok(None) => { Ok(None) => {
tracing::error!("No keys on file for token"); tracing::error!("No keys on file for token");
return Err(FabFireError::InvalidCredentials("No keys on file for token".to_string()).into()); return Err(FabFireError::InvalidCredentials(
"No keys on file for token".to_string(),
)
.into());
} }
Err(e) => { Err(e) => {
tracing::error!("Failed to get key: {:?}", e); tracing::error!("Failed to get key: {:?}", e);
return Err(FabFireError::Session(e).into()); return Err(FabFireError::Session(e).into());
} }
}; };
self.key_info = Some(KeyInfo{ key_id: 0x01, key }); self.key_info = Some(KeyInfo { key_id: 0x01, key });
} }
None => { None => {
tracing::error!("No data in response"); tracing::error!("No data in response");
@ -448,7 +523,10 @@ impl Authentication for FabFire {
} }
} }
let buf = match self.desfire.authenticate_iso_aes_challenge_cmd(self.key_info.as_ref().unwrap().key_id) { let buf = match self
.desfire
.authenticate_iso_aes_challenge_cmd(self.key_info.as_ref().unwrap().key_id)
{
Ok(buf) => match Vec::<u8>::try_from(buf) { Ok(buf) => match Vec::<u8>::try_from(buf) {
Ok(data) => data, Ok(data) => data,
Err(e) => { Err(e) => {
@ -465,7 +543,9 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) { return match serde_json::to_vec(&cmd) {
Ok(send_buf) => { Ok(send_buf) => {
self.step = Step::Authenticate1; self.step = Step::Authenticate1;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?; writer
.write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len()))) Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len())))
} }
Err(e) => { Err(e) => {
@ -477,25 +557,28 @@ impl Authentication for FabFire {
Step::Authenticate1 => { Step::Authenticate1 => {
tracing::trace!("Step: Authenticate1"); tracing::trace!("Step: Authenticate1");
let response: CardCommand = match input { let response: CardCommand = match input {
None => { return Err(SessionError::InputDataRequired); } None => {
Some(buf) => match serde_json::from_slice(buf).map_err(|e| FabFireError::DeserializationError(e)) { return Err(SessionError::InputDataRequired);
}
Some(buf) => match serde_json::from_slice(buf)
.map_err(|e| FabFireError::DeserializationError(e))
{
Ok(response) => response, Ok(response) => response,
Err(e) => { Err(e) => {
tracing::error!("Failed to deserialize response: {:?}", e); tracing::error!("Failed to deserialize response: {:?}", e);
return Err(e.into()); return Err(e.into());
} }
} },
}; };
let apdu_response = match response { let apdu_response = match response {
CardCommand::readPICC { data } => { APDUResponse::new(&*data) } CardCommand::readPICC { data } => APDUResponse::new(&*data),
_ => { _ => {
tracing::error!("Unexpected response: {:?}", response); tracing::error!("Unexpected response: {:?}", response);
return Err(FabFireError::ParseError.into()); return Err(FabFireError::ParseError.into());
} }
}; };
match apdu_response.check() { match apdu_response.check() {
Ok(_) => { Ok(_) => {
match apdu_response.body { match apdu_response.body {
@ -506,13 +589,19 @@ impl Authentication for FabFire {
//TODO: Check if we need a CSPRNG here //TODO: Check if we need a CSPRNG here
let rnd_a: [u8; 16] = rand::random(); let rnd_a: [u8; 16] = rand::random();
let (cmd_challenge_response, let (cmd_challenge_response, rnd_b, iv) = self
rnd_b, .desfire
iv) = self.desfire.authenticate_iso_aes_response_cmd( .authenticate_iso_aes_response_cmd(
rnd_b_enc, rnd_b_enc,
&*(self.key_info.as_ref().unwrap().key), &*(self.key_info.as_ref().unwrap().key),
&rnd_a).unwrap(); &rnd_a,
self.auth_info = Some(AuthInfo { rnd_a: Vec::<u8>::from(rnd_a), rnd_b, iv }); )
.unwrap();
self.auth_info = Some(AuthInfo {
rnd_a: Vec::<u8>::from(rnd_a),
rnd_b,
iv,
});
let buf = match Vec::<u8>::try_from(cmd_challenge_response) { let buf = match Vec::<u8>::try_from(cmd_challenge_response) {
Ok(data) => data, Ok(data) => data,
Err(e) => { Err(e) => {
@ -524,7 +613,9 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) { return match serde_json::to_vec(&cmd) {
Ok(send_buf) => { Ok(send_buf) => {
self.step = Step::Authenticate2; self.step = Step::Authenticate2;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?; writer
.write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len()))) Ok(rsasl::session::Step::NeedsMore(Some(send_buf.len())))
} }
Err(e) => { Err(e) => {
@ -548,39 +639,48 @@ impl Authentication for FabFire {
Step::Authenticate2 => { Step::Authenticate2 => {
// println!("Step: Authenticate2"); // println!("Step: Authenticate2");
let response: CardCommand = match input { let response: CardCommand = match input {
None => { return Err(SessionError::InputDataRequired); } None => {
Some(buf) => match serde_json::from_slice(buf).map_err(|e| FabFireError::DeserializationError(e)) { return Err(SessionError::InputDataRequired);
}
Some(buf) => match serde_json::from_slice(buf)
.map_err(|e| FabFireError::DeserializationError(e))
{
Ok(response) => response, Ok(response) => response,
Err(e) => { Err(e) => {
tracing::error!("Failed to deserialize response: {:?}", e); tracing::error!("Failed to deserialize response: {:?}", e);
return Err(e.into()); return Err(e.into());
} }
} },
}; };
let apdu_response = match response { let apdu_response = match response {
CardCommand::readPICC { data } => { APDUResponse::new(&*data) } CardCommand::readPICC { data } => APDUResponse::new(&*data),
_ => { _ => {
tracing::error!("Got invalid response: {:?}", response); tracing::error!("Got invalid response: {:?}", response);
return Err(FabFireError::ParseError.into()); return Err(FabFireError::ParseError.into());
} }
}; };
match apdu_response.check() { match apdu_response.check() {
Ok(_) => { Ok(_) => {
match apdu_response.body { match apdu_response.body {
Some(data) => { Some(data) => match self.auth_info.as_ref() {
match self.auth_info.as_ref() { None => {
None => { return Err(FabFireError::ParseError.into()); } return Err(FabFireError::ParseError.into());
}
Some(auth_info) => { Some(auth_info) => {
if self.desfire.authenticate_iso_aes_verify( if self
.desfire
.authenticate_iso_aes_verify(
data.as_slice(), data.as_slice(),
auth_info.rnd_a.as_slice(), auth_info.rnd_a.as_slice(),
auth_info.rnd_b.as_slice(), &*(self.key_info.as_ref().unwrap().key), auth_info.rnd_b.as_slice(),
auth_info.iv.as_slice()).is_ok() { &*(self.key_info.as_ref().unwrap().key),
auth_info.iv.as_slice(),
let cmd = CardCommand::message{ )
.is_ok()
{
let cmd = CardCommand::message {
msg_id: Some(4), msg_id: Some(4),
clr_txt: None, clr_txt: None,
addn_txt: Some("".to_string()), addn_txt: Some("".to_string()),
@ -588,18 +688,24 @@ impl Authentication for FabFire {
return match serde_json::to_vec(&cmd) { return match serde_json::to_vec(&cmd) {
Ok(send_buf) => { Ok(send_buf) => {
self.step = Step::Authenticate1; self.step = Step::Authenticate1;
writer.write_all(&send_buf).map_err(|e| SessionError::Io { source: e })?; writer
return Ok(rsasl::session::Step::Done(Some(send_buf.len()))) .write_all(&send_buf)
.map_err(|e| SessionError::Io { source: e })?;
return Ok(rsasl::session::Step::Done(Some(
send_buf.len(),
)));
} }
Err(e) => { Err(e) => {
tracing::error!("Failed to serialize command: {:?}", e); tracing::error!(
"Failed to serialize command: {:?}",
e
);
Err(FabFireError::SerializationError.into()) Err(FabFireError::SerializationError.into())
} }
}; };
} }
} }
} },
}
None => { None => {
tracing::error!("got empty response"); tracing::error!("got empty response");
return Err(FabFireError::ParseError.into()); return Err(FabFireError::ParseError.into());
@ -608,7 +714,9 @@ impl Authentication for FabFire {
} }
Err(_e) => { Err(_e) => {
tracing::error!("Got invalid response: {:?}", apdu_response); tracing::error!("Got invalid response: {:?}", apdu_response);
return Err(FabFireError::InvalidCredentials(format!("{}", apdu_response)).into()); return Err(
FabFireError::InvalidCredentials(format!("{}", apdu_response)).into(),
);
} }
} }
} }

View File

@ -1,6 +1,5 @@
use crate::users::Users; use crate::users::Users;
use rsasl::error::{SessionError}; use rsasl::error::SessionError;
use rsasl::mechname::Mechname; use rsasl::mechname::Mechname;
use rsasl::property::{AuthId, Password}; use rsasl::property::{AuthId, Password};
use rsasl::session::{Session, SessionData}; use rsasl::session::{Session, SessionData};
@ -31,12 +30,18 @@ impl rsasl::callback::Callback for Callback {
match property { match property {
fabfire::FABFIRECARDKEY => { fabfire::FABFIRECARDKEY => {
let authcid = session.get_property_or_callback::<AuthId>()?; let authcid = session.get_property_or_callback::<AuthId>()?;
let user = self.users.get_user(authcid.unwrap().as_ref()) let user = self
.users
.get_user(authcid.unwrap().as_ref())
.ok_or(SessionError::AuthenticationFailure)?; .ok_or(SessionError::AuthenticationFailure)?;
let kv = user.userdata.kv.get("cardkey") let kv = user
.userdata
.kv
.get("cardkey")
.ok_or(SessionError::AuthenticationFailure)?; .ok_or(SessionError::AuthenticationFailure)?;
let card_key = <[u8; 16]>::try_from(hex::decode(kv) let card_key = <[u8; 16]>::try_from(
.map_err(|_| SessionError::AuthenticationFailure)?) hex::decode(kv).map_err(|_| SessionError::AuthenticationFailure)?,
)
.map_err(|_| SessionError::AuthenticationFailure)?; .map_err(|_| SessionError::AuthenticationFailure)?;
session.set_property::<FabFireCardKey>(Arc::new(card_key)); session.set_property::<FabFireCardKey>(Arc::new(card_key));
Ok(()) Ok(())
@ -60,9 +65,7 @@ impl rsasl::callback::Callback for Callback {
.ok_or(SessionError::no_property::<AuthId>())?; .ok_or(SessionError::no_property::<AuthId>())?;
tracing::debug!(authid=%authnid, "SIMPLE validation requested"); tracing::debug!(authid=%authnid, "SIMPLE validation requested");
if let Some(user) = self if let Some(user) = self.users.get_user(authnid.as_str()) {
.users
.get_user(authnid.as_str()) {
let passwd = session let passwd = session
.get_property::<Password>() .get_property::<Password>()
.ok_or(SessionError::no_property::<Password>())?; .ok_or(SessionError::no_property::<Password>())?;
@ -84,7 +87,7 @@ impl rsasl::callback::Callback for Callback {
_ => { _ => {
tracing::error!(?validation, "Unimplemented validation requested"); tracing::error!(?validation, "Unimplemented validation requested");
Err(SessionError::no_validate(validation)) Err(SessionError::no_validate(validation))
}, }
} }
} }
} }
@ -111,10 +114,12 @@ impl AuthenticationHandle {
let mut rsasl = SASL::new(); let mut rsasl = SASL::new();
rsasl.install_callback(Arc::new(Callback::new(userdb))); rsasl.install_callback(Arc::new(Callback::new(userdb)));
let mechs: Vec<&'static str> = rsasl.server_mech_list().into_iter() let mechs: Vec<&'static str> = rsasl
.server_mech_list()
.into_iter()
.map(|m| m.mechanism.as_str()) .map(|m| m.mechanism.as_str())
.collect(); .collect();
tracing::info!(available_mechs=mechs.len(), "initialized sasl backend"); tracing::info!(available_mechs = mechs.len(), "initialized sasl backend");
tracing::debug!(?mechs, "available mechs"); tracing::debug!(?mechs, "available mechs");
Self { Self {

View File

@ -1,9 +1,6 @@
use crate::authorization::roles::Roles;
use crate::authorization::roles::{Roles};
use crate::Users; use crate::Users;
pub mod permissions; pub mod permissions;
pub mod roles; pub mod roles;

View File

@ -1,10 +1,9 @@
//! Access control logic //! Access control logic
//! //!
use std::fmt;
use std::cmp::Ordering; use std::cmp::Ordering;
use std::convert::{TryFrom, Into}; use std::convert::{Into, TryFrom};
use std::fmt;
fn is_sep_char(c: char) -> bool { fn is_sep_char(c: char) -> bool {
c == '.' c == '.'
@ -20,7 +19,7 @@ pub struct PrivilegesBuf {
/// Which permission is required to write parts of this thing /// Which permission is required to write parts of this thing
pub write: PermissionBuf, pub write: PermissionBuf,
/// Which permission is required to manage all parts of this thing /// Which permission is required to manage all parts of this thing
pub manage: PermissionBuf pub manage: PermissionBuf,
} }
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] #[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
@ -39,13 +38,17 @@ impl PermissionBuf {
#[inline(always)] #[inline(always)]
/// Allocate an empty `PermissionBuf` /// Allocate an empty `PermissionBuf`
pub fn new() -> Self { pub fn new() -> Self {
PermissionBuf { inner: String::new() } PermissionBuf {
inner: String::new(),
}
} }
#[inline(always)] #[inline(always)]
/// Allocate a `PermissionBuf` with the given capacity given to the internal [`String`] /// Allocate a `PermissionBuf` with the given capacity given to the internal [`String`]
pub fn with_capacity(cap: usize) -> Self { pub fn with_capacity(cap: usize) -> Self {
PermissionBuf { inner: String::with_capacity(cap) } PermissionBuf {
inner: String::with_capacity(cap),
}
} }
#[inline(always)] #[inline(always)]
@ -59,7 +62,13 @@ impl PermissionBuf {
pub fn _push(&mut self, perm: &Permission) { pub fn _push(&mut self, perm: &Permission) {
// in general we always need a separator unless the last byte is one or the string is empty // in general we always need a separator unless the last byte is one or the string is empty
let need_sep = self.inner.chars().rev().next().map(|c| !is_sep_char(c)).unwrap_or(false); let need_sep = self
.inner
.chars()
.rev()
.next()
.map(|c| !is_sep_char(c))
.unwrap_or(false);
if need_sep { if need_sep {
self.inner.push('.') self.inner.push('.')
} }
@ -73,7 +82,9 @@ impl PermissionBuf {
#[inline] #[inline]
pub fn from_perm(perm: &Permission) -> Self { pub fn from_perm(perm: &Permission) -> Self {
Self { inner: perm.as_str().to_string() } Self {
inner: perm.as_str().to_string(),
}
} }
#[inline(always)] #[inline(always)]
@ -162,12 +173,14 @@ impl PartialOrd for Permission {
} }
} }
match (l,r) { match (l, r) {
(None, None) => Some(Ordering::Equal), (None, None) => Some(Ordering::Equal),
(Some(_), None) => Some(Ordering::Less), (Some(_), None) => Some(Ordering::Less),
(None, Some(_)) => Some(Ordering::Greater), (None, Some(_)) => Some(Ordering::Greater),
(Some(_), Some(_)) => unreachable!("Broken contract in Permission::partial_cmp: sides \ (Some(_), Some(_)) => unreachable!(
should never be both Some!"), "Broken contract in Permission::partial_cmp: sides \
should never be both Some!"
),
} }
} }
} }
@ -208,7 +221,7 @@ impl PermRule {
pub fn match_perm<P: AsRef<Permission> + ?Sized>(&self, perm: &P) -> bool { pub fn match_perm<P: AsRef<Permission> + ?Sized>(&self, perm: &P) -> bool {
match self { match self {
PermRule::Base(ref base) => base.as_permission() == perm.as_ref(), PermRule::Base(ref base) => base.as_permission() == perm.as_ref(),
PermRule::Children(ref parent) => parent.as_permission() > perm.as_ref() , PermRule::Children(ref parent) => parent.as_permission() > perm.as_ref(),
PermRule::Subtree(ref parent) => parent.as_permission() >= perm.as_ref(), PermRule::Subtree(ref parent) => parent.as_permission() >= perm.as_ref(),
} }
} }
@ -217,12 +230,9 @@ impl PermRule {
impl fmt::Display for PermRule { impl fmt::Display for PermRule {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
PermRule::Base(perm) PermRule::Base(perm) => write!(f, "{}", perm),
=> write!(f, "{}", perm), PermRule::Children(parent) => write!(f, "{}.+", parent),
PermRule::Children(parent) PermRule::Subtree(parent) => write!(f, "{}.*", parent),
=> write!(f,"{}.+", parent),
PermRule::Subtree(parent)
=> write!(f,"{}.*", parent),
} }
} }
} }
@ -234,7 +244,7 @@ impl Into<String> for PermRule {
PermRule::Children(mut perm) => { PermRule::Children(mut perm) => {
perm.push(Permission::new("+")); perm.push(Permission::new("+"));
perm.into_string() perm.into_string()
}, }
PermRule::Subtree(mut perm) => { PermRule::Subtree(mut perm) => {
perm.push(Permission::new("+")); perm.push(Permission::new("+"));
perm.into_string() perm.into_string()
@ -252,15 +262,19 @@ impl TryFrom<String> for PermRule {
if len <= 2 { if len <= 2 {
Err("Input string for PermRule is too short") Err("Input string for PermRule is too short")
} else { } else {
match &input[len-2..len] { match &input[len - 2..len] {
".+" => { ".+" => {
input.truncate(len-2); input.truncate(len - 2);
Ok(PermRule::Children(PermissionBuf::from_string_unchecked(input))) Ok(PermRule::Children(PermissionBuf::from_string_unchecked(
}, input,
)))
}
".*" => { ".*" => {
input.truncate(len-2); input.truncate(len - 2);
Ok(PermRule::Subtree(PermissionBuf::from_string_unchecked(input))) Ok(PermRule::Subtree(PermissionBuf::from_string_unchecked(
}, input,
)))
}
_ => Ok(PermRule::Base(PermissionBuf::from_string_unchecked(input))), _ => Ok(PermRule::Base(PermissionBuf::from_string_unchecked(input))),
} }
} }
@ -273,8 +287,10 @@ mod tests {
#[test] #[test]
fn permission_ord_test() { fn permission_ord_test() {
assert!(PermissionBuf::from_string_unchecked("bffh.perm".to_string()) assert!(
> PermissionBuf::from_string_unchecked("bffh.perm.sub".to_string())); PermissionBuf::from_string_unchecked("bffh.perm".to_string())
> PermissionBuf::from_string_unchecked("bffh.perm.sub".to_string())
);
} }
#[test] #[test]
@ -312,44 +328,24 @@ mod tests {
assert!(rule.match_perm(&perm3)); assert!(rule.match_perm(&perm3));
} }
#[test]
fn format_and_read_compatible() {
use std::convert::TryInto;
let testdata = vec![
("testrole", "testsource"),
("", "norole"),
("nosource", "")
].into_iter().map(|(n,s)| (n.to_string(), s.to_string()));
for (name, source) in testdata {
let role = RoleIdentifier { name, source };
let fmt_string = format!("{}", &role);
println!("{:?} is formatted: {}", &role, &fmt_string);
let parsed: RoleIdentifier = fmt_string.try_into().unwrap();
println!("Which parses into {:?}", &parsed);
assert_eq!(role, parsed);
}
}
#[test] #[test]
fn rules_from_string_test() { fn rules_from_string_test() {
assert_eq!( assert_eq!(
PermRule::Base(PermissionBuf::from_string_unchecked("bffh.perm".to_string())), PermRule::Base(PermissionBuf::from_string_unchecked(
"bffh.perm".to_string()
)),
PermRule::try_from("bffh.perm".to_string()).unwrap() PermRule::try_from("bffh.perm".to_string()).unwrap()
); );
assert_eq!( assert_eq!(
PermRule::Children(PermissionBuf::from_string_unchecked("bffh.perm".to_string())), PermRule::Children(PermissionBuf::from_string_unchecked(
"bffh.perm".to_string()
)),
PermRule::try_from("bffh.perm.+".to_string()).unwrap() PermRule::try_from("bffh.perm.+".to_string()).unwrap()
); );
assert_eq!( assert_eq!(
PermRule::Subtree(PermissionBuf::from_string_unchecked("bffh.perm".to_string())), PermRule::Subtree(PermissionBuf::from_string_unchecked(
"bffh.perm".to_string()
)),
PermRule::try_from("bffh.perm.*".to_string()).unwrap() PermRule::try_from("bffh.perm.*".to_string()).unwrap()
); );
} }

View File

@ -1,8 +1,8 @@
use crate::authorization::permissions::{PermRule, Permission};
use crate::users::db::UserData;
use once_cell::sync::OnceCell;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::fmt; use std::fmt;
use once_cell::sync::OnceCell;
use crate::authorization::permissions::{Permission, PermRule};
use crate::users::db::UserData;
static ROLES: OnceCell<HashMap<String, Role>> = OnceCell::new(); static ROLES: OnceCell<HashMap<String, Role>> = OnceCell::new();
@ -27,7 +27,6 @@ impl Roles {
self.roles.get(roleid) self.roles.get(roleid)
} }
/// Tally a role dependency tree into a set /// Tally a role dependency tree into a set
/// ///
/// A Default implementation exists which adapter may overwrite with more efficient /// A Default implementation exists which adapter may overwrite with more efficient
@ -62,10 +61,11 @@ impl Roles {
output output
} }
fn permitted_tally(&self, fn permitted_tally(
&self,
roles: &mut HashSet<String>, roles: &mut HashSet<String>,
role_id: &String, role_id: &String,
perm: &Permission perm: &Permission,
) -> bool { ) -> bool {
if let Some(role) = self.get(role_id) { if let Some(role) = self.get(role_id) {
// Only check and tally parents of a role at the role itself if it's the first time we // Only check and tally parents of a role at the role itself if it's the first time we
@ -130,7 +130,10 @@ pub struct Role {
impl Role { impl Role {
pub fn new(parents: Vec<String>, permissions: Vec<PermRule>) -> Self { pub fn new(parents: Vec<String>, permissions: Vec<PermRule>) -> Self {
Self { parents, permissions } Self {
parents,
permissions,
}
} }
} }

View File

@ -5,7 +5,6 @@ use rsasl::property::AuthId;
use rsasl::session::{Session, Step}; use rsasl::session::{Session, Step};
use std::io::Cursor; use std::io::Cursor;
use crate::capnp::session::APISession; use crate::capnp::session::APISession;
use crate::session::SessionManager; use crate::session::SessionManager;
use api::authenticationsystem_capnp::authentication::{ use api::authenticationsystem_capnp::authentication::{

View File

@ -2,7 +2,7 @@ use std::fmt::Formatter;
use std::net::ToSocketAddrs; use std::net::ToSocketAddrs;
use std::path::PathBuf; use std::path::PathBuf;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
use crate::config::deser_option; use crate::config::deser_option;
@ -14,7 +14,11 @@ use crate::config::deser_option;
pub struct Listen { pub struct Listen {
pub address: String, pub address: String,
#[serde(default, skip_serializing_if = "Option::is_none", deserialize_with = "deser_option")] #[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deser_option"
)]
pub port: Option<u16>, pub port: Option<u16>,
} }

View File

@ -1,13 +1,13 @@
use std::net::SocketAddr;
pub use api::connection_capnp::bootstrap::Client;
use api::connection_capnp::bootstrap; use api::connection_capnp::bootstrap;
pub use api::connection_capnp::bootstrap::Client;
use std::net::SocketAddr;
use capnp::capability::Promise;
use capnp_rpc::pry;
use rsasl::mechname::Mechname;
use crate::authentication::AuthenticationHandle; use crate::authentication::AuthenticationHandle;
use crate::capnp::authenticationsystem::Authentication; use crate::capnp::authenticationsystem::Authentication;
use crate::session::SessionManager; use crate::session::SessionManager;
use capnp::capability::Promise;
use capnp_rpc::pry;
use rsasl::mechname::Mechname;
/// Cap'n Proto API Handler /// Cap'n Proto API Handler
pub struct BootCap { pub struct BootCap {
@ -17,7 +17,11 @@ pub struct BootCap {
} }
impl BootCap { impl BootCap {
pub fn new(peer_addr: SocketAddr, authentication: AuthenticationHandle, sessionmanager: SessionManager) -> Self { pub fn new(
peer_addr: SocketAddr,
authentication: AuthenticationHandle,
sessionmanager: SessionManager,
) -> Self {
tracing::trace!(%peer_addr, "bootstrapping RPC"); tracing::trace!(%peer_addr, "bootstrapping RPC");
Self { Self {
peer_addr, peer_addr,
@ -62,12 +66,14 @@ impl bootstrap::Server for BootCap {
tracing::trace!("mechanisms"); tracing::trace!("mechanisms");
let builder = result.get(); let builder = result.get();
let mechs: Vec<_> = self.authentication.list_available_mechs() let mechs: Vec<_> = self
.authentication
.list_available_mechs()
.into_iter() .into_iter()
.map(|m| m.as_str()) .map(|m| m.as_str())
.collect(); .collect();
let mut mechbuilder = builder.init_mechs(mechs.len() as u32); let mut mechbuilder = builder.init_mechs(mechs.len() as u32);
for (i,m) in mechs.iter().enumerate() { for (i, m) in mechs.iter().enumerate() {
mechbuilder.set(i as u32, m); mechbuilder.set(i as u32, m);
} }

View File

@ -1,14 +1,11 @@
use crate::capnp::machine::Machine;
use crate::resources::search::ResourcesHandle;
use crate::resources::Resource;
use crate::session::SessionHandle; use crate::session::SessionHandle;
use api::machinesystem_capnp::machine_system::{ use crate::RESOURCES;
info, use api::machinesystem_capnp::machine_system::info;
};
use capnp::capability::Promise; use capnp::capability::Promise;
use capnp_rpc::pry; use capnp_rpc::pry;
use crate::capnp::machine::Machine;
use crate::RESOURCES;
use crate::resources::Resource;
use crate::resources::search::ResourcesHandle;
#[derive(Clone)] #[derive(Clone)]
pub struct Machines { pub struct Machines {
@ -19,7 +16,10 @@ pub struct Machines {
impl Machines { impl Machines {
pub fn new(session: SessionHandle) -> Self { pub fn new(session: SessionHandle) -> Self {
// FIXME no unwrap bad // FIXME no unwrap bad
Self { session, resources: RESOURCES.get().unwrap().clone() } Self {
session,
resources: RESOURCES.get().unwrap().clone(),
}
} }
} }
@ -29,7 +29,9 @@ impl info::Server for Machines {
_: info::GetMachineListParams, _: info::GetMachineListParams,
mut result: info::GetMachineListResults, mut result: info::GetMachineListResults,
) -> Promise<(), ::capnp::Error> { ) -> Promise<(), ::capnp::Error> {
let machine_list: Vec<(usize, &Resource)> = self.resources.list_all() let machine_list: Vec<(usize, &Resource)> = self
.resources
.list_all()
.into_iter() .into_iter()
.filter(|resource| resource.visible(&self.session)) .filter(|resource| resource.visible(&self.session))
.enumerate() .enumerate()

View File

@ -1,6 +1,5 @@
use async_net::TcpListener; use async_net::TcpListener;
use capnp_rpc::rpc_twoparty_capnp::Side; use capnp_rpc::rpc_twoparty_capnp::Side;
use capnp_rpc::twoparty::VatNetwork; use capnp_rpc::twoparty::VatNetwork;
use capnp_rpc::RpcSystem; use capnp_rpc::RpcSystem;
@ -69,9 +68,7 @@ impl APIServer {
listens listens
.into_iter() .into_iter()
.map(|a| async move { .map(|a| async move { (async_net::resolve(a.to_tuple()).await, a) })
(async_net::resolve(a.to_tuple()).await, a)
})
.collect::<FuturesUnordered<_>>() .collect::<FuturesUnordered<_>>()
.filter_map(|(res, addr)| async move { .filter_map(|(res, addr)| async move {
match res { match res {
@ -111,7 +108,13 @@ impl APIServer {
tracing::warn!("No usable listen addresses configured for the API server!"); tracing::warn!("No usable listen addresses configured for the API server!");
} }
Ok(Self::new(executor, sockets, acceptor, sessionmanager, authentication)) Ok(Self::new(
executor,
sockets,
acceptor,
sessionmanager,
authentication,
))
} }
pub async fn handle_until(self, stop: impl Future) { pub async fn handle_until(self, stop: impl Future) {
@ -129,10 +132,11 @@ impl APIServer {
} else { } else {
tracing::error!(?stream, "failing a TCP connection with no peer addr"); tracing::error!(?stream, "failing a TCP connection with no peer addr");
} }
}, }
Err(e) => tracing::warn!("Failed to accept stream: {}", e), Err(e) => tracing::warn!("Failed to accept stream: {}", e),
} }
}).await; })
.await;
tracing::info!("closing down API handler"); tracing::info!("closing down API handler");
} }
@ -153,7 +157,11 @@ impl APIServer {
let (rx, tx) = futures_lite::io::split(stream); let (rx, tx) = futures_lite::io::split(stream);
let vat = VatNetwork::new(rx, tx, Side::Server, Default::default()); let vat = VatNetwork::new(rx, tx, Side::Server, Default::default());
let bootstrap: connection::Client = capnp_rpc::new_client(connection::BootCap::new(peer_addr, self.authentication.clone(), self.sessionmanager.clone())); let bootstrap: connection::Client = capnp_rpc::new_client(connection::BootCap::new(
peer_addr,
self.authentication.clone(),
self.sessionmanager.clone(),
));
if let Err(e) = RpcSystem::new(Box::new(vat), Some(bootstrap.client)).await { if let Err(e) = RpcSystem::new(Box::new(vat), Some(bootstrap.client)).await {
tracing::error!("Error during RPC handling: {}", e); tracing::error!("Error during RPC handling: {}", e);

View File

@ -10,6 +10,4 @@ impl Permissions {
} }
} }
impl PermissionSystem for Permissions { impl PermissionSystem for Permissions {}
}

View File

@ -1,12 +1,10 @@
use api::authenticationsystem_capnp::response::successful::Builder;
use crate::authorization::permissions::Permission; use crate::authorization::permissions::Permission;
use api::authenticationsystem_capnp::response::successful::Builder;
use crate::capnp::machinesystem::Machines; use crate::capnp::machinesystem::Machines;
use crate::capnp::permissionsystem::Permissions; use crate::capnp::permissionsystem::Permissions;
use crate::capnp::user_system::Users; use crate::capnp::user_system::Users;
use crate::session::{SessionHandle}; use crate::session::SessionHandle;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct APISession; pub struct APISession;

View File

@ -1,10 +1,10 @@
use crate::authorization::permissions::Permission;
use crate::session::SessionHandle;
use crate::users::{db, UserRef};
use api::general_capnp::optional;
use api::user_capnp::user::{self, admin, info, manage};
use capnp::capability::Promise; use capnp::capability::Promise;
use capnp_rpc::pry; use capnp_rpc::pry;
use crate::session::SessionHandle;
use api::user_capnp::user::{admin, info, manage, self};
use api::general_capnp::optional;
use crate::authorization::permissions::Permission;
use crate::users::{db, UserRef};
#[derive(Clone)] #[derive(Clone)]
pub struct User { pub struct User {
@ -22,7 +22,11 @@ impl User {
Self::new(session, user) Self::new(session, user)
} }
pub fn build_optional(session: &SessionHandle, user: Option<UserRef>, builder: optional::Builder<user::Owned>) { pub fn build_optional(
session: &SessionHandle,
user: Option<UserRef>,
builder: optional::Builder<user::Owned>,
) {
if let Some(user) = user.and_then(|u| session.users.get_user(u.get_username())) { if let Some(user) = user.and_then(|u| session.users.get_user(u.get_username())) {
let builder = builder.init_just(); let builder = builder.init_just();
Self::fill(&session, user, builder); Self::fill(&session, user, builder);
@ -102,12 +106,18 @@ impl admin::Server for User {
let rolename = pry!(pry!(pry!(param.get()).get_role()).get_name()); let rolename = pry!(pry!(pry!(param.get()).get_role()).get_name());
if let Some(_role) = self.session.roles.get(rolename) { if let Some(_role) = self.session.roles.get(rolename) {
let mut target = self.session.users.get_user(self.user.get_username()).unwrap(); let mut target = self
.session
.users
.get_user(self.user.get_username())
.unwrap();
// Only update if needed // Only update if needed
if !target.userdata.roles.iter().any(|r| r.as_str() == rolename) { if !target.userdata.roles.iter().any(|r| r.as_str() == rolename) {
target.userdata.roles.push(rolename.to_string()); target.userdata.roles.push(rolename.to_string());
self.session.users.put_user(self.user.get_username(), &target); self.session
.users
.put_user(self.user.get_username(), &target);
} }
} }
@ -121,22 +131,24 @@ impl admin::Server for User {
let rolename = pry!(pry!(pry!(param.get()).get_role()).get_name()); let rolename = pry!(pry!(pry!(param.get()).get_role()).get_name());
if let Some(_role) = self.session.roles.get(rolename) { if let Some(_role) = self.session.roles.get(rolename) {
let mut target = self.session.users.get_user(self.user.get_username()).unwrap(); let mut target = self
.session
.users
.get_user(self.user.get_username())
.unwrap();
// Only update if needed // Only update if needed
if target.userdata.roles.iter().any(|r| r.as_str() == rolename) { if target.userdata.roles.iter().any(|r| r.as_str() == rolename) {
target.userdata.roles.retain(|r| r.as_str() != rolename); target.userdata.roles.retain(|r| r.as_str() != rolename);
self.session.users.put_user(self.user.get_username(), &target); self.session
.users
.put_user(self.user.get_username(), &target);
} }
} }
Promise::ok(()) Promise::ok(())
} }
fn pwd( fn pwd(&mut self, _: admin::PwdParams, _: admin::PwdResults) -> Promise<(), ::capnp::Error> {
&mut self,
_: admin::PwdParams,
_: admin::PwdResults,
) -> Promise<(), ::capnp::Error> {
Promise::err(::capnp::Error::unimplemented( Promise::err(::capnp::Error::unimplemented(
"method not implemented".to_string(), "method not implemented".to_string(),
)) ))

View File

@ -1,15 +1,12 @@
use api::usersystem_capnp::user_system::{info, manage, search};
use capnp::capability::Promise; use capnp::capability::Promise;
use capnp_rpc::pry; use capnp_rpc::pry;
use api::usersystem_capnp::user_system::{
info, manage, search
};
use crate::capnp::user::User; use crate::capnp::user::User;
use crate::session::SessionHandle; use crate::session::SessionHandle;
use crate::users::{db, UserRef}; use crate::users::{db, UserRef};
#[derive(Clone)] #[derive(Clone)]
pub struct Users { pub struct Users {
session: SessionHandle, session: SessionHandle,
@ -40,7 +37,8 @@ impl manage::Server for Users {
mut result: manage::GetUserListResults, mut result: manage::GetUserListResults,
) -> Promise<(), ::capnp::Error> { ) -> Promise<(), ::capnp::Error> {
let userdb = self.session.users.into_inner(); let userdb = self.session.users.into_inner();
let users = pry!(userdb.get_all() let users = pry!(userdb
.get_all()
.map_err(|e| capnp::Error::failed(format!("UserDB error: {:?}", e)))); .map_err(|e| capnp::Error::failed(format!("UserDB error: {:?}", e))));
let mut builder = result.get().init_user_list(users.len() as u32); let mut builder = result.get().init_user_list(users.len() as u32);
for (i, (_, user)) in users.into_iter().enumerate() { for (i, (_, user)) in users.into_iter().enumerate() {

View File

@ -1,8 +1,6 @@
use std::path::Path;
use crate::Config; use crate::Config;
use std::path::Path;
pub fn read_config_file(path: impl AsRef<Path>) -> Result<Config, serde_dhall::Error> { pub fn read_config_file(path: impl AsRef<Path>) -> Result<Config, serde_dhall::Error> {
serde_dhall::from_file(path) serde_dhall::from_file(path).parse().map_err(Into::into)
.parse()
.map_err(Into::into)
} }

View File

@ -1,13 +1,13 @@
use std::default::Default;
use std::path::{PathBuf};
use std::collections::HashMap; use std::collections::HashMap;
use std::default::Default;
use std::path::PathBuf;
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
mod dhall; mod dhall;
pub use dhall::read_config_file as read; pub use dhall::read_config_file as read;
use crate::authorization::permissions::{PrivilegesBuf}; use crate::authorization::permissions::PrivilegesBuf;
use crate::authorization::roles::Role; use crate::authorization::roles::Role;
use crate::capnp::{Listen, TlsListen}; use crate::capnp::{Listen, TlsListen};
use crate::logging::LogConfig; use crate::logging::LogConfig;
@ -23,13 +23,25 @@ pub struct MachineDescription {
pub name: String, pub name: String,
/// An optional description of the Machine. /// An optional description of the Machine.
#[serde(default, skip_serializing_if = "Option::is_none", deserialize_with = "deser_option")] #[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deser_option"
)]
pub description: Option<String>, pub description: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none", deserialize_with = "deser_option")] #[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deser_option"
)]
pub wiki: Option<String>, pub wiki: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none", deserialize_with = "deser_option")] #[serde(
default,
skip_serializing_if = "Option::is_none",
deserialize_with = "deser_option"
)]
pub category: Option<String>, pub category: Option<String>,
/// The permission required /// The permission required
@ -83,48 +95,49 @@ impl Config {
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModuleConfig { pub struct ModuleConfig {
pub module: String, pub module: String,
pub params: HashMap<String, String> pub params: HashMap<String, String>,
} }
pub(crate) fn deser_option<'de, D, T>(d: D) -> std::result::Result<Option<T>, D::Error> pub(crate) fn deser_option<'de, D, T>(d: D) -> std::result::Result<Option<T>, D::Error>
where D: serde::Deserializer<'de>, T: serde::Deserialize<'de>, where
D: serde::Deserializer<'de>,
T: serde::Deserialize<'de>,
{ {
Ok(T::deserialize(d).ok()) Ok(T::deserialize(d).ok())
} }
impl Default for Config { impl Default for Config {
fn default() -> Self { fn default() -> Self {
let mut actors: HashMap::<String, ModuleConfig> = HashMap::new(); let mut actors: HashMap<String, ModuleConfig> = HashMap::new();
let mut initiators: HashMap::<String, ModuleConfig> = HashMap::new(); let mut initiators: HashMap<String, ModuleConfig> = HashMap::new();
let machines = HashMap::new(); let machines = HashMap::new();
actors.insert("Actor".to_string(), ModuleConfig { actors.insert(
"Actor".to_string(),
ModuleConfig {
module: "Shelly".to_string(), module: "Shelly".to_string(),
params: HashMap::new(), params: HashMap::new(),
}); },
initiators.insert("Initiator".to_string(), ModuleConfig { );
initiators.insert(
"Initiator".to_string(),
ModuleConfig {
module: "TCP-Listen".to_string(), module: "TCP-Listen".to_string(),
params: HashMap::new(), params: HashMap::new(),
}); },
);
Config { Config {
listens: vec![ listens: vec![Listen {
Listen {
address: "127.0.0.1".to_string(), address: "127.0.0.1".to_string(),
port: None, port: None,
} }],
],
actors, actors,
initiators, initiators,
machines, machines,
mqtt_url: "tcp://localhost:1883".to_string(), mqtt_url: "tcp://localhost:1883".to_string(),
actor_connections: vec![ actor_connections: vec![("Testmachine".to_string(), "Actor".to_string())],
("Testmachine".to_string(), "Actor".to_string()), init_connections: vec![("Initiator".to_string(), "Testmachine".to_string())],
],
init_connections: vec![
("Initiator".to_string(), "Testmachine".to_string()),
],
db_path: PathBuf::from("/run/bffh/database"), db_path: PathBuf::from("/run/bffh/database"),
auditlog_path: PathBuf::from("/var/log/bffh/audit.log"), auditlog_path: PathBuf::from("/var/log/bffh/audit.log"),
@ -133,7 +146,7 @@ impl Default for Config {
tlsconfig: TlsListen { tlsconfig: TlsListen {
certfile: PathBuf::from("./bffh.crt"), certfile: PathBuf::from("./bffh.crt"),
keyfile: PathBuf::from("./bffh.key"), keyfile: PathBuf::from("./bffh.key"),
.. Default::default() ..Default::default()
}, },
tlskeylog: None, tlskeylog: None,

View File

@ -2,6 +2,6 @@ mod raw;
pub use raw::RawDB; pub use raw::RawDB;
mod typed; mod typed;
pub use typed::{DB, ArchivedValue, Adapter, AlignedAdapter}; pub use typed::{Adapter, AlignedAdapter, ArchivedValue, DB};
pub type Error = lmdb::Error; pub type Error = lmdb::Error;

View File

@ -1,10 +1,4 @@
use lmdb::{ use lmdb::{DatabaseFlags, Environment, RwTransaction, Transaction, WriteFlags};
Transaction,
RwTransaction,
Environment,
DatabaseFlags,
WriteFlags,
};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct RawDB { pub struct RawDB {
@ -16,12 +10,21 @@ impl RawDB {
env.open_db(name).map(|db| Self { db }) env.open_db(name).map(|db| Self { db })
} }
pub fn create(env: &Environment, name: Option<&str>, flags: DatabaseFlags) -> lmdb::Result<Self> { pub fn create(
env: &Environment,
name: Option<&str>,
flags: DatabaseFlags,
) -> lmdb::Result<Self> {
env.create_db(name, flags).map(|db| Self { db }) env.create_db(name, flags).map(|db| Self { db })
} }
pub fn get<'txn, T: Transaction, K>(&self, txn: &'txn T, key: &K) -> lmdb::Result<Option<&'txn [u8]>> pub fn get<'txn, T: Transaction, K>(
where K: AsRef<[u8]> &self,
txn: &'txn T,
key: &K,
) -> lmdb::Result<Option<&'txn [u8]>>
where
K: AsRef<[u8]>,
{ {
match txn.get(self.db, key) { match txn.get(self.db, key) {
Ok(buf) => Ok(Some(buf)), Ok(buf) => Ok(Some(buf)),
@ -30,23 +33,36 @@ impl RawDB {
} }
} }
pub fn put<K, V>(&self, txn: &mut RwTransaction, key: &K, value: &V, flags: WriteFlags) pub fn put<K, V>(
-> lmdb::Result<()> &self,
where K: AsRef<[u8]>, txn: &mut RwTransaction,
key: &K,
value: &V,
flags: WriteFlags,
) -> lmdb::Result<()>
where
K: AsRef<[u8]>,
V: AsRef<[u8]>, V: AsRef<[u8]>,
{ {
txn.put(self.db, key, value, flags) txn.put(self.db, key, value, flags)
} }
pub fn reserve<'txn, K>(&self, txn: &'txn mut RwTransaction, key: &K, size: usize, flags: WriteFlags) pub fn reserve<'txn, K>(
-> lmdb::Result<&'txn mut [u8]> &self,
where K: AsRef<[u8]> txn: &'txn mut RwTransaction,
key: &K,
size: usize,
flags: WriteFlags,
) -> lmdb::Result<&'txn mut [u8]>
where
K: AsRef<[u8]>,
{ {
txn.reserve(self.db, key, size, flags) txn.reserve(self.db, key, size, flags)
} }
pub fn del<K, V>(&self, txn: &mut RwTransaction, key: &K, value: Option<&V>) -> lmdb::Result<()> pub fn del<K, V>(&self, txn: &mut RwTransaction, key: &K, value: Option<&V>) -> lmdb::Result<()>
where K: AsRef<[u8]>, where
K: AsRef<[u8]>,
V: AsRef<[u8]>, V: AsRef<[u8]>,
{ {
txn.del(self.db, key, value.map(AsRef::as_ref)) txn.del(self.db, key, value.map(AsRef::as_ref))
@ -60,7 +76,10 @@ impl RawDB {
cursor.iter_start() cursor.iter_start()
} }
pub fn open_ro_cursor<'txn, T: Transaction>(&self, txn: &'txn T) -> lmdb::Result<lmdb::RoCursor<'txn>> { pub fn open_ro_cursor<'txn, T: Transaction>(
&self,
txn: &'txn T,
) -> lmdb::Result<lmdb::RoCursor<'txn>> {
txn.open_ro_cursor(self.db) txn.open_ro_cursor(self.db)
} }
} }

View File

@ -119,7 +119,11 @@ impl<A> DB<A> {
} }
impl<A: Adapter> DB<A> { impl<A: Adapter> DB<A> {
pub fn get<T: Transaction>(&self, txn: &T, key: &impl AsRef<[u8]>) -> Result<Option<A::Item>, db::Error> { pub fn get<T: Transaction>(
&self,
txn: &T,
key: &impl AsRef<[u8]>,
) -> Result<Option<A::Item>, db::Error> {
Ok(self.db.get(txn, key)?.map(A::decode)) Ok(self.db.get(txn, key)?.map(A::decode))
} }
@ -129,8 +133,7 @@ impl<A: Adapter> DB<A> {
key: &impl AsRef<[u8]>, key: &impl AsRef<[u8]>,
value: &A::Item, value: &A::Item,
flags: WriteFlags, flags: WriteFlags,
) -> Result<(), db::Error> ) -> Result<(), db::Error> {
{
let len = A::encoded_len(value); let len = A::encoded_len(value);
let buf = self.db.reserve(txn, key, len, flags)?; let buf = self.db.reserve(txn, key, len, flags)?;
assert_eq!(buf.len(), len, "Reserved buffer is not of requested size!"); assert_eq!(buf.len(), len, "Reserved buffer is not of requested size!");
@ -146,11 +149,12 @@ impl<A: Adapter> DB<A> {
self.db.clear(txn) self.db.clear(txn)
} }
pub fn get_all<'txn, T: Transaction>(&self, txn: &'txn T) -> Result<impl IntoIterator<Item=(&'txn [u8], A::Item)>, db::Error> { pub fn get_all<'txn, T: Transaction>(
&self,
txn: &'txn T,
) -> Result<impl IntoIterator<Item = (&'txn [u8], A::Item)>, db::Error> {
let mut cursor = self.db.open_ro_cursor(txn)?; let mut cursor = self.db.open_ro_cursor(txn)?;
let it = cursor.iter_start(); let it = cursor.iter_start();
Ok(it.filter_map(|buf| buf.ok().map(|(kbuf,vbuf)| { Ok(it.filter_map(|buf| buf.ok().map(|(kbuf, vbuf)| (kbuf, A::decode(vbuf)))))
(kbuf, A::decode(vbuf))
})))
} }
} }

View File

@ -1,7 +1,7 @@
use std::io;
use std::fmt;
use rsasl::error::SessionError;
use crate::db; use crate::db;
use rsasl::error::SessionError;
use std::fmt;
use std::io;
type DBError = db::Error; type DBError = db::Error;
@ -21,19 +21,19 @@ impl fmt::Display for Error {
match self { match self {
Error::SASL(e) => { Error::SASL(e) => {
write!(f, "SASL Error: {}", e) write!(f, "SASL Error: {}", e)
}, }
Error::IO(e) => { Error::IO(e) => {
write!(f, "IO Error: {}", e) write!(f, "IO Error: {}", e)
}, }
Error::Boxed(e) => { Error::Boxed(e) => {
write!(f, "{}", e) write!(f, "{}", e)
}, }
Error::Capnp(e) => { Error::Capnp(e) => {
write!(f, "Cap'n Proto Error: {}", e) write!(f, "Cap'n Proto Error: {}", e)
}, }
Error::DB(e) => { Error::DB(e) => {
write!(f, "DB Error: {:?}", e) write!(f, "DB Error: {:?}", e)
}, }
Error::Denied => { Error::Denied => {
write!(f, "You do not have the permission required to do that.") write!(f, "You do not have the permission required to do that.")
} }

View File

@ -1,9 +1,9 @@
use std::fs::{File, OpenOptions};
use std::{fmt, io};
use std::fmt::Formatter; use std::fmt::Formatter;
use std::fs::{File, OpenOptions};
use std::io::Write; use std::io::Write;
use std::path::Path; use std::path::Path;
use std::sync::Mutex; use std::sync::Mutex;
use std::{fmt, io};
// Internal mutable state for KeyLogFile // Internal mutable state for KeyLogFile
struct KeyLogFileInner { struct KeyLogFileInner {
@ -18,10 +18,7 @@ impl fmt::Debug for KeyLogFileInner {
impl KeyLogFileInner { impl KeyLogFileInner {
fn new(path: impl AsRef<Path>) -> io::Result<Self> { fn new(path: impl AsRef<Path>) -> io::Result<Self> {
let file = OpenOptions::new() let file = OpenOptions::new().append(true).create(true).open(path)?;
.append(true)
.create(true)
.open(path)?;
Ok(Self { Ok(Self {
file, file,

View File

@ -16,9 +16,9 @@ pub mod db;
/// Shared error type /// Shared error type
pub mod error; pub mod error;
pub mod users;
pub mod authentication; pub mod authentication;
pub mod authorization; pub mod authorization;
pub mod users;
/// Resources /// Resources
pub mod resources; pub mod resources;
@ -31,38 +31,34 @@ pub mod capnp;
pub mod utils; pub mod utils;
mod tls; mod audit;
mod keylog; mod keylog;
mod logging; mod logging;
mod audit;
mod session; mod session;
mod tls;
use std::sync::Arc;
use std::sync::{Arc};
use anyhow::Context; use anyhow::Context;
use futures_util::StreamExt; use futures_util::StreamExt;
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use signal_hook::consts::signal::*;
use executor::pool::Executor;
use crate::audit::AuditLog; use crate::audit::AuditLog;
use crate::authentication::AuthenticationHandle; use crate::authentication::AuthenticationHandle;
use crate::authorization::roles::Roles; use crate::authorization::roles::Roles;
use crate::capnp::APIServer; use crate::capnp::APIServer;
use crate::config::{Config}; use crate::config::Config;
use crate::resources::modules::fabaccess::MachineState; use crate::resources::modules::fabaccess::MachineState;
use crate::resources::Resource;
use crate::resources::search::ResourcesHandle; use crate::resources::search::ResourcesHandle;
use crate::resources::state::db::StateDB; use crate::resources::state::db::StateDB;
use crate::resources::Resource;
use crate::session::SessionManager; use crate::session::SessionManager;
use crate::tls::TlsConfig; use crate::tls::TlsConfig;
use crate::users::db::UserDB; use crate::users::db::UserDB;
use crate::users::Users; use crate::users::Users;
use executor::pool::Executor;
use signal_hook::consts::signal::*;
pub const VERSION_STRING: &'static str = env!("BFFHD_VERSION_STRING"); pub const VERSION_STRING: &'static str = env!("BFFHD_VERSION_STRING");
pub const RELEASE_STRING: &'static str = env!("BFFHD_RELEASE_STRING"); pub const RELEASE_STRING: &'static str = env!("BFFHD_RELEASE_STRING");
@ -81,7 +77,7 @@ pub static RESOURCES: OnceCell<ResourcesHandle> = OnceCell::new();
impl Diflouroborane { impl Diflouroborane {
pub fn new(config: Config) -> anyhow::Result<Self> { pub fn new(config: Config) -> anyhow::Result<Self> {
logging::init(&config.logging); logging::init(&config.logging);
tracing::info!(version=VERSION_STRING, "Starting BFFH"); tracing::info!(version = VERSION_STRING, "Starting BFFH");
let span = tracing::info_span!("setup"); let span = tracing::info_span!("setup");
let _guard = span.enter(); let _guard = span.enter();
@ -89,8 +85,8 @@ impl Diflouroborane {
let executor = Executor::new(); let executor = Executor::new();
let env = StateDB::open_env(&config.db_path)?; let env = StateDB::open_env(&config.db_path)?;
let statedb = StateDB::create_with_env(env.clone()) let statedb =
.context("Failed to open state DB file")?; StateDB::create_with_env(env.clone()).context("Failed to open state DB file")?;
let users = Users::new(env.clone()).context("Failed to open users DB file")?; let users = Users::new(env.clone()).context("Failed to open users DB file")?;
let roles = Roles::new(config.roles.clone()); let roles = Roles::new(config.roles.clone());
@ -98,31 +94,43 @@ impl Diflouroborane {
let _audit_log = AuditLog::new(&config).context("Failed to initialize audit log")?; let _audit_log = AuditLog::new(&config).context("Failed to initialize audit log")?;
let resources = ResourcesHandle::new(config.machines.iter().map(|(id, desc)| { let resources = ResourcesHandle::new(config.machines.iter().map(|(id, desc)| {
Resource::new(Arc::new(resources::Inner::new(id.to_string(), statedb.clone(), desc.clone()))) Resource::new(Arc::new(resources::Inner::new(
id.to_string(),
statedb.clone(),
desc.clone(),
)))
})); }));
RESOURCES.set(resources.clone()); RESOURCES.set(resources.clone());
Ok(Self {
Ok(Self { config, executor, statedb, users, roles, resources }) config,
executor,
statedb,
users,
roles,
resources,
})
} }
pub fn run(&mut self) -> anyhow::Result<()> { pub fn run(&mut self) -> anyhow::Result<()> {
let mut signals = signal_hook_async_std::Signals::new(&[ let mut signals = signal_hook_async_std::Signals::new(&[SIGINT, SIGQUIT, SIGTERM])
SIGINT, .context("Failed to construct signal handler")?;
SIGQUIT,
SIGTERM,
]).context("Failed to construct signal handler")?;
actors::load(self.executor.clone(), &self.config, self.resources.clone())?; actors::load(self.executor.clone(), &self.config, self.resources.clone())?;
let tlsconfig = TlsConfig::new(self.config.tlskeylog.as_ref(), !self.config.is_quiet())?; let tlsconfig = TlsConfig::new(self.config.tlskeylog.as_ref(), !self.config.is_quiet())?;
let acceptor = tlsconfig.make_tls_acceptor(&self.config.tlsconfig)?; let acceptor = tlsconfig.make_tls_acceptor(&self.config.tlsconfig)?;
let sessionmanager = SessionManager::new(self.users.clone(), self.roles.clone()); let sessionmanager = SessionManager::new(self.users.clone(), self.roles.clone());
let authentication = AuthenticationHandle::new(self.users.clone()); let authentication = AuthenticationHandle::new(self.users.clone());
let apiserver = self.executor.run(APIServer::bind(self.executor.clone(), &self.config.listens, acceptor, sessionmanager, authentication))?; let apiserver = self.executor.run(APIServer::bind(
self.executor.clone(),
&self.config.listens,
acceptor,
sessionmanager,
authentication,
))?;
let (mut tx, rx) = async_oneshot::oneshot(); let (mut tx, rx) = async_oneshot::oneshot();
@ -142,4 +150,3 @@ impl Diflouroborane {
Ok(()) Ok(())
} }
} }

View File

@ -1,7 +1,6 @@
use tracing_subscriber::{EnvFilter}; use tracing_subscriber::EnvFilter;
use serde::{Deserialize, Serialize};
use serde::{Serialize, Deserialize};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogConfig { pub struct LogConfig {
@ -32,8 +31,7 @@ pub fn init(config: &LogConfig) {
EnvFilter::from_env("BFFH_LOG") EnvFilter::from_env("BFFH_LOG")
}; };
let builder = tracing_subscriber::fmt() let builder = tracing_subscriber::fmt().with_env_filter(filter);
.with_env_filter(filter);
let format = config.format.to_lowercase(); let format = config.format.to_lowercase();
match format.as_str() { match format.as_str() {

View File

@ -1,16 +1,16 @@
use rkyv::{Archive, Serialize, Deserialize}; use rkyv::{Archive, Deserialize, Serialize};
#[derive(
Clone,
Debug,
PartialEq,
Eq,
Archive,
Serialize,
Deserialize,
#[derive(Clone, Debug, PartialEq, Eq)] serde::Serialize,
#[derive(Archive, Serialize, Deserialize)] serde::Deserialize,
#[derive(serde::Serialize, serde::Deserialize)] )]
pub struct Resource { pub struct Resource {
uuid: u128, uuid: u128,
id: String, id: String,

View File

@ -1,21 +1,21 @@
use futures_signals::signal::{Mutable, Signal};
use rkyv::Infallible; use rkyv::Infallible;
use std::ops::Deref; use std::ops::Deref;
use std::sync::Arc; use std::sync::Arc;
use futures_signals::signal::{Mutable, Signal};
use rkyv::{Archived, Deserialize};
use rkyv::option::ArchivedOption;
use rkyv::ser::Serializer;
use rkyv::ser::serializers::AllocSerializer;
use crate::audit::AUDIT; use crate::audit::AUDIT;
use crate::authorization::permissions::PrivilegesBuf; use crate::authorization::permissions::PrivilegesBuf;
use crate::config::MachineDescription; use crate::config::MachineDescription;
use crate::db::ArchivedValue; use crate::db::ArchivedValue;
use crate::resources::modules::fabaccess::{MachineState, Status, ArchivedStatus}; use crate::resources::modules::fabaccess::{ArchivedStatus, MachineState, Status};
use crate::resources::state::db::StateDB; use crate::resources::state::db::StateDB;
use crate::resources::state::State; use crate::resources::state::State;
use crate::session::SessionHandle; use crate::session::SessionHandle;
use crate::users::UserRef; use crate::users::UserRef;
use rkyv::option::ArchivedOption;
use rkyv::ser::serializers::AllocSerializer;
use rkyv::ser::Serializer;
use rkyv::{Archived, Deserialize};
pub mod db; pub mod db;
pub mod search; pub mod search;
@ -43,27 +43,35 @@ impl Inner {
let update = state.to_state(); let update = state.to_state();
let mut serializer = AllocSerializer::<1024>::default(); let mut serializer = AllocSerializer::<1024>::default();
serializer.serialize_value(&update).expect("failed to serialize new default state"); serializer
.serialize_value(&update)
.expect("failed to serialize new default state");
let val = ArchivedValue::new(serializer.into_serializer().into_inner()); let val = ArchivedValue::new(serializer.into_serializer().into_inner());
db.put(&id.as_bytes(), &val).unwrap(); db.put(&id.as_bytes(), &val).unwrap();
val val
}; };
let signal = Mutable::new(state); let signal = Mutable::new(state);
Self { id, db, signal, desc } Self {
id,
db,
signal,
desc,
}
} }
pub fn signal(&self) -> impl Signal<Item=ArchivedValue<State>> { pub fn signal(&self) -> impl Signal<Item = ArchivedValue<State>> {
Box::pin(self.signal.signal_cloned()) Box::pin(self.signal.signal_cloned())
} }
fn get_state(&self) -> ArchivedValue<State> { fn get_state(&self) -> ArchivedValue<State> {
self.db.get(self.id.as_bytes()) self.db
.get(self.id.as_bytes())
.expect("lmdb error") .expect("lmdb error")
.expect("state should never be None") .expect("state should never be None")
} }
fn get_state_ref(&self) -> impl Deref<Target=ArchivedValue<State>> + '_ { fn get_state_ref(&self) -> impl Deref<Target = ArchivedValue<State>> + '_ {
self.signal.lock_ref() self.signal.lock_ref()
} }
@ -76,7 +84,10 @@ impl Inner {
self.db.put(&self.id.as_bytes(), &state).unwrap(); self.db.put(&self.id.as_bytes(), &state).unwrap();
tracing::trace!("Updated DB, sending update signal"); tracing::trace!("Updated DB, sending update signal");
AUDIT.get().unwrap().log(self.id.as_str(), &format!("{}", state)); AUDIT
.get()
.unwrap()
.log(self.id.as_str(), &format!("{}", state));
self.signal.set(state); self.signal.set(state);
tracing::trace!("Sent update signal"); tracing::trace!("Sent update signal");
@ -85,7 +96,7 @@ impl Inner {
#[derive(Clone)] #[derive(Clone)]
pub struct Resource { pub struct Resource {
inner: Arc<Inner> inner: Arc<Inner>,
} }
impl Resource { impl Resource {
@ -97,7 +108,7 @@ impl Resource {
self.inner.get_state() self.inner.get_state()
} }
pub fn get_state_ref(&self) -> impl Deref<Target=ArchivedValue<State>> + '_ { pub fn get_state_ref(&self) -> impl Deref<Target = ArchivedValue<State>> + '_ {
self.inner.get_state_ref() self.inner.get_state_ref()
} }
@ -109,7 +120,7 @@ impl Resource {
self.inner.desc.name.as_str() self.inner.desc.name.as_str()
} }
pub fn get_signal(&self) -> impl Signal<Item=ArchivedValue<State>> { pub fn get_signal(&self) -> impl Signal<Item = ArchivedValue<State>> {
self.inner.signal() self.inner.signal()
} }
@ -125,13 +136,13 @@ impl Resource {
let state = self.get_state_ref(); let state = self.get_state_ref();
let state: &Archived<State> = state.as_ref(); let state: &Archived<State> = state.as_ref();
match &state.inner.state { match &state.inner.state {
ArchivedStatus::Blocked(user) | ArchivedStatus::Blocked(user)
ArchivedStatus::InUse(user) | | ArchivedStatus::InUse(user)
ArchivedStatus::Reserved(user) | | ArchivedStatus::Reserved(user)
ArchivedStatus::ToCheck(user) => { | ArchivedStatus::ToCheck(user) => {
let user = Deserialize::<UserRef, _>::deserialize(user, &mut Infallible).unwrap(); let user = Deserialize::<UserRef, _>::deserialize(user, &mut Infallible).unwrap();
Some(user) Some(user)
}, }
_ => None, _ => None,
} }
} }
@ -158,7 +169,8 @@ impl Resource {
let old = self.inner.get_state(); let old = self.inner.get_state();
let oldref: &Archived<State> = old.as_ref(); let oldref: &Archived<State> = old.as_ref();
let previous: &Archived<Option<UserRef>> = &oldref.inner.previous; let previous: &Archived<Option<UserRef>> = &oldref.inner.previous;
let previous = Deserialize::<Option<UserRef>, _>::deserialize(previous, &mut rkyv::Infallible) let previous =
Deserialize::<Option<UserRef>, _>::deserialize(previous, &mut rkyv::Infallible)
.expect("Infallible deserializer failed"); .expect("Infallible deserializer failed");
let new = MachineState { state, previous }; let new = MachineState { state, previous };
self.set_state(new); self.set_state(new);

View File

@ -1,14 +1,13 @@
use std::fmt;
use std::fmt::{Write};
use crate::utils::oid::ObjectIdentifier; use crate::utils::oid::ObjectIdentifier;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use rkyv::{Archive, Archived, Deserialize, Infallible}; use rkyv::{Archive, Archived, Deserialize, Infallible};
use std::fmt;
use std::fmt::Write;
use std::str::FromStr; use std::str::FromStr;
//use crate::oidvalue; //use crate::oidvalue;
use crate::resources::state::{State}; use crate::resources::state::State;
use crate::users::UserRef; use crate::users::UserRef;
@ -80,13 +79,12 @@ impl MachineState {
} }
pub fn from(dbstate: &Archived<State>) -> Self { pub fn from(dbstate: &Archived<State>) -> Self {
let state: &Archived<MachineState> = &dbstate.inner; let state: &Archived<MachineState> = &dbstate.inner;
Deserialize::deserialize(state, &mut Infallible).unwrap() Deserialize::deserialize(state, &mut Infallible).unwrap()
} }
pub fn to_state(&self) -> State { pub fn to_state(&self) -> State {
State { State {
inner: self.clone() inner: self.clone(),
} }
} }

View File

@ -1,6 +1,3 @@
pub mod fabaccess; pub mod fabaccess;
pub trait MachineModel { pub trait MachineModel {}
}

View File

@ -1,13 +1,13 @@
use crate::resources::Resource;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use crate::resources::Resource;
struct Inner { struct Inner {
id: HashMap<String, Resource>, id: HashMap<String, Resource>,
} }
impl Inner { impl Inner {
pub fn new(resources: impl IntoIterator<Item=Resource>) -> Self { pub fn new(resources: impl IntoIterator<Item = Resource>) -> Self {
let mut id = HashMap::new(); let mut id = HashMap::new();
for resource in resources { for resource in resources {
@ -25,13 +25,13 @@ pub struct ResourcesHandle {
} }
impl ResourcesHandle { impl ResourcesHandle {
pub fn new(resources: impl IntoIterator<Item=Resource>) -> Self { pub fn new(resources: impl IntoIterator<Item = Resource>) -> Self {
Self { Self {
inner: Arc::new(Inner::new(resources)), inner: Arc::new(Inner::new(resources)),
} }
} }
pub fn list_all(&self) -> impl IntoIterator<Item=&Resource> { pub fn list_all(&self) -> impl IntoIterator<Item = &Resource> {
self.inner.id.values() self.inner.id.values()
} }

View File

@ -1,9 +1,6 @@
use crate::db; use crate::db;
use crate::db::{ArchivedValue, RawDB, DB, AlignedAdapter}; use crate::db::{AlignedAdapter, ArchivedValue, RawDB, DB};
use lmdb::{ use lmdb::{DatabaseFlags, Environment, EnvironmentFlags, Transaction, WriteFlags};
DatabaseFlags, Environment, EnvironmentFlags, Transaction,
WriteFlags,
};
use std::{path::Path, sync::Arc}; use std::{path::Path, sync::Arc};
use crate::resources::state::State; use crate::resources::state::State;
@ -67,7 +64,7 @@ impl StateDB {
pub fn get_all<'txn, T: Transaction>( pub fn get_all<'txn, T: Transaction>(
&self, &self,
txn: &'txn T, txn: &'txn T,
) -> Result<impl IntoIterator<Item = (&'txn [u8], ArchivedValue<State>)>, db::Error, > { ) -> Result<impl IntoIterator<Item = (&'txn [u8], ArchivedValue<State>)>, db::Error> {
self.db.get_all(txn) self.db.get_all(txn)
} }
@ -83,34 +80,5 @@ impl StateDB {
mod tests { mod tests {
use super::*; use super::*;
use crate::resource::state::value::Vec3u8;
use crate::resource::state::value::{OID_COLOUR, OID_INTENSITY, OID_POWERED};
use std::ops::Deref; use std::ops::Deref;
#[test]
fn construct_state() {
let tmpdir = tempfile::tempdir().unwrap();
let mut tmppath = tmpdir.path().to_owned();
tmppath.push("db");
let db = StateDB::create(tmppath).unwrap();
let b = State::build()
.add(OID_COLOUR.clone(), Box::new(Vec3u8 { a: 1, b: 2, c: 3 }))
.add(OID_POWERED.clone(), Box::new(true))
.add(OID_INTENSITY.clone(), Box::new(1023))
.finish();
println!("({}) {:?}", b.hash(), b);
let c = State::build()
.add(OID_COLOUR.clone(), Box::new(Vec3u8 { a: 1, b: 2, c: 3 }))
.add(OID_POWERED.clone(), Box::new(true))
.add(OID_INTENSITY.clone(), Box::new(1023))
.finish();
let key = rand::random();
db.update(key, &b, &c).unwrap();
let d = db.get_input(key).unwrap().unwrap();
let e = db.get_output(key).unwrap().unwrap();
assert_eq!(&b, d.deref());
assert_eq!(&c, e.deref());
}
} }

View File

@ -1,37 +1,27 @@
use std::{
fmt,
hash::{
Hasher
},
};
use std::fmt::{Debug, Display, Formatter}; use std::fmt::{Debug, Display, Formatter};
use std::{fmt, hash::Hasher};
use std::ops::Deref; use std::ops::Deref;
use rkyv::{out_field, Archive, Deserialize, Serialize};
use rkyv::{Archive, Deserialize, out_field, Serialize};
use serde::de::{Error, MapAccess, Unexpected}; use serde::de::{Error, MapAccess, Unexpected};
use serde::Deserializer;
use serde::ser::SerializeMap; use serde::ser::SerializeMap;
use serde::Deserializer;
use crate::MachineState;
use crate::resources::modules::fabaccess::OID_VALUE; use crate::resources::modules::fabaccess::OID_VALUE;
use crate::MachineState;
use crate::utils::oid::ObjectIdentifier; use crate::utils::oid::ObjectIdentifier;
pub mod value;
pub mod db; pub mod db;
pub mod value;
#[derive(Archive, Serialize, Deserialize)] #[derive(Archive, Serialize, Deserialize, Clone, PartialEq, Eq)]
#[derive(Clone, PartialEq, Eq)]
#[archive_attr(derive(Debug))] #[archive_attr(derive(Debug))]
pub struct State { pub struct State {
pub inner: MachineState, pub inner: MachineState,
} }
impl fmt::Debug for State { impl fmt::Debug for State {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut sf = f.debug_struct("State"); let mut sf = f.debug_struct("State");
@ -51,7 +41,8 @@ impl fmt::Display for ArchivedState {
impl serde::Serialize for State { impl serde::Serialize for State {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: serde::Serializer where
S: serde::Serializer,
{ {
let mut ser = serializer.serialize_map(Some(1))?; let mut ser = serializer.serialize_map(Some(1))?;
ser.serialize_entry(OID_VALUE.deref(), &self.inner)?; ser.serialize_entry(OID_VALUE.deref(), &self.inner)?;
@ -60,7 +51,8 @@ impl serde::Serialize for State {
} }
impl<'de> serde::Deserialize<'de> for State { impl<'de> serde::Deserialize<'de> for State {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: Deserializer<'de> where
D: Deserializer<'de>,
{ {
deserializer.deserialize_map(StateVisitor) deserializer.deserialize_map(StateVisitor)
} }
@ -74,12 +66,13 @@ impl<'de> serde::de::Visitor<'de> for StateVisitor {
write!(formatter, "a map from OIDs to value objects") write!(formatter, "a map from OIDs to value objects")
} }
fn visit_map<A: MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> fn visit_map<A: MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
{ let oid: ObjectIdentifier = map.next_key()?.ok_or(A::Error::missing_field("oid"))?;
let oid: ObjectIdentifier = map.next_key()?
.ok_or(A::Error::missing_field("oid"))?;
if oid != *OID_VALUE.deref() { if oid != *OID_VALUE.deref() {
return Err(A::Error::invalid_value(Unexpected::Other("Unknown OID"), &"OID of fabaccess state")) return Err(A::Error::invalid_value(
Unexpected::Other("Unknown OID"),
&"OID of fabaccess state",
));
} }
let val: MachineState = map.next_value()?; let val: MachineState = map.next_value()?;
Ok(State { inner: val }) Ok(State { inner: val })
@ -88,72 +81,6 @@ impl<'de> serde::de::Visitor<'de> for StateVisitor {
#[cfg(test)] #[cfg(test)]
pub mod tests { pub mod tests {
use super::*;
use super::value::*; use super::value::*;
use super::*;
pub(crate) fn gen_random() -> State {
let amt: u8 = rand::random::<u8>() % 20;
let mut sb = State::build();
for _ in 0..amt {
let oid = crate::utils::oid::tests::gen_random();
sb = match rand::random::<u32>()%12 {
0 => sb.add(oid, Box::new(rand::random::<bool>())),
1 => sb.add(oid, Box::new(rand::random::<u8>())),
2 => sb.add(oid, Box::new(rand::random::<u16>())),
3 => sb.add(oid, Box::new(rand::random::<u32>())),
4 => sb.add(oid, Box::new(rand::random::<u64>())),
5 => sb.add(oid, Box::new(rand::random::<u128>())),
6 => sb.add(oid, Box::new(rand::random::<i8>())),
7 => sb.add(oid, Box::new(rand::random::<i16>())),
8 => sb.add(oid, Box::new(rand::random::<i32>())),
9 => sb.add(oid, Box::new(rand::random::<i64>())),
10 => sb.add(oid, Box::new(rand::random::<i128>())),
11 => sb.add(oid, Box::new(rand::random::<Vec3u8>())),
_ => unreachable!(),
}
}
sb.finish()
}
#[test]
fn test_equal_state_is_eq() {
let stateA = State::build()
.add(OID_POWERED.clone(), Box::new(false))
.add(OID_INTENSITY.clone(), Box::new(1024))
.finish();
let stateB = State::build()
.add(OID_POWERED.clone(), Box::new(false))
.add(OID_INTENSITY.clone(), Box::new(1024))
.finish();
assert_eq!(stateA, stateB);
}
#[test]
fn test_unequal_state_is_ne() {
let stateA = State::build()
.add(OID_POWERED.clone(), Box::new(true))
.add(OID_INTENSITY.clone(), Box::new(512))
.finish();
let stateB = State::build()
.add(OID_POWERED.clone(), Box::new(false))
.add(OID_INTENSITY.clone(), Box::new(1024))
.finish();
assert_ne!(stateA, stateB);
}
#[test]
fn test_state_is_clone() {
let stateA = gen_random();
let stateB = stateA.clone();
let stateC = stateB.clone();
drop(stateA);
assert_eq!(stateC, stateB);
}
} }

View File

@ -1,10 +1,12 @@
use std::{hash::Hash}; use std::hash::Hash;
use ptr_meta::{DynMetadata, Pointee}; use ptr_meta::{DynMetadata, Pointee};
use rkyv::{out_field, Archive, ArchivePointee, ArchiveUnsized, Archived, ArchivedMetadata, Serialize, SerializeUnsized, RelPtr}; use rkyv::{
out_field, Archive, ArchivePointee, ArchiveUnsized, Archived, ArchivedMetadata, RelPtr,
Serialize, SerializeUnsized,
};
use rkyv_dyn::{DynError, DynSerializer}; use rkyv_dyn::{DynError, DynSerializer};
use crate::utils::oid::ObjectIdentifier; use crate::utils::oid::ObjectIdentifier;
// Not using linkme because dynamically loaded modules // Not using linkme because dynamically loaded modules
@ -16,7 +18,6 @@ use serde::ser::SerializeMap;
use std::collections::HashMap; use std::collections::HashMap;
use std::ops::Deref; use std::ops::Deref;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
@ -61,7 +62,7 @@ struct NewStateBuilder {
// turns into // turns into
struct NewState { struct NewState {
inner: ArchivedVec<ArchivedMetaBox<dyn ArchivedStateValue>> inner: ArchivedVec<ArchivedMetaBox<dyn ArchivedStateValue>>,
} }
impl NewState { impl NewState {
pub fn get_value<T: TypeOid>(&self) -> Option<&T> { pub fn get_value<T: TypeOid>(&self) -> Option<&T> {
@ -148,7 +149,7 @@ pub trait TypeOid {
} }
impl<T> SerializeDynOid for T impl<T> SerializeDynOid for T
where where
T: for<'a> Serialize<dyn DynSerializer + 'a>, T: for<'a> Serialize<dyn DynSerializer + 'a>,
T::Archived: TypeOid, T::Archived: TypeOid,
{ {
@ -371,7 +372,7 @@ pub mod macros {
stringify!($y) stringify!($y)
} }
} }
} };
} }
#[macro_export] #[macro_export]
@ -380,16 +381,15 @@ pub mod macros {
unsafe impl $crate::resources::state::value::RegisteredImpl for $z { unsafe impl $crate::resources::state::value::RegisteredImpl for $z {
fn vtable() -> usize { fn vtable() -> usize {
unsafe { unsafe {
::core::mem::transmute(ptr_meta::metadata( ::core::mem::transmute(ptr_meta::metadata(::core::ptr::null::<$z>()
::core::ptr::null::<$z>() as *const dyn $crate::resources::state::value::ArchivedStateValue as *const dyn $crate::resources::state::value::ArchivedStateValue))
))
} }
} }
fn debug_info() -> $crate::resources::state::value::ImplDebugInfo { fn debug_info() -> $crate::resources::state::value::ImplDebugInfo {
$crate::debug_info!() $crate::debug_info!()
} }
} }
} };
} }
#[macro_export] #[macro_export]

View File

@ -0,0 +1 @@

View File

@ -1,20 +1,13 @@
use crate::authorization::permissions::Permission; use crate::authorization::permissions::Permission;
use crate::authorization::roles::{Roles}; use crate::authorization::roles::Roles;
use crate::resources::Resource; use crate::resources::Resource;
use crate::Users;
use crate::users::{db, UserRef}; use crate::users::{db, UserRef};
use crate::Users;
#[derive(Clone)] #[derive(Clone)]
pub struct SessionManager { pub struct SessionManager {
users: Users, users: Users,
roles: Roles, roles: Roles,
// cache: SessionCache // todo // cache: SessionCache // todo
} }
impl SessionManager { impl SessionManager {
@ -52,33 +45,39 @@ impl SessionHandle {
} }
pub fn get_user(&self) -> db::User { pub fn get_user(&self) -> db::User {
self.users.get_user(self.user.get_username()).expect("Failed to get user self") self.users
.get_user(self.user.get_username())
.expect("Failed to get user self")
} }
pub fn has_disclose(&self, resource: &Resource) -> bool { pub fn has_disclose(&self, resource: &Resource) -> bool {
if let Some(user) = self.users.get_user(self.user.get_username()) { if let Some(user) = self.users.get_user(self.user.get_username()) {
self.roles.is_permitted(&user.userdata, &resource.get_required_privs().disclose) self.roles
.is_permitted(&user.userdata, &resource.get_required_privs().disclose)
} else { } else {
false false
} }
} }
pub fn has_read(&self, resource: &Resource) -> bool { pub fn has_read(&self, resource: &Resource) -> bool {
if let Some(user) = self.users.get_user(self.user.get_username()) { if let Some(user) = self.users.get_user(self.user.get_username()) {
self.roles.is_permitted(&user.userdata, &resource.get_required_privs().read) self.roles
.is_permitted(&user.userdata, &resource.get_required_privs().read)
} else { } else {
false false
} }
} }
pub fn has_write(&self, resource: &Resource) -> bool { pub fn has_write(&self, resource: &Resource) -> bool {
if let Some(user) = self.users.get_user(self.user.get_username()) { if let Some(user) = self.users.get_user(self.user.get_username()) {
self.roles.is_permitted(&user.userdata, &resource.get_required_privs().write) self.roles
.is_permitted(&user.userdata, &resource.get_required_privs().write)
} else { } else {
false false
} }
} }
pub fn has_manage(&self, resource: &Resource) -> bool { pub fn has_manage(&self, resource: &Resource) -> bool {
if let Some(user) = self.users.get_user(self.user.get_username()) { if let Some(user) = self.users.get_user(self.user.get_username()) {
self.roles.is_permitted(&user.userdata, &resource.get_required_privs().manage) self.roles
.is_permitted(&user.userdata, &resource.get_required_privs().manage)
} else { } else {
false false
} }

View File

@ -4,26 +4,39 @@ use std::io::BufReader;
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use futures_rustls::TlsAcceptor;
use rustls::{Certificate, PrivateKey, ServerConfig, SupportedCipherSuite};
use rustls::version::{TLS12, TLS13};
use tracing::{Level};
use crate::capnp::TlsListen; use crate::capnp::TlsListen;
use futures_rustls::TlsAcceptor;
use rustls::version::{TLS12, TLS13};
use rustls::{Certificate, PrivateKey, ServerConfig, SupportedCipherSuite};
use tracing::Level;
use crate::keylog::KeyLogFile; use crate::keylog::KeyLogFile;
fn lookup_cipher_suite(name: &str) -> Option<SupportedCipherSuite> { fn lookup_cipher_suite(name: &str) -> Option<SupportedCipherSuite> {
match name { match name {
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => Some(rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => {
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => Some(rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384), Some(rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256)
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => Some(rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256), }
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" => Some(rustls::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256), "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => {
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" => Some(rustls::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384), Some(rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384)
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => Some(rustls::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256), }
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => {
Some(rustls::cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256)
}
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" => {
Some(rustls::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256)
}
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" => {
Some(rustls::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384)
}
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => {
Some(rustls::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256)
}
"TLS13_AES_128_GCM_SHA256" => Some(rustls::cipher_suite::TLS13_AES_128_GCM_SHA256), "TLS13_AES_128_GCM_SHA256" => Some(rustls::cipher_suite::TLS13_AES_128_GCM_SHA256),
"TLS13_AES_256_GCM_SHA384" => Some(rustls::cipher_suite::TLS13_AES_256_GCM_SHA384), "TLS13_AES_256_GCM_SHA384" => Some(rustls::cipher_suite::TLS13_AES_256_GCM_SHA384),
"TLS13_CHACHA20_POLY1305_SHA256" => Some(rustls::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256), "TLS13_CHACHA20_POLY1305_SHA256" => {
Some(rustls::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256)
}
_ => None, _ => None,
} }
} }
@ -43,7 +56,6 @@ impl TlsConfig {
} }
if let Some(path) = keylogfile { if let Some(path) = keylogfile {
let keylog = Some(KeyLogFile::new(path).map(|ok| Arc::new(ok))?); let keylog = Some(KeyLogFile::new(path).map(|ok| Arc::new(ok))?);
Ok(Self { keylog }) Ok(Self { keylog })
} else { } else {

View File

@ -1,15 +1,15 @@
use lmdb::{DatabaseFlags, Environment, RwTransaction, Transaction, WriteFlags}; use lmdb::{DatabaseFlags, Environment, RwTransaction, Transaction, WriteFlags};
use std::collections::{HashMap};
use rkyv::Infallible; use rkyv::Infallible;
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Context; use anyhow::Context;
use std::sync::Arc;
use rkyv::{Deserialize};
use rkyv::ser::Serializer;
use rkyv::ser::serializers::AllocSerializer;
use crate::db; use crate::db;
use crate::db::{AlignedAdapter, ArchivedValue, DB, RawDB}; use crate::db::{AlignedAdapter, ArchivedValue, RawDB, DB};
use rkyv::ser::serializers::AllocSerializer;
use rkyv::ser::Serializer;
use rkyv::Deserialize;
#[derive( #[derive(
Clone, Clone,
@ -30,8 +30,7 @@ pub struct User {
impl User { impl User {
pub fn check_password(&self, pwd: &[u8]) -> anyhow::Result<bool> { pub fn check_password(&self, pwd: &[u8]) -> anyhow::Result<bool> {
if let Some(ref encoded) = self.userdata.passwd { if let Some(ref encoded) = self.userdata.passwd {
argon2::verify_encoded(encoded, pwd) argon2::verify_encoded(encoded, pwd).context("Stored password is an invalid string")
.context("Stored password is an invalid string")
} else { } else {
Ok(false) Ok(false)
} }
@ -48,23 +47,23 @@ impl User {
id: username.to_string(), id: username.to_string(),
userdata: UserData { userdata: UserData {
passwd: Some(hash), passwd: Some(hash),
.. Default::default() ..Default::default()
} },
} }
} }
} }
#[derive( #[derive(
Clone, Clone,
PartialEq, PartialEq,
Eq, Eq,
Debug, Debug,
Default, Default,
rkyv::Archive, rkyv::Archive,
rkyv::Serialize, rkyv::Serialize,
rkyv::Deserialize, rkyv::Deserialize,
serde::Serialize, serde::Serialize,
serde::Deserialize, serde::Deserialize,
)] )]
/// Data on an user to base decisions on /// Data on an user to base decisions on
/// ///
@ -85,12 +84,19 @@ pub struct UserData {
impl UserData { impl UserData {
pub fn new(roles: Vec<String>) -> Self { pub fn new(roles: Vec<String>) -> Self {
Self { roles, kv: HashMap::new(), passwd: None } Self {
roles,
kv: HashMap::new(),
passwd: None,
}
} }
pub fn new_with_kv(roles: Vec<String>, kv: HashMap<String, String>) -> Self { pub fn new_with_kv(roles: Vec<String>, kv: HashMap<String, String>) -> Self {
Self { roles, kv, passwd: None } Self {
roles,
kv,
passwd: None,
}
} }
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -140,7 +146,12 @@ impl UserDB {
Ok(()) Ok(())
} }
pub fn put_txn(&self, txn: &mut RwTransaction, uid: &str, user: &User) -> Result<(), db::Error> { pub fn put_txn(
&self,
txn: &mut RwTransaction,
uid: &str,
user: &User,
) -> Result<(), db::Error> {
let mut serializer = AllocSerializer::<1024>::default(); let mut serializer = AllocSerializer::<1024>::default();
serializer.serialize_value(user).expect("rkyv error"); serializer.serialize_value(user).expect("rkyv error");
let v = serializer.into_serializer().into_inner(); let v = serializer.into_serializer().into_inner();
@ -169,7 +180,8 @@ impl UserDB {
let mut out = Vec::new(); let mut out = Vec::new();
for (uid, user) in iter { for (uid, user) in iter {
let uid = unsafe { std::str::from_utf8_unchecked(uid).to_string() }; let uid = unsafe { std::str::from_utf8_unchecked(uid).to_string() };
let user: User = Deserialize::<User, _>::deserialize(user.as_ref(), &mut Infallible).unwrap(); let user: User =
Deserialize::<User, _>::deserialize(user.as_ref(), &mut Infallible).unwrap();
out.push((uid, user)); out.push((uid, user));
} }

View File

@ -10,8 +10,6 @@ use std::sync::Arc;
pub mod db; pub mod db;
use crate::users::db::UserData; use crate::users::db::UserData;
use crate::UserDB; use crate::UserDB;
@ -87,10 +85,9 @@ impl Users {
pub fn get_user(&self, uid: &str) -> Option<db::User> { pub fn get_user(&self, uid: &str) -> Option<db::User> {
tracing::trace!(uid, "Looking up user"); tracing::trace!(uid, "Looking up user");
self.userdb self.userdb.get(uid).unwrap().map(|user| {
.get(uid) Deserialize::<db::User, _>::deserialize(user.as_ref(), &mut Infallible).unwrap()
.unwrap() })
.map(|user| Deserialize::<db::User, _>::deserialize(user.as_ref(), &mut Infallible).unwrap())
} }
pub fn put_user(&self, uid: &str, user: &db::User) -> Result<(), lmdb::Error> { pub fn put_user(&self, uid: &str, user: &db::User) -> Result<(), lmdb::Error> {

View File

@ -1,20 +1,16 @@
use std::collections::HashMap; use std::collections::HashMap;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
struct Locales { struct Locales {
map: HashMap<&'static str, HashMap<&'static str, &'static str>> map: HashMap<&'static str, HashMap<&'static str, &'static str>>,
} }
impl Locales { impl Locales {
pub fn get(&self, lang: &str, msg: &str) pub fn get(&self, lang: &str, msg: &str) -> Option<(&'static str, &'static str)> {
-> Option<(&'static str, &'static str)> self.map
{ .get(msg)
self.map.get(msg).and_then(|map| { .and_then(|map| map.get_key_value(lang).map(|(k, v)| (*k, *v)))
map.get_key_value(lang).map(|(k,v)| (*k, *v))
})
} }
pub fn available(&self, _msg: &str) -> &[&'static str] { pub fn available(&self, _msg: &str) -> &[&'static str] {
@ -22,8 +18,8 @@ impl Locales {
} }
} }
static LANG: Lazy<Locales> = Lazy::new(|| { static LANG: Lazy<Locales> = Lazy::new(|| Locales {
Locales { map: HashMap::new() } map: HashMap::new(),
}); });
struct L10NString { struct L10NString {

View File

@ -52,16 +52,16 @@
//! [Object Identifiers]: https://en.wikipedia.org/wiki/Object_identifier //! [Object Identifiers]: https://en.wikipedia.org/wiki/Object_identifier
//! [ITU]: https://en.wikipedia.org/wiki/International_Telecommunications_Union //! [ITU]: https://en.wikipedia.org/wiki/International_Telecommunications_Union
use rkyv::{Archive, Serialize}; use crate::utils::varint::VarU128;
use rkyv::ser::Serializer;
use rkyv::vec::{ArchivedVec, VecResolver}; use rkyv::vec::{ArchivedVec, VecResolver};
use rkyv::{Archive, Serialize};
use std::convert::TryFrom; use std::convert::TryFrom;
use std::ops::Deref; use std::convert::TryInto;
use std::fmt; use std::fmt;
use std::fmt::Formatter; use std::fmt::Formatter;
use rkyv::ser::Serializer; use std::ops::Deref;
use std::str::FromStr; use std::str::FromStr;
use crate::utils::varint::VarU128;
use std::convert::TryInto;
type Node = u128; type Node = u128;
type VarNode = VarU128; type VarNode = VarU128;
@ -69,8 +69,8 @@ type VarNode = VarU128;
/// Convenience module for quickly importing the public interface (e.g., `use oid::prelude::*`) /// Convenience module for quickly importing the public interface (e.g., `use oid::prelude::*`)
pub mod prelude { pub mod prelude {
pub use super::ObjectIdentifier; pub use super::ObjectIdentifier;
pub use super::ObjectIdentifierRoot::*;
pub use super::ObjectIdentifierError; pub use super::ObjectIdentifierError;
pub use super::ObjectIdentifierRoot::*;
pub use core::convert::{TryFrom, TryInto}; pub use core::convert::{TryFrom, TryInto};
} }
@ -132,7 +132,8 @@ impl ObjectIdentifier {
let mut parsing_big_int = false; let mut parsing_big_int = false;
let mut big_int: Node = 0; let mut big_int: Node = 0;
for i in 1..nodes.len() { for i in 1..nodes.len() {
if !parsing_big_int && nodes[i] < 128 {} else { if !parsing_big_int && nodes[i] < 128 {
} else {
if big_int > 0 { if big_int > 0 {
if big_int >= Node::MAX >> 7 { if big_int >= Node::MAX >> 7 {
return Err(ObjectIdentifierError::IllegalChildNodeValue); return Err(ObjectIdentifierError::IllegalChildNodeValue);
@ -149,9 +150,11 @@ impl ObjectIdentifier {
Ok(Self { nodes }) Ok(Self { nodes })
} }
pub fn build<B: AsRef<[Node]>>(root: ObjectIdentifierRoot, first: u8, children: B) pub fn build<B: AsRef<[Node]>>(
-> Result<Self, ObjectIdentifierError> root: ObjectIdentifierRoot,
{ first: u8,
children: B,
) -> Result<Self, ObjectIdentifierError> {
if first > 40 { if first > 40 {
return Err(ObjectIdentifierError::IllegalFirstChildNode); return Err(ObjectIdentifierError::IllegalFirstChildNode);
} }
@ -163,7 +166,9 @@ impl ObjectIdentifier {
let var: VarNode = child.into(); let var: VarNode = child.into();
vec.extend_from_slice(var.as_bytes()) vec.extend_from_slice(var.as_bytes())
} }
Ok(Self { nodes: vec.into_boxed_slice() }) Ok(Self {
nodes: vec.into_boxed_slice(),
})
} }
#[inline(always)] #[inline(always)]
@ -196,12 +201,14 @@ impl FromStr for ObjectIdentifier {
fn from_str(value: &str) -> Result<Self, Self::Err> { fn from_str(value: &str) -> Result<Self, Self::Err> {
let mut nodes = value.split("."); let mut nodes = value.split(".");
let root = nodes.next() let root = nodes
.next()
.and_then(|n| n.parse::<u8>().ok()) .and_then(|n| n.parse::<u8>().ok())
.and_then(|n| n.try_into().ok()) .and_then(|n| n.try_into().ok())
.ok_or(ObjectIdentifierError::IllegalRootNode)?; .ok_or(ObjectIdentifierError::IllegalRootNode)?;
let first = nodes.next() let first = nodes
.next()
.and_then(|n| parse_string_first_node(n).ok()) .and_then(|n| parse_string_first_node(n).ok())
.ok_or(ObjectIdentifierError::IllegalFirstChildNode)?; .ok_or(ObjectIdentifierError::IllegalFirstChildNode)?;
@ -238,7 +245,7 @@ impl fmt::Debug for ObjectIdentifier {
#[repr(transparent)] #[repr(transparent)]
pub struct ArchivedObjectIdentifier { pub struct ArchivedObjectIdentifier {
archived: ArchivedVec<u8> archived: ArchivedVec<u8>,
} }
impl Deref for ArchivedObjectIdentifier { impl Deref for ArchivedObjectIdentifier {
@ -250,8 +257,12 @@ impl Deref for ArchivedObjectIdentifier {
impl fmt::Debug for ArchivedObjectIdentifier { impl fmt::Debug for ArchivedObjectIdentifier {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}", &convert_to_string(self.archived.as_slice()) write!(
.unwrap_or_else(|e| format!("Invalid OID: {:?}", e))) f,
"{}",
&convert_to_string(self.archived.as_slice())
.unwrap_or_else(|e| format!("Invalid OID: {:?}", e))
)
} }
} }
@ -275,7 +286,8 @@ impl Archive for &'static ObjectIdentifier {
} }
impl<S: Serializer + ?Sized> Serialize<S> for ObjectIdentifier impl<S: Serializer + ?Sized> Serialize<S> for ObjectIdentifier
where [u8]: rkyv::SerializeUnsized<S> where
[u8]: rkyv::SerializeUnsized<S>,
{ {
fn serialize(&self, serializer: &mut S) -> Result<Self::Resolver, S::Error> { fn serialize(&self, serializer: &mut S) -> Result<Self::Resolver, S::Error> {
ArchivedVec::serialize_from_slice(self.nodes.as_ref(), serializer) ArchivedVec::serialize_from_slice(self.nodes.as_ref(), serializer)
@ -340,8 +352,7 @@ fn convert_to_string(nodes: &[u8]) -> Result<String, ObjectIdentifierError> {
impl Into<String> for &ObjectIdentifier { impl Into<String> for &ObjectIdentifier {
fn into(self) -> String { fn into(self) -> String {
convert_to_string(&self.nodes) convert_to_string(&self.nodes).expect("Valid OID object couldn't be serialized.")
.expect("Valid OID object couldn't be serialized.")
} }
} }
@ -468,16 +479,13 @@ mod serde_support {
} }
} }
impl ser::Serialize for ArchivedObjectIdentifier { impl ser::Serialize for ArchivedObjectIdentifier {
fn serialize<S>( fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
&self,
serializer: S,
) -> Result<S::Ok, S::Error>
where where
S: ser::Serializer, S: ser::Serializer,
{ {
if serializer.is_human_readable() { if serializer.is_human_readable() {
let encoded: String = convert_to_string(self.deref()) let encoded: String =
.expect("Failed to convert valid OID to String"); convert_to_string(self.deref()).expect("Failed to convert valid OID to String");
serializer.serialize_str(&encoded) serializer.serialize_str(&encoded)
} else { } else {
serializer.serialize_bytes(self.deref()) serializer.serialize_bytes(self.deref())
@ -498,30 +506,13 @@ pub(crate) mod tests {
children.push(rand::random()); children.push(rand::random());
} }
ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, 25, children) ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, 25, children).unwrap()
.unwrap()
}
#[test]
fn bincode_serde_roundtrip() {
let expected = ObjectIdentifier::build(
ObjectIdentifierRoot::ItuT,
0x01,
vec![1, 2, 3, 5, 8, 13, 21],
).unwrap();
let buffer: Vec<u8> = bincode::serialize(&expected).unwrap();
let actual = bincode::deserialize(&buffer).unwrap();
assert_eq!(expected, actual);
} }
#[test] #[test]
fn encode_binary_root_node_0() { fn encode_binary_root_node_0() {
let expected: Vec<u8> = vec![0]; let expected: Vec<u8> = vec![0];
let oid = ObjectIdentifier::build( let oid = ObjectIdentifier::build(ObjectIdentifierRoot::ItuT, 0x00, vec![]).unwrap();
ObjectIdentifierRoot::ItuT,
0x00,
vec![],
).unwrap();
let actual: Vec<u8> = oid.into(); let actual: Vec<u8> = oid.into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -529,11 +520,7 @@ pub(crate) mod tests {
#[test] #[test]
fn encode_binary_root_node_1() { fn encode_binary_root_node_1() {
let expected: Vec<u8> = vec![40]; let expected: Vec<u8> = vec![40];
let oid = ObjectIdentifier::build( let oid = ObjectIdentifier::build(ObjectIdentifierRoot::Iso, 0x00, vec![]).unwrap();
ObjectIdentifierRoot::Iso,
0x00,
vec![],
).unwrap();
let actual: Vec<u8> = oid.into(); let actual: Vec<u8> = oid.into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -541,11 +528,8 @@ pub(crate) mod tests {
#[test] #[test]
fn encode_binary_root_node_2() { fn encode_binary_root_node_2() {
let expected: Vec<u8> = vec![80]; let expected: Vec<u8> = vec![80];
let oid = ObjectIdentifier::build( let oid =
ObjectIdentifierRoot::JointIsoItuT, ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, 0x00, vec![]).unwrap();
0x00,
vec![],
).unwrap();
let actual: Vec<u8> = oid.into(); let actual: Vec<u8> = oid.into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -557,7 +541,8 @@ pub(crate) mod tests {
ObjectIdentifierRoot::ItuT, ObjectIdentifierRoot::ItuT,
0x01, 0x01,
vec![1, 2, 3, 5, 8, 13, 21], vec![1, 2, 3, 5, 8, 13, 21],
).unwrap(); )
.unwrap();
let actual: Vec<u8> = oid.into(); let actual: Vec<u8> = oid.into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -572,7 +557,8 @@ pub(crate) mod tests {
ObjectIdentifierRoot::JointIsoItuT, ObjectIdentifierRoot::JointIsoItuT,
39, 39,
vec![42, 2501, 65535, 2147483647, 1235, 2352], vec![42, 2501, 65535, 2147483647, 1235, 2352],
).unwrap(); )
.unwrap();
let actual: Vec<u8> = (oid).into(); let actual: Vec<u8> = (oid).into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -580,11 +566,7 @@ pub(crate) mod tests {
#[test] #[test]
fn encode_string_root_node_0() { fn encode_string_root_node_0() {
let expected = "0.0"; let expected = "0.0";
let oid = ObjectIdentifier::build( let oid = ObjectIdentifier::build(ObjectIdentifierRoot::ItuT, 0x00, vec![]).unwrap();
ObjectIdentifierRoot::ItuT,
0x00,
vec![],
).unwrap();
let actual: String = (oid).into(); let actual: String = (oid).into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -592,11 +574,7 @@ pub(crate) mod tests {
#[test] #[test]
fn encode_string_root_node_1() { fn encode_string_root_node_1() {
let expected = "1.0"; let expected = "1.0";
let oid = ObjectIdentifier::build( let oid = ObjectIdentifier::build(ObjectIdentifierRoot::Iso, 0x00, vec![]).unwrap();
ObjectIdentifierRoot::Iso,
0x00,
vec![],
).unwrap();
let actual: String = (&oid).into(); let actual: String = (&oid).into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -604,11 +582,8 @@ pub(crate) mod tests {
#[test] #[test]
fn encode_string_root_node_2() { fn encode_string_root_node_2() {
let expected = "2.0"; let expected = "2.0";
let oid = ObjectIdentifier::build( let oid =
ObjectIdentifierRoot::JointIsoItuT, ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, 0x00, vec![]).unwrap();
0x00,
vec![],
).unwrap();
let actual: String = (&oid).into(); let actual: String = (&oid).into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -620,7 +595,8 @@ pub(crate) mod tests {
ObjectIdentifierRoot::ItuT, ObjectIdentifierRoot::ItuT,
0x01, 0x01,
vec![1, 2, 3, 5, 8, 13, 21], vec![1, 2, 3, 5, 8, 13, 21],
).unwrap(); )
.unwrap();
let actual: String = (&oid).into(); let actual: String = (&oid).into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -632,40 +608,29 @@ pub(crate) mod tests {
ObjectIdentifierRoot::JointIsoItuT, ObjectIdentifierRoot::JointIsoItuT,
39, 39,
vec![42, 2501, 65535, 2147483647, 1235, 2352], vec![42, 2501, 65535, 2147483647, 1235, 2352],
).unwrap(); )
.unwrap();
let actual: String = (&oid).into(); let actual: String = (&oid).into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
#[test] #[test]
fn parse_binary_root_node_0() { fn parse_binary_root_node_0() {
let expected = ObjectIdentifier::build( let expected = ObjectIdentifier::build(ObjectIdentifierRoot::ItuT, 0x00, vec![]);
ObjectIdentifierRoot::ItuT,
0x00,
vec![],
);
let actual = vec![0x00].try_into(); let actual = vec![0x00].try_into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
#[test] #[test]
fn parse_binary_root_node_1() { fn parse_binary_root_node_1() {
let expected = ObjectIdentifier::build( let expected = ObjectIdentifier::build(ObjectIdentifierRoot::Iso, 0x00, vec![]);
ObjectIdentifierRoot::Iso,
0x00,
vec![],
);
let actual = vec![40].try_into(); let actual = vec![40].try_into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
#[test] #[test]
fn parse_binary_root_node_2() { fn parse_binary_root_node_2() {
let expected = ObjectIdentifier::build( let expected = ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, 0x00, vec![]);
ObjectIdentifierRoot::JointIsoItuT,
0x00,
vec![],
);
let actual = vec![80].try_into(); let actual = vec![80].try_into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -698,33 +663,21 @@ pub(crate) mod tests {
#[test] #[test]
fn parse_string_root_node_0() { fn parse_string_root_node_0() {
let expected = ObjectIdentifier::build( let expected = ObjectIdentifier::build(ObjectIdentifierRoot::ItuT, 0x00, vec![]);
ObjectIdentifierRoot::ItuT,
0x00,
vec![],
);
let actual = "0.0".try_into(); let actual = "0.0".try_into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
#[test] #[test]
fn parse_string_root_node_1() { fn parse_string_root_node_1() {
let expected = ObjectIdentifier::build( let expected = ObjectIdentifier::build(ObjectIdentifierRoot::Iso, 0x00, vec![]);
ObjectIdentifierRoot::Iso,
0x00,
vec![],
);
let actual = "1.0".try_into(); let actual = "1.0".try_into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
#[test] #[test]
fn parse_string_root_node_2() { fn parse_string_root_node_2() {
let expected = ObjectIdentifier::build( let expected = ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, 0x00, vec![]);
ObjectIdentifierRoot::JointIsoItuT,
0x00,
vec![],
);
let actual = "2.0".try_into(); let actual = "2.0".try_into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -852,10 +805,11 @@ pub(crate) mod tests {
#[test] #[test]
fn parse_string_large_children_ok() { fn parse_string_large_children_ok() {
let expected = let expected = ObjectIdentifier::build(
ObjectIdentifier::build(ObjectIdentifierRoot::JointIsoItuT, ObjectIdentifierRoot::JointIsoItuT,
25, 25,
vec![190754093376743485973207716749546715206, vec![
190754093376743485973207716749546715206,
255822649272987943607843257596365752308, 255822649272987943607843257596365752308,
15843412533224453995377625663329542022, 15843412533224453995377625663329542022,
6457999595881951503805148772927347934, 6457999595881951503805148772927347934,
@ -863,7 +817,9 @@ pub(crate) mod tests {
195548685662657784196186957311035194990, 195548685662657784196186957311035194990,
233020488258340943072303499291936117654, 233020488258340943072303499291936117654,
193307160423854019916786016773068715190, 193307160423854019916786016773068715190,
]).unwrap(); ],
)
.unwrap();
let actual = "2.25.190754093376743485973207716749546715206.\ let actual = "2.25.190754093376743485973207716749546715206.\
255822649272987943607843257596365752308.\ 255822649272987943607843257596365752308.\
15843412533224453995377625663329542022.\ 15843412533224453995377625663329542022.\
@ -871,18 +827,17 @@ pub(crate) mod tests {
19545192863105095042881850060069531734.\ 19545192863105095042881850060069531734.\
195548685662657784196186957311035194990.\ 195548685662657784196186957311035194990.\
233020488258340943072303499291936117654.\ 233020488258340943072303499291936117654.\
193307160423854019916786016773068715190".try_into().unwrap(); 193307160423854019916786016773068715190"
.try_into()
.unwrap();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
#[test] #[test]
fn encode_to_string() { fn encode_to_string() {
let expected = String::from("1.2.3.4"); let expected = String::from("1.2.3.4");
let actual: String = ObjectIdentifier::build( let actual: String = ObjectIdentifier::build(ObjectIdentifierRoot::Iso, 2, vec![3, 4])
ObjectIdentifierRoot::Iso, .unwrap()
2,
vec![3, 4],
).unwrap()
.into(); .into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }
@ -890,11 +845,8 @@ pub(crate) mod tests {
#[test] #[test]
fn encode_to_bytes() { fn encode_to_bytes() {
let expected = vec![0x2A, 0x03, 0x04]; let expected = vec![0x2A, 0x03, 0x04];
let actual: Vec<u8> = ObjectIdentifier::build( let actual: Vec<u8> = ObjectIdentifier::build(ObjectIdentifierRoot::Iso, 2, vec![3, 4])
ObjectIdentifierRoot::Iso, .unwrap()
2,
vec![3, 4],
).unwrap()
.into(); .into();
assert_eq!(expected, actual); assert_eq!(expected, actual);
} }

View File

@ -1,11 +1,10 @@
use uuid::Uuid;
use api::general_capnp::u_u_i_d::{Builder, Reader}; use api::general_capnp::u_u_i_d::{Builder, Reader};
use uuid::Uuid;
pub fn uuid_to_api(uuid: Uuid, mut builder: Builder) { pub fn uuid_to_api(uuid: Uuid, mut builder: Builder) {
let [a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p] let [a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p] = uuid.as_u128().to_ne_bytes();
= uuid.as_u128().to_ne_bytes(); let lower = u64::from_ne_bytes([a, b, c, d, e, f, g, h]);
let lower = u64::from_ne_bytes([a,b,c,d,e,f,g,h]); let upper = u64::from_ne_bytes([i, j, k, l, m, n, o, p]);
let upper = u64::from_ne_bytes([i,j,k,l,m,n,o,p]);
builder.set_uuid0(lower); builder.set_uuid0(lower);
builder.set_uuid1(upper); builder.set_uuid1(upper);
} }
@ -13,8 +12,8 @@ pub fn uuid_to_api(uuid: Uuid, mut builder: Builder) {
pub fn api_to_uuid(reader: Reader) -> Uuid { pub fn api_to_uuid(reader: Reader) -> Uuid {
let lower: u64 = reader.reborrow().get_uuid0(); let lower: u64 = reader.reborrow().get_uuid0();
let upper: u64 = reader.get_uuid1(); let upper: u64 = reader.get_uuid1();
let [a,b,c,d,e,f,g,h] = lower.to_ne_bytes(); let [a, b, c, d, e, f, g, h] = lower.to_ne_bytes();
let [i,j,k,l,m,n,o,p] = upper.to_ne_bytes(); let [i, j, k, l, m, n, o, p] = upper.to_ne_bytes();
let num = u128::from_ne_bytes([a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p]); let num = u128::from_ne_bytes([a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p]);
Uuid::from_u128(num) Uuid::from_u128(num)
} }

View File

@ -27,7 +27,6 @@ impl<const N: usize> VarUInt<N> {
pub const fn into_bytes(self) -> [u8; N] { pub const fn into_bytes(self) -> [u8; N] {
self.bytes self.bytes
} }
} }
impl<const N: usize> Default for VarUInt<N> { impl<const N: usize> Default for VarUInt<N> {
@ -52,7 +51,7 @@ macro_rules! convert_from {
let bytes = this.as_mut_bytes(); let bytes = this.as_mut_bytes();
let mut more = 0u8; let mut more = 0u8;
let mut idx: usize = bytes.len()-1; let mut idx: usize = bytes.len() - 1;
while num > 0x7f { while num > 0x7f {
bytes[idx] = ((num & 0x7f) as u8 | more); bytes[idx] = ((num & 0x7f) as u8 | more);
@ -65,7 +64,7 @@ macro_rules! convert_from {
this.offset = idx; this.offset = idx;
this this
} }
} };
} }
macro_rules! convert_into { macro_rules! convert_into {
@ -84,7 +83,7 @@ macro_rules! convert_into {
let mut shift = 0; let mut shift = 0;
for neg in 1..=len { for neg in 1..=len {
let idx = len-neg; let idx = len - neg;
let val = (bytes[idx] & 0x7f) as $x; let val = (bytes[idx] & 0x7f) as $x;
let shifted = val << shift; let shifted = val << shift;
out |= shifted; out |= shifted;
@ -93,7 +92,7 @@ macro_rules! convert_into {
out out
} }
} };
} }
macro_rules! impl_convert_from_to { macro_rules! impl_convert_from_to {
@ -105,7 +104,7 @@ macro_rules! impl_convert_from_to {
impl Into<$num> for VarUInt<$req> { impl Into<$num> for VarUInt<$req> {
convert_into! { $num } convert_into! { $num }
} }
} };
} }
impl_convert_from_to!(u8, 2, VarU8); impl_convert_from_to!(u8, 2, VarU8);
@ -123,8 +122,9 @@ type VarUsize = VarU32;
type VarUsize = VarU16; type VarUsize = VarU16;
impl<T, const N: usize> From<&T> for VarUInt<N> impl<T, const N: usize> From<&T> for VarUInt<N>
where T: Copy, where
VarUInt<N>: From<T> T: Copy,
VarUInt<N>: From<T>,
{ {
fn from(t: &T) -> Self { fn from(t: &T) -> Self {
(*t).into() (*t).into()

View File

@ -1,12 +1,9 @@
use clap::{Arg, Command}; use clap::{Arg, Command};
use diflouroborane::{config, Diflouroborane}; use diflouroborane::{config, Diflouroborane};
use std::str::FromStr; use std::str::FromStr;
use std::{env, io, io::Write, path::PathBuf}; use std::{env, io, io::Write, path::PathBuf};
use nix::NixPath; use nix::NixPath;
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
@ -125,12 +122,19 @@ fn main() -> anyhow::Result<()> {
unimplemented!() unimplemented!()
} else if matches.is_present("load") { } else if matches.is_present("load") {
let bffh = Diflouroborane::new(config)?; let bffh = Diflouroborane::new(config)?;
if bffh.users.load_file(matches.value_of("load").unwrap()).is_ok() { if bffh
.users
.load_file(matches.value_of("load").unwrap())
.is_ok()
{
tracing::info!("loaded users from {}", matches.value_of("load").unwrap()); tracing::info!("loaded users from {}", matches.value_of("load").unwrap());
} else { } else {
tracing::error!("failed to load users from {}", matches.value_of("load").unwrap()); tracing::error!(
"failed to load users from {}",
matches.value_of("load").unwrap()
);
} }
return Ok(()) return Ok(());
} else { } else {
let keylog = matches.value_of("keylog"); let keylog = matches.value_of("keylog");
// When passed an empty string (i.e no value) take the value from the env // When passed an empty string (i.e no value) take the value from the env

View File

@ -4,48 +4,57 @@ fn main() {
println!(">>> Building version number..."); println!(">>> Building version number...");
let rustc = std::env::var("RUSTC").unwrap(); let rustc = std::env::var("RUSTC").unwrap();
let out = Command::new(rustc).arg("--version") let out = Command::new(rustc)
.arg("--version")
.output() .output()
.expect("failed to run `rustc --version`"); .expect("failed to run `rustc --version`");
let rustc_version = String::from_utf8(out.stdout) let rustc_version =
.expect("rustc --version returned invalid UTF-8"); String::from_utf8(out.stdout).expect("rustc --version returned invalid UTF-8");
let rustc_version = rustc_version.trim(); let rustc_version = rustc_version.trim();
println!("cargo:rustc-env=CARGO_RUSTC_VERSION={}", rustc_version); println!("cargo:rustc-env=CARGO_RUSTC_VERSION={}", rustc_version);
println!("cargo:rerun-if-env-changed=BFFHD_BUILD_TAGGED_RELEASE"); println!("cargo:rerun-if-env-changed=BFFHD_BUILD_TAGGED_RELEASE");
let tagged_release = option_env!("BFFHD_BUILD_TAGGED_RELEASE") == Some("1"); let tagged_release = option_env!("BFFHD_BUILD_TAGGED_RELEASE") == Some("1");
let version_string = if tagged_release { let version_string = if tagged_release {
format!("{version} [{rustc}]", format!(
"{version} [{rustc}]",
version = env!("CARGO_PKG_VERSION"), version = env!("CARGO_PKG_VERSION"),
rustc = rustc_version) rustc = rustc_version
)
} else { } else {
// Build version number using the current git commit id // Build version number using the current git commit id
let out = Command::new("git").arg("rev-list") let out = Command::new("git")
.arg("rev-list")
.args(["HEAD", "-1"]) .args(["HEAD", "-1"])
.output() .output()
.expect("failed to run `git rev-list HEAD -1`"); .expect("failed to run `git rev-list HEAD -1`");
let owned_gitrev = String::from_utf8(out.stdout) let owned_gitrev =
.expect("git rev-list output was not valid UTF8"); String::from_utf8(out.stdout).expect("git rev-list output was not valid UTF8");
let gitrev = owned_gitrev.trim(); let gitrev = owned_gitrev.trim();
let abbrev = match gitrev.len(){ let abbrev = match gitrev.len() {
0 => "unknown", 0 => "unknown",
_ => &gitrev[0..9], _ => &gitrev[0..9],
}; };
let out = Command::new("git").arg("log") let out = Command::new("git")
.arg("log")
.args(["-1", "--format=%as"]) .args(["-1", "--format=%as"])
.output() .output()
.expect("failed to run `git log -1 --format=\"format:%as\"`"); .expect("failed to run `git log -1 --format=\"format:%as\"`");
let commit_date = String::from_utf8(out.stdout) let commit_date = String::from_utf8(out.stdout).expect("git log output was not valid UTF8");
.expect("git log output was not valid UTF8");
let commit_date = commit_date.trim(); let commit_date = commit_date.trim();
format!("{version} ({gitrev} {date}) [{rustc}]", format!(
version=env!("CARGO_PKG_VERSION"), "{version} ({gitrev} {date}) [{rustc}]",
gitrev=abbrev, version = env!("CARGO_PKG_VERSION"),
date=commit_date, gitrev = abbrev,
rustc=rustc_version) date = commit_date,
rustc = rustc_version
)
}; };
println!("cargo:rustc-env=BFFHD_VERSION_STRING={}", version_string); println!("cargo:rustc-env=BFFHD_VERSION_STRING={}", version_string);
println!("cargo:rustc-env=BFFHD_RELEASE_STRING=\"BFFH {}\"", version_string); println!(
"cargo:rustc-env=BFFHD_RELEASE_STRING=\"BFFH {}\"",
version_string
);
} }

View File

@ -1,12 +0,0 @@
[package]
name = "dummy"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
crate-type = ["cdylib"]
[dependencies]
sdk = { path = "../sdk" }

View File

@ -1,29 +0,0 @@
use sdk::BoxFuture;
use sdk::initiators::{Initiator, InitiatorError, ResourceID, UpdateSink};
#[sdk::module]
struct Dummy {
a: u32,
b: u32,
c: u32,
d: u32,
}
impl Initiator for Dummy {
fn start_for(&mut self, machine: ResourceID) -> BoxFuture<'static, Result<(), Box<dyn InitiatorError>>> {
todo!()
}
fn run(&mut self, request: &mut UpdateSink) -> BoxFuture<'static, Result<(), Box<dyn InitiatorError>>> {
todo!()
}
}
#[cfg(test)]
mod tests {
#[test]
fn it_works() {
let result = 2 + 2;
assert_eq!(result, 4);
}
}

View File

@ -1,10 +1,10 @@
use proc_macro::TokenStream; use proc_macro::TokenStream;
use std::sync::Mutex;
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use syn::{braced, parse_macro_input, Field, Ident, Token, Visibility, Type}; use std::sync::Mutex;
use syn::parse::{Parse, ParseStream}; use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated; use syn::punctuated::Punctuated;
use syn::token::Brace; use syn::token::Brace;
use syn::{braced, parse_macro_input, Field, Ident, Token, Type, Visibility};
mod keywords { mod keywords {
syn::custom_keyword!(initiator); syn::custom_keyword!(initiator);
@ -32,8 +32,10 @@ impl Parse for ModuleAttrs {
} else if lookahead.peek(keywords::sensor) { } else if lookahead.peek(keywords::sensor) {
Ok(ModuleAttrs::Sensor) Ok(ModuleAttrs::Sensor)
} else { } else {
Err(input.error("Module type must be empty or one of \"initiator\", \"actor\", or \ Err(input.error(
\"sensor\"")) "Module type must be empty or one of \"initiator\", \"actor\", or \
\"sensor\"",
))
} }
} }
} }

View File

@ -1,10 +1 @@
pub use diflouroborane::{
initiators::{
UpdateSink,
UpdateError,
Initiator,
InitiatorError,
},
resource::claim::ResourceID,
};

View File

@ -1,5 +1,4 @@
#[forbid(private_in_public)] #[forbid(private_in_public)]
pub use sdk_proc::module; pub use sdk_proc::module;
pub use futures_util::future::BoxFuture; pub use futures_util::future::BoxFuture;

View File

@ -1,19 +1,19 @@
use executor::prelude::*;
use criterion::{black_box, criterion_group, criterion_main, Criterion}; use criterion::{black_box, criterion_group, criterion_main, Criterion};
use executor::prelude::*;
fn increment(b: &mut Criterion) { fn increment(b: &mut Criterion) {
let mut sum = 0; let mut sum = 0;
let executor = Executor::new(); let executor = Executor::new();
b.bench_function("Executor::run", |b| b.iter(|| { b.bench_function("Executor::run", |b| {
executor.run( b.iter(|| {
async { executor.run(async {
(0..10_000_000).for_each(|_| { (0..10_000_000).for_each(|_| {
sum += 1; sum += 1;
}); });
}, });
); })
})); });
black_box(sum); black_box(sum);
} }

View File

@ -1,8 +1,8 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use executor::load_balancer; use executor::load_balancer;
use executor::prelude::*; use executor::prelude::*;
use futures_timer::Delay; use futures_timer::Delay;
use std::time::Duration; use std::time::Duration;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
#[cfg(feature = "tokio-runtime")] #[cfg(feature = "tokio-runtime")]
mod benches { mod benches {
@ -27,7 +27,6 @@ mod benches {
pub fn spawn_single(b: &mut Criterion) { pub fn spawn_single(b: &mut Criterion) {
_spawn_single(b); _spawn_single(b);
} }
} }
criterion_group!(spawn, benches::spawn_lot, benches::spawn_single); criterion_group!(spawn, benches::spawn_lot, benches::spawn_single);
@ -36,29 +35,29 @@ criterion_main!(spawn);
// Benchmark for a 10K burst task spawn // Benchmark for a 10K burst task spawn
fn _spawn_lot(b: &mut Criterion) { fn _spawn_lot(b: &mut Criterion) {
let executor = Executor::new(); let executor = Executor::new();
b.bench_function("spawn_lot", |b| b.iter(|| { b.bench_function("spawn_lot", |b| {
b.iter(|| {
let _ = (0..10_000) let _ = (0..10_000)
.map(|_| { .map(|_| {
executor.spawn( executor.spawn(async {
async {
let duration = Duration::from_millis(1); let duration = Duration::from_millis(1);
Delay::new(duration).await; Delay::new(duration).await;
}, })
)
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
})); })
});
} }
// Benchmark for a single task spawn // Benchmark for a single task spawn
fn _spawn_single(b: &mut Criterion) { fn _spawn_single(b: &mut Criterion) {
let executor = Executor::new(); let executor = Executor::new();
b.bench_function("spawn single", |b| b.iter(|| { b.bench_function("spawn single", |b| {
executor.spawn( b.iter(|| {
async { executor.spawn(async {
let duration = Duration::from_millis(1); let duration = Duration::from_millis(1);
Delay::new(duration).await; Delay::new(duration).await;
}, });
); })
})); });
} }

View File

@ -1,7 +1,7 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use executor::load_balancer::{core_count, get_cores, stats, SmpStats}; use executor::load_balancer::{core_count, get_cores, stats, SmpStats};
use executor::placement; use executor::placement;
use std::thread; use std::thread;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
fn stress_stats<S: SmpStats + Sync + Send>(stats: &'static S) { fn stress_stats<S: SmpStats + Sync + Send>(stats: &'static S) {
let mut handles = Vec::with_capacity(*core_count()); let mut handles = Vec::with_capacity(*core_count());
@ -27,9 +27,11 @@ fn stress_stats<S: SmpStats + Sync + Send>(stats: &'static S) {
// 158,278 ns/iter (+/- 117,103) // 158,278 ns/iter (+/- 117,103)
fn lockless_stats_bench(b: &mut Criterion) { fn lockless_stats_bench(b: &mut Criterion) {
b.bench_function("stress_stats", |b| b.iter(|| { b.bench_function("stress_stats", |b| {
b.iter(|| {
stress_stats(stats()); stress_stats(stats());
})); })
});
} }
fn lockless_stats_bad_load(b: &mut Criterion) { fn lockless_stats_bad_load(b: &mut Criterion) {
@ -45,9 +47,11 @@ fn lockless_stats_bad_load(b: &mut Criterion) {
} }
} }
b.bench_function("get_sorted_load", |b| b.iter(|| { b.bench_function("get_sorted_load", |b| {
b.iter(|| {
let _sorted_load = stats.get_sorted_load(); let _sorted_load = stats.get_sorted_load();
})); })
});
} }
fn lockless_stats_good_load(b: &mut Criterion) { fn lockless_stats_good_load(b: &mut Criterion) {
@ -59,11 +63,17 @@ fn lockless_stats_good_load(b: &mut Criterion) {
stats.store_load(i, i); stats.store_load(i, i);
} }
b.bench_function("get_sorted_load", |b| b.iter(|| { b.bench_function("get_sorted_load", |b| {
b.iter(|| {
let _sorted_load = stats.get_sorted_load(); let _sorted_load = stats.get_sorted_load();
})); })
});
} }
criterion_group!(stats_bench, lockless_stats_bench, lockless_stats_bad_load, criterion_group!(
lockless_stats_good_load); stats_bench,
lockless_stats_bench,
lockless_stats_bad_load,
lockless_stats_good_load
);
criterion_main!(stats_bench); criterion_main!(stats_bench);

View File

@ -1,12 +1,12 @@
use executor::pool;
use executor::prelude::*;
use futures_util::{stream::FuturesUnordered, Stream};
use futures_util::{FutureExt, StreamExt};
use lightproc::prelude::RecoverableHandle;
use std::io::Write; use std::io::Write;
use std::panic::resume_unwind; use std::panic::resume_unwind;
use std::rc::Rc; use std::rc::Rc;
use std::time::Duration; use std::time::Duration;
use futures_util::{stream::FuturesUnordered, Stream};
use futures_util::{FutureExt, StreamExt};
use executor::pool;
use executor::prelude::*;
use lightproc::prelude::RecoverableHandle;
fn main() { fn main() {
tracing_subscriber::fmt() tracing_subscriber::fmt()
@ -24,9 +24,9 @@ fn main() {
let executor = Executor::new(); let executor = Executor::new();
let mut handles: FuturesUnordered<RecoverableHandle<usize>> = (0..2000).map(|n| { let mut handles: FuturesUnordered<RecoverableHandle<usize>> = (0..2000)
executor.spawn( .map(|n| {
async move { executor.spawn(async move {
let m: u64 = rand::random::<u64>() % 200; let m: u64 = rand::random::<u64>() % 200;
tracing::debug!("Will sleep {} * 1 ms", m); tracing::debug!("Will sleep {} * 1 ms", m);
// simulate some really heavy load. // simulate some really heavy load.
@ -34,9 +34,9 @@ fn main() {
async_std::task::sleep(Duration::from_millis(1)).await; async_std::task::sleep(Duration::from_millis(1)).await;
} }
return n; return n;
}, })
) })
}).collect(); .collect();
//let handle = handles.fuse().all(|opt| async move { opt.is_some() }); //let handle = handles.fuse().all(|opt| async move { opt.is_some() });
/* Futures passed to `spawn` need to be `Send` so this won't work: /* Futures passed to `spawn` need to be `Send` so this won't work:
@ -58,12 +58,12 @@ fn main() {
// However, you can't pass it a future outright but have to hand it a generator creating the // However, you can't pass it a future outright but have to hand it a generator creating the
// future on the correct thread. // future on the correct thread.
let fut = async { let fut = async {
let local_futs: FuturesUnordered<_> = (0..200).map(|ref n| { let local_futs: FuturesUnordered<_> = (0..200)
.map(|ref n| {
let n = *n; let n = *n;
let exe = executor.clone(); let exe = executor.clone();
async move { async move {
exe.spawn( exe.spawn(async {
async {
let tid = std::thread::current().id(); let tid = std::thread::current().id();
tracing::info!("spawn_local({}) is on thread {:?}", n, tid); tracing::info!("spawn_local({}) is on thread {:?}", n, tid);
exe.spawn_local(async move { exe.spawn_local(async move {
@ -86,10 +86,11 @@ fn main() {
*rc *rc
}) })
})
.await
} }
).await })
} .collect();
}).collect();
local_futs local_futs
}; };
@ -108,12 +109,10 @@ fn main() {
async_std::task::sleep(Duration::from_secs(20)).await; async_std::task::sleep(Duration::from_secs(20)).await;
tracing::info!("This is taking too long."); tracing::info!("This is taking too long.");
}; };
executor.run( executor.run(async {
async {
let res = futures_util::select! { let res = futures_util::select! {
_ = a.fuse() => {}, _ = a.fuse() => {},
_ = b.fuse() => {}, _ = b.fuse() => {},
}; };
}, });
);
} }

View File

@ -28,10 +28,10 @@
#![forbid(unused_import_braces)] #![forbid(unused_import_braces)]
pub mod load_balancer; pub mod load_balancer;
pub mod manage;
pub mod placement; pub mod placement;
pub mod pool; pub mod pool;
pub mod run; pub mod run;
pub mod manage;
mod thread_manager; mod thread_manager;
mod worker; mod worker;

View File

@ -1,6 +1,2 @@
/// View and Manage the current processes of this executor /// View and Manage the current processes of this executor
pub struct Manager { pub struct Manager {}
}

View File

@ -7,19 +7,19 @@
//! [`spawn`]: crate::pool::spawn //! [`spawn`]: crate::pool::spawn
//! [`Worker`]: crate::run_queue::Worker //! [`Worker`]: crate::run_queue::Worker
use std::cell::Cell; use crate::run::block;
use crate::thread_manager::{ThreadManager, DynamicRunner}; use crate::thread_manager::{DynamicRunner, ThreadManager};
use crate::worker::{Sleeper, WorkerThread};
use crossbeam_deque::{Injector, Stealer};
use lightproc::lightproc::LightProc; use lightproc::lightproc::LightProc;
use lightproc::recoverable_handle::RecoverableHandle; use lightproc::recoverable_handle::RecoverableHandle;
use std::cell::Cell;
use std::future::Future; use std::future::Future;
use std::iter::Iterator; use std::iter::Iterator;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use crossbeam_deque::{Injector, Stealer};
use crate::run::block;
use crate::worker::{Sleeper, WorkerThread};
#[derive(Debug)] #[derive(Debug)]
struct Spooler<'a> { struct Spooler<'a> {
@ -31,10 +31,13 @@ struct Spooler<'a> {
impl Spooler<'_> { impl Spooler<'_> {
pub fn new() -> Self { pub fn new() -> Self {
let spool = Arc::new(Injector::new()); let spool = Arc::new(Injector::new());
let threads = Box::leak(Box::new( let threads = Box::leak(Box::new(ThreadManager::new(2, AsyncRunner, spool.clone())));
ThreadManager::new(2, AsyncRunner, spool.clone())));
threads.initialize(); threads.initialize();
Self { spool, threads, _marker: PhantomData } Self {
spool,
threads,
_marker: PhantomData,
}
} }
} }
@ -53,9 +56,7 @@ impl<'a, 'executor: 'a> Executor<'executor> {
fn schedule(&self) -> impl Fn(LightProc) + 'a { fn schedule(&self) -> impl Fn(LightProc) + 'a {
let task_queue = self.spooler.spool.clone(); let task_queue = self.spooler.spool.clone();
move |lightproc: LightProc| { move |lightproc: LightProc| task_queue.push(lightproc)
task_queue.push(lightproc)
}
} }
/// ///
@ -98,8 +99,7 @@ impl<'a, 'executor: 'a> Executor<'executor> {
F: Future<Output = R> + Send + 'a, F: Future<Output = R> + Send + 'a,
R: Send + 'a, R: Send + 'a,
{ {
let (task, handle) = let (task, handle) = LightProc::recoverable(future, self.schedule());
LightProc::recoverable(future, self.schedule());
task.schedule(); task.schedule();
handle handle
} }
@ -109,8 +109,7 @@ impl<'a, 'executor: 'a> Executor<'executor> {
F: Future<Output = R> + 'a, F: Future<Output = R> + 'a,
R: Send + 'a, R: Send + 'a,
{ {
let (task, handle) = let (task, handle) = LightProc::recoverable(future, schedule_local());
LightProc::recoverable(future, schedule_local());
task.schedule(); task.schedule();
handle handle
} }
@ -174,17 +173,20 @@ impl DynamicRunner for AsyncRunner {
sleeper sleeper
} }
fn run_static<'b>(fences: impl Iterator<Item=&'b Stealer<LightProc>>, park_timeout: Duration) -> ! { fn run_static<'b>(
fences: impl Iterator<Item = &'b Stealer<LightProc>>,
park_timeout: Duration,
) -> ! {
let worker = get_worker(); let worker = get_worker();
worker.run_timeout(fences, park_timeout) worker.run_timeout(fences, park_timeout)
} }
fn run_dynamic<'b>(fences: impl Iterator<Item=&'b Stealer<LightProc>>) -> ! { fn run_dynamic<'b>(fences: impl Iterator<Item = &'b Stealer<LightProc>>) -> ! {
let worker = get_worker(); let worker = get_worker();
worker.run(fences) worker.run(fences)
} }
fn run_standalone<'b>(fences: impl Iterator<Item=&'b Stealer<LightProc>>) { fn run_standalone<'b>(fences: impl Iterator<Item = &'b Stealer<LightProc>>) {
let worker = get_worker(); let worker = get_worker();
worker.run_once(fences) worker.run_once(fences)
} }
@ -196,10 +198,9 @@ thread_local! {
fn get_worker() -> &'static WorkerThread<'static, LightProc> { fn get_worker() -> &'static WorkerThread<'static, LightProc> {
WORKER.with(|cell| { WORKER.with(|cell| {
let worker = unsafe { let worker = unsafe { &*cell.as_ptr() as &'static Option<WorkerThread<_>> };
&*cell.as_ptr() as &'static Option<WorkerThread<_>> worker
}; .as_ref()
worker.as_ref()
.expect("AsyncRunner running outside Executor context") .expect("AsyncRunner running outside Executor context")
}) })
} }

View File

@ -45,13 +45,18 @@
//! Throughput hogs determined by a combination of job in / job out frequency and current scheduler task assignment frequency. //! Throughput hogs determined by a combination of job in / job out frequency and current scheduler task assignment frequency.
//! Threshold of EMA difference is eluded by machine epsilon for floating point arithmetic errors. //! Threshold of EMA difference is eluded by machine epsilon for floating point arithmetic errors.
use crate::worker::Sleeper;
use crate::{load_balancer, placement}; use crate::{load_balancer, placement};
use core::fmt; use core::fmt;
use crossbeam_channel::bounded;
use crossbeam_deque::{Injector, Stealer};
use crossbeam_queue::ArrayQueue; use crossbeam_queue::ArrayQueue;
use fmt::{Debug, Formatter}; use fmt::{Debug, Formatter};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use lightproc::lightproc::LightProc;
use placement::CoreId; use placement::CoreId;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::sync::{Arc, RwLock};
use std::time::Duration; use std::time::Duration;
use std::{ use std::{
sync::{ sync::{
@ -60,12 +65,7 @@ use std::{
}, },
thread, thread,
}; };
use std::sync::{Arc, RwLock};
use crossbeam_channel::bounded;
use crossbeam_deque::{Injector, Stealer};
use tracing::{debug, trace}; use tracing::{debug, trace};
use lightproc::lightproc::LightProc;
use crate::worker::Sleeper;
/// The default thread park timeout before checking for new tasks. /// The default thread park timeout before checking for new tasks.
const THREAD_PARK_TIMEOUT: Duration = Duration::from_millis(1); const THREAD_PARK_TIMEOUT: Duration = Duration::from_millis(1);
@ -113,10 +113,12 @@ lazy_static! {
pub trait DynamicRunner { pub trait DynamicRunner {
fn setup(task_queue: Arc<Injector<LightProc>>) -> Sleeper<LightProc>; fn setup(task_queue: Arc<Injector<LightProc>>) -> Sleeper<LightProc>;
fn run_static<'b>(fences: impl Iterator<Item=&'b Stealer<LightProc>>, fn run_static<'b>(
park_timeout: Duration) -> !; fences: impl Iterator<Item = &'b Stealer<LightProc>>,
fn run_dynamic<'b>(fences: impl Iterator<Item=&'b Stealer<LightProc>>) -> !; park_timeout: Duration,
fn run_standalone<'b>(fences: impl Iterator<Item=&'b Stealer<LightProc>>); ) -> !;
fn run_dynamic<'b>(fences: impl Iterator<Item = &'b Stealer<LightProc>>) -> !;
fn run_standalone<'b>(fences: impl Iterator<Item = &'b Stealer<LightProc>>);
} }
/// The `ThreadManager` is creates and destroys worker threads depending on demand according to /// The `ThreadManager` is creates and destroys worker threads depending on demand according to
@ -183,11 +185,14 @@ impl<Runner: Debug> Debug for ThreadManager<Runner> {
} }
fmt.debug_struct("DynamicPoolManager") fmt.debug_struct("DynamicPoolManager")
.field("thread pool", &ThreadCount( .field(
"thread pool",
&ThreadCount(
&self.static_threads, &self.static_threads,
&self.dynamic_threads, &self.dynamic_threads,
&self.parked_threads.len(), &self.parked_threads.len(),
)) ),
)
.field("runner", &self.runner) .field("runner", &self.runner)
.field("last_frequency", &self.last_frequency) .field("last_frequency", &self.last_frequency)
.finish() .finish()
@ -195,7 +200,11 @@ impl<Runner: Debug> Debug for ThreadManager<Runner> {
} }
impl<Runner: DynamicRunner + Sync + Send> ThreadManager<Runner> { impl<Runner: DynamicRunner + Sync + Send> ThreadManager<Runner> {
pub fn new(static_threads: usize, runner: Runner, task_queue: Arc<Injector<LightProc>>) -> Self { pub fn new(
static_threads: usize,
runner: Runner,
task_queue: Arc<Injector<LightProc>>,
) -> Self {
let dynamic_threads = 1.max(num_cpus::get().checked_sub(static_threads).unwrap_or(0)); let dynamic_threads = 1.max(num_cpus::get().checked_sub(static_threads).unwrap_or(0));
let parked_threads = ArrayQueue::new(1.max(static_threads + dynamic_threads)); let parked_threads = ArrayQueue::new(1.max(static_threads + dynamic_threads));
let fences = Arc::new(RwLock::new(Vec::new())); let fences = Arc::new(RwLock::new(Vec::new()));
@ -252,7 +261,10 @@ impl<Runner: DynamicRunner + Sync + Send> ThreadManager<Runner> {
}); });
// Dynamic thread manager that will allow us to unpark threads when needed // Dynamic thread manager that will allow us to unpark threads when needed
debug!("spooling up {} dynamic worker threads", self.dynamic_threads); debug!(
"spooling up {} dynamic worker threads",
self.dynamic_threads
);
(0..self.dynamic_threads).for_each(|_| { (0..self.dynamic_threads).for_each(|_| {
let tx = tx.clone(); let tx = tx.clone();
let fencelock = fencelock.clone(); let fencelock = fencelock.clone();
@ -302,10 +314,11 @@ impl<Runner: DynamicRunner + Sync + Send> ThreadManager<Runner> {
/// Provision threads takes a number of threads that need to be made available. /// Provision threads takes a number of threads that need to be made available.
/// It will try to unpark threads from the dynamic pool, and spawn more threads if needs be. /// It will try to unpark threads from the dynamic pool, and spawn more threads if needs be.
pub fn provision_threads(&'static self, pub fn provision_threads(
&'static self,
n: usize, n: usize,
fencelock: &Arc<RwLock<Vec<Stealer<LightProc>>>>) fencelock: &Arc<RwLock<Vec<Stealer<LightProc>>>>,
{ ) {
let rem = self.unpark_thread(n); let rem = self.unpark_thread(n);
if rem != 0 { if rem != 0 {
debug!("no more threads to unpark, spawning {} new threads", rem); debug!("no more threads to unpark, spawning {} new threads", rem);
@ -391,7 +404,5 @@ impl<Runner: DynamicRunner + Sync + Send> ThreadManager<Runner> {
/// on the request rate. /// on the request rate.
/// ///
/// It uses frequency based calculation to define work. Utilizing average processing rate. /// It uses frequency based calculation to define work. Utilizing average processing rate.
fn scale_pool(&'static self) { fn scale_pool(&'static self) {}
}
} }

View File

@ -1,10 +1,10 @@
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::Duration;
use crossbeam_deque::{Injector, Steal, Stealer, Worker}; use crossbeam_deque::{Injector, Steal, Stealer, Worker};
use crossbeam_queue::SegQueue; use crossbeam_queue::SegQueue;
use crossbeam_utils::sync::{Parker, Unparker}; use crossbeam_utils::sync::{Parker, Unparker};
use lightproc::prelude::LightProc; use lightproc::prelude::LightProc;
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::Duration;
pub trait Runnable { pub trait Runnable {
fn run(self); fn run(self);
@ -61,8 +61,14 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
let unparker = parker.unparker().clone(); let unparker = parker.unparker().clone();
( (
Self { task_queue, tasks, local_tasks, parker, _marker }, Self {
Sleeper { stealer, unparker } task_queue,
tasks,
local_tasks,
parker,
_marker,
},
Sleeper { stealer, unparker },
) )
} }
@ -71,10 +77,8 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
} }
/// Run this worker thread "forever" (i.e. until the thread panics or is otherwise killed) /// Run this worker thread "forever" (i.e. until the thread panics or is otherwise killed)
pub fn run(&self, fences: impl Iterator<Item=&'a Stealer<T>>) -> ! { pub fn run(&self, fences: impl Iterator<Item = &'a Stealer<T>>) -> ! {
let fences: Vec<Stealer<T>> = fences let fences: Vec<Stealer<T>> = fences.map(|stealer| stealer.clone()).collect();
.map(|stealer| stealer.clone())
.collect();
loop { loop {
self.run_inner(&fences); self.run_inner(&fences);
@ -82,10 +86,12 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
} }
} }
pub fn run_timeout(&self, fences: impl Iterator<Item=&'a Stealer<T>>, timeout: Duration) -> ! { pub fn run_timeout(
let fences: Vec<Stealer<T>> = fences &self,
.map(|stealer| stealer.clone()) fences: impl Iterator<Item = &'a Stealer<T>>,
.collect(); timeout: Duration,
) -> ! {
let fences: Vec<Stealer<T>> = fences.map(|stealer| stealer.clone()).collect();
loop { loop {
self.run_inner(&fences); self.run_inner(&fences);
@ -93,10 +99,8 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
} }
} }
pub fn run_once(&self, fences: impl Iterator<Item=&'a Stealer<T>>) { pub fn run_once(&self, fences: impl Iterator<Item = &'a Stealer<T>>) {
let fences: Vec<Stealer<T>> = fences let fences: Vec<Stealer<T>> = fences.map(|stealer| stealer.clone()).collect();
.map(|stealer| stealer.clone())
.collect();
self.run_inner(fences); self.run_inner(fences);
} }
@ -123,17 +127,19 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
Steal::Success(task) => { Steal::Success(task) => {
task.run(); task.run();
continue 'work; continue 'work;
}, }
// If there is no more work to steal from the global queue, try other // If there is no more work to steal from the global queue, try other
// workers next // workers next
Steal::Empty => break, Steal::Empty => break,
// If a race condition occurred try again with backoff // If a race condition occurred try again with backoff
Steal::Retry => for _ in 0..(1 << i) { Steal::Retry => {
for _ in 0..(1 << i) {
core::hint::spin_loop(); core::hint::spin_loop();
i += 1; i += 1;
}, }
}
} }
} }
@ -145,7 +151,7 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
Steal::Success(task) => { Steal::Success(task) => {
task.run(); task.run();
continue 'work; continue 'work;
}, }
// If no other worker has work to do we're done once again. // If no other worker has work to do we're done once again.
Steal::Empty => break, Steal::Empty => break,
@ -169,6 +175,6 @@ impl<'a, T: Runnable + 'a> WorkerThread<'a, T> {
} }
#[inline(always)] #[inline(always)]
fn select_fence<'a, T>(fences: impl Iterator<Item=&'a Stealer<T>>) -> Option<&'a Stealer<T>> { fn select_fence<'a, T>(fences: impl Iterator<Item = &'a Stealer<T>>) -> Option<&'a Stealer<T>> {
fences.max_by_key(|fence| fence.len()) fences.max_by_key(|fence| fence.len())
} }

View File

@ -1,38 +0,0 @@
use std::io::Write;
use executor::run::run;
use std::thread;
use std::time::Duration;
use executor::prelude::{ProcStack, spawn};
#[cfg(feature = "tokio-runtime")]
mod tokio_tests {
#[tokio::test]
async fn test_run_blocking() {
super::run_test()
}
}
#[cfg(not(feature = "tokio-runtime"))]
mod no_tokio_tests {
#[test]
fn test_run_blocking() {
super::run_test()
}
}
fn run_test() {
let handle = spawn(
async {
let duration = Duration::from_millis(1);
thread::sleep(duration);
//42
},
);
let output = run(handle, ProcStack {});
println!("{:?}", output);
std::io::stdout().flush();
assert!(output.is_some());
std::thread::sleep(Duration::from_millis(200));
}

View File

@ -1,149 +0,0 @@
use executor::blocking;
use executor::run::run;
use futures_util::future::join_all;
use lightproc::recoverable_handle::RecoverableHandle;
use std::thread;
use std::time::Duration;
use std::time::Instant;
use executor::prelude::ProcStack;
// Test for slow joins without task bursts during joins.
#[test]
#[ignore]
fn slow_join() {
let thread_join_time_max = 11_000;
let start = Instant::now();
// Send an initial batch of million bursts.
let handles = (0..1_000_000)
.map(|_| {
blocking::spawn_blocking(
async {
let duration = Duration::from_millis(1);
thread::sleep(duration);
},
)
})
.collect::<Vec<RecoverableHandle<()>>>();
run(join_all(handles), ProcStack {});
// Let them join to see how it behaves under different workloads.
let duration = Duration::from_millis(thread_join_time_max);
thread::sleep(duration);
// Spawn yet another batch of work on top of it
let handles = (0..10_000)
.map(|_| {
blocking::spawn_blocking(
async {
let duration = Duration::from_millis(100);
thread::sleep(duration);
},
)
})
.collect::<Vec<RecoverableHandle<()>>>();
run(join_all(handles), ProcStack {});
// Slow joins shouldn't cause internal slow down
let elapsed = start.elapsed().as_millis() - thread_join_time_max as u128;
println!("Slow task join. Monotonic exec time: {:?} ns", elapsed);
// Previous implementation is around this threshold.
}
// Test for slow joins with task burst.
#[test]
#[ignore]
fn slow_join_interrupted() {
let thread_join_time_max = 2_000;
let start = Instant::now();
// Send an initial batch of million bursts.
let handles = (0..1_000_000)
.map(|_| {
blocking::spawn_blocking(
async {
let duration = Duration::from_millis(1);
thread::sleep(duration);
},
)
})
.collect::<Vec<RecoverableHandle<()>>>();
run(join_all(handles), ProcStack {});
// Let them join to see how it behaves under different workloads.
// This time join under the time window.
let duration = Duration::from_millis(thread_join_time_max);
thread::sleep(duration);
// Spawn yet another batch of work on top of it
let handles = (0..10_000)
.map(|_| {
blocking::spawn_blocking(
async {
let duration = Duration::from_millis(100);
thread::sleep(duration);
},
)
})
.collect::<Vec<RecoverableHandle<()>>>();
run(join_all(handles), ProcStack {});
// Slow joins shouldn't cause internal slow down
let elapsed = start.elapsed().as_millis() - thread_join_time_max as u128;
println!("Slow task join. Monotonic exec time: {:?} ns", elapsed);
// Previous implementation is around this threshold.
}
// This test is expensive but it proves that longhauling tasks are working in adaptive thread pool.
// Thread pool which spawns on-demand will panic with this test.
#[test]
#[ignore]
fn longhauling_task_join() {
let thread_join_time_max = 11_000;
let start = Instant::now();
// First batch of overhauling tasks
let _ = (0..100_000)
.map(|_| {
blocking::spawn_blocking(
async {
let duration = Duration::from_millis(1000);
thread::sleep(duration);
},
)
})
.collect::<Vec<RecoverableHandle<()>>>();
// Let them join to see how it behaves under different workloads.
let duration = Duration::from_millis(thread_join_time_max);
thread::sleep(duration);
// Send yet another medium sized batch to see how it scales.
let handles = (0..10_000)
.map(|_| {
blocking::spawn_blocking(
async {
let duration = Duration::from_millis(100);
thread::sleep(duration);
},
)
})
.collect::<Vec<RecoverableHandle<()>>>();
run(join_all(handles), ProcStack {});
// Slow joins shouldn't cause internal slow down
let elapsed = start.elapsed().as_millis() - thread_join_time_max as u128;
println!(
"Long-hauling task join. Monotonic exec time: {:?} ns",
elapsed
);
// Previous implementation will panic when this test is running.
}

View File

@ -1,11 +1,11 @@
use std::any::Any;
use std::fmt::Debug;
use std::ops::Deref;
use crossbeam::channel::{unbounded, Sender}; use crossbeam::channel::{unbounded, Sender};
use futures_executor as executor; use futures_executor as executor;
use lazy_static::lazy_static; use lazy_static::lazy_static;
use lightproc::prelude::*; use lightproc::prelude::*;
use std::any::Any;
use std::fmt::Debug;
use std::future::Future; use std::future::Future;
use std::ops::Deref;
use std::thread; use std::thread;
fn spawn_on_thread<F, R>(future: F) -> RecoverableHandle<R> fn spawn_on_thread<F, R>(future: F) -> RecoverableHandle<R>
@ -30,20 +30,17 @@ where
} }
let schedule = |t| (QUEUE.deref()).send(t).unwrap(); let schedule = |t| (QUEUE.deref()).send(t).unwrap();
let (proc, handle) = LightProc::recoverable( let (proc, handle) = LightProc::recoverable(future, schedule);
future,
schedule
);
let handle = handle let handle = handle.on_panic(
.on_panic(|err: Box<dyn Any + Send>| { |err: Box<dyn Any + Send>| match err.downcast::<&'static str>() {
match err.downcast::<&'static str>() {
Ok(reason) => println!("Future panicked: {}", &reason), Ok(reason) => println!("Future panicked: {}", &reason),
Err(err) => Err(err) => println!(
println!("Future panicked with a non-text reason of typeid {:?}", "Future panicked with a non-text reason of typeid {:?}",
err.type_id()), err.type_id()
} ),
}); },
);
proc.schedule(); proc.schedule();

View File

@ -14,15 +14,10 @@ where
{ {
let (sender, receiver) = channel::unbounded(); let (sender, receiver) = channel::unbounded();
let future = async move { let future = async move { fut.await };
fut.await
};
let schedule = move |t| sender.send(t).unwrap(); let schedule = move |t| sender.send(t).unwrap();
let (proc, handle) = LightProc::build( let (proc, handle) = LightProc::build(future, schedule);
future,
schedule,
);
proc.schedule(); proc.schedule();

View File

@ -77,7 +77,8 @@ impl LightProc {
/// }); /// });
/// ``` /// ```
pub fn recoverable<'a, F, R, S>(future: F, schedule: S) -> (Self, RecoverableHandle<R>) pub fn recoverable<'a, F, R, S>(future: F, schedule: S) -> (Self, RecoverableHandle<R>)
where F: Future<Output=R> + 'a, where
F: Future<Output = R> + 'a,
R: 'a, R: 'a,
S: Fn(LightProc) + 'a, S: Fn(LightProc) + 'a,
{ {
@ -115,7 +116,8 @@ impl LightProc {
/// ); /// );
/// ``` /// ```
pub fn build<'a, F, R, S>(future: F, schedule: S) -> (Self, ProcHandle<R>) pub fn build<'a, F, R, S>(future: F, schedule: S) -> (Self, ProcHandle<R>)
where F: Future<Output=R> + 'a, where
F: Future<Output = R> + 'a,
R: 'a, R: 'a,
S: Fn(LightProc) + 'a, S: Fn(LightProc) + 'a,
{ {

View File

@ -44,12 +44,10 @@ impl ProcData {
let (flags, references) = state.parts(); let (flags, references) = state.parts();
let new = State::new(flags | CLOSED, references); let new = State::new(flags | CLOSED, references);
// Mark the proc as closed. // Mark the proc as closed.
match self.state.compare_exchange_weak( match self
state, .state
new, .compare_exchange_weak(state, new, Ordering::AcqRel, Ordering::Acquire)
Ordering::AcqRel, {
Ordering::Acquire,
) {
Ok(_) => { Ok(_) => {
// Notify the awaiter that the proc has been closed. // Notify the awaiter that the proc has been closed.
if state.is_awaiter() { if state.is_awaiter() {
@ -117,7 +115,8 @@ impl ProcData {
// Release the lock. If we've cleared the awaiter, then also unset the awaiter flag. // Release the lock. If we've cleared the awaiter, then also unset the awaiter flag.
if new_is_none { if new_is_none {
self.state.fetch_and((!LOCKED & !AWAITER).into(), Ordering::Release); self.state
.fetch_and((!LOCKED & !AWAITER).into(), Ordering::Release);
} else { } else {
self.state.fetch_and((!LOCKED).into(), Ordering::Release); self.state.fetch_and((!LOCKED).into(), Ordering::Release);
} }
@ -142,9 +141,7 @@ impl Debug for ProcData {
.field("ref_count", &state.get_refcount()) .field("ref_count", &state.get_refcount())
.finish() .finish()
} else { } else {
fmt.debug_struct("ProcData") fmt.debug_struct("ProcData").field("state", &state).finish()
.field("state", &state)
.finish()
} }
} }
} }

View File

@ -273,9 +273,7 @@ where
let raw = Self::from_ptr(ptr); let raw = Self::from_ptr(ptr);
// Decrement the reference count. // Decrement the reference count.
let new = (*raw.pdata) let new = (*raw.pdata).state.fetch_sub(1, Ordering::AcqRel);
.state
.fetch_sub(1, Ordering::AcqRel);
let new = new.set_refcount(new.get_refcount().saturating_sub(1)); let new = new.set_refcount(new.get_refcount().saturating_sub(1));
// If this was the last reference to the proc and the `ProcHandle` has been dropped as // If this was the last reference to the proc and the `ProcHandle` has been dropped as
@ -444,13 +442,12 @@ where
// was woken and then clean up its resources. // was woken and then clean up its resources.
let (flags, references) = state.parts(); let (flags, references) = state.parts();
let flags = if state.is_closed() { let flags = if state.is_closed() {
flags & !( RUNNING | SCHEDULED ) flags & !(RUNNING | SCHEDULED)
} else { } else {
flags & !RUNNING flags & !RUNNING
}; };
let new = State::new(flags, references); let new = State::new(flags, references);
// Mark the proc as not running. // Mark the proc as not running.
match (*raw.pdata).state.compare_exchange_weak( match (*raw.pdata).state.compare_exchange_weak(
state, state,
@ -502,7 +499,7 @@ impl<'a, F, R, S> Copy for RawProc<'a, F, R, S> {}
/// A guard that closes the proc if polling its future panics. /// A guard that closes the proc if polling its future panics.
struct Guard<'a, F, R, S>(RawProc<'a, F, R, S>) struct Guard<'a, F, R, S>(RawProc<'a, F, R, S>)
where where
F: Future<Output = R> + 'a, F: Future<Output = R> + 'a,
R: 'a, R: 'a,
S: Fn(LightProc) + 'a; S: Fn(LightProc) + 'a;

View File

@ -1,9 +1,9 @@
//! //!
//! Handle for recoverable process //! Handle for recoverable process
use std::any::Any;
use crate::proc_data::ProcData; use crate::proc_data::ProcData;
use crate::proc_handle::ProcHandle; use crate::proc_handle::ProcHandle;
use crate::state::State; use crate::state::State;
use std::any::Any;
use std::fmt::{self, Debug, Formatter}; use std::fmt::{self, Debug, Formatter};
use std::future::Future; use std::future::Future;
use std::pin::Pin; use std::pin::Pin;
@ -80,12 +80,12 @@ impl<R> RecoverableHandle<R> {
/// }); /// });
/// ``` /// ```
pub fn on_panic<F>(mut self, callback: F) -> Self pub fn on_panic<F>(mut self, callback: F) -> Self
where F: FnOnce(Box<dyn Any + Send>) + Send + Sync + 'static, where
F: FnOnce(Box<dyn Any + Send>) + Send + Sync + 'static,
{ {
self.panicked = Some(Box::new(callback)); self.panicked = Some(Box::new(callback));
self self
} }
} }
impl<R> Future for RecoverableHandle<R> { impl<R> Future for RecoverableHandle<R> {
@ -102,7 +102,7 @@ impl<R> Future for RecoverableHandle<R> {
} }
Poll::Ready(None) Poll::Ready(None)
}, }
} }
} }
} }

View File

@ -73,25 +73,22 @@ bitflags::bitflags! {
#[repr(packed)] #[repr(packed)]
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)] #[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub struct State { pub struct State {
bytes: [u8; 8] bytes: [u8; 8],
} }
impl State { impl State {
#[inline(always)] #[inline(always)]
pub const fn new(flags: StateFlags, references: u32) -> Self { pub const fn new(flags: StateFlags, references: u32) -> Self {
let [a,b,c,d] = references.to_ne_bytes(); let [a, b, c, d] = references.to_ne_bytes();
let [e,f,g,h] = flags.bits.to_ne_bytes(); let [e, f, g, h] = flags.bits.to_ne_bytes();
Self::from_bytes([a,b,c,d,e,f,g,h]) Self::from_bytes([a, b, c, d, e, f, g, h])
} }
#[inline(always)] #[inline(always)]
pub const fn parts(self: Self) -> (StateFlags, u32) { pub const fn parts(self: Self) -> (StateFlags, u32) {
let [a,b,c,d,e,f,g,h] = self.bytes; let [a, b, c, d, e, f, g, h] = self.bytes;
let refcount = u32::from_ne_bytes([a,b,c,d]); let refcount = u32::from_ne_bytes([a, b, c, d]);
let state = unsafe { let state = unsafe { StateFlags::from_bits_unchecked(u32::from_ne_bytes([e, f, g, h])) };
StateFlags::from_bits_unchecked(u32::from_ne_bytes([e,f,g,h]))
};
(state, refcount) (state, refcount)
} }
@ -101,8 +98,8 @@ impl State {
/// Note that the reference counter only tracks the `LightProc` and `Waker`s. The `ProcHandle` is /// Note that the reference counter only tracks the `LightProc` and `Waker`s. The `ProcHandle` is
/// tracked separately by the `HANDLE` flag. /// tracked separately by the `HANDLE` flag.
pub const fn get_refcount(self) -> u32 { pub const fn get_refcount(self) -> u32 {
let [a,b,c,d,_,_,_,_] = self.bytes; let [a, b, c, d, _, _, _, _] = self.bytes;
u32::from_ne_bytes([a,b,c,d]) u32::from_ne_bytes([a, b, c, d])
} }
#[inline(always)] #[inline(always)]
@ -116,7 +113,7 @@ impl State {
#[inline(always)] #[inline(always)]
pub const fn get_flags(self) -> StateFlags { pub const fn get_flags(self) -> StateFlags {
let [_, _, _, _, e, f, g, h] = self.bytes; let [_, _, _, _, e, f, g, h] = self.bytes;
unsafe { StateFlags::from_bits_unchecked(u32::from_ne_bytes([e,f,g,h])) } unsafe { StateFlags::from_bits_unchecked(u32::from_ne_bytes([e, f, g, h])) }
} }
#[inline(always)] #[inline(always)]
@ -207,10 +204,10 @@ impl AtomicState {
current: State, current: State,
new: State, new: State,
success: Ordering, success: Ordering,
failure: Ordering failure: Ordering,
) -> Result<State, State> ) -> Result<State, State> {
{ self.inner
self.inner.compare_exchange(current.into_u64(), new.into_u64(), success, failure) .compare_exchange(current.into_u64(), new.into_u64(), success, failure)
.map(|u| State::from_u64(u)) .map(|u| State::from_u64(u))
.map_err(|u| State::from_u64(u)) .map_err(|u| State::from_u64(u))
} }
@ -220,37 +217,37 @@ impl AtomicState {
current: State, current: State,
new: State, new: State,
success: Ordering, success: Ordering,
failure: Ordering failure: Ordering,
) -> Result<State, State> ) -> Result<State, State> {
{ self.inner
self.inner.compare_exchange_weak(current.into_u64(), new.into_u64(), success, failure) .compare_exchange_weak(current.into_u64(), new.into_u64(), success, failure)
.map(|u| State::from_u64(u)) .map(|u| State::from_u64(u))
.map_err(|u| State::from_u64(u)) .map_err(|u| State::from_u64(u))
} }
pub fn fetch_or(&self, val: StateFlags, order: Ordering) -> State { pub fn fetch_or(&self, val: StateFlags, order: Ordering) -> State {
let [a,b,c,d] = val.bits.to_ne_bytes(); let [a, b, c, d] = val.bits.to_ne_bytes();
let store = u64::from_ne_bytes([0,0,0,0,a,b,c,d]); let store = u64::from_ne_bytes([0, 0, 0, 0, a, b, c, d]);
State::from_u64(self.inner.fetch_or(store, order)) State::from_u64(self.inner.fetch_or(store, order))
} }
pub fn fetch_and(&self, val: StateFlags, order: Ordering) -> State { pub fn fetch_and(&self, val: StateFlags, order: Ordering) -> State {
let [a,b,c,d] = val.bits.to_ne_bytes(); let [a, b, c, d] = val.bits.to_ne_bytes();
let store = u64::from_ne_bytes([!0,!0,!0,!0,a,b,c,d]); let store = u64::from_ne_bytes([!0, !0, !0, !0, a, b, c, d]);
State::from_u64(self.inner.fetch_and(store, order)) State::from_u64(self.inner.fetch_and(store, order))
} }
// FIXME: Do this properly // FIXME: Do this properly
pub fn fetch_add(&self, val: u32, order: Ordering) -> State { pub fn fetch_add(&self, val: u32, order: Ordering) -> State {
let [a,b,c,d] = val.to_ne_bytes(); let [a, b, c, d] = val.to_ne_bytes();
let store = u64::from_ne_bytes([a,b,c,d,0,0,0,0]); let store = u64::from_ne_bytes([a, b, c, d, 0, 0, 0, 0]);
State::from_u64(self.inner.fetch_add(store, order)) State::from_u64(self.inner.fetch_add(store, order))
} }
// FIXME: Do this properly // FIXME: Do this properly
pub fn fetch_sub(&self, val: u32, order: Ordering) -> State { pub fn fetch_sub(&self, val: u32, order: Ordering) -> State {
let [a,b,c,d] = val.to_ne_bytes(); let [a, b, c, d] = val.to_ne_bytes();
let store = u64::from_ne_bytes([a,b,c,d,0,0,0,0]); let store = u64::from_ne_bytes([a, b, c, d, 0, 0, 0, 0]);
State::from_u64(self.inner.fetch_sub(store, order)) State::from_u64(self.inner.fetch_sub(store, order))
} }
} }

3
tools/git-pre-commit-hook Executable file
View File

@ -0,0 +1,3 @@
#!/usr/bin/env bash
cargo fmt --all

12
tools/git-pre-push-hook Executable file
View File

@ -0,0 +1,12 @@
#!/usr/bin/env bash
echo -e "Checking code formatting:\n=========================\n\n" 1>&2
cargo fmt --check
if [[ $? -ne 0 ]]
then
o=$?
echo -e "\n\nRun \`cargo fmt --all\` before pushing please." 1>&2
exit $o
fi