diff --git a/changelog.d/12913.misc b/changelog.d/12913.misc
new file mode 100644
index 0000000000..a2bc940557
--- /dev/null
+++ b/changelog.d/12913.misc
@@ -0,0 +1 @@
+Rename storage classes.
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 7a91544119..b700cbbfa1 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -22,7 +22,7 @@ from synapse.events import EventBase
from synapse.types import JsonDict, StateMap
if TYPE_CHECKING:
- from synapse.storage import Storage
+ from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
from synapse.storage.state import StateFilter
@@ -84,7 +84,7 @@ class EventContext:
incomplete state.
"""
- _storage: "Storage"
+ _storage: "StorageControllers"
rejected: Union[Literal[False], str] = False
_state_group: Optional[int] = None
state_group_before_event: Optional[int] = None
@@ -97,7 +97,7 @@ class EventContext:
@staticmethod
def with_state(
- storage: "Storage",
+ storage: "StorageControllers",
state_group: Optional[int],
state_group_before_event: Optional[int],
state_delta_due_to_event: Optional[StateMap[str]],
@@ -117,7 +117,7 @@ class EventContext:
@staticmethod
def for_outlier(
- storage: "Storage",
+ storage: "StorageControllers",
) -> "EventContext":
"""Return an EventContext instance suitable for persisting an outlier event"""
return EventContext(storage=storage)
@@ -147,7 +147,7 @@ class EventContext:
}
@staticmethod
- def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
+ def deserialize(storage: "StorageControllers", input: JsonDict) -> "EventContext":
"""Converts a dict that was produced by `serialize` back into a
EventContext.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 5b227b85fd..3ecede22d9 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -109,7 +109,6 @@ class FederationServer(FederationBase):
super().__init__(hs)
self.handler = hs.get_federation_handler()
- self.storage = hs.get_storage()
self._spam_checker = hs.get_spam_checker()
self._federation_event_handler = hs.get_federation_event_handler()
self.state = hs.get_state_handler()
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 50e34743b7..d4fe7df533 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -30,8 +30,8 @@ logger = logging.getLogger(__name__)
class AdminHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
- self.state_storage = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
async def get_whois(self, user: UserID) -> JsonDict:
connections = []
@@ -197,7 +197,9 @@ class AdminHandler:
from_key = events[-1].internal_metadata.after
- events = await filter_events_for_client(self.storage, user_id, events)
+ events = await filter_events_for_client(
+ self._storage_controllers, user_id, events
+ )
writer.write_events(room_id, events)
@@ -233,7 +235,9 @@ class AdminHandler:
for event_id in extremities:
if not event_to_unseen_prevs[event_id]:
continue
- state = await self.state_storage.get_state_for_event(event_id)
+ state = await self._state_storage_controller.get_state_for_event(
+ event_id
+ )
writer.write_state(room_id, event_id, state)
return writer.finished()
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 2a56473dc6..72faf2ee38 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -71,7 +71,7 @@ class DeviceWorkerHandler:
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
- self.state_storage = hs.get_storage().state
+ self._state_storage = hs.get_storage_controllers().state
self._auth_handler = hs.get_auth_handler()
self.server_name = hs.hostname
@@ -204,7 +204,7 @@ class DeviceWorkerHandler:
continue
# mapping from event_id -> state_dict
- prev_state_ids = await self.state_storage.get_state_ids_for_events(
+ prev_state_ids = await self._state_storage.get_state_ids_for_events(
event_ids
)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index cb7e0ca7a8..ac13340d3a 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -139,7 +139,7 @@ class EventStreamHandler:
class EventHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
async def get_event(
self,
@@ -177,7 +177,7 @@ class EventHandler:
is_peeking = user.to_string() not in users
filtered = await filter_events_for_client(
- self.storage, user.to_string(), [event], is_peeking=is_peeking
+ self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking
)
if not filtered:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c8233270d7..80ee7e7b4e 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -125,8 +125,8 @@ class FederationHandler:
self.hs = hs
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
- self.state_storage = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname
@@ -324,7 +324,7 @@ class FederationHandler:
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
filtered_extremities = await filter_events_for_server(
- self.storage,
+ self._storage_controllers,
self.server_name,
events_to_check,
redact=False,
@@ -660,7 +660,7 @@ class FederationHandler:
# in the invitee's sync stream. It is stripped out for all other local users.
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
- context = EventContext.for_outlier(self.storage)
+ context = EventContext.for_outlier(self._storage_controllers)
stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@@ -849,7 +849,7 @@ class FederationHandler:
)
)
- context = EventContext.for_outlier(self.storage)
+ context = EventContext.for_outlier(self._storage_controllers)
await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@@ -878,7 +878,7 @@ class FederationHandler:
await self.federation_client.send_leave(host_list, event)
- context = EventContext.for_outlier(self.storage)
+ context = EventContext.for_outlier(self._storage_controllers)
stream_id = await self._federation_event_handler.persist_events_and_notify(
event.room_id, [(event, context)]
)
@@ -1027,7 +1027,7 @@ class FederationHandler:
if event.internal_metadata.outlier:
raise NotFoundError("State not known at event %s" % (event_id,))
- state_groups = await self.state_storage.get_state_groups_ids(
+ state_groups = await self._state_storage_controller.get_state_groups_ids(
room_id, [event_id]
)
@@ -1078,7 +1078,9 @@ class FederationHandler:
],
)
- events = await filter_events_for_server(self.storage, origin, events)
+ events = await filter_events_for_server(
+ self._storage_controllers, origin, events
+ )
return events
@@ -1109,7 +1111,9 @@ class FederationHandler:
if not in_room:
raise AuthError(403, "Host not in room.")
- events = await filter_events_for_server(self.storage, origin, [event])
+ events = await filter_events_for_server(
+ self._storage_controllers, origin, [event]
+ )
event = events[0]
return event
else:
@@ -1138,7 +1142,7 @@ class FederationHandler:
)
missing_events = await filter_events_for_server(
- self.storage, origin, missing_events
+ self._storage_controllers, origin, missing_events
)
return missing_events
@@ -1480,9 +1484,11 @@ class FederationHandler:
# clear the lazy-loading flag.
logger.info("Updating current state for %s", room_id)
assert (
- self.storage.persistence is not None
+ self._storage_controllers.persistence is not None
), "TODO(faster_joins): support for workers"
- await self.storage.persistence.update_current_state(room_id)
+ await self._storage_controllers.persistence.update_current_state(
+ room_id
+ )
logger.info("Clearing partial-state flag for %s", room_id)
success = await self.store.clear_partial_state_room(room_id)
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index a1361af272..b908674529 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -98,8 +98,8 @@ class FederationEventHandler:
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main
- self._storage = hs.get_storage()
- self._state_storage = self._storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
self._state_handler = hs.get_state_handler()
self._event_creation_handler = hs.get_event_creation_handler()
@@ -535,7 +535,9 @@ class FederationEventHandler:
)
return
await self._store.update_state_for_partial_state_event(event, context)
- self._state_storage.notify_event_un_partial_stated(event.event_id)
+ self._state_storage_controller.notify_event_un_partial_stated(
+ event.event_id
+ )
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str]
@@ -835,7 +837,9 @@ class FederationEventHandler:
try:
# Get the state of the events we know about
- ours = await self._state_storage.get_state_groups_ids(room_id, seen)
+ ours = await self._state_storage_controller.get_state_groups_ids(
+ room_id, seen
+ )
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps: List[StateMap[str]] = list(ours.values())
@@ -1436,7 +1440,7 @@ class FederationEventHandler:
# we're not bothering about room state, so flag the event as an outlier.
event.internal_metadata.outlier = True
- context = EventContext.for_outlier(self._storage)
+ context = EventContext.for_outlier(self._storage_controllers)
try:
validate_event_for_room_version(room_version_obj, event)
check_auth_rules_for_event(room_version_obj, event, auth)
@@ -1613,7 +1617,7 @@ class FederationEventHandler:
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.
- state_sets_d = await self._state_storage.get_state_groups_ids(
+ state_sets_d = await self._state_storage_controller.get_state_groups_ids(
event.room_id, extrem_ids
)
state_sets: List[StateMap[str]] = list(state_sets_d.values())
@@ -1885,7 +1889,7 @@ class FederationEventHandler:
# create a new state group as a delta from the existing one.
prev_group = context.state_group
- state_group = await self._state_storage.store_state_group(
+ state_group = await self._state_storage_controller.store_state_group(
event.event_id,
event.room_id,
prev_group=prev_group,
@@ -1894,7 +1898,7 @@ class FederationEventHandler:
)
return EventContext.with_state(
- storage=self._storage,
+ storage=self._storage_controllers,
state_group=state_group,
state_group_before_event=context.state_group_before_event,
state_delta_due_to_event=state_updates,
@@ -1984,11 +1988,14 @@ class FederationEventHandler:
)
return result["max_stream_id"]
else:
- assert self._storage.persistence
+ assert self._storage_controllers.persistence
# Note that this returns the events that were persisted, which may not be
# the same as were passed in if some were deduplicated due to transaction IDs.
- events, max_stream_token = await self._storage.persistence.persist_events(
+ (
+ events,
+ max_stream_token,
+ ) = await self._storage_controllers.persistence.persist_events(
event_and_contexts, backfilled=backfilled
)
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index fbdbeeedfd..d2b489e816 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -67,8 +67,8 @@ class InitialSyncHandler:
]
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
- self.storage = hs.get_storage()
- self.state_storage = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
async def snapshot_all_rooms(
self,
@@ -198,7 +198,8 @@ class InitialSyncHandler:
event.stream_ordering,
)
deferred_room_state = run_in_background(
- self.state_storage.get_state_for_events, [event.event_id]
+ self._state_storage_controller.get_state_for_events,
+ [event.event_id],
).addCallback(
lambda states: cast(StateMap[EventBase], states[event.event_id])
)
@@ -218,7 +219,7 @@ class InitialSyncHandler:
).addErrback(unwrapFirstError)
messages = await filter_events_for_client(
- self.storage, user_id, messages
+ self._storage_controllers, user_id, messages
)
start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
@@ -355,7 +356,9 @@ class InitialSyncHandler:
member_event_id: str,
is_peeking: bool,
) -> JsonDict:
- room_state = await self.state_storage.get_state_for_event(member_event_id)
+ room_state = await self._state_storage_controller.get_state_for_event(
+ member_event_id
+ )
limit = pagin_config.limit if pagin_config else None
if limit is None:
@@ -369,7 +372,7 @@ class InitialSyncHandler:
)
messages = await filter_events_for_client(
- self.storage, user_id, messages, is_peeking=is_peeking
+ self._storage_controllers, user_id, messages, is_peeking=is_peeking
)
start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token)
@@ -474,7 +477,7 @@ class InitialSyncHandler:
)
messages = await filter_events_for_client(
- self.storage, user_id, messages, is_peeking=is_peeking
+ self._storage_controllers, user_id, messages, is_peeking=is_peeking
)
start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 38b71a2c96..f377769071 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -84,8 +84,8 @@ class MessageHandler:
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
- self.state_storage = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
self._event_serializer = hs.get_event_client_serializer()
self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages
@@ -132,7 +132,7 @@ class MessageHandler:
assert (
membership_event_id is not None
), "check_user_in_room_or_world_readable returned invalid data"
- room_state = await self.state_storage.get_state_for_events(
+ room_state = await self._state_storage_controller.get_state_for_events(
[membership_event_id], StateFilter.from_types([key])
)
data = room_state[membership_event_id].get(key)
@@ -193,7 +193,7 @@ class MessageHandler:
# check whether the user is in the room at that time to determine
# whether they should be treated as peeking.
- state_map = await self.state_storage.get_state_for_event(
+ state_map = await self._state_storage_controller.get_state_for_event(
last_event.event_id,
StateFilter.from_types([(EventTypes.Member, user_id)]),
)
@@ -206,7 +206,7 @@ class MessageHandler:
is_peeking = not joined
visible_events = await filter_events_for_client(
- self.storage,
+ self._storage_controllers,
user_id,
[last_event],
filter_send_to_client=False,
@@ -214,8 +214,10 @@ class MessageHandler:
)
if visible_events:
- room_state_events = await self.state_storage.get_state_for_events(
- [last_event.event_id], state_filter=state_filter
+ room_state_events = (
+ await self._state_storage_controller.get_state_for_events(
+ [last_event.event_id], state_filter=state_filter
+ )
)
room_state: Mapping[Any, EventBase] = room_state_events[
last_event.event_id
@@ -244,8 +246,10 @@ class MessageHandler:
assert (
membership_event_id is not None
), "check_user_in_room_or_world_readable returned invalid data"
- room_state_events = await self.state_storage.get_state_for_events(
- [membership_event_id], state_filter=state_filter
+ room_state_events = (
+ await self._state_storage_controller.get_state_for_events(
+ [membership_event_id], state_filter=state_filter
+ )
)
room_state = room_state_events[membership_event_id]
@@ -402,7 +406,7 @@ class EventCreationHandler:
self.auth = hs.get_auth()
self._event_auth_handler = hs.get_event_auth_handler()
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
@@ -1032,7 +1036,7 @@ class EventCreationHandler:
# after it is created
if builder.internal_metadata.outlier:
event.internal_metadata.outlier = True
- context = EventContext.for_outlier(self.storage)
+ context = EventContext.for_outlier(self._storage_controllers)
elif (
event.type == EventTypes.MSC2716_INSERTION
and state_event_ids
@@ -1445,7 +1449,7 @@ class EventCreationHandler:
"""
extra_users = extra_users or []
- assert self.storage.persistence is not None
+ assert self._storage_controllers.persistence is not None
assert self._events_shard_config.should_handle(
self._instance_name, event.room_id
)
@@ -1679,7 +1683,7 @@ class EventCreationHandler:
event,
event_pos,
max_stream_token,
- ) = await self.storage.persistence.persist_event(
+ ) = await self._storage_controllers.persistence.persist_event(
event, context=context, backfilled=backfilled
)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 35afe6b855..6262a35822 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -129,8 +129,8 @@ class PaginationHandler:
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
- self.state_storage = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
self.clock = hs.get_clock()
self._server_name = hs.hostname
self._room_shutdown_handler = hs.get_room_shutdown_handler()
@@ -352,7 +352,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.add(room_id)
try:
async with self.pagination_lock.write(room_id):
- await self.storage.purge_events.purge_history(
+ await self._storage_controllers.purge_events.purge_history(
room_id, token, delete_local_events
)
logger.info("[purge] complete")
@@ -414,7 +414,7 @@ class PaginationHandler:
if joined:
raise SynapseError(400, "Users are still joined to this room")
- await self.storage.purge_events.purge_room(room_id)
+ await self._storage_controllers.purge_events.purge_room(room_id)
async def get_messages(
self,
@@ -529,7 +529,10 @@ class PaginationHandler:
events = await event_filter.filter(events)
events = await filter_events_for_client(
- self.storage, user_id, events, is_peeking=(member_event_id is None)
+ self._storage_controllers,
+ user_id,
+ events,
+ is_peeking=(member_event_id is None),
)
# if after the filter applied there are no more events
@@ -550,7 +553,7 @@ class PaginationHandler:
(EventTypes.Member, event.sender) for event in events
)
- state_ids = await self.state_storage.get_state_ids_for_event(
+ state_ids = await self._state_storage_controller.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter
)
@@ -664,7 +667,7 @@ class PaginationHandler:
400, "Users are still joined to this room"
)
- await self.storage.purge_events.purge_room(room_id)
+ await self._storage_controllers.purge_events.purge_room(room_id)
logger.info("complete")
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index ab7e54857d..9a1cc11bb3 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -69,7 +69,7 @@ class BundledAggregations:
class RelationsHandler:
def __init__(self, hs: "HomeServer"):
self._main_store = hs.get_datastores().main
- self._storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self._auth = hs.get_auth()
self._clock = hs.get_clock()
self._event_handler = hs.get_event_handler()
@@ -143,7 +143,10 @@ class RelationsHandler:
)
events = await filter_events_for_client(
- self._storage, user_id, events, is_peeking=(member_event_id is None)
+ self._storage_controllers,
+ user_id,
+ events,
+ is_peeking=(member_event_id is None),
)
now = self._clock.time_msec()
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index e2775b34f1..5c91d33f58 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1192,8 +1192,8 @@ class RoomContextHandler:
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
- self.state_storage = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
self._relations_handler = hs.get_relations_handler()
async def get_event_context(
@@ -1236,7 +1236,10 @@ class RoomContextHandler:
if use_admin_priviledge:
return events
return await filter_events_for_client(
- self.storage, user.to_string(), events, is_peeking=is_peeking
+ self._storage_controllers,
+ user.to_string(),
+ events,
+ is_peeking=is_peeking,
)
event = await self.store.get_event(
@@ -1293,7 +1296,7 @@ class RoomContextHandler:
# first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687
- state = await self.state_storage.get_state_for_events(
+ state = await self._state_storage_controller.get_state_for_events(
[last_event_id], state_filter=state_filter
)
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 7ce32f2e9c..1414e575d6 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -17,7 +17,7 @@ class RoomBatchHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastores().main
- self.state_storage = hs.get_storage().state
+ self._state_storage_controller = hs.get_storage_controllers().state
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
@@ -141,7 +141,7 @@ class RoomBatchHandler:
) = await self.store.get_max_depth_of(event_ids)
# mapping from (type, state_key) -> state_event_id
assert most_recent_event_id is not None
- prev_state_map = await self.state_storage.get_state_ids_for_event(
+ prev_state_map = await self._state_storage_controller.get_state_ids_for_event(
most_recent_event_id
)
# List of state event ID's
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index e02c915248..659f99f7e2 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -55,8 +55,8 @@ class SearchHandler:
self.hs = hs
self._event_serializer = hs.get_event_client_serializer()
self._relations_handler = hs.get_relations_handler()
- self.storage = hs.get_storage()
- self.state_storage = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
self.auth = hs.get_auth()
async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
@@ -460,7 +460,7 @@ class SearchHandler:
filtered_events = await search_filter.filter([r["event"] for r in results])
events = await filter_events_for_client(
- self.storage, user.to_string(), filtered_events
+ self._storage_controllers, user.to_string(), filtered_events
)
events.sort(key=lambda e: -rank_map[e.event_id])
@@ -559,7 +559,7 @@ class SearchHandler:
filtered_events = await search_filter.filter([r["event"] for r in results])
events = await filter_events_for_client(
- self.storage, user.to_string(), filtered_events
+ self._storage_controllers, user.to_string(), filtered_events
)
room_events.extend(events)
@@ -644,11 +644,11 @@ class SearchHandler:
)
events_before = await filter_events_for_client(
- self.storage, user.to_string(), res.events_before
+ self._storage_controllers, user.to_string(), res.events_before
)
events_after = await filter_events_for_client(
- self.storage, user.to_string(), res.events_after
+ self._storage_controllers, user.to_string(), res.events_after
)
context: JsonDict = {
@@ -677,7 +677,7 @@ class SearchHandler:
[(EventTypes.Member, sender) for sender in senders]
)
- state = await self.state_storage.get_state_for_event(
+ state = await self._state_storage_controller.get_state_for_event(
last_event_id, state_filter
)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c5c538e0c3..b5859dcb28 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -238,8 +238,8 @@ class SyncHandler:
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
- self.storage = hs.get_storage()
- self.state_storage = self.storage.state
+ self._storage_controllers = hs.get_storage_controllers()
+ self._state_storage_controller = self._storage_controllers.state
# TODO: flush cache entries on subsequent sync request.
# Once we get the next /sync request (ie, one with the same access token
@@ -512,7 +512,7 @@ class SyncHandler:
current_state_ids = frozenset(current_state_ids_map.values())
recents = await filter_events_for_client(
- self.storage,
+ self._storage_controllers,
sync_config.user.to_string(),
recents,
always_include_ids=current_state_ids,
@@ -580,7 +580,7 @@ class SyncHandler:
current_state_ids = frozenset(current_state_ids_map.values())
loaded_recents = await filter_events_for_client(
- self.storage,
+ self._storage_controllers,
sync_config.user.to_string(),
loaded_recents,
always_include_ids=current_state_ids,
@@ -630,7 +630,7 @@ class SyncHandler:
event: event of interest
state_filter: The state filter used to fetch state from the database.
"""
- state_ids = await self.state_storage.get_state_ids_for_event(
+ state_ids = await self._state_storage_controller.get_state_ids_for_event(
event.event_id, state_filter=state_filter or StateFilter.all()
)
if event.is_state():
@@ -710,7 +710,7 @@ class SyncHandler:
return None
last_event = last_events[-1]
- state_ids = await self.state_storage.get_state_ids_for_event(
+ state_ids = await self._state_storage_controller.get_state_ids_for_event(
last_event.event_id,
state_filter=StateFilter.from_types(
[(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@@ -889,13 +889,15 @@ class SyncHandler:
if full_state:
if batch:
current_state_ids = (
- await self.state_storage.get_state_ids_for_event(
+ await self._state_storage_controller.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter
)
)
- state_ids = await self.state_storage.get_state_ids_for_event(
- batch.events[0].event_id, state_filter=state_filter
+ state_ids = (
+ await self._state_storage_controller.get_state_ids_for_event(
+ batch.events[0].event_id, state_filter=state_filter
+ )
)
else:
@@ -915,7 +917,7 @@ class SyncHandler:
elif batch.limited:
if batch:
state_at_timeline_start = (
- await self.state_storage.get_state_ids_for_event(
+ await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter
)
)
@@ -950,7 +952,7 @@ class SyncHandler:
if batch:
current_state_ids = (
- await self.state_storage.get_state_ids_for_event(
+ await self._state_storage_controller.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter
)
)
@@ -982,7 +984,7 @@ class SyncHandler:
# So we fish out all the member events corresponding to the
# timeline here, and then dedupe any redundant ones below.
- state_ids = await self.state_storage.get_state_ids_for_event(
+ state_ids = await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id,
# we only want members!
state_filter=StateFilter.from_types(
diff --git a/synapse/notifier.py b/synapse/notifier.py
index c2b66eec62..1100434b3f 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -221,7 +221,7 @@ class Notifier:
self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {}
self.hs = hs
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastores().main
self.pending_new_room_events: List[_PendingRoomEventEntry] = []
@@ -623,7 +623,7 @@ class Notifier:
if name == "room":
new_events = await filter_events_for_client(
- self.storage,
+ self._storage_controllers,
user.to_string(),
new_events,
is_peeking=is_peeking,
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index d5603596c0..e96fb45e9f 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -65,7 +65,7 @@ class HttpPusher(Pusher):
def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
super().__init__(hs, pusher_config)
- self.storage = self.hs.get_storage()
+ self._storage_controllers = self.hs.get_storage_controllers()
self.app_display_name = pusher_config.app_display_name
self.device_display_name = pusher_config.device_display_name
self.pushkey_ts = pusher_config.ts
@@ -343,7 +343,9 @@ class HttpPusher(Pusher):
}
return d
- ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id)
+ ctx = await push_tools.get_context_for_event(
+ self._storage_controllers, event, self.user_id
+ )
d = {
"notification": {
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 84124af965..63aefd07f5 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -114,10 +114,10 @@ class Mailer:
self.send_email_handler = hs.get_send_email_handler()
self.store = self.hs.get_datastores().main
- self.state_storage = self.hs.get_storage().state
+ self._state_storage_controller = self.hs.get_storage_controllers().state
self.macaroon_gen = self.hs.get_macaroon_generator()
self.state_handler = self.hs.get_state_handler()
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.app_name = app_name
self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects
@@ -456,7 +456,7 @@ class Mailer:
}
the_events = await filter_events_for_client(
- self.storage, user_id, results.events_before
+ self._storage_controllers, user_id, results.events_before
)
the_events.append(notif_event)
@@ -494,7 +494,7 @@ class Mailer:
)
else:
# Attempt to check the historical state for the room.
- historical_state = await self.state_storage.get_state_for_event(
+ historical_state = await self._state_storage_controller.get_state_for_event(
event.event_id, StateFilter.from_types((type_state_key,))
)
sender_state_event = historical_state.get(type_state_key)
@@ -767,8 +767,10 @@ class Mailer:
member_event_ids.append(sender_state_event_id)
else:
# Attempt to check the historical state for the room.
- historical_state = await self.state_storage.get_state_for_event(
- event_id, StateFilter.from_types((type_state_key,))
+ historical_state = (
+ await self._state_storage_controller.get_state_for_event(
+ event_id, StateFilter.from_types((type_state_key,))
+ )
)
sender_state_event = historical_state.get(type_state_key)
if sender_state_event:
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index a1bf5b20dd..8397229ccb 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -16,7 +16,7 @@ from typing import Dict
from synapse.api.constants import ReceiptTypes
from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
-from synapse.storage import Storage
+from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
@@ -52,7 +52,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
async def get_context_for_event(
- storage: Storage, ev: EventBase, user_id: str
+ storage: StorageControllers, ev: EventBase, user_id: str
) -> Dict[str, str]:
ctx = {}
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 3e7300b4a1..eed29cd597 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -69,7 +69,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
super().__init__(hs)
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.clock = hs.get_clock()
self.federation_event_handler = hs.get_federation_event_handler()
@@ -133,7 +133,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event.internal_metadata.outlier = event_payload["outlier"]
context = EventContext.deserialize(
- self.storage, event_payload["context"]
+ self._storage_controllers, event_payload["context"]
)
event_and_contexts.append((event, context))
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index ce78176836..c2b2588ea5 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -70,7 +70,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.clock = hs.get_clock()
@staticmethod
@@ -127,7 +127,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
event.internal_metadata.outlier = content["outlier"]
requester = Requester.deserialize(self.store, content["requester"])
- context = EventContext.deserialize(self.storage, content["context"])
+ context = EventContext.deserialize(
+ self._storage_controllers, content["context"]
+ )
ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
diff --git a/synapse/server.py b/synapse/server.py
index 3fd23aaf52..a66ec228db 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -123,7 +123,8 @@ from synapse.server_notices.worker_server_notices_sender import (
WorkerServerNoticesSender,
)
from synapse.state import StateHandler, StateResolutionHandler
-from synapse.storage import Databases, Storage
+from synapse.storage import Databases
+from synapse.storage.controllers import StorageControllers
from synapse.streams.events import EventSources
from synapse.types import DomainSpecificString, ISynapseReactor
from synapse.util import Clock
@@ -729,8 +730,8 @@ class HomeServer(metaclass=abc.ABCMeta):
return PasswordPolicyHandler(self)
@cache_in_self
- def get_storage(self) -> Storage:
- return Storage(self, self.get_datastores())
+ def get_storage_controllers(self) -> StorageControllers:
+ return StorageControllers(self, self.get_datastores())
@cache_in_self
def get_replication_streamer(self) -> ReplicationStreamer:
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 9c9d946f38..bf09f5128a 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -127,10 +127,10 @@ class StateHandler:
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
- self.state_storage = hs.get_storage().state
+ self._state_storage_controller = hs.get_storage_controllers().state
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
- self._storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
@overload
async def get_current_state(
@@ -337,12 +337,14 @@ class StateHandler:
#
if not state_group_before_event:
- state_group_before_event = await self.state_storage.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=state_group_before_event_prev_group,
- delta_ids=deltas_to_state_group_before_event,
- current_state_ids=state_ids_before_event,
+ state_group_before_event = (
+ await self._state_storage_controller.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=state_group_before_event_prev_group,
+ delta_ids=deltas_to_state_group_before_event,
+ current_state_ids=state_ids_before_event,
+ )
)
# Assign the new state group to the cached state entry.
@@ -359,7 +361,7 @@ class StateHandler:
if not event.is_state():
return EventContext.with_state(
- storage=self._storage,
+ storage=self._storage_controllers,
state_group_before_event=state_group_before_event,
state_group=state_group_before_event,
state_delta_due_to_event={},
@@ -382,16 +384,18 @@ class StateHandler:
state_ids_after_event[key] = event.event_id
delta_ids = {key: event.event_id}
- state_group_after_event = await self.state_storage.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=state_group_before_event,
- delta_ids=delta_ids,
- current_state_ids=state_ids_after_event,
+ state_group_after_event = (
+ await self._state_storage_controller.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=state_group_before_event,
+ delta_ids=delta_ids,
+ current_state_ids=state_ids_after_event,
+ )
)
return EventContext.with_state(
- storage=self._storage,
+ storage=self._storage_controllers,
state_group=state_group_after_event,
state_group_before_event=state_group_before_event,
state_delta_due_to_event=delta_ids,
@@ -416,7 +420,9 @@ class StateHandler:
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
- state_groups = await self.state_storage.get_state_group_for_events(event_ids)
+ state_groups = await self._state_storage_controller.get_state_group_for_events(
+ event_ids
+ )
state_group_ids = state_groups.values()
@@ -424,8 +430,13 @@ class StateHandler:
state_group_ids_set = set(state_group_ids)
if len(state_group_ids_set) == 1:
(state_group_id,) = state_group_ids_set
- state = await self.state_storage.get_state_for_groups(state_group_ids_set)
- prev_group, delta_ids = await self.state_storage.get_state_group_delta(
+ state = await self._state_storage_controller.get_state_for_groups(
+ state_group_ids_set
+ )
+ (
+ prev_group,
+ delta_ids,
+ ) = await self._state_storage_controller.get_state_group_delta(
state_group_id
)
return _StateCacheEntry(
@@ -439,7 +450,7 @@ class StateHandler:
room_version = await self.store.get_room_version_id(room_id)
- state_to_resolve = await self.state_storage.get_state_for_groups(
+ state_to_resolve = await self._state_storage_controller.get_state_for_groups(
state_group_ids_set
)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 105e4e1fec..bac21ecf9c 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -18,41 +18,20 @@ The storage layer is split up into multiple parts to allow Synapse to run
against different configurations of databases (e.g. single or multiple
databases). The `DatabasePool` class represents connections to a single physical
database. The `databases` are classes that talk directly to a `DatabasePool`
-instance and have associated schemas, background updates, etc. On top of those
-there are classes that provide high level interfaces that combine calls to
-multiple `databases`.
+instance and have associated schemas, background updates, etc.
+
+On top of the databases are the StorageControllers, located in the
+`synapse.storage.controllers` module. These classes provide high level
+interfaces that combine calls to multiple `databases`. They are bundled into the
+`StorageControllers` singleton for ease of use, and exposed via
+`HomeServer.get_storage_controllers()`.
There are also schemas that get applied to every database, regardless of the
data stores associated with them (e.g. the schema version tables), which are
stored in `synapse.storage.schema`.
"""
-from typing import TYPE_CHECKING
from synapse.storage.databases import Databases
from synapse.storage.databases.main import DataStore
-from synapse.storage.persist_events import EventsPersistenceStorage
-from synapse.storage.purge_events import PurgeEventsStorage
-from synapse.storage.state import StateGroupStorage
-
-if TYPE_CHECKING:
- from synapse.server import HomeServer
-
__all__ = ["Databases", "DataStore"]
-
-
-class Storage:
- """The high level interfaces for talking to various storage layers."""
-
- def __init__(self, hs: "HomeServer", stores: Databases):
- # We include the main data store here mainly so that we don't have to
- # rewrite all the existing code to split it into high vs low level
- # interfaces.
- self.main = stores.main
-
- self.purge_events = PurgeEventsStorage(hs, stores)
- self.state = StateGroupStorage(hs, stores)
-
- self.persistence = None
- if stores.persist_events:
- self.persistence = EventsPersistenceStorage(hs, stores)
diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py
new file mode 100644
index 0000000000..992261d07b
--- /dev/null
+++ b/synapse/storage/controllers/__init__.py
@@ -0,0 +1,46 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING
+
+from synapse.storage.controllers.persist_events import (
+ EventsPersistenceStorageController,
+)
+from synapse.storage.controllers.purge_events import PurgeEventsStorageController
+from synapse.storage.controllers.state import StateGroupStorageController
+from synapse.storage.databases import Databases
+from synapse.storage.databases.main import DataStore
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+__all__ = ["Databases", "DataStore"]
+
+
+class StorageControllers:
+ """The high level interfaces for talking to various storage controller layers."""
+
+ def __init__(self, hs: "HomeServer", stores: Databases):
+ # We include the main data store here mainly so that we don't have to
+ # rewrite all the existing code to split it into high vs low level
+ # interfaces.
+ self.main = stores.main
+
+ self.purge_events = PurgeEventsStorageController(hs, stores)
+ self.state = StateGroupStorageController(hs, stores)
+
+ self.persistence = None
+ if stores.persist_events:
+ self.persistence = EventsPersistenceStorageController(hs, stores)
diff --git a/synapse/storage/persist_events.py b/synapse/storage/controllers/persist_events.py
index a21dea91c8..ef8c135b12 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -272,7 +272,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
pass
-class EventsPersistenceStorage:
+class EventsPersistenceStorageController:
"""High level interface for handling persisting newly received events.
Takes care of batching up events by room, and calculating the necessary
diff --git a/synapse/storage/purge_events.py b/synapse/storage/controllers/purge_events.py
index 30669beb7c..9ca50d6a09 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/controllers/purge_events.py
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class PurgeEventsStorage:
+class PurgeEventsStorageController:
"""High level interface for purging rooms and event history."""
def __init__(self, hs: "HomeServer", stores: Databases):
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
new file mode 100644
index 0000000000..0f09953086
--- /dev/null
+++ b/synapse/storage/controllers/state.py
@@ -0,0 +1,351 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import (
+ TYPE_CHECKING,
+ Awaitable,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+)
+
+from synapse.events import EventBase
+from synapse.storage.state import StateFilter
+from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
+from synapse.types import MutableStateMap, StateMap
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+ from synapse.storage.databases import Databases
+
+logger = logging.getLogger(__name__)
+
+
+class StateGroupStorageController:
+ """High level interface to fetching state for event."""
+
+ def __init__(self, hs: "HomeServer", stores: "Databases"):
+ self._is_mine_id = hs.is_mine_id
+ self.stores = stores
+ self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
+
+ def notify_event_un_partial_stated(self, event_id: str) -> None:
+ self._partial_state_events_tracker.notify_un_partial_stated(event_id)
+
+ async def get_state_group_delta(
+ self, state_group: int
+ ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
+ """Given a state group try to return a previous group and a delta between
+ the old and the new.
+
+ Args:
+ state_group: The state group used to retrieve state deltas.
+
+ Returns:
+ A tuple of the previous group and a state map of the event IDs which
+ make up the delta between the old and new state groups.
+ """
+
+ state_group_delta = await self.stores.state.get_state_group_delta(state_group)
+ return state_group_delta.prev_group, state_group_delta.delta_ids
+
+ async def get_state_groups_ids(
+ self, _room_id: str, event_ids: Collection[str]
+ ) -> Dict[int, MutableStateMap[str]]:
+ """Get the event IDs of all the state for the state groups for the given events
+
+ Args:
+ _room_id: id of the room for these events
+ event_ids: ids of the events
+
+ Returns:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie they are outliers or unknown)
+ """
+ if not event_ids:
+ return {}
+
+ event_to_groups = await self.get_state_group_for_events(event_ids)
+
+ groups = set(event_to_groups.values())
+ group_to_state = await self.stores.state._get_state_for_groups(groups)
+
+ return group_to_state
+
+ async def get_state_ids_for_group(
+ self, state_group: int, state_filter: Optional[StateFilter] = None
+ ) -> StateMap[str]:
+ """Get the event IDs of all the state in the given state group
+
+ Args:
+ state_group: A state group for which we want to get the state IDs.
+ state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
+
+ Returns:
+ Resolves to a map of (type, state_key) -> event_id
+ """
+ group_to_state = await self.get_state_for_groups((state_group,), state_filter)
+
+ return group_to_state[state_group]
+
+ async def get_state_groups(
+ self, room_id: str, event_ids: Collection[str]
+ ) -> Dict[int, List[EventBase]]:
+ """Get the state groups for the given list of event_ids
+
+ Args:
+ room_id: ID of the room for these events.
+ event_ids: The event IDs to retrieve state for.
+
+ Returns:
+ dict of state_group_id -> list of state events.
+ """
+ if not event_ids:
+ return {}
+
+ group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
+
+ state_event_map = await self.stores.main.get_events(
+ [
+ ev_id
+ for group_ids in group_to_ids.values()
+ for ev_id in group_ids.values()
+ ],
+ get_prev_content=False,
+ )
+
+ return {
+ group: [
+ state_event_map[v]
+ for v in event_id_map.values()
+ if v in state_event_map
+ ]
+ for group, event_id_map in group_to_ids.items()
+ }
+
+ def _get_state_groups_from_groups(
+ self, groups: List[int], state_filter: StateFilter
+ ) -> Awaitable[Dict[int, StateMap[str]]]:
+ """Returns the state groups for a given set of groups, filtering on
+ types of state events.
+
+ Args:
+ groups: list of state group IDs to query
+ state_filter: The state filter used to fetch state
+ from the database.
+
+ Returns:
+ Dict of state group to state map.
+ """
+
+ return self.stores.state._get_state_groups_from_groups(groups, state_filter)
+
+ async def get_state_for_events(
+ self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
+ ) -> Dict[str, StateMap[EventBase]]:
+ """Given a list of event_ids and type tuples, return a list of state
+ dicts for each event.
+
+ Args:
+ event_ids: The events to fetch the state of.
+ state_filter: The state filter used to fetch state.
+
+ Returns:
+ A dict of (event_id) -> (type, state_key) -> [state_events]
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie they are outliers or unknown)
+ """
+ await_full_state = True
+ if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+ await_full_state = False
+
+ event_to_groups = await self.get_state_group_for_events(
+ event_ids, await_full_state=await_full_state
+ )
+
+ groups = set(event_to_groups.values())
+ group_to_state = await self.stores.state._get_state_for_groups(
+ groups, state_filter or StateFilter.all()
+ )
+
+ state_event_map = await self.stores.main.get_events(
+ [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
+ get_prev_content=False,
+ )
+
+ event_to_state = {
+ event_id: {
+ k: state_event_map[v]
+ for k, v in group_to_state[group].items()
+ if v in state_event_map
+ }
+ for event_id, group in event_to_groups.items()
+ }
+
+ return {event: event_to_state[event] for event in event_ids}
+
+ async def get_state_ids_for_events(
+ self,
+ event_ids: Collection[str],
+ state_filter: Optional[StateFilter] = None,
+ ) -> Dict[str, StateMap[str]]:
+ """
+ Get the state dicts corresponding to a list of events, containing the event_ids
+ of the state events (as opposed to the events themselves)
+
+ Args:
+ event_ids: events whose state should be returned
+ state_filter: The state filter used to fetch state from the database.
+
+ Returns:
+ A dict from event_id -> (type, state_key) -> event_id
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie they are outliers or unknown)
+ """
+ await_full_state = True
+ if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+ await_full_state = False
+
+ event_to_groups = await self.get_state_group_for_events(
+ event_ids, await_full_state=await_full_state
+ )
+
+ groups = set(event_to_groups.values())
+ group_to_state = await self.stores.state._get_state_for_groups(
+ groups, state_filter or StateFilter.all()
+ )
+
+ event_to_state = {
+ event_id: group_to_state[group]
+ for event_id, group in event_to_groups.items()
+ }
+
+ return {event: event_to_state[event] for event in event_ids}
+
+ async def get_state_for_event(
+ self, event_id: str, state_filter: Optional[StateFilter] = None
+ ) -> StateMap[EventBase]:
+ """
+ Get the state dict corresponding to a particular event
+
+ Args:
+ event_id: event whose state should be returned
+ state_filter: The state filter used to fetch state from the database.
+
+ Returns:
+ A dict from (type, state_key) -> state_event
+
+ Raises:
+ RuntimeError if we don't have a state group for the event (ie it is an
+ outlier or is unknown)
+ """
+ state_map = await self.get_state_for_events(
+ [event_id], state_filter or StateFilter.all()
+ )
+ return state_map[event_id]
+
+ async def get_state_ids_for_event(
+ self, event_id: str, state_filter: Optional[StateFilter] = None
+ ) -> StateMap[str]:
+ """
+ Get the state dict corresponding to a particular event
+
+ Args:
+ event_id: event whose state should be returned
+ state_filter: The state filter used to fetch state from the database.
+
+ Returns:
+ A dict from (type, state_key) -> state_event_id
+
+ Raises:
+ RuntimeError if we don't have a state group for the event (ie it is an
+ outlier or is unknown)
+ """
+ state_map = await self.get_state_ids_for_events(
+ [event_id], state_filter or StateFilter.all()
+ )
+ return state_map[event_id]
+
+ def get_state_for_groups(
+ self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
+ ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key
+
+ Args:
+ groups: list of state groups for which we want to get the state.
+ state_filter: The state filter used to fetch state.
+ from the database.
+
+ Returns:
+ Dict of state group to state map.
+ """
+ return self.stores.state._get_state_for_groups(
+ groups, state_filter or StateFilter.all()
+ )
+
+ async def get_state_group_for_events(
+ self,
+ event_ids: Collection[str],
+ await_full_state: bool = True,
+ ) -> Mapping[str, int]:
+ """Returns mapping event_id -> state_group
+
+ Args:
+ event_ids: events to get state groups for
+ await_full_state: if true, will block if we do not yet have complete
+ state at these events.
+ """
+ if await_full_state:
+ await self._partial_state_events_tracker.await_full_state(event_ids)
+
+ return await self.stores.main._get_state_group_for_events(event_ids)
+
+ async def store_state_group(
+ self,
+ event_id: str,
+ room_id: str,
+ prev_group: Optional[int],
+ delta_ids: Optional[StateMap[str]],
+ current_state_ids: StateMap[str],
+ ) -> int:
+ """Store a new set of state, returning a newly assigned state group.
+
+ Args:
+ event_id: The event ID for which the state was calculated.
+ room_id: ID of the room for which the state was calculated.
+ prev_group: A previous state group for the room, optional.
+ delta_ids: The delta between state at `prev_group` and
+ `current_state_ids`, if `prev_group` was given. Same format as
+ `current_state_ids`.
+ current_state_ids: The state to store. Map of (type, state_key)
+ to event_id.
+
+ Returns:
+ The state group ID
+ """
+ return await self.stores.state.store_state_group(
+ event_id, room_id, prev_group, delta_ids, current_state_ids
+ )
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index ab630953ac..96aaffb53c 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -15,7 +15,6 @@
import logging
from typing import (
TYPE_CHECKING,
- Awaitable,
Callable,
Collection,
Dict,
@@ -32,15 +31,11 @@ import attr
from frozendict import frozendict
from synapse.api.constants import EventTypes
-from synapse.events import EventBase
-from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
from synapse.types import MutableStateMap, StateKey, StateMap
if TYPE_CHECKING:
from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
- from synapse.server import HomeServer
- from synapse.storage.databases import Databases
logger = logging.getLogger(__name__)
@@ -578,318 +573,3 @@ _ALL_NON_MEMBER_STATE_FILTER = StateFilter(
types=frozendict({EventTypes.Member: frozenset()}), include_others=True
)
_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False)
-
-
-class StateGroupStorage:
- """High level interface to fetching state for event."""
-
- def __init__(self, hs: "HomeServer", stores: "Databases"):
- self._is_mine_id = hs.is_mine_id
- self.stores = stores
- self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
-
- def notify_event_un_partial_stated(self, event_id: str) -> None:
- self._partial_state_events_tracker.notify_un_partial_stated(event_id)
-
- async def get_state_group_delta(
- self, state_group: int
- ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
- """Given a state group try to return a previous group and a delta between
- the old and the new.
-
- Args:
- state_group: The state group used to retrieve state deltas.
-
- Returns:
- A tuple of the previous group and a state map of the event IDs which
- make up the delta between the old and new state groups.
- """
-
- state_group_delta = await self.stores.state.get_state_group_delta(state_group)
- return state_group_delta.prev_group, state_group_delta.delta_ids
-
- async def get_state_groups_ids(
- self, _room_id: str, event_ids: Collection[str]
- ) -> Dict[int, MutableStateMap[str]]:
- """Get the event IDs of all the state for the state groups for the given events
-
- Args:
- _room_id: id of the room for these events
- event_ids: ids of the events
-
- Returns:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
-
- Raises:
- RuntimeError if we don't have a state group for one or more of the events
- (ie they are outliers or unknown)
- """
- if not event_ids:
- return {}
-
- event_to_groups = await self.get_state_group_for_events(event_ids)
-
- groups = set(event_to_groups.values())
- group_to_state = await self.stores.state._get_state_for_groups(groups)
-
- return group_to_state
-
- async def get_state_ids_for_group(
- self, state_group: int, state_filter: Optional[StateFilter] = None
- ) -> StateMap[str]:
- """Get the event IDs of all the state in the given state group
-
- Args:
- state_group: A state group for which we want to get the state IDs.
- state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
-
- Returns:
- Resolves to a map of (type, state_key) -> event_id
- """
- group_to_state = await self.get_state_for_groups((state_group,), state_filter)
-
- return group_to_state[state_group]
-
- async def get_state_groups(
- self, room_id: str, event_ids: Collection[str]
- ) -> Dict[int, List[EventBase]]:
- """Get the state groups for the given list of event_ids
-
- Args:
- room_id: ID of the room for these events.
- event_ids: The event IDs to retrieve state for.
-
- Returns:
- dict of state_group_id -> list of state events.
- """
- if not event_ids:
- return {}
-
- group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
-
- state_event_map = await self.stores.main.get_events(
- [
- ev_id
- for group_ids in group_to_ids.values()
- for ev_id in group_ids.values()
- ],
- get_prev_content=False,
- )
-
- return {
- group: [
- state_event_map[v]
- for v in event_id_map.values()
- if v in state_event_map
- ]
- for group, event_id_map in group_to_ids.items()
- }
-
- def _get_state_groups_from_groups(
- self, groups: List[int], state_filter: StateFilter
- ) -> Awaitable[Dict[int, StateMap[str]]]:
- """Returns the state groups for a given set of groups, filtering on
- types of state events.
-
- Args:
- groups: list of state group IDs to query
- state_filter: The state filter used to fetch state
- from the database.
-
- Returns:
- Dict of state group to state map.
- """
-
- return self.stores.state._get_state_groups_from_groups(groups, state_filter)
-
- async def get_state_for_events(
- self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
- ) -> Dict[str, StateMap[EventBase]]:
- """Given a list of event_ids and type tuples, return a list of state
- dicts for each event.
-
- Args:
- event_ids: The events to fetch the state of.
- state_filter: The state filter used to fetch state.
-
- Returns:
- A dict of (event_id) -> (type, state_key) -> [state_events]
-
- Raises:
- RuntimeError if we don't have a state group for one or more of the events
- (ie they are outliers or unknown)
- """
- await_full_state = True
- if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
- await_full_state = False
-
- event_to_groups = await self.get_state_group_for_events(
- event_ids, await_full_state=await_full_state
- )
-
- groups = set(event_to_groups.values())
- group_to_state = await self.stores.state._get_state_for_groups(
- groups, state_filter or StateFilter.all()
- )
-
- state_event_map = await self.stores.main.get_events(
- [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
- get_prev_content=False,
- )
-
- event_to_state = {
- event_id: {
- k: state_event_map[v]
- for k, v in group_to_state[group].items()
- if v in state_event_map
- }
- for event_id, group in event_to_groups.items()
- }
-
- return {event: event_to_state[event] for event in event_ids}
-
- async def get_state_ids_for_events(
- self,
- event_ids: Collection[str],
- state_filter: Optional[StateFilter] = None,
- ) -> Dict[str, StateMap[str]]:
- """
- Get the state dicts corresponding to a list of events, containing the event_ids
- of the state events (as opposed to the events themselves)
-
- Args:
- event_ids: events whose state should be returned
- state_filter: The state filter used to fetch state from the database.
-
- Returns:
- A dict from event_id -> (type, state_key) -> event_id
-
- Raises:
- RuntimeError if we don't have a state group for one or more of the events
- (ie they are outliers or unknown)
- """
- await_full_state = True
- if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
- await_full_state = False
-
- event_to_groups = await self.get_state_group_for_events(
- event_ids, await_full_state=await_full_state
- )
-
- groups = set(event_to_groups.values())
- group_to_state = await self.stores.state._get_state_for_groups(
- groups, state_filter or StateFilter.all()
- )
-
- event_to_state = {
- event_id: group_to_state[group]
- for event_id, group in event_to_groups.items()
- }
-
- return {event: event_to_state[event] for event in event_ids}
-
- async def get_state_for_event(
- self, event_id: str, state_filter: Optional[StateFilter] = None
- ) -> StateMap[EventBase]:
- """
- Get the state dict corresponding to a particular event
-
- Args:
- event_id: event whose state should be returned
- state_filter: The state filter used to fetch state from the database.
-
- Returns:
- A dict from (type, state_key) -> state_event
-
- Raises:
- RuntimeError if we don't have a state group for the event (ie it is an
- outlier or is unknown)
- """
- state_map = await self.get_state_for_events(
- [event_id], state_filter or StateFilter.all()
- )
- return state_map[event_id]
-
- async def get_state_ids_for_event(
- self, event_id: str, state_filter: Optional[StateFilter] = None
- ) -> StateMap[str]:
- """
- Get the state dict corresponding to a particular event
-
- Args:
- event_id: event whose state should be returned
- state_filter: The state filter used to fetch state from the database.
-
- Returns:
- A dict from (type, state_key) -> state_event_id
-
- Raises:
- RuntimeError if we don't have a state group for the event (ie it is an
- outlier or is unknown)
- """
- state_map = await self.get_state_ids_for_events(
- [event_id], state_filter or StateFilter.all()
- )
- return state_map[event_id]
-
- def get_state_for_groups(
- self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
- ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
- """Gets the state at each of a list of state groups, optionally
- filtering by type/state_key
-
- Args:
- groups: list of state groups for which we want to get the state.
- state_filter: The state filter used to fetch state.
- from the database.
-
- Returns:
- Dict of state group to state map.
- """
- return self.stores.state._get_state_for_groups(
- groups, state_filter or StateFilter.all()
- )
-
- async def get_state_group_for_events(
- self,
- event_ids: Collection[str],
- await_full_state: bool = True,
- ) -> Mapping[str, int]:
- """Returns mapping event_id -> state_group
-
- Args:
- event_ids: events to get state groups for
- await_full_state: if true, will block if we do not yet have complete
- state at these events.
- """
- if await_full_state:
- await self._partial_state_events_tracker.await_full_state(event_ids)
-
- return await self.stores.main._get_state_group_for_events(event_ids)
-
- async def store_state_group(
- self,
- event_id: str,
- room_id: str,
- prev_group: Optional[int],
- delta_ids: Optional[StateMap[str]],
- current_state_ids: StateMap[str],
- ) -> int:
- """Store a new set of state, returning a newly assigned state group.
-
- Args:
- event_id: The event ID for which the state was calculated.
- room_id: ID of the room for which the state was calculated.
- prev_group: A previous state group for the room, optional.
- delta_ids: The delta between state at `prev_group` and
- `current_state_ids`, if `prev_group` was given. Same format as
- `current_state_ids`.
- current_state_ids: The state to store. Map of (type, state_key)
- to event_id.
-
- Returns:
- The state group ID
- """
- return await self.stores.state.store_state_group(
- event_id, room_id, prev_group, delta_ids, current_state_ids
- )
diff --git a/synapse/visibility.py b/synapse/visibility.py
index da4af02796..97548c14e3 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -20,7 +20,7 @@ from typing_extensions import Final
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.events import EventBase
from synapse.events.utils import prune_event
-from synapse.storage import Storage
+from synapse.storage.controllers import StorageControllers
from synapse.storage.state import StateFilter
from synapse.types import RetentionPolicy, StateMap, get_domain_from_id
@@ -47,7 +47,7 @@ _HISTORY_VIS_KEY: Final[Tuple[str, str]] = (EventTypes.RoomHistoryVisibility, ""
async def filter_events_for_client(
- storage: Storage,
+ storage: StorageControllers,
user_id: str,
events: List[EventBase],
is_peeking: bool = False,
@@ -268,7 +268,7 @@ async def filter_events_for_client(
async def filter_events_for_server(
- storage: Storage,
+ storage: StorageControllers,
server_name: str,
events: List[EventBase],
redact: bool = True,
@@ -360,7 +360,7 @@ async def filter_events_for_server(
async def _event_to_history_vis(
- storage: Storage, events: Collection[EventBase]
+ storage: StorageControllers, events: Collection[EventBase]
) -> Dict[str, str]:
"""Get the history visibility at each of the given events
@@ -407,7 +407,7 @@ async def _event_to_history_vis(
async def _event_to_memberships(
- storage: Storage, events: Collection[EventBase], server_name: str
+ storage: StorageControllers, events: Collection[EventBase], server_name: str
) -> Dict[str, StateMap[EventBase]]:
"""Get the remote membership list at each of the given events
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index defbc68c18..8ddce83b83 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -29,7 +29,7 @@ class TestEventContext(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass")
@@ -87,7 +87,7 @@ class TestEventContext(unittest.HomeserverTestCase):
def _check_serialize_deserialize(self, event, context):
serialized = self.get_success(context.serialize(event, self.store))
- d_context = EventContext.deserialize(self.storage, serialized)
+ d_context = EventContext.deserialize(self._storage_controllers, serialized)
self.assertEqual(context.state_group, d_context.state_group)
self.assertEqual(context.rejected, d_context.rejected)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index ec00900621..500c9ccfbc 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -50,7 +50,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
hs = self.setup_test_homeserver(federation_http_client=None)
self.handler = hs.get_federation_handler()
self.store = hs.get_datastores().main
- self.state_storage = hs.get_storage().state
+ self.state_storage_controller = hs.get_storage_controllers().state
self._event_auth_handler = hs.get_event_auth_handler()
return hs
@@ -338,7 +338,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# mapping from (type, state_key) -> state_event_id
assert most_recent_prev_event_id is not None
prev_state_map = self.get_success(
- self.state_storage.get_state_ids_for_event(most_recent_prev_event_id)
+ self.state_storage_controller.get_state_ids_for_event(
+ most_recent_prev_event_id
+ )
)
# List of state event ID's
prev_state_ids = list(prev_state_map.values())
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index e64b28f28b..1d5b2492c0 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -70,7 +70,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
) -> None:
OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
main_store = self.hs.get_datastores().main
- state_storage = self.hs.get_storage().state
+ state_storage_controller = self.hs.get_storage_controllers().state
# create the room
user_id = self.register_user("kermit", "test")
@@ -146,10 +146,11 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
if prev_exists_as_outlier:
prev_event.internal_metadata.outlier = True
- persistence = self.hs.get_storage().persistence
+ persistence = self.hs.get_storage_controllers().persistence
self.get_success(
persistence.persist_event(
- prev_event, EventContext.for_outlier(self.hs.get_storage())
+ prev_event,
+ EventContext.for_outlier(self.hs.get_storage_controllers()),
)
)
else:
@@ -216,7 +217,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
# check that the state at that event is as expected
state = self.get_success(
- state_storage.get_state_ids_for_event(pulled_event.event_id)
+ state_storage_controller.get_state_ids_for_event(pulled_event.event_id)
)
expected_state = {
(e.type, e.state_key): e.event_id for e in state_at_prev_event
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index f4f7ab4845..44da96c792 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -37,7 +37,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.handler = self.hs.get_event_creation_handler()
- self.persist_event_storage = self.hs.get_storage().persistence
+ self._persist_event_storage_controller = (
+ self.hs.get_storage_controllers().persistence
+ )
self.user_id = self.register_user("tester", "foobar")
self.access_token = self.login("tester", "foobar")
@@ -65,7 +67,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+ self._persist_event_storage_controller.persist_event(
+ memberEvent, memberEventContext
+ )
)
return memberEvent, memberEventContext
@@ -129,7 +133,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event3.event_id)
ret_event3, event_pos3, _ = self.get_success(
- self.persist_event_storage.persist_event(event3, context)
+ self._persist_event_storage_controller.persist_event(event3, context)
)
# Assert that the returned values match those from the initial event
@@ -143,7 +147,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event3.event_id)
events, _ = self.get_success(
- self.persist_event_storage.persist_events([(event3, context)])
+ self._persist_event_storage_controller.persist_events([(event3, context)])
)
ret_event4 = events[0]
@@ -166,7 +170,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event2.event_id)
events, _ = self.get_success(
- self.persist_event_storage.persist_events(
+ self._persist_event_storage_controller.persist_events(
[(event1, context1), (event2, context2)]
)
)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 4d658d29ca..a68c2ffd45 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -954,7 +954,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
self.get_success(
- self.hs.get_storage().persistence.persist_event(event, context)
+ self.hs.get_storage_controllers().persistence.persist_event(event, context)
)
def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 85be79d19d..c5705256e6 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -32,7 +32,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase):
self.master_store = hs.get_datastores().main
self.slaved_store = self.worker_hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
def replicate(self):
"""Tell the master side of replication that something has happened, and then
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 297a9e77f8..6d3d4afe52 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -262,7 +262,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
)
msg, msgctx = self.build_event()
self.get_success(
- self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)])
+ self._storage_controllers.persistence.persist_events(
+ [(j2, j2ctx), (msg, msgctx)]
+ )
)
self.replicate()
@@ -323,12 +325,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if backfill:
self.get_success(
- self.storage.persistence.persist_events(
+ self._storage_controllers.persistence.persist_events(
[(event, context)], backfilled=True
)
)
else:
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(event, context)
+ )
return event
diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py
index 5bbbd5fbcb..19f57115a1 100644
--- a/tests/replication/slave/storage/test_receipts.py
+++ b/tests/replication/slave/storage/test_receipts.py
@@ -31,7 +31,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
def prepare(self, reactor, clock, homeserver):
super().prepare(reactor, clock, homeserver)
self.room_creator = homeserver.get_room_creation_handler()
- self.persist_event_storage = self.hs.get_storage().persistence
+ self.persist_event_storage_controller = (
+ self.hs.get_storage_controllers().persistence
+ )
# Create a test user
self.ourUser = UserID.from_string(OUR_USER_ID)
@@ -61,7 +63,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
)
)
self.get_success(
- self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+ self.persist_event_storage_controller.persist_event(
+ memberEvent, memberEventContext
+ )
)
# Join the second user to the second room
@@ -76,7 +80,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
)
)
self.get_success(
- self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+ self.persist_event_storage_controller.persist_event(
+ memberEvent, memberEventContext
+ )
)
def test_return_empty_with_no_data(self):
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 0cdf1dec40..0d44102237 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2579,7 +2579,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
other_user_tok = self.login("user", "pass")
event_builder_factory = self.hs.get_event_builder_factory()
event_creation_handler = self.hs.get_event_creation_handler()
- storage = self.hs.get_storage()
+ storage_controllers = self.hs.get_storage_controllers()
# Create two rooms, one with a local user only and one with both a local
# and remote user.
@@ -2604,7 +2604,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
event_creation_handler.create_new_client_event(builder)
)
- self.get_success(storage.persistence.persist_event(event, context))
+ self.get_success(storage_controllers.persistence.persist_event(event, context))
# Now get rooms
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 2cd7a9e6c5..ac9c113354 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -130,7 +130,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
We do this by setting a very long time between purge jobs.
"""
store = self.hs.get_datastores().main
- storage = self.hs.get_storage()
+ storage_controllers = self.hs.get_storage_controllers()
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
# Send a first event, which should be filtered out at the end of the test.
@@ -155,7 +155,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(2, len(events), "events retrieved from database")
filtered_events = self.get_success(
- filter_events_for_client(storage, self.user_id, events)
+ filter_events_for_client(storage_controllers, self.user_id, events)
)
# We should only get one event back.
diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py
index 41a1bf6d89..1b7ee08ab2 100644
--- a/tests/rest/client/test_room_batch.py
+++ b/tests/rest/client/test_room_batch.py
@@ -88,7 +88,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.clock = clock
- self.storage = hs.get_storage()
+ self._storage_controllers = hs.get_storage_controllers()
self.virtual_user_id, _ = self.register_appservice_user(
"as_user_potato", self.appservice.token
@@ -168,7 +168,9 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
# Fetch the state_groups
state_group_map = self.get_success(
- self.storage.state.get_state_groups_ids(room_id, historical_event_ids)
+ self._storage_controllers.state.get_state_groups_ids(
+ room_id, historical_event_ids
+ )
)
# We expect all of the historical events to be using the same state_group
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index c7661e7186..a0ce077a99 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -393,7 +393,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
# We need to persist the events to the events and state_events
# tables.
persist_events_store._store_event_txn(
- txn, [(e, EventContext(self.hs.get_storage())) for e in events]
+ txn,
+ [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
)
# Actually call the function that calculates the auth chain stuff.
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index aaa3189b16..a76718e8f9 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -31,7 +31,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.state = self.hs.get_state_handler()
- self.persistence = self.hs.get_storage().persistence
+ self._persistence = self.hs.get_storage_controllers().persistence
self.store = self.hs.get_datastores().main
self.register_user("user", "pass")
@@ -71,7 +71,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
context = self.get_success(
self.state.compute_event_context(event, state_ids_before_event=state)
)
- self.get_success(self.persistence.persist_event(event, context))
+ self.get_success(self._persistence.persist_event(event, context))
def assert_extremities(self, expected_extremities):
"""Assert the current extremities for the room"""
@@ -148,7 +148,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
)
)
- self.get_success(self.persistence.persist_event(remote_event_2, context))
+ self.get_success(self._persistence.persist_event(remote_event_2, context))
# Check that we haven't dropped the old extremity.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@@ -353,7 +353,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.state = self.hs.get_state_handler()
- self.persistence = self.hs.get_storage().persistence
+ self._persistence = self.hs.get_storage_controllers().persistence
self.store = self.hs.get_datastores().main
def test_remote_user_rooms_cache_invalidated(self):
@@ -390,7 +390,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
)
context = self.get_success(self.state.compute_event_context(remote_event_1))
- self.get_success(self.persistence.persist_event(remote_event_1, context))
+ self.get_success(self._persistence.persist_event(remote_event_1, context))
# Call `get_rooms_for_user` to add the remote user to the cache
rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
@@ -437,7 +437,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
)
context = self.get_success(self.state.compute_event_context(remote_event_1))
- self.get_success(self.persistence.persist_event(remote_event_1, context))
+ self.get_success(self._persistence.persist_event(remote_event_1, context))
# Call `get_users_in_room` to add the remote user to the cache
users = self.get_success(self.store.get_users_in_room(room_id))
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 08cc60237e..92cd0dfc05 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -31,7 +31,7 @@ class PurgeTests(HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id)
self.store = hs.get_datastores().main
- self.storage = self.hs.get_storage()
+ self._storage_controllers = self.hs.get_storage_controllers()
def test_purge_history(self):
"""
@@ -51,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
# Purge everything before this topological token
self.get_success(
- self.storage.purge_events.purge_history(self.room_id, token_str, True)
+ self._storage_controllers.purge_events.purge_history(
+ self.room_id, token_str, True
+ )
)
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
@@ -79,7 +81,9 @@ class PurgeTests(HomeserverTestCase):
# Purge everything before this topological token
f = self.get_failure(
- self.storage.purge_events.purge_history(self.room_id, event, True),
+ self._storage_controllers.purge_events.purge_history(
+ self.room_id, event, True
+ ),
SynapseError,
)
self.assertIn("greater than forward", f.value.args[0])
@@ -105,7 +109,9 @@ class PurgeTests(HomeserverTestCase):
self.assertIsNotNone(create_event)
# Purge everything before this topological token
- self.get_success(self.storage.purge_events.purge_room(self.room_id))
+ self.get_success(
+ self._storage_controllers.purge_events.purge_room(self.room_id)
+ )
# The events aren't found.
self.store._invalidate_get_event_cache(create_event.event_id)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index d8d17ef379..6c4e63b77c 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -31,7 +31,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage = hs.get_storage_controllers()
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -71,7 +71,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(self._storage.persistence.persist_event(event, context))
return event
@@ -93,7 +93,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(self._storage.persistence.persist_event(event, context))
return event
@@ -114,7 +114,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(self._storage.persistence.persist_event(event, context))
return event
@@ -268,7 +268,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
- self.get_success(self.storage.persistence.persist_event(event_1, context_1))
+ self.get_success(self._storage.persistence.persist_event(event_1, context_1))
event_2, context_2 = self.get_success(
self.event_creation_handler.create_new_client_event(
@@ -287,7 +287,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
)
)
- self.get_success(self.storage.persistence.persist_event(event_2, context_2))
+ self.get_success(self._storage.persistence.persist_event(event_2, context_2))
# fetch one of the redactions
fetched = self.get_success(self.store.get_event(redaction_event_id1))
@@ -411,7 +411,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
self.get_success(
- self.storage.persistence.persist_event(redaction_event, context)
+ self._storage.persistence.persist_event(redaction_event, context)
)
# Now lets jump to the future where we have censored the redaction event
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 5b011e18cd..d497a19f63 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self._storage = hs.get_storage_controllers()
self.event_factory = hs.get_event_factory()
self.room = RoomID.from_string("!abcde:test")
@@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
def inject_room_event(self, **kwargs):
self.get_success(
- self.storage.persistence.persist_event(
+ self._storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
)
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 8dfc1e1db9..e747c6b50e 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -99,7 +99,9 @@ class EventSearchInsertionTest(HomeserverTestCase):
prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
prev_event = self.get_success(store.get_event(prev_event_ids[0]))
prev_state_map = self.get_success(
- self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0])
+ self.hs.get_storage_controllers().state.get_state_ids_for_event(
+ prev_event_ids[0]
+ )
)
event_dict = {
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index f88f1c55fc..8043bdbde2 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class StateStoreTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastores().main
- self.storage = hs.get_storage()
+ self.storage = hs.get_storage_controllers()
self.state_datastore = self.storage.state.stores.state
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
diff --git a/tests/test_state.py b/tests/test_state.py
index 84694d368d..95f81bebae 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -179,12 +179,12 @@ class Graph:
class StateTestCase(unittest.TestCase):
def setUp(self):
self.dummy_store = _DummyStore()
- storage = Mock(main=self.dummy_store, state=self.dummy_store)
+ storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
hs = Mock(
spec_set=[
"config",
"get_datastores",
- "get_storage",
+ "get_storage_controllers",
"get_auth",
"get_state_handler",
"get_clock",
@@ -199,7 +199,7 @@ class StateTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
- hs.get_storage.return_value = storage
+ hs.get_storage_controllers.return_value = storage_controllers
self.state = StateHandler(hs)
self.event_id = 0
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index c654e36ee4..8027c7a856 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -70,7 +70,7 @@ async def inject_event(
"""
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
- persistence = hs.get_storage().persistence
+ persistence = hs.get_storage_controllers().persistence
assert persistence is not None
await persistence.persist_event(event, context)
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 7a9b01ef9d..f338af6c36 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -34,7 +34,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
super(FilterEventsForServerTestCase, self).setUp()
self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory()
- self.storage = self.hs.get_storage()
+ self._storage_controllers = self.hs.get_storage_controllers()
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@@ -60,7 +60,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
events_to_filter.append(evt)
filtered = self.get_success(
- filter_events_for_server(self.storage, "test_server", events_to_filter)
+ filter_events_for_server(
+ self._storage_controllers, "test_server", events_to_filter
+ )
)
# the result should be 5 redacted events, and 5 unredacted events.
@@ -80,7 +82,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
outlier = self._inject_outlier()
self.assertEqual(
self.get_success(
- filter_events_for_server(self.storage, "remote_hs", [outlier])
+ filter_events_for_server(
+ self._storage_controllers, "remote_hs", [outlier]
+ )
),
[outlier],
)
@@ -89,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
evt = self._inject_message("@unerased:local_hs")
filtered = self.get_success(
- filter_events_for_server(self.storage, "remote_hs", [outlier, evt])
+ filter_events_for_server(
+ self._storage_controllers, "remote_hs", [outlier, evt]
+ )
)
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
self.assertEqual(filtered[0], outlier)
@@ -99,7 +105,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# ... but other servers should only be able to see the outlier (the other should
# be redacted)
filtered = self.get_success(
- filter_events_for_server(self.storage, "other_server", [outlier, evt])
+ filter_events_for_server(
+ self._storage_controllers, "other_server", [outlier, evt]
+ )
)
self.assertEqual(filtered[0], outlier)
self.assertEqual(filtered[1].event_id, evt.event_id)
@@ -132,7 +140,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# ... and the filtering happens.
filtered = self.get_success(
- filter_events_for_server(self.storage, "test_server", events_to_filter)
+ filter_events_for_server(
+ self._storage_controllers, "test_server", events_to_filter
+ )
)
for i in range(0, len(events_to_filter)):
@@ -168,7 +178,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
event, context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(event, context)
+ )
return event
def _inject_room_member(
@@ -194,7 +206,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(event, context)
+ )
return event
def _inject_message(
@@ -216,7 +230,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.event_creation_handler.create_new_client_event(builder)
)
- self.get_success(self.storage.persistence.persist_event(event, context))
+ self.get_success(
+ self._storage_controllers.persistence.persist_event(event, context)
+ )
return event
def _inject_outlier(self) -> EventBase:
@@ -234,8 +250,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
event.internal_metadata.outlier = True
self.get_success(
- self.storage.persistence.persist_event(
- event, EventContext.for_outlier(self.storage)
+ self._storage_controllers.persistence.persist_event(
+ event, EventContext.for_outlier(self._storage_controllers)
)
)
return event
@@ -293,7 +309,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(
self.get_success(
filter_events_for_client(
- self.hs.get_storage(), "@user:test", [invite_event, reject_event]
+ self.hs.get_storage_controllers(),
+ "@user:test",
+ [invite_event, reject_event],
)
),
[invite_event, reject_event],
@@ -303,7 +321,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(
self.get_success(
filter_events_for_client(
- self.hs.get_storage(), "@other:test", [invite_event, reject_event]
+ self.hs.get_storage_controllers(),
+ "@other:test",
+ [invite_event, reject_event],
)
),
[],
diff --git a/tests/utils.py b/tests/utils.py
index d4ba3a9b99..3059c453d5 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -264,7 +264,7 @@ class MockClock:
async def create_room(hs, room_id: str, creator_id: str):
"""Creates and persist a creation event for the given room"""
- persistence_store = hs.get_storage().persistence
+ persistence_store = hs.get_storage_controllers().persistence
store = hs.get_datastores().main
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()
|