summary refs log tree commit diff
path: root/synapse/handlers/presence.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/presence.py')
-rw-r--r--synapse/handlers/presence.py163
1 files changed, 102 insertions, 61 deletions
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 40794187b1..670c1d353f 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -146,6 +146,10 @@ class PresenceHandler(BaseHandler):
         self._user_cachemap = {}
         self._user_cachemap_latest_serial = 0
 
+        # map room_ids to the latest presence serial for a member of that
+        # room
+        self._room_serials = {}
+
         metrics.register_callback(
             "userCachemap:size",
             lambda: len(self._user_cachemap),
@@ -297,13 +301,34 @@ class PresenceHandler(BaseHandler):
 
         self.changed_presencelike_data(user, {"last_active": now})
 
+    def get_joined_rooms_for_user(self, user):
+        """Get the list of rooms a user is joined to.
+
+        Args:
+            user(UserID): The user.
+        Returns:
+            A Deferred of a list of room id strings.
+        """
+        rm_handler = self.homeserver.get_handlers().room_member_handler
+        return rm_handler.get_joined_rooms_for_user(user)
+
+    def get_joined_users_for_room_id(self, room_id):
+        rm_handler = self.homeserver.get_handlers().room_member_handler
+        return rm_handler.get_room_members(room_id)
+
+    @defer.inlineCallbacks
     def changed_presencelike_data(self, user, state):
-        statuscache = self._get_or_make_usercache(user)
+        """Updates the presence state of a local user.
 
+        Args:
+            user(UserID): The user being updated.
+            state(dict): The new presence state for the user.
+        Returns:
+            A Deferred
+        """
         self._user_cachemap_latest_serial += 1
-        statuscache.update(state, serial=self._user_cachemap_latest_serial)
-
-        return self.push_presence(user, statuscache=statuscache)
+        statuscache = yield self.update_presence_cache(user, state)
+        yield self.push_presence(user, statuscache=statuscache)
 
     @log_function
     def started_user_eventstream(self, user):
@@ -326,13 +351,12 @@ class PresenceHandler(BaseHandler):
             room_id(str): The room id the user joined.
         """
         if self.hs.is_mine(user):
-            statuscache = self._get_or_make_usercache(user)
-
             # No actual update but we need to bump the serial anyway for the
             # event source
             self._user_cachemap_latest_serial += 1
-            statuscache.update({}, serial=self._user_cachemap_latest_serial)
-
+            statuscache = yield self.update_presence_cache(
+                user, room_ids=[room_id]
+            )
             self.push_update_to_local_and_remote(
                 observed_user=user,
                 room_ids=[room_id],
@@ -340,16 +364,17 @@ class PresenceHandler(BaseHandler):
             )
 
         # We also want to tell them about current presence of people.
-        rm_handler = self.homeserver.get_handlers().room_member_handler
-        curr_users = yield rm_handler.get_room_members(room_id)
+        curr_users = yield self.get_joined_users_for_room_id(room_id)
 
         for local_user in [c for c in curr_users if self.hs.is_mine(c)]:
-            statuscache = self._get_or_offline_usercache(local_user)
-            statuscache.update({}, serial=self._user_cachemap_latest_serial)
+            statuscache = yield self.update_presence_cache(
+                local_user, room_ids=[room_id], add_to_cache=False
+            )
+
             self.push_update_to_local_and_remote(
                 observed_user=local_user,
                 users_to_push=[user],
-                statuscache=self._get_or_offline_usercache(local_user),
+                statuscache=statuscache,
             )
 
     @defer.inlineCallbacks
@@ -546,8 +571,7 @@ class PresenceHandler(BaseHandler):
 
             # Also include people in all my rooms
 
-            rm_handler = self.homeserver.get_handlers().room_member_handler
-            room_ids = yield rm_handler.get_joined_rooms_for_user(user)
+            room_ids = yield self.get_joined_rooms_for_user(user)
 
         if state is None:
             state = yield self.store.get_presence_state(user.localpart)
@@ -747,8 +771,7 @@ class PresenceHandler(BaseHandler):
         # and also user is informed of server-forced pushes
         localusers.add(user)
 
-        rm_handler = self.homeserver.get_handlers().room_member_handler
-        room_ids = yield rm_handler.get_joined_rooms_for_user(user)
+        room_ids = yield self.get_joined_rooms_for_user(user)
 
         if not localusers and not room_ids:
             defer.returnValue(None)
@@ -793,8 +816,7 @@ class PresenceHandler(BaseHandler):
                     " | %d interested local observers %r", len(observers), observers
                 )
 
-            rm_handler = self.homeserver.get_handlers().room_member_handler
-            room_ids = yield rm_handler.get_joined_rooms_for_user(user)
+            room_ids = yield self.get_joined_rooms_for_user(user)
             if room_ids:
                 logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids)
 
@@ -813,10 +835,8 @@ class PresenceHandler(BaseHandler):
                     self.clock.time_msec() - state.pop("last_active_ago")
                 )
 
-            statuscache = self._get_or_make_usercache(user)
-
             self._user_cachemap_latest_serial += 1
-            statuscache.update(state, serial=self._user_cachemap_latest_serial)
+            yield self.update_presence_cache(user, state, room_ids=room_ids)
 
             if not observers and not room_ids:
                 logger.debug(" | no interested observers or room IDs")
@@ -875,6 +895,35 @@ class PresenceHandler(BaseHandler):
         yield defer.DeferredList(deferreds, consumeErrors=True)
 
     @defer.inlineCallbacks
+    def update_presence_cache(self, user, state={}, room_ids=None,
+                              add_to_cache=True):
+        """Update the presence cache for a user with a new state and bump the
+        serial to the latest value.
+
+        Args:
+            user(UserID): The user being updated
+            state(dict): The presence state being updated
+            room_ids(None or list of str): A list of room_ids to update. If
+                room_ids is None then fetch the list of room_ids the user is
+                joined to.
+            add_to_cache: Whether to add an entry to the presence cache if the
+                user isn't already in the cache.
+        Returns:
+            A Deferred UserPresenceCache for the user being updated.
+        """
+        if room_ids is None:
+            room_ids = yield self.get_joined_rooms_for_user(user)
+
+        for room_id in room_ids:
+            self._room_serials[room_id] = self._user_cachemap_latest_serial
+        if add_to_cache:
+            statuscache = self._get_or_make_usercache(user)
+        else:
+            statuscache = self._get_or_offline_usercache(user)
+        statuscache.update(state, serial=self._user_cachemap_latest_serial)
+        defer.returnValue(statuscache)
+
+    @defer.inlineCallbacks
     def push_update_to_local_and_remote(self, observed_user, statuscache,
                                         users_to_push=[], room_ids=[],
                                         remote_domains=[]):
@@ -997,38 +1046,10 @@ class PresenceEventSource(object):
         self.clock = hs.get_clock()
 
     @defer.inlineCallbacks
-    def is_visible(self, observer_user, observed_user):
-        if observer_user == observed_user:
-            defer.returnValue(True)
-
-        presence = self.hs.get_handlers().presence_handler
-
-        if (yield presence.store.user_rooms_intersect(
-                [u.to_string() for u in observer_user, observed_user])):
-            defer.returnValue(True)
-
-        if self.hs.is_mine(observed_user):
-            pushmap = presence._local_pushmap
-
-            defer.returnValue(
-                observed_user.localpart in pushmap and
-                observer_user in pushmap[observed_user.localpart]
-            )
-        else:
-            recvmap = presence._remote_recvmap
-
-            defer.returnValue(
-                observed_user in recvmap and
-                observer_user in recvmap[observed_user]
-            )
-
-    @defer.inlineCallbacks
     @log_function
     def get_new_events_for_user(self, user, from_key, limit):
         from_key = int(from_key)
 
-        observer_user = user
-
         presence = self.hs.get_handlers().presence_handler
         cachemap = presence._user_cachemap
 
@@ -1037,17 +1058,27 @@ class PresenceEventSource(object):
         clock = self.clock
         latest_serial = 0
 
+        user_ids_to_check = {user}
+        presence_list = yield presence.store.get_presence_list(
+            user.localpart, accepted=True
+        )
+        if presence_list is not None:
+            user_ids_to_check |= set(
+                UserID.from_string(p["observed_user_id"]) for p in presence_list
+            )
+        room_ids = yield presence.get_joined_rooms_for_user(user)
+        for room_id in set(room_ids) & set(presence._room_serials):
+            if presence._room_serials[room_id] > from_key:
+                joined = yield presence.get_joined_users_for_room_id(room_id)
+                user_ids_to_check |= set(joined)
+
         updates = []
-        # TODO(paul): use a DeferredList ? How to limit concurrency.
-        for observed_user in cachemap.keys():
+        for observed_user in user_ids_to_check & set(cachemap):
             cached = cachemap[observed_user]
 
             if cached.serial <= from_key or cached.serial > max_serial:
                 continue
 
-            if not (yield self.is_visible(observer_user, observed_user)):
-                continue
-
             latest_serial = max(cached.serial, latest_serial)
             updates.append(cached.make_event(user=observed_user, clock=clock))
 
@@ -1084,8 +1115,6 @@ class PresenceEventSource(object):
     def get_pagination_rows(self, user, pagination_config, key):
         # TODO (erikj): Does this make sense? Ordering?
 
-        observer_user = user
-
         from_key = int(pagination_config.from_key)
 
         if pagination_config.to_key:
@@ -1096,14 +1125,26 @@ class PresenceEventSource(object):
         presence = self.hs.get_handlers().presence_handler
         cachemap = presence._user_cachemap
 
+        user_ids_to_check = {user}
+        presence_list = yield presence.store.get_presence_list(
+            user.localpart, accepted=True
+        )
+        if presence_list is not None:
+            user_ids_to_check |= set(
+                UserID.from_string(p["observed_user_id"]) for p in presence_list
+            )
+        room_ids = yield presence.get_joined_rooms_for_user(user)
+        for room_id in set(room_ids) & set(presence._room_serials):
+            if presence._room_serials[room_id] >= from_key:
+                joined = yield presence.get_joined_users_for_room_id(room_id)
+                user_ids_to_check |= set(joined)
+
         updates = []
-        # TODO(paul): use a DeferredList ? How to limit concurrency.
-        for observed_user in cachemap.keys():
+        for observed_user in user_ids_to_check & set(cachemap):
             if not (to_key < cachemap[observed_user].serial <= from_key):
                 continue
 
-            if (yield self.is_visible(observer_user, observed_user)):
-                updates.append((observed_user, cachemap[observed_user]))
+            updates.append((observed_user, cachemap[observed_user]))
 
         # TODO(paul): limit