summary refs log tree commit diff
diff options
context:
space:
mode:
authorEric Eastwood <erice@element.io>2022-08-24 14:13:12 -0500
committerGitHub <noreply@github.com>2022-08-24 14:13:12 -0500
commitd58615c82cec5bd866bedcb33e3e2a5d2a961c44 (patch)
tree0c876b95db99f6c927adc011120f6953c38e7d3e
parentWhen loading current ids, sort by `stream_id` to avoid incorrect overwrite an... (diff)
downloadsynapse-d58615c82cec5bd866bedcb33e3e2a5d2a961c44.tar.xz
Directly lookup local membership instead of getting all members in a room first (`get_users_in_room` mis-use) (#13608)
See https://github.com/matrix-org/synapse/pull/13575#discussion_r953023755
-rw-r--r--changelog.d/13608.misc1
-rw-r--r--synapse/handlers/events.py9
-rw-r--r--synapse/handlers/message.py6
-rw-r--r--synapse/handlers/room.py7
-rw-r--r--synapse/handlers/room_member.py6
-rw-r--r--synapse/server_notices/server_notices_manager.py10
-rw-r--r--synapse/storage/databases/main/roommember.py26
-rw-r--r--tests/rest/client/test_relations.py12
8 files changed, 60 insertions, 17 deletions
diff --git a/changelog.d/13608.misc b/changelog.d/13608.misc
new file mode 100644
index 0000000000..19bcc45e33
--- /dev/null
+++ b/changelog.d/13608.misc
@@ -0,0 +1 @@
+Refactor `get_users_in_room(room_id)` mis-use to lookup single local user with dedicated `check_local_user_in_room(...)` function.
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index ac13340d3a..949b69cb41 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -151,7 +151,7 @@ class EventHandler:
         """Retrieve a single specified event.
 
         Args:
-            user: The user requesting the event
+            user: The local user requesting the event
             room_id: The expected room id. We'll return None if the
                 event's room does not match.
             event_id: The event ID to obtain.
@@ -173,8 +173,11 @@ class EventHandler:
         if not event:
             return None
 
-        users = await self.store.get_users_in_room(event.room_id)
-        is_peeking = user.to_string() not in users
+        is_user_in_room = await self.store.check_local_user_in_room(
+            user_id=user.to_string(), room_id=event.room_id
+        )
+        # The user is peeking if they aren't in the room already
+        is_peeking = not is_user_in_room
 
         filtered = await filter_events_for_client(
             self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index acd3de06f6..72157d5a36 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -761,8 +761,10 @@ class EventCreationHandler:
     async def _is_server_notices_room(self, room_id: str) -> bool:
         if self.config.servernotices.server_notices_mxid is None:
             return False
-        user_ids = await self.store.get_users_in_room(room_id)
-        return self.config.servernotices.server_notices_mxid in user_ids
+        is_server_notices_room = await self.store.check_local_user_in_room(
+            user_id=self.config.servernotices.server_notices_mxid, room_id=room_id
+        )
+        return is_server_notices_room
 
     async def assert_accepted_privacy_policy(self, requester: Requester) -> None:
         """Check if a user has accepted the privacy policy
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 2bf0ebd025..2fc8264858 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1284,8 +1284,11 @@ class RoomContextHandler:
         before_limit = math.floor(limit / 2.0)
         after_limit = limit - before_limit
 
-        users = await self.store.get_users_in_room(room_id)
-        is_peeking = user.to_string() not in users
+        is_user_in_room = await self.store.check_local_user_in_room(
+            user_id=user.to_string(), room_id=room_id
+        )
+        # The user is peeking if they aren't in the room already
+        is_peeking = not is_user_in_room
 
         async def filter_evts(events: List[EventBase]) -> List[EventBase]:
             if use_admin_priviledge:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 65b9a655d4..709682622f 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1620,8 +1620,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
     async def _is_server_notice_room(self, room_id: str) -> bool:
         if self._server_notices_mxid is None:
             return False
-        user_ids = await self.store.get_users_in_room(room_id)
-        return self._server_notices_mxid in user_ids
+        is_server_notices_room = await self.store.check_local_user_in_room(
+            user_id=self._server_notices_mxid, room_id=room_id
+        )
+        return is_server_notices_room
 
 
 class RoomMemberMasterHandler(RoomMemberHandler):
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 70d054a8f4..564e3705c2 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -102,6 +102,10 @@ class ServerNoticesManager:
         Returns:
             The room's ID, or None if no room could be found.
         """
+        # If there is no server notices MXID, then there is no server notices room
+        if self.server_notices_mxid is None:
+            return None
+
         rooms = await self._store.get_rooms_for_local_user_where_membership_is(
             user_id, [Membership.INVITE, Membership.JOIN]
         )
@@ -111,8 +115,10 @@ class ServerNoticesManager:
             # be joined. This is kinda deliberate, in that if somebody somehow
             # manages to invite the system user to a room, that doesn't make it
             # the server notices room.
-            user_ids = await self._store.get_users_in_room(room.room_id)
-            if len(user_ids) <= 2 and self.server_notices_mxid in user_ids:
+            is_server_notices_room = await self._store.check_local_user_in_room(
+                user_id=self.server_notices_mxid, room_id=room.room_id
+            )
+            if is_server_notices_room:
                 # we found a room which our user shares with the system notice
                 # user
                 return room.room_id
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 046ad3a11c..9e5034b401 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -534,6 +534,32 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             desc="get_local_users_in_room",
         )
 
+    async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool:
+        """
+        Check whether a given local user is currently joined to the given room.
+
+        Returns:
+            A boolean indicating whether the user is currently joined to the room
+
+        Raises:
+            Exeption when called with a non-local user to this homeserver
+        """
+        if not self.hs.is_mine_id(user_id):
+            raise Exception(
+                "Cannot call 'check_local_user_in_room' on "
+                "non-local user %s" % (user_id,),
+            )
+
+        (
+            membership,
+            member_event_id,
+        ) = await self.get_local_current_membership_for_user_in_room(
+            user_id=user_id,
+            room_id=room_id,
+        )
+
+        return membership == Membership.JOIN
+
     async def get_local_current_membership_for_user_in_room(
         self, user_id: str, room_id: str
     ) -> Tuple[Optional[str], Optional[str]]:
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index d589f07314..651f4f415d 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -999,7 +999,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
                 bundled_aggregations,
             )
 
-        self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 6)
+        self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
 
     def test_annotation_to_annotation(self) -> None:
         """Any relation to an annotation should be ignored."""
@@ -1035,7 +1035,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
                 bundled_aggregations,
             )
 
-        self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 6)
+        self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
 
     def test_thread(self) -> None:
         """
@@ -1080,21 +1080,21 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
 
         # The "user" sent the root event and is making queries for the bundled
         # aggregations: they have participated.
-        self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
+        self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9)
         # The "user2" sent replies in the thread and is making queries for the
         # bundled aggregations: they have participated.
         #
         # Note that this re-uses some cached values, so the total number of
         # queries is much smaller.
         self._test_bundled_aggregations(
-            RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token
+            RelationTypes.THREAD, _gen_assert(True), 3, access_token=self.user2_token
         )
 
         # A user with no interactions with the thread: they have not participated.
         user3_id, user3_token = self._create_user("charlie")
         self.helper.join(self.room, user=user3_id, tok=user3_token)
         self._test_bundled_aggregations(
-            RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token
+            RelationTypes.THREAD, _gen_assert(False), 3, access_token=user3_token
         )
 
     def test_thread_with_bundled_aggregations_for_latest(self) -> None:
@@ -1142,7 +1142,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
                 bundled_aggregations["latest_event"].get("unsigned"),
             )
 
-        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
+        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
 
     def test_nested_thread(self) -> None:
         """