From 0be9100c0f2ce4ac0f75de57355a9c8d98e7fcf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laura=20Kl=C3=BCnder?= Date: Sun, 18 Jun 2017 01:04:07 +0200 Subject: [PATCH] use is_created_pk everywhere --- src/c3nav/editor/models.py | 11 +++++---- src/c3nav/editor/wrappers.py | 44 +++++++++++++++++------------------- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/src/c3nav/editor/models.py b/src/c3nav/editor/models.py index 5dd8cf23..4f03c636 100644 --- a/src/c3nav/editor/models.py +++ b/src/c3nav/editor/models.py @@ -13,7 +13,7 @@ from django.utils.functional import cached_property from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ungettext_lazy -from c3nav.editor.wrappers import ModelInstanceWrapper, ModelWrapper +from c3nav.editor.wrappers import ModelInstanceWrapper, ModelWrapper, is_created_pk class ChangeSet(models.Model): @@ -98,8 +98,9 @@ class ChangeSet(models.Model): return r def get_created_object(self, model, pk, author=None, get_foreign_objects=False): - if isinstance(pk, str): - pk = int(pk[1:]) + if is_created_pk(pk): + pk = pk[1:] + pk = int(pk) self.parse_changes() if issubclass(model, ModelWrapper): model = model._obj @@ -118,7 +119,7 @@ class ChangeSet(models.Model): if isinstance(class_value, ForwardManyToOneDescriptor): field = class_value.field setattr(obj, field.attname, value) - if isinstance(pk, str): + if is_created_pk(pk): setattr(obj, class_value.cache_name, self.get_created_object(field.model, value)) elif get_foreign_objects or True: setattr(obj, class_value.cache_name, self.wrap(field.related_model.objects.get(pk=value))) @@ -340,7 +341,7 @@ class Change(models.Model): if not isinstance(value, ModelInstanceWrapper): value = self.changeset.wrap(value) - if isinstance(value.pk, str): + if is_created_pk(value.pk): if value._changeset.id != self.changeset.pk: raise ValueError('value is a Change instance but belongs to a different changeset.') self.model_class = type(value._obj) diff --git a/src/c3nav/editor/wrappers.py b/src/c3nav/editor/wrappers.py index 7cd24416..33191a15 100644 --- a/src/c3nav/editor/wrappers.py +++ b/src/c3nav/editor/wrappers.py @@ -11,6 +11,10 @@ from django.db.models.query_utils import DeferredAttribute from django.utils.functional import cached_property +def is_created_pk(pk): + return isinstance(pk, str) and pk.startswith('c') and pk[1:].isnumeric() + + class BaseWrapper: _not_wrapped = ('_changeset', '_author', '_obj', '_created_pks', '_result', '_extra', '_initial_values') _allowed_callables = ('', ) @@ -131,7 +135,7 @@ class ModelInstanceWrapper(BaseWrapper): if field.name in updates: value_pk = updates[field.name] class_value = getattr(type(self._obj), field.name, None) - if isinstance(value_pk, str): + if is_created_pk(value_pk): obj = self._wrap_model(field.model).get(pk=value_pk) setattr(self._obj, class_value.cache_name, obj) setattr(self._obj, field.attname, obj.pk) @@ -166,11 +170,9 @@ class ModelInstanceWrapper(BaseWrapper): cls_name = self._obj.__class__.__name__ if self.pk is None: return '<%s (unsaved) with Changeset #%d>' % (cls_name, self._changeset.pk) - elif isinstance(self.pk, int): - return '<%s #%d (existing) with Changeset #%d>' % (cls_name, self.pk, self._changeset.pk) - elif isinstance(self.pk, str): + elif is_created_pk(self.pk): return '<%s #%s (created) from Changeset #%d>' % (cls_name, self.pk, self._changeset.pk) - raise TypeError + return '<%s #%d (existing) with Changeset #%d>' % (cls_name, self.pk, self._changeset.pk) def save(self, author=None): if author is None: @@ -320,10 +322,6 @@ class BaseQueryWrapper(BaseWrapper): created_pks.add(pk) return (q & ~Q(pk__in=remove_pks)) | Q(pk__in=add_pks), created_pks - @staticmethod - def is_created_pk(pk): - return isinstance(pk, str) and pk.startswith('c') and pk[1:].isnumeric() - def _filter_kwarg(self, filter_name, filter_value): # print(filter_name, '=', filter_value, sep='') @@ -338,12 +336,12 @@ class BaseQueryWrapper(BaseWrapper): if field_name == 'pk' or field_name == self._obj.model._meta.pk.name: if not segments: - if self.is_created_pk(filter_value): + if is_created_pk(filter_value): return Q(pk__in=()), set([int(filter_value[1:])]) return q, set() elif segments == ['in']: - return (Q(pk__in=tuple(pk for pk in filter_value if not self.is_created_pk(pk))), - set(int(pk[1:]) for pk in filter_value if self.is_created_pk(pk))) + return (Q(pk__in=tuple(pk for pk in filter_value if not is_created_pk(pk))), + set(int(pk[1:]) for pk in filter_value if is_created_pk(pk))) if isinstance(class_value, ForwardManyToOneDescriptor): if not segments: @@ -365,7 +363,7 @@ class BaseQueryWrapper(BaseWrapper): filter_type = 'pk' if filter_type == 'pk' and segments == ['in']: - q = Q(**{field_name+'__pk__in': tuple(pk for pk in filter_value if not self.is_created_pk(pk))}) + q = Q(**{field_name+'__pk__in': tuple(pk for pk in filter_value if not is_created_pk(pk))}) filter_value = tuple(str(pk) for pk in filter_value) return self._filter_values(q, field_name, lambda val: str(val) in filter_value) @@ -373,7 +371,7 @@ class BaseQueryWrapper(BaseWrapper): raise NotImplementedError if filter_type == 'pk': - if self.is_created_pk(filter_value): + if is_created_pk(filter_value): q = Q(pk__in=()) filter_value = str(filter_value) return self._filter_values(q, field_name, lambda val: str(val) == filter_value) @@ -410,7 +408,7 @@ class BaseQueryWrapper(BaseWrapper): # field_name would be "spaces" model = class_value.field.model # space filter_value = set(filter_value) # space pks - filter_value_existing = set(pk for pk in filter_value if not self.is_created_pk(pk)) + filter_value_existing = set(pk for pk in filter_value if not is_created_pk(pk)) rel_name = class_value.field.name # get spaces that we are interested about that had groups added or removed @@ -430,10 +428,10 @@ class BaseQueryWrapper(BaseWrapper): r_added_pks = reduce(operator.or_, m2m_added.values(), set()) # lookup existing groups that were added to any of the spaces - q |= Q(pk__in=tuple(pk for pk in r_added_pks if not self.is_created_pk(pk))) + q |= Q(pk__in=tuple(pk for pk in r_added_pks if not is_created_pk(pk))) # get created groups that were added to any of the spaces - created_pks = set(int(pk[1:]) for pk in r_added_pks if self.is_created_pk(pk)) + created_pks = set(int(pk[1:]) for pk in r_added_pks if is_created_pk(pk)) return q, created_pks @@ -450,14 +448,14 @@ class BaseQueryWrapper(BaseWrapper): remove_pks = get_changeset_m2m(self._changeset.m2m_removed) add_pks = get_changeset_m2m(self._changeset.m2m_added) - if self.is_created_pk(filter_value): + if is_created_pk(filter_value): pks = add_pks - return (Q(pk__in=(pk for pk in pks if not self.is_created_pk(pk))), - set(int(pk[1:]) for pk in pks if self.is_created_pk(pk))) + return (Q(pk__in=(pk for pk in pks if not is_created_pk(pk))), + set(int(pk[1:]) for pk in pks if is_created_pk(pk))) - return (((q & ~Q(pk__in=(pk for pk in remove_pks if not self.is_created_pk(pk)))) | - Q(pk__in=(pk for pk in add_pks if not self.is_created_pk(pk)))), - set(int(pk[1:]) for pk in add_pks if self.is_created_pk(pk))) + return (((q & ~Q(pk__in=(pk for pk in remove_pks if not is_created_pk(pk)))) | + Q(pk__in=(pk for pk in add_pks if not is_created_pk(pk)))), + set(int(pk[1:]) for pk in add_pks if is_created_pk(pk))) raise NotImplementedError