diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index e0d82ad81c..a91a5d1e3c 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -23,6 +23,7 @@ from synapse.types import JsonDict, StateMap
if TYPE_CHECKING:
from synapse.storage.controllers import StorageControllers
+ from synapse.storage.databases import StateGroupDataStore
from synapse.storage.databases.main import DataStore
from synapse.types.state import StateFilter
@@ -348,6 +349,54 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
partial_state: bool
state_map_before_event: Optional[StateMap[str]] = None
+ @classmethod
+ async def batch_persist_unpersisted_contexts(
+ cls,
+ events_and_context: List[Tuple[EventBase, "UnpersistedEventContextBase"]],
+ room_id: str,
+ last_known_state_group: int,
+ datastore: "StateGroupDataStore",
+ ) -> List[Tuple[EventBase, EventContext]]:
+ """
+ Takes a list of events and their associated unpersisted contexts and persists
+ the unpersisted contexts, returning a list of events and persisted contexts.
+ Note that all the events must be in a linear chain (ie a <- b <- c).
+
+ Args:
+ events_and_context: A list of events and their unpersisted contexts
+ room_id: the room_id for the events
+ last_known_state_group: the last persisted state group
+ datastore: a state datastore
+ """
+ amended_events_and_context = await datastore.store_state_deltas_for_batched(
+ events_and_context, room_id, last_known_state_group
+ )
+
+ events_and_persisted_context = []
+ for event, unpersisted_context in amended_events_and_context:
+ if event.is_state():
+ context = EventContext(
+ storage=unpersisted_context._storage,
+ state_group=unpersisted_context.state_group_after_event,
+ state_group_before_event=unpersisted_context.state_group_before_event,
+ state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
+ partial_state=unpersisted_context.partial_state,
+ prev_group=unpersisted_context.state_group_before_event,
+ delta_ids=unpersisted_context.state_delta_due_to_event,
+ )
+ else:
+ context = EventContext(
+ storage=unpersisted_context._storage,
+ state_group=unpersisted_context.state_group_after_event,
+ state_group_before_event=unpersisted_context.state_group_before_event,
+ state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
+ partial_state=unpersisted_context.partial_state,
+ prev_group=unpersisted_context.prev_group_for_state_group_before_event,
+ delta_ids=unpersisted_context.delta_ids_to_state_group_before_event,
+ )
+ events_and_persisted_context.append((event, context))
+ return events_and_persisted_context
+
async def get_prev_state_ids(
self, state_filter: Optional["StateFilter"] = None
) -> StateMap[str]:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index aa90d0000d..e433d6b01f 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -574,7 +574,7 @@ class EventCreationHandler:
state_map: Optional[StateMap[str]] = None,
for_batch: bool = False,
current_state_group: Optional[int] = None,
- ) -> Tuple[EventBase, EventContext]:
+ ) -> Tuple[EventBase, UnpersistedEventContextBase]:
"""
Given a dict from a client, create a new event. If bool for_batch is true, will
create an event using the prev_event_ids, and will create an event context for
@@ -721,8 +721,6 @@ class EventCreationHandler:
current_state_group=current_state_group,
)
- context = await unpersisted_context.persist(event)
-
# In an ideal world we wouldn't need the second part of this condition. However,
# this behaviour isn't spec'd yet, meaning we should be able to deactivate this
# behaviour. Another reason is that this code is also evaluated each time a new
@@ -739,7 +737,7 @@ class EventCreationHandler:
assert state_map is not None
prev_event_id = state_map.get((EventTypes.Member, event.sender))
else:
- prev_state_ids = await context.get_prev_state_ids(
+ prev_state_ids = await unpersisted_context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.Member, None)])
)
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
@@ -764,8 +762,7 @@ class EventCreationHandler:
)
self.validator.validate_new(event, self.config)
-
- return event, context
+ return event, unpersisted_context
async def _is_exempt_from_privacy_policy(
self, builder: EventBuilder, requester: Requester
@@ -1005,7 +1002,7 @@ class EventCreationHandler:
max_retries = 5
for i in range(max_retries):
try:
- event, context = await self.create_event(
+ event, unpersisted_context = await self.create_event(
requester,
event_dict,
txn_id=txn_id,
@@ -1016,6 +1013,7 @@ class EventCreationHandler:
historical=historical,
depth=depth,
)
+ context = await unpersisted_context.persist(event)
assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
event.sender,
@@ -1190,7 +1188,6 @@ class EventCreationHandler:
if for_batch:
assert prev_event_ids is not None
assert state_map is not None
- assert current_state_group is not None
auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map)
event = await builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth
@@ -2046,7 +2043,7 @@ class EventCreationHandler:
max_retries = 5
for i in range(max_retries):
try:
- event, context = await self.create_event(
+ event, unpersisted_context = await self.create_event(
requester,
{
"type": EventTypes.Dummy,
@@ -2055,6 +2052,7 @@ class EventCreationHandler:
"sender": user_id,
},
)
+ context = await unpersisted_context.persist(event)
event.internal_metadata.proactively_send = False
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a26ec02284..b1784638f4 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -51,6 +51,7 @@ from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
+from synapse.events.snapshot import UnpersistedEventContext
from synapse.events.utils import copy_and_fixup_power_levels_contents
from synapse.handlers.relations import BundledAggregations
from synapse.module_api import NOT_SPAM
@@ -211,7 +212,7 @@ class RoomCreationHandler:
# the required power level to send the tombstone event.
(
tombstone_event,
- tombstone_context,
+ tombstone_unpersisted_context,
) = await self.event_creation_handler.create_event(
requester,
{
@@ -225,6 +226,9 @@ class RoomCreationHandler:
},
},
)
+ tombstone_context = await tombstone_unpersisted_context.persist(
+ tombstone_event
+ )
validate_event_for_room_version(tombstone_event)
await self._event_auth_handler.check_auth_rules_from_context(
tombstone_event
@@ -1092,7 +1096,7 @@ class RoomCreationHandler:
content: JsonDict,
for_batch: bool,
**kwargs: Any,
- ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]:
+ ) -> Tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]:
"""
Creates an event and associated event context.
Args:
@@ -1111,20 +1115,23 @@ class RoomCreationHandler:
event_dict = create_event_dict(etype, content, **kwargs)
- new_event, new_context = await self.event_creation_handler.create_event(
+ (
+ new_event,
+ new_unpersisted_context,
+ ) = await self.event_creation_handler.create_event(
creator,
event_dict,
prev_event_ids=prev_event,
depth=depth,
state_map=state_map,
for_batch=for_batch,
- current_state_group=current_state_group,
)
+
depth += 1
prev_event = [new_event.event_id]
state_map[(new_event.type, new_event.state_key)] = new_event.event_id
- return new_event, new_context
+ return new_event, new_unpersisted_context
try:
config = self._presets_dict[preset_config]
@@ -1134,10 +1141,10 @@ class RoomCreationHandler:
)
creation_content.update({"creator": creator_id})
- creation_event, creation_context = await create_event(
+ creation_event, unpersisted_creation_context = await create_event(
EventTypes.Create, creation_content, False
)
-
+ creation_context = await unpersisted_creation_context.persist(creation_event)
logger.debug("Sending %s in new room", EventTypes.Member)
ev = await self.event_creation_handler.handle_new_client_event(
requester=creator,
@@ -1181,7 +1188,6 @@ class RoomCreationHandler:
power_event, power_context = await create_event(
EventTypes.PowerLevels, pl_content, True
)
- current_state_group = power_context._state_group
events_to_send.append((power_event, power_context))
else:
power_level_content: JsonDict = {
@@ -1230,14 +1236,12 @@ class RoomCreationHandler:
power_level_content,
True,
)
- current_state_group = pl_context._state_group
events_to_send.append((pl_event, pl_context))
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
room_alias_event, room_alias_context = await create_event(
EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True
)
- current_state_group = room_alias_context._state_group
events_to_send.append((room_alias_event, room_alias_context))
if (EventTypes.JoinRules, "") not in initial_state:
@@ -1246,7 +1250,6 @@ class RoomCreationHandler:
{"join_rule": config["join_rules"]},
True,
)
- current_state_group = join_rules_context._state_group
events_to_send.append((join_rules_event, join_rules_context))
if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
@@ -1255,7 +1258,6 @@ class RoomCreationHandler:
{"history_visibility": config["history_visibility"]},
True,
)
- current_state_group = visibility_context._state_group
events_to_send.append((visibility_event, visibility_context))
if config["guest_can_join"]:
@@ -1265,14 +1267,12 @@ class RoomCreationHandler:
{EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
True,
)
- current_state_group = guest_access_context._state_group
events_to_send.append((guest_access_event, guest_access_context))
for (etype, state_key), content in initial_state.items():
event, context = await create_event(
etype, content, True, state_key=state_key
)
- current_state_group = context._state_group
events_to_send.append((event, context))
if config["encrypted"]:
@@ -1284,9 +1284,16 @@ class RoomCreationHandler:
)
events_to_send.append((encryption_event, encryption_context))
+ datastore = self.hs.get_datastores().state
+ events_and_context = (
+ await UnpersistedEventContext.batch_persist_unpersisted_contexts(
+ events_to_send, room_id, current_state_group, datastore
+ )
+ )
+
last_event = await self.event_creation_handler.handle_new_client_event(
creator,
- events_to_send,
+ events_and_context,
ignore_shadow_ban=True,
ratelimit=False,
)
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 5d4ca0e2d2..bf9df60218 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -327,7 +327,7 @@ class RoomBatchHandler:
# Mark all events as historical
event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
- event, context = await self.event_creation_handler.create_event(
+ event, unpersisted_context = await self.event_creation_handler.create_event(
await self.create_requester_for_user_id_from_app_service(
ev["sender"], app_service_requester.app_service
),
@@ -345,7 +345,7 @@ class RoomBatchHandler:
historical=True,
depth=inherited_depth,
)
-
+ context = await unpersisted_context.persist(event)
assert context._state_group
# Normally this is done when persisting the event but we have to
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index a965c7ec76..de7476f300 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -414,7 +414,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
max_retries = 5
for i in range(max_retries):
try:
- event, context = await self.event_creation_handler.create_event(
+ (
+ event,
+ unpersisted_context,
+ ) = await self.event_creation_handler.create_event(
requester,
{
"type": EventTypes.Member,
@@ -435,7 +438,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
outlier=outlier,
historical=historical,
)
-
+ context = await unpersisted_context.persist(event)
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.Member, None)])
)
@@ -1944,7 +1947,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
max_retries = 5
for i in range(max_retries):
try:
- event, context = await self.event_creation_handler.create_event(
+ (
+ event,
+ unpersisted_context,
+ ) = await self.event_creation_handler.create_event(
requester,
event_dict,
txn_id=txn_id,
@@ -1952,6 +1958,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
auth_event_ids=auth_event_ids,
outlier=True,
)
+ context = await unpersisted_context.persist(event)
event.internal_metadata.out_of_band_membership = True
result_event = (
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 89b1faa6c8..bf4cdfdf29 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Se
import attr
from synapse.api.constants import EventTypes
+from synapse.events import EventBase
+from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -401,6 +403,123 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
fetched_keys=non_member_types,
)
+ async def store_state_deltas_for_batched(
+ self,
+ events_and_context: List[Tuple[EventBase, UnpersistedEventContextBase]],
+ room_id: str,
+ prev_group: int,
+ ) -> List[Tuple[EventBase, UnpersistedEventContext]]:
+ """Generate and store state deltas for a group of events and contexts created to be
+ batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c).
+
+ Args:
+ events_and_context: the events to generate and store a state groups for
+ and their associated contexts
+ room_id: the id of the room the events were created for
+ prev_group: the state group of the last event persisted before the batched events
+ were created
+ """
+
+ def insert_deltas_group_txn(
+ txn: LoggingTransaction,
+ events_and_context: List[Tuple[EventBase, UnpersistedEventContext]],
+ prev_group: int,
+ ) -> List[Tuple[EventBase, UnpersistedEventContext]]:
+ """Generate and store state groups for the provided events and contexts.
+
+ Requires that we have the state as a delta from the last persisted state group.
+
+ Returns:
+ A list of state groups
+ """
+ is_in_db = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="state_groups",
+ keyvalues={"id": prev_group},
+ retcol="id",
+ allow_none=True,
+ )
+ if not is_in_db:
+ raise Exception(
+ "Trying to persist state with unpersisted prev_group: %r"
+ % (prev_group,)
+ )
+
+ num_state_groups = sum(
+ 1 for event, _ in events_and_context if event.is_state()
+ )
+
+ state_groups = self._state_group_seq_gen.get_next_mult_txn(
+ txn, num_state_groups
+ )
+
+ sg_before = prev_group
+ state_group_iter = iter(state_groups)
+ for event, context in events_and_context:
+ if not event.is_state():
+ context.state_group_after_event = sg_before
+ context.state_group_before_event = sg_before
+ continue
+
+ sg_after = next(state_group_iter)
+ context.state_group_after_event = sg_after
+ context.state_group_before_event = sg_before
+ context.state_delta_due_to_event = {
+ (event.type, event.state_key): event.event_id
+ }
+ sg_before = sg_after
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_groups",
+ keys=("id", "room_id", "event_id"),
+ values=[
+ (context.state_group_after_event, room_id, event.event_id)
+ for event, context in events_and_context
+ if event.is_state()
+ ],
+ )
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_group_edges",
+ keys=("state_group", "prev_state_group"),
+ values=[
+ (
+ context.state_group_after_event,
+ context.state_group_before_event,
+ )
+ for event, context in events_and_context
+ if event.is_state()
+ ],
+ )
+
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ keys=("state_group", "room_id", "type", "state_key", "event_id"),
+ values=[
+ (
+ context.state_group_after_event,
+ room_id,
+ key[0],
+ key[1],
+ state_id,
+ )
+ for event, context in events_and_context
+ if context.state_delta_due_to_event is not None
+ for key, state_id in context.state_delta_due_to_event.items()
+ ],
+ )
+ return events_and_context
+
+ return await self.db_pool.runInteraction(
+ "store_state_deltas_for_batched.insert_deltas_group",
+ insert_deltas_group_txn,
+ events_and_context,
+ prev_group,
+ )
+
async def store_state_group(
self,
event_id: str,
|