summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/_base.py4
-rw-r--r--synapse/storage/events.py167
-rw-r--r--synapse/util/__init__.py8
3 files changed, 139 insertions, 40 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index a8f8989e38..c20ff3a572 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -299,6 +299,10 @@ class SQLBaseStore(object):
         self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
                                       max_entries=hs.config.event_cache_size)
 
+        self._event_fetch_lock = threading.Lock()
+        self._event_fetch_list = []
+        self._event_fetch_ongoing = False
+
         self.database_engine = hs.database_engine
 
         self._stream_id_gen = StreamIdGenerator()
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 0aa4e0d445..be88328ce5 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -15,7 +15,7 @@
 
 from _base import SQLBaseStore, _RollbackButIsFineException
 
-from twisted.internet import defer
+from twisted.internet import defer, reactor
 
 from synapse.events import FrozenEvent
 from synapse.events.utils import prune_event
@@ -89,18 +89,17 @@ class EventsStore(SQLBaseStore):
         Returns:
             Deferred : A FrozenEvent.
         """
-        event = yield self.runInteraction(
-            "get_event", self._get_event_txn,
-            event_id,
-            check_redacted=check_redacted,
-            get_prev_content=get_prev_content,
-            allow_rejected=allow_rejected,
+        events = yield self._get_events(
+            [event_id],
+            check_redacted=True,
+            get_prev_content=False,
+            allow_rejected=False,
         )
 
-        if not event and not allow_none:
+        if not events and not allow_none:
             raise RuntimeError("Could not find event %s" % (event_id,))
 
-        defer.returnValue(event)
+        defer.returnValue(events[0] if events else None)
 
     @log_function
     def _persist_event_txn(self, txn, event, context, backfilled,
@@ -420,13 +419,21 @@ class EventsStore(SQLBaseStore):
                 if e_id in event_map and event_map[e_id]
             ])
 
-        missing_events = yield self._fetch_events(
-            txn,
-            missing_events_ids,
-            check_redacted=check_redacted,
-            get_prev_content=get_prev_content,
-            allow_rejected=allow_rejected,
-        )
+        if not txn:
+            missing_events = yield self._enqueue_events(
+                missing_events_ids,
+                check_redacted=check_redacted,
+                get_prev_content=get_prev_content,
+                allow_rejected=allow_rejected,
+            )
+        else:
+            missing_events = self._fetch_events_txn(
+                txn,
+                missing_events_ids,
+                check_redacted=check_redacted,
+                get_prev_content=get_prev_content,
+                allow_rejected=allow_rejected,
+            )
 
         event_map.update(missing_events)
 
@@ -492,11 +499,82 @@ class EventsStore(SQLBaseStore):
         ))
 
     @defer.inlineCallbacks
-    def _fetch_events(self, txn, events, check_redacted=True,
-                      get_prev_content=False, allow_rejected=False):
+    def _enqueue_events(self, events, check_redacted=True,
+                        get_prev_content=False, allow_rejected=False):
         if not events:
             defer.returnValue({})
 
+        def do_fetch(txn):
+            event_list = []
+            try:
+                with self._event_fetch_lock:
+                    event_list = self._event_fetch_list
+                    self._event_fetch_list = []
+
+                    if not event_list:
+                        return
+
+                event_id_lists = zip(*event_list)[0]
+                event_ids = [
+                    item for sublist in event_id_lists for item in sublist
+                ]
+                rows = self._fetch_event_rows(txn, event_ids)
+
+                row_dict = {
+                    r["event_id"]: r
+                    for r in rows
+                }
+
+                for ids, d in event_list:
+                    d.callback(
+                        [
+                            row_dict[i] for i in ids
+                            if i in row_dict
+                        ]
+                    )
+            except Exception as e:
+                for _, d in event_list:
+                    try:
+                        reactor.callFromThread(d.errback, e)
+                    except:
+                        pass
+            finally:
+                with self._event_fetch_lock:
+                    self._event_fetch_ongoing = False
+
+        def cb(rows):
+            return defer.gatherResults([
+                self._get_event_from_row(
+                    None,
+                    row["internal_metadata"], row["json"], row["redacts"],
+                    check_redacted=check_redacted,
+                    get_prev_content=get_prev_content,
+                    rejected_reason=row["rejects"],
+                )
+                for row in rows
+            ])
+
+        d = defer.Deferred()
+        d.addCallback(cb)
+        with self._event_fetch_lock:
+            self._event_fetch_list.append(
+                (events, d)
+            )
+
+            if not self._event_fetch_ongoing:
+                self.runInteraction(
+                    "do_fetch",
+                    do_fetch
+                )
+
+        res = yield d
+
+        defer.returnValue({
+            e.event_id: e
+            for e in res if e
+        })
+
+    def _fetch_event_rows(self, txn, events):
         rows = []
         N = 200
         for i in range(1 + len(events) / N):
@@ -505,43 +583,56 @@ class EventsStore(SQLBaseStore):
                 break
 
             sql = (
-                "SELECT e.internal_metadata, e.json, r.redacts, rej.event_id "
+                "SELECT "
+                " e.event_id as event_id, "
+                " e.internal_metadata,"
+                " e.json,"
+                " r.redacts as redacts,"
+                " rej.event_id as rejects "
                 " FROM event_json as e"
                 " LEFT JOIN rejections as rej USING (event_id)"
                 " LEFT JOIN redactions as r ON e.event_id = r.redacts"
                 " WHERE e.event_id IN (%s)"
             ) % (",".join(["?"]*len(evs)),)
 
-            if txn:
-                txn.execute(sql, evs)
-                rows.extend(txn.fetchall())
-            else:
-                res = yield self._execute("_fetch_events", None, sql, *evs)
-                rows.extend(res)
+            txn.execute(sql, evs)
+            rows.extend(self.cursor_to_dict(txn))
+
+        return rows
+
+    @defer.inlineCallbacks
+    def _fetch_events(self, txn, events, check_redacted=True,
+                     get_prev_content=False, allow_rejected=False):
+        if not events:
+            defer.returnValue({})
+
+        if txn:
+            rows = self._fetch_event_rows(
+                txn, events,
+            )
+        else:
+            rows = yield self.runInteraction(
+                self._fetch_event_rows,
+                events,
+            )
 
         res = yield defer.gatherResults(
             [
                 defer.maybeDeferred(
                     self._get_event_from_row,
                     txn,
-                    row[0], row[1], row[2],
+                    row["internal_metadata"], row["json"], row["redacts"],
                     check_redacted=check_redacted,
                     get_prev_content=get_prev_content,
-                    rejected_reason=row[3],
+                    rejected_reason=row["rejects"],
                 )
                 for row in rows
-            ],
-            consumeErrors=True,
+            ]
         )
 
-        for e in res:
-            self._get_event_cache.prefill(
-                e.event_id, check_redacted, get_prev_content, e
-            )
-
         defer.returnValue({
-            e.event_id: e
-            for e in res if e
+            r.event_id: r
+            for r in res
         })
 
     @defer.inlineCallbacks
@@ -611,6 +702,10 @@ class EventsStore(SQLBaseStore):
             if prev:
                 ev.unsigned["prev_content"] = prev.get_dict()["content"]
 
+        self._get_event_cache.prefill(
+            ev.event_id, check_redacted, get_prev_content, ev
+        )
+
         defer.returnValue(ev)
 
     def _parse_events(self, rows):
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index b9afb3364d..260714ccc2 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -80,16 +80,16 @@ class Clock(object):
     def stop_looping_call(self, loop):
         loop.stop()
 
-    def call_later(self, delay, callback):
+    def call_later(self, delay, callback, *args, **kwargs):
         current_context = LoggingContext.current_context()
 
-        def wrapped_callback():
+        def wrapped_callback(*args, **kwargs):
             with PreserveLoggingContext():
                 LoggingContext.thread_local.current_context = current_context
-                callback()
+                callback(*args, **kwargs)
 
         with PreserveLoggingContext():
-            return reactor.callLater(delay, wrapped_callback)
+            return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
 
     def cancel_call_later(self, timer):
         timer.cancel()