diff --git a/.gitignore b/.gitignore index a97178f5..bf2dd10c 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ dist/ .tox MANIFEST +.idea \ No newline at end of file diff --git a/django_mongodb_engine/base.py b/django_mongodb_engine/base.py index 033008e4..c8666877 100644 --- a/django_mongodb_engine/base.py +++ b/django_mongodb_engine/base.py @@ -5,6 +5,7 @@ import warnings from django.conf import settings +from django.contrib.gis.db.backends.base import BaseSpatialOperations from django.core.exceptions import ImproperlyConfigured from django.db.backends.signals import connection_created from django.db.utils import DatabaseError @@ -40,7 +41,7 @@ class DatabaseFeatures(NonrelDatabaseFeatures): supports_long_model_names = False -class DatabaseOperations(NonrelDatabaseOperations): +class DatabaseOperations(NonrelDatabaseOperations, BaseSpatialOperations): compiler_module = __name__.rsplit('.', 1)[0] + '.compiler' def max_name_length(self): @@ -153,6 +154,12 @@ def _value_from_db(self, value, field, field_kind, db_type): return super(DatabaseOperations, self)._value_from_db( value, field, field_kind, db_type) + def geometry_columns(self): + return None + + def spatial_ref_sys(self): + return None + class DatabaseClient(NonrelDatabaseClient): pass diff --git a/django_mongodb_engine/compiler.py b/django_mongodb_engine/compiler.py index a26c0af1..7fb158f5 100644 --- a/django_mongodb_engine/compiler.py +++ b/django_mongodb_engine/compiler.py @@ -10,6 +10,7 @@ from django.db.utils import DatabaseError, IntegrityError from django.utils.encoding import smart_str from django.utils.tree import Node +from django.contrib.gis.db.models.sql.compiler import GeoSQLCompiler as BaseGeoSQLCompiler from pymongo import ASCENDING, DESCENDING from pymongo.errors import PyMongoError, DuplicateKeyError @@ -64,6 +65,11 @@ def get_selected_fields(query): # Date OPs. 'year': lambda val: {'$gte': val[0], '$lt': val[1]}, + + # Spatial + 'within': lambda val: val, + 'intersects': lambda val: val, + 'near': lambda val: val, } NEGATED_OPERATORS_MAP = { @@ -444,3 +450,7 @@ def execute_update(self, update_spec, multi=True, **kwargs): class SQLDeleteCompiler(NonrelDeleteCompiler, SQLCompiler): pass + + +class GeoSQLCompiler(BaseGeoSQLCompiler, SQLCompiler): + pass \ No newline at end of file diff --git a/django_mongodb_engine/contrib/__init__.py b/django_mongodb_engine/contrib/__init__.py index acd3223b..61058473 100644 --- a/django_mongodb_engine/contrib/__init__.py +++ b/django_mongodb_engine/contrib/__init__.py @@ -1,8 +1,10 @@ import sys +from django.contrib.gis.db.models import GeoManager from django.db import models, connections from django.db.models.query import QuerySet from django.db.models.sql.query import Query as SQLQuery +from django_mongodb_engine.query import MongoGeoQuerySet ON_PYPY = hasattr(sys, 'pypy_version_info') @@ -175,3 +177,8 @@ def distinct(self, *args, **kwargs): database. """ return self.get_query_set().distinct(*args, **kwargs) + + +class GeoMongoDBManager(MongoDBManager, GeoManager): + def get_queryset(self): + return MongoGeoQuerySet(self.model, using=self._db) \ No newline at end of file diff --git a/django_mongodb_engine/fields.py b/django_mongodb_engine/fields.py index 5c815b9d..d54b598b 100644 --- a/django_mongodb_engine/fields.py +++ b/django_mongodb_engine/fields.py @@ -1,4 +1,11 @@ +import json +from django.contrib.gis import forms +from django.contrib.gis.db.models.proxy import GeometryProxy +from django.contrib.gis.geometry.backend import Geometry +from django.contrib.gis.geos import Polygon, LineString, Point from django.db import connections, models +from django.utils import six +from django.utils.translation import ugettext_lazy as _ from gridfs import GridFS from gridfs.errors import NoFile @@ -172,3 +179,163 @@ def _property_get(self, model): [], ['^django_mongodb_engine\.fields\.GridFSString']) except ImportError: pass + + +class MongoGeometryProxy(GeometryProxy): + def __set__(self, obj, value): + if isinstance(value, dict): + value = json.dumps(value) + super(MongoGeometryProxy, self).__set__(obj, value) + + +class GeometryField(models.Field): + """ + The base GIS field -- maps to the OpenGIS Specification Geometry type. + + Based loosely on django.contrib.gis.db.models.fields.GeometryField + """ + + # The OpenGIS Geometry name. + geom_type = 'GEOMETRY' + form_class = forms.GeometryField + + lookup_types = { + 'within': {'operator': '$geoWithin', 'types': [Polygon]}, + 'intersects': {'operator': '$geoIntersects', 'types': [Point, LineString, Polygon]}, + 'near': {'operator': '$near', 'types': [Point]}, + } + + description = _("The base GIS field -- maps to the OpenGIS Specification Geometry type.") + + def __init__(self, verbose_name=None, dim=2, **kwargs): + """ + The initialization function for geometry fields. Takes the following + as keyword arguments: + + dim: + The number of dimensions for this geometry. Defaults to 2. + """ + + # Mongo GeoJSON are WGS84 + self.srid = 4326 + + # Setting the dimension of the geometry field. + self.dim = dim + + # Setting the verbose_name keyword argument with the positional + # first parameter, so this works like normal fields. + kwargs['verbose_name'] = verbose_name + + super(GeometryField, self).__init__(**kwargs) + + def to_python(self, value): + if isinstance(value, Geometry): + return value + elif isinstance(value, dict): + return Geometry(json.dumps(value), self.srid) + elif isinstance(value, (bytes, six.string_types)): + return Geometry(value, self.srid) + raise ValueError('Could not convert to python geometry from value type "%s".' % type(value)) + + def get_prep_value(self, value): + if isinstance(value, Geometry): + return json.loads(value.json) + elif isinstance(value, six.string_types): + return json.loads(value) + raise ValueError('Could not prep geometry from value type "%s".' % type(value)) + + def get_srid(self, geom): + """ + Returns the default SRID for the given geometry, taking into account + the SRID set for the field. For example, if the input geometry + has no SRID, then that of the field will be returned. + """ + gsrid = geom.srid # SRID of given geometry. + if gsrid is None or self.srid == -1 or (gsrid == -1 and self.srid != -1): + return self.srid + else: + return gsrid + + ### Routines overloaded from Field ### + def contribute_to_class(self, cls, name, virtual_only=False): + super(GeometryField, self).contribute_to_class(cls, name, virtual_only) + + # Setup for lazy-instantiated Geometry object. + setattr(cls, self.attname, MongoGeometryProxy(Geometry, self)) + + def db_type(self, connection): + return self.geom_type + + def formfield(self, **kwargs): + defaults = {'form_class': self.form_class, + 'geom_type': self.geom_type, + 'srid': self.srid, + } + defaults.update(kwargs) + if (self.dim > 2 and not 'widget' in kwargs and + not getattr(defaults['form_class'].widget, 'supports_3d', False)): + defaults['widget'] = forms.Textarea + return super(GeometryField, self).formfield(**defaults) + + def get_prep_lookup(self, lookup_type, value): + if lookup_type not in self.lookup_types: + raise ValueError('Unknown lookup type "%s".' % lookup_type) + lookup_info = self.lookup_types[lookup_type] + if not isinstance(value, Geometry): + raise ValueError('Geometry value is of unsupported type "%s".' % type(value)) + if type(value) not in lookup_info['types']: + raise ValueError('"%s" lookup requires a value of geometry type(s) %s.' % + (lookup_type, ','.join([str(ltype) for ltype in lookup_info['types']]))) + geom_query = {'$geometry': json.loads(value.json)} + # some queries may have additional query params; e.g.: + # $near optionally takes $minDistance and $maxDistance + if hasattr(value, 'extra_params'): + geom_query.update(value.extra_params) + return {lookup_info['operator']: geom_query} + + def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False): + # this was already handled by get_prep_lookup... + return value + + +# The OpenGIS Geometry Type Fields +class PointField(GeometryField): + geom_type = 'POINT' + form_class = forms.PointField + description = _("Point") + + +class LineStringField(GeometryField): + geom_type = 'LINESTRING' + form_class = forms.LineStringField + description = _("Line string") + + +class PolygonField(GeometryField): + geom_type = 'POLYGON' + form_class = forms.PolygonField + description = _("Polygon") + + +class MultiPointField(GeometryField): + geom_type = 'MULTIPOINT' + form_class = forms.MultiPointField + description = _("Multi-point") + + +class MultiLineStringField(GeometryField): + geom_type = 'MULTILINESTRING' + form_class = forms.MultiLineStringField + description = _("Multi-line string") + + +class MultiPolygonField(GeometryField): + geom_type = 'MULTIPOLYGON' + form_class = forms.MultiPolygonField + description = _("Multi polygon") + + +class GeometryCollectionField(GeometryField): + geom_type = 'GEOMETRYCOLLECTION' + form_class = forms.GeometryCollectionField + description = _("Geometry collection") \ No newline at end of file diff --git a/django_mongodb_engine/query.py b/django_mongodb_engine/query.py index 7b1401c4..ef6089db 100644 --- a/django_mongodb_engine/query.py +++ b/django_mongodb_engine/query.py @@ -1,4 +1,6 @@ from warnings import warn +from django.contrib.gis.db.models.query import GeoQuerySet +from django.contrib.gis.db.models.sql import GeoQuery, GeoWhereNode from djangotoolbox.fields import RawField, AbstractIterableField, \ EmbeddedModelField @@ -24,3 +26,15 @@ def as_q(self, field): else: raise TypeError("Can not use A() queries on %s." % field.__class__.__name__) + + +class MongoGeoQuery(GeoQuery): + def __init__(self, model, where=GeoWhereNode): + super(MongoGeoQuery, self).__init__(model, where) + self.query_terms |= set(['near']) + + +class MongoGeoQuerySet(GeoQuerySet): + def __init__(self, model=None, query=None, using=None): + super(MongoGeoQuerySet, self).__init__(model=model, query=query, using=using) + self.query = query or MongoGeoQuery(self.model) \ No newline at end of file diff --git a/tests/gis/__init__.py b/tests/gis/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/gis/models.py b/tests/gis/models.py new file mode 100644 index 00000000..7adefb00 --- /dev/null +++ b/tests/gis/models.py @@ -0,0 +1,12 @@ +from django.db import models +from django_mongodb_engine.contrib import GeoMongoDBManager +from django_mongodb_engine.fields import GeometryField + + +class GeometryModel(models.Model): + geom = GeometryField() + + objects = GeoMongoDBManager() + + class MongoMeta: + indexes = [{'fields': [('geom', '2dsphere')]}] \ No newline at end of file diff --git a/tests/gis/tests.py b/tests/gis/tests.py new file mode 100644 index 00000000..a9f4a4d3 --- /dev/null +++ b/tests/gis/tests.py @@ -0,0 +1,130 @@ +from django.contrib.gis.geos import Point, LineString, Polygon, MultiPoint, MultiLineString, MultiPolygon, \ + GeometryCollection +import pymongo +from models import * +from utils import TestCase, get_collection + + +class GeometryTest(TestCase): + point = Point((1, 1)) + line = LineString((1, 1), (2, 2), (3, 3)) + polygon = Polygon(((0, 0), (0, 10), (10, 10), (10, 0), (0, 0)), + ((4, 4), (4, 6), (6, 6), (6, 4), (4, 4))) + multi_point = MultiPoint(Point(11, 11), + Point(12, 12), + Point(13, 13)) + multi_line = MultiLineString(LineString((21, 21), (22, 22), (33, 33)), + LineString((-21, -21), (-22, -22), (-33, -33))) + multi_polygon = MultiPolygon(Polygon(((30, 30), (30, 40), (40, 40), (40, 30), (30, 30))), + Polygon(((50, 50), (50, 60), (60, 60), (60, 50), (50, 50)))) + geom_collection = GeometryCollection(point, line, polygon) + + @classmethod + def setUpClass(cls): + coll = get_collection(GeometryModel) + + GeometryModel.objects.create(geom=cls.point) + GeometryModel.objects.create(geom=cls.line) + GeometryModel.objects.create(geom=cls.polygon) + GeometryModel.objects.create(geom=cls.multi_point) + GeometryModel.objects.create(geom=cls.multi_line) + GeometryModel.objects.create(geom=cls.multi_polygon) + GeometryModel.objects.create(geom=cls.geom_collection) + + # not sure why the tests don't create the index... + coll.ensure_index([('geom', pymongo.GEOSPHERE)]) + + def test_retrieve(self): + all_geoms = [obj.geom for obj in GeometryModel.objects.all()] + self.assertEqual(7, len(all_geoms)) + self.assertIn(self.point, all_geoms) + self.assertIn(self.line, all_geoms) + self.assertIn(self.polygon, all_geoms) + self.assertIn(self.multi_line, all_geoms) + self.assertIn(self.multi_polygon, all_geoms) + self.assertIn(self.geom_collection, all_geoms) + + def test_query_within(self): + # create a box that only contains the point + geoms = [obj.geom for obj in GeometryModel.objects.filter( + geom__within=Polygon(((0, 0), (0, 2), (2, 2), (2, 0), (0, 0))))] + self.assertEqual(1, len(geoms)) + self.assertIn(self.point, geoms) + + # create a box that contains everything + # NOTE: only returns points, lines and polygons! + geoms = [obj.geom for obj in GeometryModel.objects.filter( + geom__within=Polygon(((-20, -20), (-20, 20), (20, 20), (20, -20), (-20, -20))))] + self.assertEqual(5, len(geoms)) + self.assertIn(self.point, geoms) + self.assertIn(self.line, geoms) + self.assertIn(self.polygon, geoms) + + # create a box that contains nothing + geoms = [obj.geom for obj in GeometryModel.objects.filter( + geom__within=Polygon(((-20, -20), (-20, -19), (-19, -19), (-19, -20), (-20, -20))))] + self.assertEqual(0, len(geoms)) + + # try to query on some unsupported objects + with self.assertRaises(ValueError): + GeometryModel.objects.filter(geom__within='a string').first() + + with self.assertRaises(ValueError): + GeometryModel.objects.filter(geom__within=self.point).first() + + with self.assertRaises(ValueError): + GeometryModel.objects.filter(geom__within=self.line).first() + + def test_query_intersects(self): + # intersect with a polygon + geoms = [obj.geom for obj in GeometryModel.objects.filter( + geom__intersects=Polygon(((0, 0), (0, 2), (2, 2), (2, 0), (0, 0))))] + self.assertEqual(4, len(geoms)) + self.assertIn(self.point, geoms) + self.assertIn(self.line, geoms) + self.assertIn(self.polygon, geoms) + self.assertIn(self.geom_collection, geoms) + + # intersect with a line + geoms = [obj.geom for obj in GeometryModel.objects.filter( + geom__intersects=LineString(((3, 3), (2, 2), (1.5, 1.5))))] + self.assertEqual(3, len(geoms)) + self.assertIn(self.line, geoms) + self.assertIn(self.polygon, geoms) + self.assertIn(self.geom_collection, geoms) + + # intersect with a point + geoms = [obj.geom for obj in GeometryModel.objects.filter(geom__intersects=Point((9, 9)))] + self.assertEqual(2, len(geoms)) + self.assertIn(self.polygon, geoms) + self.assertIn(self.geom_collection, geoms) + + # intersection test that returns nothing + geoms = [obj.geom for obj in GeometryModel.objects.filter(geom__intersects=Point((-9, -9)))] + self.assertEqual(0, len(geoms)) + + # try to query on some unsupported objects + with self.assertRaises(ValueError): + GeometryModel.objects.filter(geom__intersects='a string').first() + + def test_query_near(self): + # make sure all are returned + point = Point((0, 0)) + geoms = [obj.geom for obj in GeometryModel.objects.filter(geom__near=point)] + self.assertEqual(7, len(geoms)) + + # restrict the distance + point = Point((0, 0)) + point.extra_params = {'$maxDistance': 1} + geoms = [obj.geom for obj in GeometryModel.objects.filter(geom__near=point)] + self.assertEqual(2, len(geoms)) + + point = Point((1, 1)) + point.extra_params = {'$maxDistance': 1} + geoms = [obj.geom for obj in GeometryModel.objects.filter(geom__near=point)] + self.assertEqual(4, len(geoms)) + + point = Point((0, 0)) + point.extra_params = {'$maxDistance': 1000000, '$minDistance': 100000} + geoms = [obj.geom for obj in GeometryModel.objects.filter(geom__near=point)] + self.assertEqual(2, len(geoms)) diff --git a/tests/gis/utils.py b/tests/gis/utils.py new file mode 120000 index 00000000..50fbc6d8 --- /dev/null +++ b/tests/gis/utils.py @@ -0,0 +1 @@ +../utils.py \ No newline at end of file diff --git a/tests/gis/views.py b/tests/gis/views.py new file mode 100644 index 00000000..e69de29b