summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-05-05 17:07:59 +0100
committerErik Johnston <erik@matrix.org>2020-05-05 17:40:29 +0100
commitf9073893af82eec64b594dbcaef37c407a291c52 (patch)
tree6b57fce16886aef99c6e796900822290866a38df /synapse
parentWorkaround for assertion errors from db_query_to_update_function (#7378) (diff)
downloadsynapse-f9073893af82eec64b594dbcaef37c407a291c52.tar.xz
Speed up fetching device lists changes in sync.
Currently we copy `users_who_share_room` needlessly about three times,
which is expensive when the set is large (which it can easily be).
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/sync.py12
-rw-r--r--synapse/storage/data_stores/main/devices.py4
-rw-r--r--synapse/util/caches/stream_change_cache.py19
3 files changed, 25 insertions, 10 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 4f76b7a743..00718d7f2d 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1143,10 +1143,14 @@ class SyncHandler(object):
                 user_id
             )
 
-            tracked_users = set(users_who_share_room)
-
-            # Always tell the user about their own devices
-            tracked_users.add(user_id)
+            # Always tell the user about their own devices. We check as the user
+            # ID is almost certainly already included (unless they're not in any
+            # rooms) and taking a copy of the set is relatively expensive.
+            if user_id not in users_who_share_room:
+                users_who_share_room = set(users_who_share_room)
+                users_who_share_room.add(user_id)
+
+            tracked_users = users_who_share_room
 
             # Step 1a, check for changes in devices of users we share a room with
             users_that_have_changed = await self.store.get_users_whose_devices_changed(
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index ee3a2ab031..03f5141e6c 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -541,8 +541,8 @@ class DeviceWorkerStore(SQLBaseStore):
 
         # Get set of users who *may* have changed. Users not in the returned
         # list have definitely not changed.
-        to_check = list(
-            self._device_list_stream_cache.get_entities_changed(user_ids, from_key)
+        to_check = self._device_list_stream_cache.get_entities_changed(
+            user_ids, from_key
         )
 
         if not to_check:
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 38dc3f501e..e54f80d76e 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -14,12 +14,13 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, Iterable, List, Mapping, Optional, Set
+from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union
 
 from six import integer_types
 
 from sortedcontainers import SortedDict
 
+from synapse.types import Collection
 from synapse.util import caches
 
 logger = logging.getLogger(__name__)
@@ -85,8 +86,8 @@ class StreamChangeCache:
         return False
 
     def get_entities_changed(
-        self, entities: Iterable[EntityType], stream_pos: int
-    ) -> Set[EntityType]:
+        self, entities: Collection[EntityType], stream_pos: int
+    ) -> Union[Set[EntityType], FrozenSet[EntityType]]:
         """
         Returns subset of entities that have had new things since the given
         position.  Entities unknown to the cache will be returned.  If the
@@ -94,7 +95,17 @@ class StreamChangeCache:
         """
         changed_entities = self.get_all_entities_changed(stream_pos)
         if changed_entities is not None:
-            result = set(changed_entities).intersection(entities)
+            # We now do an intersection, trying to do so in the most efficient
+            # way possible (some of these sets are *large*). First check in the
+            # given iterable is already set that we can reuse, otherwise we
+            # create a set of the *smallest* of the two iterables and call
+            # `intersection(..)` on it (this can be twice as fast as the reverse).
+            if isinstance(entities, (set, frozenset)):
+                result = entities.intersection(changed_entities)
+            elif len(changed_entities) < len(entities):
+                result = set(changed_entities).intersection(entities)
+            else:
+                result = set(entities).intersection(changed_entities)
             self.metrics.inc_hits()
         else:
             result = set(entities)