Move temporary database creation code into a class.

Signed-off-by: brian m. carlson <sandals@crustytoothpaste.net>
This commit is contained in:
brian m. carlson 2014-12-12 21:17:26 +00:00
parent e8e4546870
commit 7fdde37c27
No known key found for this signature in database
GPG key ID: BF535D811F52F68B

View file

@ -59,60 +59,56 @@ class TestDatabaseAccessors(unittest.TestCase):
DatabaseVersion.preferred().serialization()) DatabaseVersion.preferred().serialization())
class TestDatabaseIntegrity(unittest.TestCase): class TemporaryDatabase:
def create_temp_db(self): def __init__(self, schema_contents=""):
ddir = tempfile.TemporaryDirectory() self.ddir = tempfile.TemporaryDirectory()
with open(ddir.name + "/schema", "w") as fp: with open(self.ddir.name + "/schema", "w") as fp:
fp.write("fmt:0:newfol schema file:\ntxn:git\n") fp.write("fmt:3:newfol schema file:\n" + schema_contents)
db = Database.load(ddir.name)
return (ddir, db)
@property
def db(self):
return Database.load(self.ddir.name)
def __del__(self):
self.ddir.cleanup()
class TestDatabaseIntegrity(unittest.TestCase):
def test_version(self): def test_version(self):
ddir, db = self.create_temp_db() tdb = TemporaryDatabase()
self.assertEqual(Database.read_version(ddir.name), self.assertEqual(Database.read_version(tdb.ddir.name),
DatabaseVersion()) DatabaseVersion())
db.store() tdb.db.store()
db.upgrade() tdb.db.upgrade()
self.assertEqual(Database.read_version(ddir.name), self.assertEqual(Database.read_version(tdb.ddir.name),
DatabaseVersion.preferred()) DatabaseVersion.preferred())
ddir.cleanup()
def test_validate(self): def test_validate(self):
ddir, db = self.create_temp_db() tdb = TemporaryDatabase()
db.store() tdb.db.store()
db.validate() tdb.db.validate()
ddir.cleanup()
def test_validate_strict(self): def test_validate_strict(self):
ddir, db = self.create_temp_db() tdb = TemporaryDatabase()
db.store() tdb.db.store()
db.validate(strict=True) tdb.db.validate(strict=True)
ddir.cleanup()
def test_repair_doesnt_raise(self): def test_repair_doesnt_raise(self):
ddir, db = self.create_temp_db() tdb = TemporaryDatabase()
db.store() tdb.db.store()
db.repair() tdb.db.repair()
db.validate(strict=True) tdb.db.validate(strict=True)
ddir.cleanup()
def test_upgrade_records(self): def test_upgrade_records(self):
ddir, db = self.create_temp_db() tdb = TemporaryDatabase()
db.store() tdb.db.store()
db.upgrade_records() tdb.db.upgrade_records()
ddir.cleanup()
class TestDatabaseUpgrades(unittest.TestCase): class TestDatabaseUpgrades(unittest.TestCase):
def create_temp_db(self):
ddir = tempfile.TemporaryDirectory()
with open(ddir.name + "/schema", "w") as fp:
fp.write("fmt:0:newfol schema file:\ntxn:git\n")
db = Database.load(ddir.name)
return (ddir, db)
def do_upgrade_test(self, version, pattern): def do_upgrade_test(self, version, pattern):
ddir, db = self.create_temp_db() tdb = TemporaryDatabase("txn:git\n")
ddir, db = tdb.ddir, tdb.db
if not isinstance(version, DatabaseVersion): if not isinstance(version, DatabaseVersion):
version = DatabaseVersion(version) version = DatabaseVersion(version)
self.assertEqual(Database.read_version(ddir.name), self.assertEqual(Database.read_version(ddir.name),
@ -141,60 +137,49 @@ class TestDatabaseUpgrades(unittest.TestCase):
class TestMultipleTransactions(unittest.TestCase): class TestMultipleTransactions(unittest.TestCase):
def test_multiple_types(self): def test_multiple_types(self):
ddir = tempfile.TemporaryDirectory() tdb = TemporaryDatabase("txn:git:hash\n")
with open(ddir.name + "/schema", "w") as fp: ddir = tdb.ddir
fp.write("fmt:0:newfol schema file:\ntxn:git:hash\n")
db = Database.load(ddir.name) db = Database.load(ddir.name)
db.records()[:] = [Record([1, 2, 3])] db.records()[:] = [Record([1, 2, 3])]
db.store() db.store()
self.assertEqual(set(db.schema().transaction_types()), self.assertEqual(set(db.schema().transaction_types()),
set(["git", "hash"])) set(["git", "hash"]))
ddir.cleanup()
class TestExtraSchemaConfig(unittest.TestCase): class TestExtraSchemaConfig(unittest.TestCase):
def test_existing_config_file(self): def test_existing_config_file(self):
ddir1 = tempfile.TemporaryDirectory() tdb = TemporaryDatabase()
ddir2 = tempfile.TemporaryDirectory() ddir2 = tempfile.TemporaryDirectory()
config = "%s/config" % ddir2.name config = "%s/config" % ddir2.name
with open(ddir1.name + "/schema", "w") as fp:
fp.write("fmt:3:newfol schema file:\n")
with open(config, "w") as fp: with open(config, "w") as fp:
fp.write("fmt:3:newfol config file:\ntxn:git:hash\n") fp.write("fmt:3:newfol config file:\ntxn:git:hash\n")
db = Database.load(ddir1.name, extra_config=[config]) db = Database.load(tdb.ddir.name, extra_config=[config])
db.records()[:] = [Record([1, 2, 3])] db.records()[:] = [Record([1, 2, 3])]
db.store() db.store()
self.assertEqual(set(db.schema().transaction_types()), self.assertEqual(set(db.schema().transaction_types()),
set(["git", "hash"])) set(["git", "hash"]))
ddir1.cleanup()
ddir2.cleanup() ddir2.cleanup()
def test_missing_config_file(self): def test_missing_config_file(self):
ddir1 = tempfile.TemporaryDirectory() tdb = TemporaryDatabase()
config = "%s/config" % ddir1.name config = "%s/config" % tdb.ddir.name
with open(ddir1.name + "/schema", "w") as fp: db = Database.load(tdb.ddir.name, extra_config=[config])
fp.write("fmt:3:newfol schema file:\n")
db = Database.load(ddir1.name, extra_config=[config])
db.records()[:] = [Record([1, 2, 3])] db.records()[:] = [Record([1, 2, 3])]
db.store() db.store()
# Ensure no exception is raised. # Ensure no exception is raised.
ddir1.cleanup()
class TestExecutionAllowed(unittest.TestCase): class TestExecutionAllowed(unittest.TestCase):
def do_test(self, expected, schema, configv): def do_test(self, expected, schema, configv):
ddir1 = tempfile.TemporaryDirectory() tdb = TemporaryDatabase(schema)
ddir2 = tempfile.TemporaryDirectory() ddir2 = tempfile.TemporaryDirectory()
config = "%s/config" % ddir2.name config = "%s/config" % ddir2.name
with open(ddir1.name + "/schema", "w") as fp:
fp.write("fmt:3:newfol schema file:\n" + schema)
with open(config, "w") as fp: with open(config, "w") as fp:
fp.write("fmt:3:newfol config file:\n" + configv) fp.write("fmt:3:newfol config file:\n" + configv)
db = Database.load(ddir1.name, extra_config=[config]) db = Database.load(tdb.ddir.name, extra_config=[config])
db.records()[:] = [Record([1, 2, 3])] db.records()[:] = [Record([1, 2, 3])]
db.store() db.store()
self.assertEqual(db.schema().execution_allowed(), expected) self.assertEqual(db.schema().execution_allowed(), expected)
ddir1.cleanup()
ddir2.cleanup() ddir2.cleanup()
def test_true_if_only_true_schema(self): def test_true_if_only_true_schema(self):
@ -211,15 +196,9 @@ class TestExecutionAllowed(unittest.TestCase):
class TestDatabaseFiltering(unittest.TestCase): class TestDatabaseFiltering(unittest.TestCase):
def create_temp_db(self):
ddir = tempfile.TemporaryDirectory()
with open(ddir.name + "/schema", "w") as fp:
fp.write("fmt:0:newfol schema file:\ntxn:git\n")
db = Database.load(ddir.name)
return (ddir, db)
def test_filtering(self): def test_filtering(self):
ddir, db = self.create_temp_db() tdb = TemporaryDatabase("txn:git\n")
db = tdb.db
records = [ records = [
Record(["a", "b", "c"]), Record(["a", "b", "c"]),
Record([1, 2, 3]), Record([1, 2, 3]),
@ -237,7 +216,6 @@ class TestDatabaseFiltering(unittest.TestCase):
selected = db.records(has_only_numbers) selected = db.records(has_only_numbers)
self.assertEqual(type(selected), list) self.assertEqual(type(selected), list)
self.assertEqual(set(selected), set(records[1:])) self.assertEqual(set(selected), set(records[1:]))
ddir.cleanup()
class TestSingleton(unittest.TestCase): class TestSingleton(unittest.TestCase):