summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2016-03-23 16:55:29 +0000
committerErik Johnston <erik@matrix.org>2016-03-23 16:55:29 +0000
commit7a3815b372552e516cc3619c4ea143f610358206 (patch)
tree8ff286640cbc715ae82f7c13fcadfed218373070 /synapse/storage
parentEnglish (diff)
parentMerge pull request #666 from matrix-org/erikj/intern (diff)
downloadsynapse-7a3815b372552e516cc3619c4ea143f610358206.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into release-v0.14.0
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py5
-rw-r--r--synapse/storage/_base.py11
-rw-r--r--synapse/storage/directory.py2
-rw-r--r--synapse/storage/event_push_actions.py2
-rw-r--r--synapse/storage/events.py77
-rw-r--r--synapse/storage/pusher.py63
-rw-r--r--synapse/storage/receipts.py21
-rw-r--r--synapse/storage/room.py8
-rw-r--r--synapse/storage/roommember.py34
-rw-r--r--synapse/storage/schema/delta/30/deleted_pushers.sql25
-rw-r--r--synapse/storage/schema/delta/30/public_rooms.sql23
-rw-r--r--synapse/storage/state.py128
-rw-r--r--synapse/storage/stream.py33
-rw-r--r--synapse/storage/util/id_generators.py7
14 files changed, 262 insertions, 177 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 168eb27b03..250ba536ea 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -119,12 +119,15 @@ class DataStore(RoomMemberStore, RoomStore,
         self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
         self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
         self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
-        self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
         self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
         self._push_rules_stream_id_gen = ChainedIdGenerator(
             self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
         )
+        self._pushers_id_gen = StreamIdGenerator(
+            db_conn, "pushers", "id",
+            extra_tables=[("deleted_pushers", "stream_id")],
+        )
 
         events_max = self._stream_id_gen.get_max_token()
         event_cache_prefill, min_event_val = self._get_cache_dict(
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 7dc67ecd57..b75b79df36 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -18,6 +18,7 @@ from synapse.api.errors import StoreError
 from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 from synapse.util.caches.dictionary_cache import DictionaryCache
 from synapse.util.caches.descriptors import Cache
+from synapse.util.caches import intern_dict
 import synapse.metrics
 
 
@@ -26,6 +27,10 @@ from twisted.internet import defer
 import sys
 import time
 import threading
+import os
+
+
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
 
 
 logger = logging.getLogger(__name__)
@@ -163,7 +168,9 @@ class SQLBaseStore(object):
         self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
                                       max_entries=hs.config.event_cache_size)
 
-        self._state_group_cache = DictionaryCache("*stateGroupCache*", 2000)
+        self._state_group_cache = DictionaryCache(
+            "*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR
+        )
 
         self._event_fetch_lock = threading.Condition()
         self._event_fetch_list = []
@@ -344,7 +351,7 @@ class SQLBaseStore(object):
         """
         col_headers = list(column[0] for column in cursor.description)
         results = list(
-            dict(zip(col_headers, row)) for row in cursor.fetchall()
+            intern_dict(dict(zip(col_headers, row))) for row in cursor.fetchall()
         )
         return results
 
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 012a0b414a..ef231a04dc 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -155,7 +155,7 @@ class DirectoryStore(SQLBaseStore):
 
         return room_id
 
-    @cached()
+    @cached(max_entries=5000)
     def get_aliases_for_room(self, room_id):
         return self._simple_select_onecol(
             "room_aliases",
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 5820539a92..dc5830450a 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -49,7 +49,7 @@ class EventPushActionsStore(SQLBaseStore):
             )
         self._simple_insert_many_txn(txn, "event_push_actions", values)
 
-    @cachedInlineCallbacks(num_args=3, lru=True, tree=True)
+    @cachedInlineCallbacks(num_args=3, lru=True, tree=True, max_entries=5000)
     def get_unread_event_push_actions_by_room_for_user(
             self, room_id, user_id, last_read_event_id
     ):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 552e7ca35b..5233430028 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -101,30 +101,16 @@ class EventsStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     @log_function
-    def persist_event(self, event, context, backfilled=False,
+    def persist_event(self, event, context,
                       is_new_state=True, current_state=None):
-        stream_ordering = None
-        if backfilled:
-            self.min_stream_token -= 1
-            stream_ordering = self.min_stream_token
-
-        if stream_ordering is None:
-            stream_ordering_manager = self._stream_id_gen.get_next()
-        else:
-            @contextmanager
-            def stream_ordering_manager():
-                yield stream_ordering
-            stream_ordering_manager = stream_ordering_manager()
-
         try:
-            with stream_ordering_manager as stream_ordering:
+            with self._stream_id_gen.get_next() as stream_ordering:
                 event.internal_metadata.stream_ordering = stream_ordering
                 yield self.runInteraction(
                     "persist_event",
                     self._persist_event_txn,
                     event=event,
                     context=context,
-                    backfilled=backfilled,
                     is_new_state=is_new_state,
                     current_state=current_state,
                 )
@@ -165,13 +151,38 @@ class EventsStore(SQLBaseStore):
 
         defer.returnValue(events[0] if events else None)
 
+    @defer.inlineCallbacks
+    def get_events(self, event_ids, check_redacted=True,
+                   get_prev_content=False, allow_rejected=False):
+        """Get events from the database
+
+        Args:
+            event_ids (list): The event_ids of the events to fetch
+            check_redacted (bool): If True, check if event has been redacted
+                and redact it.
+            get_prev_content (bool): If True and event is a state event,
+                include the previous states content in the unsigned field.
+            allow_rejected (bool): If True return rejected events.
+
+        Returns:
+            Deferred : Dict from event_id to event.
+        """
+        events = yield self._get_events(
+            event_ids,
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            allow_rejected=allow_rejected,
+        )
+
+        defer.returnValue({e.event_id: e for e in events})
+
     @log_function
-    def _persist_event_txn(self, txn, event, context, backfilled,
+    def _persist_event_txn(self, txn, event, context,
                            is_new_state=True, current_state=None):
         # We purposefully do this first since if we include a `current_state`
         # key, we *want* to update the `current_state_events` table
         if current_state:
-            txn.call_after(self.get_current_state_for_key.invalidate_all)
+            txn.call_after(self._get_current_state_for_key.invalidate_all)
             txn.call_after(self.get_rooms_for_user.invalidate_all)
             txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
             txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
@@ -198,7 +209,7 @@ class EventsStore(SQLBaseStore):
         return self._persist_events_txn(
             txn,
             [(event, context)],
-            backfilled=backfilled,
+            backfilled=False,
             is_new_state=is_new_state,
         )
 
@@ -455,7 +466,7 @@ class EventsStore(SQLBaseStore):
             for event, _ in state_events_and_contexts:
                 if not context.rejected:
                     txn.call_after(
-                        self.get_current_state_for_key.invalidate,
+                        self._get_current_state_for_key.invalidate,
                         (event.room_id, event.type, event.state_key,)
                     )
 
@@ -526,6 +537,9 @@ class EventsStore(SQLBaseStore):
         if not event_ids:
             defer.returnValue([])
 
+        event_id_list = event_ids
+        event_ids = set(event_ids)
+
         event_map = self._get_events_from_cache(
             event_ids,
             check_redacted=check_redacted,
@@ -535,23 +549,18 @@ class EventsStore(SQLBaseStore):
 
         missing_events_ids = [e for e in event_ids if e not in event_map]
 
-        if not missing_events_ids:
-            defer.returnValue([
-                event_map[e_id] for e_id in event_ids
-                if e_id in event_map and event_map[e_id]
-            ])
-
-        missing_events = yield self._enqueue_events(
-            missing_events_ids,
-            check_redacted=check_redacted,
-            get_prev_content=get_prev_content,
-            allow_rejected=allow_rejected,
-        )
+        if missing_events_ids:
+            missing_events = yield self._enqueue_events(
+                missing_events_ids,
+                check_redacted=check_redacted,
+                get_prev_content=get_prev_content,
+                allow_rejected=allow_rejected,
+            )
 
-        event_map.update(missing_events)
+            event_map.update(missing_events)
 
         defer.returnValue([
-            event_map[e_id] for e_id in event_ids
+            event_map[e_id] for e_id in event_id_list
             if e_id in event_map and event_map[e_id]
         ])
 
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 7693ab9082..87b2ac5773 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -16,8 +16,6 @@
 from ._base import SQLBaseStore
 from twisted.internet import defer
 
-from synapse.api.errors import StoreError
-
 from canonicaljson import encode_canonical_json
 
 import logging
@@ -79,12 +77,41 @@ class PusherStore(SQLBaseStore):
         rows = yield self.runInteraction("get_all_pushers", get_pushers)
         defer.returnValue(rows)
 
+    def get_pushers_stream_token(self):
+        return self._pushers_id_gen.get_max_token()
+
+    def get_all_updated_pushers(self, last_id, current_id, limit):
+        def get_all_updated_pushers_txn(txn):
+            sql = (
+                "SELECT id, user_name, access_token, profile_tag, kind,"
+                " app_id, app_display_name, device_display_name, pushkey, ts,"
+                " lang, data"
+                " FROM pushers"
+                " WHERE ? < id AND id <= ?"
+                " ORDER BY id ASC LIMIT ?"
+            )
+            txn.execute(sql, (last_id, current_id, limit))
+            updated = txn.fetchall()
+
+            sql = (
+                "SELECT stream_id, user_id, app_id, pushkey"
+                " FROM deleted_pushers"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC LIMIT ?"
+            )
+            txn.execute(sql, (last_id, current_id, limit))
+            deleted = txn.fetchall()
+
+            return (updated, deleted)
+        return self.runInteraction(
+            "get_all_updated_pushers", get_all_updated_pushers_txn
+        )
+
     @defer.inlineCallbacks
     def add_pusher(self, user_id, access_token, kind, app_id,
                    app_display_name, device_display_name,
                    pushkey, pushkey_ts, lang, data, profile_tag=""):
-        try:
-            next_id = self._pushers_id_gen.get_next()
+        with self._pushers_id_gen.get_next() as stream_id:
             yield self._simple_upsert(
                 "pushers",
                 dict(
@@ -101,23 +128,29 @@ class PusherStore(SQLBaseStore):
                     lang=lang,
                     data=encode_canonical_json(data),
                     profile_tag=profile_tag,
-                ),
-                insertion_values=dict(
-                    id=next_id,
+                    id=stream_id,
                 ),
                 desc="add_pusher",
             )
-        except Exception as e:
-            logger.error("create_pusher with failed: %s", e)
-            raise StoreError(500, "Problem creating pusher.")
 
     @defer.inlineCallbacks
     def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
-        yield self._simple_delete_one(
-            "pushers",
-            {"app_id": app_id, "pushkey": pushkey, 'user_name': user_id},
-            desc="delete_pusher_by_app_id_pushkey_user_id",
-        )
+        def delete_pusher_txn(txn, stream_id):
+            self._simple_delete_one_txn(
+                txn,
+                "pushers",
+                {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}
+            )
+            self._simple_upsert_txn(
+                txn,
+                "deleted_pushers",
+                {"app_id": app_id, "pushkey": pushkey, "user_id": user_id},
+                {"stream_id": stream_id},
+            )
+        with self._pushers_id_gen.get_next() as stream_id:
+            yield self.runInteraction(
+                "delete_pusher", delete_pusher_txn, stream_id
+            )
 
     @defer.inlineCallbacks
     def update_pusher_last_token(self, app_id, pushkey, user_id, last_token):
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index dbc074d6b5..6b9d848eaa 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -62,18 +62,17 @@ class ReceiptsStore(SQLBaseStore):
 
     @cachedInlineCallbacks(num_args=2)
     def get_receipts_for_user(self, user_id, receipt_type):
-        def f(txn):
-            sql = (
-                "SELECT room_id,event_id "
-                "FROM receipts_linearized "
-                "WHERE user_id = ? AND receipt_type = ? "
-            )
-            txn.execute(sql, (user_id, receipt_type))
-            return txn.fetchall()
+        rows = yield self._simple_select_list(
+            table="receipts_linearized",
+            keyvalues={
+                "user_id": user_id,
+                "receipt_type": receipt_type,
+            },
+            retcols=("room_id", "event_id"),
+            desc="get_receipts_for_user",
+        )
 
-        defer.returnValue(dict(
-            (yield self.runInteraction("get_receipts_for_user", f))
-        ))
+        defer.returnValue({row["room_id"]: row["event_id"] for row in rows})
 
     @defer.inlineCallbacks
     def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 46ab38a313..9be977f387 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -77,6 +77,14 @@ class RoomStore(SQLBaseStore):
             allow_none=True,
         )
 
+    def set_room_is_public(self, room_id, is_public):
+        return self._simple_update_one(
+            table="rooms",
+            keyvalues={"room_id": room_id},
+            updatevalues={"is_public": is_public},
+            desc="set_room_is_public",
+        )
+
     def get_public_room_ids(self):
         return self._simple_select_onecol(
             table="rooms",
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 3065b0c1a5..430b49c12e 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -115,19 +115,17 @@ class RoomMemberStore(SQLBaseStore):
         ).addCallback(self._get_events)
 
     @cached()
-    def get_invites_for_user(self, user_id):
-        """ Get all the invite events for a user
+    def get_invited_rooms_for_user(self, user_id):
+        """ Get all the rooms the user is invited to
         Args:
             user_id (str): The user ID.
         Returns:
-            A deferred list of event objects.
+            A deferred list of RoomsForUser.
         """
 
         return self.get_rooms_for_user_where_membership_is(
             user_id, [Membership.INVITE]
-        ).addCallback(lambda invites: self._get_events([
-            invite.event_id for invite in invites
-        ]))
+        )
 
     def get_leave_and_ban_events_for_user(self, user_id):
         """ Get all the leave events for a user
@@ -252,30 +250,6 @@ class RoomMemberStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def user_rooms_intersect(self, user_id_list):
-        """ Checks whether all the users whose IDs are given in a list share a
-        room.
-
-        This is a "hot path" function that's called a lot, e.g. by presence for
-        generating the event stream. As such, it is implemented locally by
-        wrapping logic around heavily-cached database queries.
-        """
-        if len(user_id_list) < 2:
-            defer.returnValue(True)
-
-        deferreds = [self.get_rooms_for_user(u) for u in user_id_list]
-
-        results = yield defer.DeferredList(deferreds, consumeErrors=True)
-
-        # A list of sets of strings giving room IDs for each user
-        room_id_lists = [set([r.room_id for r in result[1]]) for result in results]
-
-        # There isn't a setintersection(*list_of_sets)
-        ret = len(room_id_lists.pop(0).intersection(*room_id_lists)) > 0
-
-        defer.returnValue(ret)
-
-    @defer.inlineCallbacks
     def forget(self, user_id, room_id):
         """Indicate that user_id wishes to discard history for room_id."""
         def f(txn):
diff --git a/synapse/storage/schema/delta/30/deleted_pushers.sql b/synapse/storage/schema/delta/30/deleted_pushers.sql
new file mode 100644
index 0000000000..712c454aa1
--- /dev/null
+++ b/synapse/storage/schema/delta/30/deleted_pushers.sql
@@ -0,0 +1,25 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS deleted_pushers(
+    stream_id BIGINT NOT NULL,
+    app_id TEXT NOT NULL,
+    pushkey TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    /* We only track the most recent delete for each app_id, pushkey and user_id. */
+    UNIQUE (app_id, pushkey, user_id)
+);
+
+CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id);
diff --git a/synapse/storage/schema/delta/30/public_rooms.sql b/synapse/storage/schema/delta/30/public_rooms.sql
new file mode 100644
index 0000000000..f09db4faa6
--- /dev/null
+++ b/synapse/storage/schema/delta/30/public_rooms.sql
@@ -0,0 +1,23 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+/* This release removes the restriction that published rooms must have an alias,
+ * so we go back and ensure the only 'public' rooms are ones with an alias.
+ * We use (1 = 0) and (1 = 1) so that it works in both postgres and sqlite
+ */
+UPDATE rooms SET is_public = (1 = 0) WHERE is_public = (1 = 1) AND room_id not in (
+    SELECT room_id FROM room_aliases
+);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 8ed8a21b0a..02cefdff26 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,9 +14,8 @@
 # limitations under the License.
 
 from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import (
-    cached, cachedInlineCallbacks, cachedList
-)
+from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches import intern_string
 
 from twisted.internet import defer
 
@@ -155,8 +154,14 @@ class StateStore(SQLBaseStore):
         events = yield self._get_events(event_ids, get_prev_content=False)
         defer.returnValue(events)
 
-    @cachedInlineCallbacks(num_args=3)
+    @defer.inlineCallbacks
     def get_current_state_for_key(self, room_id, event_type, state_key):
+        event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key)
+        events = yield self._get_events(event_ids, get_prev_content=False)
+        defer.returnValue(events)
+
+    @cached(num_args=3)
+    def _get_current_state_for_key(self, room_id, event_type, state_key):
         def f(txn):
             sql = (
                 "SELECT event_id FROM current_state_events"
@@ -167,12 +172,10 @@ class StateStore(SQLBaseStore):
             txn.execute(sql, args)
             results = txn.fetchall()
             return [r[0] for r in results]
-        event_ids = yield self.runInteraction("get_current_state_for_key", f)
-        events = yield self._get_events(event_ids, get_prev_content=False)
-        defer.returnValue(events)
+        return self.runInteraction("get_current_state_for_key", f)
 
     def _get_state_groups_from_groups(self, groups, types):
-        """Returns dictionary state_group -> state event ids
+        """Returns dictionary state_group -> (dict of (type, state_key) -> event id)
         """
         def f(txn, groups):
             if types is not None:
@@ -183,7 +186,8 @@ class StateStore(SQLBaseStore):
                 where_clause = ""
 
             sql = (
-                "SELECT state_group, event_id FROM state_groups_state WHERE"
+                "SELECT state_group, event_id, type, state_key"
+                " FROM state_groups_state WHERE"
                 " state_group IN (%s) %s" % (
                     ",".join("?" for _ in groups),
                     where_clause,
@@ -199,7 +203,8 @@ class StateStore(SQLBaseStore):
 
             results = {}
             for row in rows:
-                results.setdefault(row["state_group"], []).append(row["event_id"])
+                key = (row["type"], row["state_key"])
+                results.setdefault(row["state_group"], {})[key] = row["event_id"]
             return results
 
         chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
@@ -296,7 +301,7 @@ class StateStore(SQLBaseStore):
                 where a `state_key` of `None` matches all state_keys for the
                 `type`.
         """
-        is_all, state_dict = self._state_group_cache.get(group)
+        is_all, state_dict_ids = self._state_group_cache.get(group)
 
         type_to_key = {}
         missing_types = set()
@@ -308,7 +313,7 @@ class StateStore(SQLBaseStore):
                 if type_to_key.get(typ, object()) is not None:
                     type_to_key.setdefault(typ, set()).add(state_key)
 
-                if (typ, state_key) not in state_dict:
+                if (typ, state_key) not in state_dict_ids:
                     missing_types.add((typ, state_key))
 
         sentinel = object()
@@ -326,7 +331,7 @@ class StateStore(SQLBaseStore):
         got_all = not (missing_types or types is None)
 
         return {
-            k: v for k, v in state_dict.items()
+            k: v for k, v in state_dict_ids.items()
             if include(k[0], k[1])
         }, missing_types, got_all
 
@@ -340,8 +345,9 @@ class StateStore(SQLBaseStore):
         Args:
             group: The state group to lookup
         """
-        is_all, state_dict = self._state_group_cache.get(group)
-        return state_dict, is_all
+        is_all, state_dict_ids = self._state_group_cache.get(group)
+
+        return state_dict_ids, is_all
 
     @defer.inlineCallbacks
     def _get_state_for_groups(self, groups, types=None):
@@ -354,84 +360,72 @@ class StateStore(SQLBaseStore):
         missing_groups = []
         if types is not None:
             for group in set(groups):
-                state_dict, missing_types, got_all = self._get_some_state_from_cache(
+                state_dict_ids, missing_types, got_all = self._get_some_state_from_cache(
                     group, types
                 )
-                results[group] = state_dict
+                results[group] = state_dict_ids
 
                 if not got_all:
                     missing_groups.append(group)
         else:
             for group in set(groups):
-                state_dict, got_all = self._get_all_state_from_cache(
+                state_dict_ids, got_all = self._get_all_state_from_cache(
                     group
                 )
-                results[group] = state_dict
+
+                results[group] = state_dict_ids
 
                 if not got_all:
                     missing_groups.append(group)
 
-        if not missing_groups:
-            defer.returnValue({
-                group: {
-                    type_tuple: event
-                    for type_tuple, event in state.items()
-                    if event
-                }
-                for group, state in results.items()
-            })
+        if missing_groups:
+            # Okay, so we have some missing_types, lets fetch them.
+            cache_seq_num = self._state_group_cache.sequence
 
-        # Okay, so we have some missing_types, lets fetch them.
-        cache_seq_num = self._state_group_cache.sequence
+            group_to_state_dict = yield self._get_state_groups_from_groups(
+                missing_groups, types
+            )
 
-        group_state_dict = yield self._get_state_groups_from_groups(
-            missing_groups, types
-        )
+            # Now we want to update the cache with all the things we fetched
+            # from the database.
+            for group, group_state_dict in group_to_state_dict.items():
+                if types:
+                    # We delibrately put key -> None mappings into the cache to
+                    # cache absence of the key, on the assumption that if we've
+                    # explicitly asked for some types then we will probably ask
+                    # for them again.
+                    state_dict = {
+                        (intern_string(etype), intern_string(state_key)): None
+                        for (etype, state_key) in types
+                    }
+                    state_dict.update(results[group])
+                    results[group] = state_dict
+                else:
+                    state_dict = results[group]
+
+                state_dict.update(group_state_dict)
+
+                self._state_group_cache.update(
+                    cache_seq_num,
+                    key=group,
+                    value=state_dict,
+                    full=(types is None),
+                )
 
         state_events = yield self._get_events(
-            [e_id for l in group_state_dict.values() for e_id in l],
+            [ev_id for sd in results.values() for ev_id in sd.values()],
             get_prev_content=False
         )
 
         state_events = {e.event_id: e for e in state_events}
 
-        # Now we want to update the cache with all the things we fetched
-        # from the database.
-        for group, state_ids in group_state_dict.items():
-            if types:
-                # We delibrately put key -> None mappings into the cache to
-                # cache absence of the key, on the assumption that if we've
-                # explicitly asked for some types then we will probably ask
-                # for them again.
-                state_dict = {key: None for key in types}
-                state_dict.update(results[group])
-                results[group] = state_dict
-            else:
-                state_dict = results[group]
-
-            for event_id in state_ids:
-                try:
-                    state_event = state_events[event_id]
-                    state_dict[(state_event.type, state_event.state_key)] = state_event
-                except KeyError:
-                    # Hmm. So we do don't have that state event? Interesting.
-                    logger.warn(
-                        "Can't find state event %r for state group %r",
-                        event_id, group,
-                    )
-
-            self._state_group_cache.update(
-                cache_seq_num,
-                key=group,
-                value=state_dict,
-                full=(types is None),
-            )
-
         # Remove all the entries with None values. The None values were just
         # used for bookkeeping in the cache.
         for group, state_dict in results.items():
             results[group] = {
-                key: event for key, event in state_dict.items() if event
+                key: state_events[event_id]
+                for key, event_id in state_dict.items()
+                if event_id and event_id in state_events
             }
 
         defer.returnValue(results)
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 7f4a827528..cf84938be5 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -36,7 +36,7 @@ what sort order was used:
 from twisted.internet import defer
 
 from ._base import SQLBaseStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
 from synapse.api.constants import EventTypes
 from synapse.types import RoomStreamToken
 from synapse.util.logcontext import preserve_fn
@@ -465,9 +465,25 @@ class StreamStore(SQLBaseStore):
 
         defer.returnValue((events, token))
 
-    @cachedInlineCallbacks(num_args=4)
+    @defer.inlineCallbacks
     def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None):
+        rows, token = yield self.get_recent_event_ids_for_room(
+            room_id, limit, end_token, from_token
+        )
+
+        logger.debug("stream before")
+        events = yield self._get_events(
+            [r["event_id"] for r in rows],
+            get_prev_content=True
+        )
+        logger.debug("stream after")
+
+        self._set_before_and_after(events, rows)
+
+        defer.returnValue((events, token))
 
+    @cached(num_args=4)
+    def get_recent_event_ids_for_room(self, room_id, limit, end_token, from_token=None):
         end_token = RoomStreamToken.parse_stream_token(end_token)
 
         if from_token is None:
@@ -517,21 +533,10 @@ class StreamStore(SQLBaseStore):
 
             return rows, token
 
-        rows, token = yield self.runInteraction(
+        return self.runInteraction(
             "get_recent_events_for_room", get_recent_events_for_room_txn
         )
 
-        logger.debug("stream before")
-        events = yield self._get_events(
-            [r["event_id"] for r in rows],
-            get_prev_content=True
-        )
-        logger.debug("stream after")
-
-        self._set_before_and_after(events, rows)
-
-        defer.returnValue((events, token))
-
     @defer.inlineCallbacks
     def get_room_events_max_id(self, direction='f'):
         token = yield self._stream_id_gen.get_max_token()
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 610ddad423..a02dfc7d58 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -49,9 +49,14 @@ class StreamIdGenerator(object):
         with stream_id_gen.get_next() as stream_id:
             # ... persist event ...
     """
-    def __init__(self, db_conn, table, column):
+    def __init__(self, db_conn, table, column, extra_tables=[]):
         self._lock = threading.Lock()
         self._current_max = _load_max_id(db_conn, table, column)
+        for table, column in extra_tables:
+            self._current_max = max(
+                self._current_max,
+                _load_max_id(db_conn, table, column)
+            )
         self._unfinished_ids = deque()
 
     def get_next(self):