diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index 6f300b8e11..5569ccef8a 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -23,6 +23,7 @@ from synapse.api.constants import EventTypes, JoinRules
from synapse.api.room_versions import RoomVersions
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
+from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
@@ -296,3 +297,58 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(args[0][0]["user_agent"], "user_agent")
self.assertGreater(args[0][0]["last_seen"], 0)
self.assertNotIn("access_token", args[0][0])
+
+ def test_account_data(self) -> None:
+ """Tests that user account data get exported."""
+ # add account data
+ self.get_success(
+ self._store.add_account_data_for_user(self.user2, "m.global", {"a": 1})
+ )
+ self.get_success(
+ self._store.add_account_data_to_room(
+ self.user2, "test_room", "m.per_room", {"b": 2}
+ )
+ )
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ # two calls, one call for user data and one call for room data
+ writer.write_account_data.assert_called()
+
+ args = writer.write_account_data.call_args_list[0][0]
+ self.assertEqual(args[0], "global")
+ self.assertEqual(args[1]["m.global"]["a"], 1)
+
+ args = writer.write_account_data.call_args_list[1][0]
+ self.assertEqual(args[0], "test_room")
+ self.assertEqual(args[1]["m.per_room"]["b"], 2)
+
+ def test_media_ids(self) -> None:
+ """Tests that media's metadata get exported."""
+
+ self.get_success(
+ self._store.store_local_media(
+ media_id="media_1",
+ media_type="image/png",
+ time_now_ms=self.clock.time_msec(),
+ upload_name=None,
+ media_length=50,
+ user_id=UserID.from_string(self.user2),
+ )
+ )
+
+ writer = Mock()
+
+ self.get_success(self.admin_handler.export_user_data(self.user2, writer))
+
+ writer.write_media_id.assert_called_once()
+
+ args = writer.write_media_id.call_args[0]
+ self.assertEqual(args[0], "media_1")
+ self.assertEqual(args[1]["media_id"], "media_1")
+ self.assertEqual(args[1]["media_length"], 50)
+ self.assertGreater(args[1]["created_ts"], 0)
+ self.assertIsNone(args[1]["upload_name"])
+ self.assertIsNone(args[1]["last_access_ts"])
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 179c96adf2..2e838e6572 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -18,7 +18,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
@@ -81,7 +81,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def _create_duplicate_event(
self, txn_id: str
- ) -> Tuple[EventBase, EventContext, Optional[dict]]:
+ ) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]:
"""Create a new event with the given transaction ID. All events produced
by this method will be considered duplicates.
"""
@@ -109,7 +109,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_suitably_random"
- event1, context, _ = self._create_duplicate_event(txn_id)
+ event1, unpersisted_context, _ = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event1))
ret_event1 = self.get_success(
self.handler.handle_new_client_event(
@@ -121,7 +122,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertEqual(event1.event_id, ret_event1.event_id)
- event2, context, _ = self._create_duplicate_event(txn_id)
+ event2, unpersisted_context, _ = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event2))
# We want to test that the deduplication at the persit event end works,
# so we want to make sure we test with different events.
@@ -142,7 +144,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_event` directly also does the right
# thing.
- event3, context, _ = self._create_duplicate_event(txn_id)
+ event3, unpersisted_context, _ = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event3))
+
self.assertNotEqual(event1.event_id, event3.event_id)
ret_event3, event_pos3, _ = self.get_success(
@@ -156,7 +160,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_events` directly also does the right
# thing.
- event4, context, _ = self._create_duplicate_event(txn_id)
+ event4, unpersisted_context, _ = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event4))
+
self.assertNotEqual(event1.event_id, event3.event_id)
events, _ = self.get_success(
@@ -176,8 +182,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_else_suitably_random"
# Create two duplicate events to persist at the same time
- event1, context1, _ = self._create_duplicate_event(txn_id)
- event2, context2, _ = self._create_duplicate_event(txn_id)
+ event1, unpersisted_context1, _ = self._create_duplicate_event(txn_id)
+ context1 = self.get_success(unpersisted_context1.persist(event1))
+ event2, unpersisted_context2, _ = self._create_duplicate_event(txn_id)
+ context2 = self.get_success(unpersisted_context2.persist(event2))
# Ensure their event IDs are different to start with
self.assertNotEqual(event1.event_id, event2.event_id)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index ad42a7183e..161ff0a6c1 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -62,7 +62,7 @@ class TestSpamChecker:
request_info: Collection[Tuple[str, str]],
auth_provider_id: Optional[str],
) -> RegistrationBehaviour:
- pass
+ return RegistrationBehaviour.ALLOW
class DenyAll(TestSpamChecker):
@@ -111,7 +111,7 @@ class TestLegacyRegistrationSpamChecker:
username: Optional[str],
request_info: Collection[Tuple[str, str]],
) -> RegistrationBehaviour:
- pass
+ return RegistrationBehaviour.ALLOW
class LegacyAllowAll(TestLegacyRegistrationSpamChecker):
@@ -507,7 +507,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Lower the permissions of the inviter.
event_creation_handler = self.hs.get_event_creation_handler()
requester = create_requester(inviter)
- event, context, _ = self.get_success(
+
+ event, unpersisted_context, _ = self.get_success(
event_creation_handler.create_event(
requester,
{
@@ -519,6 +520,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
},
)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
event_creation_handler.handle_new_client_event(
requester, events_and_context=[(event, context)]
diff --git a/tests/handlers/test_sso.py b/tests/handlers/test_sso.py
index 137deab138..d6f43a98fc 100644
--- a/tests/handlers/test_sso.py
+++ b/tests/handlers/test_sso.py
@@ -113,7 +113,6 @@ async def mock_get_file(
headers: Optional[RawHeaders] = None,
is_allowed_content_type: Optional[Callable[[str], bool]] = None,
) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
-
fake_response = FakeResponse(code=404)
if url == "http://my.server/me.png":
fake_response = FakeResponse(
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index f1a50c5bcb..d11ded6c5b 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -31,7 +31,6 @@ EXPT_NUM_STATE_EVTS_IN_FRESH_PRIVATE_ROOM = 6
class StatsRoomTests(unittest.HomeserverTestCase):
-
servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 9bb9c1afc6..c5746005b5 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -192,6 +192,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.helper.join(room, self.appservice.sender, tok=self.appservice.token)
self._check_only_one_user_in_directory(user, room)
+ def test_search_term_with_colon_in_it_does_not_raise(self) -> None:
+ """
+ Regression test: Test that search terms with colons in them are acceptable.
+ """
+ u1 = self.register_user("user1", "pass")
+ self.get_success(self.handler.search_users(u1, "haha:paamayim-nekudotayim", 10))
+
def test_user_not_in_users_table(self) -> None:
"""Unclear how it happens, but on matrix.org we've seen join events
for users who aren't in the users table. Test that we don't fall over
|