Update schema to be thread-safe

This commit is contained in:
Friedrich Lindenberg 2012-07-23 15:27:42 +02:00
parent 2c0f88ea74
commit ddcdfea48f

View File

@ -1,6 +1,7 @@
import logging import logging
from datetime import datetime from datetime import datetime
from collections import defaultdict from collections import defaultdict
from threading import RLock
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean
@ -9,6 +10,7 @@ from sqlalchemy.sql import and_, expression
from migrate.versioning.util import construct_engine from migrate.versioning.util import construct_engine
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
lock = RLock()
TABLES = defaultdict(dict) TABLES = defaultdict(dict)
INDEXES = dict() INDEXES = dict()
@ -16,8 +18,8 @@ INDEXES = dict()
def connect(url): def connect(url):
""" Create an engine for the given database URL. """ """ Create an engine for the given database URL. """
kw = {} kw = {}
if url.startswith('postgres'): #if url.startswith('postgres'):
kw['pool_size'] = 1 # kw['pool_size'] = 10
engine = create_engine(url, **kw) engine = create_engine(url, **kw)
engine = construct_engine(engine) engine = construct_engine(engine)
meta = MetaData() meta = MetaData()
@ -26,44 +28,49 @@ def connect(url):
return engine return engine
def create_table(engine, table_name): def create_table(engine, table_name):
log.debug("Creating table: %s on %r" % (table_name, engine)) with lock:
table = Table(table_name, engine._metadata) log.debug("Creating table: %s on %r" % (table_name, engine))
col = Column('id', Integer, primary_key=True) table = Table(table_name, engine._metadata)
table.append_column(col) col = Column('id', Integer, primary_key=True)
table.create(engine) table.append_column(col)
TABLES[engine][table_name] = table table.create(engine)
return table TABLES[engine][table_name] = table
return table
def load_table(engine, table_name): def load_table(engine, table_name):
log.debug("Loading table: %s on %r" % (table_name, engine)) with lock:
table = Table(table_name, engine._metadata, autoload=True) log.debug("Loading table: %s on %r" % (table_name, engine))
TABLES[engine][table_name] = table table = Table(table_name, engine._metadata, autoload=True)
return table TABLES[engine][table_name] = table
return table
def get_table(engine, table_name): def get_table(engine, table_name):
# Accept Connection objects here # Accept Connection objects here
if hasattr(engine, 'engine'): if hasattr(engine, 'engine'):
engine = engine.engine engine = engine.engine
if table_name in TABLES[engine]: with lock:
return TABLES[engine][table_name] if table_name in TABLES[engine]:
if engine.has_table(table_name): return TABLES[engine][table_name]
return load_table(engine, table_name) if engine.has_table(table_name):
else: return load_table(engine, table_name)
return create_table(engine, table_name) else:
return create_table(engine, table_name)
def drop_table(engine, table_name): def drop_table(engine, table_name):
# Accept Connection objects here # Accept Connection objects here
if hasattr(engine, 'engine'): if hasattr(engine, 'engine'):
engine = engine.engine engine = engine.engine
if table_name in TABLES[engine]:
table = TABLES[engine][table_name] with lock:
elif engine.has_table(table_name): if table_name in TABLES[engine]:
table = Table(table_name, engine._metadata) table = TABLES[engine][table_name]
else: elif engine.has_table(table_name):
return table = Table(table_name, engine._metadata)
table.drop(engine) else:
TABLES[engine].pop(table_name, None) return
table.drop(engine)
TABLES[engine].pop(table_name, None)
def _guess_type(sample): def _guess_type(sample):
if isinstance(sample, bool): if isinstance(sample, bool):
@ -94,22 +101,24 @@ def _args_to_clause(table, args):
return and_(*clauses) return and_(*clauses)
def create_column(engine, table, name, type): def create_column(engine, table, name, type):
if name not in table.columns.keys(): with lock:
col = Column(name, type) if name not in table.columns.keys():
col.create(table, connection=engine) col = Column(name, type)
col.create(table, connection=engine)
def create_index(engine, table, columns, name=None): def create_index(engine, table, columns, name=None):
if not name: with lock:
sig = abs(hash('||'.join(columns))) if not name:
name = 'ix_%s_%s' % (table.name, sig) sig = abs(hash('||'.join(columns)))
if name in INDEXES: name = 'ix_%s_%s' % (table.name, sig)
return INDEXES[name] if name in INDEXES:
try: return INDEXES[name]
columns = [table.c[c] for c in columns] try:
idx = Index(name, *columns) columns = [table.c[c] for c in columns]
idx.create(engine) idx = Index(name, *columns)
except: idx.create(engine)
idx = None except:
INDEXES[name] = idx idx = None
return idx INDEXES[name] = idx
return idx