diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index cf81bcf52c..d5ac66a6ed 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -32,6 +32,7 @@ from synapse.events.utils import (
PowerLevelsContent,
SerializeEventConfig,
_split_field,
+ clone_event,
copy_and_fixup_power_levels_contents,
maybe_upsert_event_field,
prune_event,
@@ -611,6 +612,29 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
)
+class CloneEventTestCase(stdlib_unittest.TestCase):
+ def test_unsigned_is_copied(self) -> None:
+ original = make_event_from_dict(
+ {
+ "type": "A",
+ "event_id": "$test:domain",
+ "unsigned": {"a": 1, "b": 2},
+ },
+ RoomVersions.V1,
+ {"txn_id": "txn"},
+ )
+ original.internal_metadata.stream_ordering = 1234
+ self.assertEqual(original.internal_metadata.stream_ordering, 1234)
+
+ cloned = clone_event(original)
+ cloned.unsigned["b"] = 3
+
+ self.assertEqual(original.unsigned, {"a": 1, "b": 2})
+ self.assertEqual(cloned.unsigned, {"a": 1, "b": 3})
+ self.assertEqual(cloned.internal_metadata.stream_ordering, 1234)
+ self.assertEqual(cloned.internal_metadata.txn_id, "txn")
+
+
class SerializeEventTestCase(stdlib_unittest.TestCase):
def serialize(self, ev: EventBase, fields: Optional[List[str]]) -> JsonDict:
return serialize_event(
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 36684c2c91..88261450b1 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -67,6 +67,23 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")
+ def test_failed_edu_causes_500(self) -> None:
+ """If the EDU handler fails, /send should return a 500."""
+
+ async def failing_handler(_origin: str, _content: JsonDict) -> None:
+ raise Exception("bleh")
+
+ self.hs.get_federation_registry().register_edu_handler(
+ "FAIL_EDU_TYPE", failing_handler
+ )
+
+ channel = self.make_signed_federation_request(
+ "PUT",
+ "/_matrix/federation/v1/send/txn",
+ {"edus": [{"edu_type": "FAIL_EDU_TYPE", "content": {}}]},
+ )
+ self.assertEqual(500, channel.code, channel.result)
+
class ServerACLsTestCase(unittest.TestCase):
def test_blocked_server(self) -> None:
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index 190b79bf26..0237369998 100644
--- a/tests/federation/transport/test_server.py
+++ b/tests/federation/transport/test_server.py
@@ -59,7 +59,14 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
"/_matrix/federation/v1/send/txn_id_1234/",
content={
"edus": [
- {"edu_type": EduTypes.DEVICE_LIST_UPDATE, "content": {"foo": "bar"}}
+ {
+ "edu_type": EduTypes.DEVICE_LIST_UPDATE,
+ "content": {
+ "device_id": "QBUAZIFURK",
+ "stream_id": 0,
+ "user_id": "@user:id",
+ },
+ },
],
"pdus": [],
},
diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py
index c1d88f0176..c2015774a1 100644
--- a/tests/rest/admin/test_federation.py
+++ b/tests/rest/admin/test_federation.py
@@ -778,20 +778,81 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase):
self.assertEqual(number_rooms, len(channel.json_body["rooms"]))
self._check_fields(channel.json_body["rooms"])
- def _create_destination_rooms(self, number_rooms: int) -> None:
- """Create a number rooms for destination
+ def test_room_filtering(self) -> None:
+ """Tests that rooms are correctly filtered"""
+
+ # Create two rooms on the homeserver. Each has a different remote homeserver
+ # participating in it.
+ other_destination = "other.destination.org"
+ room_ids_self_dest = self._create_destination_rooms(2, destination=self.dest)
+ room_ids_other_dest = self._create_destination_rooms(
+ 1, destination=other_destination
+ )
+
+ # Ask for the rooms that `self.dest` is participating in.
+ channel = self.make_request("GET", self.url, access_token=self.admin_user_tok)
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Verify that we received only the rooms that `self.dest` is participating in.
+ # This assertion method name is a bit misleading. It does check that both lists
+ # contain the same items, and the same counts.
+ self.assertCountEqual(
+ [r["room_id"] for r in channel.json_body["rooms"]], room_ids_self_dest
+ )
+ self.assertEqual(channel.json_body["total"], len(room_ids_self_dest))
+
+ # Ask for the rooms that `other_destination` is participating in.
+ channel = self.make_request(
+ "GET",
+ self.url.replace(self.dest, other_destination),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # Verify that we received only the rooms that `other_destination` is
+ # participating in.
+ self.assertCountEqual(
+ [r["room_id"] for r in channel.json_body["rooms"]], room_ids_other_dest
+ )
+ self.assertEqual(channel.json_body["total"], len(room_ids_other_dest))
+
+ def _create_destination_rooms(
+ self,
+ number_rooms: int,
+ destination: Optional[str] = None,
+ ) -> List[str]:
+ """
+ Create the given number of rooms. The given `destination` homeserver will
+ be recorded as a participant.
Args:
number_rooms: Number of rooms to be created
+ destination: The domain of the homeserver that will be considered
+ as a participant in the rooms.
+
+ Returns:
+ The IDs of the rooms that have been created.
"""
+ room_ids = []
+
+ # If no destination was provided, default to `self.dest`.
+ if destination is None:
+ destination = self.dest
+
for _ in range(number_rooms):
room_id = self.helper.create_room_as(
self.admin_user, tok=self.admin_user_tok
)
+ room_ids.append(room_id)
+
self.get_success(
- self.store.store_destination_rooms_entries((self.dest,), room_id, 1234)
+ self.store.store_destination_rooms_entries(
+ (destination,), room_id, 1234
+ )
)
+ return room_ids
+
def _check_fields(self, content: List[JsonDict]) -> None:
"""Checks that the expected room attributes are present in content
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index 3a1f150082..3fb77fd9dd 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -20,7 +20,17 @@
#
import time
import urllib.parse
-from typing import Any, Collection, Dict, List, Optional, Tuple, Union
+from typing import (
+ Any,
+ BinaryIO,
+ Callable,
+ Collection,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
from unittest.mock import Mock
from urllib.parse import urlencode
@@ -34,8 +44,9 @@ import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
+from synapse.http.client import RawHeaders
from synapse.module_api import ModuleApi
-from synapse.rest.client import devices, login, logout, register
+from synapse.rest.client import account, devices, login, logout, profile, register
from synapse.rest.client.account import WhoamiRestServlet
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.server import HomeServer
@@ -48,6 +59,7 @@ from tests.handlers.test_saml import has_saml2
from tests.rest.client.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel
from tests.test_utils.html_parsers import TestHtmlParser
+from tests.test_utils.oidc import FakeOidcServer
from tests.unittest import HomeserverTestCase, override_config, skip_unless
try:
@@ -1421,7 +1433,19 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
class UsernamePickerTestCase(HomeserverTestCase):
"""Tests for the username picker flow of SSO login"""
- servlets = [login.register_servlets]
+ servlets = [
+ login.register_servlets,
+ profile.register_servlets,
+ account.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.http_client = Mock(spec=["get_file"])
+ self.http_client.get_file.side_effect = mock_get_file
+ hs = self.setup_test_homeserver(
+ proxied_blocklisted_http_client=self.http_client
+ )
+ return hs
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
@@ -1430,7 +1454,11 @@ class UsernamePickerTestCase(HomeserverTestCase):
config["oidc_config"] = {}
config["oidc_config"].update(TEST_OIDC_CONFIG)
config["oidc_config"]["user_mapping_provider"] = {
- "config": {"display_name_template": "{{ user.displayname }}"}
+ "config": {
+ "display_name_template": "{{ user.displayname }}",
+ "email_template": "{{ user.email }}",
+ "picture_template": "{{ user.picture }}",
+ }
}
# whitelist this client URI so we redirect straight to it rather than
@@ -1443,15 +1471,22 @@ class UsernamePickerTestCase(HomeserverTestCase):
d.update(build_synapse_client_resource_tree(self.hs))
return d
- def test_username_picker(self) -> None:
- """Test the happy path of a username picker flow."""
-
- fake_oidc_server = self.helper.fake_oidc_server()
-
+ def proceed_to_username_picker_page(
+ self,
+ fake_oidc_server: FakeOidcServer,
+ displayname: str,
+ email: str,
+ picture: str,
+ ) -> Tuple[str, str]:
# do the start of the login flow
channel, _ = self.helper.auth_via_oidc(
fake_oidc_server,
- {"sub": "tester", "displayname": "Jonny"},
+ {
+ "sub": "tester",
+ "displayname": displayname,
+ "picture": picture,
+ "email": email,
+ },
TEST_CLIENT_REDIRECT_URL,
)
@@ -1478,16 +1513,132 @@ class UsernamePickerTestCase(HomeserverTestCase):
)
session = username_mapping_sessions[session_id]
self.assertEqual(session.remote_user_id, "tester")
- self.assertEqual(session.display_name, "Jonny")
+ self.assertEqual(session.display_name, displayname)
+ self.assertEqual(session.emails, [email])
+ self.assertEqual(session.avatar_url, picture)
self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL)
# the expiry time should be about 15 minutes away
expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
+ return picker_url, session_id
+
+ def test_username_picker_use_displayname_avatar_and_email(self) -> None:
+ """Test the happy path of a username picker flow with using displayname, avatar and email."""
+
+ fake_oidc_server = self.helper.fake_oidc_server()
+
+ mxid = "@bobby:test"
+ displayname = "Jonny"
+ email = "bobby@test.com"
+ picture = "mxc://test/avatar_url"
+
+ picker_url, session_id = self.proceed_to_username_picker_page(
+ fake_oidc_server, displayname, email, picture
+ )
+
+ # Now, submit a username to the username picker, which should serve a redirect
+ # to the completion page.
+ # Also specify that we should use the provided displayname, avatar and email.
+ content = urlencode(
+ {
+ b"username": b"bobby",
+ b"use_display_name": b"true",
+ b"use_avatar": b"true",
+ b"use_email": email,
+ }
+ ).encode("utf8")
+ chan = self.make_request(
+ "POST",
+ path=picker_url,
+ content=content,
+ content_is_form=True,
+ custom_headers=[
+ ("Cookie", "username_mapping_session=" + session_id),
+ # old versions of twisted don't do form-parsing without a valid
+ # content-length header.
+ ("Content-Length", str(len(content))),
+ ],
+ )
+ self.assertEqual(chan.code, 302, chan.result)
+ location_headers = chan.headers.getRawHeaders("Location")
+ assert location_headers
+
+ # send a request to the completion page, which should 302 to the client redirectUrl
+ chan = self.make_request(
+ "GET",
+ path=location_headers[0],
+ custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
+ )
+ self.assertEqual(chan.code, 302, chan.result)
+ location_headers = chan.headers.getRawHeaders("Location")
+ assert location_headers
+
+ # ensure that the returned location matches the requested redirect URL
+ path, query = location_headers[0].split("?", 1)
+ self.assertEqual(path, "https://x")
+
+ # it will have url-encoded the params properly, so we'll have to parse them
+ params = urllib.parse.parse_qsl(
+ query, keep_blank_values=True, strict_parsing=True, errors="strict"
+ )
+ self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
+ self.assertEqual(params[2][0], "loginToken")
+
+ # fish the login token out of the returned redirect uri
+ login_token = params[2][1]
+
+ # finally, submit the matrix login token to the login API, which gives us our
+ # matrix access token, mxid, and device id.
+ chan = self.make_request(
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+ self.assertEqual(chan.json_body["user_id"], mxid)
+
+ # ensure the displayname and avatar from the OIDC response have been configured for the user.
+ channel = self.make_request(
+ "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"]
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertIn("mxc://test", channel.json_body["avatar_url"])
+ self.assertEqual(displayname, channel.json_body["displayname"])
+
+ # ensure the email from the OIDC response has been configured for the user.
+ channel = self.make_request(
+ "GET", "/account/3pid", access_token=chan.json_body["access_token"]
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(email, channel.json_body["threepids"][0]["address"])
+
+ def test_username_picker_dont_use_displayname_avatar_or_email(self) -> None:
+ """Test the happy path of a username picker flow without using displayname, avatar or email."""
+
+ fake_oidc_server = self.helper.fake_oidc_server()
+
+ mxid = "@bobby:test"
+ displayname = "Jonny"
+ email = "bobby@test.com"
+ picture = "mxc://test/avatar_url"
+ username = "bobby"
+
+ picker_url, session_id = self.proceed_to_username_picker_page(
+ fake_oidc_server, displayname, email, picture
+ )
+
# Now, submit a username to the username picker, which should serve a redirect
- # to the completion page
- content = urlencode({b"username": b"bobby"}).encode("utf8")
+ # to the completion page.
+ # Also specify that we should not use the provided displayname, avatar or email.
+ content = urlencode(
+ {
+ b"username": username,
+ b"use_display_name": b"false",
+ b"use_avatar": b"false",
+ }
+ ).encode("utf8")
chan = self.make_request(
"POST",
path=picker_url,
@@ -1536,4 +1687,29 @@ class UsernamePickerTestCase(HomeserverTestCase):
content={"type": "m.login.token", "token": login_token},
)
self.assertEqual(chan.code, 200, chan.result)
- self.assertEqual(chan.json_body["user_id"], "@bobby:test")
+ self.assertEqual(chan.json_body["user_id"], mxid)
+
+ # ensure the displayname and avatar from the OIDC response have not been configured for the user.
+ channel = self.make_request(
+ "GET", "/profile/" + mxid, access_token=chan.json_body["access_token"]
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertNotIn("avatar_url", channel.json_body)
+ self.assertEqual(username, channel.json_body["displayname"])
+
+ # ensure the email from the OIDC response has not been configured for the user.
+ channel = self.make_request(
+ "GET", "/account/3pid", access_token=chan.json_body["access_token"]
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertListEqual([], channel.json_body["threepids"])
+
+
+async def mock_get_file(
+ url: str,
+ output_stream: BinaryIO,
+ max_size: Optional[int] = None,
+ headers: Optional[RawHeaders] = None,
+ is_allowed_content_type: Optional[Callable[[str], bool]] = None,
+) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
+ return 0, {b"Content-Type": [b"image/png"]}, "", 200
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 09a5d64349..ceae40498e 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -163,7 +163,12 @@ class RetentionTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(2, len(events), "events retrieved from database")
filtered_events = self.get_success(
- filter_events_for_client(storage_controllers, self.user_id, events)
+ filter_events_for_client(
+ storage_controllers,
+ self.user_id,
+ events,
+ msc4115_membership_on_events=True,
+ )
)
# We should only get one event back.
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index b796163dcb..d398cead1c 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -48,7 +48,16 @@ from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.rest import admin
-from synapse.rest.client import account, directory, login, profile, register, room, sync
+from synapse.rest.client import (
+ account,
+ directory,
+ knock,
+ login,
+ profile,
+ register,
+ room,
+ sync,
+)
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias, UserID, create_requester
from synapse.util import Clock
@@ -733,7 +742,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(32, channel.resource_usage.db_txn_count)
+ self.assertEqual(33, channel.resource_usage.db_txn_count)
def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id
@@ -746,7 +755,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(34, channel.resource_usage.db_txn_count)
+ self.assertEqual(35, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id
@@ -1154,6 +1163,7 @@ class RoomJoinTestCase(RoomBase):
admin.register_servlets,
login.register_servlets,
room.register_servlets,
+ knock.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -1167,6 +1177,8 @@ class RoomJoinTestCase(RoomBase):
self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
+ self.store = hs.get_datastores().main
+
def test_spam_checker_may_join_room_deprecated(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly called
and blocks room joins when needed.
@@ -1317,6 +1329,57 @@ class RoomJoinTestCase(RoomBase):
expect_additional_fields=return_value[1],
)
+ def test_suspended_user_cannot_join_room(self) -> None:
+ # set the user as suspended
+ self.get_success(self.store.set_user_suspended_status(self.user2, True))
+
+ channel = self.make_request(
+ "POST", f"/join/{self.room1}", access_token=self.tok2
+ )
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(
+ channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
+ )
+
+ channel = self.make_request(
+ "POST", f"/rooms/{self.room1}/join", access_token=self.tok2
+ )
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(
+ channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
+ )
+
+ def test_suspended_user_cannot_knock_on_room(self) -> None:
+ # set the user as suspended
+ self.get_success(self.store.set_user_suspended_status(self.user2, True))
+
+ channel = self.make_request(
+ "POST",
+ f"/_matrix/client/v3/knock/{self.room1}",
+ access_token=self.tok2,
+ content={},
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(
+ channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
+ )
+
+ def test_suspended_user_cannot_invite_to_room(self) -> None:
+ # set the user as suspended
+ self.get_success(self.store.set_user_suspended_status(self.user1, True))
+
+ # first user invites second user
+ channel = self.make_request(
+ "POST",
+ f"/rooms/{self.room1}/invite",
+ access_token=self.tok1,
+ content={"user_id": self.user2},
+ )
+ self.assertEqual(
+ channel.json_body["errcode"], "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED"
+ )
+
class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 505465d529..14e3871dc1 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -43,7 +43,6 @@ class RegistrationStoreTestCase(HomeserverTestCase):
self.assertEqual(
UserInfo(
- # TODO(paul): Surely this field should be 'user_id', not 'name'
user_id=UserID.from_string(self.user_id),
is_admin=False,
is_guest=False,
@@ -57,6 +56,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
locked=False,
is_shadow_banned=False,
approved=True,
+ suspended=False,
),
(self.get_success(self.store.get_user_by_id(self.user_id))),
)
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index 1eab89f140..340642b7e7 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -71,17 +71,16 @@ class EventSearchInsertionTest(HomeserverTestCase):
store.search_msgs([room_id], "hi bob", ["content.body"])
)
self.assertEqual(result.get("count"), 1)
- if isinstance(store.database_engine, PostgresEngine):
- self.assertIn("hi", result.get("highlights"))
- self.assertIn("bob", result.get("highlights"))
+ self.assertIn("hi", result.get("highlights"))
+ self.assertIn("bob", result.get("highlights"))
# Check that search works for an unrelated message
result = self.get_success(
store.search_msgs([room_id], "another", ["content.body"])
)
self.assertEqual(result.get("count"), 1)
- if isinstance(store.database_engine, PostgresEngine):
- self.assertIn("another", result.get("highlights"))
+
+ self.assertIn("another", result.get("highlights"))
# Check that search works for a search term that overlaps with the message
# containing a null byte and an unrelated message.
@@ -90,8 +89,8 @@ class EventSearchInsertionTest(HomeserverTestCase):
result = self.get_success(
store.search_msgs([room_id], "hi alice", ["content.body"])
)
- if isinstance(store.database_engine, PostgresEngine):
- self.assertIn("alice", result.get("highlights"))
+
+ self.assertIn("alice", result.get("highlights"))
def test_non_string(self) -> None:
"""Test that non-string `value`s are not inserted into `event_search`.
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index e51f72d65f..3e2100eab4 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -21,13 +21,19 @@ import logging
from typing import Optional
from unittest.mock import patch
+from synapse.api.constants import EventUnsignedContentFields
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
-from synapse.types import JsonDict, create_requester
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.types import create_requester
from synapse.visibility import filter_events_for_client, filter_events_for_server
from tests import unittest
+from tests.test_utils.event_injection import inject_event, inject_member_event
+from tests.unittest import HomeserverTestCase
from tests.utils import create_room
logger = logging.getLogger(__name__)
@@ -56,15 +62,31 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
#
# before we do that, we persist some other events to act as state.
- self._inject_visibility("@admin:hs", "joined")
+ self.get_success(
+ inject_visibility_event(self.hs, TEST_ROOM_ID, "@admin:hs", "joined")
+ )
for i in range(10):
- self._inject_room_member("@resident%i:hs" % i)
+ self.get_success(
+ inject_member_event(
+ self.hs,
+ TEST_ROOM_ID,
+ "@resident%i:hs" % i,
+ "join",
+ )
+ )
events_to_filter = []
for i in range(10):
- user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
- evt = self._inject_room_member(user, extra_content={"a": "b"})
+ evt = self.get_success(
+ inject_member_event(
+ self.hs,
+ TEST_ROOM_ID,
+ "@user%i:%s" % (i, "test_server" if i == 5 else "other_server"),
+ "join",
+ extra_content={"a": "b"},
+ )
+ )
events_to_filter.append(evt)
filtered = self.get_success(
@@ -90,8 +112,19 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
def test_filter_outlier(self) -> None:
# outlier events must be returned, for the good of the collective federation
- self._inject_room_member("@resident:remote_hs")
- self._inject_visibility("@resident:remote_hs", "joined")
+ self.get_success(
+ inject_member_event(
+ self.hs,
+ TEST_ROOM_ID,
+ "@resident:remote_hs",
+ "join",
+ )
+ )
+ self.get_success(
+ inject_visibility_event(
+ self.hs, TEST_ROOM_ID, "@resident:remote_hs", "joined"
+ )
+ )
outlier = self._inject_outlier()
self.assertEqual(
@@ -110,7 +143,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
)
# it should also work when there are other events in the list
- evt = self._inject_message("@unerased:local_hs")
+ evt = self.get_success(
+ inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs")
+ )
filtered = self.get_success(
filter_events_for_server(
@@ -150,19 +185,34 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# change in the middle of them.
events_to_filter = []
- evt = self._inject_message("@unerased:local_hs")
+ evt = self.get_success(
+ inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs")
+ )
events_to_filter.append(evt)
- evt = self._inject_message("@erased:local_hs")
+ evt = self.get_success(
+ inject_message_event(self.hs, TEST_ROOM_ID, "@erased:local_hs")
+ )
events_to_filter.append(evt)
- evt = self._inject_room_member("@joiner:remote_hs")
+ evt = self.get_success(
+ inject_member_event(
+ self.hs,
+ TEST_ROOM_ID,
+ "@joiner:remote_hs",
+ "join",
+ )
+ )
events_to_filter.append(evt)
- evt = self._inject_message("@unerased:local_hs")
+ evt = self.get_success(
+ inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs")
+ )
events_to_filter.append(evt)
- evt = self._inject_message("@erased:local_hs")
+ evt = self.get_success(
+ inject_message_event(self.hs, TEST_ROOM_ID, "@erased:local_hs")
+ )
events_to_filter.append(evt)
# the erasey user gets erased
@@ -200,76 +250,6 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
for i in (1, 4):
self.assertNotIn("body", filtered[i].content)
- def _inject_visibility(self, user_id: str, visibility: str) -> EventBase:
- content = {"history_visibility": visibility}
- builder = self.event_builder_factory.for_room_version(
- RoomVersions.V1,
- {
- "type": "m.room.history_visibility",
- "sender": user_id,
- "state_key": "",
- "room_id": TEST_ROOM_ID,
- "content": content,
- },
- )
-
- event, unpersisted_context = self.get_success(
- self.event_creation_handler.create_new_client_event(builder)
- )
- context = self.get_success(unpersisted_context.persist(event))
- self.get_success(self._persistence.persist_event(event, context))
- return event
-
- def _inject_room_member(
- self,
- user_id: str,
- membership: str = "join",
- extra_content: Optional[JsonDict] = None,
- ) -> EventBase:
- content = {"membership": membership}
- content.update(extra_content or {})
- builder = self.event_builder_factory.for_room_version(
- RoomVersions.V1,
- {
- "type": "m.room.member",
- "sender": user_id,
- "state_key": user_id,
- "room_id": TEST_ROOM_ID,
- "content": content,
- },
- )
-
- event, unpersisted_context = self.get_success(
- self.event_creation_handler.create_new_client_event(builder)
- )
- context = self.get_success(unpersisted_context.persist(event))
-
- self.get_success(self._persistence.persist_event(event, context))
- return event
-
- def _inject_message(
- self, user_id: str, content: Optional[JsonDict] = None
- ) -> EventBase:
- if content is None:
- content = {"body": "testytest", "msgtype": "m.text"}
- builder = self.event_builder_factory.for_room_version(
- RoomVersions.V1,
- {
- "type": "m.room.message",
- "sender": user_id,
- "room_id": TEST_ROOM_ID,
- "content": content,
- },
- )
-
- event, unpersisted_context = self.get_success(
- self.event_creation_handler.create_new_client_event(builder)
- )
- context = self.get_success(unpersisted_context.persist(event))
-
- self.get_success(self._persistence.persist_event(event, context))
- return event
-
def _inject_outlier(self) -> EventBase:
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
@@ -292,7 +272,122 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
return event
-class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
+class FilterEventsForClientTestCase(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def test_joined_history_visibility(self) -> None:
+ # User joins and leaves room. Should be able to see the join and leave,
+ # and messages sent between the two, but not before or after.
+
+ self.register_user("resident", "p1")
+ resident_token = self.login("resident", "p1")
+ room_id = self.helper.create_room_as("resident", tok=resident_token)
+
+ self.get_success(
+ inject_visibility_event(self.hs, room_id, "@resident:test", "joined")
+ )
+ before_event = self.get_success(
+ inject_message_event(self.hs, room_id, "@resident:test", body="before")
+ )
+ join_event = self.get_success(
+ inject_member_event(self.hs, room_id, "@joiner:test", "join")
+ )
+ during_event = self.get_success(
+ inject_message_event(self.hs, room_id, "@resident:test", body="during")
+ )
+ leave_event = self.get_success(
+ inject_member_event(self.hs, room_id, "@joiner:test", "leave")
+ )
+ after_event = self.get_success(
+ inject_message_event(self.hs, room_id, "@resident:test", body="after")
+ )
+
+ # We have to reload the events from the db, to ensure that prev_content is
+ # populated.
+ events_to_filter = [
+ self.get_success(
+ self.hs.get_storage_controllers().main.get_event(
+ e.event_id,
+ get_prev_content=True,
+ )
+ )
+ for e in [
+ before_event,
+ join_event,
+ during_event,
+ leave_event,
+ after_event,
+ ]
+ ]
+
+ # Now run the events through the filter, and check that we can see the events
+ # we expect, and that the membership prop is as expected.
+ #
+ # We deliberately do the queries for both users upfront; this simulates
+ # concurrent queries on the server, and helps ensure that we aren't
+ # accidentally serving the same event object (with the same unsigned.membership
+ # property) to both users.
+ joiner_filtered_events = self.get_success(
+ filter_events_for_client(
+ self.hs.get_storage_controllers(),
+ "@joiner:test",
+ events_to_filter,
+ msc4115_membership_on_events=True,
+ )
+ )
+ resident_filtered_events = self.get_success(
+ filter_events_for_client(
+ self.hs.get_storage_controllers(),
+ "@resident:test",
+ events_to_filter,
+ msc4115_membership_on_events=True,
+ )
+ )
+
+ # The joiner should be able to seem the join and leave,
+ # and messages sent between the two, but not before or after.
+ self.assertEqual(
+ [e.event_id for e in [join_event, during_event, leave_event]],
+ [e.event_id for e in joiner_filtered_events],
+ )
+ self.assertEqual(
+ ["join", "join", "leave"],
+ [
+ e.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP]
+ for e in joiner_filtered_events
+ ],
+ )
+
+ # The resident user should see all the events.
+ self.assertEqual(
+ [
+ e.event_id
+ for e in [
+ before_event,
+ join_event,
+ during_event,
+ leave_event,
+ after_event,
+ ]
+ ],
+ [e.event_id for e in resident_filtered_events],
+ )
+ self.assertEqual(
+ ["join", "join", "join", "join", "join"],
+ [
+ e.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP]
+ for e in resident_filtered_events
+ ],
+ )
+
+
+class FilterEventsOutOfBandEventsForClientTestCase(
+ unittest.FederatingHomeserverTestCase
+):
def test_out_of_band_invite_rejection(self) -> None:
# this is where we have received an invite event over federation, and then
# rejected it.
@@ -341,15 +436,24 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
)
# the invited user should be able to see both the invite and the rejection
+ filtered_events = self.get_success(
+ filter_events_for_client(
+ self.hs.get_storage_controllers(),
+ "@user:test",
+ [invite_event, reject_event],
+ msc4115_membership_on_events=True,
+ )
+ )
self.assertEqual(
- self.get_success(
- filter_events_for_client(
- self.hs.get_storage_controllers(),
- "@user:test",
- [invite_event, reject_event],
- )
- ),
- [invite_event, reject_event],
+ [e.event_id for e in filtered_events],
+ [e.event_id for e in [invite_event, reject_event]],
+ )
+ self.assertEqual(
+ ["invite", "leave"],
+ [
+ e.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP]
+ for e in filtered_events
+ ],
)
# other users should see neither
@@ -359,7 +463,39 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
self.hs.get_storage_controllers(),
"@other:test",
[invite_event, reject_event],
+ msc4115_membership_on_events=True,
)
),
[],
)
+
+
+async def inject_visibility_event(
+ hs: HomeServer,
+ room_id: str,
+ sender: str,
+ visibility: str,
+) -> EventBase:
+ return await inject_event(
+ hs,
+ type="m.room.history_visibility",
+ sender=sender,
+ state_key="",
+ room_id=room_id,
+ content={"history_visibility": visibility},
+ )
+
+
+async def inject_message_event(
+ hs: HomeServer,
+ room_id: str,
+ sender: str,
+ body: Optional[str] = "testytest",
+) -> EventBase:
+ return await inject_event(
+ hs,
+ type="m.room.message",
+ sender=sender,
+ room_id=room_id,
+ content={"body": body, "msgtype": "m.text"},
+ )
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index 3df053493b..5d38718a50 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -1,3 +1,5 @@
+from parameterized import parameterized
+
from synapse.util.caches.stream_change_cache import StreamChangeCache
from tests import unittest
@@ -161,7 +163,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
self.assertFalse(cache.has_any_entity_changed(2))
self.assertFalse(cache.has_any_entity_changed(3))
- def test_get_entities_changed(self) -> None:
+ @parameterized.expand([(0,), (1000000000,)])
+ def test_get_entities_changed(self, perf_factor: int) -> None:
"""
StreamChangeCache.get_entities_changed will return the entities in the
given list that have changed since the provided stream ID. If the
@@ -178,7 +181,9 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# get the ones after that point.
self.assertEqual(
cache.get_entities_changed(
- ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2
+ ["user@foo.com", "bar@baz.net", "user@elsewhere.org"],
+ stream_pos=2,
+ _perf_factor=perf_factor,
),
{"bar@baz.net", "user@elsewhere.org"},
)
@@ -195,6 +200,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
"not@here.website",
],
stream_pos=2,
+ _perf_factor=perf_factor,
),
{"bar@baz.net", "user@elsewhere.org"},
)
@@ -210,6 +216,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
"not@here.website",
],
stream_pos=0,
+ _perf_factor=perf_factor,
),
{"user@foo.com", "bar@baz.net", "user@elsewhere.org", "not@here.website"},
)
@@ -217,7 +224,11 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# Query a subset of the entries mid-way through the stream. We should
# only get back the subset.
self.assertEqual(
- cache.get_entities_changed(["bar@baz.net"], stream_pos=2),
+ cache.get_entities_changed(
+ ["bar@baz.net"],
+ stream_pos=2,
+ _perf_factor=perf_factor,
+ ),
{"bar@baz.net"},
)
|