diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index cf98b0ab48..33ffef521b 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -45,8 +45,14 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.logging import opentracing
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.opentracing import (
+ SynapseTags,
+ active_span,
+ set_tag,
+ start_active_span_follows_from,
+ trace,
+)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.controllers.state import StateStorageController
from synapse.storage.databases import Databases
@@ -198,9 +204,8 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
process to to so, calling the per_item_callback for each item.
Args:
- room_id (str):
- task (_EventPersistQueueTask): A _PersistEventsTask or
- _UpdateCurrentStateTask to process.
+ room_id:
+ task: A _PersistEventsTask or _UpdateCurrentStateTask to process.
Returns:
the result returned by the `_per_item_callback` passed to
@@ -223,7 +228,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
queue.append(end_item)
# also add our active opentracing span to the item so that we get a link back
- span = opentracing.active_span()
+ span = active_span()
if span:
end_item.parent_opentracing_span_contexts.append(span.context)
@@ -234,7 +239,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
res = await make_deferred_yieldable(end_item.deferred.observe())
# add another opentracing span which links to the persist trace.
- with opentracing.start_active_span_follows_from(
+ with start_active_span_follows_from(
f"{task.name}_complete", (end_item.opentracing_span_context,)
):
pass
@@ -266,7 +271,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
queue = self._get_drainining_queue(room_id)
for item in queue:
try:
- with opentracing.start_active_span_follows_from(
+ with start_active_span_follows_from(
item.task.name,
item.parent_opentracing_span_contexts,
inherit_force_tracing=True,
@@ -355,7 +360,7 @@ class EventsPersistenceStorageController:
f"Found an unexpected task type in event persistence queue: {task}"
)
- @opentracing.trace
+ @trace
async def persist_events(
self,
events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
@@ -380,9 +385,21 @@ class EventsPersistenceStorageController:
PartialStateConflictError: if attempting to persist a partial state event in
a room that has been un-partial stated.
"""
+ event_ids: List[str] = []
partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
+ event_ids.append(event.event_id)
+
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+ str(event_ids),
+ )
+ set_tag(
+ SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+ str(len(event_ids)),
+ )
+ set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled))
async def enqueue(
item: Tuple[str, List[Tuple[EventBase, EventContext]]]
@@ -405,20 +422,22 @@ class EventsPersistenceStorageController:
for d in ret_vals:
replaced_events.update(d)
- events = []
+ persisted_events = []
for event, _ in events_and_contexts:
existing_event_id = replaced_events.get(event.event_id)
if existing_event_id:
- events.append(await self.main_store.get_event(existing_event_id))
+ persisted_events.append(
+ await self.main_store.get_event(existing_event_id)
+ )
else:
- events.append(event)
+ persisted_events.append(event)
return (
- events,
+ persisted_events,
self.main_store.get_room_max_token(),
)
- @opentracing.trace
+ @trace
async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
@@ -580,11 +599,6 @@ class EventsPersistenceStorageController:
# room
state_delta_for_room: Dict[str, DeltaState] = {}
- # Set of remote users which were in rooms the server has left. We
- # should check if we still share any rooms and if not we mark their
- # device lists as stale.
- potentially_left_users: Set[str] = set()
-
if not backfilled:
with Measure(self._clock, "_calculate_state_and_extrem"):
# Work out the new "current state" for each room.
@@ -698,13 +712,9 @@ class EventsPersistenceStorageController:
room_id,
ev_ctx_rm,
delta,
- current_state,
- potentially_left_users,
)
if not is_still_joined:
logger.info("Server no longer in room %s", room_id)
- latest_event_ids = set()
- current_state = {}
delta.no_longer_in_room = True
state_delta_for_room[room_id] = delta
@@ -717,8 +727,6 @@ class EventsPersistenceStorageController:
inhibit_local_membership_updates=backfilled,
)
- await self._handle_potentially_left_users(potentially_left_users)
-
return replaced_events
async def _calculate_new_extremities(
@@ -1094,8 +1102,6 @@ class EventsPersistenceStorageController:
room_id: str,
ev_ctx_rm: List[Tuple[EventBase, EventContext]],
delta: DeltaState,
- current_state: Optional[StateMap[str]],
- potentially_left_users: Set[str],
) -> bool:
"""Check if the server will still be joined after the given events have
been persised.
@@ -1105,11 +1111,6 @@ class EventsPersistenceStorageController:
ev_ctx_rm
delta: The delta of current state between what is in the database
and what the new current state will be.
- current_state: The new current state if it already been calculated,
- otherwise None.
- potentially_left_users: If the server has left the room, then joined
- remote users will be added to this set to indicate that the
- server may no longer be sharing a room with them.
"""
if not any(
@@ -1163,45 +1164,4 @@ class EventsPersistenceStorageController:
):
return True
- # The server will leave the room, so we go and find out which remote
- # users will still be joined when we leave.
- if current_state is None:
- current_state = await self.main_store.get_partial_current_state_ids(room_id)
- current_state = dict(current_state)
- for key in delta.to_delete:
- current_state.pop(key, None)
-
- current_state.update(delta.to_insert)
-
- remote_event_ids = [
- event_id
- for (
- typ,
- state_key,
- ), event_id in current_state.items()
- if typ == EventTypes.Member and not self.is_mine_id(state_key)
- ]
- members = await self.main_store.get_membership_from_event_ids(remote_event_ids)
- potentially_left_users.update(
- member.user_id
- for member in members.values()
- if member and member.membership == Membership.JOIN
- )
-
return False
-
- async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None:
- """Given a set of remote users check if the server still shares a room with
- them. If not then mark those users' device cache as stale.
- """
-
- if not user_ids:
- return
-
- joined_users = await self.main_store.get_users_server_still_shares_room_with(
- user_ids
- )
- left_users = user_ids - joined_users
-
- for user_id in left_users:
- await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id)
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 1e35046e07..2b31ce54bb 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -29,12 +29,15 @@ from typing import (
from synapse.api.constants import EventTypes
from synapse.events import EventBase
+from synapse.logging.opentracing import tag_args, trace
+from synapse.storage.roommember import ProfileInfo
from synapse.storage.state import StateFilter
from synapse.storage.util.partial_state_events_tracker import (
PartialCurrentStateTracker,
PartialStateEventsTracker,
)
from synapse.types import MutableStateMap, StateMap
+from synapse.util.cancellation import cancellable
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -179,6 +182,7 @@ class StateStorageController:
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
+ @trace
async def get_state_for_events(
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[EventBase]]:
@@ -225,10 +229,14 @@ class StateStorageController:
return {event: event_to_state[event] for event in event_ids}
+ @trace
+ @tag_args
+ @cancellable
async def get_state_ids_for_events(
self,
event_ids: Collection[str],
state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
@@ -237,6 +245,9 @@ class StateStorageController:
Args:
event_ids: events whose state should be returned
state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at these events and `state_filter` is not satisfied by partial state.
+ Defaults to `True`.
Returns:
A dict from event_id -> (type, state_key) -> event_id
@@ -245,8 +256,12 @@ class StateStorageController:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
- await_full_state = True
- if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+ if (
+ await_full_state
+ and state_filter
+ and not state_filter.must_await_full_state(self._is_mine_id)
+ ):
+ # Full state is not required if the state filter is restrictive enough.
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
@@ -287,8 +302,12 @@ class StateStorageController:
)
return state_map[event_id]
+ @trace
async def get_state_ids_for_event(
- self, event_id: str, state_filter: Optional[StateFilter] = None
+ self,
+ event_id: str,
+ state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
) -> StateMap[str]:
"""
Get the state dict corresponding to a particular event
@@ -296,6 +315,9 @@ class StateStorageController:
Args:
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at the event and `state_filter` is not satisfied by partial state.
+ Defaults to `True`.
Returns:
A dict from (type, state_key) -> state_event_id
@@ -305,7 +327,9 @@ class StateStorageController:
outlier or is unknown)
"""
state_map = await self.get_state_ids_for_events(
- [event_id], state_filter or StateFilter.all()
+ [event_id],
+ state_filter or StateFilter.all(),
+ await_full_state=await_full_state,
)
return state_map[event_id]
@@ -327,6 +351,9 @@ class StateStorageController:
groups, state_filter or StateFilter.all()
)
+ @trace
+ @tag_args
+ @cancellable
async def get_state_group_for_events(
self,
event_ids: Collection[str],
@@ -375,10 +402,12 @@ class StateStorageController:
event_id, room_id, prev_group, delta_ids, current_state_ids
)
+ @cancellable
async def get_current_state_ids(
self,
room_id: str,
state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
on_invalidate: Optional[Callable[[], None]] = None,
) -> StateMap[str]:
"""Get the current state event ids for a room based on the
@@ -391,13 +420,17 @@ class StateStorageController:
room_id: The room to get the state IDs of. state_filter: The state
filter used to fetch state from the
database.
+ await_full_state: if true, will block if we do not yet have complete
+ state for the room.
on_invalidate: Callback for when the `get_current_state_ids` cache
for the room gets invalidated.
Returns:
The current state of the room.
"""
- if not state_filter or state_filter.must_await_full_state(self._is_mine_id):
+ if await_full_state and (
+ not state_filter or state_filter.must_await_full_state(self._is_mine_id)
+ ):
await self._partial_state_room_tracker.await_full_state(room_id)
if state_filter and not state_filter.is_full():
@@ -468,6 +501,7 @@ class StateStorageController:
prev_stream_id, max_stream_id
)
+ @trace
async def get_current_state(
self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[EventBase]:
@@ -496,8 +530,67 @@ class StateStorageController:
return state_map.get(key)
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
- """Get current hosts in room based on current state."""
+ """Get current hosts in room based on current state.
+
+ Blocks until we have full state for the given room. This only happens for rooms
+ with partial state.
+ """
await self._partial_state_room_tracker.await_full_state(room_id)
return await self.stores.main.get_current_hosts_in_room(room_id)
+
+ async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
+ """Get current hosts in room based on current state.
+
+ Blocks until we have full state for the given room. This only happens for rooms
+ with partial state.
+
+ Returns:
+ A list of hosts in the room, sorted by longest in the room first. (aka.
+ sorted by join with the lowest depth first).
+ """
+
+ await self._partial_state_room_tracker.await_full_state(room_id)
+
+ return await self.stores.main.get_current_hosts_in_room_ordered(room_id)
+
+ async def get_current_hosts_in_room_or_partial_state_approximation(
+ self, room_id: str
+ ) -> Collection[str]:
+ """Get approximation of current hosts in room based on current state.
+
+ For rooms with full state, this is equivalent to `get_current_hosts_in_room`,
+ with the same order of results.
+
+ For rooms with partial state, no blocking occurs. Instead, the list of hosts
+ in the room at the time of joining is combined with the list of hosts which
+ joined the room afterwards. The returned list may include hosts that are not
+ actually in the room and exclude hosts that are in the room, since we may
+ calculate state incorrectly during the partial state phase. The order of results
+ is arbitrary for rooms with partial state.
+ """
+ # We have to read this list first to mitigate races with un-partial stating.
+ # This will be empty for rooms with full state.
+ hosts_at_join = await self.stores.main.get_partial_state_servers_at_join(
+ room_id
+ )
+
+ hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id)
+
+ hosts = set(hosts_at_join)
+ hosts.update(hosts_from_state)
+
+ return hosts
+
+ async def get_users_in_room_with_profiles(
+ self, room_id: str
+ ) -> Dict[str, ProfileInfo]:
+ """
+ Get the current users in the room with their profiles.
+ If the room is currently partial-stated, this will block until the room has
+ full state.
+ """
+ await self._partial_state_room_tracker.await_full_state(room_id)
+
+ return await self.stores.main.get_users_in_room_with_profiles(room_id)
|