summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-02-19 10:16:46 +0000
committerErik Johnston <erik@matrix.org>2020-02-19 10:16:46 +0000
commit93a0751302696c748a38f59bbc8e207c4cdbe10d (patch)
tree43383fc3f1910774a894e2339253fd810bac87bc
parentMerge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes (diff)
parentIncrease DB/CPU perf of `_is_server_still_joined` check. (#6936) (diff)
downloadsynapse-93a0751302696c748a38f59bbc8e207c4cdbe10d.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes
-rw-r--r--changelog.d/6872.misc1
-rw-r--r--changelog.d/6907.doc1
-rw-r--r--changelog.d/6936.misc1
-rw-r--r--changelog.d/6939.feature1
-rw-r--r--changelog.d/6945.bugfix1
-rw-r--r--changelog.d/6947.misc1
-rw-r--r--docs/sample_config.yaml5
-rw-r--r--synapse/config/tls.py19
-rw-r--r--synapse/events/__init__.py199
-rw-r--r--synapse/handlers/acme.py16
-rw-r--r--synapse/handlers/directory.py17
-rw-r--r--synapse/handlers/pagination.py2
-rw-r--r--synapse/rest/client/v1/room.py23
-rw-r--r--synapse/storage/data_stores/main/event_federation.py41
-rw-r--r--synapse/storage/data_stores/main/roommember.py31
-rw-r--r--synapse/storage/persist_events.py43
-rw-r--r--tests/rest/client/v1/test_rooms.py70
-rw-r--r--tests/storage/test_redaction.py2
-rw-r--r--tests/unittest.py28
19 files changed, 374 insertions, 128 deletions
diff --git a/changelog.d/6872.misc b/changelog.d/6872.misc
new file mode 100644
index 0000000000..215a0c82c3
--- /dev/null
+++ b/changelog.d/6872.misc
@@ -0,0 +1 @@
+Refactor _EventInternalMetadata object to improve type safety.
diff --git a/changelog.d/6907.doc b/changelog.d/6907.doc
new file mode 100644
index 0000000000..be0e698af8
--- /dev/null
+++ b/changelog.d/6907.doc
@@ -0,0 +1 @@
+Update Synapse's documentation to warn about the deprecation of ACME v1.
diff --git a/changelog.d/6936.misc b/changelog.d/6936.misc
new file mode 100644
index 0000000000..9400725017
--- /dev/null
+++ b/changelog.d/6936.misc
@@ -0,0 +1 @@
+Increase DB/CPU perf of `_is_server_still_joined` check.
diff --git a/changelog.d/6939.feature b/changelog.d/6939.feature
new file mode 100644
index 0000000000..40fe7fc9a9
--- /dev/null
+++ b/changelog.d/6939.feature
@@ -0,0 +1 @@
+Implement `GET /_matrix/client/r0/rooms/{roomId}/aliases` endpoint as per [MSC2432](https://github.com/matrix-org/matrix-doc/pull/2432).
diff --git a/changelog.d/6945.bugfix b/changelog.d/6945.bugfix
new file mode 100644
index 0000000000..8561be16a4
--- /dev/null
+++ b/changelog.d/6945.bugfix
@@ -0,0 +1 @@
+Fix errors from logging in the purge jobs related to the message retention policies support.
diff --git a/changelog.d/6947.misc b/changelog.d/6947.misc
new file mode 100644
index 0000000000..6d00e58654
--- /dev/null
+++ b/changelog.d/6947.misc
@@ -0,0 +1 @@
+Increase perf of `get_auth_chain_ids` used in state res v2.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 93236daddc..8a036071e1 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -476,6 +476,11 @@ retention:
 # ACME support: This will configure Synapse to request a valid TLS certificate
 # for your configured `server_name` via Let's Encrypt.
 #
+# Note that ACME v1 is now deprecated, and Synapse currently doesn't support
+# ACME v2. This means that this feature currently won't work with installs set
+# up after November 2019. For more info, and alternative solutions, see
+# https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1
+#
 # Note that provisioning a certificate in this way requires port 80 to be
 # routed to Synapse so that it can complete the http-01 ACME challenge.
 # By default, if you enable ACME support, Synapse will attempt to listen on
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 2514b0713d..97a12d51f6 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -32,6 +32,17 @@ from synapse.util import glob_to_regex
 
 logger = logging.getLogger(__name__)
 
+ACME_SUPPORT_ENABLED_WARN = """\
+This server uses Synapse's built-in ACME support. Note that ACME v1 has been
+deprecated by Let's Encrypt, and that Synapse doesn't currently support ACME v2,
+which means that this feature will not work with Synapse installs set up after
+November 2019, and that it may stop working on June 2020 for installs set up
+before that date.
+
+For more info and alternative solutions, see
+https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1
+--------------------------------------------------------------------------------"""
+
 
 class TlsConfig(Config):
     section = "tls"
@@ -44,6 +55,9 @@ class TlsConfig(Config):
 
         self.acme_enabled = acme_config.get("enabled", False)
 
+        if self.acme_enabled:
+            logger.warning(ACME_SUPPORT_ENABLED_WARN)
+
         # hyperlink complains on py2 if this is not a Unicode
         self.acme_url = six.text_type(
             acme_config.get("url", "https://acme-v01.api.letsencrypt.org/directory")
@@ -362,6 +376,11 @@ class TlsConfig(Config):
         # ACME support: This will configure Synapse to request a valid TLS certificate
         # for your configured `server_name` via Let's Encrypt.
         #
+        # Note that ACME v1 is now deprecated, and Synapse currently doesn't support
+        # ACME v2. This means that this feature currently won't work with installs set
+        # up after November 2019. For more info, and alternative solutions, see
+        # https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1
+        #
         # Note that provisioning a certificate in this way requires port 80 to be
         # routed to Synapse so that it can complete the http-01 ACME challenge.
         # By default, if you enable ACME support, Synapse will attempt to listen on
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index a842661a90..7307116556 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -1,6 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2019 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -37,34 +38,115 @@ from synapse.util.frozenutils import freeze
 USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
 
 
+class DictProperty:
+    """An object property which delegates to the `_dict` within its parent object."""
+
+    __slots__ = ["key"]
+
+    def __init__(self, key: str):
+        self.key = key
+
+    def __get__(self, instance, owner=None):
+        # if the property is accessed as a class property rather than an instance
+        # property, return the property itself rather than the value
+        if instance is None:
+            return self
+        try:
+            return instance._dict[self.key]
+        except KeyError as e1:
+            # We want this to look like a regular attribute error (mostly so that
+            # hasattr() works correctly), so we convert the KeyError into an
+            # AttributeError.
+            #
+            # To exclude the KeyError from the traceback, we explicitly
+            # 'raise from e1.__context__' (which is better than 'raise from None',
+            # becuase that would omit any *earlier* exceptions).
+            #
+            raise AttributeError(
+                "'%s' has no '%s' property" % (type(instance), self.key)
+            ) from e1.__context__
+
+    def __set__(self, instance, v):
+        instance._dict[self.key] = v
+
+    def __delete__(self, instance):
+        try:
+            del instance._dict[self.key]
+        except KeyError as e1:
+            raise AttributeError(
+                "'%s' has no '%s' property" % (type(instance), self.key)
+            ) from e1.__context__
+
+
+class DefaultDictProperty(DictProperty):
+    """An extension of DictProperty which provides a default if the property is
+    not present in the parent's _dict.
+
+    Note that this means that hasattr() on the property always returns True.
+    """
+
+    __slots__ = ["default"]
+
+    def __init__(self, key, default):
+        super().__init__(key)
+        self.default = default
+
+    def __get__(self, instance, owner=None):
+        if instance is None:
+            return self
+        return instance._dict.get(self.key, self.default)
+
+
 class _EventInternalMetadata(object):
-    def __init__(self, internal_metadata_dict):
-        self.__dict__ = dict(internal_metadata_dict)
+    __slots__ = ["_dict"]
+
+    def __init__(self, internal_metadata_dict: JsonDict):
+        # we have to copy the dict, because it turns out that the same dict is
+        # reused. TODO: fix that
+        self._dict = dict(internal_metadata_dict)
+
+    outlier = DictProperty("outlier")  # type: bool
+    out_of_band_membership = DictProperty("out_of_band_membership")  # type: bool
+    send_on_behalf_of = DictProperty("send_on_behalf_of")  # type: str
+    recheck_redaction = DictProperty("recheck_redaction")  # type: bool
+    soft_failed = DictProperty("soft_failed")  # type: bool
+    proactively_send = DictProperty("proactively_send")  # type: bool
+    redacted = DictProperty("redacted")  # type: bool
+    txn_id = DictProperty("txn_id")  # type: str
+    token_id = DictProperty("token_id")  # type: str
+    stream_ordering = DictProperty("stream_ordering")  # type: int
+
+    # XXX: These are set by StreamWorkerStore._set_before_and_after.
+    # I'm pretty sure that these are never persisted to the database, so shouldn't
+    # be here
+    before = DictProperty("before")  # type: str
+    after = DictProperty("after")  # type: str
+    order = DictProperty("order")  # type: int
 
-    def get_dict(self):
-        return dict(self.__dict__)
+    def get_dict(self) -> JsonDict:
+        return dict(self._dict)
 
-    def is_outlier(self):
-        return getattr(self, "outlier", False)
+    def is_outlier(self) -> bool:
+        return self._dict.get("outlier", False)
 
-    def is_out_of_band_membership(self):
+    def is_out_of_band_membership(self) -> bool:
         """Whether this is an out of band membership, like an invite or an invite
         rejection. This is needed as those events are marked as outliers, but
         they still need to be processed as if they're new events (e.g. updating
         invite state in the database, relaying to clients, etc).
         """
-        return getattr(self, "out_of_band_membership", False)
+        return self._dict.get("out_of_band_membership", False)
 
-    def get_send_on_behalf_of(self):
+    def get_send_on_behalf_of(self) -> Optional[str]:
         """Whether this server should send the event on behalf of another server.
         This is used by the federation "send_join" API to forward the initial join
         event for a server in the room.
 
         returns a str with the name of the server this event is sent on behalf of.
         """
-        return getattr(self, "send_on_behalf_of", None)
+        return self._dict.get("send_on_behalf_of")
 
-    def need_to_check_redaction(self):
+    def need_to_check_redaction(self) -> bool:
         """Whether the redaction event needs to be rechecked when fetching
         from the database.
 
@@ -77,9 +159,9 @@ class _EventInternalMetadata(object):
         Returns:
             bool
         """
-        return getattr(self, "recheck_redaction", False)
+        return self._dict.get("recheck_redaction", False)
 
-    def is_soft_failed(self):
+    def is_soft_failed(self) -> bool:
         """Whether the event has been soft failed.
 
         Soft failed events should be handled as usual, except:
@@ -91,7 +173,7 @@ class _EventInternalMetadata(object):
         Returns:
             bool
         """
-        return getattr(self, "soft_failed", False)
+        return self._dict.get("soft_failed", False)
 
     def should_proactively_send(self):
         """Whether the event, if ours, should be sent to other clients and
@@ -103,7 +185,7 @@ class _EventInternalMetadata(object):
         Returns:
             bool
         """
-        return getattr(self, "proactively_send", True)
+        return self._dict.get("proactively_send", True)
 
     def is_redacted(self):
         """Whether the event has been redacted.
@@ -114,52 +196,7 @@ class _EventInternalMetadata(object):
         Returns:
             bool
         """
-        return getattr(self, "redacted", False)
-
-
-_SENTINEL = object()
-
-
-def _event_dict_property(key, default=_SENTINEL):
-    """Creates a new property for the given key that delegates access to
-    `self._event_dict`.
-
-    The default is used if the key is missing from the `_event_dict`, if given,
-    otherwise an AttributeError will be raised.
-
-    Note: If a default is given then `hasattr` will always return true.
-    """
-
-    # We want to be able to use hasattr with the event dict properties.
-    # However, (on python3) hasattr expects AttributeError to be raised. Hence,
-    # we need to transform the KeyError into an AttributeError
-
-    def getter_raises(self):
-        try:
-            return self._event_dict[key]
-        except KeyError:
-            raise AttributeError(key)
-
-    def getter_default(self):
-        return self._event_dict.get(key, default)
-
-    def setter(self, v):
-        try:
-            self._event_dict[key] = v
-        except KeyError:
-            raise AttributeError(key)
-
-    def delete(self):
-        try:
-            del self._event_dict[key]
-        except KeyError:
-            raise AttributeError(key)
-
-    if default is _SENTINEL:
-        # No default given, so use the getter that raises
-        return property(getter_raises, setter, delete)
-    else:
-        return property(getter_default, setter, delete)
+        return self._dict.get("redacted", False)
 
 
 class EventBase(object):
@@ -175,23 +212,23 @@ class EventBase(object):
         self.unsigned = unsigned
         self.rejected_reason = rejected_reason
 
-        self._event_dict = event_dict
+        self._dict = event_dict
 
         self.internal_metadata = _EventInternalMetadata(internal_metadata_dict)
 
-    auth_events = _event_dict_property("auth_events")
-    depth = _event_dict_property("depth")
-    content = _event_dict_property("content")
-    hashes = _event_dict_property("hashes")
-    origin = _event_dict_property("origin")
-    origin_server_ts = _event_dict_property("origin_server_ts")
-    prev_events = _event_dict_property("prev_events")
-    redacts = _event_dict_property("redacts", None)
-    room_id = _event_dict_property("room_id")
-    sender = _event_dict_property("sender")
-    state_key = _event_dict_property("state_key")
-    type = _event_dict_property("type")
-    user_id = _event_dict_property("sender")
+    auth_events = DictProperty("auth_events")
+    depth = DictProperty("depth")
+    content = DictProperty("content")
+    hashes = DictProperty("hashes")
+    origin = DictProperty("origin")
+    origin_server_ts = DictProperty("origin_server_ts")
+    prev_events = DictProperty("prev_events")
+    redacts = DefaultDictProperty("redacts", None)
+    room_id = DictProperty("room_id")
+    sender = DictProperty("sender")
+    state_key = DictProperty("state_key")
+    type = DictProperty("type")
+    user_id = DictProperty("sender")
 
     @property
     def event_id(self) -> str:
@@ -205,13 +242,13 @@ class EventBase(object):
         return hasattr(self, "state_key") and self.state_key is not None
 
     def get_dict(self) -> JsonDict:
-        d = dict(self._event_dict)
+        d = dict(self._dict)
         d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)})
 
         return d
 
     def get(self, key, default=None):
-        return self._event_dict.get(key, default)
+        return self._dict.get(key, default)
 
     def get_internal_metadata_dict(self):
         return self.internal_metadata.get_dict()
@@ -233,16 +270,16 @@ class EventBase(object):
         raise AttributeError("Unrecognized attribute %s" % (instance,))
 
     def __getitem__(self, field):
-        return self._event_dict[field]
+        return self._dict[field]
 
     def __contains__(self, field):
-        return field in self._event_dict
+        return field in self._dict
 
     def items(self):
-        return list(self._event_dict.items())
+        return list(self._dict.items())
 
     def keys(self):
-        return six.iterkeys(self._event_dict)
+        return six.iterkeys(self._dict)
 
     def prev_event_ids(self):
         """Returns the list of prev event IDs. The order matches the order
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 46ac73106d..250faa997b 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -25,6 +25,15 @@ from synapse.app import check_bind_error
 
 logger = logging.getLogger(__name__)
 
+ACME_REGISTER_FAIL_ERROR = """
+--------------------------------------------------------------------------------
+Failed to register with the ACME provider. This is likely happening because the install
+is new, and ACME v1 has been deprecated by Let's Encrypt and is disabled for installs set
+up after November 2019.
+At the moment, Synapse doesn't support ACME v2. For more info and alternative solution,
+check out https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1
+--------------------------------------------------------------------------------"""
+
 
 class AcmeHandler(object):
     def __init__(self, hs):
@@ -71,7 +80,12 @@ class AcmeHandler(object):
         # want it to control where we save the certificates, we have to reach in
         # and trigger the registration machinery ourselves.
         self._issuer._registered = False
-        yield self._issuer._ensure_registered()
+
+        try:
+            yield self._issuer._ensure_registered()
+        except Exception:
+            logger.error(ACME_REGISTER_FAIL_ERROR)
+            raise
 
     @defer.inlineCallbacks
     def provision_certificate(self):
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index f718388884..3f8c792149 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -16,6 +16,7 @@
 
 import logging
 import string
+from typing import List
 
 from twisted.internet import defer
 
@@ -28,7 +29,7 @@ from synapse.api.errors import (
     StoreError,
     SynapseError,
 )
-from synapse.types import RoomAlias, UserID, get_domain_from_id
+from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id
 
 from ._base import BaseHandler
 
@@ -452,3 +453,17 @@ class DirectoryHandler(BaseHandler):
         yield self.store.set_room_is_public_appservice(
             room_id, appservice_id, network_id, visibility == "public"
         )
+
+    async def get_aliases_for_room(
+        self, requester: Requester, room_id: str
+    ) -> List[str]:
+        """
+        Get a list of the aliases that currently point to this room on this server
+        """
+        # allow access to server admins and current members of the room
+        is_admin = await self.auth.is_server_admin(requester.user)
+        if not is_admin:
+            await self.auth.check_joined_room(room_id, requester.user.to_string())
+
+        aliases = await self.store.get_aliases_for_room(room_id)
+        return aliases
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index caf841a643..9bf6d39668 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -133,7 +133,7 @@ class PaginationHandler(object):
             include_null = False
 
         logger.info(
-            "[purge] Running purge job for %d < max_lifetime <= %d (include NULLs = %s)",
+            "[purge] Running purge job for %s < max_lifetime <= %s (include NULLs = %s)",
             min_ms,
             max_ms,
             include_null,
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 6f31584c51..143dc738c6 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -45,6 +45,10 @@ from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
 from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
 
+MYPY = False
+if MYPY:
+    import synapse.server
+
 logger = logging.getLogger(__name__)
 
 
@@ -843,6 +847,24 @@ class RoomTypingRestServlet(RestServlet):
         return 200, {}
 
 
+class RoomAliasListServlet(RestServlet):
+    PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/aliases", unstable=False)
+
+    def __init__(self, hs: "synapse.server.HomeServer"):
+        super().__init__()
+        self.auth = hs.get_auth()
+        self.directory_handler = hs.get_handlers().directory_handler
+
+    async def on_GET(self, request, room_id):
+        requester = await self.auth.get_user_by_req(request)
+
+        alias_list = await self.directory_handler.get_aliases_for_room(
+            requester, room_id
+        )
+
+        return 200, {"aliases": alias_list}
+
+
 class SearchRestServlet(RestServlet):
     PATTERNS = client_patterns("/search$", v1=True)
 
@@ -931,6 +953,7 @@ def register_servlets(hs, http_server):
     JoinedRoomsRestServlet(hs).register(http_server)
     RoomEventServlet(hs).register(http_server)
     RoomEventContextServlet(hs).register(http_server)
+    RoomAliasListServlet(hs).register(http_server)
 
 
 def register_deprecated_servlets(hs, http_server):
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 1746f40adf..dcc375b840 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -62,32 +62,37 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         )
 
     def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
+        if include_given:
+            results = set(event_ids)
+        else:
+            results = set()
+
         if isinstance(self.database_engine, PostgresEngine):
             # For efficiency we make the database do this if we can.
-            sql = """
-                WITH RECURSIVE auth_chain(event_id) AS (
-                    SELECT auth_id FROM event_auth WHERE event_id = ANY(?)
-                    UNION
-                    SELECT auth_id FROM event_auth
-                    INNER JOIN auth_chain USING (event_id)
-                )
-                SELECT event_id FROM auth_chain
-            """
-            txn.execute(sql, (list(event_ids),))
-
-            results = set(event_id for event_id, in txn)
 
-            if include_given:
-                results.update(event_ids)
+            # We need to be a little careful with querying large amounts at
+            # once, for some reason postgres really doesn't like it. We do this
+            # by only asking for auth chain of 500 events at a time.
+            event_ids = list(event_ids)
+            chunks = [event_ids[x : x + 500] for x in range(0, len(event_ids), 500)]
+            for chunk in chunks:
+                sql = """
+                    WITH RECURSIVE auth_chain(event_id) AS (
+                        SELECT auth_id FROM event_auth WHERE event_id = ANY(?)
+                        UNION
+                        SELECT auth_id FROM event_auth
+                        INNER JOIN auth_chain USING (event_id)
+                    )
+                    SELECT event_id FROM auth_chain
+                """
+                txn.execute(sql, (chunk,))
+
+                results.update(event_id for event_id, in txn)
 
             return list(results)
 
         # Database doesn't necessarily support recursive CTE, so we fall
         # back to do doing it manually.
-        if include_given:
-            results = set(event_ids)
-        else:
-            results = set()
 
         base_sql = "SELECT auth_id FROM event_auth WHERE "
 
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 042289f0e0..d5ced05701 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -868,6 +868,37 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             desc="get_membership_from_event_ids",
         )
 
+    async def is_local_host_in_room_ignoring_users(
+        self, room_id: str, ignore_users: Collection[str]
+    ) -> bool:
+        """Check if there are any local users, excluding those in the given
+        list, in the room.
+        """
+
+        clause, args = make_in_list_sql_clause(
+            self.database_engine, "user_id", ignore_users
+        )
+
+        sql = """
+            SELECT 1 FROM local_current_membership
+            WHERE
+                room_id = ? AND membership = ?
+                AND NOT (%s)
+                LIMIT 1
+        """ % (
+            clause,
+        )
+
+        def _is_local_host_in_room_ignoring_users_txn(txn):
+            txn.execute(sql, (room_id, Membership.JOIN, *args))
+
+            return bool(txn.fetchone())
+
+        return await self.db.runInteraction(
+            "is_local_host_in_room_ignoring_users",
+            _is_local_host_in_room_ignoring_users_txn,
+        )
+
 
 class RoomMemberBackgroundUpdateStore(SQLBaseStore):
     def __init__(self, database: Database, db_conn, hs):
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index a5370ed527..b950550f23 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -727,6 +727,7 @@ class EventsPersistenceStorage(object):
 
         # Check if any of the given events are a local join that appear in the
         # current state
+        events_to_check = []  # Event IDs that aren't an event we're persisting
         for (typ, state_key), event_id in delta.to_insert.items():
             if typ != EventTypes.Member or not self.is_mine_id(state_key):
                 continue
@@ -736,8 +737,33 @@ class EventsPersistenceStorage(object):
                     if event.membership == Membership.JOIN:
                         return True
 
-        # There's been a change of membership but we don't have a local join
-        # event in the new events, so we need to check the full state.
+            # The event is not in `ev_ctx_rm`, so we need to pull it out of
+            # the DB.
+            events_to_check.append(event_id)
+
+        # Check if any of the changes that we don't have events for are joins.
+        if events_to_check:
+            rows = await self.main_store.get_membership_from_event_ids(events_to_check)
+            is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
+            if is_still_joined:
+                return True
+
+        # None of the new state events are local joins, so we check the database
+        # to see if there are any other local users in the room. We ignore users
+        # whose state has changed as we've already their new state above.
+        users_to_ignore = [
+            state_key
+            for _, state_key in itertools.chain(delta.to_insert, delta.to_delete)
+            if self.is_mine_id(state_key)
+        ]
+
+        if await self.main_store.is_local_host_in_room_ignoring_users(
+            room_id, users_to_ignore
+        ):
+            return True
+
+        # The server will leave the room, so we go and find out which remote
+        # users will still be joined when we leave.
         if current_state is None:
             current_state = await self.main_store.get_current_state_ids(room_id)
             current_state = dict(current_state)
@@ -746,19 +772,6 @@ class EventsPersistenceStorage(object):
 
             current_state.update(delta.to_insert)
 
-        event_ids = [
-            event_id
-            for (typ, state_key,), event_id in current_state.items()
-            if typ == EventTypes.Member and self.is_mine_id(state_key)
-        ]
-
-        rows = await self.main_store.get_membership_from_event_ids(event_ids)
-        is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
-        if is_still_joined:
-            return True
-
-        # The server will leave the room, so we go and find out which remote
-        # users will still be joined when we leave.
         remote_event_ids = [
             event_id
             for (typ, state_key,), event_id in current_state.items()
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index fb681a1db9..fb08a45d27 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -28,8 +28,9 @@ from twisted.internet import defer
 import synapse.rest.admin
 from synapse.api.constants import EventContentFields, EventTypes, Membership
 from synapse.handlers.pagination import PurgeStatus
-from synapse.rest.client.v1 import login, profile, room
+from synapse.rest.client.v1 import directory, login, profile, room
 from synapse.rest.client.v2_alpha import account
+from synapse.types import JsonDict, RoomAlias
 from synapse.util.stringutils import random_string
 
 from tests import unittest
@@ -1726,3 +1727,70 @@ class ContextTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(events_after), 2, events_after)
         self.assertDictEqual(events_after[0].get("content"), {}, events_after[0])
         self.assertEqual(events_after[1].get("content"), {}, events_after[1])
+
+
+class DirectoryTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        directory.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, homeserver):
+        self.room_owner = self.register_user("room_owner", "test")
+        self.room_owner_tok = self.login("room_owner", "test")
+
+        self.room_id = self.helper.create_room_as(
+            self.room_owner, tok=self.room_owner_tok
+        )
+
+    def test_no_aliases(self):
+        res = self._get_aliases(self.room_owner_tok)
+        self.assertEqual(res["aliases"], [])
+
+    def test_not_in_room(self):
+        self.register_user("user", "test")
+        user_tok = self.login("user", "test")
+        res = self._get_aliases(user_tok, expected_code=403)
+        self.assertEqual(res["errcode"], "M_FORBIDDEN")
+
+    def test_with_aliases(self):
+        alias1 = self._random_alias()
+        alias2 = self._random_alias()
+
+        self._set_alias_via_directory(alias1)
+        self._set_alias_via_directory(alias2)
+
+        res = self._get_aliases(self.room_owner_tok)
+        self.assertEqual(set(res["aliases"]), {alias1, alias2})
+
+    def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict:
+        """Calls the endpoint under test. returns the json response object."""
+        request, channel = self.make_request(
+            "GET",
+            "/_matrix/client/r0/rooms/%s/aliases" % (self.room_id,),
+            access_token=access_token,
+        )
+        self.render(request)
+        self.assertEqual(channel.code, expected_code, channel.result)
+        res = channel.json_body
+        self.assertIsInstance(res, dict)
+        if expected_code == 200:
+            self.assertIsInstance(res["aliases"], list)
+        return res
+
+    def _random_alias(self) -> str:
+        return RoomAlias(random_string(5), self.hs.hostname).to_string()
+
+    def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
+        url = "/_matrix/client/r0/directory/room/" + alias
+        data = {"room_id": self.room_id}
+        request_data = json.dumps(data)
+
+        request, channel = self.make_request(
+            "PUT", url, request_data, access_token=self.room_owner_tok
+        )
+        self.render(request)
+        self.assertEqual(channel.code, expected_code, channel.result)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index b9ee6ec1ec..db3667dc43 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -240,7 +240,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
                 built_event = yield self._base_builder.build(prev_event_ids)
 
                 built_event._event_id = self._event_id
-                built_event._event_dict["event_id"] = self._event_id
+                built_event._dict["event_id"] = self._event_id
                 assert built_event.event_id == self._event_id
 
                 return built_event
diff --git a/tests/unittest.py b/tests/unittest.py
index 98bf27d39c..8816a4d152 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -21,6 +21,7 @@ import hmac
 import inspect
 import logging
 import time
+from typing import Optional, Tuple, Type, TypeVar, Union
 
 from mock import Mock
 
@@ -42,7 +43,13 @@ from synapse.server import HomeServer
 from synapse.types import Requester, UserID, create_requester
 from synapse.util.ratelimitutils import FederationRateLimiter
 
-from tests.server import get_clock, make_request, render, setup_test_homeserver
+from tests.server import (
+    FakeChannel,
+    get_clock,
+    make_request,
+    render,
+    setup_test_homeserver,
+)
 from tests.test_utils.logging_setup import setup_logging
 from tests.utils import default_config, setupdb
 
@@ -71,6 +78,9 @@ def around(target):
     return _around
 
 
+T = TypeVar("T")
+
+
 class TestCase(unittest.TestCase):
     """A subclass of twisted.trial's TestCase which looks for 'loglevel'
     attributes on both itself and its individual test methods, to override the
@@ -334,14 +344,14 @@ class HomeserverTestCase(TestCase):
 
     def make_request(
         self,
-        method,
-        path,
-        content=b"",
-        access_token=None,
-        request=SynapseRequest,
-        shorthand=True,
-        federation_auth_origin=None,
-    ):
+        method: Union[bytes, str],
+        path: Union[bytes, str],
+        content: Union[bytes, dict] = b"",
+        access_token: Optional[str] = None,
+        request: Type[T] = SynapseRequest,
+        shorthand: bool = True,
+        federation_auth_origin: str = None,
+    ) -> Tuple[T, FakeChannel]:
         """
         Create a SynapseRequest at the path using the method and containing the
         given content.