summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/sync.py53
-rw-r--r--synapse/storage/state.py92
2 files changed, 73 insertions, 72 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index ddeed27965..1d0f0058a2 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -18,7 +18,7 @@ from ._base import BaseHandler
 from synapse.streams.config import PaginationConfig
 from synapse.api.constants import Membership, EventTypes
 from synapse.util import unwrapFirstError
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.util.logcontext import LoggingContext, preserve_fn
 from synapse.util.metrics import Measure
 
 from twisted.internet import defer
@@ -228,10 +228,14 @@ class SyncHandler(BaseHandler):
         invited = []
         archived = []
         deferreds = []
-        for event in room_list:
-            if event.membership == Membership.JOIN:
-                with PreserveLoggingContext(LoggingContext.current_context()):
-                    room_sync_deferred = self.full_state_sync_for_joined_room(
+
+        room_list_chunks = [room_list[i:i + 10] for i in xrange(0, len(room_list), 10)]
+        for room_list_chunk in room_list_chunks:
+            for event in room_list_chunk:
+                if event.membership == Membership.JOIN:
+                    room_sync_deferred = preserve_fn(
+                        self.full_state_sync_for_joined_room
+                    )(
                         room_id=event.room_id,
                         sync_config=sync_config,
                         now_token=now_token,
@@ -240,20 +244,21 @@ class SyncHandler(BaseHandler):
                         tags_by_room=tags_by_room,
                         account_data_by_room=account_data_by_room,
                     )
-                room_sync_deferred.addCallback(joined.append)
-                deferreds.append(room_sync_deferred)
-            elif event.membership == Membership.INVITE:
-                invite = yield self.store.get_event(event.event_id)
-                invited.append(InvitedSyncResult(
-                    room_id=event.room_id,
-                    invite=invite,
-                ))
-            elif event.membership in (Membership.LEAVE, Membership.BAN):
-                leave_token = now_token.copy_and_replace(
-                    "room_key", "s%d" % (event.stream_ordering,)
-                )
-                with PreserveLoggingContext(LoggingContext.current_context()):
-                    room_sync_deferred = self.full_state_sync_for_archived_room(
+                    room_sync_deferred.addCallback(joined.append)
+                    deferreds.append(room_sync_deferred)
+                elif event.membership == Membership.INVITE:
+                    invite = yield self.store.get_event(event.event_id)
+                    invited.append(InvitedSyncResult(
+                        room_id=event.room_id,
+                        invite=invite,
+                    ))
+                elif event.membership in (Membership.LEAVE, Membership.BAN):
+                    leave_token = now_token.copy_and_replace(
+                        "room_key", "s%d" % (event.stream_ordering,)
+                    )
+                    room_sync_deferred = preserve_fn(
+                        self.full_state_sync_for_archived_room
+                    )(
                         sync_config=sync_config,
                         room_id=event.room_id,
                         leave_event_id=event.event_id,
@@ -262,12 +267,12 @@ class SyncHandler(BaseHandler):
                         tags_by_room=tags_by_room,
                         account_data_by_room=account_data_by_room,
                     )
-                room_sync_deferred.addCallback(archived.append)
-                deferreds.append(room_sync_deferred)
+                    room_sync_deferred.addCallback(archived.append)
+                    deferreds.append(room_sync_deferred)
 
-        yield defer.gatherResults(
-            deferreds, consumeErrors=True
-        ).addErrback(unwrapFirstError)
+            yield defer.gatherResults(
+                deferreds, consumeErrors=True
+            ).addErrback(unwrapFirstError)
 
         account_data_for_user = sync_config.filter_collection.filter_account_data(
             self.account_data_for_user(account_data)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 6c32e8f7b3..372b540002 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -171,41 +171,43 @@ class StateStore(SQLBaseStore):
         events = yield self._get_events(event_ids, get_prev_content=False)
         defer.returnValue(events)
 
-    def _get_state_groups_from_groups(self, groups_and_types):
+    def _get_state_groups_from_groups(self, groups, types):
         """Returns dictionary state_group -> state event ids
-
-        Args:
-            groups_and_types (list): list of 2-tuple (`group`, `types`)
         """
-        def f(txn):
-            results = {}
-            for group, types in groups_and_types:
-                if types is not None:
-                    where_clause = "AND (%s)" % (
-                        " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
-                    )
-                else:
-                    where_clause = ""
-
-                sql = (
-                    "SELECT event_id FROM state_groups_state WHERE"
-                    " state_group = ? %s"
-                ) % (where_clause,)
+        def f(txn, groups):
+            if types is not None:
+                where_clause = "AND (%s)" % (
+                    " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
+                )
+            else:
+                where_clause = ""
 
-                args = [group]
-                if types is not None:
-                    args.extend([i for typ in types for i in typ])
+            sql = (
+                "SELECT state_group, event_id FROM state_groups_state WHERE"
+                " state_group IN (%s) %s" % (
+                    ",".join("?" for _ in groups),
+                    where_clause,
+                )
+            )
 
-                txn.execute(sql, args)
+            args = list(groups)
+            if types is not None:
+                args.extend([i for typ in types for i in typ])
 
-                results[group] = [r[0] for r in txn.fetchall()]
+            txn.execute(sql, args)
+            rows = self.cursor_to_dict(txn)
 
+            results = {}
+            for row in rows:
+                results.setdefault(row["state_group"], []).append(row["event_id"])
             return results
 
-        return self.runInteraction(
-            "_get_state_groups_from_groups",
-            f,
-        )
+        chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
+        for chunk in chunks:
+            return self.runInteraction(
+                "_get_state_groups_from_groups",
+                f, chunk
+            )
 
     @defer.inlineCallbacks
     def get_state_for_events(self, event_ids, types):
@@ -264,26 +266,20 @@ class StateStore(SQLBaseStore):
         )
 
     @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids",
-                num_args=1)
+                num_args=1, inlineCallbacks=True)
     def _get_state_group_for_events(self, event_ids):
         """Returns mapping event_id -> state_group
         """
-        def f(txn):
-            results = {}
-            for event_id in event_ids:
-                results[event_id] = self._simple_select_one_onecol_txn(
-                    txn,
-                    table="event_to_state_groups",
-                    keyvalues={
-                        "event_id": event_id,
-                    },
-                    retcol="state_group",
-                    allow_none=True,
-                )
-
-            return results
+        rows = yield self._simple_select_many_batch(
+            table="event_to_state_groups",
+            column="event_id",
+            iterable=event_ids,
+            keyvalues={},
+            retcols=("event_id", "state_group",),
+            desc="_get_state_group_for_events",
+        )
 
-        return self.runInteraction("_get_state_group_for_events", f)
+        defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
 
     def _get_some_state_from_cache(self, group, types):
         """Checks if group is in cache. See `_get_state_for_groups`
@@ -355,7 +351,7 @@ class StateStore(SQLBaseStore):
         all events are returned.
         """
         results = {}
-        missing_groups_and_types = []
+        missing_groups = []
         if types is not None:
             for group in set(groups):
                 state_dict, missing_types, got_all = self._get_some_state_from_cache(
@@ -364,7 +360,7 @@ class StateStore(SQLBaseStore):
                 results[group] = state_dict
 
                 if not got_all:
-                    missing_groups_and_types.append((group, missing_types))
+                    missing_groups.append(group)
         else:
             for group in set(groups):
                 state_dict, got_all = self._get_all_state_from_cache(
@@ -373,9 +369,9 @@ class StateStore(SQLBaseStore):
                 results[group] = state_dict
 
                 if not got_all:
-                    missing_groups_and_types.append((group, None))
+                    missing_groups.append(group)
 
-        if not missing_groups_and_types:
+        if not missing_groups:
             defer.returnValue({
                 group: {
                     type_tuple: event
@@ -389,7 +385,7 @@ class StateStore(SQLBaseStore):
         cache_seq_num = self._state_group_cache.sequence
 
         group_state_dict = yield self._get_state_groups_from_groups(
-            missing_groups_and_types
+            missing_groups, types
         )
 
         state_events = yield self._get_events(