diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
index f43ce66483..d87fe9d62c 100644
--- a/tests/handlers/test_room_member.py
+++ b/tests/handlers/test_room_member.py
@@ -5,10 +5,13 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
import synapse.rest.client.login
import synapse.rest.client.room
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.errors import Codes, LimitExceededError, SynapseError
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events import FrozenEventV3
+from synapse.federation.federation_base import (
+ event_from_pdu_json,
+)
from synapse.federation.federation_client import SendJoinResult
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
@@ -453,3 +456,165 @@ class RoomMemberMasterHandlerTestCase(HomeserverTestCase):
new_count = rows[0][0]
self.assertEqual(initial_count, new_count)
+
+
+class TestInviteFiltering(FederatingHomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.client.login.register_servlets,
+ synapse.rest.client.room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.handler = hs.get_room_member_handler()
+ self.fed_handler = hs.get_federation_handler()
+ self.store = hs.get_datastores().main
+
+ # Create three users.
+ self.alice = self.register_user("alice", "pass")
+ self.alice_token = self.login("alice", "pass")
+ self.bob = self.register_user("bob", "pass")
+ self.bob_token = self.login("bob", "pass")
+
+ @override_config({"experimental_features": {"msc4155_enabled": True}})
+ def test_misc4155_block_invite_local(self) -> None:
+ """Test that MSC4155 will block a user from being invited to a room"""
+ room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
+
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.bob,
+ AccountDataTypes.MSC4155_INVITE_PERMISSION_CONFIG,
+ {
+ "blocked_users": [self.alice],
+ },
+ )
+ )
+
+ f = self.get_failure(
+ self.handler.update_membership(
+ requester=create_requester(self.alice),
+ target=UserID.from_string(self.bob),
+ room_id=room_id,
+ action=Membership.INVITE,
+ ),
+ SynapseError,
+ ).value
+ self.assertEqual(f.code, 403)
+ self.assertEqual(f.errcode, "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED")
+
+ @override_config({"experimental_features": {"msc4155_enabled": False}})
+ def test_msc4155_disabled_allow_invite_local(self) -> None:
+ """Test that MSC4155 will block a user from being invited to a room"""
+ room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
+
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.bob,
+ AccountDataTypes.MSC4155_INVITE_PERMISSION_CONFIG,
+ {
+ "blocked_users": [self.alice],
+ },
+ )
+ )
+
+ self.get_success(
+ self.handler.update_membership(
+ requester=create_requester(self.alice),
+ target=UserID.from_string(self.bob),
+ room_id=room_id,
+ action=Membership.INVITE,
+ ),
+ )
+
+ @override_config({"experimental_features": {"msc4155_enabled": True}})
+ def test_msc4155_block_invite_remote(self) -> None:
+ """Test that MSC4155 will block a remote user from being invited to a room"""
+ # A remote user who sends the invite
+ remote_server = "otherserver"
+ remote_user = "@otheruser:" + remote_server
+
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.bob,
+ AccountDataTypes.MSC4155_INVITE_PERMISSION_CONFIG,
+ {"blocked_users": [remote_user]},
+ )
+ )
+
+ room_id = self.helper.create_room_as(
+ room_creator=self.alice, tok=self.alice_token
+ )
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ invite_event = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": "invite"},
+ "room_id": room_id,
+ "sender": remote_user,
+ "state_key": self.bob,
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ f = self.get_failure(
+ self.fed_handler.on_invite_request(
+ remote_server,
+ invite_event,
+ invite_event.room_version,
+ ),
+ SynapseError,
+ ).value
+ self.assertEqual(f.code, 403)
+ self.assertEqual(f.errcode, "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED")
+
+ @override_config({"experimental_features": {"msc4155_enabled": True}})
+ def test_msc4155_block_invite_remote_server(self) -> None:
+ """Test that MSC4155 will block a remote server's user from being invited to a room"""
+ # A remote user who sends the invite
+ remote_server = "otherserver"
+ remote_user = "@otheruser:" + remote_server
+
+ self.get_success(
+ self.store.add_account_data_for_user(
+ self.bob,
+ AccountDataTypes.MSC4155_INVITE_PERMISSION_CONFIG,
+ {"blocked_servers": [remote_server]},
+ )
+ )
+
+ room_id = self.helper.create_room_as(
+ room_creator=self.alice, tok=self.alice_token
+ )
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ invite_event = event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": "invite"},
+ "room_id": room_id,
+ "sender": remote_user,
+ "state_key": self.bob,
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ f = self.get_failure(
+ self.fed_handler.on_invite_request(
+ remote_server,
+ invite_event,
+ invite_event.room_version,
+ ),
+ SynapseError,
+ ).value
+ self.assertEqual(f.code, 403)
+ self.assertEqual(f.errcode, "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED")
diff --git a/tests/storage/test_invite_rule.py b/tests/storage/test_invite_rule.py
new file mode 100644
index 0000000000..38c97ecaa3
--- /dev/null
+++ b/tests/storage/test_invite_rule.py
@@ -0,0 +1,167 @@
+from synapse.storage.invite_rule import InviteRule, InviteRulesConfig
+from synapse.types import UserID
+
+from tests import unittest
+
+regular_user = UserID.from_string("@test:example.org")
+allowed_user = UserID.from_string("@allowed:allow.example.org")
+blocked_user = UserID.from_string("@blocked:block.example.org")
+ignored_user = UserID.from_string("@ignored:ignore.example.org")
+
+
+class InviteFilterTestCase(unittest.TestCase):
+ def test_empty(self) -> None:
+ """Permit by default"""
+ config = InviteRulesConfig(None)
+ self.assertEqual(
+ config.get_invite_rule(regular_user.to_string()), InviteRule.ALLOW
+ )
+
+ def test_ignore_invalid(self) -> None:
+ """Invalid strings are ignored"""
+ config = InviteRulesConfig({"blocked_users": ["not a user"]})
+ self.assertEqual(
+ config.get_invite_rule(blocked_user.to_string()), InviteRule.ALLOW
+ )
+
+ def test_user_blocked(self) -> None:
+ """Permit all, except explicitly blocked users"""
+ config = InviteRulesConfig({"blocked_users": [blocked_user.to_string()]})
+ self.assertEqual(
+ config.get_invite_rule(blocked_user.to_string()), InviteRule.BLOCK
+ )
+ self.assertEqual(
+ config.get_invite_rule(regular_user.to_string()), InviteRule.ALLOW
+ )
+
+ def test_user_ignored(self) -> None:
+ """Permit all, except explicitly ignored users"""
+ config = InviteRulesConfig({"ignored_users": [ignored_user.to_string()]})
+ self.assertEqual(
+ config.get_invite_rule(ignored_user.to_string()), InviteRule.IGNORE
+ )
+ self.assertEqual(
+ config.get_invite_rule(regular_user.to_string()), InviteRule.ALLOW
+ )
+
+ def test_user_precedence(self) -> None:
+ """Always take allowed over ignored, ignored over blocked, and then block."""
+ config = InviteRulesConfig(
+ {
+ "allowed_users": [allowed_user.to_string()],
+ "ignored_users": [allowed_user.to_string(), ignored_user.to_string()],
+ "blocked_users": [
+ allowed_user.to_string(),
+ ignored_user.to_string(),
+ blocked_user.to_string(),
+ ],
+ }
+ )
+ self.assertEqual(
+ config.get_invite_rule(allowed_user.to_string()), InviteRule.ALLOW
+ )
+ self.assertEqual(
+ config.get_invite_rule(ignored_user.to_string()), InviteRule.IGNORE
+ )
+ self.assertEqual(
+ config.get_invite_rule(blocked_user.to_string()), InviteRule.BLOCK
+ )
+
+ def test_server_blocked(self) -> None:
+ """Block all users on the server except those allowed."""
+ user_on_same_server = UserID("blocked", allowed_user.domain)
+ config = InviteRulesConfig(
+ {
+ "allowed_users": [allowed_user.to_string()],
+ "blocked_servers": [allowed_user.domain],
+ }
+ )
+ self.assertEqual(
+ config.get_invite_rule(allowed_user.to_string()), InviteRule.ALLOW
+ )
+ self.assertEqual(
+ config.get_invite_rule(user_on_same_server.to_string()), InviteRule.BLOCK
+ )
+
+ def test_server_ignored(self) -> None:
+ """Ignore all users on the server except those allowed."""
+ user_on_same_server = UserID("ignored", allowed_user.domain)
+ config = InviteRulesConfig(
+ {
+ "allowed_users": [allowed_user.to_string()],
+ "ignored_servers": [allowed_user.domain],
+ }
+ )
+ self.assertEqual(
+ config.get_invite_rule(allowed_user.to_string()), InviteRule.ALLOW
+ )
+ self.assertEqual(
+ config.get_invite_rule(user_on_same_server.to_string()), InviteRule.IGNORE
+ )
+
+ def test_server_allow(self) -> None:
+ """Allow all from a server except explictly blocked or ignored users."""
+ blocked_user_on_same_server = UserID("blocked", allowed_user.domain)
+ ignored_user_on_same_server = UserID("ignored", allowed_user.domain)
+ allowed_user_on_same_server = UserID("another", allowed_user.domain)
+ config = InviteRulesConfig(
+ {
+ "ignored_users": [ignored_user_on_same_server.to_string()],
+ "blocked_users": [blocked_user_on_same_server.to_string()],
+ "allowed_servers": [allowed_user.to_string()],
+ }
+ )
+ self.assertEqual(
+ config.get_invite_rule(allowed_user.to_string()), InviteRule.ALLOW
+ )
+ self.assertEqual(
+ config.get_invite_rule(allowed_user_on_same_server.to_string()),
+ InviteRule.ALLOW,
+ )
+ self.assertEqual(
+ config.get_invite_rule(blocked_user_on_same_server.to_string()),
+ InviteRule.BLOCK,
+ )
+ self.assertEqual(
+ config.get_invite_rule(ignored_user_on_same_server.to_string()),
+ InviteRule.IGNORE,
+ )
+
+ def test_server_precedence(self) -> None:
+ """Always take allowed over ignored, ignored over blocked, and then block."""
+ config = InviteRulesConfig(
+ {
+ "allowed_servers": [allowed_user.domain],
+ "ignored_servers": [allowed_user.domain, ignored_user.domain],
+ "blocked_servers": [
+ allowed_user.domain,
+ ignored_user.domain,
+ blocked_user.domain,
+ ],
+ }
+ )
+ self.assertEqual(
+ config.get_invite_rule(allowed_user.to_string()), InviteRule.ALLOW
+ )
+ self.assertEqual(
+ config.get_invite_rule(ignored_user.to_string()), InviteRule.IGNORE
+ )
+ self.assertEqual(
+ config.get_invite_rule(blocked_user.to_string()), InviteRule.BLOCK
+ )
+
+ def test_server_glob(self) -> None:
+ """Test that glob patterns match"""
+ config = InviteRulesConfig({"blocked_servers": ["*.example.org"]})
+ self.assertEqual(
+ config.get_invite_rule(allowed_user.to_string()), InviteRule.BLOCK
+ )
+ self.assertEqual(
+ config.get_invite_rule(ignored_user.to_string()), InviteRule.BLOCK
+ )
+ self.assertEqual(
+ config.get_invite_rule(blocked_user.to_string()), InviteRule.BLOCK
+ )
+ self.assertEqual(
+ config.get_invite_rule(regular_user.to_string()), InviteRule.ALLOW
+ )
|