diff --git a/dataset/__init__.py b/dataset/__init__.py index d9d4f96..3785251 100644 --- a/dataset/__init__.py +++ b/dataset/__init__.py @@ -4,6 +4,7 @@ import warnings warnings.filterwarnings( 'ignore', 'Unicode type received non-unicode bind param value.') +from dataset.persistence.util import sqlite_datetime_fix from dataset.persistence.database import Database from dataset.persistence.table import Table from dataset.freeze.app import freeze @@ -27,4 +28,8 @@ def connect(url=None, schema=None, reflectMetadata=True): """ if url is None: url = os.environ.get('DATABASE_URL', url) + + if url.startswith("sqlite://"): + sqlite_datetime_fix() + return Database(url, schema=schema, reflectMetadata=reflectMetadata) diff --git a/dataset/persistence/util.py b/dataset/persistence/util.py index 019f620..d22faf1 100644 --- a/dataset/persistence/util.py +++ b/dataset/persistence/util.py @@ -1,7 +1,7 @@ -from datetime import datetime +from datetime import datetime, timedelta from inspect import isgenerator -from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean +from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean, types, Table, event def guess_type(sample): @@ -50,3 +50,20 @@ class ResultIter(object): def __iter__(self): return self + + +def sqlite_datetime_fix(): + class SQLiteDateTimeType(types.TypeDecorator): + impl = types.Integer + epoch = datetime(1970, 1, 1, 0, 0, 0) + + def process_bind_param(self, value, dialect): + return (value / 1000 - self.epoch).total_seconds() + + def process_result_value(self, value, dialect): + return self.epoch + timedelta(seconds=value / 1000) + + @event.listens_for(Table, "column_reflect") + def setup_epoch(inspector, table, column_info): + if isinstance(column_info['type'], types.DateTime): + column_info['type'] = SQLiteDateTimeType()