summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/11915.misc1
-rw-r--r--synapse/appservice/__init__.py133
-rw-r--r--synapse/handlers/appservice.py4
-rw-r--r--synapse/handlers/directory.py6
-rw-r--r--synapse/handlers/receipts.py2
-rw-r--r--synapse/handlers/typing.py4
-rw-r--r--tests/appservice/test_appservice.py45
-rw-r--r--tests/handlers/test_appservice.py56
8 files changed, 175 insertions, 76 deletions
diff --git a/changelog.d/11915.misc b/changelog.d/11915.misc
new file mode 100644
index 0000000000..e3cef1511e
--- /dev/null
+++ b/changelog.d/11915.misc
@@ -0,0 +1 @@
+Simplify the `ApplicationService` class' set of public methods related to interest checking.
\ No newline at end of file
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 4d3f8e4923..07ec95f1d6 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -175,27 +175,14 @@ class ApplicationService:
             return namespace.exclusive
         return False
 
-    async def _matches_user(self, event: EventBase, store: "DataStore") -> bool:
-        if self.is_interested_in_user(event.sender):
-            return True
-
-        # also check m.room.member state key
-        if event.type == EventTypes.Member and self.is_interested_in_user(
-            event.state_key
-        ):
-            return True
-
-        does_match = await self.matches_user_in_member_list(event.room_id, store)
-        return does_match
-
     @cached(num_args=1, cache_context=True)
-    async def matches_user_in_member_list(
+    async def _matches_user_in_member_list(
         self,
         room_id: str,
         store: "DataStore",
         cache_context: _CacheContext,
     ) -> bool:
-        """Check if this service is interested a room based upon it's membership
+        """Check if this service is interested a room based upon its membership
 
         Args:
             room_id: The room to check.
@@ -214,47 +201,110 @@ class ApplicationService:
                 return True
         return False
 
-    def _matches_room_id(self, event: EventBase) -> bool:
-        if hasattr(event, "room_id"):
-            return self.is_interested_in_room(event.room_id)
-        return False
+    def is_interested_in_user(
+        self,
+        user_id: str,
+    ) -> bool:
+        """
+        Returns whether the application is interested in a given user ID.
+
+        The appservice is considered to be interested in a user if either: the
+        user ID is in the appservice's user namespace, or if the user is the
+        appservice's configured sender_localpart.
+
+        Args:
+            user_id: The ID of the user to check.
+
+        Returns:
+            True if the application service is interested in the user, False if not.
+        """
+        return (
+            # User is the appservice's sender_localpart user
+            user_id == self.sender
+            # User is in the appservice's user namespace
+            or self.is_user_in_namespace(user_id)
+        )
+
+    @cached(num_args=1, cache_context=True)
+    async def is_interested_in_room(
+        self,
+        room_id: str,
+        store: "DataStore",
+        cache_context: _CacheContext,
+    ) -> bool:
+        """
+        Returns whether the application service is interested in a given room ID.
+
+        The appservice is considered to be interested in the room if either: the ID or one
+        of the aliases of the room is in the appservice's room ID or alias namespace
+        respectively, or if one of the members of the room fall into the appservice's user
+        namespace.
 
-    async def _matches_aliases(self, event: EventBase, store: "DataStore") -> bool:
-        alias_list = await store.get_aliases_for_room(event.room_id)
+        Args:
+            room_id: The ID of the room to check.
+            store: The homeserver's datastore class.
+
+        Returns:
+            True if the application service is interested in the room, False if not.
+        """
+        # Check if we have interest in this room ID
+        if self.is_room_id_in_namespace(room_id):
+            return True
+
+        # likewise with the room's aliases (if it has any)
+        alias_list = await store.get_aliases_for_room(room_id)
         for alias in alias_list:
-            if self.is_interested_in_alias(alias):
+            if self.is_room_alias_in_namespace(alias):
                 return True
 
-        return False
+        # And finally, perform an expensive check on whether any of the
+        # users in the room match the appservice's user namespace
+        return await self._matches_user_in_member_list(
+            room_id, store, on_invalidate=cache_context.invalidate
+        )
 
-    async def is_interested(self, event: EventBase, store: "DataStore") -> bool:
+    @cached(num_args=1, cache_context=True)
+    async def is_interested_in_event(
+        self,
+        event_id: str,
+        event: EventBase,
+        store: "DataStore",
+        cache_context: _CacheContext,
+    ) -> bool:
         """Check if this service is interested in this event.
 
         Args:
+            event_id: The ID of the event to check. This is purely used for simplifying the
+                caching of calls to this method.
             event: The event to check.
             store: The datastore to query.
 
         Returns:
-            True if this service would like to know about this event.
+            True if this service would like to know about this event, otherwise False.
         """
-        # Do cheap checks first
-        if self._matches_room_id(event):
+        # Check if we're interested in this event's sender by namespace (or if they're the
+        # sender_localpart user)
+        if self.is_interested_in_user(event.sender):
             return True
 
-        # This will check the namespaces first before
-        # checking the store, so should be run before _matches_aliases
-        if await self._matches_user(event, store):
+        # additionally, if this is a membership event, perform the same checks on
+        # the user it references
+        if event.type == EventTypes.Member and self.is_interested_in_user(
+            event.state_key
+        ):
             return True
 
-        # This will check the store, so should be run last
-        if await self._matches_aliases(event, store):
+        # This will check the datastore, so should be run last
+        if await self.is_interested_in_room(
+            event.room_id, store, on_invalidate=cache_context.invalidate
+        ):
             return True
 
         return False
 
-    @cached(num_args=1)
+    @cached(num_args=1, cache_context=True)
     async def is_interested_in_presence(
-        self, user_id: UserID, store: "DataStore"
+        self, user_id: UserID, store: "DataStore", cache_context: _CacheContext
     ) -> bool:
         """Check if this service is interested a user's presence
 
@@ -272,20 +322,19 @@ class ApplicationService:
 
         # Then find out if the appservice is interested in any of those rooms
         for room_id in room_ids:
-            if await self.matches_user_in_member_list(room_id, store):
+            if await self.is_interested_in_room(
+                room_id, store, on_invalidate=cache_context.invalidate
+            ):
                 return True
         return False
 
-    def is_interested_in_user(self, user_id: str) -> bool:
-        return (
-            bool(self._matches_regex(ApplicationService.NS_USERS, user_id))
-            or user_id == self.sender
-        )
+    def is_user_in_namespace(self, user_id: str) -> bool:
+        return bool(self._matches_regex(ApplicationService.NS_USERS, user_id))
 
-    def is_interested_in_alias(self, alias: str) -> bool:
+    def is_room_alias_in_namespace(self, alias: str) -> bool:
         return bool(self._matches_regex(ApplicationService.NS_ALIASES, alias))
 
-    def is_interested_in_room(self, room_id: str) -> bool:
+    def is_room_id_in_namespace(self, room_id: str) -> bool:
         return bool(self._matches_regex(ApplicationService.NS_ROOMS, room_id))
 
     def is_exclusive_user(self, user_id: str) -> bool:
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index e6461cc3c9..bd913e524e 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -571,7 +571,7 @@ class ApplicationServicesHandler:
         room_alias_str = room_alias.to_string()
         services = self.store.get_app_services()
         alias_query_services = [
-            s for s in services if (s.is_interested_in_alias(room_alias_str))
+            s for s in services if (s.is_room_alias_in_namespace(room_alias_str))
         ]
         for alias_service in alias_query_services:
             is_known_alias = await self.appservice_api.query_alias(
@@ -660,7 +660,7 @@ class ApplicationServicesHandler:
         # inside of a list comprehension anymore.
         interested_list = []
         for s in services:
-            if await s.is_interested(event, self.store):
+            if await s.is_interested_in_event(event.event_id, event, self.store):
                 interested_list.append(s)
 
         return interested_list
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index b7064c6624..33d827a45b 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -119,7 +119,7 @@ class DirectoryHandler:
 
         service = requester.app_service
         if service:
-            if not service.is_interested_in_alias(room_alias_str):
+            if not service.is_room_alias_in_namespace(room_alias_str):
                 raise SynapseError(
                     400,
                     "This application service has not reserved this kind of alias.",
@@ -221,7 +221,7 @@ class DirectoryHandler:
     async def delete_appservice_association(
         self, service: ApplicationService, room_alias: RoomAlias
     ) -> None:
-        if not service.is_interested_in_alias(room_alias.to_string()):
+        if not service.is_room_alias_in_namespace(room_alias.to_string()):
             raise SynapseError(
                 400,
                 "This application service has not reserved this kind of alias",
@@ -376,7 +376,7 @@ class DirectoryHandler:
         # non-exclusive locks on the alias (or there are no interested services)
         services = self.store.get_app_services()
         interested_services = [
-            s for s in services if s.is_interested_in_alias(alias.to_string())
+            s for s in services if s.is_room_alias_in_namespace(alias.to_string())
         ]
 
         for service in interested_services:
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index b4132c353a..6250bb3bdf 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -269,7 +269,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
         # Then filter down to rooms that the AS can read
         events = []
         for room_id, event in rooms_to_events.items():
-            if not await service.matches_user_in_member_list(room_id, self.store):
+            if not await service.is_interested_in_room(room_id, self.store):
                 continue
 
             events.append(event)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 843c68eb0f..3b89126528 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -486,9 +486,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
                 if handler._room_serials[room_id] <= from_key:
                     continue
 
-                if not await service.matches_user_in_member_list(
-                    room_id, self._main_store
-                ):
+                if not await service.is_interested_in_room(room_id, self._main_store):
                     continue
 
                 events.append(self._make_event_for(room_id))
diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 9bd6275e92..edc584d0cf 100644
--- a/tests/appservice/test_appservice.py
+++ b/tests/appservice/test_appservice.py
@@ -36,7 +36,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
             hostname="matrix.org",  # only used by get_groups_for_user
         )
         self.event = Mock(
-            type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
+            event_id="$abc:xyz",
+            type="m.something",
+            room_id="!foo:bar",
+            sender="@someone:somewhere",
         )
 
         self.store = Mock()
@@ -50,7 +53,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertTrue(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(self.event, self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
@@ -62,7 +67,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertFalse(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(self.event, self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
@@ -76,7 +83,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertTrue(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(self.event, self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
@@ -90,7 +99,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertTrue(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(self.event, self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
@@ -104,7 +115,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertFalse(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(self.event, self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
@@ -121,7 +134,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertTrue(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(self.event, self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
@@ -174,7 +189,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertFalse(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(self.event, self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
@@ -191,7 +208,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertTrue(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(self.event, self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
@@ -207,7 +226,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertTrue(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(self.event, self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
@@ -225,7 +246,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.assertTrue(
             (
                 yield defer.ensureDeferred(
-                    self.service.is_interested(event=self.event, store=self.store)
+                    self.service.is_interested_in_event(
+                        self.event.event_id, self.event, self.store
+                    )
                 )
             )
         )
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 072e6bbcdd..cead9f90df 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -59,11 +59,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         self.event_source = hs.get_event_sources()
 
     def test_notify_interested_services(self):
-        interested_service = self._mkservice(is_interested=True)
+        interested_service = self._mkservice(is_interested_in_event=True)
         services = [
-            self._mkservice(is_interested=False),
+            self._mkservice(is_interested_in_event=False),
             interested_service,
-            self._mkservice(is_interested=False),
+            self._mkservice(is_interested_in_event=False),
         ]
 
         self.mock_as_api.query_user.return_value = make_awaitable(True)
@@ -85,7 +85,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
 
     def test_query_user_exists_unknown_user(self):
         user_id = "@someone:anywhere"
-        services = [self._mkservice(is_interested=True)]
+        services = [self._mkservice(is_interested_in_event=True)]
         services[0].is_interested_in_user.return_value = True
         self.mock_store.get_app_services.return_value = services
         self.mock_store.get_user_by_id.return_value = make_awaitable(None)
@@ -102,7 +102,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
 
     def test_query_user_exists_known_user(self):
         user_id = "@someone:anywhere"
-        services = [self._mkservice(is_interested=True)]
+        services = [self._mkservice(is_interested_in_event=True)]
         services[0].is_interested_in_user.return_value = True
         self.mock_store.get_app_services.return_value = services
         self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
@@ -127,11 +127,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
 
         room_id = "!alpha:bet"
         servers = ["aperture"]
-        interested_service = self._mkservice_alias(is_interested_in_alias=True)
+        interested_service = self._mkservice_alias(is_room_alias_in_namespace=True)
         services = [
-            self._mkservice_alias(is_interested_in_alias=False),
+            self._mkservice_alias(is_room_alias_in_namespace=False),
             interested_service,
-            self._mkservice_alias(is_interested_in_alias=False),
+            self._mkservice_alias(is_room_alias_in_namespace=False),
         ]
 
         self.mock_as_api.query_alias.return_value = make_awaitable(True)
@@ -275,7 +275,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         to be pushed out to interested appservices, and that the stream ID is
         updated accordingly.
         """
-        interested_service = self._mkservice(is_interested=True)
+        interested_service = self._mkservice(is_interested_in_event=True)
         services = [interested_service]
         self.mock_store.get_app_services.return_value = services
         self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
@@ -304,7 +304,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         Test sending out of order ephemeral events to the appservice handler
         are ignored.
         """
-        interested_service = self._mkservice(is_interested=True)
+        interested_service = self._mkservice(is_interested_in_event=True)
         services = [interested_service]
 
         self.mock_store.get_app_services.return_value = services
@@ -325,17 +325,45 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             interested_service, ephemeral=[]
         )
 
-    def _mkservice(self, is_interested, protocols=None):
+    def _mkservice(
+        self, is_interested_in_event: bool, protocols: Optional[Iterable] = None
+    ) -> Mock:
+        """
+        Create a new mock representing an ApplicationService.
+
+        Args:
+            is_interested_in_event: Whether this application service will be considered
+                interested in all events.
+            protocols: The third-party protocols that this application service claims to
+                support.
+
+        Returns:
+            A mock representing the ApplicationService.
+        """
         service = Mock()
-        service.is_interested.return_value = make_awaitable(is_interested)
+        service.is_interested_in_event.return_value = make_awaitable(
+            is_interested_in_event
+        )
         service.token = "mock_service_token"
         service.url = "mock_service_url"
         service.protocols = protocols
         return service
 
-    def _mkservice_alias(self, is_interested_in_alias):
+    def _mkservice_alias(self, is_room_alias_in_namespace: bool) -> Mock:
+        """
+        Create a new mock representing an ApplicationService that is or is not interested
+        any given room aliase.
+
+        Args:
+            is_room_alias_in_namespace: If true, the application service will be interested
+                in all room aliases that are queried against it. If false, the application
+                service will not be interested in any room aliases.
+
+        Returns:
+            A mock representing the ApplicationService.
+        """
         service = Mock()
-        service.is_interested_in_alias.return_value = is_interested_in_alias
+        service.is_room_alias_in_namespace.return_value = is_room_alias_in_namespace
         service.token = "mock_service_token"
         service.url = "mock_service_url"
         return service