534 lines
17 KiB
Python
534 lines
17 KiB
Python
import datetime
|
|
from warnings import warn
|
|
|
|
from django.db.models.functions.datetime import Extract as ExtractDate
|
|
from django.db.models.functions.datetime import ExtractYear
|
|
from django.db.models.lookups import Lookup
|
|
from django.db.models.query import QuerySet
|
|
from django.db.models.sql.where import NothingNode, WhereNode
|
|
|
|
from wagtail.search.index import class_is_indexed, get_indexed_models
|
|
from wagtail.search.query import MATCH_ALL, PlainText
|
|
|
|
|
|
class FilterError(Exception):
|
|
pass
|
|
|
|
|
|
class FieldError(Exception):
|
|
def __init__(self, *args, field_name=None, **kwargs):
|
|
self.field_name = field_name
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
class SearchFieldError(FieldError):
|
|
pass
|
|
|
|
|
|
class FilterFieldError(FieldError):
|
|
pass
|
|
|
|
|
|
class OrderByFieldError(FieldError):
|
|
pass
|
|
|
|
|
|
class BaseSearchQueryCompiler:
|
|
DEFAULT_OPERATOR = "or"
|
|
|
|
def __init__(
|
|
self,
|
|
queryset,
|
|
query,
|
|
fields=None,
|
|
operator=None,
|
|
order_by_relevance=True,
|
|
):
|
|
self.queryset = queryset
|
|
if query is None:
|
|
warn(
|
|
"Querying `None` is deprecated, use `MATCH_ALL` instead.",
|
|
DeprecationWarning,
|
|
)
|
|
query = MATCH_ALL
|
|
elif isinstance(query, str):
|
|
query = PlainText(query, operator=operator or self.DEFAULT_OPERATOR)
|
|
self.query = query
|
|
self.fields = fields
|
|
self.order_by_relevance = order_by_relevance
|
|
|
|
def _get_filterable_field(self, field_attname):
|
|
# Get field
|
|
field = {
|
|
field.get_attname(self.queryset.model): field
|
|
for field in self.queryset.model.get_filterable_search_fields()
|
|
}.get(field_attname, None)
|
|
|
|
return field
|
|
|
|
def _process_lookup(self, field, lookup, value):
|
|
raise NotImplementedError
|
|
|
|
def _process_match_none(self):
|
|
raise NotImplementedError
|
|
|
|
def _connect_filters(self, filters, connector, negated):
|
|
raise NotImplementedError
|
|
|
|
def _process_filter(self, field_attname, lookup, value, check_only=False):
|
|
# Get the field
|
|
field = self._get_filterable_field(field_attname)
|
|
|
|
if field is None:
|
|
raise FilterFieldError(
|
|
'Cannot filter search results with field "'
|
|
+ field_attname
|
|
+ "\". Please add index.FilterField('"
|
|
+ field_attname
|
|
+ "') to "
|
|
+ self.queryset.model.__name__
|
|
+ ".search_fields.",
|
|
field_name=field_attname,
|
|
)
|
|
|
|
# Process the lookup
|
|
if not check_only:
|
|
result = self._process_lookup(field, lookup, value)
|
|
|
|
if result is None:
|
|
raise FilterError(
|
|
'Could not apply filter on search results: "'
|
|
+ field_attname
|
|
+ "__"
|
|
+ lookup
|
|
+ " = "
|
|
+ str(value)
|
|
+ '". Lookup "'
|
|
+ lookup
|
|
+ '"" not recognised.'
|
|
)
|
|
|
|
return result
|
|
|
|
def _get_filters_from_where_node(self, where_node, check_only=False):
|
|
# Check if this is a leaf node
|
|
if isinstance(where_node, Lookup):
|
|
if isinstance(where_node.lhs, ExtractDate):
|
|
if not isinstance(where_node.lhs, ExtractYear):
|
|
raise FilterError(
|
|
'Cannot apply filter on search results: "'
|
|
+ where_node.lhs.lookup_name
|
|
+ '" queries are not supported.'
|
|
)
|
|
else:
|
|
field_attname = where_node.lhs.lhs.target.attname
|
|
lookup = where_node.lookup_name
|
|
if lookup == "gte":
|
|
# filter on year(date) >= value
|
|
# i.e. date >= Jan 1st of that year
|
|
value = datetime.date(int(where_node.rhs), 1, 1)
|
|
elif lookup == "gt":
|
|
# filter on year(date) > value
|
|
# i.e. date >= Jan 1st of the next year
|
|
value = datetime.date(int(where_node.rhs) + 1, 1, 1)
|
|
lookup = "gte"
|
|
elif lookup == "lte":
|
|
# filter on year(date) <= value
|
|
# i.e. date < Jan 1st of the next year
|
|
value = datetime.date(int(where_node.rhs) + 1, 1, 1)
|
|
lookup = "lt"
|
|
elif lookup == "lt":
|
|
# filter on year(date) < value
|
|
# i.e. date < Jan 1st of that year
|
|
value = datetime.date(int(where_node.rhs), 1, 1)
|
|
elif lookup == "exact":
|
|
# filter on year(date) == value
|
|
# i.e. date >= Jan 1st of that year and date < Jan 1st of the next year
|
|
filter1 = self._process_filter(
|
|
field_attname,
|
|
"gte",
|
|
datetime.date(int(where_node.rhs), 1, 1),
|
|
check_only=check_only,
|
|
)
|
|
filter2 = self._process_filter(
|
|
field_attname,
|
|
"lt",
|
|
datetime.date(int(where_node.rhs) + 1, 1, 1),
|
|
check_only=check_only,
|
|
)
|
|
if check_only:
|
|
return
|
|
else:
|
|
return self._connect_filters(
|
|
[filter1, filter2], "AND", False
|
|
)
|
|
else:
|
|
raise FilterError(
|
|
'Cannot apply filter on search results: "'
|
|
+ where_node.lhs.lookup_name
|
|
+ '" queries are not supported.'
|
|
)
|
|
else:
|
|
field_attname = where_node.lhs.target.attname
|
|
lookup = where_node.lookup_name
|
|
value = where_node.rhs
|
|
|
|
# Ignore pointer fields that show up in specific page type queries
|
|
if field_attname.endswith("_ptr_id"):
|
|
return
|
|
|
|
# Process the filter
|
|
return self._process_filter(
|
|
field_attname, lookup, value, check_only=check_only
|
|
)
|
|
|
|
elif isinstance(where_node, NothingNode):
|
|
return self._process_match_none()
|
|
|
|
elif isinstance(where_node, WhereNode):
|
|
# Get child filters
|
|
connector = where_node.connector
|
|
child_filters = [
|
|
self._get_filters_from_where_node(child)
|
|
for child in where_node.children
|
|
]
|
|
|
|
if not check_only:
|
|
child_filters = [
|
|
child_filter for child_filter in child_filters if child_filter
|
|
]
|
|
return self._connect_filters(
|
|
child_filters, connector, where_node.negated
|
|
)
|
|
|
|
else:
|
|
raise FilterError(
|
|
"Could not apply filter on search results: Unknown where node: "
|
|
+ str(type(where_node))
|
|
)
|
|
|
|
def _get_filters_from_queryset(self):
|
|
return self._get_filters_from_where_node(self.queryset.query.where)
|
|
|
|
def _get_order_by(self):
|
|
if self.order_by_relevance:
|
|
return
|
|
|
|
for field_name in self.queryset.query.order_by:
|
|
reverse = False
|
|
|
|
if field_name.startswith("-"):
|
|
reverse = True
|
|
field_name = field_name[1:]
|
|
|
|
field = self._get_filterable_field(field_name)
|
|
|
|
if field is None:
|
|
raise OrderByFieldError(
|
|
'Cannot sort search results with field "'
|
|
+ field_name
|
|
+ "\". Please add index.FilterField('"
|
|
+ field_name
|
|
+ "') to "
|
|
+ self.queryset.model.__name__
|
|
+ ".search_fields.",
|
|
field_name=field_name,
|
|
)
|
|
|
|
yield reverse, field
|
|
|
|
def check(self):
|
|
# Check search fields
|
|
if self.fields:
|
|
allowed_fields = {
|
|
field.field_name
|
|
for field in self.queryset.model.get_searchable_search_fields()
|
|
}
|
|
|
|
for field_name in self.fields:
|
|
if field_name not in allowed_fields:
|
|
raise SearchFieldError(
|
|
'Cannot search with field "'
|
|
+ field_name
|
|
+ "\". Please add index.SearchField('"
|
|
+ field_name
|
|
+ "') to "
|
|
+ self.queryset.model.__name__
|
|
+ ".search_fields.",
|
|
field_name=field_name,
|
|
)
|
|
|
|
# Check where clause
|
|
# Raises FilterFieldError if an unindexed field is being filtered on
|
|
self._get_filters_from_where_node(self.queryset.query.where, check_only=True)
|
|
|
|
# Check order by
|
|
# Raises OrderByFieldError if an unindexed field is being used to order by
|
|
list(self._get_order_by())
|
|
|
|
|
|
class BaseSearchResults:
|
|
supports_facet = False
|
|
|
|
def __init__(self, backend, query_compiler, prefetch_related=None):
|
|
self.backend = backend
|
|
self.query_compiler = query_compiler
|
|
self.prefetch_related = prefetch_related
|
|
self.start = 0
|
|
self.stop = None
|
|
self._results_cache = None
|
|
self._count_cache = None
|
|
self._score_field = None
|
|
# Attach the model to mimic a QuerySet so that we can inspect it after
|
|
# doing a search, e.g. to get the model's name in a paginator.
|
|
# The query_compiler may be None, e.g. when using EmptySearchResults.
|
|
self.model = query_compiler.queryset.model if query_compiler else None
|
|
|
|
def _set_limits(self, start=None, stop=None):
|
|
if stop is not None:
|
|
if self.stop is not None:
|
|
self.stop = min(self.stop, self.start + stop)
|
|
else:
|
|
self.stop = self.start + stop
|
|
|
|
if start is not None:
|
|
if self.stop is not None:
|
|
self.start = min(self.stop, self.start + start)
|
|
else:
|
|
self.start = self.start + start
|
|
|
|
def _clone(self):
|
|
klass = self.__class__
|
|
new = klass(
|
|
self.backend, self.query_compiler, prefetch_related=self.prefetch_related
|
|
)
|
|
new.start = self.start
|
|
new.stop = self.stop
|
|
new._score_field = self._score_field
|
|
return new
|
|
|
|
def _do_search(self):
|
|
raise NotImplementedError
|
|
|
|
def _do_count(self):
|
|
raise NotImplementedError
|
|
|
|
def results(self):
|
|
if self._results_cache is None:
|
|
self._results_cache = list(self._do_search())
|
|
return self._results_cache
|
|
|
|
def count(self):
|
|
if self._count_cache is None:
|
|
if self._results_cache is not None:
|
|
self._count_cache = len(self._results_cache)
|
|
else:
|
|
self._count_cache = self._do_count()
|
|
return self._count_cache
|
|
|
|
def __getitem__(self, key):
|
|
new = self._clone()
|
|
|
|
if isinstance(key, slice):
|
|
# Set limits
|
|
start = int(key.start) if key.start is not None else None
|
|
stop = int(key.stop) if key.stop is not None else None
|
|
new._set_limits(start, stop)
|
|
|
|
# Copy results cache
|
|
if self._results_cache is not None:
|
|
new._results_cache = self._results_cache[key]
|
|
|
|
return new
|
|
else:
|
|
if self._results_cache is not None:
|
|
return self._results_cache[key]
|
|
|
|
new.start = self.start + key
|
|
new.stop = self.start + key + 1
|
|
return list(new)[0]
|
|
|
|
def __iter__(self):
|
|
return iter(self.results())
|
|
|
|
def __len__(self):
|
|
return len(self.results())
|
|
|
|
def __repr__(self):
|
|
data = list(self[:21])
|
|
if len(data) > 20:
|
|
data[-1] = "...(remaining elements truncated)..."
|
|
return "<SearchResults %r>" % data
|
|
|
|
def annotate_score(self, field_name):
|
|
clone = self._clone()
|
|
clone._score_field = field_name
|
|
return clone
|
|
|
|
def facet(self, field_name):
|
|
raise NotImplementedError("This search backend does not support faceting")
|
|
|
|
|
|
class EmptySearchResults(BaseSearchResults):
|
|
def __init__(self):
|
|
super().__init__(None, None)
|
|
|
|
def _clone(self):
|
|
return self.__class__()
|
|
|
|
def _do_search(self):
|
|
return []
|
|
|
|
def _do_count(self):
|
|
return 0
|
|
|
|
|
|
class NullIndex:
|
|
"""
|
|
Index class that provides do-nothing implementations of the indexing operations required by
|
|
BaseSearchBackend. Use this for search backends that do not maintain an index, such as the
|
|
database backend.
|
|
"""
|
|
|
|
def add_model(self, model):
|
|
pass
|
|
|
|
def refresh(self):
|
|
pass
|
|
|
|
def add_item(self, item):
|
|
pass
|
|
|
|
def add_items(self, model, items):
|
|
pass
|
|
|
|
def delete_item(self, item):
|
|
pass
|
|
|
|
|
|
class BaseSearchBackend:
|
|
query_compiler_class = None
|
|
autocomplete_query_compiler_class = None
|
|
results_class = None
|
|
rebuilder_class = None
|
|
catch_indexing_errors = False
|
|
|
|
def __init__(self, params):
|
|
pass
|
|
|
|
def get_index_for_model(self, model):
|
|
return NullIndex()
|
|
|
|
def get_rebuilder(self):
|
|
return None
|
|
|
|
def reset_index(self):
|
|
raise NotImplementedError
|
|
|
|
def add_type(self, model):
|
|
self.get_index_for_model(model).add_model(model)
|
|
|
|
def refresh_index(self):
|
|
refreshed_indexes = []
|
|
for model in get_indexed_models():
|
|
index = self.get_index_for_model(model)
|
|
if index not in refreshed_indexes:
|
|
index.refresh()
|
|
refreshed_indexes.append(index)
|
|
|
|
def add(self, obj):
|
|
self.get_index_for_model(type(obj)).add_item(obj)
|
|
|
|
def add_bulk(self, model, obj_list):
|
|
self.get_index_for_model(model).add_items(model, obj_list)
|
|
|
|
def delete(self, obj):
|
|
self.get_index_for_model(type(obj)).delete_item(obj)
|
|
|
|
def _search(self, query_compiler_class, query, model_or_queryset, **kwargs):
|
|
# Find model/queryset
|
|
if isinstance(model_or_queryset, QuerySet):
|
|
model = model_or_queryset.model
|
|
queryset = model_or_queryset
|
|
else:
|
|
model = model_or_queryset
|
|
queryset = model_or_queryset.objects.all()
|
|
|
|
# Model must be a class that is in the index
|
|
if not class_is_indexed(model):
|
|
return EmptySearchResults()
|
|
|
|
# Check that there's still a query string after the clean up
|
|
if query == "":
|
|
return EmptySearchResults()
|
|
|
|
# Search
|
|
search_query_compiler = query_compiler_class(queryset, query, **kwargs)
|
|
|
|
# Check the query
|
|
search_query_compiler.check()
|
|
|
|
return self.results_class(self, search_query_compiler)
|
|
|
|
def search(
|
|
self,
|
|
query,
|
|
model_or_queryset,
|
|
fields=None,
|
|
operator=None,
|
|
order_by_relevance=True,
|
|
):
|
|
return self._search(
|
|
self.query_compiler_class,
|
|
query,
|
|
model_or_queryset,
|
|
fields=fields,
|
|
operator=operator,
|
|
order_by_relevance=order_by_relevance,
|
|
)
|
|
|
|
def autocomplete(
|
|
self,
|
|
query,
|
|
model_or_queryset,
|
|
fields=None,
|
|
operator=None,
|
|
order_by_relevance=True,
|
|
):
|
|
if self.autocomplete_query_compiler_class is None:
|
|
raise NotImplementedError(
|
|
"This search backend does not support the autocomplete API"
|
|
)
|
|
|
|
return self._search(
|
|
self.autocomplete_query_compiler_class,
|
|
query,
|
|
model_or_queryset,
|
|
fields=fields,
|
|
operator=operator,
|
|
order_by_relevance=order_by_relevance,
|
|
)
|
|
|
|
|
|
def get_model_root(model):
|
|
"""
|
|
This function finds the root model for any given model. The root model is
|
|
the highest concrete model that it descends from. If the model doesn't
|
|
descend from another concrete model then the model is it's own root model so
|
|
it is returned.
|
|
|
|
Examples:
|
|
>>> get_model_root(wagtailcore.Page)
|
|
wagtailcore.Page
|
|
|
|
>>> get_model_root(myapp.HomePage)
|
|
wagtailcore.Page
|
|
|
|
>>> get_model_root(wagtailimages.Image)
|
|
wagtailimages.Image
|
|
"""
|
|
if model._meta.parents:
|
|
parent_model = list(model._meta.parents.items())[0][0]
|
|
return get_model_root(parent_model)
|
|
|
|
return model
|