summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/logging/test_opentracing.py2
-rw-r--r--tests/rest/client/test_rooms.py92
-rw-r--r--tests/server.py1
-rw-r--r--tests/storage/databases/main/test_room.py69
-rw-r--r--tests/storage/test_event_push_actions.py2
-rw-r--r--tests/test_state.py4
-rw-r--r--tests/utils.py4
7 files changed, 166 insertions, 8 deletions
diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py
index e430941d27..40148d503c 100644
--- a/tests/logging/test_opentracing.py
+++ b/tests/logging/test_opentracing.py
@@ -50,7 +50,7 @@ class LogContextScopeManagerTestCase(TestCase):
         # global variables that power opentracing. We create our own tracer instance
         # and test with it.
 
-        scope_manager = LogContextScopeManager({})
+        scope_manager = LogContextScopeManager()
         config = jaeger_client.config.Config(
             config={}, service_name="test", scope_manager=scope_manager
         )
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 35c59ee9e0..1ccd96a207 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -18,7 +18,7 @@
 """Tests REST events for /rooms paths."""
 
 import json
-from typing import Any, Dict, Iterable, List, Optional, Union
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
 from unittest.mock import Mock, call
 from urllib import parse as urlparse
 
@@ -33,7 +33,9 @@ from synapse.api.constants import (
     EventContentFields,
     EventTypes,
     Membership,
+    PublicRoomsFilterFields,
     RelationTypes,
+    RoomTypes,
 )
 from synapse.api.errors import Codes, HttpResponseException
 from synapse.handlers.pagination import PurgeStatus
@@ -1858,6 +1860,90 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
 
 
+class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+
+        config = self.default_config()
+        config["allow_public_rooms_without_auth"] = True
+        config["experimental_features"] = {"msc3827_enabled": True}
+        self.hs = self.setup_test_homeserver(config=config)
+        self.url = b"/_matrix/client/r0/publicRooms"
+
+        return self.hs
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        user = self.register_user("alice", "pass")
+        self.token = self.login(user, "pass")
+
+        # Create a room
+        self.helper.create_room_as(
+            user,
+            is_public=True,
+            extra_content={"visibility": "public"},
+            tok=self.token,
+        )
+        # Create a space
+        self.helper.create_room_as(
+            user,
+            is_public=True,
+            extra_content={
+                "visibility": "public",
+                "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE},
+            },
+            tok=self.token,
+        )
+
+    def make_public_rooms_request(
+        self, room_types: Union[List[Union[str, None]], None]
+    ) -> Tuple[List[Dict[str, Any]], int]:
+        channel = self.make_request(
+            "POST",
+            self.url,
+            {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}},
+            self.token,
+        )
+        chunk = channel.json_body["chunk"]
+        count = channel.json_body["total_room_count_estimate"]
+
+        self.assertEqual(len(chunk), count)
+
+        return chunk, count
+
+    def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None:
+        chunk, count = self.make_public_rooms_request(None)
+
+        self.assertEqual(count, 2)
+
+    def test_returns_only_rooms_based_on_filter(self) -> None:
+        chunk, count = self.make_public_rooms_request([None])
+
+        self.assertEqual(count, 1)
+        self.assertEqual(chunk[0].get("org.matrix.msc3827.room_type", None), None)
+
+    def test_returns_only_space_based_on_filter(self) -> None:
+        chunk, count = self.make_public_rooms_request(["m.space"])
+
+        self.assertEqual(count, 1)
+        self.assertEqual(chunk[0].get("org.matrix.msc3827.room_type", None), "m.space")
+
+    def test_returns_both_rooms_and_space_based_on_filter(self) -> None:
+        chunk, count = self.make_public_rooms_request(["m.space", None])
+
+        self.assertEqual(count, 2)
+
+    def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None:
+        chunk, count = self.make_public_rooms_request([])
+
+        self.assertEqual(count, 2)
+
+
 class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
     """Test that we correctly fallback to local filtering if a remote server
     doesn't support search.
@@ -1882,7 +1968,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
         "Simple test for searching rooms over federation"
         self.federation_client.get_public_rooms.return_value = make_awaitable({})  # type: ignore[attr-defined]
 
-        search_filter = {"generic_search_term": "foobar"}
+        search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
 
         channel = self.make_request(
             "POST",
@@ -1911,7 +1997,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
             make_awaitable({}),
         )
 
-        search_filter = {"generic_search_term": "foobar"}
+        search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
 
         channel = self.make_request(
             "POST",
diff --git a/tests/server.py b/tests/server.py
index b9f465971f..ce017ca0f6 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -830,7 +830,6 @@ def setup_test_homeserver(
 
     # Mock TLS
     hs.tls_server_context_factory = Mock()
-    hs.tls_client_options_factory = Mock()
 
     hs.setup()
     if homeserver_to_use == TestHomeServer:
diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
index 9abd0cb446..1edb619630 100644
--- a/tests/storage/databases/main/test_room.py
+++ b/tests/storage/databases/main/test_room.py
@@ -12,6 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import json
+
+from synapse.api.constants import RoomTypes
 from synapse.rest import admin
 from synapse.rest.client import login, room
 from synapse.storage.databases.main.room import _BackgroundUpdates
@@ -91,3 +94,69 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
             )
         )
         self.assertEqual(room_creator_after, self.user_id)
+
+    def test_background_add_room_type_column(self):
+        """Test that the background update to populate the `room_type` column in
+        `room_stats_state` works properly.
+        """
+
+        # Create a room without a type
+        room_id = self._generate_room()
+
+        # Get event_id of the m.room.create event
+        event_id = self.get_success(
+            self.store.db_pool.simple_select_one_onecol(
+                table="current_state_events",
+                keyvalues={
+                    "room_id": room_id,
+                    "type": "m.room.create",
+                },
+                retcol="event_id",
+            )
+        )
+
+        # Fake a room creation event with a room type
+        event = {
+            "content": {
+                "creator": "@user:server.org",
+                "room_version": "9",
+                "type": RoomTypes.SPACE,
+            },
+            "type": "m.room.create",
+        }
+        self.get_success(
+            self.store.db_pool.simple_update(
+                table="event_json",
+                keyvalues={"event_id": event_id},
+                updatevalues={"json": json.dumps(event)},
+                desc="test",
+            )
+        )
+
+        # Insert and run the background update
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {
+                    "update_name": _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
+                    "progress_json": "{}",
+                },
+            )
+        )
+
+        # ... and tell the DataStore that it hasn't finished all updates yet
+        self.store.db_pool.updates._all_done = False
+
+        # Now let's actually drive the updates to completion
+        self.wait_for_background_updates()
+
+        # Make sure the background update filled in the room type
+        room_type_after = self.get_success(
+            self.store.db_pool.simple_select_one_onecol(
+                table="room_stats_state",
+                keyvalues={"room_id": room_id},
+                retcol="room_type",
+                allow_none=True,
+            )
+        )
+        self.assertEqual(room_type_after, RoomTypes.SPACE)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 852b663387..e68126777f 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -86,6 +86,8 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
             event.internal_metadata.is_outlier.return_value = False
             event.depth = stream
 
+            self.store._events_stream_cache.entity_has_changed(room_id, stream)
+
             self.get_success(
                 self.store.db_pool.simple_insert(
                     table="events",
diff --git a/tests/test_state.py b/tests/test_state.py
index b005dd8d0f..7b3f52f68e 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -131,7 +131,9 @@ class _DummyStore:
     async def get_room_version_id(self, room_id):
         return RoomVersions.V1.identifier
 
-    async def get_state_group_for_events(self, event_ids):
+    async def get_state_group_for_events(
+        self, event_ids, await_full_state: bool = True
+    ):
         res = {}
         for event in event_ids:
             res[event] = self._event_to_state_group[event]
diff --git a/tests/utils.py b/tests/utils.py
index cabb2c0dec..aca6a0083b 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -64,7 +64,7 @@ def setupdb():
             password=POSTGRES_PASSWORD,
             dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
         )
-        db_conn.autocommit = True
+        db_engine.attempt_to_set_autocommit(db_conn, autocommit=True)
         cur = db_conn.cursor()
         cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
         cur.execute(
@@ -94,7 +94,7 @@ def setupdb():
                 password=POSTGRES_PASSWORD,
                 dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
             )
-            db_conn.autocommit = True
+            db_engine.attempt_to_set_autocommit(db_conn, autocommit=True)
             cur = db_conn.cursor()
             cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
             cur.close()