summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/_base.py19
-rw-r--r--synapse/handlers/sync.py50
2 files changed, 35 insertions, 34 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 5c7617de44..46abb8ec51 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -53,16 +53,23 @@ class BaseHandler(object):
         self.event_builder_factory = hs.get_event_builder_factory()
 
     @defer.inlineCallbacks
-    def _filter_events_for_clients(self, users, events):
+    def _filter_events_for_clients(self, user_tuples, events):
         """ Returns dict of user_id -> list of events that user is allowed to
         see.
         """
-        event_id_to_state = yield self.store.get_state_for_events(
-            frozenset(e.event_id for e in events),
-            types=(
+        # If there is only one user, just get the state for that one user,
+        # otherwise just get all the state.
+        if len(user_tuples) == 1:
+            types = (
                 (EventTypes.RoomHistoryVisibility, ""),
-                (EventTypes.Member, None),
+                (EventTypes.Member, user_tuples[0][0]),
             )
+        else:
+            types = None
+
+        event_id_to_state = yield self.store.get_state_for_events(
+            frozenset(e.event_id for e in events),
+            types=types
         )
 
         forgotten = yield defer.gatherResults([
@@ -122,7 +129,7 @@ class BaseHandler(object):
                 for event in events
                 if allowed(event, user_id, is_guest)
             ]
-            for user_id, is_guest in users
+            for user_id, is_guest in user_tuples
         })
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index d2864977b0..aca200c1e7 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -54,8 +54,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
     "state",             # dict[(str, str), FrozenEvent]
     "ephemeral",
     "account_data",
-    "unread_notification_count",
-    "unread_highlight_count",
+    "unread_notifications",
 ])):
     __slots__ = []
 
@@ -294,11 +293,10 @@ class SyncHandler(BaseHandler):
             room_id, sync_config, ephemeral_by_room
         )
 
-        notif_count = None
-        highlight_count = None
+        unread_notifications = {}
         if notifs is not None:
-            notif_count = len(notifs)
-            highlight_count = len([
+            unread_notifications["notification_count"] = len(notifs)
+            unread_notifications["highlight_count"] = len([
                 1 for notif in notifs if _action_has_highlight(notif["actions"])
             ])
 
@@ -312,8 +310,7 @@ class SyncHandler(BaseHandler):
             account_data=self.account_data_for_room(
                 room_id, tags_by_room, account_data_by_room
             ),
-            unread_notification_count=notif_count,
-            unread_highlight_count=highlight_count,
+            unread_notifications=unread_notifications,
         ))
 
     def account_data_for_user(self, account_data):
@@ -533,18 +530,6 @@ class SyncHandler(BaseHandler):
                 else:
                     prev_batch = now_token
 
-                notifs = yield self.unread_notifs_for_room_id(
-                    room_id, sync_config, all_ephemeral_by_room
-                )
-
-                notif_count = None
-                highlight_count = None
-                if notifs is not None:
-                    notif_count = len(notifs)
-                    highlight_count = len([
-                        1 for notif in notifs if _action_has_highlight(notif["actions"])
-                    ])
-
                 just_joined = yield self.check_joined_room(sync_config, state)
                 if just_joined:
                     logger.debug("User has just joined %s: needs full state",
@@ -565,12 +550,23 @@ class SyncHandler(BaseHandler):
                     account_data=self.account_data_for_room(
                         room_id, tags_by_room, account_data_by_room
                     ),
-                    unread_notification_count=notif_count,
-                    unread_highlight_count=highlight_count,
+                    unread_notifications={},
                 )
                 logger.debug("Result for room %s: %r", room_id, room_sync)
 
                 if room_sync:
+                    notifs = yield self.unread_notifs_for_room_id(
+                        room_id, sync_config, all_ephemeral_by_room
+                    )
+
+                    if notifs is not None:
+                        notif_dict = room_sync.unread_notifications
+                        notif_dict["notification_count"] = len(notifs)
+                        notif_dict["highlight_count"] = len([
+                            1 for notif in notifs
+                            if _action_has_highlight(notif["actions"])
+                        ])
+
                     joined.append(room_sync)
 
         else:
@@ -708,11 +704,10 @@ class SyncHandler(BaseHandler):
             room_id, sync_config, all_ephemeral_by_room
         )
 
-        notif_count = None
-        highlight_count = None
+        unread_notifications = {}
         if notifs is not None:
-            notif_count = len(notifs)
-            highlight_count = len([
+            unread_notifications["notification_count"] = len(notifs)
+            unread_notifications["highlight_count"] = len([
                 1 for notif in notifs if _action_has_highlight(notif["actions"])
             ])
 
@@ -724,8 +719,7 @@ class SyncHandler(BaseHandler):
             account_data=self.account_data_for_room(
                 room_id, tags_by_room, account_data_by_room
             ),
-            unread_notification_count=notif_count,
-            unread_highlight_count=highlight_count,
+            unread_notifications=unread_notifications,
         )
 
         logger.debug("Room sync: %r", room_sync)