diff --git a/synapse/third_party_rules/access_rules.py b/synapse/third_party_rules/access_rules.py
index 1a295ea7ce..5698e3e062 100644
--- a/synapse/third_party_rules/access_rules.py
+++ b/synapse/third_party_rules/access_rules.py
@@ -265,7 +265,7 @@ class RoomAccessRules(object):
# Make sure we don't apply "direct" if the room has more than two members.
if new_rule == ACCESS_RULE_DIRECT:
existing_members, threepid_tokens = self._get_members_and_tokens_from_state(
- state_events,
+ state_events, event
)
if len(existing_members) > 2 or len(threepid_tokens) > 1:
@@ -356,7 +356,7 @@ class RoomAccessRules(object):
"""
# Get the room memberships and 3PID invite tokens from the room's state.
existing_members, threepid_tokens = self._get_members_and_tokens_from_state(
- state_events,
+ state_events, event
)
# There should never be more than one 3PID invite in the room state: if the second
@@ -494,13 +494,14 @@ class RoomAccessRules(object):
return join_rule_event.content.get("join_rule")
@staticmethod
- def _get_members_and_tokens_from_state(state_events):
+ def _get_members_and_tokens_from_state(state_events, event):
"""Retrieves from a list of state events the list of users that have a
m.room.member event in the room, and the tokens of 3PID invites in the room.
Args:
state_events (dict[tuple[event type, state key], EventBase]): The set of state
events.
+ event (EventBase): The event being checked.
Returns:
existing_members (list[str]): List of targets of the m.room.member events in
the state.
@@ -509,13 +510,24 @@ class RoomAccessRules(object):
"""
existing_members = []
threepid_invite_tokens = []
- for key, event in state_events.items():
+ for key, state_event in state_events.items():
if key[0] == EventTypes.Member:
- existing_members.append(event.state_key)
+ existing_members.append(state_event.state_key)
if key[0] == EventTypes.ThirdPartyInvite:
- threepid_invite_tokens.append(event.state_key)
+ threepid_invite_tokens.append(state_event.state_key)
- return existing_members, threepid_invite_tokens
+ # If the event is a state event, there already is an event with the same state key
+ # in the room's state, then the event is updating an existing event from the
+ # room's state, in which case we need to remove the entry from the list in order
+ # to avoid conflicts.
+ if event.is_state():
+ def filter_out_event(state_key):
+ return event.state_key != state_key
+
+ existing_members = filter(filter_out_event, existing_members)
+ threepid_invite_tokens = filter(filter_out_event, threepid_invite_tokens)
+
+ return list(existing_members), list(threepid_invite_tokens)
@staticmethod
def _is_invite_from_threepid(invite, threepid_invite_token):
diff --git a/tests/rest/client/test_room_access_rules.py b/tests/rest/client/test_room_access_rules.py
index 7e23add6b7..bb164c1e5e 100644
--- a/tests/rest/client/test_room_access_rules.py
+++ b/tests/rest/client/test_room_access_rules.py
@@ -483,6 +483,44 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
expected_code=403,
)
+ def test_revoke_3pid_invite_direct(self):
+ """Tests that revoking a 3PID invite doesn't cause the room access rules module to
+ confuse the revokation as a new 3PID invite.
+ """
+ invite_token = "sometoken"
+
+ invite_body = {
+ "display_name": "ker...@exa...",
+ "public_keys": [
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA"
+ },
+ {
+ "key_validity_url": "https://validity_url",
+ "public_key": "4_9nzEeDwR5N9s51jPodBiLnqH43A2_g2InVT137t9I"
+ }
+ ],
+ "key_validity_url": "https://validity_url",
+ "public_key": "ta8IQ0u1sp44HVpxYi7dFOdS/bfwDjcy4xLFlfY5KOA"
+ }
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body=invite_body,
+ tok=self.tok,
+ )
+
+ self.send_state_with_state_key(
+ room_id=self.direct_rooms[1],
+ event_type=EventTypes.ThirdPartyInvite,
+ state_key=invite_token,
+ body={},
+ tok=self.tok,
+ )
+
def create_room(
self, direct=False, rule=None, preset=RoomCreationPreset.TRUSTED_PRIVATE_CHAT,
initial_state=None, expected_code=200,
@@ -574,3 +612,20 @@ class RoomAccessTestCase(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
+
+ def send_state_with_state_key(
+ self, room_id, event_type, state_key, body, tok, expect_code=200
+ ):
+ path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
+ room_id, event_type, state_key
+ )
+
+ request, channel = self.make_request(
+ "PUT", path, json.dumps(body), access_token=tok
+ )
+ self.render(request)
+
+ self.assertEqual(channel.code, expect_code, channel.result)
+
+ return channel.json_body
+
|