summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2021-08-31 11:40:27 +0100
committerBrendan Abolivier <babolivier@matrix.org>2021-08-31 11:40:27 +0100
commit9de3991b9e83dacea66db18c56b63fd0f988d7f7 (patch)
tree6106408e1ef29343ec0079553e226a8ceed74c7b /tests
parentFix formatting (diff)
parentUpdate v1.32.0 changelog. It's m.login.application_service, not plural (diff)
downloadsynapse-9de3991b9e83dacea66db18c56b63fd0f988d7f7.tar.xz
Merge tag 'v1.32.0' into babolivier/dinsic_1.41.0
Synapse 1.32.0 (2021-04-20)
===========================

**Note:** This release requires Python 3.6+ and Postgres 9.6+ or SQLite 3.22+.

This release removes the deprecated `GET /_synapse/admin/v1/users/<user_id>` admin API. Please use the [v2 API](https://github.com/matrix-org/synapse/blob/develop/docs/admin_api/user_admin_api.rst#query-user-account) instead, which has improved capabilities.

This release requires Application Services to use type `m.login.application_service` when registering users via the `/_matrix/client/r0/register` endpoint to comply with the spec. Please ensure your Application Services are up to date.

Bugfixes
--------

- Fix the log lines of nested logging contexts. Broke in 1.32.0rc1. ([\#9829](https://github.com/matrix-org/synapse/issues/9829))

Synapse 1.32.0rc1 (2021-04-13)
==============================

Features
--------

- Add a Synapse module for routing presence updates between users. ([\#9491](https://github.com/matrix-org/synapse/issues/9491))
- Add an admin API to manage ratelimit for a specific user. ([\#9648](https://github.com/matrix-org/synapse/issues/9648))
- Include request information in structured logging output. ([\#9654](https://github.com/matrix-org/synapse/issues/9654))
- Add `order_by` to the admin API `GET /_synapse/admin/v2/users`. Contributed by @dklimpel. ([\#9691](https://github.com/matrix-org/synapse/issues/9691))
- Replace the `room_invite_state_types` configuration setting with `room_prejoin_state`. ([\#9700](https://github.com/matrix-org/synapse/issues/9700))
- Add experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership. ([\#9717](https://github.com/matrix-org/synapse/issues/9717), [\#9735](https://github.com/matrix-org/synapse/issues/9735))
- Update experimental support for Spaces: include `m.room.create` in the room state sent with room-invites. ([\#9710](https://github.com/matrix-org/synapse/issues/9710))
- Synapse now requires Python 3.6 or later. It also requires Postgres 9.6 or later or SQLite 3.22 or later. ([\#9766](https://github.com/matrix-org/synapse/issues/9766))

Bugfixes
--------

- Prevent `synapse_forward_extremities` and `synapse_excess_extremity_events` Prometheus metrics from initially reporting zero-values after startup. ([\#8926](https://github.com/matrix-org/synapse/issues/8926))
- Fix recently added ratelimits to correctly honour the application service `rate_limited` flag. ([\#9711](https://github.com/matrix-org/synapse/issues/9711))
- Fix longstanding bug which caused `duplicate key value violates unique constraint "remote_media_cache_thumbnails_media_origin_media_id_thumbna_key"` errors. ([\#9725](https://github.com/matrix-org/synapse/issues/9725))
- Fix bug where sharded federation senders could get stuck repeatedly querying the DB in a loop, using lots of CPU. ([\#9770](https://github.com/matrix-org/synapse/issues/9770))
- Fix duplicate logging of exceptions thrown during federation transaction processing. ([\#9780](https://github.com/matrix-org/synapse/issues/9780))

Updates to the Docker image
---------------------------

- Move opencontainers labels to the final Docker image such that users can inspect them. ([\#9765](https://github.com/matrix-org/synapse/issues/9765))

Improved Documentation
----------------------

- Make the `allowed_local_3pids` regex example in the sample config stricter. ([\#9719](https://github.com/matrix-org/synapse/issues/9719))

Deprecations and Removals
-------------------------

- Remove old admin API `GET /_synapse/admin/v1/users/<user_id>`. ([\#9401](https://github.com/matrix-org/synapse/issues/9401))
- Make `/_matrix/client/r0/register` expect a type of `m.login.application_service` when an Application Service registers a user, to align with [the relevant spec](https://spec.matrix.org/unstable/application-service-api/#server-admin-style-permissions). ([\#9548](https://github.com/matrix-org/synapse/issues/9548))

Internal Changes
----------------

- Replace deprecated `imp` module with successor `importlib`. Contributed by Cristina Muñoz. ([\#9718](https://github.com/matrix-org/synapse/issues/9718))
- Experiment with GitHub Actions for CI. ([\#9661](https://github.com/matrix-org/synapse/issues/9661))
- Introduce flake8-bugbear to the test suite and fix some of its lint violations. ([\#9682](https://github.com/matrix-org/synapse/issues/9682))
- Update `scripts-dev/complement.sh` to use a local checkout of Complement, allow running a subset of tests and have it use Synapse's Complement test blacklist. ([\#9685](https://github.com/matrix-org/synapse/issues/9685))
- Improve Jaeger tracing for `to_device` messages. ([\#9686](https://github.com/matrix-org/synapse/issues/9686))
- Add release helper script for automating part of the Synapse release process. ([\#9713](https://github.com/matrix-org/synapse/issues/9713))
- Add type hints to expiring cache. ([\#9730](https://github.com/matrix-org/synapse/issues/9730))
- Convert various testcases to `HomeserverTestCase`. ([\#9736](https://github.com/matrix-org/synapse/issues/9736))
- Start linting mypy with `no_implicit_optional`. ([\#9742](https://github.com/matrix-org/synapse/issues/9742))
- Add missing type hints to federation handler and server. ([\#9743](https://github.com/matrix-org/synapse/issues/9743))
- Check that a `ConfigError` is raised, rather than simply `Exception`, when appropriate in homeserver config file generation tests. ([\#9753](https://github.com/matrix-org/synapse/issues/9753))
- Fix incompatibility with `tox` 2.5. ([\#9769](https://github.com/matrix-org/synapse/issues/9769))
- Enable Complement tests for [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946): Spaces Summary API. ([\#9771](https://github.com/matrix-org/synapse/issues/9771))
- Use mock from the standard library instead of a separate package. ([\#9772](https://github.com/matrix-org/synapse/issues/9772))
- Update Black configuration to target Python 3.6. ([\#9781](https://github.com/matrix-org/synapse/issues/9781))
- Add option to skip unit tests when building Debian packages. ([\#9793](https://github.com/matrix-org/synapse/issues/9793))
Diffstat (limited to 'tests')
-rw-r--r--tests/api/test_auth.py2
-rw-r--r--tests/api/test_ratelimiting.py168
-rw-r--r--tests/app/test_openid_listener.py2
-rw-r--r--tests/appservice/test_appservice.py3
-rw-r--r--tests/appservice/test_scheduler.py2
-rw-r--r--tests/config/test_load.py5
-rw-r--r--tests/crypto/test_keyring.py26
-rw-r--r--tests/events/test_presence_router.py386
-rw-r--r--tests/federation/test_complexity.py2
-rw-r--r--tests/federation/test_federation_catch_up.py3
-rw-r--r--tests/federation/test_federation_sender.py3
-rw-r--r--tests/handlers/test_admin.py3
-rw-r--r--tests/handlers/test_appservice.py2
-rw-r--r--tests/handlers/test_auth.py2
-rw-r--r--tests/handlers/test_cas.py2
-rw-r--r--tests/handlers/test_directory.py2
-rw-r--r--tests/handlers/test_e2e_keys.py2
-rw-r--r--tests/handlers/test_e2e_room_keys.py3
-rw-r--r--tests/handlers/test_oidc.py3
-rw-r--r--tests/handlers/test_password_providers.py3
-rw-r--r--tests/handlers/test_presence.py2
-rw-r--r--tests/handlers/test_profile.py2
-rw-r--r--tests/handlers/test_register.py2
-rw-r--r--tests/handlers/test_saml.py3
-rw-r--r--tests/handlers/test_sync.py21
-rw-r--r--tests/handlers/test_typing.py3
-rw-r--r--tests/handlers/test_user_directory.py2
-rw-r--r--tests/http/federation/test_matrix_federation_agent.py18
-rw-r--r--tests/http/federation/test_srv_resolver.py2
-rw-r--r--tests/http/test_client.py3
-rw-r--r--tests/http/test_fedclient.py2
-rw-r--r--tests/http/test_servlet.py3
-rw-r--r--tests/http/test_simple_client.py2
-rw-r--r--tests/logging/test_terse_json.py71
-rw-r--r--tests/module_api/test_api.py178
-rw-r--r--tests/push/test_http.py2
-rw-r--r--tests/replication/_base.py4
-rw-r--r--tests/replication/slave/storage/_base.py2
-rw-r--r--tests/replication/slave/storage/test_events.py18
-rw-r--r--tests/replication/tcp/streams/test_receipts.py2
-rw-r--r--tests/replication/tcp/streams/test_typing.py2
-rw-r--r--tests/replication/test_federation_ack.py2
-rw-r--r--tests/replication/test_federation_sender_shard.py3
-rw-r--r--tests/replication/test_pusher_shard.py3
-rw-r--r--tests/replication/test_sharded_event_persister.py3
-rw-r--r--tests/rest/admin/test_admin.py3
-rw-r--r--tests/rest/admin/test_room.py3
-rw-r--r--tests/rest/admin/test_user.py408
-rw-r--r--tests/rest/client/test_retention.py2
-rw-r--r--tests/rest/client/test_shadow_banned.py2
-rw-r--r--tests/rest/client/test_third_party_rules.py3
-rw-r--r--tests/rest/client/test_transactions.py2
-rw-r--r--tests/rest/client/v1/test_events.py2
-rw-r--r--tests/rest/client/v1/test_login.py3
-rw-r--r--tests/rest/client/v1/test_presence.py2
-rw-r--r--tests/rest/client/v1/test_rooms.py8
-rw-r--r--tests/rest/client/v1/test_typing.py2
-rw-r--r--tests/rest/client/v1/utils.py17
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py7
-rw-r--r--tests/rest/client/v2_alpha/test_register.py31
-rw-r--r--tests/rest/client/v2_alpha/test_relations.py5
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py3
-rw-r--r--tests/rest/media/v1/test_media_storage.py3
-rw-r--r--tests/rest/media/v1/test_url_preview.py3
-rw-r--r--tests/scripts/test_new_matrix_user.py2
-rw-r--r--tests/server_notices/test_resource_limits_server_notices.py2
-rw-r--r--tests/storage/test_appservice.py3
-rw-r--r--tests/storage/test_background_update.py2
-rw-r--r--tests/storage/test_base.py3
-rw-r--r--tests/storage/test_cleanup_extrems.py4
-rw-r--r--tests/storage/test_client_ips.py6
-rw-r--r--tests/storage/test_database.py13
-rw-r--r--tests/storage/test_devices.py80
-rw-r--r--tests/storage/test_directory.py44
-rw-r--r--tests/storage/test_end_to_end_keys.py59
-rw-r--r--tests/storage/test_event_push_actions.py135
-rw-r--r--tests/storage/test_id_generators.py14
-rw-r--r--tests/storage/test_monthly_active_users.py2
-rw-r--r--tests/storage/test_profile.py41
-rw-r--r--tests/storage/test_redaction.py22
-rw-r--r--tests/storage/test_registration.py108
-rw-r--r--tests/storage/test_room.py61
-rw-r--r--tests/storage/test_state.py145
-rw-r--r--tests/storage/test_user_directory.py86
-rw-r--r--tests/test_distributor.py2
-rw-r--r--tests/test_event_auth.py246
-rw-r--r--tests/test_federation.py4
-rw-r--r--tests/test_mau.py23
-rw-r--r--tests/test_phone_home.py3
-rw-r--r--tests/test_state.py10
-rw-r--r--tests/test_terms_auth.py3
-rw-r--r--tests/test_utils/__init__.py3
-rw-r--r--tests/test_utils/event_injection.py6
-rw-r--r--tests/test_visibility.py10
-rw-r--r--tests/unittest.py5
-rw-r--r--tests/util/caches/test_descriptors.py16
-rw-r--r--tests/util/caches/test_ttlcache.py2
-rw-r--r--tests/util/test_file_consumer.py3
-rw-r--r--tests/util/test_logcontext.py35
-rw-r--r--tests/util/test_lrucache.py2
-rw-r--r--tests/util/test_ratelimitutils.py6
-rw-r--r--tests/utils.py6
102 files changed, 1887 insertions, 818 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py

index 34f72ae795..28d77f0ca2 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock import pymacaroons diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 483418192c..fa96ba07a5 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py
@@ -5,38 +5,25 @@ from synapse.types import create_requester from tests import unittest -class TestRatelimiter(unittest.TestCase): +class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_via_can_do_action(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0) - self.assertTrue(allowed) - self.assertEquals(10.0, time_allowed) - - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5) - self.assertFalse(allowed) - self.assertEquals(10.0, time_allowed) - - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10) - self.assertTrue(allowed) - self.assertEquals(20.0, time_allowed) - - def test_allowed_user_via_can_requester_do_action(self): - user_requester = create_requester("@user:example.com") - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=0 + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=5) ) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) @@ -51,21 +38,23 @@ class TestRatelimiter(unittest.TestCase): ) as_requester = create_requester("@user:example.com", app_service=appservice) - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=0 + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) @@ -80,73 +69,89 @@ class TestRatelimiter(unittest.TestCase): ) as_requester = create_requester("@user:example.com", app_service=appservice) - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=0 + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) def test_allowed_via_ratelimit(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # Shouldn't raise - limiter.ratelimit(key="test_id", _time_now_s=0) + self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0)) # Should raise with self.assertRaises(LimitExceededError) as context: - limiter.ratelimit(key="test_id", _time_now_s=5) + self.get_success_or_raise( + limiter.ratelimit(None, key="test_id", _time_now_s=5) + ) self.assertEqual(context.exception.retry_after_ms, 5000) # Shouldn't raise - limiter.ratelimit(key="test_id", _time_now_s=10) + self.get_success_or_raise( + limiter.ratelimit(None, key="test_id", _time_now_s=10) + ) def test_allowed_via_can_do_action_and_overriding_parameters(self): """Test that we can override options of can_do_action that would otherwise fail an action """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # First attempt should be allowed - allowed, time_allowed = limiter.can_do_action( - ("test_id",), - _time_now_s=0, + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action( + None, + ("test_id",), + _time_now_s=0, + ) ) self.assertTrue(allowed) self.assertEqual(10.0, time_allowed) # Second attempt, 1s later, will fail - allowed, time_allowed = limiter.can_do_action( - ("test_id",), - _time_now_s=1, + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action( + None, + ("test_id",), + _time_now_s=1, + ) ) self.assertFalse(allowed) self.assertEqual(10.0, time_allowed) # But, if we allow 10 actions/sec for this request, we should be allowed # to continue. - allowed, time_allowed = limiter.can_do_action( - ("test_id",), _time_now_s=1, rate_hz=10.0 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, ("test_id",), _time_now_s=1, rate_hz=10.0) ) self.assertTrue(allowed) self.assertEqual(1.1, time_allowed) # Similarly if we allow a burst of 10 actions - allowed, time_allowed = limiter.can_do_action( - ("test_id",), _time_now_s=1, burst_count=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, ("test_id",), _time_now_s=1, burst_count=10) ) self.assertTrue(allowed) self.assertEqual(1.0, time_allowed) @@ -156,29 +161,72 @@ class TestRatelimiter(unittest.TestCase): fail an action """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # First attempt should be allowed - limiter.ratelimit(key=("test_id",), _time_now_s=0) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=0) + ) # Second attempt, 1s later, will fail with self.assertRaises(LimitExceededError) as context: - limiter.ratelimit(key=("test_id",), _time_now_s=1) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1) + ) self.assertEqual(context.exception.retry_after_ms, 9000) # But, if we allow 10 actions/sec for this request, we should be allowed # to continue. - limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0) + ) # Similarly if we allow a burst of 10 actions - limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10) + ) def test_pruning(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - limiter.can_do_action(key="test_id_1", _time_now_s=0) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + self.get_success_or_raise( + limiter.can_do_action(None, key="test_id_1", _time_now_s=0) + ) self.assertIn("test_id_1", limiter.actions) - limiter.can_do_action(key="test_id_2", _time_now_s=10) + self.get_success_or_raise( + limiter.can_do_action(None, key="test_id_2", _time_now_s=10) + ) self.assertNotIn("test_id_1", limiter.actions) + + def test_db_user_override(self): + """Test that users that have ratelimiting disabled in the DB aren't + ratelimited. + """ + store = self.hs.get_datastore() + + user_id = "@user:test" + requester = create_requester(user_id) + + self.get_success( + store.db_pool.simple_insert( + table="ratelimit_override", + values={ + "user_id": user_id, + "messages_per_second": None, + "burst_count": None, + }, + desc="test_db_user_override", + ) + ) + + limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1) + + # Shouldn't raise + for _ in range(20): + self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0)) diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index 467033e201..33a37fe35e 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py
@@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock, patch +from unittest.mock import Mock, patch from parameterized import parameterized diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py
index 0bffeb1150..03a7440eec 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py
@@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import re - -from mock import Mock +from unittest.mock import Mock from twisted.internet import defer diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py
index 97f8cad0dd..3c27d797fb 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py
@@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock from twisted.internet import defer diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index 734a9983e8..c109425671 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py
@@ -20,6 +20,7 @@ from io import StringIO import yaml +from synapse.config import ConfigError from synapse.config.homeserver import HomeServerConfig from tests import unittest @@ -35,9 +36,9 @@ class ConfigLoadingTestCase(unittest.TestCase): def test_load_fails_if_server_name_missing(self): self.generate_config_and_remove_lines_containing("server_name") - with self.assertRaises(Exception): + with self.assertRaises(ConfigError): HomeServerConfig.load_config("", ["-c", self.file]) - with self.assertRaises(Exception): + with self.assertRaises(ConfigError): HomeServerConfig.load_or_generate_config("", ["-c", self.file]) def test_generates_and_loads_macaroon_secret_key(self): diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 30fcc4c1bf..a56063315b 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py
@@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import time +from unittest.mock import Mock -from mock import Mock - +import attr import canonicaljson import signedjson.key import signedjson.sign @@ -68,6 +68,11 @@ class MockPerspectiveServer: signedjson.sign.sign_json(res, self.server_name, self.key) +@attr.s(slots=True) +class FakeRequest: + id = attr.ib() + + @logcontext_clean class KeyringTestCase(unittest.HomeserverTestCase): def check_context(self, val, expected): @@ -89,7 +94,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): first_lookup_deferred = Deferred() async def first_lookup_fetch(keys_to_fetch): - self.assertEquals(current_context().request, "context_11") + self.assertEquals(current_context().request.id, "context_11") self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}}) await make_deferred_yieldable(first_lookup_deferred) @@ -102,9 +107,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): mock_fetcher.get_keys.side_effect = first_lookup_fetch async def first_lookup(): - with LoggingContext("context_11") as context_11: - context_11.request = "context_11" - + with LoggingContext("context_11", request=FakeRequest("context_11")): res_deferreds = kr.verify_json_objects_for_server( [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")] ) @@ -130,7 +133,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): # should block rather than start a second call async def second_lookup_fetch(keys_to_fetch): - self.assertEquals(current_context().request, "context_12") + self.assertEquals(current_context().request.id, "context_12") return { "server10": { get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100) @@ -142,9 +145,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): second_lookup_state = [0] async def second_lookup(): - with LoggingContext("context_12") as context_12: - context_12.request = "context_12" - + with LoggingContext("context_12", request=FakeRequest("context_12")): res_deferreds_2 = kr.verify_json_objects_for_server( [("server10", json1, 0, "test")] ) @@ -589,10 +590,7 @@ def get_key_id(key): @defer.inlineCallbacks def run_in_context(f, *args, **kwargs): - with LoggingContext("testctx") as ctx: - # we set the "request" prop to make it easier to follow what's going on in the - # logs. - ctx.request = "testctx" + with LoggingContext("testctx"): rv = yield f(*args, **kwargs) return rv diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py new file mode 100644
index 0000000000..c996ecc221 --- /dev/null +++ b/tests/events/test_presence_router.py
@@ -0,0 +1,386 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union +from unittest.mock import Mock + +import attr + +from synapse.api.constants import EduTypes +from synapse.events.presence_router import PresenceRouter +from synapse.federation.units import Transaction +from synapse.handlers.presence import UserPresenceState +from synapse.module_api import ModuleApi +from synapse.rest import admin +from synapse.rest.client.v1 import login, presence, room +from synapse.types import JsonDict, StreamToken, create_requester + +from tests.handlers.test_sync import generate_sync_config +from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config + + +@attr.s +class PresenceRouterTestConfig: + users_who_should_receive_all_presence = attr.ib(type=List[str], default=[]) + + +class PresenceRouterTestModule: + def __init__(self, config: PresenceRouterTestConfig, module_api: ModuleApi): + self._config = config + self._module_api = module_api + + async def get_users_for_states( + self, state_updates: Iterable[UserPresenceState] + ) -> Dict[str, Set[UserPresenceState]]: + users_to_state = { + user_id: set(state_updates) + for user_id in self._config.users_who_should_receive_all_presence + } + return users_to_state + + async def get_interested_users( + self, user_id: str + ) -> Union[Set[str], PresenceRouter.ALL_USERS]: + if user_id in self._config.users_who_should_receive_all_presence: + return PresenceRouter.ALL_USERS + + return set() + + @staticmethod + def parse_config(config_dict: dict) -> PresenceRouterTestConfig: + """Parse a configuration dictionary from the homeserver config, do + some validation and return a typed PresenceRouterConfig. + + Args: + config_dict: The configuration dictionary. + + Returns: + A validated config object. + """ + # Initialise a typed config object + config = PresenceRouterTestConfig() + + config.users_who_should_receive_all_presence = config_dict.get( + "users_who_should_receive_all_presence" + ) + + return config + + +class PresenceRouterTestCase(FederatingHomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + presence.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver( + federation_transport_client=Mock(spec=["send_transaction"]), + ) + + def prepare(self, reactor, clock, homeserver): + self.sync_handler = self.hs.get_sync_handler() + self.module_api = homeserver.get_module_api() + + @override_config( + { + "presence": { + "presence_router": { + "module": __name__ + ".PresenceRouterTestModule", + "config": { + "users_who_should_receive_all_presence": [ + "@presence_gobbler:test", + ] + }, + } + }, + "send_federation": True, + } + ) + def test_receiving_all_presence(self): + """Test that a user that does not share a room with another other can receive + presence for them, due to presence routing. + """ + # Create a user who should receive all presence of others + self.presence_receiving_user_id = self.register_user( + "presence_gobbler", "monkey" + ) + self.presence_receiving_user_tok = self.login("presence_gobbler", "monkey") + + # And two users who should not have any special routing + self.other_user_one_id = self.register_user("other_user_one", "monkey") + self.other_user_one_tok = self.login("other_user_one", "monkey") + self.other_user_two_id = self.register_user("other_user_two", "monkey") + self.other_user_two_tok = self.login("other_user_two", "monkey") + + # Put the other two users in a room with each other + room_id = self.helper.create_room_as( + self.other_user_one_id, tok=self.other_user_one_tok + ) + + self.helper.invite( + room_id, + self.other_user_one_id, + self.other_user_two_id, + tok=self.other_user_one_tok, + ) + self.helper.join(room_id, self.other_user_two_id, tok=self.other_user_two_tok) + # User one sends some presence + send_presence_update( + self, + self.other_user_one_id, + self.other_user_one_tok, + "online", + "boop", + ) + + # Check that the presence receiving user gets user one's presence when syncing + presence_updates, sync_token = sync_presence( + self, self.presence_receiving_user_id + ) + self.assertEqual(len(presence_updates), 1) + + presence_update = presence_updates[0] # type: UserPresenceState + self.assertEqual(presence_update.user_id, self.other_user_one_id) + self.assertEqual(presence_update.state, "online") + self.assertEqual(presence_update.status_msg, "boop") + + # Have all three users send presence + send_presence_update( + self, + self.other_user_one_id, + self.other_user_one_tok, + "online", + "user_one", + ) + send_presence_update( + self, + self.other_user_two_id, + self.other_user_two_tok, + "online", + "user_two", + ) + send_presence_update( + self, + self.presence_receiving_user_id, + self.presence_receiving_user_tok, + "online", + "presence_gobbler", + ) + + # Check that the presence receiving user gets everyone's presence + presence_updates, _ = sync_presence( + self, self.presence_receiving_user_id, sync_token + ) + self.assertEqual(len(presence_updates), 3) + + # But that User One only get itself and User Two's presence + presence_updates, _ = sync_presence(self, self.other_user_one_id) + self.assertEqual(len(presence_updates), 2) + + found = False + for update in presence_updates: + if update.user_id == self.other_user_two_id: + self.assertEqual(update.state, "online") + self.assertEqual(update.status_msg, "user_two") + found = True + + self.assertTrue(found) + + @override_config( + { + "presence": { + "presence_router": { + "module": __name__ + ".PresenceRouterTestModule", + "config": { + "users_who_should_receive_all_presence": [ + "@presence_gobbler1:test", + "@presence_gobbler2:test", + "@far_away_person:island", + ] + }, + } + }, + "send_federation": True, + } + ) + def test_send_local_online_presence_to_with_module(self): + """Tests that send_local_presence_to_users sends local online presence to a set + of specified local and remote users, with a custom PresenceRouter module enabled. + """ + # Create a user who will send presence updates + self.other_user_id = self.register_user("other_user", "monkey") + self.other_user_tok = self.login("other_user", "monkey") + + # And another two users that will also send out presence updates, as well as receive + # theirs and everyone else's + self.presence_receiving_user_one_id = self.register_user( + "presence_gobbler1", "monkey" + ) + self.presence_receiving_user_one_tok = self.login("presence_gobbler1", "monkey") + self.presence_receiving_user_two_id = self.register_user( + "presence_gobbler2", "monkey" + ) + self.presence_receiving_user_two_tok = self.login("presence_gobbler2", "monkey") + + # Have all three users send some presence updates + send_presence_update( + self, + self.other_user_id, + self.other_user_tok, + "online", + "I'm online!", + ) + send_presence_update( + self, + self.presence_receiving_user_one_id, + self.presence_receiving_user_one_tok, + "online", + "I'm also online!", + ) + send_presence_update( + self, + self.presence_receiving_user_two_id, + self.presence_receiving_user_two_tok, + "unavailable", + "I'm in a meeting!", + ) + + # Mark each presence-receiving user for receiving all user presence + self.get_success( + self.module_api.send_local_online_presence_to( + [ + self.presence_receiving_user_one_id, + self.presence_receiving_user_two_id, + ] + ) + ) + + # Perform a sync for each user + + # The other user should only receive their own presence + presence_updates, _ = sync_presence(self, self.other_user_id) + self.assertEqual(len(presence_updates), 1) + + presence_update = presence_updates[0] # type: UserPresenceState + self.assertEqual(presence_update.user_id, self.other_user_id) + self.assertEqual(presence_update.state, "online") + self.assertEqual(presence_update.status_msg, "I'm online!") + + # Whereas both presence receiving users should receive everyone's presence updates + presence_updates, _ = sync_presence(self, self.presence_receiving_user_one_id) + self.assertEqual(len(presence_updates), 3) + presence_updates, _ = sync_presence(self, self.presence_receiving_user_two_id) + self.assertEqual(len(presence_updates), 3) + + # Test that sending to a remote user works + remote_user_id = "@far_away_person:island" + + # Note that due to the remote user being in our module's + # users_who_should_receive_all_presence config, they would have + # received user presence updates already. + # + # Thus we reset the mock, and try sending all online local user + # presence again + self.hs.get_federation_transport_client().send_transaction.reset_mock() + + # Broadcast local user online presence + self.get_success( + self.module_api.send_local_online_presence_to([remote_user_id]) + ) + + # Check that the expected presence updates were sent + expected_users = [ + self.other_user_id, + self.presence_receiving_user_one_id, + self.presence_receiving_user_two_id, + ] + + calls = ( + self.hs.get_federation_transport_client().send_transaction.call_args_list + ) + for call in calls: + call_args = call[0] + federation_transaction = call_args[0] # type: Transaction + + # Get the sent EDUs in this transaction + edus = federation_transaction.get_dict()["edus"] + + for edu in edus: + # Make sure we're only checking presence-type EDUs + if edu["edu_type"] != EduTypes.Presence: + continue + + # EDUs can contain multiple presence updates + for presence_update in edu["content"]["push"]: + # Check for presence updates that contain the user IDs we're after + expected_users.remove(presence_update["user_id"]) + + # Ensure that no offline states are being sent out + self.assertNotEqual(presence_update["presence"], "offline") + + self.assertEqual(len(expected_users), 0) + + +def send_presence_update( + testcase: TestCase, + user_id: str, + access_token: str, + presence_state: str, + status_message: Optional[str] = None, +) -> JsonDict: + # Build the presence body + body = {"presence": presence_state} + if status_message: + body["status_msg"] = status_message + + # Update the user's presence state + channel = testcase.make_request( + "PUT", "/presence/%s/status" % (user_id,), body, access_token=access_token + ) + testcase.assertEqual(channel.code, 200) + + return channel.json_body + + +def sync_presence( + testcase: TestCase, + user_id: str, + since_token: Optional[StreamToken] = None, +) -> Tuple[List[UserPresenceState], StreamToken]: + """Perform a sync request for the given user and return the user presence updates + they've received, as well as the next_batch token. + + This method assumes testcase.sync_handler points to the homeserver's sync handler. + + Args: + testcase: The testcase that is currently being run. + user_id: The ID of the user to generate a sync response for. + since_token: An optional token to indicate from at what point to sync from. + + Returns: + A tuple containing a list of presence updates, and the sync response's + next_batch token. + """ + requester = create_requester(user_id) + sync_config = generate_sync_config(requester.user.to_string()) + sync_result = testcase.get_success( + testcase.sync_handler.wait_for_sync_for_user( + requester, sync_config, since_token + ) + ) + + return sync_result.presence, sync_result.next_batch diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 8186b8ca01..701fa8379f 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock from synapse.api.errors import Codes, SynapseError from synapse.rest import admin diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 95eac6a5a3..802c5ad299 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py
@@ -1,6 +1,5 @@ from typing import List, Tuple - -from mock import Mock +from unittest.mock import Mock from synapse.api.constants import EventTypes from synapse.events import EventBase diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index ecc3faa572..deb12433cf 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py
@@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional - -from mock import Mock +from unittest.mock import Mock from signedjson import key, sign from signedjson.types import BaseKey, SigningKey diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index a01fdd0839..32669ae9ce 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py
@@ -14,8 +14,7 @@ # limitations under the License. from collections import Counter - -from mock import Mock +from unittest.mock import Mock import synapse.api.errors import synapse.handlers.admin diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index d5d3fdd99a..6e325b24ce 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock from twisted.internet import defer diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index c9f889b511..321c5ba045 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py
@@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock import pymacaroons diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 7975af243c..0444b26798 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py
@@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock from synapse.handlers.cas_handler import CasResponse diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index fadec16e13..a8d0cf6603 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py
@@ -14,7 +14,7 @@ # limitations under the License. -from mock import Mock +from unittest.mock import Mock import synapse import synapse.api.errors diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 5e86c5e56b..6915ac0205 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py
@@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import mock +from unittest import mock from signedjson import key as key, sign as sign diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index d7498aa51a..07893302ec 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py
@@ -16,8 +16,7 @@ # limitations under the License. import copy - -import mock +from unittest import mock from synapse.api.errors import SynapseError diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index c7796fb837..8702ee70e0 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py
@@ -14,10 +14,9 @@ # limitations under the License. import json import os +from unittest.mock import ANY, Mock, patch from urllib.parse import parse_qs, urlparse -from mock import ANY, Mock, patch - import pymacaroons from synapse.handlers.sso import MappingException diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index a98a65ae67..e28e4159eb 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py
@@ -16,8 +16,7 @@ """Tests for the password_auth_provider interface""" from typing import Any, Type, Union - -from mock import Mock +from unittest.mock import Mock from twisted.internet import defer diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 77330f59a9..9f16cc65fc 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py
@@ -14,7 +14,7 @@ # limitations under the License. -from mock import Mock, call +from unittest.mock import Mock, call from signedjson.key import generate_signing_key diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index cbbe7280c7..60f2458c98 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock import synapse.types from synapse.api.errors import AuthError, SynapseError diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 00a0bc5274..c30b414d99 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock from synapse.api.auth import Auth from synapse.api.constants import UserTypes diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 30efd43b40..8cfc184fef 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py
@@ -13,8 +13,7 @@ # limitations under the License. from typing import Optional - -from mock import Mock +from unittest.mock import Mock import attr diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index e62586142e..8e950f25c5 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py
@@ -37,7 +37,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): def test_wait_for_sync_for_user_auth_blocking(self): user_id1 = "@user1:test" user_id2 = "@user2:test" - sync_config = self._generate_sync_config(user_id1) + sync_config = generate_sync_config(user_id1) requester = create_requester(user_id1) self.reactor.advance(100) # So we get not 0 time @@ -60,7 +60,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = False - sync_config = self._generate_sync_config(user_id2) + sync_config = generate_sync_config(user_id2) requester = create_requester(user_id2) e = self.get_failure( @@ -69,11 +69,12 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): ) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) - def _generate_sync_config(self, user_id): - return SyncConfig( - user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]), - filter_collection=DEFAULT_FILTER_COLLECTION, - is_guest=False, - request_key="request_key", - device_id="device_id", - ) + +def generate_sync_config(user_id: str) -> SyncConfig: + return SyncConfig( + user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]), + filter_collection=DEFAULT_FILTER_COLLECTION, + is_guest=False, + request_key="request_key", + device_id="device_id", + ) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 24e7138196..9fa231a37a 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py
@@ -16,8 +16,7 @@ import json from typing import Dict - -from mock import ANY, Mock, call +from unittest.mock import ANY, Mock, call from twisted.internet import defer from twisted.web.resource import Resource diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index dbe68bb058..67a8e49945 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py
@@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock from twisted.internet import defer diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 3972abb038..e6b20799e5 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging - -from mock import Mock +from typing import Optional +from unittest.mock import Mock import treq from netaddr import IPSet @@ -180,7 +180,11 @@ class MatrixFederationAgentTests(unittest.TestCase): _check_logcontext(context) def _handle_well_known_connection( - self, client_factory, expected_sni, content, response_headers={} + self, + client_factory, + expected_sni, + content, + response_headers: Optional[dict] = None, ): """Handle an outgoing HTTPs connection: wire it up to a server, check that the request is for a .well-known, and send the response. @@ -202,10 +206,12 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertEqual( request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"] ) - self._send_well_known_response(request, content, headers=response_headers) + self._send_well_known_response(request, content, headers=response_headers or {}) return well_known_server - def _send_well_known_response(self, request, content, headers={}): + def _send_well_known_response( + self, request, content, headers: Optional[dict] = None + ): """Check that an incoming request looks like a valid .well-known request, and send back the response. """ @@ -213,7 +219,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertEqual(request.path, b"/.well-known/matrix/server") self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) # send back a response - for k, v in headers.items(): + for k, v in (headers or {}).items(): request.setHeader(k, v) request.write(content) request.finish() diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py
index fee2985d35..466ce722d9 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py
@@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock from twisted.internet import defer from twisted.internet.defer import Deferred diff --git a/tests/http/test_client.py b/tests/http/test_client.py
index 0ce181a51e..7e2f2a01cc 100644 --- a/tests/http/test_client.py +++ b/tests/http/test_client.py
@@ -13,8 +13,7 @@ # limitations under the License. from io import BytesIO - -from mock import Mock +from unittest.mock import Mock from netaddr import IPSet diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py
index 9c52c8fdca..21c1297171 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock from netaddr import IPSet from parameterized import parameterized diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index 45089158ce..f979c96f7c 100644 --- a/tests/http/test_servlet.py +++ b/tests/http/test_servlet.py
@@ -14,8 +14,7 @@ # limitations under the License. import json from io import BytesIO - -from mock import Mock +from unittest.mock import Mock from synapse.api.errors import SynapseError from synapse.http.servlet import ( diff --git a/tests/http/test_simple_client.py b/tests/http/test_simple_client.py
index a1cf0862d4..cc4cae320d 100644 --- a/tests/http/test_simple_client.py +++ b/tests/http/test_simple_client.py
@@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock from netaddr import IPSet diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 48a74e2eee..ecf873e2ab 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py
@@ -12,15 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json import logging -from io import StringIO +from io import BytesIO, StringIO +from unittest.mock import Mock, patch + +from twisted.web.server import Request +from synapse.http.site import SynapseRequest from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter from synapse.logging.context import LoggingContext, LoggingContextFilter from tests.logging import LoggerCleanupMixin +from tests.server import FakeChannel from tests.unittest import TestCase @@ -120,7 +124,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): handler.addFilter(LoggingContextFilter()) logger = self.get_logger(handler) - with LoggingContext(request="test"): + with LoggingContext("name"): logger.info("Hello there, %s!", "wally") log = self.get_log_line() @@ -134,4 +138,63 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): ] self.assertCountEqual(log.keys(), expected_log_keys) self.assertEqual(log["log"], "Hello there, wally!") - self.assertEqual(log["request"], "test") + self.assertEqual(log["request"], "name") + + def test_with_request_context(self): + """ + Information from the logging context request should be added to the JSON response. + """ + handler = logging.StreamHandler(self.output) + handler.setFormatter(JsonFormatter()) + handler.addFilter(LoggingContextFilter()) + logger = self.get_logger(handler) + + # A full request isn't needed here. + site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"]) + site.site_tag = "test-site" + site.server_version_string = "Server v1" + request = SynapseRequest(FakeChannel(site, None)) + # Call requestReceived to finish instantiating the object. + request.content = BytesIO() + # Partially skip some of the internal processing of SynapseRequest. + request._started_processing = Mock() + request.request_metrics = Mock(spec=["name"]) + with patch.object(Request, "render"): + request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1") + + # Also set the requester to ensure the processing works. + request.requester = "@foo:test" + + with LoggingContext( + request.get_request_id(), parent_context=request.logcontext + ): + logger.info("Hello there, %s!", "wally") + + log = self.get_log_line() + + # The terse logger includes additional request information, if possible. + expected_log_keys = [ + "log", + "level", + "namespace", + "request", + "ip_address", + "site_tag", + "requester", + "authenticated_entity", + "method", + "url", + "protocol", + "user_agent", + ] + self.assertCountEqual(log.keys(), expected_log_keys) + self.assertEqual(log["log"], "Hello there, wally!") + self.assertTrue(log["request"].startswith("POST-")) + self.assertEqual(log["ip_address"], "127.0.0.1") + self.assertEqual(log["site_tag"], "test-site") + self.assertEqual(log["requester"], "@foo:test") + self.assertEqual(log["authenticated_entity"], "@foo:test") + self.assertEqual(log["method"], "POST") + self.assertEqual(log["url"], "/_matrix/client/versions") + self.assertEqual(log["protocol"], "1.1") + self.assertEqual(log["user_agent"], "") diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index edacd1b566..349f93560e 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py
@@ -12,27 +12,39 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock +from synapse.api.constants import EduTypes from synapse.events import EventBase +from synapse.federation.units import Transaction +from synapse.handlers.presence import UserPresenceState from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client.v1 import login, presence, room from synapse.types import create_requester -from tests.unittest import HomeserverTestCase +from tests.events.test_presence_router import send_presence_update, sync_presence +from tests.test_utils.event_injection import inject_member_event +from tests.unittest import FederatingHomeserverTestCase, override_config -class ModuleApiTestCase(HomeserverTestCase): +class ModuleApiTestCase(FederatingHomeserverTestCase): servlets = [ admin.register_servlets, login.register_servlets, room.register_servlets, + presence.register_servlets, ] def prepare(self, reactor, clock, homeserver): self.store = homeserver.get_datastore() self.module_api = homeserver.get_module_api() self.event_creation_handler = homeserver.get_event_creation_handler() + self.sync_handler = homeserver.get_sync_handler() + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver( + federation_transport_client=Mock(spec=["send_transaction"]), + ) def test_can_register_user(self): """Tests that an external module can register a user""" @@ -205,3 +217,161 @@ class ModuleApiTestCase(HomeserverTestCase): ) ) self.assertFalse(is_in_public_rooms) + + # The ability to send federation is required by send_local_online_presence_to. + @override_config({"send_federation": True}) + def test_send_local_online_presence_to(self): + """Tests that send_local_presence_to_users sends local online presence to local users.""" + # Create a user who will send presence updates + self.presence_receiver_id = self.register_user("presence_receiver", "monkey") + self.presence_receiver_tok = self.login("presence_receiver", "monkey") + + # And another user that will send presence updates out + self.presence_sender_id = self.register_user("presence_sender", "monkey") + self.presence_sender_tok = self.login("presence_sender", "monkey") + + # Put them in a room together so they will receive each other's presence updates + room_id = self.helper.create_room_as( + self.presence_receiver_id, + tok=self.presence_receiver_tok, + ) + self.helper.join(room_id, self.presence_sender_id, tok=self.presence_sender_tok) + + # Presence sender comes online + send_presence_update( + self, + self.presence_sender_id, + self.presence_sender_tok, + "online", + "I'm online!", + ) + + # Presence receiver should have received it + presence_updates, sync_token = sync_presence(self, self.presence_receiver_id) + self.assertEqual(len(presence_updates), 1) + + presence_update = presence_updates[0] # type: UserPresenceState + self.assertEqual(presence_update.user_id, self.presence_sender_id) + self.assertEqual(presence_update.state, "online") + + # Syncing again should result in no presence updates + presence_updates, sync_token = sync_presence( + self, self.presence_receiver_id, sync_token + ) + self.assertEqual(len(presence_updates), 0) + + # Trigger sending local online presence + self.get_success( + self.module_api.send_local_online_presence_to( + [ + self.presence_receiver_id, + ] + ) + ) + + # Presence receiver should have received online presence again + presence_updates, sync_token = sync_presence( + self, self.presence_receiver_id, sync_token + ) + self.assertEqual(len(presence_updates), 1) + + presence_update = presence_updates[0] # type: UserPresenceState + self.assertEqual(presence_update.user_id, self.presence_sender_id) + self.assertEqual(presence_update.state, "online") + + # Presence sender goes offline + send_presence_update( + self, + self.presence_sender_id, + self.presence_sender_tok, + "offline", + "I slink back into the darkness.", + ) + + # Trigger sending local online presence + self.get_success( + self.module_api.send_local_online_presence_to( + [ + self.presence_receiver_id, + ] + ) + ) + + # Presence receiver should *not* have received offline state + presence_updates, sync_token = sync_presence( + self, self.presence_receiver_id, sync_token + ) + self.assertEqual(len(presence_updates), 0) + + @override_config({"send_federation": True}) + def test_send_local_online_presence_to_federation(self): + """Tests that send_local_presence_to_users sends local online presence to remote users.""" + # Create a user who will send presence updates + self.presence_sender_id = self.register_user("presence_sender", "monkey") + self.presence_sender_tok = self.login("presence_sender", "monkey") + + # And a room they're a part of + room_id = self.helper.create_room_as( + self.presence_sender_id, + tok=self.presence_sender_tok, + ) + + # Mark them as online + send_presence_update( + self, + self.presence_sender_id, + self.presence_sender_tok, + "online", + "I'm online!", + ) + + # Make up a remote user to send presence to + remote_user_id = "@far_away_person:island" + + # Create a join membership event for the remote user into the room. + # This allows presence information to flow from one user to the other. + self.get_success( + inject_member_event( + self.hs, + room_id, + sender=remote_user_id, + target=remote_user_id, + membership="join", + ) + ) + + # The remote user would have received the existing room members' presence + # when they joined the room. + # + # Thus we reset the mock, and try sending online local user + # presence again + self.hs.get_federation_transport_client().send_transaction.reset_mock() + + # Broadcast local user online presence + self.get_success( + self.module_api.send_local_online_presence_to([remote_user_id]) + ) + + # Check that a presence update was sent as part of a federation transaction + found_update = False + calls = ( + self.hs.get_federation_transport_client().send_transaction.call_args_list + ) + for call in calls: + call_args = call[0] + federation_transaction = call_args[0] # type: Transaction + + # Get the sent EDUs in this transaction + edus = federation_transaction.get_dict()["edus"] + + for edu in edus: + # Make sure we're only checking presence-type EDUs + if edu["edu_type"] != EduTypes.Presence: + continue + + # EDUs can contain multiple presence updates + for presence_update in edu["content"]["push"]: + if presence_update["user_id"] == self.presence_sender_id: + found_update = True + + self.assertTrue(found_update) diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index a3b304d316..f590e8d21c 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py
@@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock from twisted.internet.defer import Deferred diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 1d4a592862..aff19d9fb3 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py
@@ -266,7 +266,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): return resource def make_worker_hs( - self, worker_app: str, extra_config: dict = {}, **kwargs + self, worker_app: str, extra_config: Optional[dict] = None, **kwargs ) -> HomeServer: """Make a new worker HS instance, correctly connecting replcation stream to the master HS. @@ -283,7 +283,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config = self._get_worker_hs_config() config["worker_app"] = worker_app - config.update(extra_config) + config.update(extra_config or {}) worker_hs = self.setup_test_homeserver( homeserver_to_use=GenericWorkerServer, diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 56497b8476..83e89383f6 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock from tests.replication._base import BaseStreamTestCase diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 0ceb0f935c..db80a0bdbd 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py
@@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Iterable, Optional from canonicaljson import encode_canonical_json @@ -332,15 +333,18 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): room_id=ROOM_ID, type="m.room.message", key=None, - internal={}, + internal: Optional[dict] = None, depth=None, - prev_events=[], - auth_events=[], - prev_state=[], + prev_events: Optional[list] = None, + auth_events: Optional[list] = None, + prev_state: Optional[list] = None, redacts=None, - push_actions=[], - **content + push_actions: Iterable = frozenset(), + **content, ): + prev_events = prev_events or [] + auth_events = auth_events or [] + prev_state = prev_state or [] if depth is None: depth = self.event_id @@ -369,7 +373,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if redacts is not None: event_dict["redacts"] = redacts - event = make_event_from_dict(event_dict, internal_metadata_dict=internal) + event = make_event_from_dict(event_dict, internal_metadata_dict=internal or {}) self.event_id += 1 state_handler = self.hs.get_state_handler() diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index 56b062ecc1..7d848e41ff 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py
@@ -15,7 +15,7 @@ # type: ignore -from mock import Mock +from unittest.mock import Mock from synapse.replication.tcp.streams._base import ReceiptsStream diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index ca49d4dd3a..4a0b342264 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py
@@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock from synapse.handlers.typing import RoomMember from synapse.replication.tcp.streams import TypingStream diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 0d9e3bb11d..44ad5eec57 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import mock +from unittest import mock from synapse.app.generic_worker import GenericWorkerServer from synapse.replication.tcp.commands import FederationAckCommand diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 2f2d117858..8ca595c3ee 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py
@@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging - -from mock import Mock +from unittest.mock import Mock from synapse.api.constants import EventTypes, Membership from synapse.events.builder import EventBuilderFactory diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index ab2988a6ba..1f12bde1aa 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py
@@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging - -from mock import Mock +from unittest.mock import Mock from twisted.internet import defer diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index c9b773fbd2..6c2e1674cb 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py
@@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging - -from mock import patch +from unittest.mock import patch from synapse.api.room_versions import RoomVersion from synapse.rest import admin diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py
index 057e27372e..4abcbe3f55 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py
@@ -17,8 +17,7 @@ import json import os import urllib.parse from binascii import unhexlify - -from mock import Mock +from unittest.mock import Mock from twisted.internet.defer import Deferred diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index b55160b70a..85f77c0a65 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py
@@ -16,8 +16,7 @@ import json import urllib.parse from typing import List, Optional - -from mock import Mock +from unittest.mock import Mock import synapse.rest.admin from synapse.api.constants import EventTypes, Membership diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 79a05b519b..a7b600a1d4 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py
@@ -19,8 +19,7 @@ import json import urllib.parse from binascii import unhexlify from typing import List, Optional - -from mock import Mock +from unittest.mock import Mock import synapse.rest.admin from synapse.api.constants import UserTypes @@ -28,7 +27,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions from synapse.rest.client.v1 import login, logout, profile, room from synapse.rest.client.v2_alpha import devices, sync -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from tests import unittest from tests.server import FakeSite, make_request @@ -467,6 +466,8 @@ class UsersListTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users" def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -634,6 +635,26 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + # unkown order_by + channel = self.make_request( + "GET", + self.url + "?order_by=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # invalid search order + channel = self.make_request( + "GET", + self.url + "?dir=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + def test_limit(self): """ Testing list of users with limit @@ -759,6 +780,103 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) + def test_order_by(self): + """ + Testing order list with parameter `order_by` + """ + + user1 = self.register_user("user1", "pass1", admin=False, displayname="Name Z") + user2 = self.register_user("user2", "pass2", admin=False, displayname="Name Y") + + # Modify user + self.get_success(self.store.set_user_deactivated_status(user1, True)) + self.get_success(self.store.set_shadow_banned(UserID.from_string(user1), True)) + + # Set avatar URL to all users, that no user has a NULL value to avoid + # different sort order between SQlite and PostreSQL + self.get_success(self.store.set_profile_avatar_url("user1", "mxc://url3")) + self.get_success(self.store.set_profile_avatar_url("user2", "mxc://url2")) + self.get_success(self.store.set_profile_avatar_url("admin", "mxc://url1")) + + # order by default (name) + self._order_test([self.admin_user, user1, user2], None) + self._order_test([self.admin_user, user1, user2], None, "f") + self._order_test([user2, user1, self.admin_user], None, "b") + + # order by name + self._order_test([self.admin_user, user1, user2], "name") + self._order_test([self.admin_user, user1, user2], "name", "f") + self._order_test([user2, user1, self.admin_user], "name", "b") + + # order by displayname + self._order_test([user2, user1, self.admin_user], "displayname") + self._order_test([user2, user1, self.admin_user], "displayname", "f") + self._order_test([self.admin_user, user1, user2], "displayname", "b") + + # order by is_guest + # like sort by ascending name, as no guest user here + self._order_test([self.admin_user, user1, user2], "is_guest") + self._order_test([self.admin_user, user1, user2], "is_guest", "f") + self._order_test([self.admin_user, user1, user2], "is_guest", "b") + + # order by admin + self._order_test([user1, user2, self.admin_user], "admin") + self._order_test([user1, user2, self.admin_user], "admin", "f") + self._order_test([self.admin_user, user1, user2], "admin", "b") + + # order by deactivated + self._order_test([self.admin_user, user2, user1], "deactivated") + self._order_test([self.admin_user, user2, user1], "deactivated", "f") + self._order_test([user1, self.admin_user, user2], "deactivated", "b") + + # order by user_type + # like sort by ascending name, as no special user type here + self._order_test([self.admin_user, user1, user2], "user_type") + self._order_test([self.admin_user, user1, user2], "user_type", "f") + self._order_test([self.admin_user, user1, user2], "is_guest", "b") + + # order by shadow_banned + self._order_test([self.admin_user, user2, user1], "shadow_banned") + self._order_test([self.admin_user, user2, user1], "shadow_banned", "f") + self._order_test([user1, self.admin_user, user2], "shadow_banned", "b") + + # order by avatar_url + self._order_test([self.admin_user, user2, user1], "avatar_url") + self._order_test([self.admin_user, user2, user1], "avatar_url", "f") + self._order_test([user1, user2, self.admin_user], "avatar_url", "b") + + def _order_test( + self, + expected_user_list: List[str], + order_by: Optional[str], + dir: Optional[str] = None, + ): + """Request the list of users in a certain order. Assert that order is what + we expect + Args: + expected_user_list: The list of user_id in the order we expect to get + back from the server + order_by: The type of ordering to give the server + dir: The direction of ordering to give the server + """ + + url = self.url + "?deactivated=true&" + if order_by is not None: + url += "order_by=%s&" % (order_by,) + if dir is not None and dir in ("b", "f"): + url += "dir=%s" % (dir,) + channel = self.make_request( + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], len(expected_user_list)) + + returned_order = [row["name"] for row in channel.json_body["users"]] + self.assertEqual(expected_user_list, returned_order) + self._check_fields(channel.json_body["users"]) + def _check_fields(self, content: JsonDict): """Checks that the expected user attributes are present in content Args: @@ -2908,3 +3026,287 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): # Ensure the user is shadow-banned (and the cache was cleared). result = self.get_success(self.store.get_user_by_access_token(other_user_token)) self.assertTrue(result.shadow_banned) + + +class RateLimitTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.url = ( + "/_synapse/admin/v1/users/%s/override_ratelimit" + % urllib.parse.quote(self.other_user) + ) + + def test_no_auth(self): + """ + Try to get information of a user without authentication. + """ + channel = self.make_request("GET", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + channel = self.make_request("POST", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + channel = self.make_request("DELETE", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + channel = self.make_request( + "GET", + self.url, + access_token=other_user_token, + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + channel = self.make_request( + "POST", + self.url, + access_token=other_user_token, + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + channel = self.make_request( + "DELETE", + self.url, + access_token=other_user_token, + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit" + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + channel = self.make_request( + "POST", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + channel = self.make_request( + "DELETE", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url = ( + "/_synapse/admin/v1/users/@unknown_person:unknown_domain/override_ratelimit" + ) + + channel = self.make_request( + "GET", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + channel = self.make_request( + "POST", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual( + "Only local users can be ratelimited", channel.json_body["error"] + ) + + channel = self.make_request( + "DELETE", + url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual( + "Only local users can be ratelimited", channel.json_body["error"] + ) + + def test_invalid_parameter(self): + """ + If parameters are invalid, an error is returned. + """ + # messages_per_second is a string + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"messages_per_second": "string"}, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # messages_per_second is negative + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"messages_per_second": -1}, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # burst_count is a string + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"burst_count": "string"}, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + # burst_count is negative + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"burst_count": -1}, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_return_zero_when_null(self): + """ + If values in database are `null` API should return an int `0` + """ + + self.get_success( + self.store.db_pool.simple_upsert( + table="ratelimit_override", + keyvalues={"user_id": self.other_user}, + values={ + "messages_per_second": None, + "burst_count": None, + }, + ) + ) + + # request status + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(0, channel.json_body["messages_per_second"]) + self.assertEqual(0, channel.json_body["burst_count"]) + + def test_success(self): + """ + Rate-limiting (set/update/delete) should succeed for an admin. + """ + # request status + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertNotIn("messages_per_second", channel.json_body) + self.assertNotIn("burst_count", channel.json_body) + + # set ratelimit + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"messages_per_second": 10, "burst_count": 11}, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(10, channel.json_body["messages_per_second"]) + self.assertEqual(11, channel.json_body["burst_count"]) + + # update ratelimit + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"messages_per_second": 20, "burst_count": 21}, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(20, channel.json_body["messages_per_second"]) + self.assertEqual(21, channel.json_body["burst_count"]) + + # request status + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(20, channel.json_body["messages_per_second"]) + self.assertEqual(21, channel.json_body["burst_count"]) + + # delete ratelimit + channel = self.make_request( + "DELETE", + self.url, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertNotIn("messages_per_second", channel.json_body) + self.assertNotIn("burst_count", channel.json_body) + + # request status + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertNotIn("messages_per_second", channel.json_body) + self.assertNotIn("burst_count", channel.json_body) diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index b8285f3240..be1211dbce 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py
@@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock from synapse.api.constants import EventTypes from synapse.rest import admin diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index d2cce44032..288ee12888 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py
@@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock, patch +from unittest.mock import Mock, patch import synapse.rest.admin from synapse.api.constants import EventTypes diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index bf39014277..a7ebe0c3e9 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py
@@ -14,8 +14,7 @@ # limitations under the License. import threading from typing import Dict - -from mock import Mock +from unittest.mock import Mock from synapse.events import EventBase from synapse.module_api import ModuleApi diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 171632e195..3b5747cb12 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py
@@ -1,4 +1,4 @@ -from mock import Mock, call +from unittest.mock import Mock, call from twisted.internet import defer, reactor diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 2ae896db1e..87a18d2cb9 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py
@@ -15,7 +15,7 @@ """ Tests REST events for /events paths.""" -from mock import Mock +from unittest.mock import Mock import synapse.rest.admin from synapse.rest.client.v1 import events, login, room diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 988821b16f..c7b79ab8a7 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py
@@ -16,10 +16,9 @@ import time import urllib.parse from typing import Any, Dict, List, Optional, Union +from unittest.mock import Mock from urllib.parse import urlencode -from mock import Mock - import pymacaroons from twisted.web.resource import Resource diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 94a5154834..c136827f79 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock from twisted.internet import defer diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index ed65f645fc..4df20c90fd 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py
@@ -19,10 +19,10 @@ """Tests REST events for /rooms paths.""" import json +from typing import Iterable +from unittest.mock import Mock from urllib import parse as urlparse -from mock import Mock - import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus @@ -207,7 +207,9 @@ class RoomPermissionsTestCase(RoomBase): ) self.assertEquals(403, channel.code, msg=channel.result["body"]) - def _test_get_membership(self, room=None, members=[], expect_code=None): + def _test_get_membership( + self, room=None, members: Iterable = frozenset(), expect_code=None + ): for member in members: path = "/rooms/%s/state/m.room.member/%s" % (room, member) channel = self.make_request("GET", path) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 329dbd06de..0b8f565121 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py
@@ -16,7 +16,7 @@ """Tests REST events for /rooms paths.""" -from mock import Mock +from unittest.mock import Mock from synapse.rest.client.v1 import room from synapse.types import UserID diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 946740aa5d..a6a292b20c 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py
@@ -21,8 +21,7 @@ import re import time import urllib.parse from typing import Any, Dict, Mapping, MutableMapping, Optional - -from mock import patch +from unittest.mock import patch import attr @@ -132,7 +131,7 @@ class RestHelper: src: str, targ: str, membership: str, - extra_data: dict = {}, + extra_data: Optional[dict] = None, tok: Optional[str] = None, expect_code: int = 200, ) -> None: @@ -156,7 +155,7 @@ class RestHelper: path = path + "?access_token=%s" % tok data = {"membership": membership} - data.update(extra_data) + data.update(extra_data or {}) channel = make_request( self.hs.get_reactor(), @@ -187,7 +186,13 @@ class RestHelper: ) def send_event( - self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200 + self, + room_id, + type, + content: Optional[dict] = None, + txn_id=None, + tok=None, + expect_code=200, ): if txn_id is None: txn_id = "m%s" % (str(time.time())) @@ -201,7 +206,7 @@ class RestHelper: self.site, "PUT", path, - json.dumps(content).encode("utf8"), + json.dumps(content or {}).encode("utf8"), ) assert ( diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 9734a2159a..ed433d9333 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Optional, Union from twisted.internet.defer import succeed @@ -74,7 +74,10 @@ class FallbackAuthTests(unittest.HomeserverTestCase): return channel def recaptcha( - self, session: str, expected_post_response: int, post_session: str = None + self, + session: str, + expected_post_response: int, + post_session: Optional[str] = None, ) -> None: """Get and respond to a fallback recaptcha. Returns the second request.""" if post_session is None: diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 2d4ce871eb..41e52c701f 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py
@@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import datetime import json import os @@ -28,7 +27,7 @@ import pkg_resources from twisted.internet import defer import synapse.rest.admin -from synapse.api.constants import LoginType +from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client.v1 import login, logout @@ -65,7 +64,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): ) self.hs.get_datastore().services_cache.append(appservice) - request_data = json.dumps({"username": "as_user_kermit"}) + request_data = json.dumps( + {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE} + ) channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data @@ -75,9 +76,31 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): det_data = {"user_id": user_id, "home_server": self.hs.hostname} self.assertDictContainsSubset(det_data, channel.json_body) + def test_POST_appservice_registration_no_type(self): + as_token = "i_am_an_app_service" + + appservice = ApplicationService( + as_token, + self.hs.config.server_name, + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", + ) + + self.hs.get_datastore().services_cache.append(appservice) + request_data = json.dumps({"username": "as_user_kermit"}) + + channel = self.make_request( + b"POST", self.url + b"?access_token=i_am_an_app_service", request_data + ) + + self.assertEquals(channel.result["code"], b"400", channel.result) + def test_POST_appservice_registration_invalid(self): self.appservice = None # no application service exists - request_data = json.dumps({"username": "kermit"}) + request_data = json.dumps( + {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} + ) channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index e7bb5583fc..21ee436b91 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -16,6 +16,7 @@ import itertools import json import urllib +from typing import Optional from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin @@ -681,7 +682,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): relation_type, event_type, key=None, - content={}, + content: Optional[dict] = None, access_token=None, parent_id=None, ): @@ -713,7 +714,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): "POST", "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" % (self.room, original_id, relation_type, event_type, query), - json.dumps(content).encode("utf-8"), + json.dumps(content or {}).encode("utf-8"), access_token=access_token, ) return channel diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 9d0d0ef414..eb8687ce68 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -14,8 +14,7 @@ # limitations under the License. import urllib.parse from io import BytesIO, StringIO - -from mock import Mock +from unittest.mock import Mock import signedjson.key from canonicaljson import encode_canonical_json diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 9f77125fd4..375f0b7977 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py
@@ -18,10 +18,9 @@ import tempfile from binascii import unhexlify from io import BytesIO from typing import Optional +from unittest.mock import Mock from urllib import parse -from mock import Mock - import attr from parameterized import parameterized_class from PIL import Image as Image diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index 6968502433..9067463e54 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py
@@ -15,8 +15,7 @@ import json import os import re - -from mock import patch +from unittest.mock import patch from twisted.internet._resolver import HostResolution from twisted.internet.address import IPv4Address, IPv6Address diff --git a/tests/scripts/test_new_matrix_user.py b/tests/scripts/test_new_matrix_user.py
index 6f56893f5e..885b95a51f 100644 --- a/tests/scripts/test_new_matrix_user.py +++ b/tests/scripts/test_new_matrix_user.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock from synapse._scripts.register_new_matrix_user import request_registration diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index d40d65b06a..450b4ec710 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock from twisted.internet import defer diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 1ce29af5fd..e755a4db62 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py
@@ -15,8 +15,7 @@ import json import os import tempfile - -from mock import Mock +from unittest.mock import Mock import yaml diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 1b4fae0bb5..069db0edc4 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py
@@ -1,4 +1,4 @@ -from mock import Mock +from unittest.mock import Mock from synapse.storage.background_updates import BackgroundUpdater diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index eac7e4dcd2..54e9e7f6fe 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py
@@ -15,8 +15,7 @@ from collections import OrderedDict - -from mock import Mock +from unittest.mock import Mock from twisted.internet import defer diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 7791138688..b02fb32ced 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py
@@ -14,9 +14,7 @@ # limitations under the License. import os.path -from unittest.mock import patch - -from mock import Mock +from unittest.mock import Mock, patch import synapse.rest.admin from synapse.api.constants import EventTypes diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 34e6526097..f7f75320ba 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py
@@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock import synapse.rest.admin from synapse.http.site import XForwardedForRequest @@ -390,7 +390,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): class ClientIpAuthTestCase(unittest.HomeserverTestCase): servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, + synapse.rest.admin.register_servlets, login.register_servlets, ] @@ -434,7 +434,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): self.reactor, self.site, "GET", - "/_synapse/admin/v1/users/" + self.user_id, + "/_synapse/admin/v2/users/" + self.user_id, access_token=access_token, custom_headers=headers1.items(), **make_request_args, diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py
index 5a77c84962..a906d30e73 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py
@@ -36,17 +36,6 @@ def _stub_db_engine(**kwargs) -> BaseDatabaseEngine: class TupleComparisonClauseTestCase(unittest.TestCase): def test_native_tuple_comparison(self): - db_engine = _stub_db_engine(supports_tuple_comparison=True) - clause, args = make_tuple_comparison_clause(db_engine, [("a", 1), ("b", 2)]) + clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)]) self.assertEqual(clause, "(a,b) > (?,?)") self.assertEqual(args, [1, 2]) - - def test_emulated_tuple_comparison(self): - db_engine = _stub_db_engine(supports_tuple_comparison=False) - clause, args = make_tuple_comparison_clause( - db_engine, [("a", 1), ("b", 2), ("c", 3)] - ) - self.assertEqual( - clause, "(a >= ? AND (a > ? OR (b >= ? AND (b > ? OR c > ?))))" - ) - self.assertEqual(args, [1, 1, 2, 2, 3]) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index dabc1c5f09..ef4cf8d0f1 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2016 OpenMarket Ltd +# Copyright 2016-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,32 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - import synapse.api.errors -import tests.unittest -import tests.utils - - -class DeviceStoreTestCase(tests.unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.store = None # type: synapse.storage.DataStore +from tests.unittest import HomeserverTestCase - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver(self.addCleanup) +class DeviceStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks def test_store_new_device(self): - yield defer.ensureDeferred( + self.get_success( self.store.store_device("user_id", "device_id", "display_name") ) - res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) + res = self.get_success(self.store.get_device("user_id", "device_id")) self.assertDictContainsSubset( { "user_id": "user_id", @@ -48,19 +37,18 @@ class DeviceStoreTestCase(tests.unittest.TestCase): res, ) - @defer.inlineCallbacks def test_get_devices_by_user(self): - yield defer.ensureDeferred( + self.get_success( self.store.store_device("user_id", "device1", "display_name 1") ) - yield defer.ensureDeferred( + self.get_success( self.store.store_device("user_id", "device2", "display_name 2") ) - yield defer.ensureDeferred( + self.get_success( self.store.store_device("user_id2", "device3", "display_name 3") ) - res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id")) + res = self.get_success(self.store.get_devices_by_user("user_id")) self.assertEqual(2, len(res.keys())) self.assertDictContainsSubset( { @@ -79,43 +67,41 @@ class DeviceStoreTestCase(tests.unittest.TestCase): res["device2"], ) - @defer.inlineCallbacks def test_count_devices_by_users(self): - yield defer.ensureDeferred( + self.get_success( self.store.store_device("user_id", "device1", "display_name 1") ) - yield defer.ensureDeferred( + self.get_success( self.store.store_device("user_id", "device2", "display_name 2") ) - yield defer.ensureDeferred( + self.get_success( self.store.store_device("user_id2", "device3", "display_name 3") ) - res = yield defer.ensureDeferred(self.store.count_devices_by_users()) + res = self.get_success(self.store.count_devices_by_users()) self.assertEqual(0, res) - res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"])) + res = self.get_success(self.store.count_devices_by_users(["unknown"])) self.assertEqual(0, res) - res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"])) + res = self.get_success(self.store.count_devices_by_users(["user_id"])) self.assertEqual(2, res) - res = yield defer.ensureDeferred( + res = self.get_success( self.store.count_devices_by_users(["user_id", "user_id2"]) ) self.assertEqual(3, res) - @defer.inlineCallbacks def test_get_device_updates_by_remote(self): device_ids = ["device_id1", "device_id2"] # Add two device updates with a single stream_id - yield defer.ensureDeferred( + self.get_success( self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) ) # Get all device updates ever meant for this remote - now_stream_id, device_updates = yield defer.ensureDeferred( + now_stream_id, device_updates = self.get_success( self.store.get_device_updates_by_remote("somehost", -1, limit=100) ) @@ -131,37 +117,35 @@ class DeviceStoreTestCase(tests.unittest.TestCase): } self.assertEqual(received_device_ids, set(expected_device_ids)) - @defer.inlineCallbacks def test_update_device(self): - yield defer.ensureDeferred( + self.get_success( self.store.store_device("user_id", "device_id", "display_name 1") ) - res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) + res = self.get_success(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 1", res["display_name"]) # do a no-op first - yield defer.ensureDeferred(self.store.update_device("user_id", "device_id")) - res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) + self.get_success(self.store.update_device("user_id", "device_id")) + res = self.get_success(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 1", res["display_name"]) # do the update - yield defer.ensureDeferred( + self.get_success( self.store.update_device( "user_id", "device_id", new_display_name="display_name 2" ) ) # check it worked - res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id")) + res = self.get_success(self.store.get_device("user_id", "device_id")) self.assertEqual("display_name 2", res["display_name"]) - @defer.inlineCallbacks def test_update_unknown_device(self): - with self.assertRaises(synapse.api.errors.StoreError) as cm: - yield defer.ensureDeferred( - self.store.update_device( - "user_id", "unknown_device_id", new_display_name="display_name 2" - ) - ) - self.assertEqual(404, cm.exception.code) + exc = self.get_failure( + self.store.update_device( + "user_id", "unknown_device_id", new_display_name="display_name 2" + ), + synapse.api.errors.StoreError, + ) + self.assertEqual(404, exc.value.code) diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index da93ca3980..0db233fd68 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,28 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from twisted.internet import defer - from synapse.types import RoomAlias, RoomID -from tests import unittest -from tests.utils import setup_test_homeserver +from tests.unittest import HomeserverTestCase -class DirectoryStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup) - +class DirectoryStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.room = RoomID.from_string("!abcde:test") self.alias = RoomAlias.from_string("#my-room:test") - @defer.inlineCallbacks def test_room_to_alias(self): - yield defer.ensureDeferred( + self.get_success( self.store.create_room_alias_association( room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] ) @@ -42,16 +34,11 @@ class DirectoryStoreTestCase(unittest.TestCase): self.assertEquals( ["#my-room:test"], - ( - yield defer.ensureDeferred( - self.store.get_aliases_for_room(self.room.to_string()) - ) - ), + (self.get_success(self.store.get_aliases_for_room(self.room.to_string()))), ) - @defer.inlineCallbacks def test_alias_to_room(self): - yield defer.ensureDeferred( + self.get_success( self.store.create_room_alias_association( room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] ) @@ -59,28 +46,19 @@ class DirectoryStoreTestCase(unittest.TestCase): self.assertObjectHasAttributes( {"room_id": self.room.to_string(), "servers": ["test"]}, - ( - yield defer.ensureDeferred( - self.store.get_association_from_room_alias(self.alias) - ) - ), + (self.get_success(self.store.get_association_from_room_alias(self.alias))), ) - @defer.inlineCallbacks def test_delete_alias(self): - yield defer.ensureDeferred( + self.get_success( self.store.create_room_alias_association( room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] ) ) - room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias)) + room_id = self.get_success(self.store.delete_room_alias(self.alias)) self.assertEqual(self.room.to_string(), room_id) self.assertIsNone( - ( - yield defer.ensureDeferred( - self.store.get_association_from_room_alias(self.alias) - ) - ) + (self.get_success(self.store.get_association_from_room_alias(self.alias))) ) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 3fc4bb13b6..1e54b940fd 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2016 OpenMarket Ltd +# Copyright 2016-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,30 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +from tests.unittest import HomeserverTestCase -import tests.unittest -import tests.utils - -class EndToEndKeyStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver(self.addCleanup) +class EndToEndKeyStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks def test_key_without_device_name(self): now = 1470174257070 json = {"key": "value"} - yield defer.ensureDeferred(self.store.store_device("user", "device", None)) + self.get_success(self.store.store_device("user", "device", None)) - yield defer.ensureDeferred( - self.store.set_e2e_device_keys("user", "device", now, json) - ) + self.get_success(self.store.set_e2e_device_keys("user", "device", now, json)) - res = yield defer.ensureDeferred( + res = self.get_success( self.store.get_e2e_device_keys_for_cs_api((("user", "device"),)) ) self.assertIn("user", res) @@ -44,38 +36,32 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): dev = res["user"]["device"] self.assertDictContainsSubset(json, dev) - @defer.inlineCallbacks def test_reupload_key(self): now = 1470174257070 json = {"key": "value"} - yield defer.ensureDeferred(self.store.store_device("user", "device", None)) + self.get_success(self.store.store_device("user", "device", None)) - changed = yield defer.ensureDeferred( + changed = self.get_success( self.store.set_e2e_device_keys("user", "device", now, json) ) self.assertTrue(changed) # If we try to upload the same key then we should be told nothing # changed - changed = yield defer.ensureDeferred( + changed = self.get_success( self.store.set_e2e_device_keys("user", "device", now, json) ) self.assertFalse(changed) - @defer.inlineCallbacks def test_get_key_with_device_name(self): now = 1470174257070 json = {"key": "value"} - yield defer.ensureDeferred( - self.store.set_e2e_device_keys("user", "device", now, json) - ) - yield defer.ensureDeferred( - self.store.store_device("user", "device", "display_name") - ) + self.get_success(self.store.set_e2e_device_keys("user", "device", now, json)) + self.get_success(self.store.store_device("user", "device", "display_name")) - res = yield defer.ensureDeferred( + res = self.get_success( self.store.get_e2e_device_keys_for_cs_api((("user", "device"),)) ) self.assertIn("user", res) @@ -85,29 +71,28 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase): {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev ) - @defer.inlineCallbacks def test_multiple_devices(self): now = 1470174257070 - yield defer.ensureDeferred(self.store.store_device("user1", "device1", None)) - yield defer.ensureDeferred(self.store.store_device("user1", "device2", None)) - yield defer.ensureDeferred(self.store.store_device("user2", "device1", None)) - yield defer.ensureDeferred(self.store.store_device("user2", "device2", None)) + self.get_success(self.store.store_device("user1", "device1", None)) + self.get_success(self.store.store_device("user1", "device2", None)) + self.get_success(self.store.store_device("user2", "device1", None)) + self.get_success(self.store.store_device("user2", "device2", None)) - yield defer.ensureDeferred( + self.get_success( self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) ) - yield defer.ensureDeferred( + self.get_success( self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"}) ) - yield defer.ensureDeferred( + self.get_success( self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) ) - yield defer.ensureDeferred( + self.get_success( self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) ) - res = yield defer.ensureDeferred( + res = self.get_success( self.store.get_e2e_device_keys_for_cs_api( (("user1", "device1"), ("user2", "device2")) ) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 485f1ee033..0289942f88 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2016 OpenMarket Ltd +# Copyright 2016-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,12 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock -from twisted.internet import defer - -import tests.unittest -import tests.utils +from tests.unittest import HomeserverTestCase USER_ID = "@user:example.com" @@ -30,37 +27,31 @@ HIGHLIGHT = [ ] -class EventPushActionsStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver(self.addCleanup) +class EventPushActionsStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.persist_events_store = hs.get_datastores().persist_events - @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_http(self): - yield defer.ensureDeferred( + self.get_success( self.store.get_unread_push_actions_for_user_in_range_for_http( USER_ID, 0, 1000, 20 ) ) - @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_email(self): - yield defer.ensureDeferred( + self.get_success( self.store.get_unread_push_actions_for_user_in_range_for_email( USER_ID, 0, 1000, 20 ) ) - @defer.inlineCallbacks def test_count_aggregation(self): room_id = "!foo:example.com" user_id = "@user1235:example.com" - @defer.inlineCallbacks def _assert_counts(noitf_count, highlight_count): - counts = yield defer.ensureDeferred( + counts = self.get_success( self.store.db_pool.runInteraction( "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) @@ -74,7 +65,6 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): }, ) - @defer.inlineCallbacks def _inject_actions(stream, action): event = Mock() event.room_id = room_id @@ -82,14 +72,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): event.internal_metadata.stream_ordering = stream event.depth = stream - yield defer.ensureDeferred( + self.get_success( self.store.add_push_actions_to_staging( event.event_id, {user_id: action}, False, ) ) - yield defer.ensureDeferred( + self.get_success( self.store.db_pool.runInteraction( "", self.persist_events_store._set_push_actions_for_event_and_users_txn, @@ -99,14 +89,14 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ) def _rotate(stream): - return defer.ensureDeferred( + self.get_success( self.store.db_pool.runInteraction( "", self.store._rotate_notifs_before_txn, stream ) ) def _mark_read(stream, depth): - return defer.ensureDeferred( + self.get_success( self.store.db_pool.runInteraction( "", self.store._remove_old_push_actions_before_txn, @@ -116,49 +106,48 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ) ) - yield _assert_counts(0, 0) - yield _inject_actions(1, PlAIN_NOTIF) - yield _assert_counts(1, 0) - yield _rotate(2) - yield _assert_counts(1, 0) + _assert_counts(0, 0) + _inject_actions(1, PlAIN_NOTIF) + _assert_counts(1, 0) + _rotate(2) + _assert_counts(1, 0) - yield _inject_actions(3, PlAIN_NOTIF) - yield _assert_counts(2, 0) - yield _rotate(4) - yield _assert_counts(2, 0) + _inject_actions(3, PlAIN_NOTIF) + _assert_counts(2, 0) + _rotate(4) + _assert_counts(2, 0) - yield _inject_actions(5, PlAIN_NOTIF) - yield _mark_read(3, 3) - yield _assert_counts(1, 0) + _inject_actions(5, PlAIN_NOTIF) + _mark_read(3, 3) + _assert_counts(1, 0) - yield _mark_read(5, 5) - yield _assert_counts(0, 0) + _mark_read(5, 5) + _assert_counts(0, 0) - yield _inject_actions(6, PlAIN_NOTIF) - yield _rotate(7) + _inject_actions(6, PlAIN_NOTIF) + _rotate(7) - yield defer.ensureDeferred( + self.get_success( self.store.db_pool.simple_delete( table="event_push_actions", keyvalues={"1": 1}, desc="" ) ) - yield _assert_counts(1, 0) + _assert_counts(1, 0) - yield _mark_read(7, 7) - yield _assert_counts(0, 0) + _mark_read(7, 7) + _assert_counts(0, 0) - yield _inject_actions(8, HIGHLIGHT) - yield _assert_counts(1, 1) - yield _rotate(9) - yield _assert_counts(1, 1) - yield _rotate(10) - yield _assert_counts(1, 1) + _inject_actions(8, HIGHLIGHT) + _assert_counts(1, 1) + _rotate(9) + _assert_counts(1, 1) + _rotate(10) + _assert_counts(1, 1) - @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return defer.ensureDeferred( + self.get_success( self.store.db_pool.simple_insert( "events", { @@ -177,24 +166,16 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ) # start with the base case where there are no events in the table - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(11) - ) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(11)) self.assertEqual(r, 0) # now with one event - yield add_event(2, 10) - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(9) - ) + add_event(2, 10) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(9)) self.assertEqual(r, 2) - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(10) - ) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(10)) self.assertEqual(r, 2) - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(11) - ) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(11)) self.assertEqual(r, 3) # add a bunch of dummy events to the events table @@ -205,39 +186,27 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): (10, 130), (20, 140), ): - yield add_event(stream_ordering, ts) + add_event(stream_ordering, ts) - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(110) - ) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(110)) self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r) # 4 and 5 are both after 120: we want 4 rather than 5 - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(120) - ) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(120)) self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r) - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(129) - ) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(129)) self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r) # check we can get the last event - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(140) - ) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(140)) self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r) # off the end - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(160) - ) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(160)) self.assertEqual(r, 21) # check we can find an event at ordering zero - yield add_event(0, 5) - r = yield defer.ensureDeferred( - self.store.find_first_stream_ordering_after_ts(1) - ) + add_event(0, 5) + r = self.get_success(self.store.find_first_stream_ordering_after_ts(1)) self.assertEqual(r, 0) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index aad6bc907e..6c389fe9ac 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py
@@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional + from synapse.storage.database import DatabasePool from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.util.id_generators import MultiWriterIdGenerator @@ -43,7 +45,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers=["master"] + self, instance_name="master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( @@ -53,7 +55,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): instance_name=instance_name, tables=[("foobar", "instance_name", "stream_id")], sequence_name="foobar_seq", - writers=writers, + writers=writers or ["master"], ) return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) @@ -476,7 +478,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers=["master"] + self, instance_name="master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( @@ -486,7 +488,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): instance_name=instance_name, tables=[("foobar", "instance_name", "stream_id")], sequence_name="foobar_seq", - writers=writers, + writers=writers or ["master"], positive=False, ) @@ -612,7 +614,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): ) def _create_id_generator( - self, instance_name="master", writers=["master"] + self, instance_name="master", writers: Optional[List[str]] = None ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( @@ -625,7 +627,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): ("foobar2", "instance_name", "stream_id"), ], sequence_name="foobar_seq", - writers=writers, + writers=writers or ["master"], ) return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 5858c7fcc4..47556791f4 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py
@@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock from twisted.internet import defer diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index b7dde51224..c6256fce86 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,59 +13,50 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from twisted.internet import defer - from synapse.types import UserID from tests import unittest -from tests.utils import setup_test_homeserver - -class ProfileStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup) +class ProfileStoreTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.u_frank = UserID.from_string("@frank:test") - @defer.inlineCallbacks def test_displayname(self): - yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) + self.get_success(self.store.create_profile(self.u_frank.localpart)) - yield defer.ensureDeferred( - self.store.set_profile_displayname(self.u_frank.localpart, "Frank", 1) + self.get_success( + self.store.set_profile_displayname(self.u_frank.localpart, "Frank") ) self.assertEquals( "Frank", ( - yield defer.ensureDeferred( + self.get_success( self.store.get_profile_displayname(self.u_frank.localpart) ) ), ) # test set to None - yield defer.ensureDeferred( - self.store.set_profile_displayname(self.u_frank.localpart, None, 2) + self.get_success( + self.store.set_profile_displayname(self.u_frank.localpart, None) ) self.assertIsNone( ( - yield defer.ensureDeferred( + self.get_success( self.store.get_profile_displayname(self.u_frank.localpart) ) ) ) - @defer.inlineCallbacks def test_avatar_url(self): - yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart)) + self.get_success(self.store.create_profile(self.u_frank.localpart)) - yield defer.ensureDeferred( + self.get_success( self.store.set_profile_avatar_url( self.u_frank.localpart, "http://my.site/here", 1 ) @@ -74,20 +65,20 @@ class ProfileStoreTestCase(unittest.TestCase): self.assertEquals( "http://my.site/here", ( - yield defer.ensureDeferred( + self.get_success( self.store.get_profile_avatar_url(self.u_frank.localpart) ) ), ) # test set to None - yield defer.ensureDeferred( - self.store.set_profile_avatar_url(self.u_frank.localpart, None, 2) + self.get_success( + self.store.set_profile_avatar_url(self.u_frank.localpart, None) ) self.assertIsNone( ( - yield defer.ensureDeferred( + self.get_success( self.store.get_profile_avatar_url(self.u_frank.localpart) ) ) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index b2a0e60856..2d2f58903c 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py
@@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,11 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from canonicaljson import json -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.types import RoomID, UserID @@ -50,10 +48,15 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.depth = 1 def inject_room_member( - self, room, user, membership, replaces_state=None, extra_content={} + self, + room, + user, + membership, + replaces_state=None, + extra_content: Optional[dict] = None, ): content = {"membership": membership} - content.update(extra_content) + content.update(extra_content or {}) builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { @@ -230,10 +233,9 @@ class RedactionTestCase(unittest.HomeserverTestCase): self._base_builder = base_builder self._event_id = event_id - @defer.inlineCallbacks - def build(self, prev_event_ids, auth_event_ids): - built_event = yield defer.ensureDeferred( - self._base_builder.build(prev_event_ids, auth_event_ids) + async def build(self, prev_event_ids, auth_event_ids): + built_event = await self._base_builder.build( + prev_event_ids, auth_event_ids ) built_event._event_id = self._event_id diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4eb41c46e8..c82cf15bc2 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,21 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from twisted.internet import defer - from synapse.api.constants import UserTypes from synapse.api.errors import ThreepidValidationError -from tests import unittest -from tests.utils import setup_test_homeserver - +from tests.unittest import HomeserverTestCase -class RegistrationStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup) +class RegistrationStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.user_id = "@my-user:test" @@ -35,9 +28,8 @@ class RegistrationStoreTestCase(unittest.TestCase): self.pwhash = "{xx1}123456789" self.device_id = "akgjhdjklgshg" - @defer.inlineCallbacks def test_register(self): - yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) + self.get_success(self.store.register_user(self.user_id, self.pwhash)) self.assertEquals( { @@ -49,93 +41,81 @@ class RegistrationStoreTestCase(unittest.TestCase): "consent_version": None, "consent_server_notice_sent": None, "appservice_id": None, - "creation_ts": 1000, + "creation_ts": 0, "user_type": None, "deactivated": 0, "shadow_banned": 0, }, - (yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))), + (self.get_success(self.store.get_user_by_id(self.user_id))), ) - @defer.inlineCallbacks def test_add_tokens(self): - yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) - yield defer.ensureDeferred( + self.get_success(self.store.register_user(self.user_id, self.pwhash)) + self.get_success( self.store.add_access_token_to_user( self.user_id, self.tokens[1], self.device_id, valid_until_ms=None ) ) - result = yield defer.ensureDeferred( - self.store.get_user_by_access_token(self.tokens[1]) - ) + result = self.get_success(self.store.get_user_by_access_token(self.tokens[1])) self.assertEqual(result.user_id, self.user_id) self.assertEqual(result.device_id, self.device_id) self.assertIsNotNone(result.token_id) - @defer.inlineCallbacks def test_user_delete_access_tokens(self): # add some tokens - yield defer.ensureDeferred(self.store.register_user(self.user_id, self.pwhash)) - yield defer.ensureDeferred( + self.get_success(self.store.register_user(self.user_id, self.pwhash)) + self.get_success( self.store.add_access_token_to_user( self.user_id, self.tokens[0], device_id=None, valid_until_ms=None ) ) - yield defer.ensureDeferred( + self.get_success( self.store.add_access_token_to_user( self.user_id, self.tokens[1], self.device_id, valid_until_ms=None ) ) # now delete some - yield defer.ensureDeferred( + self.get_success( self.store.user_delete_access_tokens(self.user_id, device_id=self.device_id) ) # check they were deleted - user = yield defer.ensureDeferred( - self.store.get_user_by_access_token(self.tokens[1]) - ) + user = self.get_success(self.store.get_user_by_access_token(self.tokens[1])) self.assertIsNone(user, "access token was not deleted by device_id") # check the one not associated with the device was not deleted - user = yield defer.ensureDeferred( - self.store.get_user_by_access_token(self.tokens[0]) - ) + user = self.get_success(self.store.get_user_by_access_token(self.tokens[0])) self.assertEqual(self.user_id, user.user_id) # now delete the rest - yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id)) + self.get_success(self.store.user_delete_access_tokens(self.user_id)) - user = yield defer.ensureDeferred( - self.store.get_user_by_access_token(self.tokens[0]) - ) + user = self.get_success(self.store.get_user_by_access_token(self.tokens[0])) self.assertIsNone(user, "access token was not deleted without device_id") - @defer.inlineCallbacks def test_is_support_user(self): TEST_USER = "@test:test" SUPPORT_USER = "@support:test" - res = yield defer.ensureDeferred(self.store.is_support_user(None)) + res = self.get_success(self.store.is_support_user(None)) self.assertFalse(res) - yield defer.ensureDeferred( + self.get_success( self.store.register_user(user_id=TEST_USER, password_hash=None) ) - res = yield defer.ensureDeferred(self.store.is_support_user(TEST_USER)) + res = self.get_success(self.store.is_support_user(TEST_USER)) self.assertFalse(res) - yield defer.ensureDeferred( + self.get_success( self.store.register_user( user_id=SUPPORT_USER, password_hash=None, user_type=UserTypes.SUPPORT ) ) - res = yield defer.ensureDeferred(self.store.is_support_user(SUPPORT_USER)) + res = self.get_success(self.store.is_support_user(SUPPORT_USER)) self.assertTrue(res) - @defer.inlineCallbacks def test_3pid_inhibit_invalid_validation_session_error(self): """Tests that enabling the configuration option to inhibit 3PID errors on /requestToken also inhibits validation errors caused by an unknown session ID. @@ -143,30 +123,28 @@ class RegistrationStoreTestCase(unittest.TestCase): # Check that, with the config setting set to false (the default value), a # validation error is caused by the unknown session ID. - try: - yield defer.ensureDeferred( - self.store.validate_threepid_session( - "fake_sid", - "fake_client_secret", - "fake_token", - 0, - ) - ) - except ThreepidValidationError as e: - self.assertEquals(e.msg, "Unknown session_id", e) + e = self.get_failure( + self.store.validate_threepid_session( + "fake_sid", + "fake_client_secret", + "fake_token", + 0, + ), + ThreepidValidationError, + ) + self.assertEquals(e.value.msg, "Unknown session_id", e) # Set the config setting to true. self.store._ignore_unknown_session_error = True # Check that now the validation error is caused by the token not matching. - try: - yield defer.ensureDeferred( - self.store.validate_threepid_session( - "fake_sid", - "fake_client_secret", - "fake_token", - 0, - ) - ) - except ThreepidValidationError as e: - self.assertEquals(e.msg, "Validation token not found or has expired", e) + e = self.get_failure( + self.store.validate_threepid_session( + "fake_sid", + "fake_client_secret", + "fake_token", + 0, + ), + ThreepidValidationError, + ) + self.assertEquals(e.value.msg, "Validation token not found or has expired", e) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index bc8400f240..0089d33c93 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,22 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.api.room_versions import RoomVersions from synapse.types import RoomAlias, RoomID, UserID -from tests import unittest -from tests.utils import setup_test_homeserver - +from tests.unittest import HomeserverTestCase -class RoomStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield setup_test_homeserver(self.addCleanup) +class RoomStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): # We can't test RoomStore on its own without the DirectoryStore, for # management of the 'room_aliases' table self.store = hs.get_datastore() @@ -37,7 +30,7 @@ class RoomStoreTestCase(unittest.TestCase): self.alias = RoomAlias.from_string("#a-room-name:test") self.u_creator = UserID.from_string("@creator:test") - yield defer.ensureDeferred( + self.get_success( self.store.store_room( self.room.to_string(), room_creator_user_id=self.u_creator.to_string(), @@ -46,7 +39,6 @@ class RoomStoreTestCase(unittest.TestCase): ) ) - @defer.inlineCallbacks def test_get_room(self): self.assertDictContainsSubset( { @@ -54,16 +46,12 @@ class RoomStoreTestCase(unittest.TestCase): "creator": self.u_creator.to_string(), "is_public": True, }, - (yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))), + (self.get_success(self.store.get_room(self.room.to_string()))), ) - @defer.inlineCallbacks def test_get_room_unknown_room(self): - self.assertIsNone( - (yield defer.ensureDeferred(self.store.get_room("!uknown:test"))) - ) + self.assertIsNone((self.get_success(self.store.get_room("!uknown:test")))) - @defer.inlineCallbacks def test_get_room_with_stats(self): self.assertDictContainsSubset( { @@ -71,29 +59,17 @@ class RoomStoreTestCase(unittest.TestCase): "creator": self.u_creator.to_string(), "public": True, }, - ( - yield defer.ensureDeferred( - self.store.get_room_with_stats(self.room.to_string()) - ) - ), + (self.get_success(self.store.get_room_with_stats(self.room.to_string()))), ) - @defer.inlineCallbacks def test_get_room_with_stats_unknown_room(self): self.assertIsNone( - ( - yield defer.ensureDeferred( - self.store.get_room_with_stats("!uknown:test") - ) - ), + (self.get_success(self.store.get_room_with_stats("!uknown:test"))), ) -class RoomEventsStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = setup_test_homeserver(self.addCleanup) - +class RoomEventsStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastore() @@ -102,7 +78,7 @@ class RoomEventsStoreTestCase(unittest.TestCase): self.room = RoomID.from_string("!abcde:test") - yield defer.ensureDeferred( + self.get_success( self.store.store_room( self.room.to_string(), room_creator_user_id="@creator:text", @@ -111,23 +87,21 @@ class RoomEventsStoreTestCase(unittest.TestCase): ) ) - @defer.inlineCallbacks def inject_room_event(self, **kwargs): - yield defer.ensureDeferred( + self.get_success( self.storage.persistence.persist_event( self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) ) ) - @defer.inlineCallbacks def STALE_test_room_name(self): name = "A-Room-Name" - yield self.inject_room_event( + self.inject_room_event( etype=EventTypes.Name, name=name, content={"name": name}, depth=1 ) - state = yield defer.ensureDeferred( + state = self.get_success( self.store.get_current_state(room_id=self.room.to_string()) ) @@ -137,15 +111,14 @@ class RoomEventsStoreTestCase(unittest.TestCase): state[0], ) - @defer.inlineCallbacks def STALE_test_room_topic(self): topic = "A place for things" - yield self.inject_room_event( + self.inject_room_event( etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1 ) - state = yield defer.ensureDeferred( + state = self.get_success( self.store.get_current_state(room_id=self.room.to_string()) ) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 2471f1267d..f06b452fa9 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,24 +15,18 @@ import logging -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.storage.state import StateFilter from synapse.types import RoomID, UserID -import tests.unittest -import tests.utils +from tests.unittest import HomeserverTestCase logger = logging.getLogger(__name__) -class StateStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - hs = yield tests.utils.setup_test_homeserver(self.addCleanup) - +class StateStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.storage = hs.get_storage() self.state_datastore = self.storage.state.stores.state @@ -44,7 +38,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room = RoomID.from_string("!abc123:test") - yield defer.ensureDeferred( + self.get_success( self.store.store_room( self.room.to_string(), room_creator_user_id="@creator:text", @@ -53,7 +47,6 @@ class StateStoreTestCase(tests.unittest.TestCase): ) ) - @defer.inlineCallbacks def inject_state_event(self, room, sender, typ, state_key, content): builder = self.event_builder_factory.for_room_version( RoomVersions.V1, @@ -66,13 +59,11 @@ class StateStoreTestCase(tests.unittest.TestCase): }, ) - event, context = yield defer.ensureDeferred( + event, context = self.get_success( self.event_creation_handler.create_new_client_event(builder) ) - yield defer.ensureDeferred( - self.storage.persistence.persist_event(event, context) - ) + self.get_success(self.storage.persistence.persist_event(event, context)) return event @@ -82,16 +73,13 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(s1[t].event_id, s2[t].event_id) self.assertEqual(len(s1), len(s2)) - @defer.inlineCallbacks def test_get_state_groups_ids(self): - e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, "", {} - ) - e2 = yield self.inject_state_event( + e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) + e2 = self.inject_state_event( self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield defer.ensureDeferred( + state_group_map = self.get_success( self.storage.state.get_state_groups_ids(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) @@ -101,16 +89,13 @@ class StateStoreTestCase(tests.unittest.TestCase): {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id}, ) - @defer.inlineCallbacks def test_get_state_groups(self): - e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, "", {} - ) - e2 = yield self.inject_state_event( + e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) + e2 = self.inject_state_event( self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield defer.ensureDeferred( + state_group_map = self.get_success( self.storage.state.get_state_groups(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) @@ -118,32 +103,29 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id}) - @defer.inlineCallbacks def test_get_state_for_event(self): # this defaults to a linear DAG as each new injection defaults to whatever # forward extremities are currently in the DB for this room. - e1 = yield self.inject_state_event( - self.room, self.u_alice, EventTypes.Create, "", {} - ) - e2 = yield self.inject_state_event( + e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {}) + e2 = self.inject_state_event( self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - e3 = yield self.inject_state_event( + e3 = self.inject_state_event( self.room, self.u_alice, EventTypes.Member, self.u_alice.to_string(), {"membership": Membership.JOIN}, ) - e4 = yield self.inject_state_event( + e4 = self.inject_state_event( self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), {"membership": Membership.JOIN}, ) - e5 = yield self.inject_state_event( + e5 = self.inject_state_event( self.room, self.u_bob, EventTypes.Member, @@ -152,9 +134,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield defer.ensureDeferred( - self.storage.state.get_state_for_event(e5.event_id) - ) + state = self.get_success(self.storage.state.get_state_for_event(e5.event_id)) self.assertIsNotNone(e4) @@ -170,7 +150,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we can filter to the m.room.name event (with a '' state key) - state = yield defer.ensureDeferred( + state = self.get_success( self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) ) @@ -179,7 +159,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) - state = yield defer.ensureDeferred( + state = self.get_success( self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) ) @@ -188,7 +168,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) - state = yield defer.ensureDeferred( + state = self.get_success( self.storage.state.get_state_for_event( e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) ) @@ -200,7 +180,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can grab a specific room member without filtering out the # other event types - state = yield defer.ensureDeferred( + state = self.get_success( self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( @@ -220,7 +200,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check that we can grab everything except members - state = yield defer.ensureDeferred( + state = self.get_success( self.storage.state.get_state_for_event( e5.event_id, state_filter=StateFilter( @@ -238,17 +218,14 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### room_id = self.room.to_string() - group_ids = yield defer.ensureDeferred( + group_ids = self.get_success( self.storage.state.get_state_groups_ids(room_id, [e5.event_id]) ) group = list(group_ids.keys())[0] # test _get_state_for_group_using_cache correctly filters out members # with types=[] - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -265,10 +242,7 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -281,10 +255,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with wildcard types - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -301,10 +272,7 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -324,10 +292,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -344,10 +309,7 @@ class StateStoreTestCase(tests.unittest.TestCase): state_dict, ) - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -360,10 +322,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -413,10 +372,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters out members # with types=[] room_id = self.room.to_string() - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -428,10 +384,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) room_id = self.room.to_string() - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -444,10 +397,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # wildcard types - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -458,10 +408,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -480,10 +427,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -494,10 +438,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict) - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( @@ -510,10 +451,7 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters in members # with specific types - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_cache, group, state_filter=StateFilter( @@ -524,10 +462,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.assertEqual(is_all, False) self.assertDictEqual({}, state_dict) - ( - state_dict, - is_all, - ) = yield self.state_datastore._get_state_for_group_using_cache( + (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache( self.state_datastore._state_group_members_cache, group, state_filter=StateFilter( diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index a6f63f4aaf..019c5b7b14 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,10 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - -from tests import unittest -from tests.utils import setup_test_homeserver +from tests.unittest import HomeserverTestCase, override_config ALICE = "@alice:a" BOB = "@bob:b" @@ -25,73 +22,52 @@ BOBBY = "@bobby:a" BELA = "@somenickname:a" -class UserDirectoryStoreTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) - self.store = self.hs.get_datastore() +class UserDirectoryStoreTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. - yield defer.ensureDeferred( - self.store.update_profile_in_user_dir(ALICE, "alice", None) - ) - yield defer.ensureDeferred( - self.store.update_profile_in_user_dir(BOB, "bob", None) - ) - yield defer.ensureDeferred( - self.store.update_profile_in_user_dir(BOBBY, "bobby", None) - ) - yield defer.ensureDeferred( - self.store.update_profile_in_user_dir(BELA, "Bela", None) - ) - yield defer.ensureDeferred( - self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)) - ) + self.get_success(self.store.update_profile_in_user_dir(ALICE, "alice", None)) + self.get_success(self.store.update_profile_in_user_dir(BOB, "bob", None)) + self.get_success(self.store.update_profile_in_user_dir(BOBBY, "bobby", None)) + self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None)) + self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB))) - @defer.inlineCallbacks def test_search_user_dir(self): # normally when alice searches the directory she should just find # bob because bobby doesn't share a room with her. - r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10)) + r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10)) self.assertFalse(r["limited"]) self.assertEqual(1, len(r["results"])) self.assertDictEqual( r["results"][0], {"user_id": BOB, "display_name": "bob", "avatar_url": None} ) - @defer.inlineCallbacks + @override_config({"user_directory": {"search_all_users": True}}) def test_search_user_dir_all_users(self): - self.hs.config.user_directory_search_all_users = True - try: - r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10)) - self.assertFalse(r["limited"]) - self.assertEqual(2, len(r["results"])) - self.assertDictEqual( - r["results"][0], - {"user_id": BOB, "display_name": "bob", "avatar_url": None}, - ) - self.assertDictEqual( - r["results"][1], - {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None}, - ) - finally: - self.hs.config.user_directory_search_all_users = False + r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10)) + self.assertFalse(r["limited"]) + self.assertEqual(2, len(r["results"])) + self.assertDictEqual( + r["results"][0], + {"user_id": BOB, "display_name": "bob", "avatar_url": None}, + ) + self.assertDictEqual( + r["results"][1], + {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None}, + ) - @defer.inlineCallbacks + @override_config({"user_directory": {"search_all_users": True}}) def test_search_user_dir_stop_words(self): """Tests that a user can look up another user by searching for the start if its display name even if that name happens to be a common English word that would usually be ignored in full text searches. """ - self.hs.config.user_directory_search_all_users = True - try: - r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "be", 10)) - self.assertFalse(r["limited"]) - self.assertEqual(1, len(r["results"])) - self.assertDictEqual( - r["results"][0], - {"user_id": BELA, "display_name": "Bela", "avatar_url": None}, - ) - finally: - self.hs.config.user_directory_search_all_users = False + r = self.get_success(self.store.search_user_dir(ALICE, "be", 10)) + self.assertFalse(r["limited"]) + self.assertEqual(1, len(r["results"])) + self.assertDictEqual( + r["results"][0], + {"user_id": BELA, "display_name": "Bela", "avatar_url": None}, + ) diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index b57f36e6ac..6a6cf709f6 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py
@@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock, patch +from unittest.mock import Mock, patch from synapse.util.distributor import Distributor diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index 3f2691ee6b..b5f18344dc 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py
@@ -207,6 +207,226 @@ class EventAuthTestCase(unittest.TestCase): do_sig_check=False, ) + def test_join_rules_public(self): + """ + Test joining a public room. + """ + creator = "@creator:example.com" + pleb = "@joiner:example.com" + + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + ("m.room.join_rules", ""): _join_rules_event(creator, "public"), + } + + # Check join. + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user cannot be force-joined to a room. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _member_event(pleb, "join", sender=creator), + auth_events, + do_sig_check=False, + ) + + # Banned should be rejected. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user who left can re-join. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user can send a join if they're in the room. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user can accept an invite. + auth_events[("m.room.member", pleb)] = _member_event( + pleb, "invite", sender=creator + ) + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + def test_join_rules_invite(self): + """ + Test joining an invite only room. + """ + creator = "@creator:example.com" + pleb = "@joiner:example.com" + + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + ("m.room.join_rules", ""): _join_rules_event(creator, "invite"), + } + + # A join without an invite is rejected. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user cannot be force-joined to a room. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _member_event(pleb, "join", sender=creator), + auth_events, + do_sig_check=False, + ) + + # Banned should be rejected. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user who left cannot re-join. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user can send a join if they're in the room. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user can accept an invite. + auth_events[("m.room.member", pleb)] = _member_event( + pleb, "invite", sender=creator + ) + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + def test_join_rules_msc3083_restricted(self): + """ + Test joining a restricted room from MSC3083. + + This is pretty much the same test as public. + """ + creator = "@creator:example.com" + pleb = "@joiner:example.com" + + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + ("m.room.join_rules", ""): _join_rules_event(creator, "restricted"), + } + + # Older room versions don't understand this join rule + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.V6, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # Check join. + event_auth.check( + RoomVersions.MSC3083, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user cannot be force-joined to a room. + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.MSC3083, + _member_event(pleb, "join", sender=creator), + auth_events, + do_sig_check=False, + ) + + # Banned should be rejected. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.MSC3083, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user who left can re-join. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") + event_auth.check( + RoomVersions.MSC3083, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user can send a join if they're in the room. + auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") + event_auth.check( + RoomVersions.MSC3083, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # A user can accept an invite. + auth_events[("m.room.member", pleb)] = _member_event( + pleb, "invite", sender=creator + ) + event_auth.check( + RoomVersions.MSC3083, + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + # helpers for making events @@ -225,19 +445,24 @@ def _create_event(user_id): ) -def _join_event(user_id): +def _member_event(user_id, membership, sender=None): return make_event_from_dict( { "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), "type": "m.room.member", - "sender": user_id, + "sender": sender or user_id, "state_key": user_id, - "content": {"membership": "join"}, + "content": {"membership": membership}, + "prev_events": [], } ) +def _join_event(user_id): + return _member_event(user_id, "join") + + def _power_levels_event(sender, content): return make_event_from_dict( { @@ -277,6 +502,21 @@ def _random_state_event(sender): ) +def _join_rules_event(sender, join_rule): + return make_event_from_dict( + { + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.join_rules", + "sender": sender, + "state_key": "", + "content": { + "join_rule": join_rule, + }, + } + ) + + event_count = 0 diff --git a/tests/test_federation.py b/tests/test_federation.py
index fc9aab32d0..382cedbd5d 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock from twisted.internet.defer import succeed @@ -134,7 +134,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): } ) - with LoggingContext(): + with LoggingContext("test-context"): failure = self.get_failure( self.handler.on_receive_pdu( "test.serv", lying_event, sent_to_us_directly=True diff --git a/tests/test_mau.py b/tests/test_mau.py
index 75d28a42df..7d92a16a8d 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py
@@ -15,9 +15,7 @@ """Tests REST events for /rooms paths.""" -import json - -from synapse.api.constants import LoginType +from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.appservice import ApplicationService from synapse.rest.client.v2_alpha import register, sync @@ -113,7 +111,7 @@ class TestMauLimit(unittest.HomeserverTestCase): ) ) - self.create_user("as_kermit4", token=as_token) + self.create_user("as_kermit4", token=as_token, appservice=True) def test_allowed_after_a_month_mau(self): # Create and sync so that the MAU counts get updated @@ -232,14 +230,15 @@ class TestMauLimit(unittest.HomeserverTestCase): self.reactor.advance(100) self.assertEqual(2, self.successResultOf(count)) - def create_user(self, localpart, token=None): - request_data = json.dumps( - { - "username": localpart, - "password": "monkey", - "auth": {"type": LoginType.DUMMY}, - } - ) + def create_user(self, localpart, token=None, appservice=False): + request_data = { + "username": localpart, + "password": "monkey", + "auth": {"type": LoginType.DUMMY}, + } + + if appservice: + request_data["type"] = APP_SERVICE_REGISTRATION_TYPE channel = self.make_request( "POST", diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py
index e7aed092c2..0f800a075b 100644 --- a/tests/test_phone_home.py +++ b/tests/test_phone_home.py
@@ -14,8 +14,7 @@ # limitations under the License. import resource - -import mock +from unittest import mock from synapse.app.phone_stats_home import phone_stats_home diff --git a/tests/test_state.py b/tests/test_state.py
index 6227a3ba95..0d626f49f6 100644 --- a/tests/test_state.py +++ b/tests/test_state.py
@@ -12,8 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from mock import Mock +from typing import List, Optional +from unittest.mock import Mock from twisted.internet import defer @@ -37,8 +37,8 @@ def create_event( state_key=None, depth=2, event_id=None, - prev_events=[], - **kwargs + prev_events: Optional[List[str]] = None, + **kwargs, ): global _next_event_id @@ -58,7 +58,7 @@ def create_event( "sender": "@user_id:example.com", "room_id": "!room_id:example.com", "depth": depth, - "prev_events": prev_events, + "prev_events": prev_events or [], } if state_key is not None: diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index a743cdc3a9..0df480db9f 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py
@@ -13,8 +13,7 @@ # limitations under the License. import json - -from mock import Mock +from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactorClock diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 43898d8142..b557ffd692 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py
@@ -21,8 +21,7 @@ import sys import warnings from asyncio import Future from typing import Any, Awaitable, Callable, TypeVar - -from mock import Mock +from unittest.mock import Mock import attr diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index c3c4a93e1f..3dfbf8f8a9 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py
@@ -33,7 +33,7 @@ async def inject_member_event( membership: str, target: Optional[str] = None, extra_content: Optional[dict] = None, - **kwargs + **kwargs, ) -> EventBase: """Inject a membership event into a room.""" if target is None: @@ -58,7 +58,7 @@ async def inject_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, prev_event_ids: Optional[List[str]] = None, - **kwargs + **kwargs, ) -> EventBase: """Inject a generic event into a room @@ -83,7 +83,7 @@ async def create_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, prev_event_ids: Optional[List[str]] = None, - **kwargs + **kwargs, ) -> Tuple[EventBase, EventContext]: if room_version is None: room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"]) diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 510b630114..e502ac197e 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py
@@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging - -from mock import Mock +from typing import Optional +from unittest.mock import Mock from twisted.internet import defer from twisted.internet.defer import succeed @@ -147,9 +147,11 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): return event @defer.inlineCallbacks - def inject_room_member(self, user_id, membership="join", extra_content={}): + def inject_room_member( + self, user_id, membership="join", extra_content: Optional[dict] = None + ): content = {"membership": membership} - content.update(extra_content) + content.update(extra_content or {}) builder = self.event_builder_factory.for_room_version( RoomVersions.V1, { diff --git a/tests/unittest.py b/tests/unittest.py
index 58a4daa1ec..92764434bd 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -21,8 +21,7 @@ import inspect import logging import time from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union - -from mock import Mock, patch +from unittest.mock import Mock, patch from canonicaljson import json @@ -471,7 +470,7 @@ class HomeserverTestCase(TestCase): kwargs["config"] = config_obj async def run_bg_updates(): - with LoggingContext("run_bg_updates", request="run_bg_updates-1"): + with LoggingContext("run_bg_updates"): while not await stor.db_pool.updates.has_completed_background_updates(): await stor.db_pool.updates.do_next_background_update(1) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index afb11b9caf..8c082e7432 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py
@@ -15,8 +15,7 @@ # limitations under the License. import logging from typing import Set - -import mock +from unittest import mock from twisted.internet import defer, reactor @@ -232,8 +231,7 @@ class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def do_lookup(): - with LoggingContext() as c1: - c1.name = "c1" + with LoggingContext("c1") as c1: r = yield obj.fn(1) self.assertEqual(current_context(), c1) return r @@ -275,8 +273,7 @@ class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def do_lookup(): - with LoggingContext() as c1: - c1.name = "c1" + with LoggingContext("c1") as c1: try: d = obj.fn(1) self.assertEqual( @@ -661,14 +658,13 @@ class CachedListDescriptorTestCase(unittest.TestCase): @descriptors.cachedList("fn", "args1") async def list_fn(self, args1, arg2): - assert current_context().request == "c1" + assert current_context().name == "c1" # we want this to behave like an asynchronous function await run_on_reactor() - assert current_context().request == "c1" + assert current_context().name == "c1" return self.mock(args1, arg2) - with LoggingContext() as c1: - c1.request = "c1" + with LoggingContext("c1") as c1: obj = Cls() obj.mock.return_value = {10: "fish", 20: "chips"} d1 = obj.list_fn([10, 20], 2) diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py
index 816795c136..23018081e5 100644 --- a/tests/util/caches/test_ttlcache.py +++ b/tests/util/caches/test_ttlcache.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock +from unittest.mock import Mock from synapse.util.caches.ttlcache import TTLCache diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
index 2012263184..d1372f6bc2 100644 --- a/tests/util/test_file_consumer.py +++ b/tests/util/test_file_consumer.py
@@ -16,8 +16,7 @@ import threading from io import StringIO - -from mock import NonCallableMock +from unittest.mock import NonCallableMock from twisted.internet import defer, reactor diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 58ee918f65..5d9c4665aa 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py
@@ -17,11 +17,10 @@ from .. import unittest class LoggingContextTestCase(unittest.TestCase): def _check_test_key(self, value): - self.assertEquals(current_context().request, value) + self.assertEquals(current_context().name, value) def test_with_context(self): - with LoggingContext() as context_one: - context_one.request = "test" + with LoggingContext("test"): self._check_test_key("test") @defer.inlineCallbacks @@ -30,15 +29,13 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def competing_callback(): - with LoggingContext() as competing_context: - competing_context.request = "competing" + with LoggingContext("competing"): yield clock.sleep(0) self._check_test_key("competing") reactor.callLater(0, competing_callback) - with LoggingContext() as context_one: - context_one.request = "one" + with LoggingContext("one"): yield clock.sleep(0) self._check_test_key("one") @@ -47,9 +44,7 @@ class LoggingContextTestCase(unittest.TestCase): callback_completed = [False] - with LoggingContext() as context_one: - context_one.request = "one" - + with LoggingContext("one"): # fire off function, but don't wait on it. d2 = run_in_background(function) @@ -133,9 +128,7 @@ class LoggingContextTestCase(unittest.TestCase): sentinel_context = current_context() - with LoggingContext() as context_one: - context_one.request = "one" - + with LoggingContext("one"): d1 = make_deferred_yieldable(blocking_function()) # make sure that the context was reset by make_deferred_yieldable self.assertIs(current_context(), sentinel_context) @@ -149,9 +142,7 @@ class LoggingContextTestCase(unittest.TestCase): def test_make_deferred_yieldable_with_chained_deferreds(self): sentinel_context = current_context() - with LoggingContext() as context_one: - context_one.request = "one" - + with LoggingContext("one"): d1 = make_deferred_yieldable(_chained_deferred_function()) # make sure that the context was reset by make_deferred_yieldable self.assertIs(current_context(), sentinel_context) @@ -166,9 +157,7 @@ class LoggingContextTestCase(unittest.TestCase): """Check that make_deferred_yieldable does the right thing when its argument isn't actually a deferred""" - with LoggingContext() as context_one: - context_one.request = "one" - + with LoggingContext("one"): d1 = make_deferred_yieldable("bum") self._check_test_key("one") @@ -177,9 +166,9 @@ class LoggingContextTestCase(unittest.TestCase): self._check_test_key("one") def test_nested_logging_context(self): - with LoggingContext(request="foo"): + with LoggingContext("foo"): nested_context = nested_logging_context(suffix="bar") - self.assertEqual(nested_context.request, "foo-bar") + self.assertEqual(nested_context.name, "foo-bar") @defer.inlineCallbacks def test_make_deferred_yieldable_with_await(self): @@ -193,9 +182,7 @@ class LoggingContextTestCase(unittest.TestCase): sentinel_context = current_context() - with LoggingContext() as context_one: - context_one.request = "one" - + with LoggingContext("one"): d1 = make_deferred_yieldable(blocking_function()) # make sure that the context was reset by make_deferred_yieldable self.assertIs(current_context(), sentinel_context) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py
index a739a6aaaf..ce4f1cc30a 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py
@@ -14,7 +14,7 @@ # limitations under the License. -from mock import Mock +from unittest.mock import Mock from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py
index 4d1aee91d5..3fed55090a 100644 --- a/tests/util/test_ratelimitutils.py +++ b/tests/util/test_ratelimitutils.py
@@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + from synapse.config.homeserver import HomeServerConfig from synapse.util.ratelimitutils import FederationRateLimiter @@ -89,9 +91,9 @@ def _await_resolution(reactor, d): return (reactor.seconds() - start_time) * 1000 -def build_rc_config(settings={}): +def build_rc_config(settings: Optional[dict] = None): config_dict = default_config("test") - config_dict.update(settings) + config_dict.update(settings or {}) config = HomeServerConfig() config.parse_config_dict(config_dict, "", "") return config.rc_federation diff --git a/tests/utils.py b/tests/utils.py
index 5d299f766f..65d7ad58d9 100644 --- a/tests/utils.py +++ b/tests/utils.py
@@ -21,10 +21,9 @@ import time import uuid import warnings from typing import Type +from unittest.mock import Mock, patch from urllib import parse as urlparse -from mock import Mock, patch - from twisted.internet import defer from synapse.api.constants import EventTypes @@ -122,7 +121,6 @@ def default_config(name, parse=False): "enable_registration_captcha": False, "macaroon_secret_key": "not even a little secret", "trusted_third_party_id_servers": [], - "room_invite_state_types": [], "password_providers": [], "worker_replication_url": "", "worker_app": None, @@ -198,7 +196,7 @@ def setup_test_homeserver( config=None, reactor=None, homeserver_to_use: Type[HomeServer] = TestHomeServer, - **kwargs + **kwargs, ): """ Setup a homeserver suitable for running tests against. Keyword arguments