summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorNeil Johnson <neil@matrix.org>2018-08-01 17:49:41 +0100
committerNeil Johnson <neil@matrix.org>2018-08-01 17:49:41 +0100
commitd766f26de978f3ffc7989f3bdc586f3f87ac28c1 (patch)
treec8043934dfa079b3af4de4c942917c10df1237e0 /synapse
parentclean up (diff)
parentMerge pull request #3630 from matrix-org/neilj/mau_sign_in_log_in_limits (diff)
downloadsynapse-d766f26de978f3ffc7989f3bdc586f3f87ac28c1.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into neilj/mau_tracker
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py4
-rw-r--r--synapse/api/errors.py1
-rwxr-xr-xsynapse/app/homeserver.py19
-rw-r--r--synapse/config/server.py8
-rw-r--r--synapse/federation/federation_server.py4
-rw-r--r--synapse/federation/send_queue.py63
-rw-r--r--synapse/federation/transaction_queue.py32
-rw-r--r--synapse/federation/transport/server.py5
-rw-r--r--synapse/federation/units.py1
-rw-r--r--synapse/handlers/auth.py48
-rw-r--r--synapse/handlers/federation.py28
-rw-r--r--synapse/handlers/register.py21
-rw-r--r--synapse/http/server.py35
-rw-r--r--synapse/http/servlet.py10
-rw-r--r--synapse/rest/client/v1/admin.py22
-rw-r--r--synapse/rest/client/v1/directory.py4
-rw-r--r--synapse/rest/client/v2_alpha/register.py12
-rw-r--r--synapse/rest/media/v1/media_storage.py2
-rw-r--r--synapse/state.py2
-rw-r--r--synapse/storage/__init__.py28
-rw-r--r--synapse/storage/appservice.py2
-rw-r--r--synapse/storage/event_federation.py2
-rw-r--r--synapse/storage/events.py20
-rw-r--r--synapse/storage/roommember.py2
-rw-r--r--synapse/storage/signatures.py2
-rw-r--r--synapse/storage/stream.py2
-rw-r--r--synapse/types.py2
-rw-r--r--synapse/util/caches/descriptors.py131
-rw-r--r--synapse/util/frozenutils.py6
29 files changed, 282 insertions, 236 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 073229b4c4..5bbbe8e2e7 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -252,10 +252,10 @@ class Auth(object):
             if ip_address not in app_service.ip_range_whitelist:
                 defer.returnValue((None, None))
 
-        if "user_id" not in request.args:
+        if b"user_id" not in request.args:
             defer.returnValue((app_service.sender, app_service))
 
-        user_id = request.args["user_id"][0]
+        user_id = request.args[b"user_id"][0].decode('utf8')
         if app_service.sender == user_id:
             defer.returnValue((app_service.sender, app_service))
 
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 6074df292f..14f5540280 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -55,6 +55,7 @@ class Codes(object):
     SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
     CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
     CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
+    MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED"
 
 
 class CodeMessageException(RuntimeError):
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 79772fa61a..5f0ca51ac7 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -20,6 +20,8 @@ import sys
 
 from six import iteritems
 
+from prometheus_client import Gauge
+
 from twisted.application import service
 from twisted.internet import defer, reactor
 from twisted.web.resource import EncodingResourceWrapper, NoResource
@@ -301,6 +303,11 @@ class SynapseHomeServer(HomeServer):
             quit_with_error(e.message)
 
 
+# Gauges to expose monthly active user control metrics
+current_mau_gauge = Gauge("synapse_admin_current_mau", "Current MAU")
+max_mau_value_gauge = Gauge("synapse_admin_max_mau_value", "MAU Limit")
+
+
 def setup(config_options):
     """
     Args:
@@ -516,6 +523,18 @@ def run(hs):
         MonthlyActiveUsersStore(hs).reap_monthly_active_users, 1000 * 60 * 60
     )
 
+    @defer.inlineCallbacks
+    def generate_monthly_active_users():
+        count = 0
+        if hs.config.limit_usage_by_mau:
+            count = yield hs.get_datastore().count_monthly_users()
+        current_mau_gauge.set(float(count))
+        max_mau_value_gauge.set(float(hs.config.max_mau_value))
+
+    generate_monthly_active_users()
+    if hs.config.limit_usage_by_mau:
+        clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000)
+
     if hs.config.report_stats:
         logger.info("Scheduling stats reporting for 3 hour intervals")
         clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000)
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 18102656b0..a8014e9c50 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -67,6 +67,14 @@ class ServerConfig(Config):
             "block_non_admin_invites", False,
         )
 
+        # Options to control access by tracking MAU
+        self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
+        if self.limit_usage_by_mau:
+            self.max_mau_value = config.get(
+                "max_mau_value", 0,
+            )
+        else:
+            self.max_mau_value = 0
         # FIXME: federation_domain_whitelist needs sytests
         self.federation_domain_whitelist = None
         federation_domain_whitelist = config.get(
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index e501251b6e..657935d1ac 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -207,10 +207,6 @@ class FederationServer(FederationBase):
                     edu.content
                 )
 
-        pdu_failures = getattr(transaction, "pdu_failures", [])
-        for fail in pdu_failures:
-            logger.info("Got failure %r", fail)
-
         response = {
             "pdus": pdu_results,
         }
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 5157c3860d..0bb468385d 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -62,8 +62,6 @@ class FederationRemoteSendQueue(object):
 
         self.edus = SortedDict()  # stream position -> Edu
 
-        self.failures = SortedDict()  # stream position -> (destination, Failure)
-
         self.device_messages = SortedDict()  # stream position -> destination
 
         self.pos = 1
@@ -79,7 +77,7 @@ class FederationRemoteSendQueue(object):
 
         for queue_name in [
             "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
-            "edus", "failures", "device_messages", "pos_time",
+            "edus", "device_messages", "pos_time",
         ]:
             register(queue_name, getattr(self, queue_name))
 
@@ -149,12 +147,6 @@ class FederationRemoteSendQueue(object):
             for key in keys[:i]:
                 del self.edus[key]
 
-            # Delete things out of failure map
-            keys = self.failures.keys()
-            i = self.failures.bisect_left(position_to_delete)
-            for key in keys[:i]:
-                del self.failures[key]
-
             # Delete things out of device map
             keys = self.device_messages.keys()
             i = self.device_messages.bisect_left(position_to_delete)
@@ -204,13 +196,6 @@ class FederationRemoteSendQueue(object):
 
         self.notifier.on_new_replication_data()
 
-    def send_failure(self, failure, destination):
-        """As per TransactionQueue"""
-        pos = self._next_pos()
-
-        self.failures[pos] = (destination, str(failure))
-        self.notifier.on_new_replication_data()
-
     def send_device_messages(self, destination):
         """As per TransactionQueue"""
         pos = self._next_pos()
@@ -285,17 +270,6 @@ class FederationRemoteSendQueue(object):
         for (pos, edu) in edus:
             rows.append((pos, EduRow(edu)))
 
-        # Fetch changed failures
-        i = self.failures.bisect_right(from_token)
-        j = self.failures.bisect_right(to_token) + 1
-        failures = self.failures.items()[i:j]
-
-        for (pos, (destination, failure)) in failures:
-            rows.append((pos, FailureRow(
-                destination=destination,
-                failure=failure,
-            )))
-
         # Fetch changed device messages
         i = self.device_messages.bisect_right(from_token)
         j = self.device_messages.bisect_right(to_token) + 1
@@ -417,34 +391,6 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", (
         buff.edus.setdefault(self.edu.destination, []).append(self.edu)
 
 
-class FailureRow(BaseFederationRow, namedtuple("FailureRow", (
-    "destination",  # str
-    "failure",
-))):
-    """Streams failures to a remote server. Failures are issued when there was
-    something wrong with a transaction the remote sent us, e.g. it included
-    an event that was invalid.
-    """
-
-    TypeId = "f"
-
-    @staticmethod
-    def from_data(data):
-        return FailureRow(
-            destination=data["destination"],
-            failure=data["failure"],
-        )
-
-    def to_data(self):
-        return {
-            "destination": self.destination,
-            "failure": self.failure,
-        }
-
-    def add_to_buffer(self, buff):
-        buff.failures.setdefault(self.destination, []).append(self.failure)
-
-
 class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
     "destination",  # str
 ))):
@@ -471,7 +417,6 @@ TypeToRow = {
         PresenceRow,
         KeyedEduRow,
         EduRow,
-        FailureRow,
         DeviceRow,
     )
 }
@@ -481,7 +426,6 @@ ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
     "presence",  # list(UserPresenceState)
     "keyed_edus",  # dict of destination -> { key -> Edu }
     "edus",  # dict of destination -> [Edu]
-    "failures",  # dict of destination -> [failures]
     "device_destinations",  # set of destinations
 ))
 
@@ -503,7 +447,6 @@ def process_rows_for_federation(transaction_queue, rows):
         presence=[],
         keyed_edus={},
         edus={},
-        failures={},
         device_destinations=set(),
     )
 
@@ -532,9 +475,5 @@ def process_rows_for_federation(transaction_queue, rows):
                 edu.destination, edu.edu_type, edu.content, key=None,
             )
 
-    for destination, failure_list in iteritems(buff.failures):
-        for failure in failure_list:
-            transaction_queue.send_failure(destination, failure)
-
     for destination in buff.device_destinations:
         transaction_queue.send_device_messages(destination)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 6996d6b695..78f9d40a3a 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -116,9 +116,6 @@ class TransactionQueue(object):
             ),
         )
 
-        # destination -> list of tuple(failure, deferred)
-        self.pending_failures_by_dest = {}
-
         # destination -> stream_id of last successfully sent to-device message.
         # NB: may be a long or an int.
         self.last_device_stream_id_by_dest = {}
@@ -382,19 +379,6 @@ class TransactionQueue(object):
 
         self._attempt_new_transaction(destination)
 
-    def send_failure(self, failure, destination):
-        if destination == self.server_name or destination == "localhost":
-            return
-
-        if not self.can_send_to(destination):
-            return
-
-        self.pending_failures_by_dest.setdefault(
-            destination, []
-        ).append(failure)
-
-        self._attempt_new_transaction(destination)
-
     def send_device_messages(self, destination):
         if destination == self.server_name or destination == "localhost":
             return
@@ -469,7 +453,6 @@ class TransactionQueue(object):
                 pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
                 pending_edus = self.pending_edus_by_dest.pop(destination, [])
                 pending_presence = self.pending_presence_by_dest.pop(destination, {})
-                pending_failures = self.pending_failures_by_dest.pop(destination, [])
 
                 pending_edus.extend(
                     self.pending_edus_keyed_by_dest.pop(destination, {}).values()
@@ -497,7 +480,7 @@ class TransactionQueue(object):
                     logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
                                  destination, len(pending_pdus))
 
-                if not pending_pdus and not pending_edus and not pending_failures:
+                if not pending_pdus and not pending_edus:
                     logger.debug("TX [%s] Nothing to send", destination)
                     self.last_device_stream_id_by_dest[destination] = (
                         device_stream_id
@@ -507,7 +490,7 @@ class TransactionQueue(object):
                 # END CRITICAL SECTION
 
                 success = yield self._send_new_transaction(
-                    destination, pending_pdus, pending_edus, pending_failures,
+                    destination, pending_pdus, pending_edus,
                 )
                 if success:
                     sent_transactions_counter.inc()
@@ -584,14 +567,12 @@ class TransactionQueue(object):
 
     @measure_func("_send_new_transaction")
     @defer.inlineCallbacks
-    def _send_new_transaction(self, destination, pending_pdus, pending_edus,
-                              pending_failures):
+    def _send_new_transaction(self, destination, pending_pdus, pending_edus):
 
         # Sort based on the order field
         pending_pdus.sort(key=lambda t: t[1])
         pdus = [x[0] for x in pending_pdus]
         edus = pending_edus
-        failures = [x.get_dict() for x in pending_failures]
 
         success = True
 
@@ -601,11 +582,10 @@ class TransactionQueue(object):
 
         logger.debug(
             "TX [%s] {%s} Attempting new transaction"
-            " (pdus: %d, edus: %d, failures: %d)",
+            " (pdus: %d, edus: %d)",
             destination, txn_id,
             len(pdus),
             len(edus),
-            len(failures)
         )
 
         logger.debug("TX [%s] Persisting transaction...", destination)
@@ -617,7 +597,6 @@ class TransactionQueue(object):
             destination=destination,
             pdus=pdus,
             edus=edus,
-            pdu_failures=failures,
         )
 
         self._next_txn_id += 1
@@ -627,12 +606,11 @@ class TransactionQueue(object):
         logger.debug("TX [%s] Persisted transaction", destination)
         logger.info(
             "TX [%s] {%s} Sending transaction [%s],"
-            " (PDUs: %d, EDUs: %d, failures: %d)",
+            " (PDUs: %d, EDUs: %d)",
             destination, txn_id,
             transaction.transaction_id,
             len(pdus),
             len(edus),
-            len(failures),
         )
 
         # Actually send the transaction
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 8574898f0c..eae5f2b427 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -165,7 +165,7 @@ def _parse_auth_header(header_bytes):
         param_dict = dict(kv.split("=") for kv in params)
 
         def strip_quotes(value):
-            if value.startswith(b"\""):
+            if value.startswith("\""):
                 return value[1:-1]
             else:
                 return value
@@ -283,11 +283,10 @@ class FederationSendServlet(BaseFederationServlet):
             )
 
             logger.info(
-                "Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
+                "Received txn %s from %s. (PDUs: %d, EDUs: %d)",
                 transaction_id, origin,
                 len(transaction_data.get("pdus", [])),
                 len(transaction_data.get("edus", [])),
-                len(transaction_data.get("failures", [])),
             )
 
             # We should ideally be getting this from the security layer.
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index bb1b3b13f7..c5ab14314e 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -73,7 +73,6 @@ class Transaction(JsonEncodedObject):
         "previous_ids",
         "pdus",
         "edus",
-        "pdu_failures",
     ]
 
     internal_keys = [
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 402e44cdef..184eef09d0 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+import unicodedata
 
 import attr
 import bcrypt
@@ -519,6 +520,7 @@ class AuthHandler(BaseHandler):
         """
         logger.info("Logging in user %s on device %s", user_id, device_id)
         access_token = yield self.issue_access_token(user_id, device_id)
+        yield self._check_mau_limits()
 
         # the device *should* have been registered before we got here; however,
         # it's possible we raced against a DELETE operation. The thing we
@@ -626,6 +628,7 @@ class AuthHandler(BaseHandler):
         # special case to check for "password" for the check_password interface
         # for the auth providers
         password = login_submission.get("password")
+
         if login_type == LoginType.PASSWORD:
             if not self._password_enabled:
                 raise SynapseError(400, "Password login has been disabled.")
@@ -707,9 +710,10 @@ class AuthHandler(BaseHandler):
         multiple inexact matches.
 
         Args:
-            user_id (str): complete @user:id
+            user_id (unicode): complete @user:id
+            password (unicode): the provided password
         Returns:
-            (str) the canonical_user_id, or None if unknown user / bad password
+            (unicode) the canonical_user_id, or None if unknown user / bad password
         """
         lookupres = yield self._find_user_id_and_pwd_hash(user_id)
         if not lookupres:
@@ -728,15 +732,18 @@ class AuthHandler(BaseHandler):
                                                   device_id)
         defer.returnValue(access_token)
 
+    @defer.inlineCallbacks
     def validate_short_term_login_token_and_get_user_id(self, login_token):
+        yield self._check_mau_limits()
         auth_api = self.hs.get_auth()
+        user_id = None
         try:
             macaroon = pymacaroons.Macaroon.deserialize(login_token)
             user_id = auth_api.get_user_id_from_macaroon(macaroon)
             auth_api.validate_macaroon(macaroon, "login", True, user_id)
-            return user_id
         except Exception:
             raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
+        defer.returnValue(user_id)
 
     @defer.inlineCallbacks
     def delete_access_token(self, access_token):
@@ -849,14 +856,19 @@ class AuthHandler(BaseHandler):
         """Computes a secure hash of password.
 
         Args:
-            password (str): Password to hash.
+            password (unicode): Password to hash.
 
         Returns:
-            Deferred(str): Hashed password.
+            Deferred(unicode): Hashed password.
         """
         def _do_hash():
-            return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
-                                 bcrypt.gensalt(self.bcrypt_rounds))
+            # Normalise the Unicode in the password
+            pw = unicodedata.normalize("NFKC", password)
+
+            return bcrypt.hashpw(
+                pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
+                bcrypt.gensalt(self.bcrypt_rounds),
+            ).decode('ascii')
 
         return make_deferred_yieldable(
             threads.deferToThreadPool(
@@ -868,16 +880,19 @@ class AuthHandler(BaseHandler):
         """Validates that self.hash(password) == stored_hash.
 
         Args:
-            password (str): Password to hash.
-            stored_hash (str): Expected hash value.
+            password (unicode): Password to hash.
+            stored_hash (unicode): Expected hash value.
 
         Returns:
             Deferred(bool): Whether self.hash(password) == stored_hash.
         """
 
         def _do_validate_hash():
+            # Normalise the Unicode in the password
+            pw = unicodedata.normalize("NFKC", password)
+
             return bcrypt.checkpw(
-                password.encode('utf8') + self.hs.config.password_pepper,
+                pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
                 stored_hash.encode('utf8')
             )
 
@@ -892,6 +907,19 @@ class AuthHandler(BaseHandler):
         else:
             return defer.succeed(False)
 
+    @defer.inlineCallbacks
+    def _check_mau_limits(self):
+        """
+        Ensure that if mau blocking is enabled that invalid users cannot
+        log in.
+        """
+        if self.hs.config.limit_usage_by_mau is True:
+            current_mau = yield self.store.count_monthly_users()
+            if current_mau >= self.hs.config.max_mau_value:
+                raise AuthError(
+                    403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
+                )
+
 
 @attr.s
 class MacaroonGenerator(object):
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 49068c06d9..91d8def08b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -76,7 +76,7 @@ class FederationHandler(BaseHandler):
         self.hs = hs
 
         self.store = hs.get_datastore()
-        self.replication_layer = hs.get_federation_client()
+        self.federation_client = hs.get_federation_client()
         self.state_handler = hs.get_state_handler()
         self.server_name = hs.hostname
         self.keyring = hs.get_keyring()
@@ -255,7 +255,7 @@ class FederationHandler(BaseHandler):
                     # know about
                     for p in prevs - seen:
                         state, got_auth_chain = (
-                            yield self.replication_layer.get_state_for_room(
+                            yield self.federation_client.get_state_for_room(
                                 origin, pdu.room_id, p
                             )
                         )
@@ -338,7 +338,7 @@ class FederationHandler(BaseHandler):
         #
         # see https://github.com/matrix-org/synapse/pull/1744
 
-        missing_events = yield self.replication_layer.get_missing_events(
+        missing_events = yield self.federation_client.get_missing_events(
             origin,
             pdu.room_id,
             earliest_events_ids=list(latest),
@@ -522,7 +522,7 @@ class FederationHandler(BaseHandler):
         if dest == self.server_name:
             raise SynapseError(400, "Can't backfill from self.")
 
-        events = yield self.replication_layer.backfill(
+        events = yield self.federation_client.backfill(
             dest,
             room_id,
             limit=limit,
@@ -570,7 +570,7 @@ class FederationHandler(BaseHandler):
         state_events = {}
         events_to_state = {}
         for e_id in edges:
-            state, auth = yield self.replication_layer.get_state_for_room(
+            state, auth = yield self.federation_client.get_state_for_room(
                 destination=dest,
                 room_id=room_id,
                 event_id=e_id
@@ -612,7 +612,7 @@ class FederationHandler(BaseHandler):
                 results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
                     [
                         logcontext.run_in_background(
-                            self.replication_layer.get_pdu,
+                            self.federation_client.get_pdu,
                             [dest],
                             event_id,
                             outlier=True,
@@ -893,7 +893,7 @@ class FederationHandler(BaseHandler):
 
         Invites must be signed by the invitee's server before distribution.
         """
-        pdu = yield self.replication_layer.send_invite(
+        pdu = yield self.federation_client.send_invite(
             destination=target_host,
             room_id=event.room_id,
             event_id=event.event_id,
@@ -955,7 +955,7 @@ class FederationHandler(BaseHandler):
                 target_hosts.insert(0, origin)
             except ValueError:
                 pass
-            ret = yield self.replication_layer.send_join(target_hosts, event)
+            ret = yield self.federation_client.send_join(target_hosts, event)
 
             origin = ret["origin"]
             state = ret["state"]
@@ -1211,7 +1211,7 @@ class FederationHandler(BaseHandler):
         except ValueError:
             pass
 
-        yield self.replication_layer.send_leave(
+        yield self.federation_client.send_leave(
             target_hosts,
             event
         )
@@ -1234,7 +1234,7 @@ class FederationHandler(BaseHandler):
     @defer.inlineCallbacks
     def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
                                content={},):
-        origin, pdu = yield self.replication_layer.make_membership_event(
+        origin, pdu = yield self.federation_client.make_membership_event(
             target_hosts,
             room_id,
             user_id,
@@ -1567,7 +1567,7 @@ class FederationHandler(BaseHandler):
                     missing_auth_events.add(e_id)
 
         for e_id in missing_auth_events:
-            m_ev = yield self.replication_layer.get_pdu(
+            m_ev = yield self.federation_client.get_pdu(
                 [origin],
                 e_id,
                 outlier=True,
@@ -1777,7 +1777,7 @@ class FederationHandler(BaseHandler):
             logger.info("Missing auth: %s", missing_auth)
             # If we don't have all the auth events, we need to get them.
             try:
-                remote_auth_chain = yield self.replication_layer.get_event_auth(
+                remote_auth_chain = yield self.federation_client.get_event_auth(
                     origin, event.room_id, event.event_id
                 )
 
@@ -1893,7 +1893,7 @@ class FederationHandler(BaseHandler):
 
                 try:
                     # 2. Get remote difference.
-                    result = yield self.replication_layer.query_auth(
+                    result = yield self.federation_client.query_auth(
                         origin,
                         event.room_id,
                         event.event_id,
@@ -2192,7 +2192,7 @@ class FederationHandler(BaseHandler):
             yield member_handler.send_membership_event(None, event, context)
         else:
             destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
-            yield self.replication_layer.forward_third_party_invite(
+            yield self.federation_client.forward_third_party_invite(
                 destinations,
                 room_id,
                 event_dict,
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 7caff0cbc8..289704b241 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -45,7 +45,7 @@ class RegistrationHandler(BaseHandler):
             hs (synapse.server.HomeServer):
         """
         super(RegistrationHandler, self).__init__(hs)
-
+        self.hs = hs
         self.auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
         self.profile_handler = hs.get_profile_handler()
@@ -131,7 +131,7 @@ class RegistrationHandler(BaseHandler):
         Args:
             localpart : The local part of the user ID to register. If None,
               one will be generated.
-            password (str) : The password to assign to this user so they can
+            password (unicode) : The password to assign to this user so they can
               login again. This can be None which means they cannot login again
               via a password (e.g. the user is an application service user).
             generate_token (bool): Whether a new access token should be
@@ -144,6 +144,7 @@ class RegistrationHandler(BaseHandler):
         Raises:
             RegistrationError if there was a problem registering.
         """
+        yield self._check_mau_limits()
         password_hash = None
         if password:
             password_hash = yield self.auth_handler().hash(password)
@@ -288,6 +289,7 @@ class RegistrationHandler(BaseHandler):
                 400,
                 "User ID can only contain characters a-z, 0-9, or '=_-./'",
             )
+        yield self._check_mau_limits()
         user = UserID(localpart, self.hs.hostname)
         user_id = user.to_string()
 
@@ -437,7 +439,7 @@ class RegistrationHandler(BaseHandler):
         """
         if localpart is None:
             raise SynapseError(400, "Request must include user id")
-
+        yield self._check_mau_limits()
         need_register = True
 
         try:
@@ -531,3 +533,16 @@ class RegistrationHandler(BaseHandler):
             remote_room_hosts=remote_room_hosts,
             action="join",
         )
+
+    @defer.inlineCallbacks
+    def _check_mau_limits(self):
+        """
+        Do not accept registrations if monthly active user limits exceeded
+         and limiting is enabled
+        """
+        if self.hs.config.limit_usage_by_mau is True:
+            current_mau = yield self.store.count_monthly_users()
+            if current_mau >= self.hs.config.max_mau_value:
+                raise RegistrationError(
+                    403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED
+                )
diff --git a/synapse/http/server.py b/synapse/http/server.py
index c70fdbdfd2..1940c1c4f4 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -13,12 +13,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import cgi
 import collections
 import logging
-import urllib
 
-from six.moves import http_client
+from six import PY3
+from six.moves import http_client, urllib
 
 from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
 
@@ -264,6 +265,7 @@ class JsonResource(HttpServer, resource.Resource):
         self.hs = hs
 
     def register_paths(self, method, path_patterns, callback):
+        method = method.encode("utf-8")  # method is bytes on py3
         for path_pattern in path_patterns:
             logger.debug("Registering for %s %s", method, path_pattern.pattern)
             self.path_regexs.setdefault(method, []).append(
@@ -296,8 +298,19 @@ class JsonResource(HttpServer, resource.Resource):
         # here. If it throws an exception, that is handled by the wrapper
         # installed by @request_handler.
 
+        def _unquote(s):
+            if PY3:
+                # On Python 3, unquote is unicode -> unicode
+                return urllib.parse.unquote(s)
+            else:
+                # On Python 2, unquote is bytes -> bytes We need to encode the
+                # URL again (as it was decoded by _get_handler_for request), as
+                # ASCII because it's a URL, and then decode it to get the UTF-8
+                # characters that were quoted.
+                return urllib.parse.unquote(s.encode('ascii')).decode('utf8')
+
         kwargs = intern_dict({
-            name: urllib.unquote(value).decode("UTF-8") if value else value
+            name: _unquote(value) if value else value
             for name, value in group_dict.items()
         })
 
@@ -313,9 +326,9 @@ class JsonResource(HttpServer, resource.Resource):
             request (twisted.web.http.Request):
 
         Returns:
-            Tuple[Callable, dict[str, str]]: callback method, and the dict
-                mapping keys to path components as specified in the handler's
-                path match regexp.
+            Tuple[Callable, dict[unicode, unicode]]: callback method, and the
+                dict mapping keys to path components as specified in the
+                handler's path match regexp.
 
                 The callback will normally be a method registered via
                 register_paths, so will return (possibly via Deferred) either
@@ -327,7 +340,7 @@ class JsonResource(HttpServer, resource.Resource):
         # Loop through all the registered callbacks to check if the method
         # and path regex match
         for path_entry in self.path_regexs.get(request.method, []):
-            m = path_entry.pattern.match(request.path)
+            m = path_entry.pattern.match(request.path.decode('ascii'))
             if m:
                 # We found a match!
                 return path_entry.callback, m.groupdict()
@@ -383,7 +396,7 @@ class RootRedirect(resource.Resource):
         self.url = path
 
     def render_GET(self, request):
-        return redirectTo(self.url, request)
+        return redirectTo(self.url.encode('ascii'), request)
 
     def getChild(self, name, request):
         if len(name) == 0:
@@ -404,12 +417,14 @@ def respond_with_json(request, code, json_object, send_cors=False,
         return
 
     if pretty_print:
-        json_bytes = encode_pretty_printed_json(json_object) + "\n"
+        json_bytes = (encode_pretty_printed_json(json_object) + "\n"
+                      ).encode("utf-8")
     else:
         if canonical_json or synapse.events.USE_FROZEN_DICTS:
+            # canonicaljson already encodes to bytes
             json_bytes = encode_canonical_json(json_object)
         else:
-            json_bytes = json.dumps(json_object)
+            json_bytes = json.dumps(json_object).encode("utf-8")
 
     return respond_with_json_bytes(
         request, code, json_bytes,
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 882816dc8f..69f7085291 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -171,8 +171,16 @@ def parse_json_value_from_request(request, allow_empty_body=False):
     if not content_bytes and allow_empty_body:
         return None
 
+    # Decode to Unicode so that simplejson will return Unicode strings on
+    # Python 2
     try:
-        content = json.loads(content_bytes)
+        content_unicode = content_bytes.decode('utf8')
+    except UnicodeDecodeError:
+        logger.warn("Unable to decode UTF-8")
+        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
+
+    try:
+        content = json.loads(content_unicode)
     except Exception as e:
         logger.warn("Unable to parse JSON: %s", e)
         raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 99f6c6e3c3..80d625eecc 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -18,6 +18,7 @@ import hashlib
 import hmac
 import logging
 
+from six import text_type
 from six.moves import http_client
 
 from twisted.internet import defer
@@ -131,7 +132,10 @@ class UserRegisterServlet(ClientV1RestServlet):
                 400, "username must be specified", errcode=Codes.BAD_JSON,
             )
         else:
-            if (not isinstance(body['username'], str) or len(body['username']) > 512):
+            if (
+                not isinstance(body['username'], text_type)
+                or len(body['username']) > 512
+            ):
                 raise SynapseError(400, "Invalid username")
 
             username = body["username"].encode("utf-8")
@@ -143,7 +147,10 @@ class UserRegisterServlet(ClientV1RestServlet):
                 400, "password must be specified", errcode=Codes.BAD_JSON,
             )
         else:
-            if (not isinstance(body['password'], str) or len(body['password']) > 512):
+            if (
+                not isinstance(body['password'], text_type)
+                or len(body['password']) > 512
+            ):
                 raise SynapseError(400, "Invalid password")
 
             password = body["password"].encode("utf-8")
@@ -166,17 +173,18 @@ class UserRegisterServlet(ClientV1RestServlet):
         want_mac.update(b"admin" if admin else b"notadmin")
         want_mac = want_mac.hexdigest()
 
-        if not hmac.compare_digest(want_mac, got_mac):
-            raise SynapseError(
-                403, "HMAC incorrect",
-            )
+        if not hmac.compare_digest(want_mac, got_mac.encode('ascii')):
+            raise SynapseError(403, "HMAC incorrect")
 
         # Reuse the parts of RegisterRestServlet to reduce code duplication
         from synapse.rest.client.v2_alpha.register import RegisterRestServlet
+
         register = RegisterRestServlet(self.hs)
 
         (user_id, _) = yield register.registration_handler.register(
-            localpart=username.lower(), password=password, admin=bool(admin),
+            localpart=body['username'].lower(),
+            password=body["password"],
+            admin=bool(admin),
             generate_token=False,
         )
 
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 69dcd618cb..97733f3026 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -18,7 +18,7 @@ import logging
 
 from twisted.internet import defer
 
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.types import RoomAlias
 
@@ -159,7 +159,7 @@ class ClientDirectoryListServer(ClientV1RestServlet):
     def on_GET(self, request, room_id):
         room = yield self.store.get_room(room_id)
         if room is None:
-            raise SynapseError(400, "Unknown room")
+            raise NotFoundError("Unknown room")
 
         defer.returnValue((200, {
             "visibility": "public" if room["is_public"] else "private"
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index d6cf915d86..2f64155d13 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -193,15 +193,15 @@ class RegisterRestServlet(RestServlet):
     def on_POST(self, request):
         body = parse_json_object_from_request(request)
 
-        kind = "user"
-        if "kind" in request.args:
-            kind = request.args["kind"][0]
+        kind = b"user"
+        if b"kind" in request.args:
+            kind = request.args[b"kind"][0]
 
-        if kind == "guest":
+        if kind == b"guest":
             ret = yield self._do_guest_registration(body)
             defer.returnValue(ret)
             return
-        elif kind != "user":
+        elif kind != b"user":
             raise UnrecognizedRequestError(
                 "Do not understand membership kind: %s" % (kind,)
             )
@@ -389,8 +389,8 @@ class RegisterRestServlet(RestServlet):
             assert_params_in_dict(params, ["password"])
 
             desired_username = params.get("username", None)
-            new_password = params.get("password", None)
             guest_access_token = params.get("guest_access_token", None)
+            new_password = params.get("password", None)
 
             if desired_username is not None:
                 desired_username = desired_username.lower()
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index b25993fcb5..a6189224ee 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -177,7 +177,7 @@ class MediaStorage(object):
             if res:
                 with res:
                     consumer = BackgroundFileConsumer(
-                        open(local_path, "w"), self.hs.get_reactor())
+                        open(local_path, "wb"), self.hs.get_reactor())
                     yield res.write_to_consumer(consumer)
                     yield consumer.wait()
                 defer.returnValue(local_path)
diff --git a/synapse/state.py b/synapse/state.py
index 033f55d967..e1092b97a9 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -577,7 +577,7 @@ def _make_state_cache_entry(
 
 def _ordered_events(events):
     def key_func(e):
-        return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
+        return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest()
 
     return sorted(events, key=key_func)
 
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index ba88a54979..134e4a80f1 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -66,6 +66,7 @@ class DataStore(RoomMemberStore, RoomStore,
                 PresenceStore, TransactionStore,
                 DirectoryStore, KeyStore, StateStore, SignatureStore,
                 ApplicationServiceStore,
+                EventsStore,
                 EventFederationStore,
                 MediaRepositoryStore,
                 RejectionsStore,
@@ -73,7 +74,6 @@ class DataStore(RoomMemberStore, RoomStore,
                 PusherStore,
                 PushRuleStore,
                 ApplicationServiceTransactionStore,
-                EventsStore,
                 ReceiptsStore,
                 EndToEndKeyStore,
                 SearchStore,
@@ -94,6 +94,7 @@ class DataStore(RoomMemberStore, RoomStore,
         self._clock = hs.get_clock()
         self.database_engine = hs.database_engine
 
+        self.db_conn = db_conn
         self._stream_id_gen = StreamIdGenerator(
             db_conn, "events", "stream_ordering",
             extra_tables=[("local_invites", "stream_id")]
@@ -266,6 +267,31 @@ class DataStore(RoomMemberStore, RoomStore,
 
         return self.runInteraction("count_users", _count_users)
 
+    def count_monthly_users(self):
+        """Counts the number of users who used this homeserver in the last 30 days
+
+        This method should be refactored with count_daily_users - the only
+        reason not to is waiting on definition of mau
+
+        Returns:
+            Defered[int]
+        """
+        def _count_monthly_users(txn):
+            thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
+            sql = """
+                SELECT COALESCE(count(*), 0) FROM (
+                    SELECT user_id FROM user_ips
+                    WHERE last_seen > ?
+                    GROUP BY user_id
+                ) u
+            """
+
+            txn.execute(sql, (thirty_days_ago,))
+            count, = txn.fetchone()
+            return count
+
+        return self.runInteraction("count_monthly_users", _count_monthly_users)
+
     def count_r30_users(self):
         """
         Counts the number of 30 day retained users, defined as:-
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 9f12b360bc..31248d5e06 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -22,7 +22,7 @@ from twisted.internet import defer
 
 from synapse.appservice import AppServiceTransaction
 from synapse.config.appservice import load_appservices
-from synapse.storage.events import EventsWorkerStore
+from synapse.storage.events_worker import EventsWorkerStore
 
 from ._base import SQLBaseStore
 
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 5d3ee90017..8bd35df119 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -25,7 +25,7 @@ from twisted.internet import defer
 from synapse.api.errors import StoreError
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.events import EventsWorkerStore
+from synapse.storage.events_worker import EventsWorkerStore
 from synapse.storage.signatures import SignatureWorkerStore
 from synapse.util.caches.descriptors import cached
 
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 2f482af3a1..61223da1a5 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -34,6 +34,8 @@ from synapse.api.errors import SynapseError
 from synapse.events import EventBase  # noqa: F401
 from synapse.events.snapshot import EventContext  # noqa: F401
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage.event_federation import EventFederationStore
 from synapse.storage.events_worker import EventsWorkerStore
 from synapse.types import RoomStreamToken, get_domain_from_id
 from synapse.util.async import ObservableDeferred
@@ -65,7 +67,13 @@ state_delta_reuse_delta_counter = Counter(
 
 
 def encode_json(json_object):
-    return frozendict_json_encoder.encode(json_object)
+    """
+    Encode a Python object as JSON and return it in a Unicode string.
+    """
+    out = frozendict_json_encoder.encode(json_object)
+    if isinstance(out, bytes):
+        out = out.decode('utf8')
+    return out
 
 
 class _EventPeristenceQueue(object):
@@ -193,7 +201,9 @@ def _retry_on_integrity_error(func):
     return f
 
 
-class EventsStore(EventsWorkerStore):
+# inherits from EventFederationStore so that we can call _update_backward_extremities
+# and _handle_mult_prev_events (though arguably those could both be moved in here)
+class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore):
     EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
     EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
 
@@ -1054,7 +1064,7 @@ class EventsStore(EventsWorkerStore):
 
                 metadata_json = encode_json(
                     event.internal_metadata.get_dict()
-                ).decode("UTF-8")
+                )
 
                 sql = (
                     "UPDATE event_json SET internal_metadata = ?"
@@ -1168,8 +1178,8 @@ class EventsStore(EventsWorkerStore):
                     "room_id": event.room_id,
                     "internal_metadata": encode_json(
                         event.internal_metadata.get_dict()
-                    ).decode("UTF-8"),
-                    "json": encode_json(event_dict(event)).decode("UTF-8"),
+                    ),
+                    "json": encode_json(event_dict(event)),
                 }
                 for event, _ in events_and_contexts
             ],
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 027bf8c85e..10dce21cea 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -24,7 +24,7 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.storage.events import EventsWorkerStore
+from synapse.storage.events_worker import EventsWorkerStore
 from synapse.types import get_domain_from_id
 from synapse.util.async import Linearizer
 from synapse.util.caches import intern_string
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index 470212aa2a..5623391f6e 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -74,7 +74,7 @@ class SignatureWorkerStore(SQLBaseStore):
             txn (cursor):
             event_id (str): Id for the Event.
         Returns:
-            A dict of algorithm -> hash.
+            A dict[unicode, bytes] of algorithm -> hash.
         """
         query = (
             "SELECT algorithm, hash"
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 25d0097b58..b9f2b74ac6 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -43,7 +43,7 @@ from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.events import EventsWorkerStore
+from synapse.storage.events_worker import EventsWorkerStore
 from synapse.types import RoomStreamToken
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.logcontext import make_deferred_yieldable, run_in_background
diff --git a/synapse/types.py b/synapse/types.py
index 08f058f714..41afb27a74 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -137,7 +137,7 @@ class DomainSpecificString(
     @classmethod
     def from_string(cls, s):
         """Parse the string given by 's' into a structure object."""
-        if len(s) < 1 or s[0] != cls.SIGIL:
+        if len(s) < 1 or s[0:1] != cls.SIGIL:
             raise SynapseError(400, "Expected %s string to start with '%s'" % (
                 cls.__name__, cls.SIGIL,
             ))
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index f8a07df6b8..861c24809c 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -473,105 +473,101 @@ class CacheListDescriptor(_CacheDescriptorBase):
 
         @functools.wraps(self.orig)
         def wrapped(*args, **kwargs):
-            # If we're passed a cache_context then we'll want to call its invalidate()
-            # whenever we are invalidated
+            # If we're passed a cache_context then we'll want to call its
+            # invalidate() whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
 
             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
             keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
             list_args = arg_dict[self.list_name]
 
-            # cached is a dict arg -> deferred, where deferred results in a
-            # 2-tuple (`arg`, `result`)
             results = {}
-            cached_defers = {}
-            missing = []
+
+            def update_results_dict(res, arg):
+                results[arg] = res
+
+            # list of deferreds to wait for
+            cached_defers = []
+
+            missing = set()
 
             # If the cache takes a single arg then that is used as the key,
             # otherwise a tuple is used.
             if num_args == 1:
-                def cache_get(arg):
-                    return cache.get(arg, callback=invalidate_callback)
+                def arg_to_cache_key(arg):
+                    return arg
             else:
-                key = list(keyargs)
+                keylist = list(keyargs)
 
-                def cache_get(arg):
-                    key[self.list_pos] = arg
-                    return cache.get(tuple(key), callback=invalidate_callback)
+                def arg_to_cache_key(arg):
+                    keylist[self.list_pos] = arg
+                    return tuple(keylist)
 
             for arg in list_args:
                 try:
-                    res = cache_get(arg)
-
+                    res = cache.get(arg_to_cache_key(arg),
+                                    callback=invalidate_callback)
                     if not isinstance(res, ObservableDeferred):
                         results[arg] = res
                     elif not res.has_succeeded():
                         res = res.observe()
-                        res.addCallback(lambda r, arg: (arg, r), arg)
-                        cached_defers[arg] = res
+                        res.addCallback(update_results_dict, arg)
+                        cached_defers.append(res)
                     else:
                         results[arg] = res.get_result()
                 except KeyError:
-                    missing.append(arg)
+                    missing.add(arg)
 
             if missing:
+                # we need an observable deferred for each entry in the list,
+                # which we put in the cache. Each deferred resolves with the
+                # relevant result for that key.
+                deferreds_map = {}
+                for arg in missing:
+                    deferred = defer.Deferred()
+                    deferreds_map[arg] = deferred
+                    key = arg_to_cache_key(arg)
+                    observable = ObservableDeferred(deferred)
+                    cache.set(key, observable, callback=invalidate_callback)
+
+                def complete_all(res):
+                    # the wrapped function has completed. It returns a
+                    # a dict. We can now resolve the observable deferreds in
+                    # the cache and update our own result map.
+                    for e in missing:
+                        val = res.get(e, None)
+                        deferreds_map[e].callback(val)
+                        results[e] = val
+
+                def errback(f):
+                    # the wrapped function has failed. Invalidate any cache
+                    # entries we're supposed to be populating, and fail
+                    # their deferreds.
+                    for e in missing:
+                        key = arg_to_cache_key(e)
+                        cache.invalidate(key)
+                        deferreds_map[e].errback(f)
+
+                    # return the failure, to propagate to our caller.
+                    return f
+
                 args_to_call = dict(arg_dict)
-                args_to_call[self.list_name] = missing
+                args_to_call[self.list_name] = list(missing)
 
-                ret_d = defer.maybeDeferred(
+                cached_defers.append(defer.maybeDeferred(
                     logcontext.preserve_fn(self.function_to_call),
                     **args_to_call
-                )
-
-                ret_d = ObservableDeferred(ret_d)
-
-                # We need to create deferreds for each arg in the list so that
-                # we can insert the new deferred into the cache.
-                for arg in missing:
-                    observer = ret_d.observe()
-                    observer.addCallback(lambda r, arg: r.get(arg, None), arg)
-
-                    observer = ObservableDeferred(observer)
-
-                    if num_args == 1:
-                        cache.set(
-                            arg, observer,
-                            callback=invalidate_callback
-                        )
-
-                        def invalidate(f, key):
-                            cache.invalidate(key)
-                            return f
-                        observer.addErrback(invalidate, arg)
-                    else:
-                        key = list(keyargs)
-                        key[self.list_pos] = arg
-                        cache.set(
-                            tuple(key), observer,
-                            callback=invalidate_callback
-                        )
-
-                        def invalidate(f, key):
-                            cache.invalidate(key)
-                            return f
-                        observer.addErrback(invalidate, tuple(key))
-
-                    res = observer.observe()
-                    res.addCallback(lambda r, arg: (arg, r), arg)
-
-                    cached_defers[arg] = res
+                ).addCallbacks(complete_all, errback))
 
             if cached_defers:
-                def update_results_dict(res):
-                    results.update(res)
-                    return results
-
-                return logcontext.make_deferred_yieldable(defer.gatherResults(
-                    list(cached_defers.values()),
+                d = defer.gatherResults(
+                    cached_defers,
                     consumeErrors=True,
-                ).addCallback(update_results_dict).addErrback(
+                ).addCallbacks(
+                    lambda _: results,
                     unwrapFirstError
-                ))
+                )
+                return logcontext.make_deferred_yieldable(d)
             else:
                 return results
 
@@ -625,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
     cache.
 
     Args:
-        cache (Cache): The underlying cache to use.
+        cached_method_name (str): The name of the single-item lookup method.
+            This is only used to find the cache to use.
         list_name (str): The name of the argument that is the list to use to
             do batch lookups in the cache.
         num_args (int): Number of arguments to use as the key in the cache
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 581c6052ac..014edea971 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from six import string_types
+from six import binary_type, text_type
 
 from canonicaljson import json
 from frozendict import frozendict
@@ -26,7 +26,7 @@ def freeze(o):
     if isinstance(o, frozendict):
         return o
 
-    if isinstance(o, string_types):
+    if isinstance(o, (binary_type, text_type)):
         return o
 
     try:
@@ -41,7 +41,7 @@ def unfreeze(o):
     if isinstance(o, (dict, frozendict)):
         return dict({k: unfreeze(v) for k, v in o.items()})
 
-    if isinstance(o, string_types):
+    if isinstance(o, (binary_type, text_type)):
         return o
 
     try: