diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index d56220f403..b4d92cf732 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,13 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, Codes
+from synapse.federation.federation_base import event_from_pdu_json
+from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from tests import unittest
+logger = logging.getLogger(__name__)
+
class FederationTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -79,3 +85,123 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(failure.code, 403, failure)
self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
self.assertEqual(failure.msg, "You are not invited to this room.")
+
+ def test_rejected_message_event_state(self):
+ """
+ Check that we store the state group correctly for rejected non-state events.
+
+ Regression test for #6289.
+ """
+ OTHER_SERVER = "otherserver"
+ OTHER_USER = "@otheruser:" + OTHER_SERVER
+
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+ # pretend that another server has joined
+ join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
+
+ # check the state group
+ sg = self.successResultOf(
+ self.store._get_state_group_for_event(join_event.event_id)
+ )
+
+ # build and send an event which will be rejected
+ ev = event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "content": {},
+ "room_id": room_id,
+ "sender": "@yetanotheruser:" + OTHER_SERVER,
+ "depth": join_event["depth"] + 1,
+ "prev_events": [join_event.event_id],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ join_event.format_version,
+ )
+
+ with LoggingContext(request="send_rejected"):
+ d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+ self.get_success(d)
+
+ # that should have been rejected
+ e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True))
+ self.assertIsNotNone(e.rejected_reason)
+
+ # ... and the state group should be the same as before
+ sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id))
+
+ self.assertEqual(sg, sg2)
+
+ def test_rejected_state_event_state(self):
+ """
+ Check that we store the state group correctly for rejected state events.
+
+ Regression test for #6289.
+ """
+ OTHER_SERVER = "otherserver"
+ OTHER_USER = "@otheruser:" + OTHER_SERVER
+
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+ # pretend that another server has joined
+ join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
+
+ # check the state group
+ sg = self.successResultOf(
+ self.store._get_state_group_for_event(join_event.event_id)
+ )
+
+ # build and send an event which will be rejected
+ ev = event_from_pdu_json(
+ {
+ "type": "org.matrix.test",
+ "state_key": "test_key",
+ "content": {},
+ "room_id": room_id,
+ "sender": "@yetanotheruser:" + OTHER_SERVER,
+ "depth": join_event["depth"] + 1,
+ "prev_events": [join_event.event_id],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ join_event.format_version,
+ )
+
+ with LoggingContext(request="send_rejected"):
+ d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
+ self.get_success(d)
+
+ # that should have been rejected
+ e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True))
+ self.assertIsNotNone(e.rejected_reason)
+
+ # ... and the state group should be the same as before
+ sg2 = self.successResultOf(self.store._get_state_group_for_event(ev.event_id))
+
+ self.assertEqual(sg, sg2)
+
+ def _build_and_send_join_event(self, other_server, other_user, room_id):
+ join_event = self.get_success(
+ self.handler.on_make_join_request(other_server, room_id, other_user)
+ )
+ # the auth code requires that a signature exists, but doesn't check that
+ # signature... go figure.
+ join_event.signatures[other_server] = {"x": "y"}
+ with LoggingContext(request="send_join"):
+ d = run_in_background(
+ self.handler.on_send_join_request, other_server, join_event
+ )
+ self.get_success(d)
+
+ # sanity-check: the room should show that the new user is a member
+ r = self.get_success(self.store.get_current_state_ids(room_id))
+ self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
+
+ return join_event
diff --git a/tests/test_state.py b/tests/test_state.py
index 38246555bd..176535947a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -21,6 +21,7 @@ from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext
from synapse.state import StateHandler, StateResolutionHandler
from tests import unittest
@@ -198,16 +199,22 @@ class StateTestCase(unittest.TestCase):
self.store.register_events(graph.walk())
- context_store = {}
+ context_store = {} # type: dict[str, EventContext]
for event in graph.walk():
context = yield self.state.compute_event_context(event)
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+ ctx_c = context_store["C"]
+ ctx_d = context_store["D"]
+
+ prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
self.assertEqual(2, len(prev_state_ids))
+ self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
+ self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
+
@defer.inlineCallbacks
def test_branch_basic_conflict(self):
graph = Graph(
@@ -241,12 +248,19 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+ # C ends up winning the resolution between B and C
+
+ ctx_c = context_store["C"]
+ ctx_d = context_store["D"]
+ prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
self.assertSetEqual(
{"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
)
+ self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
+ self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
+
@defer.inlineCallbacks
def test_branch_have_banned_conflict(self):
graph = Graph(
@@ -292,11 +306,18 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
+ # C ends up winning the resolution between C and D because bans win over other
+ # changes
+
+ ctx_c = context_store["C"]
+ ctx_e = context_store["E"]
+ prev_state_ids = yield ctx_e.get_prev_state_ids(self.store)
self.assertSetEqual(
{"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
)
+ self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
+ self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
@defer.inlineCallbacks
def test_branch_have_perms_conflict(self):
@@ -360,12 +381,20 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+ # B ends up winning the resolution between B and C because power levels
+ # win over other changes.
+ ctx_b = context_store["B"]
+ ctx_d = context_store["D"]
+
+ prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
)
+ self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
+ self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
+
def _add_depths(self, nodes, edges):
def _get_depth(ev):
node = nodes[ev]
@@ -390,13 +419,16 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event, old_state=old_state)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+ self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- self.assertEqual(
- set(e.event_id for e in old_state), set(current_state_ids.values())
+ current_state_ids = yield context.get_current_state_ids(self.store)
+ self.assertCountEqual(
+ (e.event_id for e in old_state), current_state_ids.values()
)
- self.assertIsNotNone(context.state_group)
+ self.assertIsNotNone(context.state_group_before_event)
+ self.assertEqual(context.state_group_before_event, context.state_group)
@defer.inlineCallbacks
def test_annotate_with_old_state(self):
@@ -411,11 +443,18 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event, old_state=old_state)
prev_state_ids = yield context.get_prev_state_ids(self.store)
+ self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- self.assertEqual(
- set(e.event_id for e in old_state), set(prev_state_ids.values())
+ current_state_ids = yield context.get_current_state_ids(self.store)
+ self.assertCountEqual(
+ (e.event_id for e in old_state + [event]), current_state_ids.values()
)
+ self.assertIsNotNone(context.state_group_before_event)
+ self.assertNotEqual(context.state_group_before_event, context.state_group)
+ self.assertEqual(context.state_group_before_event, context.prev_group)
+ self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
+
@defer.inlineCallbacks
def test_trivial_annotate_message(self):
prev_event_id = "prev_event_id"
|