diff options
Diffstat (limited to 'django/contrib/gis/db/models')
-rw-r--r-- | django/contrib/gis/db/models/fields.py | 121 | ||||
-rw-r--r-- | django/contrib/gis/db/models/lookups.py | 69 |
2 files changed, 155 insertions, 35 deletions
diff --git a/django/contrib/gis/db/models/fields.py b/django/contrib/gis/db/models/fields.py index 8452a11c5f..91a12623c7 100644 --- a/django/contrib/gis/db/models/fields.py +++ b/django/contrib/gis/db/models/fields.py @@ -1,7 +1,10 @@ from django.contrib.gis import forms -from django.contrib.gis.db.models.lookups import gis_lookups +from django.contrib.gis.db.models.lookups import ( + RasterBandTransform, gis_lookups, +) from django.contrib.gis.db.models.proxy import SpatialProxy from django.contrib.gis.gdal import HAS_GDAL +from django.contrib.gis.gdal.error import GDALException from django.contrib.gis.geometry.backend import Geometry, GeometryException from django.core.exceptions import ImproperlyConfigured from django.db.models.expressions import Expression @@ -157,6 +160,82 @@ class BaseSpatialField(Field): """ return connection.ops.get_geom_placeholder(self, value, compiler) + def get_srid(self, obj): + """ + Return the default SRID for the given geometry or raster, taking into + account the SRID set for the field. For example, if the input geometry + or raster doesn't have an SRID, then the SRID of the field will be + returned. + """ + srid = obj.srid # SRID of given geometry. + if srid is None or self.srid == -1 or (srid == -1 and self.srid != -1): + return self.srid + else: + return srid + + def get_db_prep_save(self, value, connection): + """ + Prepare the value for saving in the database. + """ + if not value: + return None + else: + return connection.ops.Adapter(self.get_prep_value(value)) + + def get_prep_value(self, value): + """ + Spatial lookup values are either a parameter that is (or may be + converted to) a geometry or raster, or a sequence of lookup values + that begins with a geometry or raster. This routine sets up the + geometry or raster value properly and preserves any other lookup + parameters. + """ + from django.contrib.gis.gdal import GDALRaster + + value = super(BaseSpatialField, self).get_prep_value(value) + # For IsValid lookups, boolean values are allowed. + if isinstance(value, (Expression, bool)): + return value + elif isinstance(value, (tuple, list)): + obj = value[0] + seq_value = True + else: + obj = value + seq_value = False + + # When the input is not a geometry or raster, attempt to construct one + # from the given string input. + if isinstance(obj, (Geometry, GDALRaster)): + pass + elif isinstance(obj, (bytes, six.string_types)) or hasattr(obj, '__geo_interface__'): + try: + obj = Geometry(obj) + except (GeometryException, GDALException): + try: + obj = GDALRaster(obj) + except GDALException: + raise ValueError("Couldn't create spatial object from lookup value '%s'." % obj) + elif isinstance(obj, dict): + try: + obj = GDALRaster(obj) + except GDALException: + raise ValueError("Couldn't create spatial object from lookup value '%s'." % obj) + else: + raise ValueError('Cannot use object with type %s for a spatial lookup parameter.' % type(obj).__name__) + + # Assigning the SRID value. + obj.srid = self.get_srid(obj) + + if seq_value: + lookup_val = [obj] + lookup_val.extend(value[1:]) + return tuple(lookup_val) + else: + return obj + +for klass in gis_lookups.values(): + BaseSpatialField.register_lookup(klass) + class GeometryField(GeoSelectFormatMixin, BaseSpatialField): """ @@ -224,6 +303,8 @@ class GeometryField(GeoSelectFormatMixin, BaseSpatialField): value properly, and preserve any other lookup parameters before returning to the caller. """ + from django.contrib.gis.gdal import GDALRaster + value = super(GeometryField, self).get_prep_value(value) if isinstance(value, (Expression, bool)): return value @@ -236,7 +317,7 @@ class GeometryField(GeoSelectFormatMixin, BaseSpatialField): # When the input is not a GEOS geometry, attempt to construct one # from the given string input. - if isinstance(geom, Geometry): + if isinstance(geom, (Geometry, GDALRaster)): pass elif isinstance(geom, (bytes, six.string_types)) or hasattr(geom, '__geo_interface__'): try: @@ -265,18 +346,6 @@ class GeometryField(GeoSelectFormatMixin, BaseSpatialField): value.srid = self.srid return value - def get_srid(self, geom): - """ - Returns the default SRID for the given geometry, taking into account - the SRID set for the field. For example, if the input geometry - has no SRID, then that of the field will be returned. - """ - gsrid = geom.srid # SRID of given geometry. - if gsrid is None or self.srid == -1 or (gsrid == -1 and self.srid != -1): - return self.srid - else: - return gsrid - # ### Routines overloaded from Field ### def contribute_to_class(self, cls, name, **kwargs): super(GeometryField, self).contribute_to_class(cls, name, **kwargs) @@ -316,17 +385,6 @@ class GeometryField(GeoSelectFormatMixin, BaseSpatialField): params = [connection.ops.Adapter(value)] return params - def get_db_prep_save(self, value, connection): - "Prepares the value for saving in the database." - if not value: - return None - else: - return connection.ops.Adapter(self.get_prep_value(value)) - - -for klass in gis_lookups.values(): - GeometryField.register_lookup(klass) - # The OpenGIS Geometry Type Fields class PointField(GeometryField): @@ -387,6 +445,7 @@ class RasterField(BaseSpatialField): description = _("Raster Field") geom_type = 'RASTER' + geography = False def __init__(self, *args, **kwargs): if not HAS_GDAL: @@ -421,3 +480,15 @@ class RasterField(BaseSpatialField): # delays the instantiation of the objects to the moment of evaluation # of the raster attribute. setattr(cls, self.attname, SpatialProxy(GDALRaster, self)) + + def get_transform(self, name): + try: + band_index = int(name) + return type( + 'SpecificRasterBandTransform', + (RasterBandTransform, ), + {'band_index': band_index} + ) + except ValueError: + pass + return super(RasterField, self).get_transform(name) diff --git a/django/contrib/gis/db/models/lookups.py b/django/contrib/gis/db/models/lookups.py index 158fb5cfa1..a5bf428aa4 100644 --- a/django/contrib/gis/db/models/lookups.py +++ b/django/contrib/gis/db/models/lookups.py @@ -5,16 +5,23 @@ import re from django.core.exceptions import FieldDoesNotExist from django.db.models.constants import LOOKUP_SEP from django.db.models.expressions import Col, Expression -from django.db.models.lookups import BuiltinLookup, Lookup +from django.db.models.lookups import BuiltinLookup, Lookup, Transform from django.utils import six gis_lookups = {} +class RasterBandTransform(Transform): + def as_sql(self, compiler, connection): + return compiler.compile(self.lhs) + + class GISLookup(Lookup): sql_template = None transform_func = None distance = False + band_rhs = None + band_lhs = None def __init__(self, *args, **kwargs): super(GISLookup, self).__init__(*args, **kwargs) @@ -28,10 +35,10 @@ class GISLookup(Lookup): 'point, 'the_geom', or a related lookup on a geographic field like 'address__point'. - If a GeometryField exists according to the given lookup on the model - options, it will be returned. Otherwise returns None. + If a BaseSpatialField exists according to the given lookup on the model + options, it will be returned. Otherwise return None. """ - from django.contrib.gis.db.models.fields import GeometryField + from django.contrib.gis.db.models.fields import BaseSpatialField # This takes into account the situation where the lookup is a # lookup to a related geographic field, e.g., 'address__point'. field_list = lookup.split(LOOKUP_SEP) @@ -55,11 +62,34 @@ class GISLookup(Lookup): return False # Finally, make sure we got a Geographic field and return. - if isinstance(geo_fld, GeometryField): + if isinstance(geo_fld, BaseSpatialField): return geo_fld else: return False + def process_band_indices(self, only_lhs=False): + """ + Extract the lhs band index from the band transform class and the rhs + band index from the input tuple. + """ + # PostGIS band indices are 1-based, so the band index needs to be + # increased to be consistent with the GDALRaster band indices. + if only_lhs: + self.band_rhs = 1 + self.band_lhs = self.lhs.band_index + 1 + return + + if isinstance(self.lhs, RasterBandTransform): + self.band_lhs = self.lhs.band_index + 1 + else: + self.band_lhs = 1 + + self.band_rhs = self.rhs[1] + if len(self.rhs) == 1: + self.rhs = self.rhs[0] + else: + self.rhs = (self.rhs[0], ) + self.rhs[2:] + def get_db_prep_lookup(self, value, connection): # get_db_prep_lookup is called by process_rhs from super class if isinstance(value, (tuple, list)): @@ -70,10 +100,9 @@ class GISLookup(Lookup): return ('%s', params) def process_rhs(self, compiler, connection): - rhs, rhs_params = super(GISLookup, self).process_rhs(compiler, connection) if hasattr(self.rhs, '_as_sql'): # If rhs is some QuerySet, don't touch it - return rhs, rhs_params + return super(GISLookup, self).process_rhs(compiler, connection) geom = self.rhs if isinstance(self.rhs, Col): @@ -85,9 +114,19 @@ class GISLookup(Lookup): raise ValueError('No geographic field found in expression.') self.rhs.srid = geo_fld.srid elif isinstance(self.rhs, Expression): - raise ValueError('Complex expressions not supported for GeometryField') + raise ValueError('Complex expressions not supported for spatial fields.') elif isinstance(self.rhs, (list, tuple)): geom = self.rhs[0] + # Check if a band index was passed in the query argument. + if ((len(self.rhs) == 2 and not self.lookup_name == 'relate') or + (len(self.rhs) == 3 and self.lookup_name == 'relate')): + self.process_band_indices() + elif len(self.rhs) > 2: + raise ValueError('Tuple too long for lookup %s.' % self.lookup_name) + elif isinstance(self.lhs, RasterBandTransform): + self.process_band_indices(only_lhs=True) + + rhs, rhs_params = super(GISLookup, self).process_rhs(compiler, connection) rhs = connection.ops.get_geom_placeholder(self.lhs.output_field, geom, compiler) return rhs, rhs_params @@ -274,6 +313,8 @@ class IsValidLookup(BuiltinLookup): lookup_name = 'isvalid' def as_sql(self, compiler, connection): + if self.lhs.field.geom_type == 'RASTER': + raise ValueError('The isvalid lookup is only available on geometry fields.') gis_op = connection.ops.gis_operators[self.lookup_name] sql, params = self.process_lhs(compiler, connection) sql = '%(func)s(%(lhs)s)' % {'func': gis_op.func, 'lhs': sql} @@ -323,9 +364,17 @@ class DistanceLookupBase(GISLookup): sql_template = '%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s' def process_rhs(self, compiler, connection): - if not isinstance(self.rhs, (tuple, list)) or not 2 <= len(self.rhs) <= 3: - raise ValueError("2 or 3-element tuple required for '%s' lookup." % self.lookup_name) + if not isinstance(self.rhs, (tuple, list)) or not 2 <= len(self.rhs) <= 4: + raise ValueError("2, 3, or 4-element tuple required for '%s' lookup." % self.lookup_name) + elif len(self.rhs) == 4 and not self.rhs[3] == 'spheroid': + raise ValueError("For 4-element tuples the last argument must be the 'speroid' directive.") + + # Check if the second parameter is a band index. + if len(self.rhs) > 2 and not self.rhs[2] == 'spheroid': + self.process_band_indices() + params = [connection.ops.Adapter(self.rhs[0])] + # Getting the distance parameter in the units of the field. dist_param = self.rhs[1] if hasattr(dist_param, 'resolve_expression'): |