angrybeanie_wagtail/env/lib/python3.12/site-packages/wagtail/search/backends/base.py

535 lines
17 KiB
Python
Raw Normal View History

2025-07-25 21:32:16 +10:00
import datetime
from warnings import warn
from django.db.models.functions.datetime import Extract as ExtractDate
from django.db.models.functions.datetime import ExtractYear
from django.db.models.lookups import Lookup
from django.db.models.query import QuerySet
from django.db.models.sql.where import NothingNode, WhereNode
from wagtail.search.index import class_is_indexed, get_indexed_models
from wagtail.search.query import MATCH_ALL, PlainText
class FilterError(Exception):
pass
class FieldError(Exception):
def __init__(self, *args, field_name=None, **kwargs):
self.field_name = field_name
super().__init__(*args, **kwargs)
class SearchFieldError(FieldError):
pass
class FilterFieldError(FieldError):
pass
class OrderByFieldError(FieldError):
pass
class BaseSearchQueryCompiler:
DEFAULT_OPERATOR = "or"
def __init__(
self,
queryset,
query,
fields=None,
operator=None,
order_by_relevance=True,
):
self.queryset = queryset
if query is None:
warn(
"Querying `None` is deprecated, use `MATCH_ALL` instead.",
DeprecationWarning,
)
query = MATCH_ALL
elif isinstance(query, str):
query = PlainText(query, operator=operator or self.DEFAULT_OPERATOR)
self.query = query
self.fields = fields
self.order_by_relevance = order_by_relevance
def _get_filterable_field(self, field_attname):
# Get field
field = {
field.get_attname(self.queryset.model): field
for field in self.queryset.model.get_filterable_search_fields()
}.get(field_attname, None)
return field
def _process_lookup(self, field, lookup, value):
raise NotImplementedError
def _process_match_none(self):
raise NotImplementedError
def _connect_filters(self, filters, connector, negated):
raise NotImplementedError
def _process_filter(self, field_attname, lookup, value, check_only=False):
# Get the field
field = self._get_filterable_field(field_attname)
if field is None:
raise FilterFieldError(
'Cannot filter search results with field "'
+ field_attname
+ "\". Please add index.FilterField('"
+ field_attname
+ "') to "
+ self.queryset.model.__name__
+ ".search_fields.",
field_name=field_attname,
)
# Process the lookup
if not check_only:
result = self._process_lookup(field, lookup, value)
if result is None:
raise FilterError(
'Could not apply filter on search results: "'
+ field_attname
+ "__"
+ lookup
+ " = "
+ str(value)
+ '". Lookup "'
+ lookup
+ '"" not recognised.'
)
return result
def _get_filters_from_where_node(self, where_node, check_only=False):
# Check if this is a leaf node
if isinstance(where_node, Lookup):
if isinstance(where_node.lhs, ExtractDate):
if not isinstance(where_node.lhs, ExtractYear):
raise FilterError(
'Cannot apply filter on search results: "'
+ where_node.lhs.lookup_name
+ '" queries are not supported.'
)
else:
field_attname = where_node.lhs.lhs.target.attname
lookup = where_node.lookup_name
if lookup == "gte":
# filter on year(date) >= value
# i.e. date >= Jan 1st of that year
value = datetime.date(int(where_node.rhs), 1, 1)
elif lookup == "gt":
# filter on year(date) > value
# i.e. date >= Jan 1st of the next year
value = datetime.date(int(where_node.rhs) + 1, 1, 1)
lookup = "gte"
elif lookup == "lte":
# filter on year(date) <= value
# i.e. date < Jan 1st of the next year
value = datetime.date(int(where_node.rhs) + 1, 1, 1)
lookup = "lt"
elif lookup == "lt":
# filter on year(date) < value
# i.e. date < Jan 1st of that year
value = datetime.date(int(where_node.rhs), 1, 1)
elif lookup == "exact":
# filter on year(date) == value
# i.e. date >= Jan 1st of that year and date < Jan 1st of the next year
filter1 = self._process_filter(
field_attname,
"gte",
datetime.date(int(where_node.rhs), 1, 1),
check_only=check_only,
)
filter2 = self._process_filter(
field_attname,
"lt",
datetime.date(int(where_node.rhs) + 1, 1, 1),
check_only=check_only,
)
if check_only:
return
else:
return self._connect_filters(
[filter1, filter2], "AND", False
)
else:
raise FilterError(
'Cannot apply filter on search results: "'
+ where_node.lhs.lookup_name
+ '" queries are not supported.'
)
else:
field_attname = where_node.lhs.target.attname
lookup = where_node.lookup_name
value = where_node.rhs
# Ignore pointer fields that show up in specific page type queries
if field_attname.endswith("_ptr_id"):
return
# Process the filter
return self._process_filter(
field_attname, lookup, value, check_only=check_only
)
elif isinstance(where_node, NothingNode):
return self._process_match_none()
elif isinstance(where_node, WhereNode):
# Get child filters
connector = where_node.connector
child_filters = [
self._get_filters_from_where_node(child)
for child in where_node.children
]
if not check_only:
child_filters = [
child_filter for child_filter in child_filters if child_filter
]
return self._connect_filters(
child_filters, connector, where_node.negated
)
else:
raise FilterError(
"Could not apply filter on search results: Unknown where node: "
+ str(type(where_node))
)
def _get_filters_from_queryset(self):
return self._get_filters_from_where_node(self.queryset.query.where)
def _get_order_by(self):
if self.order_by_relevance:
return
for field_name in self.queryset.query.order_by:
reverse = False
if field_name.startswith("-"):
reverse = True
field_name = field_name[1:]
field = self._get_filterable_field(field_name)
if field is None:
raise OrderByFieldError(
'Cannot sort search results with field "'
+ field_name
+ "\". Please add index.FilterField('"
+ field_name
+ "') to "
+ self.queryset.model.__name__
+ ".search_fields.",
field_name=field_name,
)
yield reverse, field
def check(self):
# Check search fields
if self.fields:
allowed_fields = {
field.field_name
for field in self.queryset.model.get_searchable_search_fields()
}
for field_name in self.fields:
if field_name not in allowed_fields:
raise SearchFieldError(
'Cannot search with field "'
+ field_name
+ "\". Please add index.SearchField('"
+ field_name
+ "') to "
+ self.queryset.model.__name__
+ ".search_fields.",
field_name=field_name,
)
# Check where clause
# Raises FilterFieldError if an unindexed field is being filtered on
self._get_filters_from_where_node(self.queryset.query.where, check_only=True)
# Check order by
# Raises OrderByFieldError if an unindexed field is being used to order by
list(self._get_order_by())
class BaseSearchResults:
supports_facet = False
def __init__(self, backend, query_compiler, prefetch_related=None):
self.backend = backend
self.query_compiler = query_compiler
self.prefetch_related = prefetch_related
self.start = 0
self.stop = None
self._results_cache = None
self._count_cache = None
self._score_field = None
# Attach the model to mimic a QuerySet so that we can inspect it after
# doing a search, e.g. to get the model's name in a paginator.
# The query_compiler may be None, e.g. when using EmptySearchResults.
self.model = query_compiler.queryset.model if query_compiler else None
def _set_limits(self, start=None, stop=None):
if stop is not None:
if self.stop is not None:
self.stop = min(self.stop, self.start + stop)
else:
self.stop = self.start + stop
if start is not None:
if self.stop is not None:
self.start = min(self.stop, self.start + start)
else:
self.start = self.start + start
def _clone(self):
klass = self.__class__
new = klass(
self.backend, self.query_compiler, prefetch_related=self.prefetch_related
)
new.start = self.start
new.stop = self.stop
new._score_field = self._score_field
return new
def _do_search(self):
raise NotImplementedError
def _do_count(self):
raise NotImplementedError
def results(self):
if self._results_cache is None:
self._results_cache = list(self._do_search())
return self._results_cache
def count(self):
if self._count_cache is None:
if self._results_cache is not None:
self._count_cache = len(self._results_cache)
else:
self._count_cache = self._do_count()
return self._count_cache
def __getitem__(self, key):
new = self._clone()
if isinstance(key, slice):
# Set limits
start = int(key.start) if key.start is not None else None
stop = int(key.stop) if key.stop is not None else None
new._set_limits(start, stop)
# Copy results cache
if self._results_cache is not None:
new._results_cache = self._results_cache[key]
return new
else:
if self._results_cache is not None:
return self._results_cache[key]
new.start = self.start + key
new.stop = self.start + key + 1
return list(new)[0]
def __iter__(self):
return iter(self.results())
def __len__(self):
return len(self.results())
def __repr__(self):
data = list(self[:21])
if len(data) > 20:
data[-1] = "...(remaining elements truncated)..."
return "<SearchResults %r>" % data
def annotate_score(self, field_name):
clone = self._clone()
clone._score_field = field_name
return clone
def facet(self, field_name):
raise NotImplementedError("This search backend does not support faceting")
class EmptySearchResults(BaseSearchResults):
def __init__(self):
super().__init__(None, None)
def _clone(self):
return self.__class__()
def _do_search(self):
return []
def _do_count(self):
return 0
class NullIndex:
"""
Index class that provides do-nothing implementations of the indexing operations required by
BaseSearchBackend. Use this for search backends that do not maintain an index, such as the
database backend.
"""
def add_model(self, model):
pass
def refresh(self):
pass
def add_item(self, item):
pass
def add_items(self, model, items):
pass
def delete_item(self, item):
pass
class BaseSearchBackend:
query_compiler_class = None
autocomplete_query_compiler_class = None
results_class = None
rebuilder_class = None
catch_indexing_errors = False
def __init__(self, params):
pass
def get_index_for_model(self, model):
return NullIndex()
def get_rebuilder(self):
return None
def reset_index(self):
raise NotImplementedError
def add_type(self, model):
self.get_index_for_model(model).add_model(model)
def refresh_index(self):
refreshed_indexes = []
for model in get_indexed_models():
index = self.get_index_for_model(model)
if index not in refreshed_indexes:
index.refresh()
refreshed_indexes.append(index)
def add(self, obj):
self.get_index_for_model(type(obj)).add_item(obj)
def add_bulk(self, model, obj_list):
self.get_index_for_model(model).add_items(model, obj_list)
def delete(self, obj):
self.get_index_for_model(type(obj)).delete_item(obj)
def _search(self, query_compiler_class, query, model_or_queryset, **kwargs):
# Find model/queryset
if isinstance(model_or_queryset, QuerySet):
model = model_or_queryset.model
queryset = model_or_queryset
else:
model = model_or_queryset
queryset = model_or_queryset.objects.all()
# Model must be a class that is in the index
if not class_is_indexed(model):
return EmptySearchResults()
# Check that there's still a query string after the clean up
if query == "":
return EmptySearchResults()
# Search
search_query_compiler = query_compiler_class(queryset, query, **kwargs)
# Check the query
search_query_compiler.check()
return self.results_class(self, search_query_compiler)
def search(
self,
query,
model_or_queryset,
fields=None,
operator=None,
order_by_relevance=True,
):
return self._search(
self.query_compiler_class,
query,
model_or_queryset,
fields=fields,
operator=operator,
order_by_relevance=order_by_relevance,
)
def autocomplete(
self,
query,
model_or_queryset,
fields=None,
operator=None,
order_by_relevance=True,
):
if self.autocomplete_query_compiler_class is None:
raise NotImplementedError(
"This search backend does not support the autocomplete API"
)
return self._search(
self.autocomplete_query_compiler_class,
query,
model_or_queryset,
fields=fields,
operator=operator,
order_by_relevance=order_by_relevance,
)
def get_model_root(model):
"""
This function finds the root model for any given model. The root model is
the highest concrete model that it descends from. If the model doesn't
descend from another concrete model then the model is it's own root model so
it is returned.
Examples:
>>> get_model_root(wagtailcore.Page)
wagtailcore.Page
>>> get_model_root(myapp.HomePage)
wagtailcore.Page
>>> get_model_root(wagtailimages.Image)
wagtailimages.Image
"""
if model._meta.parents:
parent_model = list(model._meta.parents.items())[0][0]
return get_model_root(parent_model)
return model