create table with custom primary_id

This commit is contained in:
Yi Xie 2013-09-08 11:35:43 -04:00
parent 2d27a75c1a
commit bfd75360be
2 changed files with 45 additions and 6 deletions

View File

@ -7,7 +7,7 @@ from migrate.versioning.util import construct_engine
from sqlalchemy.pool import NullPool from sqlalchemy.pool import NullPool
from sqlalchemy.schema import MetaData, Column, Index from sqlalchemy.schema import MetaData, Column, Index
from sqlalchemy.schema import Table as SQLATable from sqlalchemy.schema import Table as SQLATable
from sqlalchemy import Integer from sqlalchemy import Integer, String
from dataset.persistence.table import Table from dataset.persistence.table import Table
from dataset.persistence.util import ResultIter from dataset.persistence.util import ResultIter
@ -97,10 +97,16 @@ class Database(object):
return list(set(self.metadata.tables.keys() + return list(set(self.metadata.tables.keys() +
self._tables.keys())) self._tables.keys()))
def create_table(self, table_name): def create_table(self, table_name, primary_id=None, is_integer=False):
""" """
Creates a new table. The new table will automatically have an `id` column, which is Creates a new table. The new table will automatically have an `id` column
set to be an auto-incrementing integer as the primary key of the table. unless specified via optional parameter primary_id, which will be used
as the primary key of the table. Automatic id is set to be an
auto-incrementing integer, while the type of custom primary_id can be a
string (i.e. VARCHAR(255)) by default or an integer via integer flag.
The caller will be responsible for the uniqueness of manually specified primary_id.
This custom id feature is only available via direct create_table call.
Returns a :py:class:`Table <dataset.Table>` instance. Returns a :py:class:`Table <dataset.Table>` instance.
:: ::
@ -111,7 +117,14 @@ class Database(object):
try: try:
log.debug("Creating table: %s on %r" % (table_name, self.engine)) log.debug("Creating table: %s on %r" % (table_name, self.engine))
table = SQLATable(table_name, self.metadata) table = SQLATable(table_name, self.metadata)
col = Column('id', Integer, primary_key=True) if not primary_id:
col = Column('id', Integer, primary_key=True)
else:
if is_integer:
col = Column(primary_id, Integer, primary_key=True, autoincrement=False)
else:
col = Column(primary_id, String(255), primary_key=True)
table.append_column(col) table.append_column(col)
table.create(self.engine) table.create(self.engine)
self._tables[table_name] = table self._tables[table_name] = table

View File

@ -5,13 +5,14 @@ from datetime import datetime
from dataset import connect from dataset import connect
from dataset.util import DatasetException from dataset.util import DatasetException
from sample_data import TEST_DATA from sample_data import TEST_DATA
from sqlalchemy.exc import IntegrityError
class DatabaseTestCase(unittest.TestCase): class DatabaseTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
os.environ['DATABASE_URL'] = 'sqlite:///:memory:' os.environ['DATABASE_URL'] = 'sqlite:///:memory:'
self.db = connect() self.db = connect('sqlite:///:memory:')
self.tbl = self.db['weather'] self.tbl = self.db['weather']
for row in TEST_DATA: for row in TEST_DATA:
self.tbl.insert(row) self.tbl.insert(row)
@ -32,6 +33,31 @@ class DatabaseTestCase(unittest.TestCase):
assert len(table.table.columns) == 1, table.table.columns assert len(table.table.columns) == 1, table.table.columns
assert 'id' in table.table.c, table.table.c assert 'id' in table.table.c, table.table.c
def test_create_table_custom_id1(self):
pid = "string_id"
table = self.db.create_table("foo2", primary_id = pid)
assert table.table.exists()
assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c
table.insert({
'string_id': 'foobar'})
assert table.find_one(string_id = 'foobar')[0] == 'foobar'
def test_create_table_custom_id2(self):
pid = "int_id"
table = self.db.create_table("foo3", primary_id = pid, is_integer=True)
assert table.table.exists()
assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c
table.insert({'int_id': 123})
table.insert({'int_id': 124})
assert table.find_one(int_id = 123)[0] == 123
assert table.find_one(int_id = 124)[0] == 124
with self.assertRaises(IntegrityError):
table.insert({'int_id': 123})
def test_load_table(self): def test_load_table(self):
tbl = self.db.load_table('weather') tbl = self.db.load_table('weather')
assert tbl.table == self.tbl.table assert tbl.table == self.tbl.table