diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 4f76b7a743..00718d7f2d 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1143,10 +1143,14 @@ class SyncHandler(object):
user_id
)
- tracked_users = set(users_who_share_room)
-
- # Always tell the user about their own devices
- tracked_users.add(user_id)
+ # Always tell the user about their own devices. We check as the user
+ # ID is almost certainly already included (unless they're not in any
+ # rooms) and taking a copy of the set is relatively expensive.
+ if user_id not in users_who_share_room:
+ users_who_share_room = set(users_who_share_room)
+ users_who_share_room.add(user_id)
+
+ tracked_users = users_who_share_room
# Step 1a, check for changes in devices of users we share a room with
users_that_have_changed = await self.store.get_users_whose_devices_changed(
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index ee3a2ab031..03f5141e6c 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -541,8 +541,8 @@ class DeviceWorkerStore(SQLBaseStore):
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
- to_check = list(
- self._device_list_stream_cache.get_entities_changed(user_ids, from_key)
+ to_check = self._device_list_stream_cache.get_entities_changed(
+ user_ids, from_key
)
if not to_check:
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 38dc3f501e..e54f80d76e 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -14,12 +14,13 @@
# limitations under the License.
import logging
-from typing import Dict, Iterable, List, Mapping, Optional, Set
+from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union
from six import integer_types
from sortedcontainers import SortedDict
+from synapse.types import Collection
from synapse.util import caches
logger = logging.getLogger(__name__)
@@ -85,8 +86,8 @@ class StreamChangeCache:
return False
def get_entities_changed(
- self, entities: Iterable[EntityType], stream_pos: int
- ) -> Set[EntityType]:
+ self, entities: Collection[EntityType], stream_pos: int
+ ) -> Union[Set[EntityType], FrozenSet[EntityType]]:
"""
Returns subset of entities that have had new things since the given
position. Entities unknown to the cache will be returned. If the
@@ -94,7 +95,17 @@ class StreamChangeCache:
"""
changed_entities = self.get_all_entities_changed(stream_pos)
if changed_entities is not None:
- result = set(changed_entities).intersection(entities)
+ # We now do an intersection, trying to do so in the most efficient
+ # way possible (some of these sets are *large*). First check in the
+ # given iterable is already set that we can reuse, otherwise we
+ # create a set of the *smallest* of the two iterables and call
+ # `intersection(..)` on it (this can be twice as fast as the reverse).
+ if isinstance(entities, (set, frozenset)):
+ result = entities.intersection(changed_entities)
+ elif len(changed_entities) < len(entities):
+ result = set(changed_entities).intersection(entities)
+ else:
+ result = set(entities).intersection(changed_entities)
self.metrics.inc_hits()
else:
result = set(entities)
|