From 73cf63784b90ea194eb867aafe3f39203b7ae029 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 23 Oct 2019 16:14:16 +0100 Subject: Add DataStores and Storage classes. --- synapse/storage/__init__.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) (limited to 'synapse/storage/__init__.py') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index a249ecd219..a58187a76f 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -27,7 +27,24 @@ data stores associated with them (e.g. the schema version tables), which are stored in `synapse.storage.schema`. """ -from synapse.storage.data_stores.main import DataStore # noqa: F401 +from synapse.storage.data_stores import DataStores +from synapse.storage.data_stores.main import DataStore +from synapse.storage.persist_events import EventsPersistenceStore + +__all__ = ["DataStores", "DataStore"] + + +class Storage(object): + """The high level interfaces for talking to various storage layers. + """ + + def __init__(self, hs, stores: DataStores): + # We include the main data store here mainly so that we don't have to + # rewrite all the existing code to split it into high vs low level + # interfaces. + self.main = stores.main + + self.persistence = EventsPersistenceStore(hs, stores) def are_all_users_on_domain(txn, database_engine, domain): -- cgit 1.5.1 From a8d16f6c00d5adb204af5fa73ffaea40eea4b632 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 30 Oct 2019 13:33:38 +0000 Subject: Review comments --- synapse/server.py | 5 ++--- synapse/storage/__init__.py | 4 ++-- synapse/storage/data_stores/main/events.py | 19 ++++++++++++++----- synapse/storage/persist_events.py | 13 ++++++++++--- tests/crypto/test_keyring.py | 4 ++-- tests/test_federation.py | 11 +++++++---- 6 files changed, 37 insertions(+), 19 deletions(-) (limited to 'synapse/storage/__init__.py') diff --git a/synapse/server.py b/synapse/server.py index 54a7f4aa5f..0b81af646c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -226,7 +226,6 @@ class HomeServer(object): self.admin_redaction_ratelimiter = Ratelimiter() self.registration_ratelimiter = Ratelimiter() - self.datastore = None self.datastores = None # Other kwargs are explicit dependencies @@ -236,8 +235,8 @@ class HomeServer(object): def setup(self): logger.info("Setting up.") with self.get_db_conn() as conn: - self.datastore = self.DATASTORE_CLASS(conn, self) - self.datastores = DataStores(self.datastore, conn, self) + datastore = self.DATASTORE_CLASS(conn, self) + self.datastores = DataStores(datastore, conn, self) conn.commit() logger.info("Finished setting up.") diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index a58187a76f..a6429d17ed 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -29,7 +29,7 @@ stored in `synapse.storage.schema`. from synapse.storage.data_stores import DataStores from synapse.storage.data_stores.main import DataStore -from synapse.storage.persist_events import EventsPersistenceStore +from synapse.storage.persist_events import EventsPersistenceStorage __all__ = ["DataStores", "DataStore"] @@ -44,7 +44,7 @@ class Storage(object): # interfaces. self.main = stores.main - self.persistence = EventsPersistenceStore(hs, stores) + self.persistence = EventsPersistenceStorage(hs, stores) def are_all_users_on_domain(txn, database_engine, domain): diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 6304531cd5..813f34528c 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -146,7 +146,7 @@ class EventsStore( @_retry_on_integrity_error @defer.inlineCallbacks - def _persist_events( + def _persist_events_and_state_updates( self, events_and_contexts, current_state_for_room, @@ -155,18 +155,27 @@ class EventsStore( backfilled=False, delete_existing=False, ): - """Persist events to db + """Persist a set of events alongside updates to the current state and + forward extremities tables. Args: events_and_contexts (list[(EventBase, EventContext)]): - backfilled (bool): + current_state_for_room (dict[str, dict]): Map from room_id to the + current state of the room based on forward extremities + state_delta_for_room (dict[str, tuple]): Map from room_id to tuple + of `(to_delete, to_insert)` where to_delete is a list + of type/state keys to remove from current state, and to_insert + is a map (type,key)->event_id giving the state delta in each + room. + new_forward_extremities (dict[str, list[str]]): Map from room_id + to list of event IDs that are the new forward extremities of + the room. + backfilled (bool) delete_existing (bool): Returns: Deferred: resolves when the events have been persisted """ - if not events_and_contexts: - return # We want to calculate the stream orderings as late as possible, as # we only notify after all events with a lesser stream ordering have diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 9a63953d4d..cf66225574 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -171,7 +171,13 @@ class _EventPeristenceQueue(object): pass -class EventsPersistenceStore(object): +class EventsPersistenceStorage(object): + """High level interface for handling persisting newly received events. + + Takes care of batching up events by room, and calculating the necessary + current state and forward extremity changes. + """ + def __init__(self, hs, stores: DataStores): # We ultimately want to split out the state store from the main store, # so we use separate variables here even though they point to the same @@ -257,7 +263,8 @@ class EventsPersistenceStore(object): def _persist_events( self, events_and_contexts, backfilled=False, delete_existing=False ): - """Persist events to db + """Calculates the change to current state and forward extremities, and + persists the given events and with those updates. Args: events_and_contexts (list[(EventBase, EventContext)]): @@ -399,7 +406,7 @@ class EventsPersistenceStore(object): if current_state is not None: current_state_for_room[room_id] = current_state - yield self.main_store._persist_events( + yield self.main_store._persist_events_and_state_updates( chunk, current_state_for_room=current_state_for_room, state_delta_for_room=state_delta_for_room, diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index c4f0bbd3dd..8efd39c7f7 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -178,7 +178,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): kr = keyring.Keyring(self.hs) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.datastore.store_server_verify_keys( + r = self.hs.get_datastore().store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))], @@ -209,7 +209,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): ) key1 = signedjson.key.generate_signing_key(1) - r = self.hs.datastore.store_server_verify_keys( + r = self.hs.get_datastore().store_server_verify_keys( "server9", time.time() * 1000, [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))], diff --git a/tests/test_federation.py b/tests/test_federation.py index a73f18f88e..d1acb16f30 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -36,7 +36,8 @@ class MessageAcceptTests(unittest.TestCase): # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + self.homeserver.get_datastore().get_latest_event_ids_in_room, + self.room_id, ) )[0] @@ -75,7 +76,8 @@ class MessageAcceptTests(unittest.TestCase): self.assertEqual( self.successResultOf( maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + self.homeserver.get_datastore().get_latest_event_ids_in_room, + self.room_id, ) )[0], "$join:test.serv", @@ -97,7 +99,8 @@ class MessageAcceptTests(unittest.TestCase): # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + self.homeserver.get_datastore().get_latest_event_ids_in_room, + self.room_id, ) )[0] @@ -137,6 +140,6 @@ class MessageAcceptTests(unittest.TestCase): # Make sure the invalid event isn't there extrem = maybeDeferred( - self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + self.homeserver.get_datastore().get_latest_event_ids_in_room, self.room_id ) self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") -- cgit 1.5.1 From 5db03535d587f9a3dc460ad9c015ab6a35ff5730 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 30 Oct 2019 14:07:48 +0000 Subject: Add StateGroupStorage interface --- synapse/storage/__init__.py | 2 + synapse/storage/persist_events.py | 2 +- synapse/storage/state.py | 232 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 235 insertions(+), 1 deletion(-) (limited to 'synapse/storage/__init__.py') diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index a6429d17ed..0a1a8cc1e5 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -30,6 +30,7 @@ stored in `synapse.storage.schema`. from synapse.storage.data_stores import DataStores from synapse.storage.data_stores.main import DataStore from synapse.storage.persist_events import EventsPersistenceStorage +from synapse.storage.state import StateGroupStorage __all__ = ["DataStores", "DataStore"] @@ -45,6 +46,7 @@ class Storage(object): self.main = stores.main self.persistence = EventsPersistenceStorage(hs, stores) + self.state = StateGroupStorage(hs, stores) def are_all_users_on_domain(txn, database_engine, domain): diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index cf66225574..61dfca8fdc 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -550,7 +550,7 @@ class EventsPersistenceStorage(object): if missing_event_ids: # Now pull out the state groups for any missing events from DB - event_to_groups = yield self.state_store._get_state_group_for_events( + event_to_groups = yield self.main_store._get_state_group_for_events( missing_event_ids ) event_id_to_state_group.update(event_to_groups) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index a2df8fa827..b382a06dcc 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -19,6 +19,8 @@ from six import iteritems, itervalues import attr +from twisted.internet import defer + from synapse.api.constants import EventTypes logger = logging.getLogger(__name__) @@ -322,3 +324,233 @@ class StateFilter(object): ) return member_filter, non_member_filter + + +class StateGroupStorage(object): + """High level interface to fetching state for event. + """ + + def __init__(self, hs, stores): + self.stores = stores + + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + (prev_group, delta_ids), where both may be None. + """ + + return self.stores.main.get_state_group_delta(state_group) + + @defer.inlineCallbacks + def get_state_groups_ids(self, _room_id, event_ids): + """Get the event IDs of all the state for the state groups for the given events + + Args: + _room_id (str): id of the room for these events + event_ids (iterable[str]): ids of the events + + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + if not event_ids: + return {} + + event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + + groups = set(itervalues(event_to_groups)) + group_to_state = yield self.stores.main._get_state_for_groups(groups) + + return group_to_state + + @defer.inlineCallbacks + def get_state_ids_for_group(self, state_group): + """Get the event IDs of all the state in the given state group + + Args: + state_group (int) + + Returns: + Deferred[dict]: Resolves to a map of (type, state_key) -> event_id + """ + group_to_state = yield self._get_state_for_groups((state_group,)) + + return group_to_state[state_group] + + @defer.inlineCallbacks + def get_state_groups(self, room_id, event_ids): + """ Get the state groups for the given list of event_ids + Returns: + Deferred[dict[int, list[EventBase]]]: + dict of state_group_id -> list of state events. + """ + if not event_ids: + return {} + + group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) + + state_event_map = yield self.stores.main.get_events( + [ + ev_id + for group_ids in itervalues(group_to_ids) + for ev_id in itervalues(group_ids) + ], + get_prev_content=False, + ) + + return { + group: [ + state_event_map[v] + for v in itervalues(event_id_map) + if v in state_event_map + ] + for group, event_id_map in iteritems(group_to_ids) + } + + def _get_state_groups_from_groups(self, groups, state_filter): + """Returns the state groups for a given set of groups, filtering on + types of state events. + + Args: + groups(list[int]): list of state group IDs to query + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + + return self.stores.main._get_state_groups_from_groups(groups, state_filter) + + @defer.inlineCallbacks + def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): + """Given a list of event_ids and type tuples, return a list of state + dicts for each event. + Args: + event_ids (list[string]) + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + deferred: A dict of (event_id) -> (type, state_key) -> [state_events] + """ + event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + + groups = set(itervalues(event_to_groups)) + group_to_state = yield self.stores.main._get_state_for_groups( + groups, state_filter + ) + + state_event_map = yield self.stores.main.get_events( + [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], + get_prev_content=False, + ) + + event_to_state = { + event_id: { + k: state_event_map[v] + for k, v in iteritems(group_to_state[group]) + if v in state_event_map + } + for event_id, group in iteritems(event_to_groups) + } + + return {event: event_to_state[event] for event in event_ids} + + @defer.inlineCallbacks + def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): + """ + Get the state dicts corresponding to a list of events, containing the event_ids + of the state events (as opposed to the events themselves) + + Args: + event_ids(list(str)): events whose state should be returned + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + A deferred dict from event_id -> (type, state_key) -> event_id + """ + event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + + groups = set(itervalues(event_to_groups)) + group_to_state = yield self.stores.main._get_state_for_groups( + groups, state_filter + ) + + event_to_state = { + event_id: group_to_state[group] + for event_id, group in iteritems(event_to_groups) + } + + return {event: event_to_state[event] for event in event_ids} + + @defer.inlineCallbacks + def get_state_for_event(self, event_id, state_filter=StateFilter.all()): + """ + Get the state dict corresponding to a particular event + + Args: + event_id(str): event whose state should be returned + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + A deferred dict from (type, state_key) -> state_event + """ + state_map = yield self.get_state_for_events([event_id], state_filter) + return state_map[event_id] + + @defer.inlineCallbacks + def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): + """ + Get the state dict corresponding to a particular event + + Args: + event_id(str): event whose state should be returned + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + A deferred dict from (type, state_key) -> state_event + """ + state_map = yield self.get_state_ids_for_events([event_id], state_filter) + return state_map[event_id] + + def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + return self.stores.main._get_state_for_groups(groups, state_filter) + + def store_state_group( + self, event_id, room_id, prev_group, delta_ids, current_state_ids + ): + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id (str): The event ID for which the state was calculated + room_id (str) + prev_group (int|None): A previous state group for the room, optional. + delta_ids (dict|None): The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids (dict): The state to store. Map of (type, state_key) + to event_id. + + Returns: + Deferred[int]: The state group ID + """ + return self.stores.main.store_state_group( + event_id, room_id, prev_group, delta_ids, current_state_ids + ) -- cgit 1.5.1