summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/registration.py8
-rw-r--r--synapse/crypto/keyring.py9
-rw-r--r--synapse/handlers/device.py13
-rw-r--r--synapse/storage/databases/main/cache.py2
-rw-r--r--synapse/storage/databases/main/events_worker.py14
-rw-r--r--synapse/util/caches/lrucache.py59
6 files changed, 92 insertions, 13 deletions
diff --git a/synapse/config/registration.py b/synapse/config/registration.py

index 9e2b1f3de1..3fe0f050cd 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py
@@ -237,6 +237,14 @@ class RegistrationConfig(Config): self.inhibit_user_in_use_error = config.get("inhibit_user_in_use_error", False) + # List of user IDs not to send out device list updates for when they + # register new devices. This is useful to handle bot accounts. + # + # Note: This will still send out device list updates if the device is + # later updated, e.g. end to end keys are added. + dont_notify_new_devices_for = config.get("dont_notify_new_devices_for", []) + self.dont_notify_new_devices_for = frozenset(dont_notify_new_devices_for) + def generate_config_section( self, generate_secrets: bool = False, **kwargs: Any ) -> str: diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 1e7e5f70fe..8c301e077c 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py
@@ -839,11 +839,12 @@ class ServerKeyFetcher(BaseV2KeyFetcher): Map from server_name -> key_id -> FetchKeyResult """ - results = {} + # We only need to do one request per server. + servers_to_fetch = {k.server_name for k in keys_to_fetch} - async def get_keys(key_to_fetch_item: _FetchKeyRequest) -> None: - server_name = key_to_fetch_item.server_name + results = {} + async def get_keys(server_name: str) -> None: try: keys = await self.get_server_verify_keys_v2_direct(server_name) results[server_name] = keys @@ -852,7 +853,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): except Exception: logger.exception("Error getting keys from %s", server_name) - await yieldable_gather_results(get_keys, keys_to_fetch) + await yieldable_gather_results(get_keys, servers_to_fetch) return results async def get_server_verify_keys_v2_direct( diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 9062fac91a..67953a3ed9 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py
@@ -429,6 +429,10 @@ class DeviceHandler(DeviceWorkerHandler): self._storage_controllers = hs.get_storage_controllers() self.db_pool = hs.get_datastores().main.db_pool + self._dont_notify_new_devices_for = ( + hs.config.registration.dont_notify_new_devices_for + ) + self.device_list_updater = DeviceListUpdater(hs, self) federation_registry = hs.get_federation_registry() @@ -505,6 +509,9 @@ class DeviceHandler(DeviceWorkerHandler): self._check_device_name_length(initial_device_display_name) + # Check if we should send out device lists updates for this new device. + notify = user_id not in self._dont_notify_new_devices_for + if device_id is not None: new_device = await self.store.store_device( user_id=user_id, @@ -514,7 +521,8 @@ class DeviceHandler(DeviceWorkerHandler): auth_provider_session_id=auth_provider_session_id, ) if new_device: - await self.notify_device_update(user_id, [device_id]) + if notify: + await self.notify_device_update(user_id, [device_id]) return device_id # if the device id is not specified, we'll autogen one, but loop a few @@ -530,7 +538,8 @@ class DeviceHandler(DeviceWorkerHandler): auth_provider_session_id=auth_provider_session_id, ) if new_device: - await self.notify_device_update(user_id, [new_device_id]) + if notify: + await self.notify_device_update(user_id, [new_device_id]) return new_device_id attempts += 1 diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 7314d87404..bfd492d95d 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py
@@ -373,7 +373,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): deleted. """ - self._invalidate_local_get_event_cache_all() # type: ignore[attr-defined] + self._invalidate_local_get_event_cache_room_id(room_id) # type: ignore[attr-defined] self._attempt_to_invalidate_cache("have_seen_event", (room_id,)) self._attempt_to_invalidate_cache("get_latest_event_ids_in_room", (room_id,)) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 1fd458b510..9c3775bb7c 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -268,6 +268,8 @@ class EventsWorkerStore(SQLBaseStore): ] = AsyncLruCache( cache_name="*getEvent*", max_size=hs.config.caches.event_cache_size, + # `extra_index_cb` Returns a tuple as that is the key type + extra_index_cb=lambda _, v: (v.event.room_id,), ) # Map from event ID to a deferred that will result in a map from event @@ -782,9 +784,9 @@ class EventsWorkerStore(SQLBaseStore): if missing_events_ids: - async def get_missing_events_from_cache_or_db() -> Dict[ - str, EventCacheEntry - ]: + async def get_missing_events_from_cache_or_db() -> ( + Dict[str, EventCacheEntry] + ): """Fetches the events in `missing_event_ids` from the database. Also creates entries in `self._current_event_fetches` to allow @@ -910,12 +912,12 @@ class EventsWorkerStore(SQLBaseStore): self._event_ref.pop(event_id, None) self._current_event_fetches.pop(event_id, None) - def _invalidate_local_get_event_cache_all(self) -> None: - """Clears the in-memory get event caches. + def _invalidate_local_get_event_cache_room_id(self, room_id: str) -> None: + """Clears the in-memory get event caches for a room. Used when we purge room history. """ - self._get_event_cache.clear() + self._get_event_cache.invalidate_on_extra_index_local((room_id,)) self._event_ref.clear() self._current_event_fetches.clear() diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 6e8c1e84ac..a1b4f5b6a7 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py
@@ -35,6 +35,7 @@ from typing import ( Iterable, List, Optional, + Set, Tuple, Type, TypeVar, @@ -386,6 +387,7 @@ class LruCache(Generic[KT, VT]): apply_cache_factor_from_config: bool = True, clock: Optional[Clock] = None, prune_unread_entries: bool = True, + extra_index_cb: Optional[Callable[[KT, VT], KT]] = None, ): """ Args: @@ -416,6 +418,20 @@ class LruCache(Generic[KT, VT]): prune_unread_entries: If True, cache entries that haven't been read recently will be evicted from the cache in the background. Set to False to opt-out of this behaviour. + + extra_index_cb: If provided, the cache keeps a second index from a + (different) key to a cache entry based on the return value of + the callback. This can then be used to invalidate entries based + on the second type of key. + + For example, for the event cache this would be a callback that + maps an event to its room ID, allowing invalidation of all + events in a given room. + + Note: Though the two types of key have the same type, they are + in different namespaces. + + Note: The new key does not have to be unique. """ # Default `clock` to something sensible. Note that we rename it to # `real_clock` so that mypy doesn't think its still `Optional`. @@ -463,6 +479,8 @@ class LruCache(Generic[KT, VT]): lock = threading.Lock() + extra_index: Dict[KT, Set[KT]] = {} + def evict() -> None: while cache_len() > self.max_size: # Get the last node in the list (i.e. the oldest node). @@ -521,6 +539,11 @@ class LruCache(Generic[KT, VT]): if size_callback: cached_cache_len[0] += size_callback(node.value) + if extra_index_cb: + index_key = extra_index_cb(node.key, node.value) + mapped_keys = extra_index.setdefault(index_key, set()) + mapped_keys.add(node.key) + if caches.TRACK_MEMORY_USAGE and metrics: metrics.inc_memory_usage(node.memory) @@ -537,6 +560,14 @@ class LruCache(Generic[KT, VT]): node.run_and_clear_callbacks() + if extra_index_cb: + index_key = extra_index_cb(node.key, node.value) + mapped_keys = extra_index.get(index_key) + if mapped_keys is not None: + mapped_keys.discard(node.key) + if not mapped_keys: + extra_index.pop(index_key, None) + if caches.TRACK_MEMORY_USAGE and metrics: metrics.dec_memory_usage(node.memory) @@ -748,6 +779,8 @@ class LruCache(Generic[KT, VT]): if size_callback: cached_cache_len[0] = 0 + extra_index.clear() + if caches.TRACK_MEMORY_USAGE and metrics: metrics.clear_memory_usage() @@ -755,6 +788,28 @@ class LruCache(Generic[KT, VT]): def cache_contains(key: KT) -> bool: return key in cache + @synchronized + def cache_invalidate_on_extra_index(index_key: KT) -> None: + """Invalidates all entries that match the given extra index key. + + This can only be called when `extra_index_cb` was specified. + """ + + assert extra_index_cb is not None + + keys = extra_index.pop(index_key, None) + if not keys: + return + + for key in keys: + node = cache.pop(key, None) + if not node: + continue + + evicted_len = delete_node(node) + if metrics: + metrics.inc_evictions(EvictionReason.invalidation, evicted_len) + # make sure that we clear out any excess entries after we get resized. self._on_resize = evict @@ -771,6 +826,7 @@ class LruCache(Generic[KT, VT]): self.len = synchronized(cache_len) self.contains = cache_contains self.clear = cache_clear + self.invalidate_on_extra_index = cache_invalidate_on_extra_index def __getitem__(self, key: KT) -> VT: result = self.get(key, _Sentinel.sentinel) @@ -864,6 +920,9 @@ class AsyncLruCache(Generic[KT, VT]): # This method should invalidate any external cache and then invalidate the LruCache. return self._lru_cache.invalidate(key) + def invalidate_on_extra_index_local(self, index_key: KT) -> None: + self._lru_cache.invalidate_on_extra_index(index_key) + def invalidate_local(self, key: KT) -> None: """Remove an entry from the local cache