summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py2
-rw-r--r--synapse/storage/_base.py331
-rw-r--r--synapse/storage/engines/postgres.py2
-rw-r--r--synapse/storage/engines/sqlite3.py2
-rw-r--r--synapse/storage/event_federation.py176
-rw-r--r--synapse/storage/events.py487
-rw-r--r--synapse/storage/presence.py35
-rw-r--r--synapse/storage/push_rule.py32
-rw-r--r--synapse/storage/roommember.py39
-rw-r--r--synapse/storage/schema/delta/19/event_index.sql19
-rw-r--r--synapse/storage/state.py28
-rw-r--r--synapse/storage/stream.py181
-rw-r--r--synapse/storage/util/id_generators.py12
13 files changed, 873 insertions, 473 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 7cb91a0be9..75af44d787 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 18
+SCHEMA_VERSION = 19
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index c8c76e58fe..39884c2afe 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,10 +15,8 @@
 import logging
 
 from synapse.api.errors import StoreError
-from synapse.events import FrozenEvent
-from synapse.events.utils import prune_event
 from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
+from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
 from synapse.util.lrucache import LruCache
 import synapse.metrics
 
@@ -27,8 +25,8 @@ from util.id_generators import IdGenerator, StreamIdGenerator
 from twisted.internet import defer
 
 from collections import namedtuple, OrderedDict
+
 import functools
-import simplejson as json
 import sys
 import time
 import threading
@@ -48,7 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time")
 
 sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
 sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
-sql_getevents_timer = metrics.register_distribution("getEvents_time", labels=["desc"])
 
 caches_by_name = {}
 cache_counter = metrics.register_cache(
@@ -307,6 +304,12 @@ class SQLBaseStore(object):
         self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
                                       max_entries=hs.config.event_cache_size)
 
+        self._event_fetch_lock = threading.Condition()
+        self._event_fetch_list = []
+        self._event_fetch_ongoing = 0
+
+        self._pending_ds = []
+
         self.database_engine = hs.database_engine
 
         self._stream_id_gen = StreamIdGenerator()
@@ -315,6 +318,7 @@ class SQLBaseStore(object):
         self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
         self._pushers_id_gen = IdGenerator("pushers", "id", self)
         self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
+        self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
 
     def start_profiling(self):
         self._previous_loop_ts = self._clock.time_msec()
@@ -345,6 +349,75 @@ class SQLBaseStore(object):
 
         self._clock.looping_call(loop, 10000)
 
+    def _new_transaction(self, conn, desc, after_callbacks, func, *args, **kwargs):
+        start = time.time() * 1000
+        txn_id = self._TXN_ID
+
+        # We don't really need these to be unique, so lets stop it from
+        # growing really large.
+        self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+
+        name = "%s-%x" % (desc, txn_id, )
+
+        transaction_logger.debug("[TXN START] {%s}", name)
+
+        try:
+            i = 0
+            N = 5
+            while True:
+                try:
+                    txn = conn.cursor()
+                    txn = LoggingTransaction(
+                        txn, name, self.database_engine, after_callbacks
+                    )
+                    r = func(txn, *args, **kwargs)
+                    conn.commit()
+                    return r
+                except self.database_engine.module.OperationalError as e:
+                    # This can happen if the database disappears mid
+                    # transaction.
+                    logger.warn(
+                        "[TXN OPERROR] {%s} %s %d/%d",
+                        name, e, i, N
+                    )
+                    if i < N:
+                        i += 1
+                        try:
+                            conn.rollback()
+                        except self.database_engine.module.Error as e1:
+                            logger.warn(
+                                "[TXN EROLL] {%s} %s",
+                                name, e1,
+                            )
+                        continue
+                    raise
+                except self.database_engine.module.DatabaseError as e:
+                    if self.database_engine.is_deadlock(e):
+                        logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+                        if i < N:
+                            i += 1
+                            try:
+                                conn.rollback()
+                            except self.database_engine.module.Error as e1:
+                                logger.warn(
+                                    "[TXN EROLL] {%s} %s",
+                                    name, e1,
+                                )
+                            continue
+                    raise
+        except Exception as e:
+            logger.debug("[TXN FAIL] {%s} %s", name, e)
+            raise
+        finally:
+            end = time.time() * 1000
+            duration = end - start
+
+            transaction_logger.debug("[TXN END] {%s} %f", name, duration)
+
+            self._current_txn_total_time += duration
+            self._txn_perf_counters.update(desc, start, end)
+            sql_txn_timer.inc_by(duration, desc)
+
     @defer.inlineCallbacks
     def runInteraction(self, desc, func, *args, **kwargs):
         """Wraps the .runInteraction() method on the underlying db_pool."""
@@ -356,82 +429,50 @@ class SQLBaseStore(object):
 
         def inner_func(conn, *args, **kwargs):
             with LoggingContext("runInteraction") as context:
+                sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
+
                 if self.database_engine.is_connection_closed(conn):
                     logger.debug("Reconnecting closed database connection")
                     conn.reconnect()
 
                 current_context.copy_to(context)
-                start = time.time() * 1000
-                txn_id = self._TXN_ID
+                return self._new_transaction(
+                    conn, desc, after_callbacks, func, *args, **kwargs
+                )
 
-                # We don't really need these to be unique, so lets stop it from
-                # growing really large.
-                self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+        result = yield preserve_context_over_fn(
+            self._db_pool.runWithConnection,
+            inner_func, *args, **kwargs
+        )
 
-                name = "%s-%x" % (desc, txn_id, )
+        for after_callback, after_args in after_callbacks:
+            after_callback(*after_args)
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def runWithConnection(self, func, *args, **kwargs):
+        """Wraps the .runInteraction() method on the underlying db_pool."""
+        current_context = LoggingContext.current_context()
+
+        start_time = time.time() * 1000
 
+        def inner_func(conn, *args, **kwargs):
+            with LoggingContext("runWithConnection") as context:
                 sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
-                transaction_logger.debug("[TXN START] {%s}", name)
-                try:
-                    i = 0
-                    N = 5
-                    while True:
-                        try:
-                            txn = conn.cursor()
-                            txn = LoggingTransaction(
-                                txn, name, self.database_engine, after_callbacks
-                            )
-                            return func(txn, *args, **kwargs)
-                        except self.database_engine.module.OperationalError as e:
-                            # This can happen if the database disappears mid
-                            # transaction.
-                            logger.warn(
-                                "[TXN OPERROR] {%s} %s %d/%d",
-                                name, e, i, N
-                            )
-                            if i < N:
-                                i += 1
-                                try:
-                                    conn.rollback()
-                                except self.database_engine.module.Error as e1:
-                                    logger.warn(
-                                        "[TXN EROLL] {%s} %s",
-                                        name, e1,
-                                    )
-                                continue
-                        except self.database_engine.module.DatabaseError as e:
-                            if self.database_engine.is_deadlock(e):
-                                logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
-                                if i < N:
-                                    i += 1
-                                    try:
-                                        conn.rollback()
-                                    except self.database_engine.module.Error as e1:
-                                        logger.warn(
-                                            "[TXN EROLL] {%s} %s",
-                                            name, e1,
-                                        )
-                                    continue
-                            raise
-                except Exception as e:
-                    logger.debug("[TXN FAIL] {%s} %s", name, e)
-                    raise
-                finally:
-                    end = time.time() * 1000
-                    duration = end - start
 
-                    transaction_logger.debug("[TXN END] {%s} %f", name, duration)
+                if self.database_engine.is_connection_closed(conn):
+                    logger.debug("Reconnecting closed database connection")
+                    conn.reconnect()
 
-                    self._current_txn_total_time += duration
-                    self._txn_perf_counters.update(desc, start, end)
-                    sql_txn_timer.inc_by(duration, desc)
+                current_context.copy_to(context)
+
+                return func(conn, *args, **kwargs)
+
+        result = yield preserve_context_over_fn(
+            self._db_pool.runWithConnection,
+            inner_func, *args, **kwargs
+        )
 
-        with PreserveLoggingContext():
-            result = yield self._db_pool.runWithConnection(
-                inner_func, *args, **kwargs
-            )
-        for after_callback, after_args in after_callbacks:
-            after_callback(*after_args)
         defer.returnValue(result)
 
     def cursor_to_dict(self, cursor):
@@ -871,158 +912,6 @@ class SQLBaseStore(object):
 
         return self.runInteraction("_simple_max_id", func)
 
-    def _get_events(self, event_ids, check_redacted=True,
-                    get_prev_content=False):
-        return self.runInteraction(
-            "_get_events", self._get_events_txn, event_ids,
-            check_redacted=check_redacted, get_prev_content=get_prev_content,
-        )
-
-    def _get_events_txn(self, txn, event_ids, check_redacted=True,
-                        get_prev_content=False):
-        if not event_ids:
-            return []
-
-        events = [
-            self._get_event_txn(
-                txn, event_id,
-                check_redacted=check_redacted,
-                get_prev_content=get_prev_content
-            )
-            for event_id in event_ids
-        ]
-
-        return [e for e in events if e]
-
-    def _invalidate_get_event_cache(self, event_id):
-        for check_redacted in (False, True):
-            for get_prev_content in (False, True):
-                self._get_event_cache.invalidate(event_id, check_redacted,
-                                                 get_prev_content)
-
-    def _get_event_txn(self, txn, event_id, check_redacted=True,
-                       get_prev_content=False, allow_rejected=False):
-
-        start_time = time.time() * 1000
-
-        def update_counter(desc, last_time):
-            curr_time = self._get_event_counters.update(desc, last_time)
-            sql_getevents_timer.inc_by(curr_time - last_time, desc)
-            return curr_time
-
-        try:
-            ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
-
-            if allow_rejected or not ret.rejected_reason:
-                return ret
-            else:
-                return None
-        except KeyError:
-            pass
-        finally:
-            start_time = update_counter("event_cache", start_time)
-
-        sql = (
-            "SELECT e.internal_metadata, e.json, r.event_id, rej.reason "
-            "FROM event_json as e "
-            "LEFT JOIN redactions as r ON e.event_id = r.redacts "
-            "LEFT JOIN rejections as rej on rej.event_id = e.event_id  "
-            "WHERE e.event_id = ? "
-            "LIMIT 1 "
-        )
-
-        txn.execute(sql, (event_id,))
-
-        res = txn.fetchone()
-
-        if not res:
-            return None
-
-        internal_metadata, js, redacted, rejected_reason = res
-
-        start_time = update_counter("select_event", start_time)
-
-        result = self._get_event_from_row_txn(
-            txn, internal_metadata, js, redacted,
-            check_redacted=check_redacted,
-            get_prev_content=get_prev_content,
-            rejected_reason=rejected_reason,
-        )
-        self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result)
-
-        if allow_rejected or not rejected_reason:
-            return result
-        else:
-            return None
-
-    def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
-                                check_redacted=True, get_prev_content=False,
-                                rejected_reason=None):
-
-        start_time = time.time() * 1000
-
-        def update_counter(desc, last_time):
-            curr_time = self._get_event_counters.update(desc, last_time)
-            sql_getevents_timer.inc_by(curr_time - last_time, desc)
-            return curr_time
-
-        d = json.loads(js)
-        start_time = update_counter("decode_json", start_time)
-
-        internal_metadata = json.loads(internal_metadata)
-        start_time = update_counter("decode_internal", start_time)
-
-        ev = FrozenEvent(
-            d,
-            internal_metadata_dict=internal_metadata,
-            rejected_reason=rejected_reason,
-        )
-        start_time = update_counter("build_frozen_event", start_time)
-
-        if check_redacted and redacted:
-            ev = prune_event(ev)
-
-            ev.unsigned["redacted_by"] = redacted
-            # Get the redaction event.
-
-            because = self._get_event_txn(
-                txn,
-                redacted,
-                check_redacted=False
-            )
-
-            if because:
-                ev.unsigned["redacted_because"] = because
-            start_time = update_counter("redact_event", start_time)
-
-        if get_prev_content and "replaces_state" in ev.unsigned:
-            prev = self._get_event_txn(
-                txn,
-                ev.unsigned["replaces_state"],
-                get_prev_content=False,
-            )
-            if prev:
-                ev.unsigned["prev_content"] = prev.get_dict()["content"]
-            start_time = update_counter("get_prev_content", start_time)
-
-        return ev
-
-    def _parse_events(self, rows):
-        return self.runInteraction(
-            "_parse_events", self._parse_events_txn, rows
-        )
-
-    def _parse_events_txn(self, txn, rows):
-        event_ids = [r["event_id"] for r in rows]
-
-        return self._get_events_txn(txn, event_ids)
-
-    def _has_been_redacted_txn(self, txn, event):
-        sql = "SELECT event_id FROM redactions WHERE redacts = ?"
-        txn.execute(sql, (event.event_id,))
-        result = txn.fetchone()
-        return result[0] if result else None
-
     def get_next_stream_id(self):
         with self._next_stream_id_lock:
             i = self._next_stream_id
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a323028546..4a855ffd56 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -19,6 +19,8 @@ from ._base import IncorrectDatabaseSetup
 
 
 class PostgresEngine(object):
+    single_threaded = False
+
     def __init__(self, database_module):
         self.module = database_module
         self.module.extensions.register_type(self.module.extensions.UNICODE)
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
index ff13d8006a..d18e2808d1 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite3.py
@@ -17,6 +17,8 @@ from synapse.storage import prepare_database, prepare_sqlite3_database
 
 
 class Sqlite3Engine(object):
+    single_threaded = True
+
     def __init__(self, database_module):
         self.module = database_module
 
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 74b4e23590..1ba073884b 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -13,10 +13,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
+
 from ._base import SQLBaseStore, cached
 from syutil.base64util import encode_base64
 
 import logging
+from Queue import PriorityQueue, Empty
 
 
 logger = logging.getLogger(__name__)
@@ -33,16 +36,7 @@ class EventFederationStore(SQLBaseStore):
     """
 
     def get_auth_chain(self, event_ids):
-        return self.runInteraction(
-            "get_auth_chain",
-            self._get_auth_chain_txn,
-            event_ids
-        )
-
-    def _get_auth_chain_txn(self, txn, event_ids):
-        results = self._get_auth_chain_ids_txn(txn, event_ids)
-
-        return self._get_events_txn(txn, results)
+        return self.get_auth_chain_ids(event_ids).addCallback(self._get_events)
 
     def get_auth_chain_ids(self, event_ids):
         return self.runInteraction(
@@ -79,6 +73,28 @@ class EventFederationStore(SQLBaseStore):
             room_id,
         )
 
+    def get_oldest_events_with_depth_in_room(self, room_id):
+        return self.runInteraction(
+            "get_oldest_events_with_depth_in_room",
+            self.get_oldest_events_with_depth_in_room_txn,
+            room_id,
+        )
+
+    def get_oldest_events_with_depth_in_room_txn(self, txn, room_id):
+        sql = (
+            "SELECT b.event_id, MAX(e.depth) FROM events as e"
+            " INNER JOIN event_edges as g"
+            " ON g.event_id = e.event_id AND g.room_id = e.room_id"
+            " INNER JOIN event_backward_extremities as b"
+            " ON g.prev_event_id = b.event_id AND g.room_id = b.room_id"
+            " WHERE b.room_id = ? AND g.is_state is ?"
+            " GROUP BY b.event_id"
+        )
+
+        txn.execute(sql, (room_id, False,))
+
+        return dict(txn.fetchall())
+
     def _get_oldest_events_in_room_txn(self, txn, room_id):
         return self._simple_select_onecol_txn(
             txn,
@@ -247,11 +263,13 @@ class EventFederationStore(SQLBaseStore):
         do_insert = depth < min_depth if min_depth else True
 
         if do_insert:
-            self._simple_insert_txn(
+            self._simple_upsert_txn(
                 txn,
                 table="room_depth",
-                values={
+                keyvalues={
                     "room_id": room_id,
+                },
+                values={
                     "min_depth": depth,
                 },
             )
@@ -306,31 +324,28 @@ class EventFederationStore(SQLBaseStore):
 
                 txn.execute(query, (event_id, room_id))
 
-            # Insert all the prev_events as a backwards thing, they'll get
-            # deleted in a second if they're incorrect anyway.
-            self._simple_insert_many_txn(
-                txn,
-                table="event_backward_extremities",
-                values=[
-                    {
-                        "event_id": e_id,
-                        "room_id": room_id,
-                    }
-                    for e_id, _ in prev_events
-                ],
+            query = (
+                "INSERT INTO event_backward_extremities (event_id, room_id)"
+                " SELECT ?, ? WHERE NOT EXISTS ("
+                " SELECT 1 FROM event_backward_extremities"
+                " WHERE event_id = ? AND room_id = ?"
+                " )"
+                " AND NOT EXISTS ("
+                " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
+                " AND outlier = ?"
+                " )"
             )
 
-            # Also delete from the backwards extremities table all ones that
-            # reference events that we have already seen
+            txn.executemany(query, [
+                (e_id, room_id, e_id, room_id, e_id, room_id, False)
+                for e_id, _ in prev_events
+            ])
+
             query = (
-                "DELETE FROM event_backward_extremities WHERE EXISTS ("
-                "SELECT 1 FROM events "
-                "WHERE "
-                "event_backward_extremities.event_id = events.event_id "
-                "AND not events.outlier "
-                ")"
+                "DELETE FROM event_backward_extremities"
+                " WHERE event_id = ? AND room_id = ?"
             )
-            txn.execute(query)
+            txn.execute(query, (event_id, room_id))
 
             txn.call_after(
                 self.get_latest_event_ids_in_room.invalidate, room_id
@@ -349,6 +364,10 @@ class EventFederationStore(SQLBaseStore):
         return self.runInteraction(
             "get_backfill_events",
             self._get_backfill_events, room_id, event_list, limit
+        ).addCallback(
+            self._get_events
+        ).addCallback(
+            lambda l: sorted(l, key=lambda e: -e.depth)
         )
 
     def _get_backfill_events(self, txn, room_id, event_list, limit):
@@ -357,54 +376,75 @@ class EventFederationStore(SQLBaseStore):
             room_id, repr(event_list), limit
         )
 
-        event_results = event_list
+        event_results = set()
 
-        front = event_list
+        # We want to make sure that we do a breadth-first, "depth" ordered
+        # search.
 
         query = (
-            "SELECT prev_event_id FROM event_edges "
-            "WHERE room_id = ? AND event_id = ? "
-            "LIMIT ?"
+            "SELECT depth, prev_event_id FROM event_edges"
+            " INNER JOIN events"
+            " ON prev_event_id = events.event_id"
+            " AND event_edges.room_id = events.room_id"
+            " WHERE event_edges.room_id = ? AND event_edges.event_id = ?"
+            " AND event_edges.is_state = ?"
+            " LIMIT ?"
         )
 
-        # We iterate through all event_ids in `front` to select their previous
-        # events. These are dumped in `new_front`.
-        # We continue until we reach the limit *or* new_front is empty (i.e.,
-        # we've run out of things to select
-        while front and len(event_results) < limit:
+        queue = PriorityQueue()
 
-            new_front = []
-            for event_id in front:
-                logger.debug(
-                    "_backfill_interaction: id=%s",
-                    event_id
-                )
+        for event_id in event_list:
+            depth = self._simple_select_one_onecol_txn(
+                txn,
+                table="events",
+                keyvalues={
+                    "event_id": event_id,
+                },
+                retcol="depth"
+            )
 
-                txn.execute(
-                    query,
-                    (room_id, event_id, limit - len(event_results))
-                )
+            queue.put((-depth, event_id))
 
-                for row in txn.fetchall():
-                    logger.debug(
-                        "_backfill_interaction: got id=%s",
-                        *row
-                    )
-                    new_front.append(row[0])
+        while not queue.empty() and len(event_results) < limit:
+            try:
+                _, event_id = queue.get_nowait()
+            except Empty:
+                break
 
-            front = new_front
-            event_results += new_front
+            if event_id in event_results:
+                continue
+
+            event_results.add(event_id)
+
+            txn.execute(
+                query,
+                (room_id, event_id, False, limit - len(event_results))
+            )
+
+            for row in txn.fetchall():
+                if row[1] not in event_results:
+                    queue.put((-row[0], row[1]))
 
-        return self._get_events_txn(txn, event_results)
+        return event_results
 
+    @defer.inlineCallbacks
     def get_missing_events(self, room_id, earliest_events, latest_events,
                            limit, min_depth):
-        return self.runInteraction(
+        ids = yield self.runInteraction(
             "get_missing_events",
             self._get_missing_events,
             room_id, earliest_events, latest_events, limit, min_depth
         )
 
+        events = yield self._get_events(ids)
+
+        events = sorted(
+            [ev for ev in events if ev.depth >= min_depth],
+            key=lambda e: e.depth,
+        )
+
+        defer.returnValue(events[:limit])
+
     def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
                             limit, min_depth):
 
@@ -436,14 +476,7 @@ class EventFederationStore(SQLBaseStore):
             front = new_front
             event_results |= new_front
 
-        events = self._get_events_txn(txn, event_results)
-
-        events = sorted(
-            [ev for ev in events if ev.depth >= min_depth],
-            key=lambda e: e.depth,
-        )
-
-        return events[:limit]
+        return event_results
 
     def clean_room_for_join(self, room_id):
         return self.runInteraction(
@@ -456,3 +489,4 @@ class EventFederationStore(SQLBaseStore):
         query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
 
         txn.execute(query, (room_id,))
+        txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 1304219e86..d2a010bd88 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -15,20 +15,36 @@
 
 from _base import SQLBaseStore, _RollbackButIsFineException
 
-from twisted.internet import defer
+from twisted.internet import defer, reactor
 
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+
+from synapse.util.logcontext import preserve_context_over_deferred
 from synapse.util.logutils import log_function
 from synapse.api.constants import EventTypes
 from synapse.crypto.event_signing import compute_event_reference_hash
 
 from syutil.base64util import decode_base64
 from syutil.jsonutil import encode_canonical_json
+from contextlib import contextmanager
 
 import logging
+import simplejson as json
 
 logger = logging.getLogger(__name__)
 
 
+# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# control how we batch/bulk fetch events from the database.
+# The values are plucked out of thing air to make initial sync run faster
+# on jki.re
+# TODO: Make these configurable.
+EVENT_QUEUE_THREADS = 3  # Max number of threads that will fetch events
+EVENT_QUEUE_ITERATIONS = 3  # No. times we block waiting for requests for events
+EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
+
+
 class EventsStore(SQLBaseStore):
     @defer.inlineCallbacks
     @log_function
@@ -41,20 +57,32 @@ class EventsStore(SQLBaseStore):
             self.min_token -= 1
             stream_ordering = self.min_token
 
+        if stream_ordering is None:
+            stream_ordering_manager = yield self._stream_id_gen.get_next(self)
+        else:
+            @contextmanager
+            def stream_ordering_manager():
+                yield stream_ordering
+            stream_ordering_manager = stream_ordering_manager()
+
         try:
-            yield self.runInteraction(
-                "persist_event",
-                self._persist_event_txn,
-                event=event,
-                context=context,
-                backfilled=backfilled,
-                stream_ordering=stream_ordering,
-                is_new_state=is_new_state,
-                current_state=current_state,
-            )
+            with stream_ordering_manager as stream_ordering:
+                yield self.runInteraction(
+                    "persist_event",
+                    self._persist_event_txn,
+                    event=event,
+                    context=context,
+                    backfilled=backfilled,
+                    stream_ordering=stream_ordering,
+                    is_new_state=is_new_state,
+                    current_state=current_state,
+                )
         except _RollbackButIsFineException:
             pass
 
+        max_persisted_id = yield self._stream_id_gen.get_max_token(self)
+        defer.returnValue((stream_ordering, max_persisted_id))
+
     @defer.inlineCallbacks
     def get_event(self, event_id, check_redacted=True,
                   get_prev_content=False, allow_rejected=False,
@@ -74,18 +102,17 @@ class EventsStore(SQLBaseStore):
         Returns:
             Deferred : A FrozenEvent.
         """
-        event = yield self.runInteraction(
-            "get_event", self._get_event_txn,
-            event_id,
+        events = yield self._get_events(
+            [event_id],
             check_redacted=check_redacted,
             get_prev_content=get_prev_content,
             allow_rejected=allow_rejected,
         )
 
-        if not event and not allow_none:
+        if not events and not allow_none:
             raise RuntimeError("Could not find event %s" % (event_id,))
 
-        defer.returnValue(event)
+        defer.returnValue(events[0] if events else None)
 
     @log_function
     def _persist_event_txn(self, txn, event, context, backfilled,
@@ -95,15 +122,6 @@ class EventsStore(SQLBaseStore):
         # Remove the any existing cache entries for the event_id
         txn.call_after(self._invalidate_get_event_cache, event.event_id)
 
-        if stream_ordering is None:
-            with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
-                return self._persist_event_txn(
-                    txn, event, context, backfilled,
-                    stream_ordering=stream_ordering,
-                    is_new_state=is_new_state,
-                    current_state=current_state,
-                )
-
         # We purposefully do this first since if we include a `current_state`
         # key, we *want* to update the `current_state_events` table
         if current_state:
@@ -134,19 +152,17 @@ class EventsStore(SQLBaseStore):
         outlier = event.internal_metadata.is_outlier()
 
         if not outlier:
-            self._store_state_groups_txn(txn, event, context)
-
             self._update_min_depth_for_room_txn(
                 txn,
                 event.room_id,
                 event.depth
             )
 
-        have_persisted = self._simple_select_one_onecol_txn(
+        have_persisted = self._simple_select_one_txn(
             txn,
-            table="event_json",
+            table="events",
             keyvalues={"event_id": event.event_id},
-            retcol="event_id",
+            retcols=["event_id", "outlier"],
             allow_none=True,
         )
 
@@ -161,7 +177,9 @@ class EventsStore(SQLBaseStore):
         # if we are persisting an event that we had persisted as an outlier,
         # but is no longer one.
         if have_persisted:
-            if not outlier:
+            if not outlier and have_persisted["outlier"]:
+                self._store_state_groups_txn(txn, event, context)
+
                 sql = (
                     "UPDATE event_json SET internal_metadata = ?"
                     " WHERE event_id = ?"
@@ -181,6 +199,9 @@ class EventsStore(SQLBaseStore):
                 )
             return
 
+        if not outlier:
+            self._store_state_groups_txn(txn, event, context)
+
         self._handle_prev_events(
             txn,
             outlier=outlier,
@@ -400,3 +421,407 @@ class EventsStore(SQLBaseStore):
         return self.runInteraction(
             "have_events", f,
         )
+
+    @defer.inlineCallbacks
+    def _get_events(self, event_ids, check_redacted=True,
+                    get_prev_content=False, allow_rejected=False):
+        if not event_ids:
+            defer.returnValue([])
+
+        event_map = self._get_events_from_cache(
+            event_ids,
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            allow_rejected=allow_rejected,
+        )
+
+        missing_events_ids = [e for e in event_ids if e not in event_map]
+
+        if not missing_events_ids:
+            defer.returnValue([
+                event_map[e_id] for e_id in event_ids
+                if e_id in event_map and event_map[e_id]
+            ])
+
+        missing_events = yield self._enqueue_events(
+            missing_events_ids,
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            allow_rejected=allow_rejected,
+        )
+
+        event_map.update(missing_events)
+
+        defer.returnValue([
+            event_map[e_id] for e_id in event_ids
+            if e_id in event_map and event_map[e_id]
+        ])
+
+    def _get_events_txn(self, txn, event_ids, check_redacted=True,
+                        get_prev_content=False, allow_rejected=False):
+        if not event_ids:
+            return []
+
+        event_map = self._get_events_from_cache(
+            event_ids,
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            allow_rejected=allow_rejected,
+        )
+
+        missing_events_ids = [e for e in event_ids if e not in event_map]
+
+        if not missing_events_ids:
+            return [
+                event_map[e_id] for e_id in event_ids
+                if e_id in event_map and event_map[e_id]
+            ]
+
+        missing_events = self._fetch_events_txn(
+            txn,
+            missing_events_ids,
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            allow_rejected=allow_rejected,
+        )
+
+        event_map.update(missing_events)
+
+        return [
+            event_map[e_id] for e_id in event_ids
+            if e_id in event_map and event_map[e_id]
+        ]
+
+    def _invalidate_get_event_cache(self, event_id):
+        for check_redacted in (False, True):
+            for get_prev_content in (False, True):
+                self._get_event_cache.invalidate(event_id, check_redacted,
+                                                 get_prev_content)
+
+    def _get_event_txn(self, txn, event_id, check_redacted=True,
+                       get_prev_content=False, allow_rejected=False):
+
+        events = self._get_events_txn(
+            txn, [event_id],
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            allow_rejected=allow_rejected,
+        )
+
+        return events[0] if events else None
+
+    def _get_events_from_cache(self, events, check_redacted, get_prev_content,
+                               allow_rejected):
+        event_map = {}
+
+        for event_id in events:
+            try:
+                ret = self._get_event_cache.get(
+                    event_id, check_redacted, get_prev_content
+                )
+
+                if allow_rejected or not ret.rejected_reason:
+                    event_map[event_id] = ret
+                else:
+                    event_map[event_id] = None
+            except KeyError:
+                pass
+
+        return event_map
+
+    def _do_fetch(self, conn):
+        """Takes a database connection and waits for requests for events from
+        the _event_fetch_list queue.
+        """
+        event_list = []
+        i = 0
+        while True:
+            try:
+                with self._event_fetch_lock:
+                    event_list = self._event_fetch_list
+                    self._event_fetch_list = []
+
+                    if not event_list:
+                        single_threaded = self.database_engine.single_threaded
+                        if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+                            self._event_fetch_ongoing -= 1
+                            return
+                        else:
+                            self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+                            i += 1
+                            continue
+                    i = 0
+
+                event_id_lists = zip(*event_list)[0]
+                event_ids = [
+                    item for sublist in event_id_lists for item in sublist
+                ]
+
+                rows = self._new_transaction(
+                    conn, "do_fetch", [], self._fetch_event_rows, event_ids
+                )
+
+                row_dict = {
+                    r["event_id"]: r
+                    for r in rows
+                }
+
+                # We only want to resolve deferreds from the main thread
+                def fire(lst, res):
+                    for ids, d in lst:
+                        if not d.called:
+                            try:
+                                d.callback([
+                                    res[i]
+                                    for i in ids
+                                    if i in res
+                                ])
+                            except:
+                                logger.exception("Failed to callback")
+                reactor.callFromThread(fire, event_list, row_dict)
+            except Exception as e:
+                logger.exception("do_fetch")
+
+                # We only want to resolve deferreds from the main thread
+                def fire(evs):
+                    for _, d in evs:
+                        if not d.called:
+                            d.errback(e)
+
+                if event_list:
+                    reactor.callFromThread(fire, event_list)
+
+    @defer.inlineCallbacks
+    def _enqueue_events(self, events, check_redacted=True,
+                        get_prev_content=False, allow_rejected=False):
+        """Fetches events from the database using the _event_fetch_list. This
+        allows batch and bulk fetching of events - it allows us to fetch events
+        without having to create a new transaction for each request for events.
+        """
+        if not events:
+            defer.returnValue({})
+
+        events_d = defer.Deferred()
+        with self._event_fetch_lock:
+            self._event_fetch_list.append(
+                (events, events_d)
+            )
+
+            self._event_fetch_lock.notify()
+
+            if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
+                self._event_fetch_ongoing += 1
+                should_start = True
+            else:
+                should_start = False
+
+        if should_start:
+            self.runWithConnection(
+                self._do_fetch
+            )
+
+        rows = yield preserve_context_over_deferred(events_d)
+
+        if not allow_rejected:
+            rows[:] = [r for r in rows if not r["rejects"]]
+
+        res = yield defer.gatherResults(
+            [
+                self._get_event_from_row(
+                    row["internal_metadata"], row["json"], row["redacts"],
+                    check_redacted=check_redacted,
+                    get_prev_content=get_prev_content,
+                    rejected_reason=row["rejects"],
+                )
+                for row in rows
+            ],
+            consumeErrors=True
+        )
+
+        defer.returnValue({
+            e.event_id: e
+            for e in res if e
+        })
+
+    def _fetch_event_rows(self, txn, events):
+        rows = []
+        N = 200
+        for i in range(1 + len(events) / N):
+            evs = events[i*N:(i + 1)*N]
+            if not evs:
+                break
+
+            sql = (
+                "SELECT "
+                " e.event_id as event_id, "
+                " e.internal_metadata,"
+                " e.json,"
+                " r.redacts as redacts,"
+                " rej.event_id as rejects "
+                " FROM event_json as e"
+                " LEFT JOIN rejections as rej USING (event_id)"
+                " LEFT JOIN redactions as r ON e.event_id = r.redacts"
+                " WHERE e.event_id IN (%s)"
+            ) % (",".join(["?"]*len(evs)),)
+
+            txn.execute(sql, evs)
+            rows.extend(self.cursor_to_dict(txn))
+
+        return rows
+
+    def _fetch_events_txn(self, txn, events, check_redacted=True,
+                          get_prev_content=False, allow_rejected=False):
+        if not events:
+            return {}
+
+        rows = self._fetch_event_rows(
+            txn, events,
+        )
+
+        if not allow_rejected:
+            rows[:] = [r for r in rows if not r["rejects"]]
+
+        res = [
+            self._get_event_from_row_txn(
+                txn,
+                row["internal_metadata"], row["json"], row["redacts"],
+                check_redacted=check_redacted,
+                get_prev_content=get_prev_content,
+                rejected_reason=row["rejects"],
+            )
+            for row in rows
+        ]
+
+        return {
+            r.event_id: r
+            for r in res
+        }
+
+    @defer.inlineCallbacks
+    def _get_event_from_row(self, internal_metadata, js, redacted,
+                            check_redacted=True, get_prev_content=False,
+                            rejected_reason=None):
+        d = json.loads(js)
+        internal_metadata = json.loads(internal_metadata)
+
+        if rejected_reason:
+            rejected_reason = yield self._simple_select_one_onecol(
+                table="rejections",
+                keyvalues={"event_id": rejected_reason},
+                retcol="reason",
+                desc="_get_event_from_row",
+            )
+
+        ev = FrozenEvent(
+            d,
+            internal_metadata_dict=internal_metadata,
+            rejected_reason=rejected_reason,
+        )
+
+        if check_redacted and redacted:
+            ev = prune_event(ev)
+
+            redaction_id = yield self._simple_select_one_onecol(
+                table="redactions",
+                keyvalues={"redacts": ev.event_id},
+                retcol="event_id",
+                desc="_get_event_from_row",
+            )
+
+            ev.unsigned["redacted_by"] = redaction_id
+            # Get the redaction event.
+
+            because = yield self.get_event(
+                redaction_id,
+                check_redacted=False
+            )
+
+            if because:
+                ev.unsigned["redacted_because"] = because
+
+        if get_prev_content and "replaces_state" in ev.unsigned:
+            prev = yield self.get_event(
+                ev.unsigned["replaces_state"],
+                get_prev_content=False,
+            )
+            if prev:
+                ev.unsigned["prev_content"] = prev.get_dict()["content"]
+
+        self._get_event_cache.prefill(
+            ev.event_id, check_redacted, get_prev_content, ev
+        )
+
+        defer.returnValue(ev)
+
+    def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
+                                check_redacted=True, get_prev_content=False,
+                                rejected_reason=None):
+        d = json.loads(js)
+        internal_metadata = json.loads(internal_metadata)
+
+        if rejected_reason:
+            rejected_reason = self._simple_select_one_onecol_txn(
+                txn,
+                table="rejections",
+                keyvalues={"event_id": rejected_reason},
+                retcol="reason",
+            )
+
+        ev = FrozenEvent(
+            d,
+            internal_metadata_dict=internal_metadata,
+            rejected_reason=rejected_reason,
+        )
+
+        if check_redacted and redacted:
+            ev = prune_event(ev)
+
+            redaction_id = self._simple_select_one_onecol_txn(
+                txn,
+                table="redactions",
+                keyvalues={"redacts": ev.event_id},
+                retcol="event_id",
+            )
+
+            ev.unsigned["redacted_by"] = redaction_id
+            # Get the redaction event.
+
+            because = self._get_event_txn(
+                txn,
+                redaction_id,
+                check_redacted=False
+            )
+
+            if because:
+                ev.unsigned["redacted_because"] = because
+
+        if get_prev_content and "replaces_state" in ev.unsigned:
+            prev = self._get_event_txn(
+                txn,
+                ev.unsigned["replaces_state"],
+                get_prev_content=False,
+            )
+            if prev:
+                ev.unsigned["prev_content"] = prev.get_dict()["content"]
+
+        self._get_event_cache.prefill(
+            ev.event_id, check_redacted, get_prev_content, ev
+        )
+
+        return ev
+
+    def _parse_events(self, rows):
+        return self.runInteraction(
+            "_parse_events", self._parse_events_txn, rows
+        )
+
+    def _parse_events_txn(self, txn, rows):
+        event_ids = [r["event_id"] for r in rows]
+
+        return self._get_events_txn(txn, event_ids)
+
+    def _has_been_redacted_txn(self, txn, event):
+        sql = "SELECT event_id FROM redactions WHERE redacts = ?"
+        txn.execute(sql, (event.event_id,))
+        result = txn.fetchone()
+        return result[0] if result else None
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 22ec94bc16..fefcf6bce0 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -13,7 +13,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, cached
+
+from twisted.internet import defer
 
 
 class PresenceStore(SQLBaseStore):
@@ -87,31 +89,48 @@ class PresenceStore(SQLBaseStore):
             desc="add_presence_list_pending",
         )
 
+    @defer.inlineCallbacks
     def set_presence_list_accepted(self, observer_localpart, observed_userid):
-        return self._simple_update_one(
+        result = yield self._simple_update_one(
             table="presence_list",
             keyvalues={"user_id": observer_localpart,
                        "observed_user_id": observed_userid},
             updatevalues={"accepted": True},
             desc="set_presence_list_accepted",
         )
+        self.get_presence_list_accepted.invalidate(observer_localpart)
+        defer.returnValue(result)
 
     def get_presence_list(self, observer_localpart, accepted=None):
-        keyvalues = {"user_id": observer_localpart}
-        if accepted is not None:
-            keyvalues["accepted"] = accepted
+        if accepted:
+            return self.get_presence_list_accepted(observer_localpart)
+        else:
+            keyvalues = {"user_id": observer_localpart}
+            if accepted is not None:
+                keyvalues["accepted"] = accepted
 
+            return self._simple_select_list(
+                table="presence_list",
+                keyvalues=keyvalues,
+                retcols=["observed_user_id", "accepted"],
+                desc="get_presence_list",
+            )
+
+    @cached()
+    def get_presence_list_accepted(self, observer_localpart):
         return self._simple_select_list(
             table="presence_list",
-            keyvalues=keyvalues,
+            keyvalues={"user_id": observer_localpart, "accepted": True},
             retcols=["observed_user_id", "accepted"],
-            desc="get_presence_list",
+            desc="get_presence_list_accepted",
         )
 
+    @defer.inlineCallbacks
     def del_presence_list(self, observer_localpart, observed_userid):
-        return self._simple_delete_one(
+        yield self._simple_delete_one(
             table="presence_list",
             keyvalues={"user_id": observer_localpart,
                        "observed_user_id": observed_userid},
             desc="del_presence_list",
         )
+        self.get_presence_list_accepted.invalidate(observer_localpart)
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 88ee21b089..4cac118d17 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -23,6 +23,7 @@ logger = logging.getLogger(__name__)
 
 
 class PushRuleStore(SQLBaseStore):
+    @cached()
     @defer.inlineCallbacks
     def get_push_rules_for_user(self, user_name):
         rows = yield self._simple_select_list(
@@ -31,6 +32,7 @@ class PushRuleStore(SQLBaseStore):
                 "user_name": user_name,
             },
             retcols=PushRuleTable.fields,
+            desc="get_push_rules_enabled_for_user",
         )
 
         rows.sort(
@@ -151,6 +153,10 @@ class PushRuleStore(SQLBaseStore):
             txn.execute(sql, (user_name, priority_class, new_rule_priority))
 
         txn.call_after(
+            self.get_push_rules_for_user.invalidate, user_name
+        )
+
+        txn.call_after(
             self.get_push_rules_enabled_for_user.invalidate, user_name
         )
 
@@ -183,6 +189,9 @@ class PushRuleStore(SQLBaseStore):
         new_rule['priority'] = new_prio
 
         txn.call_after(
+            self.get_push_rules_for_user.invalidate, user_name
+        )
+        txn.call_after(
             self.get_push_rules_enabled_for_user.invalidate, user_name
         )
 
@@ -208,17 +217,34 @@ class PushRuleStore(SQLBaseStore):
             {'user_name': user_name, 'rule_id': rule_id},
             desc="delete_push_rule",
         )
+
+        self.get_push_rules_for_user.invalidate(user_name)
         self.get_push_rules_enabled_for_user.invalidate(user_name)
 
     @defer.inlineCallbacks
     def set_push_rule_enabled(self, user_name, rule_id, enabled):
-        yield self._simple_upsert(
+        ret = yield self.runInteraction(
+            "_set_push_rule_enabled_txn",
+            self._set_push_rule_enabled_txn,
+            user_name, rule_id, enabled
+        )
+        defer.returnValue(ret)
+
+    def _set_push_rule_enabled_txn(self, txn, user_name, rule_id, enabled):
+        new_id = self._push_rules_enable_id_gen.get_next_txn(txn)
+        self._simple_upsert_txn(
+            txn,
             PushRuleEnableTable.table_name,
             {'user_name': user_name, 'rule_id': rule_id},
             {'enabled': 1 if enabled else 0},
-            desc="set_push_rule_enabled",
+            {'id': new_id},
+        )
+        txn.call_after(
+            self.get_push_rules_for_user.invalidate, user_name
+        )
+        txn.call_after(
+            self.get_push_rules_enabled_for_user.invalidate, user_name
         )
-        self.get_push_rules_enabled_for_user.invalidate(user_name)
 
 
 class RuleNotFoundException(Exception):
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 3691eade05..d36a6c18a8 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -77,16 +77,16 @@ class RoomMemberStore(SQLBaseStore):
         Returns:
             Deferred: Results in a MembershipEvent or None.
         """
-        def f(txn):
-            events = self._get_members_events_txn(
-                txn,
-                room_id,
-                user_id=user_id,
-            )
-
-            return events[0] if events else None
-
-        return self.runInteraction("get_room_member", f)
+        return self.runInteraction(
+            "get_room_member",
+            self._get_members_events_txn,
+            room_id,
+            user_id=user_id,
+        ).addCallback(
+            self._get_events
+        ).addCallback(
+            lambda events: events[0] if events else None
+        )
 
     @cached()
     def get_users_in_room(self, room_id):
@@ -112,15 +112,12 @@ class RoomMemberStore(SQLBaseStore):
         Returns:
             list of namedtuples representing the members in this room.
         """
-
-        def f(txn):
-            return self._get_members_events_txn(
-                txn,
-                room_id,
-                membership=membership,
-            )
-
-        return self.runInteraction("get_room_members", f)
+        return self.runInteraction(
+            "get_room_members",
+            self._get_members_events_txn,
+            room_id,
+            membership=membership,
+        ).addCallback(self._get_events)
 
     def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
         """ Get all the rooms for this user where the membership for this user
@@ -192,14 +189,14 @@ class RoomMemberStore(SQLBaseStore):
         return self.runInteraction(
             "get_members_query", self._get_members_events_txn,
             where_clause, where_values
-        )
+        ).addCallbacks(self._get_events)
 
     def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
         rows = self._get_members_rows_txn(
             txn,
             room_id, membership, user_id,
         )
-        return self._get_events_txn(txn, [r["event_id"] for r in rows])
+        return [r["event_id"] for r in rows]
 
     def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
         where_clause = "c.room_id = ?"
diff --git a/synapse/storage/schema/delta/19/event_index.sql b/synapse/storage/schema/delta/19/event_index.sql
new file mode 100644
index 0000000000..3881fc9897
--- /dev/null
+++ b/synapse/storage/schema/delta/19/event_index.sql
@@ -0,0 +1,19 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE INDEX events_order_topo_stream_room ON events(
+    topological_ordering, stream_ordering, room_id
+);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 6df7350552..b24de34f23 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -43,6 +43,7 @@ class StateStore(SQLBaseStore):
       * `state_groups_state`: Maps state group to state events.
     """
 
+    @defer.inlineCallbacks
     def get_state_groups(self, event_ids):
         """ Get the state groups for the given list of event_ids
 
@@ -71,17 +72,29 @@ class StateStore(SQLBaseStore):
                     retcol="event_id",
                 )
 
-                state = self._get_events_txn(txn, state_ids)
-
-                res[group] = state
+                res[group] = state_ids
 
             return res
 
-        return self.runInteraction(
+        states = yield self.runInteraction(
             "get_state_groups",
             f,
         )
 
+        @defer.inlineCallbacks
+        def c(vals):
+            vals[:] = yield self._get_events(vals, get_prev_content=False)
+
+        yield defer.gatherResults(
+            [
+                c(vals)
+                for vals in states.values()
+            ],
+            consumeErrors=True,
+        )
+
+        defer.returnValue(states)
+
     def _store_state_groups_txn(self, txn, event, context):
         if context.current_state is None:
             return
@@ -152,11 +165,12 @@ class StateStore(SQLBaseStore):
                 args = (room_id, )
 
             txn.execute(sql, args)
-            results = self.cursor_to_dict(txn)
+            results = txn.fetchall()
 
-            return self._parse_events_txn(txn, results)
+            return [r[0] for r in results]
 
-        events = yield self.runInteraction("get_current_state", f)
+        event_ids = yield self.runInteraction("get_current_state", f)
+        events = yield self._get_events(event_ids, get_prev_content=False)
         defer.returnValue(events)
 
     @cached(num_args=3)
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 280d4ad605..af45fc5619 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -37,11 +37,9 @@ from twisted.internet import defer
 
 from ._base import SQLBaseStore
 from synapse.api.constants import EventTypes
-from synapse.api.errors import SynapseError
+from synapse.types import RoomStreamToken
 from synapse.util.logutils import log_function
 
-from collections import namedtuple
-
 import logging
 
 
@@ -55,76 +53,26 @@ _STREAM_TOKEN = "stream"
 _TOPOLOGICAL_TOKEN = "topological"
 
 
-class _StreamToken(namedtuple("_StreamToken", "topological stream")):
-    """Tokens are positions between events. The token "s1" comes after event 1.
-
-            s0    s1
-            |     |
-        [0] V [1] V [2]
-
-    Tokens can either be a point in the live event stream or a cursor going
-    through historic events.
-
-    When traversing the live event stream events are ordered by when they
-    arrived at the homeserver.
-
-    When traversing historic events the events are ordered by their depth in
-    the event graph "topological_ordering" and then by when they arrived at the
-    homeserver "stream_ordering".
-
-    Live tokens start with an "s" followed by the "stream_ordering" id of the
-    event it comes after. Historic tokens start with a "t" followed by the
-    "topological_ordering" id of the event it comes after, follewed by "-",
-    followed by the "stream_ordering" id of the event it comes after.
-    """
-    __slots__ = []
-
-    @classmethod
-    def parse(cls, string):
-        try:
-            if string[0] == 's':
-                return cls(topological=None, stream=int(string[1:]))
-            if string[0] == 't':
-                parts = string[1:].split('-', 1)
-                return cls(topological=int(parts[0]), stream=int(parts[1]))
-        except:
-            pass
-        raise SynapseError(400, "Invalid token %r" % (string,))
-
-    @classmethod
-    def parse_stream_token(cls, string):
-        try:
-            if string[0] == 's':
-                return cls(topological=None, stream=int(string[1:]))
-        except:
-            pass
-        raise SynapseError(400, "Invalid token %r" % (string,))
-
-    def __str__(self):
-        if self.topological is not None:
-            return "t%d-%d" % (self.topological, self.stream)
-        else:
-            return "s%d" % (self.stream,)
+def lower_bound(token):
+    if token.topological is None:
+        return "(%d < %s)" % (token.stream, "stream_ordering")
+    else:
+        return "(%d < %s OR (%d = %s AND %d < %s))" % (
+            token.topological, "topological_ordering",
+            token.topological, "topological_ordering",
+            token.stream, "stream_ordering",
+        )
 
-    def lower_bound(self):
-        if self.topological is None:
-            return "(%d < %s)" % (self.stream, "stream_ordering")
-        else:
-            return "(%d < %s OR (%d = %s AND %d < %s))" % (
-                self.topological, "topological_ordering",
-                self.topological, "topological_ordering",
-                self.stream, "stream_ordering",
-            )
 
-    def upper_bound(self):
-        if self.topological is None:
-            return "(%d >= %s)" % (self.stream, "stream_ordering")
-        else:
-            return "(%d > %s OR (%d = %s AND %d >= %s))" % (
-                self.topological, "topological_ordering",
-                self.topological, "topological_ordering",
-                self.stream, "stream_ordering",
-            )
+def upper_bound(token):
+    if token.topological is None:
+        return "(%d >= %s)" % (token.stream, "stream_ordering")
+    else:
+        return "(%d > %s OR (%d = %s AND %d >= %s))" % (
+            token.topological, "topological_ordering",
+            token.topological, "topological_ordering",
+            token.stream, "stream_ordering",
+        )
 
 
 class StreamStore(SQLBaseStore):
@@ -139,8 +87,8 @@ class StreamStore(SQLBaseStore):
             limit = MAX_STREAM_SIZE
 
         # From and to keys should be integers from ordering.
-        from_id = _StreamToken.parse_stream_token(from_key)
-        to_id = _StreamToken.parse_stream_token(to_key)
+        from_id = RoomStreamToken.parse_stream_token(from_key)
+        to_id = RoomStreamToken.parse_stream_token(to_key)
 
         if from_key == to_key:
             defer.returnValue(([], to_key))
@@ -234,8 +182,8 @@ class StreamStore(SQLBaseStore):
             limit = MAX_STREAM_SIZE
 
         # From and to keys should be integers from ordering.
-        from_id = _StreamToken.parse_stream_token(from_key)
-        to_id = _StreamToken.parse_stream_token(to_key)
+        from_id = RoomStreamToken.parse_stream_token(from_key)
+        to_id = RoomStreamToken.parse_stream_token(to_key)
 
         if from_key == to_key:
             return defer.succeed(([], to_key))
@@ -276,7 +224,7 @@ class StreamStore(SQLBaseStore):
 
         return self.runInteraction("get_room_events_stream", f)
 
-    @log_function
+    @defer.inlineCallbacks
     def paginate_room_events(self, room_id, from_key, to_key=None,
                              direction='b', limit=-1,
                              with_feedback=False):
@@ -288,17 +236,17 @@ class StreamStore(SQLBaseStore):
         args = [False, room_id]
         if direction == 'b':
             order = "DESC"
-            bounds = _StreamToken.parse(from_key).upper_bound()
+            bounds = upper_bound(RoomStreamToken.parse(from_key))
             if to_key:
                 bounds = "%s AND %s" % (
-                    bounds, _StreamToken.parse(to_key).lower_bound()
+                    bounds, lower_bound(RoomStreamToken.parse(to_key))
                 )
         else:
             order = "ASC"
-            bounds = _StreamToken.parse(from_key).lower_bound()
+            bounds = lower_bound(RoomStreamToken.parse(from_key))
             if to_key:
                 bounds = "%s AND %s" % (
-                    bounds, _StreamToken.parse(to_key).upper_bound()
+                    bounds, upper_bound(RoomStreamToken.parse(to_key))
                 )
 
         if int(limit) > 0:
@@ -333,28 +281,30 @@ class StreamStore(SQLBaseStore):
                     # when we are going backwards so we subtract one from the
                     # stream part.
                     toke -= 1
-                next_token = str(_StreamToken(topo, toke))
+                next_token = str(RoomStreamToken(topo, toke))
             else:
                 # TODO (erikj): We should work out what to do here instead.
                 next_token = to_key if to_key else from_key
 
-            events = self._get_events_txn(
-                txn,
-                [r["event_id"] for r in rows],
-                get_prev_content=True
-            )
+            return rows, next_token,
 
-            self._set_before_and_after(events, rows)
+        rows, token = yield self.runInteraction("paginate_room_events", f)
+
+        events = yield self._get_events(
+            [r["event_id"] for r in rows],
+            get_prev_content=True
+        )
 
-            return events, next_token,
+        self._set_before_and_after(events, rows)
 
-        return self.runInteraction("paginate_room_events", f)
+        defer.returnValue((events, token))
 
+    @defer.inlineCallbacks
     def get_recent_events_for_room(self, room_id, limit, end_token,
                                    with_feedback=False, from_token=None):
         # TODO (erikj): Handle compressed feedback
 
-        end_token = _StreamToken.parse_stream_token(end_token)
+        end_token = RoomStreamToken.parse_stream_token(end_token)
 
         if from_token is None:
             sql = (
@@ -365,7 +315,7 @@ class StreamStore(SQLBaseStore):
                 " LIMIT ?"
             )
         else:
-            from_token = _StreamToken.parse_stream_token(from_token)
+            from_token = RoomStreamToken.parse_stream_token(from_token)
             sql = (
                 "SELECT stream_ordering, topological_ordering, event_id"
                 " FROM events"
@@ -395,30 +345,49 @@ class StreamStore(SQLBaseStore):
                 # stream part.
                 topo = rows[0]["topological_ordering"]
                 toke = rows[0]["stream_ordering"] - 1
-                start_token = str(_StreamToken(topo, toke))
+                start_token = str(RoomStreamToken(topo, toke))
 
                 token = (start_token, str(end_token))
             else:
                 token = (str(end_token), str(end_token))
 
-            events = self._get_events_txn(
-                txn,
-                [r["event_id"] for r in rows],
-                get_prev_content=True
-            )
-
-            self._set_before_and_after(events, rows)
+            return rows, token
 
-            return events, token
-
-        return self.runInteraction(
+        rows, token = yield self.runInteraction(
             "get_recent_events_for_room", get_recent_events_for_room_txn
         )
 
+        logger.debug("stream before")
+        events = yield self._get_events(
+            [r["event_id"] for r in rows],
+            get_prev_content=True
+        )
+        logger.debug("stream after")
+
+        self._set_before_and_after(events, rows)
+
+        defer.returnValue((events, token))
+
     @defer.inlineCallbacks
-    def get_room_events_max_id(self):
+    def get_room_events_max_id(self, direction='f'):
         token = yield self._stream_id_gen.get_max_token(self)
-        defer.returnValue("s%d" % (token,))
+        if direction != 'b':
+            defer.returnValue("s%d" % (token,))
+        else:
+            topo = yield self.runInteraction(
+                "_get_max_topological_txn", self._get_max_topological_txn
+            )
+            defer.returnValue("t%d-%d" % (topo, token))
+
+    def _get_max_topological_txn(self, txn):
+        txn.execute(
+            "SELECT MAX(topological_ordering) FROM events"
+            " WHERE outlier = ?",
+            (False,)
+        )
+
+        rows = txn.fetchall()
+        return rows[0][0] if rows else 0
 
     @defer.inlineCallbacks
     def _get_min_token(self):
@@ -439,5 +408,5 @@ class StreamStore(SQLBaseStore):
             stream = row["stream_ordering"]
             topo = event.depth
             internal = event.internal_metadata
-            internal.before = str(_StreamToken(topo, stream - 1))
-            internal.after = str(_StreamToken(topo, stream))
+            internal.before = str(RoomStreamToken(topo, stream - 1))
+            internal.after = str(RoomStreamToken(topo, stream))
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index e40eb8a8c4..89d1643f10 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -78,14 +78,18 @@ class StreamIdGenerator(object):
         self._current_max = None
         self._unfinished_ids = deque()
 
-    def get_next_txn(self, txn):
+    @defer.inlineCallbacks
+    def get_next(self, store):
         """
         Usage:
-            with stream_id_gen.get_next_txn(txn) as stream_id:
+            with yield stream_id_gen.get_next as stream_id:
                 # ... persist event ...
         """
         if not self._current_max:
-            self._get_or_compute_current_max(txn)
+            yield store.runInteraction(
+                "_compute_current_max",
+                self._get_or_compute_current_max,
+            )
 
         with self._lock:
             self._current_max += 1
@@ -101,7 +105,7 @@ class StreamIdGenerator(object):
                 with self._lock:
                     self._unfinished_ids.remove(next_id)
 
-        return manager()
+        defer.returnValue(manager())
 
     @defer.inlineCallbacks
     def get_max_token(self, store):