diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/__init__.py | 15 | ||||
-rw-r--r-- | synapse/storage/_base.py | 13 | ||||
-rw-r--r-- | synapse/storage/deviceinbox.py | 35 | ||||
-rw-r--r-- | synapse/storage/devices.py | 474 | ||||
-rw-r--r-- | synapse/storage/end_to_end_keys.py | 91 | ||||
-rw-r--r-- | synapse/storage/event_federation.py | 78 | ||||
-rw-r--r-- | synapse/storage/events.py | 397 | ||||
-rw-r--r-- | synapse/storage/prepare_database.py | 2 | ||||
-rw-r--r-- | synapse/storage/registration.py | 11 | ||||
-rw-r--r-- | synapse/storage/roommember.py | 30 | ||||
-rw-r--r-- | synapse/storage/schema/delta/40/current_state_idx.sql | 17 | ||||
-rw-r--r-- | synapse/storage/schema/delta/40/device_inbox.sql | 21 | ||||
-rw-r--r-- | synapse/storage/schema/delta/40/device_list_streams.sql | 59 | ||||
-rw-r--r-- | synapse/storage/state.py | 62 | ||||
-rw-r--r-- | synapse/storage/stream.py | 14 |
15 files changed, 1016 insertions, 303 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index fe936b3e62..b9968debe5 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore, self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id" ) + self._device_list_id_gen = StreamIdGenerator( + db_conn, "device_lists_stream", "stream_id", + ) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") @@ -189,7 +192,8 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "device_inbox", entity_column="user_id", stream_column="stream_id", - max_value=max_device_inbox_id + max_value=max_device_inbox_id, + limit=1000, ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", min_device_inbox_id, @@ -202,12 +206,21 @@ class DataStore(RoomMemberStore, RoomStore, entity_column="destination", stream_column="stream_id", max_value=max_device_inbox_id, + limit=1000, ) self._device_federation_outbox_stream_cache = StreamChangeCache( "DeviceFederationOutboxStreamChangeCache", min_device_outbox_id, prefilled_cache=device_outbox_prefill, ) + device_list_max = self._device_list_id_gen.get_current_token() + self._device_list_stream_cache = StreamChangeCache( + "DeviceListStreamChangeCache", device_list_max, + ) + self._device_list_federation_stream_cache = StreamChangeCache( + "DeviceListFederationStreamChangeCache", device_list_max, + ) + cur = LoggingTransaction( db_conn.cursor(), name="_find_stream_orderings_for_times_txn", diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b62c459d8b..05374682fd 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -169,7 +169,7 @@ class SQLBaseStore(object): max_entries=hs.config.event_cache_size) self._state_group_cache = DictionaryCache( - "*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR + "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR ) self._event_fetch_lock = threading.Condition() @@ -387,6 +387,10 @@ class SQLBaseStore(object): Args: table : string giving the table name values : dict of new column names and values for them + + Returns: + bool: Whether the row was inserted or not. Only useful when + `or_ignore` is True """ try: yield self.runInteraction( @@ -398,6 +402,8 @@ class SQLBaseStore(object): # a cursor after we receive an error from the db. if not or_ignore: raise + defer.returnValue(False) + defer.returnValue(True) @staticmethod def _simple_insert_txn(txn, table, values): @@ -838,18 +844,19 @@ class SQLBaseStore(object): return txn.execute(sql, keyvalues.values()) def _get_cache_dict(self, db_conn, table, entity_column, stream_column, - max_value): + max_value, limit=100000): # Fetch a mapping of room_id -> max stream position for "recent" rooms. # It doesn't really matter how many we get, the StreamChangeCache will # do the right thing to ensure it respects the max size of cache. sql = ( "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" - " WHERE %(stream)s > ? - 100000" + " WHERE %(stream)s > ? - %(limit)s" " GROUP BY %(entity)s" ) % { "table": table, "entity": entity_column, "stream": stream_column, + "limit": limit, } sql = self.database_engine.convert_param_style(sql) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 2821eb89c9..bde3b5cbbc 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -18,13 +18,29 @@ import ujson from twisted.internet import defer -from ._base import SQLBaseStore +from .background_updates import BackgroundUpdateStore logger = logging.getLogger(__name__) -class DeviceInboxStore(SQLBaseStore): +class DeviceInboxStore(BackgroundUpdateStore): + DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" + + def __init__(self, hs): + super(DeviceInboxStore, self).__init__(hs) + + self.register_background_index_update( + "device_inbox_stream_index", + index_name="device_inbox_stream_id_user_id", + table="device_inbox", + columns=["stream_id", "user_id"], + ) + + self.register_background_update_handler( + self.DEVICE_INBOX_STREAM_ID, + self._background_drop_index_device_inbox, + ) @defer.inlineCallbacks def add_messages_to_device_inbox(self, local_messages_by_user_then_device, @@ -368,3 +384,18 @@ class DeviceInboxStore(SQLBaseStore): "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) + + @defer.inlineCallbacks + def _background_drop_index_device_inbox(self, progress, batch_size): + def reindex_txn(conn): + txn = conn.cursor() + txn.execute( + "DROP INDEX IF EXISTS device_inbox_stream_id" + ) + txn.close() + + yield self.runWithConnection(reindex_txn) + + yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID) + + defer.returnValue(1) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 17920d4480..8e17800364 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import ujson as json from twisted.internet import defer @@ -23,27 +24,29 @@ logger = logging.getLogger(__name__) class DeviceStore(SQLBaseStore): + def __init__(self, hs): + super(DeviceStore, self).__init__(hs) + + self._clock.looping_call( + self._prune_old_outbound_device_pokes, 60 * 60 * 1000 + ) + @defer.inlineCallbacks def store_device(self, user_id, device_id, - initial_device_display_name, - ignore_if_known=True): + initial_device_display_name): """Ensure the given device is known; add it to the store if not Args: user_id (str): id of user associated with the device device_id (str): id of device initial_device_display_name (str): initial displayname of the - device - ignore_if_known (bool): ignore integrity errors which mean the - device is already known + device. Ignored if device exists. Returns: - defer.Deferred - Raises: - StoreError: if ignore_if_known is False and the device was already - known + defer.Deferred: boolean whether the device was inserted or an + existing device existed with that ID. """ try: - yield self._simple_insert( + inserted = yield self._simple_insert( "devices", values={ "user_id": user_id, @@ -51,8 +54,9 @@ class DeviceStore(SQLBaseStore): "display_name": initial_device_display_name }, desc="store_device", - or_ignore=ignore_if_known, + or_ignore=True, ) + defer.returnValue(inserted) except Exception as e: logger.error("store_device with device_id=%s(%r) user_id=%s(%r)" " display_name=%s(%r) failed: %s", @@ -139,3 +143,451 @@ class DeviceStore(SQLBaseStore): ) defer.returnValue({d["device_id"]: d for d in devices}) + + def get_device_list_last_stream_id_for_remote(self, user_id): + """Get the last stream_id we got for a user. May be None if we haven't + got any information for them. + """ + return self._simple_select_one_onecol( + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + retcol="stream_id", + desc="get_device_list_remote_extremity", + allow_none=True, + ) + + def mark_remote_user_device_list_as_unsubscribed(self, user_id): + """Mark that we no longer track device lists for remote user. + """ + return self._simple_delete( + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + desc="mark_remote_user_device_list_as_unsubscribed", + ) + + def update_remote_device_list_cache_entry(self, user_id, device_id, content, + stream_id): + """Updates a single user's device in the cache. + """ + return self.runInteraction( + "update_remote_device_list_cache_entry", + self._update_remote_device_list_cache_entry_txn, + user_id, device_id, content, stream_id, + ) + + def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id, + content, stream_id): + self._simple_upsert_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "content": json.dumps(content), + } + ) + + self._simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + values={ + "stream_id": stream_id, + } + ) + + def update_remote_device_list_cache(self, user_id, devices, stream_id): + """Replace the cache of the remote user's devices. + """ + return self.runInteraction( + "update_remote_device_list_cache", + self._update_remote_device_list_cache_txn, + user_id, devices, stream_id, + ) + + def _update_remote_device_list_cache_txn(self, txn, user_id, devices, + stream_id): + self._simple_delete_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, + ) + + self._simple_insert_many_txn( + txn, + table="device_lists_remote_cache", + values=[ + { + "user_id": user_id, + "device_id": content["device_id"], + "content": json.dumps(content), + } + for content in devices + ] + ) + + self._simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + values={ + "stream_id": stream_id, + } + ) + + def get_devices_by_remote(self, destination, from_stream_id): + """Get stream of updates to send to remote servers + + Returns: + (now_stream_id, [ { updates }, .. ]) + """ + now_stream_id = self._device_list_id_gen.get_current_token() + + has_changed = self._device_list_federation_stream_cache.has_entity_changed( + destination, int(from_stream_id) + ) + if not has_changed: + return (now_stream_id, []) + + return self.runInteraction( + "get_devices_by_remote", self._get_devices_by_remote_txn, + destination, from_stream_id, now_stream_id, + ) + + def _get_devices_by_remote_txn(self, txn, destination, from_stream_id, + now_stream_id): + sql = """ + SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes + WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? + GROUP BY user_id, device_id + """ + txn.execute( + sql, (destination, from_stream_id, now_stream_id, False) + ) + rows = txn.fetchall() + + if not rows: + return (now_stream_id, []) + + # maps (user_id, device_id) -> stream_id + query_map = {(r[0], r[1]): r[2] for r in rows} + devices = self._get_e2e_device_keys_txn( + txn, query_map.keys(), include_all_devices=True + ) + + prev_sent_id_sql = """ + SELECT coalesce(max(stream_id), 0) as stream_id + FROM device_lists_outbound_pokes + WHERE destination = ? AND user_id = ? AND stream_id <= ? + """ + + results = [] + for user_id, user_devices in devices.iteritems(): + # The prev_id for the first row is always the last row before + # `from_stream_id` + txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) + rows = txn.fetchall() + prev_id = rows[0][0] + for device_id, device in user_devices.iteritems(): + stream_id = query_map[(user_id, device_id)] + result = { + "user_id": user_id, + "device_id": device_id, + "prev_id": [prev_id] if prev_id else [], + "stream_id": stream_id, + } + + prev_id = stream_id + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = json.loads(key_json) + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return (now_stream_id, results) + + def get_user_devices_from_cache(self, query_list): + """Get the devices (and keys if any) for remote users from the cache. + + Args: + query_list(list): List of (user_id, device_ids), if device_ids is + falsey then return all device ids for that user. + + Returns: + (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is + a set of user_ids and results_map is a mapping of + user_id -> device_id -> device_info + """ + return self.runInteraction( + "get_user_devices_from_cache", self._get_user_devices_from_cache_txn, + query_list, + ) + + def _get_user_devices_from_cache_txn(self, txn, query_list): + user_ids = {user_id for user_id, _ in query_list} + + user_ids_in_cache = set() + for user_id in user_ids: + stream_ids = self._simple_select_onecol_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + retcol="stream_id", + ) + if stream_ids: + user_ids_in_cache.add(user_id) + + user_ids_not_in_cache = user_ids - user_ids_in_cache + + results = {} + for user_id, device_id in query_list: + if user_id not in user_ids_in_cache: + continue + + if device_id: + content = self._simple_select_one_onecol_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + retcol="content", + ) + results.setdefault(user_id, {})[device_id] = json.loads(content) + else: + devices = self._simple_select_list_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, + retcols=("device_id", "content"), + ) + results[user_id] = { + device["device_id"]: json.loads(device["content"]) + for device in devices + } + user_ids_in_cache.discard(user_id) + + return user_ids_not_in_cache, results + + def get_devices_with_keys_by_user(self, user_id): + """Get all devices (with any device keys) for a user + + Returns: + (stream_id, devices) + """ + return self.runInteraction( + "get_devices_with_keys_by_user", + self._get_devices_with_keys_by_user_txn, user_id, + ) + + def _get_devices_with_keys_by_user_txn(self, txn, user_id): + now_stream_id = self._device_list_id_gen.get_current_token() + + devices = self._get_e2e_device_keys_txn( + txn, [(user_id, None)], include_all_devices=True + ) + + if devices: + user_devices = devices[user_id] + results = [] + for device_id, device in user_devices.iteritems(): + result = { + "device_id": device_id, + } + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = json.loads(key_json) + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return now_stream_id, results + + return now_stream_id, [] + + def mark_as_sent_devices_by_remote(self, destination, stream_id): + """Mark that updates have successfully been sent to the destination. + """ + return self.runInteraction( + "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, + destination, stream_id, + ) + + def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): + # First we DELETE all rows such that only the latest row for each + # (destination, user_id is left. We do this by selecting first and + # deleting. + sql = """ + SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes + WHERE destination = ? AND stream_id <= ? + GROUP BY user_id + HAVING count(*) > 1 + """ + txn.execute(sql, (destination, stream_id,)) + rows = txn.fetchall() + + sql = """ + DELETE FROM device_lists_outbound_pokes + WHERE destination = ? AND user_id = ? AND stream_id < ? + """ + txn.executemany( + sql, ((destination, row[0], row[1],) for row in rows) + ) + + # Mark everything that is left as sent + sql = """ + UPDATE device_lists_outbound_pokes SET sent = ? + WHERE destination = ? AND stream_id <= ? + """ + txn.execute(sql, (True, destination, stream_id,)) + + @defer.inlineCallbacks + def get_user_whose_devices_changed(self, from_key): + """Get set of users whose devices have changed since `from_key`. + """ + from_key = int(from_key) + changed = self._device_list_stream_cache.get_all_entities_changed(from_key) + if changed is not None: + defer.returnValue(set(changed)) + + sql = """ + SELECT user_id FROM device_lists_stream WHERE stream_id > ? + """ + rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) + defer.returnValue(set(row[0] for row in rows)) + + def get_all_device_list_changes_for_remotes(self, from_key): + """Return a list of `(stream_id, user_id, destination)` which is the + combined list of changes to devices, and which destinations need to be + poked. `destination` may be None if no destinations need to be poked. + """ + sql = """ + SELECT stream_id, user_id, destination FROM device_lists_stream + LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) + WHERE stream_id > ? + """ + return self._execute( + "get_users_and_hosts_device_list", None, + sql, from_key, + ) + + @defer.inlineCallbacks + def add_device_change_to_streams(self, user_id, device_ids, hosts): + """Persist that a user's devices have been updated, and which hosts + (if any) should be poked. + """ + with self._device_list_id_gen.get_next() as stream_id: + yield self.runInteraction( + "add_device_change_to_streams", self._add_device_change_txn, + user_id, device_ids, hosts, stream_id, + ) + defer.returnValue(stream_id) + + def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id): + now = self._clock.time_msec() + + txn.call_after( + self._device_list_stream_cache.entity_has_changed, + user_id, stream_id, + ) + for host in hosts: + txn.call_after( + self._device_list_federation_stream_cache.entity_has_changed, + host, stream_id, + ) + + self._simple_insert_many_txn( + txn, + table="device_lists_stream", + values=[ + { + "stream_id": stream_id, + "user_id": user_id, + "device_id": device_id, + } + for device_id in device_ids + ] + ) + + self._simple_insert_many_txn( + txn, + table="device_lists_outbound_pokes", + values=[ + { + "destination": destination, + "stream_id": stream_id, + "user_id": user_id, + "device_id": device_id, + "sent": False, + "ts": now, + } + for destination in hosts + for device_id in device_ids + ] + ) + + def get_device_stream_token(self): + return self._device_list_id_gen.get_current_token() + + def _prune_old_outbound_device_pokes(self): + """Delete old entries out of the device_lists_outbound_pokes to ensure + that we don't fill up due to dead servers. We keep one entry per + (destination, user_id) tuple to ensure that the prev_ids remain correct + if the server does come back. + """ + yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000 + + def _prune_txn(txn): + select_sql = """ + SELECT destination, user_id, max(stream_id) as stream_id + FROM device_lists_outbound_pokes + GROUP BY destination, user_id + HAVING min(ts) < ? AND count(*) > 1 + """ + + txn.execute(select_sql, (yesterday,)) + rows = txn.fetchall() + + if not rows: + return + + delete_sql = """ + DELETE FROM device_lists_outbound_pokes + WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ? + """ + + txn.executemany( + delete_sql, + ( + (yesterday, row[0], row[1], row[2]) + for row in rows + ) + ) + + logger.info("Pruned %d device list outbound pokes", txn.rowcount) + + return self.runInteraction( + "_prune_old_outbound_device_pokes", _prune_txn + ) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 385d607056..2040e022fa 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -12,74 +12,111 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections +from twisted.internet import defer -import twisted.internet.defer +from canonicaljson import encode_canonical_json +import ujson as json from ._base import SQLBaseStore class EndToEndKeyStore(SQLBaseStore): - def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes): - return self._simple_upsert( - table="e2e_device_keys_json", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - values={ - "ts_added_ms": time_now, - "key_json": json_bytes, - } + def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): + """Stores device keys for a device. Returns whether there was a change + or the keys were already in the database. + """ + def _set_e2e_device_keys_txn(txn): + old_key_json = self._simple_select_one_onecol_txn( + txn, + table="e2e_device_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + retcol="key_json", + allow_none=True, + ) + + new_key_json = encode_canonical_json(device_keys) + if old_key_json == new_key_json: + return False + + self._simple_upsert_txn( + txn, + table="e2e_device_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "ts_added_ms": time_now, + "key_json": new_key_json, + } + ) + + return True + + return self.runInteraction( + "set_e2e_device_keys", _set_e2e_device_keys_txn ) - def get_e2e_device_keys(self, query_list): + @defer.inlineCallbacks + def get_e2e_device_keys(self, query_list, include_all_devices=False): """Fetch a list of device keys. Args: query_list(list): List of pairs of user_ids and device_ids. + include_all_devices (bool): whether to include entries for devices + that don't have device keys Returns: Dict mapping from user-id to dict mapping from device_id to dict containing "key_json", "device_display_name". """ if not query_list: - return {} + defer.returnValue({}) - return self.runInteraction( - "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list + results = yield self.runInteraction( + "get_e2e_device_keys", self._get_e2e_device_keys_txn, + query_list, include_all_devices, ) - def _get_e2e_device_keys_txn(self, txn, query_list): + for user_id, device_keys in results.iteritems(): + for device_id, device_info in device_keys.iteritems(): + device_info["keys"] = json.loads(device_info.pop("key_json")) + + defer.returnValue(results) + + def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices): query_clauses = [] query_params = [] for (user_id, device_id) in query_list: - query_clause = "k.user_id = ?" + query_clause = "user_id = ?" query_params.append(user_id) if device_id: - query_clause += " AND k.device_id = ?" + query_clause += " AND device_id = ?" query_params.append(device_id) query_clauses.append(query_clause) sql = ( - "SELECT k.user_id, k.device_id, " + "SELECT user_id, device_id, " " d.display_name AS device_display_name, " " k.key_json" - " FROM e2e_device_keys_json k" - " LEFT JOIN devices d ON d.user_id = k.user_id" - " AND d.device_id = k.device_id" + " FROM devices d" + " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" " WHERE %s" ) % ( + "LEFT" if include_all_devices else "INNER", " OR ".join("(" + q + ")" for q in query_clauses) ) txn.execute(sql, query_params) rows = self.cursor_to_dict(txn) - result = collections.defaultdict(dict) + result = {} for row in rows: - result[row["user_id"]][row["device_id"]] = row + result.setdefault(row["user_id"], {})[row["device_id"]] = row return result @@ -152,7 +189,7 @@ class EndToEndKeyStore(SQLBaseStore): "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) - @twisted.internet.defer.inlineCallbacks + @defer.inlineCallbacks def delete_e2e_keys_by_device(self, user_id, device_id): yield self._simple_delete( table="e2e_device_keys_json", diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 53feaa1960..ee88c61954 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -129,7 +129,7 @@ class EventFederationStore(SQLBaseStore): room_id, ) - @cached() + @cached(max_entries=5000, iterable=True) def get_latest_event_ids_in_room(self, room_id): return self._simple_select_onecol( table="event_forward_extremities", @@ -235,80 +235,21 @@ class EventFederationStore(SQLBaseStore): ], ) - self._update_extremeties(txn, events) + self._update_backward_extremeties(txn, events) - def _update_extremeties(self, txn, events): - """Updates the event_*_extremities tables based on the new/updated + def _update_backward_extremeties(self, txn, events): + """Updates the event_backward_extremities tables based on the new/updated events being persisted. This is called for new events *and* for events that were outliers, but - are are now being persisted as non-outliers. + are now being persisted as non-outliers. + + Forward extremities are handled when we first start persisting the events. """ events_by_room = {} for ev in events: events_by_room.setdefault(ev.room_id, []).append(ev) - for room_id, room_events in events_by_room.items(): - prevs = [ - e_id for ev in room_events for e_id, _ in ev.prev_events - if not ev.internal_metadata.is_outlier() - ] - if prevs: - txn.execute( - "DELETE FROM event_forward_extremities" - " WHERE room_id = ?" - " AND event_id in (%s)" % ( - ",".join(["?"] * len(prevs)), - ), - [room_id] + prevs, - ) - - query = ( - "INSERT INTO event_forward_extremities (event_id, room_id)" - " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_edges WHERE prev_event_id = ?" - " )" - ) - - txn.executemany( - query, - [ - (ev.event_id, ev.room_id, ev.event_id) for ev in events - if not ev.internal_metadata.is_outlier() - ] - ) - - # We now insert into stream_ordering_to_exterm a mapping from room_id, - # new stream_ordering to new forward extremeties in the room. - # This allows us to later efficiently look up the forward extremeties - # for a room before a given stream_ordering - max_stream_ord = max( - ev.internal_metadata.stream_ordering for ev in events - ) - new_extrem = {} - for room_id in events_by_room: - event_ids = self._simple_select_onecol_txn( - txn, - table="event_forward_extremities", - keyvalues={"room_id": room_id}, - retcol="event_id", - ) - new_extrem[room_id] = event_ids - - self._simple_insert_many_txn( - txn, - table="stream_ordering_to_exterm", - values=[ - { - "room_id": room_id, - "event_id": event_id, - "stream_ordering": max_stream_ord, - } - for room_id, extrem_evs in new_extrem.items() - for event_id in extrem_evs - ] - ) - query = ( "INSERT INTO event_backward_extremities (event_id, room_id)" " SELECT ?, ? WHERE NOT EXISTS (" @@ -339,11 +280,6 @@ class EventFederationStore(SQLBaseStore): ] ) - for room_id in events_by_room: - txn.call_after( - self.get_latest_event_ids_in_room.invalidate, (room_id,) - ) - def get_forward_extremeties_for_room(self, room_id, stream_ordering): # We want to make the cache more effective, so we clamp to the last # change before the given ordering. diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 04dbdac3f8..8659f605a5 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore, _RollbackButIsFineException +from ._base import SQLBaseStore from twisted.internet import defer, reactor @@ -27,6 +27,8 @@ from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError +from synapse.state import resolve_events +from synapse.util.caches.descriptors import cached from canonicaljson import encode_canonical_json from collections import deque, namedtuple, OrderedDict @@ -71,22 +73,19 @@ class _EventPeristenceQueue(object): """ _EventPersistQueueItem = namedtuple("_EventPersistQueueItem", ( - "events_and_contexts", "current_state", "backfilled", "deferred", + "events_and_contexts", "backfilled", "deferred", )) def __init__(self): self._event_persist_queues = {} self._currently_persisting_rooms = set() - def add_to_queue(self, room_id, events_and_contexts, backfilled, current_state): + def add_to_queue(self, room_id, events_and_contexts, backfilled): """Add events to the queue, with the given persist_event options. """ queue = self._event_persist_queues.setdefault(room_id, deque()) if queue: end_item = queue[-1] - if end_item.current_state or current_state: - # We perist events with current_state set to True one at a time - pass if end_item.backfilled == backfilled: end_item.events_and_contexts.extend(events_and_contexts) return end_item.deferred.observe() @@ -96,7 +95,6 @@ class _EventPeristenceQueue(object): queue.append(self._EventPersistQueueItem( events_and_contexts=events_and_contexts, backfilled=backfilled, - current_state=current_state, deferred=deferred, )) @@ -216,7 +214,6 @@ class EventsStore(SQLBaseStore): d = preserve_fn(self._event_persist_queue.add_to_queue)( room_id, evs_ctxs, backfilled=backfilled, - current_state=None, ) deferreds.append(d) @@ -229,11 +226,10 @@ class EventsStore(SQLBaseStore): @defer.inlineCallbacks @log_function - def persist_event(self, event, context, current_state=None, backfilled=False): + def persist_event(self, event, context, backfilled=False): deferred = self._event_persist_queue.add_to_queue( event.room_id, [(event, context)], backfilled=backfilled, - current_state=current_state, ) self._maybe_start_persisting(event.room_id) @@ -246,21 +242,10 @@ class EventsStore(SQLBaseStore): def _maybe_start_persisting(self, room_id): @defer.inlineCallbacks def persisting_queue(item): - if item.current_state: - for event, context in item.events_and_contexts: - # There should only ever be one item in - # events_and_contexts when current_state is - # not None - yield self._persist_event( - event, context, - current_state=item.current_state, - backfilled=item.backfilled, - ) - else: - yield self._persist_events( - item.events_and_contexts, - backfilled=item.backfilled, - ) + yield self._persist_events( + item.events_and_contexts, + backfilled=item.backfilled, + ) self._event_persist_queue.handle_queue(room_id, persisting_queue) @@ -294,35 +279,183 @@ class EventsStore(SQLBaseStore): for chunk in chunks: # We can't easily parallelize these since different chunks # might contain the same event. :( + + # NB: Assumes that we are only persisting events for one room + # at a time. + new_forward_extremeties = {} + current_state_for_room = {} + if not backfilled: + with Measure(self._clock, "_calculate_state_and_extrem"): + # Work out the new "current state" for each room. + # We do this by working out what the new extremities are and then + # calculating the state from that. + events_by_room = {} + for event, context in chunk: + events_by_room.setdefault(event.room_id, []).append( + (event, context) + ) + + for room_id, ev_ctx_rm in events_by_room.items(): + # Work out new extremities by recursively adding and removing + # the new events. + latest_event_ids = yield self.get_latest_event_ids_in_room( + room_id + ) + new_latest_event_ids = yield self._calculate_new_extremeties( + room_id, [ev for ev, _ in ev_ctx_rm] + ) + + if new_latest_event_ids == set(latest_event_ids): + # No change in extremities, so no change in state + continue + + new_forward_extremeties[room_id] = new_latest_event_ids + + state = yield self._calculate_state_delta( + room_id, ev_ctx_rm, new_latest_event_ids + ) + if state: + current_state_for_room[room_id] = state + yield self.runInteraction( "persist_events", self._persist_events_txn, events_and_contexts=chunk, backfilled=backfilled, delete_existing=delete_existing, + current_state_for_room=current_state_for_room, + new_forward_extremeties=new_forward_extremeties, ) persist_event_counter.inc_by(len(chunk)) - @_retry_on_integrity_error @defer.inlineCallbacks - @log_function - def _persist_event(self, event, context, current_state=None, backfilled=False, - delete_existing=False): - try: - with self._stream_id_gen.get_next() as stream_ordering: - event.internal_metadata.stream_ordering = stream_ordering - yield self.runInteraction( - "persist_event", - self._persist_event_txn, - event=event, - context=context, - current_state=current_state, - backfilled=backfilled, - delete_existing=delete_existing, - ) - persist_event_counter.inc() - except _RollbackButIsFineException: - pass + def _calculate_new_extremeties(self, room_id, events): + """Calculates the new forward extremeties for a room given events to + persist. + + Assumes that we are only persisting events for one room at a time. + """ + latest_event_ids = yield self.get_latest_event_ids_in_room( + room_id + ) + new_latest_event_ids = set(latest_event_ids) + # First, add all the new events to the list + new_latest_event_ids.update( + event.event_id for event in events + if not event.internal_metadata.is_outlier() + ) + # Now remove all events that are referenced by the to-be-added events + new_latest_event_ids.difference_update( + e_id + for event in events + for e_id, _ in event.prev_events + if not event.internal_metadata.is_outlier() + ) + + # And finally remove any events that are referenced by previously added + # events. + rows = yield self._simple_select_many_batch( + table="event_edges", + column="prev_event_id", + iterable=list(new_latest_event_ids), + retcols=["prev_event_id"], + keyvalues={ + "room_id": room_id, + "is_state": False, + }, + desc="_calculate_new_extremeties", + ) + + new_latest_event_ids.difference_update( + row["prev_event_id"] for row in rows + ) + + defer.returnValue(new_latest_event_ids) + + @defer.inlineCallbacks + def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids): + """Calculate the new state deltas for a room. + + Assumes that we are only persisting events for one room at a time. + + Returns: + 2-tuple (to_delete, to_insert) where both are state dicts, i.e. + (type, state_key) -> event_id. `to_delete` are the entries to + first be deleted from current_state_events, `to_insert` are entries + to insert. + May return None if there are no changes to be applied. + """ + # Now we need to work out the different state sets for + # each state extremities + state_sets = [] + missing_event_ids = [] + was_updated = False + for event_id in new_latest_event_ids: + # First search in the list of new events we're adding, + # and then use the current state from that + for ev, ctx in events_context: + if event_id == ev.event_id: + if ctx.current_state_ids is None: + raise Exception("Unknown current state") + state_sets.append(ctx.current_state_ids) + if ctx.delta_ids or hasattr(ev, "state_key"): + was_updated = True + break + else: + # If we couldn't find it, then we'll need to pull + # the state from the database + was_updated = True + missing_event_ids.append(event_id) + + if missing_event_ids: + # Now pull out the state for any missing events from DB + event_to_groups = yield self._get_state_group_for_events( + missing_event_ids, + ) + + groups = set(event_to_groups.values()) + group_to_state = yield self._get_state_for_groups(groups) + + state_sets.extend(group_to_state.values()) + + if not new_latest_event_ids: + current_state = {} + elif was_updated: + current_state = yield resolve_events( + state_sets, + state_map_factory=lambda ev_ids: self.get_events( + ev_ids, get_prev_content=False, check_redacted=False, + ), + ) + else: + return + + existing_state_rows = yield self._simple_select_list( + table="current_state_events", + keyvalues={"room_id": room_id}, + retcols=["event_id", "type", "state_key"], + desc="_calculate_state_delta", + ) + + existing_events = set(row["event_id"] for row in existing_state_rows) + new_events = set(ev_id for ev_id in current_state.itervalues()) + changed_events = existing_events ^ new_events + + if not changed_events: + return + + to_delete = { + (row["type"], row["state_key"]): row["event_id"] + for row in existing_state_rows + if row["event_id"] in changed_events + } + events_to_insert = (new_events - existing_events) + to_insert = { + key: ev_id for key, ev_id in current_state.iteritems() + if ev_id in events_to_insert + } + + defer.returnValue((to_delete, to_insert)) @defer.inlineCallbacks def get_event(self, event_id, check_redacted=True, @@ -381,52 +514,9 @@ class EventsStore(SQLBaseStore): defer.returnValue({e.event_id: e for e in events}) @log_function - def _persist_event_txn(self, txn, event, context, current_state, backfilled=False, - delete_existing=False): - # We purposefully do this first since if we include a `current_state` - # key, we *want* to update the `current_state_events` table - if current_state: - txn.call_after(self._get_current_state_for_key.invalidate_all) - txn.call_after(self.get_rooms_for_user.invalidate_all) - txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) - - # Add an entry to the current_state_resets table to record the point - # where we clobbered the current state - stream_order = event.internal_metadata.stream_ordering - self._simple_insert_txn( - txn, - table="current_state_resets", - values={"event_stream_ordering": stream_order} - ) - - self._simple_delete_txn( - txn, - table="current_state_events", - keyvalues={"room_id": event.room_id}, - ) - - for s in current_state: - self._simple_insert_txn( - txn, - "current_state_events", - { - "event_id": s.event_id, - "room_id": s.room_id, - "type": s.type, - "state_key": s.state_key, - } - ) - - return self._persist_events_txn( - txn, - [(event, context)], - backfilled=backfilled, - delete_existing=delete_existing, - ) - - @log_function def _persist_events_txn(self, txn, events_and_contexts, backfilled, - delete_existing=False): + delete_existing=False, current_state_for_room={}, + new_forward_extremeties={}): """Insert some number of room events into the necessary database tables. Rejected events are only inserted into the events table, the events_json table, @@ -436,6 +526,93 @@ class EventsStore(SQLBaseStore): If delete_existing is True then existing events will be purged from the database before insertion. This is useful when retrying due to IntegrityError. """ + max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering + for room_id, current_state_tuple in current_state_for_room.iteritems(): + to_delete, to_insert = current_state_tuple + txn.executemany( + "DELETE FROM current_state_events WHERE event_id = ?", + [(ev_id,) for ev_id in to_delete.itervalues()], + ) + + self._simple_insert_many_txn( + txn, + table="current_state_events", + values=[ + { + "event_id": ev_id, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + } + for key, ev_id in to_insert.iteritems() + ], + ) + + # Invalidate the various caches + + # Figure out the changes of membership to invalidate the + # `get_rooms_for_user` cache. + # We find out which membership events we may have deleted + # and which we have added, then we invlidate the caches for all + # those users. + members_changed = set( + state_key for ev_type, state_key in to_delete.iterkeys() + if ev_type == EventTypes.Member + ) + members_changed.update( + state_key for ev_type, state_key in to_insert.iterkeys() + if ev_type == EventTypes.Member + ) + + for member in members_changed: + self._invalidate_cache_and_stream( + txn, self.get_rooms_for_user, (member,) + ) + + self._invalidate_cache_and_stream( + txn, self.get_users_in_room, (room_id,) + ) + + for room_id, new_extrem in new_forward_extremeties.items(): + self._simple_delete_txn( + txn, + table="event_forward_extremities", + keyvalues={"room_id": room_id}, + ) + txn.call_after( + self.get_latest_event_ids_in_room.invalidate, (room_id,) + ) + + self._simple_insert_many_txn( + txn, + table="event_forward_extremities", + values=[ + { + "event_id": ev_id, + "room_id": room_id, + } + for room_id, new_extrem in new_forward_extremeties.items() + for ev_id in new_extrem + ], + ) + # We now insert into stream_ordering_to_exterm a mapping from room_id, + # new stream_ordering to new forward extremeties in the room. + # This allows us to later efficiently look up the forward extremeties + # for a room before a given stream_ordering + self._simple_insert_many_txn( + txn, + table="stream_ordering_to_exterm", + values=[ + { + "room_id": room_id, + "event_id": event_id, + "stream_ordering": max_stream_order, + } + for room_id, new_extrem in new_forward_extremeties.items() + for event_id in new_extrem + ] + ) + # Ensure that we don't have the same event twice. # Pick the earliest non-outlier if there is one, else the earliest one. new_events_and_contexts = OrderedDict() @@ -550,7 +727,7 @@ class EventsStore(SQLBaseStore): # Update the event_backward_extremities table now that this # event isn't an outlier any more. - self._update_extremeties(txn, [event]) + self._update_backward_extremeties(txn, [event]) events_and_contexts = [ ec for ec in events_and_contexts if ec[0] not in to_remove @@ -798,29 +975,6 @@ class EventsStore(SQLBaseStore): # to update the current state table return - for event, _ in state_events_and_contexts: - if event.internal_metadata.is_outlier(): - # Outlier events shouldn't clobber the current state. - continue - - txn.call_after( - self._get_current_state_for_key.invalidate, - (event.room_id, event.type, event.state_key,) - ) - - self._simple_upsert_txn( - txn, - "current_state_events", - keyvalues={ - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - }, - values={ - "event_id": event.event_id, - } - ) - return def _add_to_cache(self, txn, events_and_contexts): @@ -1084,10 +1238,10 @@ class EventsStore(SQLBaseStore): self._do_fetch ) - logger.info("Loading %d events", len(events)) + logger.debug("Loading %d events", len(events)) with PreserveLoggingContext(): rows = yield events_d - logger.info("Loaded %d events (%d rows)", len(events), len(rows)) + logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) if not allow_rejected: rows[:] = [r for r in rows if not r["rejects"]] @@ -1418,6 +1572,7 @@ class EventsStore(SQLBaseStore): """The current minimum token that backfilled events have reached""" return -self._backfill_id_gen.get_current_token() + @cached(num_args=5, max_entries=10) def get_all_new_events(self, last_backfill_id, last_forward_id, current_backfill_id, current_forward_id, limit): """Get all the new events that have arrived at the server either as @@ -1450,15 +1605,6 @@ class EventsStore(SQLBaseStore): upper_bound = current_forward_id sql = ( - "SELECT event_stream_ordering FROM current_state_resets" - " WHERE ? < event_stream_ordering" - " AND event_stream_ordering <= ?" - " ORDER BY event_stream_ordering ASC" - ) - txn.execute(sql, (last_forward_id, upper_bound)) - state_resets = txn.fetchall() - - sql = ( "SELECT event_stream_ordering, event_id, state_group" " FROM ex_outlier_stream" " WHERE ? > event_stream_ordering" @@ -1469,7 +1615,6 @@ class EventsStore(SQLBaseStore): forward_ex_outliers = txn.fetchall() else: new_forward_events = [] - state_resets = [] forward_ex_outliers = [] sql = ( @@ -1509,7 +1654,6 @@ class EventsStore(SQLBaseStore): return AllNewEventsResult( new_forward_events, new_backfill_events, forward_ex_outliers, backward_ex_outliers, - state_resets, ) return self.runInteraction("get_all_new_events", get_all_new_events_txn) @@ -1735,5 +1879,4 @@ class EventsStore(SQLBaseStore): AllNewEventsResult = namedtuple("AllNewEventsResult", [ "new_forward_events", "new_backfill_events", "forward_ex_outliers", "backward_ex_outliers", - "state_resets" ]) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index e46ae6502e..b357f22be7 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 39 +SCHEMA_VERSION = 40 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 983a8ec52b..26be6060c3 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -413,6 +413,17 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): desc="user_delete_threepids", ) + def user_delete_threepid(self, user_id, medium, address): + return self._simple_delete( + "user_threepids", + keyvalues={ + "user_id": user_id, + "medium": medium, + "address": address, + }, + desc="user_delete_threepids", + ) + @defer.inlineCallbacks def count_all_users(self): """Counts all users registered on the homeserver.""" diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 5d18037c7c..545d3d3a99 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -66,8 +66,6 @@ class RoomMemberStore(SQLBaseStore): ) for event in events: - txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) - txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after( self._membership_stream_cache.entity_has_changed, event.state_key, event.internal_metadata.stream_ordering @@ -131,7 +129,7 @@ class RoomMemberStore(SQLBaseStore): with self._stream_id_gen.get_next() as stream_ordering: yield self.runInteraction("locally_reject_invite", f, stream_ordering) - @cached(max_entries=5000) + @cached(max_entries=500000, iterable=True) def get_users_in_room(self, room_id): def f(txn): @@ -220,7 +218,7 @@ class RoomMemberStore(SQLBaseStore): " ON e.event_id = c.event_id" " AND m.room_id = c.room_id" " AND m.user_id = c.state_key" - " WHERE %s" + " WHERE c.type = 'm.room.member' AND %s" ) % (where_clause,) txn.execute(sql, args) @@ -266,7 +264,7 @@ class RoomMemberStore(SQLBaseStore): " ON m.event_id = c.event_id " " AND m.room_id = c.room_id " " AND m.user_id = c.state_key" - " WHERE %(where)s" + " WHERE c.type = 'm.room.member' AND %(where)s" ) % { "where": where_clause, } @@ -276,12 +274,29 @@ class RoomMemberStore(SQLBaseStore): return rows - @cached(max_entries=5000) + @cached(max_entries=500000, iterable=True) def get_rooms_for_user(self, user_id): return self.get_rooms_for_user_where_membership_is( user_id, membership_list=[Membership.JOIN], ) + @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) + def get_users_who_share_room_with_user(self, user_id, cache_context): + """Returns the set of users who share a room with `user_id` + """ + rooms = yield self.get_rooms_for_user( + user_id, on_invalidate=cache_context.invalidate, + ) + + user_who_share_room = set() + for room in rooms: + user_ids = yield self.get_users_in_room( + room.room_id, on_invalidate=cache_context.invalidate, + ) + user_who_share_room.update(user_ids) + + defer.returnValue(user_who_share_room) + def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" def f(txn): @@ -390,7 +405,8 @@ class RoomMemberStore(SQLBaseStore): room_id, state_group, state_ids, ) - @cachedInlineCallbacks(num_args=2, cache_context=True) + @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, + max_entries=100000) def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, cache_context, event=None): # We don't use `state_group`, it's there so that we can cache based diff --git a/synapse/storage/schema/delta/40/current_state_idx.sql b/synapse/storage/schema/delta/40/current_state_idx.sql new file mode 100644 index 0000000000..7ffa189f39 --- /dev/null +++ b/synapse/storage/schema/delta/40/current_state_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('current_state_members_idx', '{}'); diff --git a/synapse/storage/schema/delta/40/device_inbox.sql b/synapse/storage/schema/delta/40/device_inbox.sql new file mode 100644 index 0000000000..b9fe1f0480 --- /dev/null +++ b/synapse/storage/schema/delta/40/device_inbox.sql @@ -0,0 +1,21 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- turn the pre-fill startup query into a index-only scan on postgresql. +INSERT into background_updates (update_name, progress_json) + VALUES ('device_inbox_stream_index', '{}'); + +INSERT into background_updates (update_name, progress_json, depends_on) + VALUES ('device_inbox_stream_drop', '{}', 'device_inbox_stream_index'); diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql new file mode 100644 index 0000000000..54841b3843 --- /dev/null +++ b/synapse/storage/schema/delta/40/device_list_streams.sql @@ -0,0 +1,59 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Cache of remote devices. +CREATE TABLE device_lists_remote_cache ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + content TEXT NOT NULL +); + +CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id); + + +-- The last update we got for a user. Empty if we're not receiving updates for +-- that user. +CREATE TABLE device_lists_remote_extremeties ( + user_id TEXT NOT NULL, + stream_id TEXT NOT NULL +); + +CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id); + + +-- Stream of device lists updates. Includes both local and remotes +CREATE TABLE device_lists_stream ( + stream_id BIGINT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL +); + +CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id); + + +-- The stream of updates to send to other servers. We keep at least one row +-- per user that was sent so that the prev_id for any new updates can be +-- calculated +CREATE TABLE device_lists_outbound_pokes ( + destination TEXT NOT NULL, + stream_id BIGINT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + sent BOOLEAN NOT NULL, + ts BIGINT NOT NULL -- So that in future we can clear out pokes to dead servers +); + +CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id); +CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7f466c40ac..1b3800eb6a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -49,6 +49,7 @@ class StateStore(SQLBaseStore): STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" def __init__(self, hs): super(StateStore, self).__init__(hs) @@ -60,6 +61,13 @@ class StateStore(SQLBaseStore): self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state, ) + self.register_background_index_update( + self.CURRENT_STATE_INDEX_UPDATE_NAME, + index_name="current_state_events_member_index", + table="current_state_events", + columns=["state_key"], + where_clause="type='m.room.member'", + ) @defer.inlineCallbacks def get_state_groups_ids(self, room_id, event_ids): @@ -232,59 +240,7 @@ class StateStore(SQLBaseStore): return count - @defer.inlineCallbacks - def get_current_state(self, room_id, event_type=None, state_key=""): - if event_type and state_key is not None: - result = yield self.get_current_state_for_key( - room_id, event_type, state_key - ) - defer.returnValue(result) - - def f(txn): - sql = ( - "SELECT event_id FROM current_state_events" - " WHERE room_id = ? " - ) - - if event_type and state_key is not None: - sql += " AND type = ? AND state_key = ? " - args = (room_id, event_type, state_key) - elif event_type: - sql += " AND type = ?" - args = (room_id, event_type) - else: - args = (room_id, ) - - txn.execute(sql, args) - results = txn.fetchall() - - return [r[0] for r in results] - - event_ids = yield self.runInteraction("get_current_state", f) - events = yield self._get_events(event_ids, get_prev_content=False) - defer.returnValue(events) - - @defer.inlineCallbacks - def get_current_state_for_key(self, room_id, event_type, state_key): - event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key) - events = yield self._get_events(event_ids, get_prev_content=False) - defer.returnValue(events) - - @cached(num_args=3) - def _get_current_state_for_key(self, room_id, event_type, state_key): - def f(txn): - sql = ( - "SELECT event_id FROM current_state_events" - " WHERE room_id = ? AND type = ? AND state_key = ?" - ) - - args = (room_id, event_type, state_key) - txn.execute(sql, args) - results = txn.fetchall() - return [r[0] for r in results] - return self.runInteraction("get_current_state_for_key", f) - - @cached(num_args=2, max_entries=1000) + @cached(num_args=2, max_entries=100000, iterable=True) def _get_state_group_from_group(self, group, types): raise NotImplementedError() diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 2dc24951c4..200d124632 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -244,6 +244,20 @@ class StreamStore(SQLBaseStore): defer.returnValue(results) + def get_rooms_that_changed(self, room_ids, from_key): + """Given a list of rooms and a token, return rooms where there may have + been changes. + + Args: + room_ids (list) + from_key (str): The room_key portion of a StreamToken + """ + from_key = RoomStreamToken.parse_stream_token(from_key).stream + return set( + room_id for room_id in room_ids + if self._events_stream_cache.has_entity_changed(room_id, from_key) + ) + @defer.inlineCallbacks def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0, order='DESC'): |