summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-xsynapse/app/homeserver.py33
-rw-r--r--synapse/handlers/sync.py346
-rw-r--r--synapse/server.py31
-rw-r--r--synapse/storage/__init__.py45
-rw-r--r--synapse/storage/_base.py49
-rw-r--r--synapse/storage/events.py17
-rw-r--r--synapse/storage/receipts.py67
-rw-r--r--synapse/storage/stream.py151
-rw-r--r--synapse/storage/tags.py7
-rw-r--r--synapse/storage/util/id_generators.py36
-rw-r--r--synapse/util/caches/room_change_cache.py86
-rw-r--r--tests/storage/test_appservice.py6
-rw-r--r--tests/storage/test_registration.py3
-rw-r--r--tests/utils.py8
14 files changed, 507 insertions, 378 deletions
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index a83d3f394f..56a34bd50b 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -254,6 +254,17 @@ class SynapseHomeServer(HomeServer):
         except IncorrectDatabaseSetup as e:
             quit_with_error(e.message)
 
+    def get_db_conn(self):
+        db_conn = self.database_engine.module.connect(
+            **{
+                k: v for k, v in self.db_config.get("args", {}).items()
+                if not k.startswith("cp_")
+            }
+        )
+
+        self.database_engine.on_new_connection(db_conn)
+        return db_conn
+
 
 def quit_with_error(error_string):
     message_lines = error_string.split("\n")
@@ -390,13 +401,7 @@ def setup(config_options):
     logger.info("Preparing database: %s...", config.database_config['name'])
 
     try:
-        db_conn = database_engine.module.connect(
-            **{
-                k: v for k, v in config.database_config.get("args", {}).items()
-                if not k.startswith("cp_")
-            }
-        )
-
+        db_conn = hs.get_db_conn()
         database_engine.prepare_database(db_conn)
         hs.run_startup_checks(db_conn, database_engine)
 
@@ -411,13 +416,17 @@ def setup(config_options):
 
     logger.info("Database prepared in %s.", config.database_config['name'])
 
+    hs.setup()
     hs.start_listening()
 
-    hs.get_pusherpool().start()
-    hs.get_state_handler().start_caching()
-    hs.get_datastore().start_profiling()
-    hs.get_datastore().start_doing_background_updates()
-    hs.get_replication_layer().start_get_pdu_cache()
+    def start():
+        hs.get_pusherpool().start()
+        hs.get_state_handler().start_caching()
+        hs.get_datastore().start_profiling()
+        hs.get_datastore().start_doing_background_updates()
+        hs.get_replication_layer().start_get_pdu_cache()
+
+    reactor.callWhenRunning(start)
 
     return hs
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 328c049b03..f5e20d6a6e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -72,7 +72,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
         )
 
 
-class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
+class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
     "room_id",            # str
     "timeline",           # TimelineBatch
     "state",              # dict[(str, str), FrozenEvent]
@@ -429,44 +429,20 @@ class SyncHandler(BaseHandler):
 
         defer.returnValue((now_token, ephemeral_by_room))
 
-    @defer.inlineCallbacks
     def full_state_sync_for_archived_room(self, room_id, sync_config,
                                           leave_event_id, leave_token,
                                           timeline_since_token, tags_by_room,
                                           account_data_by_room):
         """Sync a room for a client which is starting without any state
         Returns:
-            A Deferred JoinedSyncResult.
+            A Deferred ArchivedSyncResult.
         """
 
-        batch = yield self.load_filtered_recents(
-            room_id, sync_config, leave_token, since_token=timeline_since_token
-        )
-
-        leave_state = yield self.store.get_state_for_event(leave_event_id)
-
-        leave_state = {
-            (e.type, e.state_key): e
-            for e in sync_config.filter_collection.filter_room_state(
-                leave_state.values()
-            )
-        }
-
-        account_data = self.account_data_for_room(
-            room_id, tags_by_room, account_data_by_room
-        )
-
-        account_data = sync_config.filter_collection.filter_room_account_data(
-            account_data
+        return self.incremental_sync_for_archived_room(
+            sync_config, room_id, leave_event_id, timeline_since_token, tags_by_room,
+            account_data_by_room, full_state=True, leave_token=leave_token,
         )
 
-        defer.returnValue(ArchivedSyncResult(
-            room_id=room_id,
-            timeline=batch,
-            state=leave_state,
-            account_data=account_data,
-        ))
-
     @defer.inlineCallbacks
     def incremental_sync_with_gap(self, sync_config, since_token):
         """ Get the incremental delta needed to bring the client up to
@@ -512,154 +488,127 @@ class SyncHandler(BaseHandler):
                 sync_config.user
             )
 
-        timeline_limit = sync_config.filter_collection.timeline_limit()
+        user_id = sync_config.user.to_string()
 
-        room_events, _ = yield self.store.get_room_events_stream(
-            sync_config.user.to_string(),
-            from_key=since_token.room_key,
-            to_key=now_token.room_key,
-            limit=timeline_limit + 1,
-        )
+        timeline_limit = sync_config.filter_collection.timeline_limit()
 
         tags_by_room = yield self.store.get_updated_tags(
-            sync_config.user.to_string(),
+            user_id,
             since_token.account_data_key,
         )
 
         account_data, account_data_by_room = (
             yield self.store.get_updated_account_data_for_user(
-                sync_config.user.to_string(),
+                user_id,
                 since_token.account_data_key,
             )
         )
 
-        joined = []
-        archived = []
-        if len(room_events) <= timeline_limit:
-            # There is no gap in any of the rooms. Therefore we can just
-            # partition the new events by room and return them.
-            logger.debug("Got %i events for incremental sync - not limited",
-                         len(room_events))
-
-            invite_events = []
-            leave_events = []
-            events_by_room_id = {}
-            for event in room_events:
-                events_by_room_id.setdefault(event.room_id, []).append(event)
-                if event.room_id not in joined_room_ids:
-                    if (event.type == EventTypes.Member
-                            and event.state_key == sync_config.user.to_string()):
-                        if event.membership == Membership.INVITE:
-                            invite_events.append(event)
-                        elif event.membership in (Membership.LEAVE, Membership.BAN):
-                            leave_events.append(event)
-
-            for room_id in joined_room_ids:
-                recents = events_by_room_id.get(room_id, [])
-                logger.debug("Events for room %s: %r", room_id, recents)
-                state = {
-                    (event.type, event.state_key): event
-                    for event in recents if event.is_state()}
-                limited = False
-
-                if recents:
-                    prev_batch = now_token.copy_and_replace(
-                        "room_key", recents[0].internal_metadata.before
-                    )
-                else:
-                    prev_batch = now_token
-
-                just_joined = yield self.check_joined_room(sync_config, state)
-                if just_joined:
-                    logger.debug("User has just joined %s: needs full state",
-                                 room_id)
-                    state = yield self.get_state_at(room_id, now_token)
-                    # the timeline is inherently limited if we've just joined
-                    limited = True
-
-                recents = sync_config.filter_collection.filter_room_timeline(recents)
-
-                state = {
-                    (e.type, e.state_key): e
-                    for e in sync_config.filter_collection.filter_room_state(
-                        state.values()
-                    )
-                }
-
-                acc_data = self.account_data_for_room(
-                    room_id, tags_by_room, account_data_by_room
-                )
+        # Get a list of membership change events that have happened.
+        rooms_changed = yield self.store.get_room_changes_for_user(
+            user_id, since_token.room_key, now_token.room_key
+        )
 
-                acc_data = sync_config.filter_collection.filter_room_account_data(
-                    acc_data
-                )
+        mem_change_events_by_room_id = {}
+        for event in rooms_changed:
+            mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
 
-                ephemeral = sync_config.filter_collection.filter_room_ephemeral(
-                    ephemeral_by_room.get(room_id, [])
-                )
+        newly_joined_rooms = []
+        archived = []
+        invited = []
+        for room_id, events in mem_change_events_by_room_id.items():
+            non_joins = [e for e in events if e.membership != Membership.JOIN]
+            has_join = len(non_joins) != len(events)
+
+            # We want to figure out if we joined the room at some point since
+            # the last sync (even if we have since left). This is to make sure
+            # we do send down the room, and with full state, where necessary
+            if room_id in joined_room_ids or has_join:
+                old_state = yield self.get_state_at(room_id, since_token)
+                old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
+                if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
+                        newly_joined_rooms.append(room_id)
+
+                if room_id in joined_room_ids:
+                    continue
+
+            if not non_joins:
+                continue
 
-                room_sync = JoinedSyncResult(
-                    room_id=room_id,
-                    timeline=TimelineBatch(
-                        events=recents,
-                        prev_batch=prev_batch,
-                        limited=limited,
-                    ),
-                    state=state,
-                    ephemeral=ephemeral,
-                    account_data=acc_data,
-                    unread_notifications={},
+            # Only bother if we're still currently invited
+            should_invite = non_joins[-1].membership == Membership.INVITE
+            if should_invite:
+                room_sync = InvitedSyncResult(room_id, invite=non_joins[-1])
+                if room_sync:
+                    invited.append(room_sync)
+
+            # Always include leave/ban events. Just take the last one.
+            # TODO: How do we handle ban -> leave in same batch?
+            leave_events = [
+                e for e in non_joins
+                if e.membership in (Membership.LEAVE, Membership.BAN)
+            ]
+
+            if leave_events:
+                leave_event = leave_events[-1]
+                room_sync = yield self.incremental_sync_for_archived_room(
+                    sync_config, room_id, leave_event.event_id, since_token,
+                    tags_by_room, account_data_by_room,
+                    full_state=room_id in newly_joined_rooms
                 )
-                logger.debug("Result for room %s: %r", room_id, room_sync)
-
                 if room_sync:
-                    notifs = yield self.unread_notifs_for_room_id(
-                        room_id, sync_config, all_ephemeral_by_room
-                    )
+                    archived.append(room_sync)
 
-                    if notifs is not None:
-                        notif_dict = room_sync.unread_notifications
-                        notif_dict["notification_count"] = len(notifs)
-                        notif_dict["highlight_count"] = len([
-                            1 for notif in notifs
-                            if _action_has_highlight(notif["actions"])
-                        ])
+        # Get all events for rooms we're currently joined to.
+        room_to_events = yield self.store.get_room_events_stream_for_rooms(
+            room_ids=joined_room_ids,
+            from_key=since_token.room_key,
+            to_key=now_token.room_key,
+            limit=timeline_limit + 1,
+        )
 
-                    joined.append(room_sync)
+        joined = []
+        # We loop through all room ids, even if there are no new events, in case
+        # there are non room events taht we need to notify about.
+        for room_id in joined_room_ids:
+            room_entry = room_to_events.get(room_id, None)
 
-        else:
-            logger.debug("Got %i events for incremental sync - hit limit",
-                         len(room_events))
+            if room_entry:
+                events, start_key = room_entry
 
-            invite_events = yield self.store.get_invites_for_user(
-                sync_config.user.to_string()
-            )
+                prev_batch_token = now_token.copy_and_replace("room_key", start_key)
 
-            leave_events = yield self.store.get_leave_and_ban_events_for_user(
-                sync_config.user.to_string()
-            )
+                newly_joined_room = room_id in newly_joined_rooms
+                full_state = newly_joined_room
 
-            for room_id in joined_room_ids:
-                room_sync = yield self.incremental_sync_with_gap_for_room(
-                    room_id, sync_config, since_token, now_token,
-                    ephemeral_by_room, tags_by_room, account_data_by_room,
-                    all_ephemeral_by_room=all_ephemeral_by_room,
+                batch = yield self.load_filtered_recents(
+                    room_id, sync_config, prev_batch_token,
+                    since_token=since_token,
+                    recents=events,
+                    newly_joined_room=newly_joined_room,
                 )
-                if room_sync:
-                    joined.append(room_sync)
+            else:
+                batch = TimelineBatch(
+                    events=[],
+                    prev_batch=since_token,
+                    limited=False,
+                )
+                full_state = False
 
-        for leave_event in leave_events:
-            room_sync = yield self.incremental_sync_for_archived_room(
-                sync_config, leave_event, since_token, tags_by_room,
-                account_data_by_room
+            room_sync = yield self.incremental_sync_with_gap_for_room(
+                room_id=room_id,
+                sync_config=sync_config,
+                since_token=since_token,
+                now_token=now_token,
+                ephemeral_by_room=ephemeral_by_room,
+                tags_by_room=tags_by_room,
+                account_data_by_room=account_data_by_room,
+                all_ephemeral_by_room=all_ephemeral_by_room,
+                batch=batch,
+                full_state=full_state,
             )
             if room_sync:
-                archived.append(room_sync)
-
-        invited = [
-            InvitedSyncResult(room_id=event.room_id, invite=event)
-            for event in invite_events
-        ]
+                joined.append(room_sync)
 
         account_data_for_user = sync_config.filter_collection.filter_account_data(
             self.account_data_for_user(account_data)
@@ -680,12 +629,10 @@ class SyncHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def load_filtered_recents(self, room_id, sync_config, now_token,
-                              since_token=None):
+                              since_token=None, recents=None, newly_joined_room=False):
         """
         :returns a Deferred TimelineBatch
         """
-        limited = True
-        recents = []
         filtering_factor = 2
         timeline_limit = sync_config.filter_collection.timeline_limit()
         load_limit = max(timeline_limit * filtering_factor, 100)
@@ -693,15 +640,29 @@ class SyncHandler(BaseHandler):
         room_key = now_token.room_key
         end_key = room_key
 
+        limited = recents is None or newly_joined_room or timeline_limit < len(recents)
+
+        if recents is not None:
+            recents = sync_config.filter_collection.filter_room_timeline(recents)
+            recents = yield self._filter_events_for_client(
+                sync_config.user.to_string(),
+                recents,
+                is_peeking=sync_config.is_guest,
+            )
+        else:
+            recents = []
+
+        since_key = None
+        if since_token and not newly_joined_room:
+            since_key = since_token.room_key
+
         while limited and len(recents) < timeline_limit and max_repeat:
-            events, keys = yield self.store.get_recent_events_for_room(
+            events, end_key = yield self.store.get_room_events_stream_for_room(
                 room_id,
                 limit=load_limit + 1,
-                from_token=since_token.room_key if since_token else None,
-                end_token=end_key,
+                from_key=since_key,
+                to_key=end_key,
             )
-            room_key, _ = keys
-            end_key = "s" + room_key.split('-')[-1]
             loaded_recents = sync_config.filter_collection.filter_room_timeline(events)
             loaded_recents = yield self._filter_events_for_client(
                 sync_config.user.to_string(),
@@ -710,8 +671,10 @@ class SyncHandler(BaseHandler):
             )
             loaded_recents.extend(recents)
             recents = loaded_recents
+
             if len(events) <= load_limit:
                 limited = False
+                break
             max_repeat -= 1
 
         if len(recents) > timeline_limit:
@@ -724,7 +687,9 @@ class SyncHandler(BaseHandler):
         )
 
         defer.returnValue(TimelineBatch(
-            events=recents, prev_batch=prev_batch_token, limited=limited
+            events=recents,
+            prev_batch=prev_batch_token,
+            limited=limited or newly_joined_room
         ))
 
     @defer.inlineCallbacks
@@ -732,24 +697,8 @@ class SyncHandler(BaseHandler):
                                            since_token, now_token,
                                            ephemeral_by_room, tags_by_room,
                                            account_data_by_room,
-                                           all_ephemeral_by_room):
-        """ Get the incremental delta needed to bring the client up to date for
-        the room. Gives the client the most recent events and the changes to
-        state.
-        Returns:
-            A Deferred JoinedSyncResult
-        """
-        logger.debug("Doing incremental sync for room %s between %s and %s",
-                     room_id, since_token, now_token)
-
-        # TODO(mjark): Check for redactions we might have missed.
-
-        batch = yield self.load_filtered_recents(
-            room_id, sync_config, now_token, since_token,
-        )
-
-        logger.debug("Recents %r", batch)
-
+                                           all_ephemeral_by_room,
+                                           batch, full_state=False):
         if batch.limited:
             current_state = yield self.get_state_at(room_id, now_token)
 
@@ -814,43 +763,48 @@ class SyncHandler(BaseHandler):
         defer.returnValue(room_sync)
 
     @defer.inlineCallbacks
-    def incremental_sync_for_archived_room(self, sync_config, leave_event,
+    def incremental_sync_for_archived_room(self, sync_config, room_id, leave_event_id,
                                            since_token, tags_by_room,
-                                           account_data_by_room):
+                                           account_data_by_room, full_state,
+                                           leave_token=None):
         """ Get the incremental delta needed to bring the client up to date for
         the archived room.
         Returns:
             A Deferred ArchivedSyncResult
         """
 
-        stream_token = yield self.store.get_stream_token_for_event(
-            leave_event.event_id
-        )
+        if not leave_token:
+            stream_token = yield self.store.get_stream_token_for_event(
+                leave_event_id
+            )
 
-        leave_token = since_token.copy_and_replace("room_key", stream_token)
+            leave_token = since_token.copy_and_replace("room_key", stream_token)
 
-        if since_token.is_after(leave_token):
+        if since_token and since_token.is_after(leave_token):
             defer.returnValue(None)
 
         batch = yield self.load_filtered_recents(
-            leave_event.room_id, sync_config, leave_token, since_token,
+            room_id, sync_config, leave_token, since_token,
         )
 
         logger.debug("Recents %r", batch)
 
         state_events_at_leave = yield self.store.get_state_for_event(
-            leave_event.event_id
+            leave_event_id
         )
 
-        state_at_previous_sync = yield self.get_state_at(
-            leave_event.room_id, stream_position=since_token
-        )
+        if not full_state:
+            state_at_previous_sync = yield self.get_state_at(
+                room_id, stream_position=since_token
+            )
 
-        state_events_delta = yield self.compute_state_delta(
-            since_token=since_token,
-            previous_state=state_at_previous_sync,
-            current_state=state_events_at_leave,
-        )
+            state_events_delta = yield self.compute_state_delta(
+                since_token=since_token,
+                previous_state=state_at_previous_sync,
+                current_state=state_events_at_leave,
+            )
+        else:
+            state_events_delta = state_events_at_leave
 
         state_events_delta = {
             (e.type, e.state_key): e
@@ -860,7 +814,7 @@ class SyncHandler(BaseHandler):
         }
 
         account_data = self.account_data_for_room(
-            leave_event.room_id, tags_by_room, account_data_by_room
+            room_id, tags_by_room, account_data_by_room
         )
 
         account_data = sync_config.filter_collection.filter_room_account_data(
@@ -868,7 +822,7 @@ class SyncHandler(BaseHandler):
         )
 
         room_sync = ArchivedSyncResult(
-            room_id=leave_event.room_id,
+            room_id=room_id,
             timeline=batch,
             state=state_events_delta,
             account_data=account_data,
diff --git a/synapse/server.py b/synapse/server.py
index a59e46ca2d..e013a349c9 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -28,7 +28,7 @@ from synapse.notifier import Notifier
 from synapse.api.auth import Auth
 from synapse.handlers import Handlers
 from synapse.state import StateHandler
-from synapse.storage import DataStore
+from synapse.storage import get_datastore
 from synapse.util import Clock
 from synapse.util.distributor import Distributor
 from synapse.streams.events import EventSources
@@ -40,6 +40,11 @@ from synapse.api.filtering import Filtering
 
 from synapse.http.matrixfederationclient import MatrixFederationHttpClient
 
+import logging
+
+
+logger = logging.getLogger(__name__)
+
 
 class HomeServer(object):
     """A basic homeserver object without lazy component builders.
@@ -102,10 +107,19 @@ class HomeServer(object):
         self.hostname = hostname
         self._building = {}
 
+        self.clock = Clock()
+        self.distributor = Distributor()
+        self.ratelimiter = Ratelimiter()
+
         # Other kwargs are explicit dependencies
         for depname in kwargs:
             setattr(self, depname, kwargs[depname])
 
+    def setup(self):
+        logger.info("Setting up.")
+        self.datastore = get_datastore(self)
+        logger.info("Finished setting up.")
+
     def get_ip_from_request(self, request):
         # X-Forwarded-For is handled by our custom request type.
         return request.getClientIP()
@@ -116,15 +130,9 @@ class HomeServer(object):
     def is_mine_id(self, string):
         return string.split(":", 1)[1] == self.hostname
 
-    def build_clock(self):
-        return Clock()
-
     def build_replication_layer(self):
         return initialize_http_replication(self)
 
-    def build_datastore(self):
-        return DataStore(self)
-
     def build_handlers(self):
         return Handlers(self)
 
@@ -135,10 +143,9 @@ class HomeServer(object):
         return Auth(self)
 
     def build_http_client_context_factory(self):
-        config = self.get_config()
         return (
             InsecureInterceptableContextFactory()
-            if config.use_insecure_ssl_client_just_for_testing_do_not_use
+            if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
             else BrowserLikePolicyForHTTPS()
         )
 
@@ -157,15 +164,9 @@ class HomeServer(object):
     def build_state_handler(self):
         return StateHandler(self)
 
-    def build_distributor(self):
-        return Distributor()
-
     def build_event_sources(self):
         return EventSources(self)
 
-    def build_ratelimiter(self):
-        return Ratelimiter()
-
     def build_keyring(self):
         return Keyring(self)
 
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 7a3f6c4662..c8cab45f77 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -46,6 +46,9 @@ from .tags import TagsStore
 from .account_data import AccountDataStore
 
 
+from util.id_generators import IdGenerator, StreamIdGenerator
+
+
 import logging
 
 
@@ -58,6 +61,22 @@ logger = logging.getLogger(__name__)
 LAST_SEEN_GRANULARITY = 120*1000
 
 
+def get_datastore(hs):
+    logger.info("getting called!")
+
+    conn = hs.get_db_conn()
+    try:
+        cur = conn.cursor()
+        cur.execute("SELECT MIN(stream_ordering) FROM events",)
+        rows = cur.fetchall()
+        min_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
+        min_token = min(min_token, -1)
+
+        return DataStore(conn, hs, min_token)
+    finally:
+        conn.close()
+
+
 class DataStore(RoomMemberStore, RoomStore,
                 RegistrationStore, StreamStore, ProfileStore,
                 PresenceStore, TransactionStore,
@@ -79,18 +98,36 @@ class DataStore(RoomMemberStore, RoomStore,
                 EventPushActionsStore
                 ):
 
-    def __init__(self, hs):
-        super(DataStore, self).__init__(hs)
+    def __init__(self, db_conn, hs, min_stream_token):
         self.hs = hs
 
-        self.min_token_deferred = self._get_min_token()
-        self.min_token = None
+        self.min_stream_token = min_stream_token
 
         self.client_ip_last_seen = Cache(
             name="client_ip_last_seen",
             keylen=4,
         )
 
+        self._stream_id_gen = StreamIdGenerator(
+            db_conn, "events", "stream_ordering"
+        )
+        self._receipts_id_gen = StreamIdGenerator(
+            db_conn, "receipts_linearized", "stream_id"
+        )
+        self._account_data_id_gen = StreamIdGenerator(
+            db_conn, "account_data_max_stream_id", "stream_id"
+        )
+
+        self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
+        self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
+        self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
+        self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
+        self._pushers_id_gen = IdGenerator("pushers", "id", self)
+        self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
+        self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
+
+        super(DataStore, self).__init__(hs)
+
     @defer.inlineCallbacks
     def insert_client_ip(self, user, access_token, ip, user_agent):
         now = int(self._clock.time_msec())
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 90d7aee94a..5e77320540 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,13 +15,11 @@
 import logging
 
 from synapse.api.errors import StoreError
-from synapse.util.logutils import log_function
 from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
 from synapse.util.caches.dictionary_cache import DictionaryCache
 from synapse.util.caches.descriptors import Cache
 import synapse.metrics
 
-from util.id_generators import IdGenerator, StreamIdGenerator
 
 from twisted.internet import defer
 
@@ -175,16 +173,6 @@ class SQLBaseStore(object):
 
         self.database_engine = hs.database_engine
 
-        self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
-        self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
-        self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
-        self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
-        self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
-        self._pushers_id_gen = IdGenerator("pushers", "id", self)
-        self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
-        self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
-        self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
-
     def start_profiling(self):
         self._previous_loop_ts = self._clock.time_msec()
 
@@ -345,7 +333,8 @@ class SQLBaseStore(object):
 
         defer.returnValue(result)
 
-    def cursor_to_dict(self, cursor):
+    @staticmethod
+    def cursor_to_dict(cursor):
         """Converts a SQL cursor into an list of dicts.
 
         Args:
@@ -402,8 +391,8 @@ class SQLBaseStore(object):
             if not or_ignore:
                 raise
 
-    @log_function
-    def _simple_insert_txn(self, txn, table, values):
+    @staticmethod
+    def _simple_insert_txn(txn, table, values):
         keys, vals = zip(*values.items())
 
         sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@@ -414,7 +403,8 @@ class SQLBaseStore(object):
 
         txn.execute(sql, vals)
 
-    def _simple_insert_many_txn(self, txn, table, values):
+    @staticmethod
+    def _simple_insert_many_txn(txn, table, values):
         if not values:
             return
 
@@ -537,9 +527,10 @@ class SQLBaseStore(object):
             table, keyvalues, retcol, allow_none=allow_none,
         )
 
-    def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
+    @classmethod
+    def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
                                       allow_none=False):
-        ret = self._simple_select_onecol_txn(
+        ret = cls._simple_select_onecol_txn(
             txn,
             table=table,
             keyvalues=keyvalues,
@@ -554,7 +545,8 @@ class SQLBaseStore(object):
             else:
                 raise StoreError(404, "No row found")
 
-    def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
+    @staticmethod
+    def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
         sql = (
             "SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
         ) % {
@@ -603,7 +595,8 @@ class SQLBaseStore(object):
             table, keyvalues, retcols
         )
 
-    def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
+    @classmethod
+    def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
@@ -627,7 +620,7 @@ class SQLBaseStore(object):
             )
             txn.execute(sql)
 
-        return self.cursor_to_dict(txn)
+        return cls.cursor_to_dict(txn)
 
     @defer.inlineCallbacks
     def _simple_select_many_batch(self, table, column, iterable, retcols,
@@ -662,7 +655,8 @@ class SQLBaseStore(object):
 
         defer.returnValue(results)
 
-    def _simple_select_many_txn(self, txn, table, column, iterable, keyvalues, retcols):
+    @classmethod
+    def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
@@ -699,7 +693,7 @@ class SQLBaseStore(object):
             )
 
         txn.execute(sql, values)
-        return self.cursor_to_dict(txn)
+        return cls.cursor_to_dict(txn)
 
     def _simple_update_one(self, table, keyvalues, updatevalues,
                            desc="_simple_update_one"):
@@ -726,7 +720,8 @@ class SQLBaseStore(object):
             table, keyvalues, updatevalues,
         )
 
-    def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
+    @staticmethod
+    def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
         update_sql = "UPDATE %s SET %s WHERE %s" % (
             table,
             ", ".join("%s = ?" % (k,) for k in updatevalues),
@@ -743,7 +738,8 @@ class SQLBaseStore(object):
         if txn.rowcount > 1:
             raise StoreError(500, "More than one row matched")
 
-    def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
+    @staticmethod
+    def _simple_select_one_txn(txn, table, keyvalues, retcols,
                                allow_none=False):
         select_sql = "SELECT %s FROM %s WHERE %s" % (
             ", ".join(retcols),
@@ -784,7 +780,8 @@ class SQLBaseStore(object):
                 raise StoreError(500, "more than one row matched")
         return self.runInteraction(desc, func)
 
-    def _simple_delete_txn(self, txn, table, keyvalues):
+    @staticmethod
+    def _simple_delete_txn(txn, table, keyvalues):
         sql = "DELETE FROM %s WHERE %s" % (
             table,
             " AND ".join("%s = ?" % (k, ) for k in keyvalues)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index ba368a3eca..80187722ea 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -66,11 +66,9 @@ class EventsStore(SQLBaseStore):
             return
 
         if backfilled:
-            if not self.min_token_deferred.called:
-                yield self.min_token_deferred
-            start = self.min_token - 1
-            self.min_token -= len(events_and_contexts) + 1
-            stream_orderings = range(start, self.min_token, -1)
+            start = self.min_stream_token - 1
+            self.min_stream_token -= len(events_and_contexts) + 1
+            stream_orderings = range(start, self.min_stream_token, -1)
 
             @contextmanager
             def stream_ordering_manager():
@@ -107,10 +105,8 @@ class EventsStore(SQLBaseStore):
                       is_new_state=True, current_state=None):
         stream_ordering = None
         if backfilled:
-            if not self.min_token_deferred.called:
-                yield self.min_token_deferred
-            self.min_token -= 1
-            stream_ordering = self.min_token
+            self.min_stream_token -= 1
+            stream_ordering = self.min_stream_token
 
         if stream_ordering is None:
             stream_ordering_manager = yield self._stream_id_gen.get_next(self)
@@ -132,6 +128,9 @@ class EventsStore(SQLBaseStore):
                     is_new_state=is_new_state,
                     current_state=current_state,
                 )
+                self._events_stream_cache.room_has_changed(
+                    None, event.room_id, stream_ordering
+                )
         except _RollbackButIsFineException:
             pass
 
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index c4232bdc65..7118368d97 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -15,11 +15,10 @@
 
 from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
-from synapse.util.caches import cache_counter, caches_by_name
+from synapse.util.caches.room_change_cache import RoomStreamChangeCache
 
 from twisted.internet import defer
 
-from blist import sorteddict
 import logging
 import ujson as json
 
@@ -31,7 +30,9 @@ class ReceiptsStore(SQLBaseStore):
     def __init__(self, hs):
         super(ReceiptsStore, self).__init__(hs)
 
-        self._receipts_stream_cache = _RoomStreamChangeCache()
+        self._receipts_stream_cache = RoomStreamChangeCache(
+            "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token(None)
+        )
 
     @cached(num_args=2)
     def get_receipts_for_room(self, room_id, receipt_type):
@@ -368,63 +369,3 @@ class ReceiptsStore(SQLBaseStore):
                 "data": json.dumps(data),
             }
         )
-
-
-class _RoomStreamChangeCache(object):
-    """Keeps track of the stream_id of the latest change in rooms.
-
-    Given a list of rooms and stream key, it will give a subset of rooms that
-    may have changed since that key. If the key is too old then the cache
-    will simply return all rooms.
-    """
-    def __init__(self, size_of_cache=10000):
-        self._size_of_cache = size_of_cache
-        self._room_to_key = {}
-        self._cache = sorteddict()
-        self._earliest_key = None
-        self.name = "ReceiptsRoomChangeCache"
-        caches_by_name[self.name] = self._cache
-
-    @defer.inlineCallbacks
-    def get_rooms_changed(self, store, room_ids, key):
-        """Returns subset of room ids that have had new receipts since the
-        given key. If the key is too old it will just return the given list.
-        """
-        if key > (yield self._get_earliest_key(store)):
-            keys = self._cache.keys()
-            i = keys.bisect_right(key)
-
-            result = set(
-                self._cache[k] for k in keys[i:]
-            ).intersection(room_ids)
-
-            cache_counter.inc_hits(self.name)
-        else:
-            result = room_ids
-            cache_counter.inc_misses(self.name)
-
-        defer.returnValue(result)
-
-    @defer.inlineCallbacks
-    def room_has_changed(self, store, room_id, key):
-        """Informs the cache that the room has been changed at the given key.
-        """
-        if key > (yield self._get_earliest_key(store)):
-            old_key = self._room_to_key.get(room_id, None)
-            if old_key:
-                key = max(key, old_key)
-                self._cache.pop(old_key, None)
-            self._cache[key] = room_id
-
-            while len(self._cache) > self._size_of_cache:
-                k, r = self._cache.popitem()
-                self._earliest_key = max(k, self._earliest_key)
-                self._room_to_key.pop(r, None)
-
-    @defer.inlineCallbacks
-    def _get_earliest_key(self, store):
-        if self._earliest_key is None:
-            self._earliest_key = yield store.get_max_receipt_stream_id()
-            self._earliest_key = int(self._earliest_key)
-
-        defer.returnValue(self._earliest_key)
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 02b1913e26..0b22251790 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -37,6 +37,7 @@ from twisted.internet import defer
 
 from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.room_change_cache import RoomStreamChangeCache
 from synapse.api.constants import EventTypes
 from synapse.types import RoomStreamToken
 from synapse.util.logutils import log_function
@@ -77,6 +78,12 @@ def upper_bound(token):
 
 
 class StreamStore(SQLBaseStore):
+    def __init__(self, hs):
+        super(StreamStore, self).__init__(hs)
+
+        self._events_stream_cache = RoomStreamChangeCache(
+            "EventsRoomStreamChangeCache", self._stream_id_gen.get_max_token(None)
+        )
 
     @defer.inlineCallbacks
     def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
@@ -157,6 +164,134 @@ class StreamStore(SQLBaseStore):
         results = yield self.runInteraction("get_appservice_room_stream", f)
         defer.returnValue(results)
 
+    @defer.inlineCallbacks
+    def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0):
+        from_id = RoomStreamToken.parse_stream_token(from_key).stream
+
+        room_ids = yield self._events_stream_cache.get_rooms_changed(
+            self, room_ids, from_id
+        )
+
+        if not room_ids:
+            defer.returnValue({})
+
+        results = {}
+        room_ids = list(room_ids)
+        for rm_ids in (room_ids[i:i+20] for i in xrange(0, len(room_ids), 20)):
+            res = yield defer.gatherResults([
+                self.get_room_events_stream_for_room(
+                    room_id, from_key, to_key, limit
+                ).addCallback(lambda r, rm: (rm, r), room_id)
+                for room_id in room_ids
+            ])
+            results.update(dict(res))
+
+        defer.returnValue(results)
+
+    @defer.inlineCallbacks
+    def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0):
+        if from_key is not None:
+            from_id = RoomStreamToken.parse_stream_token(from_key).stream
+        else:
+            from_id = None
+        to_id = RoomStreamToken.parse_stream_token(to_key).stream
+
+        if from_key == to_key:
+            defer.returnValue(([], from_key))
+
+        has_changed = yield self._events_stream_cache.get_room_has_changed(
+            room_id, from_id
+        )
+
+        if not has_changed:
+            defer.returnValue(([], from_key))
+
+        def f(txn):
+            if from_id is not None:
+                sql = (
+                    "SELECT event_id, stream_ordering FROM events WHERE"
+                    " room_id = ?"
+                    " AND not outlier"
+                    " AND stream_ordering > ? AND stream_ordering <= ?"
+                    " ORDER BY stream_ordering DESC LIMIT ?"
+                )
+                txn.execute(sql, (room_id, from_id, to_id, limit))
+            else:
+                sql = (
+                    "SELECT event_id, stream_ordering FROM events WHERE"
+                    " room_id = ?"
+                    " AND not outlier"
+                    " AND stream_ordering <= ?"
+                    " ORDER BY stream_ordering DESC LIMIT ?"
+                )
+                txn.execute(sql, (room_id, to_id, limit))
+
+            rows = self.cursor_to_dict(txn)
+
+            ret = self._get_events_txn(
+                txn,
+                [r["event_id"] for r in rows],
+                get_prev_content=True
+            )
+
+            ret.reverse()
+
+            self._set_before_and_after(ret, rows)
+
+            if rows:
+                key = "s%d" % min(r["stream_ordering"] for r in rows)
+            else:
+                # Assume we didn't get anything because there was nothing to
+                # get.
+                key = from_key
+
+            return ret, key
+        res = yield self.runInteraction("get_room_events_stream_for_room", f)
+        defer.returnValue(res)
+
+    def get_room_changes_for_user(self, user_id, from_key, to_key):
+        if from_key is not None:
+            from_id = RoomStreamToken.parse_stream_token(from_key).stream
+        else:
+            from_id = None
+        to_id = RoomStreamToken.parse_stream_token(to_key).stream
+
+        if from_key == to_key:
+            return defer.succeed([])
+
+        def f(txn):
+            if from_id is not None:
+                sql = (
+                    "SELECT m.event_id, stream_ordering FROM events AS e,"
+                    " room_memberships AS m"
+                    " WHERE e.event_id = m.event_id"
+                    " AND m.user_id = ?"
+                    " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
+                    " ORDER BY e.stream_ordering ASC"
+                )
+                txn.execute(sql, (user_id, from_id, to_id,))
+            else:
+                sql = (
+                    "SELECT m.event_id, stream_ordering FROM events AS e,"
+                    " room_memberships AS m"
+                    " WHERE e.event_id = m.event_id"
+                    " AND m.user_id = ?"
+                    " AND stream_ordering <= ?"
+                    " ORDER BY stream_ordering ASC"
+                )
+                txn.execute(sql, (user_id, to_id,))
+            rows = self.cursor_to_dict(txn)
+
+            ret = self._get_events_txn(
+                txn,
+                [r["event_id"] for r in rows],
+                get_prev_content=True
+            )
+
+            return ret
+
+        return self.runInteraction("get_room_changes_for_user", f)
+
     @log_function
     def get_room_events_stream(
         self,
@@ -174,7 +309,8 @@ class StreamStore(SQLBaseStore):
                 "SELECT c.room_id FROM history_visibility AS h"
                 " INNER JOIN current_state_events AS c"
                 " ON h.event_id = c.event_id"
-                " WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % (
+                " WHERE c.room_id IN (%s)"
+                " AND h.history_visibility = 'world_readable'" % (
                     ",".join(map(lambda _: "?", room_ids))
                 )
             )
@@ -444,19 +580,6 @@ class StreamStore(SQLBaseStore):
         rows = txn.fetchall()
         return rows[0][0] if rows else 0
 
-    @defer.inlineCallbacks
-    def _get_min_token(self):
-        row = yield self._execute(
-            "_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
-        )
-
-        self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
-        self.min_token = min(self.min_token, -1)
-
-        logger.debug("min_token is: %s", self.min_token)
-
-        defer.returnValue(self.min_token)
-
     @staticmethod
     def _set_before_and_after(events, rows):
         for event, row in zip(events, rows):
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index ed9c91e5ea..4c39e07cbd 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -16,7 +16,6 @@
 from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached
 from twisted.internet import defer
-from .util.id_generators import StreamIdGenerator
 
 import ujson as json
 import logging
@@ -25,12 +24,6 @@ logger = logging.getLogger(__name__)
 
 
 class TagsStore(SQLBaseStore):
-    def __init__(self, hs):
-        super(TagsStore, self).__init__(hs)
-
-        self._account_data_id_gen = StreamIdGenerator(
-            "account_data_max_stream_id", "stream_id"
-        )
 
     def get_max_account_data_stream_id(self):
         """Get the current max stream id for the private user data stream
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f58bf7fd2c..5c522f4ab9 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -72,28 +72,24 @@ class StreamIdGenerator(object):
         with stream_id_gen.get_next_txn(txn) as stream_id:
             # ... persist event ...
     """
-    def __init__(self, table, column):
+    def __init__(self, db_conn, table, column):
         self.table = table
         self.column = column
 
         self._lock = threading.Lock()
 
-        self._current_max = None
+        cur = db_conn.cursor()
+        self._current_max = self._get_or_compute_current_max(cur)
+        cur.close()
+
         self._unfinished_ids = deque()
 
-    @defer.inlineCallbacks
     def get_next(self, store):
         """
         Usage:
             with yield stream_id_gen.get_next as stream_id:
                 # ... persist event ...
         """
-        if not self._current_max:
-            yield store.runInteraction(
-                "_compute_current_max",
-                self._get_or_compute_current_max,
-            )
-
         with self._lock:
             self._current_max += 1
             next_id = self._current_max
@@ -108,21 +104,14 @@ class StreamIdGenerator(object):
                 with self._lock:
                     self._unfinished_ids.remove(next_id)
 
-        defer.returnValue(manager())
+        return manager()
 
-    @defer.inlineCallbacks
     def get_next_mult(self, store, n):
         """
         Usage:
             with yield stream_id_gen.get_next(store, n) as stream_ids:
                 # ... persist events ...
         """
-        if not self._current_max:
-            yield store.runInteraction(
-                "_compute_current_max",
-                self._get_or_compute_current_max,
-            )
-
         with self._lock:
             next_ids = range(self._current_max + 1, self._current_max + n + 1)
             self._current_max += n
@@ -139,24 +128,17 @@ class StreamIdGenerator(object):
                     for next_id in next_ids:
                         self._unfinished_ids.remove(next_id)
 
-        defer.returnValue(manager())
+        return manager()
 
-    @defer.inlineCallbacks
     def get_max_token(self, store):
         """Returns the maximum stream id such that all stream ids less than or
         equal to it have been successfully persisted.
         """
-        if not self._current_max:
-            yield store.runInteraction(
-                "_compute_current_max",
-                self._get_or_compute_current_max,
-            )
-
         with self._lock:
             if self._unfinished_ids:
-                defer.returnValue(self._unfinished_ids[0] - 1)
+                return self._unfinished_ids[0] - 1
 
-            defer.returnValue(self._current_max)
+            return self._current_max
 
     def _get_or_compute_current_max(self, txn):
         with self._lock:
diff --git a/synapse/util/caches/room_change_cache.py b/synapse/util/caches/room_change_cache.py
new file mode 100644
index 0000000000..3a873c9c30
--- /dev/null
+++ b/synapse/util/caches/room_change_cache.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.caches import cache_counter, caches_by_name
+
+
+from blist import sorteddict
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class RoomStreamChangeCache(object):
+    """Keeps track of the stream_id of the latest change in rooms.
+
+    Given a list of rooms and stream key, it will give a subset of rooms that
+    may have changed since that key. If the key is too old then the cache
+    will simply return all rooms.
+    """
+    def __init__(self, name, current_key, size_of_cache=10000):
+        self._size_of_cache = size_of_cache
+        self._room_to_key = {}
+        self._cache = sorteddict()
+        self._earliest_known_key = current_key
+        self.name = name
+        caches_by_name[self.name] = self._cache
+
+    def get_room_has_changed(self, room_id, key):
+        if key <= self._earliest_known_key:
+            return True
+
+        room_key = self._room_to_key.get(room_id, None)
+        if room_key is None:
+            return True
+
+        if key < room_key:
+            return True
+
+        return False
+
+    def get_rooms_changed(self, store, room_ids, key):
+        """Returns subset of room ids that have had new things since the
+        given key. If the key is too old it will just return the given list.
+        """
+        if key > self._earliest_known_key:
+            keys = self._cache.keys()
+            i = keys.bisect_right(key)
+
+            result = set(
+                self._cache[k] for k in keys[i:]
+            ).intersection(room_ids)
+
+            cache_counter.inc_hits(self.name)
+        else:
+            result = room_ids
+            cache_counter.inc_misses(self.name)
+
+        return result
+
+    def room_has_changed(self, store, room_id, key):
+        """Informs the cache that the room has been changed at the given key.
+        """
+        if key > self._earliest_known_key:
+            old_key = self._room_to_key.get(room_id, None)
+            if old_key:
+                key = max(key, old_key)
+                self._cache.pop(old_key, None)
+            self._cache[key] = room_id
+
+            while len(self._cache) > self._size_of_cache:
+                k, r = self._cache.popitem()
+                self._earliest_key = max(k, self._earliest_key)
+                self._room_to_key.pop(r, None)
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 5abecdf6e0..ed8af10d87 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -439,7 +439,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         f2 = self._write_config(suffix="2")
 
         config = Mock(app_service_config_files=[f1, f2])
-        hs = yield setup_test_homeserver(config=config)
+        hs = yield setup_test_homeserver(config=config, datastore=Mock())
 
         ApplicationServiceStore(hs)
 
@@ -449,7 +449,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         f2 = self._write_config(id="id", suffix="2")
 
         config = Mock(app_service_config_files=[f1, f2])
-        hs = yield setup_test_homeserver(config=config)
+        hs = yield setup_test_homeserver(config=config, datastore=Mock())
 
         with self.assertRaises(ConfigError) as cm:
             ApplicationServiceStore(hs)
@@ -465,7 +465,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
         f2 = self._write_config(as_token="as_token", suffix="2")
 
         config = Mock(app_service_config_files=[f1, f2])
-        hs = yield setup_test_homeserver(config=config)
+        hs = yield setup_test_homeserver(config=config, datastore=Mock())
 
         with self.assertRaises(ConfigError) as cm:
             ApplicationServiceStore(hs)
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index a35efcc71e..7b3b4c13bc 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -18,7 +18,6 @@ from tests import unittest
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
-from synapse.storage.registration import RegistrationStore
 from synapse.util import stringutils
 
 from tests.utils import setup_test_homeserver
@@ -31,7 +30,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
         hs = yield setup_test_homeserver()
         self.db_pool = hs.get_db_pool()
 
-        self.store = RegistrationStore(hs)
+        self.store = hs.get_datastore()
 
         self.user_id = "@my-user:test"
         self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz",
diff --git a/tests/utils.py b/tests/utils.py
index d75d492cb5..43cc2b30cd 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -60,8 +60,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
             name, db_pool=db_pool, config=config,
             version_string="Synapse/tests",
             database_engine=create_engine("sqlite3"),
+            get_db_conn=db_pool.get_db_conn,
             **kargs
         )
+        hs.setup()
     else:
         hs = HomeServer(
             name, db_pool=None, datastore=datastore, config=config,
@@ -280,6 +282,12 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
             lambda conn: prepare_database(conn, engine)
         )
 
+    def get_db_conn(self):
+        conn = self.connect()
+        engine = create_engine("sqlite3")
+        prepare_database(conn, engine)
+        return conn
+
 
 class MemoryDataStore(object):