|
| 1 | +import threading |
| 2 | + |
| 3 | +from asgiref.sync import async_to_sync |
| 4 | + |
1 | 5 | from django.contrib.sessions.middleware import SessionMiddleware
|
| 6 | +from django.db import connection |
| 7 | +from django.http.request import HttpRequest |
| 8 | +from django.http.response import HttpResponse |
2 | 9 | from django.middleware.cache import (
|
3 | 10 | CacheMiddleware, FetchFromCacheMiddleware, UpdateCacheMiddleware,
|
4 | 11 | )
|
5 | 12 | from django.middleware.common import CommonMiddleware
|
6 | 13 | from django.middleware.security import SecurityMiddleware
|
7 | 14 | from django.test import SimpleTestCase
|
8 |
| -from django.utils.deprecation import RemovedInDjango40Warning |
| 15 | +from django.utils.deprecation import MiddlewareMixin, RemovedInDjango40Warning |
9 | 16 |
|
10 | 17 |
|
11 | 18 | class MiddlewareMixinTests(SimpleTestCase):
|
@@ -37,3 +44,37 @@ def test_subclass_deprecation(self):
|
37 | 44 | with self.subTest(middleware=middleware):
|
38 | 45 | with self.assertRaisesMessage(RemovedInDjango40Warning, self.msg):
|
39 | 46 | middleware()
|
| 47 | + |
| 48 | + def test_sync_to_async_uses_base_thread_and_connection(self): |
| 49 | + """ |
| 50 | + The process_request() and process_response() hooks must be called with |
| 51 | + the sync_to_async thread_sensitive flag enabled, so that database |
| 52 | + operations use the correct thread and connection. |
| 53 | + """ |
| 54 | + def request_lifecycle(): |
| 55 | + """Fake request_started/request_finished.""" |
| 56 | + return (threading.get_ident(), id(connection)) |
| 57 | + |
| 58 | + async def get_response(self): |
| 59 | + return HttpResponse() |
| 60 | + |
| 61 | + class SimpleMiddleWare(MiddlewareMixin): |
| 62 | + def process_request(self, request): |
| 63 | + request.thread_and_connection = request_lifecycle() |
| 64 | + |
| 65 | + def process_response(self, request, response): |
| 66 | + response.thread_and_connection = request_lifecycle() |
| 67 | + return response |
| 68 | + |
| 69 | + threads_and_connections = [] |
| 70 | + threads_and_connections.append(request_lifecycle()) |
| 71 | + |
| 72 | + request = HttpRequest() |
| 73 | + response = async_to_sync(SimpleMiddleWare(get_response))(request) |
| 74 | + threads_and_connections.append(request.thread_and_connection) |
| 75 | + threads_and_connections.append(response.thread_and_connection) |
| 76 | + |
| 77 | + threads_and_connections.append(request_lifecycle()) |
| 78 | + |
| 79 | + self.assertEqual(len(threads_and_connections), 4) |
| 80 | + self.assertEqual(len(set(threads_and_connections)), 1) |
0 commit comments