diff --git a/dataset/persistence/database.py b/dataset/persistence/database.py index 20ca061..20a1519 100644 --- a/dataset/persistence/database.py +++ b/dataset/persistence/database.py @@ -7,7 +7,7 @@ from migrate.versioning.util import construct_engine from sqlalchemy.pool import NullPool from sqlalchemy.schema import MetaData, Column, Index from sqlalchemy.schema import Table as SQLATable -from sqlalchemy import Integer +from sqlalchemy import Integer, String from dataset.persistence.table import Table from dataset.persistence.util import ResultIter @@ -97,10 +97,16 @@ class Database(object): return list(set(self.metadata.tables.keys() + self._tables.keys())) - def create_table(self, table_name): + def create_table(self, table_name, primary_id=None, is_integer=False): """ - Creates a new table. The new table will automatically have an `id` column, which is - set to be an auto-incrementing integer as the primary key of the table. + Creates a new table. The new table will automatically have an `id` column + unless specified via optional parameter primary_id, which will be used + as the primary key of the table. Automatic id is set to be an + auto-incrementing integer, while the type of custom primary_id can be a + string (i.e. VARCHAR(255)) by default or an integer via integer flag. + The caller will be responsible for the uniqueness of manually specified primary_id. + + This custom id feature is only available via direct create_table call. Returns a :py:class:`Table ` instance. :: @@ -111,7 +117,14 @@ class Database(object): try: log.debug("Creating table: %s on %r" % (table_name, self.engine)) table = SQLATable(table_name, self.metadata) - col = Column('id', Integer, primary_key=True) + if not primary_id: + col = Column('id', Integer, primary_key=True) + else: + if is_integer: + col = Column(primary_id, Integer, primary_key=True, autoincrement=False) + else: + col = Column(primary_id, String(255), primary_key=True) + table.append_column(col) table.create(self.engine) self._tables[table_name] = table diff --git a/test/test_persistence.py b/test/test_persistence.py index 893ceb3..80614b8 100644 --- a/test/test_persistence.py +++ b/test/test_persistence.py @@ -5,13 +5,14 @@ from datetime import datetime from dataset import connect from dataset.util import DatasetException from sample_data import TEST_DATA +from sqlalchemy.exc import IntegrityError class DatabaseTestCase(unittest.TestCase): def setUp(self): os.environ['DATABASE_URL'] = 'sqlite:///:memory:' - self.db = connect() + self.db = connect('sqlite:///:memory:') self.tbl = self.db['weather'] for row in TEST_DATA: self.tbl.insert(row) @@ -32,6 +33,31 @@ class DatabaseTestCase(unittest.TestCase): assert len(table.table.columns) == 1, table.table.columns assert 'id' in table.table.c, table.table.c + def test_create_table_custom_id1(self): + pid = "string_id" + table = self.db.create_table("foo2", primary_id = pid) + assert table.table.exists() + assert len(table.table.columns) == 1, table.table.columns + assert pid in table.table.c, table.table.c + + table.insert({ + 'string_id': 'foobar'}) + assert table.find_one(string_id = 'foobar')[0] == 'foobar' + + def test_create_table_custom_id2(self): + pid = "int_id" + table = self.db.create_table("foo3", primary_id = pid, is_integer=True) + assert table.table.exists() + assert len(table.table.columns) == 1, table.table.columns + assert pid in table.table.c, table.table.c + + table.insert({'int_id': 123}) + table.insert({'int_id': 124}) + assert table.find_one(int_id = 123)[0] == 123 + assert table.find_one(int_id = 124)[0] == 124 + with self.assertRaises(IntegrityError): + table.insert({'int_id': 123}) + def test_load_table(self): tbl = self.db.load_table('weather') assert tbl.table == self.tbl.table