From 40e576e29cf6f06d6b5244c5d1df34cf33b1f556 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 1 May 2019 15:32:38 +0100 Subject: Move admin api impl to its own package It doesn't really belong under rest/client/v1 any more. --- tests/handlers/test_user_directory.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'tests/handlers') diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index f1d0aa42b6..32ef83e9c0 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -14,8 +14,9 @@ # limitations under the License. from mock import Mock +import synapse.rest.admin from synapse.api.constants import UserTypes -from synapse.rest.client.v1 import admin, login, room +from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import user_directory from synapse.storage.roommember import ProfileInfo @@ -29,7 +30,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): servlets = [ login.register_servlets, - admin.register_servlets, + synapse.rest.admin.register_servlets, room.register_servlets, ] @@ -327,7 +328,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): user_directory.register_servlets, room.register_servlets, login.register_servlets, - admin.register_servlets, + synapse.rest.admin.register_servlets, ] def make_homeserver(self, reactor, clock): -- cgit 1.5.1 From 12f9d51e826058998cb11759e068de8977ddd3d5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 2 May 2019 11:59:16 +0100 Subject: Add admin api for sending server_notices (#5121) --- changelog.d/5121.feature | 1 + docs/admin_api/server_notices.md | 48 ++++++++++++ docs/server_notices.md | 25 ++---- synapse/rest/__init__.py | 4 +- synapse/rest/admin/__init__.py | 17 +++- synapse/rest/admin/server_notice_servlet.py | 100 ++++++++++++++++++++++++ tests/handlers/test_user_directory.py | 4 +- tests/push/test_email.py | 2 +- tests/push/test_http.py | 2 +- tests/rest/admin/test_admin.py | 8 +- tests/rest/client/test_consent.py | 2 +- tests/rest/client/test_identity.py | 2 +- tests/rest/client/v1/test_events.py | 2 +- tests/rest/client/v1/test_login.py | 2 +- tests/rest/client/v1/test_rooms.py | 2 +- tests/rest/client/v2_alpha/test_auth.py | 2 +- tests/rest/client/v2_alpha/test_capabilities.py | 2 +- tests/rest/client/v2_alpha/test_register.py | 4 +- tests/rest/client/v2_alpha/test_sync.py | 2 +- tests/server_notices/test_consent.py | 2 +- tests/storage/test_client_ips.py | 5 +- 21 files changed, 196 insertions(+), 42 deletions(-) create mode 100644 changelog.d/5121.feature create mode 100644 docs/admin_api/server_notices.md create mode 100644 synapse/rest/admin/server_notice_servlet.py (limited to 'tests/handlers') diff --git a/changelog.d/5121.feature b/changelog.d/5121.feature new file mode 100644 index 0000000000..54b228680d --- /dev/null +++ b/changelog.d/5121.feature @@ -0,0 +1 @@ +Implement an admin API for sending server notices. Many thanks to @krombel who provided a foundation for this work. diff --git a/docs/admin_api/server_notices.md b/docs/admin_api/server_notices.md new file mode 100644 index 0000000000..5ddd21cfb2 --- /dev/null +++ b/docs/admin_api/server_notices.md @@ -0,0 +1,48 @@ +# Server Notices + +The API to send notices is as follows: + +``` +POST /_synapse/admin/v1/send_server_notice +``` + +or: + +``` +PUT /_synapse/admin/v1/send_server_notice/{txnId} +``` + +You will need to authenticate with an access token for an admin user. + +When using the `PUT` form, retransmissions with the same transaction ID will be +ignored in the same way as with `PUT +/_matrix/client/r0/rooms/{roomId}/send/{eventType}/{txnId}`. + +The request body should look something like the following: + +```json +{ + "user_id": "@target_user:server_name", + "content": { + "msgtype": "m.text", + "body": "This is my message" + } +} +``` + +You can optionally include the following additional parameters: + +* `type`: the type of event. Defaults to `m.room.message`. +* `state_key`: Setting this will result in a state event being sent. + + +Once the notice has been sent, the APU will return the following response: + +```json +{ + "event_id": "" +} +``` + +Note that server notices must be enabled in `homeserver.yaml` before this API +can be used. See [server_notices.md](../server_notices.md) for more information. diff --git a/docs/server_notices.md b/docs/server_notices.md index 58f8776319..950a6608e9 100644 --- a/docs/server_notices.md +++ b/docs/server_notices.md @@ -1,5 +1,4 @@ -Server Notices -============== +# Server Notices 'Server Notices' are a new feature introduced in Synapse 0.30. They provide a channel whereby server administrators can send messages to users on the server. @@ -11,8 +10,7 @@ they may also find a use for features such as "Message of the day". This is a feature specific to Synapse, but it uses standard Matrix communication mechanisms, so should work with any Matrix client. -User experience ---------------- +## User experience When the user is first sent a server notice, they will get an invitation to a room (typically called 'Server Notices', though this is configurable in @@ -29,8 +27,7 @@ levels. Having joined the room, the user can leave the room if they want. Subsequent server notices will then cause a new room to be created. -Synapse configuration ---------------------- +## Synapse configuration Server notices come from a specific user id on the server. Server administrators are free to choose the user id - something like `server` is @@ -58,17 +55,7 @@ room which will be created. `system_mxid_display_name` and `system_mxid_avatar_url` can be used to set the displayname and avatar of the Server Notices user. -Sending notices ---------------- +## Sending notices -As of the current version of synapse, there is no convenient interface for -sending notices (other than the automated ones sent as part of consent -tracking). - -In the meantime, it is possible to test this feature using the manhole. Having -gone into the manhole as described in [manhole.md](manhole.md), a notice can be -sent with something like: - -``` ->>> hs.get_server_notices_manager().send_notice('@user:server.com', {'msgtype':'m.text', 'body':'foo'}) -``` +To send server notices to users you can use the +[admin_api](admin_api/server_notices.md). diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index e8e1bcddea..3a24d31d1b 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -117,4 +117,6 @@ class ClientRestResource(JsonResource): account_validity.register_servlets(hs, client_resource) # moving to /_synapse/admin - synapse.rest.admin.register_servlets(hs, client_resource) + synapse.rest.admin.register_servlets_for_client_rest_resource( + hs, client_resource + ) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index d02f5198b8..0ce89741f0 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -37,6 +37,7 @@ from synapse.http.servlet import ( parse_string, ) from synapse.rest.admin._base import assert_requester_is_admin, assert_user_is_admin +from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.types import UserID, create_requester from synapse.util.versionstring import get_version_string @@ -813,16 +814,26 @@ class AccountValidityRenewServlet(RestServlet): } defer.returnValue((200, res)) +######################################################################################## +# +# please don't add more servlets here: this file is already long and unwieldy. Put +# them in separate files within the 'admin' package. +# +######################################################################################## + class AdminRestResource(JsonResource): """The REST resource which gets mounted at /_synapse/admin""" def __init__(self, hs): JsonResource.__init__(self, hs, canonical_json=False) - register_servlets(hs, self) + + register_servlets_for_client_rest_resource(hs, self) + SendServerNoticeServlet(hs).register(self) -def register_servlets(hs, http_server): +def register_servlets_for_client_rest_resource(hs, http_server): + """Register only the servlets which need to be exposed on /_matrix/client/xxx""" WhoisRestServlet(hs).register(http_server) PurgeMediaCacheRestServlet(hs).register(http_server) PurgeHistoryStatusRestServlet(hs).register(http_server) @@ -839,3 +850,5 @@ def register_servlets(hs, http_server): VersionServlet(hs).register(http_server) DeleteGroupAdminRestServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server) + # don't add more things here: new servlets should only be exposed on + # /_synapse/admin so should not go here. Instead register them in AdminRestResource. diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py new file mode 100644 index 0000000000..ae5aca9dac --- /dev/null +++ b/synapse/rest/admin/server_notice_servlet.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.rest.admin import assert_requester_is_admin +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.types import UserID + + +class SendServerNoticeServlet(RestServlet): + """Servlet which will send a server notice to a given user + + POST /_synapse/admin/v1/send_server_notice + { + "user_id": "@target_user:server_name", + "content": { + "msgtype": "m.text", + "body": "This is my message" + } + } + + returns: + + { + "event_id": "$1895723857jgskldgujpious" + } + """ + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + self.hs = hs + self.auth = hs.get_auth() + self.txns = HttpTransactionCache(hs) + self.snm = hs.get_server_notices_manager() + + def register(self, json_resource): + PATTERN = "^/_synapse/admin/v1/send_server_notice" + json_resource.register_paths( + "POST", + (re.compile(PATTERN + "$"), ), + self.on_POST, + ) + json_resource.register_paths( + "PUT", + (re.compile(PATTERN + "/(?P[^/]*)$",), ), + self.on_PUT, + ) + + @defer.inlineCallbacks + def on_POST(self, request, txn_id=None): + yield assert_requester_is_admin(self.auth, request) + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ("user_id", "content")) + event_type = body.get("type", EventTypes.Message) + state_key = body.get("state_key") + + if not self.snm.is_enabled(): + raise SynapseError(400, "Server notices are not enabled on this server") + + user_id = body["user_id"] + UserID.from_string(user_id) + if not self.hs.is_mine_id(user_id): + raise SynapseError(400, "Server notices can only be sent to local users") + + event = yield self.snm.send_notice( + user_id=body["user_id"], + type=event_type, + state_key=state_key, + event_content=body["content"], + ) + + defer.returnValue((200, {"event_id": event.event_id})) + + def on_PUT(self, request, txn_id): + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, txn_id, + ) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 32ef83e9c0..7dd1a1daf8 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -30,7 +30,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): servlets = [ login.register_servlets, - synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, ] @@ -328,7 +328,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): user_directory.register_servlets, room.register_servlets, login.register_servlets, - synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, ] def make_homeserver(self, reactor, clock): diff --git a/tests/push/test_email.py b/tests/push/test_email.py index e29bd18ad7..325ea449ae 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -34,7 +34,7 @@ class EmailPusherTests(HomeserverTestCase): skip = "No Jinja installed" if not load_jinja2_templates else None servlets = [ - synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, ] diff --git a/tests/push/test_http.py b/tests/push/test_http.py index 3f9f56bb79..13bd2c8688 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -33,7 +33,7 @@ class HTTPPusherTests(HomeserverTestCase): skip = "No Jinja installed" if not load_jinja2_templates else None servlets = [ - synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, ] diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 42858b5fea..db4cfd8550 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -30,7 +30,7 @@ from tests import unittest class VersionTestCase(unittest.HomeserverTestCase): servlets = [ - synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, ] @@ -63,7 +63,7 @@ class VersionTestCase(unittest.HomeserverTestCase): class UserRegisterTestCase(unittest.HomeserverTestCase): - servlets = [synapse.rest.admin.register_servlets] + servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] def make_homeserver(self, reactor, clock): @@ -359,7 +359,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): class ShutdownRoomTestCase(unittest.HomeserverTestCase): servlets = [ - synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, events.register_servlets, room.register_servlets, @@ -496,7 +496,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): class DeleteGroupTestCase(unittest.HomeserverTestCase): servlets = [ - synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, groups.register_servlets, ] diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index 36e6c1c67d..5528971190 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -32,7 +32,7 @@ except Exception: class ConsentResourceTestCase(unittest.HomeserverTestCase): skip = "No Jinja installed" if not load_jinja2_templates else None servlets = [ - synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, ] diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index d4fe0aee7d..2e51ffa418 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -24,7 +24,7 @@ from tests import unittest class IdentityTestCase(unittest.HomeserverTestCase): servlets = [ - synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, ] diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 5cb1c1ae9f..8a9a55a527 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -29,7 +29,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): servlets = [ events.register_servlets, room.register_servlets, - synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, ] diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 8d9ef877f6..9ebd91f678 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -11,7 +11,7 @@ LOGIN_URL = b"/_matrix/client/r0/login" class LoginRestServletTestCase(unittest.HomeserverTestCase): servlets = [ - synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, ] diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 1a34924f3e..521ac80f9a 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -804,7 +804,7 @@ class RoomMessageListTestCase(RoomBase): class RoomSearchTestCase(unittest.HomeserverTestCase): servlets = [ - synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, ] diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 67021185d0..0ca3c4657b 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -27,7 +27,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase): servlets = [ auth.register_servlets, - synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, register.register_servlets, ] hijack_auth = False diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index 8134163e20..f3ef977404 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -23,7 +23,7 @@ from tests import unittest class CapabilitiesTestCase(unittest.HomeserverTestCase): servlets = [ - synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, capabilities.register_servlets, login.register_servlets, ] diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 4d698af03a..1c3a621d26 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -199,7 +199,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): servlets = [ register.register_servlets, - synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, sync.register_servlets, account_validity.register_servlets, @@ -308,7 +308,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): skip = "No Jinja installed" if not load_jinja2_templates else None servlets = [ register.register_servlets, - synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, sync.register_servlets, account_validity.register_servlets, diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 65fac1d5ce..71895094bd 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -73,7 +73,7 @@ class FilterTestCase(unittest.HomeserverTestCase): class SyncTypingTests(unittest.HomeserverTestCase): servlets = [ - synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, room.register_servlets, login.register_servlets, sync.register_servlets, diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py index e8b8ac5725..e0b4e0eb63 100644 --- a/tests/server_notices/test_consent.py +++ b/tests/server_notices/test_consent.py @@ -23,7 +23,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): servlets = [ sync.register_servlets, - synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, login.register_servlets, room.register_servlets, ] diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index b0f6fd34d8..b62eae7abc 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -206,7 +206,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): class ClientIpAuthTestCase(unittest.HomeserverTestCase): - servlets = [synapse.rest.admin.register_servlets, login.register_servlets] + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver() -- cgit 1.5.1 From b36c82576e3bb7ea72600ecf0e80c904ccf47d1d Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Fri, 10 May 2019 00:12:11 -0500 Subject: Run Black on the tests again (#5170) --- changelog.d/5170.misc | 1 + tests/api/test_filtering.py | 1 - tests/api/test_ratelimiting.py | 6 +- tests/app/test_openid_listener.py | 46 ++-- tests/config/test_generate.py | 8 +- tests/config/test_room_directory.py | 178 +++++++------ tests/config/test_server.py | 1 - tests/config/test_tls.py | 11 +- tests/crypto/test_keyring.py | 3 +- tests/federation/test_federation_sender.py | 113 +++++---- tests/handlers/test_directory.py | 24 +- tests/handlers/test_e2e_room_keys.py | 277 +++++++++++---------- tests/handlers/test_presence.py | 14 +- tests/handlers/test_typing.py | 152 +++++------ tests/handlers/test_user_directory.py | 8 +- tests/http/__init__.py | 6 +- .../federation/test_matrix_federation_agent.py | 212 ++++++---------- tests/http/federation/test_srv_resolver.py | 36 +-- tests/http/test_fedclient.py | 31 +-- tests/patch_inline_callbacks.py | 16 +- tests/replication/slave/storage/_base.py | 15 +- tests/replication/slave/storage/test_events.py | 19 +- tests/replication/tcp/streams/_base.py | 6 +- tests/rest/admin/test_admin.py | 103 +++----- tests/rest/client/test_identity.py | 8 +- tests/rest/client/v1/test_directory.py | 53 ++-- tests/rest/client/v1/test_login.py | 36 +-- tests/rest/client/v1/test_profile.py | 27 +- tests/rest/client/v2_alpha/test_register.py | 82 +++--- tests/rest/media/v1/test_base.py | 14 +- tests/rest/test_well_known.py | 13 +- tests/server.py | 7 +- .../test_resource_limits_server_notices.py | 1 - tests/state/test_v2.py | 204 +++------------ tests/storage/test_background_update.py | 4 +- tests/storage/test_base.py | 5 +- tests/storage/test_end_to_end_keys.py | 1 - tests/storage/test_monthly_active_users.py | 17 +- tests/storage/test_redaction.py | 6 +- tests/storage/test_registration.py | 2 +- tests/storage/test_roommember.py | 2 +- tests/storage/test_state.py | 83 +++--- tests/storage/test_user_directory.py | 4 +- tests/test_event_auth.py | 8 +- tests/test_federation.py | 1 - tests/test_mau.py | 8 +- tests/test_metrics.py | 56 +++-- tests/test_terms_auth.py | 8 +- tests/test_types.py | 6 +- tests/test_utils/logging_setup.py | 4 +- tests/test_visibility.py | 6 +- tests/unittest.py | 13 +- tests/util/test_async_utils.py | 23 +- tests/utils.py | 9 +- 54 files changed, 829 insertions(+), 1169 deletions(-) create mode 100644 changelog.d/5170.misc (limited to 'tests/handlers') diff --git a/changelog.d/5170.misc b/changelog.d/5170.misc new file mode 100644 index 0000000000..7919dac555 --- /dev/null +++ b/changelog.d/5170.misc @@ -0,0 +1 @@ +Run `black` on the tests directory. diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 2a7044801a..6ba623de13 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -109,7 +109,6 @@ class FilteringTestCase(unittest.TestCase): "event_format": "client", "event_fields": ["type", "content", "sender"], }, - # a single backslash should be permitted (though it is debatable whether # it should be permitted before anything other than `.`, and what that # actually means) diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 30a255d441..dbdd427cac 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -10,19 +10,19 @@ class TestRatelimiter(unittest.TestCase): key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1 ) self.assertTrue(allowed) - self.assertEquals(10., time_allowed) + self.assertEquals(10.0, time_allowed) allowed, time_allowed = limiter.can_do_action( key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1 ) self.assertFalse(allowed) - self.assertEquals(10., time_allowed) + self.assertEquals(10.0, time_allowed) allowed, time_allowed = limiter.can_do_action( key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1 ) self.assertTrue(allowed) - self.assertEquals(20., time_allowed) + self.assertEquals(20.0, time_allowed) def test_pruning(self): limiter = Ratelimiter() diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 590abc1e92..48792d1480 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -25,16 +25,18 @@ from tests.unittest import HomeserverTestCase class FederationReaderOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=FederationReaderServer, + http_client=None, homeserverToUse=FederationReaderServer ) return hs - @parameterized.expand([ - (["federation"], "auth_fail"), - ([], "no_resource"), - (["openid", "federation"], "auth_fail"), - (["openid"], "auth_fail"), - ]) + @parameterized.expand( + [ + (["federation"], "auth_fail"), + ([], "no_resource"), + (["openid", "federation"], "auth_fail"), + (["openid"], "auth_fail"), + ] + ) def test_openid_listener(self, names, expectation): """ Test different openid listener configurations. @@ -53,17 +55,14 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] try: - self.resource = ( - site.resource.children[b"_matrix"].children[b"federation"] - ) + self.resource = site.resource.children[b"_matrix"].children[b"federation"] except KeyError: if expectation == "no_resource": return raise request, channel = self.make_request( - "GET", - "/_matrix/federation/v1/openid/userinfo", + "GET", "/_matrix/federation/v1/openid/userinfo" ) self.render(request) @@ -74,16 +73,18 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - http_client=None, homeserverToUse=SynapseHomeServer, + http_client=None, homeserverToUse=SynapseHomeServer ) return hs - @parameterized.expand([ - (["federation"], "auth_fail"), - ([], "no_resource"), - (["openid", "federation"], "auth_fail"), - (["openid"], "auth_fail"), - ]) + @parameterized.expand( + [ + (["federation"], "auth_fail"), + ([], "no_resource"), + (["openid", "federation"], "auth_fail"), + (["openid"], "auth_fail"), + ] + ) def test_openid_listener(self, names, expectation): """ Test different openid listener configurations. @@ -102,17 +103,14 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): # Grab the resource from the site that was told to listen site = self.reactor.tcpServers[0][1] try: - self.resource = ( - site.resource.children[b"_matrix"].children[b"federation"] - ) + self.resource = site.resource.children[b"_matrix"].children[b"federation"] except KeyError: if expectation == "no_resource": return raise request, channel = self.make_request( - "GET", - "/_matrix/federation/v1/openid/userinfo", + "GET", "/_matrix/federation/v1/openid/userinfo" ) self.render(request) diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index 795b4c298d..5017cbce85 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -45,13 +45,7 @@ class ConfigGenerationTestCase(unittest.TestCase): ) self.assertSetEqual( - set( - [ - "homeserver.yaml", - "lemurs.win.log.config", - "lemurs.win.signing.key", - ] - ), + set(["homeserver.yaml", "lemurs.win.log.config", "lemurs.win.signing.key"]), set(os.listdir(self.dir)), ) diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py index 47fffcfeb2..0ec10019b3 100644 --- a/tests/config/test_room_directory.py +++ b/tests/config/test_room_directory.py @@ -22,7 +22,8 @@ from tests import unittest class RoomDirectoryConfigTestCase(unittest.TestCase): def test_alias_creation_acl(self): - config = yaml.safe_load(""" + config = yaml.safe_load( + """ alias_creation_rules: - user_id: "*bob*" alias: "*" @@ -38,43 +39,49 @@ class RoomDirectoryConfigTestCase(unittest.TestCase): action: "allow" room_list_publication_rules: [] - """) + """ + ) rd_config = RoomDirectoryConfig() rd_config.read_config(config) - self.assertFalse(rd_config.is_alias_creation_allowed( - user_id="@bob:example.com", - room_id="!test", - alias="#test:example.com", - )) - - self.assertTrue(rd_config.is_alias_creation_allowed( - user_id="@test:example.com", - room_id="!test", - alias="#unofficial_st:example.com", - )) - - self.assertTrue(rd_config.is_alias_creation_allowed( - user_id="@foobar:example.com", - room_id="!test", - alias="#test:example.com", - )) - - self.assertTrue(rd_config.is_alias_creation_allowed( - user_id="@gah:example.com", - room_id="!test", - alias="#goo:example.com", - )) - - self.assertFalse(rd_config.is_alias_creation_allowed( - user_id="@test:example.com", - room_id="!test", - alias="#test:example.com", - )) + self.assertFalse( + rd_config.is_alias_creation_allowed( + user_id="@bob:example.com", room_id="!test", alias="#test:example.com" + ) + ) + + self.assertTrue( + rd_config.is_alias_creation_allowed( + user_id="@test:example.com", + room_id="!test", + alias="#unofficial_st:example.com", + ) + ) + + self.assertTrue( + rd_config.is_alias_creation_allowed( + user_id="@foobar:example.com", + room_id="!test", + alias="#test:example.com", + ) + ) + + self.assertTrue( + rd_config.is_alias_creation_allowed( + user_id="@gah:example.com", room_id="!test", alias="#goo:example.com" + ) + ) + + self.assertFalse( + rd_config.is_alias_creation_allowed( + user_id="@test:example.com", room_id="!test", alias="#test:example.com" + ) + ) def test_room_publish_acl(self): - config = yaml.safe_load(""" + config = yaml.safe_load( + """ alias_creation_rules: [] room_list_publication_rules: @@ -92,55 +99,66 @@ class RoomDirectoryConfigTestCase(unittest.TestCase): action: "allow" - room_id: "!test-deny" action: "deny" - """) + """ + ) rd_config = RoomDirectoryConfig() rd_config.read_config(config) - self.assertFalse(rd_config.is_publishing_room_allowed( - user_id="@bob:example.com", - room_id="!test", - aliases=["#test:example.com"], - )) - - self.assertTrue(rd_config.is_publishing_room_allowed( - user_id="@test:example.com", - room_id="!test", - aliases=["#unofficial_st:example.com"], - )) - - self.assertTrue(rd_config.is_publishing_room_allowed( - user_id="@foobar:example.com", - room_id="!test", - aliases=[], - )) - - self.assertTrue(rd_config.is_publishing_room_allowed( - user_id="@gah:example.com", - room_id="!test", - aliases=["#goo:example.com"], - )) - - self.assertFalse(rd_config.is_publishing_room_allowed( - user_id="@test:example.com", - room_id="!test", - aliases=["#test:example.com"], - )) - - self.assertTrue(rd_config.is_publishing_room_allowed( - user_id="@foobar:example.com", - room_id="!test-deny", - aliases=[], - )) - - self.assertFalse(rd_config.is_publishing_room_allowed( - user_id="@gah:example.com", - room_id="!test-deny", - aliases=[], - )) - - self.assertTrue(rd_config.is_publishing_room_allowed( - user_id="@test:example.com", - room_id="!test", - aliases=["#unofficial_st:example.com", "#blah:example.com"], - )) + self.assertFalse( + rd_config.is_publishing_room_allowed( + user_id="@bob:example.com", + room_id="!test", + aliases=["#test:example.com"], + ) + ) + + self.assertTrue( + rd_config.is_publishing_room_allowed( + user_id="@test:example.com", + room_id="!test", + aliases=["#unofficial_st:example.com"], + ) + ) + + self.assertTrue( + rd_config.is_publishing_room_allowed( + user_id="@foobar:example.com", room_id="!test", aliases=[] + ) + ) + + self.assertTrue( + rd_config.is_publishing_room_allowed( + user_id="@gah:example.com", + room_id="!test", + aliases=["#goo:example.com"], + ) + ) + + self.assertFalse( + rd_config.is_publishing_room_allowed( + user_id="@test:example.com", + room_id="!test", + aliases=["#test:example.com"], + ) + ) + + self.assertTrue( + rd_config.is_publishing_room_allowed( + user_id="@foobar:example.com", room_id="!test-deny", aliases=[] + ) + ) + + self.assertFalse( + rd_config.is_publishing_room_allowed( + user_id="@gah:example.com", room_id="!test-deny", aliases=[] + ) + ) + + self.assertTrue( + rd_config.is_publishing_room_allowed( + user_id="@test:example.com", + room_id="!test", + aliases=["#unofficial_st:example.com", "#blah:example.com"], + ) + ) diff --git a/tests/config/test_server.py b/tests/config/test_server.py index f5836d73ac..de64965a60 100644 --- a/tests/config/test_server.py +++ b/tests/config/test_server.py @@ -19,7 +19,6 @@ from tests import unittest class ServerConfigTestCase(unittest.TestCase): - def test_is_threepid_reserved(self): user1 = {'medium': 'email', 'address': 'user1@example.com'} user2 = {'medium': 'email', 'address': 'user2@example.com'} diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index c260d3359f..40ca428778 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -26,7 +26,6 @@ class TestConfig(TlsConfig): class TLSConfigTests(TestCase): - def test_warn_self_signed(self): """ Synapse will give a warning when it loads a self-signed certificate. @@ -34,7 +33,8 @@ class TLSConfigTests(TestCase): config_dir = self.mktemp() os.mkdir(config_dir) with open(os.path.join(config_dir, "cert.pem"), 'w') as f: - f.write("""-----BEGIN CERTIFICATE----- + f.write( + """-----BEGIN CERTIFICATE----- MIID6DCCAtACAws9CjANBgkqhkiG9w0BAQUFADCBtzELMAkGA1UEBhMCVFIxDzAN BgNVBAgMBsOHb3J1bTEUMBIGA1UEBwwLQmHFn21ha8OnxLExEjAQBgNVBAMMCWxv Y2FsaG9zdDEcMBoGA1UECgwTVHdpc3RlZCBNYXRyaXggTGFiczEkMCIGA1UECwwb @@ -56,11 +56,12 @@ I8OtG1xGwcok53lyDuuUUDexnK4O5BkjKiVlNPg4HPim5Kuj2hRNFfNt/F2BVIlj iZupikC5MT1LQaRwidkSNxCku1TfAyueiBwhLnFwTmIGNnhuDCutEVAD9kFmcJN2 SznugAcPk4doX2+rL+ila+ThqgPzIkwTUHtnmjI0TI6xsDUlXz5S3UyudrE2Qsfz s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= ------END CERTIFICATE-----""") +-----END CERTIFICATE-----""" + ) config = { "tls_certificate_path": os.path.join(config_dir, "cert.pem"), - "tls_fingerprints": [] + "tls_fingerprints": [], } t = TestConfig() @@ -75,5 +76,5 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= "Self-signed TLS certificates will not be accepted by " "Synapse 1.0. Please either provide a valid certificate, " "or use Synapse's ACME support to provision one." - ) + ), ) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index f5bd7a1aa1..3c79d4afe7 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -169,7 +169,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): self.http_client.post_json.return_value = defer.Deferred() res_deferreds_2 = kr.verify_json_objects_for_server( - [("server10", json1, )] + [("server10", json1)] ) res_deferreds_2[0].addBoth(self.check_context, None) yield logcontext.make_deferred_yieldable(res_deferreds_2[0]) @@ -345,6 +345,7 @@ def _verify_json_for_server(keyring, server_name, json_object): """thin wrapper around verify_json_for_server which makes sure it is wrapped with the patched defer.inlineCallbacks. """ + @defer.inlineCallbacks def v(): rv1 = yield keyring.verify_json_for_server(server_name, json_object) diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 28e7e27416..7bb106b5f7 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -33,11 +33,15 @@ class FederationSenderTestCases(HomeserverTestCase): mock_state_handler = self.hs.get_state_handler() mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = self.hs.get_federation_transport_client().send_transaction + mock_send_transaction = ( + self.hs.get_federation_transport_client().send_transaction + ) mock_send_transaction.return_value = defer.succeed({}) sender = self.hs.get_federation_sender() - receipt = ReadReceipt("room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}) + receipt = ReadReceipt( + "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} + ) self.successResultOf(sender.send_read_receipt(receipt)) self.pump() @@ -46,21 +50,24 @@ class FederationSenderTestCases(HomeserverTestCase): mock_send_transaction.assert_called_once() json_cb = mock_send_transaction.call_args[0][1] data = json_cb() - self.assertEqual(data['edus'], [ - { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['event_id'], - 'data': {'ts': 1234}, - }, - }, + self.assertEqual( + data['edus'], + [ + { + 'edu_type': 'm.receipt', + 'content': { + 'room_id': { + 'm.read': { + 'user_id': { + 'event_ids': ['event_id'], + 'data': {'ts': 1234}, + } + } + } }, - }, - }, - ]) + } + ], + ) def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but @@ -68,11 +75,15 @@ class FederationSenderTestCases(HomeserverTestCase): mock_state_handler = self.hs.get_state_handler() mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] - mock_send_transaction = self.hs.get_federation_transport_client().send_transaction + mock_send_transaction = ( + self.hs.get_federation_transport_client().send_transaction + ) mock_send_transaction.return_value = defer.succeed({}) sender = self.hs.get_federation_sender() - receipt = ReadReceipt("room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}) + receipt = ReadReceipt( + "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} + ) self.successResultOf(sender.send_read_receipt(receipt)) self.pump() @@ -81,25 +92,30 @@ class FederationSenderTestCases(HomeserverTestCase): mock_send_transaction.assert_called_once() json_cb = mock_send_transaction.call_args[0][1] data = json_cb() - self.assertEqual(data['edus'], [ - { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['event_id'], - 'data': {'ts': 1234}, - }, - }, + self.assertEqual( + data['edus'], + [ + { + 'edu_type': 'm.receipt', + 'content': { + 'room_id': { + 'm.read': { + 'user_id': { + 'event_ids': ['event_id'], + 'data': {'ts': 1234}, + } + } + } }, - }, - }, - ]) + } + ], + ) mock_send_transaction.reset_mock() # send the second RR - receipt = ReadReceipt("room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}) + receipt = ReadReceipt( + "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234} + ) self.successResultOf(sender.send_read_receipt(receipt)) self.pump() mock_send_transaction.assert_not_called() @@ -111,18 +127,21 @@ class FederationSenderTestCases(HomeserverTestCase): mock_send_transaction.assert_called_once() json_cb = mock_send_transaction.call_args[0][1] data = json_cb() - self.assertEqual(data['edus'], [ - { - 'edu_type': 'm.receipt', - 'content': { - 'room_id': { - 'm.read': { - 'user_id': { - 'event_ids': ['other_id'], - 'data': {'ts': 1234}, - }, - }, + self.assertEqual( + data['edus'], + [ + { + 'edu_type': 'm.receipt', + 'content': { + 'room_id': { + 'm.read': { + 'user_id': { + 'event_ids': ['other_id'], + 'data': {'ts': 1234}, + } + } + } }, - }, - }, - ]) + } + ], + ) diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 5b2105bc76..917548bb31 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -115,11 +115,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): # We cheekily override the config to add custom alias creation rules config = {} config["alias_creation_rules"] = [ - { - "user_id": "*", - "alias": "#unofficial_*", - "action": "allow", - } + {"user_id": "*", "alias": "#unofficial_*", "action": "allow"} ] config["room_list_publication_rules"] = [] @@ -162,9 +158,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): room_id = self.helper.create_room_as(self.user_id) request, channel = self.make_request( - "PUT", - b"directory/list/room/%s" % (room_id.encode('ascii'),), - b'{}', + "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}' ) self.render(request) self.assertEquals(200, channel.code, channel.result) @@ -179,10 +173,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): self.directory_handler.enable_room_list_search = True # Room list is enabled so we should get some results - request, channel = self.make_request( - "GET", - b"publicRooms", - ) + request, channel = self.make_request("GET", b"publicRooms") self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) > 0) @@ -191,10 +182,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): self.directory_handler.enable_room_list_search = False # Room list disabled so we should get no results - request, channel = self.make_request( - "GET", - b"publicRooms", - ) + request, channel = self.make_request("GET", b"publicRooms") self.render(request) self.assertEquals(200, channel.code, channel.result) self.assertTrue(len(channel.json_body["chunk"]) == 0) @@ -202,9 +190,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): # Room list disabled so we shouldn't be allowed to publish rooms room_id = self.helper.create_room_as(self.user_id) request, channel = self.make_request( - "PUT", - b"directory/list/room/%s" % (room_id.encode('ascii'),), - b'{}', + "PUT", b"directory/list/room/%s" % (room_id.encode('ascii'),), b'{}' ) self.render(request) self.assertEquals(403, channel.code, channel.result) diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 1c49bbbc3c..2e72a1dd23 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -36,7 +36,7 @@ room_keys = { "first_message_index": 1, "forwarded_count": 1, "is_verified": False, - "session_data": "SSBBTSBBIEZJU0gK" + "session_data": "SSBBTSBBIEZJU0gK", } } } @@ -47,15 +47,13 @@ room_keys = { class E2eRoomKeysHandlerTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(E2eRoomKeysHandlerTestCase, self).__init__(*args, **kwargs) - self.hs = None # type: synapse.server.HomeServer + self.hs = None # type: synapse.server.HomeServer self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler @defer.inlineCallbacks def setUp(self): self.hs = yield utils.setup_test_homeserver( - self.addCleanup, - handlers=None, - replication_layer=mock.Mock(), + self.addCleanup, handlers=None, replication_layer=mock.Mock() ) self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs) self.local_user = "@boris:" + self.hs.hostname @@ -88,67 +86,86 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_create_version(self): """Check that we can create and then retrieve versions. """ - res = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + res = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(res, "1") # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) - self.assertDictEqual(res, { - "version": "1", - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + self.assertDictEqual( + res, + { + "version": "1", + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) # check we can retrieve it as a specific version res = yield self.handler.get_version_info(self.local_user, "1") - self.assertDictEqual(res, { - "version": "1", - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + self.assertDictEqual( + res, + { + "version": "1", + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }, + ) # upload a new one... - res = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }) + res = yield self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) self.assertEqual(res, "2") # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) - self.assertDictEqual(res, { - "version": "2", - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }) + self.assertDictEqual( + res, + { + "version": "2", + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) @defer.inlineCallbacks def test_update_version(self): """Check that we can update versions. """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") - res = yield self.handler.update_version(self.local_user, version, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": version - }) + res = yield self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": version, + }, + ) self.assertDictEqual(res, {}) # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) - self.assertDictEqual(res, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": version - }) + self.assertDictEqual( + res, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": version, + }, + ) @defer.inlineCallbacks def test_update_missing_version(self): @@ -156,11 +173,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.update_version(self.local_user, "1", { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "1" - }) + yield self.handler.update_version( + self.local_user, + "1", + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "1", + }, + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -170,29 +191,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """Check that we get a 400 if the version in the body is missing or doesn't match """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") res = None try: - yield self.handler.update_version(self.local_user, version, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data" - }) + yield self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + }, + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 400) res = None try: - yield self.handler.update_version(self.local_user, version, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "revised_first_version_auth_data", - "version": "incorrect" - }) + yield self.handler.update_version( + self.local_user, + version, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "incorrect", + }, + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 400) @@ -223,10 +252,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_delete_version(self): """Check that we can create and then delete versions. """ - res = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + res = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(res, "1") # check we can delete it @@ -255,16 +284,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_get_missing_room_keys(self): """Check we get an empty response from an empty backup """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") res = yield self.handler.get_room_keys(self.local_user, version) - self.assertDictEqual(res, { - "rooms": {} - }) + self.assertDictEqual(res, {"rooms": {}}) # TODO: test the locking semantics when uploading room_keys, # although this is probably best done in sytest @@ -275,7 +302,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """ res = None try: - yield self.handler.upload_room_keys(self.local_user, "no_version", room_keys) + yield self.handler.upload_room_keys( + self.local_user, "no_version", room_keys + ) except errors.SynapseError as e: res = e.code self.assertEqual(res, 404) @@ -285,10 +314,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): """Check that we get a 404 on uploading keys when an nonexistent version is specified """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") res = None @@ -304,16 +333,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_wrong_version(self): """Check that we get a 403 on uploading keys for an old version """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "second_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "second_version_auth_data", + }, + ) self.assertEqual(version, "2") res = None @@ -327,10 +359,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_insert(self): """Check that we can insert and retrieve keys for a session """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") yield self.handler.upload_room_keys(self.local_user, version, room_keys) @@ -340,18 +372,13 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check getting room_keys for a given room res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org" + self.local_user, version, room_id="!abc:matrix.org" ) self.assertDictEqual(res, room_keys) # check getting room_keys for a given session_id res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) self.assertDictEqual(res, room_keys) @@ -359,10 +386,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_upload_room_keys_merge(self): """Check that we can upload a new room_key for an existing session and have it correctly merged""" - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") yield self.handler.upload_room_keys(self.local_user, version, room_keys) @@ -378,7 +405,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], - "SSBBTSBBIEZJU0gK" + "SSBBTSBBIEZJU0gK", ) # test that marking the session as verified however /does/ replace it @@ -387,8 +414,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( - res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], - "new" + res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new" ) # test that a session with a higher forwarded_count doesn't replace one @@ -399,8 +425,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res = yield self.handler.get_room_keys(self.local_user, version) self.assertEqual( - res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], - "new" + res['rooms']['!abc:matrix.org']['sessions']['c0ff33']['session_data'], "new" ) # TODO: check edge cases as well as the common variations here @@ -409,56 +434,36 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): def test_delete_room_keys(self): """Check that we can insert and delete keys for a session """ - version = yield self.handler.create_version(self.local_user, { - "algorithm": "m.megolm_backup.v1", - "auth_data": "first_version_auth_data", - }) + version = yield self.handler.create_version( + self.local_user, + {"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"}, + ) self.assertEqual(version, "1") # check for bulk-delete yield self.handler.upload_room_keys(self.local_user, version, room_keys) yield self.handler.delete_room_keys(self.local_user, version) res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) - self.assertDictEqual(res, { - "rooms": {} - }) + self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per room yield self.handler.upload_room_keys(self.local_user, version, room_keys) yield self.handler.delete_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", + self.local_user, version, room_id="!abc:matrix.org" ) res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) - self.assertDictEqual(res, { - "rooms": {} - }) + self.assertDictEqual(res, {"rooms": {}}) # check for bulk-delete per session yield self.handler.upload_room_keys(self.local_user, version, room_keys) yield self.handler.delete_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) res = yield self.handler.get_room_keys( - self.local_user, - version, - room_id="!abc:matrix.org", - session_id="c0ff33", + self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" ) - self.assertDictEqual(res, { - "rooms": {} - }) + self.assertDictEqual(res, {"rooms": {}}) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 94c6080e34..f70c6e7d65 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -424,8 +424,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "server", http_client=None, - federation_sender=Mock(), + "server", http_client=None, federation_sender=Mock() ) return hs @@ -457,7 +456,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # Mark test2 as online, test will be offline with a last_active of 0 self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}, + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} ) self.reactor.pump([0]) # Wait for presence updates to be handled @@ -506,13 +505,13 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # Mark test as online self.presence_handler.set_state( - UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}, + UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} ) # Mark test2 as online, test will be offline with a last_active of 0. # Note we don't join them to the room yet self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}, + UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} ) # Add servers to the room @@ -541,8 +540,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_state.state, PresenceState.ONLINE) self.federation_sender.send_presence_to_destinations.assert_called_once_with( - destinations=set(("server2", "server3")), - states=[expected_state] + destinations=set(("server2", "server3")), states=[expected_state] ) def _add_new_user(self, room_id, user_id): @@ -565,7 +563,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): type=EventTypes.Member, sender=user_id, state_key=user_id, - content={"membership": Membership.JOIN} + content={"membership": Membership.JOIN}, ) prev_event_ids = self.get_success( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 5a0b6c201c..cb8b4d2913 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -64,20 +64,22 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) hs = self.setup_test_homeserver( - datastore=(Mock( - spec=[ - # Bits that Federation needs - "prep_send_transaction", - "delivered_txn", - "get_received_txn_response", - "set_received_txn_response", - "get_destination_retry_timings", - "get_devices_by_remote", - # Bits that user_directory needs - "get_user_directory_stream_pos", - "get_current_state_deltas", - ] - )), + datastore=( + Mock( + spec=[ + # Bits that Federation needs + "prep_send_transaction", + "delivered_txn", + "get_received_txn_response", + "set_received_txn_response", + "get_destination_retry_timings", + "get_devices_by_remote", + # Bits that user_directory needs + "get_user_directory_stream_pos", + "get_current_state_deltas", + ] + ) + ), notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring, @@ -87,7 +89,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): # the tests assume that we are starting at unix time 1000 - reactor.pump((1000, )) + reactor.pump((1000,)) mock_notifier = hs.get_notifier() self.on_new_event = mock_notifier.on_new_event @@ -114,6 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def check_joined_room(room_id, user_id): if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") + hs.get_auth().check_joined_room = check_joined_room def get_joined_hosts_for_room(room_id): @@ -123,6 +126,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def get_current_users_in_room(room_id): return set(str(u) for u in self.room_members) + hs.get_state_handler().get_current_users_in_room = get_current_users_in_room self.datastore.get_user_directory_stream_pos.return_value = ( @@ -141,21 +145,16 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 0) - self.successResultOf(self.handler.started_typing( - target_user=U_APPLE, - auth_user=U_APPLE, - room_id=ROOM_ID, - timeout=20000, - )) - - self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[ROOM_ID])] + self.successResultOf( + self.handler.started_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000 + ) ) + self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], [ @@ -170,12 +169,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): def test_started_typing_remote_send(self): self.room_members = [U_APPLE, U_ONION] - self.successResultOf(self.handler.started_typing( - target_user=U_APPLE, - auth_user=U_APPLE, - room_id=ROOM_ID, - timeout=20000, - )) + self.successResultOf( + self.handler.started_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=20000 + ) + ) put_json = self.hs.get_http_client().put_json put_json.assert_called_once_with( @@ -216,14 +214,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(channel.code, 200) - self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[ROOM_ID])] - ) + self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], [ @@ -247,14 +241,14 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 0) - self.successResultOf(self.handler.stopped_typing( - target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID - )) - - self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[ROOM_ID])] + self.successResultOf( + self.handler.stopped_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID + ) ) + self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) + put_json = self.hs.get_http_client().put_json put_json.assert_called_once_with( "farm", @@ -274,18 +268,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], - [ - { - "type": "m.typing", - "room_id": ROOM_ID, - "content": {"user_ids": []}, - } - ], + [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], ) def test_typing_timeout(self): @@ -293,22 +279,17 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.assertEquals(self.event_source.get_current_key(), 0) - self.successResultOf(self.handler.started_typing( - target_user=U_APPLE, - auth_user=U_APPLE, - room_id=ROOM_ID, - timeout=10000, - )) - - self.on_new_event.assert_has_calls( - [call('typing_key', 1, rooms=[ROOM_ID])] + self.successResultOf( + self.handler.started_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000 + ) ) + + self.on_new_event.assert_has_calls([call('typing_key', 1, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], [ @@ -320,45 +301,30 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) - self.reactor.pump([16, ]) + self.reactor.pump([16]) - self.on_new_event.assert_has_calls( - [call('typing_key', 2, rooms=[ROOM_ID])] - ) + self.on_new_event.assert_has_calls([call('typing_key', 2, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 2) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=1 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) self.assertEquals( events[0], - [ - { - "type": "m.typing", - "room_id": ROOM_ID, - "content": {"user_ids": []}, - } - ], + [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], ) # SYN-230 - see if we can still set after timeout - self.successResultOf(self.handler.started_typing( - target_user=U_APPLE, - auth_user=U_APPLE, - room_id=ROOM_ID, - timeout=10000, - )) - - self.on_new_event.assert_has_calls( - [call('typing_key', 3, rooms=[ROOM_ID])] + self.successResultOf( + self.handler.started_typing( + target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID, timeout=10000 + ) ) + + self.on_new_event.assert_has_calls([call('typing_key', 3, rooms=[ROOM_ID])]) self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) - events = self.event_source.get_new_events( - room_ids=[ROOM_ID], from_key=0 - ) + events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) self.assertEquals( events[0], [ diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 7dd1a1daf8..44468f5382 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -352,9 +352,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): # Assert user directory is not empty request, channel = self.make_request( - "POST", - b"user_directory/search", - b'{"search_term":"user2"}', + "POST", b"user_directory/search", b'{"search_term":"user2"}' ) self.render(request) self.assertEquals(200, channel.code, channel.result) @@ -363,9 +361,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): # Disable user directory and check search returns nothing self.config.user_directory_search_enabled = False request, channel = self.make_request( - "POST", - b"user_directory/search", - b'{"search_term":"user2"}', + "POST", b"user_directory/search", b'{"search_term":"user2"}' ) self.render(request) self.assertEquals(200, channel.code, channel.result) diff --git a/tests/http/__init__.py b/tests/http/__init__.py index ee8010f598..851fc0eb33 100644 --- a/tests/http/__init__.py +++ b/tests/http/__init__.py @@ -24,14 +24,12 @@ def get_test_cert_file(): # # openssl req -x509 -newkey rsa:4096 -keyout server.pem -out server.pem -days 36500 \ # -nodes -subj '/CN=testserv' - return os.path.join( - os.path.dirname(__file__), - 'server.pem', - ) + return os.path.join(os.path.dirname(__file__), 'server.pem') class ServerTLSContext(object): """A TLS Context which presents our test cert.""" + def __init__(self): self.filename = get_test_cert_file() diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index e9eb662c4c..7036615041 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -79,12 +79,12 @@ class MatrixFederationAgentTests(TestCase): # stubbing that out here. client_protocol = client_factory.buildProtocol(None) client_protocol.makeConnection( - FakeTransport(server_tls_protocol, self.reactor, client_protocol), + FakeTransport(server_tls_protocol, self.reactor, client_protocol) ) # tell the server tls protocol to send its stuff back to the client, too server_tls_protocol.makeConnection( - FakeTransport(client_protocol, self.reactor, server_tls_protocol), + FakeTransport(client_protocol, self.reactor, server_tls_protocol) ) # give the reactor a pump to get the TLS juices flowing. @@ -125,7 +125,7 @@ class MatrixFederationAgentTests(TestCase): _check_logcontext(context) def _handle_well_known_connection( - self, client_factory, expected_sni, content, response_headers={}, + self, client_factory, expected_sni, content, response_headers={} ): """Handle an outgoing HTTPs connection: wire it up to a server, check that the request is for a .well-known, and send the response. @@ -139,8 +139,7 @@ class MatrixFederationAgentTests(TestCase): """ # make the connection for .well-known well_known_server = self._make_connection( - client_factory, - expected_sni=expected_sni, + client_factory, expected_sni=expected_sni ) # check the .well-known request and send a response self.assertEqual(len(well_known_server.requests), 1) @@ -154,17 +153,14 @@ class MatrixFederationAgentTests(TestCase): """ self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/.well-known/matrix/server') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # send back a response for k, v in headers.items(): request.setHeader(k, v) request.write(content) request.finish() - self.reactor.pump((0.1, )) + self.reactor.pump((0.1,)) def test_get(self): """ @@ -184,18 +180,14 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=b"testserv", - ) + http_server = self._make_connection(client_factory, expected_sni=b"testserv") self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv:8448'] + request.requestHeaders.getRawHeaders(b'host'), [b'testserv:8448'] ) content = request.content.read() self.assertEqual(content, b'') @@ -244,19 +236,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=None, - ) + http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'1.2.3.4'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'1.2.3.4']) # finish the request request.finish() @@ -285,19 +271,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=None, - ) + http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'[::1]'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]']) # finish the request request.finish() @@ -326,19 +306,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 80) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=None, - ) + http_server = self._make_connection(client_factory, expected_sni=None) self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'[::1]:80'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'[::1]:80']) # finish the request request.finish() @@ -377,7 +351,7 @@ class MatrixFederationAgentTests(TestCase): # now there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.testserv", + b"_matrix._tcp.testserv" ) # we should fall back to a direct connection @@ -387,19 +361,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=b'testserv', - ) + http_server = self._make_connection(client_factory, expected_sni=b'testserv') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # finish the request request.finish() @@ -427,13 +395,14 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 443) self._handle_well_known_connection( - client_factory, expected_sni=b"testserv", + client_factory, + expected_sni=b"testserv", content=b'{ "m.server": "target-server" }', ) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.target-server", + b"_matrix._tcp.target-server" ) # now we should get a connection to the target server @@ -444,8 +413,7 @@ class MatrixFederationAgentTests(TestCase): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'target-server', + client_factory, expected_sni=b'target-server' ) self.assertEqual(len(http_server.requests), 1) @@ -453,8 +421,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'target-server'], + request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] ) # finish the request @@ -490,8 +457,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 443) redirect_server = self._make_connection( - client_factory, - expected_sni=b"testserv", + client_factory, expected_sni=b"testserv" ) # send a 302 redirect @@ -500,7 +466,7 @@ class MatrixFederationAgentTests(TestCase): request.redirect(b'https://testserv/even_better_known') request.finish() - self.reactor.pump((0.1, )) + self.reactor.pump((0.1,)) # now there should be another connection clients = self.reactor.tcpClients @@ -510,8 +476,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 443) well_known_server = self._make_connection( - client_factory, - expected_sni=b"testserv", + client_factory, expected_sni=b"testserv" ) self.assertEqual(len(well_known_server.requests), 1, "No request after 302") @@ -521,11 +486,11 @@ class MatrixFederationAgentTests(TestCase): request.write(b'{ "m.server": "target-server" }') request.finish() - self.reactor.pump((0.1, )) + self.reactor.pump((0.1,)) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.target-server", + b"_matrix._tcp.target-server" ) # now we should get a connection to the target server @@ -536,8 +501,7 @@ class MatrixFederationAgentTests(TestCase): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'target-server', + client_factory, expected_sni=b'target-server' ) self.assertEqual(len(http_server.requests), 1) @@ -545,8 +509,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'target-server'], + request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] ) # finish the request @@ -585,12 +548,12 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 443) self._handle_well_known_connection( - client_factory, expected_sni=b"testserv", content=b'NOT JSON', + client_factory, expected_sni=b"testserv", content=b'NOT JSON' ) # now there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.testserv", + b"_matrix._tcp.testserv" ) # we should fall back to a direct connection @@ -600,19 +563,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8448) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=b'testserv', - ) + http_server = self._make_connection(client_factory, expected_sni=b'testserv') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # finish the request request.finish() @@ -635,7 +592,7 @@ class MatrixFederationAgentTests(TestCase): # the request for a .well-known will have failed with a DNS lookup error. self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.testserv", + b"_matrix._tcp.testserv" ) # Make sure treq is trying to connect @@ -646,19 +603,13 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 8443) # make a test server, and wire up the client - http_server = self._make_connection( - client_factory, - expected_sni=b'testserv', - ) + http_server = self._make_connection(client_factory, expected_sni=b'testserv') self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') - self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'testserv'], - ) + self.assertEqual(request.requestHeaders.getRawHeaders(b'host'), [b'testserv']) # finish the request request.finish() @@ -685,17 +636,18 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(port, 443) self.mock_resolver.resolve_service.side_effect = lambda _: [ - Server(host=b"srvtarget", port=8443), + Server(host=b"srvtarget", port=8443) ] self._handle_well_known_connection( - client_factory, expected_sni=b"testserv", + client_factory, + expected_sni=b"testserv", content=b'{ "m.server": "target-server" }', ) # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.target-server", + b"_matrix._tcp.target-server" ) # now we should get a connection to the target of the SRV record @@ -706,8 +658,7 @@ class MatrixFederationAgentTests(TestCase): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'target-server', + client_factory, expected_sni=b'target-server' ) self.assertEqual(len(http_server.requests), 1) @@ -715,8 +666,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'target-server'], + request.requestHeaders.getRawHeaders(b'host'), [b'target-server'] ) # finish the request @@ -757,7 +707,7 @@ class MatrixFederationAgentTests(TestCase): # now there should have been a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.xn--bcher-kva.com", + b"_matrix._tcp.xn--bcher-kva.com" ) # We should fall back to port 8448 @@ -769,8 +719,7 @@ class MatrixFederationAgentTests(TestCase): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'xn--bcher-kva.com', + client_factory, expected_sni=b'xn--bcher-kva.com' ) self.assertEqual(len(http_server.requests), 1) @@ -778,8 +727,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'xn--bcher-kva.com'], + request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com'] ) # finish the request @@ -801,7 +749,7 @@ class MatrixFederationAgentTests(TestCase): self.assertNoResult(test_d) self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.xn--bcher-kva.com", + b"_matrix._tcp.xn--bcher-kva.com" ) # Make sure treq is trying to connect @@ -813,8 +761,7 @@ class MatrixFederationAgentTests(TestCase): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, - expected_sni=b'xn--bcher-kva.com', + client_factory, expected_sni=b'xn--bcher-kva.com' ) self.assertEqual(len(http_server.requests), 1) @@ -822,8 +769,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(request.method, b'GET') self.assertEqual(request.path, b'/foo/bar') self.assertEqual( - request.requestHeaders.getRawHeaders(b'host'), - [b'xn--bcher-kva.com'], + request.requestHeaders.getRawHeaders(b'host'), [b'xn--bcher-kva.com'] ) # finish the request @@ -897,67 +843,70 @@ class TestCachePeriodFromHeaders(TestCase): # uppercase self.assertEqual( _cache_period_from_headers( - Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}), - ), 100, + Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}) + ), + 100, ) # missing value - self.assertIsNone(_cache_period_from_headers( - Headers({b'Cache-Control': [b'max-age=, bar']}), - )) + self.assertIsNone( + _cache_period_from_headers(Headers({b'Cache-Control': [b'max-age=, bar']})) + ) # hackernews: bogus due to semicolon - self.assertIsNone(_cache_period_from_headers( - Headers({b'Cache-Control': [b'private; max-age=0']}), - )) + self.assertIsNone( + _cache_period_from_headers( + Headers({b'Cache-Control': [b'private; max-age=0']}) + ) + ) # github self.assertEqual( _cache_period_from_headers( - Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}), - ), 0, + Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}) + ), + 0, ) # google self.assertEqual( _cache_period_from_headers( - Headers({b'cache-control': [b'private, max-age=0']}), - ), 0, + Headers({b'cache-control': [b'private, max-age=0']}) + ), + 0, ) def test_expires(self): self.assertEqual( _cache_period_from_headers( Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}), - time_now=lambda: 1548833700 - ), 33, + time_now=lambda: 1548833700, + ), + 33, ) # cache-control overrides expires self.assertEqual( _cache_period_from_headers( - Headers({ - b'cache-control': [b'max-age=10'], - b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT'] - }), - time_now=lambda: 1548833700 - ), 10, + Headers( + { + b'cache-control': [b'max-age=10'], + b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT'], + } + ), + time_now=lambda: 1548833700, + ), + 10, ) # invalid expires means immediate expiry - self.assertEqual( - _cache_period_from_headers( - Headers({b'Expires': [b'0']}), - ), 0, - ) + self.assertEqual(_cache_period_from_headers(Headers({b'Expires': [b'0']})), 0) def _check_logcontext(context): current = LoggingContext.current_context() if current is not context: - raise AssertionError( - "Expected logcontext %s but was %s" % (context, current), - ) + raise AssertionError("Expected logcontext %s but was %s" % (context, current)) def _build_test_server(): @@ -973,7 +922,7 @@ def _build_test_server(): server_factory.log = _log_request server_tls_factory = TLSMemoryBIOFactory( - ServerTLSContext(), isClient=False, wrappedFactory=server_factory, + ServerTLSContext(), isClient=False, wrappedFactory=server_factory ) return server_tls_factory.buildProtocol(None) @@ -987,6 +936,7 @@ def _log_request(request): @implementer(IPolicyForHTTPS) class TrustingTLSPolicyForHTTPS(object): """An IPolicyForHTTPS which doesn't do any certificate verification""" + def creatorForNetloc(self, hostname, port): certificateOptions = OpenSSLCertificateOptions() return ClientTLSOptions(hostname, certificateOptions.getContext()) diff --git a/tests/http/federation/test_srv_resolver.py b/tests/http/federation/test_srv_resolver.py index a872e2441e..034c0db8d2 100644 --- a/tests/http/federation/test_srv_resolver.py +++ b/tests/http/federation/test_srv_resolver.py @@ -68,9 +68,7 @@ class SrvResolverTestCase(unittest.TestCase): dns_client_mock.lookupService.assert_called_once_with(service_name) - result_deferred.callback( - ([answer_srv], None, None) - ) + result_deferred.callback(([answer_srv], None, None)) servers = self.successResultOf(test_d) @@ -112,7 +110,7 @@ class SrvResolverTestCase(unittest.TestCase): cache = {service_name: [entry]} resolver = SrvResolver( - dns_client=dns_client_mock, cache=cache, get_time=clock.time, + dns_client=dns_client_mock, cache=cache, get_time=clock.time ) servers = yield resolver.resolve_service(service_name) @@ -168,11 +166,13 @@ class SrvResolverTestCase(unittest.TestCase): self.assertNoResult(resolve_d) # returning a single "." should make the lookup fail with a ConenctError - lookup_deferred.callback(( - [dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))], - None, - None, - )) + lookup_deferred.callback( + ( + [dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"."))], + None, + None, + ) + ) self.failureResultOf(resolve_d, ConnectError) @@ -191,14 +191,16 @@ class SrvResolverTestCase(unittest.TestCase): resolve_d = resolver.resolve_service(service_name) self.assertNoResult(resolve_d) - lookup_deferred.callback(( - [ - dns.RRHeader(type=dns.A, payload=dns.Record_A()), - dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")), - ], - None, - None, - )) + lookup_deferred.callback( + ( + [ + dns.RRHeader(type=dns.A, payload=dns.Record_A()), + dns.RRHeader(type=dns.SRV, payload=dns.Record_SRV(target=b"host")), + ], + None, + None, + ) + ) servers = self.successResultOf(resolve_d) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index cd8e086f86..279e456614 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -36,9 +36,7 @@ from tests.unittest import HomeserverTestCase def check_logcontext(context): current = LoggingContext.current_context() if current is not context: - raise AssertionError( - "Expected logcontext %s but was %s" % (context, current), - ) + raise AssertionError("Expected logcontext %s but was %s" % (context, current)) class FederationClientTests(HomeserverTestCase): @@ -54,6 +52,7 @@ class FederationClientTests(HomeserverTestCase): """ happy-path test of a GET request """ + @defer.inlineCallbacks def do_request(): with LoggingContext("one") as context: @@ -175,8 +174,7 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance( - f.value.inner_exception, - (ConnectingCancelledError, TimeoutError), + f.value.inner_exception, (ConnectingCancelledError, TimeoutError) ) def test_client_connect_no_response(self): @@ -216,9 +214,7 @@ class FederationClientTests(HomeserverTestCase): Once the client gets the headers, _request returns successfully. """ request = MatrixFederationRequest( - method="GET", - destination="testserv:8008", - path="foo/bar", + method="GET", destination="testserv:8008", path="foo/bar" ) d = self.cl._send_request(request, timeout=10000) @@ -258,8 +254,10 @@ class FederationClientTests(HomeserverTestCase): # Send it the HTTP response client.dataReceived( - (b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" - b"Server: Fake\r\n\r\n") + ( + b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n" + b"Server: Fake\r\n\r\n" + ) ) # Push by enough to time it out @@ -274,9 +272,7 @@ class FederationClientTests(HomeserverTestCase): requiring a trailing slash. We need to retry the request with a trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622. """ - d = self.cl.get_json( - "testserv:8008", "foo/bar", try_trailing_slash_on_400=True, - ) + d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) # Send the request self.pump() @@ -329,9 +325,7 @@ class FederationClientTests(HomeserverTestCase): See test_client_requires_trailing_slashes() for context. """ - d = self.cl.get_json( - "testserv:8008", "foo/bar", try_trailing_slash_on_400=True, - ) + d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) # Send the request self.pump() @@ -368,10 +362,7 @@ class FederationClientTests(HomeserverTestCase): self.failureResultOf(d) def test_client_sends_body(self): - self.cl.post_json( - "testserv:8008", "foo/bar", timeout=10000, - data={"a": "b"} - ) + self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}) self.pump() diff --git a/tests/patch_inline_callbacks.py b/tests/patch_inline_callbacks.py index 0f613945c8..ee0add3455 100644 --- a/tests/patch_inline_callbacks.py +++ b/tests/patch_inline_callbacks.py @@ -45,7 +45,9 @@ def do_patch(): except Exception: if LoggingContext.current_context() != start_context: err = "%s changed context from %s to %s on exception" % ( - f, start_context, LoggingContext.current_context() + f, + start_context, + LoggingContext.current_context(), ) print(err, file=sys.stderr) raise Exception(err) @@ -54,7 +56,9 @@ def do_patch(): if not isinstance(res, Deferred) or res.called: if LoggingContext.current_context() != start_context: err = "%s changed context from %s to %s" % ( - f, start_context, LoggingContext.current_context() + f, + start_context, + LoggingContext.current_context(), ) # print the error to stderr because otherwise all we # see in travis-ci is the 500 error @@ -66,9 +70,7 @@ def do_patch(): err = ( "%s returned incomplete deferred in non-sentinel context " "%s (start was %s)" - ) % ( - f, LoggingContext.current_context(), start_context, - ) + ) % (f, LoggingContext.current_context(), start_context) print(err, file=sys.stderr) raise Exception(err) @@ -76,7 +78,9 @@ def do_patch(): if LoggingContext.current_context() != start_context: err = "%s completion of %s changed context from %s to %s" % ( "Failure" if isinstance(r, Failure) else "Success", - f, start_context, LoggingContext.current_context(), + f, + start_context, + LoggingContext.current_context(), ) print(err, file=sys.stderr) raise Exception(err) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 1f72a2a04f..104349cdbd 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -74,21 +74,18 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): self.assertEqual( master_result, expected_result, - "Expected master result to be %r but was %r" % ( - expected_result, master_result - ), + "Expected master result to be %r but was %r" + % (expected_result, master_result), ) self.assertEqual( slaved_result, expected_result, - "Expected slave result to be %r but was %r" % ( - expected_result, slaved_result - ), + "Expected slave result to be %r but was %r" + % (expected_result, slaved_result), ) self.assertEqual( master_result, slaved_result, - "Slave result %r does not match master result %r" % ( - slaved_result, master_result - ), + "Slave result %r does not match master result %r" + % (slaved_result, master_result), ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 65ecff3bd6..a368117b43 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -234,10 +234,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" ) msg, msgctx = self.build_event() - self.get_success(self.master_store.persist_events([ - (j2, j2ctx), - (msg, msgctx), - ])) + self.get_success(self.master_store.persist_events([(j2, j2ctx), (msg, msgctx)])) self.replicate() event_source = RoomEventSource(self.hs) @@ -257,15 +254,13 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): # # First, we get a list of the rooms we are joined to joined_rooms = self.get_success( - self.slaved_store.get_rooms_for_user_with_stream_ordering( - USER_ID_2, - ), + self.slaved_store.get_rooms_for_user_with_stream_ordering(USER_ID_2) ) # Then, we get a list of the events since the last sync membership_changes = self.get_success( self.slaved_store.get_membership_changes_for_user( - USER_ID_2, prev_token, current_token, + USER_ID_2, prev_token, current_token ) ) @@ -298,9 +293,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.master_store.persist_events([(event, context)], backfilled=True) ) else: - self.get_success( - self.master_store.persist_event(event, context) - ) + self.get_success(self.master_store.persist_event(event, context)) return event @@ -359,9 +352,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ) else: state_handler = self.hs.get_state_handler() - context = self.get_success(state_handler.compute_event_context( - event - )) + context = self.get_success(state_handler.compute_event_context(event)) self.master_store.add_push_actions_to_staging( event.event_id, {user_id: actions for user_id, actions in push_actions} diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 38b368a972..ce3835ae6a 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -22,6 +22,7 @@ from tests.server import FakeTransport class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" + def prepare(self, reactor, clock, hs): # build a replication server server_factory = ReplicationStreamProtocolFactory(self.hs) @@ -52,6 +53,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): class TestReplicationClientHandler(object): """Drop-in for ReplicationClientHandler which just collects RDATA rows""" + def __init__(self): self.received_rdata_rows = [] @@ -69,6 +71,4 @@ class TestReplicationClientHandler(object): def on_rdata(self, stream_name, token, rows): for r in rows: - self.received_rdata_rows.append( - (stream_name, token, r) - ) + self.received_rdata_rows.append((stream_name, token, r)) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index da19a83918..ee5f09041f 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -41,10 +41,10 @@ class VersionTestCase(unittest.HomeserverTestCase): request, channel = self.make_request("GET", self.url, shorthand=False) self.render(request) - self.assertEqual(200, int(channel.result["code"]), - msg=channel.result["body"]) - self.assertEqual({'server_version', 'python_version'}, - set(channel.json_body.keys())) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual( + {'server_version', 'python_version'}, set(channel.json_body.keys()) + ) class UserRegisterTestCase(unittest.HomeserverTestCase): @@ -200,9 +200,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) - want_mac.update( - nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin" - ) + want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin") want_mac = want_mac.hexdigest() body = json.dumps( @@ -330,11 +328,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # # Invalid user_type - body = json.dumps({ - "nonce": nonce(), - "username": "a", - "password": "1234", - "user_type": "invalid"} + body = json.dumps( + { + "nonce": nonce(), + "username": "a", + "password": "1234", + "user_type": "invalid", + } ) request, channel = self.make_request("POST", self.url, body.encode('utf8')) self.render(request) @@ -357,9 +357,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): hs.config.user_consent_version = "1" consent_uri_builder = Mock() - consent_uri_builder.build_user_consent_uri.return_value = ( - "http://example.com" - ) + consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" self.event_creation_handler._consent_uri_builder = consent_uri_builder self.store = hs.get_datastore() @@ -371,9 +369,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): self.other_user_token = self.login("user", "pass") # Mark the admin user as having consented - self.get_success( - self.store.user_set_consent_version(self.admin_user, "1"), - ) + self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) def test_shutdown_room_consent(self): """Test that we can shutdown rooms with local users who have not @@ -385,9 +381,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) # Assert one user in room - users_in_room = self.get_success( - self.store.get_users_in_room(room_id), - ) + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) self.assertEqual([self.other_user], users_in_room) # Enable require consent to send events @@ -395,8 +389,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): # Assert that the user is getting consent error self.helper.send( - room_id, - body="foo", tok=self.other_user_token, expect_code=403, + room_id, body="foo", tok=self.other_user_token, expect_code=403 ) # Test that the admin can still send shutdown @@ -412,9 +405,7 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Assert there is now no longer anyone in the room - users_in_room = self.get_success( - self.store.get_users_in_room(room_id), - ) + users_in_room = self.get_success(self.store.get_users_in_room(room_id)) self.assertEqual([], users_in_room) @unittest.DEBUG @@ -459,24 +450,20 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase): url = "rooms/%s/initialSync" % (room_id,) request, channel = self.make_request( - "GET", - url.encode('ascii'), - access_token=self.admin_user_tok, + "GET", url.encode('ascii'), access_token=self.admin_user_tok ) self.render(request) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"], + expect_code, int(channel.result["code"]), msg=channel.result["body"] ) url = "events?timeout=0&room_id=" + room_id request, channel = self.make_request( - "GET", - url.encode('ascii'), - access_token=self.admin_user_tok, + "GET", url.encode('ascii'), access_token=self.admin_user_tok ) self.render(request) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"], + expect_code, int(channel.result["code"]), msg=channel.result["body"] ) @@ -502,15 +489,11 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): "POST", "/create_group".encode('ascii'), access_token=self.admin_user_tok, - content={ - "localpart": "test", - } + content={"localpart": "test"}, ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) group_id = channel.json_body["group_id"] @@ -520,27 +503,17 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user) request, channel = self.make_request( - "PUT", - url.encode('ascii'), - access_token=self.admin_user_tok, - content={} + "PUT", url.encode('ascii'), access_token=self.admin_user_tok, content={} ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) url = "/groups/%s/self/accept_invite" % (group_id,) request, channel = self.make_request( - "PUT", - url.encode('ascii'), - access_token=self.other_user_token, - content={} + "PUT", url.encode('ascii'), access_token=self.other_user_token, content={} ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Check other user knows they're in the group self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok)) @@ -552,15 +525,11 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): "POST", url.encode('ascii'), access_token=self.admin_user_tok, - content={ - "localpart": "test", - } + content={"localpart": "test"}, ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) # Check group returns 404 self._check_group(group_id, expect_code=404) @@ -576,28 +545,22 @@ class DeleteGroupTestCase(unittest.HomeserverTestCase): url = "/groups/%s/profile" % (group_id,) request, channel = self.make_request( - "GET", - url.encode('ascii'), - access_token=self.admin_user_tok, + "GET", url.encode('ascii'), access_token=self.admin_user_tok ) self.render(request) self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"], + expect_code, int(channel.result["code"]), msg=channel.result["body"] ) def _get_groups_user_is_in(self, access_token): """Returns the list of groups the user is in (given their access token) """ request, channel = self.make_request( - "GET", - "/joined_groups".encode('ascii'), - access_token=access_token, + "GET", "/joined_groups".encode('ascii'), access_token=access_token ) self.render(request) - self.assertEqual( - 200, int(channel.result["code"]), msg=channel.result["body"], - ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) return channel.json_body["groups"] diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index 2e51ffa418..1a714ff58a 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -44,7 +44,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): tok = self.login("kermit", "monkey") request, channel = self.make_request( - b"POST", "/createRoom", b"{}", access_token=tok, + b"POST", "/createRoom", b"{}", access_token=tok ) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -56,11 +56,9 @@ class IdentityTestCase(unittest.HomeserverTestCase): "address": "test@example.com", } request_data = json.dumps(params) - request_url = ( - "/rooms/%s/invite" % (room_id) - ).encode('ascii') + request_url = ("/rooms/%s/invite" % (room_id)).encode('ascii') request, channel = self.make_request( - b"POST", request_url, request_data, access_token=tok, + b"POST", request_url, request_data, access_token=tok ) self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py index f63c68e7ed..73c5b44b46 100644 --- a/tests/rest/client/v1/test_directory.py +++ b/tests/rest/client/v1/test_directory.py @@ -45,7 +45,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.room_owner_tok = self.login("room_owner", "test") self.room_id = self.helper.create_room_as( - self.room_owner, tok=self.room_owner_tok, + self.room_owner, tok=self.room_owner_tok ) self.user = self.register_user("user", "test") @@ -80,12 +80,10 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # We use deliberately a localpart under the length threshold so # that we can make sure that the check is done on the whole alias. - data = { - "room_alias_name": random_string(256 - len(self.hs.hostname)), - } + data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} request_data = json.dumps(data) request, channel = self.make_request( - "POST", url, request_data, access_token=self.user_tok, + "POST", url, request_data, access_token=self.user_tok ) self.render(request) self.assertEqual(channel.code, 400, channel.result) @@ -96,51 +94,42 @@ class DirectoryTestCase(unittest.HomeserverTestCase): # Check with an alias of allowed length. There should already be # a test that ensures it works in test_register.py, but let's be # as cautious as possible here. - data = { - "room_alias_name": random_string(5), - } + data = {"room_alias_name": random_string(5)} request_data = json.dumps(data) request, channel = self.make_request( - "POST", url, request_data, access_token=self.user_tok, + "POST", url, request_data, access_token=self.user_tok ) self.render(request) self.assertEqual(channel.code, 200, channel.result) def set_alias_via_state_event(self, expected_code, alias_length=5): - url = ("/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" - % (self.room_id, self.hs.hostname)) - - data = { - "aliases": [ - self.random_alias(alias_length), - ], - } + url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( + self.room_id, + self.hs.hostname, + ) + + data = {"aliases": [self.random_alias(alias_length)]} request_data = json.dumps(data) request, channel = self.make_request( - "PUT", url, request_data, access_token=self.user_tok, + "PUT", url, request_data, access_token=self.user_tok ) self.render(request) self.assertEqual(channel.code, expected_code, channel.result) def set_alias_via_directory(self, expected_code, alias_length=5): url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length) - data = { - "room_id": self.room_id, - } + data = {"room_id": self.room_id} request_data = json.dumps(data) request, channel = self.make_request( - "PUT", url, request_data, access_token=self.user_tok, + "PUT", url, request_data, access_token=self.user_tok ) self.render(request) self.assertEqual(channel.code, expected_code, channel.result) def random_alias(self, length): - return RoomAlias( - random_string(length), - self.hs.hostname, - ).to_string() + return RoomAlias(random_string(length), self.hs.hostname).to_string() def ensure_user_left_room(self): self.ensure_membership("leave") @@ -151,17 +140,9 @@ class DirectoryTestCase(unittest.HomeserverTestCase): def ensure_membership(self, membership): try: if membership == "leave": - self.helper.leave( - room=self.room_id, - user=self.user, - tok=self.user_tok, - ) + self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok) if membership == "join": - self.helper.join( - room=self.room_id, - user=self.user, - tok=self.user_tok, - ) + self.helper.join(room=self.room_id, user=self.user, tok=self.user_tok) except AssertionError: # We don't care whether the leave request didn't return a 200 (e.g. # if the user isn't already in the room), because we only want to diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 9ebd91f678..0397f91a9e 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -37,10 +37,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): for i in range(0, 6): params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit" + str(i), - }, + "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } request_data = json.dumps(params) @@ -57,14 +54,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit" + str(i), - }, + "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, "password": "monkey", } request_data = json.dumps(params) @@ -82,10 +76,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): for i in range(0, 6): params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit", - }, + "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } request_data = json.dumps(params) @@ -102,14 +93,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit", - }, + "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "monkey", } request_data = json.dumps(params) @@ -127,10 +115,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): for i in range(0, 6): params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit", - }, + "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } request_data = json.dumps(params) @@ -147,14 +132,11 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # than 1min. self.assertTrue(retry_after_ms < 6000) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) params = { "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": "kermit", - }, + "identifier": {"type": "m.id.user", "user": "kermit"}, "password": "notamonkey", } request_data = json.dumps(params) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 7306e61b7c..ed034879cf 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -199,37 +199,24 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): def test_in_shared_room(self): self.ensure_requester_left_room() - self.helper.join( - room=self.room_id, - user=self.requester, - tok=self.requester_tok, - ) + self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok) self.try_fetch_profile(200, self.requester_tok) def try_fetch_profile(self, expected_code, access_token=None): - self.request_profile( - expected_code, - access_token=access_token - ) + self.request_profile(expected_code, access_token=access_token) self.request_profile( - expected_code, - url_suffix="/displayname", - access_token=access_token, + expected_code, url_suffix="/displayname", access_token=access_token ) self.request_profile( - expected_code, - url_suffix="/avatar_url", - access_token=access_token, + expected_code, url_suffix="/avatar_url", access_token=access_token ) def request_profile(self, expected_code, url_suffix="", access_token=None): request, channel = self.make_request( - "GET", - self.profile_url + url_suffix, - access_token=access_token, + "GET", self.profile_url + url_suffix, access_token=access_token ) self.render(request) self.assertEqual(channel.code, expected_code, channel.result) @@ -237,9 +224,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): def ensure_requester_left_room(self): try: self.helper.leave( - room=self.room_id, - user=self.requester, - tok=self.requester_tok, + room=self.room_id, user=self.requester, tok=self.requester_tok ) except AssertionError: # We don't care whether the leave request didn't return a 200 (e.g. diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 1c3a621d26..be95dc592d 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -41,11 +41,10 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): as_token = "i_am_an_app_service" appservice = ApplicationService( - as_token, self.hs.config.server_name, + as_token, + self.hs.config.server_name, id="1234", - namespaces={ - "users": [{"regex": r"@as_user.*", "exclusive": True}], - }, + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, ) self.hs.get_datastore().services_cache.append(appservice) @@ -57,10 +56,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) - det_data = { - "user_id": user_id, - "home_server": self.hs.hostname, - } + det_data = {"user_id": user_id, "home_server": self.hs.hostname} self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_appservice_registration_invalid(self): @@ -128,10 +124,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) - det_data = { - "home_server": self.hs.hostname, - "device_id": "guest_device", - } + det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} self.assertEquals(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) @@ -159,7 +152,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): else: self.assertEquals(channel.result["code"], b"200", channel.result) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) @@ -187,7 +180,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): else: self.assertEquals(channel.result["code"], b"200", channel.result) - self.reactor.advance(retry_after_ms / 1000.) + self.reactor.advance(retry_after_ms / 1000.0) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) @@ -221,23 +214,19 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. - request, channel = self.make_request( - b"GET", "/sync", access_token=tok, - ) + request, channel = self.make_request(b"GET", "/sync", access_token=tok) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) - request, channel = self.make_request( - b"GET", "/sync", access_token=tok, - ) + request, channel = self.make_request(b"GET", "/sync", access_token=tok) self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals( - channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result, + channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) def test_manual_renewal(self): @@ -253,21 +242,17 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): admin_tok = self.login("admin", "adminpassword") url = "/_matrix/client/unstable/admin/account_validity/validity" - params = { - "user_id": user_id, - } + params = {"user_id": user_id} request_data = json.dumps(params) request, channel = self.make_request( - b"POST", url, request_data, access_token=admin_tok, + b"POST", url, request_data, access_token=admin_tok ) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. - request, channel = self.make_request( - b"GET", "/sync", access_token=tok, - ) + request, channel = self.make_request(b"GET", "/sync", access_token=tok) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -286,20 +271,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): } request_data = json.dumps(params) request, channel = self.make_request( - b"POST", url, request_data, access_token=admin_tok, + b"POST", url, request_data, access_token=admin_tok ) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. - request, channel = self.make_request( - b"GET", "/sync", access_token=tok, - ) + request, channel = self.make_request(b"GET", "/sync", access_token=tok) self.render(request) self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals( - channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result, + channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) @@ -358,10 +341,15 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # We need to manually add an email address otherwise the handler will do # nothing. now = self.hs.clock.time_msec() - self.get_success(self.store.user_add_threepid( - user_id=user_id, medium="email", address="kermit@example.com", - validated_at=now, added_at=now, - )) + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address="kermit@example.com", + validated_at=now, + added_at=now, + ) + ) # Move 6 days forward. This should trigger a renewal email to be sent. self.reactor.advance(datetime.timedelta(days=6).total_seconds()) @@ -379,9 +367,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # our access token should be denied from now, otherwise they should # succeed. self.reactor.advance(datetime.timedelta(days=3).total_seconds()) - request, channel = self.make_request( - b"GET", "/sync", access_token=tok, - ) + request, channel = self.make_request(b"GET", "/sync", access_token=tok) self.render(request) self.assertEquals(channel.result["code"], b"200", channel.result) @@ -393,13 +379,19 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # We need to manually add an email address otherwise the handler will do # nothing. now = self.hs.clock.time_msec() - self.get_success(self.store.user_add_threepid( - user_id=user_id, medium="email", address="kermit@example.com", - validated_at=now, added_at=now, - )) + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address="kermit@example.com", + validated_at=now, + added_at=now, + ) + ) request, channel = self.make_request( - b"POST", "/_matrix/client/unstable/account_validity/send_mail", + b"POST", + "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) self.render(request) diff --git a/tests/rest/media/v1/test_base.py b/tests/rest/media/v1/test_base.py index af8f74eb42..00688a7325 100644 --- a/tests/rest/media/v1/test_base.py +++ b/tests/rest/media/v1/test_base.py @@ -26,20 +26,14 @@ class GetFileNameFromHeadersTests(unittest.TestCase): b'inline; filename="aze%20rty"': u"aze%20rty", b'inline; filename="aze\"rty"': u'aze"rty', b'inline; filename="azer;ty"': u"azer;ty", - b"inline; filename*=utf-8''foo%C2%A3bar": u"foo£bar", } def tests(self): for hdr, expected in self.TEST_CASES.items(): - res = get_filename_from_headers( - { - b'Content-Disposition': [hdr], - }, - ) + res = get_filename_from_headers({b'Content-Disposition': [hdr]}) self.assertEqual( - res, expected, - "expected output for %s to be %s but was %s" % ( - hdr, expected, res, - ) + res, + expected, + "expected output for %s to be %s but was %s" % (hdr, expected, res), ) diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 8d8f03e005..b090bb974c 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -31,27 +31,24 @@ class WellKnownTests(unittest.HomeserverTestCase): self.hs.config.default_identity_server = "https://testis" request, channel = self.make_request( - "GET", - "/.well-known/matrix/client", - shorthand=False, + "GET", "/.well-known/matrix/client", shorthand=False ) self.render(request) self.assertEqual(request.code, 200) self.assertEqual( - channel.json_body, { + channel.json_body, + { "m.homeserver": {"base_url": "https://tesths"}, "m.identity_server": {"base_url": "https://testis"}, - } + }, ) def test_well_known_no_public_baseurl(self): self.hs.config.public_baseurl = None request, channel = self.make_request( - "GET", - "/.well-known/matrix/client", - shorthand=False, + "GET", "/.well-known/matrix/client", shorthand=False ) self.render(request) diff --git a/tests/server.py b/tests/server.py index 8f89f4a83d..fc41345488 100644 --- a/tests/server.py +++ b/tests/server.py @@ -182,7 +182,8 @@ def make_request( if federation_auth_origin is not None: req.requestHeaders.addRawHeader( - b"Authorization", b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,) + b"Authorization", + b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,), ) if content: @@ -233,7 +234,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): class FakeResolver(object): def getHostByName(self, name, timeout=None): if name not in lookups: - return fail(DNSLookupError("OH NO: unknown %s" % (name, ))) + return fail(DNSLookupError("OH NO: unknown %s" % (name,))) return succeed(lookups[name]) self.nameResolver = SimpleResolverComplexifier(FakeResolver()) @@ -454,6 +455,6 @@ class FakeTransport(object): logger.warning("Exception writing to protocol: %s", e) return - self.buffer = self.buffer[len(to_write):] + self.buffer = self.buffer[len(to_write) :] if self.buffer and self.autoflush: self._reactor.callLater(0.0, self.flush) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index be73e718c2..a490b81ed4 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -27,7 +27,6 @@ from tests import unittest class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): hs_config = self.default_config("test") hs_config.server_notices_mxid = "@server:test" diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index f448b01326..9c5311d916 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -50,6 +50,7 @@ class FakeEvent(object): refer to events. The event_id has node_id as localpart and example.com as domain. """ + def __init__(self, id, sender, type, state_key, content): self.node_id = id self.event_id = EventID(id, "example.com").to_string() @@ -142,24 +143,14 @@ INITIAL_EVENTS = [ content=MEMBERSHIP_CONTENT_JOIN, ), FakeEvent( - id="START", - sender=ZARA, - type=EventTypes.Message, - state_key=None, - content={}, + id="START", sender=ZARA, type=EventTypes.Message, state_key=None, content={} ), FakeEvent( - id="END", - sender=ZARA, - type=EventTypes.Message, - state_key=None, - content={}, + id="END", sender=ZARA, type=EventTypes.Message, state_key=None, content={} ), ] -INITIAL_EDGES = [ - "START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE", -] +INITIAL_EDGES = ["START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE"] class StateTestCase(unittest.TestCase): @@ -170,12 +161,7 @@ class StateTestCase(unittest.TestCase): sender=ALICE, type=EventTypes.PowerLevels, state_key="", - content={ - "users": { - ALICE: 100, - BOB: 50, - } - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( id="MA", @@ -196,19 +182,11 @@ class StateTestCase(unittest.TestCase): sender=BOB, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), ] - edges = [ - ["END", "MB", "MA", "PA", "START"], - ["END", "PB", "PA"], - ] + edges = [["END", "MB", "MA", "PA", "START"], ["END", "PB", "PA"]] expected_state_ids = ["PA", "MA", "MB"] @@ -232,10 +210,7 @@ class StateTestCase(unittest.TestCase): ), ] - edges = [ - ["END", "JR", "START"], - ["END", "ME", "START"], - ] + edges = [["END", "JR", "START"], ["END", "ME", "START"]] expected_state_ids = ["JR"] @@ -248,45 +223,25 @@ class StateTestCase(unittest.TestCase): sender=ALICE, type=EventTypes.PowerLevels, state_key="", - content={ - "users": { - ALICE: 100, - BOB: 50, - } - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - CHARLIE: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50, CHARLIE: 50}}, ), FakeEvent( id="PC", sender=CHARLIE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - CHARLIE: 0, - }, - }, + content={"users": {ALICE: 100, BOB: 50, CHARLIE: 0}}, ), ] - edges = [ - ["END", "PC", "PB", "PA", "START"], - ["END", "PA"], - ] + edges = [["END", "PC", "PB", "PA", "START"], ["END", "PA"]] expected_state_ids = ["PC"] @@ -295,68 +250,38 @@ class StateTestCase(unittest.TestCase): def test_topic_basic(self): events = [ FakeEvent( - id="T1", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA1", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T2", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA2", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 0, - }, - }, + content={"users": {ALICE: 100, BOB: 0}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T3", - sender=BOB, - type=EventTypes.Topic, - state_key="", - content={}, + id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={} ), ] - edges = [ - ["END", "PA2", "T2", "PA1", "T1", "START"], - ["END", "T3", "PB", "PA1"], - ] + edges = [["END", "PA2", "T2", "PA1", "T1", "START"], ["END", "T3", "PB", "PA1"]] expected_state_ids = ["PA2", "T2"] @@ -365,30 +290,17 @@ class StateTestCase(unittest.TestCase): def test_topic_reset(self): events = [ FakeEvent( - id="T1", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T2", - sender=BOB, - type=EventTypes.Topic, - state_key="", - content={}, + id="T2", sender=BOB, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="MB", @@ -399,10 +311,7 @@ class StateTestCase(unittest.TestCase): ), ] - edges = [ - ["END", "MB", "T2", "PA", "T1", "START"], - ["END", "T1"], - ] + edges = [["END", "MB", "T2", "PA", "T1", "START"], ["END", "T1"]] expected_state_ids = ["T1", "MB", "PA"] @@ -411,61 +320,34 @@ class StateTestCase(unittest.TestCase): def test_topic(self): events = [ FakeEvent( - id="T1", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA1", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T2", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="PA2", sender=ALICE, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 0, - }, - }, + content={"users": {ALICE: 100, BOB: 0}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, state_key='', - content={ - "users": { - ALICE: 100, - BOB: 50, - }, - }, + content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( - id="T3", - sender=BOB, - type=EventTypes.Topic, - state_key="", - content={}, + id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={} ), FakeEvent( id="MZ1", @@ -475,11 +357,7 @@ class StateTestCase(unittest.TestCase): content={}, ), FakeEvent( - id="T4", - sender=ALICE, - type=EventTypes.Topic, - state_key="", - content={}, + id="T4", sender=ALICE, type=EventTypes.Topic, state_key="", content={} ), ] @@ -587,13 +465,7 @@ class StateTestCase(unittest.TestCase): class LexicographicalTestCase(unittest.TestCase): def test_simple(self): - graph = { - "l": {"o"}, - "m": {"n", "o"}, - "n": {"o"}, - "o": set(), - "p": {"o"}, - } + graph = {"l": {"o"}, "m": {"n", "o"}, "n": {"o"}, "o": set(), "p": {"o"}} res = list(lexicographical_topological_sort(graph, key=lambda x: x)) @@ -680,7 +552,13 @@ class SimpleParamStateTestCase(unittest.TestCase): self.expected_combined_state = { (e.type, e.state_key): e.event_id - for e in [create_event, alice_member, join_rules, bob_member, charlie_member] + for e in [ + create_event, + alice_member, + join_rules, + bob_member, + charlie_member, + ] } def test_event_map_none(self): @@ -720,11 +598,7 @@ class TestStateResolutionStore(object): Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. """ - return { - eid: self.event_map[eid] - for eid in event_ids - if eid in self.event_map - } + return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} def get_auth_chain(self, event_ids): """Gets the full auth chain for a set of events (including rejected diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 5568a607c7..fbb9302694 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -9,9 +9,7 @@ from tests.utils import setup_test_homeserver class BackgroundUpdateTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - hs = yield setup_test_homeserver( - self.addCleanup - ) + hs = yield setup_test_homeserver(self.addCleanup) self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index f18db8c384..c778de1f0c 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -56,10 +56,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): fake_engine = Mock(wraps=engine) fake_engine.can_native_upsert = False hs = TestHomeServer( - "test", - db_pool=self.db_pool, - config=config, - database_engine=fake_engine, + "test", db_pool=self.db_pool, config=config, database_engine=fake_engine ) self.datastore = SQLBaseStore(None, hs) diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 11fb8c0c19..cd2bcd4ca3 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -20,7 +20,6 @@ import tests.utils class EndToEndKeyStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks def setUp(self): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index d6569a82bb..f458c03054 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -56,8 +56,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.register(user_id=user1, token="123", password_hash=None) self.store.register(user_id=user2, token="456", password_hash=None) self.store.register( - user_id=user3, token="789", - password_hash=None, user_type=UserTypes.SUPPORT + user_id=user3, token="789", password_hash=None, user_type=UserTypes.SUPPORT ) self.pump() @@ -173,9 +172,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): def test_populate_monthly_users_should_update(self): self.store.upsert_monthly_active_user = Mock() - self.store.is_trial_user = Mock( - return_value=defer.succeed(False) - ) + self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.user_last_seen_monthly_active = Mock( return_value=defer.succeed(None) @@ -187,13 +184,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): def test_populate_monthly_users_should_not_update(self): self.store.upsert_monthly_active_user = Mock() - self.store.is_trial_user = Mock( - return_value=defer.succeed(False) - ) + self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed( - self.hs.get_clock().time_msec() - ) + return_value=defer.succeed(self.hs.get_clock().time_msec()) ) self.store.populate_monthly_active_users('user_id') self.pump() @@ -243,7 +236,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): user_id=support_user_id, token="123", password_hash=None, - user_type=UserTypes.SUPPORT + user_type=UserTypes.SUPPORT, ) self.store.upsert_monthly_active_user(support_user_id) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 0fc5019e9f..4823d44dec 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -60,7 +60,7 @@ class RedactionTestCase(unittest.TestCase): "state_key": user.to_string(), "room_id": room.to_string(), "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -83,7 +83,7 @@ class RedactionTestCase(unittest.TestCase): "state_key": user.to_string(), "room_id": room.to_string(), "content": {"body": body, "msgtype": u"message"}, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -105,7 +105,7 @@ class RedactionTestCase(unittest.TestCase): "room_id": room.to_string(), "content": {"reason": reason}, "redacts": event_id, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index cb3cc4d2e5..c0e0155bb4 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -116,7 +116,7 @@ class RegistrationStoreTestCase(unittest.TestCase): user_id=SUPPORT_USER, token="456", password_hash=None, - user_type=UserTypes.SUPPORT + user_type=UserTypes.SUPPORT, ) res = yield self.store.is_support_user(SUPPORT_USER) self.assertTrue(res) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 063387863e..73ed943f5a 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -58,7 +58,7 @@ class RoomMemberStoreTestCase(unittest.TestCase): "state_key": user.to_string(), "room_id": room.to_string(), "content": {"membership": membership}, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 78e260a7fa..b6169436de 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -29,7 +29,6 @@ logger = logging.getLogger(__name__) class StateStoreTestCase(tests.unittest.TestCase): - @defer.inlineCallbacks def setUp(self): hs = yield tests.utils.setup_test_homeserver(self.addCleanup) @@ -57,7 +56,7 @@ class StateStoreTestCase(tests.unittest.TestCase): "state_key": state_key, "room_id": room.to_string(), "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -83,15 +82,14 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id]) + state_group_map = yield self.store.get_state_groups_ids( + self.room, [e2.event_id] + ) self.assertEqual(len(state_group_map), 1) state_map = list(state_group_map.values())[0] self.assertDictEqual( state_map, - { - (EventTypes.Create, ''): e1.event_id, - (EventTypes.Name, ''): e2.event_id, - }, + {(EventTypes.Create, ''): e1.event_id, (EventTypes.Name, ''): e2.event_id}, ) @defer.inlineCallbacks @@ -103,15 +101,11 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"} ) - state_group_map = yield self.store.get_state_groups( - self.room, [e2.event_id]) + state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id]) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] - self.assertEqual( - {ev.event_id for ev in state_list}, - {e1.event_id, e2.event_id}, - ) + self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id}) @defer.inlineCallbacks def test_get_state_for_event(self): @@ -147,9 +141,7 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield self.store.get_state_for_event( - e5.event_id, - ) + state = yield self.store.get_state_for_event(e5.event_id) self.assertIsNotNone(e4) @@ -194,7 +186,7 @@ class StateStoreTestCase(tests.unittest.TestCase): state_filter=StateFilter( types={EventTypes.Member: {self.u_alice.to_string()}}, include_others=True, - ) + ), ) self.assertStateMapEqual( @@ -208,9 +200,9 @@ class StateStoreTestCase(tests.unittest.TestCase): # check that we can grab everything except members state = yield self.store.get_state_for_event( - e5.event_id, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + e5.event_id, + state_filter=StateFilter( + types={EventTypes.Member: set()}, include_others=True ), ) @@ -229,10 +221,10 @@ class StateStoreTestCase(tests.unittest.TestCase): # test _get_state_for_group_using_cache correctly filters out members # with types=[] (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, group, + self.store._state_group_cache, + group, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + types={EventTypes.Member: set()}, include_others=True ), ) @@ -249,8 +241,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + types={EventTypes.Member: set()}, include_others=True ), ) @@ -263,8 +254,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, - include_others=True, + types={EventTypes.Member: None}, include_others=True ), ) @@ -281,8 +271,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, - include_others=True, + types={EventTypes.Member: None}, include_others=True ), ) @@ -302,8 +291,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=True, + types={EventTypes.Member: {e5.state_key}}, include_others=True ), ) @@ -320,8 +308,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=True, + types={EventTypes.Member: {e5.state_key}}, include_others=True ), ) @@ -334,8 +321,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=False, + types={EventTypes.Member: {e5.state_key}}, include_others=False ), ) @@ -384,10 +370,10 @@ class StateStoreTestCase(tests.unittest.TestCase): # with types=[] room_id = self.room.to_string() (state_dict, is_all) = yield self.store._get_state_for_group_using_cache( - self.store._state_group_cache, group, + self.store._state_group_cache, + group, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + types={EventTypes.Member: set()}, include_others=True ), ) @@ -399,8 +385,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: set()}, - include_others=True, + types={EventTypes.Member: set()}, include_others=True ), ) @@ -413,8 +398,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, - include_others=True, + types={EventTypes.Member: None}, include_others=True ), ) @@ -425,8 +409,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: None}, - include_others=True, + types={EventTypes.Member: None}, include_others=True ), ) @@ -445,8 +428,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=True, + types={EventTypes.Member: {e5.state_key}}, include_others=True ), ) @@ -457,8 +439,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=True, + types={EventTypes.Member: {e5.state_key}}, include_others=True ), ) @@ -471,8 +452,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=False, + types={EventTypes.Member: {e5.state_key}}, include_others=False ), ) @@ -483,8 +463,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store._state_group_members_cache, group, state_filter=StateFilter( - types={EventTypes.Member: {e5.state_key}}, - include_others=False, + types={EventTypes.Member: {e5.state_key}}, include_others=False ), ) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index fd3361404f..d7d244ce97 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -36,9 +36,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): yield self.store.update_profile_in_user_dir(ALICE, "alice", None) yield self.store.update_profile_in_user_dir(BOB, "bob", None) yield self.store.update_profile_in_user_dir(BOBBY, "bobby", None) - yield self.store.add_users_in_public_rooms( - "!room:id", (ALICE, BOB) - ) + yield self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)) @defer.inlineCallbacks def test_search_user_dir(self): diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 4c8f87e958..8b2741d277 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -37,7 +37,9 @@ class EventAuthTestCase(unittest.TestCase): # creator should be able to send state event_auth.check( - RoomVersions.V1.identifier, _random_state_event(creator), auth_events, + RoomVersions.V1.identifier, + _random_state_event(creator), + auth_events, do_sig_check=False, ) @@ -82,7 +84,9 @@ class EventAuthTestCase(unittest.TestCase): # king should be able to send state event_auth.check( - RoomVersions.V1.identifier, _random_state_event(king), auth_events, + RoomVersions.V1.identifier, + _random_state_event(king), + auth_events, do_sig_check=False, ) diff --git a/tests/test_federation.py b/tests/test_federation.py index 1a5dc32c88..6a8339b561 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -1,4 +1,3 @@ - from mock import Mock from twisted.internet.defer import maybeDeferred, succeed diff --git a/tests/test_mau.py b/tests/test_mau.py index 00be1a8c21..1fbe0d51ff 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -33,9 +33,7 @@ class TestMauLimit(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), + "red", http_client=None, federation_client=Mock() ) self.store = self.hs.get_datastore() @@ -210,9 +208,7 @@ class TestMauLimit(unittest.HomeserverTestCase): return access_token def do_sync_for_user(self, token): - request, channel = self.make_request( - "GET", "/sync", access_token=token - ) + request, channel = self.make_request("GET", "/sync", access_token=token) self.render(request) if channel.code != 200: diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 0ff6d0e283..2edbae5c6d 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -44,9 +44,7 @@ def get_sample_labels_value(sample): class TestMauLimit(unittest.TestCase): def test_basic(self): gauge = InFlightGauge( - "test1", "", - labels=["test_label"], - sub_metrics=["foo", "bar"], + "test1", "", labels=["test_label"], sub_metrics=["foo", "bar"] ) def handle1(metrics): @@ -59,37 +57,49 @@ class TestMauLimit(unittest.TestCase): gauge.register(("key1",), handle1) - self.assert_dict({ - "test1_total": {("key1",): 1}, - "test1_foo": {("key1",): 2}, - "test1_bar": {("key1",): 5}, - }, self.get_metrics_from_gauge(gauge)) + self.assert_dict( + { + "test1_total": {("key1",): 1}, + "test1_foo": {("key1",): 2}, + "test1_bar": {("key1",): 5}, + }, + self.get_metrics_from_gauge(gauge), + ) gauge.unregister(("key1",), handle1) - self.assert_dict({ - "test1_total": {("key1",): 0}, - "test1_foo": {("key1",): 0}, - "test1_bar": {("key1",): 0}, - }, self.get_metrics_from_gauge(gauge)) + self.assert_dict( + { + "test1_total": {("key1",): 0}, + "test1_foo": {("key1",): 0}, + "test1_bar": {("key1",): 0}, + }, + self.get_metrics_from_gauge(gauge), + ) gauge.register(("key1",), handle1) gauge.register(("key2",), handle2) - self.assert_dict({ - "test1_total": {("key1",): 1, ("key2",): 1}, - "test1_foo": {("key1",): 2, ("key2",): 3}, - "test1_bar": {("key1",): 5, ("key2",): 7}, - }, self.get_metrics_from_gauge(gauge)) + self.assert_dict( + { + "test1_total": {("key1",): 1, ("key2",): 1}, + "test1_foo": {("key1",): 2, ("key2",): 3}, + "test1_bar": {("key1",): 5, ("key2",): 7}, + }, + self.get_metrics_from_gauge(gauge), + ) gauge.unregister(("key2",), handle2) gauge.register(("key1",), handle2) - self.assert_dict({ - "test1_total": {("key1",): 2, ("key2",): 0}, - "test1_foo": {("key1",): 5, ("key2",): 0}, - "test1_bar": {("key1",): 7, ("key2",): 0}, - }, self.get_metrics_from_gauge(gauge)) + self.assert_dict( + { + "test1_total": {("key1",): 2, ("key2",): 0}, + "test1_foo": {("key1",): 5, ("key2",): 0}, + "test1_bar": {("key1",): 7, ("key2",): 0}, + }, + self.get_metrics_from_gauge(gauge), + ) def get_metrics_from_gauge(self, gauge): results = {} diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 0968e86a7b..f412985d2c 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -69,10 +69,10 @@ class TermsTestCase(unittest.HomeserverTestCase): "name": "My Cool Privacy Policy", "url": "https://example.org/_matrix/consent?v=1.0", }, - "version": "1.0" - }, - }, - }, + "version": "1.0", + } + } + } } self.assertIsInstance(channel.json_body["params"], dict) self.assertDictContainsSubset(channel.json_body["params"], expected_params) diff --git a/tests/test_types.py b/tests/test_types.py index d314a7ff58..d83c36559f 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -94,8 +94,7 @@ class MapUsernameTestCase(unittest.TestCase): def testSymbols(self): self.assertEqual( - map_username_to_mxid_localpart("test=$?_1234"), - "test=3d=24=3f_1234", + map_username_to_mxid_localpart("test=$?_1234"), "test=3d=24=3f_1234" ) def testLeadingUnderscore(self): @@ -105,6 +104,5 @@ class MapUsernameTestCase(unittest.TestCase): # this should work with either a unicode or a bytes self.assertEqual(map_username_to_mxid_localpart(u'têst'), "t=c3=aast") self.assertEqual( - map_username_to_mxid_localpart(u'têst'.encode('utf-8')), - "t=c3=aast", + map_username_to_mxid_localpart(u'têst'.encode('utf-8')), "t=c3=aast" ) diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index d0bc8e2112..fde0baee8e 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -22,6 +22,7 @@ from synapse.util.logcontext import LoggingContextFilter class ToTwistedHandler(logging.Handler): """logging handler which sends the logs to the twisted log""" + tx_log = twisted.logger.Logger() def emit(self, record): @@ -41,7 +42,8 @@ def setup_logging(): root_logger = logging.getLogger() log_format = ( - "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s" + "%(asctime)s - %(name)s - %(lineno)d - " + "%(levelname)s - %(request)s - %(message)s" ) handler = ToTwistedHandler() diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 3bdb500514..6a180ddc32 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -132,7 +132,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): "state_key": "", "room_id": TEST_ROOM_ID, "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -153,7 +153,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): "state_key": user_id, "room_id": TEST_ROOM_ID, "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( @@ -174,7 +174,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): "sender": user_id, "room_id": TEST_ROOM_ID, "content": content, - } + }, ) event, context = yield self.event_creation_handler.create_new_client_event( diff --git a/tests/unittest.py b/tests/unittest.py index 029a88d770..94df8cf47e 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -84,9 +84,8 @@ class TestCase(unittest.TestCase): # all future bets are off. if LoggingContext.current_context() is not LoggingContext.sentinel: self.fail( - "Test starting with non-sentinel logging context %s" % ( - LoggingContext.current_context(), - ) + "Test starting with non-sentinel logging context %s" + % (LoggingContext.current_context(),) ) old_level = logging.getLogger().level @@ -300,7 +299,13 @@ class HomeserverTestCase(TestCase): content = json.dumps(content).encode('utf8') return make_request( - self.reactor, method, path, content, access_token, request, shorthand, + self.reactor, + method, + path, + content, + access_token, + request, + shorthand, federation_auth_origin, ) diff --git a/tests/util/test_async_utils.py b/tests/util/test_async_utils.py index 84dd71e47a..bf85d3b8ec 100644 --- a/tests/util/test_async_utils.py +++ b/tests/util/test_async_utils.py @@ -42,10 +42,10 @@ class TimeoutDeferredTest(TestCase): self.assertNoResult(timing_out_d) self.assertFalse(cancelled[0], "deferred was cancelled prematurely") - self.clock.pump((1.0, )) + self.clock.pump((1.0,)) self.assertTrue(cancelled[0], "deferred was not cancelled by timeout") - self.failureResultOf(timing_out_d, defer.TimeoutError, ) + self.failureResultOf(timing_out_d, defer.TimeoutError) def test_times_out_when_canceller_throws(self): """Test that we have successfully worked around @@ -59,9 +59,9 @@ class TimeoutDeferredTest(TestCase): self.assertNoResult(timing_out_d) - self.clock.pump((1.0, )) + self.clock.pump((1.0,)) - self.failureResultOf(timing_out_d, defer.TimeoutError, ) + self.failureResultOf(timing_out_d, defer.TimeoutError) def test_logcontext_is_preserved_on_cancellation(self): blocking_was_cancelled = [False] @@ -80,10 +80,10 @@ class TimeoutDeferredTest(TestCase): # the errbacks should be run in the test logcontext def errback(res, deferred_name): self.assertIs( - LoggingContext.current_context(), context_one, - "errback %s run in unexpected logcontext %s" % ( - deferred_name, LoggingContext.current_context(), - ) + LoggingContext.current_context(), + context_one, + "errback %s run in unexpected logcontext %s" + % (deferred_name, LoggingContext.current_context()), ) return res @@ -94,11 +94,10 @@ class TimeoutDeferredTest(TestCase): self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) timing_out_d.addErrback(errback, "timingout") - self.clock.pump((1.0, )) + self.clock.pump((1.0,)) self.assertTrue( - blocking_was_cancelled[0], - "non-completing deferred was not cancelled", + blocking_was_cancelled[0], "non-completing deferred was not cancelled" ) - self.failureResultOf(timing_out_d, defer.TimeoutError, ) + self.failureResultOf(timing_out_d, defer.TimeoutError) self.assertIs(LoggingContext.current_context(), context_one) diff --git a/tests/utils.py b/tests/utils.py index cb75514851..c2ef4b0bb5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -68,7 +68,9 @@ def setupdb(): # connect to postgres to create the base database. db_conn = db_engine.module.connect( - user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) db_conn.autocommit = True @@ -94,7 +96,9 @@ def setupdb(): def _cleanup(): db_conn = db_engine.module.connect( - user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, ) db_conn.autocommit = True @@ -114,7 +118,6 @@ def default_config(name): "server_name": name, "media_store_path": "media", "uploads_path": "uploads", - # the test signing key is just an arbitrary ed25519 key to keep the config # parser happy "signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg", -- cgit 1.5.1 From df2ebd75d3abde2bc2262551d8b2fd40c4b4bddf Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Mon, 13 May 2019 15:01:14 -0500 Subject: Migrate all tests to use the dict-based config format instead of hanging items off HomeserverConfig (#5171) --- changelog.d/5171.misc | 1 + synapse/rest/media/v1/storage_provider.py | 1 + tests/handlers/test_register.py | 8 +- tests/handlers/test_user_directory.py | 4 +- .../federation/test_matrix_federation_agent.py | 4 +- tests/push/test_email.py | 36 +++--- tests/push/test_http.py | 2 +- tests/rest/client/test_consent.py | 11 +- tests/rest/client/test_identity.py | 2 +- tests/rest/client/v1/test_directory.py | 2 +- tests/rest/client/v1/test_events.py | 6 +- tests/rest/client/v1/test_profile.py | 2 +- tests/rest/client/v1/test_rooms.py | 2 +- tests/rest/client/v2_alpha/test_auth.py | 6 +- tests/rest/client/v2_alpha/test_register.py | 51 ++++---- tests/rest/media/v1/test_media_storage.py | 17 +-- tests/rest/media/v1/test_url_preview.py | 40 +++---- tests/server.py | 61 +++++----- tests/server_notices/test_consent.py | 32 +++-- .../test_resource_limits_server_notices.py | 9 +- tests/test_state.py | 2 +- tests/unittest.py | 12 +- tests/utils.py | 132 ++++++++++----------- 23 files changed, 240 insertions(+), 203 deletions(-) create mode 100644 changelog.d/5171.misc (limited to 'tests/handlers') diff --git a/changelog.d/5171.misc b/changelog.d/5171.misc new file mode 100644 index 0000000000..d148b03b51 --- /dev/null +++ b/changelog.d/5171.misc @@ -0,0 +1 @@ +Update tests to consistently be configured via the same code that is used when loading from configuration files. diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 5aa03031f6..d90cbfb56a 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -108,6 +108,7 @@ class FileStorageProviderBackend(StorageProvider): """ def __init__(self, hs, config): + self.hs = hs self.cache_directory = hs.config.media_store_path self.base_directory = config diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 017ea0385e..1c253d0579 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -37,8 +37,12 @@ class RegistrationTestCase(unittest.HomeserverTestCase): hs_config = self.default_config("test") # some of the tests rely on us having a user consent version - hs_config.user_consent_version = "test_consent_version" - hs_config.max_mau_value = 50 + hs_config["user_consent"] = { + "version": "test_consent_version", + "template_dir": ".", + } + hs_config["max_mau_value"] = 50 + hs_config["limit_usage_by_mau"] = True hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True) return hs diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 44468f5382..9021e647fe 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -37,7 +37,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() - config.update_user_directory = True + config["update_user_directory"] = True return self.setup_test_homeserver(config=config) def prepare(self, reactor, clock, hs): @@ -333,7 +333,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() - config.update_user_directory = True + config["update_user_directory"] = True hs = self.setup_test_homeserver(config=config) self.config = hs.config diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 7036615041..ed0ca079d9 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -54,7 +54,9 @@ class MatrixFederationAgentTests(TestCase): self.agent = MatrixFederationAgent( reactor=self.reactor, - tls_client_options_factory=ClientTLSOptionsFactory(default_config("test")), + tls_client_options_factory=ClientTLSOptionsFactory( + default_config("test", parse=True) + ), _well_known_tls_policy=TrustingTLSPolicyForHTTPS(), _srv_resolver=self.mock_resolver, _well_known_cache=self.well_known_cache, diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 325ea449ae..9cdde1a9bd 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -52,22 +52,26 @@ class EmailPusherTests(HomeserverTestCase): return d config = self.default_config() - config.email_enable_notifs = True - config.start_pushers = True - - config.email_template_dir = os.path.abspath( - pkg_resources.resource_filename('synapse', 'res/templates') - ) - config.email_notif_template_html = "notif_mail.html" - config.email_notif_template_text = "notif_mail.txt" - config.email_smtp_host = "127.0.0.1" - config.email_smtp_port = 20 - config.require_transport_security = False - config.email_smtp_user = None - config.email_smtp_pass = None - config.email_app_name = "Matrix" - config.email_notif_from = "test@example.com" - config.email_riot_base_url = None + config["email"] = { + "enable_notifs": True, + "template_dir": os.path.abspath( + pkg_resources.resource_filename('synapse', 'res/templates') + ), + "expiry_template_html": "notice_expiry.html", + "expiry_template_text": "notice_expiry.txt", + "notif_template_html": "notif_mail.html", + "notif_template_text": "notif_mail.txt", + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "app_name": "Matrix", + "notif_from": "test@example.com", + "riot_base_url": None, + } + config["public_baseurl"] = "aaa" + config["start_pushers"] = True hs = self.setup_test_homeserver(config=config, sendmail=sendmail) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index 13bd2c8688..aba618b2be 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -54,7 +54,7 @@ class HTTPPusherTests(HomeserverTestCase): m.post_json_get_json = post_json_get_json config = self.default_config() - config.start_pushers = True + config["start_pushers"] = True hs = self.setup_test_homeserver(config=config, simple_http_client=m) diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index 5528971190..88f8f1abdc 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -42,15 +42,18 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() - config.user_consent_version = "1" - config.public_baseurl = "" - config.form_secret = "123abc" + config["public_baseurl"] = "aaaa" + config["form_secret"] = "123abc" # Make some temporary templates... temp_consent_path = self.mktemp() os.mkdir(temp_consent_path) os.mkdir(os.path.join(temp_consent_path, 'en')) - config.user_consent_template_dir = os.path.abspath(temp_consent_path) + + config["user_consent"] = { + "version": "1", + "template_dir": os.path.abspath(temp_consent_path), + } with open(os.path.join(temp_consent_path, "en/1.html"), 'w') as f: f.write("{{version}},{{has_consented}}") diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index 1a714ff58a..68949307d9 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -32,7 +32,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() - config.enable_3pid_lookup = False + config["enable_3pid_lookup"] = False self.hs = self.setup_test_homeserver(config=config) return self.hs diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py index 73c5b44b46..633b7dbda0 100644 --- a/tests/rest/client/v1/test_directory.py +++ b/tests/rest/client/v1/test_directory.py @@ -34,7 +34,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() - config.require_membership_for_aliases = True + config["require_membership_for_aliases"] = True self.hs = self.setup_test_homeserver(config=config) diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 8a9a55a527..f340b7e851 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -36,9 +36,9 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() - config.enable_registration_captcha = False - config.enable_registration = True - config.auto_join_rooms = [] + config["enable_registration_captcha"] = False + config["enable_registration"] = True + config["auto_join_rooms"] = [] hs = self.setup_test_homeserver( config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"]) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index ed034879cf..769c37ce52 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -171,7 +171,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() - config.require_auth_for_profile_requests = True + config["require_auth_for_profile_requests"] = True self.hs = self.setup_test_homeserver(config=config) return self.hs diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 9b191436cc..6220172cde 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -919,7 +919,7 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): self.url = b"/_matrix/client/r0/publicRooms" config = self.default_config() - config.restrict_public_rooms_to_local_users = True + config["restrict_public_rooms_to_local_users"] = True self.hs = self.setup_test_homeserver(config=config) return self.hs diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 0ca3c4657b..ad7d476401 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -36,9 +36,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase): config = self.default_config() - config.enable_registration_captcha = True - config.recaptcha_public_key = "brokencake" - config.registrations_require_3pid = [] + config["enable_registration_captcha"] = True + config["recaptcha_public_key"] = "brokencake" + config["registrations_require_3pid"] = [] hs = self.setup_test_homeserver(config=config) return hs diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index be95dc592d..65685883db 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -201,9 +201,11 @@ class AccountValidityTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() # Test for account expiring after a week. - config.enable_registration = True - config.account_validity.enabled = True - config.account_validity.period = 604800000 # Time in ms for 1 week + config["enable_registration"] = True + config["account_validity"] = { + "enabled": True, + "period": 604800000, # Time in ms for 1 week + } self.hs = self.setup_test_homeserver(config=config) return self.hs @@ -299,14 +301,17 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() + # Test for account expiring after a week and renewal emails being sent 2 # days before expiry. - config.enable_registration = True - config.account_validity.enabled = True - config.account_validity.renew_by_email_enabled = True - config.account_validity.period = 604800000 # Time in ms for 1 week - config.account_validity.renew_at = 172800000 # Time in ms for 2 days - config.account_validity.renew_email_subject = "Renew your account" + config["enable_registration"] = True + config["account_validity"] = { + "enabled": True, + "period": 604800000, # Time in ms for 1 week + "renew_at": 172800000, # Time in ms for 2 days + "renew_by_email_enabled": True, + "renew_email_subject": "Renew your account", + } # Email config. self.email_attempts = [] @@ -315,17 +320,23 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): self.email_attempts.append((args, kwargs)) return - config.email_template_dir = os.path.abspath( - pkg_resources.resource_filename('synapse', 'res/templates') - ) - config.email_expiry_template_html = "notice_expiry.html" - config.email_expiry_template_text = "notice_expiry.txt" - config.email_smtp_host = "127.0.0.1" - config.email_smtp_port = 20 - config.require_transport_security = False - config.email_smtp_user = None - config.email_smtp_pass = None - config.email_notif_from = "test@example.com" + config["email"] = { + "enable_notifs": True, + "template_dir": os.path.abspath( + pkg_resources.resource_filename('synapse', 'res/templates') + ), + "expiry_template_html": "notice_expiry.html", + "expiry_template_text": "notice_expiry.txt", + "notif_template_html": "notif_mail.html", + "notif_template_text": "notif_mail.txt", + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "aaa" self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index ad5e9a612f..1069a44145 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -25,13 +25,11 @@ from six.moves.urllib import parse from twisted.internet import defer, reactor from twisted.internet.defer import Deferred -from synapse.config.repository import MediaStorageProviderConfig from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend from synapse.util.logcontext import make_deferred_yieldable -from synapse.util.module_loader import load_module from tests import unittest @@ -120,12 +118,14 @@ class MediaRepoTests(unittest.HomeserverTestCase): client.get_file = get_file self.storage_path = self.mktemp() + self.media_store_path = self.mktemp() os.mkdir(self.storage_path) + os.mkdir(self.media_store_path) config = self.default_config() - config.media_store_path = self.storage_path - config.thumbnail_requirements = {} - config.max_image_pixels = 2000000 + config["media_store_path"] = self.media_store_path + config["thumbnail_requirements"] = {} + config["max_image_pixels"] = 2000000 provider_config = { "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend", @@ -134,12 +134,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): "store_remote": True, "config": {"directory": self.storage_path}, } - - loaded = list(load_module(provider_config)) + [ - MediaStorageProviderConfig(False, False, False) - ] - - config.media_storage_providers = [loaded] + config["media_storage_providers"] = [provider_config] hs = self.setup_test_homeserver(config=config, http_client=client) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index f696395f3c..1ab0f7293a 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -16,7 +16,6 @@ import os import attr -from netaddr import IPSet from twisted.internet._resolver import HostResolution from twisted.internet.address import IPv4Address, IPv6Address @@ -25,9 +24,6 @@ from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol from twisted.web._newclient import ResponseDone -from synapse.config.repository import MediaStorageProviderConfig -from synapse.util.module_loader import load_module - from tests import unittest from tests.server import FakeTransport @@ -67,23 +63,23 @@ class URLPreviewTests(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): - self.storage_path = self.mktemp() - os.mkdir(self.storage_path) - config = self.default_config() - config.url_preview_enabled = True - config.max_spider_size = 9999999 - config.url_preview_ip_range_blacklist = IPSet( - ( - "192.168.1.1", - "1.0.0.0/8", - "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", - "2001:800::/21", - ) + config["url_preview_enabled"] = True + config["max_spider_size"] = 9999999 + config["url_preview_ip_range_blacklist"] = ( + "192.168.1.1", + "1.0.0.0/8", + "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", + "2001:800::/21", ) - config.url_preview_ip_range_whitelist = IPSet(("1.1.1.1",)) - config.url_preview_url_blacklist = [] - config.media_store_path = self.storage_path + config["url_preview_ip_range_whitelist"] = ("1.1.1.1",) + config["url_preview_url_blacklist"] = [] + + self.storage_path = self.mktemp() + self.media_store_path = self.mktemp() + os.mkdir(self.storage_path) + os.mkdir(self.media_store_path) + config["media_store_path"] = self.media_store_path provider_config = { "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend", @@ -93,11 +89,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): "config": {"directory": self.storage_path}, } - loaded = list(load_module(provider_config)) + [ - MediaStorageProviderConfig(False, False, False) - ] - - config.media_storage_providers = [loaded] + config["media_storage_providers"] = [provider_config] hs = self.setup_test_homeserver(config=config) diff --git a/tests/server.py b/tests/server.py index fc41345488..c15a47f2a4 100644 --- a/tests/server.py +++ b/tests/server.py @@ -227,6 +227,8 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): """ def __init__(self): + self.threadpool = ThreadPool(self) + self._udp = [] lookups = self.lookups = {} @@ -255,6 +257,37 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): self.callLater(0, d.callback, True) return d + def getThreadPool(self): + return self.threadpool + + +class ThreadPool: + """ + Threadless thread pool. + """ + + def __init__(self, reactor): + self._reactor = reactor + + def start(self): + pass + + def stop(self): + pass + + def callInThreadWithCallback(self, onResult, function, *args, **kwargs): + def _(res): + if isinstance(res, Failure): + onResult(False, res) + else: + onResult(True, res) + + d = Deferred() + d.addCallback(lambda x: function(*args, **kwargs)) + d.addBoth(_) + self._reactor.callLater(0, d.callback, True) + return d + def setup_test_homeserver(cleanup_func, *args, **kwargs): """ @@ -290,36 +323,10 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): **kwargs ) - class ThreadPool: - """ - Threadless thread pool. - """ - - def start(self): - pass - - def stop(self): - pass - - def callInThreadWithCallback(self, onResult, function, *args, **kwargs): - def _(res): - if isinstance(res, Failure): - onResult(False, res) - else: - onResult(True, res) - - d = Deferred() - d.addCallback(lambda x: function(*args, **kwargs)) - d.addBoth(_) - clock._reactor.callLater(0, d.callback, True) - return d - - clock.threadpool = ThreadPool() - if pool: pool.runWithConnection = runWithConnection pool.runInteraction = runInteraction - pool.threadpool = ThreadPool() + pool.threadpool = ThreadPool(clock._reactor) pool.running = True return d diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py index e0b4e0eb63..872039c8f1 100644 --- a/tests/server_notices/test_consent.py +++ b/tests/server_notices/test_consent.py @@ -12,6 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import os + import synapse.rest.admin from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import sync @@ -30,20 +33,27 @@ class ConsentNoticesTests(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): + tmpdir = self.mktemp() + os.mkdir(tmpdir) self.consent_notice_message = "consent %(consent_uri)s" config = self.default_config() - config.user_consent_version = "1" - config.user_consent_server_notice_content = { - "msgtype": "m.text", - "body": self.consent_notice_message, + config["user_consent"] = { + "version": "1", + "template_dir": tmpdir, + "server_notice_content": { + "msgtype": "m.text", + "body": self.consent_notice_message, + }, + } + config["public_baseurl"] = "https://example.com/" + config["form_secret"] = "123abc" + + config["server_notices"] = { + "system_mxid_localpart": "notices", + "system_mxid_display_name": "test display name", + "system_mxid_avatar_url": None, + "room_name": "Server Notices", } - config.public_baseurl = "https://example.com/" - config.form_secret = "123abc" - - config.server_notices_mxid = "@notices:test" - config.server_notices_mxid_display_name = "test display name" - config.server_notices_mxid_avatar_url = None - config.server_notices_room_name = "Server Notices" hs = self.setup_test_homeserver(config=config) diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index a490b81ed4..739ee59ce4 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -29,7 +29,12 @@ from tests import unittest class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs_config = self.default_config("test") - hs_config.server_notices_mxid = "@server:test" + hs_config["server_notices"] = { + "system_mxid_localpart": "server", + "system_mxid_display_name": "test display name", + "system_mxid_avatar_url": None, + "room_name": "Server Notices", + } hs = self.setup_test_homeserver(config=hs_config, expire_access_token=True) return hs @@ -79,7 +84,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._send_notice.assert_not_called() # Test when mau limiting disabled self.hs.config.hs_disabled = False - self.hs.limit_usage_by_mau = False + self.hs.config.limit_usage_by_mau = False self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() diff --git a/tests/test_state.py b/tests/test_state.py index 5bcc6aaa18..6491a7105a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -168,7 +168,7 @@ class StateTestCase(unittest.TestCase): "get_state_resolution_handler", ] ) - hs.config = default_config("tesths") + hs.config = default_config("tesths", True) hs.get_datastore.return_value = self.store hs.get_state_handler.return_value = None hs.get_clock.return_value = MockClock() diff --git a/tests/unittest.py b/tests/unittest.py index 94df8cf47e..26204470b1 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -27,6 +27,7 @@ import twisted.logger from twisted.internet.defer import Deferred from twisted.trial import unittest +from synapse.config.homeserver import HomeServerConfig from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest from synapse.server import HomeServer @@ -245,7 +246,7 @@ class HomeserverTestCase(TestCase): def default_config(self, name="test"): """ - Get a default HomeServer config object. + Get a default HomeServer config dict. Args: name (str): The homeserver name/domain. @@ -335,7 +336,14 @@ class HomeserverTestCase(TestCase): kwargs.update(self._hs_args) if "config" not in kwargs: config = self.default_config() - kwargs["config"] = config + else: + config = kwargs["config"] + + # Parse the config from a config dict into a HomeServerConfig + config_obj = HomeServerConfig() + config_obj.parse_config_dict(config) + kwargs["config"] = config_obj + hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() diff --git a/tests/utils.py b/tests/utils.py index c2ef4b0bb5..f21074ae28 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -110,7 +110,7 @@ def setupdb(): atexit.register(_cleanup) -def default_config(name): +def default_config(name, parse=False): """ Create a reasonable test config. """ @@ -121,75 +121,69 @@ def default_config(name): # the test signing key is just an arbitrary ed25519 key to keep the config # parser happy "signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg", + "event_cache_size": 1, + "enable_registration": True, + "enable_registration_captcha": False, + "macaroon_secret_key": "not even a little secret", + "expire_access_token": False, + "trusted_third_party_id_servers": [], + "room_invite_state_types": [], + "password_providers": [], + "worker_replication_url": "", + "worker_app": None, + "email_enable_notifs": False, + "block_non_admin_invites": False, + "federation_domain_whitelist": None, + "federation_rc_reject_limit": 10, + "federation_rc_sleep_limit": 10, + "federation_rc_sleep_delay": 100, + "federation_rc_concurrent": 10, + "filter_timeline_limit": 5000, + "user_directory_search_all_users": False, + "user_consent_server_notice_content": None, + "block_events_without_consent_error": None, + "user_consent_at_registration": False, + "user_consent_policy_name": "Privacy Policy", + "media_storage_providers": [], + "autocreate_auto_join_rooms": True, + "auto_join_rooms": [], + "limit_usage_by_mau": False, + "hs_disabled": False, + "hs_disabled_message": "", + "hs_disabled_limit_type": "", + "max_mau_value": 50, + "mau_trial_days": 0, + "mau_stats_only": False, + "mau_limits_reserved_threepids": [], + "admin_contact": None, + "rc_message": {"per_second": 10000, "burst_count": 10000}, + "rc_registration": {"per_second": 10000, "burst_count": 10000}, + "rc_login": { + "address": {"per_second": 10000, "burst_count": 10000}, + "account": {"per_second": 10000, "burst_count": 10000}, + "failed_attempts": {"per_second": 10000, "burst_count": 10000}, + }, + "saml2_enabled": False, + "public_baseurl": None, + "default_identity_server": None, + "key_refresh_interval": 24 * 60 * 60 * 1000, + "old_signing_keys": {}, + "tls_fingerprints": [], + "use_frozen_dicts": False, + # We need a sane default_room_version, otherwise attempts to create + # rooms will fail. + "default_room_version": "1", + # disable user directory updates, because they get done in the + # background, which upsets the test runner. + "update_user_directory": False, } - config = HomeServerConfig() - config.parse_config_dict(config_dict) - - # TODO: move this stuff into config_dict or get rid of it - config.event_cache_size = 1 - config.enable_registration = True - config.enable_registration_captcha = False - config.macaroon_secret_key = "not even a little secret" - config.expire_access_token = False - config.trusted_third_party_id_servers = [] - config.room_invite_state_types = [] - config.password_providers = [] - config.worker_replication_url = "" - config.worker_app = None - config.email_enable_notifs = False - config.block_non_admin_invites = False - config.federation_domain_whitelist = None - config.federation_rc_reject_limit = 10 - config.federation_rc_sleep_limit = 10 - config.federation_rc_sleep_delay = 100 - config.federation_rc_concurrent = 10 - config.filter_timeline_limit = 5000 - config.user_directory_search_all_users = False - config.user_consent_server_notice_content = None - config.block_events_without_consent_error = None - config.user_consent_at_registration = False - config.user_consent_policy_name = "Privacy Policy" - config.media_storage_providers = [] - config.autocreate_auto_join_rooms = True - config.auto_join_rooms = [] - config.limit_usage_by_mau = False - config.hs_disabled = False - config.hs_disabled_message = "" - config.hs_disabled_limit_type = "" - config.max_mau_value = 50 - config.mau_trial_days = 0 - config.mau_stats_only = False - config.mau_limits_reserved_threepids = [] - config.admin_contact = None - config.rc_messages_per_second = 10000 - config.rc_message_burst_count = 10000 - config.rc_registration.per_second = 10000 - config.rc_registration.burst_count = 10000 - config.rc_login_address.per_second = 10000 - config.rc_login_address.burst_count = 10000 - config.rc_login_account.per_second = 10000 - config.rc_login_account.burst_count = 10000 - config.rc_login_failed_attempts.per_second = 10000 - config.rc_login_failed_attempts.burst_count = 10000 - config.saml2_enabled = False - config.public_baseurl = None - config.default_identity_server = None - config.key_refresh_interval = 24 * 60 * 60 * 1000 - config.old_signing_keys = {} - config.tls_fingerprints = [] - - config.use_frozen_dicts = False - - # we need a sane default_room_version, otherwise attempts to create rooms will - # fail. - config.default_room_version = "1" - - # disable user directory updates, because they get done in the - # background, which upsets the test runner. - config.update_user_directory = False - - return config + if parse: + config = HomeServerConfig() + config.parse_config_dict(config_dict) + return config + + return config_dict class TestHomeServer(HomeServer): @@ -223,7 +217,7 @@ def setup_test_homeserver( from twisted.internet import reactor if config is None: - config = default_config(name) + config = default_config(name, parse=True) config.ldap_enabled = False -- cgit 1.5.1 From 3787133c9e3fcf0e9b85700418bf03c48ec86ab3 Mon Sep 17 00:00:00 2001 From: ReidAnderson Date: Mon, 20 May 2019 05:20:08 -0500 Subject: Limit UserIds to a length that fits in a state key (#5198) --- changelog.d/5198.bugfix | 1 + synapse/api/constants.py | 3 +++ synapse/handlers/register.py | 11 ++++++++++- tests/handlers/test_register.py | 7 +++++++ 4 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 changelog.d/5198.bugfix (limited to 'tests/handlers') diff --git a/changelog.d/5198.bugfix b/changelog.d/5198.bugfix new file mode 100644 index 0000000000..c6b156f17d --- /dev/null +++ b/changelog.d/5198.bugfix @@ -0,0 +1 @@ +Prevent registration for user ids that are to long to fit into a state key. Contributed by Reid Anderson. \ No newline at end of file diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 8547a63535..c7bf95b426 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -23,6 +23,9 @@ MAX_DEPTH = 2**63 - 1 # the maximum length for a room alias is 255 characters MAX_ALIAS_LENGTH = 255 +# the maximum length for a user id is 255 characters +MAX_USERID_LENGTH = 255 + class Membership(object): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a51d11a257..e83ee24f10 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -19,7 +19,7 @@ import logging from twisted.internet import defer from synapse import types -from synapse.api.constants import LoginType +from synapse.api.constants import MAX_USERID_LENGTH, LoginType from synapse.api.errors import ( AuthError, Codes, @@ -123,6 +123,15 @@ class RegistrationHandler(BaseHandler): self.check_user_id_not_appservice_exclusive(user_id) + if len(user_id) > MAX_USERID_LENGTH: + raise SynapseError( + 400, + "User ID may not be longer than %s characters" % ( + MAX_USERID_LENGTH, + ), + Codes.INVALID_USERNAME + ) + users = yield self.store.get_users_by_id_case_insensitive(user_id) if users: if not guest_access_token: diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 1c253d0579..5ffba2ca7a 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -228,3 +228,10 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_register_not_support_user(self): res = self.get_success(self.handler.register(localpart='user')) self.assertFalse(self.store.is_support_user(res[0])) + + def test_invalid_user_id_length(self): + invalid_user_id = "x" * 256 + self.get_failure( + self.handler.register(localpart=invalid_user_id), + SynapseError + ) -- cgit 1.5.1 From 4a30e4acb4ef14431914bd42ad09a51bd81d6c3e Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 21 May 2019 11:36:50 -0500 Subject: Room Statistics (#4338) --- changelog.d/4338.feature | 1 + docs/sample_config.yaml | 16 ++ synapse/api/constants.py | 1 + synapse/config/homeserver.py | 42 ++- synapse/config/stats.py | 60 ++++ synapse/handlers/stats.py | 325 +++++++++++++++++++++ synapse/server.py | 6 + synapse/storage/__init__.py | 2 + synapse/storage/events_worker.py | 24 ++ synapse/storage/roommember.py | 32 +++ synapse/storage/schema/delta/54/stats.sql | 80 ++++++ synapse/storage/state_deltas.py | 12 +- synapse/storage/stats.py | 450 ++++++++++++++++++++++++++++++ tests/handlers/test_stats.py | 251 +++++++++++++++++ tests/rest/client/v1/utils.py | 17 ++ 15 files changed, 1306 insertions(+), 13 deletions(-) create mode 100644 changelog.d/4338.feature create mode 100644 synapse/config/stats.py create mode 100644 synapse/handlers/stats.py create mode 100644 synapse/storage/schema/delta/54/stats.sql create mode 100644 synapse/storage/stats.py create mode 100644 tests/handlers/test_stats.py (limited to 'tests/handlers') diff --git a/changelog.d/4338.feature b/changelog.d/4338.feature new file mode 100644 index 0000000000..01285e965c --- /dev/null +++ b/changelog.d/4338.feature @@ -0,0 +1 @@ +Synapse now more efficiently collates room statistics. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index f658ec8ecd..559fbcdd01 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1153,6 +1153,22 @@ password_config: # + +# Local statistics collection. Used in populating the room directory. +# +# 'bucket_size' controls how large each statistics timeslice is. It can +# be defined in a human readable short form -- e.g. "1d", "1y". +# +# 'retention' controls how long historical statistics will be kept for. +# It can be defined in a human readable short form -- e.g. "1d", "1y". +# +# +#stats: +# enabled: true +# bucket_size: 1d +# retention: 1y + + # Server Notices room configuration # # Uncomment this section to enable a room which can be used to send notices diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 6b347b1749..ee129c8689 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -79,6 +79,7 @@ class EventTypes(object): RoomHistoryVisibility = "m.room.history_visibility" CanonicalAlias = "m.room.canonical_alias" + Encryption = "m.room.encryption" RoomAvatar = "m.room.avatar" RoomEncryption = "m.room.encryption" GuestAccess = "m.room.guest_access" diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 727fdc54d8..5c4fc8ff21 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from .api import ApiConfig from .appservice import AppServiceConfig from .captcha import CaptchaConfig @@ -36,20 +37,41 @@ from .saml2_config import SAML2Config from .server import ServerConfig from .server_notices_config import ServerNoticesConfig from .spam_checker import SpamCheckerConfig +from .stats import StatsConfig from .tls import TlsConfig from .user_directory import UserDirectoryConfig from .voip import VoipConfig from .workers import WorkerConfig -class HomeServerConfig(ServerConfig, TlsConfig, DatabaseConfig, LoggingConfig, - RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, - VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, - AppServiceConfig, KeyConfig, SAML2Config, CasConfig, - JWTConfig, PasswordConfig, EmailConfig, - WorkerConfig, PasswordAuthProviderConfig, PushConfig, - SpamCheckerConfig, GroupsConfig, UserDirectoryConfig, - ConsentConfig, - ServerNoticesConfig, RoomDirectoryConfig, - ): +class HomeServerConfig( + ServerConfig, + TlsConfig, + DatabaseConfig, + LoggingConfig, + RatelimitConfig, + ContentRepositoryConfig, + CaptchaConfig, + VoipConfig, + RegistrationConfig, + MetricsConfig, + ApiConfig, + AppServiceConfig, + KeyConfig, + SAML2Config, + CasConfig, + JWTConfig, + PasswordConfig, + EmailConfig, + WorkerConfig, + PasswordAuthProviderConfig, + PushConfig, + SpamCheckerConfig, + GroupsConfig, + UserDirectoryConfig, + ConsentConfig, + StatsConfig, + ServerNoticesConfig, + RoomDirectoryConfig, +): pass diff --git a/synapse/config/stats.py b/synapse/config/stats.py new file mode 100644 index 0000000000..80fc1b9dd0 --- /dev/null +++ b/synapse/config/stats.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import sys + +from ._base import Config + + +class StatsConfig(Config): + """Stats Configuration + Configuration for the behaviour of synapse's stats engine + """ + + def read_config(self, config): + self.stats_enabled = True + self.stats_bucket_size = 86400 + self.stats_retention = sys.maxsize + stats_config = config.get("stats", None) + if stats_config: + self.stats_enabled = stats_config.get("enabled", self.stats_enabled) + self.stats_bucket_size = ( + self.parse_duration(stats_config.get("bucket_size", "1d")) / 1000 + ) + self.stats_retention = ( + self.parse_duration( + stats_config.get("retention", "%ds" % (sys.maxsize,)) + ) + / 1000 + ) + + def default_config(self, config_dir_path, server_name, **kwargs): + return """ + # Local statistics collection. Used in populating the room directory. + # + # 'bucket_size' controls how large each statistics timeslice is. It can + # be defined in a human readable short form -- e.g. "1d", "1y". + # + # 'retention' controls how long historical statistics will be kept for. + # It can be defined in a human readable short form -- e.g. "1d", "1y". + # + # + #stats: + # enabled: true + # bucket_size: 1d + # retention: 1y + """ diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py new file mode 100644 index 0000000000..0e92b405ba --- /dev/null +++ b/synapse/handlers/stats.py @@ -0,0 +1,325 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.handlers.state_deltas import StateDeltasHandler +from synapse.metrics import event_processing_positions +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import UserID +from synapse.util.metrics import Measure + +logger = logging.getLogger(__name__) + + +class StatsHandler(StateDeltasHandler): + """Handles keeping the *_stats tables updated with a simple time-series of + information about the users, rooms and media on the server, such that admins + have some idea of who is consuming their resources. + + Heavily derived from UserDirectoryHandler + """ + + def __init__(self, hs): + super(StatsHandler, self).__init__(hs) + self.hs = hs + self.store = hs.get_datastore() + self.state = hs.get_state_handler() + self.server_name = hs.hostname + self.clock = hs.get_clock() + self.notifier = hs.get_notifier() + self.is_mine_id = hs.is_mine_id + self.stats_bucket_size = hs.config.stats_bucket_size + + # The current position in the current_state_delta stream + self.pos = None + + # Guard to ensure we only process deltas one at a time + self._is_processing = False + + if hs.config.stats_enabled: + self.notifier.add_replication_callback(self.notify_new_event) + + # We kick this off so that we don't have to wait for a change before + # we start populating stats + self.clock.call_later(0, self.notify_new_event) + + def notify_new_event(self): + """Called when there may be more deltas to process + """ + if not self.hs.config.stats_enabled: + return + + if self._is_processing: + return + + @defer.inlineCallbacks + def process(): + try: + yield self._unsafe_process() + finally: + self._is_processing = False + + self._is_processing = True + run_as_background_process("stats.notify_new_event", process) + + @defer.inlineCallbacks + def _unsafe_process(self): + # If self.pos is None then means we haven't fetched it from DB + if self.pos is None: + self.pos = yield self.store.get_stats_stream_pos() + + # If still None then the initial background update hasn't happened yet + if self.pos is None: + defer.returnValue(None) + + # Loop round handling deltas until we're up to date + while True: + with Measure(self.clock, "stats_delta"): + deltas = yield self.store.get_current_state_deltas(self.pos) + if not deltas: + return + + logger.info("Handling %d state deltas", len(deltas)) + yield self._handle_deltas(deltas) + + self.pos = deltas[-1]["stream_id"] + yield self.store.update_stats_stream_pos(self.pos) + + event_processing_positions.labels("stats").set(self.pos) + + @defer.inlineCallbacks + def _handle_deltas(self, deltas): + """ + Called with the state deltas to process + """ + for delta in deltas: + typ = delta["type"] + state_key = delta["state_key"] + room_id = delta["room_id"] + event_id = delta["event_id"] + stream_id = delta["stream_id"] + prev_event_id = delta["prev_event_id"] + + logger.debug("Handling: %r %r, %s", typ, state_key, event_id) + + token = yield self.store.get_earliest_token_for_room_stats(room_id) + + # If the earliest token to begin from is larger than our current + # stream ID, skip processing this delta. + if token is not None and token >= stream_id: + logger.debug( + "Ignoring: %s as earlier than this room's initial ingestion event", + event_id, + ) + continue + + if event_id is None and prev_event_id is None: + # Errr... + continue + + event_content = {} + + if event_id is not None: + event_content = (yield self.store.get_event(event_id)).content or {} + + # quantise time to the nearest bucket + now = yield self.store.get_received_ts(event_id) + now = (now // 1000 // self.stats_bucket_size) * self.stats_bucket_size + + if typ == EventTypes.Member: + # we could use _get_key_change here but it's a bit inefficient + # given we're not testing for a specific result; might as well + # just grab the prev_membership and membership strings and + # compare them. + prev_event_content = {} + if prev_event_id is not None: + prev_event_content = ( + yield self.store.get_event(prev_event_id) + ).content + + membership = event_content.get("membership", Membership.LEAVE) + prev_membership = prev_event_content.get("membership", Membership.LEAVE) + + if prev_membership == membership: + continue + + if prev_membership == Membership.JOIN: + yield self.store.update_stats_delta( + now, "room", room_id, "joined_members", -1 + ) + elif prev_membership == Membership.INVITE: + yield self.store.update_stats_delta( + now, "room", room_id, "invited_members", -1 + ) + elif prev_membership == Membership.LEAVE: + yield self.store.update_stats_delta( + now, "room", room_id, "left_members", -1 + ) + elif prev_membership == Membership.BAN: + yield self.store.update_stats_delta( + now, "room", room_id, "banned_members", -1 + ) + else: + err = "%s is not a valid prev_membership" % (repr(prev_membership),) + logger.error(err) + raise ValueError(err) + + if membership == Membership.JOIN: + yield self.store.update_stats_delta( + now, "room", room_id, "joined_members", +1 + ) + elif membership == Membership.INVITE: + yield self.store.update_stats_delta( + now, "room", room_id, "invited_members", +1 + ) + elif membership == Membership.LEAVE: + yield self.store.update_stats_delta( + now, "room", room_id, "left_members", +1 + ) + elif membership == Membership.BAN: + yield self.store.update_stats_delta( + now, "room", room_id, "banned_members", +1 + ) + else: + err = "%s is not a valid membership" % (repr(membership),) + logger.error(err) + raise ValueError(err) + + user_id = state_key + if self.is_mine_id(user_id): + # update user_stats as it's one of our users + public = yield self._is_public_room(room_id) + + if membership == Membership.LEAVE: + yield self.store.update_stats_delta( + now, + "user", + user_id, + "public_rooms" if public else "private_rooms", + -1, + ) + elif membership == Membership.JOIN: + yield self.store.update_stats_delta( + now, + "user", + user_id, + "public_rooms" if public else "private_rooms", + +1, + ) + + elif typ == EventTypes.Create: + # Newly created room. Add it with all blank portions. + yield self.store.update_room_state( + room_id, + { + "join_rules": None, + "history_visibility": None, + "encryption": None, + "name": None, + "topic": None, + "avatar": None, + "canonical_alias": None, + }, + ) + + elif typ == EventTypes.JoinRules: + yield self.store.update_room_state( + room_id, {"join_rules": event_content.get("join_rule")} + ) + + is_public = yield self._get_key_change( + prev_event_id, event_id, "join_rule", JoinRules.PUBLIC + ) + if is_public is not None: + yield self.update_public_room_stats(now, room_id, is_public) + + elif typ == EventTypes.RoomHistoryVisibility: + yield self.store.update_room_state( + room_id, + {"history_visibility": event_content.get("history_visibility")}, + ) + + is_public = yield self._get_key_change( + prev_event_id, event_id, "history_visibility", "world_readable" + ) + if is_public is not None: + yield self.update_public_room_stats(now, room_id, is_public) + + elif typ == EventTypes.Encryption: + yield self.store.update_room_state( + room_id, {"encryption": event_content.get("algorithm")} + ) + elif typ == EventTypes.Name: + yield self.store.update_room_state( + room_id, {"name": event_content.get("name")} + ) + elif typ == EventTypes.Topic: + yield self.store.update_room_state( + room_id, {"topic": event_content.get("topic")} + ) + elif typ == EventTypes.RoomAvatar: + yield self.store.update_room_state( + room_id, {"avatar": event_content.get("url")} + ) + elif typ == EventTypes.CanonicalAlias: + yield self.store.update_room_state( + room_id, {"canonical_alias": event_content.get("alias")} + ) + + @defer.inlineCallbacks + def update_public_room_stats(self, ts, room_id, is_public): + """ + Increment/decrement a user's number of public rooms when a room they are + in changes to/from public visibility. + + Args: + ts (int): Timestamp in seconds + room_id (str) + is_public (bool) + """ + # For now, blindly iterate over all local users in the room so that + # we can handle the whole problem of copying buckets over as needed + user_ids = yield self.store.get_users_in_room(room_id) + + for user_id in user_ids: + if self.hs.is_mine(UserID.from_string(user_id)): + yield self.store.update_stats_delta( + ts, "user", user_id, "public_rooms", +1 if is_public else -1 + ) + yield self.store.update_stats_delta( + ts, "user", user_id, "private_rooms", -1 if is_public else +1 + ) + + @defer.inlineCallbacks + def _is_public_room(self, room_id): + join_rules = yield self.state.get_current_state(room_id, EventTypes.JoinRules) + history_visibility = yield self.state.get_current_state( + room_id, EventTypes.RoomHistoryVisibility + ) + + if (join_rules and join_rules.content.get("join_rule") == JoinRules.PUBLIC) or ( + ( + history_visibility + and history_visibility.content.get("history_visibility") + == "world_readable" + ) + ): + defer.returnValue(True) + else: + defer.returnValue(False) diff --git a/synapse/server.py b/synapse/server.py index 80d40b9272..9229a68a8d 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -72,6 +72,7 @@ from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_member import RoomMemberMasterHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler from synapse.handlers.set_password import SetPasswordHandler +from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler from synapse.handlers.typing import TypingHandler from synapse.handlers.user_directory import UserDirectoryHandler @@ -139,6 +140,7 @@ class HomeServer(object): 'acme_handler', 'auth_handler', 'device_handler', + 'stats_handler', 'e2e_keys_handler', 'e2e_room_keys_handler', 'event_handler', @@ -191,6 +193,7 @@ class HomeServer(object): REQUIRED_ON_MASTER_STARTUP = [ "user_directory_handler", + "stats_handler" ] # This is overridden in derived application classes @@ -474,6 +477,9 @@ class HomeServer(object): def build_secrets(self): return Secrets() + def build_stats_handler(self): + return StatsHandler(self) + def build_spam_checker(self): return SpamChecker(self) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 7522d3fd57..66675d08ae 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -55,6 +55,7 @@ from .roommember import RoomMemberStore from .search import SearchStore from .signatures import SignatureStore from .state import StateStore +from .stats import StatsStore from .stream import StreamStore from .tags import TagsStore from .transactions import TransactionStore @@ -100,6 +101,7 @@ class DataStore( GroupServerStore, UserErasureStore, MonthlyActiveUsersStore, + StatsStore, RelationsStore, ): def __init__(self, db_conn, hs): diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index adc6cf26b5..83ffae2132 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -611,3 +611,27 @@ class EventsWorkerStore(SQLBaseStore): return res return self.runInteraction("get_rejection_reasons", f) + + def _get_total_state_event_counts_txn(self, txn, room_id): + """ + See get_state_event_counts. + """ + sql = "SELECT COUNT(*) FROM state_events WHERE room_id=?" + txn.execute(sql, (room_id,)) + row = txn.fetchone() + return row[0] if row else 0 + + def get_total_state_event_counts(self, room_id): + """ + Gets the total number of state events in a room. + + Args: + room_id (str) + + Returns: + Deferred[int] + """ + return self.runInteraction( + "get_total_state_event_counts", + self._get_total_state_event_counts_txn, room_id + ) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 57df17bcc2..4bd1669458 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -142,6 +142,38 @@ class RoomMemberWorkerStore(EventsWorkerStore): return self.runInteraction("get_room_summary", _get_room_summary_txn) + def _get_user_count_in_room_txn(self, txn, room_id, membership): + """ + See get_user_count_in_room. + """ + sql = ( + "SELECT count(*) FROM room_memberships as m" + " INNER JOIN current_state_events as c" + " ON m.event_id = c.event_id " + " AND m.room_id = c.room_id " + " AND m.user_id = c.state_key" + " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?" + ) + + txn.execute(sql, (room_id, membership)) + row = txn.fetchone() + return row[0] + + def get_user_count_in_room(self, room_id, membership): + """ + Get the user count in a room with a particular membership. + + Args: + room_id (str) + membership (Membership) + + Returns: + Deferred[int] + """ + return self.runInteraction( + "get_users_in_room", self._get_user_count_in_room_txn, room_id, membership + ) + @cached() def get_invited_rooms_for_user(self, user_id): """ Get all the rooms the user is invited to diff --git a/synapse/storage/schema/delta/54/stats.sql b/synapse/storage/schema/delta/54/stats.sql new file mode 100644 index 0000000000..652e58308e --- /dev/null +++ b/synapse/storage/schema/delta/54/stats.sql @@ -0,0 +1,80 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE stats_stream_pos ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT, + CHECK (Lock='X') +); + +INSERT INTO stats_stream_pos (stream_id) VALUES (null); + +CREATE TABLE user_stats ( + user_id TEXT NOT NULL, + ts BIGINT NOT NULL, + bucket_size INT NOT NULL, + public_rooms INT NOT NULL, + private_rooms INT NOT NULL +); + +CREATE UNIQUE INDEX user_stats_user_ts ON user_stats(user_id, ts); + +CREATE TABLE room_stats ( + room_id TEXT NOT NULL, + ts BIGINT NOT NULL, + bucket_size INT NOT NULL, + current_state_events INT NOT NULL, + joined_members INT NOT NULL, + invited_members INT NOT NULL, + left_members INT NOT NULL, + banned_members INT NOT NULL, + state_events INT NOT NULL +); + +CREATE UNIQUE INDEX room_stats_room_ts ON room_stats(room_id, ts); + +-- cache of current room state; useful for the publicRooms list +CREATE TABLE room_state ( + room_id TEXT NOT NULL, + join_rules TEXT, + history_visibility TEXT, + encryption TEXT, + name TEXT, + topic TEXT, + avatar TEXT, + canonical_alias TEXT + -- get aliases straight from the right table +); + +CREATE UNIQUE INDEX room_state_room ON room_state(room_id); + +CREATE TABLE room_stats_earliest_token ( + room_id TEXT NOT NULL, + token BIGINT NOT NULL +); + +CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token(room_id); + +-- Set up staging tables +INSERT INTO background_updates (update_name, progress_json) VALUES + ('populate_stats_createtables', '{}'); + +-- Run through each room and update stats +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_process_rooms', '{}', 'populate_stats_createtables'); + +-- Clean up staging tables +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('populate_stats_cleanup', '{}', 'populate_stats_process_rooms'); diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py index 31a0279b18..5fdb442104 100644 --- a/synapse/storage/state_deltas.py +++ b/synapse/storage/state_deltas.py @@ -84,10 +84,16 @@ class StateDeltasStore(SQLBaseStore): "get_current_state_deltas", get_current_state_deltas_txn ) - def get_max_stream_id_in_current_state_deltas(self): - return self._simple_select_one_onecol( + def _get_max_stream_id_in_current_state_deltas_txn(self, txn): + return self._simple_select_one_onecol_txn( + txn, table="current_state_delta_stream", keyvalues={}, retcol="COALESCE(MAX(stream_id), -1)", - desc="get_max_stream_id_in_current_state_deltas", + ) + + def get_max_stream_id_in_current_state_deltas(self): + return self.runInteraction( + "get_max_stream_id_in_current_state_deltas", + self._get_max_stream_id_in_current_state_deltas_txn, ) diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py new file mode 100644 index 0000000000..71b80a891d --- /dev/null +++ b/synapse/storage/stats.py @@ -0,0 +1,450 @@ +# -*- coding: utf-8 -*- +# Copyright 2018, 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.storage.state_deltas import StateDeltasStore +from synapse.util.caches.descriptors import cached + +logger = logging.getLogger(__name__) + +# these fields track absolutes (e.g. total number of rooms on the server) +ABSOLUTE_STATS_FIELDS = { + "room": ( + "current_state_events", + "joined_members", + "invited_members", + "left_members", + "banned_members", + "state_events", + ), + "user": ("public_rooms", "private_rooms"), +} + +TYPE_TO_ROOM = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")} + +TEMP_TABLE = "_temp_populate_stats" + + +class StatsStore(StateDeltasStore): + def __init__(self, db_conn, hs): + super(StatsStore, self).__init__(db_conn, hs) + + self.server_name = hs.hostname + self.clock = self.hs.get_clock() + self.stats_enabled = hs.config.stats_enabled + self.stats_bucket_size = hs.config.stats_bucket_size + + self.register_background_update_handler( + "populate_stats_createtables", self._populate_stats_createtables + ) + self.register_background_update_handler( + "populate_stats_process_rooms", self._populate_stats_process_rooms + ) + self.register_background_update_handler( + "populate_stats_cleanup", self._populate_stats_cleanup + ) + + @defer.inlineCallbacks + def _populate_stats_createtables(self, progress, batch_size): + + if not self.stats_enabled: + yield self._end_background_update("populate_stats_createtables") + defer.returnValue(1) + + # Get all the rooms that we want to process. + def _make_staging_area(txn): + sql = ( + "CREATE TABLE IF NOT EXISTS " + + TEMP_TABLE + + "_rooms(room_id TEXT NOT NULL, events BIGINT NOT NULL)" + ) + txn.execute(sql) + + sql = ( + "CREATE TABLE IF NOT EXISTS " + + TEMP_TABLE + + "_position(position TEXT NOT NULL)" + ) + txn.execute(sql) + + # Get rooms we want to process from the database + sql = """ + SELECT room_id, count(*) FROM current_state_events + GROUP BY room_id + """ + txn.execute(sql) + rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()] + self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) + del rooms + + new_pos = yield self.get_max_stream_id_in_current_state_deltas() + yield self.runInteraction("populate_stats_temp_build", _make_staging_area) + yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) + self.get_earliest_token_for_room_stats.invalidate_all() + + yield self._end_background_update("populate_stats_createtables") + defer.returnValue(1) + + @defer.inlineCallbacks + def _populate_stats_cleanup(self, progress, batch_size): + """ + Update the user directory stream position, then clean up the old tables. + """ + if not self.stats_enabled: + yield self._end_background_update("populate_stats_cleanup") + defer.returnValue(1) + + position = yield self._simple_select_one_onecol( + TEMP_TABLE + "_position", None, "position" + ) + yield self.update_stats_stream_pos(position) + + def _delete_staging_area(txn): + txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms") + txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") + + yield self.runInteraction("populate_stats_cleanup", _delete_staging_area) + + yield self._end_background_update("populate_stats_cleanup") + defer.returnValue(1) + + @defer.inlineCallbacks + def _populate_stats_process_rooms(self, progress, batch_size): + + if not self.stats_enabled: + yield self._end_background_update("populate_stats_process_rooms") + defer.returnValue(1) + + # If we don't have progress filed, delete everything. + if not progress: + yield self.delete_all_stats() + + def _get_next_batch(txn): + # Only fetch 250 rooms, so we don't fetch too many at once, even + # if those 250 rooms have less than batch_size state events. + sql = """ + SELECT room_id, events FROM %s_rooms + ORDER BY events DESC + LIMIT 250 + """ % ( + TEMP_TABLE, + ) + txn.execute(sql) + rooms_to_work_on = txn.fetchall() + + if not rooms_to_work_on: + return None + + # Get how many are left to process, so we can give status on how + # far we are in processing + txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms") + progress["remaining"] = txn.fetchone()[0] + + return rooms_to_work_on + + rooms_to_work_on = yield self.runInteraction( + "populate_stats_temp_read", _get_next_batch + ) + + # No more rooms -- complete the transaction. + if not rooms_to_work_on: + yield self._end_background_update("populate_stats_process_rooms") + defer.returnValue(1) + + logger.info( + "Processing the next %d rooms of %d remaining", + (len(rooms_to_work_on), progress["remaining"]), + ) + + # Number of state events we've processed by going through each room + processed_event_count = 0 + + for room_id, event_count in rooms_to_work_on: + + current_state_ids = yield self.get_current_state_ids(room_id) + + join_rules = yield self.get_event( + current_state_ids.get((EventTypes.JoinRules, "")), allow_none=True + ) + history_visibility = yield self.get_event( + current_state_ids.get((EventTypes.RoomHistoryVisibility, "")), + allow_none=True, + ) + encryption = yield self.get_event( + current_state_ids.get((EventTypes.RoomEncryption, "")), allow_none=True + ) + name = yield self.get_event( + current_state_ids.get((EventTypes.Name, "")), allow_none=True + ) + topic = yield self.get_event( + current_state_ids.get((EventTypes.Topic, "")), allow_none=True + ) + avatar = yield self.get_event( + current_state_ids.get((EventTypes.RoomAvatar, "")), allow_none=True + ) + canonical_alias = yield self.get_event( + current_state_ids.get((EventTypes.CanonicalAlias, "")), allow_none=True + ) + + def _or_none(x, arg): + if x: + return x.content.get(arg) + return None + + yield self.update_room_state( + room_id, + { + "join_rules": _or_none(join_rules, "join_rule"), + "history_visibility": _or_none( + history_visibility, "history_visibility" + ), + "encryption": _or_none(encryption, "algorithm"), + "name": _or_none(name, "name"), + "topic": _or_none(topic, "topic"), + "avatar": _or_none(avatar, "url"), + "canonical_alias": _or_none(canonical_alias, "alias"), + }, + ) + + now = self.hs.get_reactor().seconds() + + # quantise time to the nearest bucket + now = (now // self.stats_bucket_size) * self.stats_bucket_size + + def _fetch_data(txn): + + # Get the current token of the room + current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn) + + current_state_events = len(current_state_ids) + joined_members = self._get_user_count_in_room_txn( + txn, room_id, Membership.JOIN + ) + invited_members = self._get_user_count_in_room_txn( + txn, room_id, Membership.INVITE + ) + left_members = self._get_user_count_in_room_txn( + txn, room_id, Membership.LEAVE + ) + banned_members = self._get_user_count_in_room_txn( + txn, room_id, Membership.BAN + ) + total_state_events = self._get_total_state_event_counts_txn( + txn, room_id + ) + + self._update_stats_txn( + txn, + "room", + room_id, + now, + { + "bucket_size": self.stats_bucket_size, + "current_state_events": current_state_events, + "joined_members": joined_members, + "invited_members": invited_members, + "left_members": left_members, + "banned_members": banned_members, + "state_events": total_state_events, + }, + ) + self._simple_insert_txn( + txn, + "room_stats_earliest_token", + {"room_id": room_id, "token": current_token}, + ) + + yield self.runInteraction("update_room_stats", _fetch_data) + + # We've finished a room. Delete it from the table. + yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) + # Update the remaining counter. + progress["remaining"] -= 1 + yield self.runInteraction( + "populate_stats", + self._background_update_progress_txn, + "populate_stats_process_rooms", + progress, + ) + + processed_event_count += event_count + + if processed_event_count > batch_size: + # Don't process any more rooms, we've hit our batch size. + defer.returnValue(processed_event_count) + + defer.returnValue(processed_event_count) + + def delete_all_stats(self): + """ + Delete all statistics records. + """ + + def _delete_all_stats_txn(txn): + txn.execute("DELETE FROM room_state") + txn.execute("DELETE FROM room_stats") + txn.execute("DELETE FROM room_stats_earliest_token") + txn.execute("DELETE FROM user_stats") + + return self.runInteraction("delete_all_stats", _delete_all_stats_txn) + + def get_stats_stream_pos(self): + return self._simple_select_one_onecol( + table="stats_stream_pos", + keyvalues={}, + retcol="stream_id", + desc="stats_stream_pos", + ) + + def update_stats_stream_pos(self, stream_id): + return self._simple_update_one( + table="stats_stream_pos", + keyvalues={}, + updatevalues={"stream_id": stream_id}, + desc="update_stats_stream_pos", + ) + + def update_room_state(self, room_id, fields): + """ + Args: + room_id (str) + fields (dict[str:Any]) + """ + return self._simple_upsert( + table="room_state", + keyvalues={"room_id": room_id}, + values=fields, + desc="update_room_state", + ) + + def get_deltas_for_room(self, room_id, start, size=100): + """ + Get statistics deltas for a given room. + + Args: + room_id (str) + start (int): Pagination start. Number of entries, not timestamp. + size (int): How many entries to return. + + Returns: + Deferred[list[dict]], where the dict has the keys of + ABSOLUTE_STATS_FIELDS["room"] and "ts". + """ + return self._simple_select_list_paginate( + "room_stats", + {"room_id": room_id}, + "ts", + start, + size, + retcols=(list(ABSOLUTE_STATS_FIELDS["room"]) + ["ts"]), + order_direction="DESC", + ) + + def get_all_room_state(self): + return self._simple_select_list( + "room_state", None, retcols=("name", "topic", "canonical_alias") + ) + + @cached() + def get_earliest_token_for_room_stats(self, room_id): + """ + Fetch the "earliest token". This is used by the room stats delta + processor to ignore deltas that have been processed between the + start of the background task and any particular room's stats + being calculated. + + Returns: + Deferred[int] + """ + return self._simple_select_one_onecol( + "room_stats_earliest_token", + {"room_id": room_id}, + retcol="token", + allow_none=True, + ) + + def update_stats(self, stats_type, stats_id, ts, fields): + table, id_col = TYPE_TO_ROOM[stats_type] + return self._simple_upsert( + table=table, + keyvalues={id_col: stats_id, "ts": ts}, + values=fields, + desc="update_stats", + ) + + def _update_stats_txn(self, txn, stats_type, stats_id, ts, fields): + table, id_col = TYPE_TO_ROOM[stats_type] + return self._simple_upsert_txn( + txn, table=table, keyvalues={id_col: stats_id, "ts": ts}, values=fields + ) + + def update_stats_delta(self, ts, stats_type, stats_id, field, value): + def _update_stats_delta(txn): + table, id_col = TYPE_TO_ROOM[stats_type] + + sql = ( + "SELECT * FROM %s" + " WHERE %s=? and ts=(" + " SELECT MAX(ts) FROM %s" + " WHERE %s=?" + ")" + ) % (table, id_col, table, id_col) + txn.execute(sql, (stats_id, stats_id)) + rows = self.cursor_to_dict(txn) + if len(rows) == 0: + # silently skip as we don't have anything to apply a delta to yet. + # this tries to minimise any race between the initial sync and + # subsequent deltas arriving. + return + + current_ts = ts + latest_ts = rows[0]["ts"] + if current_ts < latest_ts: + # This one is in the past, but we're just encountering it now. + # Mark it as part of the current bucket. + current_ts = latest_ts + elif ts != latest_ts: + # we have to copy our absolute counters over to the new entry. + values = { + key: rows[0][key] for key in ABSOLUTE_STATS_FIELDS[stats_type] + } + values[id_col] = stats_id + values["ts"] = ts + values["bucket_size"] = self.stats_bucket_size + + self._simple_insert_txn(txn, table=table, values=values) + + # actually update the new value + if stats_type in ABSOLUTE_STATS_FIELDS[stats_type]: + self._simple_update_txn( + txn, + table=table, + keyvalues={id_col: stats_id, "ts": current_ts}, + updatevalues={field: value}, + ) + else: + sql = ("UPDATE %s SET %s=%s+? WHERE %s=? AND ts=?") % ( + table, + field, + field, + id_col, + ) + txn.execute(sql, (value, stats_id, current_ts)) + + return self.runInteraction("update_stats_delta", _update_stats_delta) diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py new file mode 100644 index 0000000000..249aba3d59 --- /dev/null +++ b/tests/handlers/test_stats.py @@ -0,0 +1,251 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mock import Mock + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests import unittest + + +class StatsRoomTests(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + + self.store = hs.get_datastore() + self.handler = self.hs.get_stats_handler() + + def _add_background_updates(self): + """ + Add the background updates we need to run. + """ + # Ugh, have to reset this flag + self.store._all_done = False + + self.get_success( + self.store._simple_insert( + "background_updates", + {"update_name": "populate_stats_createtables", "progress_json": "{}"}, + ) + ) + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_stats_process_rooms", + "progress_json": "{}", + "depends_on": "populate_stats_createtables", + }, + ) + ) + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_stats_cleanup", + "progress_json": "{}", + "depends_on": "populate_stats_process_rooms", + }, + ) + ) + + def test_initial_room(self): + """ + The background updates will build the table from scratch. + """ + r = self.get_success(self.store.get_all_room_state()) + self.assertEqual(len(r), 0) + + # Disable stats + self.hs.config.stats_enabled = False + self.handler.stats_enabled = False + + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + self.helper.send_state( + room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token + ) + + # Stats disabled, shouldn't have done anything + r = self.get_success(self.store.get_all_room_state()) + self.assertEqual(len(r), 0) + + # Enable stats + self.hs.config.stats_enabled = True + self.handler.stats_enabled = True + + # Do the initial population of the user directory via the background update + self._add_background_updates() + + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) + + r = self.get_success(self.store.get_all_room_state()) + + self.assertEqual(len(r), 1) + self.assertEqual(r[0]["topic"], "foo") + + def test_initial_earliest_token(self): + """ + Ingestion via notify_new_event will ignore tokens that the background + update have already processed. + """ + self.reactor.advance(86401) + + self.hs.config.stats_enabled = False + self.handler.stats_enabled = False + + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + u2 = self.register_user("u2", "pass") + u2_token = self.login("u2", "pass") + + u3 = self.register_user("u3", "pass") + u3_token = self.login("u3", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + self.helper.send_state( + room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token + ) + + # Begin the ingestion by creating the temp tables. This will also store + # the position that the deltas should begin at, once they take over. + self.hs.config.stats_enabled = True + self.handler.stats_enabled = True + self.store._all_done = False + self.get_success(self.store.update_stats_stream_pos(None)) + + self.get_success( + self.store._simple_insert( + "background_updates", + {"update_name": "populate_stats_createtables", "progress_json": "{}"}, + ) + ) + + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) + + # Now, before the table is actually ingested, add some more events. + self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token) + self.helper.join(room=room_1, user=u2, tok=u2_token) + + # Now do the initial ingestion. + self.get_success( + self.store._simple_insert( + "background_updates", + {"update_name": "populate_stats_process_rooms", "progress_json": "{}"}, + ) + ) + self.get_success( + self.store._simple_insert( + "background_updates", + { + "update_name": "populate_stats_cleanup", + "progress_json": "{}", + "depends_on": "populate_stats_process_rooms", + }, + ) + ) + + self.store._all_done = False + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) + + self.reactor.advance(86401) + + # Now add some more events, triggering ingestion. Because of the stream + # position being set to before the events sent in the middle, a simpler + # implementation would reprocess those events, and say there were four + # users, not three. + self.helper.invite(room=room_1, src=u1, targ=u3, tok=u1_token) + self.helper.join(room=room_1, user=u3, tok=u3_token) + + # Get the deltas! There should be two -- day 1, and day 2. + r = self.get_success(self.store.get_deltas_for_room(room_1, 0)) + + # The oldest has 2 joined members + self.assertEqual(r[-1]["joined_members"], 2) + + # The newest has 3 + self.assertEqual(r[0]["joined_members"], 3) + + def test_incorrect_state_transition(self): + """ + If the state transition is not one of (JOIN, INVITE, LEAVE, BAN) to + (JOIN, INVITE, LEAVE, BAN), an error is raised. + """ + events = { + "a1": {"membership": Membership.LEAVE}, + "a2": {"membership": "not a real thing"}, + } + + def get_event(event_id): + m = Mock() + m.content = events[event_id] + d = defer.Deferred() + self.reactor.callLater(0.0, d.callback, m) + return d + + def get_received_ts(event_id): + return defer.succeed(1) + + self.store.get_received_ts = get_received_ts + self.store.get_event = get_event + + deltas = [ + { + "type": EventTypes.Member, + "state_key": "some_user", + "room_id": "room", + "event_id": "a1", + "prev_event_id": "a2", + "stream_id": "bleb", + } + ] + + f = self.get_failure(self.handler._handle_deltas(deltas), ValueError) + self.assertEqual( + f.value.args[0], "'not a real thing' is not a valid prev_membership" + ) + + # And the other way... + deltas = [ + { + "type": EventTypes.Member, + "state_key": "some_user", + "room_id": "room", + "event_id": "a2", + "prev_event_id": "a1", + "stream_id": "bleb", + } + ] + + f = self.get_failure(self.handler._handle_deltas(deltas), ValueError) + self.assertEqual( + f.value.args[0], "'not a real thing' is not a valid membership" + ) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 05b0143c42..f7133fc12e 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -127,3 +127,20 @@ class RestHelper(object): ) return channel.json_body + + def send_state(self, room_id, event_type, body, tok, expect_code=200): + path = "/_matrix/client/r0/rooms/%s/state/%s" % (room_id, event_type) + if tok: + path = path + "?access_token=%s" % tok + + request, channel = make_request( + self.hs.get_reactor(), "PUT", path, json.dumps(body).encode('utf8') + ) + render(request, self.resource, self.hs.get_reactor()) + + assert int(channel.result["code"]) == expect_code, ( + "Expected: %d, got: %d, resp: %r" + % (expect_code, int(channel.result["code"]), channel.result["body"]) + ) + + return channel.json_body -- cgit 1.5.1 From 75538813fcd0403ec8915484a813b99e6eb256c6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 5 Jun 2019 15:45:46 +0100 Subject: Fix background updates to handle redactions/rejections (#5352) * Fix background updates to handle redactions/rejections In background updates based on current state delta stream we need to handle that we may not have all the events (or at least that `get_events` may raise an exception). --- changelog.d/5352.bugfix | 1 + synapse/handlers/presence.py | 11 ++++--- synapse/handlers/stats.py | 18 ++++++++---- synapse/storage/events_worker.py | 37 ++++++++++++++++++++++++ tests/handlers/test_stats.py | 62 ++++++++++++++++++++++++++++++++++++++-- 5 files changed, 117 insertions(+), 12 deletions(-) create mode 100644 changelog.d/5352.bugfix (limited to 'tests/handlers') diff --git a/changelog.d/5352.bugfix b/changelog.d/5352.bugfix new file mode 100644 index 0000000000..2ffefe5a68 --- /dev/null +++ b/changelog.d/5352.bugfix @@ -0,0 +1 @@ +Fix room stats and presence background updates to correctly handle missing events. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 6209858bbb..e49c8203ef 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -828,14 +828,17 @@ class PresenceHandler(object): # joins. continue - event = yield self.store.get_event(event_id) - if event.content.get("membership") != Membership.JOIN: + event = yield self.store.get_event(event_id, allow_none=True) + if not event or event.content.get("membership") != Membership.JOIN: # We only care about joins continue if prev_event_id: - prev_event = yield self.store.get_event(prev_event_id) - if prev_event.content.get("membership") == Membership.JOIN: + prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + if ( + prev_event + and prev_event.content.get("membership") == Membership.JOIN + ): # Ignore changes to join events. continue diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 0e92b405ba..7ad16c8566 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -115,6 +115,7 @@ class StatsHandler(StateDeltasHandler): event_id = delta["event_id"] stream_id = delta["stream_id"] prev_event_id = delta["prev_event_id"] + stream_pos = delta["stream_id"] logger.debug("Handling: %r %r, %s", typ, state_key, event_id) @@ -136,10 +137,15 @@ class StatsHandler(StateDeltasHandler): event_content = {} if event_id is not None: - event_content = (yield self.store.get_event(event_id)).content or {} + event = yield self.store.get_event(event_id, allow_none=True) + if event: + event_content = event.content or {} + + # We use stream_pos here rather than fetch by event_id as event_id + # may be None + now = yield self.store.get_received_ts_by_stream_pos(stream_pos) # quantise time to the nearest bucket - now = yield self.store.get_received_ts(event_id) now = (now // 1000 // self.stats_bucket_size) * self.stats_bucket_size if typ == EventTypes.Member: @@ -149,9 +155,11 @@ class StatsHandler(StateDeltasHandler): # compare them. prev_event_content = {} if prev_event_id is not None: - prev_event_content = ( - yield self.store.get_event(prev_event_id) - ).content + prev_event = yield self.store.get_event( + prev_event_id, allow_none=True, + ) + if prev_event: + prev_event_content = prev_event.content membership = event_content.get("membership", Membership.LEAVE) prev_membership = prev_event_content.get("membership", Membership.LEAVE) diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 1782428048..cc7df5cf14 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -78,6 +78,43 @@ class EventsWorkerStore(SQLBaseStore): desc="get_received_ts", ) + def get_received_ts_by_stream_pos(self, stream_ordering): + """Given a stream ordering get an approximate timestamp of when it + happened. + + This is done by simply taking the received ts of the first event that + has a stream ordering greater than or equal to the given stream pos. + If none exists returns the current time, on the assumption that it must + have happened recently. + + Args: + stream_ordering (int) + + Returns: + Deferred[int] + """ + + def _get_approximate_received_ts_txn(txn): + sql = """ + SELECT received_ts FROM events + WHERE stream_ordering >= ? + LIMIT 1 + """ + + txn.execute(sql, (stream_ordering,)) + row = txn.fetchone() + if row and row[0]: + ts = row[0] + else: + ts = self.clock.time_msec() + + return ts + + return self.runInteraction( + "get_approximate_received_ts", + _get_approximate_received_ts_txn, + ) + @defer.inlineCallbacks def get_event( self, diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 249aba3d59..2710c991cf 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -204,7 +204,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): "a2": {"membership": "not a real thing"}, } - def get_event(event_id): + def get_event(event_id, allow_none=True): m = Mock() m.content = events[event_id] d = defer.Deferred() @@ -224,7 +224,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): "room_id": "room", "event_id": "a1", "prev_event_id": "a2", - "stream_id": "bleb", + "stream_id": 60, } ] @@ -241,7 +241,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): "room_id": "room", "event_id": "a2", "prev_event_id": "a1", - "stream_id": "bleb", + "stream_id": 100, } ] @@ -249,3 +249,59 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual( f.value.args[0], "'not a real thing' is not a valid membership" ) + + def test_redacted_prev_event(self): + """ + If the prev_event does not exist, then it is assumed to be a LEAVE. + """ + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + + # Do the initial population of the user directory via the background update + self._add_background_updates() + + while not self.get_success(self.store.has_completed_background_updates()): + self.get_success(self.store.do_next_background_update(100), by=0.1) + + events = { + "a1": None, + "a2": {"membership": Membership.JOIN}, + } + + def get_event(event_id, allow_none=True): + if events.get(event_id): + m = Mock() + m.content = events[event_id] + else: + m = None + d = defer.Deferred() + self.reactor.callLater(0.0, d.callback, m) + return d + + def get_received_ts(event_id): + return defer.succeed(1) + + self.store.get_received_ts = get_received_ts + self.store.get_event = get_event + + deltas = [ + { + "type": EventTypes.Member, + "state_key": "some_user:test", + "room_id": room_1, + "event_id": "a2", + "prev_event_id": "a1", + "stream_id": 100, + } + ] + + # Handle our fake deltas, which has a user going from LEAVE -> JOIN. + self.get_success(self.handler._handle_deltas(deltas)) + + # One delta, with two joined members -- the room creator, and our fake + # user. + r = self.get_success(self.store.get_deltas_for_room(room_1, 0)) + self.assertEqual(len(r), 1) + self.assertEqual(r[0]["joined_members"], 2) -- cgit 1.5.1