summaryrefslogtreecommitdiff
path: root/django/db/models
diff options
context:
space:
mode:
Diffstat (limited to 'django/db/models')
-rw-r--r--django/db/models/__init__.py25
-rw-r--r--django/db/models/base.py114
-rw-r--r--django/db/models/fields/__init__.py205
-rw-r--r--django/db/models/fields/generic.py259
-rw-r--r--django/db/models/fields/related.py55
-rw-r--r--django/db/models/loading.py2
-rw-r--r--django/db/models/manager.py8
-rw-r--r--django/db/models/manipulators.py10
-rw-r--r--django/db/models/options.py5
-rw-r--r--django/db/models/query.py229
-rw-r--r--django/db/models/related.py17
11 files changed, 497 insertions, 432 deletions
diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py
index 0308dd047a..6c3abb6b59 100644
--- a/django/db/models/__init__.py
+++ b/django/db/models/__init__.py
@@ -8,7 +8,6 @@ from django.db.models.manager import Manager
from django.db.models.base import Model, AdminOptions
from django.db.models.fields import *
from django.db.models.fields.related import ForeignKey, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel, TABULAR, STACKED
-from django.db.models.fields.generic import GenericRelation, GenericRel, GenericForeignKey
from django.db.models import signals
from django.utils.functional import curry
from django.utils.text import capfirst
@@ -27,27 +26,3 @@ def permalink(func):
viewname = bits[0]
return reverse(bits[0], None, *bits[1:3])
return inner
-
-class LazyDate(object):
- """
- Use in limit_choices_to to compare the field to dates calculated at run time
- instead of when the model is loaded. For example::
-
- ... limit_choices_to = {'date__gt' : models.LazyDate(days=-3)} ...
-
- which will limit the choices to dates greater than three days ago.
- """
- def __init__(self, **kwargs):
- self.delta = datetime.timedelta(**kwargs)
-
- def __str__(self):
- return str(self.__get_value__())
-
- def __repr__(self):
- return "<LazyDate: %s>" % self.delta
-
- def __get_value__(self):
- return (datetime.datetime.now() + self.delta).date()
-
- def __getattr__(self, attr):
- return getattr(self.__get_value__(), attr)
diff --git a/django/db/models/base.py b/django/db/models/base.py
index 70569a2561..a8e6303e1c 100644
--- a/django/db/models/base.py
+++ b/django/db/models/base.py
@@ -13,6 +13,7 @@ from django.dispatch import dispatcher
from django.utils.datastructures import SortedDict
from django.utils.functional import curry
from django.conf import settings
+from itertools import izip
import types
import sys
import os
@@ -21,8 +22,13 @@ class ModelBase(type):
"Metaclass for all models"
def __new__(cls, name, bases, attrs):
# If this isn't a subclass of Model, don't do anything special.
- if not bases or bases == (object,):
- return type.__new__(cls, name, bases, attrs)
+ try:
+ if not filter(lambda b: issubclass(b, Model), bases):
+ return super(ModelBase, cls).__new__(cls, name, bases, attrs)
+ except NameError:
+ # 'Model' isn't defined yet, meaning we're looking at Django's own
+ # Model class, defined below.
+ return super(ModelBase, cls).__new__(cls, name, bases, attrs)
# Create the class.
new_class = type.__new__(cls, name, bases, {'__module__': attrs.pop('__module__')})
@@ -36,11 +42,11 @@ class ModelBase(type):
new_class._meta.parents.append(base)
new_class._meta.parents.extend(base._meta.parents)
- model_module = sys.modules[new_class.__module__]
if getattr(new_class._meta, 'app_label', None) is None:
# Figure out the app_label by looking one level up.
# For 'django.contrib.sites.models', this would be 'sites'.
+ model_module = sys.modules[new_class.__module__]
new_class._meta.app_label = model_module.__name__.split('.')[-2]
# Bail out early if we have already created this class.
@@ -63,7 +69,7 @@ class ModelBase(type):
if getattr(new_class._meta, 'row_level_permissions', False):
from django.contrib.auth.models import RowLevelPermission
- gen_rel = django.db.models.GenericRelation(RowLevelPermission, object_id_field="model_id", content_type_field="model_ct")
+ gen_rel = django.contrib.contenttypes.generic.GenericRelation(RowLevelPermission, object_id_field="model_id", content_type_field="model_ct")
new_class.add_to_class("row_level_permissions", gen_rel)
new_class._prepare()
@@ -95,41 +101,74 @@ class Model(object):
def __init__(self, *args, **kwargs):
dispatcher.send(signal=signals.pre_init, sender=self.__class__, args=args, kwargs=kwargs)
- for f in self._meta.fields:
- if isinstance(f.rel, ManyToOneRel):
- try:
- # Assume object instance was passed in.
- rel_obj = kwargs.pop(f.name)
- except KeyError:
+
+ # There is a rather weird disparity here; if kwargs, it's set, then args
+ # overrides it. It should be one or the other; don't duplicate the work
+ # The reason for the kwargs check is that standard iterator passes in by
+ # args, and nstantiation for iteration is 33% faster.
+ args_len = len(args)
+ if args_len > len(self._meta.fields):
+ # Daft, but matches old exception sans the err msg.
+ raise IndexError("Number of args exceeds number of fields")
+
+ fields_iter = iter(self._meta.fields)
+ if not kwargs:
+ # The ordering of the izip calls matter - izip throws StopIteration
+ # when an iter throws it. So if the first iter throws it, the second
+ # is *not* consumed. We rely on this, so don't change the order
+ # without changing the logic.
+ for val, field in izip(args, fields_iter):
+ setattr(self, field.attname, val)
+ else:
+ # Slower, kwargs-ready version.
+ for val, field in izip(args, fields_iter):
+ setattr(self, field.attname, val)
+ kwargs.pop(field.name, None)
+ # Maintain compatibility with existing calls.
+ if isinstance(field.rel, ManyToOneRel):
+ kwargs.pop(field.attname, None)
+
+ # Now we're left with the unprocessed fields that *must* come from
+ # keywords, or default.
+
+ for field in fields_iter:
+ if kwargs:
+ if isinstance(field.rel, ManyToOneRel):
try:
- # Object instance wasn't passed in -- must be an ID.
- val = kwargs.pop(f.attname)
+ # Assume object instance was passed in.
+ rel_obj = kwargs.pop(field.name)
except KeyError:
- val = f.get_default()
- else:
- # Object instance was passed in.
- # Special case: You can pass in "None" for related objects if it's allowed.
- if rel_obj is None and f.null:
- val = None
- else:
try:
- val = getattr(rel_obj, f.rel.get_related_field().attname)
- except AttributeError:
- raise TypeError, "Invalid value: %r should be a %s instance, not a %s" % (f.name, f.rel.to, type(rel_obj))
- setattr(self, f.attname, val)
+ # Object instance wasn't passed in -- must be an ID.
+ val = kwargs.pop(field.attname)
+ except KeyError:
+ val = field.get_default()
+ else:
+ # Object instance was passed in. Special case: You can
+ # pass in "None" for related objects if it's allowed.
+ if rel_obj is None and field.null:
+ val = None
+ else:
+ try:
+ val = getattr(rel_obj, field.rel.get_related_field().attname)
+ except AttributeError:
+ raise TypeError("Invalid value: %r should be a %s instance, not a %s" %
+ (field.name, field.rel.to, type(rel_obj)))
+ else:
+ val = kwargs.pop(field.attname, field.get_default())
else:
- val = kwargs.pop(f.attname, f.get_default())
- setattr(self, f.attname, val)
- for prop in kwargs.keys():
- try:
- if isinstance(getattr(self.__class__, prop), property):
- setattr(self, prop, kwargs.pop(prop))
- except AttributeError:
- pass
+ val = field.get_default()
+ setattr(self, field.attname, val)
+
if kwargs:
- raise TypeError, "'%s' is an invalid keyword argument for this function" % kwargs.keys()[0]
- for i, arg in enumerate(args):
- setattr(self, self._meta.fields[i].attname, arg)
+ for prop in kwargs.keys():
+ try:
+ if isinstance(getattr(self.__class__, prop), property):
+ setattr(self, prop, kwargs.pop(prop))
+ except AttributeError:
+ pass
+ if kwargs:
+ raise TypeError, "'%s' is an invalid keyword argument for this function" % kwargs.keys()[0]
dispatcher.send(signal=signals.post_init, sender=self.__class__, instance=self)
def add_to_class(cls, name, value):
@@ -327,7 +366,7 @@ class Model(object):
def _get_FIELD_size(self, field):
return os.path.getsize(self._get_FIELD_filename(field))
- def _save_FIELD_file(self, field, filename, raw_contents):
+ def _save_FIELD_file(self, field, filename, raw_contents, save=True):
directory = field.get_directory_name()
try: # Create the date-based directory if it doesn't exist.
os.makedirs(os.path.join(settings.MEDIA_ROOT, directory))
@@ -362,8 +401,9 @@ class Model(object):
if field.height_field:
setattr(self, field.height_field, height)
- # Save the object, because it has changed.
- self.save()
+ # Save the object because it has changed unless save is False
+ if save:
+ self.save()
_save_FIELD_file.alters_data = True
diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py
index fe317ac24f..136ce31b8b 100644
--- a/django/db/models/fields/__init__.py
+++ b/django/db/models/fields/__init__.py
@@ -10,6 +10,10 @@ from django.utils.itercompat import tee
from django.utils.text import capfirst
from django.utils.translation import gettext, gettext_lazy
import datetime, os, time
+try:
+ import decimal
+except ImportError:
+ from django.utils import _decimal as decimal # for Python 2.3
class NOT_PROVIDED:
pass
@@ -67,7 +71,7 @@ class Field(object):
def __init__(self, verbose_name=None, name=None, primary_key=False,
maxlength=None, unique=False, blank=False, null=False, db_index=False,
- core=False, rel=None, default=NOT_PROVIDED, editable=True,
+ core=False, rel=None, default=NOT_PROVIDED, editable=True, serialize=True,
prepopulate_from=None, unique_for_date=None, unique_for_month=None,
unique_for_year=None, validator_list=None, choices=None, radio_admin=None,
help_text='', db_column=None):
@@ -78,6 +82,7 @@ class Field(object):
self.blank, self.null = blank, null
self.core, self.rel, self.default = core, rel, default
self.editable = editable
+ self.serialize = serialize
self.validator_list = validator_list or []
self.prepopulate_from = prepopulate_from
self.unique_for_date, self.unique_for_month = unique_for_date, unique_for_month
@@ -164,7 +169,7 @@ class Field(object):
def get_db_prep_lookup(self, lookup_type, value):
"Returns field's value prepared for database lookup."
- if lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte', 'year', 'month', 'day', 'search'):
+ if lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte', 'month', 'day', 'search'):
return [value]
elif lookup_type in ('range', 'in'):
return value
@@ -178,7 +183,13 @@ class Field(object):
return ["%%%s" % prep_for_like_query(value)]
elif lookup_type == 'isnull':
return []
- raise TypeError, "Field has invalid lookup: %s" % lookup_type
+ elif lookup_type == 'year':
+ try:
+ value = int(value)
+ except ValueError:
+ raise ValueError("The __year lookup type requires an integer argument")
+ return ['%s-01-01 00:00:00' % value, '%s-12-31 23:59:59.999999' % value]
+ raise TypeError("Field has invalid lookup: %s" % lookup_type)
def has_default(self):
"Returns a boolean of whether this field has a default value."
@@ -334,10 +345,17 @@ class Field(object):
return self._choices
choices = property(_get_choices)
- def formfield(self):
+ def formfield(self, form_class=forms.CharField, **kwargs):
"Returns a django.newforms.Field instance for this database Field."
- # TODO: This is just a temporary default during development.
- return forms.CharField(required=not self.blank, label=capfirst(self.verbose_name))
+ defaults = {'required': not self.blank, 'label': capfirst(self.verbose_name), 'help_text': self.help_text}
+ if self.choices:
+ defaults['widget'] = forms.Select(choices=self.get_choices())
+ defaults.update(kwargs)
+ return form_class(**defaults)
+
+ def value_from_object(self, obj):
+ "Returns the value of this field in the given model instance."
+ return getattr(obj, self.attname)
class AutoField(Field):
empty_strings_allowed = False
@@ -375,7 +393,7 @@ class AutoField(Field):
super(AutoField, self).contribute_to_class(cls, name)
cls._meta.has_auto_field = True
- def formfield(self):
+ def formfield(self, **kwargs):
return None
class BooleanField(Field):
@@ -392,8 +410,10 @@ class BooleanField(Field):
def get_manipulator_field_objs(self):
return [oldforms.CheckboxField]
- def formfield(self):
- return forms.BooleanField(required=not self.blank, label=capfirst(self.verbose_name))
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.BooleanField}
+ defaults.update(kwargs)
+ return super(BooleanField, self).formfield(**defaults)
class CharField(Field):
def get_manipulator_field_objs(self):
@@ -409,8 +429,10 @@ class CharField(Field):
raise validators.ValidationError, gettext_lazy("This field cannot be null.")
return str(value)
- def formfield(self):
- return forms.CharField(max_length=self.maxlength, required=not self.blank, label=capfirst(self.verbose_name))
+ def formfield(self, **kwargs):
+ defaults = {'max_length': self.maxlength}
+ defaults.update(kwargs)
+ return super(CharField, self).formfield(**defaults)
# TODO: Maybe move this into contrib, because it's specialized.
class CommaSeparatedIntegerField(CharField):
@@ -428,6 +450,8 @@ class DateField(Field):
Field.__init__(self, verbose_name, name, **kwargs)
def to_python(self, value):
+ if value is None:
+ return value
if isinstance(value, datetime.datetime):
return value.date()
if isinstance(value, datetime.date):
@@ -479,15 +503,19 @@ class DateField(Field):
def get_manipulator_field_objs(self):
return [oldforms.DateField]
- def flatten_data(self, follow, obj = None):
+ def flatten_data(self, follow, obj=None):
val = self._get_val_from_obj(obj)
return {self.attname: (val is not None and val.strftime("%Y-%m-%d") or '')}
- def formfield(self):
- return forms.DateField(required=not self.blank, label=capfirst(self.verbose_name))
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.DateField}
+ defaults.update(kwargs)
+ return super(DateField, self).formfield(**defaults)
class DateTimeField(DateField):
def to_python(self, value):
+ if value is None:
+ return value
if isinstance(value, datetime.datetime):
return value
if isinstance(value, datetime.date):
@@ -544,8 +572,69 @@ class DateTimeField(DateField):
return {date_field: (val is not None and val.strftime("%Y-%m-%d") or ''),
time_field: (val is not None and val.strftime("%H:%M:%S") or '')}
- def formfield(self):
- return forms.DateTimeField(required=not self.blank, label=capfirst(self.verbose_name))
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.DateTimeField}
+ defaults.update(kwargs)
+ return super(DateTimeField, self).formfield(**defaults)
+
+class DecimalField(Field):
+ empty_strings_allowed = False
+ def __init__(self, verbose_name=None, name=None, max_digits=None, decimal_places=None, **kwargs):
+ self.max_digits, self.decimal_places = max_digits, decimal_places
+ Field.__init__(self, verbose_name, name, **kwargs)
+
+ def to_python(self, value):
+ if value is None:
+ return value
+ try:
+ return decimal.Decimal(value)
+ except decimal.InvalidOperation:
+ raise validators.ValidationError, gettext("This value must be a decimal number.")
+
+ def _format(self, value):
+ if isinstance(value, basestring):
+ return value
+ else:
+ return self.format_number(value)
+
+ def format_number(self, value):
+ """
+ Formats a number into a string with the requisite number of digits and
+ decimal places.
+ """
+ num_chars = self.max_digits
+ # Allow for a decimal point
+ if self.decimal_places > 0:
+ num_chars += 1
+ # Allow for a minus sign
+ if value < 0:
+ num_chars += 1
+
+ return "%.*f" % (self.decimal_places, value)
+
+ def get_db_prep_save(self, value):
+ if value is not None:
+ value = self._format(value)
+ return super(DecimalField, self).get_db_prep_save(value)
+
+ def get_db_prep_lookup(self, lookup_type, value):
+ if lookup_type == 'range':
+ value = [self._format(v) for v in value]
+ else:
+ value = self._format(value)
+ return super(DecimalField, self).get_db_prep_lookup(lookup_type, value)
+
+ def get_manipulator_field_objs(self):
+ return [curry(oldforms.DecimalField, max_digits=self.max_digits, decimal_places=self.decimal_places)]
+
+ def formfield(self, **kwargs):
+ defaults = {
+ 'max_digits': self.max_digits,
+ 'decimal_places': self.decimal_places,
+ 'form_class': forms.DecimalField,
+ }
+ defaults.update(kwargs)
+ return super(DecimalField, self).formfield(**defaults)
class EmailField(CharField):
def __init__(self, *args, **kwargs):
@@ -561,8 +650,10 @@ class EmailField(CharField):
def validate(self, field_data, all_data):
validators.isValidEmail(field_data, all_data)
- def formfield(self):
- return forms.EmailField(required=not self.blank, label=capfirst(self.verbose_name))
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.EmailField}
+ defaults.update(kwargs)
+ return super(EmailField, self).formfield(**defaults)
class FileField(Field):
def __init__(self, verbose_name=None, name=None, upload_to='', **kwargs):
@@ -610,7 +701,7 @@ class FileField(Field):
setattr(cls, 'get_%s_filename' % self.name, curry(cls._get_FIELD_filename, field=self))
setattr(cls, 'get_%s_url' % self.name, curry(cls._get_FIELD_url, field=self))
setattr(cls, 'get_%s_size' % self.name, curry(cls._get_FIELD_size, field=self))
- setattr(cls, 'save_%s_file' % self.name, lambda instance, filename, raw_contents: instance._save_FIELD_file(self, filename, raw_contents))
+ setattr(cls, 'save_%s_file' % self.name, lambda instance, filename, raw_contents, save=True: instance._save_FIELD_file(self, filename, raw_contents, save))
dispatcher.connect(self.delete_file, signal=signals.post_delete, sender=cls)
def delete_file(self, instance):
@@ -628,14 +719,14 @@ class FileField(Field):
def get_manipulator_field_names(self, name_prefix):
return [name_prefix + self.name + '_file', name_prefix + self.name]
- def save_file(self, new_data, new_object, original_object, change, rel):
+ def save_file(self, new_data, new_object, original_object, change, rel, save=True):
upload_field_name = self.get_manipulator_field_names('')[0]
if new_data.get(upload_field_name, False):
func = getattr(new_object, 'save_%s_file' % self.name)
if rel:
- func(new_data[upload_field_name][0]["filename"], new_data[upload_field_name][0]["content"])
+ func(new_data[upload_field_name][0]["filename"], new_data[upload_field_name][0]["content"], save)
else:
- func(new_data[upload_field_name]["filename"], new_data[upload_field_name]["content"])
+ func(new_data[upload_field_name]["filename"], new_data[upload_field_name]["content"], save)
def get_directory_name(self):
return os.path.normpath(datetime.datetime.now().strftime(self.upload_to))
@@ -655,12 +746,14 @@ class FilePathField(Field):
class FloatField(Field):
empty_strings_allowed = False
- def __init__(self, verbose_name=None, name=None, max_digits=None, decimal_places=None, **kwargs):
- self.max_digits, self.decimal_places = max_digits, decimal_places
- Field.__init__(self, verbose_name, name, **kwargs)
def get_manipulator_field_objs(self):
- return [curry(oldforms.FloatField, max_digits=self.max_digits, decimal_places=self.decimal_places)]
+ return [oldforms.FloatField]
+
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.FloatField}
+ defaults.update(kwargs)
+ return super(FloatField, self).formfield(**defaults)
class ImageField(FileField):
def __init__(self, verbose_name=None, name=None, width_field=None, height_field=None, **kwargs):
@@ -679,12 +772,12 @@ class ImageField(FileField):
if not self.height_field:
setattr(cls, 'get_%s_height' % self.name, curry(cls._get_FIELD_height, field=self))
- def save_file(self, new_data, new_object, original_object, change, rel):
- FileField.save_file(self, new_data, new_object, original_object, change, rel)
+ def save_file(self, new_data, new_object, original_object, change, rel, save=True):
+ FileField.save_file(self, new_data, new_object, original_object, change, rel, save)
# If the image has height and/or width field(s) and they haven't
# changed, set the width and/or height field(s) back to their original
# values.
- if change and (self.width_field or self.height_field):
+ if change and (self.width_field or self.height_field) and save:
if self.width_field:
setattr(new_object, self.width_field, getattr(original_object, self.width_field))
if self.height_field:
@@ -696,8 +789,10 @@ class IntegerField(Field):
def get_manipulator_field_objs(self):
return [oldforms.IntegerField]
- def formfield(self):
- return forms.IntegerField(required=not self.blank, label=capfirst(self.verbose_name))
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.IntegerField}
+ defaults.update(kwargs)
+ return super(IntegerField, self).formfield(**defaults)
class IPAddressField(Field):
def __init__(self, *args, **kwargs):
@@ -715,6 +810,13 @@ class NullBooleanField(Field):
kwargs['null'] = True
Field.__init__(self, *args, **kwargs)
+ def to_python(self, value):
+ if value in (None, True, False): return value
+ if value in ('None'): return None
+ if value in ('t', 'True', '1'): return True
+ if value in ('f', 'False', '0'): return False
+ raise validators.ValidationError, gettext("This value must be either None, True or False.")
+
def get_manipulator_field_objs(self):
return [oldforms.NullBooleanField]
@@ -725,6 +827,12 @@ class PhoneNumberField(IntegerField):
def validate(self, field_data, all_data):
validators.isValidPhone(field_data, all_data)
+ def formfield(self, **kwargs):
+ from django.contrib.localflavor.us.forms import USPhoneNumberField
+ defaults = {'form_class': USPhoneNumberField}
+ defaults.update(kwargs)
+ return super(PhoneNumberField, self).formfield(**defaults)
+
class PositiveIntegerField(IntegerField):
def get_manipulator_field_objs(self):
return [oldforms.PositiveIntegerField]
@@ -738,7 +846,7 @@ class SlugField(Field):
kwargs['maxlength'] = kwargs.get('maxlength', 50)
kwargs.setdefault('validator_list', []).append(validators.isSlug)
# Set db_index=True unless it's been set manually.
- if not kwargs.has_key('db_index'):
+ if 'db_index' not in kwargs:
kwargs['db_index'] = True
Field.__init__(self, *args, **kwargs)
@@ -753,6 +861,11 @@ class TextField(Field):
def get_manipulator_field_objs(self):
return [oldforms.LargeTextField]
+ def formfield(self, **kwargs):
+ defaults = {'widget': forms.Textarea}
+ defaults.update(kwargs)
+ return super(TextField, self).formfield(**defaults)
+
class TimeField(Field):
empty_strings_allowed = False
def __init__(self, verbose_name=None, name=None, auto_now=False, auto_now_add=False, **kwargs):
@@ -781,7 +894,7 @@ class TimeField(Field):
if value is not None:
# MySQL will throw a warning if microseconds are given, because it
# doesn't support microseconds.
- if settings.DATABASE_ENGINE == 'mysql':
+ if settings.DATABASE_ENGINE == 'mysql' and hasattr(value, 'microsecond'):
value = value.replace(microsecond=0)
value = str(value)
return Field.get_db_prep_save(self, value)
@@ -793,26 +906,40 @@ class TimeField(Field):
val = self._get_val_from_obj(obj)
return {self.attname: (val is not None and val.strftime("%H:%M:%S") or '')}
- def formfield(self):
- return forms.TimeField(required=not self.blank, label=capfirst(self.verbose_name))
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.TimeField}
+ defaults.update(kwargs)
+ return super(TimeField, self).formfield(**defaults)
-class URLField(Field):
+class URLField(CharField):
def __init__(self, verbose_name=None, name=None, verify_exists=True, **kwargs):
+ kwargs['maxlength'] = kwargs.get('maxlength', 200)
if verify_exists:
kwargs.setdefault('validator_list', []).append(validators.isExistingURL)
self.verify_exists = verify_exists
- Field.__init__(self, verbose_name, name, **kwargs)
+ CharField.__init__(self, verbose_name, name, **kwargs)
def get_manipulator_field_objs(self):
return [oldforms.URLField]
- def formfield(self):
- return forms.URLField(required=not self.blank, verify_exists=self.verify_exists, label=capfirst(self.verbose_name))
+ def get_internal_type(self):
+ return "CharField"
+
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.URLField, 'verify_exists': self.verify_exists}
+ defaults.update(kwargs)
+ return super(URLField, self).formfield(**defaults)
class USStateField(Field):
def get_manipulator_field_objs(self):
return [oldforms.USStateField]
+ def formfield(self, **kwargs):
+ from django.contrib.localflavor.us.forms import USStateSelect
+ defaults = {'widget': USStateSelect}
+ defaults.update(kwargs)
+ return super(USStateField, self).formfield(**defaults)
+
class XMLField(TextField):
def __init__(self, verbose_name=None, name=None, schema_path=None, **kwargs):
self.schema_path = schema_path
diff --git a/django/db/models/fields/generic.py b/django/db/models/fields/generic.py
deleted file mode 100644
index 1ad8346e42..0000000000
--- a/django/db/models/fields/generic.py
+++ /dev/null
@@ -1,259 +0,0 @@
-"""
-Classes allowing "generic" relations through ContentType and object-id fields.
-"""
-
-from django import oldforms
-from django.core.exceptions import ObjectDoesNotExist
-from django.db import backend
-from django.db.models import signals
-from django.db.models.fields.related import RelatedField, Field, ManyToManyRel
-from django.db.models.loading import get_model
-from django.dispatch import dispatcher
-from django.utils.functional import curry
-
-class GenericForeignKey(object):
- """
- Provides a generic relation to any object through content-type/object-id
- fields.
- """
-
- def __init__(self, ct_field="content_type", fk_field="object_id"):
- self.ct_field = ct_field
- self.fk_field = fk_field
-
- def contribute_to_class(self, cls, name):
- # Make sure the fields exist (these raise FieldDoesNotExist,
- # which is a fine error to raise here)
- self.name = name
- self.model = cls
- self.cache_attr = "_%s_cache" % name
-
- # For some reason I don't totally understand, using weakrefs here doesn't work.
- dispatcher.connect(self.instance_pre_init, signal=signals.pre_init, sender=cls, weak=False)
-
- # Connect myself as the descriptor for this field
- setattr(cls, name, self)
-
- def instance_pre_init(self, signal, sender, args, kwargs):
- # Handle initalizing an object with the generic FK instaed of
- # content-type/object-id fields.
- if kwargs.has_key(self.name):
- value = kwargs.pop(self.name)
- kwargs[self.ct_field] = self.get_content_type(value)
- kwargs[self.fk_field] = value._get_pk_val()
-
- def get_content_type(self, obj):
- # Convenience function using get_model avoids a circular import when using this model
- ContentType = get_model("contenttypes", "contenttype")
- return ContentType.objects.get_for_model(obj)
-
- def __get__(self, instance, instance_type=None):
- if instance is None:
- raise AttributeError, "%s must be accessed via instance" % self.name
-
- try:
- return getattr(instance, self.cache_attr)
- except AttributeError:
- rel_obj = None
- ct = getattr(instance, self.ct_field)
- if ct:
- try:
- rel_obj = ct.get_object_for_this_type(pk=getattr(instance, self.fk_field))
- except ObjectDoesNotExist:
- pass
- setattr(instance, self.cache_attr, rel_obj)
- return rel_obj
-
- def __set__(self, instance, value):
- if instance is None:
- raise AttributeError, "%s must be accessed via instance" % self.related.opts.object_name
-
- ct = None
- fk = None
- if value is not None:
- ct = self.get_content_type(value)
- fk = value._get_pk_val()
-
- setattr(instance, self.ct_field, ct)
- setattr(instance, self.fk_field, fk)
- setattr(instance, self.cache_attr, value)
-
-class GenericRelation(RelatedField, Field):
- """Provides an accessor to generic related objects (i.e. comments)"""
-
- def __init__(self, to, **kwargs):
- kwargs['verbose_name'] = kwargs.get('verbose_name', None)
- kwargs['rel'] = GenericRel(to,
- related_name=kwargs.pop('related_name', None),
- limit_choices_to=kwargs.pop('limit_choices_to', None),
- symmetrical=kwargs.pop('symmetrical', True))
-
- # Override content-type/object-id field names on the related class
- self.object_id_field_name = kwargs.pop("object_id_field", "object_id")
- self.content_type_field_name = kwargs.pop("content_type_field", "content_type")
-
- kwargs['blank'] = True
- kwargs['editable'] = False
- Field.__init__(self, **kwargs)
-
- def get_manipulator_field_objs(self):
- choices = self.get_choices_default()
- return [curry(oldforms.SelectMultipleField, size=min(max(len(choices), 5), 15), choices=choices)]
-
- def get_choices_default(self):
- return Field.get_choices(self, include_blank=False)
-
- def flatten_data(self, follow, obj = None):
- new_data = {}
- if obj:
- instance_ids = [instance._get_pk_val() for instance in getattr(obj, self.name).all()]
- new_data[self.name] = instance_ids
- return new_data
-
- def m2m_db_table(self):
- return self.rel.to._meta.db_table
-
- def m2m_column_name(self):
- return self.object_id_field_name
-
- def m2m_reverse_name(self):
- return self.object_id_field_name
-
- def contribute_to_class(self, cls, name):
- super(GenericRelation, self).contribute_to_class(cls, name)
-
- # Save a reference to which model this class is on for future use
- self.model = cls
-
- # Add the descriptor for the m2m relation
- setattr(cls, self.name, ReverseGenericRelatedObjectsDescriptor(self))
-
- def contribute_to_related_class(self, cls, related):
- pass
-
- def set_attributes_from_rel(self):
- pass
-
- def get_internal_type(self):
- return "ManyToManyField"
-
-class ReverseGenericRelatedObjectsDescriptor(object):
- """
- This class provides the functionality that makes the related-object
- managers available as attributes on a model class, for fields that have
- multiple "remote" values and have a GenericRelation defined in their model
- (rather than having another model pointed *at* them). In the example
- "article.publications", the publications attribute is a
- ReverseGenericRelatedObjectsDescriptor instance.
- """
- def __init__(self, field):
- self.field = field
-
- def __get__(self, instance, instance_type=None):
- if instance is None:
- raise AttributeError, "Manager must be accessed via instance"
-
- # This import is done here to avoid circular import importing this module
- from django.contrib.contenttypes.models import ContentType
-
- # Dynamically create a class that subclasses the related model's
- # default manager.
- rel_model = self.field.rel.to
- superclass = rel_model._default_manager.__class__
- RelatedManager = create_generic_related_manager(superclass)
-
- manager = RelatedManager(
- model = rel_model,
- instance = instance,
- symmetrical = (self.field.rel.symmetrical and instance.__class__ == rel_model),
- join_table = backend.quote_name(self.field.m2m_db_table()),
- source_col_name = backend.quote_name(self.field.m2m_column_name()),
- target_col_name = backend.quote_name(self.field.m2m_reverse_name()),
- content_type = ContentType.objects.get_for_model(self.field.model),
- content_type_field_name = self.field.content_type_field_name,
- object_id_field_name = self.field.object_id_field_name
- )
-
- return manager
-
- def __set__(self, instance, value):
- if instance is None:
- raise AttributeError, "Manager must be accessed via instance"
-
- manager = self.__get__(instance)
- manager.clear()
- for obj in value:
- manager.add(obj)
-
-def create_generic_related_manager(superclass):
- """
- Factory function for a manager that subclasses 'superclass' (which is a
- Manager) and adds behavior for generic related objects.
- """
-
- class GenericRelatedObjectManager(superclass):
- def __init__(self, model=None, core_filters=None, instance=None, symmetrical=None,
- join_table=None, source_col_name=None, target_col_name=None, content_type=None,
- content_type_field_name=None, object_id_field_name=None):
-
- super(GenericRelatedObjectManager, self).__init__()
- self.core_filters = core_filters or {}
- self.model = model
- self.content_type = content_type
- self.symmetrical = symmetrical
- self.instance = instance
- self.join_table = join_table
- self.join_table = model._meta.db_table
- self.source_col_name = source_col_name
- self.target_col_name = target_col_name
- self.content_type_field_name = content_type_field_name
- self.object_id_field_name = object_id_field_name
- self.pk_val = self.instance._get_pk_val()
-
- def get_query_set(self):
- query = {
- '%s__pk' % self.content_type_field_name : self.content_type.id,
- '%s__exact' % self.object_id_field_name : self.pk_val,
- }
- return superclass.get_query_set(self).filter(**query)
-
- def add(self, *objs):
- for obj in objs:
- setattr(obj, self.content_type_field_name, self.content_type)
- setattr(obj, self.object_id_field_name, self.pk_val)
- obj.save()
- add.alters_data = True
-
- def remove(self, *objs):
- for obj in objs:
- obj.delete()
- remove.alters_data = True
-
- def clear(self):
- for obj in self.all():
- obj.delete()
- clear.alters_data = True
-
- def create(self, **kwargs):
- kwargs[self.content_type_field_name] = self.content_type
- kwargs[self.object_id_field_name] = self.pk_val
- obj = self.model(**kwargs)
- obj.save()
- return obj
- create.alters_data = True
-
- return GenericRelatedObjectManager
-
-class GenericRel(ManyToManyRel):
- def __init__(self, to, related_name=None, limit_choices_to=None, symmetrical=True):
- self.to = to
- self.num_in_admin = 0
- self.related_name = related_name
- self.filter_interface = None
- self.limit_choices_to = limit_choices_to or {}
- self.edit_inline = False
- self.raw_id_admin = False
- self.symmetrical = symmetrical
- self.multiple = True
- assert not (self.raw_id_admin and self.filter_interface), \
- "Generic relations may not use both raw_id_admin and filter_interface"
diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
index 797ef05be1..0739d0461a 100644
--- a/django/db/models/fields/related.py
+++ b/django/db/models/fields/related.py
@@ -2,10 +2,12 @@ from django.db import backend, transaction
from django.db.models import signals, get_model
from django.db.models.fields import AutoField, Field, IntegerField, get_ul_class
from django.db.models.related import RelatedObject
+from django.utils.text import capfirst
from django.utils.translation import gettext_lazy, string_concat, ngettext
from django.utils.functional import curry
from django.core import validators
from django import oldforms
+from django import newforms as forms
from django.dispatch import dispatcher
# For Python 2.3
@@ -314,18 +316,20 @@ def create_many_related_manager(superclass):
# join_table: name of the m2m link table
# source_col_name: the PK colname in join_table for the source object
# target_col_name: the PK colname in join_table for the target object
- # *objs - objects to add
+ # *objs - objects to add. Either object instances, or primary keys of object instances.
from django.db import connection
# If there aren't any objects, there is nothing to do.
if objs:
# Check that all the objects are of the right type
+ new_ids = set()
for obj in objs:
- if not isinstance(obj, self.model):
- raise ValueError, "objects to add() must be %s instances" % self.model._meta.object_name
+ if isinstance(obj, self.model):
+ new_ids.add(obj._get_pk_val())
+ else:
+ new_ids.add(obj)
# Add the newly created or already existing objects to the join table.
# First find out which items are already added, to avoid adding them twice
- new_ids = set([obj._get_pk_val() for obj in objs])
cursor = connection.cursor()
cursor.execute("SELECT %s FROM %s WHERE %s = %%s AND %s IN (%s)" % \
(target_col_name, self.join_table, source_col_name,
@@ -352,14 +356,16 @@ def create_many_related_manager(superclass):
# If there aren't any objects, there is nothing to do.
if objs:
# Check that all the objects are of the right type
+ old_ids = set()
for obj in objs:
- if not isinstance(obj, self.model):
- raise ValueError, "objects to remove() must be %s instances" % self.model._meta.object_name
+ if isinstance(obj, self.model):
+ old_ids.add(obj._get_pk_val())
+ else:
+ old_ids.add(obj)
# Remove the specified objects from the join table
- old_ids = set([obj._get_pk_val() for obj in objs])
cursor = connection.cursor()
cursor.execute("DELETE FROM %s WHERE %s = %%s AND %s IN (%s)" % \
- (self.join_table, source_col_name,
+ (self.join_table, source_col_name,
target_col_name, ",".join(['%s'] * len(old_ids))),
[self._pk_val] + list(old_ids))
transaction.commit_unless_managed()
@@ -468,7 +474,7 @@ class ForeignKey(RelatedField, Field):
to_field = to_field or to._meta.pk.name
kwargs['verbose_name'] = kwargs.get('verbose_name', '')
- if kwargs.has_key('edit_inline_type'):
+ if 'edit_inline_type' in kwargs:
import warnings
warnings.warn("edit_inline_type is deprecated. Use edit_inline instead.")
kwargs['edit_inline'] = kwargs.pop('edit_inline_type')
@@ -546,6 +552,11 @@ class ForeignKey(RelatedField, Field):
def contribute_to_related_class(self, cls, related):
setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related))
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.ModelChoiceField, 'queryset': self.rel.to._default_manager.all()}
+ defaults.update(kwargs)
+ return super(ForeignKey, self).formfield(**defaults)
+
class OneToOneField(RelatedField, IntegerField):
def __init__(self, to, to_field=None, **kwargs):
try:
@@ -556,7 +567,7 @@ class OneToOneField(RelatedField, IntegerField):
to_field = to_field or to._meta.pk.name
kwargs['verbose_name'] = kwargs.get('verbose_name', '')
- if kwargs.has_key('edit_inline_type'):
+ if 'edit_inline_type' in kwargs:
import warnings
warnings.warn("edit_inline_type is deprecated. Use edit_inline instead.")
kwargs['edit_inline'] = kwargs.pop('edit_inline_type')
@@ -607,6 +618,11 @@ class OneToOneField(RelatedField, IntegerField):
if not cls._meta.one_to_one_field:
cls._meta.one_to_one_field = self
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.ModelChoiceField, 'queryset': self.rel.to._default_manager.all()}
+ defaults.update(kwargs)
+ return super(OneToOneField, self).formfield(**defaults)
+
class ManyToManyField(RelatedField, Field):
def __init__(self, to, **kwargs):
kwargs['verbose_name'] = kwargs.get('verbose_name', None)
@@ -617,6 +633,7 @@ class ManyToManyField(RelatedField, Field):
limit_choices_to=kwargs.pop('limit_choices_to', None),
raw_id_admin=kwargs.pop('raw_id_admin', False),
symmetrical=kwargs.pop('symmetrical', True))
+ self.db_table = kwargs.pop('db_table', None)
if kwargs["rel"].raw_id_admin:
kwargs.setdefault("validator_list", []).append(self.isValidIDList)
Field.__init__(self, **kwargs)
@@ -639,7 +656,10 @@ class ManyToManyField(RelatedField, Field):
def _get_m2m_db_table(self, opts):
"Function that can be curried to provide the m2m table name for this relation"
- return '%s_%s' % (opts.db_table, self.name)
+ if self.db_table:
+ return self.db_table
+ else:
+ return '%s_%s' % (opts.db_table, self.name)
def _get_m2m_column_name(self, related):
"Function that can be curried to provide the source column name for the m2m table"
@@ -713,6 +733,19 @@ class ManyToManyField(RelatedField, Field):
def set_attributes_from_rel(self):
pass
+ def value_from_object(self, obj):
+ "Returns the value of this field in the given model instance."
+ return getattr(obj, self.attname).all()
+
+ def formfield(self, **kwargs):
+ defaults = {'form_class': forms.ModelMultipleChoiceField, 'queryset': self.rel.to._default_manager.all()}
+ defaults.update(kwargs)
+ # If initial is passed in, it's a list of related objects, but the
+ # MultipleChoiceField takes a list of IDs.
+ if defaults.get('initial') is not None:
+ defaults['initial'] = [i._get_pk_val() for i in defaults['initial']]
+ return super(ManyToManyField, self).formfield(**defaults)
+
class ManyToOneRel(object):
def __init__(self, to, field_name, num_in_admin=3, min_num_in_admin=None,
max_num_in_admin=None, num_extra_on_change=1, edit_inline=False,
diff --git a/django/db/models/loading.py b/django/db/models/loading.py
index f4aff2438b..224f5e8451 100644
--- a/django/db/models/loading.py
+++ b/django/db/models/loading.py
@@ -103,7 +103,7 @@ def register_models(app_label, *models):
# in the _app_models dictionary
model_name = model._meta.object_name.lower()
model_dict = _app_models.setdefault(app_label, {})
- if model_dict.has_key(model_name):
+ if model_name in model_dict:
# The same model may be imported via different paths (e.g.
# appname.models and project.appname.models). We use the source
# filename as a means to detect identity.
diff --git a/django/db/models/manager.py b/django/db/models/manager.py
index 6005874516..b60eed262a 100644
--- a/django/db/models/manager.py
+++ b/django/db/models/manager.py
@@ -1,4 +1,4 @@
-from django.db.models.query import QuerySet
+from django.db.models.query import QuerySet, EmptyQuerySet
from django.dispatch import dispatcher
from django.db.models import signals
from django.db.models.fields import FieldDoesNotExist
@@ -41,12 +41,18 @@ class Manager(object):
#######################
# PROXIES TO QUERYSET #
#######################
+
+ def get_empty_query_set(self):
+ return EmptyQuerySet(self.model)
def get_query_set(self):
"""Returns a new QuerySet object. Subclasses can override this method
to easily customise the behaviour of the Manager.
"""
return QuerySet(self.model)
+
+ def none(self):
+ return self.get_empty_query_set()
def all(self):
return self.get_query_set()
diff --git a/django/db/models/manipulators.py b/django/db/models/manipulators.py
index e9dfa7037c..d5fc5f725e 100644
--- a/django/db/models/manipulators.py
+++ b/django/db/models/manipulators.py
@@ -96,14 +96,16 @@ class AutomaticManipulator(oldforms.Manipulator):
if self.change:
params[self.opts.pk.attname] = self.obj_key
- # First, save the basic object itself.
+ # First, create the basic object itself.
new_object = self.model(**params)
- new_object.save()
- # Now that the object's been saved, save any uploaded files.
+ # Now that the object's been created, save any uploaded files.
for f in self.opts.fields:
if isinstance(f, FileField):
- f.save_file(new_data, new_object, self.change and self.original_object or None, self.change, rel=False)
+ f.save_file(new_data, new_object, self.change and self.original_object or None, self.change, rel=False, save=False)
+
+ # Now save the object
+ new_object.save()
# Calculate which primary fields have changed.
if self.change:
diff --git a/django/db/models/options.py b/django/db/models/options.py
index ee253ff451..556168e7d0 100644
--- a/django/db/models/options.py
+++ b/django/db/models/options.py
@@ -85,6 +85,7 @@ class Options(object):
self.fields.insert(bisect(self.fields, field), field)
if not self.pk and field.primary_key:
self.pk = field
+ field.serialize = False
def __repr__(self):
return '<Options for %s>' % self.object_name
@@ -140,7 +141,7 @@ class Options(object):
def get_follow(self, override=None):
follow = {}
for f in self.fields + self.many_to_many + self.get_all_related_objects():
- if override and override.has_key(f.name):
+ if override and f.name in override:
child_override = override[f.name]
else:
child_override = None
@@ -182,7 +183,7 @@ class Options(object):
# TODO: follow
if not hasattr(self, '_field_types'):
self._field_types = {}
- if not self._field_types.has_key(field_type):
+ if field_type not in self._field_types:
try:
# First check self.fields.
for f in self.fields:
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 53ed63ae5b..a6e702be18 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -1,8 +1,9 @@
from django.db import backend, connection, transaction
from django.db.models.fields import DateField, FieldDoesNotExist
-from django.db.models import signals
+from django.db.models import signals, loading
from django.dispatch import dispatcher
from django.utils.datastructures import SortedDict
+from django.contrib.contenttypes import generic
import operator
import re
@@ -25,6 +26,9 @@ QUERY_TERMS = (
# Larger values are slightly faster at the expense of more storage space.
GET_ITERATOR_CHUNK_SIZE = 100
+class EmptyResultSet(Exception):
+ pass
+
####################
# HELPER FUNCTIONS #
####################
@@ -80,6 +84,7 @@ class QuerySet(object):
self._filters = Q()
self._order_by = None # Ordering, e.g. ('date', '-name'). If None, use model's ordering.
self._select_related = False # Whether to fill cache for related objects.
+ self._max_related_depth = 0 # Maximum "depth" for select_related
self._distinct = False # Whether the query should use SELECT DISTINCT.
self._select = {} # Dictionary of attname -> SQL.
self._where = [] # List of extra WHERE clauses to use.
@@ -104,6 +109,8 @@ class QuerySet(object):
def __getitem__(self, k):
"Retrieve an item or slice from the set of results."
+ if not isinstance(k, (slice, int)):
+ raise TypeError
assert (not isinstance(k, slice) and (k >= 0)) \
or (isinstance(k, slice) and (k.start is None or k.start >= 0) and (k.stop is None or k.stop >= 0)), \
"Negative indexing is not supported."
@@ -163,12 +170,16 @@ class QuerySet(object):
def iterator(self):
"Performs the SELECT database lookup of this QuerySet."
+ try:
+ select, sql, params = self._get_sql_clause()
+ except EmptyResultSet:
+ raise StopIteration
+
# self._select is a dictionary, and dictionaries' key order is
# undefined, so we convert it to a list of tuples.
extra_select = self._select.items()
cursor = connection.cursor()
- select, sql, params = self._get_sql_clause()
cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params)
fill_cache = self._select_related
index_end = len(self.model._meta.fields)
@@ -178,7 +189,8 @@ class QuerySet(object):
raise StopIteration
for row in rows:
if fill_cache:
- obj, index_end = get_cached_row(self.model, row, 0)
+ obj, index_end = get_cached_row(klass=self.model, row=row,
+ index_start=0, max_depth=self._max_related_depth)
else:
obj = self.model(*row[:index_end])
for i, k in enumerate(extra_select):
@@ -186,13 +198,31 @@ class QuerySet(object):
yield obj
def count(self):
- "Performs a SELECT COUNT() and returns the number of records as an integer."
+ """
+ Performs a SELECT COUNT() and returns the number of records as an
+ integer.
+
+ If the queryset is already cached (i.e. self._result_cache is set) this
+ simply returns the length of the cached results set to avoid multiple
+ SELECT COUNT(*) calls.
+ """
+ if self._result_cache is not None:
+ return len(self._result_cache)
+
counter = self._clone()
counter._order_by = ()
+ counter._select_related = False
+
+ offset = counter._offset
+ limit = counter._limit
counter._offset = None
counter._limit = None
- counter._select_related = False
- select, sql, params = counter._get_sql_clause()
+
+ try:
+ select, sql, params = counter._get_sql_clause()
+ except EmptyResultSet:
+ return 0
+
cursor = connection.cursor()
if self._distinct:
id_col = "%s.%s" % (backend.quote_name(self.model._meta.db_table),
@@ -200,7 +230,16 @@ class QuerySet(object):
cursor.execute("SELECT COUNT(DISTINCT(%s))" % id_col + sql, params)
else:
cursor.execute("SELECT COUNT(*)" + sql, params)
- return cursor.fetchone()[0]
+ count = cursor.fetchone()[0]
+
+ # Apply any offset and limit constraints manually, since using LIMIT or
+ # OFFSET in SQL doesn't change the output of COUNT.
+ if offset:
+ count = max(0, count - offset)
+ if limit:
+ count = min(limit, count)
+
+ return count
def get(self, *args, **kwargs):
"Performs the SELECT and returns a single object matching the given keyword arguments."
@@ -359,9 +398,9 @@ class QuerySet(object):
else:
return self._filter_or_exclude(None, **filter_obj)
- def select_related(self, true_or_false=True):
+ def select_related(self, true_or_false=True, depth=0):
"Returns a new QuerySet instance with '_select_related' modified."
- return self._clone(_select_related=true_or_false)
+ return self._clone(_select_related=true_or_false, _max_related_depth=depth)
def order_by(self, *field_names):
"Returns a new QuerySet instance with the ordering changed."
@@ -395,6 +434,7 @@ class QuerySet(object):
c._filters = self._filters
c._order_by = self._order_by
c._select_related = self._select_related
+ c._max_related_depth = self._max_related_depth
c._distinct = self._distinct
c._select = self._select.copy()
c._where = self._where[:]
@@ -448,7 +488,10 @@ class QuerySet(object):
# Add additional tables and WHERE clauses based on select_related.
if self._select_related:
- fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table])
+ fill_table_cache(opts, select, tables, where,
+ old_prefix=opts.db_table,
+ cache_tables_seen=[opts.db_table],
+ max_depth=self._max_related_depth)
# Add any additional SELECTs.
if self._select:
@@ -509,22 +552,42 @@ class QuerySet(object):
return select, " ".join(sql), params
class ValuesQuerySet(QuerySet):
- def iterator(self):
- # select_related and select aren't supported in values().
+ def __init__(self, *args, **kwargs):
+ super(ValuesQuerySet, self).__init__(*args, **kwargs)
+ # select_related isn't supported in values().
self._select_related = False
- self._select = {}
+
+ def iterator(self):
+ try:
+ select, sql, params = self._get_sql_clause()
+ except EmptyResultSet:
+ raise StopIteration
# self._fields is a list of field names to fetch.
if self._fields:
- columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields]
+ #columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields]
+ if not self._select:
+ columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields]
+ else:
+ columns = []
+ for f in self._fields:
+ if f in [field.name for field in self.model._meta.fields]:
+ columns.append( self.model._meta.get_field(f, many_to_many=False).column )
+ elif not self._select.has_key( f ):
+ raise FieldDoesNotExist, '%s has no field named %r' % ( self.model._meta.object_name, f )
+
field_names = self._fields
else: # Default to all fields.
columns = [f.column for f in self.model._meta.fields]
field_names = [f.attname for f in self.model._meta.fields]
- cursor = connection.cursor()
- select, sql, params = self._get_sql_clause()
select = ['%s.%s' % (backend.quote_name(self.model._meta.db_table), backend.quote_name(c)) for c in columns]
+
+ # Add any additional SELECTs.
+ if self._select:
+ select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), backend.quote_name(s[0])) for s in self._select.items()])
+
+ cursor = connection.cursor()
cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params)
while 1:
rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)
@@ -545,7 +608,12 @@ class DateQuerySet(QuerySet):
if self._field.null:
self._where.append('%s.%s IS NOT NULL' % \
(backend.quote_name(self.model._meta.db_table), backend.quote_name(self._field.column)))
- select, sql, params = self._get_sql_clause()
+
+ try:
+ select, sql, params = self._get_sql_clause()
+ except EmptyResultSet:
+ raise StopIteration
+
sql = 'SELECT %s %s GROUP BY 1 ORDER BY 1 %s' % \
(backend.get_date_trunc_sql(self._kind, '%s.%s' % (backend.quote_name(self.model._meta.db_table),
backend.quote_name(self._field.column))), sql, self._order)
@@ -563,6 +631,25 @@ class DateQuerySet(QuerySet):
c._order = self._order
return c
+class EmptyQuerySet(QuerySet):
+ def __init__(self, model=None):
+ super(EmptyQuerySet, self).__init__(model)
+ self._result_cache = []
+
+ def count(self):
+ return 0
+
+ def delete(self):
+ pass
+
+ def _clone(self, klass=None, **kwargs):
+ c = super(EmptyQuerySet, self)._clone(klass, **kwargs)
+ c._result_cache = []
+ return c
+
+ def _get_sql_clause(self):
+ raise EmptyResultSet
+
class QOperator(object):
"Base class for QAnd and QOr"
def __init__(self, *args):
@@ -571,10 +658,14 @@ class QOperator(object):
def get_sql(self, opts):
joins, where, params = SortedDict(), [], []
for val in self.args:
- joins2, where2, params2 = val.get_sql(opts)
- joins.update(joins2)
- where.extend(where2)
- params.extend(params2)
+ try:
+ joins2, where2, params2 = val.get_sql(opts)
+ joins.update(joins2)
+ where.extend(where2)
+ params.extend(params2)
+ except EmptyResultSet:
+ if not isinstance(self, QOr):
+ raise EmptyResultSet
if where:
return joins, ['(%s)' % self.operator.join(where)], params
return joins, [], params
@@ -628,8 +719,11 @@ class QNot(Q):
self.q = q
def get_sql(self, opts):
- joins, where, params = self.q.get_sql(opts)
- where2 = ['(NOT (%s))' % " AND ".join(where)]
+ try:
+ joins, where, params = self.q.get_sql(opts)
+ where2 = ['(NOT (%s))' % " AND ".join(where)]
+ except EmptyResultSet:
+ return SortedDict(), [], []
return joins, where2, params
def get_where_clause(lookup_type, table_prefix, field_name, value):
@@ -641,10 +735,14 @@ def get_where_clause(lookup_type, table_prefix, field_name, value):
except KeyError:
pass
if lookup_type == 'in':
- return '%s%s IN (%s)' % (table_prefix, field_name, ','.join(['%s' for v in value]))
- elif lookup_type == 'range':
+ in_string = ','.join(['%s' for id in value])
+ if in_string:
+ return '%s%s IN (%s)' % (table_prefix, field_name, in_string)
+ else:
+ raise EmptyResultSet
+ elif lookup_type in ('range', 'year'):
return '%s%s BETWEEN %%s AND %%s' % (table_prefix, field_name)
- elif lookup_type in ('year', 'month', 'day'):
+ elif lookup_type in ('month', 'day'):
return "%s = %%s" % backend.get_date_extract_sql(lookup_type, table_prefix + field_name)
elif lookup_type == 'isnull':
return "%s%s IS %sNULL" % (table_prefix, field_name, (not value and 'NOT ' or ''))
@@ -652,21 +750,33 @@ def get_where_clause(lookup_type, table_prefix, field_name, value):
return backend.get_fulltext_search_sql(table_prefix + field_name)
raise TypeError, "Got invalid lookup_type: %s" % repr(lookup_type)
-def get_cached_row(klass, row, index_start):
- "Helper function that recursively returns an object with cache filled"
+def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0):
+ """Helper function that recursively returns an object with cache filled"""
+
+ # If we've got a max_depth set and we've exceeded that depth, bail now.
+ if max_depth and cur_depth > max_depth:
+ return None
+
index_end = index_start + len(klass._meta.fields)
obj = klass(*row[index_start:index_end])
for f in klass._meta.fields:
if f.rel and not f.null:
- rel_obj, index_end = get_cached_row(f.rel.to, row, index_end)
- setattr(obj, f.get_cache_name(), rel_obj)
+ cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, cur_depth+1)
+ if cached_row:
+ rel_obj, index_end = cached_row
+ setattr(obj, f.get_cache_name(), rel_obj)
return obj, index_end
-def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen):
+def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen, max_depth=0, cur_depth=0):
"""
Helper function that recursively populates the select, tables and where (in
place) for select_related queries.
"""
+
+ # If we've got a max_depth set and we've exceeded that depth, bail now.
+ if max_depth and cur_depth > max_depth:
+ return None
+
qn = backend.quote_name
for f in opts.fields:
if f.rel and not f.null:
@@ -681,12 +791,12 @@ def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen)
where.append('%s.%s = %s.%s' % \
(qn(old_prefix), qn(f.column), qn(db_table), qn(f.rel.get_related_field().column)))
select.extend(['%s.%s' % (qn(db_table), qn(f2.column)) for f2 in f.rel.to._meta.fields])
- fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen)
+ fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen, max_depth, cur_depth+1)
def parse_lookup(kwarg_items, opts):
# Helper function that handles converting API kwargs
# (e.g. "name__exact": "tom") to SQL.
- # Returns a tuple of (tables, joins, where, params).
+ # Returns a tuple of (joins, where, params).
# 'joins' is a sorted dictionary describing the tables that must be joined
# to complete the query. The dictionary is sorted because creation order
@@ -725,12 +835,14 @@ def parse_lookup(kwarg_items, opts):
if len(path) < 1:
raise TypeError, "Cannot parse keyword query %r" % kwarg
-
+
if value is None:
# Interpret '__exact=None' as the sql '= NULL'; otherwise, reject
# all uses of None as a query value.
if lookup_type != 'exact':
raise ValueError, "Cannot use None as a query value"
+ elif callable(value):
+ value = value()
joins2, where2, params2 = lookup_inner(path, lookup_type, value, opts, opts.db_table, None)
joins.update(joins2)
@@ -755,6 +867,13 @@ def find_field(name, field_list, related_query):
return None
return matches[0]
+def field_choices(field_list, related_query):
+ if related_query:
+ choices = [f.field.related_query_name() for f in field_list]
+ else:
+ choices = [f.name for f in field_list]
+ return choices
+
def lookup_inner(path, lookup_type, value, opts, table, column):
qn = backend.quote_name
joins, where, params = SortedDict(), [], []
@@ -827,13 +946,23 @@ def lookup_inner(path, lookup_type, value, opts, table, column):
new_opts = field.rel.to._meta
new_column = new_opts.pk.column
join_column = field.column
-
- raise FieldFound
+ raise FieldFound
+ elif path:
+ # For regular fields, if there are still items on the path,
+ # an error has been made. We munge "name" so that the error
+ # properly identifies the cause of the problem.
+ name += LOOKUP_SEPARATOR + path[0]
+ else:
+ raise FieldFound
except FieldFound: # Match found, loop has been shortcut.
pass
else: # No match found.
- raise TypeError, "Cannot resolve keyword '%s' into field" % name
+ choices = field_choices(current_opts.many_to_many, False) + \
+ field_choices(current_opts.get_all_related_many_to_many_objects(), True) + \
+ field_choices(current_opts.get_all_related_objects(), True) + \
+ field_choices(current_opts.fields, False)
+ raise TypeError, "Cannot resolve keyword '%s' into field. Choices are: %s" % (name, ", ".join(choices))
# Check whether an intermediate join is required between current_table
# and new_table.
@@ -926,18 +1055,26 @@ def delete_objects(seen_objs):
pk_list = [pk for pk,instance in seen_objs[cls]]
for related in cls._meta.get_all_related_many_to_many_objects():
- for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
- cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \
- (qn(related.field.m2m_db_table()),
- qn(related.field.m2m_reverse_name()),
- ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])),
- pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE])
+ if not isinstance(related.field, generic.GenericRelation):
+ for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
+ cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \
+ (qn(related.field.m2m_db_table()),
+ qn(related.field.m2m_reverse_name()),
+ ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])),
+ pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE])
for f in cls._meta.many_to_many:
+ if isinstance(f, generic.GenericRelation):
+ from django.contrib.contenttypes.models import ContentType
+ query_extra = 'AND %s=%%s' % f.rel.to._meta.get_field(f.content_type_field_name).column
+ args_extra = [ContentType.objects.get_for_model(cls).id]
+ else:
+ query_extra = ''
+ args_extra = []
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
- cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \
+ cursor.execute(("DELETE FROM %s WHERE %s IN (%s)" % \
(qn(f.m2m_db_table()), qn(f.m2m_column_name()),
- ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])),
- pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE])
+ ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]]))) + query_extra,
+ pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE] + args_extra)
for field in cls._meta.fields:
if field.rel and field.null and field.rel.to in seen_objs:
for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
diff --git a/django/db/models/related.py b/django/db/models/related.py
index ac1ec50ca2..2c1dc5c516 100644
--- a/django/db/models/related.py
+++ b/django/db/models/related.py
@@ -1,7 +1,7 @@
class BoundRelatedObject(object):
def __init__(self, related_object, field_mapping, original):
self.relation = related_object
- self.field_mappings = field_mapping[related_object.opts.module_name]
+ self.field_mappings = field_mapping[related_object.name]
def template_name(self):
raise NotImplementedError
@@ -16,7 +16,7 @@ class RelatedObject(object):
self.opts = model._meta
self.field = field
self.edit_inline = field.rel.edit_inline
- self.name = self.opts.module_name
+ self.name = '%s:%s' % (self.opts.app_label, self.opts.module_name)
self.var_name = self.opts.object_name.lower()
def flatten_data(self, follow, obj=None):
@@ -68,7 +68,10 @@ class RelatedObject(object):
# object
return [attr]
else:
- return [None] * self.field.rel.num_in_admin
+ if self.field.rel.min_num_in_admin:
+ return [None] * max(self.field.rel.num_in_admin, self.field.rel.min_num_in_admin)
+ else:
+ return [None] * self.field.rel.num_in_admin
def get_db_prep_lookup(self, lookup_type, value):
# Defer to the actual field definition for db prep
@@ -101,12 +104,12 @@ class RelatedObject(object):
attr = getattr(manipulator.original_object, self.get_accessor_name())
count = attr.count()
count += self.field.rel.num_extra_on_change
- if self.field.rel.min_num_in_admin:
- count = max(count, self.field.rel.min_num_in_admin)
- if self.field.rel.max_num_in_admin:
- count = min(count, self.field.rel.max_num_in_admin)
else:
count = self.field.rel.num_in_admin
+ if self.field.rel.min_num_in_admin:
+ count = max(count, self.field.rel.min_num_in_admin)
+ if self.field.rel.max_num_in_admin:
+ count = min(count, self.field.rel.max_num_in_admin)
else:
count = 1