diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/data_stores/main/__init__.py | 5 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/cache.py | 44 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/deviceinbox.py | 88 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/devices.py | 211 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/directory.py | 26 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/e2e_room_keys.py | 3 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/end_to_end_keys.py | 14 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/events.py | 114 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/events_worker.py | 118 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/media_repository.py | 4 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/presence.py | 23 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/push_rule.py | 1 | ||||
-rw-r--r-- | synapse/storage/data_stores/main/room.py | 40 | ||||
-rw-r--r-- | synapse/storage/database.py | 11 |
14 files changed, 396 insertions, 306 deletions
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index acca079f23..649e835303 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -144,7 +144,10 @@ class DataStore( db_conn, "device_lists_stream", "stream_id", - extra_tables=[("user_signature_stream", "stream_id")], + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ], ) self._cross_signing_id_gen = StreamIdGenerator( db_conn, "e2e_cross_signing_keys", "stream_id" diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index d4c44dcc75..4dc5da3fe8 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -32,7 +32,29 @@ logger = logging.getLogger(__name__) CURRENT_STATE_CACHE_NAME = "cs_cache_fake" -class CacheInvalidationStore(SQLBaseStore): +class CacheInvalidationWorkerStore(SQLBaseStore): + def get_all_updated_caches(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_caches_txn(txn): + # We purposefully don't bound by the current token, as we want to + # send across cache invalidations as quickly as possible. Cache + # invalidations are idempotent, so duplicates are fine. + sql = ( + "SELECT stream_id, cache_func, keys, invalidation_ts" + " FROM cache_invalidation_stream" + " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, limit)) + return txn.fetchall() + + return self.db.runInteraction( + "get_all_updated_caches", get_all_updated_caches_txn + ) + + +class CacheInvalidationStore(CacheInvalidationWorkerStore): async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -145,26 +167,6 @@ class CacheInvalidationStore(SQLBaseStore): }, ) - def get_all_updated_caches(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_caches_txn(txn): - # We purposefully don't bound by the current token, as we want to - # send across cache invalidations as quickly as possible. Cache - # invalidations are idempotent, so duplicates are fine. - sql = ( - "SELECT stream_id, cache_func, keys, invalidation_ts" - " FROM cache_invalidation_stream" - " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, limit)) - return txn.fetchall() - - return self.db.runInteraction( - "get_all_updated_caches", get_all_updated_caches_txn - ) - def get_cache_stream_token(self): if self._cache_id_gen: return self._cache_id_gen.get_current_token() diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 0613b49f4a..9a1178fb39 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -207,6 +207,50 @@ class DeviceInboxWorkerStore(SQLBaseStore): "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) + def get_all_new_device_messages(self, last_pos, current_pos, limit): + """ + Args: + last_pos(int): + current_pos(int): + limit(int): + Returns: + A deferred list of rows from the device inbox + """ + if last_pos == current_pos: + return defer.succeed([]) + + def get_all_new_device_messages_txn(txn): + # We limit like this as we might have multiple rows per stream_id, and + # we want to make sure we always get all entries for any stream_id + # we return. + upper_pos = min(current_pos, last_pos + limit) + sql = ( + "SELECT max(stream_id), user_id" + " FROM device_inbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY user_id" + ) + txn.execute(sql, (last_pos, upper_pos)) + rows = txn.fetchall() + + sql = ( + "SELECT max(stream_id), destination" + " FROM device_federation_outbox" + " WHERE ? < stream_id AND stream_id <= ?" + " GROUP BY destination" + ) + txn.execute(sql, (last_pos, upper_pos)) + rows.extend(txn) + + # Order by ascending stream ordering + rows.sort() + + return rows + + return self.db.runInteraction( + "get_all_new_device_messages", get_all_new_device_messages_txn + ) + class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" @@ -411,47 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) rows.append((user_id, device_id, stream_id, message_json)) txn.executemany(sql, rows) - - def get_all_new_device_messages(self, last_pos, current_pos, limit): - """ - Args: - last_pos(int): - current_pos(int): - limit(int): - Returns: - A deferred list of rows from the device inbox - """ - if last_pos == current_pos: - return defer.succeed([]) - - def get_all_new_device_messages_txn(txn): - # We limit like this as we might have multiple rows per stream_id, and - # we want to make sure we always get all entries for any stream_id - # we return. - upper_pos = min(current_pos, last_pos + limit) - sql = ( - "SELECT max(stream_id), user_id" - " FROM device_inbox" - " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY user_id" - ) - txn.execute(sql, (last_pos, upper_pos)) - rows = txn.fetchall() - - sql = ( - "SELECT max(stream_id), destination" - " FROM device_federation_outbox" - " WHERE ? < stream_id AND stream_id <= ?" - " GROUP BY destination" - ) - txn.execute(sql, (last_pos, upper_pos)) - rows.extend(txn) - - # Order by ascending stream ordering - rows.sort() - - return rows - - return self.db.runInteraction( - "get_all_new_device_messages", get_all_new_device_messages_txn - ) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 8af5f7de54..20995e1b78 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import List, Tuple from six import iteritems @@ -31,7 +32,7 @@ from synapse.logging.opentracing import ( ) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import Database +from synapse.storage.database import Database, LoggingTransaction from synapse.types import Collection, get_verify_key_from_cross_signing_key from synapse.util.caches.descriptors import ( Cache, @@ -40,6 +41,7 @@ from synapse.util.caches.descriptors import ( cachedList, ) from synapse.util.iterutils import batch_iter +from synapse.util.stringutils import shortstr logger = logging.getLogger(__name__) @@ -112,23 +114,13 @@ class DeviceWorkerStore(SQLBaseStore): if not has_changed: return now_stream_id, [] - # We retrieve n+1 devices from the list of outbound pokes where n is - # our outbound device update limit. We then check if the very last - # device has the same stream_id as the second-to-last device. If so, - # then we ignore all devices with that stream_id and only send the - # devices with a lower stream_id. - # - # If when culling the list we end up with no devices afterwards, we - # consider the device update to be too large, and simply skip the - # stream_id; the rationale being that such a large device list update - # is likely an error. updates = yield self.db.runInteraction( "get_device_updates_by_remote", self._get_device_updates_by_remote_txn, destination, from_stream_id, now_stream_id, - limit + 1, + limit, ) # Return an empty list if there are no updates @@ -166,14 +158,6 @@ class DeviceWorkerStore(SQLBaseStore): "device_id": verify_key.version, } - # if we have exceeded the limit, we need to exclude any results with the - # same stream_id as the last row. - if len(updates) > limit: - stream_id_cutoff = updates[-1][2] - now_stream_id = stream_id_cutoff - 1 - else: - stream_id_cutoff = None - # Perform the equivalent of a GROUP BY # # Iterate through the updates list and copy non-duplicate @@ -192,10 +176,6 @@ class DeviceWorkerStore(SQLBaseStore): query_map = {} cross_signing_keys_by_user = {} for user_id, device_id, update_stream_id, update_context in updates: - if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff: - # Stop processing updates - break - if ( user_id in master_key_by_user and device_id == master_key_by_user[user_id]["device_id"] @@ -218,17 +198,6 @@ class DeviceWorkerStore(SQLBaseStore): if update_stream_id > previous_update_stream_id: query_map[key] = (update_stream_id, update_context) - # If we didn't find any updates with a stream_id lower than the cutoff, it - # means that there are more than limit updates all of which have the same - # steam_id. - - # That should only happen if a client is spamming the server with new - # devices, in which case E2E isn't going to work well anyway. We'll just - # skip that stream_id and return an empty list, and continue with the next - # stream_id next time. - if not query_map and not cross_signing_keys_by_user: - return stream_id_cutoff, [] - results = yield self._get_device_update_edus_by_remote( destination, from_stream_id, query_map ) @@ -611,22 +580,33 @@ class DeviceWorkerStore(SQLBaseStore): else: return set() - def get_all_device_list_changes_for_remotes(self, from_key, to_key): - """Return a list of `(stream_id, user_id, destination)` which is the - combined list of changes to devices, and which destinations need to be - poked. `destination` may be None if no destinations need to be poked. + async def get_all_device_list_changes_for_remotes( + self, from_key: int, to_key: int, limit: int, + ) -> List[Tuple[int, str]]: + """Return a list of `(stream_id, entity)` which is the combined list of + changes to devices and which destinations need to be poked. Entity is + either a user ID (starting with '@') or a remote destination. """ - # We do a group by here as there can be a large number of duplicate - # entries, since we throw away device IDs. + + # This query Does The Right Thing where it'll correctly apply the + # bounds to the inner queries. sql = """ - SELECT MAX(stream_id) AS stream_id, user_id, destination - FROM device_lists_stream - LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) + SELECT stream_id, entity FROM ( + SELECT stream_id, user_id AS entity FROM device_lists_stream + UNION ALL + SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes + ) AS e WHERE ? < stream_id AND stream_id <= ? - GROUP BY user_id, destination + LIMIT ? """ - return self.db.execute( - "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key + + return await self.db.execute( + "get_all_device_list_changes_for_remotes", + None, + sql, + from_key, + to_key, + limit, ) @cached(max_entries=10000) @@ -1021,29 +1001,49 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): """Persist that a user's devices have been updated, and which hosts (if any) should be poked. """ - with self._device_list_id_gen.get_next() as stream_id: + if not device_ids: + return + + with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: + yield self.db.runInteraction( + "add_device_change_to_stream", + self._add_device_change_to_stream_txn, + user_id, + device_ids, + stream_ids, + ) + + if not hosts: + return stream_ids[-1] + + context = get_active_span_text_map() + with self._device_list_id_gen.get_next_mult( + len(hosts) * len(device_ids) + ) as stream_ids: yield self.db.runInteraction( - "add_device_change_to_streams", - self._add_device_change_txn, + "add_device_outbound_poke_to_stream", + self._add_device_outbound_poke_to_stream_txn, user_id, device_ids, hosts, - stream_id, + stream_ids, + context, ) - return stream_id - def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id): - now = self._clock.time_msec() + return stream_ids[-1] + def _add_device_change_to_stream_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_ids: Collection[str], + stream_ids: List[str], + ): txn.call_after( - self._device_list_stream_cache.entity_has_changed, user_id, stream_id + self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1], ) - for host in hosts: - txn.call_after( - self._device_list_federation_stream_cache.entity_has_changed, - host, - stream_id, - ) + + min_stream_id = stream_ids[0] # Delete older entries in the table, as we really only care about # when the latest change happened. @@ -1052,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): DELETE FROM device_lists_stream WHERE user_id = ? AND device_id = ? AND stream_id < ? """, - [(user_id, device_id, stream_id) for device_id in device_ids], + [(user_id, device_id, min_stream_id) for device_id in device_ids], ) self.db.simple_insert_many_txn( @@ -1060,11 +1060,22 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): table="device_lists_stream", values=[ {"stream_id": stream_id, "user_id": user_id, "device_id": device_id} - for device_id in device_ids + for stream_id, device_id in zip(stream_ids, device_ids) ], ) - context = get_active_span_text_map() + def _add_device_outbound_poke_to_stream_txn( + self, txn, user_id, device_ids, hosts, stream_ids, context, + ): + for host in hosts: + txn.call_after( + self._device_list_federation_stream_cache.entity_has_changed, + host, + stream_ids[-1], + ) + + now = self._clock.time_msec() + next_stream_id = iter(stream_ids) self.db.simple_insert_many_txn( txn, @@ -1072,7 +1083,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values=[ { "destination": destination, - "stream_id": stream_id, + "stream_id": next(next_stream_id), "user_id": user_id, "device_id": device_id, "sent": False, @@ -1086,18 +1097,47 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ], ) - def _prune_old_outbound_device_pokes(self): + def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000): """Delete old entries out of the device_lists_outbound_pokes to ensure - that we don't fill up due to dead servers. We keep one entry per - (destination, user_id) tuple to ensure that the prev_ids remain correct - if the server does come back. + that we don't fill up due to dead servers. + + Normally, we try to send device updates as a delta since a previous known point: + this is done by setting the prev_id in the m.device_list_update EDU. However, + for that to work, we have to have a complete record of each change to + each device, which can add up to quite a lot of data. + + An alternative mechanism is that, if the remote server sees that it has missed + an entry in the stream_id sequence for a given user, it will request a full + list of that user's devices. Hence, we can reduce the amount of data we have to + store (and transmit in some future transaction), by clearing almost everything + for a given destination out of the database, and having the remote server + resync. + + All we need to do is make sure we keep at least one row for each + (user, destination) pair, to remind us to send a m.device_list_update EDU for + that user when the destination comes back. It doesn't matter which device + we keep. """ - yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000 + yesterday = self._clock.time_msec() - prune_age def _prune_txn(txn): + # look for (user, destination) pairs which have an update older than + # the cutoff. + # + # For each pair, we also need to know the most recent stream_id, and + # an arbitrary device_id at that stream_id. select_sql = """ - SELECT destination, user_id, max(stream_id) as stream_id - FROM device_lists_outbound_pokes + SELECT + dlop1.destination, + dlop1.user_id, + MAX(dlop1.stream_id) AS stream_id, + (SELECT MIN(dlop2.device_id) AS device_id FROM + device_lists_outbound_pokes dlop2 + WHERE dlop2.destination = dlop1.destination AND + dlop2.user_id=dlop1.user_id AND + dlop2.stream_id=MAX(dlop1.stream_id) + ) + FROM device_lists_outbound_pokes dlop1 GROUP BY destination, user_id HAVING min(ts) < ? AND count(*) > 1 """ @@ -1108,14 +1148,29 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not rows: return + logger.info( + "Pruning old outbound device list updates for %i users/destinations: %s", + len(rows), + shortstr((row[0], row[1]) for row in rows), + ) + + # we want to keep the update with the highest stream_id for each user. + # + # there might be more than one update (with different device_ids) with the + # same stream_id, so we also delete all but one rows with the max stream id. delete_sql = """ DELETE FROM device_lists_outbound_pokes - WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ? + WHERE destination = ? AND user_id = ? AND ( + stream_id < ? OR + (stream_id = ? AND device_id != ?) + ) """ - - txn.executemany( - delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows) - ) + count = 0 + for (destination, user_id, stream_id, device_id) in rows: + txn.execute( + delete_sql, (destination, user_id, stream_id, stream_id, device_id) + ) + count += txn.rowcount # Since we've deleted unsent deltas, we need to remove the entry # of last successful sent so that the prev_ids are correctly set. @@ -1125,7 +1180,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): """ txn.executemany(sql, ((row[0], row[1]) for row in rows)) - logger.info("Pruned %d device list outbound pokes", txn.rowcount) + logger.info("Pruned %d device list outbound pokes", count) return run_as_background_process( "prune_old_outbound_device_pokes", diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py index c9e7de7d12..e1d1bc3e05 100644 --- a/synapse/storage/data_stores/main/directory.py +++ b/synapse/storage/data_stores/main/directory.py @@ -14,6 +14,7 @@ # limitations under the License. from collections import namedtuple +from typing import Optional from twisted.internet import defer @@ -159,10 +160,29 @@ class DirectoryStore(DirectoryWorkerStore): return room_id - def update_aliases_for_room(self, old_room_id, new_room_id, creator): + def update_aliases_for_room( + self, old_room_id: str, new_room_id: str, creator: Optional[str] = None, + ): + """Repoint all of the aliases for a given room, to a different room. + + Args: + old_room_id: + new_room_id: + creator: The user to record as the creator of the new mapping. + If None, the creator will be left unchanged. + """ + def _update_aliases_for_room_txn(txn): - sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?" - txn.execute(sql, (new_room_id, creator, old_room_id)) + update_creator_sql = "" + sql_params = (new_room_id, old_room_id) + if creator: + update_creator_sql = ", creator = ?" + sql_params = (new_room_id, creator, old_room_id) + + sql = "UPDATE room_aliases SET room_id = ? %s WHERE room_id = ?" % ( + update_creator_sql, + ) + txn.execute(sql, sql_params) self._invalidate_cache_and_stream( txn, self.get_aliases_for_room, (old_room_id,) ) diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index 84594cf0a9..23f4570c4b 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -146,7 +146,8 @@ class EndToEndRoomKeyStore(SQLBaseStore): room_entry["sessions"][row["session_id"]] = { "first_message_index": row["first_message_index"], "forwarded_count": row["forwarded_count"], - "is_verified": row["is_verified"], + # is_verified must be returned to the client as a boolean + "is_verified": bool(row["is_verified"]), "session_data": json.loads(row["session_data"]), } diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index 001a53f9b4..bcf746b7ef 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -537,7 +537,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): return result - def get_all_user_signature_changes_for_remotes(self, from_key, to_key): + def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit): """Return a list of changes from the user signature stream to notify remotes. Note that the user signature stream represents when a user signs their device with their user-signing key, which is not published to other @@ -552,13 +552,19 @@ class EndToEndKeyWorkerStore(SQLBaseStore): Deferred[list[(int,str)]] a list of `(stream_id, user_id)` """ sql = """ - SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id + SELECT stream_id, from_user_id AS user_id FROM user_signature_stream WHERE ? < stream_id AND stream_id <= ? - GROUP BY user_id + ORDER BY stream_id ASC + LIMIT ? """ return self.db.execute( - "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key + "get_all_user_signature_changes_for_remotes", + None, + sql, + from_key, + to_key, + limit, ) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index d593ef47b8..e71c23541d 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1267,104 +1267,6 @@ class EventsStore( ret = yield self.db.runInteraction("count_daily_active_rooms", _count) return ret - def get_current_backfill_token(self): - """The current minimum token that backfilled events have reached""" - return -self._backfill_id_gen.get_current_token() - - def get_current_events_token(self): - """The current maximum token that events have reached""" - return self._stream_id_gen.get_current_token() - - def get_all_new_forward_event_rows(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_new_forward_event_rows(txn): - sql = ( - "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? < stream_ordering AND stream_ordering <= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - new_event_updates = txn.fetchall() - - if len(new_event_updates) == limit: - upper_bound = new_event_updates[-1][0] - else: - upper_bound = current_id - - sql = ( - "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? < event_stream_ordering" - " AND event_stream_ordering <= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (last_id, upper_bound)) - new_event_updates.extend(txn) - - return new_event_updates - - return self.db.runInteraction( - "get_all_new_forward_event_rows", get_all_new_forward_event_rows - ) - - def get_all_new_backfill_event_rows(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_new_backfill_event_rows(txn): - sql = ( - "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? > stream_ordering AND stream_ordering >= ?" - " ORDER BY stream_ordering ASC" - " LIMIT ?" - ) - txn.execute(sql, (-last_id, -current_id, limit)) - new_event_updates = txn.fetchall() - - if len(new_event_updates) == limit: - upper_bound = new_event_updates[-1][0] - else: - upper_bound = current_id - - sql = ( - "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" - " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" - " LEFT JOIN redactions USING (event_id)" - " LEFT JOIN state_events USING (event_id)" - " LEFT JOIN event_relations USING (event_id)" - " WHERE ? > event_stream_ordering" - " AND event_stream_ordering >= ?" - " ORDER BY event_stream_ordering DESC" - ) - txn.execute(sql, (-last_id, -upper_bound)) - new_event_updates.extend(txn.fetchall()) - - return new_event_updates - - return self.db.runInteraction( - "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows - ) - @cached(num_args=5, max_entries=10) def get_all_new_events( self, @@ -1850,22 +1752,6 @@ class EventsStore( return (int(res["topological_ordering"]), int(res["stream_ordering"])) - def get_all_updated_current_state_deltas(self, from_token, to_token, limit): - def get_all_updated_current_state_deltas_txn(txn): - sql = """ - SELECT stream_id, room_id, type, state_key, event_id - FROM current_state_delta_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC LIMIT ? - """ - txn.execute(sql, (from_token, to_token, limit)) - return txn.fetchall() - - return self.db.runInteraction( - "get_all_updated_current_state_deltas", - get_all_updated_current_state_deltas_txn, - ) - def insert_labels_for_event_txn( self, txn, event_id, labels, room_id, topological_ordering ): diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index ca237c6f12..16ea8948b1 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -35,7 +35,7 @@ from synapse.api.room_versions import ( ) from synapse.events import make_event_from_dict from synapse.events.utils import prune_event -from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.database import Database @@ -409,7 +409,7 @@ class EventsWorkerStore(SQLBaseStore): missing_events_ids = [e for e in event_ids if e not in event_entry_map] if missing_events_ids: - log_ctx = LoggingContext.current_context() + log_ctx = current_context() log_ctx.record_event_fetch(len(missing_events_ids)) # Note that _get_events_from_db is also responsible for turning db rows @@ -963,3 +963,117 @@ class EventsWorkerStore(SQLBaseStore): complexity_v1 = round(state_events / 500, 2) return {"v1": complexity_v1} + + def get_current_backfill_token(self): + """The current minimum token that backfilled events have reached""" + return -self._backfill_id_gen.get_current_token() + + def get_current_events_token(self): + """The current maximum token that events have reached""" + return self._stream_id_gen.get_current_token() + + def get_all_new_forward_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_forward_event_rows(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? < event_stream_ordering" + " AND event_stream_ordering <= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_id, upper_bound)) + new_event_updates.extend(txn) + + return new_event_updates + + return self.db.runInteraction( + "get_all_new_forward_event_rows", get_all_new_forward_event_rows + ) + + def get_all_new_backfill_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_backfill_event_rows(txn): + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (-last_id, -current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts, relates_to_id" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " LEFT JOIN event_relations USING (event_id)" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_id, -upper_bound)) + new_event_updates.extend(txn.fetchall()) + + return new_event_updates + + return self.db.runInteraction( + "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows + ) + + def get_all_updated_current_state_deltas(self, from_token, to_token, limit): + def get_all_updated_current_state_deltas_txn(txn): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? + """ + txn.execute(sql, (from_token, to_token, limit)) + return txn.fetchall() + + return self.db.runInteraction( + "get_all_updated_current_state_deltas", + get_all_updated_current_state_deltas_txn, + ) diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index 80ca36dedf..cf195f8aa6 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -340,7 +340,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "get_expired_url_cache", _get_expired_url_cache_txn ) - def delete_url_cache(self, media_ids): + async def delete_url_cache(self, media_ids): if len(media_ids) == 0: return @@ -349,7 +349,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def _delete_url_cache_txn(txn): txn.executemany(sql, [(media_id,) for media_id in media_ids]) - return self.db.runInteraction("delete_url_cache", _delete_url_cache_txn) + return await self.db.runInteraction("delete_url_cache", _delete_url_cache_txn) def get_url_cache_media_before(self, before_ts): sql = ( diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py index 604c8b7ddd..dab31e0c2d 100644 --- a/synapse/storage/data_stores/main/presence.py +++ b/synapse/storage/data_stores/main/presence.py @@ -60,7 +60,7 @@ class PresenceStore(SQLBaseStore): "status_msg": state.status_msg, "currently_active": state.currently_active, } - for state in presence_states + for stream_id, state in zip(stream_orderings, presence_states) ], ) @@ -73,19 +73,22 @@ class PresenceStore(SQLBaseStore): ) txn.execute(sql + clause, [stream_id] + list(args)) - def get_all_presence_updates(self, last_id, current_id): + def get_all_presence_updates(self, last_id, current_id, limit): if last_id == current_id: return defer.succeed([]) def get_all_presence_updates_txn(txn): - sql = ( - "SELECT stream_id, user_id, state, last_active_ts," - " last_federation_update_ts, last_user_sync_ts, status_msg," - " currently_active" - " FROM presence_stream" - " WHERE ? < stream_id AND stream_id <= ?" - ) - txn.execute(sql, (last_id, current_id)) + sql = """ + SELECT stream_id, user_id, state, last_active_ts, + last_federation_update_ts, last_user_sync_ts, + status_msg, + currently_active + FROM presence_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() return self.db.runInteraction( diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index 62ac88d9f2..46f9bda773 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -41,6 +41,7 @@ def _load_rules(rawrules, enabled_map): rule = dict(rawrule) rule["conditions"] = json.loads(rawrule["conditions"]) rule["actions"] = json.loads(rawrule["actions"]) + rule["default"] = False ruleslist.append(rule) # We're going to be mutating this a lot, so do a deep copy diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index e6c10c6316..aaebe427d3 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -732,6 +732,26 @@ class RoomWorkerStore(SQLBaseStore): return total_media_quarantined + def get_all_new_public_rooms(self, prev_id, current_id, limit): + def get_all_new_public_rooms(txn): + sql = """ + SELECT stream_id, room_id, visibility, appservice_id, network_id + FROM public_room_list_stream + WHERE stream_id > ? AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + + txn.execute(sql, (prev_id, current_id, limit)) + return txn.fetchall() + + if prev_id == current_id: + return defer.succeed([]) + + return self.db.runInteraction( + "get_all_new_public_rooms", get_all_new_public_rooms + ) + class RoomBackgroundUpdateStore(SQLBaseStore): REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" @@ -1249,26 +1269,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - def get_all_new_public_rooms(self, prev_id, current_id, limit): - def get_all_new_public_rooms(txn): - sql = """ - SELECT stream_id, room_id, visibility, appservice_id, network_id - FROM public_room_list_stream - WHERE stream_id > ? AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? - """ - - txn.execute(sql, (prev_id, current_id, limit)) - return txn.fetchall() - - if prev_id == current_id: - return defer.succeed([]) - - return self.db.runInteraction( - "get_all_new_public_rooms", get_all_new_public_rooms - ) - @defer.inlineCallbacks def block_room(self, room_id, user_id): """Marks the room as blocked. Can be called multiple times. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index e61595336c..715c0346dd 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -32,6 +32,7 @@ from synapse.config.database import DatabaseConnectionConfig from synapse.logging.context import ( LoggingContext, LoggingContextOrSentinel, + current_context, make_deferred_yieldable, ) from synapse.metrics.background_process_metrics import run_as_background_process @@ -483,7 +484,7 @@ class Database(object): end = monotonic_time() duration = end - start - LoggingContext.current_context().add_database_transaction(duration) + current_context().add_database_transaction(duration) transaction_logger.debug("[TXN END] {%s} %f sec", name, duration) @@ -510,7 +511,7 @@ class Database(object): after_callbacks = [] # type: List[_CallbackListEntry] exception_callbacks = [] # type: List[_CallbackListEntry] - if LoggingContext.current_context() == LoggingContext.sentinel: + if not current_context(): logger.warning("Starting db txn '%s' from sentinel context", desc) try: @@ -547,10 +548,8 @@ class Database(object): Returns: Deferred: The result of func """ - parent_context = ( - LoggingContext.current_context() - ) # type: Optional[LoggingContextOrSentinel] - if parent_context == LoggingContext.sentinel: + parent_context = current_context() # type: Optional[LoggingContextOrSentinel] + if not parent_context: logger.warning( "Starting db connection from sentinel context: metrics will be lost" ) |