diff --git a/dataset/database.py b/dataset/database.py index 3155b5a..87d85be 100644 --- a/dataset/database.py +++ b/dataset/database.py @@ -45,9 +45,9 @@ class Database(object): if len(schema_qs): schema = schema_qs.pop() - self.types = Types() self.schema = schema self.engine = create_engine(url, **engine_kwargs) + self.types = Types(self.engine.dialect.name) self.url = url self.row_type = row_type self.ensure_schema = ensure_schema diff --git a/dataset/types.py b/dataset/types.py index 46f5550..904ffa9 100644 --- a/dataset/types.py +++ b/dataset/types.py @@ -1,7 +1,7 @@ from datetime import datetime, date from sqlalchemy import Integer, UnicodeText, Float, BigInteger -from sqlalchemy import Boolean, Date, DateTime, Unicode +from sqlalchemy import Boolean, Date, DateTime, Unicode, JSON from sqlalchemy.types import TypeEngine @@ -16,22 +16,37 @@ class Types(object): date = Date datetime = DateTime - def guess(cls, sample): + def __init__(self, dialect = None): + self._dialect = dialect + + @property + def json(self): + if self._dialect is not None and self._dialect == 'postgresql': + from sqlalchemy.dialects.postgresql import JSONB + return JSONB + return JSON + + def guess(self, sample, dialect = None): """Given a single sample, guess the column type for the field. If the sample is an instance of an SQLAlchemy type, the type will be used instead. """ + if dialect is not None: + self._dialect = dialect + if isinstance(sample, TypeEngine): return sample if isinstance(sample, bool): - return cls.boolean + return self.boolean elif isinstance(sample, int): - return cls.bigint + return self.bigint elif isinstance(sample, float): - return cls.float + return self.float elif isinstance(sample, datetime): - return cls.datetime + return self.datetime elif isinstance(sample, date): - return cls.date - return cls.text + return self.date + elif isinstance(sample, dict): + return self.json + return self.text diff --git a/test/test_dataset.py b/test/test_dataset.py index e5cecd6..fb21952 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -192,6 +192,20 @@ class TableTestCase(unittest.TestCase): ) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) + def test_insert_json(self): + last_id = self.tbl.insert({ + 'date': datetime(2011, 1, 2), + 'temperature': -10, + 'place': 'Berlin', + 'info': { + 'currency': 'EUR', + 'language': 'German', + 'population': 3292365 + } + }) + assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) + assert self.tbl.find_one(id=last_id)['place'] == 'Berlin' + def test_upsert(self): self.tbl.upsert({ 'date': datetime(2011, 1, 2),