From a1355aaa6aa9ab784280ea1146a0394c9cbf6713 Mon Sep 17 00:00:00 2001 From: Gregor Reitzenstein Date: Fri, 10 Sep 2021 21:19:30 +0200 Subject: [PATCH] Fix all tests --- src/db.rs | 82 +++++++++++++++++++++++++++++++++++++++++++----- src/db/access.rs | 5 +++ 2 files changed, 79 insertions(+), 8 deletions(-) diff --git a/src/db.rs b/src/db.rs index 8697c85..5f688d5 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use std::path::PathBuf; use std::str::FromStr; +use std::ops::{Deref, DerefMut}; use slog::Logger; @@ -72,6 +73,10 @@ use lmdb::{ RoTransaction, RwTransaction, WriteFlags, + Cursor, + RoCursor, + RwCursor, + Iter, }; #[derive(Debug, Clone)] @@ -117,6 +122,14 @@ impl DB { txn.del(self.db, key, value.map(AsRef::as_ref)) } + pub fn iter<'txn, C: Cursor<'txn>>(&self, cursor: &'txn mut C) -> Iter<'txn> { + cursor.iter_start() + } + + pub fn open_ro_cursor<'txn, T: Transaction>(&self, txn: &'txn T) -> lmdb::Result> { + txn.open_ro_cursor(self.db) + } + pub fn begin_ro_txn<'env>(&'env self) -> lmdb::Result> { self.env.begin_ro_txn() } @@ -147,23 +160,23 @@ fn bincode_default() -> impl bincode::Options { use std::marker::PhantomData; -pub struct TypedDatabase<'a, A, V: ?Sized> { +pub struct Objectstore<'a, A, V: ?Sized> { pub db: DB, - adapter: A, + adapter: PhantomData, marker: PhantomData<&'a V> } -impl TypedDatabase<'_, A, V> { - pub fn new(db: DB, adapter: A) -> Self { +impl Objectstore<'_, A, V> { + pub fn new(db: DB) -> Self { Self { db: db, - adapter: adapter, + adapter: PhantomData, marker: PhantomData, } } } -impl<'txn, A, V> TypedDatabase<'txn, A, V> +impl<'txn, A, V> Objectstore<'txn, A, V> where A: DatabaseAdapter, V: ?Sized + Serialize + Deserialize<'txn>, { @@ -205,6 +218,11 @@ impl<'txn, A, V> TypedDatabase<'txn, A, V> ) } + pub fn iter(&self, txn: &'txn T) -> StdResult, A::Err> { + let mut cursor = self.db.open_ro_cursor(txn)?; + let iter = cursor.iter_start(); + Ok(ObjectIter::new(cursor, iter)) + } pub fn put(&self, txn: &'txn mut RwTransaction, key: &A::Key, value: &V, flags: lmdb::WriteFlags) -> StdResult<(), A::Err> @@ -228,6 +246,38 @@ impl<'txn, A, V> TypedDatabase<'txn, A, V> } } +pub struct ObjectIter<'txn, A, V: ?Sized> { + cursor: RoCursor<'txn>, + inner: Iter<'txn>, + + adapter: PhantomData, + marker: PhantomData<&'txn V>, +} + +impl<'txn, A, V: ?Sized> ObjectIter<'txn, A, V> { + pub fn new(cursor: RoCursor<'txn>, inner: Iter<'txn>) -> Self { + let marker = PhantomData; + let adapter = PhantomData; + Self { cursor, inner, adapter, marker } + } +} + +impl<'txn, A, V> Iterator for ObjectIter<'txn, A, V> + where A: DatabaseAdapter, + V: ?Sized + Serialize + Deserialize<'txn>, +{ + type Item = StdResult; + + fn next(&mut self) -> Option { + self.inner.next()? + .map_or_else( + |err| Some(Err(err.into())), + |(_, v)| Some(bincode_default().deserialize(v).map_err(|e| e.into())) + ) + } +} + + #[cfg(test)] mod tests { use super::*; @@ -301,7 +351,7 @@ mod tests { } } - type TestDB<'txn> = TypedDatabase<'txn, TestAdapter, &'txn str>; + type TestDB<'txn> = Objectstore<'txn, TestAdapter, &'txn str>; #[test] fn simple_get() { @@ -311,11 +361,15 @@ mod tests { let db = DB::new(e.env.clone(), ldb); let adapter = TestAdapter; - let testdb = TestDB::new(db.clone(), adapter); + let testdb = TestDB::new(db.clone()); let mut val = "value"; let mut txn = db.begin_rw_txn().expect("Failed to being rw txn"); testdb.put(&mut txn, "key", &val, WF::empty()).expect("Failed to insert"); + testdb.put(&mut txn, "key2", &val, WF::empty()).expect("Failed to insert"); + testdb.put(&mut txn, "key3", &val, WF::empty()).expect("Failed to insert"); + testdb.put(&mut txn, "key4", &val, WF::empty()).expect("Failed to insert"); + testdb.put(&mut txn, "key5", &val, WF::empty()).expect("Failed to insert"); txn.commit().expect("commit failed"); { @@ -339,5 +393,17 @@ mod tests { assert!(found); assert_eq!("longer_value", val); } + + { + let txn = db.begin_ro_txn().unwrap(); + let mut it = testdb.iter(&txn).unwrap(); + assert_eq!("longer_value", it.next().unwrap().unwrap()); + let mut i = 0; + while let Some(e) = it.next() { + assert_eq!("value", e.unwrap()); + i += 1; + } + assert_eq!(i, 4) + } } } diff --git a/src/db/access.rs b/src/db/access.rs index a47be69..a3fe9dc 100644 --- a/src/db/access.rs +++ b/src/db/access.rs @@ -549,6 +549,11 @@ mod tests { .expect("Couldn't load the example role defs. Does `examples/roles.toml` exist?"); let expected = vec![ + (RoleIdentifier { name: "anotherrole".to_string(), source: "lmdb".to_string() }, + Role { + parents: vec![], + permissions: vec![], + }), (RoleIdentifier { name: "testrole".to_string(), source: "lmdb".to_string() }, Role { parents: vec![],