summary refs log tree commit diff
path: root/synapse/handlers/sync.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/sync.py')
-rw-r--r--synapse/handlers/sync.py111
1 files changed, 62 insertions, 49 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 6bdb24baff..e2ddb628ff 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -16,9 +16,7 @@
 
 import itertools
 import logging
-from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple
-
-from six import iteritems, itervalues
+from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
 
 import attr
 from prometheus_client import Counter
@@ -33,6 +31,7 @@ from synapse.storage.state import StateFilter
 from synapse.types import (
     Collection,
     JsonDict,
+    MutableStateMap,
     RoomStreamToken,
     StateMap,
     StreamToken,
@@ -45,6 +44,9 @@ from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.metrics import Measure, measure_func
 from synapse.visibility import filter_events_for_client
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 # Debug logger for https://github.com/matrix-org/synapse/issues/4422
@@ -96,7 +98,12 @@ class TimelineBatch:
     __bool__ = __nonzero__  # python3
 
 
-@attr.s(slots=True, frozen=True)
+# We can't freeze this class, because we need to update it after it's instantiated to
+# update its unread count. This is because we calculate the unread count for a room only
+# if there are updates for it, which we check after the instance has been created.
+# This should not be a big deal because we update the notification counts afterwards as
+# well anyway.
+@attr.s(slots=True)
 class JoinedSyncResult:
     room_id = attr.ib(type=str)
     timeline = attr.ib(type=TimelineBatch)
@@ -105,6 +112,7 @@ class JoinedSyncResult:
     account_data = attr.ib(type=List[JsonDict])
     unread_notifications = attr.ib(type=JsonDict)
     summary = attr.ib(type=Optional[JsonDict])
+    unread_count = attr.ib(type=int)
 
     def __nonzero__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
@@ -238,8 +246,8 @@ class SyncResult:
     __bool__ = __nonzero__  # python3
 
 
-class SyncHandler(object):
-    def __init__(self, hs):
+class SyncHandler:
+    def __init__(self, hs: "HomeServer"):
         self.hs_config = hs.config
         self.store = hs.get_datastore()
         self.notifier = hs.get_notifier()
@@ -285,6 +293,7 @@ class SyncHandler(object):
             timeout,
             full_state,
         )
+        logger.debug("Returning sync response for %s", user_id)
         return res
 
     async def _wait_for_sync_for_user(
@@ -390,7 +399,7 @@ class SyncHandler(object):
                 # result returned by the event source is poor form (it might cache
                 # the object)
                 room_id = event["room_id"]
-                event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"}
+                event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
                 ephemeral_by_room.setdefault(room_id, []).append(event_copy)
 
             receipt_key = since_token.receipt_key if since_token else "0"
@@ -408,7 +417,7 @@ class SyncHandler(object):
             for event in receipts:
                 room_id = event["room_id"]
                 # exclude room id, as above
-                event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"}
+                event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
                 ephemeral_by_room.setdefault(room_id, []).append(event_copy)
 
         return now_token, ephemeral_by_room
@@ -422,10 +431,6 @@ class SyncHandler(object):
         potential_recents: Optional[List[EventBase]] = None,
         newly_joined_room: bool = False,
     ) -> TimelineBatch:
-        """
-        Returns:
-            a Deferred TimelineBatch
-        """
         with Measure(self.clock, "load_filtered_recents"):
             timeline_limit = sync_config.filter_collection.timeline_limit()
             block_all_timeline = (
@@ -454,7 +459,7 @@ class SyncHandler(object):
                     current_state_ids_map = await self.state.get_current_state_ids(
                         room_id
                     )
-                    current_state_ids = frozenset(itervalues(current_state_ids_map))
+                    current_state_ids = frozenset(current_state_ids_map.values())
 
                 recents = await filter_events_for_client(
                     self.storage,
@@ -509,7 +514,7 @@ class SyncHandler(object):
                     current_state_ids_map = await self.state.get_current_state_ids(
                         room_id
                     )
-                    current_state_ids = frozenset(itervalues(current_state_ids_map))
+                    current_state_ids = frozenset(current_state_ids_map.values())
 
                 loaded_recents = await filter_events_for_client(
                     self.storage,
@@ -593,7 +598,7 @@ class SyncHandler(object):
         room_id: str,
         sync_config: SyncConfig,
         batch: TimelineBatch,
-        state: StateMap[EventBase],
+        state: MutableStateMap[EventBase],
         now_token: StreamToken,
     ) -> Optional[JsonDict]:
         """ Works out a room summary block for this room, summarising the number
@@ -715,9 +720,8 @@ class SyncHandler(object):
         ]
 
         missing_hero_state = await self.store.get_events(missing_hero_event_ids)
-        missing_hero_state = missing_hero_state.values()
 
-        for s in missing_hero_state:
+        for s in missing_hero_state.values():
             cache.set(s.state_key, s.event_id)
             state[(EventTypes.Member, s.state_key)] = s
 
@@ -741,7 +745,7 @@ class SyncHandler(object):
         since_token: Optional[StreamToken],
         now_token: StreamToken,
         full_state: bool,
-    ) -> StateMap[EventBase]:
+    ) -> MutableStateMap[EventBase]:
         """ Works out the difference in state between the start of the timeline
         and the previous sync.
 
@@ -909,7 +913,7 @@ class SyncHandler(object):
                     logger.debug("filtering state from %r...", state_ids)
                     state_ids = {
                         t: event_id
-                        for t, event_id in iteritems(state_ids)
+                        for t, event_id in state_ids.items()
                         if cache.get(t[1]) != event_id
                     }
                     logger.debug("...to %r", state_ids)
@@ -935,7 +939,7 @@ class SyncHandler(object):
 
     async def unread_notifs_for_room_id(
         self, room_id: str, sync_config: SyncConfig
-    ) -> Optional[Dict[str, str]]:
+    ) -> Dict[str, int]:
         with Measure(self.clock, "unread_notifs_for_room_id"):
             last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
                 user_id=sync_config.user.to_string(),
@@ -943,15 +947,10 @@ class SyncHandler(object):
                 receipt_type="m.read",
             )
 
-            if last_unread_event_id:
-                notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
-                    room_id, sync_config.user.to_string(), last_unread_event_id
-                )
-                return notifs
-
-        # There is no new information in this period, so your notification
-        # count is whatever it was last time.
-        return None
+            notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
+                room_id, sync_config.user.to_string(), last_unread_event_id
+            )
+            return notifs
 
     async def generate_sync_result(
         self,
@@ -965,7 +964,7 @@ class SyncHandler(object):
         # this is due to some of the underlying streams not supporting the ability
         # to query up to a given point.
         # Always use the `now_token` in `SyncResultBuilder`
-        now_token = await self.event_sources.get_current_token()
+        now_token = self.event_sources.get_current_token()
 
         logger.debug(
             "Calculating sync response for %r between %s and %s",
@@ -992,10 +991,14 @@ class SyncHandler(object):
             joined_room_ids=joined_room_ids,
         )
 
+        logger.debug("Fetching account data")
+
         account_data_by_room = await self._generate_sync_entry_for_account_data(
             sync_result_builder
         )
 
+        logger.debug("Fetching room data")
+
         res = await self._generate_sync_entry_for_rooms(
             sync_result_builder, account_data_by_room
         )
@@ -1006,10 +1009,12 @@ class SyncHandler(object):
             since_token is None and sync_config.filter_collection.blocks_all_presence()
         )
         if self.hs_config.use_presence and not block_all_presence_data:
+            logger.debug("Fetching presence data")
             await self._generate_sync_entry_for_presence(
                 sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
             )
 
+        logger.debug("Fetching to-device data")
         await self._generate_sync_entry_for_to_device(sync_result_builder)
 
         device_lists = await self._generate_sync_entry_for_device_list(
@@ -1020,6 +1025,7 @@ class SyncHandler(object):
             newly_left_users=newly_left_users,
         )
 
+        logger.debug("Fetching OTK data")
         device_id = sync_config.device_id
         one_time_key_counts = {}  # type: JsonDict
         if device_id:
@@ -1027,6 +1033,7 @@ class SyncHandler(object):
                 user_id, device_id
             )
 
+        logger.debug("Fetching group data")
         await self._generate_sync_entry_for_groups(sync_result_builder)
 
         # debug for https://github.com/matrix-org/synapse/issues/4422
@@ -1037,6 +1044,7 @@ class SyncHandler(object):
                     "Sync result for newly joined room %s: %r", room_id, joined_room
                 )
 
+        logger.debug("Sync response calculation complete")
         return SyncResult(
             presence=sync_result_builder.presence,
             account_data=sync_result_builder.account_data,
@@ -1409,8 +1417,9 @@ class SyncHandler(object):
         newly_joined_rooms = room_changes.newly_joined_rooms
         newly_left_rooms = room_changes.newly_left_rooms
 
-        def handle_room_entries(room_entry):
-            return self._generate_room_entry(
+        async def handle_room_entries(room_entry):
+            logger.debug("Generating room entry for %s", room_entry.room_id)
+            res = await self._generate_room_entry(
                 sync_result_builder,
                 ignored_users,
                 room_entry,
@@ -1419,6 +1428,8 @@ class SyncHandler(object):
                 account_data=account_data_by_room.get(room_entry.room_id, {}),
                 always_include=sync_result_builder.full_state,
             )
+            logger.debug("Generated room entry for %s", room_entry.room_id)
+            return res
 
         await concurrently_execute(handle_room_entries, room_entries, 10)
 
@@ -1430,7 +1441,7 @@ class SyncHandler(object):
         if since_token:
             for joined_sync in sync_result_builder.joined:
                 it = itertools.chain(
-                    joined_sync.timeline.events, itervalues(joined_sync.state)
+                    joined_sync.timeline.events, joined_sync.state.values()
                 )
                 for event in it:
                     if event.type == EventTypes.Member:
@@ -1505,7 +1516,7 @@ class SyncHandler(object):
         newly_left_rooms = []
         room_entries = []
         invited = []
-        for room_id, events in iteritems(mem_change_events_by_room_id):
+        for room_id, events in mem_change_events_by_room_id.items():
             logger.debug(
                 "Membership changes in %s: [%s]",
                 room_id,
@@ -1762,7 +1773,7 @@ class SyncHandler(object):
         ignored_users: Set[str],
         room_builder: "RoomSyncResultBuilder",
         ephemeral: List[JsonDict],
-        tags: Optional[List[JsonDict]],
+        tags: Optional[Dict[str, Dict[str, Any]]],
         account_data: Dict[str, JsonDict],
         always_include: bool = False,
     ):
@@ -1878,7 +1889,7 @@ class SyncHandler(object):
             )
 
         if room_builder.rtype == "joined":
-            unread_notifications = {}  # type: Dict[str, str]
+            unread_notifications = {}  # type: Dict[str, int]
             room_sync = JoinedSyncResult(
                 room_id=room_id,
                 timeline=batch,
@@ -1887,14 +1898,16 @@ class SyncHandler(object):
                 account_data=account_data_events,
                 unread_notifications=unread_notifications,
                 summary=summary,
+                unread_count=0,
             )
 
             if room_sync or always_include:
                 notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
 
-                if notifs is not None:
-                    unread_notifications["notification_count"] = notifs["notify_count"]
-                    unread_notifications["highlight_count"] = notifs["highlight_count"]
+                unread_notifications["notification_count"] = notifs["notify_count"]
+                unread_notifications["highlight_count"] = notifs["highlight_count"]
+
+                room_sync.unread_count = notifs["unread_count"]
 
                 sync_result_builder.joined.append(room_sync)
 
@@ -1993,17 +2006,17 @@ def _calculate_state(
     event_id_to_key = {
         e: key
         for key, e in itertools.chain(
-            iteritems(timeline_contains),
-            iteritems(previous),
-            iteritems(timeline_start),
-            iteritems(current),
+            timeline_contains.items(),
+            previous.items(),
+            timeline_start.items(),
+            current.items(),
         )
     }
 
-    c_ids = set(itervalues(current))
-    ts_ids = set(itervalues(timeline_start))
-    p_ids = set(itervalues(previous))
-    tc_ids = set(itervalues(timeline_contains))
+    c_ids = set(current.values())
+    ts_ids = set(timeline_start.values())
+    p_ids = set(previous.values())
+    tc_ids = set(timeline_contains.values())
 
     # If we are lazyloading room members, we explicitly add the membership events
     # for the senders in the timeline into the state block returned by /sync,
@@ -2017,7 +2030,7 @@ def _calculate_state(
 
     if lazy_load_members:
         p_ids.difference_update(
-            e for t, e in iteritems(timeline_start) if t[0] == EventTypes.Member
+            e for t, e in timeline_start.items() if t[0] == EventTypes.Member
         )
 
     state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
@@ -2062,7 +2075,7 @@ class SyncResultBuilder:
 
 
 @attr.s
-class RoomSyncResultBuilder(object):
+class RoomSyncResultBuilder:
     """Stores information needed to create either a `JoinedSyncResult` or
     `ArchivedSyncResult`.