From 30867fad21d887cec4e6ecb0b027748541a4fbbc Mon Sep 17 00:00:00 2001 From: Friedrich Lindenberg Date: Mon, 1 Apr 2013 18:05:41 +0200 Subject: [PATCH] Refactor more stuff up to an object. --- dataset/persistence/database.py | 38 +++++++++++++++++++++-- dataset/persistence/table.py | 14 +++++++++ dataset/persistence/util.py | 16 ++++++++++ dataset/schema.py | 55 --------------------------------- 4 files changed, 65 insertions(+), 58 deletions(-) create mode 100644 dataset/persistence/table.py diff --git a/dataset/persistence/database.py b/dataset/persistence/database.py index c7a8814..8f34f3b 100644 --- a/dataset/persistence/database.py +++ b/dataset/persistence/database.py @@ -4,6 +4,12 @@ from threading import RLock from sqlalchemy import create_engine from migrate.versioning.util import construct_engine from sqlalchemy.pool import NullPool +from sqlalchemy.schema import MetaData, Column, Index +from sqlalchemy.schema import Table as SQLATable +from sqlalchemy import Integer + +from dataset.persistence.table import Table + log = logging.getLogger(__name__) @@ -23,13 +29,39 @@ class Database(object): self.tables = {} self.indexes = {} + def create_table(table_name): + with self.lock: + log.debug("Creating table: %s on %r" % (table_name, self.engine)) + table = SQLATable(table_name, self.metadata) + col = Column('id', Integer, primary_key=True) + table.append_column(col) + table.create(self.engine) + self.tables[table_name] = table + return Table(self, table) + + def load_table(self, table_name): + with self.lock: + log.debug("Loading table: %s on %r" % (table_name, self)) + table = SQLATable(table_name, self.metadata, autoload=True) + self.tables[table_name] = table + return Table(self, table) + + def get_table(self, table_name): + with self.lock: + if table_name in self.tables: + return Table(self, self.tables[table_name]) + if self.engine.has_table(table_name): + return load_table(engine, table_name) + else: + return create_table(engine, table_name) + + def __getitem__(self, table_name): + return self.get_table(table_name) + @classmethod def connect(self, url): return Database(url) - def __repr__(self): return '' % self.url - - diff --git a/dataset/persistence/table.py b/dataset/persistence/table.py new file mode 100644 index 0000000..ded753e --- /dev/null +++ b/dataset/persistence/table.py @@ -0,0 +1,14 @@ + + +class Table(object): + + def __init__(self, database, table): + self.database = database + self.table = table + + def drop(self): + with self.database.lock: + self.database.tables.pop(self.table.name, None) + self.table.drop(engine) + + diff --git a/dataset/persistence/util.py b/dataset/persistence/util.py index dacdc60..9287549 100644 --- a/dataset/persistence/util.py +++ b/dataset/persistence/util.py @@ -1,3 +1,19 @@ +from datetime import datetime + +from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean + + +def guess_type(sample): + if isinstance(sample, bool): + return Boolean + elif isinstance(sample, int): + return Integer + elif isinstance(sample, float): + return Float + elif isinstance(sample, datetime): + return DateTime + return UnicodeText + def resultiter(rp): """ SQLAlchemy ResultProxies are not iterable to get a diff --git a/dataset/schema.py b/dataset/schema.py index 40b5d19..0ade8ed 100644 --- a/dataset/schema.py +++ b/dataset/schema.py @@ -9,64 +9,9 @@ from sqlalchemy.sql import and_, expression log = logging.getLogger(__name__) -def create_table(engine, table_name): - with lock: - log.debug("Creating table: %s on %r" % (table_name, engine)) - table = Table(table_name, engine._metadata) - col = Column('id', Integer, primary_key=True) - table.append_column(col) - table.create(engine) - engine._tables[table_name] = table - return table -def load_table(engine, table_name): - with lock: - log.debug("Loading table: %s on %r" % (table_name, engine)) - table = Table(table_name, engine._metadata, autoload=True) - engine._tables[table_name] = table - return table -def get_table(engine, table_name): - if isinstance(table_name, Table): - return table_name - # Accept Connection objects here - if hasattr(engine, 'engine'): - engine = engine.engine - - with lock: - if table_name in engine._tables: - return engine._tables[table_name] - if engine.has_table(table_name): - return load_table(engine, table_name) - else: - return create_table(engine, table_name) - -def drop_table(engine, table_name): - # Accept Connection objects here - if hasattr(engine, 'engine'): - engine = engine.engine - - with lock: - if table_name in engine._tables: - table = engine._tables[table_name] - elif engine.has_table(table_name): - table = Table(table_name, engine._metadata) - else: - return - table.drop(engine) - engine._tables.pop(table_name, None) - -def _guess_type(sample): - if isinstance(sample, bool): - return Boolean - elif isinstance(sample, int): - return Integer - elif isinstance(sample, float): - return Float - elif isinstance(sample, datetime): - return DateTime - return UnicodeText def _ensure_columns(engine, table, row, types={}): columns = set(row.keys()) - set(table.columns.keys())