Fix the GenericForeignKey implementation

This commit is contained in:
Jessica Tallon 2015-04-28 19:53:25 +02:00
parent bfe1e8ce88
commit 6185a4b9e6

View File

@ -25,7 +25,7 @@ import datetime
from sqlalchemy import Column, Integer, Unicode, UnicodeText, DateTime, \
Boolean, ForeignKey, UniqueConstraint, PrimaryKeyConstraint, \
SmallInteger, Date
SmallInteger, Date, types
from sqlalchemy.orm import relationship, backref, with_polymorphic, validates, \
class_mapper
from sqlalchemy.orm.collections import attribute_mapped_collection
@ -69,7 +69,7 @@ class GenericModelReference(Base):
return None
model = self._get_model_from_type(self.model_type)
return model.query.filter_by(id=self.obj_pk)
return model.query.filter_by(id=self.obj_pk).first()
def set_object(self, obj):
model = obj.__class__
@ -100,38 +100,30 @@ class GenericModelReference(Base):
def _get_model_from_type(self, model_type):
""" Gets a model from a tablename (model type) """
if getattr(self, "_TYPE_MAP", None) is None:
if getattr(self.__class__, "_TYPE_MAP", None) is None:
# We want to build on the class (not the instance) a map of all the
# models by the table name (type) for easy lookup, this is done on
# the class so it can be shared between all instances
# to prevent circular imports do import here
self._TYPE_MAP = dict(((m.__tablename__, m) for m in MODELS))
setattr(self.__class__._TYPE_MAP, self._TYPE_MAP)
setattr(self.__class__, "_TYPE_MAP", self._TYPE_MAP)
return self._TYPE_MAP[model_type]
return self.__class__._TYPE_MAP[model_type]
class GenericForeignKey(ForeignKey):
class GenericForeignKey(types.TypeDecorator):
def __init__(self, *args, **kwargs):
super(GenericForeignKey, self).__init__(
GenericModelReference.id,
*args,
**kwargs
)
impl = Integer
def __get__(self, *args, **kwargs):
def process_result_value(self, value, *args, **kwargs):
""" Looks up GenericModelReference and model for field """
# Find the value of the foreign key.
ref = super(self, GenericForeignKey).__get__(*args, **kwargs)
# If this hasn't been set yet return None
if ref is None:
if value is None:
return None
# Look up the GenericModelReference for this.
gmr = GenericModelReference.query.filter_by(id=ref).first()
gmr = GenericModelReference.query.filter_by(id=value).first()
# If it's set to something invalid (i.e. no GMR exists return None)
if gmr is None:
@ -140,6 +132,30 @@ class GenericForeignKey(ForeignKey):
# Ask the GMR for the corresponding model
return gmr.get_object()
def process_bind_param(self, value, *args, **kwargs):
""" Save the foreign key """
if value is None:
return None
# Is there one for this already.
model = type(value)
pk = getattr(value, "id")
gmr = GenericModelReference.query.filter_by(id=pk).first()
if gmr is None:
# We need to create one
gmr = GenericModelReference(
obj_pk=pk,
model_type=model.__tablename__
)
gmr.save()
return gmr.id
def _set_parent_with_dispatch(self, parent):
self.parent = parent
class Location(Base):
""" Represents a physical location """
@ -1431,11 +1447,9 @@ class Activity(Base, ActivityMixin):
generator = Column(Integer,
ForeignKey("core__generators.id"),
nullable=True)
object = Column(Integer,
GenericForeignKey(),
object = Column(GenericForeignKey(),
nullable=False)
target = Column(Integer,
GenericForeignKey(),
target = Column(GenericForeignKey(),
nullable=True)
get_actor = relationship(User,