summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/admin.py7
-rw-r--r--synapse/handlers/device.py3
-rw-r--r--synapse/handlers/events.py6
-rw-r--r--synapse/handlers/federation.py19
-rw-r--r--synapse/handlers/initial_sync.py14
-rw-r--r--synapse/handlers/message.py10
-rw-r--r--synapse/handlers/pagination.py6
-rw-r--r--synapse/handlers/room.py6
-rw-r--r--synapse/handlers/search.py12
-rw-r--r--synapse/handlers/sync.py20
-rw-r--r--synapse/notifier.py6
-rw-r--r--synapse/push/httppusher.py3
-rw-r--r--synapse/push/mailer.py3
-rw-r--r--synapse/push/push_tools.py9
-rw-r--r--synapse/state/__init__.py13
-rw-r--r--synapse/visibility.py30
-rw-r--r--tests/storage/test_state.py150
-rw-r--r--tests/test_state.py3
-rw-r--r--tests/test_visibility.py11
19 files changed, 216 insertions, 115 deletions
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 1a87b58838..6407d56f8e 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -30,6 +30,9 @@ class AdminHandler(BaseHandler):
     def __init__(self, hs):
         super(AdminHandler, self).__init__(hs)
 
+        self.storage = hs.get_storage()
+        self.state_store = self.storage.state
+
     @defer.inlineCallbacks
     def get_whois(self, user):
         connections = []
@@ -205,7 +208,7 @@ class AdminHandler(BaseHandler):
 
                 from_key = events[-1].internal_metadata.after
 
-                events = yield filter_events_for_client(self.store, user_id, events)
+                events = yield filter_events_for_client(self.storage, user_id, events)
 
                 writer.write_events(room_id, events)
 
@@ -241,7 +244,7 @@ class AdminHandler(BaseHandler):
             for event_id in extremities:
                 if not event_to_unseen_prevs[event_id]:
                     continue
-                state = yield self.store.get_state_for_event(event_id)
+                state = yield self.state_store.get_state_for_event(event_id)
                 writer.write_state(room_id, event_id, state)
 
         return writer.finished()
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 5f23ee4488..b3fd7e6249 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -46,6 +46,7 @@ class DeviceWorkerHandler(BaseHandler):
 
         self.hs = hs
         self.state = hs.get_state_handler()
+        self.state_store = hs.get_storage().state
         self._auth_handler = hs.get_auth_handler()
 
     @trace
@@ -178,7 +179,7 @@ class DeviceWorkerHandler(BaseHandler):
                 continue
 
             # mapping from event_id -> state_dict
-            prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
+            prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids)
 
             # Check if we've joined the room? If so we just blindly add all the users to
             # the "possibly changed" users.
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 5e748687e3..45fe13c62f 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -147,6 +147,10 @@ class EventStreamHandler(BaseHandler):
 
 
 class EventHandler(BaseHandler):
+    def __init__(self, hs):
+        super(EventHandler, self).__init__(hs)
+        self.storage = hs.get_storage()
+
     @defer.inlineCallbacks
     def get_event(self, user, room_id, event_id):
         """Retrieve a single specified event.
@@ -172,7 +176,7 @@ class EventHandler(BaseHandler):
         is_peeking = user.to_string() not in users
 
         filtered = yield filter_events_for_client(
-            self.store, user.to_string(), [event], is_peeking=is_peeking
+            self.storage, user.to_string(), [event], is_peeking=is_peeking
         )
 
         if not filtered:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 08276fdebf..4d9e33346d 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -110,6 +110,7 @@ class FederationHandler(BaseHandler):
 
         self.store = hs.get_datastore()
         self.storage = hs.get_storage()
+        self.state_store = self.storage.state
         self.federation_client = hs.get_federation_client()
         self.state_handler = hs.get_state_handler()
         self.server_name = hs.hostname
@@ -325,7 +326,7 @@ class FederationHandler(BaseHandler):
                 event_map = {event_id: pdu}
                 try:
                     # Get the state of the events we know about
-                    ours = yield self.store.get_state_groups_ids(room_id, seen)
+                    ours = yield self.state_store.get_state_groups_ids(room_id, seen)
 
                     # state_maps is a list of mappings from (type, state_key) to event_id
                     state_maps = list(
@@ -889,7 +890,7 @@ class FederationHandler(BaseHandler):
         # We set `check_history_visibility_only` as we might otherwise get false
         # positives from users having been erased.
         filtered_extremities = yield filter_events_for_server(
-            self.store,
+            self.storage,
             self.server_name,
             list(extremities_events.values()),
             redact=False,
@@ -1550,7 +1551,7 @@ class FederationHandler(BaseHandler):
             event_id, allow_none=False, check_room_id=room_id
         )
 
-        state_groups = yield self.store.get_state_groups(room_id, [event_id])
+        state_groups = yield self.state_store.get_state_groups(room_id, [event_id])
 
         if state_groups:
             _, state = list(iteritems(state_groups)).pop()
@@ -1579,7 +1580,7 @@ class FederationHandler(BaseHandler):
             event_id, allow_none=False, check_room_id=room_id
         )
 
-        state_groups = yield self.store.get_state_groups_ids(room_id, [event_id])
+        state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id])
 
         if state_groups:
             _, state = list(state_groups.items()).pop()
@@ -1607,7 +1608,7 @@ class FederationHandler(BaseHandler):
 
         events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
 
-        events = yield filter_events_for_server(self.store, origin, events)
+        events = yield filter_events_for_server(self.storage, origin, events)
 
         return events
 
@@ -1637,7 +1638,7 @@ class FederationHandler(BaseHandler):
             if not in_room:
                 raise AuthError(403, "Host not in room.")
 
-            events = yield filter_events_for_server(self.store, origin, [event])
+            events = yield filter_events_for_server(self.storage, origin, [event])
             event = events[0]
             return event
         else:
@@ -1903,7 +1904,7 @@ class FederationHandler(BaseHandler):
                 # given state at the event. This should correctly handle cases
                 # like bans, especially with state res v2.
 
-                state_sets = yield self.store.get_state_groups(
+                state_sets = yield self.state_store.get_state_groups(
                     event.room_id, extrem_ids
                 )
                 state_sets = list(state_sets.values())
@@ -1994,7 +1995,7 @@ class FederationHandler(BaseHandler):
         )
 
         missing_events = yield filter_events_for_server(
-            self.store, origin, missing_events
+            self.storage, origin, missing_events
         )
 
         return missing_events
@@ -2235,7 +2236,7 @@ class FederationHandler(BaseHandler):
 
         # create a new state group as a delta from the existing one.
         prev_group = context.state_group
-        state_group = yield self.store.store_state_group(
+        state_group = yield self.state_store.store_state_group(
             event.event_id,
             event.room_id,
             prev_group=prev_group,
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index f991efeee3..49c9e031f9 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -43,6 +43,8 @@ class InitialSyncHandler(BaseHandler):
         self.validator = EventValidator()
         self.snapshot_cache = SnapshotCache()
         self._event_serializer = hs.get_event_client_serializer()
+        self.storage = hs.get_storage()
+        self.state_store = self.storage.state
 
     def snapshot_all_rooms(
         self,
@@ -169,7 +171,7 @@ class InitialSyncHandler(BaseHandler):
                 elif event.membership == Membership.LEAVE:
                     room_end_token = "s%d" % (event.stream_ordering,)
                     deferred_room_state = run_in_background(
-                        self.store.get_state_for_events, [event.event_id]
+                        self.state_store.get_state_for_events, [event.event_id]
                     )
                     deferred_room_state.addCallback(
                         lambda states: states[event.event_id]
@@ -189,7 +191,9 @@ class InitialSyncHandler(BaseHandler):
                     )
                 ).addErrback(unwrapFirstError)
 
-                messages = yield filter_events_for_client(self.store, user_id, messages)
+                messages = yield filter_events_for_client(
+                    self.storage, user_id, messages
+                )
 
                 start_token = now_token.copy_and_replace("room_key", token)
                 end_token = now_token.copy_and_replace("room_key", room_end_token)
@@ -307,7 +311,7 @@ class InitialSyncHandler(BaseHandler):
     def _room_initial_sync_parted(
         self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
     ):
-        room_state = yield self.store.get_state_for_events([member_event_id])
+        room_state = yield self.state_store.get_state_for_events([member_event_id])
 
         room_state = room_state[member_event_id]
 
@@ -322,7 +326,7 @@ class InitialSyncHandler(BaseHandler):
         )
 
         messages = yield filter_events_for_client(
-            self.store, user_id, messages, is_peeking=is_peeking
+            self.storage, user_id, messages, is_peeking=is_peeking
         )
 
         start_token = StreamToken.START.copy_and_replace("room_key", token)
@@ -414,7 +418,7 @@ class InitialSyncHandler(BaseHandler):
         )
 
         messages = yield filter_events_for_client(
-            self.store, user_id, messages, is_peeking=is_peeking
+            self.storage, user_id, messages, is_peeking=is_peeking
         )
 
         start_token = now_token.copy_and_replace("room_key", token)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 7908a2d52c..6e2a360262 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -59,6 +59,8 @@ class MessageHandler(object):
         self.clock = hs.get_clock()
         self.state = hs.get_state_handler()
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
+        self.state_store = self.storage.state
         self._event_serializer = hs.get_event_client_serializer()
 
     @defer.inlineCallbacks
@@ -82,7 +84,7 @@ class MessageHandler(object):
             data = yield self.state.get_current_state(room_id, event_type, state_key)
         elif membership == Membership.LEAVE:
             key = (event_type, state_key)
-            room_state = yield self.store.get_state_for_events(
+            room_state = yield self.state_store.get_state_for_events(
                 [membership_event_id], StateFilter.from_types([key])
             )
             data = room_state[membership_event_id].get(key)
@@ -135,12 +137,12 @@ class MessageHandler(object):
                 raise NotFoundError("Can't find event for token %s" % (at_token,))
 
             visible_events = yield filter_events_for_client(
-                self.store, user_id, last_events
+                self.storage, user_id, last_events
             )
 
             event = last_events[0]
             if visible_events:
-                room_state = yield self.store.get_state_for_events(
+                room_state = yield self.state_store.get_state_for_events(
                     [event.event_id], state_filter=state_filter
                 )
                 room_state = room_state[event.event_id]
@@ -161,7 +163,7 @@ class MessageHandler(object):
                 )
                 room_state = yield self.store.get_events(state_ids.values())
             elif membership == Membership.LEAVE:
-                room_state = yield self.store.get_state_for_events(
+                room_state = yield self.state_store.get_state_for_events(
                     [membership_event_id], state_filter=state_filter
                 )
                 room_state = room_state[membership_event_id]
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 5744f4579d..b7185fe7a0 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -69,6 +69,8 @@ class PaginationHandler(object):
         self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
+        self.state_store = self.storage.state
         self.clock = hs.get_clock()
         self._server_name = hs.hostname
 
@@ -255,7 +257,7 @@ class PaginationHandler(object):
                 events = event_filter.filter(events)
 
             events = yield filter_events_for_client(
-                self.store, user_id, events, is_peeking=(member_event_id is None)
+                self.storage, user_id, events, is_peeking=(member_event_id is None)
             )
 
         if not events:
@@ -274,7 +276,7 @@ class PaginationHandler(object):
                 (EventTypes.Member, event.sender) for event in events
             )
 
-            state_ids = yield self.store.get_state_ids_for_event(
+            state_ids = yield self.state_store.get_state_ids_for_event(
                 events[0].event_id, state_filter=state_filter
             )
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 2816bd8f87..84bad39815 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -822,6 +822,8 @@ class RoomContextHandler(object):
     def __init__(self, hs):
         self.hs = hs
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
+        self.state_store = self.storage.state
 
     @defer.inlineCallbacks
     def get_event_context(self, user, room_id, event_id, limit, event_filter):
@@ -848,7 +850,7 @@ class RoomContextHandler(object):
 
         def filter_evts(events):
             return filter_events_for_client(
-                self.store, user.to_string(), events, is_peeking=is_peeking
+                self.storage, user.to_string(), events, is_peeking=is_peeking
             )
 
         event = yield self.store.get_event(
@@ -890,7 +892,7 @@ class RoomContextHandler(object):
         # first? Shouldn't we be consistent with /sync?
         # https://github.com/matrix-org/matrix-doc/issues/687
 
-        state = yield self.store.get_state_for_events(
+        state = yield self.state_store.get_state_for_events(
             [last_event_id], state_filter=state_filter
         )
         results["state"] = list(state[last_event_id].values())
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index cd5e90bacb..f4d8a60774 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -35,6 +35,8 @@ class SearchHandler(BaseHandler):
     def __init__(self, hs):
         super(SearchHandler, self).__init__(hs)
         self._event_serializer = hs.get_event_client_serializer()
+        self.storage = hs.get_storage()
+        self.state_store = self.storage.state
 
     @defer.inlineCallbacks
     def get_old_rooms_from_upgraded_room(self, room_id):
@@ -221,7 +223,7 @@ class SearchHandler(BaseHandler):
             filtered_events = search_filter.filter([r["event"] for r in results])
 
             events = yield filter_events_for_client(
-                self.store, user.to_string(), filtered_events
+                self.storage, user.to_string(), filtered_events
             )
 
             events.sort(key=lambda e: -rank_map[e.event_id])
@@ -271,7 +273,7 @@ class SearchHandler(BaseHandler):
                 filtered_events = search_filter.filter([r["event"] for r in results])
 
                 events = yield filter_events_for_client(
-                    self.store, user.to_string(), filtered_events
+                    self.storage, user.to_string(), filtered_events
                 )
 
                 room_events.extend(events)
@@ -340,11 +342,11 @@ class SearchHandler(BaseHandler):
                 )
 
                 res["events_before"] = yield filter_events_for_client(
-                    self.store, user.to_string(), res["events_before"]
+                    self.storage, user.to_string(), res["events_before"]
                 )
 
                 res["events_after"] = yield filter_events_for_client(
-                    self.store, user.to_string(), res["events_after"]
+                    self.storage, user.to_string(), res["events_after"]
                 )
 
                 res["start"] = now_token.copy_and_replace(
@@ -372,7 +374,7 @@ class SearchHandler(BaseHandler):
                         [(EventTypes.Member, sender) for sender in senders]
                     )
 
-                    state = yield self.store.get_state_for_event(
+                    state = yield self.state_store.get_state_for_event(
                         last_event_id, state_filter
                     )
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index d99160e9d7..43a082dcda 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -230,6 +230,8 @@ class SyncHandler(object):
         self.response_cache = ResponseCache(hs, "sync")
         self.state = hs.get_state_handler()
         self.auth = hs.get_auth()
+        self.storage = hs.get_storage()
+        self.state_store = self.storage.state
 
         # ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
         self.lazy_loaded_members_cache = ExpiringCache(
@@ -417,7 +419,7 @@ class SyncHandler(object):
                     current_state_ids = frozenset(itervalues(current_state_ids))
 
                 recents = yield filter_events_for_client(
-                    self.store,
+                    self.storage,
                     sync_config.user.to_string(),
                     recents,
                     always_include_ids=current_state_ids,
@@ -470,7 +472,7 @@ class SyncHandler(object):
                     current_state_ids = frozenset(itervalues(current_state_ids))
 
                 loaded_recents = yield filter_events_for_client(
-                    self.store,
+                    self.storage,
                     sync_config.user.to_string(),
                     loaded_recents,
                     always_include_ids=current_state_ids,
@@ -509,7 +511,7 @@ class SyncHandler(object):
         Returns:
             A Deferred map from ((type, state_key)->Event)
         """
-        state_ids = yield self.store.get_state_ids_for_event(
+        state_ids = yield self.state_store.get_state_ids_for_event(
             event.event_id, state_filter=state_filter
         )
         if event.is_state():
@@ -580,7 +582,7 @@ class SyncHandler(object):
             return None
 
         last_event = last_events[-1]
-        state_ids = yield self.store.get_state_ids_for_event(
+        state_ids = yield self.state_store.get_state_ids_for_event(
             last_event.event_id,
             state_filter=StateFilter.from_types(
                 [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
@@ -757,11 +759,11 @@ class SyncHandler(object):
 
             if full_state:
                 if batch:
-                    current_state_ids = yield self.store.get_state_ids_for_event(
+                    current_state_ids = yield self.state_store.get_state_ids_for_event(
                         batch.events[-1].event_id, state_filter=state_filter
                     )
 
-                    state_ids = yield self.store.get_state_ids_for_event(
+                    state_ids = yield self.state_store.get_state_ids_for_event(
                         batch.events[0].event_id, state_filter=state_filter
                     )
 
@@ -781,7 +783,7 @@ class SyncHandler(object):
                 )
             elif batch.limited:
                 if batch:
-                    state_at_timeline_start = yield self.store.get_state_ids_for_event(
+                    state_at_timeline_start = yield self.state_store.get_state_ids_for_event(
                         batch.events[0].event_id, state_filter=state_filter
                     )
                 else:
@@ -810,7 +812,7 @@ class SyncHandler(object):
                 )
 
                 if batch:
-                    current_state_ids = yield self.store.get_state_ids_for_event(
+                    current_state_ids = yield self.state_store.get_state_ids_for_event(
                         batch.events[-1].event_id, state_filter=state_filter
                     )
                 else:
@@ -841,7 +843,7 @@ class SyncHandler(object):
                         # So we fish out all the member events corresponding to the
                         # timeline here, and then dedupe any redundant ones below.
 
-                        state_ids = yield self.store.get_state_ids_for_event(
+                        state_ids = yield self.state_store.get_state_ids_for_event(
                             batch.events[0].event_id,
                             # we only want members!
                             state_filter=StateFilter.from_types(
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 4e091314e6..af161a81d7 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -159,6 +159,7 @@ class Notifier(object):
         self.room_to_user_streams = {}
 
         self.hs = hs
+        self.storage = hs.get_storage()
         self.event_sources = hs.get_event_sources()
         self.store = hs.get_datastore()
         self.pending_new_room_events = []
@@ -425,7 +426,10 @@ class Notifier(object):
 
                 if name == "room":
                     new_events = yield filter_events_for_client(
-                        self.store, user.to_string(), new_events, is_peeking=is_peeking
+                        self.storage,
+                        user.to_string(),
+                        new_events,
+                        is_peeking=is_peeking,
                     )
                 elif name == "presence":
                     now = self.clock.time_msec()
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 6299587808..36e26032a1 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -64,6 +64,7 @@ class HttpPusher(object):
     def __init__(self, hs, pusherdict):
         self.hs = hs
         self.store = self.hs.get_datastore()
+        self.storage = self.hs.get_storage()
         self.clock = self.hs.get_clock()
         self.state_handler = self.hs.get_state_handler()
         self.user_id = pusherdict["user_name"]
@@ -329,7 +330,7 @@ class HttpPusher(object):
             return d
 
         ctx = yield push_tools.get_context_for_event(
-            self.store, self.state_handler, event, self.user_id
+            self.storage, self.state_handler, event, self.user_id
         )
 
         d = {
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 5b16ab4ae8..1d15a06a58 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -119,6 +119,7 @@ class Mailer(object):
         self.store = self.hs.get_datastore()
         self.macaroon_gen = self.hs.get_macaroon_generator()
         self.state_handler = self.hs.get_state_handler()
+        self.storage = hs.get_storage()
         self.app_name = app_name
 
         logger.info("Created Mailer for app_name %s" % app_name)
@@ -389,7 +390,7 @@ class Mailer(object):
         }
 
         the_events = yield filter_events_for_client(
-            self.store, user_id, results["events_before"]
+            self.storage, user_id, results["events_before"]
         )
         the_events.append(notif_event)
 
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index a54051a726..de5c101a58 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -16,6 +16,7 @@
 from twisted.internet import defer
 
 from synapse.push.presentable_names import calculate_room_name, name_from_member_event
+from synapse.storage import Storage
 
 
 @defer.inlineCallbacks
@@ -43,22 +44,22 @@ def get_badge_count(store, user_id):
 
 
 @defer.inlineCallbacks
-def get_context_for_event(store, state_handler, ev, user_id):
+def get_context_for_event(storage: Storage, state_handler, ev, user_id):
     ctx = {}
 
-    room_state_ids = yield store.get_state_ids_for_event(ev.event_id)
+    room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id)
 
     # we no longer bother setting room_alias, and make room_name the
     # human-readable name instead, be that m.room.name, an alias or
     # a list of people in the room
     name = yield calculate_room_name(
-        store, room_state_ids, user_id, fallback_to_single_member=False
+        storage.main, room_state_ids, user_id, fallback_to_single_member=False
     )
     if name:
         ctx["name"] = name
 
     sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
-    sender_state_event = yield store.get_event(sender_state_event_id)
+    sender_state_event = yield storage.main.get_event(sender_state_event_id)
     ctx["sender_display_name"] = name_from_member_event(sender_state_event)
 
     return ctx
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index dc9f5a9008..4e91eb66fe 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -103,6 +103,7 @@ class StateHandler(object):
     def __init__(self, hs):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
+        self.state_store = hs.get_storage().state
         self.hs = hs
         self._state_resolution_handler = hs.get_state_resolution_handler()
 
@@ -271,7 +272,7 @@ class StateHandler(object):
             else:
                 current_state_ids = prev_state_ids
 
-            state_group = yield self.store.store_state_group(
+            state_group = yield self.state_store.store_state_group(
                 event.event_id,
                 event.room_id,
                 prev_group=None,
@@ -321,7 +322,7 @@ class StateHandler(object):
                 delta_ids = dict(entry.delta_ids)
                 delta_ids[key] = event.event_id
 
-            state_group = yield self.store.store_state_group(
+            state_group = yield self.state_store.store_state_group(
                 event.event_id,
                 event.room_id,
                 prev_group=prev_group,
@@ -334,7 +335,7 @@ class StateHandler(object):
             delta_ids = entry.delta_ids
 
             if entry.state_group is None:
-                entry.state_group = yield self.store.store_state_group(
+                entry.state_group = yield self.state_store.store_state_group(
                     event.event_id,
                     event.room_id,
                     prev_group=entry.prev_group,
@@ -376,14 +377,16 @@ class StateHandler(object):
         # map from state group id to the state in that state group (where
         # 'state' is a map from state key to event id)
         # dict[int, dict[(str, str), str]]
-        state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids)
+        state_groups_ids = yield self.state_store.get_state_groups_ids(
+            room_id, event_ids
+        )
 
         if len(state_groups_ids) == 0:
             return _StateCacheEntry(state={}, state_group=None)
         elif len(state_groups_ids) == 1:
             name, state_list = list(state_groups_ids.items()).pop()
 
-            prev_group, delta_ids = yield self.store.get_state_group_delta(name)
+            prev_group, delta_ids = yield self.state_store.get_state_group_delta(name)
 
             return _StateCacheEntry(
                 state=state_list,
diff --git a/synapse/visibility.py b/synapse/visibility.py
index bf0f1eebd8..8c843febd8 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -23,6 +23,7 @@ from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events.utils import prune_event
+from synapse.storage import Storage
 from synapse.storage.state import StateFilter
 from synapse.types import get_domain_from_id
 
@@ -43,14 +44,13 @@ MEMBERSHIP_PRIORITY = (
 
 @defer.inlineCallbacks
 def filter_events_for_client(
-    store, user_id, events, is_peeking=False, always_include_ids=frozenset()
+    storage: Storage, user_id, events, is_peeking=False, always_include_ids=frozenset()
 ):
     """
     Check which events a user is allowed to see
 
     Args:
-        store (synapse.storage.DataStore): our datastore (can also be a worker
-            store)
+        storage
         user_id(str): user id to be checked
         events(list[synapse.events.EventBase]): sequence of events to be checked
         is_peeking(bool): should be True if:
@@ -68,12 +68,12 @@ def filter_events_for_client(
     events = list(e for e in events if not e.internal_metadata.is_soft_failed())
 
     types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
-    event_id_to_state = yield store.get_state_for_events(
+    event_id_to_state = yield storage.state.get_state_for_events(
         frozenset(e.event_id for e in events),
         state_filter=StateFilter.from_types(types),
     )
 
-    ignore_dict_content = yield store.get_global_account_data_by_type_for_user(
+    ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user(
         "m.ignored_user_list", user_id
     )
 
@@ -84,7 +84,7 @@ def filter_events_for_client(
         else []
     )
 
-    erased_senders = yield store.are_users_erased((e.sender for e in events))
+    erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
 
     def allowed(event):
         """
@@ -213,13 +213,17 @@ def filter_events_for_client(
 
 @defer.inlineCallbacks
 def filter_events_for_server(
-    store, server_name, events, redact=True, check_history_visibility_only=False
+    storage: Storage,
+    server_name,
+    events,
+    redact=True,
+    check_history_visibility_only=False,
 ):
     """Filter a list of events based on whether given server is allowed to
     see them.
 
     Args:
-        store (DataStore)
+        storage
         server_name (str)
         events (iterable[FrozenEvent])
         redact (bool): Whether to return a redacted version of the event, or
@@ -274,7 +278,7 @@ def filter_events_for_server(
     # Lets check to see if all the events have a history visibility
     # of "shared" or "world_readable". If thats the case then we don't
     # need to check membership (as we know the server is in the room).
-    event_to_state_ids = yield store.get_state_ids_for_events(
+    event_to_state_ids = yield storage.state.get_state_ids_for_events(
         frozenset(e.event_id for e in events),
         state_filter=StateFilter.from_types(
             types=((EventTypes.RoomHistoryVisibility, ""),)
@@ -292,14 +296,14 @@ def filter_events_for_server(
     if not visibility_ids:
         all_open = True
     else:
-        event_map = yield store.get_events(visibility_ids)
+        event_map = yield storage.main.get_events(visibility_ids)
         all_open = all(
             e.content.get("history_visibility") in (None, "shared", "world_readable")
             for e in itervalues(event_map)
         )
 
     if not check_history_visibility_only:
-        erased_senders = yield store.are_users_erased((e.sender for e in events))
+        erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
     else:
         # We don't want to check whether users are erased, which is equivalent
         # to no users having been erased.
@@ -328,7 +332,7 @@ def filter_events_for_server(
 
     # first, for each event we're wanting to return, get the event_ids
     # of the history vis and membership state at those events.
-    event_to_state_ids = yield store.get_state_ids_for_events(
+    event_to_state_ids = yield storage.state.get_state_ids_for_events(
         frozenset(e.event_id for e in events),
         state_filter=StateFilter.from_types(
             types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None))
@@ -358,7 +362,7 @@ def filter_events_for_server(
             return False
         return state_key[idx + 1 :] == server_name
 
-    event_map = yield store.get_events(
+    event_map = yield storage.main.get_events(
         [
             e_id
             for e_id, key in iteritems(event_id_to_state_key)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index d573a3e07b..43200654f1 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -35,6 +35,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         self.store = hs.get_datastore()
         self.storage = hs.get_storage()
+        self.state_datastore = self.store
         self.event_builder_factory = hs.get_event_builder_factory()
         self.event_creation_handler = hs.get_event_creation_handler()
 
@@ -83,7 +84,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield self.store.get_state_groups_ids(
+        state_group_map = yield self.storage.state.get_state_groups_ids(
             self.room, [e2.event_id]
         )
         self.assertEqual(len(state_group_map), 1)
@@ -102,7 +103,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id])
+        state_group_map = yield self.storage.state.get_state_groups(
+            self.room, [e2.event_id]
+        )
         self.assertEqual(len(state_group_map), 1)
         state_list = list(state_group_map.values())[0]
 
@@ -142,7 +145,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we get the full state as of the final event
-        state = yield self.store.get_state_for_event(e5.event_id)
+        state = yield self.storage.state.get_state_for_event(e5.event_id)
 
         self.assertIsNotNone(e4)
 
@@ -158,21 +161,21 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we can filter to the m.room.name event (with a '' state key)
-        state = yield self.store.get_state_for_event(
+        state = yield self.storage.state.get_state_for_event(
             e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
         )
 
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can filter to the m.room.name event (with a wildcard None state key)
-        state = yield self.store.get_state_for_event(
+        state = yield self.storage.state.get_state_for_event(
             e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
         )
 
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can grab the m.room.member events (with a wildcard None state key)
-        state = yield self.store.get_state_for_event(
+        state = yield self.storage.state.get_state_for_event(
             e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
         )
 
@@ -182,7 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # check we can grab a specific room member without filtering out the
         # other event types
-        state = yield self.store.get_state_for_event(
+        state = yield self.storage.state.get_state_for_event(
             e5.event_id,
             state_filter=StateFilter(
                 types={EventTypes.Member: {self.u_alice.to_string()}},
@@ -200,7 +203,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check that we can grab everything except members
-        state = yield self.store.get_state_for_event(
+        state = yield self.storage.state.get_state_for_event(
             e5.event_id,
             state_filter=StateFilter(
                 types={EventTypes.Member: set()}, include_others=True
@@ -216,13 +219,18 @@ class StateStoreTestCase(tests.unittest.TestCase):
         #######################################################
 
         room_id = self.room.to_string()
-        group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
+        group_ids = yield self.storage.state.get_state_groups_ids(
+            room_id, [e5.event_id]
+        )
         group = list(group_ids.keys())[0]
 
         # test _get_state_for_group_using_cache correctly filters out members
         # with types=[]
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: set()}, include_others=True
@@ -238,8 +246,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: set()}, include_others=True
@@ -251,8 +262,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with wildcard types
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: None}, include_others=True
@@ -268,8 +282,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: None}, include_others=True
@@ -288,8 +305,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -305,8 +325,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -318,8 +341,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=False
@@ -332,9 +358,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
         #######################################################
         # deliberately remove e2 (room name) from the _state_group_cache
 
-        (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
-            group
-        )
+        (
+            is_all,
+            known_absent,
+            state_dict_ids,
+        ) = self.state_datastore._state_group_cache.get(group)
 
         self.assertEqual(is_all, True)
         self.assertEqual(known_absent, set())
@@ -347,18 +375,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         state_dict_ids.pop((e2.type, e2.state_key))
-        self.store._state_group_cache.invalidate(group)
-        self.store._state_group_cache.update(
-            sequence=self.store._state_group_cache.sequence,
+        self.state_datastore._state_group_cache.invalidate(group)
+        self.state_datastore._state_group_cache.update(
+            sequence=self.state_datastore._state_group_cache.sequence,
             key=group,
             value=state_dict_ids,
             # list fetched keys so it knows it's partial
             fetched_keys=((e1.type, e1.state_key),),
         )
 
-        (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
-            group
-        )
+        (
+            is_all,
+            known_absent,
+            state_dict_ids,
+        ) = self.state_datastore._state_group_cache.get(group)
 
         self.assertEqual(is_all, False)
         self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
@@ -370,8 +400,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
         # test _get_state_for_group_using_cache correctly filters out members
         # with types=[]
         room_id = self.room.to_string()
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: set()}, include_others=True
@@ -382,8 +415,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
         room_id = self.room.to_string()
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: set()}, include_others=True
@@ -395,8 +431,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # wildcard types
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: None}, include_others=True
@@ -406,8 +445,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: None}, include_others=True
@@ -425,8 +467,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -436,8 +481,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=True
@@ -449,8 +497,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # test _get_state_for_group_using_cache correctly filters in members
         # with specific types
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=False
@@ -460,8 +511,11 @@ class StateStoreTestCase(tests.unittest.TestCase):
         self.assertEqual(is_all, False)
         self.assertDictEqual({}, state_dict)
 
-        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
-            self.store._state_group_members_cache,
+        (
+            state_dict,
+            is_all,
+        ) = yield self.state_datastore._get_state_for_group_using_cache(
+            self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
                 types={EventTypes.Member: {e5.state_key}}, include_others=False
diff --git a/tests/test_state.py b/tests/test_state.py
index 610ec9fb46..38246555bd 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -158,10 +158,12 @@ class Graph(object):
 class StateTestCase(unittest.TestCase):
     def setUp(self):
         self.store = StateGroupStore()
+        storage = Mock(main=self.store, state=self.store)
         hs = Mock(
             spec_set=[
                 "config",
                 "get_datastore",
+                "get_storage",
                 "get_auth",
                 "get_state_handler",
                 "get_clock",
@@ -174,6 +176,7 @@ class StateTestCase(unittest.TestCase):
         hs.get_clock.return_value = MockClock()
         hs.get_auth.return_value = Auth(hs)
         hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
+        hs.get_storage.return_value = storage
 
         self.state = StateHandler(hs)
         self.event_id = 0
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 6ae1ea9b04..f7381b2885 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -14,6 +14,8 @@
 # limitations under the License.
 import logging
 
+from mock import Mock
+
 from twisted.internet import defer
 from twisted.internet.defer import succeed
 
@@ -63,7 +65,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             events_to_filter.append(evt)
 
         filtered = yield filter_events_for_server(
-            self.store, "test_server", events_to_filter
+            self.storage, "test_server", events_to_filter
         )
 
         # the result should be 5 redacted events, and 5 unredacted events.
@@ -101,7 +103,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
 
         # ... and the filtering happens.
         filtered = yield filter_events_for_server(
-            self.store, "test_server", events_to_filter
+            self.storage, "test_server", events_to_filter
         )
 
         for i in range(0, len(events_to_filter)):
@@ -258,6 +260,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
 
         logger.info("Starting filtering")
         start = time.time()
+
+        storage = Mock()
+        storage.main = test_store
+        storage.state = test_store
+
         filtered = yield filter_events_for_server(
             test_store, "test_server", events_to_filter
         )