diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/__init__.py | 1 | ||||
-rw-r--r-- | synapse/storage/engines/postgres.py | 6 | ||||
-rw-r--r-- | synapse/storage/engines/sqlite3.py | 19 | ||||
-rw-r--r-- | synapse/storage/events.py | 188 | ||||
-rw-r--r-- | synapse/storage/room.py | 153 | ||||
-rw-r--r-- | synapse/storage/schema/delta/47/state_group_seq.py | 37 | ||||
-rw-r--r-- | synapse/storage/search.py | 13 | ||||
-rw-r--r-- | synapse/storage/state.py | 196 |
8 files changed, 390 insertions, 223 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index d01d46338a..f8fbd02ceb 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -124,7 +124,6 @@ class DataStore(RoomMemberStore, RoomStore, ) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") - self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index a6ae79dfad..8a0386c1a4 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -62,3 +62,9 @@ class PostgresEngine(object): def lock_table(self, txn, table): txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,)) + + def get_next_state_group_id(self, txn): + """Returns an int that can be used as a new state_group ID + """ + txn.execute("SELECT nextval('state_group_id_seq')") + return txn.fetchone()[0] diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index 755c9a1f07..60f0fa7fb3 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -16,6 +16,7 @@ from synapse.storage.prepare_database import prepare_database import struct +import threading class Sqlite3Engine(object): @@ -24,6 +25,11 @@ class Sqlite3Engine(object): def __init__(self, database_module, database_config): self.module = database_module + # The current max state_group, or None if we haven't looked + # in the DB yet. + self._current_state_group_id = None + self._current_state_group_id_lock = threading.Lock() + def check_database(self, txn): pass @@ -43,6 +49,19 @@ class Sqlite3Engine(object): def lock_table(self, txn, table): return + def get_next_state_group_id(self, txn): + """Returns an int that can be used as a new state_group ID + """ + # We do application locking here since if we're using sqlite then + # we are a single process synapse. + with self._current_state_group_id_lock: + if self._current_state_group_id is None: + txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") + self._current_state_group_id = txn.fetchone()[0] + + self._current_state_group_id += 1 + return self._current_state_group_id + # Following functions taken from: https://github.com/coleifer/peewee diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 33fccfa7a8..86a7c5920d 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -342,8 +342,20 @@ class EventsStore(SQLBaseStore): # NB: Assumes that we are only persisting events for one room # at a time. + + # map room_id->list[event_ids] giving the new forward + # extremities in each room new_forward_extremeties = {} + + # map room_id->(type,state_key)->event_id tracking the full + # state in each room after adding these events current_state_for_room = {} + + # map room_id->(to_delete, to_insert) where each entry is + # a map (type,key)->event_id giving the state delta in each + # room + state_delta_for_room = {} + if not backfilled: with Measure(self._clock, "_calculate_state_and_extrem"): # Work out the new "current state" for each room. @@ -386,11 +398,19 @@ class EventsStore(SQLBaseStore): if all_single_prev_not_state: continue - state = yield self._calculate_state_delta( - room_id, ev_ctx_rm, new_latest_event_ids + logger.info( + "Calculating state delta for room %s", room_id, + ) + current_state = yield self._get_new_state_after_events( + ev_ctx_rm, new_latest_event_ids, ) - if state: - current_state_for_room[room_id] = state + if current_state is not None: + current_state_for_room[room_id] = current_state + delta = yield self._calculate_state_delta( + room_id, current_state, + ) + if delta is not None: + state_delta_for_room[room_id] = delta yield self.runInteraction( "persist_events", @@ -398,7 +418,7 @@ class EventsStore(SQLBaseStore): events_and_contexts=chunk, backfilled=backfilled, delete_existing=delete_existing, - current_state_for_room=current_state_for_room, + state_delta_for_room=state_delta_for_room, new_forward_extremeties=new_forward_extremeties, ) persist_event_counter.inc_by(len(chunk)) @@ -415,7 +435,7 @@ class EventsStore(SQLBaseStore): event_counter.inc(event.type, origin_type, origin_entity) - for room_id, (_, _, new_state) in current_state_for_room.iteritems(): + for room_id, new_state in current_state_for_room.iteritems(): self.get_current_state_ids.prefill( (room_id, ), new_state ) @@ -467,20 +487,22 @@ class EventsStore(SQLBaseStore): defer.returnValue(new_latest_event_ids) @defer.inlineCallbacks - def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids): - """Calculate the new state deltas for a room. + def _get_new_state_after_events(self, events_context, new_latest_event_ids): + """Calculate the current state dict after adding some new events to + a room - Assumes that we are only persisting events for one room at a time. + Args: + events_context (list[(EventBase, EventContext)]): + events and contexts which are being added to the room + + new_latest_event_ids (iterable[str]): + the new forward extremities for the room. Returns: - 3-tuple (to_delete, to_insert, new_state) where both are state dicts, - i.e. (type, state_key) -> event_id. `to_delete` are the entries to - first be deleted from current_state_events, `to_insert` are entries - to insert. `new_state` is the full set of state. - May return None if there are no changes to be applied. + Deferred[dict[(str,str), str]|None]: + None if there are no changes to the room state, or + a dict of (type, state_key) -> event_id]. """ - # Now we need to work out the different state sets for - # each state extremities state_sets = [] state_groups = set() missing_event_ids = [] @@ -523,12 +545,12 @@ class EventsStore(SQLBaseStore): state_sets.extend(group_to_state.itervalues()) if not new_latest_event_ids: - current_state = {} + defer.returnValue({}) elif was_updated: if len(state_sets) == 1: # If there is only one state set, then we know what the current # state is. - current_state = state_sets[0] + defer.returnValue(state_sets[0]) else: # We work out the current state by passing the state sets to the # state resolution algorithm. It may ask for some events, including @@ -537,8 +559,7 @@ class EventsStore(SQLBaseStore): # up in the db. logger.info( - "Resolving state for %s with %i state sets", - room_id, len(state_sets), + "Resolving state with %i state sets", len(state_sets), ) events_map = {ev.event_id: ev for ev, _ in events_context} @@ -567,9 +588,22 @@ class EventsStore(SQLBaseStore): state_sets, state_map_factory=get_events, ) + defer.returnValue(current_state) else: return + @defer.inlineCallbacks + def _calculate_state_delta(self, room_id, current_state): + """Calculate the new state deltas for a room. + + Assumes that we are only persisting events for one room at a time. + + Returns: + 2-tuple (to_delete, to_insert) where both are state dicts, + i.e. (type, state_key) -> event_id. `to_delete` are the entries to + first be deleted from current_state_events, `to_insert` are entries + to insert. + """ existing_state = yield self.get_current_state_ids(room_id) existing_events = set(existing_state.itervalues()) @@ -589,7 +623,7 @@ class EventsStore(SQLBaseStore): if ev_id in events_to_insert } - defer.returnValue((to_delete, to_insert, current_state)) + defer.returnValue((to_delete, to_insert)) @defer.inlineCallbacks def get_event(self, event_id, check_redacted=True, @@ -649,7 +683,7 @@ class EventsStore(SQLBaseStore): @log_function def _persist_events_txn(self, txn, events_and_contexts, backfilled, - delete_existing=False, current_state_for_room={}, + delete_existing=False, state_delta_for_room={}, new_forward_extremeties={}): """Insert some number of room events into the necessary database tables. @@ -665,7 +699,7 @@ class EventsStore(SQLBaseStore): delete_existing (bool): True to purge existing table rows for the events from the database. This is useful when retrying due to IntegrityError. - current_state_for_room (dict[str, (list[str], list[str])]): + state_delta_for_room (dict[str, (list[str], list[str])]): The current-state delta for each room. For each room, a tuple (to_delete, to_insert), being a list of event ids to be removed from the current state, and a list of event ids to be added to @@ -677,7 +711,7 @@ class EventsStore(SQLBaseStore): """ max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering - self._update_current_state_txn(txn, current_state_for_room, max_stream_order) + self._update_current_state_txn(txn, state_delta_for_room, max_stream_order) self._update_forward_extremities_txn( txn, @@ -721,9 +755,8 @@ class EventsStore(SQLBaseStore): events_and_contexts=events_and_contexts, ) - # Insert into the state_groups, state_groups_state, and - # event_to_state_groups tables. - self._store_mult_state_groups_txn(txn, events_and_contexts) + # Insert into event_to_state_groups. + self._store_event_state_mappings_txn(txn, events_and_contexts) # _store_rejected_events_txn filters out any events which were # rejected, and returns the filtered list. @@ -743,7 +776,7 @@ class EventsStore(SQLBaseStore): def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order): for room_id, current_state_tuple in state_delta_by_room.iteritems(): - to_delete, to_insert, _ = current_state_tuple + to_delete, to_insert = current_state_tuple txn.executemany( "DELETE FROM current_state_events WHERE event_id = ?", [(ev_id,) for ev_id in to_delete.itervalues()], @@ -958,10 +991,9 @@ class EventsStore(SQLBaseStore): # an outlier in the database. We now have some state at that # so we need to update the state_groups table with that state. - # insert into the state_group, state_groups_state and - # event_to_state_groups tables. + # insert into event_to_state_groups. try: - self._store_mult_state_groups_txn(txn, ((event, context),)) + self._store_event_state_mappings_txn(txn, ((event, context),)) except Exception: logger.exception("") raise @@ -2031,16 +2063,32 @@ class EventsStore(SQLBaseStore): ) return self.runInteraction("get_all_new_events", get_all_new_events_txn) - def delete_old_state(self, room_id, topological_ordering): - return self.runInteraction( - "delete_old_state", - self._delete_old_state_txn, room_id, topological_ordering - ) + def purge_history( + self, room_id, topological_ordering, delete_local_events, + ): + """Deletes room history before a certain point + + Args: + room_id (str): + + topological_ordering (int): + minimum topo ordering to preserve - def _delete_old_state_txn(self, txn, room_id, topological_ordering): - """Deletes old room state + delete_local_events (bool): + if True, we will delete local events as well as remote ones + (instead of just marking them as outliers and deleting their + state groups). """ + return self.runInteraction( + "purge_history", + self._purge_history_txn, room_id, topological_ordering, + delete_local_events, + ) + + def _purge_history_txn( + self, txn, room_id, topological_ordering, delete_local_events, + ): # Tables that should be pruned: # event_auth # event_backward_extremities @@ -2081,7 +2129,7 @@ class EventsStore(SQLBaseStore): 400, "topological_ordering is greater than forward extremeties" ) - logger.debug("[purge] looking for events to delete") + logger.info("[purge] looking for events to delete") txn.execute( "SELECT event_id, state_key FROM events" @@ -2093,16 +2141,16 @@ class EventsStore(SQLBaseStore): to_delete = [ (event_id,) for event_id, state_key in event_rows - if state_key is None and not self.hs.is_mine_id(event_id) + if state_key is None and ( + delete_local_events or not self.hs.is_mine_id(event_id) + ) ] logger.info( - "[purge] found %i events before cutoff, of which %i are remote" - " non-state events to delete", len(event_rows), len(to_delete)) - - for event_id, state_key in event_rows: - txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) + "[purge] found %i events before cutoff, of which %i can be deleted", + len(event_rows), len(to_delete), + ) - logger.debug("[purge] Finding new backward extremities") + logger.info("[purge] Finding new backward extremities") # We calculate the new entries for the backward extremeties by finding # all events that point to events that are to be purged @@ -2116,7 +2164,7 @@ class EventsStore(SQLBaseStore): ) new_backwards_extrems = txn.fetchall() - logger.debug("[purge] replacing backward extremities: %r", new_backwards_extrems) + logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems) txn.execute( "DELETE FROM event_backward_extremities WHERE room_id = ?", @@ -2132,7 +2180,7 @@ class EventsStore(SQLBaseStore): ] ) - logger.debug("[purge] finding redundant state groups") + logger.info("[purge] finding redundant state groups") # Get all state groups that are only referenced by events that are # to be deleted. @@ -2149,15 +2197,15 @@ class EventsStore(SQLBaseStore): ) state_rows = txn.fetchall() - logger.debug("[purge] found %i redundant state groups", len(state_rows)) + logger.info("[purge] found %i redundant state groups", len(state_rows)) # make a set of the redundant state groups, so that we can look them up # efficiently state_groups_to_delete = set([sg for sg, in state_rows]) # Now we get all the state groups that rely on these state groups - logger.debug("[purge] finding state groups which depend on redundant" - " state groups") + logger.info("[purge] finding state groups which depend on redundant" + " state groups") remaining_state_groups = [] for i in xrange(0, len(state_rows), 100): chunk = [sg for sg, in state_rows[i:i + 100]] @@ -2182,7 +2230,7 @@ class EventsStore(SQLBaseStore): # Now we turn the state groups that reference to-be-deleted state # groups to non delta versions. for sg in remaining_state_groups: - logger.debug("[purge] de-delta-ing remaining state group %s", sg) + logger.info("[purge] de-delta-ing remaining state group %s", sg) curr_state = self._get_state_groups_from_groups_txn( txn, [sg], types=None ) @@ -2219,7 +2267,7 @@ class EventsStore(SQLBaseStore): ], ) - logger.debug("[purge] removing redundant state groups") + logger.info("[purge] removing redundant state groups") txn.executemany( "DELETE FROM state_groups_state WHERE state_group = ?", state_rows @@ -2229,18 +2277,15 @@ class EventsStore(SQLBaseStore): state_rows ) - # Delete all non-state - logger.debug("[purge] removing events from event_to_state_groups") + logger.info("[purge] removing events from event_to_state_groups") txn.executemany( "DELETE FROM event_to_state_groups WHERE event_id = ?", [(event_id,) for event_id, _ in event_rows] ) - - logger.debug("[purge] updating room_depth") - txn.execute( - "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", - (topological_ordering, room_id,) - ) + for event_id, _ in event_rows: + txn.call_after(self._get_state_group_for_event.invalidate, ( + event_id, + )) # Delete all remote non-state events for table in ( @@ -2258,7 +2303,8 @@ class EventsStore(SQLBaseStore): "event_signatures", "rejections", ): - logger.debug("[purge] removing remote non-state events from %s", table) + logger.info("[purge] removing remote non-state events from %s", + table) txn.executemany( "DELETE FROM %s WHERE event_id = ?" % (table,), @@ -2266,16 +2312,30 @@ class EventsStore(SQLBaseStore): ) # Mark all state and own events as outliers - logger.debug("[purge] marking remaining events as outliers") + logger.info("[purge] marking remaining events as outliers") txn.executemany( "UPDATE events SET outlier = ?" " WHERE event_id = ?", [ (True, event_id,) for event_id, state_key in event_rows - if state_key is not None or self.hs.is_mine_id(event_id) + if state_key is not None or ( + not delete_local_events and self.hs.is_mine_id(event_id) + ) ] ) + # synapse tries to take out an exclusive lock on room_depth whenever it + # persists events (because upsert), and once we run this update, we + # will block that for the rest of our transaction. + # + # So, let's stick it at the end so that we don't block event + # persistence. + logger.info("[purge] updating room_depth") + txn.execute( + "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", + (topological_ordering, room_id,) + ) + logger.info("[purge] done") @defer.inlineCallbacks diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 0fcfb7f86d..fff6652e05 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -506,73 +506,114 @@ class RoomStore(SearchStore): ) self.is_room_blocked.invalidate((room_id,)) + def get_media_mxcs_in_room(self, room_id): + """Retrieves all the local and remote media MXC URIs in a given room + + Args: + room_id (str) + + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + def _get_media_mxcs_in_room_txn(txn): + local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) + local_media_mxcs = [] + remote_media_mxcs = [] + + # Convert the IDs to MXC URIs + for media_id in local_mxcs: + local_media_mxcs.append("mxc://%s/%s" % (self.hostname, media_id)) + for hostname, media_id in remote_mxcs: + remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id)) + + return local_media_mxcs, remote_media_mxcs + return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn) + def quarantine_media_ids_in_room(self, room_id, quarantined_by): """For a room loops through all events with media and quarantines the associated media """ - def _get_media_ids_in_room(txn): - mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") + def _quarantine_media_in_room_txn(txn): + local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) + total_media_quarantined = 0 - next_token = self.get_current_events_token() + 1 + # Now update all the tables to set the quarantined_by flag - total_media_quarantined = 0 + txn.executemany(""" + UPDATE local_media_repository + SET quarantined_by = ? + WHERE media_id = ? + """, ((quarantined_by, media_id) for media_id in local_mxcs)) - while next_token: - sql = """ - SELECT stream_ordering, content FROM events - WHERE room_id = ? - AND stream_ordering < ? - AND contains_url = ? AND outlier = ? - ORDER BY stream_ordering DESC - LIMIT ? + txn.executemany( """ - txn.execute(sql, (room_id, next_token, True, False, 100)) - - next_token = None - local_media_mxcs = [] - remote_media_mxcs = [] - for stream_ordering, content_json in txn: - next_token = stream_ordering - content = json.loads(content_json) - - content_url = content.get("url") - thumbnail_url = content.get("info", {}).get("thumbnail_url") - - for url in (content_url, thumbnail_url): - if not url: - continue - matches = mxc_re.match(url) - if matches: - hostname = matches.group(1) - media_id = matches.group(2) - if hostname == self.hostname: - local_media_mxcs.append(media_id) - else: - remote_media_mxcs.append((hostname, media_id)) - - # Now update all the tables to set the quarantined_by flag - - txn.executemany(""" - UPDATE local_media_repository + UPDATE remote_media_cache SET quarantined_by = ? - WHERE media_id = ? - """, ((quarantined_by, media_id) for media_id in local_media_mxcs)) - - txn.executemany( - """ - UPDATE remote_media_cache - SET quarantined_by = ? - WHERE media_origin = ? AND media_id = ? - """, - ( - (quarantined_by, origin, media_id) - for origin, media_id in remote_media_mxcs - ) + WHERE media_origin = ? AND media_id = ? + """, + ( + (quarantined_by, origin, media_id) + for origin, media_id in remote_mxcs ) + ) - total_media_quarantined += len(local_media_mxcs) - total_media_quarantined += len(remote_media_mxcs) + total_media_quarantined += len(local_mxcs) + total_media_quarantined += len(remote_mxcs) return total_media_quarantined - return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room) + return self.runInteraction( + "quarantine_media_in_room", + _quarantine_media_in_room_txn, + ) + + def _get_media_mxcs_in_room_txn(self, txn, room_id): + """Retrieves all the local and remote media MXC URIs in a given room + + Args: + txn (cursor) + room_id (str) + + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") + + next_token = self.get_current_events_token() + 1 + local_media_mxcs = [] + remote_media_mxcs = [] + + while next_token: + sql = """ + SELECT stream_ordering, content FROM events + WHERE room_id = ? + AND stream_ordering < ? + AND contains_url = ? AND outlier = ? + ORDER BY stream_ordering DESC + LIMIT ? + """ + txn.execute(sql, (room_id, next_token, True, False, 100)) + + next_token = None + for stream_ordering, content_json in txn: + next_token = stream_ordering + content = json.loads(content_json) + + content_url = content.get("url") + thumbnail_url = content.get("info", {}).get("thumbnail_url") + + for url in (content_url, thumbnail_url): + if not url: + continue + matches = mxc_re.match(url) + if matches: + hostname = matches.group(1) + media_id = matches.group(2) + if hostname == self.hostname: + local_media_mxcs.append(media_id) + else: + remote_media_mxcs.append((hostname, media_id)) + + return local_media_mxcs, remote_media_mxcs diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/schema/delta/47/state_group_seq.py new file mode 100644 index 0000000000..f6766501d2 --- /dev/null +++ b/synapse/storage/schema/delta/47/state_group_seq.py @@ -0,0 +1,37 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.engines import PostgresEngine + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + # if we already have some state groups, we want to start making new + # ones with a higher id. + cur.execute("SELECT max(id) FROM state_groups") + row = cur.fetchone() + + if row[0] is None: + start_val = 1 + else: + start_val = row[0] + 1 + + cur.execute( + "CREATE SEQUENCE state_group_id_seq START WITH %s", + (start_val, ), + ) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/search.py b/synapse/storage/search.py index eecf778516..7c80fab277 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -12,18 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from collections import namedtuple +import logging +import re import sys +import ujson as json + from twisted.internet import defer from .background_updates import BackgroundUpdateStore from synapse.api.errors import SynapseError from synapse.storage.engines import PostgresEngine, Sqlite3Engine -import logging -import re -import ujson as json - logger = logging.getLogger(__name__) @@ -280,10 +281,10 @@ class SearchStore(BackgroundUpdateStore): if isinstance(self.database_engine, PostgresEngine): sql = ( "INSERT INTO event_search" - " (event_id, room_id, key, vector, stream_ordering, " - " origin_server_ts)" + " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" " VALUES (?,?,?,to_tsvector('english', ?),?,?)" ) + args = (( entry.event_id, entry.room_id, entry.key, entry.value, entry.stream_ordering, entry.origin_server_ts, diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 360e3e4355..adb48df73e 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -42,11 +42,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt return len(self.delta_ids) if self.delta_ids else 0 -class StateGroupReadStore(SQLBaseStore): - """The read-only parts of StateGroupStore - - None of these functions write to the state tables, so are suitable for - including in the SlavedStores. +class StateGroupWorkerStore(SQLBaseStore): + """The parts of StateGroupStore that can be called from workers. """ STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" @@ -54,7 +51,7 @@ class StateGroupReadStore(SQLBaseStore): CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" def __init__(self, db_conn, hs): - super(StateGroupReadStore, self).__init__(db_conn, hs) + super(StateGroupWorkerStore, self).__init__(db_conn, hs) self._state_group_cache = DictionaryCache( "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR @@ -549,116 +546,66 @@ class StateGroupReadStore(SQLBaseStore): defer.returnValue(results) + def store_state_group(self, event_id, room_id, prev_group, delta_ids, + current_state_ids): + """Store a new set of state, returning a newly assigned state group. -class StateStore(StateGroupReadStore, BackgroundUpdateStore): - """ Keeps track of the state at a given event. - - This is done by the concept of `state groups`. Every event is a assigned - a state group (identified by an arbitrary string), which references a - collection of state events. The current state of an event is then the - collection of state events referenced by the event's state group. - - Hence, every change in the current state causes a new state group to be - generated. However, if no change happens (e.g., if we get a message event - with only one parent it inherits the state group from its parent.) - - There are three tables: - * `state_groups`: Stores group name, first event with in the group and - room id. - * `event_to_state_groups`: Maps events to state groups. - * `state_groups_state`: Maps state group to state events. - """ - - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" - STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" - CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - - def __init__(self, db_conn, hs): - super(StateStore, self).__init__(db_conn, hs) - self.register_background_update_handler( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, - self._background_deduplicate_state, - ) - self.register_background_update_handler( - self.STATE_GROUP_INDEX_UPDATE_NAME, - self._background_index_state, - ) - self.register_background_index_update( - self.CURRENT_STATE_INDEX_UPDATE_NAME, - index_name="current_state_events_member_index", - table="current_state_events", - columns=["state_key"], - where_clause="type='m.room.member'", - ) - - def _have_persisted_state_group_txn(self, txn, state_group): - txn.execute( - "SELECT count(*) FROM state_groups WHERE id = ?", - (state_group,) - ) - row = txn.fetchone() - return row and row[0] - - def _store_mult_state_groups_txn(self, txn, events_and_contexts): - state_groups = {} - for event, context in events_and_contexts: - if event.internal_metadata.is_outlier(): - continue + Args: + event_id (str): The event ID for which the state was calculated + room_id (str) + prev_group (int|None): A previous state group for the room, optional. + delta_ids (dict|None): The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids (dict): The state to store. Map of (type, state_key) + to event_id. - if context.current_state_ids is None: + Returns: + Deferred[int]: The state group ID + """ + def _store_state_group_txn(txn): + if current_state_ids is None: # AFAIK, this can never happen - logger.error( - "Non-outlier event %s had current_state_ids==None", - event.event_id) - continue + raise Exception("current_state_ids cannot be None") - # if the event was rejected, just give it the same state as its - # predecessor. - if context.rejected: - state_groups[event.event_id] = context.prev_group - continue - - state_groups[event.event_id] = context.state_group - - if self._have_persisted_state_group_txn(txn, context.state_group): - continue + state_group = self.database_engine.get_next_state_group_id(txn) self._simple_insert_txn( txn, table="state_groups", values={ - "id": context.state_group, - "room_id": event.room_id, - "event_id": event.event_id, + "id": state_group, + "room_id": room_id, + "event_id": event_id, }, ) # We persist as a delta if we can, while also ensuring the chain # of deltas isn't tooo long, as otherwise read performance degrades. - if context.prev_group: + if prev_group: is_in_db = self._simple_select_one_onecol_txn( txn, table="state_groups", - keyvalues={"id": context.prev_group}, + keyvalues={"id": prev_group}, retcol="id", allow_none=True, ) if not is_in_db: raise Exception( "Trying to persist state with unpersisted prev_group: %r" - % (context.prev_group,) + % (prev_group,) ) potential_hops = self._count_state_group_hops_txn( - txn, context.prev_group + txn, prev_group ) - if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: self._simple_insert_txn( txn, table="state_group_edges", values={ - "state_group": context.state_group, - "prev_state_group": context.prev_group, + "state_group": state_group, + "prev_state_group": prev_group, }, ) @@ -667,13 +614,13 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore): table="state_groups_state", values=[ { - "state_group": context.state_group, - "room_id": event.room_id, + "state_group": state_group, + "room_id": room_id, "type": key[0], "state_key": key[1], "event_id": state_id, } - for key, state_id in context.delta_ids.iteritems() + for key, state_id in delta_ids.iteritems() ], ) else: @@ -682,13 +629,13 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore): table="state_groups_state", values=[ { - "state_group": context.state_group, - "room_id": event.room_id, + "state_group": state_group, + "room_id": room_id, "type": key[0], "state_key": key[1], "event_id": state_id, } - for key, state_id in context.current_state_ids.iteritems() + for key, state_id in current_state_ids.iteritems() ], ) @@ -699,11 +646,71 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore): txn.call_after( self._state_group_cache.update, self._state_group_cache.sequence, - key=context.state_group, - value=dict(context.current_state_ids), + key=state_group, + value=dict(current_state_ids), full=True, ) + return state_group + + return self.runInteraction("store_state_group", _store_state_group_txn) + + +class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): + """ Keeps track of the state at a given event. + + This is done by the concept of `state groups`. Every event is a assigned + a state group (identified by an arbitrary string), which references a + collection of state events. The current state of an event is then the + collection of state events referenced by the event's state group. + + Hence, every change in the current state causes a new state group to be + generated. However, if no change happens (e.g., if we get a message event + with only one parent it inherits the state group from its parent.) + + There are three tables: + * `state_groups`: Stores group name, first event with in the group and + room id. + * `event_to_state_groups`: Maps events to state groups. + * `state_groups_state`: Maps state group to state events. + """ + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" + + def __init__(self, db_conn, hs): + super(StateStore, self).__init__(db_conn, hs) + self.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, + self._background_index_state, + ) + self.register_background_index_update( + self.CURRENT_STATE_INDEX_UPDATE_NAME, + index_name="current_state_events_member_index", + table="current_state_events", + columns=["state_key"], + where_clause="type='m.room.member'", + ) + + def _store_event_state_mappings_txn(self, txn, events_and_contexts): + state_groups = {} + for event, context in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue + + # if the event was rejected, just give it the same state as its + # predecessor. + if context.rejected: + state_groups[event.event_id] = context.prev_group + continue + + state_groups[event.event_id] = context.state_group + self._simple_insert_many_txn( txn, table="event_to_state_groups", @@ -763,9 +770,6 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore): return count - def get_next_state_group(self): - return self._state_groups_id_gen.get_next() - @defer.inlineCallbacks def _background_deduplicate_state(self, progress, batch_size): """This background update will slowly deduplicate state by reencoding |