summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2024-01-13 11:18:48 +0000
committerErik Johnston <erik@matrix.org>2024-01-15 21:27:03 +0000
commitc836cb988ec61b12b7cccc92776e1c473c589151 (patch)
tree0dc3bad2a9fafdbb9818d633fbd2a3f1c7adf89b
parentBump service-identity from 23.1.0 to 24.1.0 (#16816) (diff)
downloadsynapse-github/erikj/better_events_typing.tar.xz
-rw-r--r--synapse/events/utils.py14
-rw-r--r--synapse/handlers/events.py77
-rw-r--r--synapse/notifier.py14
3 files changed, 52 insertions, 53 deletions
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index cb7ebc31e7..1ccb63c7be 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -404,7 +404,7 @@ _DEFAULT_SERIALIZE_EVENT_CONFIG = SerializeEventConfig()
 
 
 def serialize_event(
-    e: Union[JsonDict, EventBase],
+    e: EventBase,
     time_now_ms: int,
     *,
     config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
@@ -420,10 +420,6 @@ def serialize_event(
         The serialized event dictionary.
     """
 
-    # FIXME(erikj): To handle the case of presence events and the like
-    if not isinstance(e, EventBase):
-        return e
-
     time_now_ms = int(time_now_ms)
 
     # Should this strip out None's?
@@ -531,7 +527,7 @@ class EventClientSerializer:
 
     async def serialize_event(
         self,
-        event: Union[JsonDict, EventBase],
+        event: EventBase,
         time_now: int,
         *,
         config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
@@ -549,10 +545,6 @@ class EventClientSerializer:
         Returns:
             The serialized event
         """
-        # To handle the case of presence events and the like
-        if not isinstance(event, EventBase):
-            return event
-
         serialized_event = serialize_event(event, time_now, config=config)
 
         new_unsigned = {}
@@ -656,7 +648,7 @@ class EventClientSerializer:
 
     async def serialize_events(
         self,
-        events: Iterable[Union[JsonDict, EventBase]],
+        events: Iterable[EventBase],
         time_now: int,
         *,
         config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 36404d9c78..aa4d3f2e9e 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -20,7 +20,7 @@
 
 import logging
 import random
-from typing import TYPE_CHECKING, Iterable, List, Optional
+from typing import TYPE_CHECKING, Iterable, List, Optional, cast
 
 from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState
 from synapse.api.errors import AuthError, SynapseError
@@ -29,7 +29,7 @@ from synapse.events.utils import SerializeEventConfig
 from synapse.handlers.presence import format_user_presence_state
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, Requester, UserID
+from synapse.types import JsonDict, Requester, StreamKeyType, UserID
 from synapse.visibility import filter_events_for_client
 
 if TYPE_CHECKING:
@@ -93,49 +93,54 @@ class EventStreamHandler:
                 is_guest=requester.is_guest,
                 explicit_room_id=room_id,
             )
-            events = stream_result.events
+            events_by_source = stream_result.events_by_source
 
             time_now = self.clock.time_msec()
 
             # When the user joins a new room, or another user joins a currently
             # joined room, we need to send down presence for those users.
-            to_add: List[JsonDict] = []
-            for event in events:
-                if not isinstance(event, EventBase):
+            to_return: List[JsonDict] = []
+            for keyname, source_events in events_by_source.items():
+                if keyname != StreamKeyType.ROOM:
+                    e = cast(List[JsonDict], source_events)
+                    to_return.extend(e)
                     continue
-                if event.type == EventTypes.Member:
-                    if event.membership != Membership.JOIN:
-                        continue
-                    # Send down presence.
-                    if event.state_key == requester.user.to_string():
-                        # Send down presence for everyone in the room.
-                        users: Iterable[str] = await self.store.get_users_in_room(
-                            event.room_id
+
+                events = cast(List[EventBase], source_events)
+
+                serialized_events = await self._event_serializer.serialize_events(
+                    events,
+                    time_now,
+                    config=SerializeEventConfig(
+                        as_client_event=as_client_event, requester=requester
+                    ),
+                )
+                to_return.extend(serialized_events)
+
+                for event in events:
+                    if event.type == EventTypes.Member:
+                        if event.membership != Membership.JOIN:
+                            continue
+                        # Send down presence.
+                        if event.state_key == requester.user.to_string():
+                            # Send down presence for everyone in the room.
+                            users: Iterable[str] = await self.store.get_users_in_room(
+                                event.room_id
+                            )
+                        else:
+                            users = [event.state_key]
+
+                        states = await presence_handler.get_states(users)
+                        to_return.extend(
+                            {
+                                "type": EduTypes.PRESENCE,
+                                "content": format_user_presence_state(state, time_now),
+                            }
+                            for state in states
                         )
-                    else:
-                        users = [event.state_key]
-
-                    states = await presence_handler.get_states(users)
-                    to_add.extend(
-                        {
-                            "type": EduTypes.PRESENCE,
-                            "content": format_user_presence_state(state, time_now),
-                        }
-                        for state in states
-                    )
-
-            events.extend(to_add)
-
-            chunks = await self._event_serializer.serialize_events(
-                events,
-                time_now,
-                config=SerializeEventConfig(
-                    as_client_event=as_client_event, requester=requester
-                ),
-            )
 
             chunk = {
-                "chunk": chunks,
+                "chunk": to_return,
                 "start": await stream_result.start_token.to_string(self.store),
                 "end": await stream_result.end_token.to_string(self.store),
             }
diff --git a/synapse/notifier.py b/synapse/notifier.py
index dec47add7e..4213fd7c6f 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -198,12 +198,12 @@ class _NotifierUserStream:
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class EventStreamResult:
-    events: List[Union[JsonDict, EventBase]]
+    events_by_source: Dict[StreamKeyType, List[Union[JsonDict, EventBase]]]
     start_token: StreamToken
     end_token: StreamToken
 
     def __bool__(self) -> bool:
-        return bool(self.events)
+        return any(bool(e) for e in self.events_by_source.values())
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -694,12 +694,12 @@ class Notifier:
             before_token: StreamToken, after_token: StreamToken
         ) -> EventStreamResult:
             if after_token == before_token:
-                return EventStreamResult([], from_token, from_token)
+                return EventStreamResult({}, from_token, from_token)
 
             # The events fetched from each source are a JsonDict, EventBase, or
             # UserPresenceState, but see below for UserPresenceState being
             # converted to JsonDict.
-            events: List[Union[JsonDict, EventBase]] = []
+            events_by_source: Dict[StreamKeyType, List[Union[JsonDict, EventBase]]] = {}
             end_token = from_token
 
             for keyname, source in self.event_sources.sources.get_sources():
@@ -734,10 +734,12 @@ class Notifier:
                         for event in new_events
                     ]
 
-                events.extend(new_events)
+                if new_events:
+                    events_by_source.setdefault(keyname, []).extend(new_events)
+
                 end_token = end_token.copy_and_replace(keyname, new_key)
 
-            return EventStreamResult(events, from_token, end_token)
+            return EventStreamResult(events_by_source, from_token, end_token)
 
         user_id_for_stream = user.to_string()
         if is_peeking: