diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index a67fdb3c22..f677d048aa 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1941,6 +1941,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id,
stream_ids[-1],
)
+ txn.call_after(
+ self._get_e2e_device_keys_for_federation_query_inner.invalidate,
+ (user_id,),
+ )
min_stream_id = stream_ids[0]
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 4bc391f213..91ae9c457d 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -16,6 +16,7 @@
import abc
from typing import (
TYPE_CHECKING,
+ Any,
Collection,
Dict,
Iterable,
@@ -39,6 +40,7 @@ from synapse.appservice import (
TransactionUnusedFallbackKeys,
)
from synapse.logging.opentracing import log_kv, set_tag, trace
+from synapse.replication.tcp.streams._base import DeviceListsStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
DatabasePool,
@@ -104,6 +106,23 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
self.hs.config.federation.allow_device_name_lookup_over_federation
)
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[Any],
+ ) -> None:
+ if stream_name == DeviceListsStream.NAME:
+ for row in rows:
+ assert isinstance(row, DeviceListsStream.DeviceListsStreamRow)
+ if row.entity.startswith("@"):
+ self._get_e2e_device_keys_for_federation_query_inner.invalidate(
+ (row.entity,)
+ )
+
+ super().process_replication_rows(stream_name, instance_name, token, rows)
+
async def get_e2e_device_keys_for_federation_query(
self, user_id: str
) -> Tuple[int, List[JsonDict]]:
@@ -114,6 +133,50 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
"""
now_stream_id = self.get_device_stream_token()
+ # We need to be careful with the caching here, as we need to always
+ # return *all* persisted devices, however there may be a lag between a
+ # new device being persisted and the cache being invalidated.
+ cached_results = (
+ self._get_e2e_device_keys_for_federation_query_inner.cache.get_immediate(
+ user_id, None
+ )
+ )
+ if cached_results is not None:
+ # Check that there have been no new devices added by another worker
+ # after the cache. This should be quick as there should be few rows
+ # with a higher stream ordering.
+ #
+ # Note that we invalidate based on the device stream, so we only
+ # have to check for potential invalidations after the
+ # `now_stream_id`.
+ sql = """
+ SELECT user_id FROM device_lists_stream
+ WHERE stream_id >= ? AND user_id = ?
+ """
+ rows = await self.db_pool.execute(
+ "get_e2e_device_keys_for_federation_query_check",
+ None,
+ sql,
+ now_stream_id,
+ user_id,
+ )
+ if not rows:
+ # No new rows, so cache is still valid.
+ return now_stream_id, cached_results
+
+ # There has, so let's invalidate the cache and run the query.
+ self._get_e2e_device_keys_for_federation_query_inner.invalidate((user_id,))
+
+ results = await self._get_e2e_device_keys_for_federation_query_inner(user_id)
+
+ return now_stream_id, results
+
+ @cached(iterable=True)
+ async def _get_e2e_device_keys_for_federation_query_inner(
+ self, user_id: str
+ ) -> List[JsonDict]:
+ """Get all devices (with any device keys) for a user"""
+
devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
if devices:
@@ -134,9 +197,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
results.append(result)
- return now_stream_id, results
+ return results
- return now_stream_id, []
+ return []
@trace
@cancellable
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index f777777cbf..fff417f9e3 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -145,7 +145,7 @@ class FilteringWorkerStore(SQLBaseStore):
@cached(num_args=2)
async def get_user_filter(
- self, user_localpart: str, filter_id: Union[int, str]
+ self, user_id: UserID, filter_id: Union[int, str]
) -> JsonDict:
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
# with a coherent error message rather than 500 M_UNKNOWN.
@@ -156,7 +156,7 @@ class FilteringWorkerStore(SQLBaseStore):
def_json = await self.db_pool.simple_select_one_onecol(
table="user_filters",
- keyvalues={"user_id": user_localpart, "filter_id": filter_id},
+ keyvalues={"full_user_id": user_id.to_string(), "filter_id": filter_id},
retcol="filter_json",
allow_none=False,
desc="get_user_filter",
@@ -172,15 +172,15 @@ class FilteringWorkerStore(SQLBaseStore):
def _do_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT filter_id FROM user_filters "
- "WHERE user_id = ? AND filter_json = ?"
+ "WHERE full_user_id = ? AND filter_json = ?"
)
- txn.execute(sql, (user_id.localpart, bytearray(def_json)))
+ txn.execute(sql, (user_id.to_string(), bytearray(def_json)))
filter_id_response = txn.fetchone()
if filter_id_response is not None:
return filter_id_response[0]
- sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
- txn.execute(sql, (user_id.localpart,))
+ sql = "SELECT MAX(filter_id) FROM user_filters WHERE full_user_id = ?"
+ txn.execute(sql, (user_id.to_string(),))
max_id = cast(Tuple[Optional[int]], txn.fetchone())[0]
if max_id is None:
filter_id = 0
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 21d54c7a7a..3ba9cc8853 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -137,11 +137,11 @@ class ProfileWorkerStore(SQLBaseStore):
return 50
- async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
+ async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
try:
profile = await self.db_pool.simple_select_one(
table="profiles",
- keyvalues={"user_id": user_localpart},
+ keyvalues={"full_user_id": user_id.to_string()},
retcols=("displayname", "avatar_url"),
desc="get_profileinfo",
)
@@ -156,18 +156,18 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
- async def get_profile_displayname(self, user_localpart: str) -> Optional[str]:
+ async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
- keyvalues={"user_id": user_localpart},
+ keyvalues={"full_user_id": user_id.to_string()},
retcol="displayname",
desc="get_profile_displayname",
)
- async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
+ async def get_profile_avatar_url(self, user_id: UserID) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
- keyvalues={"user_id": user_localpart},
+ keyvalues={"full_user_id": user_id.to_string()},
retcol="avatar_url",
desc="get_profile_avatar_url",
)
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 9f862f00c1..e098ceea3c 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -88,7 +88,6 @@ def _load_rules(
msc1767_enabled=experimental_config.msc1767_enabled,
msc3664_enabled=experimental_config.msc3664_enabled,
msc3381_polls_enabled=experimental_config.msc3381_polls_enabled,
- msc3952_intentional_mentions=experimental_config.msc3952_intentional_mentions,
msc3958_suppress_edits_enabled=experimental_config.msc3958_supress_edit_notifs,
)
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 4a6c6c724d..96908f14ba 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -365,6 +365,36 @@ class RelationsWorkerStore(SQLBaseStore):
func=get_all_relation_ids_for_event_with_types_txn,
)
+ async def get_all_relations_for_event(
+ self,
+ event_id: str,
+ ) -> List[str]:
+ """Get the event IDs of all events that have a relation to the given event.
+
+ Args:
+ event_id: The event for which to look for related events.
+
+ Returns:
+ A list of the IDs of the events that relate to the given event.
+ """
+
+ def get_all_relation_ids_for_event_txn(
+ txn: LoggingTransaction,
+ ) -> List[str]:
+ rows = self.db_pool.simple_select_list_txn(
+ txn=txn,
+ table="event_relations",
+ keyvalues={"relates_to_id": event_id},
+ retcols=["event_id"],
+ )
+
+ return [row["event_id"] for row in rows]
+
+ return await self.db_pool.runInteraction(
+ desc="get_all_relation_ids_for_event",
+ func=get_all_relation_ids_for_event_txn,
+ )
+
async def event_includes_relation(self, event_id: str) -> bool:
"""Check if the given event relates to another event.
|