diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index da1acea275..845f683358 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -50,6 +50,7 @@ from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase, relation_from_event
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
+from synapse.events.utils import maybe_upsert_event_field
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
from synapse.logging import opentracing
@@ -59,7 +60,6 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.replication.http.send_events import ReplicationSendEventsRestServlet
from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.state import StateFilter
from synapse.types import (
MutableStateMap,
PersistedEventPosition,
@@ -70,6 +70,7 @@ from synapse.types import (
UserID,
create_requester,
)
+from synapse.types.state import StateFilter
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
from synapse.util.async_helpers import Linearizer, gather_results
from synapse.util.caches.expiringcache import ExpiringCache
@@ -877,6 +878,36 @@ class EventCreationHandler:
return prev_event
return None
+ async def get_event_from_transaction(
+ self,
+ requester: Requester,
+ txn_id: str,
+ room_id: str,
+ ) -> Optional[EventBase]:
+ """For the given transaction ID and room ID, check if there is a matching event.
+ If so, fetch it and return it.
+
+ Args:
+ requester: The requester making the request in the context of which we want
+ to fetch the event.
+ txn_id: The transaction ID.
+ room_id: The room ID.
+
+ Returns:
+ An event if one could be found, None otherwise.
+ """
+ if requester.access_token_id:
+ existing_event_id = await self.store.get_event_id_from_transaction_id(
+ room_id,
+ requester.user.to_string(),
+ requester.access_token_id,
+ txn_id,
+ )
+ if existing_event_id:
+ return await self.store.get_event(existing_event_id)
+
+ return None
+
async def create_and_send_nonmember_event(
self,
requester: Requester,
@@ -956,18 +987,17 @@ class EventCreationHandler:
# extremities to pile up, which in turn leads to state resolution
# taking longer.
async with self.limiter.queue(event_dict["room_id"]):
- if txn_id and requester.access_token_id:
- existing_event_id = await self.store.get_event_id_from_transaction_id(
- event_dict["room_id"],
- requester.user.to_string(),
- requester.access_token_id,
- txn_id,
+ if txn_id:
+ event = await self.get_event_from_transaction(
+ requester, txn_id, event_dict["room_id"]
)
- if existing_event_id:
- event = await self.store.get_event(existing_event_id)
+ if event:
# we know it was persisted, so must have a stream ordering
assert event.internal_metadata.stream_ordering
- return event, event.internal_metadata.stream_ordering
+ return (
+ event,
+ event.internal_metadata.stream_ordering,
+ )
event, context = await self.create_event(
requester,
@@ -1106,11 +1136,13 @@ class EventCreationHandler:
)
state_events = await self.store.get_events_as_list(state_event_ids)
# Create a StateMap[str]
- state_map = {(e.type, e.state_key): e.event_id for e in state_events}
+ current_state_ids = {
+ (e.type, e.state_key): e.event_id for e in state_events
+ }
# Actually strip down and only use the necessary auth events
auth_event_ids = self._event_auth_handler.compute_auth_events(
event=temp_event,
- current_state_ids=state_map,
+ current_state_ids=current_state_ids,
for_verification=False,
)
@@ -1360,8 +1392,16 @@ class EventCreationHandler:
else:
try:
validate_event_for_room_version(event)
+ # If we are persisting a batch of events the event(s) needed to auth the
+ # current event may be part of the batch and will not be in the DB yet
+ event_id_to_event = {e.event_id: e for e, _ in events_and_context}
+ batched_auth_events = {}
+ for event_id in event.auth_event_ids():
+ auth_event = event_id_to_event.get(event_id)
+ if auth_event:
+ batched_auth_events[event_id] = auth_event
await self._event_auth_handler.check_auth_rules_from_context(
- event, context
+ event, batched_auth_events
)
except AuthError as err:
logger.warning("Denying new event %r because %s", event, err)
@@ -1390,7 +1430,7 @@ class EventCreationHandler:
extra_users=extra_users,
),
run_in_background(
- self.cache_joined_hosts_for_event, event, context
+ self.cache_joined_hosts_for_events, events_and_context
).addErrback(
log_failure, "cache_joined_hosts_for_event failed"
),
@@ -1425,17 +1465,9 @@ class EventCreationHandler:
a room that has been un-partial stated.
"""
- for event, context in events_and_context:
- # Skip push notification actions for historical messages
- # because we don't want to notify people about old history back in time.
- # The historical messages also do not have the proper `context.current_state_ids`
- # and `state_groups` because they have `prev_events` that aren't persisted yet
- # (historical messages persisted in reverse-chronological order).
- if not event.internal_metadata.is_historical():
- with opentracing.start_active_span("calculate_push_actions"):
- await self._bulk_push_rule_evaluator.action_for_event_by_user(
- event, context
- )
+ await self._bulk_push_rule_evaluator.action_for_events_by_user(
+ events_and_context
+ )
try:
# If we're a worker we need to hit out to the master.
@@ -1491,62 +1523,65 @@ class EventCreationHandler:
await self.store.remove_push_actions_from_staging(event.event_id)
raise
- async def cache_joined_hosts_for_event(
- self, event: EventBase, context: EventContext
+ async def cache_joined_hosts_for_events(
+ self, events_and_context: List[Tuple[EventBase, EventContext]]
) -> None:
- """Precalculate the joined hosts at the event, when using Redis, so that
+ """Precalculate the joined hosts at each of the given events, when using Redis, so that
external federation senders don't have to recalculate it themselves.
"""
- if not self._external_cache.is_enabled():
- return
-
- # If external cache is enabled we should always have this.
- assert self._external_cache_joined_hosts_updates is not None
+ for event, _ in events_and_context:
+ if not self._external_cache.is_enabled():
+ return
- # We actually store two mappings, event ID -> prev state group,
- # state group -> joined hosts, which is much more space efficient
- # than event ID -> joined hosts.
- #
- # Note: We have to cache event ID -> prev state group, as we don't
- # store that in the DB.
- #
- # Note: We set the state group -> joined hosts cache if it hasn't been
- # set for a while, so that the expiry time is reset.
+ # If external cache is enabled we should always have this.
+ assert self._external_cache_joined_hosts_updates is not None
- state_entry = await self.state.resolve_state_groups_for_events(
- event.room_id, event_ids=event.prev_event_ids()
- )
+ # We actually store two mappings, event ID -> prev state group,
+ # state group -> joined hosts, which is much more space efficient
+ # than event ID -> joined hosts.
+ #
+ # Note: We have to cache event ID -> prev state group, as we don't
+ # store that in the DB.
+ #
+ # Note: We set the state group -> joined hosts cache if it hasn't been
+ # set for a while, so that the expiry time is reset.
- if state_entry.state_group:
- await self._external_cache.set(
- "event_to_prev_state_group",
- event.event_id,
- state_entry.state_group,
- expiry_ms=60 * 60 * 1000,
+ state_entry = await self.state.resolve_state_groups_for_events(
+ event.room_id, event_ids=event.prev_event_ids()
)
- if state_entry.state_group in self._external_cache_joined_hosts_updates:
- return
+ if state_entry.state_group:
+ await self._external_cache.set(
+ "event_to_prev_state_group",
+ event.event_id,
+ state_entry.state_group,
+ expiry_ms=60 * 60 * 1000,
+ )
- state = await state_entry.get_state(
- self._storage_controllers.state, StateFilter.all()
- )
- with opentracing.start_active_span("get_joined_hosts"):
- joined_hosts = await self.store.get_joined_hosts(
- event.room_id, state, state_entry
+ if state_entry.state_group in self._external_cache_joined_hosts_updates:
+ return
+
+ state = await state_entry.get_state(
+ self._storage_controllers.state, StateFilter.all()
)
+ with opentracing.start_active_span("get_joined_hosts"):
+ joined_hosts = await self.store.get_joined_hosts(
+ event.room_id, state, state_entry
+ )
- # Note that the expiry times must be larger than the expiry time in
- # _external_cache_joined_hosts_updates.
- await self._external_cache.set(
- "get_joined_hosts",
- str(state_entry.state_group),
- list(joined_hosts),
- expiry_ms=60 * 60 * 1000,
- )
+ # Note that the expiry times must be larger than the expiry time in
+ # _external_cache_joined_hosts_updates.
+ await self._external_cache.set(
+ "get_joined_hosts",
+ str(state_entry.state_group),
+ list(joined_hosts),
+ expiry_ms=60 * 60 * 1000,
+ )
- self._external_cache_joined_hosts_updates[state_entry.state_group] = None
+ self._external_cache_joined_hosts_updates[
+ state_entry.state_group
+ ] = None
async def _validate_canonical_alias(
self,
@@ -1705,12 +1740,15 @@ class EventCreationHandler:
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
- event.unsigned[
- "invite_room_state"
- ] = await self.store.get_stripped_room_state_from_event_context(
- context,
- self.room_prejoin_state_types,
- membership_user_id=event.sender,
+ maybe_upsert_event_field(
+ event,
+ event.unsigned,
+ "invite_room_state",
+ await self.store.get_stripped_room_state_from_event_context(
+ context,
+ self.room_prejoin_state_types,
+ membership_user_id=event.sender,
+ ),
)
invitee = UserID.from_string(event.state_key)
@@ -1728,11 +1766,14 @@ class EventCreationHandler:
event.signatures.update(returned_invite.signatures)
if event.content["membership"] == Membership.KNOCK:
- event.unsigned[
- "knock_room_state"
- ] = await self.store.get_stripped_room_state_from_event_context(
- context,
- self.room_prejoin_state_types,
+ maybe_upsert_event_field(
+ event,
+ event.unsigned,
+ "knock_room_state",
+ await self.store.get_stripped_room_state_from_event_context(
+ context,
+ self.room_prejoin_state_types,
+ ),
)
if event.type == EventTypes.Redaction:
|