diff options
Diffstat (limited to 'tests/gis_tests/inspectapp/tests.py')
-rw-r--r-- | tests/gis_tests/inspectapp/tests.py | 176 |
1 files changed, 176 insertions, 0 deletions
diff --git a/tests/gis_tests/inspectapp/tests.py b/tests/gis_tests/inspectapp/tests.py new file mode 100644 index 0000000000..19d5aa054c --- /dev/null +++ b/tests/gis_tests/inspectapp/tests.py @@ -0,0 +1,176 @@ +from __future__ import unicode_literals + +import os +import re +from unittest import skipUnless + +from django.contrib.gis.gdal import HAS_GDAL +from django.core.management import call_command +from django.db import connection, connections +from django.test import TestCase, skipUnlessDBFeature +from django.utils.six import StringIO + +from ..test_data import TEST_DATA + +if HAS_GDAL: + from django.contrib.gis.gdal import Driver, GDALException + from django.contrib.gis.utils.ogrinspect import ogrinspect + + from .models import AllOGRFields + + +@skipUnless(HAS_GDAL, "InspectDbTests needs GDAL support") +@skipUnlessDBFeature("gis_enabled") +class InspectDbTests(TestCase): + def test_geom_columns(self): + """ + Test the geo-enabled inspectdb command. + """ + out = StringIO() + call_command('inspectdb', + table_name_filter=lambda tn: tn.startswith('inspectapp_'), + stdout=out) + output = out.getvalue() + if connection.features.supports_geometry_field_introspection: + self.assertIn('geom = models.PolygonField()', output) + self.assertIn('point = models.PointField()', output) + else: + self.assertIn('geom = models.GeometryField(', output) + self.assertIn('point = models.GeometryField(', output) + self.assertIn('objects = models.GeoManager()', output) + + +@skipUnless(HAS_GDAL, "OGRInspectTest needs GDAL support") +@skipUnlessDBFeature("gis_enabled") +class OGRInspectTest(TestCase): + maxDiff = 1024 + + def test_poly(self): + shp_file = os.path.join(TEST_DATA, 'test_poly', 'test_poly.shp') + model_def = ogrinspect(shp_file, 'MyModel') + + expected = [ + '# This is an auto-generated Django model module created by ogrinspect.', + 'from django.contrib.gis.db import models', + '', + 'class MyModel(models.Model):', + ' float = models.FloatField()', + ' int = models.FloatField()', + ' str = models.CharField(max_length=80)', + ' geom = models.PolygonField(srid=-1)', + ' objects = models.GeoManager()', + ] + + self.assertEqual(model_def, '\n'.join(expected)) + + def test_date_field(self): + shp_file = os.path.join(TEST_DATA, 'cities', 'cities.shp') + model_def = ogrinspect(shp_file, 'City') + + expected = [ + '# This is an auto-generated Django model module created by ogrinspect.', + 'from django.contrib.gis.db import models', + '', + 'class City(models.Model):', + ' name = models.CharField(max_length=80)', + ' population = models.FloatField()', + ' density = models.FloatField()', + ' created = models.DateField()', + ' geom = models.PointField(srid=-1)', + ' objects = models.GeoManager()', + ] + + self.assertEqual(model_def, '\n'.join(expected)) + + def test_time_field(self): + # Getting the database identifier used by OGR, if None returned + # GDAL does not have the support compiled in. + ogr_db = get_ogr_db_string() + if not ogr_db: + self.skipTest("Unable to setup an OGR connection to your database") + + try: + # Writing shapefiles via GDAL currently does not support writing OGRTime + # fields, so we need to actually use a database + model_def = ogrinspect(ogr_db, 'Measurement', + layer_key=AllOGRFields._meta.db_table, + decimal=['f_decimal']) + except GDALException: + self.skipTest("Unable to setup an OGR connection to your database") + + self.assertTrue(model_def.startswith( + '# This is an auto-generated Django model module created by ogrinspect.\n' + 'from django.contrib.gis.db import models\n' + '\n' + 'class Measurement(models.Model):\n' + )) + + # The ordering of model fields might vary depending on several factors (version of GDAL, etc.) + self.assertIn(' f_decimal = models.DecimalField(max_digits=0, decimal_places=0)', model_def) + self.assertIn(' f_int = models.IntegerField()', model_def) + self.assertIn(' f_datetime = models.DateTimeField()', model_def) + self.assertIn(' f_time = models.TimeField()', model_def) + self.assertIn(' f_float = models.FloatField()', model_def) + self.assertIn(' f_char = models.CharField(max_length=10)', model_def) + self.assertIn(' f_date = models.DateField()', model_def) + + self.assertIsNotNone(re.search( + r' geom = models.PolygonField\(([^\)])*\)\n' # Some backends may have srid=-1 + r' objects = models.GeoManager\(\)', model_def)) + + def test_management_command(self): + shp_file = os.path.join(TEST_DATA, 'cities', 'cities.shp') + out = StringIO() + call_command('ogrinspect', shp_file, 'City', stdout=out) + output = out.getvalue() + self.assertIn('class City(models.Model):', output) + + +def get_ogr_db_string(): + """ + Construct the DB string that GDAL will use to inspect the database. + GDAL will create its own connection to the database, so we re-use the + connection settings from the Django test. + """ + db = connections.databases['default'] + + # Map from the django backend into the OGR driver name and database identifier + # http://www.gdal.org/ogr/ogr_formats.html + # + # TODO: Support Oracle (OCI). + drivers = { + 'django.contrib.gis.db.backends.postgis': ('PostgreSQL', "PG:dbname='%(db_name)s'", ' '), + 'django.contrib.gis.db.backends.mysql': ('MySQL', 'MYSQL:"%(db_name)s"', ','), + 'django.contrib.gis.db.backends.spatialite': ('SQLite', '%(db_name)s', '') + } + + db_engine = db['ENGINE'] + if db_engine not in drivers: + return None + + drv_name, db_str, param_sep = drivers[db_engine] + + # Ensure that GDAL library has driver support for the database. + try: + Driver(drv_name) + except: + return None + + # SQLite/Spatialite in-memory databases + if db['NAME'] == ":memory:": + return None + + # Build the params of the OGR database connection string + params = [db_str % {'db_name': db['NAME']}] + + def add(key, template): + value = db.get(key, None) + # Don't add the parameter if it is not in django's settings + if value: + params.append(template % value) + add('HOST', "host='%s'") + add('PORT', "port='%s'") + add('USER', "user='%s'") + add('PASSWORD', "password='%s'") + + return param_sep.join(params) |