summary refs log tree commit diff
diff options
context:
space:
mode:
authorNeil Johnson <neil@matrix.org>2018-11-02 15:23:36 +0000
committerNeil Johnson <neil@matrix.org>2018-11-02 15:23:36 +0000
commit1d0a5ab3c55c2f8377f73324c4fffd482d688b6b (patch)
treeecf06d475590f3267f2e5a470a0fb4281be911dd
parentMerge branch 'develop' of github.com:matrix-org/synapse into neilj/create_sup... (diff)
parentwip tests to filter out support user (diff)
downloadsynapse-1d0a5ab3c55c2f8377f73324c4fffd482d688b6b.tar.xz
Merge branch 'neilj/create_support_user' of github.com:matrix-org/synapse into neilj/create_support_user
-rw-r--r--.travis.yml33
-rw-r--r--README.rst13
-rw-r--r--changelog.d/3975.feature1
-rw-r--r--changelog.d/4011.misc1
-rw-r--r--changelog.d/4051.feature1
-rw-r--r--changelog.d/4072.misc1
-rw-r--r--changelog.d/4081.bugfix2
-rw-r--r--changelog.d/4089.feature1
-rw-r--r--docker/conf/homeserver.yaml4
-rwxr-xr-xsynapse/app/homeserver.py8
-rw-r--r--synapse/config/homeserver.py3
-rw-r--r--synapse/config/registration.py15
-rw-r--r--synapse/config/room_directory.py102
-rw-r--r--synapse/federation/federation_server.py16
-rw-r--r--synapse/handlers/directory.py9
-rw-r--r--synapse/handlers/initial_sync.py4
-rw-r--r--synapse/handlers/message.py20
-rw-r--r--synapse/handlers/pagination.py15
-rw-r--r--synapse/handlers/register.py30
-rw-r--r--synapse/handlers/room.py22
-rw-r--r--synapse/handlers/sync.py97
-rw-r--r--synapse/rest/client/v1/room.py3
-rw-r--r--synapse/server.py1
-rw-r--r--synapse/storage/directory.py2
-rw-r--r--synapse/storage/events.py2
-rw-r--r--synapse/storage/monthly_active_users.py72
-rw-r--r--synapse/storage/registration.py35
-rw-r--r--synapse/storage/state.py845
-rw-r--r--synapse/storage/user_directory.py47
-rw-r--r--synapse/util/__init__.py25
-rw-r--r--synapse/visibility.py15
-rw-r--r--tests/config/test_room_directory.py67
-rw-r--r--tests/handlers/test_directory.py48
-rw-r--r--tests/handlers/test_register.py75
-rw-r--r--tests/storage/test_monthly_active_users.py10
-rw-r--r--tests/storage/test_state.py175
-rw-r--r--tests/storage/test_user_directory.py39
-rw-r--r--tests/utils.py1
-rw-r--r--tox.ini5
39 files changed, 1271 insertions, 594 deletions
diff --git a/.travis.yml b/.travis.yml
index 197dec2bc9..fd41841c77 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,8 +1,20 @@
 sudo: false
 language: python
 
-# tell travis to cache ~/.cache/pip
-cache: pip
+cache:
+  directories:
+    # we only bother to cache the wheels; parts of the http cache get
+    # invalidated every build (because they get served with a max-age of 600
+    # seconds), which means that we end up re-uploading the whole cache for
+    # every build, which is time-consuming In any case, it's not obvious that
+    # downloading the cache from S3 would be much faster than downloading the
+    # originals from pypi.
+    #
+    - $HOME/.cache/pip/wheels
+
+# don't clone the whole repo history, one commit will do
+git:
+  depth: 1
 
 # only build branches we care about (PRs are built seperately)
 branches:
@@ -11,10 +23,6 @@ branches:
     - develop
     - /^release-v/
 
-before_script:
-  - git remote set-branches --add origin develop
-  - git fetch origin develop
-
 matrix:
   fast_finish: true
   include:
@@ -22,7 +30,7 @@ matrix:
     env: TOX_ENV=packaging
 
   - python: 3.6
-    env: TOX_ENV=pep8
+    env: TOX_ENV="pep8,check_isort"
 
   - python: 2.7
     env: TOX_ENV=py27
@@ -46,11 +54,14 @@ matrix:
     services:
       - postgresql
 
-  - python: 3.6
-    env: TOX_ENV=check_isort
-
-  - python: 3.6
+  - # we only need to check for the newsfragment if it's a PR build
+    if: type = pull_request
+    python: 3.6
     env: TOX_ENV=check-newsfragment
+    script:
+      - git remote set-branches --add origin develop
+      - git fetch origin develop
+      - tox -e $TOX_ENV
 
 install:
   - pip install tox
diff --git a/README.rst b/README.rst
index 456a3d9d43..9165db8319 100644
--- a/README.rst
+++ b/README.rst
@@ -657,7 +657,8 @@ Using a reverse proxy with Synapse
 
 It is recommended to put a reverse proxy such as
 `nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_,
-`Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_ or
+`Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_,
+`Caddy <https://caddyserver.com/docs/proxy>`_ or
 `HAProxy <https://www.haproxy.org/>`_ in front of Synapse. One advantage of
 doing so is that it means that you can expose the default https port (443) to
 Matrix clients without needing to run Synapse with root privileges.
@@ -688,7 +689,15 @@ so an example nginx configuration might look like::
       }
   }
 
-and an example apache configuration may look like::
+an example Caddy configuration might look like::
+
+    matrix.example.com {
+      proxy /_matrix http://localhost:8008 {
+        transparent
+      }
+    }
+
+and an example Apache configuration might look like::
 
     <VirtualHost *:443>
         SSLEngine on
diff --git a/changelog.d/3975.feature b/changelog.d/3975.feature
new file mode 100644
index 0000000000..162f30a532
--- /dev/null
+++ b/changelog.d/3975.feature
@@ -0,0 +1 @@
+Servers with auto-join rooms will now automatically create those rooms when the first user registers
diff --git a/changelog.d/4011.misc b/changelog.d/4011.misc
new file mode 100644
index 0000000000..ad7768c4cd
--- /dev/null
+++ b/changelog.d/4011.misc
@@ -0,0 +1 @@
+Reduce database load when fetching state groups
diff --git a/changelog.d/4051.feature b/changelog.d/4051.feature
new file mode 100644
index 0000000000..9c1b3a72a0
--- /dev/null
+++ b/changelog.d/4051.feature
@@ -0,0 +1 @@
+Add config option to control alias creation
diff --git a/changelog.d/4072.misc b/changelog.d/4072.misc
new file mode 100644
index 0000000000..9d7279fd2b
--- /dev/null
+++ b/changelog.d/4072.misc
@@ -0,0 +1 @@
+The README now contains example for the Caddy web server. Contributed by steamp0rt.
diff --git a/changelog.d/4081.bugfix b/changelog.d/4081.bugfix
new file mode 100644
index 0000000000..cfe4b3e9d9
--- /dev/null
+++ b/changelog.d/4081.bugfix
@@ -0,0 +1,2 @@
+Fix race condition where config defined reserved users were not being added to
+the monthly active user list prior to the homeserver reactor firing up
diff --git a/changelog.d/4089.feature b/changelog.d/4089.feature
new file mode 100644
index 0000000000..62c9d839bb
--- /dev/null
+++ b/changelog.d/4089.feature
@@ -0,0 +1 @@
+ Configure Docker image to listen on both ipv4 and ipv6.
diff --git a/docker/conf/homeserver.yaml b/docker/conf/homeserver.yaml
index a38b929f50..1b0f655d26 100644
--- a/docker/conf/homeserver.yaml
+++ b/docker/conf/homeserver.yaml
@@ -21,7 +21,7 @@ listeners:
   {% if not SYNAPSE_NO_TLS %}
   -
     port: 8448
-    bind_addresses: ['0.0.0.0']
+    bind_addresses: ['::']
     type: http
     tls: true
     x_forwarded: false
@@ -34,7 +34,7 @@ listeners:
 
   - port: 8008
     tls: false
-    bind_addresses: ['0.0.0.0']
+    bind_addresses: ['::']
     type: http
     x_forwarded: false
 
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 7a49448965..7315b57d28 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -553,14 +553,6 @@ def run(hs):
             generate_monthly_active_users,
         )
 
-    # XXX is this really supposed to be a background process? it looks
-    # like it needs to complete before some of the other stuff runs.
-    run_as_background_process(
-        "initialise_reserved_users",
-        hs.get_datastore().initialise_reserved_users,
-        hs.config.mau_limits_reserved_threepids,
-    )
-
     start_generate_monthly_active_users()
     if hs.config.limit_usage_by_mau:
         clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000)
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index b8d5690f2b..10dd40159f 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -31,6 +31,7 @@ from .push import PushConfig
 from .ratelimiting import RatelimitConfig
 from .registration import RegistrationConfig
 from .repository import ContentRepositoryConfig
+from .room_directory import RoomDirectoryConfig
 from .saml2 import SAML2Config
 from .server import ServerConfig
 from .server_notices_config import ServerNoticesConfig
@@ -49,7 +50,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
                        WorkerConfig, PasswordAuthProviderConfig, PushConfig,
                        SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,
                        ConsentConfig,
-                       ServerNoticesConfig,
+                       ServerNoticesConfig, RoomDirectoryConfig,
                        ):
     pass
 
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 0fb964eb67..7480ed5145 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -15,10 +15,10 @@
 
 from distutils.util import strtobool
 
+from synapse.config._base import Config, ConfigError
+from synapse.types import RoomAlias
 from synapse.util.stringutils import random_string_with_symbols
 
-from ._base import Config
-
 
 class RegistrationConfig(Config):
 
@@ -44,6 +44,10 @@ class RegistrationConfig(Config):
         )
 
         self.auto_join_rooms = config.get("auto_join_rooms", [])
+        for room_alias in self.auto_join_rooms:
+            if not RoomAlias.is_valid(room_alias):
+                raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,))
+        self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
 
     def default_config(self, **kwargs):
         registration_shared_secret = random_string_with_symbols(50)
@@ -98,6 +102,13 @@ class RegistrationConfig(Config):
         # to these rooms
         #auto_join_rooms:
         #    - "#example:example.com"
+
+        # Where auto_join_rooms are specified, setting this flag ensures that the
+        # the rooms exist by creating them when the first user on the
+        # homeserver registers.
+        # Setting to false means that if the rooms are not manually created,
+        # users cannot be auto-joined since they do not exist.
+        autocreate_auto_join_rooms: true
         """ % locals()
 
     def add_arguments(self, parser):
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
new file mode 100644
index 0000000000..9da13ab11b
--- /dev/null
+++ b/synapse/config/room_directory.py
@@ -0,0 +1,102 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util import glob_to_regex
+
+from ._base import Config, ConfigError
+
+
+class RoomDirectoryConfig(Config):
+    def read_config(self, config):
+        alias_creation_rules = config["alias_creation_rules"]
+
+        self._alias_creation_rules = [
+            _AliasRule(rule)
+            for rule in alias_creation_rules
+        ]
+
+    def default_config(self, config_dir_path, server_name, **kwargs):
+        return """
+        # The `alias_creation` option controls who's allowed to create aliases
+        # on this server.
+        #
+        # The format of this option is a list of rules that contain globs that
+        # match against user_id and the new alias (fully qualified with server
+        # name). The action in the first rule that matches is taken, which can
+        # currently either be "allow" or "deny".
+        #
+        # If no rules match the request is denied.
+        alias_creation_rules:
+            - user_id: "*"
+              alias: "*"
+              action: allow
+        """
+
+    def is_alias_creation_allowed(self, user_id, alias):
+        """Checks if the given user is allowed to create the given alias
+
+        Args:
+            user_id (str)
+            alias (str)
+
+        Returns:
+            boolean: True if user is allowed to crate the alias
+        """
+        for rule in self._alias_creation_rules:
+            if rule.matches(user_id, alias):
+                return rule.action == "allow"
+
+        return False
+
+
+class _AliasRule(object):
+    def __init__(self, rule):
+        action = rule["action"]
+        user_id = rule["user_id"]
+        alias = rule["alias"]
+
+        if action in ("allow", "deny"):
+            self.action = action
+        else:
+            raise ConfigError(
+                "alias_creation_rules rules can only have action of 'allow'"
+                " or 'deny'"
+            )
+
+        try:
+            self._user_id_regex = glob_to_regex(user_id)
+            self._alias_regex = glob_to_regex(alias)
+        except Exception as e:
+            raise ConfigError("Failed to parse glob into regex: %s", e)
+
+    def matches(self, user_id, alias):
+        """Tests if this rule matches the given user_id and alias.
+
+        Args:
+            user_id (str)
+            alias (str)
+
+        Returns:
+            boolean
+        """
+
+        # Note: The regexes are anchored at both ends
+        if not self._user_id_regex.match(user_id):
+            return False
+
+        if not self._alias_regex.match(alias):
+            return False
+
+        return True
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index af0107a46e..0f9302a6a8 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -14,7 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import re
 
 import six
 from six import iteritems
@@ -44,6 +43,7 @@ from synapse.replication.http.federation import (
     ReplicationGetQueryRestServlet,
 )
 from synapse.types import get_domain_from_id
+from synapse.util import glob_to_regex
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.logcontext import nested_logging_context
@@ -729,22 +729,10 @@ def _acl_entry_matches(server_name, acl_entry):
     if not isinstance(acl_entry, six.string_types):
         logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry))
         return False
-    regex = _glob_to_regex(acl_entry)
+    regex = glob_to_regex(acl_entry)
     return regex.match(server_name)
 
 
-def _glob_to_regex(glob):
-    res = ''
-    for c in glob:
-        if c == '*':
-            res = res + '.*'
-        elif c == '?':
-            res = res + '.'
-        else:
-            res = res + re.escape(c)
-    return re.compile(res + "\\Z", re.IGNORECASE)
-
-
 class FederationHandlerRegistry(object):
     """Allows classes to register themselves as handlers for a given EDU or
     query type for incoming federation traffic.
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 02f12f6645..7d67bf803a 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -43,6 +43,7 @@ class DirectoryHandler(BaseHandler):
         self.state = hs.get_state_handler()
         self.appservice_handler = hs.get_application_service_handler()
         self.event_creation_handler = hs.get_event_creation_handler()
+        self.config = hs.config
 
         self.federation = hs.get_federation_client()
         hs.get_federation_registry().register_query_handler(
@@ -111,6 +112,14 @@ class DirectoryHandler(BaseHandler):
                     403, "This user is not permitted to create this alias",
                 )
 
+            if not self.config.is_alias_creation_allowed(user_id, room_alias.to_string()):
+                # Lets just return a generic message, as there may be all sorts of
+                # reasons why we said no. TODO: Allow configurable error messages
+                # per alias creation rule?
+                raise SynapseError(
+                    403, "Not allowed to create alias",
+                )
+
             can_create = yield self.can_modify_alias(
                 room_alias,
                 user_id=user_id
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index e009395207..563bb3cea3 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -156,7 +156,7 @@ class InitialSyncHandler(BaseHandler):
                     room_end_token = "s%d" % (event.stream_ordering,)
                     deferred_room_state = run_in_background(
                         self.store.get_state_for_events,
-                        [event.event_id], None,
+                        [event.event_id],
                     )
                     deferred_room_state.addCallback(
                         lambda states: states[event.event_id]
@@ -301,7 +301,7 @@ class InitialSyncHandler(BaseHandler):
     def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
                                   membership, member_event_id, is_peeking):
         room_state = yield self.store.get_state_for_events(
-            [member_event_id], None
+            [member_event_id],
         )
 
         room_state = room_state[member_event_id]
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 6c4fcfb10a..969e588e73 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -35,6 +35,7 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
 from synapse.events.utils import serialize_event
 from synapse.events.validator import EventValidator
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
+from synapse.storage.state import StateFilter
 from synapse.types import RoomAlias, UserID
 from synapse.util.async_helpers import Linearizer
 from synapse.util.frozenutils import frozendict_json_encoder
@@ -80,7 +81,7 @@ class MessageHandler(object):
         elif membership == Membership.LEAVE:
             key = (event_type, state_key)
             room_state = yield self.store.get_state_for_events(
-                [membership_event_id], [key]
+                [membership_event_id], StateFilter.from_types([key])
             )
             data = room_state[membership_event_id].get(key)
 
@@ -88,7 +89,7 @@ class MessageHandler(object):
 
     @defer.inlineCallbacks
     def get_state_events(
-        self, user_id, room_id, types=None, filtered_types=None,
+        self, user_id, room_id, state_filter=StateFilter.all(),
         at_token=None, is_guest=False,
     ):
         """Retrieve all state events for a given room. If the user is
@@ -100,13 +101,8 @@ class MessageHandler(object):
         Args:
             user_id(str): The user requesting state events.
             room_id(str): The room ID to get all state events from.
-            types(list[(str, str|None)]|None): List of (type, state_key) tuples
-                which are used to filter the state fetched. If `state_key` is None,
-                all events are returned of the given type.
-                May be None, which matches any key.
-            filtered_types(list[str]|None): Only apply filtering via `types` to this
-                list of event types.  Other types of events are returned unfiltered.
-                If None, `types` filtering is applied to all events.
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
             at_token(StreamToken|None): the stream token of the at which we are requesting
                 the stats. If the user is not allowed to view the state as of that
                 stream token, we raise a 403 SynapseError. If None, returns the current
@@ -139,7 +135,7 @@ class MessageHandler(object):
             event = last_events[0]
             if visible_events:
                 room_state = yield self.store.get_state_for_events(
-                    [event.event_id], types, filtered_types=filtered_types,
+                    [event.event_id], state_filter=state_filter,
                 )
                 room_state = room_state[event.event_id]
             else:
@@ -158,12 +154,12 @@ class MessageHandler(object):
 
             if membership == Membership.JOIN:
                 state_ids = yield self.store.get_filtered_current_state_ids(
-                    room_id, types, filtered_types=filtered_types,
+                    room_id, state_filter=state_filter,
                 )
                 room_state = yield self.store.get_events(state_ids.values())
             elif membership == Membership.LEAVE:
                 room_state = yield self.store.get_state_for_events(
-                    [membership_event_id], types, filtered_types=filtered_types,
+                    [membership_event_id], state_filter=state_filter,
                 )
                 room_state = room_state[membership_event_id]
 
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index a155b6e938..43f81bd607 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -21,6 +21,7 @@ from twisted.python.failure import Failure
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import SynapseError
 from synapse.events.utils import serialize_event
+from synapse.storage.state import StateFilter
 from synapse.types import RoomStreamToken
 from synapse.util.async_helpers import ReadWriteLock
 from synapse.util.logcontext import run_in_background
@@ -255,16 +256,14 @@ class PaginationHandler(object):
         if event_filter and event_filter.lazy_load_members():
             # TODO: remove redundant members
 
-            types = [
-                (EventTypes.Member, state_key)
-                for state_key in set(
-                    event.sender  # FIXME: we also care about invite targets etc.
-                    for event in events
-                )
-            ]
+            # FIXME: we also care about invite targets etc.
+            state_filter = StateFilter.from_types(
+                (EventTypes.Member, event.sender)
+                for event in events
+            )
 
             state_ids = yield self.store.get_state_ids_for_event(
-                events[0].event_id, types=types,
+                events[0].event_id, state_filter=state_filter,
             )
 
             if state_ids:
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index da914c46ff..e9d7b25a36 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -50,6 +50,7 @@ class RegistrationHandler(BaseHandler):
         self._auth_handler = hs.get_auth_handler()
         self.profile_handler = hs.get_profile_handler()
         self.user_directory_handler = hs.get_user_directory_handler()
+        self.room_creation_handler = self.hs.get_room_creation_handler()
         self.captcha_client = CaptchaServerHttpClient(hs)
 
         self._next_generated_user_id = None
@@ -220,9 +221,36 @@ class RegistrationHandler(BaseHandler):
 
         # auto-join the user to any rooms we're supposed to dump them into
         fake_requester = create_requester(user_id)
+
+        # try to create the room if we're the first user on the server
+        should_auto_create_rooms = False
+        if self.hs.config.autocreate_auto_join_rooms:
+            count = yield self.store.count_all_users()
+            should_auto_create_rooms = count == 1
+
         for r in self.hs.config.auto_join_rooms:
             try:
-                yield self._join_user_to_room(fake_requester, r)
+                if should_auto_create_rooms:
+                    room_alias = RoomAlias.from_string(r)
+                    if self.hs.hostname != room_alias.domain:
+                        logger.warning(
+                            'Cannot create room alias %s, '
+                            'it does not match server domain',
+                            r,
+                        )
+                    else:
+                        # create room expects the localpart of the room alias
+                        room_alias_localpart = room_alias.localpart
+                        yield self.room_creation_handler.create_room(
+                            fake_requester,
+                            config={
+                                "preset": "public_chat",
+                                "room_alias_name": room_alias_localpart
+                            },
+                            ratelimit=False,
+                        )
+                else:
+                    yield self._join_user_to_room(fake_requester, r)
             except Exception as e:
                 logger.error("Failed to join new user to %r: %r", r, e)
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index ab1571b27b..3ba92bdb4c 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -33,6 +33,7 @@ from synapse.api.constants import (
     RoomCreationPreset,
 )
 from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
+from synapse.storage.state import StateFilter
 from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
 from synapse.util import stringutils
 from synapse.visibility import filter_events_for_client
@@ -489,23 +490,24 @@ class RoomContextHandler(object):
         else:
             last_event_id = event_id
 
-        types = None
-        filtered_types = None
         if event_filter and event_filter.lazy_load_members():
-            members = set(ev.sender for ev in itertools.chain(
-                results["events_before"],
-                (results["event"],),
-                results["events_after"],
-            ))
-            filtered_types = [EventTypes.Member]
-            types = [(EventTypes.Member, member) for member in members]
+            state_filter = StateFilter.from_lazy_load_member_list(
+                ev.sender
+                for ev in itertools.chain(
+                    results["events_before"],
+                    (results["event"],),
+                    results["events_after"],
+                )
+            )
+        else:
+            state_filter = StateFilter.all()
 
         # XXX: why do we return the state as of the last event rather than the
         # first? Shouldn't we be consistent with /sync?
         # https://github.com/matrix-org/matrix-doc/issues/687
 
         state = yield self.store.get_state_for_events(
-            [last_event_id], types, filtered_types=filtered_types,
+            [last_event_id], state_filter=state_filter,
         )
         results["state"] = list(state[last_event_id].values())
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 351892a94f..09739f2862 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -27,6 +27,7 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes, Membership
 from synapse.push.clientformat import format_push_rules_for_user
 from synapse.storage.roommember import MemberSummary
+from synapse.storage.state import StateFilter
 from synapse.types import RoomStreamToken
 from synapse.util.async_helpers import concurrently_execute
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -469,25 +470,20 @@ class SyncHandler(object):
         ))
 
     @defer.inlineCallbacks
-    def get_state_after_event(self, event, types=None, filtered_types=None):
+    def get_state_after_event(self, event, state_filter=StateFilter.all()):
         """
         Get the room state after the given event
 
         Args:
             event(synapse.events.EventBase): event of interest
-            types(list[(str, str|None)]|None): List of (type, state_key) tuples
-                which are used to filter the state fetched. If `state_key` is None,
-                all events are returned of the given type.
-                May be None, which matches any key.
-            filtered_types(list[str]|None): Only apply filtering via `types` to this
-                list of event types.  Other types of events are returned unfiltered.
-                If None, `types` filtering is applied to all events.
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
 
         Returns:
             A Deferred map from ((type, state_key)->Event)
         """
         state_ids = yield self.store.get_state_ids_for_event(
-            event.event_id, types, filtered_types=filtered_types,
+            event.event_id, state_filter=state_filter,
         )
         if event.is_state():
             state_ids = state_ids.copy()
@@ -495,18 +491,14 @@ class SyncHandler(object):
         defer.returnValue(state_ids)
 
     @defer.inlineCallbacks
-    def get_state_at(self, room_id, stream_position, types=None, filtered_types=None):
+    def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()):
         """ Get the room state at a particular stream position
 
         Args:
             room_id(str): room for which to get state
             stream_position(StreamToken): point at which to get state
-            types(list[(str, str|None)]|None): List of (type, state_key) tuples
-                which are used to filter the state fetched. If `state_key` is None,
-                all events are returned of the given type.
-            filtered_types(list[str]|None): Only apply filtering via `types` to this
-                list of event types.  Other types of events are returned unfiltered.
-                If None, `types` filtering is applied to all events.
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
 
         Returns:
             A Deferred map from ((type, state_key)->Event)
@@ -522,7 +514,7 @@ class SyncHandler(object):
         if last_events:
             last_event = last_events[-1]
             state = yield self.get_state_after_event(
-                last_event, types, filtered_types=filtered_types,
+                last_event, state_filter=state_filter,
             )
 
         else:
@@ -563,10 +555,11 @@ class SyncHandler(object):
 
         last_event = last_events[-1]
         state_ids = yield self.store.get_state_ids_for_event(
-            last_event.event_id, [
+            last_event.event_id,
+            state_filter=StateFilter.from_types([
                 (EventTypes.Name, ''),
                 (EventTypes.CanonicalAlias, ''),
-            ]
+            ]),
         )
 
         # this is heavily cached, thus: fast.
@@ -717,8 +710,7 @@ class SyncHandler(object):
 
         with Measure(self.clock, "compute_state_delta"):
 
-            types = None
-            filtered_types = None
+            members_to_fetch = None
 
             lazy_load_members = sync_config.filter_collection.lazy_load_members()
             include_redundant_members = (
@@ -729,16 +721,21 @@ class SyncHandler(object):
                 # We only request state for the members needed to display the
                 # timeline:
 
-                types = [
-                    (EventTypes.Member, state_key)
-                    for state_key in set(
-                        event.sender  # FIXME: we also care about invite targets etc.
-                        for event in batch.events
-                    )
-                ]
+                members_to_fetch = set(
+                    event.sender  # FIXME: we also care about invite targets etc.
+                    for event in batch.events
+                )
+
+                if full_state:
+                    # always make sure we LL ourselves so we know we're in the room
+                    # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209
+                    # We only need apply this on full state syncs given we disabled
+                    # LL for incr syncs in #3840.
+                    members_to_fetch.add(sync_config.user.to_string())
 
-                # only apply the filtering to room members
-                filtered_types = [EventTypes.Member]
+                state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch)
+            else:
+                state_filter = StateFilter.all()
 
             timeline_state = {
                 (event.type, event.state_key): event.event_id
@@ -746,28 +743,19 @@ class SyncHandler(object):
             }
 
             if full_state:
-                if lazy_load_members:
-                    # always make sure we LL ourselves so we know we're in the room
-                    # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209
-                    # We only need apply this on full state syncs given we disabled
-                    # LL for incr syncs in #3840.
-                    types.append((EventTypes.Member, sync_config.user.to_string()))
-
                 if batch:
                     current_state_ids = yield self.store.get_state_ids_for_event(
-                        batch.events[-1].event_id, types=types,
-                        filtered_types=filtered_types,
+                        batch.events[-1].event_id, state_filter=state_filter,
                     )
 
                     state_ids = yield self.store.get_state_ids_for_event(
-                        batch.events[0].event_id, types=types,
-                        filtered_types=filtered_types,
+                        batch.events[0].event_id, state_filter=state_filter,
                     )
 
                 else:
                     current_state_ids = yield self.get_state_at(
-                        room_id, stream_position=now_token, types=types,
-                        filtered_types=filtered_types,
+                        room_id, stream_position=now_token,
+                        state_filter=state_filter,
                     )
 
                     state_ids = current_state_ids
@@ -781,8 +769,7 @@ class SyncHandler(object):
                 )
             elif batch.limited:
                 state_at_timeline_start = yield self.store.get_state_ids_for_event(
-                    batch.events[0].event_id, types=types,
-                    filtered_types=filtered_types,
+                    batch.events[0].event_id, state_filter=state_filter,
                 )
 
                 # for now, we disable LL for gappy syncs - see
@@ -797,17 +784,15 @@ class SyncHandler(object):
                 # members to just be ones which were timeline senders, which then ensures
                 # all of the rest get included in the state block (if we need to know
                 # about them).
-                types = None
-                filtered_types = None
+                state_filter = StateFilter.all()
 
                 state_at_previous_sync = yield self.get_state_at(
-                    room_id, stream_position=since_token, types=types,
-                    filtered_types=filtered_types,
+                    room_id, stream_position=since_token,
+                    state_filter=state_filter,
                 )
 
                 current_state_ids = yield self.store.get_state_ids_for_event(
-                    batch.events[-1].event_id, types=types,
-                    filtered_types=filtered_types,
+                    batch.events[-1].event_id, state_filter=state_filter,
                 )
 
                 state_ids = _calculate_state(
@@ -821,7 +806,7 @@ class SyncHandler(object):
             else:
                 state_ids = {}
                 if lazy_load_members:
-                    if types and batch.events:
+                    if members_to_fetch and batch.events:
                         # We're returning an incremental sync, with no
                         # "gap" since the previous sync, so normally there would be
                         # no state to return.
@@ -831,8 +816,12 @@ class SyncHandler(object):
                         # timeline here, and then dedupe any redundant ones below.
 
                         state_ids = yield self.store.get_state_ids_for_event(
-                            batch.events[0].event_id, types=types,
-                            filtered_types=None,  # we only want members!
+                            batch.events[0].event_id,
+                            # we only want members!
+                            state_filter=StateFilter.from_types(
+                                (EventTypes.Member, member)
+                                for member in members_to_fetch
+                            ),
                         )
 
             if lazy_load_members and not include_redundant_members:
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 663934efd0..fcfe7857f6 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -33,6 +33,7 @@ from synapse.http.servlet import (
     parse_json_object_from_request,
     parse_string,
 )
+from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
 from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
 
@@ -409,7 +410,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
             room_id=room_id,
             user_id=requester.user.to_string(),
             at_token=at_token,
-            types=[(EventTypes.Member, None)],
+            state_filter=StateFilter.from_types([(EventTypes.Member, None)]),
         )
 
         chunk = []
diff --git a/synapse/server.py b/synapse/server.py
index 3e9d3d8256..cf6b872cbd 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -207,6 +207,7 @@ class HomeServer(object):
         logger.info("Setting up.")
         with self.get_db_conn() as conn:
             self.datastore = self.DATASTORE_CLASS(conn, self)
+            conn.commit()
         logger.info("Finished setting up.")
 
     def get_reactor(self):
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index cfb687cb53..61a029a53c 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -90,7 +90,7 @@ class DirectoryWorkerStore(SQLBaseStore):
 class DirectoryStore(DirectoryWorkerStore):
     @defer.inlineCallbacks
     def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
-        """ Creates an associatin between  a room alias and room_id/servers
+        """ Creates an association between a room alias and room_id/servers
 
         Args:
             room_alias (RoomAlias)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index c780f55277..8881b009df 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -2089,7 +2089,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
         for sg in remaining_state_groups:
             logger.info("[purge] de-delta-ing remaining state group %s", sg)
             curr_state = self._get_state_groups_from_groups_txn(
-                txn, [sg], types=None
+                txn, [sg],
             )
             curr_state = curr_state[sg]
 
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
index 569e08cc54..5e34aa91ac 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/monthly_active_users.py
@@ -33,20 +33,29 @@ class MonthlyActiveUsersStore(SQLBaseStore):
         self._clock = hs.get_clock()
         self.hs = hs
         self.reserved_users = ()
-        self.support_user = None
+        # Do not add more reserved users than the total allowable number
+        self._initialise_reserved_users(
+            dbconn.cursor(),
+            hs.config.mau_limits_reserved_threepids[:self.hs.config.max_mau_value],
+        )
 
-    @defer.inlineCallbacks
-    def initialise_reserved_users(self, threepids):
-        store = self.hs.get_datastore()
+    def _initialise_reserved_users(self, txn, threepids):
+        """Ensures that reserved threepids are accounted for in the MAU table, should
+        be called on start up.
+
+        Args:
+            txn (cursor):
+            threepids (list[dict]): List of threepid dicts to reserve
+        """
         reserved_user_list = []
 
-        # Do not add more reserved users than the total allowable number
-        for tp in threepids[:self.hs.config.max_mau_value]:
-            user_id = yield store.get_user_id_by_threepid(
+        for tp in threepids:
+            user_id = self.get_user_id_by_threepid_txn(
+                txn,
                 tp["medium"], tp["address"]
             )
             if user_id:
-                yield self.upsert_monthly_active_user(user_id)
+                self.upsert_monthly_active_user_txn(txn, user_id)
                 reserved_user_list.append(user_id)
             else:
                 logger.warning(
@@ -56,8 +65,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def reap_monthly_active_users(self):
-        """
-        Cleans out monthly active user table to ensure that no stale
+        """Cleans out monthly active user table to ensure that no stale
         entries exist.
 
         Returns:
@@ -166,12 +174,37 @@ class MonthlyActiveUsersStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def upsert_monthly_active_user(self, user_id):
+        """Updates or inserts the user into the monthly active user table, which
+        is used to track the current MAU usage of the server
+
+        Args:
+            user_id (str): user to add/update
         """
-            Updates or inserts monthly active user member
-            Arguments:
-                user_id (str): user to add/update
-            Deferred[bool]: True if a new entry was created, False if an
-                existing one was updated.
+        is_insert = yield self.runInteraction(
+            "upsert_monthly_active_user", self.upsert_monthly_active_user_txn,
+            user_id
+        )
+
+        if is_insert:
+            self.user_last_seen_monthly_active.invalidate((user_id,))
+            self.get_monthly_active_count.invalidate(())
+
+    def upsert_monthly_active_user_txn(self, txn, user_id):
+        """Updates or inserts monthly active user member
+
+        Note that, after calling this method, it will generally be necessary
+        to invalidate the caches on user_last_seen_monthly_active and
+        get_monthly_active_count. We can't do that here, because we are running
+        in a database thread rather than the main thread, and we can't call
+        txn.call_after because txn may not be a LoggingTransaction.
+
+        Args:
+            txn (cursor):
+            user_id (str): user to add/update
+
+        Returns:
+            bool: True if a new entry was created, False if an
+            existing one was updated.
         """
         # Support user never to be included in MAU stats
         if user_id is self.hs.config.support_user:
@@ -180,8 +213,8 @@ class MonthlyActiveUsersStore(SQLBaseStore):
         # never be a big table and alternative approaches (batching multiple
         # upserts into a single txn) introduced a lot of extra complexity.
         # See https://github.com/matrix-org/synapse/issues/3854 for more
-        is_insert = yield self._simple_upsert(
-            desc="upsert_monthly_active_user",
+        is_insert = self._simple_upsert_txn(
+            txn,
             table="monthly_active_users",
             keyvalues={
                 "user_id": user_id,
@@ -190,9 +223,8 @@ class MonthlyActiveUsersStore(SQLBaseStore):
                 "timestamp": int(self._clock.time_msec()),
             },
         )
-        if is_insert:
-            self.user_last_seen_monthly_active.invalidate((user_id,))
-            self.get_monthly_active_count.invalidate(())
+
+        return is_insert
 
     @cached(num_args=1)
     def user_last_seen_monthly_active(self, user_id):
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 2dd14aba1c..80d76bf9d7 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -474,17 +474,44 @@ class RegistrationStore(RegistrationWorkerStore,
 
     @defer.inlineCallbacks
     def get_user_id_by_threepid(self, medium, address):
-        ret = yield self._simple_select_one(
+        """Returns user id from threepid
+
+        Args:
+            medium (str): threepid medium e.g. email
+            address (str): threepid address e.g. me@example.com
+
+        Returns:
+            Deferred[str|None]: user id or None if no user id/threepid mapping exists
+        """
+        user_id = yield self.runInteraction(
+            "get_user_id_by_threepid", self.get_user_id_by_threepid_txn,
+            medium, address
+        )
+        defer.returnValue(user_id)
+
+    def get_user_id_by_threepid_txn(self, txn, medium, address):
+        """Returns user id from threepid
+
+        Args:
+            txn (cursor):
+            medium (str): threepid medium e.g. email
+            address (str): threepid address e.g. me@example.com
+
+        Returns:
+            str|None: user id or None if no user id/threepid mapping exists
+        """
+        ret = self._simple_select_one_txn(
+            txn,
             "user_threepids",
             {
                 "medium": medium,
                 "address": address
             },
-            ['user_id'], True, 'get_user_id_by_threepid'
+            ['user_id'], True
         )
         if ret:
-            defer.returnValue(ret['user_id'])
-        defer.returnValue(None)
+            return ret['user_id']
+        return None
 
     def user_delete_threepid(self, user_id, medium, address):
         return self._simple_delete(
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 3f4cbd61c4..ef65929bb2 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -19,6 +19,8 @@ from collections import namedtuple
 from six import iteritems, itervalues
 from six.moves import range
 
+import attr
+
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
@@ -48,6 +50,318 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
         return len(self.delta_ids) if self.delta_ids else 0
 
 
+@attr.s(slots=True)
+class StateFilter(object):
+    """A filter used when querying for state.
+
+    Attributes:
+        types (dict[str, set[str]|None]): Map from type to set of state keys (or
+            None). This specifies which state_keys for the given type to fetch
+            from the DB. If None then all events with that type are fetched. If
+            the set is empty then no events with that type are fetched.
+        include_others (bool): Whether to fetch events with types that do not
+            appear in `types`.
+    """
+
+    types = attr.ib()
+    include_others = attr.ib(default=False)
+
+    def __attrs_post_init__(self):
+        # If `include_others` is set we canonicalise the filter by removing
+        # wildcards from the types dictionary
+        if self.include_others:
+            self.types = {
+                k: v for k, v in iteritems(self.types)
+                if v is not None
+            }
+
+    @staticmethod
+    def all():
+        """Creates a filter that fetches everything.
+
+        Returns:
+            StateFilter
+        """
+        return StateFilter(types={}, include_others=True)
+
+    @staticmethod
+    def none():
+        """Creates a filter that fetches nothing.
+
+        Returns:
+            StateFilter
+        """
+        return StateFilter(types={}, include_others=False)
+
+    @staticmethod
+    def from_types(types):
+        """Creates a filter that only fetches the given types
+
+        Args:
+            types (Iterable[tuple[str, str|None]]): A list of type and state
+                keys to fetch. A state_key of None fetches everything for
+                that type
+
+        Returns:
+            StateFilter
+        """
+        type_dict = {}
+        for typ, s in types:
+            if typ in type_dict:
+                if type_dict[typ] is None:
+                    continue
+
+            if s is None:
+                type_dict[typ] = None
+                continue
+
+            type_dict.setdefault(typ, set()).add(s)
+
+        return StateFilter(types=type_dict)
+
+    @staticmethod
+    def from_lazy_load_member_list(members):
+        """Creates a filter that returns all non-member events, plus the member
+        events for the given users
+
+        Args:
+            members (iterable[str]): Set of user IDs
+
+        Returns:
+            StateFilter
+        """
+        return StateFilter(
+            types={EventTypes.Member: set(members)},
+            include_others=True,
+        )
+
+    def return_expanded(self):
+        """Creates a new StateFilter where type wild cards have been removed
+        (except for memberships). The returned filter is a superset of the
+        current one, i.e. anything that passes the current filter will pass
+        the returned filter.
+
+        This helps the caching as the DictionaryCache knows if it has *all* the
+        state, but does not know if it has all of the keys of a particular type,
+        which makes wildcard lookups expensive unless we have a complete cache.
+        Hence, if we are doing a wildcard lookup, populate the cache fully so
+        that we can do an efficient lookup next time.
+
+        Note that since we have two caches, one for membership events and one for
+        other events, we can be a bit more clever than simply returning
+        `StateFilter.all()` if `has_wildcards()` is True.
+
+        We return a StateFilter where:
+            1. the list of membership events to return is the same
+            2. if there is a wildcard that matches non-member events we
+               return all non-member events
+
+        Returns:
+            StateFilter
+        """
+
+        if self.is_full():
+            # If we're going to return everything then there's nothing to do
+            return self
+
+        if not self.has_wildcards():
+            # If there are no wild cards, there's nothing to do
+            return self
+
+        if EventTypes.Member in self.types:
+            get_all_members = self.types[EventTypes.Member] is None
+        else:
+            get_all_members = self.include_others
+
+        has_non_member_wildcard = self.include_others or any(
+            state_keys is None
+            for t, state_keys in iteritems(self.types)
+            if t != EventTypes.Member
+        )
+
+        if not has_non_member_wildcard:
+            # If there are no non-member wild cards we can just return ourselves
+            return self
+
+        if get_all_members:
+            # We want to return everything.
+            return StateFilter.all()
+        else:
+            # We want to return all non-members, but only particular
+            # memberships
+            return StateFilter(
+                types={EventTypes.Member: self.types[EventTypes.Member]},
+                include_others=True,
+            )
+
+    def make_sql_filter_clause(self):
+        """Converts the filter to an SQL clause.
+
+        For example:
+
+            f = StateFilter.from_types([("m.room.create", "")])
+            clause, args = f.make_sql_filter_clause()
+            clause == "(type = ? AND state_key = ?)"
+            args == ['m.room.create', '']
+
+
+        Returns:
+            tuple[str, list]: The SQL string (may be empty) and arguments. An
+            empty SQL string is returned when the filter matches everything
+            (i.e. is "full").
+        """
+
+        where_clause = ""
+        where_args = []
+
+        if self.is_full():
+            return where_clause, where_args
+
+        if not self.include_others and not self.types:
+            # i.e. this is an empty filter, so we need to return a clause that
+            # will match nothing
+            return "1 = 2", []
+
+        # First we build up a lost of clauses for each type/state_key combo
+        clauses = []
+        for etype, state_keys in iteritems(self.types):
+            if state_keys is None:
+                clauses.append("(type = ?)")
+                where_args.append(etype)
+                continue
+
+            for state_key in state_keys:
+                clauses.append("(type = ? AND state_key = ?)")
+                where_args.extend((etype, state_key))
+
+        # This will match anything that appears in `self.types`
+        where_clause = " OR ".join(clauses)
+
+        # If we want to include stuff that's not in the types dict then we add
+        # a `OR type NOT IN (...)` clause to the end.
+        if self.include_others:
+            if where_clause:
+                where_clause += " OR "
+
+            where_clause += "type NOT IN (%s)" % (
+                ",".join(["?"] * len(self.types)),
+            )
+            where_args.extend(self.types)
+
+        return where_clause, where_args
+
+    def max_entries_returned(self):
+        """Returns the maximum number of entries this filter will return if
+        known, otherwise returns None.
+
+        For example a simple state filter asking for `("m.room.create", "")`
+        will return 1, whereas the default state filter will return None.
+
+        This is used to bail out early if the right number of entries have been
+        fetched.
+        """
+        if self.has_wildcards():
+            return None
+
+        return len(self.concrete_types())
+
+    def filter_state(self, state_dict):
+        """Returns the state filtered with by this StateFilter
+
+        Args:
+            state (dict[tuple[str, str], Any]): The state map to filter
+
+        Returns:
+            dict[tuple[str, str], Any]: The filtered state map
+        """
+        if self.is_full():
+            return dict(state_dict)
+
+        filtered_state = {}
+        for k, v in iteritems(state_dict):
+            typ, state_key = k
+            if typ in self.types:
+                state_keys = self.types[typ]
+                if state_keys is None or state_key in state_keys:
+                    filtered_state[k] = v
+            elif self.include_others:
+                filtered_state[k] = v
+
+        return filtered_state
+
+    def is_full(self):
+        """Whether this filter fetches everything or not
+
+        Returns:
+            bool
+        """
+        return self.include_others and not self.types
+
+    def has_wildcards(self):
+        """Whether the filter includes wildcards or is attempting to fetch
+        specific state.
+
+        Returns:
+            bool
+        """
+
+        return (
+            self.include_others
+            or any(
+                state_keys is None
+                for state_keys in itervalues(self.types)
+            )
+        )
+
+    def concrete_types(self):
+        """Returns a list of concrete type/state_keys (i.e. not None) that
+        will be fetched. This will be a complete list if `has_wildcards`
+        returns False, but otherwise will be a subset (or even empty).
+
+        Returns:
+            list[tuple[str,str]]
+        """
+        return [
+            (t, s)
+            for t, state_keys in iteritems(self.types)
+            if state_keys is not None
+            for s in state_keys
+        ]
+
+    def get_member_split(self):
+        """Return the filter split into two: one which assumes it's exclusively
+        matching against member state, and one which assumes it's matching
+        against non member state.
+
+        This is useful due to the returned filters giving correct results for
+        `is_full()`, `has_wildcards()`, etc, when operating against maps that
+        either exclusively contain member events or only contain non-member
+        events. (Which is the case when dealing with the member vs non-member
+        state caches).
+
+        Returns:
+            tuple[StateFilter, StateFilter]: The member and non member filters
+        """
+
+        if EventTypes.Member in self.types:
+            state_keys = self.types[EventTypes.Member]
+            if state_keys is None:
+                member_filter = StateFilter.all()
+            else:
+                member_filter = StateFilter({EventTypes.Member: state_keys})
+        elif self.include_others:
+            member_filter = StateFilter.all()
+        else:
+            member_filter = StateFilter.none()
+
+        non_member_filter = StateFilter(
+            types={k: v for k, v in iteritems(self.types) if k != EventTypes.Member},
+            include_others=self.include_others,
+        )
+
+        return member_filter, non_member_filter
+
+
 # this inherits from EventsWorkerStore because it calls self.get_events
 class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     """The parts of StateGroupStore that can be called from workers.
@@ -152,61 +466,41 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
     # FIXME: how should this be cached?
-    def get_filtered_current_state_ids(self, room_id, types, filtered_types=None):
+    def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
         """Get the current state event of a given type for a room based on the
         current_state_events table.  This may not be as up-to-date as the result
         of doing a fresh state resolution as per state_handler.get_current_state
+
         Args:
             room_id (str)
-            types (list[(Str, (Str|None))]): List of (type, state_key) tuples
-                which are used to filter the state fetched. `state_key` may be
-                None, which matches any `state_key`
-            filtered_types (list[Str]|None): List of types to apply the above filter to.
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+
         Returns:
-            deferred: dict of (type, state_key) -> event
+            Deferred[dict[tuple[str, str], str]]: Map from type/state_key to
+            event ID.
         """
 
-        include_other_types = False if filtered_types is None else True
-
         def _get_filtered_current_state_ids_txn(txn):
             results = {}
-            sql = """SELECT type, state_key, event_id FROM current_state_events
-                     WHERE room_id = ? %s"""
-            # Turns out that postgres doesn't like doing a list of OR's and
-            # is about 1000x slower, so we just issue a query for each specific
-            # type seperately.
-            if types:
-                clause_to_args = [
-                    (
-                        "AND type = ? AND state_key = ?",
-                        (etype, state_key)
-                    ) if state_key is not None else (
-                        "AND type = ?",
-                        (etype,)
-                    )
-                    for etype, state_key in types
-                ]
-
-                if include_other_types:
-                    unique_types = set(filtered_types)
-                    clause_to_args.append(
-                        (
-                            "AND type <> ? " * len(unique_types),
-                            list(unique_types)
-                        )
-                    )
-            else:
-                # If types is None we fetch all the state, and so just use an
-                # empty where clause with no extra args.
-                clause_to_args = [("", [])]
-            for where_clause, where_args in clause_to_args:
-                args = [room_id]
-                args.extend(where_args)
-                txn.execute(sql % (where_clause,), args)
-                for row in txn:
-                    typ, state_key, event_id = row
-                    key = (intern_string(typ), intern_string(state_key))
-                    results[key] = event_id
+            sql = """
+                SELECT type, state_key, event_id FROM current_state_events
+                WHERE room_id = ?
+            """
+
+            where_clause, where_args = state_filter.make_sql_filter_clause()
+
+            if where_clause:
+                sql += " AND (%s)" % (where_clause,)
+
+            args = [room_id]
+            args.extend(where_args)
+            txn.execute(sql, args)
+            for row in txn:
+                typ, state_key, event_id = row
+                key = (intern_string(typ), intern_string(state_key))
+                results[key] = event_id
+
             return results
 
         return self.runInteraction(
@@ -322,20 +616,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         })
 
     @defer.inlineCallbacks
-    def _get_state_groups_from_groups(self, groups, types, members=None):
+    def _get_state_groups_from_groups(self, groups, state_filter):
         """Returns the state groups for a given set of groups, filtering on
         types of state events.
 
         Args:
             groups(list[int]): list of state group IDs to query
-            types (Iterable[str, str|None]|None): list of 2-tuples of the form
-                (`type`, `state_key`), where a `state_key` of `None` matches all
-                state_keys for the `type`. If None, all types are returned.
-            members (bool|None): If not None, then, in addition to any filtering
-                implied by types, the results are also filtered to only include
-                member events (if True), or to exclude member events (if False)
-
-        Returns:
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
         Returns:
             Deferred[dict[int, dict[tuple[str, str], str]]]:
                 dict of state_group_id -> (dict of (type, state_key) -> event id)
@@ -346,19 +634,23 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         for chunk in chunks:
             res = yield self.runInteraction(
                 "_get_state_groups_from_groups",
-                self._get_state_groups_from_groups_txn, chunk, types, members,
+                self._get_state_groups_from_groups_txn, chunk, state_filter,
             )
             results.update(res)
 
         defer.returnValue(results)
 
     def _get_state_groups_from_groups_txn(
-        self, txn, groups, types=None, members=None,
+        self, txn, groups, state_filter=StateFilter.all(),
     ):
         results = {group: {} for group in groups}
 
-        if types is not None:
-            types = list(set(types))  # deduplicate types list
+        where_clause, where_args = state_filter.make_sql_filter_clause()
+
+        # Unless the filter clause is empty, we're going to append it after an
+        # existing where clause
+        if where_clause:
+            where_clause = " AND (%s)" % (where_clause,)
 
         if isinstance(self.database_engine, PostgresEngine):
             # Temporarily disable sequential scans in this transaction. This is
@@ -374,79 +666,33 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             # group for the given type, state_key.
             # This may return multiple rows per (type, state_key), but last_value
             # should be the same.
-            sql = ("""
+            sql = """
                 WITH RECURSIVE state(state_group) AS (
                     VALUES(?::bigint)
                     UNION ALL
                     SELECT prev_state_group FROM state_group_edges e, state s
                     WHERE s.state_group = e.state_group
                 )
-                SELECT type, state_key, last_value(event_id) OVER (
+                SELECT DISTINCT type, state_key, last_value(event_id) OVER (
                     PARTITION BY type, state_key ORDER BY state_group ASC
                     ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
                 ) AS event_id FROM state_groups_state
                 WHERE state_group IN (
                     SELECT state_group FROM state
                 )
-                %s
-            """)
+            """
 
-            if members is True:
-                sql += " AND type = '%s'" % (EventTypes.Member,)
-            elif members is False:
-                sql += " AND type <> '%s'" % (EventTypes.Member,)
-
-            # Turns out that postgres doesn't like doing a list of OR's and
-            # is about 1000x slower, so we just issue a query for each specific
-            # type seperately.
-            if types is not None:
-                clause_to_args = [
-                    (
-                        "AND type = ? AND state_key = ?",
-                        (etype, state_key)
-                    ) if state_key is not None else (
-                        "AND type = ?",
-                        (etype,)
-                    )
-                    for etype, state_key in types
-                ]
-            else:
-                # If types is None we fetch all the state, and so just use an
-                # empty where clause with no extra args.
-                clause_to_args = [("", [])]
-
-            for where_clause, where_args in clause_to_args:
-                for group in groups:
-                    args = [group]
-                    args.extend(where_args)
+            for group in groups:
+                args = [group]
+                args.extend(where_args)
 
-                    txn.execute(sql % (where_clause,), args)
-                    for row in txn:
-                        typ, state_key, event_id = row
-                        key = (typ, state_key)
-                        results[group][key] = event_id
+                txn.execute(sql + where_clause, args)
+                for row in txn:
+                    typ, state_key, event_id = row
+                    key = (typ, state_key)
+                    results[group][key] = event_id
         else:
-            where_args = []
-            where_clauses = []
-            wildcard_types = False
-            if types is not None:
-                for typ in types:
-                    if typ[1] is None:
-                        where_clauses.append("(type = ?)")
-                        where_args.append(typ[0])
-                        wildcard_types = True
-                    else:
-                        where_clauses.append("(type = ? AND state_key = ?)")
-                        where_args.extend([typ[0], typ[1]])
-
-                where_clause = "AND (%s)" % (" OR ".join(where_clauses))
-            else:
-                where_clause = ""
-
-            if members is True:
-                where_clause += " AND type = '%s'" % EventTypes.Member
-            elif members is False:
-                where_clause += " AND type <> '%s'" % EventTypes.Member
+            max_entries_returned = state_filter.max_entries_returned()
 
             # We don't use WITH RECURSIVE on sqlite3 as there are distributions
             # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
@@ -460,12 +706,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                     # without the right indices (which we can't add until
                     # after we finish deduping state, which requires this func)
                     args = [next_group]
-                    if types:
-                        args.extend(where_args)
+                    args.extend(where_args)
 
                     txn.execute(
                         "SELECT type, state_key, event_id FROM state_groups_state"
-                        " WHERE state_group = ? %s" % (where_clause,),
+                        " WHERE state_group = ? " + where_clause,
                         args
                     )
                     results[group].update(
@@ -481,9 +726,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                     # wildcards (i.e. Nones) in which case we have to do an exhaustive
                     # search
                     if (
-                        types is not None and
-                        not wildcard_types and
-                        len(results[group]) == len(types)
+                        max_entries_returned is not None and
+                        len(results[group]) == max_entries_returned
                     ):
                         break
 
@@ -498,20 +742,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         return results
 
     @defer.inlineCallbacks
-    def get_state_for_events(self, event_ids, types, filtered_types=None):
+    def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
         """Given a list of event_ids and type tuples, return a list of state
-        dicts for each event. The state dicts will only have the type/state_keys
-        that are in the `types` list.
+        dicts for each event.
 
         Args:
             event_ids (list[string])
-            types (list[(str, str|None)]|None): List of (type, state_key) tuples
-                which are used to filter the state fetched. If `state_key` is None,
-                all events are returned of the given type.
-                May be None, which matches any key.
-            filtered_types(list[str]|None): Only apply filtering via `types` to this
-                list of event types.  Other types of events are returned unfiltered.
-                If None, `types` filtering is applied to all events.
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
 
         Returns:
             deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
@@ -521,7 +759,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
         groups = set(itervalues(event_to_groups))
-        group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
+        group_to_state = yield self._get_state_for_groups(groups, state_filter)
 
         state_event_map = yield self.get_events(
             [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
@@ -540,20 +778,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         defer.returnValue({event: event_to_state[event] for event in event_ids})
 
     @defer.inlineCallbacks
-    def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None):
+    def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
         """
         Get the state dicts corresponding to a list of events, containing the event_ids
         of the state events (as opposed to the events themselves)
 
         Args:
             event_ids(list(str)): events whose state should be returned
-            types(list[(str, str|None)]|None): List of (type, state_key) tuples
-                which are used to filter the state fetched. If `state_key` is None,
-                all events are returned of the given type.
-                May be None, which matches any key.
-            filtered_types(list[str]|None): Only apply filtering via `types` to this
-                list of event types.  Other types of events are returned unfiltered.
-                If None, `types` filtering is applied to all events.
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
 
         Returns:
             A deferred dict from event_id -> (type, state_key) -> event_id
@@ -563,7 +796,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
         groups = set(itervalues(event_to_groups))
-        group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
+        group_to_state = yield self._get_state_for_groups(groups, state_filter)
 
         event_to_state = {
             event_id: group_to_state[group]
@@ -573,45 +806,35 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         defer.returnValue({event: event_to_state[event] for event in event_ids})
 
     @defer.inlineCallbacks
-    def get_state_for_event(self, event_id, types=None, filtered_types=None):
+    def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
         """
         Get the state dict corresponding to a particular event
 
         Args:
             event_id(str): event whose state should be returned
-            types(list[(str, str|None)]|None): List of (type, state_key) tuples
-                which are used to filter the state fetched. If `state_key` is None,
-                all events are returned of the given type.
-                May be None, which matches any key.
-            filtered_types(list[str]|None): Only apply filtering via `types` to this
-                list of event types.  Other types of events are returned unfiltered.
-                If None, `types` filtering is applied to all events.
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
 
         Returns:
             A deferred dict from (type, state_key) -> state_event
         """
-        state_map = yield self.get_state_for_events([event_id], types, filtered_types)
+        state_map = yield self.get_state_for_events([event_id], state_filter)
         defer.returnValue(state_map[event_id])
 
     @defer.inlineCallbacks
-    def get_state_ids_for_event(self, event_id, types=None, filtered_types=None):
+    def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
         """
         Get the state dict corresponding to a particular event
 
         Args:
             event_id(str): event whose state should be returned
-            types(list[(str, str|None)]|None): List of (type, state_key) tuples
-                which are used to filter the state fetched. If `state_key` is None,
-                all events are returned of the given type.
-                May be None, which matches any key.
-            filtered_types(list[str]|None): Only apply filtering via `types` to this
-                list of event types.  Other types of events are returned unfiltered.
-                If None, `types` filtering is applied to all events.
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
 
         Returns:
             A deferred dict from (type, state_key) -> state_event
         """
-        state_map = yield self.get_state_ids_for_events([event_id], types, filtered_types)
+        state_map = yield self.get_state_ids_for_events([event_id], state_filter)
         defer.returnValue(state_map[event_id])
 
     @cached(max_entries=50000)
@@ -642,18 +865,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
 
-    def _get_some_state_from_cache(self, cache, group, types, filtered_types=None):
+    def _get_state_for_group_using_cache(self, cache, group, state_filter):
         """Checks if group is in cache. See `_get_state_for_groups`
 
         Args:
             cache(DictionaryCache): the state group cache to use
             group(int): The state group to lookup
-            types(list[str, str|None]): List of 2-tuples of the form
-                (`type`, `state_key`), where a `state_key` of `None` matches all
-                state_keys for the `type`.
-            filtered_types(list[str]|None): Only apply filtering via `types` to this
-                list of event types.  Other types of events are returned unfiltered.
-                If None, `types` filtering is applied to all events.
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
 
         Returns 2-tuple (`state_dict`, `got_all`).
         `got_all` is a bool indicating if we successfully retrieved all
@@ -662,124 +881,102 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         """
         is_all, known_absent, state_dict_ids = cache.get(group)
 
-        type_to_key = {}
+        if is_all or state_filter.is_full():
+            # Either we have everything or want everything, either way
+            # `is_all` tells us whether we've gotten everything.
+            return state_filter.filter_state(state_dict_ids), is_all
 
         # tracks whether any of our requested types are missing from the cache
         missing_types = False
 
-        for typ, state_key in types:
-            key = (typ, state_key)
-
-            if (
-                state_key is None or
-                (filtered_types is not None and typ not in filtered_types)
-            ):
-                type_to_key[typ] = None
-                # we mark the type as missing from the cache because
-                # when the cache was populated it might have been done with a
-                # restricted set of state_keys, so the wildcard will not work
-                # and the cache may be incomplete.
-                missing_types = True
-            else:
-                if type_to_key.get(typ, object()) is not None:
-                    type_to_key.setdefault(typ, set()).add(state_key)
-
+        if state_filter.has_wildcards():
+            # We don't know if we fetched all the state keys for the types in
+            # the filter that are wildcards, so we have to assume that we may
+            # have missed some.
+            missing_types = True
+        else:
+            # There aren't any wild cards, so `concrete_types()` returns the
+            # complete list of event types we're wanting.
+            for key in state_filter.concrete_types():
                 if key not in state_dict_ids and key not in known_absent:
                     missing_types = True
+                    break
 
-        sentinel = object()
-
-        def include(typ, state_key):
-            valid_state_keys = type_to_key.get(typ, sentinel)
-            if valid_state_keys is sentinel:
-                return filtered_types is not None and typ not in filtered_types
-            if valid_state_keys is None:
-                return True
-            if state_key in valid_state_keys:
-                return True
-            return False
-
-        got_all = is_all
-        if not got_all:
-            # the cache is incomplete. We may still have got all the results we need, if
-            # we don't have any wildcards in the match list.
-            if not missing_types and filtered_types is None:
-                got_all = True
-
-        return {
-            k: v for k, v in iteritems(state_dict_ids)
-            if include(k[0], k[1])
-        }, got_all
-
-    def _get_all_state_from_cache(self, cache, group):
-        """Checks if group is in cache. See `_get_state_for_groups`
-
-        Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
-        indicating if we successfully retrieved all requests state from the
-        cache, if False we need to query the DB for the missing state.
-
-        Args:
-            cache(DictionaryCache): the state group cache to use
-            group: The state group to lookup
-        """
-        is_all, _, state_dict_ids = cache.get(group)
-
-        return state_dict_ids, is_all
+        return state_filter.filter_state(state_dict_ids), not missing_types
 
     @defer.inlineCallbacks
-    def _get_state_for_groups(self, groups, types=None, filtered_types=None):
+    def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
 
         Args:
             groups (iterable[int]): list of state groups for which we want
                 to get the state.
-            types (None|iterable[(str, None|str)]):
-                indicates the state type/keys required. If None, the whole
-                state is fetched and returned.
-
-                Otherwise, each entry should be a `(type, state_key)` tuple to
-                include in the response. A `state_key` of None is a wildcard
-                meaning that we require all state with that type.
-            filtered_types(list[str]|None): Only apply filtering via `types` to this
-                list of event types.  Other types of events are returned unfiltered.
-                If None, `types` filtering is applied to all events.
-
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
         Returns:
             Deferred[dict[int, dict[tuple[str, str], str]]]:
                 dict of state_group_id -> (dict of (type, state_key) -> event id)
         """
-        if types is not None:
-            non_member_types = [t for t in types if t[0] != EventTypes.Member]
 
-            if filtered_types is not None and EventTypes.Member not in filtered_types:
-                # we want all of the membership events
-                member_types = None
-            else:
-                member_types = [t for t in types if t[0] == EventTypes.Member]
-
-        else:
-            non_member_types = None
-            member_types = None
+        member_filter, non_member_filter = state_filter.get_member_split()
 
-        non_member_state = yield self._get_state_for_groups_using_cache(
-            groups, self._state_group_cache, non_member_types, filtered_types,
+        # Now we look them up in the member and non-member caches
+        non_member_state, incomplete_groups_nm, = (
+            yield self._get_state_for_groups_using_cache(
+                groups, self._state_group_cache,
+                state_filter=non_member_filter,
+            )
         )
-        # XXX: we could skip this entirely if member_types is []
-        member_state = yield self._get_state_for_groups_using_cache(
-            # we set filtered_types=None as member_state only ever contain members.
-            groups, self._state_group_members_cache, member_types, None,
+
+        member_state, incomplete_groups_m, = (
+            yield self._get_state_for_groups_using_cache(
+                groups, self._state_group_members_cache,
+                state_filter=member_filter,
+            )
         )
 
-        state = non_member_state
+        state = dict(non_member_state)
         for group in groups:
             state[group].update(member_state[group])
 
+        # Now fetch any missing groups from the database
+
+        incomplete_groups = incomplete_groups_m | incomplete_groups_nm
+
+        if not incomplete_groups:
+            defer.returnValue(state)
+
+        cache_sequence_nm = self._state_group_cache.sequence
+        cache_sequence_m = self._state_group_members_cache.sequence
+
+        # Help the cache hit ratio by expanding the filter a bit
+        db_state_filter = state_filter.return_expanded()
+
+        group_to_state_dict = yield self._get_state_groups_from_groups(
+            list(incomplete_groups),
+            state_filter=db_state_filter,
+        )
+
+        # Now lets update the caches
+        self._insert_into_cache(
+            group_to_state_dict,
+            db_state_filter,
+            cache_seq_num_members=cache_sequence_m,
+            cache_seq_num_non_members=cache_sequence_nm,
+        )
+
+        # And finally update the result dict, by filtering out any extra
+        # stuff we pulled out of the database.
+        for group, group_state_dict in iteritems(group_to_state_dict):
+            # We just replace any existing entries, as we will have loaded
+            # everything we need from the database anyway.
+            state[group] = state_filter.filter_state(group_state_dict)
+
         defer.returnValue(state)
 
-    @defer.inlineCallbacks
     def _get_state_for_groups_using_cache(
-        self, groups, cache, types=None, filtered_types=None
+        self, groups, cache, state_filter,
     ):
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key, querying from a specific cache.
@@ -790,89 +987,85 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             cache (DictionaryCache): the cache of group ids to state dicts which
                 we will pass through - either the normal state cache or the specific
                 members state cache.
-            types (None|iterable[(str, None|str)]):
-                indicates the state type/keys required. If None, the whole
-                state is fetched and returned.
-
-                Otherwise, each entry should be a `(type, state_key)` tuple to
-                include in the response. A `state_key` of None is a wildcard
-                meaning that we require all state with that type.
-            filtered_types(list[str]|None): Only apply filtering via `types` to this
-                list of event types.  Other types of events are returned unfiltered.
-                If None, `types` filtering is applied to all events.
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
 
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
+            tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
+            dict of state_group_id -> (dict of (type, state_key) -> event id)
+            of entries in the cache, and the state group ids either missing
+            from the cache or incomplete.
         """
-        if types:
-            types = frozenset(types)
         results = {}
-        missing_groups = []
-        if types is not None:
-            for group in set(groups):
-                state_dict_ids, got_all = self._get_some_state_from_cache(
-                    cache, group, types, filtered_types
-                )
-                results[group] = state_dict_ids
+        incomplete_groups = set()
+        for group in set(groups):
+            state_dict_ids, got_all = self._get_state_for_group_using_cache(
+                cache, group, state_filter
+            )
+            results[group] = state_dict_ids
 
-                if not got_all:
-                    missing_groups.append(group)
-        else:
-            for group in set(groups):
-                state_dict_ids, got_all = self._get_all_state_from_cache(
-                    cache, group
-                )
+            if not got_all:
+                incomplete_groups.add(group)
 
-                results[group] = state_dict_ids
+        return results, incomplete_groups
 
-                if not got_all:
-                    missing_groups.append(group)
+    def _insert_into_cache(self, group_to_state_dict, state_filter,
+                           cache_seq_num_members, cache_seq_num_non_members):
+        """Inserts results from querying the database into the relevant cache.
 
-        if missing_groups:
-            # Okay, so we have some missing_types, let's fetch them.
-            cache_seq_num = cache.sequence
+        Args:
+            group_to_state_dict (dict): The new entries pulled from database.
+                Map from state group to state dict
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+            cache_seq_num_members (int): Sequence number of member cache since
+                last lookup in cache
+            cache_seq_num_non_members (int): Sequence number of member cache since
+                last lookup in cache
+        """
 
-            # the DictionaryCache knows if it has *all* the state, but
-            # does not know if it has all of the keys of a particular type,
-            # which makes wildcard lookups expensive unless we have a complete
-            # cache. Hence, if we are doing a wildcard lookup, populate the
-            # cache fully so that we can do an efficient lookup next time.
+        # We need to work out which types we've fetched from the DB for the
+        # member vs non-member caches. This should be as accurate as possible,
+        # but can be an underestimate (e.g. when we have wild cards)
 
-            if filtered_types or (types and any(k is None for (t, k) in types)):
-                types_to_fetch = None
-            else:
-                types_to_fetch = types
+        member_filter, non_member_filter = state_filter.get_member_split()
+        if member_filter.is_full():
+            # We fetched all member events
+            member_types = None
+        else:
+            # `concrete_types()` will only return a subset when there are wild
+            # cards in the filter, but that's fine.
+            member_types = member_filter.concrete_types()
 
-            group_to_state_dict = yield self._get_state_groups_from_groups(
-                missing_groups, types_to_fetch, cache == self._state_group_members_cache,
-            )
+        if non_member_filter.is_full():
+            # We fetched all non member events
+            non_member_types = None
+        else:
+            non_member_types = non_member_filter.concrete_types()
+
+        for group, group_state_dict in iteritems(group_to_state_dict):
+            state_dict_members = {}
+            state_dict_non_members = {}
 
-            for group, group_state_dict in iteritems(group_to_state_dict):
-                state_dict = results[group]
-
-                # update the result, filtering by `types`.
-                if types:
-                    for k, v in iteritems(group_state_dict):
-                        (typ, _) = k
-                        if (
-                            (k in types or (typ, None) in types) or
-                            (filtered_types and typ not in filtered_types)
-                        ):
-                            state_dict[k] = v
+            for k, v in iteritems(group_state_dict):
+                if k[0] == EventTypes.Member:
+                    state_dict_members[k] = v
                 else:
-                    state_dict.update(group_state_dict)
-
-                # update the cache with all the things we fetched from the
-                # database.
-                cache.update(
-                    cache_seq_num,
-                    key=group,
-                    value=group_state_dict,
-                    fetched_keys=types_to_fetch,
-                )
+                    state_dict_non_members[k] = v
 
-        defer.returnValue(results)
+            self._state_group_members_cache.update(
+                cache_seq_num_members,
+                key=group,
+                value=state_dict_members,
+                fetched_keys=member_types,
+            )
+
+            self._state_group_cache.update(
+                cache_seq_num_non_members,
+                key=group,
+                value=state_dict_non_members,
+                fetched_keys=non_member_types,
+            )
 
     def store_state_group(self, event_id, room_id, prev_group, delta_ids,
                           current_state_ids):
@@ -1181,12 +1374,12 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
                         continue
 
                     prev_state = self._get_state_groups_from_groups_txn(
-                        txn, [prev_group], types=None
+                        txn, [prev_group],
                     )
                     prev_state = prev_state[prev_group]
 
                     curr_state = self._get_state_groups_from_groups_txn(
-                        txn, [state_group], types=None
+                        txn, [state_group],
                     )
                     curr_state = curr_state[state_group]
 
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index b8981cdb6d..cd25e07719 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -64,6 +64,11 @@ class UserDirectoryStore(SQLBaseStore):
                 or publically joinable
             user_ids (list(str)): Users to add
         """
+
+        support_id = self.hs.config.support_user_id
+        if support_id in user_ids:
+            user_ids.remove(support_id)
+
         yield self._simple_insert_many(
             table="users_in_public_rooms",
             values=[
@@ -86,8 +91,7 @@ class UserDirectoryStore(SQLBaseStore):
             users_with_profile (dict): Users to add to directory in the form of
                 mapping of user_id -> ProfileInfo
         """
-        # TODO Filter out support user
-
+        users_with_profile.pop(self.hs.config.support_user_id, None)
 
         if isinstance(self.database_engine, PostgresEngine):
             # We weight the loclpart most highly, then display name and finally
@@ -149,16 +153,19 @@ class UserDirectoryStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def update_user_in_user_dir(self, user_id, room_id):
-        yield self._simple_update_one(
-            table="user_directory",
-            keyvalues={"user_id": user_id},
-            updatevalues={"room_id": room_id},
-            desc="update_user_in_user_dir",
-        )
-        self.get_user_in_directory.invalidate((user_id,))
+        if user_id is not self.hs.config.support_user_id:
+            yield self._simple_update_one(
+                table="user_directory",
+                keyvalues={"user_id": user_id},
+                updatevalues={"room_id": room_id},
+                desc="update_user_in_user_dir",
+            )
+            self.get_user_in_directory.invalidate((user_id,))
 
     def update_profile_in_user_dir(self, user_id, display_name, avatar_url, room_id):
         def _update_profile_in_user_dir_txn(txn):
+            if user_is is self.hs.config.support_user_id:
+                return
             new_entry = self._simple_upsert_txn(
                 txn,
                 table="user_directory",
@@ -216,20 +223,20 @@ class UserDirectoryStore(SQLBaseStore):
                 raise Exception("Unrecognized database engine")
 
             txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
-
         return self.runInteraction(
             "update_profile_in_user_dir", _update_profile_in_user_dir_txn
         )
 
     @defer.inlineCallbacks
     def update_user_in_public_user_list(self, user_id, room_id):
-        yield self._simple_update_one(
-            table="users_in_public_rooms",
-            keyvalues={"user_id": user_id},
-            updatevalues={"room_id": room_id},
-            desc="update_user_in_public_user_list",
-        )
-        self.get_user_in_public_room.invalidate((user_id,))
+        if user_is is not self.hs.config.support_user_id:
+            yield self._simple_update_one(
+                table="users_in_public_rooms",
+                keyvalues={"user_id": user_id},
+                updatevalues={"room_id": room_id},
+                desc="update_user_in_public_user_list",
+            )
+            self.get_user_in_public_room.invalidate((user_id,))
 
     def remove_from_user_dir(self, user_id):
         def _remove_from_user_dir_txn(txn):
@@ -332,7 +339,7 @@ class UserDirectoryStore(SQLBaseStore):
         rows = yield self._execute("get_all_local_users", None, sql)
         defer.returnValue([name for name, in rows])
 
-    def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
+    def add_users_who_share_room(self, room_id, share_private, user_id_tuples_x):
         """Insert entries into the users_who_share_rooms table. The first
         user should be a local user.
 
@@ -342,6 +349,9 @@ class UserDirectoryStore(SQLBaseStore):
             user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
         """
         def _add_users_who_share_room_txn(txn):
+            support_user = self.hs.config.support_user_id
+            user_id_tuples = filter(lambda x: support_user not in x, user_id_tuples_x)
+
             self._simple_insert_many_txn(
                 txn,
                 table="users_who_share_rooms",
@@ -378,6 +388,7 @@ class UserDirectoryStore(SQLBaseStore):
             user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
         """
         def _update_users_who_share_room_txn(txn):
+
             sql = """
                 UPDATE users_who_share_rooms
                 SET room_id = ?, share_private = ?
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 9a8fae0497..0ae7e2ef3b 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+import re
 from itertools import islice
 
 import attr
@@ -138,3 +139,27 @@ def log_failure(failure, msg, consumeErrors=True):
 
     if not consumeErrors:
         return failure
+
+
+def glob_to_regex(glob):
+    """Converts a glob to a compiled regex object.
+
+    The regex is anchored at the beginning and end of the string.
+
+    Args:
+        glob (str)
+
+    Returns:
+        re.RegexObject
+    """
+    res = ''
+    for c in glob:
+        if c == '*':
+            res = res + '.*'
+        elif c == '?':
+            res = res + '.'
+        else:
+            res = res + re.escape(c)
+
+    # \A anchors at start of string, \Z at end of string
+    return re.compile(r"\A" + res + r"\Z", re.IGNORECASE)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 43f48196be..0281a7c919 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -23,6 +23,7 @@ from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events.utils import prune_event
+from synapse.storage.state import StateFilter
 from synapse.types import get_domain_from_id
 
 logger = logging.getLogger(__name__)
@@ -72,7 +73,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
     )
     event_id_to_state = yield store.get_state_for_events(
         frozenset(e.event_id for e in events),
-        types=types,
+        state_filter=StateFilter.from_types(types),
     )
 
     ignore_dict_content = yield store.get_global_account_data_by_type_for_user(
@@ -273,8 +274,8 @@ def filter_events_for_server(store, server_name, events):
     # need to check membership (as we know the server is in the room).
     event_to_state_ids = yield store.get_state_ids_for_events(
         frozenset(e.event_id for e in events),
-        types=(
-            (EventTypes.RoomHistoryVisibility, ""),
+        state_filter=StateFilter.from_types(
+            types=((EventTypes.RoomHistoryVisibility, ""),),
         )
     )
 
@@ -314,9 +315,11 @@ def filter_events_for_server(store, server_name, events):
     # of the history vis and membership state at those events.
     event_to_state_ids = yield store.get_state_ids_for_events(
         frozenset(e.event_id for e in events),
-        types=(
-            (EventTypes.RoomHistoryVisibility, ""),
-            (EventTypes.Member, None),
+        state_filter=StateFilter.from_types(
+            types=(
+                (EventTypes.RoomHistoryVisibility, ""),
+                (EventTypes.Member, None),
+            ),
         )
     )
 
diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py
new file mode 100644
index 0000000000..f37a17d618
--- /dev/null
+++ b/tests/config/test_room_directory.py
@@ -0,0 +1,67 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import yaml
+
+from synapse.config.room_directory import RoomDirectoryConfig
+
+from tests import unittest
+
+
+class RoomDirectoryConfigTestCase(unittest.TestCase):
+    def test_alias_creation_acl(self):
+        config = yaml.load("""
+        alias_creation_rules:
+            - user_id: "*bob*"
+              alias: "*"
+              action: "deny"
+            - user_id: "*"
+              alias: "#unofficial_*"
+              action: "allow"
+            - user_id: "@foo*:example.com"
+              alias: "*"
+              action: "allow"
+            - user_id: "@gah:example.com"
+              alias: "#goo:example.com"
+              action: "allow"
+        """)
+
+        rd_config = RoomDirectoryConfig()
+        rd_config.read_config(config)
+
+        self.assertFalse(rd_config.is_alias_creation_allowed(
+            user_id="@bob:example.com",
+            alias="#test:example.com",
+        ))
+
+        self.assertTrue(rd_config.is_alias_creation_allowed(
+            user_id="@test:example.com",
+            alias="#unofficial_st:example.com",
+        ))
+
+        self.assertTrue(rd_config.is_alias_creation_allowed(
+            user_id="@foobar:example.com",
+            alias="#test:example.com",
+        ))
+
+        self.assertTrue(rd_config.is_alias_creation_allowed(
+            user_id="@gah:example.com",
+            alias="#goo:example.com",
+        ))
+
+        self.assertFalse(rd_config.is_alias_creation_allowed(
+            user_id="@test:example.com",
+            alias="#test:example.com",
+        ))
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index ec7355688b..8ae6556c0a 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -18,7 +18,9 @@ from mock import Mock
 
 from twisted.internet import defer
 
+from synapse.config.room_directory import RoomDirectoryConfig
 from synapse.handlers.directory import DirectoryHandler
+from synapse.rest.client.v1 import directory, room
 from synapse.types import RoomAlias
 
 from tests import unittest
@@ -102,3 +104,49 @@ class DirectoryTestCase(unittest.TestCase):
         )
 
         self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response)
+
+
+class TestCreateAliasACL(unittest.HomeserverTestCase):
+    user_id = "@test:test"
+
+    servlets = [directory.register_servlets, room.register_servlets]
+
+    def prepare(self, hs, reactor, clock):
+        # We cheekily override the config to add custom alias creation rules
+        config = {}
+        config["alias_creation_rules"] = [
+            {
+                "user_id": "*",
+                "alias": "#unofficial_*",
+                "action": "allow",
+            }
+        ]
+
+        rd_config = RoomDirectoryConfig()
+        rd_config.read_config(config)
+
+        self.hs.config.is_alias_creation_allowed = rd_config.is_alias_creation_allowed
+
+        return hs
+
+    def test_denied(self):
+        room_id = self.helper.create_room_as(self.user_id)
+
+        request, channel = self.make_request(
+            "PUT",
+            b"directory/room/%23test%3Atest",
+            ('{"room_id":"%s"}' % (room_id,)).encode('ascii'),
+        )
+        self.render(request)
+        self.assertEquals(403, channel.code, channel.result)
+
+    def test_allowed(self):
+        room_id = self.helper.create_room_as(self.user_id)
+
+        request, channel = self.make_request(
+            "PUT",
+            b"directory/room/%23unofficial_test%3Atest",
+            ('{"room_id":"%s"}' % (room_id,)).encode('ascii'),
+        )
+        self.render(request)
+        self.assertEquals(200, channel.code, channel.result)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 7b4ade3dfb..3e9a190727 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import ResourceLimitError
 from synapse.handlers.register import RegistrationHandler
-from synapse.types import UserID, create_requester
+from synapse.types import RoomAlias, UserID, create_requester
 
 from tests.utils import setup_test_homeserver
 
@@ -41,30 +41,27 @@ class RegistrationTestCase(unittest.TestCase):
         self.mock_captcha_client = Mock()
         self.hs = yield setup_test_homeserver(
             self.addCleanup,
-            handlers=None,
-            http_client=None,
             expire_access_token=True,
-            profile_handler=Mock(),
         )
         self.macaroon_generator = Mock(
             generate_access_token=Mock(return_value='secret')
         )
         self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
-        self.hs.handlers = RegistrationHandlers(self.hs)
         self.handler = self.hs.get_handlers().registration_handler
         self.store = self.hs.get_datastore()
         self.hs.config.max_mau_value = 50
         self.lots_of_users = 100
         self.small_number_of_users = 1
 
+        self.requester = create_requester("@requester:test")
+
     @defer.inlineCallbacks
     def test_user_is_created_and_logged_in_if_doesnt_exist(self):
-        local_part = "someone"
-        display_name = "someone"
-        user_id = "@someone:test"
-        requester = create_requester("@as:test")
+        frank = UserID.from_string("@frank:test")
+        user_id = frank.to_string()
+        requester = create_requester(user_id)
         result_user_id, result_token = yield self.handler.get_or_create_user(
-            requester, local_part, display_name
+            requester, frank.localpart, "Frankie"
         )
         self.assertEquals(result_user_id, user_id)
         self.assertEquals(result_token, 'secret')
@@ -78,12 +75,11 @@ class RegistrationTestCase(unittest.TestCase):
             token="jkv;g498752-43gj['eamb!-5",
             password_hash=None,
         )
-        local_part = "frank"
-        display_name = "Frank"
-        user_id = "@frank:test"
-        requester = create_requester("@as:test")
+        local_part = frank.localpart
+        user_id = frank.to_string()
+        requester = create_requester(user_id)
         result_user_id, result_token = yield self.handler.get_or_create_user(
-            requester, local_part, display_name
+            requester, local_part, None
         )
         self.assertEquals(result_user_id, user_id)
         self.assertEquals(result_token, 'secret')
@@ -92,7 +88,7 @@ class RegistrationTestCase(unittest.TestCase):
     def test_mau_limits_when_disabled(self):
         self.hs.config.limit_usage_by_mau = False
         # Ensure does not throw exception
-        yield self.handler.get_or_create_user("requester", 'a', "display_name")
+        yield self.handler.get_or_create_user(self.requester, 'a', "display_name")
 
     @defer.inlineCallbacks
     def test_get_or_create_user_mau_not_blocked(self):
@@ -101,7 +97,7 @@ class RegistrationTestCase(unittest.TestCase):
             return_value=defer.succeed(self.hs.config.max_mau_value - 1)
         )
         # Ensure does not throw exception
-        yield self.handler.get_or_create_user("@user:server", 'c', "User")
+        yield self.handler.get_or_create_user(self.requester, 'c', "User")
 
     @defer.inlineCallbacks
     def test_get_or_create_user_mau_blocked(self):
@@ -110,13 +106,13 @@ class RegistrationTestCase(unittest.TestCase):
             return_value=defer.succeed(self.lots_of_users)
         )
         with self.assertRaises(ResourceLimitError):
-            yield self.handler.get_or_create_user("requester", 'b', "display_name")
+            yield self.handler.get_or_create_user(self.requester, 'b', "display_name")
 
         self.store.get_monthly_active_count = Mock(
             return_value=defer.succeed(self.hs.config.max_mau_value)
         )
         with self.assertRaises(ResourceLimitError):
-            yield self.handler.get_or_create_user("requester", 'b', "display_name")
+            yield self.handler.get_or_create_user(self.requester, 'b', "display_name")
 
     @defer.inlineCallbacks
     def test_register_mau_blocked(self):
@@ -147,3 +143,44 @@ class RegistrationTestCase(unittest.TestCase):
         )
         with self.assertRaises(ResourceLimitError):
             yield self.handler.register_saml2(localpart="local_part")
+
+    @defer.inlineCallbacks
+    def test_auto_create_auto_join_rooms(self):
+        room_alias_str = "#room:test"
+        self.hs.config.auto_join_rooms = [room_alias_str]
+        res = yield self.handler.register(localpart='jeff')
+        rooms = yield self.store.get_rooms_for_user(res[0])
+
+        directory_handler = self.hs.get_handlers().directory_handler
+        room_alias = RoomAlias.from_string(room_alias_str)
+        room_id = yield directory_handler.get_association(room_alias)
+
+        self.assertTrue(room_id['room_id'] in rooms)
+        self.assertEqual(len(rooms), 1)
+
+    @defer.inlineCallbacks
+    def test_auto_create_auto_join_rooms_with_no_rooms(self):
+        self.hs.config.auto_join_rooms = []
+        frank = UserID.from_string("@frank:test")
+        res = yield self.handler.register(frank.localpart)
+        self.assertEqual(res[0], frank.to_string())
+        rooms = yield self.store.get_rooms_for_user(res[0])
+        self.assertEqual(len(rooms), 0)
+
+    @defer.inlineCallbacks
+    def test_auto_create_auto_join_where_room_is_another_domain(self):
+        self.hs.config.auto_join_rooms = ["#room:another"]
+        frank = UserID.from_string("@frank:test")
+        res = yield self.handler.register(frank.localpart)
+        self.assertEqual(res[0], frank.to_string())
+        rooms = yield self.store.get_rooms_for_user(res[0])
+        self.assertEqual(len(rooms), 0)
+
+    @defer.inlineCallbacks
+    def test_auto_create_auto_join_where_auto_create_is_false(self):
+        self.hs.config.autocreate_auto_join_rooms = False
+        room_alias_str = "#room:test"
+        self.hs.config.auto_join_rooms = [room_alias_str]
+        res = yield self.handler.register(localpart='jeff')
+        rooms = yield self.store.get_rooms_for_user(res[0])
+        self.assertEqual(len(rooms), 0)
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 1d9119a8de..79ce092cba 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -54,7 +54,10 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
         now = int(self.hs.get_clock().time_msec())
         self.store.user_add_threepid(user1, "email", user1_email, now, now)
         self.store.user_add_threepid(user2, "email", user2_email, now, now)
-        self.store.initialise_reserved_users(threepids)
+
+        self.store.runInteraction(
+            "initialise", self.store._initialise_reserved_users, threepids
+        )
         self.pump()
 
         active_count = self.store.get_monthly_active_count()
@@ -201,7 +204,10 @@ class MonthlyActiveUsersTestCase(HomeserverTestCase):
             {'medium': 'email', 'address': user2_email},
         ]
         self.hs.config.mau_limits_reserved_threepids = threepids
-        self.store.initialise_reserved_users(threepids)
+        self.store.runInteraction(
+            "initialise", self.store._initialise_reserved_users, threepids
+        )
+
         self.pump()
         count = self.store.get_registered_reserved_users_count()
         self.assertEquals(self.get_success(count), 0)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index b9c5b39d59..086a39d834 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -18,6 +18,7 @@ import logging
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
+from synapse.storage.state import StateFilter
 from synapse.types import RoomID, UserID
 
 import tests.unittest
@@ -148,7 +149,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # check we get the full state as of the final event
         state = yield self.store.get_state_for_event(
-            e5.event_id, None, filtered_types=None
+            e5.event_id,
         )
 
         self.assertIsNotNone(e4)
@@ -166,33 +167,35 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # check we can filter to the m.room.name event (with a '' state key)
         state = yield self.store.get_state_for_event(
-            e5.event_id, [(EventTypes.Name, '')], filtered_types=None
+            e5.event_id, StateFilter.from_types([(EventTypes.Name, '')])
         )
 
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can filter to the m.room.name event (with a wildcard None state key)
         state = yield self.store.get_state_for_event(
-            e5.event_id, [(EventTypes.Name, None)], filtered_types=None
+            e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
         )
 
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can grab the m.room.member events (with a wildcard None state key)
         state = yield self.store.get_state_for_event(
-            e5.event_id, [(EventTypes.Member, None)], filtered_types=None
+            e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
         )
 
         self.assertStateMapEqual(
             {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
         )
 
-        # check we can use filtered_types to grab a specific room member
-        # without filtering out the other event types
+        # check we can grab a specific room member without filtering out the
+        # other event types
         state = yield self.store.get_state_for_event(
             e5.event_id,
-            [(EventTypes.Member, self.u_alice.to_string())],
-            filtered_types=[EventTypes.Member],
+            state_filter=StateFilter(
+                types={EventTypes.Member: {self.u_alice.to_string()}},
+                include_others=True,
+            )
         )
 
         self.assertStateMapEqual(
@@ -204,10 +207,12 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state,
         )
 
-        # check that types=[], filtered_types=[EventTypes.Member]
-        # doesn't return all members
+        # check that we can grab everything except members
         state = yield self.store.get_state_for_event(
-            e5.event_id, [], filtered_types=[EventTypes.Member]
+            e5.event_id, state_filter=StateFilter(
+                types={EventTypes.Member: set()},
+                include_others=True,
+            ),
         )
 
         self.assertStateMapEqual(
@@ -215,16 +220,21 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         #######################################################
-        # _get_some_state_from_cache tests against a full cache
+        # _get_state_for_group_using_cache tests against a full cache
         #######################################################
 
         room_id = self.room.to_string()
         group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
         group = list(group_ids.keys())[0]
 
-        # test _get_some_state_from_cache correctly filters out members with types=[]
-        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
-            self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member]
+        # test _get_state_for_group_using_cache correctly filters out members
+        # with types=[]
+        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+            self.store._state_group_cache, group,
+            state_filter=StateFilter(
+                types={EventTypes.Member: set()},
+                include_others=True,
+            ),
         )
 
         self.assertEqual(is_all, True)
@@ -236,22 +246,27 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
             self.store._state_group_members_cache,
             group,
-            [],
-            filtered_types=[EventTypes.Member],
+            state_filter=StateFilter(
+                types={EventTypes.Member: set()},
+                include_others=True,
+            ),
         )
 
         self.assertEqual(is_all, True)
         self.assertDictEqual({}, state_dict)
 
-        # test _get_some_state_from_cache correctly filters in members with wildcard types
-        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+        # test _get_state_for_group_using_cache correctly filters in members
+        # with wildcard types
+        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
             self.store._state_group_cache,
             group,
-            [(EventTypes.Member, None)],
-            filtered_types=[EventTypes.Member],
+            state_filter=StateFilter(
+                types={EventTypes.Member: None},
+                include_others=True,
+            ),
         )
 
         self.assertEqual(is_all, True)
@@ -263,11 +278,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
             self.store._state_group_members_cache,
             group,
-            [(EventTypes.Member, None)],
-            filtered_types=[EventTypes.Member],
+            state_filter=StateFilter(
+                types={EventTypes.Member: None},
+                include_others=True,
+            ),
         )
 
         self.assertEqual(is_all, True)
@@ -280,12 +297,15 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        # test _get_some_state_from_cache correctly filters in members with specific types
-        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+        # test _get_state_for_group_using_cache correctly filters in members
+        # with specific types
+        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
             self.store._state_group_cache,
             group,
-            [(EventTypes.Member, e5.state_key)],
-            filtered_types=[EventTypes.Member],
+            state_filter=StateFilter(
+                types={EventTypes.Member: {e5.state_key}},
+                include_others=True,
+            ),
         )
 
         self.assertEqual(is_all, True)
@@ -297,23 +317,27 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
             self.store._state_group_members_cache,
             group,
-            [(EventTypes.Member, e5.state_key)],
-            filtered_types=[EventTypes.Member],
+            state_filter=StateFilter(
+                types={EventTypes.Member: {e5.state_key}},
+                include_others=True,
+            ),
         )
 
         self.assertEqual(is_all, True)
         self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
 
-        # test _get_some_state_from_cache correctly filters in members with specific types
-        # and no filtered_types
-        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+        # test _get_state_for_group_using_cache correctly filters in members
+        # with specific types
+        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
             self.store._state_group_members_cache,
             group,
-            [(EventTypes.Member, e5.state_key)],
-            filtered_types=None,
+            state_filter=StateFilter(
+                types={EventTypes.Member: {e5.state_key}},
+                include_others=False,
+            ),
         )
 
         self.assertEqual(is_all, True)
@@ -357,42 +381,54 @@ class StateStoreTestCase(tests.unittest.TestCase):
         ############################################
         # test that things work with a partial cache
 
-        # test _get_some_state_from_cache correctly filters out members with types=[]
+        # test _get_state_for_group_using_cache correctly filters out members
+        # with types=[]
         room_id = self.room.to_string()
-        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
-            self.store._state_group_cache, group, [], filtered_types=[EventTypes.Member]
+        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
+            self.store._state_group_cache, group,
+            state_filter=StateFilter(
+                types={EventTypes.Member: set()},
+                include_others=True,
+            ),
         )
 
         self.assertEqual(is_all, False)
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
         room_id = self.room.to_string()
-        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
             self.store._state_group_members_cache,
             group,
-            [],
-            filtered_types=[EventTypes.Member],
+            state_filter=StateFilter(
+                types={EventTypes.Member: set()},
+                include_others=True,
+            ),
         )
 
         self.assertEqual(is_all, True)
         self.assertDictEqual({}, state_dict)
 
-        # test _get_some_state_from_cache correctly filters in members wildcard types
-        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+        # test _get_state_for_group_using_cache correctly filters in members
+        # wildcard types
+        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
             self.store._state_group_cache,
             group,
-            [(EventTypes.Member, None)],
-            filtered_types=[EventTypes.Member],
+            state_filter=StateFilter(
+                types={EventTypes.Member: None},
+                include_others=True,
+            ),
         )
 
         self.assertEqual(is_all, False)
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
-        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
             self.store._state_group_members_cache,
             group,
-            [(EventTypes.Member, None)],
-            filtered_types=[EventTypes.Member],
+            state_filter=StateFilter(
+                types={EventTypes.Member: None},
+                include_others=True,
+            ),
         )
 
         self.assertEqual(is_all, True)
@@ -404,44 +440,53 @@ class StateStoreTestCase(tests.unittest.TestCase):
             state_dict,
         )
 
-        # test _get_some_state_from_cache correctly filters in members with specific types
-        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+        # test _get_state_for_group_using_cache correctly filters in members
+        # with specific types
+        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
             self.store._state_group_cache,
             group,
-            [(EventTypes.Member, e5.state_key)],
-            filtered_types=[EventTypes.Member],
+            state_filter=StateFilter(
+                types={EventTypes.Member: {e5.state_key}},
+                include_others=True,
+            ),
         )
 
         self.assertEqual(is_all, False)
         self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
 
-        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
             self.store._state_group_members_cache,
             group,
-            [(EventTypes.Member, e5.state_key)],
-            filtered_types=[EventTypes.Member],
+            state_filter=StateFilter(
+                types={EventTypes.Member: {e5.state_key}},
+                include_others=True,
+            ),
         )
 
         self.assertEqual(is_all, True)
         self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
 
-        # test _get_some_state_from_cache correctly filters in members with specific types
-        # and no filtered_types
-        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+        # test _get_state_for_group_using_cache correctly filters in members
+        # with specific types
+        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
             self.store._state_group_cache,
             group,
-            [(EventTypes.Member, e5.state_key)],
-            filtered_types=None,
+            state_filter=StateFilter(
+                types={EventTypes.Member: {e5.state_key}},
+                include_others=False,
+            ),
         )
 
         self.assertEqual(is_all, False)
         self.assertDictEqual({}, state_dict)
 
-        (state_dict, is_all) = yield self.store._get_some_state_from_cache(
+        (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
             self.store._state_group_members_cache,
             group,
-            [(EventTypes.Member, e5.state_key)],
-            filtered_types=None,
+            state_filter=StateFilter(
+                types={EventTypes.Member: {e5.state_key}},
+                include_others=False,
+            ),
         )
 
         self.assertEqual(is_all, True)
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 0dde1ab2fe..12f64de691 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -75,3 +75,42 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
             )
         finally:
             self.hs.config.user_directory_search_all_users = False
+
+    @defer.inlineCallbacks
+    def test_cannot_add_support_user_to_directory(self):
+        self.hs.config.user_directory_search_all_users = True
+        self.hs.config.support_user_id = "@support:test"
+        SUPPORT_USER = self.hs.config.support_user_id
+        yield self.store.add_profiles_to_user_dir(
+            "!room:id",
+            {SUPPORT_USER: ProfileInfo(None, "support")},
+        )
+        yield self.store.add_users_to_public_room("!room:id", [SUPPORT_USER])
+        yield self.store.add_users_who_share_room(
+            "!room:id", False, ((ALICE, SUPPORT_USER),)
+        )
+
+        r = yield self.store.search_user_dir(ALICE, "support", 10)
+        self.assertFalse(r["limited"])
+        self.assertEqual(0, len(r["results"]))
+
+        # add_users_who_share_room
+        # add_users_to_public_room
+        # add_profiles_to_user_dir
+        # update_user_in_user_dir
+        # update_profile_in_user_dir
+        # update_user_in_public_user_list
+
+        # yield self.store.add_profiles_to_user_dir(
+        #     "!room:id",
+        #     {SUPPORT_USER: ProfileInfo(None, "support")},
+        # )
+        # yield self.store.add_profiles_to_user_dir(SUPPORT_USER,
+        #
+        #
+        #
+        # yield self.store.add_users_to_public_room("!room:id", [SUPPORT_USER])
+        #
+        # yield self.store.add_users_who_share_room(
+        #     "!room:id", False, ((ALICE, SUPPORT_USER), (BOB, SUPPORT_USER))
+        # )
diff --git a/tests/utils.py b/tests/utils.py
index f4ade88929..806b499449 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -124,6 +124,7 @@ def default_config(name):
     config.user_consent_server_notice_content = None
     config.block_events_without_consent_error = None
     config.media_storage_providers = []
+    config.autocreate_auto_join_rooms = True
     config.auto_join_rooms = []
     config.limit_usage_by_mau = False
     config.hs_disabled = False
diff --git a/tox.ini b/tox.ini
index 04d2f721bf..9de5a5704a 100644
--- a/tox.ini
+++ b/tox.ini
@@ -3,7 +3,6 @@ envlist = packaging, py27, py36, pep8, check_isort
 
 [base]
 deps =
-    coverage
     Twisted>=17.1
     mock
     python-subunit
@@ -26,9 +25,7 @@ passenv = *
 
 commands =
     /usr/bin/find "{toxinidir}" -name '*.pyc' -delete
-    coverage run {env:COVERAGE_OPTS:} --source="{toxinidir}/synapse" \
-        "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}
-    {env:DUMP_COVERAGE_COMMAND:coverage report -m}
+    "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}
 
 [testenv:py27]