summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/6320.bugfix1
-rw-r--r--synapse/events/snapshot.py25
-rw-r--r--synapse/handlers/federation.py1
-rw-r--r--synapse/state/__init__.py172
-rw-r--r--synapse/storage/data_stores/main/state.py2
-rw-r--r--tests/handlers/test_federation.py126
-rw-r--r--tests/test_state.py61
7 files changed, 285 insertions, 103 deletions
diff --git a/changelog.d/6320.bugfix b/changelog.d/6320.bugfix
new file mode 100644
index 0000000000..2c3fad5655
--- /dev/null
+++ b/changelog.d/6320.bugfix
@@ -0,0 +1 @@
+Fix bug which casued rejected events to be persisted with the wrong room state.
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 0f3c5989cb..64e898f40c 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -48,10 +48,21 @@ class EventContext:
             Note that this is a private attribute: it should be accessed via
             the ``state_group`` property.
 
+        state_group_before_event: The ID of the state group representing the state
+            of the room before this event.
+
+            If this is a non-state event, this will be the same as ``state_group``. If
+            it's a state event, it will be the same as ``prev_group``.
+
+            If ``state_group`` is None (ie, the event is an outlier),
+            ``state_group_before_event`` will always also be ``None``.
+
         prev_group: If it is known, ``state_group``'s prev_group. Note that this being
             None does not necessarily mean that ``state_group`` does not have
             a prev_group!
 
+            If the event is a state event, this is normally the same as ``prev_group``.
+
             If ``state_group`` is None (ie, the event is an outlier), ``prev_group``
             will always also be ``None``.
 
@@ -77,7 +88,8 @@ class EventContext:
             ``get_current_state_ids``. _AsyncEventContext impl calculates this
             on-demand: it will be None until that happens.
 
-        _prev_state_ids: The room state map, excluding this event. For a non-state
+        _prev_state_ids: The room state map, excluding this event - ie, the state
+            in ``state_group_before_event``. For a non-state
             event, this will be the same as _current_state_events.
 
             Note that it is a completely different thing to prev_group!
@@ -92,6 +104,7 @@ class EventContext:
 
     rejected = attr.ib(default=False, type=Union[bool, str])
     _state_group = attr.ib(default=None, type=Optional[int])
+    state_group_before_event = attr.ib(default=None, type=Optional[int])
     prev_group = attr.ib(default=None, type=Optional[int])
     delta_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]])
     app_service = attr.ib(default=None, type=Optional[ApplicationService])
@@ -103,12 +116,18 @@ class EventContext:
 
     @staticmethod
     def with_state(
-        state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None
+        state_group,
+        state_group_before_event,
+        current_state_ids,
+        prev_state_ids,
+        prev_group=None,
+        delta_ids=None,
     ):
         return EventContext(
             current_state_ids=current_state_ids,
             prev_state_ids=prev_state_ids,
             state_group=state_group,
+            state_group_before_event=state_group_before_event,
             prev_group=prev_group,
             delta_ids=delta_ids,
         )
@@ -140,6 +159,7 @@ class EventContext:
             "event_type": event.type,
             "event_state_key": event.state_key if event.is_state() else None,
             "state_group": self._state_group,
+            "state_group_before_event": self.state_group_before_event,
             "rejected": self.rejected,
             "prev_group": self.prev_group,
             "delta_ids": _encode_state_dict(self.delta_ids),
@@ -165,6 +185,7 @@ class EventContext:
             event_type=input["event_type"],
             event_state_key=input["event_state_key"],
             state_group=input["state_group"],
+            state_group_before_event=input["state_group_before_event"],
             prev_group=input["prev_group"],
             delta_ids=_decode_state_dict(input["delta_ids"]),
             rejected=input["rejected"],
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 7a0d132a3e..108bf40b0a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -2287,6 +2287,7 @@ class FederationHandler(BaseHandler):
 
         return EventContext.with_state(
             state_group=state_group,
+            state_group_before_event=context.state_group_before_event,
             current_state_ids=current_state_ids,
             prev_state_ids=prev_state_ids,
             prev_group=prev_group,
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 2c04ab1854..139beef8ed 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,6 +16,7 @@
 
 import logging
 from collections import namedtuple
+from typing import Iterable, Optional
 
 from six import iteritems, itervalues
 
@@ -27,6 +28,7 @@ from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
+from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.logging.utils import log_function
 from synapse.state import v1, v2
@@ -212,15 +214,17 @@ class StateHandler(object):
         return joined_hosts
 
     @defer.inlineCallbacks
-    def compute_event_context(self, event, old_state=None):
+    def compute_event_context(
+        self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
+    ):
         """Build an EventContext structure for the event.
 
         This works out what the current state should be for the event, and
         generates a new state group if necessary.
 
         Args:
-            event (synapse.events.EventBase):
-            old_state (dict|None): The state at the event if it can't be
+            event:
+            old_state: The state at the event if it can't be
                 calculated from existing events. This is normally only specified
                 when receiving an event from federation where we don't have the
                 prev events for, e.g. when backfilling.
@@ -251,113 +255,103 @@ class StateHandler(object):
             # group for it.
             context = EventContext.with_state(
                 state_group=None,
+                state_group_before_event=None,
                 current_state_ids=current_state_ids,
                 prev_state_ids=prev_state_ids,
             )
 
             return context
 
+        #
+        # first of all, figure out the state before the event
+        #
+
         if old_state:
-            # We already have the state, so we don't need to calculate it.
-            # Let's just correctly fill out the context and create a
-            # new state group for it.
-
-            prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state}
-
-            if event.is_state():
-                key = (event.type, event.state_key)
-                if key in prev_state_ids:
-                    replaces = prev_state_ids[key]
-                    if replaces != event.event_id:  # Paranoia check
-                        event.unsigned["replaces_state"] = replaces
-                current_state_ids = dict(prev_state_ids)
-                current_state_ids[key] = event.event_id
-            else:
-                current_state_ids = prev_state_ids
+            # if we're given the state before the event, then we use that
+            state_ids_before_event = {
+                (s.type, s.state_key): s.event_id for s in old_state
+            }
+            state_group_before_event = None
+            state_group_before_event_prev_group = None
+            deltas_to_state_group_before_event = None
 
-            state_group = yield self.state_store.store_state_group(
-                event.event_id,
-                event.room_id,
-                prev_group=None,
-                delta_ids=None,
-                current_state_ids=current_state_ids,
-            )
+        else:
+            # otherwise, we'll need to resolve the state across the prev_events.
+            logger.debug("calling resolve_state_groups from compute_event_context")
 
-            context = EventContext.with_state(
-                state_group=state_group,
-                current_state_ids=current_state_ids,
-                prev_state_ids=prev_state_ids,
+            entry = yield self.resolve_state_groups_for_events(
+                event.room_id, event.prev_event_ids()
             )
 
-            return context
+            state_ids_before_event = entry.state
+            state_group_before_event = entry.state_group
+            state_group_before_event_prev_group = entry.prev_group
+            deltas_to_state_group_before_event = entry.delta_ids
 
-        logger.debug("calling resolve_state_groups from compute_event_context")
+        #
+        # make sure that we have a state group at that point. If it's not a state event,
+        # that will be the state group for the new event. If it *is* a state event,
+        # it might get rejected (in which case we'll need to persist it with the
+        # previous state group)
+        #
 
-        entry = yield self.resolve_state_groups_for_events(
-            event.room_id, event.prev_event_ids()
-        )
+        if not state_group_before_event:
+            state_group_before_event = yield self.state_store.store_state_group(
+                event.event_id,
+                event.room_id,
+                prev_group=state_group_before_event_prev_group,
+                delta_ids=deltas_to_state_group_before_event,
+                current_state_ids=state_ids_before_event,
+            )
 
-        prev_state_ids = entry.state
-        prev_group = None
-        delta_ids = None
+            # XXX: can we update the state cache entry for the new state group? or
+            # could we set a flag on resolve_state_groups_for_events to tell it to
+            # always make a state group?
+
+        #
+        # now if it's not a state event, we're done
+        #
+
+        if not event.is_state():
+            return EventContext.with_state(
+                state_group_before_event=state_group_before_event,
+                state_group=state_group_before_event,
+                current_state_ids=state_ids_before_event,
+                prev_state_ids=state_ids_before_event,
+                prev_group=state_group_before_event_prev_group,
+                delta_ids=deltas_to_state_group_before_event,
+            )
 
-        if event.is_state():
-            # If this is a state event then we need to create a new state
-            # group for the state after this event.
+        #
+        # otherwise, we'll need to create a new state group for after the event
+        #
 
-            key = (event.type, event.state_key)
-            if key in prev_state_ids:
-                replaces = prev_state_ids[key]
+        key = (event.type, event.state_key)
+        if key in state_ids_before_event:
+            replaces = state_ids_before_event[key]
+            if replaces != event.event_id:
                 event.unsigned["replaces_state"] = replaces
 
-            current_state_ids = dict(prev_state_ids)
-            current_state_ids[key] = event.event_id
-
-            if entry.state_group:
-                # If the state at the event has a state group assigned then
-                # we can use that as the prev group
-                prev_group = entry.state_group
-                delta_ids = {key: event.event_id}
-            elif entry.prev_group:
-                # If the state at the event only has a prev group, then we can
-                # use that as a prev group too.
-                prev_group = entry.prev_group
-                delta_ids = dict(entry.delta_ids)
-                delta_ids[key] = event.event_id
-
-            state_group = yield self.state_store.store_state_group(
-                event.event_id,
-                event.room_id,
-                prev_group=prev_group,
-                delta_ids=delta_ids,
-                current_state_ids=current_state_ids,
-            )
-        else:
-            current_state_ids = prev_state_ids
-            prev_group = entry.prev_group
-            delta_ids = entry.delta_ids
-
-            if entry.state_group is None:
-                entry.state_group = yield self.state_store.store_state_group(
-                    event.event_id,
-                    event.room_id,
-                    prev_group=entry.prev_group,
-                    delta_ids=entry.delta_ids,
-                    current_state_ids=current_state_ids,
-                )
-                entry.state_id = entry.state_group
-
-            state_group = entry.state_group
-
-        context = EventContext.with_state(
-            state_group=state_group,
-            current_state_ids=current_state_ids,
-            prev_state_ids=prev_state_ids,
-            prev_group=prev_group,
+        state_ids_after_event = dict(state_ids_before_event)
+        state_ids_after_event[key] = event.event_id
+        delta_ids = {key: event.event_id}
+
+        state_group_after_event = yield self.state_store.store_state_group(
+            event.event_id,
+            event.room_id,
+            prev_group=state_group_before_event,
             delta_ids=delta_ids,
+            current_state_ids=state_ids_after_event,
         )
 
-        return context
+        return EventContext.with_state(
+            state_group=state_group_after_event,
+            state_group_before_event=state_group_before_event,
+            current_state_ids=state_ids_after_event,
+            prev_state_ids=state_ids_before_event,
+            prev_group=state_group_before_event,
+            delta_ids=delta_ids,
+        )
 
     @measure_func()
     @defer.inlineCallbacks
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 3132848034..9e1541988e 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -1231,7 +1231,7 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
             # if the event was rejected, just give it the same state as its
             # predecessor.
             if context.rejected:
-                state_groups[event.event_id] = context.prev_group
+                state_groups[event.event_id] = context.state_group_before_event
                 continue
 
             state_groups[event.event_id] = context.state_group
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index d56220f403..b4d92cf732 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,13 +12,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError, Codes
+from synapse.federation.federation_base import event_from_pdu_json
+from synapse.logging.context import LoggingContext, run_in_background
 from synapse.rest import admin
 from synapse.rest.client.v1 import login, room
 
 from tests import unittest
 
+logger = logging.getLogger(__name__)
+
 
 class FederationTestCase(unittest.HomeserverTestCase):
     servlets = [
@@ -79,3 +85,123 @@ class FederationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(failure.code, 403, failure)
         self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
         self.assertEqual(failure.msg, "You are not invited to this room.")
+
+    def test_rejected_message_event_state(self):
+        """
+        Check that we store the state group correctly for rejected non-state events.
+
+        Regression test for #6289.
+        """
+        OTHER_SERVER = "otherserver"
+        OTHER_USER = "@otheruser:" + OTHER_SERVER
+
+        # create the room
+        user_id = self.register_user("kermit", "test")
+        tok = self.login("kermit", "test")
+        room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+        # pretend that another server has joined
+        join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
+
+        # check the state group
+        sg = self.successResultOf(
+            self.store._get_state_group_for_event(join_event.event_id)
+        )
+
+        # build and send an event which will be rejected
+        ev = event_from_pdu_json(
+            {
+                "type": EventTypes.Message,
+                "content": {},
+                "room_id": room_id,
+                "sender": "@yetanotheruser:" + OTHER_SERVER,
+                "depth": join_event["depth"] + 1,
+                "prev_events": [join_event.event_id],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            join_event.format_version,
+        )
+
+        with LoggingContext(request="send_rejected"):
+            d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+        self.get_success(d)
+
+        # that should have been rejected
+        e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True))
+        self.assertIsNotNone(e.rejected_reason)
+
+        # ... and the state group should be the same as before
+        sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id))
+
+        self.assertEqual(sg, sg2)
+
+    def test_rejected_state_event_state(self):
+        """
+        Check that we store the state group correctly for rejected state events.
+
+        Regression test for #6289.
+        """
+        OTHER_SERVER = "otherserver"
+        OTHER_USER = "@otheruser:" + OTHER_SERVER
+
+        # create the room
+        user_id = self.register_user("kermit", "test")
+        tok = self.login("kermit", "test")
+        room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+        # pretend that another server has joined
+        join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
+
+        # check the state group
+        sg = self.successResultOf(
+            self.store._get_state_group_for_event(join_event.event_id)
+        )
+
+        # build and send an event which will be rejected
+        ev = event_from_pdu_json(
+            {
+                "type": "org.matrix.test",
+                "state_key": "test_key",
+                "content": {},
+                "room_id": room_id,
+                "sender": "@yetanotheruser:" + OTHER_SERVER,
+                "depth": join_event["depth"] + 1,
+                "prev_events": [join_event.event_id],
+                "auth_events": [],
+                "origin_server_ts": self.clock.time_msec(),
+            },
+            join_event.format_version,
+        )
+
+        with LoggingContext(request="send_rejected"):
+            d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+        self.get_success(d)
+
+        # that should have been rejected
+        e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True))
+        self.assertIsNotNone(e.rejected_reason)
+
+        # ... and the state group should be the same as before
+        sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id))
+
+        self.assertEqual(sg, sg2)
+
+    def _build_and_send_join_event(self, other_server, other_user, room_id):
+        join_event = self.get_success(
+            self.handler.on_make_join_request(other_server, room_id, other_user)
+        )
+        # the auth code requires that a signature exists, but doesn't check that
+        # signature... go figure.
+        join_event.signatures[other_server] = {"x": "y"}
+        with LoggingContext(request="send_join"):
+            d = run_in_background(
+                self.handler.on_send_join_request, other_server, join_event
+            )
+        self.get_success(d)
+
+        # sanity-check: the room should show that the new user is a member
+        r = self.get_success(self.store.get_current_state_ids(room_id))
+        self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
+
+        return join_event
diff --git a/tests/test_state.py b/tests/test_state.py
index 38246555bd..176535947a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -21,6 +21,7 @@ from synapse.api.auth import Auth
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext
 from synapse.state import StateHandler, StateResolutionHandler
 
 from tests import unittest
@@ -198,16 +199,22 @@ class StateTestCase(unittest.TestCase):
 
         self.store.register_events(graph.walk())
 
-        context_store = {}
+        context_store = {}  # type: dict[str, EventContext]
 
         for event in graph.walk():
             context = yield self.state.compute_event_context(event)
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
-        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+        ctx_c = context_store["C"]
+        ctx_d = context_store["D"]
+
+        prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
         self.assertEqual(2, len(prev_state_ids))
 
+        self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
+        self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
+
     @defer.inlineCallbacks
     def test_branch_basic_conflict(self):
         graph = Graph(
@@ -241,12 +248,19 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
-        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+        # C ends up winning the resolution between B and C
+
+        ctx_c = context_store["C"]
+        ctx_d = context_store["D"]
 
+        prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
         self.assertSetEqual(
             {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
         )
 
+        self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
+        self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
+
     @defer.inlineCallbacks
     def test_branch_have_banned_conflict(self):
         graph = Graph(
@@ -292,11 +306,18 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
-        prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
+        # C ends up winning the resolution between C and D because bans win over other
+        # changes
+
+        ctx_c = context_store["C"]
+        ctx_e = context_store["E"]
 
+        prev_state_ids = yield ctx_e.get_prev_state_ids(self.store)
         self.assertSetEqual(
             {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
         )
+        self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
+        self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
 
     @defer.inlineCallbacks
     def test_branch_have_perms_conflict(self):
@@ -360,12 +381,20 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
-        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+        # B ends up winning the resolution between B and C because power levels
+        # win over other changes.
 
+        ctx_b = context_store["B"]
+        ctx_d = context_store["D"]
+
+        prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
         self.assertSetEqual(
             {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
         )
 
+        self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
+        self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
+
     def _add_depths(self, nodes, edges):
         def _get_depth(ev):
             node = nodes[ev]
@@ -390,13 +419,16 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self.state.compute_event_context(event, old_state=old_state)
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
-        self.assertEqual(
-            set(e.event_id for e in old_state), set(current_state_ids.values())
+        current_state_ids = yield context.get_current_state_ids(self.store)
+        self.assertCountEqual(
+            (e.event_id for e in old_state), current_state_ids.values()
         )
 
-        self.assertIsNotNone(context.state_group)
+        self.assertIsNotNone(context.state_group_before_event)
+        self.assertEqual(context.state_group_before_event, context.state_group)
 
     @defer.inlineCallbacks
     def test_annotate_with_old_state(self):
@@ -411,11 +443,18 @@ class StateTestCase(unittest.TestCase):
         context = yield self.state.compute_event_context(event, old_state=old_state)
 
         prev_state_ids = yield context.get_prev_state_ids(self.store)
+        self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
-        self.assertEqual(
-            set(e.event_id for e in old_state), set(prev_state_ids.values())
+        current_state_ids = yield context.get_current_state_ids(self.store)
+        self.assertCountEqual(
+            (e.event_id for e in old_state + [event]), current_state_ids.values()
         )
 
+        self.assertIsNotNone(context.state_group_before_event)
+        self.assertNotEqual(context.state_group_before_event, context.state_group)
+        self.assertEqual(context.state_group_before_event, context.prev_group)
+        self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
+
     @defer.inlineCallbacks
     def test_trivial_annotate_message(self):
         prev_event_id = "prev_event_id"