use rkyv::ser::Serializer; use rkyv::ser::serializers::AllocSerializer; use thiserror::Error; use crate::db; use crate::db::{AlignedAdapter, ArchivedValue, RawDB, DB}; use lmdb::{DatabaseFlags, Environment, EnvironmentFlags, Transaction, WriteFlags}; use miette::Diagnostic; use std::fmt::Debug; use std::{path::Path, sync::Arc}; use crate::resources::state::State; #[derive(Debug, Clone)] pub struct StateDB { env: Arc, db: DB>, } #[derive(Clone, Debug, PartialEq, Eq, Error, Diagnostic)] pub enum StateDBError { #[error("opening the state db environment failed")] #[diagnostic( code(bffh::db::state::open_env), help("does the parent directory for state_db exist?") )] OpenEnv(#[source] db::Error), #[error("opening the state db failed")] #[diagnostic(code(bffh::db::state::open))] Open(#[source] db::Error), #[error("creating the state db failed")] #[diagnostic(code(bffh::db::state::create))] Create(#[source] db::Error), } impl StateDB { pub fn open_env>(path: P) -> Result, StateDBError> { Environment::new() .set_flags( EnvironmentFlags::WRITE_MAP | EnvironmentFlags::NO_SUB_DIR | EnvironmentFlags::NO_TLS | EnvironmentFlags::NO_READAHEAD, ) .set_max_dbs(8) .open(path.as_ref()) .map(Arc::new) .map_err(|e| StateDBError::OpenEnv(e.into())) } fn new(env: Arc, db: RawDB) -> Self { let db = DB::new(db); Self { env, db } } pub fn open_with_env(env: Arc) -> Result { let db = RawDB::open(&env, Some("state")) .map_err(|e| StateDBError::Open(e.into()))?; Ok(Self::new(env, db)) } pub fn open>(path: P) -> Result { let env = Self::open_env(path)?; Self::open_with_env(env) } pub fn create_with_env(env: Arc) -> Result { let flags = DatabaseFlags::empty(); let db = RawDB::create(&env, Some("state"), flags) .map_err(|e| StateDBError::Create(e.into()))?; Ok(Self::new(env, db)) } pub fn create>(path: P) -> Result { let env = Self::open_env(path)?; Self::create_with_env(env) } pub fn begin_ro_txn(&self) -> Result { self.env.begin_ro_txn().map_err(db::Error::from) } pub fn get(&self, key: impl AsRef<[u8]>) -> Result>, db::Error> { let txn = self.env.begin_ro_txn()?; self.db.get(&txn, &key.as_ref()) } pub fn get_all<'txn, T: Transaction>( &self, txn: &'txn T, ) -> Result)>, db::Error> { self.db.get_all(txn) } pub fn put(&self, key: &impl AsRef<[u8]>, val: &ArchivedValue) -> Result<(), db::Error> { let mut txn = self.env.begin_rw_txn()?; let flags = WriteFlags::empty(); self.db.put(&mut txn, key, val, flags)?; Ok(txn.commit()?) } pub fn load_map(&self, map: &std::collections::HashMap) -> miette::Result<()> { use miette::IntoDiagnostic; let mut txn = self.env.begin_rw_txn().into_diagnostic()?; let flags = WriteFlags::empty(); for (key, val) in map { let mut serializer = AllocSerializer::<1024>::default(); serializer.serialize_value(val).into_diagnostic()?; let serialized = ArchivedValue::new(serializer.into_serializer().into_inner()); self.db.put(&mut txn, &key.as_bytes(), &serialized, flags)?; } txn.commit().into_diagnostic()?; Ok(()) } pub fn dump_map(&self) -> miette::Result> { let mut map = std::collections::HashMap::new(); for (key, val) in self.get_all(&self.begin_ro_txn()?)? { let key_str = core::str::from_utf8(&key).map_err(|_e| miette::Error::msg("state key not UTF8"))?.to_string(); let val_state: State = rkyv::Deserialize::deserialize(val.as_ref(), &mut rkyv::Infallible).unwrap(); map.insert(key_str, val_state); } Ok(map) } } #[cfg(test)] mod tests { use super::*; use std::ops::Deref; }