diff options
Diffstat (limited to 'django/db/models')
-rw-r--r-- | django/db/models/__init__.py | 25 | ||||
-rw-r--r-- | django/db/models/base.py | 114 | ||||
-rw-r--r-- | django/db/models/fields/__init__.py | 205 | ||||
-rw-r--r-- | django/db/models/fields/generic.py | 259 | ||||
-rw-r--r-- | django/db/models/fields/related.py | 55 | ||||
-rw-r--r-- | django/db/models/loading.py | 2 | ||||
-rw-r--r-- | django/db/models/manager.py | 8 | ||||
-rw-r--r-- | django/db/models/manipulators.py | 10 | ||||
-rw-r--r-- | django/db/models/options.py | 5 | ||||
-rw-r--r-- | django/db/models/query.py | 229 | ||||
-rw-r--r-- | django/db/models/related.py | 17 |
11 files changed, 497 insertions, 432 deletions
diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index 0308dd047a..6c3abb6b59 100644 --- a/django/db/models/__init__.py +++ b/django/db/models/__init__.py @@ -8,7 +8,6 @@ from django.db.models.manager import Manager from django.db.models.base import Model, AdminOptions from django.db.models.fields import * from django.db.models.fields.related import ForeignKey, OneToOneField, ManyToManyField, ManyToOneRel, ManyToManyRel, OneToOneRel, TABULAR, STACKED -from django.db.models.fields.generic import GenericRelation, GenericRel, GenericForeignKey from django.db.models import signals from django.utils.functional import curry from django.utils.text import capfirst @@ -27,27 +26,3 @@ def permalink(func): viewname = bits[0] return reverse(bits[0], None, *bits[1:3]) return inner - -class LazyDate(object): - """ - Use in limit_choices_to to compare the field to dates calculated at run time - instead of when the model is loaded. For example:: - - ... limit_choices_to = {'date__gt' : models.LazyDate(days=-3)} ... - - which will limit the choices to dates greater than three days ago. - """ - def __init__(self, **kwargs): - self.delta = datetime.timedelta(**kwargs) - - def __str__(self): - return str(self.__get_value__()) - - def __repr__(self): - return "<LazyDate: %s>" % self.delta - - def __get_value__(self): - return (datetime.datetime.now() + self.delta).date() - - def __getattr__(self, attr): - return getattr(self.__get_value__(), attr) diff --git a/django/db/models/base.py b/django/db/models/base.py index 70569a2561..a8e6303e1c 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -13,6 +13,7 @@ from django.dispatch import dispatcher from django.utils.datastructures import SortedDict from django.utils.functional import curry from django.conf import settings +from itertools import izip import types import sys import os @@ -21,8 +22,13 @@ class ModelBase(type): "Metaclass for all models" def __new__(cls, name, bases, attrs): # If this isn't a subclass of Model, don't do anything special. - if not bases or bases == (object,): - return type.__new__(cls, name, bases, attrs) + try: + if not filter(lambda b: issubclass(b, Model), bases): + return super(ModelBase, cls).__new__(cls, name, bases, attrs) + except NameError: + # 'Model' isn't defined yet, meaning we're looking at Django's own + # Model class, defined below. + return super(ModelBase, cls).__new__(cls, name, bases, attrs) # Create the class. new_class = type.__new__(cls, name, bases, {'__module__': attrs.pop('__module__')}) @@ -36,11 +42,11 @@ class ModelBase(type): new_class._meta.parents.append(base) new_class._meta.parents.extend(base._meta.parents) - model_module = sys.modules[new_class.__module__] if getattr(new_class._meta, 'app_label', None) is None: # Figure out the app_label by looking one level up. # For 'django.contrib.sites.models', this would be 'sites'. + model_module = sys.modules[new_class.__module__] new_class._meta.app_label = model_module.__name__.split('.')[-2] # Bail out early if we have already created this class. @@ -63,7 +69,7 @@ class ModelBase(type): if getattr(new_class._meta, 'row_level_permissions', False): from django.contrib.auth.models import RowLevelPermission - gen_rel = django.db.models.GenericRelation(RowLevelPermission, object_id_field="model_id", content_type_field="model_ct") + gen_rel = django.contrib.contenttypes.generic.GenericRelation(RowLevelPermission, object_id_field="model_id", content_type_field="model_ct") new_class.add_to_class("row_level_permissions", gen_rel) new_class._prepare() @@ -95,41 +101,74 @@ class Model(object): def __init__(self, *args, **kwargs): dispatcher.send(signal=signals.pre_init, sender=self.__class__, args=args, kwargs=kwargs) - for f in self._meta.fields: - if isinstance(f.rel, ManyToOneRel): - try: - # Assume object instance was passed in. - rel_obj = kwargs.pop(f.name) - except KeyError: + + # There is a rather weird disparity here; if kwargs, it's set, then args + # overrides it. It should be one or the other; don't duplicate the work + # The reason for the kwargs check is that standard iterator passes in by + # args, and nstantiation for iteration is 33% faster. + args_len = len(args) + if args_len > len(self._meta.fields): + # Daft, but matches old exception sans the err msg. + raise IndexError("Number of args exceeds number of fields") + + fields_iter = iter(self._meta.fields) + if not kwargs: + # The ordering of the izip calls matter - izip throws StopIteration + # when an iter throws it. So if the first iter throws it, the second + # is *not* consumed. We rely on this, so don't change the order + # without changing the logic. + for val, field in izip(args, fields_iter): + setattr(self, field.attname, val) + else: + # Slower, kwargs-ready version. + for val, field in izip(args, fields_iter): + setattr(self, field.attname, val) + kwargs.pop(field.name, None) + # Maintain compatibility with existing calls. + if isinstance(field.rel, ManyToOneRel): + kwargs.pop(field.attname, None) + + # Now we're left with the unprocessed fields that *must* come from + # keywords, or default. + + for field in fields_iter: + if kwargs: + if isinstance(field.rel, ManyToOneRel): try: - # Object instance wasn't passed in -- must be an ID. - val = kwargs.pop(f.attname) + # Assume object instance was passed in. + rel_obj = kwargs.pop(field.name) except KeyError: - val = f.get_default() - else: - # Object instance was passed in. - # Special case: You can pass in "None" for related objects if it's allowed. - if rel_obj is None and f.null: - val = None - else: try: - val = getattr(rel_obj, f.rel.get_related_field().attname) - except AttributeError: - raise TypeError, "Invalid value: %r should be a %s instance, not a %s" % (f.name, f.rel.to, type(rel_obj)) - setattr(self, f.attname, val) + # Object instance wasn't passed in -- must be an ID. + val = kwargs.pop(field.attname) + except KeyError: + val = field.get_default() + else: + # Object instance was passed in. Special case: You can + # pass in "None" for related objects if it's allowed. + if rel_obj is None and field.null: + val = None + else: + try: + val = getattr(rel_obj, field.rel.get_related_field().attname) + except AttributeError: + raise TypeError("Invalid value: %r should be a %s instance, not a %s" % + (field.name, field.rel.to, type(rel_obj))) + else: + val = kwargs.pop(field.attname, field.get_default()) else: - val = kwargs.pop(f.attname, f.get_default()) - setattr(self, f.attname, val) - for prop in kwargs.keys(): - try: - if isinstance(getattr(self.__class__, prop), property): - setattr(self, prop, kwargs.pop(prop)) - except AttributeError: - pass + val = field.get_default() + setattr(self, field.attname, val) + if kwargs: - raise TypeError, "'%s' is an invalid keyword argument for this function" % kwargs.keys()[0] - for i, arg in enumerate(args): - setattr(self, self._meta.fields[i].attname, arg) + for prop in kwargs.keys(): + try: + if isinstance(getattr(self.__class__, prop), property): + setattr(self, prop, kwargs.pop(prop)) + except AttributeError: + pass + if kwargs: + raise TypeError, "'%s' is an invalid keyword argument for this function" % kwargs.keys()[0] dispatcher.send(signal=signals.post_init, sender=self.__class__, instance=self) def add_to_class(cls, name, value): @@ -327,7 +366,7 @@ class Model(object): def _get_FIELD_size(self, field): return os.path.getsize(self._get_FIELD_filename(field)) - def _save_FIELD_file(self, field, filename, raw_contents): + def _save_FIELD_file(self, field, filename, raw_contents, save=True): directory = field.get_directory_name() try: # Create the date-based directory if it doesn't exist. os.makedirs(os.path.join(settings.MEDIA_ROOT, directory)) @@ -362,8 +401,9 @@ class Model(object): if field.height_field: setattr(self, field.height_field, height) - # Save the object, because it has changed. - self.save() + # Save the object because it has changed unless save is False + if save: + self.save() _save_FIELD_file.alters_data = True diff --git a/django/db/models/fields/__init__.py b/django/db/models/fields/__init__.py index fe317ac24f..136ce31b8b 100644 --- a/django/db/models/fields/__init__.py +++ b/django/db/models/fields/__init__.py @@ -10,6 +10,10 @@ from django.utils.itercompat import tee from django.utils.text import capfirst from django.utils.translation import gettext, gettext_lazy import datetime, os, time +try: + import decimal +except ImportError: + from django.utils import _decimal as decimal # for Python 2.3 class NOT_PROVIDED: pass @@ -67,7 +71,7 @@ class Field(object): def __init__(self, verbose_name=None, name=None, primary_key=False, maxlength=None, unique=False, blank=False, null=False, db_index=False, - core=False, rel=None, default=NOT_PROVIDED, editable=True, + core=False, rel=None, default=NOT_PROVIDED, editable=True, serialize=True, prepopulate_from=None, unique_for_date=None, unique_for_month=None, unique_for_year=None, validator_list=None, choices=None, radio_admin=None, help_text='', db_column=None): @@ -78,6 +82,7 @@ class Field(object): self.blank, self.null = blank, null self.core, self.rel, self.default = core, rel, default self.editable = editable + self.serialize = serialize self.validator_list = validator_list or [] self.prepopulate_from = prepopulate_from self.unique_for_date, self.unique_for_month = unique_for_date, unique_for_month @@ -164,7 +169,7 @@ class Field(object): def get_db_prep_lookup(self, lookup_type, value): "Returns field's value prepared for database lookup." - if lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte', 'year', 'month', 'day', 'search'): + if lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte', 'month', 'day', 'search'): return [value] elif lookup_type in ('range', 'in'): return value @@ -178,7 +183,13 @@ class Field(object): return ["%%%s" % prep_for_like_query(value)] elif lookup_type == 'isnull': return [] - raise TypeError, "Field has invalid lookup: %s" % lookup_type + elif lookup_type == 'year': + try: + value = int(value) + except ValueError: + raise ValueError("The __year lookup type requires an integer argument") + return ['%s-01-01 00:00:00' % value, '%s-12-31 23:59:59.999999' % value] + raise TypeError("Field has invalid lookup: %s" % lookup_type) def has_default(self): "Returns a boolean of whether this field has a default value." @@ -334,10 +345,17 @@ class Field(object): return self._choices choices = property(_get_choices) - def formfield(self): + def formfield(self, form_class=forms.CharField, **kwargs): "Returns a django.newforms.Field instance for this database Field." - # TODO: This is just a temporary default during development. - return forms.CharField(required=not self.blank, label=capfirst(self.verbose_name)) + defaults = {'required': not self.blank, 'label': capfirst(self.verbose_name), 'help_text': self.help_text} + if self.choices: + defaults['widget'] = forms.Select(choices=self.get_choices()) + defaults.update(kwargs) + return form_class(**defaults) + + def value_from_object(self, obj): + "Returns the value of this field in the given model instance." + return getattr(obj, self.attname) class AutoField(Field): empty_strings_allowed = False @@ -375,7 +393,7 @@ class AutoField(Field): super(AutoField, self).contribute_to_class(cls, name) cls._meta.has_auto_field = True - def formfield(self): + def formfield(self, **kwargs): return None class BooleanField(Field): @@ -392,8 +410,10 @@ class BooleanField(Field): def get_manipulator_field_objs(self): return [oldforms.CheckboxField] - def formfield(self): - return forms.BooleanField(required=not self.blank, label=capfirst(self.verbose_name)) + def formfield(self, **kwargs): + defaults = {'form_class': forms.BooleanField} + defaults.update(kwargs) + return super(BooleanField, self).formfield(**defaults) class CharField(Field): def get_manipulator_field_objs(self): @@ -409,8 +429,10 @@ class CharField(Field): raise validators.ValidationError, gettext_lazy("This field cannot be null.") return str(value) - def formfield(self): - return forms.CharField(max_length=self.maxlength, required=not self.blank, label=capfirst(self.verbose_name)) + def formfield(self, **kwargs): + defaults = {'max_length': self.maxlength} + defaults.update(kwargs) + return super(CharField, self).formfield(**defaults) # TODO: Maybe move this into contrib, because it's specialized. class CommaSeparatedIntegerField(CharField): @@ -428,6 +450,8 @@ class DateField(Field): Field.__init__(self, verbose_name, name, **kwargs) def to_python(self, value): + if value is None: + return value if isinstance(value, datetime.datetime): return value.date() if isinstance(value, datetime.date): @@ -479,15 +503,19 @@ class DateField(Field): def get_manipulator_field_objs(self): return [oldforms.DateField] - def flatten_data(self, follow, obj = None): + def flatten_data(self, follow, obj=None): val = self._get_val_from_obj(obj) return {self.attname: (val is not None and val.strftime("%Y-%m-%d") or '')} - def formfield(self): - return forms.DateField(required=not self.blank, label=capfirst(self.verbose_name)) + def formfield(self, **kwargs): + defaults = {'form_class': forms.DateField} + defaults.update(kwargs) + return super(DateField, self).formfield(**defaults) class DateTimeField(DateField): def to_python(self, value): + if value is None: + return value if isinstance(value, datetime.datetime): return value if isinstance(value, datetime.date): @@ -544,8 +572,69 @@ class DateTimeField(DateField): return {date_field: (val is not None and val.strftime("%Y-%m-%d") or ''), time_field: (val is not None and val.strftime("%H:%M:%S") or '')} - def formfield(self): - return forms.DateTimeField(required=not self.blank, label=capfirst(self.verbose_name)) + def formfield(self, **kwargs): + defaults = {'form_class': forms.DateTimeField} + defaults.update(kwargs) + return super(DateTimeField, self).formfield(**defaults) + +class DecimalField(Field): + empty_strings_allowed = False + def __init__(self, verbose_name=None, name=None, max_digits=None, decimal_places=None, **kwargs): + self.max_digits, self.decimal_places = max_digits, decimal_places + Field.__init__(self, verbose_name, name, **kwargs) + + def to_python(self, value): + if value is None: + return value + try: + return decimal.Decimal(value) + except decimal.InvalidOperation: + raise validators.ValidationError, gettext("This value must be a decimal number.") + + def _format(self, value): + if isinstance(value, basestring): + return value + else: + return self.format_number(value) + + def format_number(self, value): + """ + Formats a number into a string with the requisite number of digits and + decimal places. + """ + num_chars = self.max_digits + # Allow for a decimal point + if self.decimal_places > 0: + num_chars += 1 + # Allow for a minus sign + if value < 0: + num_chars += 1 + + return "%.*f" % (self.decimal_places, value) + + def get_db_prep_save(self, value): + if value is not None: + value = self._format(value) + return super(DecimalField, self).get_db_prep_save(value) + + def get_db_prep_lookup(self, lookup_type, value): + if lookup_type == 'range': + value = [self._format(v) for v in value] + else: + value = self._format(value) + return super(DecimalField, self).get_db_prep_lookup(lookup_type, value) + + def get_manipulator_field_objs(self): + return [curry(oldforms.DecimalField, max_digits=self.max_digits, decimal_places=self.decimal_places)] + + def formfield(self, **kwargs): + defaults = { + 'max_digits': self.max_digits, + 'decimal_places': self.decimal_places, + 'form_class': forms.DecimalField, + } + defaults.update(kwargs) + return super(DecimalField, self).formfield(**defaults) class EmailField(CharField): def __init__(self, *args, **kwargs): @@ -561,8 +650,10 @@ class EmailField(CharField): def validate(self, field_data, all_data): validators.isValidEmail(field_data, all_data) - def formfield(self): - return forms.EmailField(required=not self.blank, label=capfirst(self.verbose_name)) + def formfield(self, **kwargs): + defaults = {'form_class': forms.EmailField} + defaults.update(kwargs) + return super(EmailField, self).formfield(**defaults) class FileField(Field): def __init__(self, verbose_name=None, name=None, upload_to='', **kwargs): @@ -610,7 +701,7 @@ class FileField(Field): setattr(cls, 'get_%s_filename' % self.name, curry(cls._get_FIELD_filename, field=self)) setattr(cls, 'get_%s_url' % self.name, curry(cls._get_FIELD_url, field=self)) setattr(cls, 'get_%s_size' % self.name, curry(cls._get_FIELD_size, field=self)) - setattr(cls, 'save_%s_file' % self.name, lambda instance, filename, raw_contents: instance._save_FIELD_file(self, filename, raw_contents)) + setattr(cls, 'save_%s_file' % self.name, lambda instance, filename, raw_contents, save=True: instance._save_FIELD_file(self, filename, raw_contents, save)) dispatcher.connect(self.delete_file, signal=signals.post_delete, sender=cls) def delete_file(self, instance): @@ -628,14 +719,14 @@ class FileField(Field): def get_manipulator_field_names(self, name_prefix): return [name_prefix + self.name + '_file', name_prefix + self.name] - def save_file(self, new_data, new_object, original_object, change, rel): + def save_file(self, new_data, new_object, original_object, change, rel, save=True): upload_field_name = self.get_manipulator_field_names('')[0] if new_data.get(upload_field_name, False): func = getattr(new_object, 'save_%s_file' % self.name) if rel: - func(new_data[upload_field_name][0]["filename"], new_data[upload_field_name][0]["content"]) + func(new_data[upload_field_name][0]["filename"], new_data[upload_field_name][0]["content"], save) else: - func(new_data[upload_field_name]["filename"], new_data[upload_field_name]["content"]) + func(new_data[upload_field_name]["filename"], new_data[upload_field_name]["content"], save) def get_directory_name(self): return os.path.normpath(datetime.datetime.now().strftime(self.upload_to)) @@ -655,12 +746,14 @@ class FilePathField(Field): class FloatField(Field): empty_strings_allowed = False - def __init__(self, verbose_name=None, name=None, max_digits=None, decimal_places=None, **kwargs): - self.max_digits, self.decimal_places = max_digits, decimal_places - Field.__init__(self, verbose_name, name, **kwargs) def get_manipulator_field_objs(self): - return [curry(oldforms.FloatField, max_digits=self.max_digits, decimal_places=self.decimal_places)] + return [oldforms.FloatField] + + def formfield(self, **kwargs): + defaults = {'form_class': forms.FloatField} + defaults.update(kwargs) + return super(FloatField, self).formfield(**defaults) class ImageField(FileField): def __init__(self, verbose_name=None, name=None, width_field=None, height_field=None, **kwargs): @@ -679,12 +772,12 @@ class ImageField(FileField): if not self.height_field: setattr(cls, 'get_%s_height' % self.name, curry(cls._get_FIELD_height, field=self)) - def save_file(self, new_data, new_object, original_object, change, rel): - FileField.save_file(self, new_data, new_object, original_object, change, rel) + def save_file(self, new_data, new_object, original_object, change, rel, save=True): + FileField.save_file(self, new_data, new_object, original_object, change, rel, save) # If the image has height and/or width field(s) and they haven't # changed, set the width and/or height field(s) back to their original # values. - if change and (self.width_field or self.height_field): + if change and (self.width_field or self.height_field) and save: if self.width_field: setattr(new_object, self.width_field, getattr(original_object, self.width_field)) if self.height_field: @@ -696,8 +789,10 @@ class IntegerField(Field): def get_manipulator_field_objs(self): return [oldforms.IntegerField] - def formfield(self): - return forms.IntegerField(required=not self.blank, label=capfirst(self.verbose_name)) + def formfield(self, **kwargs): + defaults = {'form_class': forms.IntegerField} + defaults.update(kwargs) + return super(IntegerField, self).formfield(**defaults) class IPAddressField(Field): def __init__(self, *args, **kwargs): @@ -715,6 +810,13 @@ class NullBooleanField(Field): kwargs['null'] = True Field.__init__(self, *args, **kwargs) + def to_python(self, value): + if value in (None, True, False): return value + if value in ('None'): return None + if value in ('t', 'True', '1'): return True + if value in ('f', 'False', '0'): return False + raise validators.ValidationError, gettext("This value must be either None, True or False.") + def get_manipulator_field_objs(self): return [oldforms.NullBooleanField] @@ -725,6 +827,12 @@ class PhoneNumberField(IntegerField): def validate(self, field_data, all_data): validators.isValidPhone(field_data, all_data) + def formfield(self, **kwargs): + from django.contrib.localflavor.us.forms import USPhoneNumberField + defaults = {'form_class': USPhoneNumberField} + defaults.update(kwargs) + return super(PhoneNumberField, self).formfield(**defaults) + class PositiveIntegerField(IntegerField): def get_manipulator_field_objs(self): return [oldforms.PositiveIntegerField] @@ -738,7 +846,7 @@ class SlugField(Field): kwargs['maxlength'] = kwargs.get('maxlength', 50) kwargs.setdefault('validator_list', []).append(validators.isSlug) # Set db_index=True unless it's been set manually. - if not kwargs.has_key('db_index'): + if 'db_index' not in kwargs: kwargs['db_index'] = True Field.__init__(self, *args, **kwargs) @@ -753,6 +861,11 @@ class TextField(Field): def get_manipulator_field_objs(self): return [oldforms.LargeTextField] + def formfield(self, **kwargs): + defaults = {'widget': forms.Textarea} + defaults.update(kwargs) + return super(TextField, self).formfield(**defaults) + class TimeField(Field): empty_strings_allowed = False def __init__(self, verbose_name=None, name=None, auto_now=False, auto_now_add=False, **kwargs): @@ -781,7 +894,7 @@ class TimeField(Field): if value is not None: # MySQL will throw a warning if microseconds are given, because it # doesn't support microseconds. - if settings.DATABASE_ENGINE == 'mysql': + if settings.DATABASE_ENGINE == 'mysql' and hasattr(value, 'microsecond'): value = value.replace(microsecond=0) value = str(value) return Field.get_db_prep_save(self, value) @@ -793,26 +906,40 @@ class TimeField(Field): val = self._get_val_from_obj(obj) return {self.attname: (val is not None and val.strftime("%H:%M:%S") or '')} - def formfield(self): - return forms.TimeField(required=not self.blank, label=capfirst(self.verbose_name)) + def formfield(self, **kwargs): + defaults = {'form_class': forms.TimeField} + defaults.update(kwargs) + return super(TimeField, self).formfield(**defaults) -class URLField(Field): +class URLField(CharField): def __init__(self, verbose_name=None, name=None, verify_exists=True, **kwargs): + kwargs['maxlength'] = kwargs.get('maxlength', 200) if verify_exists: kwargs.setdefault('validator_list', []).append(validators.isExistingURL) self.verify_exists = verify_exists - Field.__init__(self, verbose_name, name, **kwargs) + CharField.__init__(self, verbose_name, name, **kwargs) def get_manipulator_field_objs(self): return [oldforms.URLField] - def formfield(self): - return forms.URLField(required=not self.blank, verify_exists=self.verify_exists, label=capfirst(self.verbose_name)) + def get_internal_type(self): + return "CharField" + + def formfield(self, **kwargs): + defaults = {'form_class': forms.URLField, 'verify_exists': self.verify_exists} + defaults.update(kwargs) + return super(URLField, self).formfield(**defaults) class USStateField(Field): def get_manipulator_field_objs(self): return [oldforms.USStateField] + def formfield(self, **kwargs): + from django.contrib.localflavor.us.forms import USStateSelect + defaults = {'widget': USStateSelect} + defaults.update(kwargs) + return super(USStateField, self).formfield(**defaults) + class XMLField(TextField): def __init__(self, verbose_name=None, name=None, schema_path=None, **kwargs): self.schema_path = schema_path diff --git a/django/db/models/fields/generic.py b/django/db/models/fields/generic.py deleted file mode 100644 index 1ad8346e42..0000000000 --- a/django/db/models/fields/generic.py +++ /dev/null @@ -1,259 +0,0 @@ -""" -Classes allowing "generic" relations through ContentType and object-id fields. -""" - -from django import oldforms -from django.core.exceptions import ObjectDoesNotExist -from django.db import backend -from django.db.models import signals -from django.db.models.fields.related import RelatedField, Field, ManyToManyRel -from django.db.models.loading import get_model -from django.dispatch import dispatcher -from django.utils.functional import curry - -class GenericForeignKey(object): - """ - Provides a generic relation to any object through content-type/object-id - fields. - """ - - def __init__(self, ct_field="content_type", fk_field="object_id"): - self.ct_field = ct_field - self.fk_field = fk_field - - def contribute_to_class(self, cls, name): - # Make sure the fields exist (these raise FieldDoesNotExist, - # which is a fine error to raise here) - self.name = name - self.model = cls - self.cache_attr = "_%s_cache" % name - - # For some reason I don't totally understand, using weakrefs here doesn't work. - dispatcher.connect(self.instance_pre_init, signal=signals.pre_init, sender=cls, weak=False) - - # Connect myself as the descriptor for this field - setattr(cls, name, self) - - def instance_pre_init(self, signal, sender, args, kwargs): - # Handle initalizing an object with the generic FK instaed of - # content-type/object-id fields. - if kwargs.has_key(self.name): - value = kwargs.pop(self.name) - kwargs[self.ct_field] = self.get_content_type(value) - kwargs[self.fk_field] = value._get_pk_val() - - def get_content_type(self, obj): - # Convenience function using get_model avoids a circular import when using this model - ContentType = get_model("contenttypes", "contenttype") - return ContentType.objects.get_for_model(obj) - - def __get__(self, instance, instance_type=None): - if instance is None: - raise AttributeError, "%s must be accessed via instance" % self.name - - try: - return getattr(instance, self.cache_attr) - except AttributeError: - rel_obj = None - ct = getattr(instance, self.ct_field) - if ct: - try: - rel_obj = ct.get_object_for_this_type(pk=getattr(instance, self.fk_field)) - except ObjectDoesNotExist: - pass - setattr(instance, self.cache_attr, rel_obj) - return rel_obj - - def __set__(self, instance, value): - if instance is None: - raise AttributeError, "%s must be accessed via instance" % self.related.opts.object_name - - ct = None - fk = None - if value is not None: - ct = self.get_content_type(value) - fk = value._get_pk_val() - - setattr(instance, self.ct_field, ct) - setattr(instance, self.fk_field, fk) - setattr(instance, self.cache_attr, value) - -class GenericRelation(RelatedField, Field): - """Provides an accessor to generic related objects (i.e. comments)""" - - def __init__(self, to, **kwargs): - kwargs['verbose_name'] = kwargs.get('verbose_name', None) - kwargs['rel'] = GenericRel(to, - related_name=kwargs.pop('related_name', None), - limit_choices_to=kwargs.pop('limit_choices_to', None), - symmetrical=kwargs.pop('symmetrical', True)) - - # Override content-type/object-id field names on the related class - self.object_id_field_name = kwargs.pop("object_id_field", "object_id") - self.content_type_field_name = kwargs.pop("content_type_field", "content_type") - - kwargs['blank'] = True - kwargs['editable'] = False - Field.__init__(self, **kwargs) - - def get_manipulator_field_objs(self): - choices = self.get_choices_default() - return [curry(oldforms.SelectMultipleField, size=min(max(len(choices), 5), 15), choices=choices)] - - def get_choices_default(self): - return Field.get_choices(self, include_blank=False) - - def flatten_data(self, follow, obj = None): - new_data = {} - if obj: - instance_ids = [instance._get_pk_val() for instance in getattr(obj, self.name).all()] - new_data[self.name] = instance_ids - return new_data - - def m2m_db_table(self): - return self.rel.to._meta.db_table - - def m2m_column_name(self): - return self.object_id_field_name - - def m2m_reverse_name(self): - return self.object_id_field_name - - def contribute_to_class(self, cls, name): - super(GenericRelation, self).contribute_to_class(cls, name) - - # Save a reference to which model this class is on for future use - self.model = cls - - # Add the descriptor for the m2m relation - setattr(cls, self.name, ReverseGenericRelatedObjectsDescriptor(self)) - - def contribute_to_related_class(self, cls, related): - pass - - def set_attributes_from_rel(self): - pass - - def get_internal_type(self): - return "ManyToManyField" - -class ReverseGenericRelatedObjectsDescriptor(object): - """ - This class provides the functionality that makes the related-object - managers available as attributes on a model class, for fields that have - multiple "remote" values and have a GenericRelation defined in their model - (rather than having another model pointed *at* them). In the example - "article.publications", the publications attribute is a - ReverseGenericRelatedObjectsDescriptor instance. - """ - def __init__(self, field): - self.field = field - - def __get__(self, instance, instance_type=None): - if instance is None: - raise AttributeError, "Manager must be accessed via instance" - - # This import is done here to avoid circular import importing this module - from django.contrib.contenttypes.models import ContentType - - # Dynamically create a class that subclasses the related model's - # default manager. - rel_model = self.field.rel.to - superclass = rel_model._default_manager.__class__ - RelatedManager = create_generic_related_manager(superclass) - - manager = RelatedManager( - model = rel_model, - instance = instance, - symmetrical = (self.field.rel.symmetrical and instance.__class__ == rel_model), - join_table = backend.quote_name(self.field.m2m_db_table()), - source_col_name = backend.quote_name(self.field.m2m_column_name()), - target_col_name = backend.quote_name(self.field.m2m_reverse_name()), - content_type = ContentType.objects.get_for_model(self.field.model), - content_type_field_name = self.field.content_type_field_name, - object_id_field_name = self.field.object_id_field_name - ) - - return manager - - def __set__(self, instance, value): - if instance is None: - raise AttributeError, "Manager must be accessed via instance" - - manager = self.__get__(instance) - manager.clear() - for obj in value: - manager.add(obj) - -def create_generic_related_manager(superclass): - """ - Factory function for a manager that subclasses 'superclass' (which is a - Manager) and adds behavior for generic related objects. - """ - - class GenericRelatedObjectManager(superclass): - def __init__(self, model=None, core_filters=None, instance=None, symmetrical=None, - join_table=None, source_col_name=None, target_col_name=None, content_type=None, - content_type_field_name=None, object_id_field_name=None): - - super(GenericRelatedObjectManager, self).__init__() - self.core_filters = core_filters or {} - self.model = model - self.content_type = content_type - self.symmetrical = symmetrical - self.instance = instance - self.join_table = join_table - self.join_table = model._meta.db_table - self.source_col_name = source_col_name - self.target_col_name = target_col_name - self.content_type_field_name = content_type_field_name - self.object_id_field_name = object_id_field_name - self.pk_val = self.instance._get_pk_val() - - def get_query_set(self): - query = { - '%s__pk' % self.content_type_field_name : self.content_type.id, - '%s__exact' % self.object_id_field_name : self.pk_val, - } - return superclass.get_query_set(self).filter(**query) - - def add(self, *objs): - for obj in objs: - setattr(obj, self.content_type_field_name, self.content_type) - setattr(obj, self.object_id_field_name, self.pk_val) - obj.save() - add.alters_data = True - - def remove(self, *objs): - for obj in objs: - obj.delete() - remove.alters_data = True - - def clear(self): - for obj in self.all(): - obj.delete() - clear.alters_data = True - - def create(self, **kwargs): - kwargs[self.content_type_field_name] = self.content_type - kwargs[self.object_id_field_name] = self.pk_val - obj = self.model(**kwargs) - obj.save() - return obj - create.alters_data = True - - return GenericRelatedObjectManager - -class GenericRel(ManyToManyRel): - def __init__(self, to, related_name=None, limit_choices_to=None, symmetrical=True): - self.to = to - self.num_in_admin = 0 - self.related_name = related_name - self.filter_interface = None - self.limit_choices_to = limit_choices_to or {} - self.edit_inline = False - self.raw_id_admin = False - self.symmetrical = symmetrical - self.multiple = True - assert not (self.raw_id_admin and self.filter_interface), \ - "Generic relations may not use both raw_id_admin and filter_interface" diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 797ef05be1..0739d0461a 100644 --- a/django/db/models/fields/related.py +++ b/django/db/models/fields/related.py @@ -2,10 +2,12 @@ from django.db import backend, transaction from django.db.models import signals, get_model from django.db.models.fields import AutoField, Field, IntegerField, get_ul_class from django.db.models.related import RelatedObject +from django.utils.text import capfirst from django.utils.translation import gettext_lazy, string_concat, ngettext from django.utils.functional import curry from django.core import validators from django import oldforms +from django import newforms as forms from django.dispatch import dispatcher # For Python 2.3 @@ -314,18 +316,20 @@ def create_many_related_manager(superclass): # join_table: name of the m2m link table # source_col_name: the PK colname in join_table for the source object # target_col_name: the PK colname in join_table for the target object - # *objs - objects to add + # *objs - objects to add. Either object instances, or primary keys of object instances. from django.db import connection # If there aren't any objects, there is nothing to do. if objs: # Check that all the objects are of the right type + new_ids = set() for obj in objs: - if not isinstance(obj, self.model): - raise ValueError, "objects to add() must be %s instances" % self.model._meta.object_name + if isinstance(obj, self.model): + new_ids.add(obj._get_pk_val()) + else: + new_ids.add(obj) # Add the newly created or already existing objects to the join table. # First find out which items are already added, to avoid adding them twice - new_ids = set([obj._get_pk_val() for obj in objs]) cursor = connection.cursor() cursor.execute("SELECT %s FROM %s WHERE %s = %%s AND %s IN (%s)" % \ (target_col_name, self.join_table, source_col_name, @@ -352,14 +356,16 @@ def create_many_related_manager(superclass): # If there aren't any objects, there is nothing to do. if objs: # Check that all the objects are of the right type + old_ids = set() for obj in objs: - if not isinstance(obj, self.model): - raise ValueError, "objects to remove() must be %s instances" % self.model._meta.object_name + if isinstance(obj, self.model): + old_ids.add(obj._get_pk_val()) + else: + old_ids.add(obj) # Remove the specified objects from the join table - old_ids = set([obj._get_pk_val() for obj in objs]) cursor = connection.cursor() cursor.execute("DELETE FROM %s WHERE %s = %%s AND %s IN (%s)" % \ - (self.join_table, source_col_name, + (self.join_table, source_col_name, target_col_name, ",".join(['%s'] * len(old_ids))), [self._pk_val] + list(old_ids)) transaction.commit_unless_managed() @@ -468,7 +474,7 @@ class ForeignKey(RelatedField, Field): to_field = to_field or to._meta.pk.name kwargs['verbose_name'] = kwargs.get('verbose_name', '') - if kwargs.has_key('edit_inline_type'): + if 'edit_inline_type' in kwargs: import warnings warnings.warn("edit_inline_type is deprecated. Use edit_inline instead.") kwargs['edit_inline'] = kwargs.pop('edit_inline_type') @@ -546,6 +552,11 @@ class ForeignKey(RelatedField, Field): def contribute_to_related_class(self, cls, related): setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related)) + def formfield(self, **kwargs): + defaults = {'form_class': forms.ModelChoiceField, 'queryset': self.rel.to._default_manager.all()} + defaults.update(kwargs) + return super(ForeignKey, self).formfield(**defaults) + class OneToOneField(RelatedField, IntegerField): def __init__(self, to, to_field=None, **kwargs): try: @@ -556,7 +567,7 @@ class OneToOneField(RelatedField, IntegerField): to_field = to_field or to._meta.pk.name kwargs['verbose_name'] = kwargs.get('verbose_name', '') - if kwargs.has_key('edit_inline_type'): + if 'edit_inline_type' in kwargs: import warnings warnings.warn("edit_inline_type is deprecated. Use edit_inline instead.") kwargs['edit_inline'] = kwargs.pop('edit_inline_type') @@ -607,6 +618,11 @@ class OneToOneField(RelatedField, IntegerField): if not cls._meta.one_to_one_field: cls._meta.one_to_one_field = self + def formfield(self, **kwargs): + defaults = {'form_class': forms.ModelChoiceField, 'queryset': self.rel.to._default_manager.all()} + defaults.update(kwargs) + return super(OneToOneField, self).formfield(**defaults) + class ManyToManyField(RelatedField, Field): def __init__(self, to, **kwargs): kwargs['verbose_name'] = kwargs.get('verbose_name', None) @@ -617,6 +633,7 @@ class ManyToManyField(RelatedField, Field): limit_choices_to=kwargs.pop('limit_choices_to', None), raw_id_admin=kwargs.pop('raw_id_admin', False), symmetrical=kwargs.pop('symmetrical', True)) + self.db_table = kwargs.pop('db_table', None) if kwargs["rel"].raw_id_admin: kwargs.setdefault("validator_list", []).append(self.isValidIDList) Field.__init__(self, **kwargs) @@ -639,7 +656,10 @@ class ManyToManyField(RelatedField, Field): def _get_m2m_db_table(self, opts): "Function that can be curried to provide the m2m table name for this relation" - return '%s_%s' % (opts.db_table, self.name) + if self.db_table: + return self.db_table + else: + return '%s_%s' % (opts.db_table, self.name) def _get_m2m_column_name(self, related): "Function that can be curried to provide the source column name for the m2m table" @@ -713,6 +733,19 @@ class ManyToManyField(RelatedField, Field): def set_attributes_from_rel(self): pass + def value_from_object(self, obj): + "Returns the value of this field in the given model instance." + return getattr(obj, self.attname).all() + + def formfield(self, **kwargs): + defaults = {'form_class': forms.ModelMultipleChoiceField, 'queryset': self.rel.to._default_manager.all()} + defaults.update(kwargs) + # If initial is passed in, it's a list of related objects, but the + # MultipleChoiceField takes a list of IDs. + if defaults.get('initial') is not None: + defaults['initial'] = [i._get_pk_val() for i in defaults['initial']] + return super(ManyToManyField, self).formfield(**defaults) + class ManyToOneRel(object): def __init__(self, to, field_name, num_in_admin=3, min_num_in_admin=None, max_num_in_admin=None, num_extra_on_change=1, edit_inline=False, diff --git a/django/db/models/loading.py b/django/db/models/loading.py index f4aff2438b..224f5e8451 100644 --- a/django/db/models/loading.py +++ b/django/db/models/loading.py @@ -103,7 +103,7 @@ def register_models(app_label, *models): # in the _app_models dictionary model_name = model._meta.object_name.lower() model_dict = _app_models.setdefault(app_label, {}) - if model_dict.has_key(model_name): + if model_name in model_dict: # The same model may be imported via different paths (e.g. # appname.models and project.appname.models). We use the source # filename as a means to detect identity. diff --git a/django/db/models/manager.py b/django/db/models/manager.py index 6005874516..b60eed262a 100644 --- a/django/db/models/manager.py +++ b/django/db/models/manager.py @@ -1,4 +1,4 @@ -from django.db.models.query import QuerySet +from django.db.models.query import QuerySet, EmptyQuerySet from django.dispatch import dispatcher from django.db.models import signals from django.db.models.fields import FieldDoesNotExist @@ -41,12 +41,18 @@ class Manager(object): ####################### # PROXIES TO QUERYSET # ####################### + + def get_empty_query_set(self): + return EmptyQuerySet(self.model) def get_query_set(self): """Returns a new QuerySet object. Subclasses can override this method to easily customise the behaviour of the Manager. """ return QuerySet(self.model) + + def none(self): + return self.get_empty_query_set() def all(self): return self.get_query_set() diff --git a/django/db/models/manipulators.py b/django/db/models/manipulators.py index e9dfa7037c..d5fc5f725e 100644 --- a/django/db/models/manipulators.py +++ b/django/db/models/manipulators.py @@ -96,14 +96,16 @@ class AutomaticManipulator(oldforms.Manipulator): if self.change: params[self.opts.pk.attname] = self.obj_key - # First, save the basic object itself. + # First, create the basic object itself. new_object = self.model(**params) - new_object.save() - # Now that the object's been saved, save any uploaded files. + # Now that the object's been created, save any uploaded files. for f in self.opts.fields: if isinstance(f, FileField): - f.save_file(new_data, new_object, self.change and self.original_object or None, self.change, rel=False) + f.save_file(new_data, new_object, self.change and self.original_object or None, self.change, rel=False, save=False) + + # Now save the object + new_object.save() # Calculate which primary fields have changed. if self.change: diff --git a/django/db/models/options.py b/django/db/models/options.py index ee253ff451..556168e7d0 100644 --- a/django/db/models/options.py +++ b/django/db/models/options.py @@ -85,6 +85,7 @@ class Options(object): self.fields.insert(bisect(self.fields, field), field) if not self.pk and field.primary_key: self.pk = field + field.serialize = False def __repr__(self): return '<Options for %s>' % self.object_name @@ -140,7 +141,7 @@ class Options(object): def get_follow(self, override=None): follow = {} for f in self.fields + self.many_to_many + self.get_all_related_objects(): - if override and override.has_key(f.name): + if override and f.name in override: child_override = override[f.name] else: child_override = None @@ -182,7 +183,7 @@ class Options(object): # TODO: follow if not hasattr(self, '_field_types'): self._field_types = {} - if not self._field_types.has_key(field_type): + if field_type not in self._field_types: try: # First check self.fields. for f in self.fields: diff --git a/django/db/models/query.py b/django/db/models/query.py index 53ed63ae5b..a6e702be18 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1,8 +1,9 @@ from django.db import backend, connection, transaction from django.db.models.fields import DateField, FieldDoesNotExist -from django.db.models import signals +from django.db.models import signals, loading from django.dispatch import dispatcher from django.utils.datastructures import SortedDict +from django.contrib.contenttypes import generic import operator import re @@ -25,6 +26,9 @@ QUERY_TERMS = ( # Larger values are slightly faster at the expense of more storage space. GET_ITERATOR_CHUNK_SIZE = 100 +class EmptyResultSet(Exception): + pass + #################### # HELPER FUNCTIONS # #################### @@ -80,6 +84,7 @@ class QuerySet(object): self._filters = Q() self._order_by = None # Ordering, e.g. ('date', '-name'). If None, use model's ordering. self._select_related = False # Whether to fill cache for related objects. + self._max_related_depth = 0 # Maximum "depth" for select_related self._distinct = False # Whether the query should use SELECT DISTINCT. self._select = {} # Dictionary of attname -> SQL. self._where = [] # List of extra WHERE clauses to use. @@ -104,6 +109,8 @@ class QuerySet(object): def __getitem__(self, k): "Retrieve an item or slice from the set of results." + if not isinstance(k, (slice, int)): + raise TypeError assert (not isinstance(k, slice) and (k >= 0)) \ or (isinstance(k, slice) and (k.start is None or k.start >= 0) and (k.stop is None or k.stop >= 0)), \ "Negative indexing is not supported." @@ -163,12 +170,16 @@ class QuerySet(object): def iterator(self): "Performs the SELECT database lookup of this QuerySet." + try: + select, sql, params = self._get_sql_clause() + except EmptyResultSet: + raise StopIteration + # self._select is a dictionary, and dictionaries' key order is # undefined, so we convert it to a list of tuples. extra_select = self._select.items() cursor = connection.cursor() - select, sql, params = self._get_sql_clause() cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params) fill_cache = self._select_related index_end = len(self.model._meta.fields) @@ -178,7 +189,8 @@ class QuerySet(object): raise StopIteration for row in rows: if fill_cache: - obj, index_end = get_cached_row(self.model, row, 0) + obj, index_end = get_cached_row(klass=self.model, row=row, + index_start=0, max_depth=self._max_related_depth) else: obj = self.model(*row[:index_end]) for i, k in enumerate(extra_select): @@ -186,13 +198,31 @@ class QuerySet(object): yield obj def count(self): - "Performs a SELECT COUNT() and returns the number of records as an integer." + """ + Performs a SELECT COUNT() and returns the number of records as an + integer. + + If the queryset is already cached (i.e. self._result_cache is set) this + simply returns the length of the cached results set to avoid multiple + SELECT COUNT(*) calls. + """ + if self._result_cache is not None: + return len(self._result_cache) + counter = self._clone() counter._order_by = () + counter._select_related = False + + offset = counter._offset + limit = counter._limit counter._offset = None counter._limit = None - counter._select_related = False - select, sql, params = counter._get_sql_clause() + + try: + select, sql, params = counter._get_sql_clause() + except EmptyResultSet: + return 0 + cursor = connection.cursor() if self._distinct: id_col = "%s.%s" % (backend.quote_name(self.model._meta.db_table), @@ -200,7 +230,16 @@ class QuerySet(object): cursor.execute("SELECT COUNT(DISTINCT(%s))" % id_col + sql, params) else: cursor.execute("SELECT COUNT(*)" + sql, params) - return cursor.fetchone()[0] + count = cursor.fetchone()[0] + + # Apply any offset and limit constraints manually, since using LIMIT or + # OFFSET in SQL doesn't change the output of COUNT. + if offset: + count = max(0, count - offset) + if limit: + count = min(limit, count) + + return count def get(self, *args, **kwargs): "Performs the SELECT and returns a single object matching the given keyword arguments." @@ -359,9 +398,9 @@ class QuerySet(object): else: return self._filter_or_exclude(None, **filter_obj) - def select_related(self, true_or_false=True): + def select_related(self, true_or_false=True, depth=0): "Returns a new QuerySet instance with '_select_related' modified." - return self._clone(_select_related=true_or_false) + return self._clone(_select_related=true_or_false, _max_related_depth=depth) def order_by(self, *field_names): "Returns a new QuerySet instance with the ordering changed." @@ -395,6 +434,7 @@ class QuerySet(object): c._filters = self._filters c._order_by = self._order_by c._select_related = self._select_related + c._max_related_depth = self._max_related_depth c._distinct = self._distinct c._select = self._select.copy() c._where = self._where[:] @@ -448,7 +488,10 @@ class QuerySet(object): # Add additional tables and WHERE clauses based on select_related. if self._select_related: - fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table]) + fill_table_cache(opts, select, tables, where, + old_prefix=opts.db_table, + cache_tables_seen=[opts.db_table], + max_depth=self._max_related_depth) # Add any additional SELECTs. if self._select: @@ -509,22 +552,42 @@ class QuerySet(object): return select, " ".join(sql), params class ValuesQuerySet(QuerySet): - def iterator(self): - # select_related and select aren't supported in values(). + def __init__(self, *args, **kwargs): + super(ValuesQuerySet, self).__init__(*args, **kwargs) + # select_related isn't supported in values(). self._select_related = False - self._select = {} + + def iterator(self): + try: + select, sql, params = self._get_sql_clause() + except EmptyResultSet: + raise StopIteration # self._fields is a list of field names to fetch. if self._fields: - columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields] + #columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields] + if not self._select: + columns = [self.model._meta.get_field(f, many_to_many=False).column for f in self._fields] + else: + columns = [] + for f in self._fields: + if f in [field.name for field in self.model._meta.fields]: + columns.append( self.model._meta.get_field(f, many_to_many=False).column ) + elif not self._select.has_key( f ): + raise FieldDoesNotExist, '%s has no field named %r' % ( self.model._meta.object_name, f ) + field_names = self._fields else: # Default to all fields. columns = [f.column for f in self.model._meta.fields] field_names = [f.attname for f in self.model._meta.fields] - cursor = connection.cursor() - select, sql, params = self._get_sql_clause() select = ['%s.%s' % (backend.quote_name(self.model._meta.db_table), backend.quote_name(c)) for c in columns] + + # Add any additional SELECTs. + if self._select: + select.extend(['(%s) AS %s' % (quote_only_if_word(s[1]), backend.quote_name(s[0])) for s in self._select.items()]) + + cursor = connection.cursor() cursor.execute("SELECT " + (self._distinct and "DISTINCT " or "") + ",".join(select) + sql, params) while 1: rows = cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE) @@ -545,7 +608,12 @@ class DateQuerySet(QuerySet): if self._field.null: self._where.append('%s.%s IS NOT NULL' % \ (backend.quote_name(self.model._meta.db_table), backend.quote_name(self._field.column))) - select, sql, params = self._get_sql_clause() + + try: + select, sql, params = self._get_sql_clause() + except EmptyResultSet: + raise StopIteration + sql = 'SELECT %s %s GROUP BY 1 ORDER BY 1 %s' % \ (backend.get_date_trunc_sql(self._kind, '%s.%s' % (backend.quote_name(self.model._meta.db_table), backend.quote_name(self._field.column))), sql, self._order) @@ -563,6 +631,25 @@ class DateQuerySet(QuerySet): c._order = self._order return c +class EmptyQuerySet(QuerySet): + def __init__(self, model=None): + super(EmptyQuerySet, self).__init__(model) + self._result_cache = [] + + def count(self): + return 0 + + def delete(self): + pass + + def _clone(self, klass=None, **kwargs): + c = super(EmptyQuerySet, self)._clone(klass, **kwargs) + c._result_cache = [] + return c + + def _get_sql_clause(self): + raise EmptyResultSet + class QOperator(object): "Base class for QAnd and QOr" def __init__(self, *args): @@ -571,10 +658,14 @@ class QOperator(object): def get_sql(self, opts): joins, where, params = SortedDict(), [], [] for val in self.args: - joins2, where2, params2 = val.get_sql(opts) - joins.update(joins2) - where.extend(where2) - params.extend(params2) + try: + joins2, where2, params2 = val.get_sql(opts) + joins.update(joins2) + where.extend(where2) + params.extend(params2) + except EmptyResultSet: + if not isinstance(self, QOr): + raise EmptyResultSet if where: return joins, ['(%s)' % self.operator.join(where)], params return joins, [], params @@ -628,8 +719,11 @@ class QNot(Q): self.q = q def get_sql(self, opts): - joins, where, params = self.q.get_sql(opts) - where2 = ['(NOT (%s))' % " AND ".join(where)] + try: + joins, where, params = self.q.get_sql(opts) + where2 = ['(NOT (%s))' % " AND ".join(where)] + except EmptyResultSet: + return SortedDict(), [], [] return joins, where2, params def get_where_clause(lookup_type, table_prefix, field_name, value): @@ -641,10 +735,14 @@ def get_where_clause(lookup_type, table_prefix, field_name, value): except KeyError: pass if lookup_type == 'in': - return '%s%s IN (%s)' % (table_prefix, field_name, ','.join(['%s' for v in value])) - elif lookup_type == 'range': + in_string = ','.join(['%s' for id in value]) + if in_string: + return '%s%s IN (%s)' % (table_prefix, field_name, in_string) + else: + raise EmptyResultSet + elif lookup_type in ('range', 'year'): return '%s%s BETWEEN %%s AND %%s' % (table_prefix, field_name) - elif lookup_type in ('year', 'month', 'day'): + elif lookup_type in ('month', 'day'): return "%s = %%s" % backend.get_date_extract_sql(lookup_type, table_prefix + field_name) elif lookup_type == 'isnull': return "%s%s IS %sNULL" % (table_prefix, field_name, (not value and 'NOT ' or '')) @@ -652,21 +750,33 @@ def get_where_clause(lookup_type, table_prefix, field_name, value): return backend.get_fulltext_search_sql(table_prefix + field_name) raise TypeError, "Got invalid lookup_type: %s" % repr(lookup_type) -def get_cached_row(klass, row, index_start): - "Helper function that recursively returns an object with cache filled" +def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0): + """Helper function that recursively returns an object with cache filled""" + + # If we've got a max_depth set and we've exceeded that depth, bail now. + if max_depth and cur_depth > max_depth: + return None + index_end = index_start + len(klass._meta.fields) obj = klass(*row[index_start:index_end]) for f in klass._meta.fields: if f.rel and not f.null: - rel_obj, index_end = get_cached_row(f.rel.to, row, index_end) - setattr(obj, f.get_cache_name(), rel_obj) + cached_row = get_cached_row(f.rel.to, row, index_end, max_depth, cur_depth+1) + if cached_row: + rel_obj, index_end = cached_row + setattr(obj, f.get_cache_name(), rel_obj) return obj, index_end -def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen): +def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen, max_depth=0, cur_depth=0): """ Helper function that recursively populates the select, tables and where (in place) for select_related queries. """ + + # If we've got a max_depth set and we've exceeded that depth, bail now. + if max_depth and cur_depth > max_depth: + return None + qn = backend.quote_name for f in opts.fields: if f.rel and not f.null: @@ -681,12 +791,12 @@ def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen) where.append('%s.%s = %s.%s' % \ (qn(old_prefix), qn(f.column), qn(db_table), qn(f.rel.get_related_field().column))) select.extend(['%s.%s' % (qn(db_table), qn(f2.column)) for f2 in f.rel.to._meta.fields]) - fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen) + fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen, max_depth, cur_depth+1) def parse_lookup(kwarg_items, opts): # Helper function that handles converting API kwargs # (e.g. "name__exact": "tom") to SQL. - # Returns a tuple of (tables, joins, where, params). + # Returns a tuple of (joins, where, params). # 'joins' is a sorted dictionary describing the tables that must be joined # to complete the query. The dictionary is sorted because creation order @@ -725,12 +835,14 @@ def parse_lookup(kwarg_items, opts): if len(path) < 1: raise TypeError, "Cannot parse keyword query %r" % kwarg - + if value is None: # Interpret '__exact=None' as the sql '= NULL'; otherwise, reject # all uses of None as a query value. if lookup_type != 'exact': raise ValueError, "Cannot use None as a query value" + elif callable(value): + value = value() joins2, where2, params2 = lookup_inner(path, lookup_type, value, opts, opts.db_table, None) joins.update(joins2) @@ -755,6 +867,13 @@ def find_field(name, field_list, related_query): return None return matches[0] +def field_choices(field_list, related_query): + if related_query: + choices = [f.field.related_query_name() for f in field_list] + else: + choices = [f.name for f in field_list] + return choices + def lookup_inner(path, lookup_type, value, opts, table, column): qn = backend.quote_name joins, where, params = SortedDict(), [], [] @@ -827,13 +946,23 @@ def lookup_inner(path, lookup_type, value, opts, table, column): new_opts = field.rel.to._meta new_column = new_opts.pk.column join_column = field.column - - raise FieldFound + raise FieldFound + elif path: + # For regular fields, if there are still items on the path, + # an error has been made. We munge "name" so that the error + # properly identifies the cause of the problem. + name += LOOKUP_SEPARATOR + path[0] + else: + raise FieldFound except FieldFound: # Match found, loop has been shortcut. pass else: # No match found. - raise TypeError, "Cannot resolve keyword '%s' into field" % name + choices = field_choices(current_opts.many_to_many, False) + \ + field_choices(current_opts.get_all_related_many_to_many_objects(), True) + \ + field_choices(current_opts.get_all_related_objects(), True) + \ + field_choices(current_opts.fields, False) + raise TypeError, "Cannot resolve keyword '%s' into field. Choices are: %s" % (name, ", ".join(choices)) # Check whether an intermediate join is required between current_table # and new_table. @@ -926,18 +1055,26 @@ def delete_objects(seen_objs): pk_list = [pk for pk,instance in seen_objs[cls]] for related in cls._meta.get_all_related_many_to_many_objects(): - for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \ - (qn(related.field.m2m_db_table()), - qn(related.field.m2m_reverse_name()), - ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])), - pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]) + if not isinstance(related.field, generic.GenericRelation): + for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): + cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \ + (qn(related.field.m2m_db_table()), + qn(related.field.m2m_reverse_name()), + ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])), + pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]) for f in cls._meta.many_to_many: + if isinstance(f, generic.GenericRelation): + from django.contrib.contenttypes.models import ContentType + query_extra = 'AND %s=%%s' % f.rel.to._meta.get_field(f.content_type_field_name).column + args_extra = [ContentType.objects.get_for_model(cls).id] + else: + query_extra = '' + args_extra = [] for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - cursor.execute("DELETE FROM %s WHERE %s IN (%s)" % \ + cursor.execute(("DELETE FROM %s WHERE %s IN (%s)" % \ (qn(f.m2m_db_table()), qn(f.m2m_column_name()), - ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]])), - pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]) + ','.join(['%s' for pk in pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE]]))) + query_extra, + pk_list[offset:offset+GET_ITERATOR_CHUNK_SIZE] + args_extra) for field in cls._meta.fields: if field.rel and field.null and field.rel.to in seen_objs: for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): diff --git a/django/db/models/related.py b/django/db/models/related.py index ac1ec50ca2..2c1dc5c516 100644 --- a/django/db/models/related.py +++ b/django/db/models/related.py @@ -1,7 +1,7 @@ class BoundRelatedObject(object): def __init__(self, related_object, field_mapping, original): self.relation = related_object - self.field_mappings = field_mapping[related_object.opts.module_name] + self.field_mappings = field_mapping[related_object.name] def template_name(self): raise NotImplementedError @@ -16,7 +16,7 @@ class RelatedObject(object): self.opts = model._meta self.field = field self.edit_inline = field.rel.edit_inline - self.name = self.opts.module_name + self.name = '%s:%s' % (self.opts.app_label, self.opts.module_name) self.var_name = self.opts.object_name.lower() def flatten_data(self, follow, obj=None): @@ -68,7 +68,10 @@ class RelatedObject(object): # object return [attr] else: - return [None] * self.field.rel.num_in_admin + if self.field.rel.min_num_in_admin: + return [None] * max(self.field.rel.num_in_admin, self.field.rel.min_num_in_admin) + else: + return [None] * self.field.rel.num_in_admin def get_db_prep_lookup(self, lookup_type, value): # Defer to the actual field definition for db prep @@ -101,12 +104,12 @@ class RelatedObject(object): attr = getattr(manipulator.original_object, self.get_accessor_name()) count = attr.count() count += self.field.rel.num_extra_on_change - if self.field.rel.min_num_in_admin: - count = max(count, self.field.rel.min_num_in_admin) - if self.field.rel.max_num_in_admin: - count = min(count, self.field.rel.max_num_in_admin) else: count = self.field.rel.num_in_admin + if self.field.rel.min_num_in_admin: + count = max(count, self.field.rel.min_num_in_admin) + if self.field.rel.max_num_in_admin: + count = min(count, self.field.rel.max_num_in_admin) else: count = 1 |