diff --git a/src/c3nav/mapdata/fields.py b/src/c3nav/mapdata/fields.py index d1cbd30c..7985aa96 100644 --- a/src/c3nav/mapdata/fields.py +++ b/src/c3nav/mapdata/fields.py @@ -5,6 +5,7 @@ import typing from django.core.exceptions import ValidationError from django.core.validators import RegexValidator from django.db import models +from django.utils.functional import cached_property from django.utils.translation import ugettext_lazy as _ from shapely import validation from shapely.geometry import LineString, MultiPolygon, Point, Polygon, mapping, shape @@ -68,15 +69,19 @@ class GeometryField(models.TextField): self._validate_geomtype(geometry) return geometry + @cached_property + def classes(self): + return { + 'polygon': (Polygon, ), + 'multipolygon': (Polygon, MultiPolygon), + 'linestring': (LineString, ), + 'point': (Point, ) + }[self.geomtype] + def _validate_geomtype(self, value, exception: typing.Type[Exception]=ValidationError): - if self.geomtype == 'polygon' and not isinstance(value, Polygon): - raise exception('Expected Polygon instance, got %s instead.' % repr(value)) - if self.geomtype == 'multipolygon' and not isinstance(value, (Polygon, MultiPolygon)): - raise exception('Expected Polygon or MultiPolygon instance, got %s instead.' % repr(value)) - elif self.geomtype == 'linestring' and not isinstance(value, LineString): - raise exception('Expected LineString instance, got %s instead.' % repr(value)) - elif self.geomtype == 'point' and not isinstance(value, Point): - raise exception('Expected Point instance, got %s instead.' % repr(value)) + if not isinstance(value, self.classes): + raise exception('Expected %s instance, got %s instead.' % (' or '.join(c.__name__ for c in self.classes), + repr(value))) def get_final_value(self, value, as_json=False): json_value = format_geojson(mapping(value))