diff --git a/dataset/__init__.py b/dataset/__init__.py index 2a4c530..4441aee 100644 --- a/dataset/__init__.py +++ b/dataset/__init__.py @@ -11,7 +11,7 @@ from dataset.freeze.app import freeze __all__ = ['Database', 'Table', 'freeze', 'connect'] -def connect(url=None, reflectMetadata=True): +def connect(url=None, schema=None, reflectMetadata=True): """ Opens a new connection to a database. *url* can be any valid `SQLAlchemy engine URL`_. If *url* is not defined it will try to use *DATABASE_URL* from environment variable. @@ -26,4 +26,4 @@ def connect(url=None, reflectMetadata=True): """ if url is None: url = os.environ.get('DATABASE_URL', url) - return Database(url, reflectMetadata) + return Database(url, schema=schema, reflectMetadata=reflectMetadata) diff --git a/dataset/persistence/database.py b/dataset/persistence/database.py index 31c83ee..369c45f 100644 --- a/dataset/persistence/database.py +++ b/dataset/persistence/database.py @@ -1,5 +1,6 @@ import logging import threading +from urlparse import parse_qs from sqlalchemy import create_engine from migrate.versioning.util import construct_engine @@ -17,16 +18,25 @@ log = logging.getLogger(__name__) class Database(object): - def __init__(self, url, reflectMetadata=True): + def __init__(self, url, schema=None, reflectMetadata=True): kw = {} if url.startswith('postgres'): kw['poolclass'] = NullPool - engine = create_engine(url, **kw) self.lock = threading.RLock() self.local = threading.local() + if '?' in url: + url, query = url.split('?', 1) + query = parse_qs(query) + if schema is None: + # le pop + schema_qs = query.pop('schema', query.pop('searchpath', [])) + if len(schema_qs): + schema = schema_qs.pop() + self.schema = schema + engine = create_engine(url, **kw) self.url = url self.engine = construct_engine(engine) - self.metadata = MetaData() + self.metadata = MetaData(schema=schema) self.metadata.bind = self.engine if reflectMetadata: self.metadata.reflect(self.engine) @@ -146,7 +156,7 @@ class Database(object): return Table(self, self._tables[table_name]) self._acquire() try: - if self.engine.has_table(table_name): + if self.engine.has_table(table_name, schema=self.schema): return self.load_table(table_name) else: return self.create_table(table_name) @@ -172,4 +182,3 @@ class Database(object): def __repr__(self): return '' % self.url - diff --git a/setup.py b/setup.py index abb218d..ef548fe 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages setup( name='dataset', - version='0.3.7', + version='0.3.8', description="Toolkit for Python-based data processing.", long_description="", classifiers=[