558 lines
17 KiB
Python
558 lines
17 KiB
Python
from __future__ import unicode_literals
|
|
|
|
import re
|
|
|
|
from django.core.exceptions import FieldDoesNotExist
|
|
from django.db.models import Model, Q, prefetch_related_objects
|
|
|
|
from modelcluster.utils import NullRelationshipValueEncountered, extract_field_value, get_model_field, sort_by_fields
|
|
|
|
|
|
# Constructor for test functions that determine whether an object passes some boolean condition
|
|
def test_exact(model, attribute_name, value):
|
|
if isinstance(value, Model):
|
|
if value.pk is None:
|
|
# comparing against an unsaved model, so objects need to match by reference
|
|
def _test(obj):
|
|
try:
|
|
other_value = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
return other_value is value
|
|
|
|
return _test
|
|
|
|
else:
|
|
# comparing against a saved model; objects need to match by type and ID.
|
|
# Additionally, where model inheritance is involved, we need to treat it as a
|
|
# positive match if one is a subclass of the other
|
|
def _test(obj):
|
|
try:
|
|
other_value = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
return value.pk == other_value.pk and (
|
|
isinstance(value, other_value.__class__)
|
|
or isinstance(other_value, value.__class__)
|
|
)
|
|
|
|
return _test
|
|
else:
|
|
field = get_model_field(model, attribute_name)
|
|
# convert value to the correct python type for this field
|
|
typed_value = field.to_python(value)
|
|
|
|
# just a plain Python value = do a normal equality check
|
|
def _test(obj):
|
|
try:
|
|
other_value = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
return other_value == typed_value
|
|
|
|
return _test
|
|
|
|
|
|
def test_iexact(model, attribute_name, match_value):
|
|
field = get_model_field(model, attribute_name)
|
|
match_value = field.to_python(match_value)
|
|
|
|
if match_value is None:
|
|
|
|
def _test(obj):
|
|
try:
|
|
val = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
return val is None
|
|
else:
|
|
match_value = match_value.upper()
|
|
|
|
def _test(obj):
|
|
try:
|
|
val = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
return val is not None and val.upper() == match_value
|
|
|
|
return _test
|
|
|
|
|
|
def test_contains(model, attribute_name, value):
|
|
field = get_model_field(model, attribute_name)
|
|
match_value = field.to_python(value)
|
|
|
|
def _test(obj):
|
|
try:
|
|
val = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
return val is not None and match_value in val
|
|
|
|
return _test
|
|
|
|
|
|
def test_icontains(model, attribute_name, value):
|
|
field = get_model_field(model, attribute_name)
|
|
match_value = field.to_python(value).upper()
|
|
|
|
def _test(obj):
|
|
try:
|
|
val = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
return val is not None and match_value in val.upper()
|
|
|
|
return _test
|
|
|
|
|
|
def test_lt(model, attribute_name, value):
|
|
field = get_model_field(model, attribute_name)
|
|
match_value = field.to_python(value)
|
|
|
|
def _test(obj):
|
|
try:
|
|
val = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
return val is not None and val < match_value
|
|
|
|
return _test
|
|
|
|
|
|
def test_lte(model, attribute_name, value):
|
|
field = get_model_field(model, attribute_name)
|
|
match_value = field.to_python(value)
|
|
|
|
def _test(obj):
|
|
try:
|
|
val = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
return val is not None and val <= match_value
|
|
|
|
return _test
|
|
|
|
|
|
def test_gt(model, attribute_name, value):
|
|
field = get_model_field(model, attribute_name)
|
|
match_value = field.to_python(value)
|
|
|
|
def _test(obj):
|
|
try:
|
|
val = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
return val is not None and val > match_value
|
|
|
|
return _test
|
|
|
|
|
|
def test_gte(model, attribute_name, value):
|
|
field = get_model_field(model, attribute_name)
|
|
match_value = field.to_python(value)
|
|
|
|
def _test(obj):
|
|
try:
|
|
val = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
return val is not None and val >= match_value
|
|
|
|
return _test
|
|
|
|
|
|
def test_in(model, attribute_name, value_list):
|
|
field = get_model_field(model, attribute_name)
|
|
match_values = set(field.to_python(val) for val in value_list)
|
|
|
|
def _test(obj):
|
|
try:
|
|
val = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
return val in match_values
|
|
|
|
return _test
|
|
|
|
|
|
def test_startswith(model, attribute_name, value):
|
|
field = get_model_field(model, attribute_name)
|
|
match_value = field.to_python(value)
|
|
|
|
def _test(obj):
|
|
try:
|
|
val = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
return val is not None and val.startswith(match_value)
|
|
|
|
return _test
|
|
|
|
|
|
def test_istartswith(model, attribute_name, value):
|
|
field = get_model_field(model, attribute_name)
|
|
match_value = field.to_python(value).upper()
|
|
|
|
def _test(obj):
|
|
try:
|
|
val = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
return val is not None and val.upper().startswith(match_value)
|
|
|
|
return _test
|
|
|
|
|
|
def test_endswith(model, attribute_name, value):
|
|
field = get_model_field(model, attribute_name)
|
|
match_value = field.to_python(value)
|
|
|
|
def _test(obj):
|
|
try:
|
|
val = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
return val is not None and val.endswith(match_value)
|
|
|
|
return _test
|
|
|
|
|
|
def test_iendswith(model, attribute_name, value):
|
|
field = get_model_field(model, attribute_name)
|
|
match_value = field.to_python(value).upper()
|
|
|
|
def _test(obj):
|
|
try:
|
|
val = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
return val is not None and val.upper().endswith(match_value)
|
|
|
|
return _test
|
|
|
|
|
|
def test_range(model, attribute_name, range_val):
|
|
field = get_model_field(model, attribute_name)
|
|
start_val = field.to_python(range_val[0])
|
|
end_val = field.to_python(range_val[1])
|
|
|
|
def _test(obj):
|
|
try:
|
|
val = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
return (val is not None and val >= start_val and val <= end_val)
|
|
|
|
return _test
|
|
|
|
|
|
def test_isnull(model, attribute_name, sense):
|
|
def _test(obj):
|
|
try:
|
|
val = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
if sense:
|
|
return val is None
|
|
else:
|
|
return val is not None
|
|
|
|
return _test
|
|
|
|
|
|
def test_regex(model, attribute_name, regex_string):
|
|
regex = re.compile(regex_string)
|
|
|
|
def _test(obj):
|
|
try:
|
|
val = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
return val is not None and regex.search(val)
|
|
|
|
return _test
|
|
|
|
|
|
def test_iregex(model, attribute_name, regex_string):
|
|
regex = re.compile(regex_string, re.I)
|
|
|
|
def _test(obj):
|
|
try:
|
|
val = extract_field_value(obj, attribute_name)
|
|
except NullRelationshipValueEncountered:
|
|
return False
|
|
return val is not None and regex.search(val)
|
|
|
|
return _test
|
|
|
|
|
|
FILTER_EXPRESSION_TOKENS = {
|
|
'exact': test_exact,
|
|
'iexact': test_iexact,
|
|
'contains': test_contains,
|
|
'icontains': test_icontains,
|
|
'lt': test_lt,
|
|
'lte': test_lte,
|
|
'gt': test_gt,
|
|
'gte': test_gte,
|
|
'in': test_in,
|
|
'startswith': test_startswith,
|
|
'istartswith': test_istartswith,
|
|
'endswith': test_endswith,
|
|
'iendswith': test_iendswith,
|
|
'range': test_range,
|
|
'isnull': test_isnull,
|
|
'regex': test_regex,
|
|
'iregex': test_iregex,
|
|
}
|
|
|
|
|
|
def _build_test_function_from_filter(model, key_clauses, val):
|
|
# Translate a filter kwarg rule (e.g. foo__bar__exact=123) into a function which can
|
|
# take a model instance and return a boolean indicating whether it passes the rule
|
|
try:
|
|
get_model_field(model, "__".join(key_clauses))
|
|
except FieldDoesNotExist:
|
|
# it is safe to assume the last clause indicates the type of test
|
|
field_match_found = False
|
|
else:
|
|
field_match_found = True
|
|
|
|
if not field_match_found and key_clauses[-1] in FILTER_EXPRESSION_TOKENS:
|
|
constructor = FILTER_EXPRESSION_TOKENS[key_clauses.pop()]
|
|
else:
|
|
constructor = test_exact
|
|
# recombine the remaining items to be interpretted
|
|
# by get_model_field() and extract_field_value()
|
|
attribute_name = "__".join(key_clauses)
|
|
return constructor(model, attribute_name, val)
|
|
|
|
|
|
class FakeQuerySetIterable:
|
|
def __init__(self, queryset):
|
|
self.queryset = queryset
|
|
|
|
|
|
class ModelIterable(FakeQuerySetIterable):
|
|
def __iter__(self):
|
|
yield from self.queryset.results
|
|
|
|
|
|
class DictIterable(FakeQuerySetIterable):
|
|
def __iter__(self):
|
|
field_names = self.queryset.dict_fields or [field.name for field in self.queryset.model._meta.fields]
|
|
for obj in self.queryset.results:
|
|
yield {
|
|
field_name: extract_field_value(obj, field_name, pk_only=True, suppress_fielddoesnotexist=True, suppress_nullrelationshipvalueencountered=True)
|
|
for field_name in field_names
|
|
}
|
|
|
|
|
|
class ValuesListIterable(FakeQuerySetIterable):
|
|
def __iter__(self):
|
|
field_names = self.queryset.tuple_fields or [field.name for field in self.queryset.model._meta.fields]
|
|
for obj in self.queryset.results:
|
|
yield tuple([extract_field_value(obj, field_name, pk_only=True, suppress_fielddoesnotexist=True, suppress_nullrelationshipvalueencountered=True) for field_name in field_names])
|
|
|
|
|
|
class FlatValuesListIterable(FakeQuerySetIterable):
|
|
def __iter__(self):
|
|
field_name = self.queryset.tuple_fields[0]
|
|
for obj in self.queryset.results:
|
|
yield extract_field_value(obj, field_name, pk_only=True, suppress_fielddoesnotexist=True, suppress_nullrelationshipvalueencountered=True)
|
|
|
|
|
|
class FakeQuerySet(object):
|
|
def __init__(self, model, results):
|
|
self.model = model
|
|
self.results = results
|
|
self.dict_fields = []
|
|
self.tuple_fields = []
|
|
self.iterable_class = ModelIterable
|
|
|
|
def all(self):
|
|
return self
|
|
|
|
def get_clone(self, results = None):
|
|
new = FakeQuerySet(self.model, results if results is not None else self.results)
|
|
new.dict_fields = self.dict_fields
|
|
new.tuple_fields = self.tuple_fields
|
|
new.iterable_class = self.iterable_class
|
|
return new
|
|
|
|
def resolve_q_object(self, q_object):
|
|
connector = q_object.connector
|
|
filters = []
|
|
|
|
def test(filters):
|
|
def test_inner(obj):
|
|
result = False
|
|
if connector == Q.AND:
|
|
result = all([test(obj) for test in filters])
|
|
elif connector == Q.OR:
|
|
result = any([test(obj) for test in filters])
|
|
else:
|
|
result = sum([test(obj) for test in filters]) == 1
|
|
if q_object.negated:
|
|
return not result
|
|
return result
|
|
return test_inner
|
|
|
|
for child in q_object.children:
|
|
if isinstance(child, Q):
|
|
filters.append(self.resolve_q_object(child))
|
|
else:
|
|
key_clauses, val = child
|
|
filters.append(_build_test_function_from_filter(self.model, key_clauses.split('__'), val))
|
|
|
|
return test(filters)
|
|
|
|
def _get_filters(self, *args, **kwargs):
|
|
# a list of test functions; objects must pass all tests to be included
|
|
# in the filtered list
|
|
filters = []
|
|
|
|
for q_object in args:
|
|
filters.append(self.resolve_q_object(q_object))
|
|
|
|
for key, val in kwargs.items():
|
|
filters.append(
|
|
_build_test_function_from_filter(self.model, key.split('__'), val)
|
|
)
|
|
|
|
return filters
|
|
|
|
def filter(self, *args, **kwargs):
|
|
filters = self._get_filters(*args, **kwargs)
|
|
|
|
clone = self.get_clone(results=[
|
|
obj for obj in self.results
|
|
if all([test(obj) for test in filters])
|
|
])
|
|
return clone
|
|
|
|
def exclude(self, *args, **kwargs):
|
|
filters = self._get_filters(*args, **kwargs)
|
|
|
|
clone = self.get_clone(results=[
|
|
obj for obj in self.results
|
|
if not all([test(obj) for test in filters])
|
|
])
|
|
return clone
|
|
|
|
def get(self, *args, **kwargs):
|
|
clone = self.filter(*args, **kwargs)
|
|
result_count = clone.count()
|
|
|
|
if result_count == 0:
|
|
raise self.model.DoesNotExist("%s matching query does not exist." % self.model._meta.object_name)
|
|
elif result_count == 1:
|
|
for result in clone:
|
|
return result
|
|
else:
|
|
raise self.model.MultipleObjectsReturned(
|
|
"get() returned more than one %s -- it returned %s!" % (self.model._meta.object_name, result_count)
|
|
)
|
|
|
|
def count(self):
|
|
return len(self.results)
|
|
|
|
def exists(self):
|
|
return bool(self.results)
|
|
|
|
def first(self):
|
|
for result in self:
|
|
return result
|
|
|
|
def last(self):
|
|
if self.results:
|
|
clone = self.get_clone(results=reversed(self.results))
|
|
for result in clone:
|
|
return result
|
|
|
|
def select_related(self, *args):
|
|
# has no meaningful effect on non-db querysets
|
|
return self
|
|
|
|
def prefetch_related(self, *args):
|
|
prefetch_related_objects(self.results, *args)
|
|
return self
|
|
|
|
def only(self, *args):
|
|
# has no meaningful effect on non-db querysets
|
|
return self
|
|
|
|
def defer(self, *args):
|
|
# has no meaningful effect on non-db querysets
|
|
return self
|
|
|
|
def values(self, *fields):
|
|
clone = self.get_clone()
|
|
clone.dict_fields = fields
|
|
# Ensure all 'fields' are available model fields
|
|
for f in fields:
|
|
get_model_field(self.model, f)
|
|
clone.iterable_class = DictIterable
|
|
return clone
|
|
|
|
def values_list(self, *fields, flat=None):
|
|
clone = self.get_clone()
|
|
clone.tuple_fields = fields
|
|
# Ensure all 'fields' are available model fields
|
|
for f in fields:
|
|
get_model_field(self.model, f)
|
|
if flat:
|
|
if len(fields) > 1:
|
|
raise TypeError("'flat' is not valid when values_list is called with more than one field.")
|
|
clone.iterable_class = FlatValuesListIterable
|
|
else:
|
|
clone.iterable_class = ValuesListIterable
|
|
return clone
|
|
|
|
def order_by(self, *fields):
|
|
clone = self.get_clone(results=self.results[:])
|
|
sort_by_fields(clone.results, fields)
|
|
return clone
|
|
|
|
def distinct(self, *fields):
|
|
unique_results = []
|
|
if not fields:
|
|
fields = [field.name for field in self.model._meta.fields if not field.primary_key]
|
|
seen_keys = set()
|
|
for result in self.results:
|
|
key = tuple(str(extract_field_value(result, field)) for field in fields)
|
|
if key not in seen_keys:
|
|
seen_keys.add(key)
|
|
unique_results.append(result)
|
|
return self.get_clone(results=unique_results)
|
|
|
|
# a standard QuerySet will store the results in _result_cache on running the query;
|
|
# this is effectively the same as self.results on a FakeQuerySet, and so we'll make
|
|
# _result_cache an alias of self.results for the benefit of Django internals that
|
|
# exploit it
|
|
def _get_result_cache(self):
|
|
return self.results
|
|
|
|
def _set_result_cache(self, val):
|
|
self.results = list(val)
|
|
|
|
_result_cache = property(_get_result_cache, _set_result_cache)
|
|
|
|
def __getitem__(self, k):
|
|
return self.results[k]
|
|
|
|
def __iter__(self):
|
|
iterator = self.iterable_class(self)
|
|
yield from iterator
|
|
|
|
def __nonzero__(self):
|
|
return bool(self.results)
|
|
|
|
def __repr__(self):
|
|
return repr(list(self))
|
|
|
|
def __len__(self):
|
|
return len(self.results)
|
|
|
|
ordered = True # results are returned in a consistent order
|