diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index cb7ebc31e7..1ccb63c7be 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -404,7 +404,7 @@ _DEFAULT_SERIALIZE_EVENT_CONFIG = SerializeEventConfig()
def serialize_event(
- e: Union[JsonDict, EventBase],
+ e: EventBase,
time_now_ms: int,
*,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
@@ -420,10 +420,6 @@ def serialize_event(
The serialized event dictionary.
"""
- # FIXME(erikj): To handle the case of presence events and the like
- if not isinstance(e, EventBase):
- return e
-
time_now_ms = int(time_now_ms)
# Should this strip out None's?
@@ -531,7 +527,7 @@ class EventClientSerializer:
async def serialize_event(
self,
- event: Union[JsonDict, EventBase],
+ event: EventBase,
time_now: int,
*,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
@@ -549,10 +545,6 @@ class EventClientSerializer:
Returns:
The serialized event
"""
- # To handle the case of presence events and the like
- if not isinstance(event, EventBase):
- return event
-
serialized_event = serialize_event(event, time_now, config=config)
new_unsigned = {}
@@ -656,7 +648,7 @@ class EventClientSerializer:
async def serialize_events(
self,
- events: Iterable[Union[JsonDict, EventBase]],
+ events: Iterable[EventBase],
time_now: int,
*,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 36404d9c78..aa4d3f2e9e 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -20,7 +20,7 @@
import logging
import random
-from typing import TYPE_CHECKING, Iterable, List, Optional
+from typing import TYPE_CHECKING, Iterable, List, Optional, cast
from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState
from synapse.api.errors import AuthError, SynapseError
@@ -29,7 +29,7 @@ from synapse.events.utils import SerializeEventConfig
from synapse.handlers.presence import format_user_presence_state
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, Requester, UserID
+from synapse.types import JsonDict, Requester, StreamKeyType, UserID
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
@@ -93,49 +93,54 @@ class EventStreamHandler:
is_guest=requester.is_guest,
explicit_room_id=room_id,
)
- events = stream_result.events
+ events_by_source = stream_result.events_by_source
time_now = self.clock.time_msec()
# When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users.
- to_add: List[JsonDict] = []
- for event in events:
- if not isinstance(event, EventBase):
+ to_return: List[JsonDict] = []
+ for keyname, source_events in events_by_source.items():
+ if keyname != StreamKeyType.ROOM:
+ e = cast(List[JsonDict], source_events)
+ to_return.extend(e)
continue
- if event.type == EventTypes.Member:
- if event.membership != Membership.JOIN:
- continue
- # Send down presence.
- if event.state_key == requester.user.to_string():
- # Send down presence for everyone in the room.
- users: Iterable[str] = await self.store.get_users_in_room(
- event.room_id
+
+ events = cast(List[EventBase], source_events)
+
+ serialized_events = await self._event_serializer.serialize_events(
+ events,
+ time_now,
+ config=SerializeEventConfig(
+ as_client_event=as_client_event, requester=requester
+ ),
+ )
+ to_return.extend(serialized_events)
+
+ for event in events:
+ if event.type == EventTypes.Member:
+ if event.membership != Membership.JOIN:
+ continue
+ # Send down presence.
+ if event.state_key == requester.user.to_string():
+ # Send down presence for everyone in the room.
+ users: Iterable[str] = await self.store.get_users_in_room(
+ event.room_id
+ )
+ else:
+ users = [event.state_key]
+
+ states = await presence_handler.get_states(users)
+ to_return.extend(
+ {
+ "type": EduTypes.PRESENCE,
+ "content": format_user_presence_state(state, time_now),
+ }
+ for state in states
)
- else:
- users = [event.state_key]
-
- states = await presence_handler.get_states(users)
- to_add.extend(
- {
- "type": EduTypes.PRESENCE,
- "content": format_user_presence_state(state, time_now),
- }
- for state in states
- )
-
- events.extend(to_add)
-
- chunks = await self._event_serializer.serialize_events(
- events,
- time_now,
- config=SerializeEventConfig(
- as_client_event=as_client_event, requester=requester
- ),
- )
chunk = {
- "chunk": chunks,
+ "chunk": to_return,
"start": await stream_result.start_token.to_string(self.store),
"end": await stream_result.end_token.to_string(self.store),
}
diff --git a/synapse/notifier.py b/synapse/notifier.py
index dec47add7e..4213fd7c6f 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -198,12 +198,12 @@ class _NotifierUserStream:
@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventStreamResult:
- events: List[Union[JsonDict, EventBase]]
+ events_by_source: Dict[StreamKeyType, List[Union[JsonDict, EventBase]]]
start_token: StreamToken
end_token: StreamToken
def __bool__(self) -> bool:
- return bool(self.events)
+ return any(bool(e) for e in self.events_by_source.values())
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -694,12 +694,12 @@ class Notifier:
before_token: StreamToken, after_token: StreamToken
) -> EventStreamResult:
if after_token == before_token:
- return EventStreamResult([], from_token, from_token)
+ return EventStreamResult({}, from_token, from_token)
# The events fetched from each source are a JsonDict, EventBase, or
# UserPresenceState, but see below for UserPresenceState being
# converted to JsonDict.
- events: List[Union[JsonDict, EventBase]] = []
+ events_by_source: Dict[StreamKeyType, List[Union[JsonDict, EventBase]]] = {}
end_token = from_token
for keyname, source in self.event_sources.sources.get_sources():
@@ -734,10 +734,12 @@ class Notifier:
for event in new_events
]
- events.extend(new_events)
+ if new_events:
+ events_by_source.setdefault(keyname, []).extend(new_events)
+
end_token = end_token.copy_and_replace(keyname, new_key)
- return EventStreamResult(events, from_token, end_token)
+ return EventStreamResult(events_by_source, from_token, end_token)
user_id_for_stream = user.to_string()
if is_peeking:
|