diff --git a/dataset/persistence/database.py b/dataset/persistence/database.py index 7072853..ca84604 100644 --- a/dataset/persistence/database.py +++ b/dataset/persistence/database.py @@ -55,7 +55,7 @@ class Database(object): self.url = url self.row_type = row_type self.ensure_schema = ensure_schema - self.reflect_views = reflect_views + self._tables = {} @property def executable(self): @@ -75,15 +75,8 @@ class Database(object): """Return a SQLAlchemy schema cache object.""" return MetaData(schema=self.schema, bind=self.executable) - def _acquire(self): - self.lock.acquire() - - def _release(self): - self.lock.release() - def begin(self): - """ - Enter a transaction explicitly. + """Enter a transaction explicitly. No data will be written until the transaction has been committed. """ @@ -92,8 +85,7 @@ class Database(object): self.local.tx.append(self.executable.begin()) def commit(self): - """ - Commit the current transaction. + """Commit the current transaction. Make all statements executed since the transaction was begun permanent. """ @@ -133,9 +125,12 @@ class Database(object): inspector = Inspector.from_engine(self.executable) return inspector.get_table_names(schema=self.schema) - def __contains__(self, member): + def __contains__(self, table_name): """Check if the given table name exists in the database.""" - return member in self.tables + try: + return self._valid_table_name(table_name) in self.tables + except ValueError: + return False def _valid_table_name(self, table_name): """Check if the table name is obviously invalid.""" @@ -146,15 +141,12 @@ class Database(object): def create_table(self, table_name, primary_id=None, primary_type=None): """Create a new table. - The new table will automatically have an `id` column unless specified via - optional parameter primary_id, which will be used as the primary key of the - table. Automatic id is set to be an auto-incrementing integer, while the - type of custom primary_id can be a - String or an Integer as specified with primary_type flag. The default - length of String is 255. The caller can specify the length. - The caller will be responsible for the uniqueness of manual primary_id. - - This custom id feature is only available via direct create_table call. + Either loads a table or creates it if it doesn't exist yet. You can + define the name and type of the primary key field, if a new table is to + be created. The default is to create an auto-incrementing integer, + `id`. You can also set the primary key to be a string or big integer. + The caller will be responsible for the uniqueness of `primary_id` if + it is defined as a text type. Returns a :py:class:`Table ` instance. :: @@ -177,11 +169,9 @@ class Database(object): assert not isinstance(primary_type, six.string_types), \ 'Text-based primary_type support is dropped, use db.types.' table_name = self._valid_table_name(table_name) - table = self._reflect_table(table_name) - if table is not None: - return Table(self, table) - self._acquire() - try: + with self.lock: + if table_name in self: + return self.load_table(table_name) log.debug("Creating table: %s" % (table_name)) table = SQLATable(table_name, self.metadata, schema=self.schema) if primary_id is not False: @@ -194,9 +184,8 @@ class Database(object): autoincrement=autoincrement) table.append_column(col) table.create(self.executable, checkfirst=True) - return Table(self, table) - finally: - self._release() + self._tables[table_name] = Table(self, table) + return self._tables[table_name] def load_table(self, table_name): """Load a table. @@ -211,38 +200,37 @@ class Database(object): table = db.load_table('population') """ table_name = self._valid_table_name(table_name) + if table_name in self._tables: + return self._tables.get(table_name) log.debug("Loading table: %s", table_name) - table = self._reflect_table(table_name) - if table is not None: - return Table(self, table) + with self.lock: + table = self._reflect_table(table_name) + if table is not None: + self._tables[table_name] = Table(self, table) + return self._tables[table_name] def _reflect_table(self, table_name): """Reload a table schema from the database.""" table_name = self._valid_table_name(table_name) try: - return SQLATable(table_name, - self.metadata, - schema=self.schema, - autoload=True, - autoload_with=self.executable) + table = SQLATable(table_name, + self.metadata, + schema=self.schema, + autoload=True, + autoload_with=self.executable) + return table except NoSuchTableError: return None def get_table(self, table_name, primary_id=None, primary_type=None): - """ - Smart wrapper around *load_table* and *create_table*. + """Load or create a table. - Either loads a table or creates it if it doesn't exist yet. - For short-hand to create a table with custom id and type using [], where - table_name, primary_id, and primary_type are specified as a tuple - - Returns a :py:class:`Table ` instance. + This is now the same as `create_table`. :: table = db.get_table('population') # you can also use the short-hand syntax: table = db['population'] - """ return self.create_table(table_name, primary_id, primary_type) @@ -250,7 +238,7 @@ class Database(object): """Get a given table.""" return self.get_table(table_name) - def query(self, query, **kw): + def query(self, query, **kwargs): """Run a statement on the database directly. Allows for the execution of arbitrary read/write queries. A query can either be @@ -269,8 +257,9 @@ class Database(object): """ if isinstance(query, six.string_types): query = text(query) - return ResultIter(self.executable.execute(query, **kw), - row_type=self.row_type) + _step = kwargs.pop('_step', 5000) + return ResultIter(self.executable.execute(query, **kwargs), + row_type=self.row_type, step=_step) def __repr__(self): """Text representation contains the URL.""" diff --git a/dataset/persistence/table.py b/dataset/persistence/table.py index 0f57dc3..649a17c 100644 --- a/dataset/persistence/table.py +++ b/dataset/persistence/table.py @@ -22,13 +22,16 @@ class Table(object): """Initialise the table from database schema.""" self.indexes = dict((i.name, i) for i in table.indexes) self.database = database + self.name = table.name self.table = table self._is_dropped = False @property def columns(self): """Get a listing of all columns that exist in the table.""" - return list(self.table.columns.keys()) + if self.table is None: + return [] + return self.table.columns.keys() def drop(self): """ @@ -40,16 +43,15 @@ class Table(object): sure to get a fresh instance from the :py:class:`Database `. """ self._check_dropped() - self.database._acquire() - try: + with self.database.lock: self.table.drop(self.database.executable, checkfirst=True) + # self.database._tables.pop(self.name, None) self.table = None - finally: - self.database._release() def _check_dropped(self): + # self.table = self.database._reflect_table(self.name) if self.table is None: - raise DatasetException('the table has been dropped. this object should not be used again.') + raise DatasetException('The table has been dropped.') def _prune_row(self, row): """Remove keys from row not in column set.""" @@ -165,10 +167,10 @@ class Table(object): they will be created based on the settings of ``ensure`` and ``types``, matching the behavior of :py:meth:`insert() `. """ + self._check_dropped() # check whether keys arg is a string and format as a list if not isinstance(keys, (list, tuple)): keys = [keys] - self._check_dropped() if not keys or len(keys) == len(row): return False clause = [(u, row.get(u)) for u in keys] @@ -297,8 +299,7 @@ class Table(object): table.create_column('created_at', db.types.datetime) """ self._check_dropped() - self.database._acquire() - try: + with self.database.lock: name = normalize_column_name(name) if name in self.columns: log.debug("Column exists: %s" % name) @@ -310,8 +311,6 @@ class Table(object): self.table.schema ) self.table = self.database._reflect_table(self.table.name) - finally: - self.database._release() def create_column_by_example(self, name, value): """ @@ -336,19 +335,16 @@ class Table(object): if self.database.engine.dialect.name == 'sqlite': raise NotImplementedError("SQLite does not support dropping columns.") self._check_dropped() - if name not in self.table.columns.keys(): + if name not in self.columns: log.debug("Column does not exist: %s", name) return - self.database._acquire() - try: + with self.database.lock: self.database.op.drop_column( self.table.name, name, self.table.schema ) self.table = self.database.update_table(self.table.name) - finally: - self.database._release() def create_index(self, columns, name=None, **kw): """ @@ -360,51 +356,30 @@ class Table(object): table.create_index(['name', 'country']) """ self._check_dropped() - if not name: - sig = '||'.join(columns) + with self.database.lock: + if not name: + sig = '||'.join(columns) + + # This is a work-around for a bug in <=0.6.1 which would + # create indexes based on hash() rather than a proper hash. + key = abs(hash(sig)) + name = 'ix_%s_%s' % (self.table.name, key) + if name in self.indexes: + return self.indexes[name] + + key = sha1(sig.encode('utf-8')).hexdigest()[:16] + name = 'ix_%s_%s' % (self.table.name, key) - # This is a work-around for a bug in <=0.6.1 which would create - # indexes based on hash() rather than a proper hash. - key = abs(hash(sig)) - name = 'ix_%s_%s' % (self.table.name, key) if name in self.indexes: return self.indexes[name] - - key = sha1(sig.encode('utf-8')).hexdigest()[:16] - name = 'ix_%s_%s' % (self.table.name, key) - - if name in self.indexes: - return self.indexes[name] - try: - self.database._acquire() - columns = [self.table.c[c] for c in columns] - idx = Index(name, *columns, **kw) - idx.create(self.database.executable) - except: - idx = None - finally: - self.database._release() - self.indexes[name] = idx - return idx - - def find_one(self, *args, **kwargs): - """ - Get a single result from the table. - - Works just like :py:meth:`find() ` but returns one result, or None. - - :: - - row = table.find_one(country='United States') - """ - kwargs['_limit'] = 1 - kwargs['_step'] = None - resiter = self.find(*args, **kwargs) - try: - for row in resiter: - return row - finally: - resiter.close() + try: + columns = [self.table.c[c] for c in columns] + idx = Index(name, *columns, **kw) + idx.create(self.database.executable) + except: + idx = None + self.indexes[name] = idx + return idx def _args_to_order_by(self, order_by): if not isinstance(order_by, (list, tuple)): @@ -467,6 +442,24 @@ class Table(object): return ResultIter(self.database.executable.execute(query), row_type=self.database.row_type, step=_step) + def find_one(self, *args, **kwargs): + """Get a single result from the table. + + Works just like :py:meth:`find() ` but returns one + result, or None. + :: + + row = table.find_one(country='United States') + """ + kwargs['_limit'] = 1 + kwargs['_step'] = None + resiter = self.find(*args, **kwargs) + try: + for row in resiter: + return row + finally: + resiter.close() + def count(self, *_clauses, **kwargs): """Return the count of results for the given filter set.""" # NOTE: this does not have support for limit and offset since I can't diff --git a/test/test_persistence.py b/test/test_persistence.py index 99e73ff..86fd85f 100644 --- a/test/test_persistence.py +++ b/test/test_persistence.py @@ -346,6 +346,12 @@ class TableTestCase(unittest.TestCase): self.db['weather'].drop() assert 'weather' not in self.db + def test_table_drop_then_create(self): + assert 'weather' in self.db + self.db['weather'].drop() + assert 'weather' not in self.db + self.db['weather'].insert({'foo': 'bar'}) + def test_columns(self): cols = self.tbl.columns assert len(list(cols)) == 4, 'column count mismatch'