summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/devices.py4
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py67
-rw-r--r--synapse/storage/databases/main/filtering.py12
-rw-r--r--synapse/storage/databases/main/profile.py12
-rw-r--r--synapse/storage/databases/main/push_rule.py1
-rw-r--r--synapse/storage/databases/main/relations.py30
6 files changed, 111 insertions, 15 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py

index a67fdb3c22..f677d048aa 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -1941,6 +1941,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id, stream_ids[-1], ) + txn.call_after( + self._get_e2e_device_keys_for_federation_query_inner.invalidate, + (user_id,), + ) min_stream_id = stream_ids[0] diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 4bc391f213..91ae9c457d 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -16,6 +16,7 @@ import abc from typing import ( TYPE_CHECKING, + Any, Collection, Dict, Iterable, @@ -39,6 +40,7 @@ from synapse.appservice import ( TransactionUnusedFallbackKeys, ) from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.replication.tcp.streams._base import DeviceListsStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import ( DatabasePool, @@ -104,6 +106,23 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker self.hs.config.federation.allow_device_name_lookup_over_federation ) + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: + if stream_name == DeviceListsStream.NAME: + for row in rows: + assert isinstance(row, DeviceListsStream.DeviceListsStreamRow) + if row.entity.startswith("@"): + self._get_e2e_device_keys_for_federation_query_inner.invalidate( + (row.entity,) + ) + + super().process_replication_rows(stream_name, instance_name, token, rows) + async def get_e2e_device_keys_for_federation_query( self, user_id: str ) -> Tuple[int, List[JsonDict]]: @@ -114,6 +133,50 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker """ now_stream_id = self.get_device_stream_token() + # We need to be careful with the caching here, as we need to always + # return *all* persisted devices, however there may be a lag between a + # new device being persisted and the cache being invalidated. + cached_results = ( + self._get_e2e_device_keys_for_federation_query_inner.cache.get_immediate( + user_id, None + ) + ) + if cached_results is not None: + # Check that there have been no new devices added by another worker + # after the cache. This should be quick as there should be few rows + # with a higher stream ordering. + # + # Note that we invalidate based on the device stream, so we only + # have to check for potential invalidations after the + # `now_stream_id`. + sql = """ + SELECT user_id FROM device_lists_stream + WHERE stream_id >= ? AND user_id = ? + """ + rows = await self.db_pool.execute( + "get_e2e_device_keys_for_federation_query_check", + None, + sql, + now_stream_id, + user_id, + ) + if not rows: + # No new rows, so cache is still valid. + return now_stream_id, cached_results + + # There has, so let's invalidate the cache and run the query. + self._get_e2e_device_keys_for_federation_query_inner.invalidate((user_id,)) + + results = await self._get_e2e_device_keys_for_federation_query_inner(user_id) + + return now_stream_id, results + + @cached(iterable=True) + async def _get_e2e_device_keys_for_federation_query_inner( + self, user_id: str + ) -> List[JsonDict]: + """Get all devices (with any device keys) for a user""" + devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)]) if devices: @@ -134,9 +197,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker results.append(result) - return now_stream_id, results + return results - return now_stream_id, [] + return [] @trace @cancellable diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index f777777cbf..fff417f9e3 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py
@@ -145,7 +145,7 @@ class FilteringWorkerStore(SQLBaseStore): @cached(num_args=2) async def get_user_filter( - self, user_localpart: str, filter_id: Union[int, str] + self, user_id: UserID, filter_id: Union[int, str] ) -> JsonDict: # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail # with a coherent error message rather than 500 M_UNKNOWN. @@ -156,7 +156,7 @@ class FilteringWorkerStore(SQLBaseStore): def_json = await self.db_pool.simple_select_one_onecol( table="user_filters", - keyvalues={"user_id": user_localpart, "filter_id": filter_id}, + keyvalues={"full_user_id": user_id.to_string(), "filter_id": filter_id}, retcol="filter_json", allow_none=False, desc="get_user_filter", @@ -172,15 +172,15 @@ class FilteringWorkerStore(SQLBaseStore): def _do_txn(txn: LoggingTransaction) -> int: sql = ( "SELECT filter_id FROM user_filters " - "WHERE user_id = ? AND filter_json = ?" + "WHERE full_user_id = ? AND filter_json = ?" ) - txn.execute(sql, (user_id.localpart, bytearray(def_json))) + txn.execute(sql, (user_id.to_string(), bytearray(def_json))) filter_id_response = txn.fetchone() if filter_id_response is not None: return filter_id_response[0] - sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?" - txn.execute(sql, (user_id.localpart,)) + sql = "SELECT MAX(filter_id) FROM user_filters WHERE full_user_id = ?" + txn.execute(sql, (user_id.to_string(),)) max_id = cast(Tuple[Optional[int]], txn.fetchone())[0] if max_id is None: filter_id = 0 diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 21d54c7a7a..3ba9cc8853 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py
@@ -137,11 +137,11 @@ class ProfileWorkerStore(SQLBaseStore): return 50 - async def get_profileinfo(self, user_localpart: str) -> ProfileInfo: + async def get_profileinfo(self, user_id: UserID) -> ProfileInfo: try: profile = await self.db_pool.simple_select_one( table="profiles", - keyvalues={"user_id": user_localpart}, + keyvalues={"full_user_id": user_id.to_string()}, retcols=("displayname", "avatar_url"), desc="get_profileinfo", ) @@ -156,18 +156,18 @@ class ProfileWorkerStore(SQLBaseStore): avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) - async def get_profile_displayname(self, user_localpart: str) -> Optional[str]: + async def get_profile_displayname(self, user_id: UserID) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="profiles", - keyvalues={"user_id": user_localpart}, + keyvalues={"full_user_id": user_id.to_string()}, retcol="displayname", desc="get_profile_displayname", ) - async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]: + async def get_profile_avatar_url(self, user_id: UserID) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="profiles", - keyvalues={"user_id": user_localpart}, + keyvalues={"full_user_id": user_id.to_string()}, retcol="avatar_url", desc="get_profile_avatar_url", ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 9f862f00c1..e098ceea3c 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -88,7 +88,6 @@ def _load_rules( msc1767_enabled=experimental_config.msc1767_enabled, msc3664_enabled=experimental_config.msc3664_enabled, msc3381_polls_enabled=experimental_config.msc3381_polls_enabled, - msc3952_intentional_mentions=experimental_config.msc3952_intentional_mentions, msc3958_suppress_edits_enabled=experimental_config.msc3958_supress_edit_notifs, ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 4a6c6c724d..96908f14ba 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py
@@ -365,6 +365,36 @@ class RelationsWorkerStore(SQLBaseStore): func=get_all_relation_ids_for_event_with_types_txn, ) + async def get_all_relations_for_event( + self, + event_id: str, + ) -> List[str]: + """Get the event IDs of all events that have a relation to the given event. + + Args: + event_id: The event for which to look for related events. + + Returns: + A list of the IDs of the events that relate to the given event. + """ + + def get_all_relation_ids_for_event_txn( + txn: LoggingTransaction, + ) -> List[str]: + rows = self.db_pool.simple_select_list_txn( + txn=txn, + table="event_relations", + keyvalues={"relates_to_id": event_id}, + retcols=["event_id"], + ) + + return [row["event_id"] for row in rows] + + return await self.db_pool.runInteraction( + desc="get_all_relation_ids_for_event", + func=get_all_relation_ids_for_event_txn, + ) + async def event_includes_relation(self, event_id: str) -> bool: """Check if the given event relates to another event.