Merge pull request #321 from benfasoli/master

Adds support for JSON and JSONB columns
This commit is contained in:
Friedrich Lindenberg 2020-04-06 22:44:48 +02:00 committed by GitHub
commit 507b0a5a40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 9 deletions

View File

@ -45,9 +45,9 @@ class Database(object):
if len(schema_qs):
schema = schema_qs.pop()
self.types = Types()
self.schema = schema
self.engine = create_engine(url, **engine_kwargs)
self.types = Types(self.engine.dialect.name)
self.url = url
self.row_type = row_type
self.ensure_schema = ensure_schema

View File

@ -1,7 +1,7 @@
from datetime import datetime, date
from sqlalchemy import Integer, UnicodeText, Float, BigInteger
from sqlalchemy import Boolean, Date, DateTime, Unicode
from sqlalchemy import Boolean, Date, DateTime, Unicode, JSON
from sqlalchemy.types import TypeEngine
@ -16,22 +16,37 @@ class Types(object):
date = Date
datetime = DateTime
def guess(cls, sample):
def __init__(self, dialect = None):
self._dialect = dialect
@property
def json(self):
if self._dialect is not None and self._dialect == 'postgresql':
from sqlalchemy.dialects.postgresql import JSONB
return JSONB
return JSON
def guess(self, sample, dialect = None):
"""Given a single sample, guess the column type for the field.
If the sample is an instance of an SQLAlchemy type, the type will be
used instead.
"""
if dialect is not None:
self._dialect = dialect
if isinstance(sample, TypeEngine):
return sample
if isinstance(sample, bool):
return cls.boolean
return self.boolean
elif isinstance(sample, int):
return cls.bigint
return self.bigint
elif isinstance(sample, float):
return cls.float
return self.float
elif isinstance(sample, datetime):
return cls.datetime
return self.datetime
elif isinstance(sample, date):
return cls.date
return cls.text
return self.date
elif isinstance(sample, dict):
return self.json
return self.text

View File

@ -192,6 +192,20 @@ class TableTestCase(unittest.TestCase):
)
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
def test_insert_json(self):
last_id = self.tbl.insert({
'date': datetime(2011, 1, 2),
'temperature': -10,
'place': 'Berlin',
'info': {
'currency': 'EUR',
'language': 'German',
'population': 3292365
}
})
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
assert self.tbl.find_one(id=last_id)['place'] == 'Berlin'
def test_upsert(self):
self.tbl.upsert({
'date': datetime(2011, 1, 2),