summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/15542.misc1
-rw-r--r--synapse/api/auth_blocking.py4
-rw-r--r--synapse/crypto/keyring.py4
-rw-r--r--synapse/federation/federation_base.py2
-rw-r--r--synapse/federation/federation_client.py4
-rw-r--r--synapse/federation/federation_server.py3
-rw-r--r--synapse/federation/send_queue.py3
-rw-r--r--synapse/federation/sender/__init__.py11
-rw-r--r--synapse/federation/transport/client.py4
-rw-r--r--synapse/federation/transport/server/_base.py5
-rw-r--r--synapse/handlers/event_auth.py5
-rw-r--r--synapse/handlers/federation.py3
-rw-r--r--synapse/handlers/federation_event.py3
-rw-r--r--synapse/handlers/profile.py4
-rw-r--r--synapse/handlers/sso.py3
-rw-r--r--synapse/handlers/typing.py3
-rw-r--r--synapse/rest/admin/media.py4
-rw-r--r--synapse/rest/client/room.py4
-rw-r--r--synapse/rest/media/download_resource.py4
-rw-r--r--synapse/rest/media/thumbnail_resource.py4
-rw-r--r--synapse/server.py4
-rw-r--r--synapse/storage/databases/main/room.py2
-rw-r--r--tests/unittest.py16
23 files changed, 64 insertions, 36 deletions
diff --git a/changelog.d/15542.misc b/changelog.d/15542.misc
new file mode 100644
index 0000000000..32e3d678a1
--- /dev/null
+++ b/changelog.d/15542.misc
@@ -0,0 +1 @@
+Factor out an `is_mine_server_name` method.
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index 22348d2d86..fcf5b842c6 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -39,7 +39,7 @@ class AuthBlocking:
         self._mau_limits_reserved_threepids = (
             hs.config.server.mau_limits_reserved_threepids
         )
-        self._server_name = hs.hostname
+        self._is_mine_server_name = hs.is_mine_server_name
         self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
 
     async def check_auth_blocking(
@@ -77,7 +77,7 @@ class AuthBlocking:
         if requester:
             if requester.authenticated_entity.startswith("@"):
                 user_id = requester.authenticated_entity
-            elif requester.authenticated_entity == self._server_name:
+            elif self._is_mine_server_name(requester.authenticated_entity):
                 # We never block the server from doing actions on behalf of
                 # users.
                 return
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index afdf6863d6..260aab3241 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -173,7 +173,7 @@ class Keyring:
             process_batch_callback=self._inner_fetch_key_requests,
         )
 
-        self._hostname = hs.hostname
+        self._is_mine_server_name = hs.is_mine_server_name
 
         # build a FetchKeyResult for each of our own keys, to shortcircuit the
         # fetcher.
@@ -277,7 +277,7 @@ class Keyring:
 
         # If we are the originating server, short-circuit the key-fetch for any keys
         # we already have
-        if verify_request.server_name == self._hostname:
+        if self._is_mine_server_name(verify_request.server_name):
             for key_id in verify_request.key_ids:
                 if key_id in self._local_verify_keys:
                     found_keys[key_id] = self._local_verify_keys[key_id]
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 3df975958d..b77022b406 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -49,7 +49,7 @@ class FederationBase:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
 
-        self.server_name = hs.hostname
+        self._is_mine_server_name = hs.is_mine_server_name
         self.keyring = hs.get_keyring()
         self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
         self.store = hs.get_datastores().main
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 0b2d1a78f7..076b9287c6 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -854,7 +854,7 @@ class FederationClient(FederationBase):
 
         for destination in destinations:
             # We don't want to ask our own server for information we don't have
-            if destination == self.server_name:
+            if self._is_mine_server_name(destination):
                 continue
 
             try:
@@ -1536,7 +1536,7 @@ class FederationClient(FederationBase):
         self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
     ) -> None:
         for destination in destinations:
-            if destination == self.server_name:
+            if self._is_mine_server_name(destination):
                 continue
 
             try:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index ca43c7bfc0..c590d8f96f 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -129,6 +129,7 @@ class FederationServer(FederationBase):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
+        self.server_name = hs.hostname
         self.handler = hs.get_federation_handler()
         self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
         self._federation_event_handler = hs.get_federation_event_handler()
@@ -942,7 +943,7 @@ class FederationServer(FederationBase):
             authorising_server = get_domain_from_id(
                 event.content[EventContentFields.AUTHORISING_USER]
             )
-            if authorising_server != self.server_name:
+            if not self._is_mine_server_name(authorising_server):
                 raise SynapseError(
                     400,
                     f"Cannot authorise request from resident server: {authorising_server}",
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 0b7c81677e..fb448f2155 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -68,6 +68,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
         self.is_mine_id = hs.is_mine_id
+        self.is_mine_server_name = hs.is_mine_server_name
 
         # We may have multiple federation sender instances, so we need to track
         # their positions separately.
@@ -198,7 +199,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
         key: Optional[Hashable] = None,
     ) -> None:
         """As per FederationSender"""
-        if destination == self.server_name:
+        if self.is_mine_server_name(destination):
             logger.info("Not sending EDU to ourselves")
             return
 
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index edc4b1768c..f3bdc5a4d2 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -362,6 +362,7 @@ class FederationSender(AbstractFederationSender):
 
         self.clock = hs.get_clock()
         self.is_mine_id = hs.is_mine_id
+        self.is_mine_server_name = hs.is_mine_server_name
 
         self._presence_router: Optional["PresenceRouter"] = None
         self._transaction_manager = TransactionManager(hs)
@@ -766,7 +767,7 @@ class FederationSender(AbstractFederationSender):
         domains = [
             d
             for d in domains_set
-            if d != self.server_name
+            if not self.is_mine_server_name(d)
             and self._federation_shard_config.should_handle(self._instance_name, d)
         ]
         if not domains:
@@ -832,7 +833,7 @@ class FederationSender(AbstractFederationSender):
             assert self.is_mine_id(state.user_id)
 
         for destination in destinations:
-            if destination == self.server_name:
+            if self.is_mine_server_name(destination):
                 continue
             if not self._federation_shard_config.should_handle(
                 self._instance_name, destination
@@ -860,7 +861,7 @@ class FederationSender(AbstractFederationSender):
             content: content of EDU
             key: clobbering key for this edu
         """
-        if destination == self.server_name:
+        if self.is_mine_server_name(destination):
             logger.info("Not sending EDU to ourselves")
             return
 
@@ -897,7 +898,7 @@ class FederationSender(AbstractFederationSender):
             queue.send_edu(edu)
 
     def send_device_messages(self, destination: str, immediate: bool = True) -> None:
-        if destination == self.server_name:
+        if self.is_mine_server_name(destination):
             logger.warning("Not sending device update to ourselves")
             return
 
@@ -919,7 +920,7 @@ class FederationSender(AbstractFederationSender):
         might have come back.
         """
 
-        if destination == self.server_name:
+        if self.is_mine_server_name(destination):
             logger.warning("Not waking up ourselves")
             return
 
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index bc70b94f68..d2fa9976da 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -58,9 +58,9 @@ class TransportLayerClient:
     """Sends federation HTTP requests to other servers"""
 
     def __init__(self, hs: "HomeServer"):
-        self.server_name = hs.hostname
         self.client = hs.get_federation_http_client()
         self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
+        self._is_mine_server_name = hs.is_mine_server_name
 
     async def get_room_state_ids(
         self, destination: str, room_id: str, event_id: str
@@ -235,7 +235,7 @@ class TransportLayerClient:
             transaction.transaction_id,
         )
 
-        if transaction.destination == self.server_name:
+        if self._is_mine_server_name(transaction.destination):
             raise RuntimeError("Transport layer cannot send to itself!")
 
         # FIXME: This is only used by the tests. The actual json sent is
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index cdaf0d5de7..b6e9c58760 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -57,6 +57,7 @@ class Authenticator:
         self._clock = hs.get_clock()
         self.keyring = hs.get_keyring()
         self.server_name = hs.hostname
+        self._is_mine_server_name = hs.is_mine_server_name
         self.store = hs.get_datastores().main
         self.federation_domain_whitelist = (
             hs.config.federation.federation_domain_whitelist
@@ -100,7 +101,9 @@ class Authenticator:
                 json_request["signatures"].setdefault(origin, {})[key] = sig
 
                 # if the origin_server sent a destination along it needs to match our own server_name
-                if destination is not None and destination != self.server_name:
+                if destination is not None and not self._is_mine_server_name(
+                    destination
+                ):
                     raise AuthenticationError(
                         HTTPStatus.UNAUTHORIZED,
                         "Destination mismatch in auth header",
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 0db0bd7304..3e37c0cbe2 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -29,7 +29,7 @@ from synapse.event_auth import (
 )
 from synapse.events import EventBase
 from synapse.events.builder import EventBuilder
-from synapse.types import StateMap, StrCollection, get_domain_from_id
+from synapse.types import StateMap, StrCollection
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -47,6 +47,7 @@ class EventAuthHandler:
         self._store = hs.get_datastores().main
         self._state_storage_controller = hs.get_storage_controllers().state
         self._server_name = hs.hostname
+        self._is_mine_id = hs.is_mine_id
 
     async def check_auth_rules_from_context(
         self,
@@ -247,7 +248,7 @@ class EventAuthHandler:
         if not await self.is_user_in_rooms(allowed_rooms, user_id):
             # If this is a remote request, the user might be in an allowed room
             # that we do not know about.
-            if get_domain_from_id(user_id) != self._server_name:
+            if not self._is_mine_id(user_id):
                 for room_id in allowed_rooms:
                     if not await self._store.is_host_joined(room_id, self._server_name):
                         raise SynapseError(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4ad808a5b4..19dec4812f 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -141,6 +141,7 @@ class FederationHandler:
         self.server_name = hs.hostname
         self.keyring = hs.get_keyring()
         self.is_mine_id = hs.is_mine_id
+        self.is_mine_server_name = hs.is_mine_server_name
         self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
         self.event_creation_handler = hs.get_event_creation_handler()
         self.event_builder_factory = hs.get_event_builder_factory()
@@ -453,7 +454,7 @@ class FederationHandler:
 
             for dom in domains:
                 # We don't want to ask our own server for information we don't have
-                if dom == self.server_name:
+                if self.is_mine_server_name(dom):
                     continue
 
                 try:
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index fc15024166..06343d40e4 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -163,6 +163,7 @@ class FederationEventHandler:
         self._notifier = hs.get_notifier()
 
         self._is_mine_id = hs.is_mine_id
+        self._is_mine_server_name = hs.is_mine_server_name
         self._server_name = hs.hostname
         self._instance_name = hs.get_instance_name()
 
@@ -688,7 +689,7 @@ class FederationEventHandler:
         server from invalid events (there is probably no point in trying to
         re-fetch invalid events from every other HS in the room.)
         """
-        if dest == self._server_name:
+        if self._is_mine_server_name(dest):
             raise SynapseError(400, "Can't backfill from self.")
 
         events = await self._federation_client.backfill(
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 983b9b66fb..48f9858931 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -59,7 +59,7 @@ class ProfileHandler:
         self.max_avatar_size = hs.config.server.max_avatar_size
         self.allowed_avatar_mimetypes = hs.config.server.allowed_avatar_mimetypes
 
-        self.server_name = hs.config.server.server_name
+        self._is_mine_server_name = hs.is_mine_server_name
 
         self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
 
@@ -309,7 +309,7 @@ class ProfileHandler:
         else:
             server_name = host
 
-        if server_name == self.server_name:
+        if self._is_mine_server_name(server_name):
             media_info = await self.store.get_local_media(media_id)
         else:
             media_info = await self.store.get_cached_remote_media(server_name, media_id)
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index c28325323c..92c3742625 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -194,6 +194,7 @@ class SsoHandler:
         self._clock = hs.get_clock()
         self._store = hs.get_datastores().main
         self._server_name = hs.hostname
+        self._is_mine_server_name = hs.is_mine_server_name
         self._registration_handler = hs.get_registration_handler()
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
@@ -802,7 +803,7 @@ class SsoHandler:
             if profile["avatar_url"] is not None:
                 server_name = profile["avatar_url"].split("/")[-2]
                 media_id = profile["avatar_url"].split("/")[-1]
-                if server_name == self._server_name:
+                if self._is_mine_server_name(server_name):
                     media = await self._media_repo.store.get_local_media(media_id)
                     if media is not None and upload_name == media["upload_name"]:
                         logger.info("skipping saving the user avatar")
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 39ae44ea95..7aeae5319c 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -68,6 +68,7 @@ class FollowerTypingHandler:
         self.server_name = hs.config.server.server_name
         self.clock = hs.get_clock()
         self.is_mine_id = hs.is_mine_id
+        self.is_mine_server_name = hs.is_mine_server_name
 
         self.federation = None
         if hs.should_send_federation():
@@ -153,7 +154,7 @@ class FollowerTypingHandler:
                 member.room_id
             )
             for domain in hosts:
-                if domain != self.server_name:
+                if not self.is_mine_server_name(domain):
                     logger.debug("sending typing update to %s", domain)
                     self.federation.build_and_send_edu(
                         destination=domain,
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index c134ccfb3d..b7637dff0b 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -258,7 +258,7 @@ class DeleteMediaByID(RestServlet):
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
         self.auth = hs.get_auth()
-        self.server_name = hs.hostname
+        self._is_mine_server_name = hs.is_mine_server_name
         self.media_repository = hs.get_media_repository()
 
     async def on_DELETE(
@@ -266,7 +266,7 @@ class DeleteMediaByID(RestServlet):
     ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        if self.server_name != server_name:
+        if not self._is_mine_server_name(server_name):
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")
 
         if await self.store.get_local_media(media_id) is None:
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 7699cc8d1b..951bd033f5 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -501,7 +501,7 @@ class PublicRoomListRestServlet(RestServlet):
             limit = None
 
         handler = self.hs.get_room_list_handler()
-        if server and server != self.hs.config.server.server_name:
+        if server and not self.hs.is_mine_server_name(server):
             # Ensure the server is valid.
             try:
                 parse_and_validate_server_name(server)
@@ -551,7 +551,7 @@ class PublicRoomListRestServlet(RestServlet):
             limit = None
 
         handler = self.hs.get_room_list_handler()
-        if server and server != self.hs.config.server.server_name:
+        if server and not self.hs.is_mine_server_name(server):
             # Ensure the server is valid.
             try:
                 parse_and_validate_server_name(server)
diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py
index 8f270cf4cc..3c618ef60a 100644
--- a/synapse/rest/media/download_resource.py
+++ b/synapse/rest/media/download_resource.py
@@ -37,7 +37,7 @@ class DownloadResource(DirectServeJsonResource):
     def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
         super().__init__()
         self.media_repo = media_repo
-        self.server_name = hs.hostname
+        self._is_mine_server_name = hs.is_mine_server_name
 
     async def _async_render_GET(self, request: SynapseRequest) -> None:
         set_cors_headers(request)
@@ -59,7 +59,7 @@ class DownloadResource(DirectServeJsonResource):
             b"no-referrer",
         )
         server_name, media_id, name = parse_media_id(request)
-        if server_name == self.server_name:
+        if self._is_mine_server_name(server_name):
             await self.media_repo.get_local_media(request, media_id, name)
         else:
             allow_remote = parse_boolean(request, "allow_remote", default=True)
diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py
index 4ee2a0dbda..a6396fb05a 100644
--- a/synapse/rest/media/thumbnail_resource.py
+++ b/synapse/rest/media/thumbnail_resource.py
@@ -59,7 +59,7 @@ class ThumbnailResource(DirectServeJsonResource):
         self.media_repo = media_repo
         self.media_storage = media_storage
         self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
-        self.server_name = hs.hostname
+        self._is_mine_server_name = hs.is_mine_server_name
 
     async def _async_render_GET(self, request: SynapseRequest) -> None:
         set_cors_headers(request)
@@ -71,7 +71,7 @@ class ThumbnailResource(DirectServeJsonResource):
         # TODO Parse the Accept header to get an prioritised list of thumbnail types.
         m_type = "image/png"
 
-        if server_name == self.server_name:
+        if self._is_mine_server_name(server_name):
             if self.dynamic_thumbnails:
                 await self._select_or_generate_local_thumbnail(
                     request, media_id, width, height, method, m_type
diff --git a/synapse/server.py b/synapse/server.py
index c557c60482..fd29c28173 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -377,6 +377,10 @@ class HomeServer(metaclass=abc.ABCMeta):
             return False
         return localpart_hostname[1] == self.hostname
 
+    def is_mine_server_name(self, server_name: str) -> bool:
+        """Determines whether a server name refers to this homeserver."""
+        return server_name == self.hostname
+
     @cache_in_self
     def get_clock(self) -> Clock:
         return Clock(self._reactor)
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index dd7dbb6901..ca8be8c80d 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -996,7 +996,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                 If it is `None` media will be removed from quarantine
         """
         logger.info("Quarantining media: %s/%s", server_name, media_id)
-        is_local = server_name == self.config.server.server_name
+        is_local = self.hs.is_mine_server_name(server_name)
 
         def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int:
             local_mxcs = [media_id] if is_local else []
diff --git a/tests/unittest.py b/tests/unittest.py
index ee2f78ab01..b6fdf69635 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -566,7 +566,9 @@ class HomeserverTestCase(TestCase):
             client_ip,
         )
 
-    def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer:
+    def setup_test_homeserver(
+        self, name: Optional[str] = None, **kwargs: Any
+    ) -> HomeServer:
         """
         Set up the test homeserver, meant to be called by the overridable
         make_homeserver. It automatically passes through the test class's
@@ -585,15 +587,25 @@ class HomeserverTestCase(TestCase):
         else:
             config = kwargs["config"]
 
+        # The server name can be specified using either the `name` argument or a config
+        # override. The `name` argument takes precedence over any config overrides.
+        if name is not None:
+            config["server_name"] = name
+
         # Parse the config from a config dict into a HomeServerConfig
         config_obj = make_homeserver_config_obj(config)
         kwargs["config"] = config_obj
 
+        # The server name in the config is now `name`, if provided, or the `server_name`
+        # from a config override, or the default of "test". Whichever it is, we
+        # construct a homeserver with a matching name.
+        kwargs["name"] = config_obj.server.server_name
+
         async def run_bg_updates() -> None:
             with LoggingContext("run_bg_updates"):
                 self.get_success(stor.db_pool.updates.run_background_updates(False))
 
-        hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
+        hs = setup_test_homeserver(self.addCleanup, **kwargs)
         stor = hs.get_datastores().main
 
         # Run the database background updates, when running against "master".