"""
Custom SQLAlchemy column types
"""
import collections.abc
import json
import sqlalchemy
__all__ = (
'SET',
)
[docs]class SET(sqlalchemy.types.TypeDecorator):
"""
Emulates a Python `set` type using an `sqlalchemy.types.JSON` back end
A set object is an unordered collection of distinct objects. SQL
does not natively support guaranteed distinct collections, but
this constraint can be enforced by the
`sqlalchemy.types.TypeEngine` when reading from and writing to the
database, and using `sqlalchemy.event` hooks when assigning to an
ORM attribute. E.g.:
.. testsetup::
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
engine = create_engine('sqlite:///:memory:')
Session = sessionmaker(bind=engine)
session = Session()
DeclarativeBase = declarative_base()
.. testcode::
from sqlalchemy import Column, Integer
from fictive.sqlalchemy.types import SET
class Model(DeclarativeBase):
__tablename__ = 'model'
id = Column(Integer, primary_key=True)
set_column = Column(SET)
DeclarativeBase.metadata.create_all(engine)
.. doctest::
>>> instance = Model(set_column=[1, 2, 2, 3, 4, 3])
>>> session.add(instance)
>>> session.commit()
>>> instance.set_column
{1, 2, 3, 4}
.. testcleanup::
DeclarativeBase.metadata.drop_all(engine)
session.close()
engine.dispose()
"""
impl = sqlalchemy.types.JSON
@property
def python_type(self): # pragma: no cover
return dict
[docs] def process_bind_param(self, value, dialect):
"""
remove duplicate elements & convert to a JSON-compatible type (`list`)
"""
if value is None:
return None
if not isinstance(value, collections.abc.Set):
value = set(value)
return list(value)
[docs] def process_result_value(self, value, dialect):
"""
convert retrieved JSON value (a `list`) to a `set`
"""
if value is None:
return None
return set(value)
[docs] def process_literal_param(self, value, dialect): # pragma: no cover
return self.process_bind_param(json.loads(value), dialect)
@sqlalchemy.event.listens_for(object, 'attribute_instrument')
def _receive_attribute_instrument(cls, key, inst):
# pylint: disable=unused-argument
"""
Converts non-`None` value assigned to a `SET` attribute to be a `set`
"""
if not hasattr(inst.property, 'columns'):
return # pragma: no cover
@sqlalchemy.event.listens_for(inst, "set", retval=True)
def set_(instance, value, oldvalue, initiator):
# pylint: disable=unused-variable
column_type = inst.property.columns[0].type
if not isinstance(column_type, SET):
return value
return column_type.process_result_value(
column_type.process_bind_param(value, None), None)