summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py14
-rw-r--r--synapse/api/ratelimiting.py2
-rw-r--r--synapse/config/logger.py3
-rw-r--r--synapse/crypto/keyclient.py8
-rw-r--r--synapse/federation/transport/client.py2
-rw-r--r--synapse/federation/transport/server.py6
-rw-r--r--synapse/http/request_metrics.py13
-rw-r--r--synapse/http/servlet.py56
-rw-r--r--synapse/http/site.py12
-rw-r--r--synapse/metrics/background_process_metrics.py26
-rw-r--r--synapse/rest/client/transactions.py2
-rw-r--r--synapse/rest/consent/consent_resource.py15
-rw-r--r--synapse/rest/media/v1/upload_resource.py10
-rw-r--r--synapse/secrets.py2
-rw-r--r--synapse/storage/state.py158
-rw-r--r--synapse/util/logcontext.py12
-rw-r--r--synapse/util/logutils.py10
-rw-r--r--synapse/util/stringutils.py15
-rw-r--r--synapse/util/versionstring.py12
20 files changed, 302 insertions, 78 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index a14d578e36..252c49ca82 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -17,4 +17,4 @@
 """ This is a reference implementation of a Matrix home server.
 """
 
-__version__ = "0.33.2"
+__version__ = "0.33.3rc2"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 022211e34e..6502a6be7b 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -211,7 +211,7 @@ class Auth(object):
             user_agent = request.requestHeaders.getRawHeaders(
                 b"User-Agent",
                 default=[b""]
-            )[0]
+            )[0].decode('ascii', 'surrogateescape')
             if user and access_token and ip_addr:
                 yield self.store.insert_client_ip(
                     user_id=user.to_string(),
@@ -682,7 +682,7 @@ class Auth(object):
         Returns:
             bool: False if no access_token was given, True otherwise.
         """
-        query_params = request.args.get("access_token")
+        query_params = request.args.get(b"access_token")
         auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
         return bool(query_params) or bool(auth_headers)
 
@@ -698,7 +698,7 @@ class Auth(object):
                 401 since some of the old clients depended on auth errors returning
                 403.
         Returns:
-            str: The access_token
+            unicode: The access_token
         Raises:
             AuthError: If there isn't an access_token in the request.
         """
@@ -720,9 +720,9 @@ class Auth(object):
                     "Too many Authorization headers.",
                     errcode=Codes.MISSING_TOKEN,
                 )
-            parts = auth_headers[0].split(" ")
-            if parts[0] == "Bearer" and len(parts) == 2:
-                return parts[1]
+            parts = auth_headers[0].split(b" ")
+            if parts[0] == b"Bearer" and len(parts) == 2:
+                return parts[1].decode('ascii')
             else:
                 raise AuthError(
                     token_not_found_http_status,
@@ -738,7 +738,7 @@ class Auth(object):
                     errcode=Codes.MISSING_TOKEN
                 )
 
-            return query_params[0]
+            return query_params[0].decode('ascii')
 
     @defer.inlineCallbacks
     def check_in_room_or_world_readable(self, room_id, user_id):
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 06cc8d90b8..3bb5b3da37 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -72,7 +72,7 @@ class Ratelimiter(object):
         return allowed, time_allowed
 
     def prune_message_counts(self, time_now_s):
-        for user_id in self.message_counts.keys():
+        for user_id in list(self.message_counts.keys()):
             message_count, time_start, msg_rate_hz = (
                 self.message_counts[user_id]
             )
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index cfc20dcccf..3f187adfc8 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -168,7 +168,8 @@ def setup_logging(config, use_worker_options=False):
         if log_file:
             # TODO: Customisable file size / backup count
             handler = logging.handlers.RotatingFileHandler(
-                log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
+                log_file, maxBytes=(1000 * 1000 * 100), backupCount=3,
+                encoding='utf8'
             )
 
             def sighup(signum, stack):
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index c20a32096a..e94400b8e2 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -18,7 +18,9 @@ import logging
 from canonicaljson import json
 
 from twisted.internet import defer, reactor
+from twisted.internet.error import ConnectError
 from twisted.internet.protocol import Factory
+from twisted.names.error import DomainError
 from twisted.web.http import HTTPClient
 
 from synapse.http.endpoint import matrix_federation_endpoint
@@ -47,12 +49,14 @@ def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1):
                 server_response, server_certificate = yield protocol.remote_key
                 defer.returnValue((server_response, server_certificate))
         except SynapseKeyClientError as e:
-            logger.exception("Error getting key for %r" % (server_name,))
+            logger.warn("Error getting key for %r: %s", server_name, e)
             if e.status.startswith("4"):
                 # Don't retry for 4xx responses.
                 raise IOError("Cannot get key for %r" % server_name)
+        except (ConnectError, DomainError) as e:
+            logger.warn("Error getting key for %r: %s", server_name, e)
         except Exception as e:
-            logger.exception(e)
+            logger.exception("Error getting key for %r", server_name)
     raise IOError("Cannot get key for %r" % server_name)
 
 
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index b4fbe2c9d5..1054441ca5 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -106,7 +106,7 @@ class TransportLayerClient(object):
             dest (str)
             room_id (str)
             event_tuples (list)
-            limt (int)
+            limit (int)
 
         Returns:
             Deferred: Results in a dict received from the remote homeserver.
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 77969a4f38..7a993fd1cf 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -261,10 +261,10 @@ class BaseFederationServlet(object):
             except NoAuthenticationError:
                 origin = None
                 if self.REQUIRE_AUTH:
-                    logger.exception("authenticate_request failed")
+                    logger.warn("authenticate_request failed: missing authentication")
                     raise
-            except Exception:
-                logger.exception("authenticate_request failed")
+            except Exception as e:
+                logger.warn("authenticate_request failed: %s", e)
                 raise
 
             if origin:
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 588e280571..72c2654678 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+import threading
 
 from prometheus_client.core import Counter, Histogram
 
@@ -111,6 +112,9 @@ in_flight_requests_db_sched_duration = Counter(
 # The set of all in flight requests, set[RequestMetrics]
 _in_flight_requests = set()
 
+# Protects the _in_flight_requests set from concurrent accesss
+_in_flight_requests_lock = threading.Lock()
+
 
 def _get_in_flight_counts():
     """Returns a count of all in flight requests by (method, server_name)
@@ -120,7 +124,8 @@ def _get_in_flight_counts():
     """
     # Cast to a list to prevent it changing while the Prometheus
     # thread is collecting metrics
-    reqs = list(_in_flight_requests)
+    with _in_flight_requests_lock:
+        reqs = list(_in_flight_requests)
 
     for rm in reqs:
         rm.update_metrics()
@@ -154,10 +159,12 @@ class RequestMetrics(object):
         # to the "in flight" metrics.
         self._request_stats = self.start_context.get_resource_usage()
 
-        _in_flight_requests.add(self)
+        with _in_flight_requests_lock:
+            _in_flight_requests.add(self)
 
     def stop(self, time_sec, request):
-        _in_flight_requests.discard(self)
+        with _in_flight_requests_lock:
+            _in_flight_requests.discard(self)
 
         context = LoggingContext.current_context()
 
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 69f7085291..a1e4b88e6d 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -29,7 +29,7 @@ def parse_integer(request, name, default=None, required=False):
 
     Args:
         request: the twisted HTTP request.
-        name (str): the name of the query parameter.
+        name (bytes/unicode): the name of the query parameter.
         default (int|None): value to use if the parameter is absent, defaults
             to None.
         required (bool): whether to raise a 400 SynapseError if the
@@ -46,6 +46,10 @@ def parse_integer(request, name, default=None, required=False):
 
 
 def parse_integer_from_args(args, name, default=None, required=False):
+
+    if not isinstance(name, bytes):
+        name = name.encode('ascii')
+
     if name in args:
         try:
             return int(args[name][0])
@@ -65,7 +69,7 @@ def parse_boolean(request, name, default=None, required=False):
 
     Args:
         request: the twisted HTTP request.
-        name (str): the name of the query parameter.
+        name (bytes/unicode): the name of the query parameter.
         default (bool|None): value to use if the parameter is absent, defaults
             to None.
         required (bool): whether to raise a 400 SynapseError if the
@@ -83,11 +87,15 @@ def parse_boolean(request, name, default=None, required=False):
 
 
 def parse_boolean_from_args(args, name, default=None, required=False):
+
+    if not isinstance(name, bytes):
+        name = name.encode('ascii')
+
     if name in args:
         try:
             return {
-                "true": True,
-                "false": False,
+                b"true": True,
+                b"false": False,
             }[args[name][0]]
         except Exception:
             message = (
@@ -104,21 +112,29 @@ def parse_boolean_from_args(args, name, default=None, required=False):
 
 
 def parse_string(request, name, default=None, required=False,
-                 allowed_values=None, param_type="string"):
-    """Parse a string parameter from the request query string.
+                 allowed_values=None, param_type="string", encoding='ascii'):
+    """
+    Parse a string parameter from the request query string.
+
+    If encoding is not None, the content of the query param will be
+    decoded to Unicode using the encoding, otherwise it will be encoded
 
     Args:
         request: the twisted HTTP request.
-        name (str): the name of the query parameter.
-        default (str|None): value to use if the parameter is absent, defaults
-            to None.
+        name (bytes/unicode): the name of the query parameter.
+        default (bytes/unicode|None): value to use if the parameter is absent,
+            defaults to None. Must be bytes if encoding is None.
         required (bool): whether to raise a 400 SynapseError if the
             parameter is absent, defaults to False.
-        allowed_values (list[str]): List of allowed values for the string,
-            or None if any value is allowed, defaults to None
+        allowed_values (list[bytes/unicode]): List of allowed values for the
+            string, or None if any value is allowed, defaults to None. Must be
+            the same type as name, if given.
+        encoding: The encoding to decode the name to, and decode the string
+            content with.
 
     Returns:
-        str|None: A string value or the default.
+        bytes/unicode|None: A string value or the default. Unicode if encoding
+        was given, bytes otherwise.
 
     Raises:
         SynapseError if the parameter is absent and required, or if the
@@ -126,14 +142,22 @@ def parse_string(request, name, default=None, required=False,
             is not one of those allowed values.
     """
     return parse_string_from_args(
-        request.args, name, default, required, allowed_values, param_type,
+        request.args, name, default, required, allowed_values, param_type, encoding
     )
 
 
 def parse_string_from_args(args, name, default=None, required=False,
-                           allowed_values=None, param_type="string"):
+                           allowed_values=None, param_type="string", encoding='ascii'):
+
+    if not isinstance(name, bytes):
+        name = name.encode('ascii')
+
     if name in args:
         value = args[name][0]
+
+        if encoding:
+            value = value.decode(encoding)
+
         if allowed_values is not None and value not in allowed_values:
             message = "Query parameter %r must be one of [%s]" % (
                 name, ", ".join(repr(v) for v in allowed_values)
@@ -146,6 +170,10 @@ def parse_string_from_args(args, name, default=None, required=False,
             message = "Missing %s query parameter %r" % (param_type, name)
             raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
         else:
+
+            if encoding and isinstance(default, bytes):
+                return default.decode(encoding)
+
             return default
 
 
diff --git a/synapse/http/site.py b/synapse/http/site.py
index f5a8f78406..88ed3714f9 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -182,7 +182,7 @@ class SynapseRequest(Request):
         # the client disconnects.
         with PreserveLoggingContext(self.logcontext):
             logger.warn(
-                "Error processing request: %s %s", reason.type, reason.value,
+                "Error processing request %r: %s %s", self, reason.type, reason.value,
             )
 
             if not self._is_processing:
@@ -219,6 +219,12 @@ class SynapseRequest(Request):
         """Log the completion of this request and update the metrics
         """
 
+        if self.logcontext is None:
+            # this can happen if the connection closed before we read the
+            # headers (so render was never called). In that case we'll already
+            # have logged a warning, so just bail out.
+            return
+
         usage = self.logcontext.get_resource_usage()
 
         if self._processing_finished_time is None:
@@ -235,7 +241,7 @@ class SynapseRequest(Request):
         # need to decode as it could be raw utf-8 bytes
         # from a IDN servname in an auth header
         authenticated_entity = self.authenticated_entity
-        if authenticated_entity is not None:
+        if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
             authenticated_entity = authenticated_entity.decode("utf-8", "replace")
 
         # ...or could be raw utf-8 bytes in the User-Agent header.
@@ -328,7 +334,7 @@ class SynapseSite(Site):
         proxied = config.get("x_forwarded", False)
         self.requestFactory = SynapseRequestFactory(self, proxied)
         self.access_logger = logging.getLogger(logger_name)
-        self.server_version_string = server_version_string
+        self.server_version_string = server_version_string.encode('ascii')
 
     def log(self, request):
         pass
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index ce678d5f75..167167be0a 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -13,6 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import threading
+
 import six
 
 from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily
@@ -78,6 +80,9 @@ _background_process_counts = dict()  # type: dict[str, int]
 # of process descriptions that no longer have any active processes.
 _background_processes = dict()  # type: dict[str, set[_BackgroundProcess]]
 
+# A lock that covers the above dicts
+_bg_metrics_lock = threading.Lock()
+
 
 class _Collector(object):
     """A custom metrics collector for the background process metrics.
@@ -92,7 +97,11 @@ class _Collector(object):
             labels=["name"],
         )
 
-        for desc, processes in six.iteritems(_background_processes):
+        # We copy the dict so that it doesn't change from underneath us
+        with _bg_metrics_lock:
+            _background_processes_copy = dict(_background_processes)
+
+        for desc, processes in six.iteritems(_background_processes_copy):
             background_process_in_flight_count.add_metric(
                 (desc,), len(processes),
             )
@@ -167,19 +176,26 @@ def run_as_background_process(desc, func, *args, **kwargs):
     """
     @defer.inlineCallbacks
     def run():
-        count = _background_process_counts.get(desc, 0)
-        _background_process_counts[desc] = count + 1
+        with _bg_metrics_lock:
+            count = _background_process_counts.get(desc, 0)
+            _background_process_counts[desc] = count + 1
+
         _background_process_start_count.labels(desc).inc()
 
         with LoggingContext(desc) as context:
             context.request = "%s-%i" % (desc, count)
             proc = _BackgroundProcess(desc, context)
-            _background_processes.setdefault(desc, set()).add(proc)
+
+            with _bg_metrics_lock:
+                _background_processes.setdefault(desc, set()).add(proc)
+
             try:
                 yield func(*args, **kwargs)
             finally:
                 proc.update_metrics()
-                _background_processes[desc].remove(proc)
+
+                with _bg_metrics_lock:
+                    _background_processes[desc].remove(proc)
 
     with PreserveLoggingContext():
         return run()
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 511e96ab00..48c17f1b6d 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -53,7 +53,7 @@ class HttpTransactionCache(object):
             str: A transaction key
         """
         token = self.auth.get_access_token_from_request(request)
-        return request.path + "/" + token
+        return request.path.decode('utf8') + "/" + token
 
     def fetch_or_execute_request(self, request, fn, *args, **kwargs):
         """A helper function for fetch_or_execute which extracts
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 147ff7d79b..7362e1858d 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -140,7 +140,7 @@ class ConsentResource(Resource):
         version = parse_string(request, "v",
                                default=self._default_consent_version)
         username = parse_string(request, "u", required=True)
-        userhmac = parse_string(request, "h", required=True)
+        userhmac = parse_string(request, "h", required=True, encoding=None)
 
         self._check_hash(username, userhmac)
 
@@ -175,7 +175,7 @@ class ConsentResource(Resource):
         """
         version = parse_string(request, "v", required=True)
         username = parse_string(request, "u", required=True)
-        userhmac = parse_string(request, "h", required=True)
+        userhmac = parse_string(request, "h", required=True, encoding=None)
 
         self._check_hash(username, userhmac)
 
@@ -210,9 +210,18 @@ class ConsentResource(Resource):
         finish_request(request)
 
     def _check_hash(self, userid, userhmac):
+        """
+        Args:
+            userid (unicode):
+            userhmac (bytes):
+
+        Raises:
+              SynapseError if the hash doesn't match
+
+        """
         want_mac = hmac.new(
             key=self._hmac_secret,
-            msg=userid,
+            msg=userid.encode('utf-8'),
             digestmod=sha256,
         ).hexdigest()
 
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 9b22d204a6..c1240e1963 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -55,7 +55,7 @@ class UploadResource(Resource):
         requester = yield self.auth.get_user_by_req(request)
         # TODO: The checks here are a bit late. The content will have
         # already been uploaded to a tmp file at this point
-        content_length = request.getHeader("Content-Length")
+        content_length = request.getHeader(b"Content-Length").decode('ascii')
         if content_length is None:
             raise SynapseError(
                 msg="Request must specify a Content-Length", code=400
@@ -66,10 +66,10 @@ class UploadResource(Resource):
                 code=413,
             )
 
-        upload_name = parse_string(request, "filename")
+        upload_name = parse_string(request, b"filename", encoding=None)
         if upload_name:
             try:
-                upload_name = upload_name.decode('UTF-8')
+                upload_name = upload_name.decode('utf8')
             except UnicodeDecodeError:
                 raise SynapseError(
                     msg="Invalid UTF-8 filename parameter: %r" % (upload_name),
@@ -78,8 +78,8 @@ class UploadResource(Resource):
 
         headers = request.requestHeaders
 
-        if headers.hasHeader("Content-Type"):
-            media_type = headers.getRawHeaders(b"Content-Type")[0]
+        if headers.hasHeader(b"Content-Type"):
+            media_type = headers.getRawHeaders(b"Content-Type")[0].decode('ascii')
         else:
             raise SynapseError(
                 msg="Upload request missing 'Content-Type'",
diff --git a/synapse/secrets.py b/synapse/secrets.py
index f05e9ea535..f6280f951c 100644
--- a/synapse/secrets.py
+++ b/synapse/secrets.py
@@ -38,4 +38,4 @@ else:
             return os.urandom(nbytes)
 
         def token_hex(self, nbytes=32):
-            return binascii.hexlify(self.token_bytes(nbytes))
+            return binascii.hexlify(self.token_bytes(nbytes)).decode('ascii')
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index dd03c4168b..4b971efdba 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -60,8 +60,43 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     def __init__(self, db_conn, hs):
         super(StateGroupWorkerStore, self).__init__(db_conn, hs)
 
+        # Originally the state store used a single DictionaryCache to cache the
+        # event IDs for the state types in a given state group to avoid hammering
+        # on the state_group* tables.
+        #
+        # The point of using a DictionaryCache is that it can cache a subset
+        # of the state events for a given state group (i.e. a subset of the keys for a
+        # given dict which is an entry in the cache for a given state group ID).
+        #
+        # However, this poses problems when performing complicated queries
+        # on the store - for instance: "give me all the state for this group, but
+        # limit members to this subset of users", as DictionaryCache's API isn't
+        # rich enough to say "please cache any of these fields, apart from this subset".
+        # This is problematic when lazy loading members, which requires this behaviour,
+        # as without it the cache has no choice but to speculatively load all
+        # state events for the group, which negates the efficiency being sought.
+        #
+        # Rather than overcomplicating DictionaryCache's API, we instead split the
+        # state_group_cache into two halves - one for tracking non-member events,
+        # and the other for tracking member_events.  This means that lazy loading
+        # queries can be made in a cache-friendly manner by querying both caches
+        # separately and then merging the result.  So for the example above, you
+        # would query the members cache for a specific subset of state keys
+        # (which DictionaryCache will handle efficiently and fine) and the non-members
+        # cache for all state (which DictionaryCache will similarly handle fine)
+        # and then just merge the results together.
+        #
+        # We size the non-members cache to be smaller than the members cache as the
+        # vast majority of state in Matrix (today) is member events.
+
         self._state_group_cache = DictionaryCache(
-            "*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache")
+            "*stateGroupCache*",
+            # TODO: this hasn't been tuned yet
+            50000 * get_cache_factor_for("stateGroupCache")
+        )
+        self._state_group_members_cache = DictionaryCache(
+            "*stateGroupMembersCache*",
+            500000 * get_cache_factor_for("stateGroupMembersCache")
         )
 
     @defer.inlineCallbacks
@@ -275,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         })
 
     @defer.inlineCallbacks
-    def _get_state_groups_from_groups(self, groups, types):
+    def _get_state_groups_from_groups(self, groups, types, members=None):
         """Returns the state groups for a given set of groups, filtering on
         types of state events.
 
@@ -284,6 +319,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             types (Iterable[str, str|None]|None): list of 2-tuples of the form
                 (`type`, `state_key`), where a `state_key` of `None` matches all
                 state_keys for the `type`. If None, all types are returned.
+            members (bool|None): If not None, then, in addition to any filtering
+                implied by types, the results are also filtered to only include
+                member events (if True), or to exclude member events (if False)
 
         Returns:
             dictionary state_group -> (dict of (type, state_key) -> event id)
@@ -294,14 +332,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         for chunk in chunks:
             res = yield self.runInteraction(
                 "_get_state_groups_from_groups",
-                self._get_state_groups_from_groups_txn, chunk, types,
+                self._get_state_groups_from_groups_txn, chunk, types, members,
             )
             results.update(res)
 
         defer.returnValue(results)
 
     def _get_state_groups_from_groups_txn(
-        self, txn, groups, types=None,
+        self, txn, groups, types=None, members=None,
     ):
         results = {group: {} for group in groups}
 
@@ -339,6 +377,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 %s
             """)
 
+            if members is True:
+                sql += " AND type = '%s'" % (EventTypes.Member,)
+            elif members is False:
+                sql += " AND type <> '%s'" % (EventTypes.Member,)
+
             # Turns out that postgres doesn't like doing a list of OR's and
             # is about 1000x slower, so we just issue a query for each specific
             # type seperately.
@@ -386,6 +429,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             else:
                 where_clause = ""
 
+            if members is True:
+                where_clause += " AND type = '%s'" % EventTypes.Member
+            elif members is False:
+                where_clause += " AND type <> '%s'" % EventTypes.Member
+
             # We don't use WITH RECURSIVE on sqlite3 as there are distributions
             # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
             for group in groups:
@@ -580,10 +628,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
 
-    def _get_some_state_from_cache(self, group, types, filtered_types=None):
+    def _get_some_state_from_cache(self, cache, group, types, filtered_types=None):
         """Checks if group is in cache. See `_get_state_for_groups`
 
         Args:
+            cache(DictionaryCache): the state group cache to use
             group(int): The state group to lookup
             types(list[str, str|None]): List of 2-tuples of the form
                 (`type`, `state_key`), where a `state_key` of `None` matches all
@@ -597,11 +646,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         requests state from the cache, if False we need to query the DB for the
         missing state.
         """
-        is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
+        is_all, known_absent, state_dict_ids = cache.get(group)
 
         type_to_key = {}
 
-        # tracks whether any of ourrequested types are missing from the cache
+        # tracks whether any of our requested types are missing from the cache
         missing_types = False
 
         for typ, state_key in types:
@@ -648,7 +697,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             if include(k[0], k[1])
         }, got_all
 
-    def _get_all_state_from_cache(self, group):
+    def _get_all_state_from_cache(self, cache, group):
         """Checks if group is in cache. See `_get_state_for_groups`
 
         Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
@@ -656,9 +705,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         cache, if False we need to query the DB for the missing state.
 
         Args:
+            cache(DictionaryCache): the state group cache to use
             group: The state group to lookup
         """
-        is_all, _, state_dict_ids = self._state_group_cache.get(group)
+        is_all, _, state_dict_ids = cache.get(group)
 
         return state_dict_ids, is_all
 
@@ -685,6 +735,62 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             Deferred[dict[int, dict[(type, state_key), EventBase]]]
                 a dictionary mapping from state group to state dictionary.
         """
+        if types is not None:
+            non_member_types = [t for t in types if t[0] != EventTypes.Member]
+
+            if filtered_types is not None and EventTypes.Member not in filtered_types:
+                # we want all of the membership events
+                member_types = None
+            else:
+                member_types = [t for t in types if t[0] == EventTypes.Member]
+
+        else:
+            non_member_types = None
+            member_types = None
+
+        non_member_state = yield self._get_state_for_groups_using_cache(
+            groups, self._state_group_cache, non_member_types, filtered_types,
+        )
+        # XXX: we could skip this entirely if member_types is []
+        member_state = yield self._get_state_for_groups_using_cache(
+            # we set filtered_types=None as member_state only ever contain members.
+            groups, self._state_group_members_cache, member_types, None,
+        )
+
+        state = non_member_state
+        for group in groups:
+            state[group].update(member_state[group])
+
+        defer.returnValue(state)
+
+    @defer.inlineCallbacks
+    def _get_state_for_groups_using_cache(
+        self, groups, cache, types=None, filtered_types=None
+    ):
+        """Gets the state at each of a list of state groups, optionally
+        filtering by type/state_key, querying from a specific cache.
+
+        Args:
+            groups (iterable[int]): list of state groups for which we want
+                to get the state.
+            cache (DictionaryCache): the cache of group ids to state dicts which
+                we will pass through - either the normal state cache or the specific
+                members state cache.
+            types (None|iterable[(str, None|str)]):
+                indicates the state type/keys required. If None, the whole
+                state is fetched and returned.
+
+                Otherwise, each entry should be a `(type, state_key)` tuple to
+                include in the response. A `state_key` of None is a wildcard
+                meaning that we require all state with that type.
+            filtered_types(list[str]|None): Only apply filtering via `types` to this
+                list of event types.  Other types of events are returned unfiltered.
+                If None, `types` filtering is applied to all events.
+
+        Returns:
+            Deferred[dict[int, dict[(type, state_key), EventBase]]]
+                a dictionary mapping from state group to state dictionary.
+        """
         if types:
             types = frozenset(types)
         results = {}
@@ -692,7 +798,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         if types is not None:
             for group in set(groups):
                 state_dict_ids, got_all = self._get_some_state_from_cache(
-                    group, types, filtered_types
+                    cache, group, types, filtered_types
                 )
                 results[group] = state_dict_ids
 
@@ -701,7 +807,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         else:
             for group in set(groups):
                 state_dict_ids, got_all = self._get_all_state_from_cache(
-                    group
+                    cache, group
                 )
 
                 results[group] = state_dict_ids
@@ -710,8 +816,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                     missing_groups.append(group)
 
         if missing_groups:
-            # Okay, so we have some missing_types, lets fetch them.
-            cache_seq_num = self._state_group_cache.sequence
+            # Okay, so we have some missing_types, let's fetch them.
+            cache_seq_num = cache.sequence
 
             # the DictionaryCache knows if it has *all* the state, but
             # does not know if it has all of the keys of a particular type,
@@ -725,7 +831,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                 types_to_fetch = types
 
             group_to_state_dict = yield self._get_state_groups_from_groups(
-                missing_groups, types_to_fetch
+                missing_groups, types_to_fetch, cache == self._state_group_members_cache,
             )
 
             for group, group_state_dict in iteritems(group_to_state_dict):
@@ -745,7 +851,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
                 # update the cache with all the things we fetched from the
                 # database.
-                self._state_group_cache.update(
+                cache.update(
                     cache_seq_num,
                     key=group,
                     value=group_state_dict,
@@ -847,15 +953,33 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
                     ],
                 )
 
-            # Prefill the state group cache with this group.
+            # Prefill the state group caches with this group.
             # It's fine to use the sequence like this as the state group map
             # is immutable. (If the map wasn't immutable then this prefill could
             # race with another update)
+
+            current_member_state_ids = {
+                s: ev
+                for (s, ev) in iteritems(current_state_ids)
+                if s[0] == EventTypes.Member
+            }
+            txn.call_after(
+                self._state_group_members_cache.update,
+                self._state_group_members_cache.sequence,
+                key=state_group,
+                value=dict(current_member_state_ids),
+            )
+
+            current_non_member_state_ids = {
+                s: ev
+                for (s, ev) in iteritems(current_state_ids)
+                if s[0] != EventTypes.Member
+            }
             txn.call_after(
                 self._state_group_cache.update,
                 self._state_group_cache.sequence,
                 key=state_group,
-                value=dict(current_state_ids),
+                value=dict(current_non_member_state_ids),
             )
 
             return state_group
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 07e83fadda..a0c2d37610 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -385,7 +385,13 @@ class LoggingContextFilter(logging.Filter):
         context = LoggingContext.current_context()
         for key, value in self.defaults.items():
             setattr(record, key, value)
-        context.copy_to(record)
+
+        # context should never be None, but if it somehow ends up being, then
+        # we end up in a death spiral of infinite loops, so let's check, for
+        # robustness' sake.
+        if context is not None:
+            context.copy_to(record)
+
         return True
 
 
@@ -396,7 +402,9 @@ class PreserveLoggingContext(object):
 
     __slots__ = ["current_context", "new_context", "has_parent"]
 
-    def __init__(self, new_context=LoggingContext.sentinel):
+    def __init__(self, new_context=None):
+        if new_context is None:
+            new_context = LoggingContext.sentinel
         self.new_context = new_context
 
     def __enter__(self):
diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py
index 62a00189cc..ef31458226 100644
--- a/synapse/util/logutils.py
+++ b/synapse/util/logutils.py
@@ -20,6 +20,8 @@ import time
 from functools import wraps
 from inspect import getcallargs
 
+from six import PY3
+
 _TIME_FUNC_ID = 0
 
 
@@ -28,8 +30,12 @@ def _log_debug_as_f(f, msg, msg_args):
     logger = logging.getLogger(name)
 
     if logger.isEnabledFor(logging.DEBUG):
-        lineno = f.func_code.co_firstlineno
-        pathname = f.func_code.co_filename
+        if PY3:
+            lineno = f.__code__.co_firstlineno
+            pathname = f.__code__.co_filename
+        else:
+            lineno = f.func_code.co_firstlineno
+            pathname = f.func_code.co_filename
 
         record = logging.LogRecord(
             name=name,
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 43d9db67ec..6f318c6a29 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -16,6 +16,7 @@
 import random
 import string
 
+from six import PY3
 from six.moves import range
 
 _string_with_symbols = (
@@ -34,6 +35,17 @@ def random_string_with_symbols(length):
 
 
 def is_ascii(s):
+
+    if PY3:
+        if isinstance(s, bytes):
+            try:
+                s.decode('ascii').encode('ascii')
+            except UnicodeDecodeError:
+                return False
+            except UnicodeEncodeError:
+                return False
+            return True
+
     try:
         s.encode("ascii")
     except UnicodeEncodeError:
@@ -49,6 +61,9 @@ def to_ascii(s):
 
     If given None then will return None.
     """
+    if PY3:
+        return s
+
     if s is None:
         return None
 
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index 1fbcd41115..3baba3225a 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -30,7 +30,7 @@ def get_version_string(module):
                 ['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
                 stderr=null,
                 cwd=cwd,
-            ).strip()
+            ).strip().decode('ascii')
             git_branch = "b=" + git_branch
         except subprocess.CalledProcessError:
             git_branch = ""
@@ -40,7 +40,7 @@ def get_version_string(module):
                 ['git', 'describe', '--exact-match'],
                 stderr=null,
                 cwd=cwd,
-            ).strip()
+            ).strip().decode('ascii')
             git_tag = "t=" + git_tag
         except subprocess.CalledProcessError:
             git_tag = ""
@@ -50,7 +50,7 @@ def get_version_string(module):
                 ['git', 'rev-parse', '--short', 'HEAD'],
                 stderr=null,
                 cwd=cwd,
-            ).strip()
+            ).strip().decode('ascii')
         except subprocess.CalledProcessError:
             git_commit = ""
 
@@ -60,7 +60,7 @@ def get_version_string(module):
                 ['git', 'describe', '--dirty=' + dirty_string],
                 stderr=null,
                 cwd=cwd,
-            ).strip().endswith(dirty_string)
+            ).strip().decode('ascii').endswith(dirty_string)
 
             git_dirty = "dirty" if is_dirty else ""
         except subprocess.CalledProcessError:
@@ -77,8 +77,8 @@ def get_version_string(module):
                 "%s (%s)" % (
                     module.__version__, git_version,
                 )
-            ).encode("ascii")
+            )
     except Exception as e:
         logger.info("Failed to check for git repository: %s", e)
 
-    return module.__version__.encode("ascii")
+    return module.__version__