summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/__init__.py2
-rw-r--r--synapse/storage/_base.py256
-rw-r--r--synapse/storage/schema/delta/19/event_index.sql19
-rw-r--r--synapse/storage/state.py21
-rw-r--r--synapse/storage/stream.py22
5 files changed, 287 insertions, 33 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 7cb91a0be9..75af44d787 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 18
+SCHEMA_VERSION = 19
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 81052409b7..46a1c07460 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -28,6 +28,7 @@ from twisted.internet import defer
 
 from collections import namedtuple, OrderedDict
 import functools
+import itertools
 import simplejson as json
 import sys
 import time
@@ -867,28 +868,37 @@ class SQLBaseStore(object):
 
         return self.runInteraction("_simple_max_id", func)
 
+    @defer.inlineCallbacks
     def _get_events(self, event_ids, check_redacted=True,
-                    get_prev_content=False):
-        return self.runInteraction(
-            "_get_events", self._get_events_txn, event_ids,
-            check_redacted=check_redacted, get_prev_content=get_prev_content,
+                    get_prev_content=False, desc="_get_events"):
+        N = 50  # Only fetch 100 events at a time.
+
+        ds = [
+            self._fetch_events(
+                event_ids[i*N:(i+1)*N],
+                check_redacted=check_redacted,
+                get_prev_content=get_prev_content,
+            )
+            for i in range(1 + len(event_ids) / N)
+        ]
+
+        res = yield defer.gatherResults(ds, consumeErrors=True)
+
+        defer.returnValue(
+            list(itertools.chain(*res))
         )
 
     def _get_events_txn(self, txn, event_ids, check_redacted=True,
                         get_prev_content=False):
-        if not event_ids:
-            return []
-
-        events = [
-            self._get_event_txn(
-                txn, event_id,
+        N = 50  # Only fetch 100 events at a time.
+        return list(itertools.chain(*[
+            self._fetch_events_txn(
+                txn, event_ids[i*N:(i+1)*N],
                 check_redacted=check_redacted,
-                get_prev_content=get_prev_content
+                get_prev_content=get_prev_content,
             )
-            for event_id in event_ids
-        ]
-
-        return [e for e in events if e]
+            for i in range(1 + len(event_ids) / N)
+        ]))
 
     def _invalidate_get_event_cache(self, event_id):
         for check_redacted in (False, True):
@@ -919,10 +929,10 @@ class SQLBaseStore(object):
             start_time = update_counter("event_cache", start_time)
 
         sql = (
-            "SELECT e.internal_metadata, e.json, r.event_id, rej.reason "
+            "SELECT e.internal_metadata, e.json, r.redacts, rej.event_id "
             "FROM event_json as e "
+            "LEFT JOIN rejections as rej USING (event_id) "
             "LEFT JOIN redactions as r ON e.event_id = r.redacts "
-            "LEFT JOIN rejections as rej on rej.event_id = e.event_id  "
             "WHERE e.event_id = ? "
             "LIMIT 1 "
         )
@@ -951,6 +961,199 @@ class SQLBaseStore(object):
         else:
             return None
 
+    def _fetch_events_txn(self, txn, events, check_redacted=True,
+                          get_prev_content=False, allow_rejected=False):
+        if not events:
+            return []
+
+        event_map = {}
+
+        for event_id in events:
+            try:
+                ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
+
+                if allow_rejected or not ret.rejected_reason:
+                    event_map[event_id] = ret
+                else:
+                    event_map[event_id] = None
+            except KeyError:
+                pass
+
+        missing_events = [
+            e for e in events
+            if e not in event_map
+        ]
+
+        if missing_events:
+            sql = (
+                "SELECT e.internal_metadata, e.json, r.redacts, rej.event_id "
+                " FROM event_json as e"
+                " LEFT JOIN rejections as rej USING (event_id)"
+                " LEFT JOIN redactions as r ON e.event_id = r.redacts"
+                " WHERE e.event_id IN (%s)"
+            ) % (",".join(["?"]*len(missing_events)),)
+
+            txn.execute(sql, missing_events)
+            rows = txn.fetchall()
+
+            res = [
+                self._get_event_from_row_txn(
+                    txn, row[0], row[1], row[2],
+                    check_redacted=check_redacted,
+                    get_prev_content=get_prev_content,
+                    rejected_reason=row[3],
+                )
+                for row in rows
+            ]
+
+            event_map.update({
+                e.event_id: e
+                for e in res if e
+            })
+
+            for e in res:
+                self._get_event_cache.prefill(
+                    e.event_id, check_redacted, get_prev_content, e
+                )
+
+        return [
+            event_map[e_id] for e_id in events
+            if e_id in event_map and event_map[e_id]
+        ]
+
+    @defer.inlineCallbacks
+    def _fetch_events(self, events, check_redacted=True,
+                      get_prev_content=False, allow_rejected=False):
+        if not events:
+            defer.returnValue([])
+
+        event_map = {}
+
+        for event_id in events:
+            try:
+                ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
+
+                if allow_rejected or not ret.rejected_reason:
+                    event_map[event_id] = ret
+                else:
+                    event_map[event_id] = None
+            except KeyError:
+                pass
+
+        missing_events = [
+            e for e in events
+            if e not in event_map
+        ]
+
+        if missing_events:
+            sql = (
+                "SELECT e.internal_metadata, e.json, r.redacts, rej.event_id "
+                " FROM event_json as e"
+                " LEFT JOIN rejections as rej USING (event_id)"
+                " LEFT JOIN redactions as r ON e.event_id = r.redacts"
+                " WHERE e.event_id IN (%s)"
+            ) % (",".join(["?"]*len(missing_events)),)
+
+            rows = yield self._execute(
+                "_fetch_events",
+                None,
+                sql,
+                *missing_events
+            )
+
+            res_ds = [
+                self._get_event_from_row(
+                    row[0], row[1], row[2],
+                    check_redacted=check_redacted,
+                    get_prev_content=get_prev_content,
+                    rejected_reason=row[3],
+                )
+                for row in rows
+            ]
+
+            res = yield defer.gatherResults(res_ds, consumeErrors=True)
+
+            event_map.update({
+                e.event_id: e
+                for e in res if e
+            })
+
+            for e in res:
+                self._get_event_cache.prefill(
+                    e.event_id, check_redacted, get_prev_content, e
+                )
+
+        defer.returnValue([
+            event_map[e_id] for e_id in events
+            if e_id in event_map and event_map[e_id]
+        ])
+
+    @defer.inlineCallbacks
+    def _get_event_from_row(self, internal_metadata, js, redacted,
+                            check_redacted=True, get_prev_content=False,
+                            rejected_reason=None):
+
+        start_time = time.time() * 1000
+
+        def update_counter(desc, last_time):
+            curr_time = self._get_event_counters.update(desc, last_time)
+            sql_getevents_timer.inc_by(curr_time - last_time, desc)
+            return curr_time
+
+        d = json.loads(js)
+        start_time = update_counter("decode_json", start_time)
+
+        internal_metadata = json.loads(internal_metadata)
+        start_time = update_counter("decode_internal", start_time)
+
+        if rejected_reason:
+            rejected_reason = yield self._simple_select_one_onecol(
+                desc="_get_event_from_row",
+                table="rejections",
+                keyvalues={"event_id": rejected_reason},
+                retcol="reason",
+            )
+
+        ev = FrozenEvent(
+            d,
+            internal_metadata_dict=internal_metadata,
+            rejected_reason=rejected_reason,
+        )
+        start_time = update_counter("build_frozen_event", start_time)
+
+        if check_redacted and redacted:
+            ev = prune_event(ev)
+
+            redaction_id = yield self._simple_select_one_onecol(
+                desc="_get_event_from_row",
+                table="redactions",
+                keyvalues={"redacts": ev.event_id},
+                retcol="event_id",
+            )
+
+            ev.unsigned["redacted_by"] = redaction_id
+            # Get the redaction event.
+
+            because = yield self.get_event_txn(
+                redaction_id,
+                check_redacted=False
+            )
+
+            if because:
+                ev.unsigned["redacted_because"] = because
+            start_time = update_counter("redact_event", start_time)
+
+        if get_prev_content and "replaces_state" in ev.unsigned:
+            prev = yield self.get_event(
+                ev.unsigned["replaces_state"],
+                get_prev_content=False,
+            )
+            if prev:
+                ev.unsigned["prev_content"] = prev.get_dict()["content"]
+            start_time = update_counter("get_prev_content", start_time)
+
+        defer.returnValue(ev)
+
     def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
                                 check_redacted=True, get_prev_content=False,
                                 rejected_reason=None):
@@ -968,6 +1171,14 @@ class SQLBaseStore(object):
         internal_metadata = json.loads(internal_metadata)
         start_time = update_counter("decode_internal", start_time)
 
+        if rejected_reason:
+            rejected_reason = self._simple_select_one_onecol_txn(
+                txn,
+                table="rejections",
+                keyvalues={"event_id": rejected_reason},
+                retcol="reason",
+            )
+
         ev = FrozenEvent(
             d,
             internal_metadata_dict=internal_metadata,
@@ -978,12 +1189,19 @@ class SQLBaseStore(object):
         if check_redacted and redacted:
             ev = prune_event(ev)
 
-            ev.unsigned["redacted_by"] = redacted
+            redaction_id = self._simple_select_one_onecol_txn(
+                txn,
+                table="redactions",
+                keyvalues={"redacts": ev.event_id},
+                retcol="event_id",
+            )
+
+            ev.unsigned["redacted_by"] = redaction_id
             # Get the redaction event.
 
             because = self._get_event_txn(
                 txn,
-                redacted,
+                redaction_id,
                 check_redacted=False
             )
 
diff --git a/synapse/storage/schema/delta/19/event_index.sql b/synapse/storage/schema/delta/19/event_index.sql
new file mode 100644
index 0000000000..f3792817bb
--- /dev/null
+++ b/synapse/storage/schema/delta/19/event_index.sql
@@ -0,0 +1,19 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE INDEX events_order_topo_stream_room ON events(
+    topological_ordering, stream_ordering, room_id
+);
\ No newline at end of file
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index dbc0e49c1f..483b316e9f 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -43,6 +43,7 @@ class StateStore(SQLBaseStore):
       * `state_groups_state`: Maps state group to state events.
     """
 
+    @defer.inlineCallbacks
     def get_state_groups(self, event_ids):
         """ Get the state groups for the given list of event_ids
 
@@ -71,17 +72,31 @@ class StateStore(SQLBaseStore):
                     retcol="event_id",
                 )
 
-                state = self._get_events_txn(txn, state_ids)
+                # state = self._get_events_txn(txn, state_ids)
 
-                res[group] = state
+                res[group] = state_ids
 
             return res
 
-        return self.runInteraction(
+        states = yield self.runInteraction(
             "get_state_groups",
             f,
         )
 
+        @defer.inlineCallbacks
+        def c(vals):
+            vals[:] = yield self._fetch_events(vals, get_prev_content=False)
+
+        yield defer.gatherResults(
+            [
+                c(vals)
+                for vals in states.values()
+            ],
+            consumeErrors=True,
+        )
+
+        defer.returnValue(states)
+
     def _store_state_groups_txn(self, txn, event, context):
         if context.current_state is None:
             return
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 8045e17fd7..db9c2f0389 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -298,6 +298,7 @@ class StreamStore(SQLBaseStore):
 
         return self.runInteraction("paginate_room_events", f)
 
+    @defer.inlineCallbacks
     def get_recent_events_for_room(self, room_id, limit, end_token,
                                    with_feedback=False, from_token=None):
         # TODO (erikj): Handle compressed feedback
@@ -349,20 +350,21 @@ class StreamStore(SQLBaseStore):
             else:
                 token = (str(end_token), str(end_token))
 
-            events = self._get_events_txn(
-                txn,
-                [r["event_id"] for r in rows],
-                get_prev_content=True
-            )
+            return rows, token
 
-            self._set_before_and_after(events, rows)
-
-            return events, token
-
-        return self.runInteraction(
+        rows, token = yield self.runInteraction(
             "get_recent_events_for_room", get_recent_events_for_room_txn
         )
 
+        events = yield self._get_events(
+            [r["event_id"] for r in rows],
+            get_prev_content=True
+        )
+
+        self._set_before_and_after(events, rows)
+
+        defer.returnValue((events, token))
+
     @defer.inlineCallbacks
     def get_room_events_max_id(self, direction='f'):
         token = yield self._stream_id_gen.get_max_token(self)