diff --git a/src/access.rs b/src/access.rs index 1e721de..9dfec84 100644 --- a/src/access.rs +++ b/src/access.rs @@ -1,31 +1,107 @@ //! Access control logic //! +use std::collections::HashSet; + +use flexbuffers; +use serde::{Serialize, Deserialize}; + use slog::Logger; +use lmdb::{Transaction, RoTransaction, RwTransaction}; use crate::config::Config; +use crate::error::Result; +type UserIdentifier = u64; +type RoleIdentifier = u64; +type PermIdentifier = u64; pub struct PermissionsProvider { log: Logger, + roledb: lmdb::Database, + permdb: lmdb::Database, + userdb: lmdb::Database, } impl PermissionsProvider { - pub fn new(log: Logger) -> Self { - Self { log } + pub fn new(log: Logger, roledb: lmdb::Database, permdb: lmdb::Database, userdb: lmdb::Database) -> Self { + Self { log, roledb, permdb, userdb } + } + + /// Check if a given user has the given permission + #[allow(unused)] + pub fn check(&self, txn: &T, userID: UserIdentifier, permID: PermIdentifier) -> Result { + if let Some(user) = self.get_user(txn, userID)? { + // Tally all roles. Makes dependent roles easier + let mut roles = HashSet::new(); + for roleID in user.roles { + self.tally_role(txn, &mut roles, roleID)?; + } + + // Iter all unique role->permissions we've found and early return on match. + // TODO: Change this for negative permissions? + for role in roles.iter() { + for perm in role.permissions.iter() { + if permID == *perm { + return Ok(true); + } + } + } + } + + return Ok(false); + } + + fn tally_role(&self, txn: &T, roles: &mut HashSet, roleID: RoleIdentifier) -> Result<()> { + if let Some(role) = self.get_role(txn, roleID)? { + // Only check and tally parents of a role at the role itself if it's the first time we + // see it + if !roles.contains(&role) { + for parent in role.parents.iter() { + self.tally_role(txn, roles, *parent)?; + } + + roles.insert(role); + } + } + + Ok(()) + } + + fn get_role<'txn, T: Transaction>(&self, txn: &'txn T, roleID: RoleIdentifier) -> Result> { + match txn.get(self.roledb, &roleID.to_ne_bytes()) { + Ok(bytes) => { + Ok(Some(flexbuffers::from_slice(bytes)?)) + }, + Err(lmdb::Error::NotFound) => { Ok(None) }, + Err(e) => { Err(e.into()) } + } + } + + fn get_user(&self, txn: &T, userID: UserIdentifier) -> Result> { + match txn.get(self.userdb, &userID.to_ne_bytes()) { + Ok(bytes) => { + Ok(Some(flexbuffers::from_slice(bytes)?)) + }, + Err(lmdb::Error::NotFound) => { Ok(None) }, + Err(e) => { Err(e.into()) } + } } } /// This line documents init pub fn init(log: Logger, config: &Config, env: &lmdb::Environment) -> std::result::Result { - return Ok(PermissionsProvider::new(log)); + let mut flags = lmdb::DatabaseFlags::empty(); + flags.set(lmdb::DatabaseFlags::INTEGER_KEY, true); + let roledb = env.create_db(Some("role"), flags)?; + let permdb = env.create_db(Some("perm"), flags)?; + let userdb = env.create_db(Some("user"), flags)?; + return Ok(PermissionsProvider::new(log, roledb, permdb, userdb)); } -type RoleIdentifier = u64; -type PermIdentifier = u64; - /// A Person, from the Authorization perspective -struct Person { +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +struct User { name: String, /// A Person has N ≥ 0 roles. @@ -45,6 +121,7 @@ struct Person { /// of a machine; if later on a similar enough machine is put to use the administrator can just add /// the permission for that machine to an already existing role instead of manually having to /// assign to all users. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] struct Role { name: String, @@ -61,6 +138,7 @@ struct Role { /// /// Permissions are rather simple flags. A person can have or not have a permission, dictated by /// its roles and the permissions assigned to those roles. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] struct Permission { name: String, } diff --git a/src/error.rs b/src/error.rs index a746d3a..e7c0dcd 100644 --- a/src/error.rs +++ b/src/error.rs @@ -13,6 +13,8 @@ pub enum Error { Boxed(Box), Capnp(capnp::Error), LMDB(lmdb::Error), + FlexbuffersDe(flexbuffers::DeserializationError), + FlexbuffersSer(flexbuffers::SerializationError), } impl fmt::Display for Error { @@ -39,6 +41,12 @@ impl fmt::Display for Error { Error::LMDB(e) => { write!(f, "LMDB Error: {}", e) }, + Error::FlexbuffersDe(e) => { + write!(f, "Flexbuffers decoding error: {}", e) + }, + Error::FlexbuffersSer(e) => { + write!(f, "Flexbuffers encoding error: {}", e) + }, } } } @@ -85,4 +93,16 @@ impl From for Error { } } +impl From for Error { + fn from(e: flexbuffers::DeserializationError) -> Error { + Error::FlexbuffersDe(e) + } +} + +impl From for Error { + fn from(e: flexbuffers::SerializationError) -> Error { + Error::FlexbuffersSer(e) + } +} + pub type Result = std::result::Result;