angrybeanie_wagtail/env/lib/python3.12/site-packages/wagtail/tests/test_hooks.py
2025-07-25 21:32:16 +10:00

201 lines
6.8 KiB
Python

from unittest import mock
from django.contrib.sessions.middleware import SessionMiddleware
from django.http import HttpResponse
from django.test import RequestFactory, TestCase
from wagtail import hooks
from wagtail.models import Page, PageViewRestriction
from wagtail.test.utils import WagtailTestUtils
from wagtail.views import serve, serve_chain
from wagtail.wagtail_hooks import check_view_restrictions
def test_hook():
pass
class TestLoginView(WagtailTestUtils, TestCase):
fixtures = ["test.json"]
@classmethod
def setUpClass(cls):
hooks.register("test_hook_name", test_hook)
@classmethod
def tearDownClass(cls):
del hooks._hooks["test_hook_name"]
def test_before_hook(self):
def before_hook():
pass
with self.register_hook("test_hook_name", before_hook, order=-1):
hook_fns = hooks.get_hooks("test_hook_name")
self.assertEqual(hook_fns, [before_hook, test_hook])
def test_after_hook(self):
def after_hook():
pass
with self.register_hook("test_hook_name", after_hook, order=1):
hook_fns = hooks.get_hooks("test_hook_name")
self.assertEqual(hook_fns, [test_hook, after_hook])
class TestServeHooks(WagtailTestUtils, TestCase):
fixtures = ["test.json"]
def setUp(self):
self.page = Page.objects.get(id=2)
self.request = RequestFactory().get("/test/")
middleware = SessionMiddleware(lambda x: None)
middleware.process_request(self.request)
self.request.session.save()
def test_serve_chain_order(self):
order_calls = []
def hook_1(next_fn):
def wrapper(page, request, *args, **kwargs):
order_calls.append(1)
return next_fn(page, request, *args, **kwargs)
return wrapper
def hook_2(next_fn):
def wrapper(page, request, *args, **kwargs):
order_calls.append(2)
return next_fn(page, request, *args, **kwargs)
return wrapper
def hook_3(next_fn):
def wrapper(page, request, *args, **kwargs):
order_calls.append(3)
return next_fn(page, request, *args, **kwargs)
return wrapper
with self.register_hook("on_serve_page", hook_1):
with self.register_hook("on_serve_page", hook_2):
with self.register_hook("on_serve_page", hook_3):
serve(self.request, self.page.url)
self.assertEqual(order_calls, [1, 2, 3])
def test_serve_chain_modification(self):
def hook_modifier(next_fn):
def wrapper(page, request, *args, **kwargs):
response = next_fn(page, request, *args, **kwargs)
response.content = b"Modified content"
return response
return wrapper
with self.register_hook("on_serve_page", hook_modifier):
response = serve(self.request, self.page.url)
self.assertEqual(response.content, b"Modified content")
def test_serve_chain_halt_execution(self):
def hook_halt(next_fn):
def wrapper(page, request, *args, **kwargs):
return HttpResponse("Halted")
return wrapper
with self.register_hook("on_serve_page", hook_halt):
response = serve(self.request, self.page.url)
self.assertEqual(response.content, b"Halted")
def test_serve_chain_view_restriction(self):
restriction = PageViewRestriction.objects.create(
page=self.page,
restriction_type=PageViewRestriction.PASSWORD,
password="password",
)
with self.register_hook("on_serve_page", check_view_restrictions):
response = self.client.get(self.page.url)
self.assertEqual(response.status_code, 200)
self.assertTemplateUsed(response, "wagtailcore/password_required.html")
restriction.delete()
def test_serve_always_called_last(self):
hook_calls = []
serve_called = []
def hook_1(next_fn):
def wrapper(page, request, *args, **kwargs):
hook_calls.append(1)
return next_fn(page, request, *args, **kwargs)
return wrapper
def hook_2(next_fn):
def wrapper(page, request, *args, **kwargs):
hook_calls.append(2)
return next_fn(page, request, *args, **kwargs)
return wrapper
def hook_3(next_fn):
def wrapper(page, request, *args, **kwargs):
hook_calls.append(3)
return next_fn(page, request, *args, **kwargs)
return wrapper
original_serve_chain = serve_chain
def mock_serve_chain(page, request, *args, **kwargs):
serve_called.append(True)
return original_serve_chain(page, request, *args, **kwargs)
with mock.patch("wagtail.views.serve_chain", mock_serve_chain):
with self.register_hook("on_serve_page", hook_1):
with self.register_hook("on_serve_page", hook_2):
with self.register_hook("on_serve_page", hook_3):
serve(self.request, self.page.url)
self.assertEqual(hook_calls, [1, 2, 3])
self.assertEqual(len(serve_called), 1)
self.assertTrue(serve_called[0])
def test_check_view_restrictions_receives_correct_parameters(self):
received_params = []
def hook_spy(next_fn):
def wrapper(page, request, *args, **kwargs):
received_params.append(
{"page": page, "request": request, "args": args, "kwargs": kwargs}
)
return next_fn(page, request, *args, **kwargs)
return wrapper
self.assertIsNotNone(self.page, "Test page should not be None")
self.assertIsNotNone(self.request, "Test request should not be None")
with self.register_hook("on_serve_page", hook_spy):
route_result = Page.route_for_request(self.request, self.page.url)
self.assertIsNotNone(route_result, "route_result should not be None")
if route_result:
page, args, kwargs = route_result
serve(self.request, self.page.url)
self.assertEqual(len(received_params), 1)
params = received_params[0]
self.assertIsNotNone(params["page"], "Hook received None as page")
self.assertIsNotNone(params["request"], "Hook received None as request")
self.assertEqual(params["page"], self.page)
self.assertEqual(params["request"], self.request)
self.assertEqual(params["args"], ([], {}))
self.assertEqual(params["kwargs"], {})