diff --git a/sqlaload/schema.py b/sqlaload/schema.py index 911ee85..44a4ba6 100644 --- a/sqlaload/schema.py +++ b/sqlaload/schema.py @@ -1,6 +1,7 @@ import logging from datetime import datetime from collections import defaultdict +from threading import RLock from sqlalchemy import create_engine from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean @@ -9,6 +10,7 @@ from sqlalchemy.sql import and_, expression from migrate.versioning.util import construct_engine log = logging.getLogger(__name__) +lock = RLock() TABLES = defaultdict(dict) INDEXES = dict() @@ -16,8 +18,8 @@ INDEXES = dict() def connect(url): """ Create an engine for the given database URL. """ kw = {} - if url.startswith('postgres'): - kw['pool_size'] = 1 + #if url.startswith('postgres'): + # kw['pool_size'] = 10 engine = create_engine(url, **kw) engine = construct_engine(engine) meta = MetaData() @@ -26,44 +28,49 @@ def connect(url): return engine def create_table(engine, table_name): - log.debug("Creating table: %s on %r" % (table_name, engine)) - table = Table(table_name, engine._metadata) - col = Column('id', Integer, primary_key=True) - table.append_column(col) - table.create(engine) - TABLES[engine][table_name] = table - return table + with lock: + log.debug("Creating table: %s on %r" % (table_name, engine)) + table = Table(table_name, engine._metadata) + col = Column('id', Integer, primary_key=True) + table.append_column(col) + table.create(engine) + TABLES[engine][table_name] = table + return table def load_table(engine, table_name): - log.debug("Loading table: %s on %r" % (table_name, engine)) - table = Table(table_name, engine._metadata, autoload=True) - TABLES[engine][table_name] = table - return table + with lock: + log.debug("Loading table: %s on %r" % (table_name, engine)) + table = Table(table_name, engine._metadata, autoload=True) + TABLES[engine][table_name] = table + return table def get_table(engine, table_name): # Accept Connection objects here if hasattr(engine, 'engine'): engine = engine.engine - - if table_name in TABLES[engine]: - return TABLES[engine][table_name] - if engine.has_table(table_name): - return load_table(engine, table_name) - else: - return create_table(engine, table_name) + + with lock: + if table_name in TABLES[engine]: + return TABLES[engine][table_name] + if engine.has_table(table_name): + return load_table(engine, table_name) + else: + return create_table(engine, table_name) def drop_table(engine, table_name): # Accept Connection objects here if hasattr(engine, 'engine'): engine = engine.engine - if table_name in TABLES[engine]: - table = TABLES[engine][table_name] - elif engine.has_table(table_name): - table = Table(table_name, engine._metadata) - else: - return - table.drop(engine) - TABLES[engine].pop(table_name, None) + + with lock: + if table_name in TABLES[engine]: + table = TABLES[engine][table_name] + elif engine.has_table(table_name): + table = Table(table_name, engine._metadata) + else: + return + table.drop(engine) + TABLES[engine].pop(table_name, None) def _guess_type(sample): if isinstance(sample, bool): @@ -94,22 +101,24 @@ def _args_to_clause(table, args): return and_(*clauses) def create_column(engine, table, name, type): - if name not in table.columns.keys(): - col = Column(name, type) - col.create(table, connection=engine) + with lock: + if name not in table.columns.keys(): + col = Column(name, type) + col.create(table, connection=engine) def create_index(engine, table, columns, name=None): - if not name: - sig = abs(hash('||'.join(columns))) - name = 'ix_%s_%s' % (table.name, sig) - if name in INDEXES: - return INDEXES[name] - try: - columns = [table.c[c] for c in columns] - idx = Index(name, *columns) - idx.create(engine) - except: - idx = None - INDEXES[name] = idx - return idx + with lock: + if not name: + sig = abs(hash('||'.join(columns))) + name = 'ix_%s_%s' % (table.name, sig) + if name in INDEXES: + return INDEXES[name] + try: + columns = [table.c[c] for c in columns] + idx = Index(name, *columns) + idx.create(engine) + except: + idx = None + INDEXES[name] = idx + return idx