summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8671.misc1
-rw-r--r--synapse/handlers/message.py35
-rw-r--r--synapse/storage/databases/main/events_worker.py54
3 files changed, 61 insertions, 29 deletions
diff --git a/changelog.d/8671.misc b/changelog.d/8671.misc
new file mode 100644
index 0000000000..bef8dc425a
--- /dev/null
+++ b/changelog.d/8671.misc
@@ -0,0 +1 @@
+Abstract some invite-related code in preparation for landing knocking.
\ No newline at end of file
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index f1b4d35182..4ead75ec3a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1100,34 +1100,13 @@ class EventCreationHandler:
 
         if event.type == EventTypes.Member:
             if event.content["membership"] == Membership.INVITE:
-
-                def is_inviter_member_event(e):
-                    return e.type == EventTypes.Member and e.sender == event.sender
-
-                current_state_ids = await context.get_current_state_ids()
-
-                # We know this event is not an outlier, so this must be
-                # non-None.
-                assert current_state_ids is not None
-
-                state_to_include_ids = [
-                    e_id
-                    for k, e_id in current_state_ids.items()
-                    if k[0] in self.room_invite_state_types
-                    or k == (EventTypes.Member, event.sender)
-                ]
-
-                state_to_include = await self.store.get_events(state_to_include_ids)
-
-                event.unsigned["invite_room_state"] = [
-                    {
-                        "type": e.type,
-                        "state_key": e.state_key,
-                        "content": e.content,
-                        "sender": e.sender,
-                    }
-                    for e in state_to_include.values()
-                ]
+                event.unsigned[
+                    "invite_room_state"
+                ] = await self.store.get_stripped_room_state_from_event_context(
+                    context,
+                    self.room_invite_state_types,
+                    membership_user_id=event.sender,
+                )
 
                 invitee = UserID.from_string(event.state_key)
                 if not self.hs.is_mine(invitee):
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 6e7f16f39c..cd1f31aa62 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -31,6 +31,7 @@ from synapse.api.room_versions import (
     RoomVersions,
 )
 from synapse.events import EventBase, make_event_from_dict
+from synapse.events.snapshot import EventContext
 from synapse.events.utils import prune_event
 from synapse.logging.context import PreserveLoggingContext, current_context
 from synapse.metrics.background_process_metrics import (
@@ -44,7 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
 from synapse.storage.database import DatabasePool
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
-from synapse.types import Collection, get_domain_from_id
+from synapse.types import Collection, JsonDict, get_domain_from_id
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
@@ -525,6 +526,57 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_map
 
+    async def get_stripped_room_state_from_event_context(
+        self,
+        context: EventContext,
+        state_types_to_include: List[EventTypes],
+        membership_user_id: Optional[str],
+    ) -> List[JsonDict]:
+        """
+        Retrieve the stripped state from a room, given an event context to retrieve state
+        from as well as the state types to include. Optionally, include the membership
+        events from a specific user.
+
+        "Stripped" state means that only the `type`, `state_key`, `content` and `sender` keys
+        are included from each state event.
+
+        Args:
+            context: The event context to retrieve state of the room from.
+            state_types_to_include: The type of state events to include.
+            membership_user_id: An optional user ID to include the stripped membership state
+                events of. This is useful when generating the stripped state of a room for
+                invites. We want to send membership events of the inviter, so that the
+                invitee can display the inviter's profile information if the room lacks any.
+
+        Returns:
+            A list of dictionaries, each representing a stripped state event from the room.
+        """
+        current_state_ids = await context.get_current_state_ids()
+
+        # We know this event is not an outlier, so this must be
+        # non-None.
+        assert current_state_ids is not None
+
+        # The state to include
+        state_to_include_ids = [
+            e_id
+            for k, e_id in current_state_ids.items()
+            if k[0] in state_types_to_include
+            or (membership_user_id and k == (EventTypes.Member, membership_user_id))
+        ]
+
+        state_to_include = await self.get_events(state_to_include_ids)
+
+        return [
+            {
+                "type": e.type,
+                "state_key": e.state_key,
+                "content": e.content,
+                "sender": e.sender,
+            }
+            for e in state_to_include.values()
+        ]
+
     def _do_fetch(self, conn):
         """Takes a database connection and waits for requests for events from
         the _event_fetch_list queue.