user get_submodels everywhere
This commit is contained in:
parent
bd7dcc647e
commit
03d40add64
6 changed files with 14 additions and 45 deletions
|
@ -1 +0,0 @@
|
|||
default_app_config = 'c3nav.mapdata.apps.MapdataConfig'
|
|
@ -12,10 +12,11 @@ from rest_framework.response import Response
|
|||
from rest_framework.viewsets import GenericViewSet, ReadOnlyModelViewSet
|
||||
|
||||
from c3nav.mapdata.models import Building, Door, Hole, LocationGroup, Source, Space
|
||||
from c3nav.mapdata.models.geometry.level import LEVEL_MODELS
|
||||
from c3nav.mapdata.models.geometry.space import SPACE_MODELS, Area, Column, LineObstacle, Obstacle, Point, Stair
|
||||
from c3nav.mapdata.models.geometry.level import LevelGeometryMixin
|
||||
from c3nav.mapdata.models.geometry.space import Area, Column, LineObstacle, Obstacle, Point, SpaceGeometryMixin, Stair
|
||||
from c3nav.mapdata.models.level import Level
|
||||
from c3nav.mapdata.models.locations import LOCATION_MODELS, Location, LocationRedirect, LocationSlug, SpecificLocation
|
||||
from c3nav.mapdata.models.locations import Location, LocationRedirect, LocationSlug, SpecificLocation
|
||||
from c3nav.mapdata.utils.models import get_submodels
|
||||
|
||||
|
||||
def optimize_query(qs):
|
||||
|
@ -28,7 +29,7 @@ class MapdataViewSet(ReadOnlyModelViewSet):
|
|||
def list(self, request, *args, **kwargs):
|
||||
qs = optimize_query(self.get_queryset())
|
||||
geometry = ('geometry' in request.GET)
|
||||
if qs.model in LEVEL_MODELS and 'level' in request.GET:
|
||||
if issubclass(qs.model, LevelGeometryMixin) and 'level' in request.GET:
|
||||
if not request.GET['level'].isdigit():
|
||||
raise ValidationError(detail={'detail': _('%s is not an integer.') % 'level'})
|
||||
try:
|
||||
|
@ -36,7 +37,7 @@ class MapdataViewSet(ReadOnlyModelViewSet):
|
|||
except Level.DoesNotExist:
|
||||
raise NotFound(detail=_('level not found.'))
|
||||
qs = qs.filter(level=level)
|
||||
if qs.model in SPACE_MODELS and 'space' in request.GET:
|
||||
if issubclass(qs.model, SpaceGeometryMixin) and 'space' in request.GET:
|
||||
if not request.GET['space'].isdigit():
|
||||
raise ValidationError(detail={'detail': _('%s is not an integer.') % 'space'})
|
||||
try:
|
||||
|
@ -73,7 +74,7 @@ class LevelViewSet(MapdataViewSet):
|
|||
|
||||
@list_route(methods=['get'])
|
||||
def geometrytypes(self, request):
|
||||
return self.list_types(LEVEL_MODELS)
|
||||
return self.list_types(get_submodels(LevelGeometryMixin))
|
||||
|
||||
@detail_route(methods=['get'])
|
||||
def svg(self, requests, pk=None):
|
||||
|
@ -93,7 +94,7 @@ class SpaceViewSet(MapdataViewSet):
|
|||
|
||||
@list_route(methods=['get'])
|
||||
def geometrytypes(self, request):
|
||||
return self.list_types(SPACE_MODELS)
|
||||
return self.list_types(get_submodels(SpaceGeometryMixin))
|
||||
|
||||
|
||||
class DoorViewSet(MapdataViewSet):
|
||||
|
@ -152,7 +153,7 @@ class LocationViewSet(RetrieveModelMixin, GenericViewSet):
|
|||
|
||||
def list(self, request, *args, **kwargs):
|
||||
queryset = sorted(chain(*(optimize_query(model.objects.filter(Q(can_search=True) | Q(can_describe=True)))
|
||||
for model in LOCATION_MODELS)), key=lambda obj: obj.id)
|
||||
for model in get_submodels(Location))), key=lambda obj: obj.id)
|
||||
return Response([obj.serialize(include_type=True, detailed='detailed' in request.GET) for obj in queryset])
|
||||
|
||||
def retrieve(self, request, slug=None, *args, **kwargs):
|
||||
|
@ -168,7 +169,7 @@ class LocationViewSet(RetrieveModelMixin, GenericViewSet):
|
|||
|
||||
@list_route(methods=['get'])
|
||||
def types(self, request):
|
||||
return MapdataViewSet.list_types(LOCATION_MODELS, geomtype=False)
|
||||
return MapdataViewSet.list_types(get_submodels(Location), geomtype=False)
|
||||
|
||||
@list_route(methods=['get'])
|
||||
def redirects(self, request):
|
||||
|
@ -178,7 +179,7 @@ class LocationViewSet(RetrieveModelMixin, GenericViewSet):
|
|||
def search(self, request):
|
||||
# todo: implement caching here
|
||||
results = sorted(chain(*(optimize_query(model.objects.filter(can_search=True))
|
||||
for model in LOCATION_MODELS)), key=lambda obj: obj.id)
|
||||
for model in get_submodels(Location))), key=lambda obj: obj.id)
|
||||
search = request.GET.get('s')
|
||||
if not search:
|
||||
return Response([obj.serialize(include_type=True, detailed='detailed' in request.GET) for obj in results])
|
||||
|
|
|
@ -1,27 +0,0 @@
|
|||
from django.apps import AppConfig
|
||||
from django.core.exceptions import FieldDoesNotExist
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from c3nav.mapdata.utils.models import get_submodels
|
||||
|
||||
|
||||
class MapdataConfig(AppConfig):
|
||||
name = 'c3nav.mapdata'
|
||||
|
||||
def ready(self):
|
||||
from c3nav.mapdata.models.geometry.base import GeometryMixin, GEOMETRY_MODELS
|
||||
for cls in get_submodels(GeometryMixin):
|
||||
GEOMETRY_MODELS[cls.__name__] = cls
|
||||
try:
|
||||
cls._meta.get_field('geometry')
|
||||
except FieldDoesNotExist:
|
||||
raise TypeError(_('Model %s has GeometryMixin as base class but has no geometry field.') % cls)
|
||||
|
||||
from c3nav.mapdata.models.locations import Location, LOCATION_MODELS
|
||||
LOCATION_MODELS.extend(get_submodels(Location))
|
||||
|
||||
from c3nav.mapdata.models.geometry.level import LevelGeometryMixin, LEVEL_MODELS
|
||||
LEVEL_MODELS.extend(get_submodels(LevelGeometryMixin))
|
||||
|
||||
from c3nav.mapdata.models.geometry.space import SpaceGeometryMixin, SPACE_MODELS
|
||||
SPACE_MODELS.extend(get_submodels(SpaceGeometryMixin))
|
|
@ -5,8 +5,6 @@ from shapely.geometry import Point, mapping
|
|||
from c3nav.mapdata.models.base import SerializableMixin
|
||||
from c3nav.mapdata.utils.json import format_geojson
|
||||
|
||||
GEOMETRY_MODELS = OrderedDict()
|
||||
|
||||
|
||||
class GeometryMixin(SerializableMixin):
|
||||
"""
|
||||
|
|
|
@ -7,8 +7,7 @@ from django.utils.translation import get_language
|
|||
|
||||
from c3nav.mapdata.fields import JSONField
|
||||
from c3nav.mapdata.models.base import SerializableMixin
|
||||
|
||||
LOCATION_MODELS = []
|
||||
from c3nav.mapdata.utils.models import get_submodels
|
||||
|
||||
|
||||
class LocationSlugManager(models.Manager):
|
||||
|
@ -16,7 +15,7 @@ class LocationSlugManager(models.Manager):
|
|||
result = super().get_queryset()
|
||||
if self.model == LocationSlug:
|
||||
result = result.select_related(*(model._meta.default_related_name
|
||||
for model in LOCATION_MODELS+[LocationRedirect]))
|
||||
for model in get_submodels(Location)+[LocationRedirect]))
|
||||
return result
|
||||
|
||||
|
||||
|
@ -35,7 +34,7 @@ class LocationSlug(SerializableMixin, models.Model):
|
|||
|
||||
def get_child(self):
|
||||
# todo: cache this
|
||||
for model in LOCATION_MODELS+[LocationRedirect]:
|
||||
for model in get_submodels(Location)+[LocationRedirect]:
|
||||
with suppress(AttributeError):
|
||||
return getattr(self, model._meta.default_related_name)
|
||||
return None
|
||||
|
|
|
@ -3,7 +3,6 @@ import typing
|
|||
from celery import chain
|
||||
from django.db import models
|
||||
|
||||
|
||||
_submodels_by_model = {}
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue