summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/logging/test_opentracing.py2
-rw-r--r--tests/rest/admin/test_room.py4
-rw-r--r--tests/rest/admin/test_user.py2
-rw-r--r--tests/rest/client/test_account.py5
-rw-r--r--tests/rest/client/test_profile.py10
-rw-r--r--tests/rest/client/test_relations.py2
-rw-r--r--tests/rest/client/test_rooms.py92
-rw-r--r--tests/server.py39
-rw-r--r--tests/storage/databases/main/test_room.py69
-rw-r--r--tests/storage/test_event_push_actions.py2
-rw-r--r--tests/test_state.py4
-rw-r--r--tests/utils.py4
12 files changed, 201 insertions, 34 deletions
diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py

index e430941d27..40148d503c 100644 --- a/tests/logging/test_opentracing.py +++ b/tests/logging/test_opentracing.py
@@ -50,7 +50,7 @@ class LogContextScopeManagerTestCase(TestCase): # global variables that power opentracing. We create our own tracer instance # and test with it. - scope_manager = LogContextScopeManager({}) + scope_manager = LogContextScopeManager() config = jaeger_client.config.Config( config={}, service_name="test", scope_manager=scope_manager ) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index ca6af9417b..230dc76f72 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py
@@ -1579,8 +1579,8 @@ class RoomTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) - self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id")) - self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name")) + self.assertEqual(room_id, channel.json_body["rooms"][0].get("room_id")) + self.assertEqual("ж", channel.json_body["rooms"][0].get("name")) def test_single_room(self) -> None: """Test that a single room can be requested correctly""" diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 0d44102237..e32aaadb98 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py
@@ -1488,7 +1488,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): if channel.code != HTTPStatus.OK: raise HttpResponseException( - channel.code, channel.result["reason"], channel.json_body + channel.code, channel.result["reason"], channel.result["body"] ) # Set monthly active users to the limit diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index a43a137273..1f9b65351e 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py
@@ -949,7 +949,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): client_secret: str, next_link: Optional[str] = None, expect_code: int = 200, - ) -> str: + ) -> Optional[str]: """Request a validation token to add an email address to a user's account Args: @@ -959,7 +959,8 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): expect_code: Expected return code of the call Returns: - The ID of the new threepid validation session + The ID of the new threepid validation session, or None if the response + did not contain a session ID. """ body = {"client_secret": client_secret, "email": email, "send_attempt": 1} if next_link: diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index 29bed0e872..8de5a342ae 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py
@@ -153,18 +153,22 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) - def _get_displayname(self, name: Optional[str] = None) -> str: + def _get_displayname(self, name: Optional[str] = None) -> Optional[str]: channel = self.make_request( "GET", "/profile/%s/displayname" % (name or self.owner,) ) self.assertEqual(channel.code, 200, channel.result) - return channel.json_body["displayname"] + # FIXME: If a user has no displayname set, Synapse returns 200 and omits a + # displayname from the response. This contradicts the spec, see #13137. + return channel.json_body.get("displayname") - def _get_avatar_url(self, name: Optional[str] = None) -> str: + def _get_avatar_url(self, name: Optional[str] = None) -> Optional[str]: channel = self.make_request( "GET", "/profile/%s/avatar_url" % (name or self.owner,) ) self.assertEqual(channel.code, 200, channel.result) + # FIXME: If a user has no avatar set, Synapse returns 200 and omits an + # avatar_url from the response. This contradicts the spec, see #13137. return channel.json_body.get("avatar_url") @unittest.override_config({"max_avatar_size": 50}) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index aa84906548..ad03eee17b 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py
@@ -800,7 +800,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): ) expected_event_ids.append(channel.json_body["event_id"]) - prev_token = "" + prev_token: Optional[str] = "" found_event_ids: List[str] = [] for _ in range(20): from_token = "" diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 35c59ee9e0..1ccd96a207 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py
@@ -18,7 +18,7 @@ """Tests REST events for /rooms paths.""" import json -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from unittest.mock import Mock, call from urllib import parse as urlparse @@ -33,7 +33,9 @@ from synapse.api.constants import ( EventContentFields, EventTypes, Membership, + PublicRoomsFilterFields, RelationTypes, + RoomTypes, ) from synapse.api.errors import Codes, HttpResponseException from synapse.handlers.pagination import PurgeStatus @@ -1858,6 +1860,90 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) +class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + + config = self.default_config() + config["allow_public_rooms_without_auth"] = True + config["experimental_features"] = {"msc3827_enabled": True} + self.hs = self.setup_test_homeserver(config=config) + self.url = b"/_matrix/client/r0/publicRooms" + + return self.hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + user = self.register_user("alice", "pass") + self.token = self.login(user, "pass") + + # Create a room + self.helper.create_room_as( + user, + is_public=True, + extra_content={"visibility": "public"}, + tok=self.token, + ) + # Create a space + self.helper.create_room_as( + user, + is_public=True, + extra_content={ + "visibility": "public", + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE}, + }, + tok=self.token, + ) + + def make_public_rooms_request( + self, room_types: Union[List[Union[str, None]], None] + ) -> Tuple[List[Dict[str, Any]], int]: + channel = self.make_request( + "POST", + self.url, + {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}}, + self.token, + ) + chunk = channel.json_body["chunk"] + count = channel.json_body["total_room_count_estimate"] + + self.assertEqual(len(chunk), count) + + return chunk, count + + def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None: + chunk, count = self.make_public_rooms_request(None) + + self.assertEqual(count, 2) + + def test_returns_only_rooms_based_on_filter(self) -> None: + chunk, count = self.make_public_rooms_request([None]) + + self.assertEqual(count, 1) + self.assertEqual(chunk[0].get("org.matrix.msc3827.room_type", None), None) + + def test_returns_only_space_based_on_filter(self) -> None: + chunk, count = self.make_public_rooms_request(["m.space"]) + + self.assertEqual(count, 1) + self.assertEqual(chunk[0].get("org.matrix.msc3827.room_type", None), "m.space") + + def test_returns_both_rooms_and_space_based_on_filter(self) -> None: + chunk, count = self.make_public_rooms_request(["m.space", None]) + + self.assertEqual(count, 2) + + def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None: + chunk, count = self.make_public_rooms_request([]) + + self.assertEqual(count, 2) + + class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): """Test that we correctly fallback to local filtering if a remote server doesn't support search. @@ -1882,7 +1968,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): "Simple test for searching rooms over federation" self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined] - search_filter = {"generic_search_term": "foobar"} + search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} channel = self.make_request( "POST", @@ -1911,7 +1997,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): make_awaitable({}), ) - search_filter = {"generic_search_term": "foobar"} + search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} channel = self.make_request( "POST", diff --git a/tests/server.py b/tests/server.py
index b9f465971f..df3f1564c9 100644 --- a/tests/server.py +++ b/tests/server.py
@@ -43,6 +43,7 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import ( IAddress, + IConsumer, IHostnameResolver, IProtocol, IPullProducer, @@ -53,11 +54,7 @@ from twisted.internet.interfaces import ( ITransport, ) from twisted.python.failure import Failure -from twisted.test.proto_helpers import ( - AccumulatingProtocol, - MemoryReactor, - MemoryReactorClock, -) +from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.web.http_headers import Headers from twisted.web.resource import IResource from twisted.web.server import Request, Site @@ -96,6 +93,7 @@ class TimedOutException(Exception): """ +@implementer(IConsumer) @attr.s(auto_attribs=True) class FakeChannel: """ @@ -104,7 +102,7 @@ class FakeChannel: """ site: Union[Site, "FakeSite"] - _reactor: MemoryReactor + _reactor: MemoryReactorClock result: dict = attr.Factory(dict) _ip: str = "127.0.0.1" _producer: Optional[Union[IPullProducer, IPushProducer]] = None @@ -122,7 +120,7 @@ class FakeChannel: self._request = request @property - def json_body(self): + def json_body(self) -> JsonDict: return json.loads(self.text_body) @property @@ -140,7 +138,7 @@ class FakeChannel: return self.result.get("done", False) @property - def code(self): + def code(self) -> int: if not self.result: raise Exception("No result yet.") return int(self.result["code"]) @@ -160,7 +158,7 @@ class FakeChannel: self.result["reason"] = reason self.result["headers"] = headers - def write(self, content): + def write(self, content: bytes) -> None: assert isinstance(content, bytes), "Should be bytes! " + repr(content) if "body" not in self.result: @@ -168,11 +166,16 @@ class FakeChannel: self.result["body"] += content - def registerProducer(self, producer, streaming): + # Type ignore: mypy doesn't like the fact that producer isn't an IProducer. + def registerProducer( # type: ignore[override] + self, + producer: Union[IPullProducer, IPushProducer], + streaming: bool, + ) -> None: self._producer = producer self.producerStreaming = streaming - def _produce(): + def _produce() -> None: if self._producer: self._producer.resumeProducing() self._reactor.callLater(0.1, _produce) @@ -180,31 +183,32 @@ class FakeChannel: if not streaming: self._reactor.callLater(0.0, _produce) - def unregisterProducer(self): + def unregisterProducer(self) -> None: if self._producer is None: return self._producer = None - def requestDone(self, _self): + def requestDone(self, _self: Request) -> None: self.result["done"] = True if isinstance(_self, SynapseRequest): + assert _self.logcontext is not None self.resource_usage = _self.logcontext.get_resource_usage() - def getPeer(self): + def getPeer(self) -> IAddress: # We give an address so that getClientAddress/getClientIP returns a non null entry, # causing us to record the MAU return address.IPv4Address("TCP", self._ip, 3423) - def getHost(self): + def getHost(self) -> IAddress: # this is called by Request.__init__ to configure Request.host. return address.IPv4Address("TCP", "127.0.0.1", 8888) - def isSecure(self): + def isSecure(self) -> bool: return False @property - def transport(self): + def transport(self) -> "FakeChannel": return self def await_result(self, timeout_ms: int = 1000) -> None: @@ -830,7 +834,6 @@ def setup_test_homeserver( # Mock TLS hs.tls_server_context_factory = Mock() - hs.tls_client_options_factory = Mock() hs.setup() if homeserver_to_use == TestHomeServer: diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
index 9abd0cb446..1edb619630 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py
@@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json + +from synapse.api.constants import RoomTypes from synapse.rest import admin from synapse.rest.client import login, room from synapse.storage.databases.main.room import _BackgroundUpdates @@ -91,3 +94,69 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): ) ) self.assertEqual(room_creator_after, self.user_id) + + def test_background_add_room_type_column(self): + """Test that the background update to populate the `room_type` column in + `room_stats_state` works properly. + """ + + # Create a room without a type + room_id = self._generate_room() + + # Get event_id of the m.room.create event + event_id = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table="current_state_events", + keyvalues={ + "room_id": room_id, + "type": "m.room.create", + }, + retcol="event_id", + ) + ) + + # Fake a room creation event with a room type + event = { + "content": { + "creator": "@user:server.org", + "room_version": "9", + "type": RoomTypes.SPACE, + }, + "type": "m.room.create", + } + self.get_success( + self.store.db_pool.simple_update( + table="event_json", + keyvalues={"event_id": event_id}, + updatevalues={"json": json.dumps(event)}, + desc="test", + ) + ) + + # Insert and run the background update + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN, + "progress_json": "{}", + }, + ) + ) + + # ... and tell the DataStore that it hasn't finished all updates yet + self.store.db_pool.updates._all_done = False + + # Now let's actually drive the updates to completion + self.wait_for_background_updates() + + # Make sure the background update filled in the room type + room_type_after = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table="room_stats_state", + keyvalues={"room_id": room_id}, + retcol="room_type", + allow_none=True, + ) + ) + self.assertEqual(room_type_after, RoomTypes.SPACE) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 852b663387..e68126777f 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py
@@ -86,6 +86,8 @@ class EventPushActionsStoreTestCase(HomeserverTestCase): event.internal_metadata.is_outlier.return_value = False event.depth = stream + self.store._events_stream_cache.entity_has_changed(room_id, stream) + self.get_success( self.store.db_pool.simple_insert( table="events", diff --git a/tests/test_state.py b/tests/test_state.py
index b005dd8d0f..7b3f52f68e 100644 --- a/tests/test_state.py +++ b/tests/test_state.py
@@ -131,7 +131,9 @@ class _DummyStore: async def get_room_version_id(self, room_id): return RoomVersions.V1.identifier - async def get_state_group_for_events(self, event_ids): + async def get_state_group_for_events( + self, event_ids, await_full_state: bool = True + ): res = {} for event in event_ids: res[event] = self._event_to_state_group[event] diff --git a/tests/utils.py b/tests/utils.py
index cabb2c0dec..aca6a0083b 100644 --- a/tests/utils.py +++ b/tests/utils.py
@@ -64,7 +64,7 @@ def setupdb(): password=POSTGRES_PASSWORD, dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) - db_conn.autocommit = True + db_engine.attempt_to_set_autocommit(db_conn, autocommit=True) cur = db_conn.cursor() cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) cur.execute( @@ -94,7 +94,7 @@ def setupdb(): password=POSTGRES_PASSWORD, dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) - db_conn.autocommit = True + db_engine.attempt_to_set_autocommit(db_conn, autocommit=True) cur = db_conn.cursor() cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) cur.close()