import datetime from warnings import warn from django.db.models.functions.datetime import Extract as ExtractDate from django.db.models.functions.datetime import ExtractYear from django.db.models.lookups import Lookup from django.db.models.query import QuerySet from django.db.models.sql.where import NothingNode, WhereNode from wagtail.search.index import class_is_indexed, get_indexed_models from wagtail.search.query import MATCH_ALL, PlainText class FilterError(Exception): pass class FieldError(Exception): def __init__(self, *args, field_name=None, **kwargs): self.field_name = field_name super().__init__(*args, **kwargs) class SearchFieldError(FieldError): pass class FilterFieldError(FieldError): pass class OrderByFieldError(FieldError): pass class BaseSearchQueryCompiler: DEFAULT_OPERATOR = "or" def __init__( self, queryset, query, fields=None, operator=None, order_by_relevance=True, ): self.queryset = queryset if query is None: warn( "Querying `None` is deprecated, use `MATCH_ALL` instead.", DeprecationWarning, ) query = MATCH_ALL elif isinstance(query, str): query = PlainText(query, operator=operator or self.DEFAULT_OPERATOR) self.query = query self.fields = fields self.order_by_relevance = order_by_relevance def _get_filterable_field(self, field_attname): # Get field field = { field.get_attname(self.queryset.model): field for field in self.queryset.model.get_filterable_search_fields() }.get(field_attname, None) return field def _process_lookup(self, field, lookup, value): raise NotImplementedError def _process_match_none(self): raise NotImplementedError def _connect_filters(self, filters, connector, negated): raise NotImplementedError def _process_filter(self, field_attname, lookup, value, check_only=False): # Get the field field = self._get_filterable_field(field_attname) if field is None: raise FilterFieldError( 'Cannot filter search results with field "' + field_attname + "\". Please add index.FilterField('" + field_attname + "') to " + self.queryset.model.__name__ + ".search_fields.", field_name=field_attname, ) # Process the lookup if not check_only: result = self._process_lookup(field, lookup, value) if result is None: raise FilterError( 'Could not apply filter on search results: "' + field_attname + "__" + lookup + " = " + str(value) + '". Lookup "' + lookup + '"" not recognised.' ) return result def _get_filters_from_where_node(self, where_node, check_only=False): # Check if this is a leaf node if isinstance(where_node, Lookup): if isinstance(where_node.lhs, ExtractDate): if not isinstance(where_node.lhs, ExtractYear): raise FilterError( 'Cannot apply filter on search results: "' + where_node.lhs.lookup_name + '" queries are not supported.' ) else: field_attname = where_node.lhs.lhs.target.attname lookup = where_node.lookup_name if lookup == "gte": # filter on year(date) >= value # i.e. date >= Jan 1st of that year value = datetime.date(int(where_node.rhs), 1, 1) elif lookup == "gt": # filter on year(date) > value # i.e. date >= Jan 1st of the next year value = datetime.date(int(where_node.rhs) + 1, 1, 1) lookup = "gte" elif lookup == "lte": # filter on year(date) <= value # i.e. date < Jan 1st of the next year value = datetime.date(int(where_node.rhs) + 1, 1, 1) lookup = "lt" elif lookup == "lt": # filter on year(date) < value # i.e. date < Jan 1st of that year value = datetime.date(int(where_node.rhs), 1, 1) elif lookup == "exact": # filter on year(date) == value # i.e. date >= Jan 1st of that year and date < Jan 1st of the next year filter1 = self._process_filter( field_attname, "gte", datetime.date(int(where_node.rhs), 1, 1), check_only=check_only, ) filter2 = self._process_filter( field_attname, "lt", datetime.date(int(where_node.rhs) + 1, 1, 1), check_only=check_only, ) if check_only: return else: return self._connect_filters( [filter1, filter2], "AND", False ) else: raise FilterError( 'Cannot apply filter on search results: "' + where_node.lhs.lookup_name + '" queries are not supported.' ) else: field_attname = where_node.lhs.target.attname lookup = where_node.lookup_name value = where_node.rhs # Ignore pointer fields that show up in specific page type queries if field_attname.endswith("_ptr_id"): return # Process the filter return self._process_filter( field_attname, lookup, value, check_only=check_only ) elif isinstance(where_node, NothingNode): return self._process_match_none() elif isinstance(where_node, WhereNode): # Get child filters connector = where_node.connector child_filters = [ self._get_filters_from_where_node(child) for child in where_node.children ] if not check_only: child_filters = [ child_filter for child_filter in child_filters if child_filter ] return self._connect_filters( child_filters, connector, where_node.negated ) else: raise FilterError( "Could not apply filter on search results: Unknown where node: " + str(type(where_node)) ) def _get_filters_from_queryset(self): return self._get_filters_from_where_node(self.queryset.query.where) def _get_order_by(self): if self.order_by_relevance: return for field_name in self.queryset.query.order_by: reverse = False if field_name.startswith("-"): reverse = True field_name = field_name[1:] field = self._get_filterable_field(field_name) if field is None: raise OrderByFieldError( 'Cannot sort search results with field "' + field_name + "\". Please add index.FilterField('" + field_name + "') to " + self.queryset.model.__name__ + ".search_fields.", field_name=field_name, ) yield reverse, field def check(self): # Check search fields if self.fields: allowed_fields = { field.field_name for field in self.queryset.model.get_searchable_search_fields() } for field_name in self.fields: if field_name not in allowed_fields: raise SearchFieldError( 'Cannot search with field "' + field_name + "\". Please add index.SearchField('" + field_name + "') to " + self.queryset.model.__name__ + ".search_fields.", field_name=field_name, ) # Check where clause # Raises FilterFieldError if an unindexed field is being filtered on self._get_filters_from_where_node(self.queryset.query.where, check_only=True) # Check order by # Raises OrderByFieldError if an unindexed field is being used to order by list(self._get_order_by()) class BaseSearchResults: supports_facet = False def __init__(self, backend, query_compiler, prefetch_related=None): self.backend = backend self.query_compiler = query_compiler self.prefetch_related = prefetch_related self.start = 0 self.stop = None self._results_cache = None self._count_cache = None self._score_field = None # Attach the model to mimic a QuerySet so that we can inspect it after # doing a search, e.g. to get the model's name in a paginator. # The query_compiler may be None, e.g. when using EmptySearchResults. self.model = query_compiler.queryset.model if query_compiler else None def _set_limits(self, start=None, stop=None): if stop is not None: if self.stop is not None: self.stop = min(self.stop, self.start + stop) else: self.stop = self.start + stop if start is not None: if self.stop is not None: self.start = min(self.stop, self.start + start) else: self.start = self.start + start def _clone(self): klass = self.__class__ new = klass( self.backend, self.query_compiler, prefetch_related=self.prefetch_related ) new.start = self.start new.stop = self.stop new._score_field = self._score_field return new def _do_search(self): raise NotImplementedError def _do_count(self): raise NotImplementedError def results(self): if self._results_cache is None: self._results_cache = list(self._do_search()) return self._results_cache def count(self): if self._count_cache is None: if self._results_cache is not None: self._count_cache = len(self._results_cache) else: self._count_cache = self._do_count() return self._count_cache def __getitem__(self, key): new = self._clone() if isinstance(key, slice): # Set limits start = int(key.start) if key.start is not None else None stop = int(key.stop) if key.stop is not None else None new._set_limits(start, stop) # Copy results cache if self._results_cache is not None: new._results_cache = self._results_cache[key] return new else: if self._results_cache is not None: return self._results_cache[key] new.start = self.start + key new.stop = self.start + key + 1 return list(new)[0] def __iter__(self): return iter(self.results()) def __len__(self): return len(self.results()) def __repr__(self): data = list(self[:21]) if len(data) > 20: data[-1] = "...(remaining elements truncated)..." return "" % data def annotate_score(self, field_name): clone = self._clone() clone._score_field = field_name return clone def facet(self, field_name): raise NotImplementedError("This search backend does not support faceting") class EmptySearchResults(BaseSearchResults): def __init__(self): super().__init__(None, None) def _clone(self): return self.__class__() def _do_search(self): return [] def _do_count(self): return 0 class NullIndex: """ Index class that provides do-nothing implementations of the indexing operations required by BaseSearchBackend. Use this for search backends that do not maintain an index, such as the database backend. """ def add_model(self, model): pass def refresh(self): pass def add_item(self, item): pass def add_items(self, model, items): pass def delete_item(self, item): pass class BaseSearchBackend: query_compiler_class = None autocomplete_query_compiler_class = None results_class = None rebuilder_class = None catch_indexing_errors = False def __init__(self, params): pass def get_index_for_model(self, model): return NullIndex() def get_rebuilder(self): return None def reset_index(self): raise NotImplementedError def add_type(self, model): self.get_index_for_model(model).add_model(model) def refresh_index(self): refreshed_indexes = [] for model in get_indexed_models(): index = self.get_index_for_model(model) if index not in refreshed_indexes: index.refresh() refreshed_indexes.append(index) def add(self, obj): self.get_index_for_model(type(obj)).add_item(obj) def add_bulk(self, model, obj_list): self.get_index_for_model(model).add_items(model, obj_list) def delete(self, obj): self.get_index_for_model(type(obj)).delete_item(obj) def _search(self, query_compiler_class, query, model_or_queryset, **kwargs): # Find model/queryset if isinstance(model_or_queryset, QuerySet): model = model_or_queryset.model queryset = model_or_queryset else: model = model_or_queryset queryset = model_or_queryset.objects.all() # Model must be a class that is in the index if not class_is_indexed(model): return EmptySearchResults() # Check that there's still a query string after the clean up if query == "": return EmptySearchResults() # Search search_query_compiler = query_compiler_class(queryset, query, **kwargs) # Check the query search_query_compiler.check() return self.results_class(self, search_query_compiler) def search( self, query, model_or_queryset, fields=None, operator=None, order_by_relevance=True, ): return self._search( self.query_compiler_class, query, model_or_queryset, fields=fields, operator=operator, order_by_relevance=order_by_relevance, ) def autocomplete( self, query, model_or_queryset, fields=None, operator=None, order_by_relevance=True, ): if self.autocomplete_query_compiler_class is None: raise NotImplementedError( "This search backend does not support the autocomplete API" ) return self._search( self.autocomplete_query_compiler_class, query, model_or_queryset, fields=fields, operator=operator, order_by_relevance=order_by_relevance, ) def get_model_root(model): """ This function finds the root model for any given model. The root model is the highest concrete model that it descends from. If the model doesn't descend from another concrete model then the model is it's own root model so it is returned. Examples: >>> get_model_root(wagtailcore.Page) wagtailcore.Page >>> get_model_root(myapp.HomePage) wagtailcore.Page >>> get_model_root(wagtailimages.Image) wagtailimages.Image """ if model._meta.parents: parent_model = list(model._meta.parents.items())[0][0] return get_model_root(parent_model) return model