diff --git a/mediagoblin/db/models.py b/mediagoblin/db/models.py index 101e7cee..10e0c33f 100644 --- a/mediagoblin/db/models.py +++ b/mediagoblin/db/models.py @@ -20,7 +20,6 @@ TODO: indexes on foreignkeys, where useful. import logging import datetime -import sys from sqlalchemy import Column, Integer, Unicode, UnicodeText, DateTime, \ Boolean, ForeignKey, UniqueConstraint, PrimaryKeyConstraint, \ @@ -32,9 +31,10 @@ from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.util import memoized_property from mediagoblin.db.extratypes import PathTupleWithSlashes, JSONEncoded -from mediagoblin.db.base import Base, DictReadAttrProxy, Session +from mediagoblin.db.base import Base, DictReadAttrProxy from mediagoblin.db.mixin import UserMixin, MediaEntryMixin, MediaCommentMixin, CollectionMixin, CollectionItemMixin from mediagoblin.tools.files import delete_media_files +from mediagoblin.tools.common import import_component # It's actually kind of annoying how sqlalchemy-migrate does this, if # I understand it right, but whatever. Anyway, don't remove this :P @@ -165,7 +165,6 @@ class MediaEntry(Base, MediaEntryMixin): collections = association_proxy("collections_helper", "in_collection") ## TODO - # media_data # fail_error def get_comments(self, ascending=False): @@ -195,40 +194,31 @@ class MediaEntry(Base, MediaEntryMixin): if media is not None: return media.url_for_self(urlgen) - #@memoized_property @property def media_data(self): - session = Session() - - return session.query(self.media_data_table).filter_by( - media_entry=self.id).first() + return getattr(self, self.media_data_ref) def media_data_init(self, **kwargs): """ Initialize or update the contents of a media entry's media_data row """ - session = Session() + media_data = self.media_data - media_data = session.query(self.media_data_table).filter_by( - media_entry=self.id).first() - - # No media data, so actually add a new one if media_data is None: - media_data = self.media_data_table( - media_entry=self.id, - **kwargs) - session.add(media_data) - # Update old media data + # Get the correct table: + table = import_component(self.media_type + '.models:DATA_MODEL') + # No media data, so actually add a new one + media_data = table(**kwargs) + # Get the relationship set up. + media_data.get_media_entry = self else: + # Update old media data for field, value in kwargs.iteritems(): setattr(media_data, field, value) @memoized_property - def media_data_table(self): - # TODO: memoize this - models_module = self.media_type + '.models' - __import__(models_module) - return sys.modules[models_module].DATA_MODEL + def media_data_ref(self): + return import_component(self.media_type + '.models:BACKREF_NAME') def __repr__(self): safe_title = self.title.encode('ascii', 'replace') diff --git a/mediagoblin/media_types/ascii/models.py b/mediagoblin/media_types/ascii/models.py index 3416993c..c7505292 100644 --- a/mediagoblin/media_types/ascii/models.py +++ b/mediagoblin/media_types/ascii/models.py @@ -32,7 +32,8 @@ class AsciiData(Base): media_entry = Column(Integer, ForeignKey('core__media_entries.id'), primary_key=True) get_media_entry = relationship("MediaEntry", - backref=backref(BACKREF_NAME, cascade="all, delete-orphan")) + backref=backref(BACKREF_NAME, uselist=False, + cascade="all, delete-orphan")) DATA_MODEL = AsciiData diff --git a/mediagoblin/media_types/audio/models.py b/mediagoblin/media_types/audio/models.py index 368ab1eb..d01367d5 100644 --- a/mediagoblin/media_types/audio/models.py +++ b/mediagoblin/media_types/audio/models.py @@ -32,7 +32,8 @@ class AudioData(Base): media_entry = Column(Integer, ForeignKey('core__media_entries.id'), primary_key=True) get_media_entry = relationship("MediaEntry", - backref=backref(BACKREF_NAME, cascade="all, delete-orphan")) + backref=backref(BACKREF_NAME, uselist=False, + cascade="all, delete-orphan")) DATA_MODEL = AudioData diff --git a/mediagoblin/media_types/image/models.py b/mediagoblin/media_types/image/models.py index 63d80aa8..b2ea3960 100644 --- a/mediagoblin/media_types/image/models.py +++ b/mediagoblin/media_types/image/models.py @@ -33,7 +33,8 @@ class ImageData(Base): media_entry = Column(Integer, ForeignKey('core__media_entries.id'), primary_key=True) get_media_entry = relationship("MediaEntry", - backref=backref(BACKREF_NAME, cascade="all, delete-orphan")) + backref=backref(BACKREF_NAME, uselist=False, + cascade="all, delete-orphan")) width = Column(Integer) height = Column(Integer) diff --git a/mediagoblin/media_types/stl/models.py b/mediagoblin/media_types/stl/models.py index 17091f0e..ff50e9c0 100644 --- a/mediagoblin/media_types/stl/models.py +++ b/mediagoblin/media_types/stl/models.py @@ -32,7 +32,8 @@ class StlData(Base): media_entry = Column(Integer, ForeignKey('core__media_entries.id'), primary_key=True) get_media_entry = relationship("MediaEntry", - backref=backref(BACKREF_NAME, cascade="all, delete-orphan")) + backref=backref(BACKREF_NAME, uselist=False, + cascade="all, delete-orphan")) center_x = Column(Float) center_y = Column(Float) diff --git a/mediagoblin/media_types/video/models.py b/mediagoblin/media_types/video/models.py index 645ef4d3..a771352c 100644 --- a/mediagoblin/media_types/video/models.py +++ b/mediagoblin/media_types/video/models.py @@ -32,7 +32,8 @@ class VideoData(Base): media_entry = Column(Integer, ForeignKey('core__media_entries.id'), primary_key=True) get_media_entry = relationship("MediaEntry", - backref=backref(BACKREF_NAME, cascade="all, delete-orphan")) + backref=backref(BACKREF_NAME, uselist=False, + cascade="all, delete-orphan")) width = Column(SmallInteger) height = Column(SmallInteger) diff --git a/mediagoblin/tests/test_modelmethods.py b/mediagoblin/tests/test_modelmethods.py index c1064d3a..7719bd97 100644 --- a/mediagoblin/tests/test_modelmethods.py +++ b/mediagoblin/tests/test_modelmethods.py @@ -17,7 +17,9 @@ # Maybe not every model needs a test, but some models have special # methods, and so it makes sense to test them here. +from nose.tools import assert_equal +from mediagoblin.db.base import Session from mediagoblin.db.models import MediaEntry from mediagoblin.tests.tools import get_app, \ @@ -128,3 +130,18 @@ class TestMediaEntrySlugs(object): u"@!#?@!", save=False) qbert_entry.generate_slug() assert qbert_entry.slug is None + + +def test_media_data_init(): + Session.rollback() + Session.remove() + media = MediaEntry() + media.media_type = u"mediagoblin.media_types.image" + assert media.media_data is None + media.media_data_init() + assert media.media_data is not None + obj_in_session = 0 + for obj in Session(): + obj_in_session += 1 + print repr(obj) + assert_equal(obj_in_session, 0) diff --git a/mediagoblin/tests/test_submission.py b/mediagoblin/tests/test_submission.py index 00f1ed3d..fc3d8c83 100644 --- a/mediagoblin/tests/test_submission.py +++ b/mediagoblin/tests/test_submission.py @@ -27,6 +27,7 @@ from pkg_resources import resource_filename from mediagoblin.tests.tools import get_app, \ fixture_add_user from mediagoblin import mg_globals +from mediagoblin.db.models import MediaEntry from mediagoblin.tools import template from mediagoblin.media_types.image import MEDIA_MANAGER as img_MEDIA_MANAGER @@ -40,6 +41,7 @@ EVIL_FILE = resource('evil') EVIL_JPG = resource('evil.jpg') EVIL_PNG = resource('evil.png') BIG_BLUE = resource('bigblue.png') +from .test_exif import GPS_JPG GOOD_TAG_STRING = u'yin,yang' BAD_TAG_STRING = unicode('rage,' + 'f' * 26 + 'u' * 26) @@ -122,7 +124,7 @@ class TestSubmission: self.check_normal_upload(u'Normal upload 2', GOOD_PNG) def check_media(self, request, find_data, count=None): - media = request.db.MediaEntry.find(find_data) + media = MediaEntry.find(find_data) if count is not None: assert_equal(media.count(), count) if count == 0: @@ -265,6 +267,11 @@ class TestSubmission: # ------------------------------------------- self.check_false_image(u'Malicious Upload 3', EVIL_PNG) + def test_media_data(self): + self.check_normal_upload(u"With GPS data", GPS_JPG) + media = self.check_media(None, {"title": u"With GPS data"}, 1) + assert_equal(media.media_data.gps_latitude, 59.336666666666666) + def test_processing(self): data = {'title': u'Big Blue'} response, request = self.do_post(data, *REQUEST_CONTEXT, do_follow=True,