#!/usr/bin/python3 from newfol.exception import SchemaError from newfol.database import DatabaseVersion, Database, Schema, Singleton from newfol.filemanip import Record import tempfile import unittest class TestDatabaseVersion(unittest.TestCase): def test_default_serialization(self): self.assertEqual(DatabaseVersion().serialization(), "csv") def test_preferred_serialization(self): self.assertEqual(DatabaseVersion.preferred().serialization(), "json") def test_preferred_serialization_version(self): self.assertEqual(DatabaseVersion.preferred().serialization_version(), 2) def test_preferred_compression(self): self.assertEqual(DatabaseVersion.preferred().compression(), None) def test_preferred_compression_version(self): self.assertEqual(DatabaseVersion.preferred().compression_version(), 0) def test_preferred_record_version(self): self.assertEqual(DatabaseVersion.preferred().record_version(), 3) def test_preferred_version_version(self): self.assertEqual(DatabaseVersion.preferred().version_version(), 0) class TestDatabaseAccessors(unittest.TestCase): def test_location(self): for location in ("/tmp/foo", "/bar", "/home/quux/.newfol"): obj = Database(location, 2, Schema(), []) self.assertEqual(obj.location(), location) def test_records(self): obj = Database("/tmp/foo", 2, Schema(), [Record([1, 2, 3])]) records = obj.records() self.assertEqual(len(records), 1) self.assertEqual(len(records[0].fields), 3) def test_version(self): obj = Database("/tmp/foo", DatabaseVersion(), Schema(), []) self.assertEqual(obj.version(), 0) obj = Database("/tmp/foo", DatabaseVersion.preferred(), Schema(), []) self.assertEqual(obj.version(), DatabaseVersion.preferred()) def test_serialization(self): obj = Database("/tmp/foo", DatabaseVersion(), Schema(), []) self.assertEqual(obj.serialization(), DatabaseVersion(0).serialization()) obj = Database("/tmp/foo", DatabaseVersion.preferred(), Schema(), []) self.assertEqual(obj.serialization(), DatabaseVersion.preferred().serialization()) class TemporaryDatabase: def __init__(self, schema_contents=""): self.ddir = tempfile.TemporaryDirectory() with open(self.ddir.name + "/schema", "w") as fp: fp.write("fmt:3:newfol schema file:\n" + schema_contents) @property def db(self): return Database.load(self.ddir.name) def __del__(self): self.ddir.cleanup() class TestDatabaseIntegrity(unittest.TestCase): def test_version(self): tdb = TemporaryDatabase() self.assertEqual(Database.read_version(tdb.ddir.name), DatabaseVersion()) tdb.db.store() tdb.db.upgrade() self.assertEqual(Database.read_version(tdb.ddir.name), DatabaseVersion.preferred()) def test_validate(self): tdb = TemporaryDatabase() tdb.db.store() tdb.db.validate() def test_validate_strict(self): tdb = TemporaryDatabase() tdb.db.store() tdb.db.validate(strict=True) def test_repair_doesnt_raise(self): tdb = TemporaryDatabase() tdb.db.store() tdb.db.repair() tdb.db.validate(strict=True) def test_upgrade_records(self): tdb = TemporaryDatabase() tdb.db.store() tdb.db.upgrade_records() class TestDatabaseUpgrades(unittest.TestCase): def do_upgrade_test(self, version, pattern): tdb = TemporaryDatabase("txn:git\n") ddir, db = tdb.ddir, tdb.db if not isinstance(version, DatabaseVersion): version = DatabaseVersion(version) self.assertEqual(Database.read_version(ddir.name), DatabaseVersion()) db.store() db.upgrade(version=version) self.assertEqual(Database.read_version(ddir.name), version) with open(ddir.name + "/dtb", "rb") as fp: data = fp.read(len(pattern)) self.assertEqual(data, pattern) ddir.cleanup() def test_upgrade_to_xz(self): version = (DatabaseVersion.preferred() & ~0x00ff0000) | 0x00010000 self.do_upgrade_test(version, b"\xfd7zXZ\x00") def test_upgrade_to_json(self): version = (DatabaseVersion.preferred() & ~0x0000ff00) | 0x00000200 self.do_upgrade_test(version, b"[") def test_upgrade_to_pickle(self): version = (DatabaseVersion.preferred() & ~0x0000ff00) | 0x00000100 self.do_upgrade_test(version, b"(lp0\n") class TestMultipleTransactions(unittest.TestCase): def test_multiple_types(self): tdb = TemporaryDatabase("txn:git:hash\n") ddir = tdb.ddir db = Database.load(ddir.name) db.records()[:] = [Record([1, 2, 3])] db.store() self.assertEqual(set(db.schema().transaction_types()), set(["git", "hash"])) class TestExtraSchemaConfig(unittest.TestCase): def test_existing_config_file(self): tdb = TemporaryDatabase() ddir2 = tempfile.TemporaryDirectory() config = "%s/config" % ddir2.name with open(config, "w") as fp: fp.write("fmt:3:newfol config file:\ntxn:git:hash\n") db = Database.load(tdb.ddir.name, extra_config=[config]) db.records()[:] = [Record([1, 2, 3])] db.store() self.assertEqual(set(db.schema().transaction_types()), set(["git", "hash"])) ddir2.cleanup() def test_missing_config_file(self): tdb = TemporaryDatabase() config = "%s/config" % tdb.ddir.name db = Database.load(tdb.ddir.name, extra_config=[config]) db.records()[:] = [Record([1, 2, 3])] db.store() # Ensure no exception is raised. class TestSchemaColumns(unittest.TestCase): def check_invalid(self, value): with self.assertRaises(SchemaError): tdb = TemporaryDatabase("col:%s\n" % value) tdb.db def test_non_integral(self): self.check_invalid(3.5) def test_negative(self): self.check_invalid(-1) def test_zero(self): self.check_invalid(0) class TestExecutionAllowed(unittest.TestCase): def do_test(self, expected, schema, configv): tdb = TemporaryDatabase(schema) ddir2 = tempfile.TemporaryDirectory() config = "%s/config" % ddir2.name with open(config, "w") as fp: fp.write("fmt:3:newfol config file:\n" + configv) db = Database.load(tdb.ddir.name, extra_config=[config]) db.records()[:] = [Record([1, 2, 3])] db.store() self.assertEqual(db.schema().execution_allowed(), expected) ddir2.cleanup() def test_true_if_only_true_schema(self): self.do_test(True, "exe:yes\n", "") def test_true_if_only_true_config(self): self.do_test(True, "", "exe:yes\n") def test_false_if_schema_false(self): self.do_test(False, "exe:no\n", "exe:yes\n") def test_false_if_config_false(self): self.do_test(False, "exe:yes\n", "exe:no\n") class TestDatabaseFiltering(unittest.TestCase): def test_filtering(self): tdb = TemporaryDatabase("txn:git\n") db = tdb.db records = [ Record(["a", "b", "c"]), Record([1, 2, 3]), Record([7, 8, 9]) ] db.records().extend(records) def has_only_numbers(rec): for field in rec.fields: try: int(field) except: return False return True selected = db.records(has_only_numbers) self.assertEqual(type(selected), list) self.assertEqual(set(selected), set(records[1:])) class TestSingleton(unittest.TestCase): def test_singleton(self): def TestClass(metaclass=Singleton): pass a = TestClass() b = TestClass() self.assertIs(a, b) if __name__ == '__main__': unittest.main()