summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/data_stores/__init__.py45
-rw-r--r--synapse/storage/data_stores/main/client_ips.py10
-rw-r--r--synapse/storage/data_stores/main/deviceinbox.py17
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py217
-rw-r--r--synapse/storage/data_stores/main/events.py157
-rw-r--r--synapse/storage/data_stores/main/events_worker.py97
-rw-r--r--synapse/storage/data_stores/main/push_rule.py2
-rw-r--r--synapse/storage/data_stores/main/pusher.py25
-rw-r--r--synapse/storage/data_stores/main/roommember.py2
-rw-r--r--synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql1
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql20
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql29
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres52
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite6
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql1
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/README.md13
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/README.txt19
-rw-r--r--synapse/storage/data_stores/main/search.py19
-rw-r--r--synapse/storage/data_stores/main/state.py955
-rw-r--r--synapse/storage/data_stores/state/__init__.py16
-rw-r--r--synapse/storage/data_stores/state/bg_updates.py374
-rw-r--r--synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql (renamed from synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql)0
-rw-r--r--synapse/storage/data_stores/state/schema/delta/30/state_stream.sql (renamed from synapse/storage/data_stores/main/schema/delta/30/state_stream.sql)0
-rw-r--r--synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql19
-rw-r--r--synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql (renamed from synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql)0
-rw-r--r--synapse/storage/data_stores/state/schema/delta/35/state.sql (renamed from synapse/storage/data_stores/main/schema/delta/35/state.sql)0
-rw-r--r--synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql (renamed from synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql)0
-rw-r--r--synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py (renamed from synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py)0
-rw-r--r--synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql17
-rw-r--r--synapse/storage/data_stores/state/schema/full_schemas/54/full.sql37
-rw-r--r--synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres21
-rw-r--r--synapse/storage/data_stores/state/store.py640
-rw-r--r--synapse/storage/database.py45
-rw-r--r--synapse/storage/engines/sqlite.py6
-rw-r--r--synapse/storage/persist_events.py2
-rw-r--r--synapse/storage/prepare_database.py30
-rw-r--r--synapse/storage/purge_events.py4
-rw-r--r--synapse/storage/state.py14
39 files changed, 1633 insertions, 1281 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py

index b7637b5dc0..88546ad614 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py
@@ -40,7 +40,7 @@ class SQLBaseStore(object): def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() - self.database_engine = hs.database_engine + self.database_engine = database.engine self.db = database self.rand = random.SystemRandom() diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index cafedd5c0d..d20df5f076 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py
@@ -13,24 +13,55 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.database import Database +import logging + +from synapse.storage.data_stores.state import StateGroupDataStore +from synapse.storage.database import Database, make_conn +from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database +logger = logging.getLogger(__name__) + class DataStores(object): """The various data stores. These are low level interfaces to physical databases. + + Attributes: + main (DataStore) """ - def __init__(self, main_store_class, db_conn, hs): + def __init__(self, main_store_class, hs): # Note we pass in the main store class here as workers use a different main # store. - database = Database(hs) - # Check that db is correctly configured. - database.engine.check_database(db_conn.cursor()) + self.databases = [] + + for database_config in hs.config.database.databases: + db_name = database_config.name + engine = create_engine(database_config.config) + + with make_conn(database_config, engine) as db_conn: + logger.info("Preparing database %r...", db_name) + + engine.check_database(db_conn.cursor()) + prepare_database( + db_conn, engine, hs.config, data_stores=database_config.data_stores, + ) + + database = Database(hs, database_config, engine) + + if "main" in database_config.data_stores: + logger.info("Starting 'main' data store") + self.main = main_store_class(database, db_conn, hs) + + if "state" in database_config.data_stores: + logger.info("Starting 'state' data store") + self.state = StateGroupDataStore(database, db_conn, hs) + + db_conn.commit() - prepare_database(db_conn, database.engine, config=hs.config) + self.databases.append(database) - self.main = main_store_class(database, db_conn, hs) + logger.info("Database %r prepared", db_name) diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 320c5b0f07..13f4c9c72e 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py
@@ -412,7 +412,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): def _update_client_ips_batch(self): # If the DB pool has already terminated, don't try updating - if not self.hs.get_db_pool().running: + if not self.db.is_running(): return to_update = self._batch_row_update @@ -451,16 +451,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Technically an access token might not be associated with # a device so we need to check. if device_id: - self.db.simple_upsert_txn( + # this is always an update rather than an upsert: the row should + # already exist, and if it doesn't, that may be because it has been + # deleted, and we don't want to re-create it. + self.db.simple_update_txn( txn, table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, - values={ + updatevalues={ "user_agent": user_agent, "last_seen": last_seen, "ip": ip, }, - lock=False, ) except Exception as e: # Failed to upsert, log and continue diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 85cfa16850..0613b49f4a 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -358,21 +358,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) def _add_messages_to_local_device_inbox_txn( self, txn, stream_id, messages_by_user_then_device ): - # Compatible method of performing an upsert - sql = "SELECT stream_id FROM device_max_stream_id" - - txn.execute(sql) - rows = txn.fetchone() - if rows: - db_stream_id = rows[0] - if db_stream_id < stream_id: - # Insert the new stream_id - sql = "UPDATE device_max_stream_id SET stream_id = ?" - else: - # No rows, perform an insert - sql = "INSERT INTO device_max_stream_id (stream_id) VALUES (?)" - - txn.execute(sql, (stream_id,)) + sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?" + txn.execute(sql, (stream_id, stream_id)) local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items(): diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 38cd0ca9b8..e551606f9d 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -14,15 +14,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, List + from six import iteritems from canonicaljson import encode_canonical_json, json +from twisted.enterprise.adbapi import Connection from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedList class EndToEndKeyWorkerStore(SQLBaseStore): @@ -271,7 +274,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): Args: txn (twisted.enterprise.adbapi.Connection): db connection user_id (str): the user whose key is being requested - key_type (str): the type of key that is being set: either 'master' + key_type (str): the type of key that is being requested: either 'master' for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key from_user_id (str): if specified, signatures made by this user on @@ -316,8 +319,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore): """Returns a user's cross-signing key. Args: - user_id (str): the user whose self-signing key is being requested - key_type (str): the type of cross-signing key to get + user_id (str): the user whose key is being requested + key_type (str): the type of key that is being requested: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key from_user_id (str): if specified, signatures made by this user on the self-signing key will be included in the result @@ -332,6 +337,206 @@ class EndToEndKeyWorkerStore(SQLBaseStore): from_user_id, ) + @cached(num_args=1) + def _get_bare_e2e_cross_signing_keys(self, user_id): + """Dummy function. Only used to make a cache for + _get_bare_e2e_cross_signing_keys_bulk. + """ + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_bare_e2e_cross_signing_keys", + list_name="user_ids", + num_args=1, + ) + def _get_bare_e2e_cross_signing_keys_bulk( + self, user_ids: List[str] + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing keys for a set of users. The output of this + function should be passed to _get_e2e_cross_signing_signatures_txn if + the signatures for the calling user need to be fetched. + + Args: + user_ids (list[str]): the users whose keys are being requested + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. If a user's cross-signing keys were not found, either + their user ID will not be in the dict, or their user ID will map + to None. + + """ + return self.db.runInteraction( + "get_bare_e2e_cross_signing_keys_bulk", + self._get_bare_e2e_cross_signing_keys_bulk_txn, + user_ids, + ) + + def _get_bare_e2e_cross_signing_keys_bulk_txn( + self, txn: Connection, user_ids: List[str], + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing keys for a set of users. The output of this + function should be passed to _get_e2e_cross_signing_signatures_txn if + the signatures for the calling user need to be fetched. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + user_ids (list[str]): the users whose keys are being requested + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. If a user's cross-signing keys were not found, their user + ID will not be in the dict. + + """ + result = {} + + batch_size = 100 + chunks = [ + user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size) + ] + for user_chunk in chunks: + sql = """ + SELECT k.user_id, k.keytype, k.keydata, k.stream_id + FROM e2e_cross_signing_keys k + INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id + FROM e2e_cross_signing_keys + GROUP BY user_id, keytype) s + USING (user_id, stream_id, keytype) + WHERE k.user_id IN (%s) + """ % ( + ",".join("?" for u in user_chunk), + ) + query_params = [] + query_params.extend(user_chunk) + + txn.execute(sql, query_params) + rows = self.db.cursor_to_dict(txn) + + for row in rows: + user_id = row["user_id"] + key_type = row["keytype"] + key = json.loads(row["keydata"]) + user_info = result.setdefault(user_id, {}) + user_info[key_type] = key + + return result + + def _get_e2e_cross_signing_signatures_txn( + self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str, + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing signatures made by a user on a set of keys. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + keys (dict[str, dict[str, dict]]): a map of user ID to key type to + key data. This dict will be modified to add signatures. + from_user_id (str): fetch the signatures made by this user + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. The return value will be the same as the keys argument, + with the modifications included. + """ + + # find out what cross-signing keys (a.k.a. devices) we need to get + # signatures for. This is a map of (user_id, device_id) to key type + # (device_id is the key's public part). + devices = {} + + for user_id, user_info in keys.items(): + if user_info is None: + continue + for key_type, key in user_info.items(): + device_id = None + for k in key["keys"].values(): + device_id = k + devices[(user_id, device_id)] = key_type + + device_list = list(devices) + + # split into batches + batch_size = 100 + chunks = [ + device_list[i : i + batch_size] + for i in range(0, len(device_list), batch_size) + ] + for user_chunk in chunks: + sql = """ + SELECT target_user_id, target_device_id, key_id, signature + FROM e2e_cross_signing_signatures + WHERE user_id = ? + AND (%s) + """ % ( + " OR ".join( + "(target_user_id = ? AND target_device_id = ?)" for d in devices + ) + ) + query_params = [from_user_id] + for item in devices: + # item is a (user_id, device_id) tuple + query_params.extend(item) + + txn.execute(sql, query_params) + rows = self.db.cursor_to_dict(txn) + + # and add the signatures to the appropriate keys + for row in rows: + key_id = row["key_id"] + target_user_id = row["target_user_id"] + target_device_id = row["target_device_id"] + key_type = devices[(target_user_id, target_device_id)] + # We need to copy everything, because the result may have come + # from the cache. dict.copy only does a shallow copy, so we + # need to recursively copy the dicts that will be modified. + user_info = keys[target_user_id] = keys[target_user_id].copy() + target_user_key = user_info[key_type] = user_info[key_type].copy() + if "signatures" in target_user_key: + signatures = target_user_key["signatures"] = target_user_key[ + "signatures" + ].copy() + if from_user_id in signatures: + user_sigs = signatures[from_user_id] = signatures[from_user_id] + user_sigs[key_id] = row["signature"] + else: + signatures[from_user_id] = {key_id: row["signature"]} + else: + target_user_key["signatures"] = { + from_user_id: {key_id: row["signature"]} + } + + return keys + + @defer.inlineCallbacks + def get_e2e_cross_signing_keys_bulk( + self, user_ids: List[str], from_user_id: str = None + ) -> defer.Deferred: + """Returns the cross-signing keys for a set of users. + + Args: + user_ids (list[str]): the users whose keys are being requested + from_user_id (str): if specified, signatures made by this user on + the self-signing keys will be included in the result + + Returns: + Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to + key data. If a user's cross-signing keys were not found, either + their user ID will not be in the dict, or their user ID will map + to None. + """ + + result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) + + if from_user_id: + result = yield self.db.runInteraction( + "get_e2e_cross_signing_signatures", + self._get_e2e_cross_signing_signatures_txn, + result, + from_user_id, + ) + + return result + def get_all_user_signature_changes_for_remotes(self, from_key, to_key): """Return a list of changes from the user signature stream to notify remotes. Note that the user signature stream represents when a user signs their @@ -520,6 +725,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): }, ) + self._invalidate_cache_and_stream( + txn, self._get_bare_e2e_cross_signing_keys, (user_id,) + ) + def set_e2e_cross_signing_key(self, user_id, key_type, key): """Set a user's cross-signing key. diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 998bba1aad..58f35d7f56 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py
@@ -1757,163 +1757,6 @@ class EventsStore( return state_groups - def purge_unreferenced_state_groups( - self, room_id: str, state_groups_to_delete - ) -> defer.Deferred: - """Deletes no longer referenced state groups and de-deltas any state - groups that reference them. - - Args: - room_id: The room the state groups belong to (must all be in the - same room). - state_groups_to_delete (Collection[int]): Set of all state groups - to delete. - """ - - return self.db.runInteraction( - "purge_unreferenced_state_groups", - self._purge_unreferenced_state_groups, - room_id, - state_groups_to_delete, - ) - - def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete): - logger.info( - "[purge] found %i state groups to delete", len(state_groups_to_delete) - ) - - rows = self.db.simple_select_many_txn( - txn, - table="state_group_edges", - column="prev_state_group", - iterable=state_groups_to_delete, - keyvalues={}, - retcols=("state_group",), - ) - - remaining_state_groups = set( - row["state_group"] - for row in rows - if row["state_group"] not in state_groups_to_delete - ) - - logger.info( - "[purge] de-delta-ing %i remaining state groups", - len(remaining_state_groups), - ) - - # Now we turn the state groups that reference to-be-deleted state - # groups to non delta versions. - for sg in remaining_state_groups: - logger.info("[purge] de-delta-ing remaining state group %s", sg) - curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) - curr_state = curr_state[sg] - - self.db.simple_delete_txn( - txn, table="state_groups_state", keyvalues={"state_group": sg} - ) - - self.db.simple_delete_txn( - txn, table="state_group_edges", keyvalues={"state_group": sg} - ) - - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": sg, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(curr_state) - ], - ) - - logger.info("[purge] removing redundant state groups") - txn.executemany( - "DELETE FROM state_groups_state WHERE state_group = ?", - ((sg,) for sg in state_groups_to_delete), - ) - txn.executemany( - "DELETE FROM state_groups WHERE id = ?", - ((sg,) for sg in state_groups_to_delete), - ) - - @defer.inlineCallbacks - def get_previous_state_groups(self, state_groups): - """Fetch the previous groups of the given state groups. - - Args: - state_groups (Iterable[int]) - - Returns: - Deferred[dict[int, int]]: mapping from state group to previous - state group. - """ - - rows = yield self.db.simple_select_many_batch( - table="state_group_edges", - column="prev_state_group", - iterable=state_groups, - keyvalues={}, - retcols=("prev_state_group", "state_group"), - desc="get_previous_state_groups", - ) - - return {row["state_group"]: row["prev_state_group"] for row in rows} - - def purge_room_state(self, room_id, state_groups_to_delete): - """Deletes all record of a room from state tables - - Args: - room_id (str): - state_groups_to_delete (list[int]): State groups to delete - """ - - return self.db.runInteraction( - "purge_room_state", - self._purge_room_state_txn, - room_id, - state_groups_to_delete, - ) - - def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): - # first we have to delete the state groups states - logger.info("[purge] removing %s from state_groups_state", room_id) - - self.db.simple_delete_many_txn( - txn, - table="state_groups_state", - column="state_group", - iterable=state_groups_to_delete, - keyvalues={}, - ) - - # ... and the state group edges - logger.info("[purge] removing %s from state_group_edges", room_id) - - self.db.simple_delete_many_txn( - txn, - table="state_group_edges", - column="state_group", - iterable=state_groups_to_delete, - keyvalues={}, - ) - - # ... and the state groups - logger.info("[purge] removing %s from state_groups", room_id) - - self.db.simple_delete_many_txn( - txn, - table="state_groups", - column="id", - iterable=state_groups_to_delete, - keyvalues={}, - ) - async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 9ee117ce0f..2c9142814c 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py
@@ -19,8 +19,10 @@ import itertools import logging import threading from collections import namedtuple +from typing import List, Optional from canonicaljson import json +from constantly import NamedConstant, Names from twisted.internet import defer @@ -55,6 +57,16 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) +class EventRedactBehaviour(Names): + """ + What to do when retrieving a redacted event from the database. + """ + + AS_IS = NamedConstant() + REDACT = NamedConstant() + BLOCK = NamedConstant() + + class EventsWorkerStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) @@ -125,25 +137,27 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_event( self, - event_id, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, - allow_none=False, - check_room_id=None, + event_id: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: bool = False, + check_room_id: Optional[str] = None, ): """Get an event from the database by event_id. Args: - event_id (str): The event_id of the event to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_id: The event_id of the event to fetch + redact_behaviour: Determine what to do with a redacted event. Possible values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. - allow_none (bool): If True, return None if no event found, if + allow_rejected: If True return rejected events. + allow_none: If True, return None if no event found, if False throw a NotFoundError - check_room_id (str|None): if not None, check the room of the found event. + check_room_id: if not None, check the room of the found event. If there is a mismatch, behave as per allow_none. Returns: @@ -154,7 +168,7 @@ class EventsWorkerStore(SQLBaseStore): events = yield self.get_events_as_list( [event_id], - check_redacted=check_redacted, + redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, allow_rejected=allow_rejected, ) @@ -173,27 +187,30 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_events( self, - event_ids, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, + event_ids: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, ): """Get events from the database Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_ids: The event_ids of the events to fetch + redact_behaviour: Determine what to do with a redacted event. Possible + values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. + allow_rejected: If True return rejected events. Returns: Deferred : Dict from event_id to event. """ events = yield self.get_events_as_list( event_ids, - check_redacted=check_redacted, + redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, allow_rejected=allow_rejected, ) @@ -203,21 +220,23 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_events_as_list( self, - event_ids, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, + event_ids: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, ): """Get events from the database and return in a list in the same order as given by `event_ids` arg. Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_ids: The event_ids of the events to fetch + redact_behaviour: Determine what to do with a redacted event. Possible values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. + allow_rejected: If True, return rejected events. Returns: Deferred[list[EventBase]]: List of events fetched from the database. The @@ -319,10 +338,14 @@ class EventsWorkerStore(SQLBaseStore): # Update the cache to save doing the checks again. entry.event.internal_metadata.recheck_redaction = False - if check_redacted and entry.redacted_event: - event = entry.redacted_event - else: - event = entry.event + event = entry.event + + if entry.redacted_event: + if redact_behaviour == EventRedactBehaviour.BLOCK: + # Skip this event + continue + elif redact_behaviour == EventRedactBehaviour.REDACT: + event = entry.redacted_event events.append(event) diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index 5ba13aa973..e2673ae073 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py
@@ -244,7 +244,7 @@ class PushRulesWorkerStore( # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids(self) + current_state_ids = yield context.get_current_state_ids() result = yield self._bulk_get_push_rules_for_room( event.room_id, state_group, current_state_ids, event=event ) diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index f07309ef09..6b03233262 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py
@@ -15,8 +15,7 @@ # limitations under the License. import logging - -import six +from typing import Iterable, Iterator from canonicaljson import encode_canonical_json, json @@ -27,21 +26,16 @@ from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList logger = logging.getLogger(__name__) -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview - class PusherWorkerStore(SQLBaseStore): - def _decode_pushers_rows(self, rows): + def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]: + """JSON-decode the data in the rows returned from the `pushers` table + + Drops any rows whose data cannot be decoded + """ for r in rows: dataJson = r["data"] - r["data"] = None try: - if isinstance(dataJson, db_binary_type): - dataJson = str(dataJson).decode("UTF8") - r["data"] = json.loads(dataJson) except Exception as e: logger.warning( @@ -50,12 +44,9 @@ class PusherWorkerStore(SQLBaseStore): dataJson, e.args[0], ) - pass - - if isinstance(r["pushkey"], db_binary_type): - r["pushkey"] = str(r["pushkey"]).decode("UTF8") + continue - return rows + yield r @defer.inlineCallbacks def user_has_pusher(self, user_id): diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 92e3b9c512..70ff5751b6 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py
@@ -477,7 +477,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids(self) + current_state_ids = yield context.get_current_state_ids() result = yield self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) diff --git a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
index 4219cdd06a..2de50d408c 100644 --- a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql +++ b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
@@ -20,7 +20,6 @@ DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY -DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT diff --git a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql new file mode 100644
index 0000000000..c2f557fde9 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql
@@ -0,0 +1,20 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This line already existed in deltas/35/device_stream_id but was not included in the +-- 54 full schema SQL. Add some SQL here to insert the missing row if it does not exist +INSERT INTO device_max_stream_id (stream_id) SELECT 0 WHERE NOT EXISTS ( + SELECT * from device_max_stream_id +); \ No newline at end of file diff --git a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql new file mode 100644
index 0000000000..4f24c1405d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql
@@ -0,0 +1,29 @@ +/* Copyright 2019 Werner Sembach + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Groups/communities now get deleted when the last member leaves. This is a one time cleanup to remove old groups/communities that were already empty before that change was made. +DELETE FROM group_attestations_remote WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_attestations_renewals WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_invites WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_users WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM local_group_membership WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM local_group_updates WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM groups WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
index 4ad2929f32..889a9a0ce4 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
@@ -975,40 +975,6 @@ CREATE TABLE state_events ( -CREATE TABLE state_group_edges ( - state_group bigint NOT NULL, - prev_state_group bigint NOT NULL -); - - - -CREATE SEQUENCE state_group_id_seq - START WITH 1 - INCREMENT BY 1 - NO MINVALUE - NO MAXVALUE - CACHE 1; - - - -CREATE TABLE state_groups ( - id bigint NOT NULL, - room_id text NOT NULL, - event_id text NOT NULL -); - - - -CREATE TABLE state_groups_state ( - state_group bigint NOT NULL, - room_id text NOT NULL, - type text NOT NULL, - state_key text NOT NULL, - event_id text NOT NULL -); - - - CREATE TABLE stats_stream_pos ( lock character(1) DEFAULT 'X'::bpchar NOT NULL, stream_id bigint, @@ -1482,12 +1448,6 @@ ALTER TABLE ONLY state_events ADD CONSTRAINT state_events_event_id_key UNIQUE (event_id); - -ALTER TABLE ONLY state_groups - ADD CONSTRAINT state_groups_pkey PRIMARY KEY (id); - - - ALTER TABLE ONLY stats_stream_pos ADD CONSTRAINT stats_stream_pos_lock_key UNIQUE (lock); @@ -1928,18 +1888,6 @@ CREATE UNIQUE INDEX room_stats_room_ts ON room_stats USING btree (room_id, ts); -CREATE INDEX state_group_edges_idx ON state_group_edges USING btree (state_group); - - - -CREATE INDEX state_group_edges_prev_idx ON state_group_edges USING btree (prev_state_group); - - - -CREATE INDEX state_groups_state_type_idx ON state_groups_state USING btree (state_group, type, state_key); - - - CREATE INDEX stream_ordering_to_exterm_idx ON stream_ordering_to_exterm USING btree (stream_ordering); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
index bad33291e7..a0411ede7e 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
@@ -42,8 +42,6 @@ CREATE INDEX ev_edges_id ON event_edges(event_id); CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id); CREATE TABLE room_depth( room_id TEXT NOT NULL, min_depth INTEGER NOT NULL, UNIQUE (room_id) ); CREATE INDEX room_depth_room ON room_depth(room_id); -CREATE TABLE state_groups( id BIGINT PRIMARY KEY, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); -CREATE TABLE state_groups_state( state_group BIGINT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, event_id TEXT NOT NULL ); CREATE TABLE event_to_state_groups( event_id TEXT NOT NULL, state_group BIGINT NOT NULL, UNIQUE (event_id) ); CREATE TABLE local_media_repository ( media_id TEXT, media_type TEXT, media_length INTEGER, created_ts BIGINT, upload_name TEXT, user_id TEXT, quarantined_by TEXT, url_cache TEXT, last_access_ts BIGINT, UNIQUE (media_id) ); CREATE TABLE local_media_repository_thumbnails ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type ) ); @@ -120,9 +118,6 @@ CREATE TABLE device_max_stream_id ( stream_id BIGINT NOT NULL ); CREATE TABLE public_room_list_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, visibility BOOLEAN NOT NULL , appservice_id TEXT, network_id TEXT); CREATE INDEX public_room_list_stream_idx on public_room_list_stream( stream_id ); CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( room_id, stream_id ); -CREATE TABLE state_group_edges( state_group BIGINT NOT NULL, prev_state_group BIGINT NOT NULL ); -CREATE INDEX state_group_edges_idx ON state_group_edges(state_group); -CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group); CREATE TABLE stream_ordering_to_exterm ( stream_ordering BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( stream_ordering ); CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( room_id, stream_ordering ); @@ -254,6 +249,5 @@ CREATE INDEX user_ips_last_seen_only ON user_ips (last_seen); CREATE INDEX users_creation_ts ON users (creation_ts); CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups (state_group); CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache (user_id, device_id); -CREATE INDEX state_groups_state_type_idx ON state_groups_state(state_group, type, state_key); CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties (user_id); CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips (user_id, access_token, ip); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
index c265fd20e2..91d21b2921 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
@@ -5,3 +5,4 @@ INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coales INSERT INTO user_directory_stream_pos (stream_id) VALUES (0); INSERT INTO stats_stream_pos (stream_id) VALUES (0); INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0); +-- device_max_stream_id is handled separately in 56/device_stream_id_insert.sql \ No newline at end of file diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md new file mode 100644
index 0000000000..bbd3f18604 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md
@@ -0,0 +1,13 @@ +# Building full schema dumps + +These schemas need to be made from a database that has had all background updates run. + +To do so, use `scripts-dev/make_full_schema.sh`. This will produce +`full.sql.postgres ` and `full.sql.sqlite` files. + +Ensure postgres is installed and your user has the ability to run bash commands +such as `createdb`. + +``` +./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/ +``` diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.txt b/synapse/storage/data_stores/main/schema/full_schemas/README.txt deleted file mode 100644
index d3f6401344..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/README.txt +++ /dev/null
@@ -1,19 +0,0 @@ -Building full schema dumps -========================== - -These schemas need to be made from a database that has had all background updates run. - -Postgres --------- - -$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres - -SQLite ------- - -$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite - -After ------ - -Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas". \ No newline at end of file diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index 4eec2fae5e..47ebb8a214 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py
@@ -25,6 +25,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -384,7 +385,7 @@ class SearchStore(SearchBackgroundUpdateStore): """ clauses = [] - search_query = search_query = _parse_query(self.database_engine, search_term) + search_query = _parse_query(self.database_engine, search_term) args = [] @@ -453,7 +454,12 @@ class SearchStore(SearchBackgroundUpdateStore): results = list(filter(lambda row: row["room_id"] in room_ids, results)) - events = yield self.get_events_as_list([r["event_id"] for r in results]) + # We set redact_behaviour to BLOCK here to prevent redacted events being returned in + # search results (which is a data leak) + events = yield self.get_events_as_list( + [r["event_id"] for r in results], + redact_behaviour=EventRedactBehaviour.BLOCK, + ) event_map = {ev.event_id: ev for ev in events} @@ -495,7 +501,7 @@ class SearchStore(SearchBackgroundUpdateStore): """ clauses = [] - search_query = search_query = _parse_query(self.database_engine, search_term) + search_query = _parse_query(self.database_engine, search_term) args = [] @@ -600,7 +606,12 @@ class SearchStore(SearchBackgroundUpdateStore): results = list(filter(lambda row: row["room_id"] in room_ids, results)) - events = yield self.get_events_as_list([r["event_id"] for r in results]) + # We set redact_behaviour to BLOCK here to prevent redacted events being returned in + # search results (which is a data leak) + events = yield self.get_events_as_list( + [r["event_id"] for r in results], + redact_behaviour=EventRedactBehaviour.BLOCK, + ) event_map = {ev.event_id: ev for ev in events} diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 9ef7b48c74..0dc39f139c 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py
@@ -17,8 +17,7 @@ import logging from collections import namedtuple from typing import Iterable, Tuple -from six import iteritems, itervalues -from six.moves import range +from six import iteritems from twisted.internet import defer @@ -29,11 +28,9 @@ from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.database import Database -from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter -from synapse.util.caches import get_cache_factor_for, intern_string +from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.stringutils import to_ascii logger = logging.getLogger(__name__) @@ -55,207 +52,14 @@ class _GetStateGroupDelta( return len(self.delta_ids) if self.delta_ids else 0 -class StateGroupBackgroundUpdateStore(SQLBaseStore): - """Defines functions related to state groups needed to run the state backgroud - updates. - """ - - def _count_state_group_hops_txn(self, txn, state_group): - """Given a state group, count how many hops there are in the tree. - - This is used to ensure the delta chains don't get too long. - """ - if isinstance(self.database_engine, PostgresEngine): - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT count(*) FROM state; - """ - - txn.execute(sql, (state_group,)) - row = txn.fetchone() - if row and row[0]: - return row[0] - else: - return 0 - else: - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - next_group = state_group - count = 0 - - while next_group: - next_group = self.db.simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - if next_group: - count += 1 - - return count - - def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter=StateFilter.all() - ): - results = {group: {} for group in groups} - - where_clause, where_args = state_filter.make_sql_filter_clause() - - # Unless the filter clause is empty, we're going to append it after an - # existing where clause - if where_clause: - where_clause = " AND (%s)" % (where_clause,) - - if isinstance(self.database_engine, PostgresEngine): - # Temporarily disable sequential scans in this transaction. This is - # a temporary hack until we can add the right indices in - txn.execute("SET LOCAL enable_seqscan=off") - - # The below query walks the state_group tree so that the "state" - # table includes all state_groups in the tree. It then joins - # against `state_groups_state` to fetch the latest state. - # It assumes that previous state groups are always numerically - # lesser. - # The PARTITION is used to get the event_id in the greatest state - # group for the given type, state_key. - # This may return multiple rows per (type, state_key), but last_value - # should be the same. - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT DISTINCT type, state_key, last_value(event_id) OVER ( - PARTITION BY type, state_key ORDER BY state_group ASC - ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING - ) AS event_id FROM state_groups_state - WHERE state_group IN ( - SELECT state_group FROM state - ) - """ - - for group in groups: - args = [group] - args.extend(where_args) - - txn.execute(sql + where_clause, args) - for row in txn: - typ, state_key, event_id = row - key = (typ, state_key) - results[group][key] = event_id - else: - max_entries_returned = state_filter.max_entries_returned() - - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - for group in groups: - next_group = group - - while next_group: - # We did this before by getting the list of group ids, and - # then passing that list to sqlite to get latest event for - # each (type, state_key). However, that was terribly slow - # without the right indices (which we can't add until - # after we finish deduping state, which requires this func) - args = [next_group] - args.extend(where_args) - - txn.execute( - "SELECT type, state_key, event_id FROM state_groups_state" - " WHERE state_group = ? " + where_clause, - args, - ) - results[group].update( - ((typ, state_key), event_id) - for typ, state_key, event_id in txn - if (typ, state_key) not in results[group] - ) - - # If the number of entries in the (type,state_key)->event_id dict - # matches the number of (type,state_keys) types we were searching - # for, then we must have found them all, so no need to go walk - # further down the tree... UNLESS our types filter contained - # wildcards (i.e. Nones) in which case we have to do an exhaustive - # search - if ( - max_entries_returned is not None - and len(results[group]) == max_entries_returned - ): - break - - next_group = self.db.simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - - return results - - # this inherits from EventsWorkerStore because it calls self.get_events -class StateGroupWorkerStore( - EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore -): +class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers. """ - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" - STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" - CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - def __init__(self, database: Database, db_conn, hs): super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) - # Originally the state store used a single DictionaryCache to cache the - # event IDs for the state types in a given state group to avoid hammering - # on the state_group* tables. - # - # The point of using a DictionaryCache is that it can cache a subset - # of the state events for a given state group (i.e. a subset of the keys for a - # given dict which is an entry in the cache for a given state group ID). - # - # However, this poses problems when performing complicated queries - # on the store - for instance: "give me all the state for this group, but - # limit members to this subset of users", as DictionaryCache's API isn't - # rich enough to say "please cache any of these fields, apart from this subset". - # This is problematic when lazy loading members, which requires this behaviour, - # as without it the cache has no choice but to speculatively load all - # state events for the group, which negates the efficiency being sought. - # - # Rather than overcomplicating DictionaryCache's API, we instead split the - # state_group_cache into two halves - one for tracking non-member events, - # and the other for tracking member_events. This means that lazy loading - # queries can be made in a cache-friendly manner by querying both caches - # separately and then merging the result. So for the example above, you - # would query the members cache for a specific subset of state keys - # (which DictionaryCache will handle efficiently and fine) and the non-members - # cache for all state (which DictionaryCache will similarly handle fine) - # and then just merge the results together. - # - # We size the non-members cache to be smaller than the members cache as the - # vast majority of state in Matrix (today) is member events. - - self._state_group_cache = DictionaryCache( - "*stateGroupCache*", - # TODO: this hasn't been tuned yet - 50000 * get_cache_factor_for("stateGroupCache"), - ) - self._state_group_members_cache = DictionaryCache( - "*stateGroupMembersCache*", - 500000 * get_cache_factor_for("stateGroupMembersCache"), - ) - @defer.inlineCallbacks def get_room_version(self, room_id): """Get the room_version of a given room @@ -278,7 +82,7 @@ class StateGroupWorkerStore( @defer.inlineCallbacks def get_room_predecessor(self, room_id): - """Get the predecessor room of an upgraded room if one exists. + """Get the predecessor of an upgraded room if it exists. Otherwise return None. Args: @@ -291,14 +95,22 @@ class StateGroupWorkerStore( * room_id (str): The room ID of the predecessor room * event_id (str): The ID of the tombstone event in the predecessor room + None if a predecessor key is not found, or is not a dictionary. + Raises: - NotFoundError if the room is unknown + NotFoundError if the given room is unknown """ # Retrieve the room's create event create_event = yield self.get_create_event_for_room(room_id) - # Return predecessor if present - return create_event.content.get("predecessor", None) + # Retrieve the predecessor key of the create event + predecessor = create_event.content.get("predecessor", None) + + # Ensure the key is a dictionary + if not isinstance(predecessor, dict): + return None + + return predecessor @defer.inlineCallbacks def get_create_event_for_room(self, room_id): @@ -318,7 +130,7 @@ class StateGroupWorkerStore( # If we can't find the create event, assume we've hit a dead end if not create_id: - raise NotFoundError("Unknown room %s" % (room_id)) + raise NotFoundError("Unknown room %s" % (room_id,)) # Retrieve the room's create event and return create_event = yield self.get_event(create_id) @@ -423,229 +235,6 @@ class StateGroupWorkerStore( return event.content.get("canonical_alias") - @cached(max_entries=10000, iterable=True) - def get_state_group_delta(self, state_group): - """Given a state group try to return a previous group and a delta between - the old and the new. - - Returns: - (prev_group, delta_ids), where both may be None. - """ - - def _get_state_group_delta_txn(txn): - prev_group = self.db.simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": state_group}, - retcol="prev_state_group", - allow_none=True, - ) - - if not prev_group: - return _GetStateGroupDelta(None, None) - - delta_ids = self.db.simple_select_list_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - retcols=("type", "state_key", "event_id"), - ) - - return _GetStateGroupDelta( - prev_group, - {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, - ) - - return self.db.runInteraction( - "get_state_group_delta", _get_state_group_delta_txn - ) - - @defer.inlineCallbacks - def get_state_groups_ids(self, _room_id, event_ids): - """Get the event IDs of all the state for the state groups for the given events - - Args: - _room_id (str): id of the room for these events - event_ids (iterable[str]): ids of the events - - Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) - """ - if not event_ids: - return {} - - event_to_groups = yield self._get_state_group_for_events(event_ids) - - groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups) - - return group_to_state - - @defer.inlineCallbacks - def get_state_ids_for_group(self, state_group): - """Get the event IDs of all the state in the given state group - - Args: - state_group (int) - - Returns: - Deferred[dict]: Resolves to a map of (type, state_key) -> event_id - """ - group_to_state = yield self._get_state_for_groups((state_group,)) - - return group_to_state[state_group] - - @defer.inlineCallbacks - def get_state_groups(self, room_id, event_ids): - """ Get the state groups for the given list of event_ids - - Returns: - Deferred[dict[int, list[EventBase]]]: - dict of state_group_id -> list of state events. - """ - if not event_ids: - return {} - - group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) - - state_event_map = yield self.get_events( - [ - ev_id - for group_ids in itervalues(group_to_ids) - for ev_id in itervalues(group_ids) - ], - get_prev_content=False, - ) - - return { - group: [ - state_event_map[v] - for v in itervalues(event_id_map) - if v in state_event_map - ] - for group, event_id_map in iteritems(group_to_ids) - } - - @defer.inlineCallbacks - def _get_state_groups_from_groups(self, groups, state_filter): - """Returns the state groups for a given set of groups, filtering on - types of state events. - - Args: - groups(list[int]): list of state group IDs to query - state_filter (StateFilter): The state filter used to fetch state - from the database. - Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) - """ - results = {} - - chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] - for chunk in chunks: - res = yield self.db.runInteraction( - "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, - chunk, - state_filter, - ) - results.update(res) - - return results - - @defer.inlineCallbacks - def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): - """Given a list of event_ids and type tuples, return a list of state - dicts for each event. - - Args: - event_ids (list[string]) - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - deferred: A dict of (event_id) -> (type, state_key) -> [state_events] - """ - event_to_groups = yield self._get_state_group_for_events(event_ids) - - groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, state_filter) - - state_event_map = yield self.get_events( - [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], - get_prev_content=False, - ) - - event_to_state = { - event_id: { - k: state_event_map[v] - for k, v in iteritems(group_to_state[group]) - if v in state_event_map - } - for event_id, group in iteritems(event_to_groups) - } - - return {event: event_to_state[event] for event in event_ids} - - @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): - """ - Get the state dicts corresponding to a list of events, containing the event_ids - of the state events (as opposed to the events themselves) - - Args: - event_ids(list(str)): events whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A deferred dict from event_id -> (type, state_key) -> event_id - """ - event_to_groups = yield self._get_state_group_for_events(event_ids) - - groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, state_filter) - - event_to_state = { - event_id: group_to_state[group] - for event_id, group in iteritems(event_to_groups) - } - - return {event: event_to_state[event] for event in event_ids} - - @defer.inlineCallbacks - def get_state_for_event(self, event_id, state_filter=StateFilter.all()): - """ - Get the state dict corresponding to a particular event - - Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A deferred dict from (type, state_key) -> state_event - """ - state_map = yield self.get_state_for_events([event_id], state_filter) - return state_map[event_id] - - @defer.inlineCallbacks - def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): - """ - Get the state dict corresponding to a particular event - - Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A deferred dict from (type, state_key) -> state_event - """ - state_map = yield self.get_state_ids_for_events([event_id], state_filter) - return state_map[event_id] - @cached(max_entries=50000) def _get_state_group_for_event(self, event_id): return self.db.simple_select_one_onecol( @@ -676,329 +265,6 @@ class StateGroupWorkerStore( return {row["event_id"]: row["state_group"] for row in rows} - def _get_state_for_group_using_cache(self, cache, group, state_filter): - """Checks if group is in cache. See `_get_state_for_groups` - - Args: - cache(DictionaryCache): the state group cache to use - group(int): The state group to lookup - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns 2-tuple (`state_dict`, `got_all`). - `got_all` is a bool indicating if we successfully retrieved all - requests state from the cache, if False we need to query the DB for the - missing state. - """ - is_all, known_absent, state_dict_ids = cache.get(group) - - if is_all or state_filter.is_full(): - # Either we have everything or want everything, either way - # `is_all` tells us whether we've gotten everything. - return state_filter.filter_state(state_dict_ids), is_all - - # tracks whether any of our requested types are missing from the cache - missing_types = False - - if state_filter.has_wildcards(): - # We don't know if we fetched all the state keys for the types in - # the filter that are wildcards, so we have to assume that we may - # have missed some. - missing_types = True - else: - # There aren't any wild cards, so `concrete_types()` returns the - # complete list of event types we're wanting. - for key in state_filter.concrete_types(): - if key not in state_dict_ids and key not in known_absent: - missing_types = True - break - - return state_filter.filter_state(state_dict_ids), not missing_types - - @defer.inlineCallbacks - def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key - - Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - state_filter (StateFilter): The state filter used to fetch state - from the database. - Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) - """ - - member_filter, non_member_filter = state_filter.get_member_split() - - # Now we look them up in the member and non-member caches - ( - non_member_state, - incomplete_groups_nm, - ) = yield self._get_state_for_groups_using_cache( - groups, self._state_group_cache, state_filter=non_member_filter - ) - - ( - member_state, - incomplete_groups_m, - ) = yield self._get_state_for_groups_using_cache( - groups, self._state_group_members_cache, state_filter=member_filter - ) - - state = dict(non_member_state) - for group in groups: - state[group].update(member_state[group]) - - # Now fetch any missing groups from the database - - incomplete_groups = incomplete_groups_m | incomplete_groups_nm - - if not incomplete_groups: - return state - - cache_sequence_nm = self._state_group_cache.sequence - cache_sequence_m = self._state_group_members_cache.sequence - - # Help the cache hit ratio by expanding the filter a bit - db_state_filter = state_filter.return_expanded() - - group_to_state_dict = yield self._get_state_groups_from_groups( - list(incomplete_groups), state_filter=db_state_filter - ) - - # Now lets update the caches - self._insert_into_cache( - group_to_state_dict, - db_state_filter, - cache_seq_num_members=cache_sequence_m, - cache_seq_num_non_members=cache_sequence_nm, - ) - - # And finally update the result dict, by filtering out any extra - # stuff we pulled out of the database. - for group, group_state_dict in iteritems(group_to_state_dict): - # We just replace any existing entries, as we will have loaded - # everything we need from the database anyway. - state[group] = state_filter.filter_state(group_state_dict) - - return state - - def _get_state_for_groups_using_cache(self, groups, cache, state_filter): - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key, querying from a specific cache. - - Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - cache (DictionaryCache): the cache of group ids to state dicts which - we will pass through - either the normal state cache or the specific - members state cache. - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of - dict of state_group_id -> (dict of (type, state_key) -> event id) - of entries in the cache, and the state group ids either missing - from the cache or incomplete. - """ - results = {} - incomplete_groups = set() - for group in set(groups): - state_dict_ids, got_all = self._get_state_for_group_using_cache( - cache, group, state_filter - ) - results[group] = state_dict_ids - - if not got_all: - incomplete_groups.add(group) - - return results, incomplete_groups - - def _insert_into_cache( - self, - group_to_state_dict, - state_filter, - cache_seq_num_members, - cache_seq_num_non_members, - ): - """Inserts results from querying the database into the relevant cache. - - Args: - group_to_state_dict (dict): The new entries pulled from database. - Map from state group to state dict - state_filter (StateFilter): The state filter used to fetch state - from the database. - cache_seq_num_members (int): Sequence number of member cache since - last lookup in cache - cache_seq_num_non_members (int): Sequence number of member cache since - last lookup in cache - """ - - # We need to work out which types we've fetched from the DB for the - # member vs non-member caches. This should be as accurate as possible, - # but can be an underestimate (e.g. when we have wild cards) - - member_filter, non_member_filter = state_filter.get_member_split() - if member_filter.is_full(): - # We fetched all member events - member_types = None - else: - # `concrete_types()` will only return a subset when there are wild - # cards in the filter, but that's fine. - member_types = member_filter.concrete_types() - - if non_member_filter.is_full(): - # We fetched all non member events - non_member_types = None - else: - non_member_types = non_member_filter.concrete_types() - - for group, group_state_dict in iteritems(group_to_state_dict): - state_dict_members = {} - state_dict_non_members = {} - - for k, v in iteritems(group_state_dict): - if k[0] == EventTypes.Member: - state_dict_members[k] = v - else: - state_dict_non_members[k] = v - - self._state_group_members_cache.update( - cache_seq_num_members, - key=group, - value=state_dict_members, - fetched_keys=member_types, - ) - - self._state_group_cache.update( - cache_seq_num_non_members, - key=group, - value=state_dict_non_members, - fetched_keys=non_member_types, - ) - - def store_state_group( - self, event_id, room_id, prev_group, delta_ids, current_state_ids - ): - """Store a new set of state, returning a newly assigned state group. - - Args: - event_id (str): The event ID for which the state was calculated - room_id (str) - prev_group (int|None): A previous state group for the room, optional. - delta_ids (dict|None): The delta between state at `prev_group` and - `current_state_ids`, if `prev_group` was given. Same format as - `current_state_ids`. - current_state_ids (dict): The state to store. Map of (type, state_key) - to event_id. - - Returns: - Deferred[int]: The state group ID - """ - - def _store_state_group_txn(txn): - if current_state_ids is None: - # AFAIK, this can never happen - raise Exception("current_state_ids cannot be None") - - state_group = self.database_engine.get_next_state_group_id(txn) - - self.db.simple_insert_txn( - txn, - table="state_groups", - values={"id": state_group, "room_id": room_id, "event_id": event_id}, - ) - - # We persist as a delta if we can, while also ensuring the chain - # of deltas isn't tooo long, as otherwise read performance degrades. - if prev_group: - is_in_db = self.db.simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, - ) - if not is_in_db: - raise Exception( - "Trying to persist state with unpersisted prev_group: %r" - % (prev_group,) - ) - - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self.db.simple_insert_txn( - txn, - table="state_group_edges", - values={"state_group": state_group, "prev_state_group": prev_group}, - ) - - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(delta_ids) - ], - ) - else: - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(current_state_ids) - ], - ) - - # Prefill the state group caches with this group. - # It's fine to use the sequence like this as the state group map - # is immutable. (If the map wasn't immutable then this prefill could - # race with another update) - - current_member_state_ids = { - s: ev - for (s, ev) in iteritems(current_state_ids) - if s[0] == EventTypes.Member - } - txn.call_after( - self._state_group_members_cache.update, - self._state_group_members_cache.sequence, - key=state_group, - value=dict(current_member_state_ids), - ) - - current_non_member_state_ids = { - s: ev - for (s, ev) in iteritems(current_state_ids) - if s[0] != EventTypes.Member - } - txn.call_after( - self._state_group_cache.update, - self._state_group_cache.sequence, - key=state_group, - value=dict(current_non_member_state_ids), - ) - - return state_group - - return self.db.runInteraction("store_state_group", _store_state_group_txn) - @defer.inlineCallbacks def get_referenced_state_groups(self, state_groups): """Check if the state groups are referenced by events. @@ -1023,22 +289,14 @@ class StateGroupWorkerStore( return set(row["state_group"] for row in rows) -class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): +class MainStateBackgroundUpdateStore(SQLBaseStore): - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" - STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" def __init__(self, database: Database, db_conn, hs): - super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_update_handler( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, - self._background_deduplicate_state, - ) - self.db.updates.register_background_update_handler( - self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state - ) + super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + self.db.updates.register_background_index_update( self.CURRENT_STATE_INDEX_UPDATE_NAME, index_name="current_state_events_member_index", @@ -1053,181 +311,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): columns=["state_group"], ) - @defer.inlineCallbacks - def _background_deduplicate_state(self, progress, batch_size): - """This background update will slowly deduplicate state by reencoding - them as deltas. - """ - last_state_group = progress.get("last_state_group", 0) - rows_inserted = progress.get("rows_inserted", 0) - max_group = progress.get("max_group", None) - - BATCH_SIZE_SCALE_FACTOR = 100 - batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) - - if max_group is None: - rows = yield self.db.execute( - "_background_deduplicate_state", - None, - "SELECT coalesce(max(id), 0) FROM state_groups", - ) - max_group = rows[0][0] - - def reindex_txn(txn): - new_last_state_group = last_state_group - for count in range(batch_size): - txn.execute( - "SELECT id, room_id FROM state_groups" - " WHERE ? < id AND id <= ?" - " ORDER BY id ASC" - " LIMIT 1", - (new_last_state_group, max_group), - ) - row = txn.fetchone() - if row: - state_group, room_id = row - - if not row or not state_group: - return True, count - - txn.execute( - "SELECT state_group FROM state_group_edges" - " WHERE state_group = ?", - (state_group,), - ) - - # If we reach a point where we've already started inserting - # edges we should stop. - if txn.fetchall(): - return True, count - - txn.execute( - "SELECT coalesce(max(id), 0) FROM state_groups" - " WHERE id < ? AND room_id = ?", - (state_group, room_id), - ) - (prev_group,) = txn.fetchone() - new_last_state_group = state_group - - if prev_group: - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if potential_hops >= MAX_STATE_DELTA_HOPS: - # We want to ensure chains are at most this long,# - # otherwise read performance degrades. - continue - - prev_state = self._get_state_groups_from_groups_txn( - txn, [prev_group] - ) - prev_state = prev_state[prev_group] - - curr_state = self._get_state_groups_from_groups_txn( - txn, [state_group] - ) - curr_state = curr_state[state_group] - - if not set(prev_state.keys()) - set(curr_state.keys()): - # We can only do a delta if the current has a strict super set - # of keys - - delta_state = { - key: value - for key, value in iteritems(curr_state) - if prev_state.get(key, None) != value - } - - self.db.simple_delete_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": state_group}, - ) - - self.db.simple_insert_txn( - txn, - table="state_group_edges", - values={ - "state_group": state_group, - "prev_state_group": prev_group, - }, - ) - - self.db.simple_delete_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - ) - - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(delta_state) - ], - ) - - progress = { - "last_state_group": state_group, - "rows_inserted": rows_inserted + batch_size, - "max_group": max_group, - } - - self.db.updates._background_update_progress_txn( - txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress - ) - - return False, batch_size - - finished, result = yield self.db.runInteraction( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn - ) - - if finished: - yield self.db.updates._end_background_update( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME - ) - - return result * BATCH_SIZE_SCALE_FACTOR - - @defer.inlineCallbacks - def _background_index_state(self, progress, batch_size): - def reindex_txn(conn): - conn.rollback() - if isinstance(self.database_engine, PostgresEngine): - # postgres insists on autocommit for the index - conn.set_session(autocommit=True) - try: - txn = conn.cursor() - txn.execute( - "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" - " ON state_groups_state(state_group, type, state_key)" - ) - txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - finally: - conn.set_session(autocommit=False) - else: - txn = conn.cursor() - txn.execute( - "CREATE INDEX state_groups_state_type_idx" - " ON state_groups_state(state_group, type, state_key)" - ) - txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - - yield self.db.runWithConnection(reindex_txn) - - yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) - - return 1 - - -class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): +class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): """ Keeps track of the state at a given event. This is done by the concept of `state groups`. Every event is a assigned diff --git a/synapse/storage/data_stores/state/__init__.py b/synapse/storage/data_stores/state/__init__.py new file mode 100644
index 0000000000..86e09f6229 --- /dev/null +++ b/synapse/storage/data_stores/state/__init__.py
@@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.data_stores.state.store import StateGroupDataStore # noqa: F401 diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/data_stores/state/bg_updates.py new file mode 100644
index 0000000000..e8edaf9f7b --- /dev/null +++ b/synapse/storage/data_stores/state/bg_updates.py
@@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from six import iteritems + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database +from synapse.storage.engines import PostgresEngine +from synapse.storage.state import StateFilter + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class StateGroupBackgroundUpdateStore(SQLBaseStore): + """Defines functions related to state groups needed to run the state backgroud + updates. + """ + + def _count_state_group_hops_txn(self, txn, state_group): + """Given a state group, count how many hops there are in the tree. + + This is used to ensure the delta chains don't get too long. + """ + if isinstance(self.database_engine, PostgresEngine): + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT count(*) FROM state; + """ + + txn.execute(sql, (state_group,)) + row = txn.fetchone() + if row and row[0]: + return row[0] + else: + return 0 + else: + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = state_group + count = 0 + + while next_group: + next_group = self.db.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + count += 1 + + return count + + def _get_state_groups_from_groups_txn( + self, txn, groups, state_filter=StateFilter.all() + ): + results = {group: {} for group in groups} + + where_clause, where_args = state_filter.make_sql_filter_clause() + + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) + + if isinstance(self.database_engine, PostgresEngine): + # Temporarily disable sequential scans in this transaction. This is + # a temporary hack until we can add the right indices in + txn.execute("SET LOCAL enable_seqscan=off") + + # The below query walks the state_group tree so that the "state" + # table includes all state_groups in the tree. It then joins + # against `state_groups_state` to fetch the latest state. + # It assumes that previous state groups are always numerically + # lesser. + # The PARTITION is used to get the event_id in the greatest state + # group for the given type, state_key. + # This may return multiple rows per (type, state_key), but last_value + # should be the same. + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT DISTINCT type, state_key, last_value(event_id) OVER ( + PARTITION BY type, state_key ORDER BY state_group ASC + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) AS event_id FROM state_groups_state + WHERE state_group IN ( + SELECT state_group FROM state + ) + """ + + for group in groups: + args = [group] + args.extend(where_args) + + txn.execute(sql + where_clause, args) + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + results[group][key] = event_id + else: + max_entries_returned = state_filter.max_entries_returned() + + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + for group in groups: + next_group = group + + while next_group: + # We did this before by getting the list of group ids, and + # then passing that list to sqlite to get latest event for + # each (type, state_key). However, that was terribly slow + # without the right indices (which we can't add until + # after we finish deduping state, which requires this func) + args = [next_group] + args.extend(where_args) + + txn.execute( + "SELECT type, state_key, event_id FROM state_groups_state" + " WHERE state_group = ? " + where_clause, + args, + ) + results[group].update( + ((typ, state_key), event_id) + for typ, state_key, event_id in txn + if (typ, state_key) not in results[group] + ) + + # If the number of entries in the (type,state_key)->event_id dict + # matches the number of (type,state_keys) types we were searching + # for, then we must have found them all, so no need to go walk + # further down the tree... UNLESS our types filter contained + # wildcards (i.e. Nones) in which case we have to do an exhaustive + # search + if ( + max_entries_returned is not None + and len(results[group]) == max_entries_returned + ): + break + + next_group = self.db.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + + return results + + +class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" + + def __init__(self, database: Database, db_conn, hs): + super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + self.db.updates.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.db.updates.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state + ) + self.db.updates.register_background_index_update( + self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME, + index_name="state_groups_room_id_idx", + table="state_groups", + columns=["room_id"], + ) + + @defer.inlineCallbacks + def _background_deduplicate_state(self, progress, batch_size): + """This background update will slowly deduplicate state by reencoding + them as deltas. + """ + last_state_group = progress.get("last_state_group", 0) + rows_inserted = progress.get("rows_inserted", 0) + max_group = progress.get("max_group", None) + + BATCH_SIZE_SCALE_FACTOR = 100 + + batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) + + if max_group is None: + rows = yield self.db.execute( + "_background_deduplicate_state", + None, + "SELECT coalesce(max(id), 0) FROM state_groups", + ) + max_group = rows[0][0] + + def reindex_txn(txn): + new_last_state_group = last_state_group + for count in range(batch_size): + txn.execute( + "SELECT id, room_id FROM state_groups" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC" + " LIMIT 1", + (new_last_state_group, max_group), + ) + row = txn.fetchone() + if row: + state_group, room_id = row + + if not row or not state_group: + return True, count + + txn.execute( + "SELECT state_group FROM state_group_edges" + " WHERE state_group = ?", + (state_group,), + ) + + # If we reach a point where we've already started inserting + # edges we should stop. + if txn.fetchall(): + return True, count + + txn.execute( + "SELECT coalesce(max(id), 0) FROM state_groups" + " WHERE id < ? AND room_id = ?", + (state_group, room_id), + ) + (prev_group,) = txn.fetchone() + new_last_state_group = state_group + + if prev_group: + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if potential_hops >= MAX_STATE_DELTA_HOPS: + # We want to ensure chains are at most this long,# + # otherwise read performance degrades. + continue + + prev_state = self._get_state_groups_from_groups_txn( + txn, [prev_group] + ) + prev_state = prev_state[prev_group] + + curr_state = self._get_state_groups_from_groups_txn( + txn, [state_group] + ) + curr_state = curr_state[state_group] + + if not set(prev_state.keys()) - set(curr_state.keys()): + # We can only do a delta if the current has a strict super set + # of keys + + delta_state = { + key: value + for key, value in iteritems(curr_state) + if prev_state.get(key, None) != value + } + + self.db.simple_delete_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + ) + + self.db.simple_insert_txn( + txn, + table="state_group_edges", + values={ + "state_group": state_group, + "prev_state_group": prev_group, + }, + ) + + self.db.simple_delete_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + ) + + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(delta_state) + ], + ) + + progress = { + "last_state_group": state_group, + "rows_inserted": rows_inserted + batch_size, + "max_group": max_group, + } + + self.db.updates._background_update_progress_txn( + txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress + ) + + return False, batch_size + + finished, result = yield self.db.runInteraction( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn + ) + + if finished: + yield self.db.updates._end_background_update( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME + ) + + return result * BATCH_SIZE_SCALE_FACTOR + + @defer.inlineCallbacks + def _background_index_state(self, progress, batch_size): + def reindex_txn(conn): + conn.rollback() + if isinstance(self.database_engine, PostgresEngine): + # postgres insists on autocommit for the index + conn.set_session(autocommit=True) + try: + txn = conn.cursor() + txn.execute( + "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + finally: + conn.set_session(autocommit=False) + else: + txn = conn.cursor() + txn.execute( + "CREATE INDEX state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + + yield self.db.runWithConnection(reindex_txn) + + yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) + + return 1 diff --git a/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql
index ae09fa0065..ae09fa0065 100644 --- a/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql +++ b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql
index e85699e82e..e85699e82e 100644 --- a/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql +++ b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql
diff --git a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql new file mode 100644
index 0000000000..1450313bfa --- /dev/null +++ b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql
@@ -0,0 +1,19 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- The following indices are redundant, other indices are equivalent or +-- supersets +DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY diff --git a/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql
index 33980d02f0..33980d02f0 100644 --- a/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql +++ b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/35/state.sql b/synapse/storage/data_stores/state/schema/delta/35/state.sql
index 0f1fa68a89..0f1fa68a89 100644 --- a/synapse/storage/data_stores/main/schema/delta/35/state.sql +++ b/synapse/storage/data_stores/state/schema/delta/35/state.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql
index 97e5067ef4..97e5067ef4 100644 --- a/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql +++ b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py
index 9fd1ccf6f7..9fd1ccf6f7 100644 --- a/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py +++ b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py
diff --git a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql new file mode 100644
index 0000000000..7916ef18b2 --- /dev/null +++ b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql
@@ -0,0 +1,17 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('state_groups_room_id_idx', '{}'); diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql new file mode 100644
index 0000000000..35f97d6b3d --- /dev/null +++ b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql
@@ -0,0 +1,37 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE state_groups ( + id BIGINT PRIMARY KEY, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE state_groups_state ( + state_group BIGINT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE state_group_edges ( + state_group BIGINT NOT NULL, + prev_state_group BIGINT NOT NULL +); + +CREATE INDEX state_group_edges_idx ON state_group_edges (state_group); +CREATE INDEX state_group_edges_prev_idx ON state_group_edges (prev_state_group); +CREATE INDEX state_groups_state_type_idx ON state_groups_state (state_group, type, state_key); diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres new file mode 100644
index 0000000000..fcd926c9fb --- /dev/null +++ b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres
@@ -0,0 +1,21 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE state_group_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py new file mode 100644
index 0000000000..d53695f238 --- /dev/null +++ b/synapse/storage/data_stores/state/store.py
@@ -0,0 +1,640 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import namedtuple + +from six import iteritems +from six.moves import range + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore +from synapse.storage.database import Database +from synapse.storage.state import StateFilter +from synapse.util.caches import get_cache_factor_for +from synapse.util.caches.descriptors import cached +from synapse.util.caches.dictionary_cache import DictionaryCache + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class _GetStateGroupDelta( + namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) +): + """Return type of get_state_group_delta that implements __len__, which lets + us use the itrable flag when caching + """ + + __slots__ = [] + + def __len__(self): + return len(self.delta_ids) if self.delta_ids else 0 + + +class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): + """A data store for fetching/storing state groups. + """ + + def __init__(self, database: Database, db_conn, hs): + super(StateGroupDataStore, self).__init__(database, db_conn, hs) + + # Originally the state store used a single DictionaryCache to cache the + # event IDs for the state types in a given state group to avoid hammering + # on the state_group* tables. + # + # The point of using a DictionaryCache is that it can cache a subset + # of the state events for a given state group (i.e. a subset of the keys for a + # given dict which is an entry in the cache for a given state group ID). + # + # However, this poses problems when performing complicated queries + # on the store - for instance: "give me all the state for this group, but + # limit members to this subset of users", as DictionaryCache's API isn't + # rich enough to say "please cache any of these fields, apart from this subset". + # This is problematic when lazy loading members, which requires this behaviour, + # as without it the cache has no choice but to speculatively load all + # state events for the group, which negates the efficiency being sought. + # + # Rather than overcomplicating DictionaryCache's API, we instead split the + # state_group_cache into two halves - one for tracking non-member events, + # and the other for tracking member_events. This means that lazy loading + # queries can be made in a cache-friendly manner by querying both caches + # separately and then merging the result. So for the example above, you + # would query the members cache for a specific subset of state keys + # (which DictionaryCache will handle efficiently and fine) and the non-members + # cache for all state (which DictionaryCache will similarly handle fine) + # and then just merge the results together. + # + # We size the non-members cache to be smaller than the members cache as the + # vast majority of state in Matrix (today) is member events. + + self._state_group_cache = DictionaryCache( + "*stateGroupCache*", + # TODO: this hasn't been tuned yet + 50000 * get_cache_factor_for("stateGroupCache"), + ) + self._state_group_members_cache = DictionaryCache( + "*stateGroupMembersCache*", + 500000 * get_cache_factor_for("stateGroupMembersCache"), + ) + + @cached(max_entries=10000, iterable=True) + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + (prev_group, delta_ids), where both may be None. + """ + + def _get_state_group_delta_txn(txn): + prev_group = self.db.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + retcol="prev_state_group", + allow_none=True, + ) + + if not prev_group: + return _GetStateGroupDelta(None, None) + + delta_ids = self.db.simple_select_list_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + retcols=("type", "state_key", "event_id"), + ) + + return _GetStateGroupDelta( + prev_group, + {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, + ) + + return self.db.runInteraction( + "get_state_group_delta", _get_state_group_delta_txn + ) + + @defer.inlineCallbacks + def _get_state_groups_from_groups(self, groups, state_filter): + """Returns the state groups for a given set of groups, filtering on + types of state events. + + Args: + groups(list[int]): list of state group IDs to query + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + results = {} + + chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] + for chunk in chunks: + res = yield self.db.runInteraction( + "_get_state_groups_from_groups", + self._get_state_groups_from_groups_txn, + chunk, + state_filter, + ) + results.update(res) + + return results + + def _get_state_for_group_using_cache(self, cache, group, state_filter): + """Checks if group is in cache. See `_get_state_for_groups` + + Args: + cache(DictionaryCache): the state group cache to use + group(int): The state group to lookup + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns 2-tuple (`state_dict`, `got_all`). + `got_all` is a bool indicating if we successfully retrieved all + requests state from the cache, if False we need to query the DB for the + missing state. + """ + is_all, known_absent, state_dict_ids = cache.get(group) + + if is_all or state_filter.is_full(): + # Either we have everything or want everything, either way + # `is_all` tells us whether we've gotten everything. + return state_filter.filter_state(state_dict_ids), is_all + + # tracks whether any of our requested types are missing from the cache + missing_types = False + + if state_filter.has_wildcards(): + # We don't know if we fetched all the state keys for the types in + # the filter that are wildcards, so we have to assume that we may + # have missed some. + missing_types = True + else: + # There aren't any wild cards, so `concrete_types()` returns the + # complete list of event types we're wanting. + for key in state_filter.concrete_types(): + if key not in state_dict_ids and key not in known_absent: + missing_types = True + break + + return state_filter.filter_state(state_dict_ids), not missing_types + + @defer.inlineCallbacks + def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + + member_filter, non_member_filter = state_filter.get_member_split() + + # Now we look them up in the member and non-member caches + ( + non_member_state, + incomplete_groups_nm, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, state_filter=non_member_filter + ) + + ( + member_state, + incomplete_groups_m, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_members_cache, state_filter=member_filter + ) + + state = dict(non_member_state) + for group in groups: + state[group].update(member_state[group]) + + # Now fetch any missing groups from the database + + incomplete_groups = incomplete_groups_m | incomplete_groups_nm + + if not incomplete_groups: + return state + + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + # Help the cache hit ratio by expanding the filter a bit + db_state_filter = state_filter.return_expanded() + + group_to_state_dict = yield self._get_state_groups_from_groups( + list(incomplete_groups), state_filter=db_state_filter + ) + + # Now lets update the caches + self._insert_into_cache( + group_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + # And finally update the result dict, by filtering out any extra + # stuff we pulled out of the database. + for group, group_state_dict in iteritems(group_to_state_dict): + # We just replace any existing entries, as we will have loaded + # everything we need from the database anyway. + state[group] = state_filter.filter_state(group_state_dict) + + return state + + def _get_state_for_groups_using_cache(self, groups, cache, state_filter): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key, querying from a specific cache. + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + cache (DictionaryCache): the cache of group ids to state dicts which + we will pass through - either the normal state cache or the specific + members state cache. + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of + dict of state_group_id -> (dict of (type, state_key) -> event id) + of entries in the cache, and the state group ids either missing + from the cache or incomplete. + """ + results = {} + incomplete_groups = set() + for group in set(groups): + state_dict_ids, got_all = self._get_state_for_group_using_cache( + cache, group, state_filter + ) + results[group] = state_dict_ids + + if not got_all: + incomplete_groups.add(group) + + return results, incomplete_groups + + def _insert_into_cache( + self, + group_to_state_dict, + state_filter, + cache_seq_num_members, + cache_seq_num_non_members, + ): + """Inserts results from querying the database into the relevant cache. + + Args: + group_to_state_dict (dict): The new entries pulled from database. + Map from state group to state dict + state_filter (StateFilter): The state filter used to fetch state + from the database. + cache_seq_num_members (int): Sequence number of member cache since + last lookup in cache + cache_seq_num_non_members (int): Sequence number of member cache since + last lookup in cache + """ + + # We need to work out which types we've fetched from the DB for the + # member vs non-member caches. This should be as accurate as possible, + # but can be an underestimate (e.g. when we have wild cards) + + member_filter, non_member_filter = state_filter.get_member_split() + if member_filter.is_full(): + # We fetched all member events + member_types = None + else: + # `concrete_types()` will only return a subset when there are wild + # cards in the filter, but that's fine. + member_types = member_filter.concrete_types() + + if non_member_filter.is_full(): + # We fetched all non member events + non_member_types = None + else: + non_member_types = non_member_filter.concrete_types() + + for group, group_state_dict in iteritems(group_to_state_dict): + state_dict_members = {} + state_dict_non_members = {} + + for k, v in iteritems(group_state_dict): + if k[0] == EventTypes.Member: + state_dict_members[k] = v + else: + state_dict_non_members[k] = v + + self._state_group_members_cache.update( + cache_seq_num_members, + key=group, + value=state_dict_members, + fetched_keys=member_types, + ) + + self._state_group_cache.update( + cache_seq_num_non_members, + key=group, + value=state_dict_non_members, + fetched_keys=non_member_types, + ) + + def store_state_group( + self, event_id, room_id, prev_group, delta_ids, current_state_ids + ): + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id (str): The event ID for which the state was calculated + room_id (str) + prev_group (int|None): A previous state group for the room, optional. + delta_ids (dict|None): The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids (dict): The state to store. Map of (type, state_key) + to event_id. + + Returns: + Deferred[int]: The state group ID + """ + + def _store_state_group_txn(txn): + if current_state_ids is None: + # AFAIK, this can never happen + raise Exception("current_state_ids cannot be None") + + state_group = self.database_engine.get_next_state_group_id(txn) + + self.db.simple_insert_txn( + txn, + table="state_groups", + values={"id": state_group, "room_id": room_id, "event_id": event_id}, + ) + + # We persist as a delta if we can, while also ensuring the chain + # of deltas isn't tooo long, as otherwise read performance degrades. + if prev_group: + is_in_db = self.db.simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + self.db.simple_insert_txn( + txn, + table="state_group_edges", + values={"state_group": state_group, "prev_state_group": prev_group}, + ) + + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(delta_ids) + ], + ) + else: + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(current_state_ids) + ], + ) + + # Prefill the state group caches with this group. + # It's fine to use the sequence like this as the state group map + # is immutable. (If the map wasn't immutable then this prefill could + # race with another update) + + current_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] == EventTypes.Member + } + txn.call_after( + self._state_group_members_cache.update, + self._state_group_members_cache.sequence, + key=state_group, + value=dict(current_member_state_ids), + ) + + current_non_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] != EventTypes.Member + } + txn.call_after( + self._state_group_cache.update, + self._state_group_cache.sequence, + key=state_group, + value=dict(current_non_member_state_ids), + ) + + return state_group + + return self.db.runInteraction("store_state_group", _store_state_group_txn) + + def purge_unreferenced_state_groups( + self, room_id: str, state_groups_to_delete + ) -> defer.Deferred: + """Deletes no longer referenced state groups and de-deltas any state + groups that reference them. + + Args: + room_id: The room the state groups belong to (must all be in the + same room). + state_groups_to_delete (Collection[int]): Set of all state groups + to delete. + """ + + return self.db.runInteraction( + "purge_unreferenced_state_groups", + self._purge_unreferenced_state_groups, + room_id, + state_groups_to_delete, + ) + + def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete): + logger.info( + "[purge] found %i state groups to delete", len(state_groups_to_delete) + ) + + rows = self.db.simple_select_many_txn( + txn, + table="state_group_edges", + column="prev_state_group", + iterable=state_groups_to_delete, + keyvalues={}, + retcols=("state_group",), + ) + + remaining_state_groups = set( + row["state_group"] + for row in rows + if row["state_group"] not in state_groups_to_delete + ) + + logger.info( + "[purge] de-delta-ing %i remaining state groups", + len(remaining_state_groups), + ) + + # Now we turn the state groups that reference to-be-deleted state + # groups to non delta versions. + for sg in remaining_state_groups: + logger.info("[purge] de-delta-ing remaining state group %s", sg) + curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) + curr_state = curr_state[sg] + + self.db.simple_delete_txn( + txn, table="state_groups_state", keyvalues={"state_group": sg} + ) + + self.db.simple_delete_txn( + txn, table="state_group_edges", keyvalues={"state_group": sg} + ) + + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": sg, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(curr_state) + ], + ) + + logger.info("[purge] removing redundant state groups") + txn.executemany( + "DELETE FROM state_groups_state WHERE state_group = ?", + ((sg,) for sg in state_groups_to_delete), + ) + txn.executemany( + "DELETE FROM state_groups WHERE id = ?", + ((sg,) for sg in state_groups_to_delete), + ) + + @defer.inlineCallbacks + def get_previous_state_groups(self, state_groups): + """Fetch the previous groups of the given state groups. + + Args: + state_groups (Iterable[int]) + + Returns: + Deferred[dict[int, int]]: mapping from state group to previous + state group. + """ + + rows = yield self.db.simple_select_many_batch( + table="state_group_edges", + column="prev_state_group", + iterable=state_groups, + keyvalues={}, + retcols=("prev_state_group", "state_group"), + desc="get_previous_state_groups", + ) + + return {row["state_group"]: row["prev_state_group"] for row in rows} + + def purge_room_state(self, room_id, state_groups_to_delete): + """Deletes all record of a room from state tables + + Args: + room_id (str): + state_groups_to_delete (list[int]): State groups to delete + """ + + return self.db.runInteraction( + "purge_room_state", + self._purge_room_state_txn, + room_id, + state_groups_to_delete, + ) + + def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): + # first we have to delete the state groups states + logger.info("[purge] removing %s from state_groups_state", room_id) + + self.db.simple_delete_many_txn( + txn, + table="state_groups_state", + column="state_group", + iterable=state_groups_to_delete, + keyvalues={}, + ) + + # ... and the state group edges + logger.info("[purge] removing %s from state_group_edges", room_id) + + self.db.simple_delete_many_txn( + txn, + table="state_group_edges", + column="state_group", + iterable=state_groups_to_delete, + keyvalues={}, + ) + + # ... and the state groups + logger.info("[purge] removing %s from state_groups", room_id) + + self.db.simple_delete_many_txn( + txn, + table="state_groups", + column="id", + iterable=state_groups_to_delete, + keyvalues={}, + ) diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ec19ae1d9d..1003dd84a5 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -24,9 +24,11 @@ from six.moves import intern, range from prometheus_client import Histogram +from twisted.enterprise import adbapi from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.config.database import DatabaseConnectionConfig from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater @@ -74,6 +76,37 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { } +def make_pool( + reactor, db_config: DatabaseConnectionConfig, engine +) -> adbapi.ConnectionPool: + """Get the connection pool for the database. + """ + + return adbapi.ConnectionPool( + db_config.config["name"], + cp_reactor=reactor, + cp_openfun=engine.on_new_connection, + **db_config.config.get("args", {}) + ) + + +def make_conn(db_config: DatabaseConnectionConfig, engine): + """Make a new connection to the database and return it. + + Returns: + Connection + """ + + db_params = { + k: v + for k, v in db_config.config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = engine.module.connect(**db_params) + engine.on_new_connection(db_conn) + return db_conn + + class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() @@ -218,10 +251,11 @@ class Database(object): _TXN_ID = 0 - def __init__(self, hs): + def __init__(self, hs, database_config: DatabaseConnectionConfig, engine): self.hs = hs self._clock = hs.get_clock() - self._db_pool = hs.get_db_pool() + self._database_config = database_config + self._db_pool = make_pool(hs.get_reactor(), database_config, engine) self.updates = BackgroundUpdater(hs, self) @@ -234,7 +268,7 @@ class Database(object): # to watch it self._txn_perf_counters = PerformanceCounters() - self.engine = hs.database_engine + self.engine = engine # A set of tables that are not safe to use native upserts in. self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) @@ -255,6 +289,11 @@ class Database(object): self._check_safe_to_upsert, ) + def is_running(self): + """Is the database pool currently running + """ + return self._db_pool.running + @defer.inlineCallbacks def _check_safe_to_upsert(self): """ diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index cbc74cd302..df039a072d 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py
@@ -16,8 +16,6 @@ import struct import threading -from synapse.storage.prepare_database import prepare_database - class Sqlite3Engine(object): single_threaded = True @@ -62,6 +60,10 @@ class Sqlite3Engine(object): return sql def on_new_connection(self, db_conn): + + # We need to import here to avoid an import loop. + from synapse.storage.prepare_database import prepare_database + if self._is_in_memory: # In memory databases need to be rebuilt each time. Ideally we'd # reuse the same connection as we do when starting up, but that diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index fa03ca9ff7..1ed44925fc 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py
@@ -183,7 +183,7 @@ class EventsPersistenceStorage(object): # so we use separate variables here even though they point to the same # store for now. self.main_store = stores.main - self.state_store = stores.main + self.state_store = stores.state self._clock = hs.get_clock() self.is_mine_id = hs.is_mine_id diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 731e1c9d9c..e70026b80a 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py
@@ -18,6 +18,7 @@ import imp import logging import os import re +from collections import Counter import attr @@ -41,7 +42,7 @@ class UpgradeDatabaseException(PrepareDatabaseException): pass -def prepare_database(db_conn, database_engine, config): +def prepare_database(db_conn, database_engine, config, data_stores=["main", "state"]): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. @@ -54,11 +55,10 @@ def prepare_database(db_conn, database_engine, config): config (synapse.config.homeserver.HomeServerConfig|None): application config, or None if we are connecting to an existing database which we expect to be configured already + data_stores (list[str]): The name of the data stores that will be used + with this database. Defaults to all data stores. """ - # For now we only have the one datastore. - data_stores = ["main"] - try: cur = db_conn.cursor() version_info = _get_or_create_schema_state(cur, database_engine) @@ -70,7 +70,10 @@ def prepare_database(db_conn, database_engine, config): if user_version != SCHEMA_VERSION: # If we don't pass in a config file then we are expecting to # have already upgraded the DB. - raise UpgradeDatabaseException("Database needs to be upgraded") + raise UpgradeDatabaseException( + "Expected database schema version %i but got %i" + % (SCHEMA_VERSION, user_version) + ) else: _upgrade_existing_database( cur, @@ -313,6 +316,9 @@ def _upgrade_existing_database( ) ) + # Used to check if we have any duplicate file names + file_name_counter = Counter() + # Now find which directories have anything of interest. directory_entries = [] for directory in directories: @@ -323,6 +329,9 @@ def _upgrade_existing_database( _DirectoryListing(file_name, os.path.join(directory, file_name)) for file_name in file_names ) + + for file_name in file_names: + file_name_counter[file_name] += 1 except FileNotFoundError: # Data stores can have empty entries for a given version delta. pass @@ -331,6 +340,17 @@ def _upgrade_existing_database( "Could not open delta dir for version %d: %s" % (v, directory) ) + duplicates = set( + file_name for file_name, count in file_name_counter.items() if count > 1 + ) + if duplicates: + # We don't support using the same file name in the same delta version. + raise PrepareDatabaseException( + "Found multiple delta files with the same name in v%d: %s", + v, + duplicates, + ) + # We sort to ensure that we apply the delta files in a consistent # order (to avoid bugs caused by inconsistent directory listing order) directory_entries.sort() diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index a368182034..d6a7bd7834 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py
@@ -58,7 +58,7 @@ class PurgeEventsStorage(object): sg_to_delete = yield self._find_unreferenced_groups(state_groups) - yield self.stores.main.purge_unreferenced_state_groups(room_id, sg_to_delete) + yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) @defer.inlineCallbacks def _find_unreferenced_groups(self, state_groups): @@ -102,7 +102,7 @@ class PurgeEventsStorage(object): # groups that are referenced. current_search -= referenced - edges = yield self.stores.main.get_previous_state_groups(current_search) + edges = yield self.stores.state.get_previous_state_groups(current_search) prevs = set(edges.values()) # We don't bother re-handling groups we've already seen diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 3735846899..cbeb586014 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py
@@ -342,7 +342,7 @@ class StateGroupStorage(object): (prev_group, delta_ids) """ - return self.stores.main.get_state_group_delta(state_group) + return self.stores.state.get_state_group_delta(state_group) @defer.inlineCallbacks def get_state_groups_ids(self, _room_id, event_ids): @@ -362,7 +362,7 @@ class StateGroupStorage(object): event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self.stores.main._get_state_for_groups(groups) + group_to_state = yield self.stores.state._get_state_for_groups(groups) return group_to_state @@ -423,7 +423,7 @@ class StateGroupStorage(object): dict of state_group_id -> (dict of (type, state_key) -> event id) """ - return self.stores.main._get_state_groups_from_groups(groups, state_filter) + return self.stores.state._get_state_groups_from_groups(groups, state_filter) @defer.inlineCallbacks def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): @@ -439,7 +439,7 @@ class StateGroupStorage(object): event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self.stores.main._get_state_for_groups( + group_to_state = yield self.stores.state._get_state_for_groups( groups, state_filter ) @@ -476,7 +476,7 @@ class StateGroupStorage(object): event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self.stores.main._get_state_for_groups( + group_to_state = yield self.stores.state._get_state_for_groups( groups, state_filter ) @@ -532,7 +532,7 @@ class StateGroupStorage(object): Deferred[dict[int, dict[tuple[str, str], str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id) """ - return self.stores.main._get_state_for_groups(groups, state_filter) + return self.stores.state._get_state_for_groups(groups, state_filter) def store_state_group( self, event_id, room_id, prev_group, delta_ids, current_state_ids @@ -552,6 +552,6 @@ class StateGroupStorage(object): Returns: Deferred[int]: The state group ID """ - return self.stores.main.store_state_group( + return self.stores.state.store_state_group( event_id, room_id, prev_group, delta_ids, current_state_ids )