diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py
index e430941d27..40148d503c 100644
--- a/tests/logging/test_opentracing.py
+++ b/tests/logging/test_opentracing.py
@@ -50,7 +50,7 @@ class LogContextScopeManagerTestCase(TestCase):
# global variables that power opentracing. We create our own tracer instance
# and test with it.
- scope_manager = LogContextScopeManager({})
+ scope_manager = LogContextScopeManager()
config = jaeger_client.config.Config(
config={}, service_name="test", scope_manager=scope_manager
)
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index ca6af9417b..230dc76f72 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1579,8 +1579,8 @@ class RoomTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok,
)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
- self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id"))
- self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name"))
+ self.assertEqual(room_id, channel.json_body["rooms"][0].get("room_id"))
+ self.assertEqual("ж", channel.json_body["rooms"][0].get("name"))
def test_single_room(self) -> None:
"""Test that a single room can be requested correctly"""
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 0d44102237..e32aaadb98 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1488,7 +1488,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
if channel.code != HTTPStatus.OK:
raise HttpResponseException(
- channel.code, channel.result["reason"], channel.json_body
+ channel.code, channel.result["reason"], channel.result["body"]
)
# Set monthly active users to the limit
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index a43a137273..1f9b65351e 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -949,7 +949,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
client_secret: str,
next_link: Optional[str] = None,
expect_code: int = 200,
- ) -> str:
+ ) -> Optional[str]:
"""Request a validation token to add an email address to a user's account
Args:
@@ -959,7 +959,8 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
expect_code: Expected return code of the call
Returns:
- The ID of the new threepid validation session
+ The ID of the new threepid validation session, or None if the response
+ did not contain a session ID.
"""
body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
if next_link:
diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py
index 29bed0e872..8de5a342ae 100644
--- a/tests/rest/client/test_profile.py
+++ b/tests/rest/client/test_profile.py
@@ -153,18 +153,22 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 400, channel.result)
- def _get_displayname(self, name: Optional[str] = None) -> str:
+ def _get_displayname(self, name: Optional[str] = None) -> Optional[str]:
channel = self.make_request(
"GET", "/profile/%s/displayname" % (name or self.owner,)
)
self.assertEqual(channel.code, 200, channel.result)
- return channel.json_body["displayname"]
+ # FIXME: If a user has no displayname set, Synapse returns 200 and omits a
+ # displayname from the response. This contradicts the spec, see #13137.
+ return channel.json_body.get("displayname")
- def _get_avatar_url(self, name: Optional[str] = None) -> str:
+ def _get_avatar_url(self, name: Optional[str] = None) -> Optional[str]:
channel = self.make_request(
"GET", "/profile/%s/avatar_url" % (name or self.owner,)
)
self.assertEqual(channel.code, 200, channel.result)
+ # FIXME: If a user has no avatar set, Synapse returns 200 and omits an
+ # avatar_url from the response. This contradicts the spec, see #13137.
return channel.json_body.get("avatar_url")
@unittest.override_config({"max_avatar_size": 50})
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index aa84906548..ad03eee17b 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -800,7 +800,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
)
expected_event_ids.append(channel.json_body["event_id"])
- prev_token = ""
+ prev_token: Optional[str] = ""
found_event_ids: List[str] = []
for _ in range(20):
from_token = ""
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 35c59ee9e0..1ccd96a207 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -18,7 +18,7 @@
"""Tests REST events for /rooms paths."""
import json
-from typing import Any, Dict, Iterable, List, Optional, Union
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from unittest.mock import Mock, call
from urllib import parse as urlparse
@@ -33,7 +33,9 @@ from synapse.api.constants import (
EventContentFields,
EventTypes,
Membership,
+ PublicRoomsFilterFields,
RelationTypes,
+ RoomTypes,
)
from synapse.api.errors import Codes, HttpResponseException
from synapse.handlers.pagination import PurgeStatus
@@ -1858,6 +1860,90 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
+class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+
+ config = self.default_config()
+ config["allow_public_rooms_without_auth"] = True
+ config["experimental_features"] = {"msc3827_enabled": True}
+ self.hs = self.setup_test_homeserver(config=config)
+ self.url = b"/_matrix/client/r0/publicRooms"
+
+ return self.hs
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ user = self.register_user("alice", "pass")
+ self.token = self.login(user, "pass")
+
+ # Create a room
+ self.helper.create_room_as(
+ user,
+ is_public=True,
+ extra_content={"visibility": "public"},
+ tok=self.token,
+ )
+ # Create a space
+ self.helper.create_room_as(
+ user,
+ is_public=True,
+ extra_content={
+ "visibility": "public",
+ "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE},
+ },
+ tok=self.token,
+ )
+
+ def make_public_rooms_request(
+ self, room_types: Union[List[Union[str, None]], None]
+ ) -> Tuple[List[Dict[str, Any]], int]:
+ channel = self.make_request(
+ "POST",
+ self.url,
+ {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}},
+ self.token,
+ )
+ chunk = channel.json_body["chunk"]
+ count = channel.json_body["total_room_count_estimate"]
+
+ self.assertEqual(len(chunk), count)
+
+ return chunk, count
+
+ def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None:
+ chunk, count = self.make_public_rooms_request(None)
+
+ self.assertEqual(count, 2)
+
+ def test_returns_only_rooms_based_on_filter(self) -> None:
+ chunk, count = self.make_public_rooms_request([None])
+
+ self.assertEqual(count, 1)
+ self.assertEqual(chunk[0].get("org.matrix.msc3827.room_type", None), None)
+
+ def test_returns_only_space_based_on_filter(self) -> None:
+ chunk, count = self.make_public_rooms_request(["m.space"])
+
+ self.assertEqual(count, 1)
+ self.assertEqual(chunk[0].get("org.matrix.msc3827.room_type", None), "m.space")
+
+ def test_returns_both_rooms_and_space_based_on_filter(self) -> None:
+ chunk, count = self.make_public_rooms_request(["m.space", None])
+
+ self.assertEqual(count, 2)
+
+ def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None:
+ chunk, count = self.make_public_rooms_request([])
+
+ self.assertEqual(count, 2)
+
+
class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
"""Test that we correctly fallback to local filtering if a remote server
doesn't support search.
@@ -1882,7 +1968,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
"Simple test for searching rooms over federation"
self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined]
- search_filter = {"generic_search_term": "foobar"}
+ search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
channel = self.make_request(
"POST",
@@ -1911,7 +1997,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
make_awaitable({}),
)
- search_filter = {"generic_search_term": "foobar"}
+ search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
channel = self.make_request(
"POST",
diff --git a/tests/server.py b/tests/server.py
index b9f465971f..df3f1564c9 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -43,6 +43,7 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import (
IAddress,
+ IConsumer,
IHostnameResolver,
IProtocol,
IPullProducer,
@@ -53,11 +54,7 @@ from twisted.internet.interfaces import (
ITransport,
)
from twisted.python.failure import Failure
-from twisted.test.proto_helpers import (
- AccumulatingProtocol,
- MemoryReactor,
- MemoryReactorClock,
-)
+from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http_headers import Headers
from twisted.web.resource import IResource
from twisted.web.server import Request, Site
@@ -96,6 +93,7 @@ class TimedOutException(Exception):
"""
+@implementer(IConsumer)
@attr.s(auto_attribs=True)
class FakeChannel:
"""
@@ -104,7 +102,7 @@ class FakeChannel:
"""
site: Union[Site, "FakeSite"]
- _reactor: MemoryReactor
+ _reactor: MemoryReactorClock
result: dict = attr.Factory(dict)
_ip: str = "127.0.0.1"
_producer: Optional[Union[IPullProducer, IPushProducer]] = None
@@ -122,7 +120,7 @@ class FakeChannel:
self._request = request
@property
- def json_body(self):
+ def json_body(self) -> JsonDict:
return json.loads(self.text_body)
@property
@@ -140,7 +138,7 @@ class FakeChannel:
return self.result.get("done", False)
@property
- def code(self):
+ def code(self) -> int:
if not self.result:
raise Exception("No result yet.")
return int(self.result["code"])
@@ -160,7 +158,7 @@ class FakeChannel:
self.result["reason"] = reason
self.result["headers"] = headers
- def write(self, content):
+ def write(self, content: bytes) -> None:
assert isinstance(content, bytes), "Should be bytes! " + repr(content)
if "body" not in self.result:
@@ -168,11 +166,16 @@ class FakeChannel:
self.result["body"] += content
- def registerProducer(self, producer, streaming):
+ # Type ignore: mypy doesn't like the fact that producer isn't an IProducer.
+ def registerProducer( # type: ignore[override]
+ self,
+ producer: Union[IPullProducer, IPushProducer],
+ streaming: bool,
+ ) -> None:
self._producer = producer
self.producerStreaming = streaming
- def _produce():
+ def _produce() -> None:
if self._producer:
self._producer.resumeProducing()
self._reactor.callLater(0.1, _produce)
@@ -180,31 +183,32 @@ class FakeChannel:
if not streaming:
self._reactor.callLater(0.0, _produce)
- def unregisterProducer(self):
+ def unregisterProducer(self) -> None:
if self._producer is None:
return
self._producer = None
- def requestDone(self, _self):
+ def requestDone(self, _self: Request) -> None:
self.result["done"] = True
if isinstance(_self, SynapseRequest):
+ assert _self.logcontext is not None
self.resource_usage = _self.logcontext.get_resource_usage()
- def getPeer(self):
+ def getPeer(self) -> IAddress:
# We give an address so that getClientAddress/getClientIP returns a non null entry,
# causing us to record the MAU
return address.IPv4Address("TCP", self._ip, 3423)
- def getHost(self):
+ def getHost(self) -> IAddress:
# this is called by Request.__init__ to configure Request.host.
return address.IPv4Address("TCP", "127.0.0.1", 8888)
- def isSecure(self):
+ def isSecure(self) -> bool:
return False
@property
- def transport(self):
+ def transport(self) -> "FakeChannel":
return self
def await_result(self, timeout_ms: int = 1000) -> None:
@@ -830,7 +834,6 @@ def setup_test_homeserver(
# Mock TLS
hs.tls_server_context_factory = Mock()
- hs.tls_client_options_factory = Mock()
hs.setup()
if homeserver_to_use == TestHomeServer:
diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
index 9abd0cb446..1edb619630 100644
--- a/tests/storage/databases/main/test_room.py
+++ b/tests/storage/databases/main/test_room.py
@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
+
+from synapse.api.constants import RoomTypes
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.storage.databases.main.room import _BackgroundUpdates
@@ -91,3 +94,69 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
)
)
self.assertEqual(room_creator_after, self.user_id)
+
+ def test_background_add_room_type_column(self):
+ """Test that the background update to populate the `room_type` column in
+ `room_stats_state` works properly.
+ """
+
+ # Create a room without a type
+ room_id = self._generate_room()
+
+ # Get event_id of the m.room.create event
+ event_id = self.get_success(
+ self.store.db_pool.simple_select_one_onecol(
+ table="current_state_events",
+ keyvalues={
+ "room_id": room_id,
+ "type": "m.room.create",
+ },
+ retcol="event_id",
+ )
+ )
+
+ # Fake a room creation event with a room type
+ event = {
+ "content": {
+ "creator": "@user:server.org",
+ "room_version": "9",
+ "type": RoomTypes.SPACE,
+ },
+ "type": "m.room.create",
+ }
+ self.get_success(
+ self.store.db_pool.simple_update(
+ table="event_json",
+ keyvalues={"event_id": event_id},
+ updatevalues={"json": json.dumps(event)},
+ desc="test",
+ )
+ )
+
+ # Insert and run the background update
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {
+ "update_name": _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
+ "progress_json": "{}",
+ },
+ )
+ )
+
+ # ... and tell the DataStore that it hasn't finished all updates yet
+ self.store.db_pool.updates._all_done = False
+
+ # Now let's actually drive the updates to completion
+ self.wait_for_background_updates()
+
+ # Make sure the background update filled in the room type
+ room_type_after = self.get_success(
+ self.store.db_pool.simple_select_one_onecol(
+ table="room_stats_state",
+ keyvalues={"room_id": room_id},
+ retcol="room_type",
+ allow_none=True,
+ )
+ )
+ self.assertEqual(room_type_after, RoomTypes.SPACE)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 852b663387..e68126777f 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -86,6 +86,8 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
event.internal_metadata.is_outlier.return_value = False
event.depth = stream
+ self.store._events_stream_cache.entity_has_changed(room_id, stream)
+
self.get_success(
self.store.db_pool.simple_insert(
table="events",
diff --git a/tests/test_state.py b/tests/test_state.py
index b005dd8d0f..7b3f52f68e 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -131,7 +131,9 @@ class _DummyStore:
async def get_room_version_id(self, room_id):
return RoomVersions.V1.identifier
- async def get_state_group_for_events(self, event_ids):
+ async def get_state_group_for_events(
+ self, event_ids, await_full_state: bool = True
+ ):
res = {}
for event in event_ids:
res[event] = self._event_to_state_group[event]
diff --git a/tests/utils.py b/tests/utils.py
index cabb2c0dec..aca6a0083b 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -64,7 +64,7 @@ def setupdb():
password=POSTGRES_PASSWORD,
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
)
- db_conn.autocommit = True
+ db_engine.attempt_to_set_autocommit(db_conn, autocommit=True)
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
cur.execute(
@@ -94,7 +94,7 @@ def setupdb():
password=POSTGRES_PASSWORD,
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
)
- db_conn.autocommit = True
+ db_engine.attempt_to_set_autocommit(db_conn, autocommit=True)
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
cur.close()
|