diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 8a7ceb0df2..29ec18e25e 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -21,7 +21,7 @@ import itertools
import logging
from collections.abc import Container
from http import HTTPStatus
-from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import attr
from signedjson.key import decode_verify_key_bytes
@@ -69,20 +69,29 @@ from synapse.replication.http.federation import (
ReplicationFederationSendEventsRestServlet,
ReplicationStoreRoomOnInviteRestServlet,
)
-from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
-from synapse.state import StateResolutionStore, resolve_events_with_store
-from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
-from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
+from synapse.state import StateResolutionStore
+from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.types import (
+ JsonDict,
+ MutableStateMap,
+ PersistedEventPosition,
+ RoomStreamToken,
+ StateMap,
+ UserID,
+ get_domain_from_id,
+)
from synapse.util.async_helpers import Linearizer, concurrently_execute
-from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import shortstr
from synapse.visibility import filter_events_for_server
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
-@attr.s
+@attr.s(slots=True)
class _NewEventInfo:
"""Holds information about a received event, ready for passing to _handle_new_events
@@ -96,7 +105,7 @@ class _NewEventInfo:
event = attr.ib(type=EventBase)
state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
- auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
+ auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None)
class FederationHandler(BaseHandler):
@@ -110,8 +119,8 @@ class FederationHandler(BaseHandler):
rooms.
"""
- def __init__(self, hs):
- super(FederationHandler, self).__init__(hs)
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
self.hs = hs
@@ -120,11 +129,11 @@ class FederationHandler(BaseHandler):
self.state_store = self.storage.state
self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
+ self._state_resolution_handler = hs.get_state_resolution_handler()
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.action_generator = hs.get_action_generator()
self.is_mine_id = hs.is_mine_id
- self.pusher_pool = hs.get_pusherpool()
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self._message_handler = hs.get_message_handler()
@@ -135,9 +144,6 @@ class FederationHandler(BaseHandler):
self._replication = hs.get_replication_data_handler()
self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
- self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client(
- hs
- )
self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
hs
)
@@ -153,8 +159,9 @@ class FederationHandler(BaseHandler):
self._device_list_updater = hs.get_device_handler().device_list_updater
self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite
- # When joining a room we need to queue any events for that room up
- self.room_queues = {}
+ # When joining a room we need to queue any events for that room up.
+ # For each room, a list of (pdu, origin) tuples.
+ self.room_queues = {} # type: Dict[str, List[Tuple[EventBase, str]]]
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
self.third_party_event_rules = hs.get_third_party_event_rules()
@@ -279,7 +286,7 @@ class FederationHandler(BaseHandler):
raise Exception(
"Error fetching missing prev_events for %s: %s"
% (event_id, e)
- )
+ ) from e
# Update the set of things we've seen after trying to
# fetch the missing stuff
@@ -383,8 +390,7 @@ class FederationHandler(BaseHandler):
event_map[x.event_id] = x
room_version = await self.store.get_room_version_id(room_id)
- state_map = await resolve_events_with_store(
- self.clock,
+ state_map = await self._state_resolution_handler.resolve_events_with_store(
room_id,
room_version,
state_maps,
@@ -437,11 +443,11 @@ class FederationHandler(BaseHandler):
if not prevs - seen:
return
- latest = await self.store.get_latest_event_ids_in_room(room_id)
+ latest_list = await self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
- latest = set(latest)
+ latest = set(latest_list)
latest |= seen
logger.info(
@@ -701,31 +707,10 @@ class FederationHandler(BaseHandler):
logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
try:
- context = await self._handle_new_event(origin, event, state=state)
+ await self._handle_new_event(origin, event, state=state)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
- if event.type == EventTypes.Member:
- if event.membership == Membership.JOIN:
- # Only fire user_joined_room if the user has acutally
- # joined the room. Don't bother if the user is just
- # changing their profile info.
- newly_joined = True
-
- prev_state_ids = await context.get_prev_state_ids()
-
- prev_state_id = prev_state_ids.get((event.type, event.state_key))
- if prev_state_id:
- prev_state = await self.store.get_event(
- prev_state_id, allow_none=True
- )
- if prev_state and prev_state.membership == Membership.JOIN:
- newly_joined = False
-
- if newly_joined:
- user = UserID.from_string(event.state_key)
- await self.user_joined_room(user, room_id)
-
# For encrypted messages we check that we know about the sending device,
# if we don't then we mark the device cache for that user as stale.
if event.type == EventTypes.Encrypted:
@@ -778,7 +763,7 @@ class FederationHandler(BaseHandler):
# keys across all devices.
current_keys = [
key
- for device in cached_devices
+ for device in cached_devices.values()
for key in device.get("keys", {}).get("keys", {}).values()
]
@@ -836,6 +821,9 @@ class FederationHandler(BaseHandler):
dest, room_id, limit=limit, extremities=extremities
)
+ if not events:
+ return []
+
# ideally we'd sanity check the events here for excess prev_events etc,
# but it's hard to reject events at this point without completely
# breaking backfill in the same way that it is currently broken by
@@ -920,7 +908,8 @@ class FederationHandler(BaseHandler):
)
)
- await self._handle_new_events(dest, ev_infos, backfilled=True)
+ if ev_infos:
+ await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
@@ -940,15 +929,26 @@ class FederationHandler(BaseHandler):
return events
- async def maybe_backfill(self, room_id, current_depth):
+ async def maybe_backfill(
+ self, room_id: str, current_depth: int, limit: int
+ ) -> bool:
"""Checks the database to see if we should backfill before paginating,
and if so do.
+
+ Args:
+ room_id
+ current_depth: The depth from which we're paginating from. This is
+ used to decide if we should backfill and what extremities to
+ use.
+ limit: The number of events that the pagination request will
+ return. This is used as part of the heuristic to decide if we
+ should back paginate.
"""
extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
- return
+ return False
# We only want to paginate if we can actually see the events we'll get,
# as otherwise we'll just spend a lot of resources to get redacted
@@ -1001,16 +1001,54 @@ class FederationHandler(BaseHandler):
sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1]))
max_depth = sorted_extremeties_tuple[0][1]
+ # If we're approaching an extremity we trigger a backfill, otherwise we
+ # no-op.
+ #
+ # We chose twice the limit here as then clients paginating backwards
+ # will send pagination requests that trigger backfill at least twice
+ # using the most recent extremity before it gets removed (see below). We
+ # chose more than one times the limit in case of failure, but choosing a
+ # much larger factor will result in triggering a backfill request much
+ # earlier than necessary.
+ if current_depth - 2 * limit > max_depth:
+ logger.debug(
+ "Not backfilling as we don't need to. %d < %d - 2 * %d",
+ max_depth,
+ current_depth,
+ limit,
+ )
+ return False
+
+ logger.debug(
+ "room_id: %s, backfill: current_depth: %s, max_depth: %s, extrems: %s",
+ room_id,
+ current_depth,
+ max_depth,
+ sorted_extremeties_tuple,
+ )
+
+ # We ignore extremities that have a greater depth than our current depth
+ # as:
+ # 1. we don't really care about getting events that have happened
+ # before our current position; and
+ # 2. we have likely previously tried and failed to backfill from that
+ # extremity, so to avoid getting "stuck" requesting the same
+ # backfill repeatedly we drop those extremities.
+ filtered_sorted_extremeties_tuple = [
+ t for t in sorted_extremeties_tuple if int(t[1]) <= current_depth
+ ]
+
+ # However, we need to check that the filtered extremities are non-empty.
+ # If they are empty then either we can a) bail or b) still attempt to
+ # backill. We opt to try backfilling anyway just in case we do get
+ # relevant events.
+ if filtered_sorted_extremeties_tuple:
+ sorted_extremeties_tuple = filtered_sorted_extremeties_tuple
+
# We don't want to specify too many extremities as it causes the backfill
# request URI to be too long.
extremities = dict(sorted_extremeties_tuple[:5])
- if current_depth > max_depth:
- logger.debug(
- "Not backfilling as we don't need to. %d < %d", max_depth, current_depth
- )
- return
-
# Now we need to decide which hosts to hit first.
# First we try hosts that are already in the room
@@ -1213,7 +1251,7 @@ class FederationHandler(BaseHandler):
event_infos.append(_NewEventInfo(event, None, auth))
await self._handle_new_events(
- destination, event_infos,
+ destination, room_id, event_infos,
)
def _sanity_check_event(self, ev):
@@ -1360,15 +1398,15 @@ class FederationHandler(BaseHandler):
)
max_stream_id = await self._persist_auth_tree(
- origin, auth_chain, state, event, room_version_obj
+ origin, room_id, auth_chain, state, event, room_version_obj
)
# We wait here until this instance has seen the events come down
# replication (if we're using replication) as the below uses caches.
- #
- # TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position(
- self.config.worker.writers.events, "events", max_stream_id
+ self.config.worker.events_shard_config.get_instance(room_id),
+ "events",
+ max_stream_id,
)
# Check whether this room is the result of an upgrade of a room we already know
@@ -1547,11 +1585,6 @@ class FederationHandler(BaseHandler):
event.signatures,
)
- if event.type == EventTypes.Member:
- if event.content["membership"] == Membership.JOIN:
- user = UserID.from_string(event.state_key)
- await self.user_joined_room(user, event.room_id)
-
prev_state_ids = await context.get_prev_state_ids()
state_ids = list(prev_state_ids.values())
@@ -1629,7 +1662,7 @@ class FederationHandler(BaseHandler):
)
context = await self.state_handler.compute_event_context(event)
- await self.persist_events_and_notify([(event, context)])
+ await self.persist_events_and_notify(event.room_id, [(event, context)])
return event
@@ -1656,7 +1689,9 @@ class FederationHandler(BaseHandler):
await self.federation_client.send_leave(host_list, event)
context = await self.state_handler.compute_event_context(event)
- stream_id = await self.persist_events_and_notify([(event, context)])
+ stream_id = await self.persist_events_and_notify(
+ event.room_id, [(event, context)]
+ )
return event, stream_id
@@ -1787,9 +1822,7 @@ class FederationHandler(BaseHandler):
"""Returns the state at the event. i.e. not including said event.
"""
- event = await self.store.get_event(
- event_id, allow_none=False, check_room_id=room_id
- )
+ event = await self.store.get_event(event_id, check_room_id=room_id)
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
@@ -1815,9 +1848,7 @@ class FederationHandler(BaseHandler):
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event.
"""
- event = await self.store.get_event(
- event_id, allow_none=False, check_room_id=room_id
- )
+ event = await self.store.get_event(event_id, check_room_id=room_id)
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
@@ -1887,8 +1918,8 @@ class FederationHandler(BaseHandler):
else:
return None
- def get_min_depth_for_context(self, context):
- return self.store.get_min_depth(context)
+ async def get_min_depth_for_context(self, context):
+ return await self.store.get_min_depth(context)
async def _handle_new_event(
self, origin, event, state=None, auth_events=None, backfilled=False
@@ -1908,7 +1939,7 @@ class FederationHandler(BaseHandler):
)
await self.persist_events_and_notify(
- [(event, context)], backfilled=backfilled
+ event.room_id, [(event, context)], backfilled=backfilled
)
except Exception:
run_in_background(
@@ -1921,6 +1952,7 @@ class FederationHandler(BaseHandler):
async def _handle_new_events(
self,
origin: str,
+ room_id: str,
event_infos: Iterable[_NewEventInfo],
backfilled: bool = False,
) -> None:
@@ -1952,6 +1984,7 @@ class FederationHandler(BaseHandler):
)
await self.persist_events_and_notify(
+ room_id,
[
(ev_info.event, context)
for ev_info, context in zip(event_infos, contexts)
@@ -1962,6 +1995,7 @@ class FederationHandler(BaseHandler):
async def _persist_auth_tree(
self,
origin: str,
+ room_id: str,
auth_events: List[EventBase],
state: List[EventBase],
event: EventBase,
@@ -1976,6 +2010,7 @@ class FederationHandler(BaseHandler):
Args:
origin: Where the events came from
+ room_id,
auth_events
state
event
@@ -2050,31 +2085,34 @@ class FederationHandler(BaseHandler):
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
await self.persist_events_and_notify(
+ room_id,
[
(e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state)
- ]
+ ],
)
new_event_context = await self.state_handler.compute_event_context(
event, old_state=state
)
- return await self.persist_events_and_notify([(event, new_event_context)])
+ return await self.persist_events_and_notify(
+ room_id, [(event, new_event_context)]
+ )
async def _prep_event(
self,
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
- auth_events: Optional[StateMap[EventBase]],
+ auth_events: Optional[MutableStateMap[EventBase]],
backfilled: bool,
) -> EventContext:
context = await self.state_handler.compute_event_context(event, old_state=state)
if not auth_events:
prev_state_ids = await context.get_prev_state_ids()
- auth_events_ids = await self.auth.compute_auth_events(
+ auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_x = await self.store.get_events(auth_events_ids)
@@ -2117,8 +2155,8 @@ class FederationHandler(BaseHandler):
if backfilled or event.internal_metadata.is_outlier():
return
- extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
- extrem_ids = set(extrem_ids)
+ extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
+ extrem_ids = set(extrem_ids_list)
prev_event_ids = set(event.prev_event_ids())
if extrem_ids == prev_event_ids:
@@ -2143,15 +2181,17 @@ class FederationHandler(BaseHandler):
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.
- state_sets = await self.state_store.get_state_groups(
+ state_sets_d = await self.state_store.get_state_groups(
event.room_id, extrem_ids
)
- state_sets = list(state_sets.values())
+ state_sets = list(state_sets_d.values()) # type: List[Iterable[EventBase]]
state_sets.append(state)
- current_state_ids = await self.state_handler.resolve_events(
+ current_states = await self.state_handler.resolve_events(
room_version, state_sets, event
)
- current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
+ current_state_ids = {
+ k: e.event_id for k, e in current_states.items()
+ } # type: StateMap[str]
else:
current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
@@ -2163,11 +2203,13 @@ class FederationHandler(BaseHandler):
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
- current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
+ current_state_ids_list = [
+ e for k, e in current_state_ids.items() if k in auth_types
+ ]
- current_auth_events = await self.store.get_events(current_state_ids)
+ auth_events_map = await self.store.get_events(current_state_ids_list)
current_auth_events = {
- (e.type, e.state_key): e for e in current_auth_events.values()
+ (e.type, e.state_key): e for e in auth_events_map.values()
}
try:
@@ -2183,9 +2225,7 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
- event = await self.store.get_event(
- event_id, allow_none=False, check_room_id=room_id
- )
+ event = await self.store.get_event(event_id, check_room_id=room_id)
# Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong.
@@ -2237,7 +2277,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
context: EventContext,
- auth_events: StateMap[EventBase],
+ auth_events: MutableStateMap[EventBase],
) -> EventContext:
"""
@@ -2288,7 +2328,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
context: EventContext,
- auth_events: StateMap[EventBase],
+ auth_events: MutableStateMap[EventBase],
) -> EventContext:
"""Helper for do_auth. See there for docs.
@@ -2480,7 +2520,7 @@ class FederationHandler(BaseHandler):
}
current_state_ids = await context.get_current_state_ids()
- current_state_ids = dict(current_state_ids)
+ current_state_ids = dict(current_state_ids) # type: ignore
current_state_ids.update(state_updates)
@@ -2909,6 +2949,7 @@ class FederationHandler(BaseHandler):
async def persist_events_and_notify(
self,
+ room_id: str,
event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
backfilled: bool = False,
) -> int:
@@ -2916,20 +2957,26 @@ class FederationHandler(BaseHandler):
necessary.
Args:
- event_and_contexts:
+ room_id: The room ID of events being persisted.
+ event_and_contexts: Sequence of events with their associated
+ context that should be persisted. All events must belong to
+ the same room.
backfilled: Whether these events are a result of
backfilling or not
"""
- if self.config.worker.writers.events != self._instance_name:
+ instance = self.config.worker.events_shard_config.get_instance(room_id)
+ if instance != self._instance_name:
result = await self._send_events(
- instance_name=self.config.worker.writers.events,
+ instance_name=instance,
store=self.store,
+ room_id=room_id,
event_and_contexts=event_and_contexts,
backfilled=backfilled,
)
return result["max_stream_id"]
else:
- max_stream_id = await self.storage.persistence.persist_events(
+ assert self.storage.persistence
+ max_stream_token = await self.storage.persistence.persist_events(
event_and_contexts, backfilled=backfilled
)
@@ -2940,12 +2987,12 @@ class FederationHandler(BaseHandler):
if not backfilled: # Never notify for backfilled events
for event, _ in event_and_contexts:
- await self._notify_persisted_event(event, max_stream_id)
+ await self._notify_persisted_event(event, max_stream_token)
- return max_stream_id
+ return max_stream_token.stream
async def _notify_persisted_event(
- self, event: EventBase, max_stream_id: int
+ self, event: EventBase, max_stream_token: RoomStreamToken
) -> None:
"""Checks to see if notifier/pushers should be notified about the
event or not.
@@ -2971,13 +3018,13 @@ class FederationHandler(BaseHandler):
elif event.internal_metadata.is_outlier():
return
- event_stream_id = event.internal_metadata.stream_ordering
+ event_pos = PersistedEventPosition(
+ self._instance_name, event.internal_metadata.stream_ordering
+ )
self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id, extra_users=extra_users
+ event, event_pos, max_stream_token, extra_users=extra_users
)
- await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
-
async def _clean_room_for_join(self, room_id: str) -> None:
"""Called to clean up any data in DB for a given room, ready for the
server to join the room.
@@ -2990,16 +3037,6 @@ class FederationHandler(BaseHandler):
else:
await self.store.clean_room_for_join(room_id)
- async def user_joined_room(self, user: UserID, room_id: str) -> None:
- """Called when a new user has joined the room
- """
- if self.config.worker_app:
- await self._notify_user_membership_change(
- room_id=room_id, user_id=user.to_string(), change="joined"
- )
- else:
- user_joined_room(self.distributor, user, room_id)
-
async def get_room_complexity(
self, remote_room_hosts: List[str], room_id: str
) -> Optional[dict]:
|