summary refs log tree commit diff
path: root/tests/unittest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittest.py')
-rw-r--r--tests/unittest.py49
1 files changed, 44 insertions, 5 deletions
diff --git a/tests/unittest.py b/tests/unittest.py
index e654c0442d..040b126a27 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import hmac
 import inspect
 import logging
 import time
-from typing import Optional, Tuple, Type, TypeVar, Union
+from typing import Optional, Tuple, Type, TypeVar, Union, overload
 
 from mock import Mock, patch
 
@@ -241,7 +241,7 @@ class HomeserverTestCase(TestCase):
         # create a site to wrap the resource.
         self.site = SynapseSite(
             logger_name="synapse.access.http.fake",
-            site_tag="test",
+            site_tag=self.hs.config.server.server_name,
             config=self.hs.config.server.listeners[0],
             resource=self.resource,
             server_version_string="1",
@@ -254,17 +254,24 @@ class HomeserverTestCase(TestCase):
         if hasattr(self, "user_id"):
             if self.hijack_auth:
 
+                # We need a valid token ID to satisfy foreign key constraints.
+                token_id = self.get_success(
+                    self.hs.get_datastore().add_access_token_to_user(
+                        self.helper.auth_user_id, "some_fake_token", None, None,
+                    )
+                )
+
                 async def get_user_by_access_token(token=None, allow_guest=False):
                     return {
                         "user": UserID.from_string(self.helper.auth_user_id),
-                        "token_id": 1,
+                        "token_id": token_id,
                         "is_guest": False,
                     }
 
                 async def get_user_by_req(request, allow_guest=False, rights="access"):
                     return create_requester(
                         UserID.from_string(self.helper.auth_user_id),
-                        1,
+                        token_id,
                         False,
                         False,
                         None,
@@ -357,6 +364,36 @@ class HomeserverTestCase(TestCase):
         Function to optionally be overridden in subclasses.
         """
 
+    # Annoyingly mypy doesn't seem to pick up the fact that T is SynapseRequest
+    # when the `request` arg isn't given, so we define an explicit override to
+    # cover that case.
+    @overload
+    def make_request(
+        self,
+        method: Union[bytes, str],
+        path: Union[bytes, str],
+        content: Union[bytes, dict] = b"",
+        access_token: Optional[str] = None,
+        shorthand: bool = True,
+        federation_auth_origin: str = None,
+        content_is_form: bool = False,
+    ) -> Tuple[SynapseRequest, FakeChannel]:
+        ...
+
+    @overload
+    def make_request(
+        self,
+        method: Union[bytes, str],
+        path: Union[bytes, str],
+        content: Union[bytes, dict] = b"",
+        access_token: Optional[str] = None,
+        request: Type[T] = SynapseRequest,
+        shorthand: bool = True,
+        federation_auth_origin: str = None,
+        content_is_form: bool = False,
+    ) -> Tuple[T, FakeChannel]:
+        ...
+
     def make_request(
         self,
         method: Union[bytes, str],
@@ -608,7 +645,9 @@ class HomeserverTestCase(TestCase):
         if soft_failed:
             event.internal_metadata.soft_failed = True
 
-        self.get_success(event_creator.send_nonmember_event(requester, event, context))
+        self.get_success(
+            event_creator.handle_new_client_event(requester, event, context)
+        )
 
         return event.event_id