newfol/test/testdatabase.py
brian m. carlson 63345a969d
Prefer JSON over pickle.
This allows for better diffing, and it also makes it possible to load
untrusted data from the dtb file if that becomes necessary.

Signed-off-by: brian m. carlson <sandals@crustytoothpaste.net>
2014-02-14 02:58:47 +00:00

167 lines
6.2 KiB
Python
Executable file

#!/usr/bin/python3
from newfol.database import DatabaseVersion, Database, Schema, Singleton
from newfol.filemanip import Record
import os
import os.path
import pwd
import socket
import tempfile
import unittest
class TestDatabaseVersion(unittest.TestCase):
def test_default_serialization(self):
self.assertEqual(DatabaseVersion().serialization(), "csv")
def test_preferred_serialization(self):
self.assertEqual(DatabaseVersion.preferred().serialization(), "json")
def test_preferred_serialization_version(self):
self.assertEqual(DatabaseVersion.preferred().serialization_version(), 2)
def test_preferred_compression(self):
self.assertEqual(DatabaseVersion.preferred().compression(), None)
def test_preferred_compression_version(self):
self.assertEqual(DatabaseVersion.preferred().compression_version(), 0)
def test_preferred_record_version(self):
self.assertEqual(DatabaseVersion.preferred().record_version(), 3)
def test_preferred_version_version(self):
self.assertEqual(DatabaseVersion.preferred().version_version(), 0)
class TestDatabaseAccessors(unittest.TestCase):
def test_location(self):
for location in ("/tmp/foo", "/bar", "/home/quux/.newfol"):
obj = Database(location, 2, Schema(), [])
self.assertEqual(obj.location(), location)
def test_records(self):
obj = Database("/tmp/foo", 2, Schema(), [Record([1, 2, 3])])
records = obj.records()
self.assertEqual(len(records), 1)
self.assertEqual(len(records[0].fields), 3)
def test_version(self):
obj = Database("/tmp/foo", DatabaseVersion(), Schema(), [])
self.assertEqual(obj.version(), 0)
obj = Database("/tmp/foo", DatabaseVersion.preferred(), Schema(), [])
self.assertEqual(obj.version(), DatabaseVersion.preferred())
def test_serialization(self):
obj = Database("/tmp/foo", DatabaseVersion(), Schema(), [])
self.assertEqual(obj.serialization(),
DatabaseVersion(0).serialization())
obj = Database("/tmp/foo", DatabaseVersion.preferred(), Schema(), [])
self.assertEqual(obj.serialization(),
DatabaseVersion.preferred().serialization())
class TestDatabaseIntegrity(unittest.TestCase):
def create_temp_db(self):
ddir = tempfile.TemporaryDirectory()
fp = open(ddir.name + "/schema", "w")
fp.write("fmt:0:newfol schema file:\ntxn:git\n")
fp.close()
db = Database.load(ddir.name)
return (ddir, db)
def test_version(self):
ddir, db = self.create_temp_db()
self.assertEqual(Database.read_version(ddir.name),
DatabaseVersion())
db.store()
db.upgrade()
self.assertEqual(Database.read_version(ddir.name),
DatabaseVersion.preferred())
ddir.cleanup()
def test_validate(self):
ddir, db = self.create_temp_db()
db.store()
db.validate()
ddir.cleanup()
def test_validate_strict(self):
ddir, db = self.create_temp_db()
db.store()
db.validate(strict=True)
ddir.cleanup()
def test_repair_doesnt_raise(self):
ddir, db = self.create_temp_db()
db.store()
db.repair()
db.validate(strict=True)
ddir.cleanup()
def test_upgrade_records(self):
ddir, db = self.create_temp_db()
db.store()
db.upgrade_records()
ddir.cleanup()
class TestDatabaseUpgrades(unittest.TestCase):
def create_temp_db(self):
ddir = tempfile.TemporaryDirectory()
fp = open(ddir.name + "/schema", "w")
fp.write("fmt:0:newfol schema file:\ntxn:git\n")
fp.close()
db = Database.load(ddir.name)
return (ddir, db)
def do_upgrade_test(self, version, pattern):
ddir, db = self.create_temp_db()
if not isinstance(version, DatabaseVersion):
version = DatabaseVersion(version)
self.assertEqual(Database.read_version(ddir.name),
DatabaseVersion())
db.store()
db.upgrade(version=version)
self.assertEqual(Database.read_version(ddir.name),
version)
fp = open(ddir.name + "/dtb", "rb")
data = fp.read(len(pattern))
self.assertEqual(data, pattern)
fp.close()
ddir.cleanup()
def test_upgrade_to_xz(self):
version = (DatabaseVersion.preferred() & ~0x00ff0000) | 0x00010000
self.do_upgrade_test(version, b"\xfd7zXZ\x00")
def test_upgrade_to_json(self):
version = (DatabaseVersion.preferred() & ~0x0000ff00) | 0x00000200
self.do_upgrade_test(version, b"[")
def test_upgrade_to_json(self):
version = (DatabaseVersion.preferred() & ~0x0000ff00) | 0x00000200
self.do_upgrade_test(version, b"[")
def test_upgrade_to_pickle(self):
version = (DatabaseVersion.preferred() & ~0x0000ff00) | 0x00000100
self.do_upgrade_test(version, b"(lp0\n")
class TestGitTransactions(unittest.TestCase):
def setUp(self):
self.ddir = tempfile.TemporaryDirectory()
fp = open(self.ddir.name + "/schema", "w")
fp.write("fmt:0:newfol schema file:\ntxn:git\n")
fp.close()
self.db = Database.load(self.ddir.name)
def test_has_git_dir(self):
self.db.store()
self.assertTrue(os.path.isdir(self.ddir.name + "/.git"))
def get_format_data(self, fmt):
cwd = os.getcwd()
data = None
try:
os.chdir(self.ddir.name)
fp = os.popen("git log --pretty=format:" + fmt + " -n1")
data = fp.read()
fp.close()
finally:
os.chdir(cwd)
return data
def test_correct_committer_name(self):
self.db.store()
self.assertEqual(self.get_format_data("%cn"), "newfol")
def test_correct_committer_email(self):
self.db.store()
user = pwd.getpwuid(os.getuid())[0]
fqdn = socket.getfqdn()
self.assertEqual(self.get_format_data("%ce"), user + "@" + fqdn)
class TestSingleton(unittest.TestCase):
def test_singleton(self):
def TestClass(metaclass=Singleton):
pass
a = TestClass()
b = TestClass()
self.assertIs(a, b)
if __name__ == '__main__':
unittest.main()