224 lines
7.5 KiB
Python
Executable file
224 lines
7.5 KiB
Python
Executable file
#!/usr/bin/python3
|
|
|
|
from newfol.database import DatabaseVersion, Database, Schema, Singleton
|
|
from newfol.filemanip import Record
|
|
import tempfile
|
|
import unittest
|
|
|
|
|
|
class TestDatabaseVersion(unittest.TestCase):
|
|
def test_default_serialization(self):
|
|
self.assertEqual(DatabaseVersion().serialization(), "csv")
|
|
|
|
def test_preferred_serialization(self):
|
|
self.assertEqual(DatabaseVersion.preferred().serialization(), "json")
|
|
|
|
def test_preferred_serialization_version(self):
|
|
self.assertEqual(DatabaseVersion.preferred().serialization_version(),
|
|
2)
|
|
|
|
def test_preferred_compression(self):
|
|
self.assertEqual(DatabaseVersion.preferred().compression(), None)
|
|
|
|
def test_preferred_compression_version(self):
|
|
self.assertEqual(DatabaseVersion.preferred().compression_version(), 0)
|
|
|
|
def test_preferred_record_version(self):
|
|
self.assertEqual(DatabaseVersion.preferred().record_version(), 3)
|
|
|
|
def test_preferred_version_version(self):
|
|
self.assertEqual(DatabaseVersion.preferred().version_version(), 0)
|
|
|
|
|
|
class TestDatabaseAccessors(unittest.TestCase):
|
|
def test_location(self):
|
|
for location in ("/tmp/foo", "/bar", "/home/quux/.newfol"):
|
|
obj = Database(location, 2, Schema(), [])
|
|
self.assertEqual(obj.location(), location)
|
|
|
|
def test_records(self):
|
|
obj = Database("/tmp/foo", 2, Schema(), [Record([1, 2, 3])])
|
|
records = obj.records()
|
|
self.assertEqual(len(records), 1)
|
|
self.assertEqual(len(records[0].fields), 3)
|
|
|
|
def test_version(self):
|
|
obj = Database("/tmp/foo", DatabaseVersion(), Schema(), [])
|
|
self.assertEqual(obj.version(), 0)
|
|
|
|
obj = Database("/tmp/foo", DatabaseVersion.preferred(), Schema(), [])
|
|
self.assertEqual(obj.version(), DatabaseVersion.preferred())
|
|
|
|
def test_serialization(self):
|
|
obj = Database("/tmp/foo", DatabaseVersion(), Schema(), [])
|
|
self.assertEqual(obj.serialization(),
|
|
DatabaseVersion(0).serialization())
|
|
|
|
obj = Database("/tmp/foo", DatabaseVersion.preferred(), Schema(), [])
|
|
self.assertEqual(obj.serialization(),
|
|
DatabaseVersion.preferred().serialization())
|
|
|
|
|
|
class TestDatabaseIntegrity(unittest.TestCase):
|
|
def create_temp_db(self):
|
|
ddir = tempfile.TemporaryDirectory()
|
|
with open(ddir.name + "/schema", "w") as fp:
|
|
fp.write("fmt:0:newfol schema file:\ntxn:git\n")
|
|
db = Database.load(ddir.name)
|
|
return (ddir, db)
|
|
|
|
def test_version(self):
|
|
ddir, db = self.create_temp_db()
|
|
self.assertEqual(Database.read_version(ddir.name),
|
|
DatabaseVersion())
|
|
db.store()
|
|
db.upgrade()
|
|
self.assertEqual(Database.read_version(ddir.name),
|
|
DatabaseVersion.preferred())
|
|
ddir.cleanup()
|
|
|
|
def test_validate(self):
|
|
ddir, db = self.create_temp_db()
|
|
db.store()
|
|
db.validate()
|
|
ddir.cleanup()
|
|
|
|
def test_validate_strict(self):
|
|
ddir, db = self.create_temp_db()
|
|
db.store()
|
|
db.validate(strict=True)
|
|
ddir.cleanup()
|
|
|
|
def test_repair_doesnt_raise(self):
|
|
ddir, db = self.create_temp_db()
|
|
db.store()
|
|
db.repair()
|
|
db.validate(strict=True)
|
|
ddir.cleanup()
|
|
|
|
def test_upgrade_records(self):
|
|
ddir, db = self.create_temp_db()
|
|
db.store()
|
|
db.upgrade_records()
|
|
ddir.cleanup()
|
|
|
|
|
|
class TestDatabaseUpgrades(unittest.TestCase):
|
|
def create_temp_db(self):
|
|
ddir = tempfile.TemporaryDirectory()
|
|
with open(ddir.name + "/schema", "w") as fp:
|
|
fp.write("fmt:0:newfol schema file:\ntxn:git\n")
|
|
db = Database.load(ddir.name)
|
|
return (ddir, db)
|
|
|
|
def do_upgrade_test(self, version, pattern):
|
|
ddir, db = self.create_temp_db()
|
|
if not isinstance(version, DatabaseVersion):
|
|
version = DatabaseVersion(version)
|
|
self.assertEqual(Database.read_version(ddir.name),
|
|
DatabaseVersion())
|
|
db.store()
|
|
db.upgrade(version=version)
|
|
self.assertEqual(Database.read_version(ddir.name),
|
|
version)
|
|
with open(ddir.name + "/dtb", "rb") as fp:
|
|
data = fp.read(len(pattern))
|
|
self.assertEqual(data, pattern)
|
|
ddir.cleanup()
|
|
|
|
def test_upgrade_to_xz(self):
|
|
version = (DatabaseVersion.preferred() & ~0x00ff0000) | 0x00010000
|
|
self.do_upgrade_test(version, b"\xfd7zXZ\x00")
|
|
|
|
def test_upgrade_to_json(self):
|
|
version = (DatabaseVersion.preferred() & ~0x0000ff00) | 0x00000200
|
|
self.do_upgrade_test(version, b"[")
|
|
|
|
def test_upgrade_to_pickle(self):
|
|
version = (DatabaseVersion.preferred() & ~0x0000ff00) | 0x00000100
|
|
self.do_upgrade_test(version, b"(lp0\n")
|
|
|
|
|
|
class TestMultipleTransactions(unittest.TestCase):
|
|
def test_multiple_types(self):
|
|
ddir = tempfile.TemporaryDirectory()
|
|
with open(ddir.name + "/schema", "w") as fp:
|
|
fp.write("fmt:0:newfol schema file:\ntxn:git:hash\n")
|
|
db = Database.load(ddir.name)
|
|
db.records()[:] = [Record([1, 2, 3])]
|
|
db.store()
|
|
self.assertEqual(set(db.schema().transaction_types()),
|
|
set(["git", "hash"]))
|
|
ddir.cleanup()
|
|
|
|
|
|
class TestExtraSchemaConfig(unittest.TestCase):
|
|
def test_existing_config_file(self):
|
|
ddir1 = tempfile.TemporaryDirectory()
|
|
ddir2 = tempfile.TemporaryDirectory()
|
|
config = "%s/config" % ddir2.name
|
|
with open(ddir1.name + "/schema", "w") as fp:
|
|
fp.write("fmt:3:newfol schema file:\n")
|
|
with open(config, "w") as fp:
|
|
fp.write("fmt:3:newfol config file:\ntxn:git:hash\n")
|
|
db = Database.load(ddir1.name, extra_config=[config])
|
|
db.records()[:] = [Record([1, 2, 3])]
|
|
db.store()
|
|
self.assertEqual(set(db.schema().transaction_types()),
|
|
set(["git", "hash"]))
|
|
ddir1.cleanup()
|
|
ddir2.cleanup()
|
|
|
|
def test_missing_config_file(self):
|
|
ddir1 = tempfile.TemporaryDirectory()
|
|
config = "%s/config" % ddir1.name
|
|
with open(ddir1.name + "/schema", "w") as fp:
|
|
fp.write("fmt:3:newfol schema file:\n")
|
|
db = Database.load(ddir1.name, extra_config=[config])
|
|
db.records()[:] = [Record([1, 2, 3])]
|
|
db.store()
|
|
# Ensure no exception is raised.
|
|
ddir1.cleanup()
|
|
|
|
|
|
class TestDatabaseFiltering(unittest.TestCase):
|
|
def create_temp_db(self):
|
|
ddir = tempfile.TemporaryDirectory()
|
|
with open(ddir.name + "/schema", "w") as fp:
|
|
fp.write("fmt:0:newfol schema file:\ntxn:git\n")
|
|
db = Database.load(ddir.name)
|
|
return (ddir, db)
|
|
|
|
def test_filtering(self):
|
|
ddir, db = self.create_temp_db()
|
|
records = [
|
|
Record(["a", "b", "c"]),
|
|
Record([1, 2, 3]),
|
|
Record([7, 8, 9])
|
|
]
|
|
db.records().extend(records)
|
|
|
|
def has_only_numbers(rec):
|
|
for field in rec.fields:
|
|
try:
|
|
int(field)
|
|
except:
|
|
return False
|
|
return True
|
|
selected = db.records(has_only_numbers)
|
|
self.assertEqual(type(selected), list)
|
|
self.assertEqual(set(selected), set(records[1:]))
|
|
ddir.cleanup()
|
|
|
|
|
|
class TestSingleton(unittest.TestCase):
|
|
def test_singleton(self):
|
|
def TestClass(metaclass=Singleton):
|
|
pass
|
|
a = TestClass()
|
|
b = TestClass()
|
|
self.assertIs(a, b)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|