summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/app/homeserver.py4
-rw-r--r--synapse/events/snapshot.py159
-rw-r--r--synapse/handlers/federation.py7
-rw-r--r--synapse/handlers/register.py50
-rw-r--r--synapse/replication/http/register.py2
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py21
-rw-r--r--synapse/state/__init__.py175
-rw-r--r--synapse/storage/data_stores/main/registration.py8
-rw-r--r--synapse/storage/data_stores/main/state.py2
-rw-r--r--synapse/util/caches/descriptors.py48
11 files changed, 321 insertions, 157 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 8587ffa76f..ec16f54a49 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -36,7 +36,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.5.0"
+__version__ = "1.5.1"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 00a7f8330e..73e2c29d06 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -636,7 +636,7 @@ def run(hs):
 
     if hs.config.report_stats:
         logger.info("Scheduling stats reporting for 3 hour intervals")
-        clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000, hs, stats)
+        clock.looping_call(start_phone_stats_home, 3 * 60 * 60 * 1000)
 
         # We need to defer this init for the cases that we daemonize
         # otherwise the process ID we get is that of the non-daemon process
@@ -644,7 +644,7 @@ def run(hs):
 
         # We wait 5 minutes to send the first set of stats as the server can
         # be quite busy the first few minutes
-        clock.call_later(5 * 60, start_phone_stats_home, hs, stats)
+        clock.call_later(5 * 60, start_phone_stats_home)
 
     _base.start_reactor(
         "synapse-homeserver",
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index a269de5482..64e898f40c 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -12,6 +12,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Dict, Optional, Tuple, Union
+
 from six import iteritems
 
 import attr
@@ -19,54 +21,113 @@ from frozendict import frozendict
 
 from twisted.internet import defer
 
+from synapse.appservice import ApplicationService
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 
 
 @attr.s(slots=True)
 class EventContext:
     """
+    Holds information relevant to persisting an event
+
     Attributes:
-        state_group (int|None): state group id, if the state has been stored
-            as a state group. This is usually only None if e.g. the event is
-            an outlier.
-        rejected (bool|str): A rejection reason if the event was rejected, else
-            False
-
-        prev_group (int): Previously persisted state group. ``None`` for an
-            outlier.
-        delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
-            (type, state_key) -> event_id. ``None`` for an outlier.
-
-        app_service: FIXME
-
-        _current_state_ids (dict[(str, str), str]|None):
-            The current state map including the current event. None if outlier
-            or we haven't fetched the state from DB yet.
+        rejected: A rejection reason if the event was rejected, else False
+
+        _state_group: The ID of the state group for this event. Note that state events
+            are persisted with a state group which includes the new event, so this is
+            effectively the state *after* the event in question.
+
+            For a *rejected* state event, where the state of the rejected event is
+            ignored, this state_group should never make it into the
+            event_to_state_groups table. Indeed, inspecting this value for a rejected
+            state event is almost certainly incorrect.
+
+            For an outlier, where we don't have the state at the event, this will be
+            None.
+
+            Note that this is a private attribute: it should be accessed via
+            the ``state_group`` property.
+
+        state_group_before_event: The ID of the state group representing the state
+            of the room before this event.
+
+            If this is a non-state event, this will be the same as ``state_group``. If
+            it's a state event, it will be the same as ``prev_group``.
+
+            If ``state_group`` is None (ie, the event is an outlier),
+            ``state_group_before_event`` will always also be ``None``.
+
+        prev_group: If it is known, ``state_group``'s prev_group. Note that this being
+            None does not necessarily mean that ``state_group`` does not have
+            a prev_group!
+
+            If the event is a state event, this is normally the same as ``prev_group``.
+
+            If ``state_group`` is None (ie, the event is an outlier), ``prev_group``
+            will always also be ``None``.
+
+            Note that this *not* (necessarily) the state group associated with
+            ``_prev_state_ids``.
+
+        delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group``
+            and ``state_group``.
+
+        app_service: If this event is being sent by a (local) application service, that
+            app service.
+
+        _current_state_ids: The room state map, including this event - ie, the state
+            in ``state_group``.
+
             (type, state_key) -> event_id
 
-        _prev_state_ids (dict[(str, str), str]|None):
-            The current state map excluding the current event. None if outlier
-            or we haven't fetched the state from DB yet.
+            FIXME: what is this for an outlier? it seems ill-defined. It seems like
+            it could be either {}, or the state we were given by the remote
+            server, depending on $THINGS
+
+            Note that this is a private attribute: it should be accessed via
+            ``get_current_state_ids``. _AsyncEventContext impl calculates this
+            on-demand: it will be None until that happens.
+
+        _prev_state_ids: The room state map, excluding this event - ie, the state
+            in ``state_group_before_event``. For a non-state
+            event, this will be the same as _current_state_events.
+
+            Note that it is a completely different thing to prev_group!
+
             (type, state_key) -> event_id
+
+            FIXME: again, what is this for an outlier?
+
+            As with _current_state_ids, this is a private attribute. It should be
+            accessed via get_prev_state_ids.
     """
 
-    state_group = attr.ib(default=None)
-    rejected = attr.ib(default=False)
-    prev_group = attr.ib(default=None)
-    delta_ids = attr.ib(default=None)
-    app_service = attr.ib(default=None)
+    rejected = attr.ib(default=False, type=Union[bool, str])
+    _state_group = attr.ib(default=None, type=Optional[int])
+    state_group_before_event = attr.ib(default=None, type=Optional[int])
+    prev_group = attr.ib(default=None, type=Optional[int])
+    delta_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]])
+    app_service = attr.ib(default=None, type=Optional[ApplicationService])
 
-    _prev_state_ids = attr.ib(default=None)
-    _current_state_ids = attr.ib(default=None)
+    _current_state_ids = attr.ib(
+        default=None, type=Optional[Dict[Tuple[str, str], str]]
+    )
+    _prev_state_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]])
 
     @staticmethod
     def with_state(
-        state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None
+        state_group,
+        state_group_before_event,
+        current_state_ids,
+        prev_state_ids,
+        prev_group=None,
+        delta_ids=None,
     ):
         return EventContext(
             current_state_ids=current_state_ids,
             prev_state_ids=prev_state_ids,
             state_group=state_group,
+            state_group_before_event=state_group_before_event,
             prev_group=prev_group,
             delta_ids=delta_ids,
         )
@@ -97,7 +158,8 @@ class EventContext:
             "prev_state_id": prev_state_id,
             "event_type": event.type,
             "event_state_key": event.state_key if event.is_state() else None,
-            "state_group": self.state_group,
+            "state_group": self._state_group,
+            "state_group_before_event": self.state_group_before_event,
             "rejected": self.rejected,
             "prev_group": self.prev_group,
             "delta_ids": _encode_state_dict(self.delta_ids),
@@ -123,6 +185,7 @@ class EventContext:
             event_type=input["event_type"],
             event_state_key=input["event_state_key"],
             state_group=input["state_group"],
+            state_group_before_event=input["state_group_before_event"],
             prev_group=input["prev_group"],
             delta_ids=_decode_state_dict(input["delta_ids"]),
             rejected=input["rejected"],
@@ -134,22 +197,52 @@ class EventContext:
 
         return context
 
+    @property
+    def state_group(self) -> Optional[int]:
+        """The ID of the state group for this event.
+
+        Note that state events are persisted with a state group which includes the new
+        event, so this is effectively the state *after* the event in question.
+
+        For an outlier, where we don't have the state at the event, this will be None.
+
+        It is an error to access this for a rejected event, since rejected state should
+        not make it into the room state. Accessing this property will raise an exception
+        if ``rejected`` is set.
+        """
+        if self.rejected:
+            raise RuntimeError("Attempt to access state_group of rejected event")
+
+        return self._state_group
+
     @defer.inlineCallbacks
     def get_current_state_ids(self, store):
-        """Gets the current state IDs
+        """
+        Gets the room state map, including this event - ie, the state in ``state_group``
+
+        It is an error to access this for a rejected event, since rejected state should
+        not make it into the room state. This method will raise an exception if
+        ``rejected`` is set.
 
         Returns:
             Deferred[dict[(str, str), str]|None]: Returns None if state_group
                 is None, which happens when the associated event is an outlier.
+
                 Maps a (type, state_key) to the event ID of the state event matching
                 this tuple.
         """
+        if self.rejected:
+            raise RuntimeError("Attempt to access state_ids of rejected event")
+
         yield self._ensure_fetched(store)
         return self._current_state_ids
 
     @defer.inlineCallbacks
     def get_prev_state_ids(self, store):
-        """Gets the prev state IDs
+        """
+        Gets the room state map, excluding this event.
+
+        For a non-state event, this will be the same as get_current_state_ids().
 
         Returns:
             Deferred[dict[(str, str), str]|None]: Returns None if state_group
@@ -163,11 +256,17 @@ class EventContext:
     def get_cached_current_state_ids(self):
         """Gets the current state IDs if we have them already cached.
 
+        It is an error to access this for a rejected event, since rejected state should
+        not make it into the room state. This method will raise an exception if
+        ``rejected`` is set.
+
         Returns:
             dict[(str, str), str]|None: Returns None if we haven't cached the
             state or if state_group is None, which happens when the associated
             event is an outlier.
         """
+        if self.rejected:
+            raise RuntimeError("Attempt to access state_ids of rejected event")
 
         return self._current_state_ids
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 8cafcfdab0..05dd8d2671 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1688,7 +1688,11 @@ class FederationHandler(BaseHandler):
         # hack around with a try/finally instead.
         success = False
         try:
-            if not event.internal_metadata.is_outlier() and not backfilled:
+            if (
+                not event.internal_metadata.is_outlier()
+                and not backfilled
+                and not context.rejected
+            ):
                 yield self.action_generator.handle_push_actions_for_event(
                     event, context
                 )
@@ -2276,6 +2280,7 @@ class FederationHandler(BaseHandler):
 
         return EventContext.with_state(
             state_group=state_group,
+            state_group_before_event=context.state_group_before_event,
             current_state_ids=current_state_ids,
             prev_state_ids=prev_state_ids,
             prev_group=prev_group,
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index cff6b0d375..235f11c322 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -24,7 +24,6 @@ from synapse.api.errors import (
     AuthError,
     Codes,
     ConsentNotGivenError,
-    LimitExceededError,
     RegistrationError,
     SynapseError,
 )
@@ -168,6 +167,7 @@ class RegistrationHandler(BaseHandler):
         Raises:
             RegistrationError if there was a problem registering.
         """
+        yield self.check_registration_ratelimit(address)
 
         yield self.auth.check_auth_blocking(threepid=threepid)
         password_hash = None
@@ -217,8 +217,13 @@ class RegistrationHandler(BaseHandler):
 
         else:
             # autogen a sequential user ID
+            fail_count = 0
             user = None
             while not user:
+                # Fail after being unable to find a suitable ID a few times
+                if fail_count > 10:
+                    raise SynapseError(500, "Unable to find a suitable guest user ID")
+
                 localpart = yield self._generate_user_id()
                 user = UserID(localpart, self.hs.hostname)
                 user_id = user.to_string()
@@ -233,10 +238,14 @@ class RegistrationHandler(BaseHandler):
                         create_profile_with_displayname=default_display_name,
                         address=address,
                     )
+
+                    # Successfully registered
+                    break
                 except SynapseError:
                     # if user id is taken, just generate another
                     user = None
                     user_id = None
+                    fail_count += 1
 
         if not self.hs.config.user_consent_at_registration:
             yield self._auto_join_rooms(user_id)
@@ -414,6 +423,29 @@ class RegistrationHandler(BaseHandler):
             ratelimit=False,
         )
 
+    def check_registration_ratelimit(self, address):
+        """A simple helper method to check whether the registration rate limit has been hit
+        for a given IP address
+
+        Args:
+            address (str|None): the IP address used to perform the registration. If this is
+                None, no ratelimiting will be performed.
+
+        Raises:
+            LimitExceededError: If the rate limit has been exceeded.
+        """
+        if not address:
+            return
+
+        time_now = self.clock.time()
+
+        self.ratelimiter.ratelimit(
+            address,
+            time_now_s=time_now,
+            rate_hz=self.hs.config.rc_registration.per_second,
+            burst_count=self.hs.config.rc_registration.burst_count,
+        )
+
     def register_with_store(
         self,
         user_id,
@@ -446,22 +478,6 @@ class RegistrationHandler(BaseHandler):
         Returns:
             Deferred
         """
-        # Don't rate limit for app services
-        if appservice_id is None and address is not None:
-            time_now = self.clock.time()
-
-            allowed, time_allowed = self.ratelimiter.can_do_action(
-                address,
-                time_now_s=time_now,
-                rate_hz=self.hs.config.rc_registration.per_second,
-                burst_count=self.hs.config.rc_registration.burst_count,
-            )
-
-            if not allowed:
-                raise LimitExceededError(
-                    retry_after_ms=int(1000 * (time_allowed - time_now))
-                )
-
         if self.hs.config.worker_app:
             return self._register_client(
                 user_id=user_id,
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 915cfb9430..0c4aca1291 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -75,6 +75,8 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
     async def _handle_request(self, request, user_id):
         content = parse_json_object_from_request(request)
 
+        self.registration_handler.check_registration_ratelimit(content["address"])
+
         await self.registration_handler.register_with_store(
             user_id=user_id,
             password_hash=content["password_hash"],
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 531d923f76..15c15a12f5 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -56,6 +56,9 @@ logger = logging.getLogger(__name__)
 _charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
 _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
 
+OG_TAG_NAME_MAXLEN = 50
+OG_TAG_VALUE_MAXLEN = 1000
+
 
 class PreviewUrlResource(DirectServeResource):
     isLeaf = True
@@ -171,7 +174,7 @@ class PreviewUrlResource(DirectServeResource):
             ts (int):
 
         Returns:
-            Deferred[str]: json-encoded og data
+            Deferred[bytes]: json-encoded og data
         """
         # check the URL cache in the DB (which will also provide us with
         # historical previews, if we have any)
@@ -272,6 +275,18 @@ class PreviewUrlResource(DirectServeResource):
             logger.warning("Failed to find any OG data in %s", url)
             og = {}
 
+        # filter out any stupidly long values
+        keys_to_remove = []
+        for k, v in og.items():
+            # values can be numeric as well as strings, hence the cast to str
+            if len(k) > OG_TAG_NAME_MAXLEN or len(str(v)) > OG_TAG_VALUE_MAXLEN:
+                logger.warning(
+                    "Pruning overlong tag %s from OG data", k[:OG_TAG_NAME_MAXLEN]
+                )
+                keys_to_remove.append(k)
+        for k in keys_to_remove:
+            del og[k]
+
         logger.debug("Calculated OG for %s as %s", url, og)
 
         jsonog = json.dumps(og)
@@ -506,6 +521,10 @@ def _calc_og(tree, media_uri):
     og = {}
     for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
         if "content" in tag.attrib:
+            # if we've got more than 50 tags, someone is taking the piss
+            if len(og) >= 50:
+                logger.warning("Skipping OG for page with too many 'og:' tags")
+                return {}
             og[tag.attrib["property"]] = tag.attrib["content"]
 
     # TODO: grab article: meta tags too, e.g.:
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 4e91eb66fe..139beef8ed 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,6 +16,7 @@
 
 import logging
 from collections import namedtuple
+from typing import Iterable, Optional
 
 from six import iteritems, itervalues
 
@@ -27,6 +28,7 @@ from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
+from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.logging.utils import log_function
 from synapse.state import v1, v2
@@ -212,15 +214,17 @@ class StateHandler(object):
         return joined_hosts
 
     @defer.inlineCallbacks
-    def compute_event_context(self, event, old_state=None):
+    def compute_event_context(
+        self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
+    ):
         """Build an EventContext structure for the event.
 
         This works out what the current state should be for the event, and
         generates a new state group if necessary.
 
         Args:
-            event (synapse.events.EventBase):
-            old_state (dict|None): The state at the event if it can't be
+            event:
+            old_state: The state at the event if it can't be
                 calculated from existing events. This is normally only specified
                 when receiving an event from federation where we don't have the
                 prev events for, e.g. when backfilling.
@@ -232,6 +236,9 @@ class StateHandler(object):
             # If this is an outlier, then we know it shouldn't have any current
             # state. Certainly store.get_current_state won't return any, and
             # persisting the event won't store the state group.
+
+            # FIXME: why do we populate current_state_ids? I thought the point was
+            # that we weren't supposed to have any state for outliers?
             if old_state:
                 prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state}
                 if event.is_state():
@@ -248,113 +255,103 @@ class StateHandler(object):
             # group for it.
             context = EventContext.with_state(
                 state_group=None,
+                state_group_before_event=None,
                 current_state_ids=current_state_ids,
                 prev_state_ids=prev_state_ids,
             )
 
             return context
 
+        #
+        # first of all, figure out the state before the event
+        #
+
         if old_state:
-            # We already have the state, so we don't need to calculate it.
-            # Let's just correctly fill out the context and create a
-            # new state group for it.
-
-            prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state}
-
-            if event.is_state():
-                key = (event.type, event.state_key)
-                if key in prev_state_ids:
-                    replaces = prev_state_ids[key]
-                    if replaces != event.event_id:  # Paranoia check
-                        event.unsigned["replaces_state"] = replaces
-                current_state_ids = dict(prev_state_ids)
-                current_state_ids[key] = event.event_id
-            else:
-                current_state_ids = prev_state_ids
+            # if we're given the state before the event, then we use that
+            state_ids_before_event = {
+                (s.type, s.state_key): s.event_id for s in old_state
+            }
+            state_group_before_event = None
+            state_group_before_event_prev_group = None
+            deltas_to_state_group_before_event = None
 
-            state_group = yield self.state_store.store_state_group(
-                event.event_id,
-                event.room_id,
-                prev_group=None,
-                delta_ids=None,
-                current_state_ids=current_state_ids,
-            )
+        else:
+            # otherwise, we'll need to resolve the state across the prev_events.
+            logger.debug("calling resolve_state_groups from compute_event_context")
 
-            context = EventContext.with_state(
-                state_group=state_group,
-                current_state_ids=current_state_ids,
-                prev_state_ids=prev_state_ids,
+            entry = yield self.resolve_state_groups_for_events(
+                event.room_id, event.prev_event_ids()
             )
 
-            return context
+            state_ids_before_event = entry.state
+            state_group_before_event = entry.state_group
+            state_group_before_event_prev_group = entry.prev_group
+            deltas_to_state_group_before_event = entry.delta_ids
 
-        logger.debug("calling resolve_state_groups from compute_event_context")
+        #
+        # make sure that we have a state group at that point. If it's not a state event,
+        # that will be the state group for the new event. If it *is* a state event,
+        # it might get rejected (in which case we'll need to persist it with the
+        # previous state group)
+        #
 
-        entry = yield self.resolve_state_groups_for_events(
-            event.room_id, event.prev_event_ids()
-        )
+        if not state_group_before_event:
+            state_group_before_event = yield self.state_store.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=state_group_before_event_prev_group,
+                delta_ids=deltas_to_state_group_before_event,
+                current_state_ids=state_ids_before_event,
+            )
 
-        prev_state_ids = entry.state
-        prev_group = None
-        delta_ids = None
+            # XXX: can we update the state cache entry for the new state group? or
+            # could we set a flag on resolve_state_groups_for_events to tell it to
+            # always make a state group?
+
+        #
+        # now if it's not a state event, we're done
+        #
+
+        if not event.is_state():
+            return EventContext.with_state(
+                state_group_before_event=state_group_before_event,
+                state_group=state_group_before_event,
+                current_state_ids=state_ids_before_event,
+                prev_state_ids=state_ids_before_event,
+                prev_group=state_group_before_event_prev_group,
+                delta_ids=deltas_to_state_group_before_event,
+            )
 
-        if event.is_state():
-            # If this is a state event then we need to create a new state
-            # group for the state after this event.
+        #
+        # otherwise, we'll need to create a new state group for after the event
+        #
 
-            key = (event.type, event.state_key)
-            if key in prev_state_ids:
-                replaces = prev_state_ids[key]
+        key = (event.type, event.state_key)
+        if key in state_ids_before_event:
+            replaces = state_ids_before_event[key]
+            if replaces != event.event_id:
                 event.unsigned["replaces_state"] = replaces
 
-            current_state_ids = dict(prev_state_ids)
-            current_state_ids[key] = event.event_id
-
-            if entry.state_group:
-                # If the state at the event has a state group assigned then
-                # we can use that as the prev group
-                prev_group = entry.state_group
-                delta_ids = {key: event.event_id}
-            elif entry.prev_group:
-                # If the state at the event only has a prev group, then we can
-                # use that as a prev group too.
-                prev_group = entry.prev_group
-                delta_ids = dict(entry.delta_ids)
-                delta_ids[key] = event.event_id
-
-            state_group = yield self.state_store.store_state_group(
-                event.event_id,
-                event.room_id,
-                prev_group=prev_group,
-                delta_ids=delta_ids,
-                current_state_ids=current_state_ids,
-            )
-        else:
-            current_state_ids = prev_state_ids
-            prev_group = entry.prev_group
-            delta_ids = entry.delta_ids
-
-            if entry.state_group is None:
-                entry.state_group = yield self.state_store.store_state_group(
-                    event.event_id,
-                    event.room_id,
-                    prev_group=entry.prev_group,
-                    delta_ids=entry.delta_ids,
-                    current_state_ids=current_state_ids,
-                )
-                entry.state_id = entry.state_group
-
-            state_group = entry.state_group
-
-        context = EventContext.with_state(
-            state_group=state_group,
-            current_state_ids=current_state_ids,
-            prev_state_ids=prev_state_ids,
-            prev_group=prev_group,
+        state_ids_after_event = dict(state_ids_before_event)
+        state_ids_after_event[key] = event.event_id
+        delta_ids = {key: event.event_id}
+
+        state_group_after_event = yield self.state_store.store_state_group(
+            event.event_id,
+            event.room_id,
+            prev_group=state_group_before_event,
             delta_ids=delta_ids,
+            current_state_ids=state_ids_after_event,
         )
 
-        return context
+        return EventContext.with_state(
+            state_group=state_group_after_event,
+            state_group_before_event=state_group_before_event,
+            current_state_ids=state_ids_after_event,
+            prev_state_ids=state_ids_before_event,
+            prev_group=state_group_before_event,
+            delta_ids=delta_ids,
+        )
 
     @measure_func()
     @defer.inlineCallbacks
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index f70d41ecab..ee1b2b2bbf 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -488,14 +488,14 @@ class RegistrationWorkerStore(SQLBaseStore):
         we can. Unfortunately, it's possible some of them are already taken by
         existing users, and there may be gaps in the already taken range. This
         function returns the start of the first allocatable gap. This is to
-        avoid the case of ID 10000000 being pre-allocated, so us wasting the
-        first (and shortest) many generated user IDs.
+        avoid the case of ID 1000 being pre-allocated and starting at 1001 while
+        0-999 are available.
         """
 
         def _find_next_generated_user_id(txn):
-            # We bound between '@1' and '@a' to avoid pulling the entire table
+            # We bound between '@0' and '@a' to avoid pulling the entire table
             # out.
-            txn.execute("SELECT name FROM users WHERE '@1' <= name AND name < '@a'")
+            txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
 
             regex = re.compile(r"^@(\d+):")
 
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index a41cac7b36..5c293e0fa3 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -1235,7 +1235,7 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
             # if the event was rejected, just give it the same state as its
             # predecessor.
             if context.rejected:
-                state_groups[event.event_id] = context.prev_group
+                state_groups[event.event_id] = context.state_group_before_event
                 continue
 
             state_groups[event.event_id] = context.state_group
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 0e8da27f53..84f5ae22c3 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -17,8 +17,8 @@ import functools
 import inspect
 import logging
 import threading
-from collections import namedtuple
-from typing import Any, cast
+from typing import Any, Tuple, Union, cast
+from weakref import WeakValueDictionary
 
 from six import itervalues
 
@@ -38,6 +38,8 @@ from . import register_cache
 
 logger = logging.getLogger(__name__)
 
+CacheKey = Union[Tuple, Any]
+
 
 class _CachedFunction(Protocol):
     invalidate = None  # type: Any
@@ -430,7 +432,7 @@ class CacheDescriptor(_CacheDescriptorBase):
             # Add our own `cache_context` to argument list if the wrapped function
             # has asked for one
             if self.add_cache_context:
-                kwargs["cache_context"] = _CacheContext(cache, cache_key)
+                kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
 
             try:
                 cached_result_d = cache.get(cache_key, callback=invalidate_callback)
@@ -624,14 +626,38 @@ class CacheListDescriptor(_CacheDescriptorBase):
         return wrapped
 
 
-class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
-    # We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
-    # which namedtuple does for us (i.e. two _CacheContext are the same if
-    # their caches and keys match). This is important in particular to
-    # dedupe when we add callbacks to lru cache nodes, otherwise the number
-    # of callbacks would grow.
-    def invalidate(self):
-        self.cache.invalidate(self.key)
+class _CacheContext:
+    """Holds cache information from the cached function higher in the calling order.
+
+    Can be used to invalidate the higher level cache entry if something changes
+    on a lower level.
+    """
+
+    _cache_context_objects = (
+        WeakValueDictionary()
+    )  # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext]
+
+    def __init__(self, cache, cache_key):  # type: (Cache, CacheKey) -> None
+        self._cache = cache
+        self._cache_key = cache_key
+
+    def invalidate(self):  # type: () -> None
+        """Invalidates the cache entry referred to by the context."""
+        self._cache.invalidate(self._cache_key)
+
+    @classmethod
+    def get_instance(cls, cache, cache_key):  # type: (Cache, CacheKey) -> _CacheContext
+        """Returns an instance constructed with the given arguments.
+
+        A new instance is only created if none already exists.
+        """
+
+        # We make sure there are no identical _CacheContext instances. This is
+        # important in particular to dedupe when we add callbacks to lru cache
+        # nodes, otherwise the number of callbacks would grow.
+        return cls._cache_context_objects.setdefault(
+            (cache, cache_key), cls(cache, cache_key)
+        )
 
 
 def cached(