diff --git a/changelog.d/14675.misc b/changelog.d/14675.misc
new file mode 100644
index 0000000000..bc1ac1c82a
--- /dev/null
+++ b/changelog.d/14675.misc
@@ -0,0 +1 @@
+Add a class UnpersistedEventContext to allow for the batching up of storing state groups.
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 6eaef8b57a..e0d82ad81c 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional, Tuple
import attr
@@ -26,8 +27,51 @@ if TYPE_CHECKING:
from synapse.types.state import StateFilter
+class UnpersistedEventContextBase(ABC):
+ """
+ This is a base class for EventContext and UnpersistedEventContext, objects which
+ hold information relevant to storing an associated event. Note that an
+ UnpersistedEventContexts must be converted into an EventContext before it is
+ suitable to send to the db with its associated event.
+
+ Attributes:
+ _storage: storage controllers for interfacing with the database
+ app_service: If the associated event is being sent by a (local) application service, that
+ app service.
+ """
+
+ def __init__(self, storage_controller: "StorageControllers"):
+ self._storage: "StorageControllers" = storage_controller
+ self.app_service: Optional[ApplicationService] = None
+
+ @abstractmethod
+ async def persist(
+ self,
+ event: EventBase,
+ ) -> "EventContext":
+ """
+ A method to convert an UnpersistedEventContext to an EventContext, suitable for
+ sending to the database with the associated event.
+ """
+ pass
+
+ @abstractmethod
+ async def get_prev_state_ids(
+ self, state_filter: Optional["StateFilter"] = None
+ ) -> StateMap[str]:
+ """
+ Gets the room state at the event (ie not including the event if the event is a
+ state event).
+
+ Args:
+ state_filter: specifies the type of state event to fetch from DB, example:
+ EventTypes.JoinRules
+ """
+ pass
+
+
@attr.s(slots=True, auto_attribs=True)
-class EventContext:
+class EventContext(UnpersistedEventContextBase):
"""
Holds information relevant to persisting an event
@@ -77,9 +121,6 @@ class EventContext:
delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group``
and ``state_group``.
- app_service: If this event is being sent by a (local) application service, that
- app service.
-
partial_state: if True, we may be storing this event with a temporary,
incomplete state.
"""
@@ -122,6 +163,9 @@ class EventContext:
"""Return an EventContext instance suitable for persisting an outlier event"""
return EventContext(storage=storage)
+ async def persist(self, event: EventBase) -> "EventContext":
+ return self
+
async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
@@ -254,6 +298,128 @@ class EventContext:
)
+@attr.s(slots=True, auto_attribs=True)
+class UnpersistedEventContext(UnpersistedEventContextBase):
+ """
+ The event context holds information about the state groups for an event. It is important
+ to remember that an event technically has two state groups: the state group before the
+ event, and the state group after the event. If the event is not a state event, the state
+ group will not change (ie the state group before the event will be the same as the state
+ group after the event), but if it is a state event the state group before the event
+ will differ from the state group after the event.
+ This is a version of an EventContext before the new state group (if any) has been
+ computed and stored. It contains information about the state before the event (which
+ also may be the information after the event, if the event is not a state event). The
+ UnpersistedEventContext must be converted into an EventContext by calling the method
+ 'persist' on it before it is suitable to be sent to the DB for processing.
+
+ state_group_after_event:
+ The state group after the event. This will always be None until it is persisted.
+ If the event is not a state event, this will be the same as
+ state_group_before_event.
+
+ state_group_before_event:
+ The ID of the state group representing the state of the room before this event.
+
+ state_delta_due_to_event:
+ If the event is a state event, then this is the delta of the state between
+ `state_group` and `state_group_before_event`
+
+ prev_group_for_state_group_before_event:
+ If it is known, ``state_group_before_event``'s previous state group.
+
+ delta_ids_to_state_group_before_event:
+ If ``prev_group_for_state_group_before_event`` is not None, the state delta
+ between ``prev_group_for_state_group_before_event`` and ``state_group_before_event``.
+
+ partial_state:
+ Whether the event has partial state.
+
+ state_map_before_event:
+ A map of the state before the event, i.e. the state at `state_group_before_event`
+ """
+
+ _storage: "StorageControllers"
+ state_group_before_event: Optional[int]
+ state_group_after_event: Optional[int]
+ state_delta_due_to_event: Optional[dict]
+ prev_group_for_state_group_before_event: Optional[int]
+ delta_ids_to_state_group_before_event: Optional[StateMap[str]]
+ partial_state: bool
+ state_map_before_event: Optional[StateMap[str]] = None
+
+ async def get_prev_state_ids(
+ self, state_filter: Optional["StateFilter"] = None
+ ) -> StateMap[str]:
+ """
+ Gets the room state map, excluding this event.
+
+ Args:
+ state_filter: specifies the type of state event to fetch from DB
+
+ Returns:
+ Maps a (type, state_key) to the event ID of the state event matching
+ this tuple.
+ """
+ if self.state_map_before_event:
+ return self.state_map_before_event
+
+ assert self.state_group_before_event is not None
+ return await self._storage.state.get_state_ids_for_group(
+ self.state_group_before_event, state_filter
+ )
+
+ async def persist(self, event: EventBase) -> EventContext:
+ """
+ Creates a full `EventContext` for the event, persisting any referenced state that
+ has not yet been persisted.
+
+ Args:
+ event: event that the EventContext is associated with.
+
+ Returns: An EventContext suitable for sending to the database with the event
+ for persisting
+ """
+ assert self.partial_state is not None
+
+ # If we have a full set of state for before the event but don't have a state
+ # group for that state, we need to get one
+ if self.state_group_before_event is None:
+ assert self.state_map_before_event
+ state_group_before_event = await self._storage.state.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=self.prev_group_for_state_group_before_event,
+ delta_ids=self.delta_ids_to_state_group_before_event,
+ current_state_ids=self.state_map_before_event,
+ )
+ self.state_group_before_event = state_group_before_event
+
+ # if the event isn't a state event the state group doesn't change
+ if not self.state_delta_due_to_event:
+ state_group_after_event = self.state_group_before_event
+
+ # otherwise if it is a state event we need to get a state group for it
+ else:
+ state_group_after_event = await self._storage.state.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=self.state_group_before_event,
+ delta_ids=self.state_delta_due_to_event,
+ current_state_ids=None,
+ )
+
+ return EventContext.with_state(
+ storage=self._storage,
+ state_group=state_group_after_event,
+ state_group_before_event=self.state_group_before_event,
+ state_delta_due_to_event=self.state_delta_due_to_event,
+ partial_state=self.partial_state,
+ prev_group=self.state_group_before_event,
+ delta_ids=self.state_delta_due_to_event,
+ )
+
+
def _encode_state_dict(
state_dict: Optional[StateMap[str]],
) -> Optional[List[Tuple[str, str, str]]]:
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 72ab696898..97c61cc258 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -18,7 +18,7 @@ from twisted.internet.defer import CancelledError
from synapse.api.errors import ModuleFailedException, SynapseError
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import UnpersistedEventContextBase
from synapse.storage.roommember import ProfileInfo
from synapse.types import Requester, StateMap
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
@@ -231,7 +231,9 @@ class ThirdPartyEventRules:
self._on_threepid_bind_callbacks.append(on_threepid_bind)
async def check_event_allowed(
- self, event: EventBase, context: EventContext
+ self,
+ event: EventBase,
+ context: UnpersistedEventContextBase,
) -> Tuple[bool, Optional[dict]]:
"""Check if a provided event should be allowed in the given context.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 7f64130e0a..43ed4a3dd1 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -56,7 +56,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.crypto.event_signing import compute_event_signature
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.events.validator import EventValidator
from synapse.federation.federation_client import InvalidResponseError
from synapse.http.servlet import assert_params_in_dict
@@ -990,7 +990,10 @@ class FederationHandler:
)
try:
- event, context = await self.event_creation_handler.create_new_client_event(
+ (
+ event,
+ unpersisted_context,
+ ) = await self.event_creation_handler.create_new_client_event(
builder=builder
)
except SynapseError as e:
@@ -998,7 +1001,9 @@ class FederationHandler:
raise
# Ensure the user can even join the room.
- await self._federation_event_handler.check_join_restrictions(context, event)
+ await self._federation_event_handler.check_join_restrictions(
+ unpersisted_context, event
+ )
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
@@ -1178,7 +1183,7 @@ class FederationHandler:
},
)
- event, context = await self.event_creation_handler.create_new_client_event(
+ event, _ = await self.event_creation_handler.create_new_client_event(
builder=builder
)
@@ -1228,12 +1233,13 @@ class FederationHandler:
},
)
- event, context = await self.event_creation_handler.create_new_client_event(
- builder=builder
- )
+ (
+ event,
+ unpersisted_context,
+ ) = await self.event_creation_handler.create_new_client_event(builder=builder)
event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
- event, context
+ event, unpersisted_context
)
if not event_allowed:
logger.warning("Creation of knock %s forbidden by third-party rules", event)
@@ -1406,15 +1412,20 @@ class FederationHandler:
try:
(
event,
- context,
+ unpersisted_context,
) = await self.event_creation_handler.create_new_client_event(
builder=builder
)
- event, context = await self.add_display_name_to_third_party_invite(
- room_version_obj, event_dict, event, context
+ (
+ event,
+ unpersisted_context,
+ ) = await self.add_display_name_to_third_party_invite(
+ room_version_obj, event_dict, event, unpersisted_context
)
+ context = await unpersisted_context.persist(event)
+
EventValidator().validate_new(event, self.config)
# We need to tell the transaction queue to send this out, even
@@ -1483,14 +1494,19 @@ class FederationHandler:
try:
(
event,
- context,
+ unpersisted_context,
) = await self.event_creation_handler.create_new_client_event(
builder=builder
)
- event, context = await self.add_display_name_to_third_party_invite(
- room_version_obj, event_dict, event, context
+ (
+ event,
+ unpersisted_context,
+ ) = await self.add_display_name_to_third_party_invite(
+ room_version_obj, event_dict, event, unpersisted_context
)
+ context = await unpersisted_context.persist(event)
+
try:
validate_event_for_room_version(event)
await self._event_auth_handler.check_auth_rules_from_context(event)
@@ -1522,8 +1538,8 @@ class FederationHandler:
room_version_obj: RoomVersion,
event_dict: JsonDict,
event: EventBase,
- context: EventContext,
- ) -> Tuple[EventBase, EventContext]:
+ context: UnpersistedEventContextBase,
+ ) -> Tuple[EventBase, UnpersistedEventContextBase]:
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"],
@@ -1557,11 +1573,14 @@ class FederationHandler:
room_version_obj, event_dict
)
EventValidator().validate_builder(builder)
- event, context = await self.event_creation_handler.create_new_client_event(
- builder=builder
- )
+
+ (
+ event,
+ unpersisted_context,
+ ) = await self.event_creation_handler.create_new_client_event(builder=builder)
+
EventValidator().validate_new(event, self.config)
- return event, context
+ return event, unpersisted_context
async def _check_signature(self, event: EventBase, context: EventContext) -> None:
"""
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index e037acbca2..3561f2f1de 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -58,7 +58,7 @@ from synapse.event_auth import (
validate_event_for_room_version,
)
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo
from synapse.logging.context import nested_logging_context
from synapse.logging.opentracing import (
@@ -426,7 +426,9 @@ class FederationEventHandler:
return event, context
async def check_join_restrictions(
- self, context: EventContext, event: EventBase
+ self,
+ context: UnpersistedEventContextBase,
+ event: EventBase,
) -> None:
"""Check that restrictions in restricted join rules are matched
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 5f6da2943f..3e30f52e4d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -48,7 +48,7 @@ from synapse.api.urls import ConsentURIBuilder
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase, relation_from_event
from synapse.events.builder import EventBuilder
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.events.utils import maybe_upsert_event_field
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
@@ -708,7 +708,7 @@ class EventCreationHandler:
builder.internal_metadata.historical = historical
- event, context = await self.create_new_client_event(
+ event, unpersisted_context = await self.create_new_client_event(
builder=builder,
requester=requester,
allow_no_prev_events=allow_no_prev_events,
@@ -721,6 +721,8 @@ class EventCreationHandler:
current_state_group=current_state_group,
)
+ context = await unpersisted_context.persist(event)
+
# In an ideal world we wouldn't need the second part of this condition. However,
# this behaviour isn't spec'd yet, meaning we should be able to deactivate this
# behaviour. Another reason is that this code is also evaluated each time a new
@@ -1083,13 +1085,14 @@ class EventCreationHandler:
state_map: Optional[StateMap[str]] = None,
for_batch: bool = False,
current_state_group: Optional[int] = None,
- ) -> Tuple[EventBase, EventContext]:
+ ) -> Tuple[EventBase, UnpersistedEventContextBase]:
"""Create a new event for a local client. If bool for_batch is true, will
create an event using the prev_event_ids, and will create an event context for
the event using the parameters state_map and current_state_group, thus these parameters
must be provided in this case if for_batch is True. The subsequently created event
and context are suitable for being batched up and bulk persisted to the database
- with other similarly created events.
+ with other similarly created events. Note that this returns an UnpersistedEventContext,
+ which must be converted to an EventContext before it can be sent to the DB.
Args:
builder:
@@ -1131,7 +1134,7 @@ class EventCreationHandler:
batch persisting
Returns:
- Tuple of created event, context
+ Tuple of created event, UnpersistedEventContext
"""
# Strip down the state_event_ids to only what we need to auth the event.
# For example, we don't need extra m.room.member that don't match event.sender
@@ -1192,9 +1195,16 @@ class EventCreationHandler:
event = await builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth
)
- context = await self.state.compute_event_context_for_batched(
- event, state_map, current_state_group
+
+ context: UnpersistedEventContextBase = (
+ await self.state.calculate_context_info(
+ event,
+ state_ids_before_event=state_map,
+ partial_state=False,
+ state_group_before_event=current_state_group,
+ )
)
+
else:
event = await builder.build(
prev_event_ids=prev_event_ids,
@@ -1244,16 +1254,17 @@ class EventCreationHandler:
state_map_for_event[(data.event_type, data.state_key)] = state_id
- context = await self.state.compute_event_context(
+ # TODO(faster_joins): check how MSC2716 works and whether we can have
+ # partial state here
+ # https://github.com/matrix-org/synapse/issues/13003
+ context = await self.state.calculate_context_info(
event,
state_ids_before_event=state_map_for_event,
- # TODO(faster_joins): check how MSC2716 works and whether we can have
- # partial state here
- # https://github.com/matrix-org/synapse/issues/13003
partial_state=False,
)
+
else:
- context = await self.state.compute_event_context(event)
+ context = await self.state.calculate_context_info(event)
if requester:
context.app_service = requester.app_service
@@ -2082,9 +2093,9 @@ class EventCreationHandler:
async def _rebuild_event_after_third_party_rules(
self, third_party_result: dict, original_event: EventBase
- ) -> Tuple[EventBase, EventContext]:
+ ) -> Tuple[EventBase, UnpersistedEventContextBase]:
# the third_party_event_rules want to replace the event.
- # we do some basic checks, and then return the replacement event and context.
+ # we do some basic checks, and then return the replacement event.
# Construct a new EventBuilder and validate it, which helps with the
# rest of these checks.
@@ -2138,5 +2149,6 @@ class EventCreationHandler:
# we rebuild the event context, to be on the safe side. If nothing else,
# delta_ids might need an update.
- context = await self.state.compute_event_context(event)
+ context = await self.state.calculate_context_info(event)
+
return event, context
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index fdfb46ab82..e877e6f1a1 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -39,7 +39,11 @@ from prometheus_client import Counter, Histogram
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import (
+ EventContext,
+ UnpersistedEventContext,
+ UnpersistedEventContextBase,
+)
from synapse.logging.context import ContextResourceUsage
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2
@@ -262,31 +266,31 @@ class StateHandler:
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
return await self.store.get_joined_hosts(room_id, state, entry)
- async def compute_event_context(
+ async def calculate_context_info(
self,
event: EventBase,
state_ids_before_event: Optional[StateMap[str]] = None,
partial_state: Optional[bool] = None,
- ) -> EventContext:
- """Build an EventContext structure for a non-outlier event.
-
- (for an outlier, call EventContext.for_outlier directly)
-
- This works out what the current state should be for the event, and
- generates a new state group if necessary.
-
- Args:
- event:
- state_ids_before_event: The event ids of the state before the event if
- it can't be calculated from existing events. This is normally
- only specified when receiving an event from federation where we
- don't have the prev events, e.g. when backfilling.
- partial_state:
- `True` if `state_ids_before_event` is partial and omits non-critical
- membership events.
- `False` if `state_ids_before_event` is the full state.
- `None` when `state_ids_before_event` is not provided. In this case, the
- flag will be calculated based on `event`'s prev events.
+ state_group_before_event: Optional[int] = None,
+ ) -> UnpersistedEventContextBase:
+ """
+ Calulates the contents of an unpersisted event context, other than the current
+ state group (which is either provided or calculated when the event context is persisted)
+
+ state_ids_before_event:
+ The event ids of the full state before the event if
+ it can't be calculated from existing events. This is normally
+ only specified when receiving an event from federation where we
+ don't have the prev events, e.g. when backfilling or when the event
+ is being created for batch persisting.
+ partial_state:
+ `True` if `state_ids_before_event` is partial and omits non-critical
+ membership events.
+ `False` if `state_ids_before_event` is the full state.
+ `None` when `state_ids_before_event` is not provided. In this case, the
+ flag will be calculated based on `event`'s prev events.
+ state_group_before_event:
+ the current state group at the time of event, if known
Returns:
The event context.
@@ -294,7 +298,6 @@ class StateHandler:
RuntimeError if `state_ids_before_event` is not provided and one or more
prev events are missing or outliers.
"""
-
assert not event.internal_metadata.is_outlier()
#
@@ -306,17 +309,6 @@ class StateHandler:
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
- # .. though we need to get a state group for it.
- state_group_before_event = (
- await self._state_storage_controller.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=None,
- delta_ids=None,
- current_state_ids=state_ids_before_event,
- )
- )
-
# the partial_state flag must be provided
assert partial_state is not None
else:
@@ -345,6 +337,7 @@ class StateHandler:
logger.debug("calling resolve_state_groups from compute_event_context")
# we've already taken into account partial state, so no need to wait for
# complete state here.
+
entry = await self.resolve_state_groups_for_events(
event.room_id,
event.prev_event_ids(),
@@ -383,18 +376,19 @@ class StateHandler:
#
if not event.is_state():
- return EventContext.with_state(
+ return UnpersistedEventContext(
storage=self._storage_controllers,
state_group_before_event=state_group_before_event,
- state_group=state_group_before_event,
+ state_group_after_event=state_group_before_event,
state_delta_due_to_event={},
- prev_group=state_group_before_event_prev_group,
- delta_ids=deltas_to_state_group_before_event,
+ prev_group_for_state_group_before_event=state_group_before_event_prev_group,
+ delta_ids_to_state_group_before_event=deltas_to_state_group_before_event,
partial_state=partial_state,
+ state_map_before_event=state_ids_before_event,
)
#
- # otherwise, we'll need to create a new state group for after the event
+ # otherwise, we'll need to set up creating a new state group for after the event
#
key = (event.type, event.state_key)
@@ -412,88 +406,60 @@ class StateHandler:
delta_ids = {key: event.event_id}
- state_group_after_event = (
- await self._state_storage_controller.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=state_group_before_event,
- delta_ids=delta_ids,
- current_state_ids=None,
- )
- )
-
- return EventContext.with_state(
+ return UnpersistedEventContext(
storage=self._storage_controllers,
- state_group=state_group_after_event,
state_group_before_event=state_group_before_event,
+ state_group_after_event=None,
state_delta_due_to_event=delta_ids,
- prev_group=state_group_before_event,
- delta_ids=delta_ids,
+ prev_group_for_state_group_before_event=state_group_before_event_prev_group,
+ delta_ids_to_state_group_before_event=deltas_to_state_group_before_event,
partial_state=partial_state,
+ state_map_before_event=state_ids_before_event,
)
- async def compute_event_context_for_batched(
+ async def compute_event_context(
self,
event: EventBase,
- state_ids_before_event: StateMap[str],
- current_state_group: int,
+ state_ids_before_event: Optional[StateMap[str]] = None,
+ partial_state: Optional[bool] = None,
) -> EventContext:
- """
- Generate an event context for an event that has not yet been persisted to the
- database. Intended for use with events that are created to be persisted in a batch.
- Args:
- event: the event the context is being computed for
- state_ids_before_event: a state map consisting of the state ids of the events
- created prior to this event.
- current_state_group: the current state group before the event.
- """
- state_group_before_event_prev_group = None
- deltas_to_state_group_before_event = None
-
- state_group_before_event = current_state_group
-
- # if the event is not state, we are set
- if not event.is_state():
- return EventContext.with_state(
- storage=self._storage_controllers,
- state_group_before_event=state_group_before_event,
- state_group=state_group_before_event,
- state_delta_due_to_event={},
- prev_group=state_group_before_event_prev_group,
- delta_ids=deltas_to_state_group_before_event,
- partial_state=False,
- )
+ """Build an EventContext structure for a non-outlier event.
- # otherwise, we'll need to create a new state group for after the event
- key = (event.type, event.state_key)
+ (for an outlier, call EventContext.for_outlier directly)
- if state_ids_before_event is not None:
- replaces = state_ids_before_event.get(key)
+ This works out what the current state should be for the event, and
+ generates a new state group if necessary.
- if replaces and replaces != event.event_id:
- event.unsigned["replaces_state"] = replaces
+ Args:
+ event:
+ state_ids_before_event: The event ids of the state before the event if
+ it can't be calculated from existing events. This is normally
+ only specified when receiving an event from federation where we
+ don't have the prev events, e.g. when backfilling.
+ partial_state:
+ `True` if `state_ids_before_event` is partial and omits non-critical
+ membership events.
+ `False` if `state_ids_before_event` is the full state.
+ `None` when `state_ids_before_event` is not provided. In this case, the
+ flag will be calculated based on `event`'s prev events.
+ entry:
+ A state cache entry for the resolved state across the prev events. We may
+ have already calculated this, so if it's available pass it in
+ Returns:
+ The event context.
- delta_ids = {key: event.event_id}
+ Raises:
+ RuntimeError if `state_ids_before_event` is not provided and one or more
+ prev events are missing or outliers.
+ """
- state_group_after_event = (
- await self._state_storage_controller.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=state_group_before_event,
- delta_ids=delta_ids,
- current_state_ids=None,
- )
+ unpersisted_context = await self.calculate_context_info(
+ event=event,
+ state_ids_before_event=state_ids_before_event,
+ partial_state=partial_state,
)
- return EventContext.with_state(
- storage=self._storage_controllers,
- state_group=state_group_after_event,
- state_group_before_event=state_group_before_event,
- state_delta_due_to_event=delta_ids,
- prev_group=state_group_before_event,
- delta_ids=delta_ids,
- partial_state=False,
- )
+ return await unpersisted_context.persist(event)
@measure_func()
async def resolve_state_groups_for_events(
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 75fc5a17a4..e9be5fb504 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -949,10 +949,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
self.get_success(
self.hs.get_storage_controllers().persistence.persist_event(event, context)
)
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 5c1ced355f..b50406e129 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2934,10 +2934,12 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
self.get_success(storage_controllers.persistence.persist_event(event, context))
# Now get rooms
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index df4740f9d9..0100f7da14 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -74,10 +74,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
self.get_success(self._persistence.persist_event(event, context))
return event
@@ -96,10 +98,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
self.get_success(self._persistence.persist_event(event, context))
return event
@@ -119,10 +123,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
self.get_success(self._persistence.persist_event(event, context))
return event
@@ -259,7 +265,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def internal_metadata(self) -> _EventInternalMetadata:
return self._base_builder.internal_metadata
- event_1, context_1 = self.get_success(
+ event_1, unpersisted_context_1 = self.get_success(
self.event_creation_handler.create_new_client_event(
cast(
EventBuilder,
@@ -280,9 +286,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
+ context_1 = self.get_success(unpersisted_context_1.persist(event_1))
+
self.get_success(self._persistence.persist_event(event_1, context_1))
- event_2, context_2 = self.get_success(
+ event_2, unpersisted_context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
cast(
EventBuilder,
@@ -302,6 +310,8 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
)
+
+ context_2 = self.get_success(unpersisted_context_2.persist(event_2))
self.get_success(self._persistence.persist_event(event_2, context_2))
# fetch one of the redactions
@@ -421,10 +431,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- redaction_event, context = self.get_success(
+ redaction_event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(redaction_event))
+
self.get_success(self._persistence.persist_event(redaction_event, context))
# Now lets jump to the future where we have censored the redaction event
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index bad7f0bc60..f730b888f7 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -67,10 +67,12 @@ class StateStoreTestCase(HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
+
assert self.storage.persistence is not None
self.get_success(self.storage.persistence.persist_event(event, context))
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 1a50c2acf1..a6330ed840 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -92,8 +92,13 @@ async def create_event(
builder = hs.get_event_builder_factory().for_room_version(
KNOWN_ROOM_VERSIONS[room_version], kwargs
)
- event, context = await hs.get_event_creation_handler().create_new_client_event(
+ (
+ event,
+ unpersisted_context,
+ ) = await hs.get_event_creation_handler().create_new_client_event(
builder, prev_event_ids=prev_event_ids
)
+ context = await unpersisted_context.persist(event)
+
return event, context
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 875e37988f..36d6b37aa4 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -175,9 +175,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
self._storage_controllers.persistence.persist_event(event, context)
)
@@ -202,9 +203,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
self._storage_controllers.persistence.persist_event(event, context)
@@ -226,9 +228,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
self._storage_controllers.persistence.persist_event(event, context)
diff --git a/tests/utils.py b/tests/utils.py
index d76bf9716a..15fabbc2d0 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -335,6 +335,9 @@ async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
},
)
- event, context = await event_creation_handler.create_new_client_event(builder)
+ event, unpersisted_context = await event_creation_handler.create_new_client_event(
+ builder
+ )
+ context = await unpersisted_context.persist(event)
await persistence_store.persist_event(event, context)
|