diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index dcd02932b2..70c6219b2e 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -273,7 +273,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
spam_checker = self.hs.get_spam_checker()
class AllowAll:
- def check_username_for_spam(self, user_profile):
+ async def check_username_for_spam(self, user_profile):
# Allow all users.
return False
@@ -286,7 +286,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Configure a spam checker that filters all users.
class BlockAll:
- def check_username_for_spam(self, user_profile):
+ async def check_username_for_spam(self, user_profile):
# All users are spammy.
return True
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 73f469b802..f6e7e5fdaa 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -18,30 +18,35 @@ import logging
from io import StringIO
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
+from synapse.logging.context import LoggingContext, LoggingContextFilter
from tests.logging import LoggerCleanupMixin
from tests.unittest import TestCase
class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
+ def setUp(self):
+ self.output = StringIO()
+
+ def get_log_line(self):
+ # One log message, with a single trailing newline.
+ data = self.output.getvalue()
+ logs = data.splitlines()
+ self.assertEqual(len(logs), 1)
+ self.assertEqual(data.count("\n"), 1)
+ return json.loads(logs[0])
+
def test_terse_json_output(self):
"""
The Terse JSON formatter converts log messages to JSON.
"""
- output = StringIO()
-
- handler = logging.StreamHandler(output)
+ handler = logging.StreamHandler(self.output)
handler.setFormatter(TerseJsonFormatter())
logger = self.get_logger(handler)
logger.info("Hello there, %s!", "wally")
- # One log message, with a single trailing newline.
- data = output.getvalue()
- logs = data.splitlines()
- self.assertEqual(len(logs), 1)
- self.assertEqual(data.count("\n"), 1)
- log = json.loads(logs[0])
+ log = self.get_log_line()
# The terse logger should give us these keys.
expected_log_keys = [
@@ -57,9 +62,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"""
Additional information can be included in the structured logging.
"""
- output = StringIO()
-
- handler = logging.StreamHandler(output)
+ handler = logging.StreamHandler(self.output)
handler.setFormatter(TerseJsonFormatter())
logger = self.get_logger(handler)
@@ -67,12 +70,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"Hello there, %s!", "wally", extra={"foo": "bar", "int": 3, "bool": True}
)
- # One log message, with a single trailing newline.
- data = output.getvalue()
- logs = data.splitlines()
- self.assertEqual(len(logs), 1)
- self.assertEqual(data.count("\n"), 1)
- log = json.loads(logs[0])
+ log = self.get_log_line()
# The terse logger should give us these keys.
expected_log_keys = [
@@ -96,26 +94,47 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"""
The Terse JSON formatter converts log messages to JSON.
"""
- output = StringIO()
-
- handler = logging.StreamHandler(output)
+ handler = logging.StreamHandler(self.output)
handler.setFormatter(JsonFormatter())
logger = self.get_logger(handler)
logger.info("Hello there, %s!", "wally")
- # One log message, with a single trailing newline.
- data = output.getvalue()
- logs = data.splitlines()
- self.assertEqual(len(logs), 1)
- self.assertEqual(data.count("\n"), 1)
- log = json.loads(logs[0])
+ log = self.get_log_line()
+
+ # The terse logger should give us these keys.
+ expected_log_keys = [
+ "log",
+ "level",
+ "namespace",
+ ]
+ self.assertCountEqual(log.keys(), expected_log_keys)
+ self.assertEqual(log["log"], "Hello there, wally!")
+
+ def test_with_context(self):
+ """
+ The logging context should be added to the JSON response.
+ """
+ handler = logging.StreamHandler(self.output)
+ handler.setFormatter(JsonFormatter())
+ handler.addFilter(LoggingContextFilter(request=""))
+ logger = self.get_logger(handler)
+
+ with LoggingContext() as context_one:
+ context_one.request = "test"
+ logger.info("Hello there, %s!", "wally")
+
+ log = self.get_log_line()
# The terse logger should give us these keys.
expected_log_keys = [
"log",
"level",
"namespace",
+ "request",
+ "scope",
]
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")
+ self.assertEqual(log["request"], "test")
+ self.assertIsNone(log["scope"])
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 46933a0493..9c100050d2 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1084,6 +1084,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("canonical_alias", channel.json_body)
self.assertIn("joined_members", channel.json_body)
self.assertIn("joined_local_members", channel.json_body)
+ self.assertIn("joined_local_devices", channel.json_body)
self.assertIn("version", channel.json_body)
self.assertIn("creator", channel.json_body)
self.assertIn("encryption", channel.json_body)
@@ -1096,6 +1097,39 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(room_id_1, channel.json_body["room_id"])
+ def test_single_room_devices(self):
+ """Test that `joined_local_devices` can be requested correctly"""
+ room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
+
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(1, channel.json_body["joined_local_devices"])
+
+ # Have another user join the room
+ user_1 = self.register_user("foo", "pass")
+ user_tok_1 = self.login("foo", "pass")
+ self.helper.join(room_id_1, user_1, tok=user_tok_1)
+
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(2, channel.json_body["joined_local_devices"])
+
+ # leave room
+ self.helper.leave(room_id_1, self.admin_user, tok=self.admin_user_tok)
+ self.helper.leave(room_id_1, user_1, tok=user_tok_1)
+ url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
+ request, channel = self.make_request(
+ "GET", url.encode("ascii"), access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual(0, channel.json_body["joined_local_devices"])
+
def test_room_members(self):
"""Test that room members can be requested correctly"""
# Create two test rooms
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index e67de41c18..55d872f0ee 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -26,6 +26,7 @@ from mock import Mock
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
+from synapse.rest import admin
from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import account
from synapse.types import JsonDict, RoomAlias, UserID
@@ -625,6 +626,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
servlets = [
+ admin.register_servlets,
profile.register_servlets,
room.register_servlets,
]
@@ -703,6 +705,20 @@ class RoomJoinRatelimitTestCase(RoomBase):
request, channel = self.make_request("POST", path % room_id, {})
self.assertEquals(channel.code, 200)
+ @unittest.override_config(
+ {
+ "rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}},
+ "auto_join_rooms": ["#room:red", "#room2:red", "#room3:red", "#room4:red"],
+ "autocreate_auto_join_rooms": True,
+ },
+ )
+ def test_autojoin_rooms(self):
+ user_id = self.register_user("testuser", "password")
+
+ # Check that the new user successfully joined the four rooms
+ rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id))
+ self.assertEqual(len(rooms), 4)
+
class RoomMessagesTestCase(RoomBase):
""" Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index ecb00f4e02..dabc1c5f09 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -80,6 +80,32 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
@defer.inlineCallbacks
+ def test_count_devices_by_users(self):
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device1", "display_name 1")
+ )
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device2", "display_name 2")
+ )
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id2", "device3", "display_name 3")
+ )
+
+ res = yield defer.ensureDeferred(self.store.count_devices_by_users())
+ self.assertEqual(0, res)
+
+ res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"]))
+ self.assertEqual(0, res)
+
+ res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"]))
+ self.assertEqual(2, res)
+
+ res = yield defer.ensureDeferred(
+ self.store.count_devices_by_users(["user_id", "user_id2"])
+ )
+ self.assertEqual(3, res)
+
+ @defer.inlineCallbacks
def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]
|