From b849e46139675c3098fdaca8ceff6b76be3f2f02 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Thu, 7 Jan 2021 23:01:59 +0200 Subject: Add forward extremities endpoint to rooms admin API GET /_synapse/admin/v1/rooms//forward_extremities now gets forward extremities for a room, returning count and the list of extremities. Signed-off-by: Jason Robinson --- synapse/storage/databases/main/__init__.py | 2 ++ .../databases/main/events_forward_extremities.py | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+) create mode 100644 synapse/storage/databases/main/events_forward_extremities.py (limited to 'synapse/storage/databases') diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index c4de07a0a8..93b25af057 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -43,6 +43,7 @@ from .end_to_end_keys import EndToEndKeyStore from .event_federation import EventFederationStore from .event_push_actions import EventPushActionsStore from .events_bg_updates import EventsBackgroundUpdatesStore +from .events_forward_extremities import EventForwardExtremitiesStore from .filtering import FilteringStore from .group_server import GroupServerStore from .keys import KeyStore @@ -118,6 +119,7 @@ class DataStore( UIAuthStore, CacheInvalidationWorkerStore, ServerMetricsStore, + EventForwardExtremitiesStore, ): def __init__(self, database: DatabasePool, db_conn, hs): self.hs = hs diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py new file mode 100644 index 0000000000..250a424cc0 --- /dev/null +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -0,0 +1,20 @@ +from typing import List, Dict + +from synapse.storage._base import SQLBaseStore + + +class EventForwardExtremitiesStore(SQLBaseStore): + async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]: + def get_forward_extremities_for_room_txn(txn): + sql = ( + "SELECT event_id, state_group FROM event_forward_extremities NATURAL JOIN event_to_state_groups " + "WHERE room_id = ?" + ) + + txn.execute(sql, (room_id,)) + rows = txn.fetchall() + return [{"event_id": row[0], "state_group": row[1]} for row in rows] + + return await self.db_pool.runInteraction( + "get_forward_extremities_for_room", get_forward_extremities_for_room_txn + ) -- cgit 1.5.1 From 85c0999bfb70f2e8438a9730b8858e7845027190 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Fri, 8 Jan 2021 00:12:23 +0200 Subject: Add Rooms admin forward extremities DELETE endpoint Signed-off-by: Jason Robinson --- synapse/rest/admin/rooms.py | 5 +++ .../databases/main/events_forward_extremities.py | 49 +++++++++++++++++++++- 2 files changed, 53 insertions(+), 1 deletion(-) (limited to 'synapse/storage/databases') diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 1f7b7daea9..76f8603821 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -543,6 +543,11 @@ class ForwardExtremitiesRestServlet(RestServlet): room_id = await self.resolve_room_id(room_identifier) + deleted_count = await self.store.delete_forward_extremities_for_room(room_id) + return 200, { + "deleted": deleted_count, + } + async def on_GET(self, request, room_identifier): requester = await self.auth.get_user_by_req(request) await assert_user_is_admin(self.auth, requester.user) diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py index 250a424cc0..cc684a94fe 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -4,7 +4,54 @@ from synapse.storage._base import SQLBaseStore class EventForwardExtremitiesStore(SQLBaseStore): + + async def delete_forward_extremities_for_room(self, room_id: str) -> int: + """Delete any extra forward extremities for a room. + + Returns count deleted. + """ + def delete_forward_extremities_for_room_txn(txn): + # First we need to get the event_id to not delete + sql = ( + "SELECT " + " last_value(event_id) OVER w AS event_id" + " FROM event_forward_extremities" + " NATURAL JOIN events" + " where room_id = ?" + " WINDOW w AS (" + " PARTITION BY room_id" + " ORDER BY stream_ordering" + " range between unbounded preceding and unbounded following" + " )" + " ORDER BY stream_ordering" + ) + txn.execute(sql, (room_id,)) + rows = txn.fetchall() + + # TODO: should this raise a SynapseError instead of better to blow? + event_id = rows[0][0] + + # Now delete the extra forward extremities + sql = ( + "DELETE FROM event_forward_extremities " + "WHERE" + " event_id != ?" + " AND room_id = ?" + ) + + # TODO we should not commit yet + txn.execute(sql, (event_id, room_id)) + + # TODO flush the cache then commit + + return txn.rowcount + + return await self.db_pool.runInteraction( + "delete_forward_extremities_for_room", delete_forward_extremities_for_room_txn, + ) + async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]: + """Get list of forward extremities for a room.""" def get_forward_extremities_for_room_txn(txn): sql = ( "SELECT event_id, state_group FROM event_forward_extremities NATURAL JOIN event_to_state_groups " @@ -16,5 +63,5 @@ class EventForwardExtremitiesStore(SQLBaseStore): return [{"event_id": row[0], "state_group": row[1]} for row in rows] return await self.db_pool.runInteraction( - "get_forward_extremities_for_room", get_forward_extremities_for_room_txn + "get_forward_extremities_for_room", get_forward_extremities_for_room_txn, ) -- cgit 1.5.1 From 90ad4d443a109ad95741b499d914006578acceef Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Sat, 9 Jan 2021 21:57:41 +0200 Subject: Implement clearing cache after deleting forward extremities Also run linter. Signed-off-by: Jason Robinson --- synapse/rest/admin/rooms.py | 21 +++++------ .../databases/main/events_forward_extremities.py | 41 +++++++++++++++++----- 2 files changed, 42 insertions(+), 20 deletions(-) (limited to 'synapse/storage/databases') diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 76f8603821..6757a8100b 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -524,18 +524,20 @@ class ForwardExtremitiesRestServlet(RestServlet): async def resolve_room_id(self, room_identifier: str) -> str: """Resolve to a room ID, if necessary.""" if RoomID.is_valid(room_identifier): - room_id = room_identifier + resolved_room_id = room_identifier elif RoomAlias.is_valid(room_identifier): room_alias = RoomAlias.from_string(room_identifier) room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias) - room_id = room_id.to_string() + resolved_room_id = room_id.to_string() else: raise SynapseError( 400, "%s was not legal room ID or room alias" % (room_identifier,) ) - if not room_id: - raise SynapseError(400, "Unknown room ID or room alias %s" % room_identifier) - return room_id + if not resolved_room_id: + raise SynapseError( + 400, "Unknown room ID or room alias %s" % room_identifier + ) + return resolved_room_id async def on_DELETE(self, request, room_identifier): requester = await self.auth.get_user_by_req(request) @@ -544,9 +546,7 @@ class ForwardExtremitiesRestServlet(RestServlet): room_id = await self.resolve_room_id(room_identifier) deleted_count = await self.store.delete_forward_extremities_for_room(room_id) - return 200, { - "deleted": deleted_count, - } + return 200, {"deleted": deleted_count} async def on_GET(self, request, room_identifier): requester = await self.auth.get_user_by_req(request) @@ -555,7 +555,4 @@ class ForwardExtremitiesRestServlet(RestServlet): room_id = await self.resolve_room_id(room_identifier) extremities = await self.store.get_forward_extremities_for_room(room_id) - return 200, { - "count": len(extremities), - "results": extremities, - } + return 200, {"count": len(extremities), "results": extremities} diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py index cc684a94fe..6b8da52fee 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -1,15 +1,22 @@ -from typing import List, Dict +import logging +from typing import Dict, List +from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore +logger = logging.getLogger(__name__) -class EventForwardExtremitiesStore(SQLBaseStore): +class EventForwardExtremitiesStore(SQLBaseStore): async def delete_forward_extremities_for_room(self, room_id: str) -> int: """Delete any extra forward extremities for a room. + Invalidates the "get_latest_event_ids_in_room" cache if any forward + extremities were deleted. + Returns count deleted. """ + def delete_forward_extremities_for_room_txn(txn): # First we need to get the event_id to not delete sql = ( @@ -27,9 +34,17 @@ class EventForwardExtremitiesStore(SQLBaseStore): ) txn.execute(sql, (room_id,)) rows = txn.fetchall() - - # TODO: should this raise a SynapseError instead of better to blow? - event_id = rows[0][0] + try: + event_id = rows[0][0] + logger.debug( + "Found event_id %s as the forward extremity to keep for room %s", + event_id, + room_id, + ) + except KeyError: + msg = f"No forward extremity event found for room {room_id}" + logger.warning(msg) + raise SynapseError(400, msg) # Now delete the extra forward extremities sql = ( @@ -39,19 +54,29 @@ class EventForwardExtremitiesStore(SQLBaseStore): " AND room_id = ?" ) - # TODO we should not commit yet txn.execute(sql, (event_id, room_id)) + logger.info( + "Deleted %s extra forward extremities for room %s", + txn.rowcount, + room_id, + ) - # TODO flush the cache then commit + if txn.rowcount > 0: + # Invalidate the cache + self._invalidate_cache_and_stream( + txn, self.get_latest_event_ids_in_room, (room_id,), + ) return txn.rowcount return await self.db_pool.runInteraction( - "delete_forward_extremities_for_room", delete_forward_extremities_for_room_txn, + "delete_forward_extremities_for_room", + delete_forward_extremities_for_room_txn, ) async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]: """Get list of forward extremities for a room.""" + def get_forward_extremities_for_room_txn(txn): sql = ( "SELECT event_id, state_group FROM event_forward_extremities NATURAL JOIN event_to_state_groups " -- cgit 1.5.1 From b52fb703f788b3de3afa1142852354b876f6bacf Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Mon, 11 Jan 2021 09:47:03 +0200 Subject: Don't try to use f-strings Signed-off-by: Jason Robinson --- synapse/storage/databases/main/events_forward_extremities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage/databases') diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py index 6b8da52fee..83f751cf5b 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -42,7 +42,7 @@ class EventForwardExtremitiesStore(SQLBaseStore): room_id, ) except KeyError: - msg = f"No forward extremity event found for room {room_id}" + msg = "No forward extremity event found for room %s" % room_id logger.warning(msg) raise SynapseError(400, msg) -- cgit 1.5.1 From da16d06301aec83d144812d727c24192eb890c93 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Mon, 11 Jan 2021 23:43:58 +0200 Subject: Address pr feedback * docs updates * prettify SQL * add missing copyright * cursor_to_dict * update touched files copyright years Signed-off-by: Jason Robinson --- docs/admin_api/rooms.md | 12 +--- synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/rooms.py | 2 +- synapse/storage/databases/main/__init__.py | 2 +- .../databases/main/events_forward_extremities.py | 64 +++++++++++++--------- 5 files changed, 46 insertions(+), 36 deletions(-) (limited to 'synapse/storage/databases') diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 1d59bb5c4b..86daa393a7 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -516,11 +516,8 @@ optionally be specified, e.g.: # Forward Extremities Admin API Enables querying and deleting forward extremities from rooms. When a lot of forward -extremities accumulate in a room, performance can become degraded. - -When using this API endpoint to delete any extra forward extremities for a room, -the server does not need to be restarted as the relevant caches will be cleared -in the API call. +extremities accumulate in a room, performance can become degraded. For details, see +[#1760](https://github.com/matrix-org/synapse/issues/1760). ## Check for forward extremities @@ -537,7 +534,7 @@ A response as follows will be returned: "count": 1, "results": [ { - "event_id": "$M5SP266vsnxctfwFgFLNceaCo3ujhRtg_NiiHabcdfgh", + "event_id": "$M5SP266vsnxctfwFgFLNceaCo3ujhRtg_NiiHabcdefgh", "state_group": 439 } ] @@ -561,6 +558,3 @@ that were deleted. "deleted": 1 } ``` - -The cache `get_latest_event_ids_in_room` will be invalidated, if any forward extremities -were deleted. diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index b80b036090..319ad7bf7f 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018-2019 New Vector Ltd +# Copyright 2020, 2021 The Matrix.org Foundation C.I.C. + # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 6757a8100b..da1499cab3 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 93b25af057..b936f54f1e 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py index 83f751cf5b..e6c2d6e122 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -1,3 +1,18 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging from typing import Dict, List @@ -19,19 +34,19 @@ class EventForwardExtremitiesStore(SQLBaseStore): def delete_forward_extremities_for_room_txn(txn): # First we need to get the event_id to not delete - sql = ( - "SELECT " - " last_value(event_id) OVER w AS event_id" - " FROM event_forward_extremities" - " NATURAL JOIN events" - " where room_id = ?" - " WINDOW w AS (" - " PARTITION BY room_id" - " ORDER BY stream_ordering" - " range between unbounded preceding and unbounded following" - " )" - " ORDER BY stream_ordering" - ) + sql = """ + SELECT + last_value(event_id) OVER w AS event_id + FROM event_forward_extremities + NATURAL JOIN events + WHERE room_id = ? + WINDOW w AS ( + PARTITION BY room_id + ORDER BY stream_ordering + range between unbounded preceding and unbounded following + ) + ORDER BY stream_ordering + """ txn.execute(sql, (room_id,)) rows = txn.fetchall() try: @@ -47,12 +62,10 @@ class EventForwardExtremitiesStore(SQLBaseStore): raise SynapseError(400, msg) # Now delete the extra forward extremities - sql = ( - "DELETE FROM event_forward_extremities " - "WHERE" - " event_id != ?" - " AND room_id = ?" - ) + sql = """ + DELETE FROM event_forward_extremities + WHERE event_id != ? AND room_id = ? + """ txn.execute(sql, (event_id, room_id)) logger.info( @@ -78,14 +91,15 @@ class EventForwardExtremitiesStore(SQLBaseStore): """Get list of forward extremities for a room.""" def get_forward_extremities_for_room_txn(txn): - sql = ( - "SELECT event_id, state_group FROM event_forward_extremities NATURAL JOIN event_to_state_groups " - "WHERE room_id = ?" - ) + sql = """ + SELECT event_id, state_group + FROM event_forward_extremities + NATURAL JOIN event_to_state_groups + WHERE room_id = ? + """ txn.execute(sql, (room_id,)) - rows = txn.fetchall() - return [{"event_id": row[0], "state_group": row[1]} for row in rows] + return self.db_pool.cursor_to_dict(txn) return await self.db_pool.runInteraction( "get_forward_extremities_for_room", get_forward_extremities_for_room_txn, -- cgit 1.5.1 From 49c619a9a2203da61f496fe6e3ae308be87efda8 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Mon, 11 Jan 2021 23:49:58 +0200 Subject: Simplify delete_forward_extremities_for_room_txn SQL As per feedback. Signed-off-by: Jason Robinson --- .../storage/databases/main/events_forward_extremities.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) (limited to 'synapse/storage/databases') diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py index e6c2d6e122..c7ec08469d 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -35,17 +35,11 @@ class EventForwardExtremitiesStore(SQLBaseStore): def delete_forward_extremities_for_room_txn(txn): # First we need to get the event_id to not delete sql = """ - SELECT - last_value(event_id) OVER w AS event_id - FROM event_forward_extremities - NATURAL JOIN events + SELECT event_id FROM event_forward_extremities + INNER JOIN events USING (room_id, event_id) WHERE room_id = ? - WINDOW w AS ( - PARTITION BY room_id - ORDER BY stream_ordering - range between unbounded preceding and unbounded following - ) - ORDER BY stream_ordering + ORDER BY stream_ordering DESC + LIMIT 1 """ txn.execute(sql, (room_id,)) rows = txn.fetchall() -- cgit 1.5.1 From c177faf5a92d8ef02dd59e16dcf6ca9fb5ca6a33 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Mon, 11 Jan 2021 23:55:44 +0200 Subject: Remove trailing whitespace to appease the linter Signed-off-by: Jason Robinson --- synapse/storage/databases/main/events_forward_extremities.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/storage/databases') diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py index c7ec08469d..5fea974050 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -86,8 +86,8 @@ class EventForwardExtremitiesStore(SQLBaseStore): def get_forward_extremities_for_room_txn(txn): sql = """ - SELECT event_id, state_group - FROM event_forward_extremities + SELECT event_id, state_group + FROM event_forward_extremities NATURAL JOIN event_to_state_groups WHERE room_id = ? """ -- cgit 1.5.1 From 0cd2938bc854d947ae8102ded688a626c9fac5b5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 20 Jan 2021 13:15:14 +0000 Subject: Support icons for Identity Providers (#9154) --- changelog.d/9154.feature | 1 + docs/sample_config.yaml | 4 ++ mypy.ini | 1 + synapse/config/oidc_config.py | 20 ++++++ synapse/config/server.py | 2 +- synapse/federation/federation_server.py | 2 +- synapse/federation/transport/server.py | 2 +- synapse/handlers/cas_handler.py | 4 ++ synapse/handlers/oidc_handler.py | 3 + synapse/handlers/room.py | 2 +- synapse/handlers/saml_handler.py | 4 ++ synapse/handlers/sso.py | 5 ++ synapse/http/endpoint.py | 79 --------------------- synapse/res/templates/sso_login_idp_picker.html | 3 + synapse/rest/client/v1/room.py | 3 +- synapse/storage/databases/main/room.py | 6 +- synapse/types.py | 2 +- synapse/util/stringutils.py | 92 +++++++++++++++++++++++++ tests/http/test_endpoint.py | 2 +- 19 files changed, 146 insertions(+), 91 deletions(-) create mode 100644 changelog.d/9154.feature delete mode 100644 synapse/http/endpoint.py (limited to 'synapse/storage/databases') diff --git a/changelog.d/9154.feature b/changelog.d/9154.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9154.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 7fdd798d70..b49a5da8cc 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1726,6 +1726,10 @@ saml2_config: # idp_name: A user-facing name for this identity provider, which is used to # offer the user a choice of login mechanisms. # +# idp_icon: An optional icon for this identity provider, which is presented +# by identity picker pages. If given, must be an MXC URI of the format +# mxc:/// +# # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. # diff --git a/mypy.ini b/mypy.ini index b996867121..bd99069c81 100644 --- a/mypy.ini +++ b/mypy.ini @@ -100,6 +100,7 @@ files = synapse/util/async_helpers.py, synapse/util/caches, synapse/util/metrics.py, + synapse/util/stringutils.py, tests/replication, tests/test_utils, tests/handlers/test_password_providers.py, diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index df55367434..f257fcd412 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -23,6 +23,7 @@ from synapse.config._util import validate_config from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import Collection, JsonDict from synapse.util.module_loader import load_module +from synapse.util.stringutils import parse_and_validate_mxc_uri from ._base import Config, ConfigError @@ -66,6 +67,10 @@ class OIDCConfig(Config): # idp_name: A user-facing name for this identity provider, which is used to # offer the user a choice of login mechanisms. # + # idp_icon: An optional icon for this identity provider, which is presented + # by identity picker pages. If given, must be an MXC URI of the format + # mxc:/// + # # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. # @@ -207,6 +212,7 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "properties": { "idp_id": {"type": "string", "minLength": 1, "maxLength": 128}, "idp_name": {"type": "string"}, + "idp_icon": {"type": "string"}, "discover": {"type": "boolean"}, "issuer": {"type": "string"}, "client_id": {"type": "string"}, @@ -336,9 +342,20 @@ def _parse_oidc_config_dict( config_path + ("idp_id",), ) + # MSC2858 also specifies that the idp_icon must be a valid MXC uri + idp_icon = oidc_config.get("idp_icon") + if idp_icon is not None: + try: + parse_and_validate_mxc_uri(idp_icon) + except ValueError as e: + raise ConfigError( + "idp_icon must be a valid MXC URI", config_path + ("idp_icon",) + ) from e + return OidcProviderConfig( idp_id=idp_id, idp_name=oidc_config.get("idp_name", "OIDC"), + idp_icon=idp_icon, discover=oidc_config.get("discover", True), issuer=oidc_config["issuer"], client_id=oidc_config["client_id"], @@ -366,6 +383,9 @@ class OidcProviderConfig: # user-facing name for this identity provider. idp_name = attr.ib(type=str) + # Optional MXC URI for icon for this IdP. + idp_icon = attr.ib(type=Optional[str]) + # whether the OIDC discovery mechanism is used to discover endpoints discover = attr.ib(type=bool) diff --git a/synapse/config/server.py b/synapse/config/server.py index 75ba161f35..47a0370173 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -26,7 +26,7 @@ import yaml from netaddr import IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.http.endpoint import parse_and_validate_server_name +from synapse.util.stringutils import parse_and_validate_server_name from ._base import Config, ConfigError diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e5339aca23..171d25c945 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -49,7 +49,6 @@ from synapse.events import EventBase from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction -from synapse.http.endpoint import parse_server_name from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, @@ -66,6 +65,7 @@ from synapse.types import JsonDict, get_domain_from_id from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import parse_server_name if TYPE_CHECKING: from synapse.server import HomeServer diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index cfd094e58f..95c64510a9 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -28,7 +28,6 @@ from synapse.api.urls import ( FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX, ) -from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource from synapse.http.servlet import ( parse_boolean_from_args, @@ -45,6 +44,7 @@ from synapse.logging.opentracing import ( ) from synapse.server import HomeServer from synapse.types import ThirdPartyInstanceID, get_domain_from_id +from synapse.util.stringutils import parse_and_validate_server_name from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index f3430c6713..0f342c607b 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -80,6 +80,10 @@ class CasHandler: # user-facing name of this auth provider self.idp_name = "CAS" + # we do not currently support icons for CAS auth, but this is required by + # the SsoIdentityProvider protocol type. + self.idp_icon = None + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index ba686d74b2..1607e12935 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -271,6 +271,9 @@ class OidcProvider: # user-facing name of this auth provider self.idp_name = provider.idp_name + # MXC URI for icon for this auth provider + self.idp_icon = provider.idp_icon + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3bece6d668..ee27d99135 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -38,7 +38,6 @@ from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents -from synapse.http.endpoint import parse_and_validate_server_name from synapse.storage.state import StateFilter from synapse.types import ( JsonDict, @@ -55,6 +54,7 @@ from synapse.types import ( from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import parse_and_validate_server_name from synapse.visibility import filter_events_for_client from ._base import BaseHandler diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index a8376543c9..38461cf79d 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -78,6 +78,10 @@ class SamlHandler(BaseHandler): # user-facing name of this auth provider self.idp_name = "SAML" + # we do not currently support icons for SAML auth, but this is required by + # the SsoIdentityProvider protocol type. + self.idp_icon = None + # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index dcc85e9871..d493327a10 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -75,6 +75,11 @@ class SsoIdentityProvider(Protocol): def idp_name(self) -> str: """User-facing name for this provider""" + @property + def idp_icon(self) -> Optional[str]: + """Optional MXC URI for user-facing icon""" + return None + @abc.abstractmethod async def handle_redirect_request( self, diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py deleted file mode 100644 index 92a5b606c8..0000000000 --- a/synapse/http/endpoint.py +++ /dev/null @@ -1,79 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import re - -logger = logging.getLogger(__name__) - - -def parse_server_name(server_name): - """Split a server name into host/port parts. - - Args: - server_name (str): server name to parse - - Returns: - Tuple[str, int|None]: host/port parts. - - Raises: - ValueError if the server name could not be parsed. - """ - try: - if server_name[-1] == "]": - # ipv6 literal, hopefully - return server_name, None - - domain_port = server_name.rsplit(":", 1) - domain = domain_port[0] - port = int(domain_port[1]) if domain_port[1:] else None - return domain, port - except Exception: - raise ValueError("Invalid server name '%s'" % server_name) - - -VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") - - -def parse_and_validate_server_name(server_name): - """Split a server name into host/port parts and do some basic validation. - - Args: - server_name (str): server name to parse - - Returns: - Tuple[str, int|None]: host/port parts. - - Raises: - ValueError if the server name could not be parsed. - """ - host, port = parse_server_name(server_name) - - # these tests don't need to be bulletproof as we'll find out soon enough - # if somebody is giving us invalid data. What we *do* need is to be sure - # that nobody is sneaking IP literals in that look like hostnames, etc. - - # look for ipv6 literals - if host[0] == "[": - if host[-1] != "]": - raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) - return host, port - - # otherwise it should only be alphanumerics. - if not VALID_HOST_REGEX.match(host): - raise ValueError( - "Server name '%s' contains invalid characters" % (server_name,) - ) - - return host, port diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html index f53c9cd679..5b38481012 100644 --- a/synapse/res/templates/sso_login_idp_picker.html +++ b/synapse/res/templates/sso_login_idp_picker.html @@ -17,6 +17,9 @@
  • +{% if p.idp_icon %} + +{% endif %}
  • {% endfor %} diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index e6725b03b0..f95627ee61 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -32,7 +32,6 @@ from synapse.api.errors import ( ) from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2 -from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -47,7 +46,7 @@ from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from synapse.util import json_decoder -from synapse.util.stringutils import random_string +from synapse.util.stringutils import parse_and_validate_server_name, random_string if TYPE_CHECKING: import synapse.server diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 284f2ce77c..a9fcb5f59c 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -16,7 +16,6 @@ import collections import logging -import re from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple @@ -30,6 +29,7 @@ from synapse.storage.databases.main.search import SearchStore from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached +from synapse.util.stringutils import MXC_REGEX logger = logging.getLogger(__name__) @@ -660,8 +660,6 @@ class RoomWorkerStore(SQLBaseStore): The local and remote media as a lists of tuples where the key is the hostname and the value is the media ID. """ - mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") - sql = """ SELECT stream_ordering, json FROM events JOIN event_json USING (room_id, event_id) @@ -688,7 +686,7 @@ class RoomWorkerStore(SQLBaseStore): for url in (content_url, thumbnail_url): if not url: continue - matches = mxc_re.match(url) + matches = MXC_REGEX.match(url) if matches: hostname = matches.group(1) media_id = matches.group(2) diff --git a/synapse/types.py b/synapse/types.py index 20a43d05bf..eafe729dfe 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -37,7 +37,7 @@ from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 from synapse.api.errors import Codes, SynapseError -from synapse.http.endpoint import parse_and_validate_server_name +from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: from synapse.appservice.api import ApplicationService diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index b103c8694c..f8038bf861 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -18,6 +18,7 @@ import random import re import string from collections.abc import Iterable +from typing import Optional, Tuple from synapse.api.errors import Codes, SynapseError @@ -26,6 +27,15 @@ _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$") +# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris, +# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically +# says "there is no grammar for media ids" +# +# The server_name part of this is purposely lax: use parse_and_validate_mxc for +# additional validation. +# +MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$") + # random_string and random_string_with_symbols are used for a range of things, # some cryptographically important, some less so. We use SystemRandom to make sure # we get cryptographically-secure randoms. @@ -59,6 +69,88 @@ def assert_valid_client_secret(client_secret): ) +def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]: + """Split a server name into host/port parts. + + Args: + server_name: server name to parse + + Returns: + host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + try: + if server_name[-1] == "]": + # ipv6 literal, hopefully + return server_name, None + + domain_port = server_name.rsplit(":", 1) + domain = domain_port[0] + port = int(domain_port[1]) if domain_port[1:] else None + return domain, port + except Exception: + raise ValueError("Invalid server name '%s'" % server_name) + + +VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") + + +def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]: + """Split a server name into host/port parts and do some basic validation. + + Args: + server_name: server name to parse + + Returns: + host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + host, port = parse_server_name(server_name) + + # these tests don't need to be bulletproof as we'll find out soon enough + # if somebody is giving us invalid data. What we *do* need is to be sure + # that nobody is sneaking IP literals in that look like hostnames, etc. + + # look for ipv6 literals + if host[0] == "[": + if host[-1] != "]": + raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) + return host, port + + # otherwise it should only be alphanumerics. + if not VALID_HOST_REGEX.match(host): + raise ValueError( + "Server name '%s' contains invalid characters" % (server_name,) + ) + + return host, port + + +def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]: + """Parse the given string as an MXC URI + + Checks that the "server name" part is a valid server name + + Args: + mxc: the (alleged) MXC URI to be checked + Returns: + hostname, port, media id + Raises: + ValueError if the URI cannot be parsed + """ + m = MXC_REGEX.match(mxc) + if not m: + raise ValueError("mxc URI %r did not match expected format" % (mxc,)) + server_name = m.group(1) + media_id = m.group(2) + host, port = parse_and_validate_server_name(server_name) + return host, port, media_id + + def shortstr(iterable: Iterable, maxitems: int = 5) -> str: """If iterable has maxitems or fewer, return the stringification of a list containing those items. diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index b2e9533b07..d06ea518ce 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name +from synapse.util.stringutils import parse_and_validate_server_name, parse_server_name from tests import unittest -- cgit 1.5.1 From eee6fcf5fa857af95c46185fc11d540343c77d2d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 21 Jan 2021 10:22:53 +0000 Subject: Use execute_batch instead of executemany in places (#9181) `execute_batch` does fewer round trips in postgres than `executemany`, but does not give a correct `txn.rowcount` result after. --- changelog.d/9181.misc | 1 + synapse/storage/database.py | 5 ++--- synapse/storage/databases/main/events.py | 18 +++++++++--------- 3 files changed, 12 insertions(+), 12 deletions(-) create mode 100644 changelog.d/9181.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/9181.misc b/changelog.d/9181.misc new file mode 100644 index 0000000000..7820d09cd0 --- /dev/null +++ b/changelog.d/9181.misc @@ -0,0 +1 @@ +Speed up batch insertion when using PostgreSQL. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index a19d65ad23..c7220bc778 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -267,8 +267,7 @@ class LoggingTransaction: self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) else: - for val in args: - self.execute(sql, val) + self.executemany(sql, args) def execute_values(self, sql: str, *args: Any) -> List[Tuple]: """Corresponds to psycopg2.extras.execute_values. Only available when @@ -888,7 +887,7 @@ class DatabasePool: ", ".join("?" for _ in keys[0]), ) - txn.executemany(sql, vals) + txn.execute_batch(sql, vals) async def simple_upsert( self, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 3216b3f3c8..5db7d7aaa8 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -876,7 +876,7 @@ class PersistEventsStore: WHERE room_id = ? AND type = ? AND state_key = ? ) """ - txn.executemany( + txn.execute_batch( sql, ( ( @@ -895,7 +895,7 @@ class PersistEventsStore: ) # Now we actually update the current_state_events table - txn.executemany( + txn.execute_batch( "DELETE FROM current_state_events" " WHERE room_id = ? AND type = ? AND state_key = ?", ( @@ -907,7 +907,7 @@ class PersistEventsStore: # We include the membership in the current state table, hence we do # a lookup when we insert. This assumes that all events have already # been inserted into room_memberships. - txn.executemany( + txn.execute_batch( """INSERT INTO current_state_events (room_id, type, state_key, event_id, membership) VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) @@ -927,7 +927,7 @@ class PersistEventsStore: # we have no record of the fact the user *was* a member of the # room but got, say, state reset out of it. if to_delete or to_insert: - txn.executemany( + txn.execute_batch( "DELETE FROM local_current_membership" " WHERE room_id = ? AND user_id = ?", ( @@ -938,7 +938,7 @@ class PersistEventsStore: ) if to_insert: - txn.executemany( + txn.execute_batch( """INSERT INTO local_current_membership (room_id, user_id, event_id, membership) VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) @@ -1738,7 +1738,7 @@ class PersistEventsStore: """ if events_and_contexts: - txn.executemany( + txn.execute_batch( sql, ( ( @@ -1767,7 +1767,7 @@ class PersistEventsStore: # Now we delete the staging area for *all* events that were being # persisted. - txn.executemany( + txn.execute_batch( "DELETE FROM event_push_actions_staging WHERE event_id = ?", ((event.event_id,) for event, _ in all_events_and_contexts), ) @@ -1886,7 +1886,7 @@ class PersistEventsStore: " )" ) - txn.executemany( + txn.execute_batch( query, [ (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) @@ -1900,7 +1900,7 @@ class PersistEventsStore: "DELETE FROM event_backward_extremities" " WHERE event_id = ? AND room_id = ?" ) - txn.executemany( + txn.execute_batch( query, [ (ev.event_id, ev.room_id) -- cgit 1.5.1 From 7a43482f1916622967f5a4b389f93944dd5deb07 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 21 Jan 2021 14:44:12 +0000 Subject: Use execute_batch in more places (#9188) * Use execute_batch in more places * Newsfile --- changelog.d/9188.misc | 1 + synapse/storage/database.py | 6 ++++++ synapse/storage/databases/main/devices.py | 4 ++-- synapse/storage/databases/main/event_push_actions.py | 4 ++-- synapse/storage/databases/main/events_bg_updates.py | 12 ++---------- synapse/storage/databases/main/media_repository.py | 10 +++++----- synapse/storage/databases/main/purge_events.py | 2 +- synapse/storage/databases/main/registration.py | 2 +- synapse/storage/databases/main/roommember.py | 6 +----- .../storage/databases/main/schema/delta/59/01ignored_user.py | 2 +- synapse/storage/databases/main/search.py | 4 ++-- synapse/storage/databases/state/store.py | 4 ++-- 12 files changed, 26 insertions(+), 31 deletions(-) create mode 100644 changelog.d/9188.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/9188.misc b/changelog.d/9188.misc new file mode 100644 index 0000000000..7820d09cd0 --- /dev/null +++ b/changelog.d/9188.misc @@ -0,0 +1 @@ +Speed up batch insertion when using PostgreSQL. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index c7220bc778..d2ba4bd2fc 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -262,6 +262,12 @@ class LoggingTransaction: return self.txn.description def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None: + """Similar to `executemany`, except `txn.rowcount` will not be correct + afterwards. + + More efficient than `executemany` on PostgreSQL + """ + if isinstance(self.database_engine, PostgresEngine): from psycopg2.extras import execute_batch # type: ignore diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 9097677648..659d8f245f 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -897,7 +897,7 @@ class DeviceWorkerStore(SQLBaseStore): DELETE FROM device_lists_outbound_last_success WHERE destination = ? AND user_id = ? """ - txn.executemany(sql, ((row[0], row[1]) for row in rows)) + txn.execute_batch(sql, ((row[0], row[1]) for row in rows)) logger.info("Pruned %d device list outbound pokes", count) @@ -1343,7 +1343,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Delete older entries in the table, as we really only care about # when the latest change happened. - txn.executemany( + txn.execute_batch( """ DELETE FROM device_lists_stream WHERE user_id = ? AND device_id = ? AND stream_id < ? diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 1b657191a9..438383abe1 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -487,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): VALUES (?, ?, ?, ?, ?, ?) """ - txn.executemany( + txn.execute_batch( sql, ( _gen_entry(user_id, actions) @@ -803,7 +803,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ], ) - txn.executemany( + txn.execute_batch( """ UPDATE event_push_summary SET notif_count = ?, unread_count = ?, stream_ordering = ? diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index e46e44ba54..5ca4fa6817 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -139,8 +139,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - INSERT_CLUMP_SIZE = 1000 - def reindex_txn(txn): sql = ( "SELECT stream_ordering, event_id, json FROM events" @@ -178,9 +176,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" - for index in range(0, len(update_rows), INSERT_CLUMP_SIZE): - clump = update_rows[index : index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) + txn.execute_batch(sql, update_rows) progress = { "target_min_stream_id_inclusive": target_min_stream_id, @@ -210,8 +206,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - INSERT_CLUMP_SIZE = 1000 - def reindex_search_txn(txn): sql = ( "SELECT stream_ordering, event_id FROM events" @@ -256,9 +250,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" - for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE): - clump = rows_to_update[index : index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) + txn.execute_batch(sql, rows_to_update) progress = { "target_min_stream_id_inclusive": target_min_stream_id, diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 283c8a5e22..e017177655 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -417,7 +417,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " WHERE media_origin = ? AND media_id = ?" ) - txn.executemany( + txn.execute_batch( sql, ( (time_ms, media_origin, media_id) @@ -430,7 +430,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " WHERE media_id = ?" ) - txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) + txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media)) return await self.db_pool.runInteraction( "update_cached_last_access_time", update_cache_txn @@ -557,7 +557,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?" def _delete_url_cache_txn(txn): - txn.executemany(sql, [(media_id,) for media_id in media_ids]) + txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) return await self.db_pool.runInteraction( "delete_url_cache", _delete_url_cache_txn @@ -586,11 +586,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def _delete_url_cache_media_txn(txn): sql = "DELETE FROM local_media_repository WHERE media_id = ?" - txn.executemany(sql, [(media_id,) for media_id in media_ids]) + txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?" - txn.executemany(sql, [(media_id,) for media_id in media_ids]) + txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) return await self.db_pool.runInteraction( "delete_url_cache_media", _delete_url_cache_media_txn diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 5d668aadb2..ecfc9f20b1 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): ) # Update backward extremeties - txn.executemany( + txn.execute_batch( "INSERT INTO event_backward_extremities (room_id, event_id)" " VALUES (?, ?)", [(room_id, event_id) for event_id, in new_backwards_extrems], diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 8d05288ed4..585b4049d6 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1104,7 +1104,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): FROM user_threepids """ - txn.executemany(sql, [(id_server,) for id_server in id_servers]) + txn.execute_batch(sql, [(id_server,) for id_server in id_servers]) if id_servers: await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index dcdaf09682..92382bed28 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -873,8 +873,6 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): "max_stream_id_exclusive", self._stream_order_on_start + 1 ) - INSERT_CLUMP_SIZE = 1000 - def add_membership_profile_txn(txn): sql = """ SELECT stream_ordering, event_id, events.room_id, event_json.json @@ -915,9 +913,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): UPDATE room_memberships SET display_name = ?, avatar_url = ? WHERE event_id = ? AND room_id = ? """ - for index in range(0, len(to_update), INSERT_CLUMP_SIZE): - clump = to_update[index : index + INSERT_CLUMP_SIZE] - txn.executemany(to_update_sql, clump) + txn.execute_batch(to_update_sql, to_update) progress = { "target_min_stream_id_inclusive": target_min_stream_id, diff --git a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py index f35c70b699..9e8f35c1d2 100644 --- a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py +++ b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py @@ -55,7 +55,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs # { "ignored_users": "@someone:example.org": {} } ignored_users = content.get("ignored_users", {}) if isinstance(ignored_users, dict) and ignored_users: - cur.executemany(insert_sql, [(user_id, u) for u in ignored_users]) + cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users]) # Add indexes after inserting data for efficiency. logger.info("Adding constraints to ignored_users table") diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index e34fce6281..871af64b11 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -63,7 +63,7 @@ class SearchWorkerStore(SQLBaseStore): for entry in entries ) - txn.executemany(sql, args) + txn.execute_batch(sql, args) elif isinstance(self.database_engine, Sqlite3Engine): sql = ( @@ -75,7 +75,7 @@ class SearchWorkerStore(SQLBaseStore): for entry in entries ) - txn.executemany(sql, args) + txn.execute_batch(sql, args) else: # This should be unreachable. raise Exception("Unrecognized database engine") diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 0e31cc811a..89cdc84a9c 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ) logger.info("[purge] removing redundant state groups") - txn.executemany( + txn.execute_batch( "DELETE FROM state_groups_state WHERE state_group = ?", ((sg,) for sg in state_groups_to_delete), ) - txn.executemany( + txn.execute_batch( "DELETE FROM state_groups WHERE id = ?", ((sg,) for sg in state_groups_to_delete), ) -- cgit 1.5.1 From 2506074ef0a880b527d61457c32cd397a0d3ab2d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 21 Jan 2021 15:09:09 +0000 Subject: Fix receipts or account data not being sent down sync (#9193) Introduced in #9104 This wasn't picked up by the tests as this is all fine the first time you run Synapse (after upgrading), but then when you restart the wrong value is pulled from `stream_positions`. --- changelog.d/9193.bugfix | 1 + synapse/storage/databases/main/account_data.py | 2 +- synapse/storage/databases/main/receipts.py | 4 +- synapse/storage/util/id_generators.py | 6 ++- synapse/storage/util/sequence.py | 56 ++++++++++++++++++++++++-- 5 files changed, 62 insertions(+), 7 deletions(-) create mode 100644 changelog.d/9193.bugfix (limited to 'synapse/storage/databases') diff --git a/changelog.d/9193.bugfix b/changelog.d/9193.bugfix new file mode 100644 index 0000000000..5233ffc3e7 --- /dev/null +++ b/changelog.d/9193.bugfix @@ -0,0 +1 @@ +Fix receipts or account data not being sent down sync. Introduced in v1.26.0rc1. diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 68896f34af..a277a1ef13 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -68,7 +68,7 @@ class AccountDataWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.events: + if hs.get_instance_name() in hs.config.worker.writers.account_data: self._account_data_id_gen = StreamIdGenerator( db_conn, "room_account_data", diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index e0e57f0578..e4843a202c 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -45,7 +45,7 @@ class ReceiptsWorkerStore(SQLBaseStore): self._receipts_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, - stream_name="account_data", + stream_name="receipts", instance_name=self._instance_name, tables=[("receipts_linearized", "instance_name", "stream_id")], sequence_name="receipts_sequence", @@ -61,7 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.events: + if hs.get_instance_name() in hs.config.worker.writers.receipts: self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" ) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 39a3ab1162..bb84c0d792 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -261,7 +261,11 @@ class MultiWriterIdGenerator: # We check that the table and sequence haven't diverged. for table, _, id_column in tables: self._sequence_gen.check_consistency( - db_conn, table=table, id_column=id_column, positive=positive + db_conn, + table=table, + id_column=id_column, + stream_name=stream_name, + positive=positive, ) # This goes and fills out the above state from the database. diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 412df6b8ef..b6fe136fb7 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -45,6 +45,21 @@ and run the following SQL: See docs/postgres.md for more information. """ +_INCONSISTENT_STREAM_ERROR = """ +Postgres sequence '%(seq)s' is inconsistent with associated stream position +of '%(stream_name)s' in the 'stream_positions' table. + +This is likely a programming error and should be reported at +https://github.com/matrix-org/synapse. + +A temporary workaround to fix this error is to shut down Synapse (including +any and all workers) and run the following SQL: + + DELETE FROM stream_positions WHERE stream_name = '%(stream_name)s'; + +This will need to be done every time the server is restarted. +""" + class SequenceGenerator(metaclass=abc.ABCMeta): """A class which generates a unique sequence of integers""" @@ -60,14 +75,20 @@ class SequenceGenerator(metaclass=abc.ABCMeta): db_conn: "LoggingDatabaseConnection", table: str, id_column: str, + stream_name: Optional[str] = None, positive: bool = True, ): """Should be called during start up to test that the current value of the sequence is greater than or equal to the maximum ID in the table. - This is to handle various cases where the sequence value can get out - of sync with the table, e.g. if Synapse gets rolled back to a previous + This is to handle various cases where the sequence value can get out of + sync with the table, e.g. if Synapse gets rolled back to a previous version and the rolled forwards again. + + If a stream name is given then this will check that any value in the + `stream_positions` table is less than or equal to the current sequence + value. If it isn't then it's likely that streams have been crossed + somewhere (e.g. two ID generators have the same stream name). """ ... @@ -93,8 +114,12 @@ class PostgresSequenceGenerator(SequenceGenerator): db_conn: "LoggingDatabaseConnection", table: str, id_column: str, + stream_name: Optional[str] = None, positive: bool = True, ): + """See SequenceGenerator.check_consistency for docstring. + """ + txn = db_conn.cursor(txn_name="sequence.check_consistency") # First we get the current max ID from the table. @@ -118,6 +143,18 @@ class PostgresSequenceGenerator(SequenceGenerator): "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name} ) last_value, is_called = txn.fetchone() + + # If we have an associated stream check the stream_positions table. + max_in_stream_positions = None + if stream_name: + txn.execute( + "SELECT MAX(stream_id) FROM stream_positions WHERE stream_name = ?", + (stream_name,), + ) + row = txn.fetchone() + if row: + max_in_stream_positions = row[0] + txn.close() # If `is_called` is False then `last_value` is actually the value that @@ -138,6 +175,14 @@ class PostgresSequenceGenerator(SequenceGenerator): % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql} ) + # If we have values in the stream positions table then they have to be + # less than or equal to `last_value` + if max_in_stream_positions and max_in_stream_positions > last_value: + raise IncorrectDatabaseSetup( + _INCONSISTENT_STREAM_ERROR + % {"seq": self._sequence_name, "stream": stream_name} + ) + GetFirstCallbackType = Callable[[Cursor], int] @@ -175,7 +220,12 @@ class LocalSequenceGenerator(SequenceGenerator): return self._current_max_id def check_consistency( - self, db_conn: Connection, table: str, id_column: str, positive: bool = True + self, + db_conn: Connection, + table: str, + id_column: str, + stream_name: Optional[str] = None, + positive: bool = True, ): # There is nothing to do for in memory sequences pass -- cgit 1.5.1 From ccfafac88245c806ad5bde1ebe9312ff1032d829 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 21 Jan 2021 16:03:25 +0000 Subject: Add schema update to fix existing DBs affected by #9193 (#9195) --- changelog.d/9195.bugfix | 1 + .../main/schema/delta/59/07shard_account_data_fix.sql | 18 ++++++++++++++++++ synapse/storage/util/sequence.py | 2 +- 3 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9195.bugfix create mode 100644 synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql (limited to 'synapse/storage/databases') diff --git a/changelog.d/9195.bugfix b/changelog.d/9195.bugfix new file mode 100644 index 0000000000..5233ffc3e7 --- /dev/null +++ b/changelog.d/9195.bugfix @@ -0,0 +1 @@ +Fix receipts or account data not being sent down sync. Introduced in v1.26.0rc1. diff --git a/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql b/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql new file mode 100644 index 0000000000..9f2b5ebc5a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/07shard_account_data_fix.sql @@ -0,0 +1,18 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We incorrectly populated these, so we delete them and let the +-- MultiWriterIdGenerator repopulate it. +DELETE FROM stream_positions WHERE stream_name = 'receipts' OR stream_name = 'account_data'; diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index b6fe136fb7..c780ade077 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -180,7 +180,7 @@ class PostgresSequenceGenerator(SequenceGenerator): if max_in_stream_positions and max_in_stream_positions > last_value: raise IncorrectDatabaseSetup( _INCONSISTENT_STREAM_ERROR - % {"seq": self._sequence_name, "stream": stream_name} + % {"seq": self._sequence_name, "stream_name": stream_name} ) -- cgit 1.5.1 From 758ed5f1bc16f4b73d73d94129761a8680fd71c5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 21 Jan 2021 17:00:12 +0000 Subject: Speed up chain cover calculation (#9176) --- changelog.d/9176.misc | 1 + synapse/storage/databases/main/events.py | 199 ++++++++++++++++++++++--------- synapse/storage/util/sequence.py | 16 +++ 3 files changed, 161 insertions(+), 55 deletions(-) create mode 100644 changelog.d/9176.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/9176.misc b/changelog.d/9176.misc new file mode 100644 index 0000000000..9c41d7b0f9 --- /dev/null +++ b/changelog.d/9176.misc @@ -0,0 +1 @@ +Speed up chain cover calculation when persisting a batch of state events at once. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 5db7d7aaa8..ccda9f1caa 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -473,8 +473,9 @@ class PersistEventsStore: txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, ) - @staticmethod + @classmethod def _add_chain_cover_index( + cls, txn, db_pool: DatabasePool, event_to_room_id: Dict[str, str], @@ -614,60 +615,17 @@ class PersistEventsStore: if not events_to_calc_chain_id_for: return - # We now calculate the chain IDs/sequence numbers for the events. We - # do this by looking at the chain ID and sequence number of any auth - # event with the same type/state_key and incrementing the sequence - # number by one. If there was no match or the chain ID/sequence - # number is already taken we generate a new chain. - # - # We need to do this in a topologically sorted order as we want to - # generate chain IDs/sequence numbers of an event's auth events - # before the event itself. - chains_tuples_allocated = set() # type: Set[Tuple[int, int]] - new_chain_tuples = {} # type: Dict[str, Tuple[int, int]] - for event_id in sorted_topologically( - events_to_calc_chain_id_for, event_to_auth_chain - ): - existing_chain_id = None - for auth_id in event_to_auth_chain.get(event_id, []): - if event_to_types.get(event_id) == event_to_types.get(auth_id): - existing_chain_id = chain_map[auth_id] - break - - new_chain_tuple = None - if existing_chain_id: - # We found a chain ID/sequence number candidate, check its - # not already taken. - proposed_new_id = existing_chain_id[0] - proposed_new_seq = existing_chain_id[1] + 1 - if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated: - already_allocated = db_pool.simple_select_one_onecol_txn( - txn, - table="event_auth_chains", - keyvalues={ - "chain_id": proposed_new_id, - "sequence_number": proposed_new_seq, - }, - retcol="event_id", - allow_none=True, - ) - if already_allocated: - # Mark it as already allocated so we don't need to hit - # the DB again. - chains_tuples_allocated.add((proposed_new_id, proposed_new_seq)) - else: - new_chain_tuple = ( - proposed_new_id, - proposed_new_seq, - ) - - if not new_chain_tuple: - new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1) - - chains_tuples_allocated.add(new_chain_tuple) - - chain_map[event_id] = new_chain_tuple - new_chain_tuples[event_id] = new_chain_tuple + # Allocate chain ID/sequence numbers to each new event. + new_chain_tuples = cls._allocate_chain_ids( + txn, + db_pool, + event_to_room_id, + event_to_types, + event_to_auth_chain, + events_to_calc_chain_id_for, + chain_map, + ) + chain_map.update(new_chain_tuples) db_pool.simple_insert_many_txn( txn, @@ -794,6 +752,137 @@ class PersistEventsStore: ], ) + @staticmethod + def _allocate_chain_ids( + txn, + db_pool: DatabasePool, + event_to_room_id: Dict[str, str], + event_to_types: Dict[str, Tuple[str, str]], + event_to_auth_chain: Dict[str, List[str]], + events_to_calc_chain_id_for: Set[str], + chain_map: Dict[str, Tuple[int, int]], + ) -> Dict[str, Tuple[int, int]]: + """Allocates, but does not persist, chain ID/sequence numbers for the + events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index + for info on args) + """ + + # We now calculate the chain IDs/sequence numbers for the events. We do + # this by looking at the chain ID and sequence number of any auth event + # with the same type/state_key and incrementing the sequence number by + # one. If there was no match or the chain ID/sequence number is already + # taken we generate a new chain. + # + # We try to reduce the number of times that we hit the database by + # batching up calls, to make this more efficient when persisting large + # numbers of state events (e.g. during joins). + # + # We do this by: + # 1. Calculating for each event which auth event will be used to + # inherit the chain ID, i.e. converting the auth chain graph to a + # tree that we can allocate chains on. We also keep track of which + # existing chain IDs have been referenced. + # 2. Fetching the max allocated sequence number for each referenced + # existing chain ID, generating a map from chain ID to the max + # allocated sequence number. + # 3. Iterating over the tree and allocating a chain ID/seq no. to the + # new event, by incrementing the sequence number from the + # referenced event's chain ID/seq no. and checking that the + # incremented sequence number hasn't already been allocated (by + # looking in the map generated in the previous step). We generate a + # new chain if the sequence number has already been allocated. + # + + existing_chains = set() # type: Set[int] + tree = [] # type: List[Tuple[str, Optional[str]]] + + # We need to do this in a topologically sorted order as we want to + # generate chain IDs/sequence numbers of an event's auth events before + # the event itself. + for event_id in sorted_topologically( + events_to_calc_chain_id_for, event_to_auth_chain + ): + for auth_id in event_to_auth_chain.get(event_id, []): + if event_to_types.get(event_id) == event_to_types.get(auth_id): + existing_chain_id = chain_map.get(auth_id) + if existing_chain_id: + existing_chains.add(existing_chain_id[0]) + + tree.append((event_id, auth_id)) + break + else: + tree.append((event_id, None)) + + # Fetch the current max sequence number for each existing referenced chain. + sql = """ + SELECT chain_id, MAX(sequence_number) FROM event_auth_chains + WHERE %s + GROUP BY chain_id + """ + clause, args = make_in_list_sql_clause( + db_pool.engine, "chain_id", existing_chains + ) + txn.execute(sql % (clause,), args) + + chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int] + + # Allocate the new events chain ID/sequence numbers. + # + # To reduce the number of calls to the database we don't allocate a + # chain ID number in the loop, instead we use a temporary `object()` for + # each new chain ID. Once we've done the loop we generate the necessary + # number of new chain IDs in one call, replacing all temporary + # objects with real allocated chain IDs. + + unallocated_chain_ids = set() # type: Set[object] + new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]] + for event_id, auth_event_id in tree: + # If we reference an auth_event_id we fetch the allocated chain ID, + # either from the existing `chain_map` or the newly generated + # `new_chain_tuples` map. + existing_chain_id = None + if auth_event_id: + existing_chain_id = new_chain_tuples.get(auth_event_id) + if not existing_chain_id: + existing_chain_id = chain_map[auth_event_id] + + new_chain_tuple = None # type: Optional[Tuple[Any, int]] + if existing_chain_id: + # We found a chain ID/sequence number candidate, check its + # not already taken. + proposed_new_id = existing_chain_id[0] + proposed_new_seq = existing_chain_id[1] + 1 + + if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq: + new_chain_tuple = ( + proposed_new_id, + proposed_new_seq, + ) + + # If we need to start a new chain we allocate a temporary chain ID. + if not new_chain_tuple: + new_chain_tuple = (object(), 1) + unallocated_chain_ids.add(new_chain_tuple[0]) + + new_chain_tuples[event_id] = new_chain_tuple + chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1] + + # Generate new chain IDs for all unallocated chain IDs. + newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn( + txn, len(unallocated_chain_ids) + ) + + # Map from potentially temporary chain ID to real chain ID + chain_id_to_allocated_map = dict( + zip(unallocated_chain_ids, newly_allocated_chain_ids) + ) # type: Dict[Any, int] + chain_id_to_allocated_map.update((c, c) for c in existing_chains) + + return { + event_id: (chain_id_to_allocated_map[chain_id], seq) + for event_id, (chain_id, seq) in new_chain_tuples.items() + } + def _persist_transaction_ids_txn( self, txn: LoggingTransaction, diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index c780ade077..0ec4dc2918 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -69,6 +69,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta): """Gets the next ID in the sequence""" ... + @abc.abstractmethod + def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: + """Get the next `n` IDs in the sequence""" + ... + @abc.abstractmethod def check_consistency( self, @@ -219,6 +224,17 @@ class LocalSequenceGenerator(SequenceGenerator): self._current_max_id += 1 return self._current_max_id + def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: + with self._lock: + if self._current_max_id is None: + assert self._callback is not None + self._current_max_id = self._callback(txn) + self._callback = None + + first_id = self._current_max_id + 1 + self._current_max_id += n + return [first_id + i for i in range(n)] + def check_consistency( self, db_conn: Connection, -- cgit 1.5.1 From 930ba009719788ebc2004c6ef89329dae1b9689b Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Sat, 23 Jan 2021 21:34:32 +0200 Subject: Add depth and received_ts to forward_extremities admin API response Also add a warning on the admin API documentation. Signed-off-by: Jason Robinson --- docs/admin_api/rooms.md | 8 +++++++- synapse/storage/databases/main/events_forward_extremities.py | 3 ++- 2 files changed, 9 insertions(+), 2 deletions(-) (limited to 'synapse/storage/databases') diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 86daa393a7..f34cec1ff7 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -535,7 +535,9 @@ A response as follows will be returned: "results": [ { "event_id": "$M5SP266vsnxctfwFgFLNceaCo3ujhRtg_NiiHabcdefgh", - "state_group": 439 + "state_group": 439, + "depth": 123, + "received_ts": 1611263016761 } ] } @@ -543,6 +545,10 @@ A response as follows will be returned: ## Deleting forward extremities +**WARNING**: Please ensure you know what you're doing and have read +the related issue [#1760](https://github.com/matrix-org/synapse/issues/1760). +Under no situations should this API be executed as an automated maintenance task! + If a room has lots of forward extremities, the extra can be deleted as follows: diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py index 5fea974050..84aaa919fb 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -86,9 +86,10 @@ class EventForwardExtremitiesStore(SQLBaseStore): def get_forward_extremities_for_room_txn(txn): sql = """ - SELECT event_id, state_group + SELECT event_id, state_group, depth, received_ts FROM event_forward_extremities NATURAL JOIN event_to_state_groups + NATURAL JOIN events WHERE room_id = ? """ -- cgit 1.5.1 From 4a55d267eef1388690e6781b580910e341358f95 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 25 Jan 2021 14:49:39 -0500 Subject: Add an admin API for shadow-banning users. (#9209) This expands the current shadow-banning feature to be usable via the admin API and adds documentation for it. A shadow-banned users receives successful responses to their client-server API requests, but the events are not propagated into rooms. Shadow-banning a user should be used as a tool of last resort and may lead to confusing or broken behaviour for the client. --- changelog.d/9209.feature | 1 + docs/admin_api/user_admin_api.rst | 30 ++++++++++++ stubs/txredisapi.pyi | 1 - synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/users.py | 36 +++++++++++++++ synapse/storage/databases/main/registration.py | 29 ++++++++++++ tests/rest/admin/test_user.py | 64 ++++++++++++++++++++++++++ tests/rest/client/test_shadow_banned.py | 8 +--- 8 files changed, 164 insertions(+), 7 deletions(-) create mode 100644 changelog.d/9209.feature (limited to 'synapse/storage/databases') diff --git a/changelog.d/9209.feature b/changelog.d/9209.feature new file mode 100644 index 0000000000..ec926e8eb4 --- /dev/null +++ b/changelog.d/9209.feature @@ -0,0 +1 @@ +Add an admin API endpoint for shadow-banning users. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index b3d413cf57..1eb674939e 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -760,3 +760,33 @@ The following fields are returned in the JSON response body: - ``total`` - integer - Number of pushers. See also `Client-Server API Spec `_ + +Shadow-banning users +==================== + +Shadow-banning is a useful tool for moderating malicious or egregiously abusive users. +A shadow-banned users receives successful responses to their client-server API requests, +but the events are not propagated into rooms. This can be an effective tool as it +(hopefully) takes longer for the user to realise they are being moderated before +pivoting to another account. + +Shadow-banning a user should be used as a tool of last resort and may lead to confusing +or broken behaviour for the client. A shadow-banned user will not receive any +notification and it is generally more appropriate to ban or kick abusive users. +A shadow-banned user will be unable to contact anyone on the server. + +The API is:: + + POST /_synapse/admin/v1/users//shadow_ban + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst `_. + +An empty JSON dict is returned. + +**Parameters** + +The following parameters should be set in the URL: + +- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must + be local. diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index bfac6840e6..726454ba31 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -15,7 +15,6 @@ """Contains *incomplete* type hints for txredisapi. """ - from typing import List, Optional, Type, Union class RedisProtocol: diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 6f7dc06503..f04740cd38 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -51,6 +51,7 @@ from synapse.rest.admin.users import ( PushersRestServlet, ResetPasswordRestServlet, SearchUsersRestServlet, + ShadowBanRestServlet, UserAdminServlet, UserMediaRestServlet, UserMembershipRestServlet, @@ -230,6 +231,7 @@ def register_servlets(hs, http_server): EventReportsRestServlet(hs).register(http_server) PushersRestServlet(hs).register(http_server) MakeRoomAdminRestServlet(hs).register(http_server) + ShadowBanRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 86198bab30..68c3c64a0d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -890,3 +890,39 @@ class UserTokenRestServlet(RestServlet): ) return 200, {"access_token": token} + + +class ShadowBanRestServlet(RestServlet): + """An admin API for shadow-banning a user. + + A shadow-banned users receives successful responses to their client-server + API requests, but the events are not propagated into rooms. + + Shadow-banning a user should be used as a tool of last resort and may lead + to confusing or broken behaviour for the client. + + Example: + + POST /_synapse/admin/v1/users/@test:example.com/shadow_ban + {} + + 200 OK + {} + """ + + PATTERNS = admin_patterns("/users/(?P[^/]*)/shadow_ban") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST(self, request, user_id): + await assert_requester_is_admin(self.auth, request) + + if not self.hs.is_mine_id(user_id): + raise SynapseError(400, "Only local users can be shadow-banned") + + await self.store.set_shadow_banned(UserID.from_string(user_id), True) + + return 200, {} diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 585b4049d6..0618b4387a 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -360,6 +360,35 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) + async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None: + """Sets whether a user shadow-banned. + + Args: + user: user ID of the user to test + shadow_banned: true iff the user is to be shadow-banned, false otherwise. + """ + + def set_shadow_banned_txn(txn): + self.db_pool.simple_update_one_txn( + txn, + table="users", + keyvalues={"name": user.to_string()}, + updatevalues={"shadow_banned": shadow_banned}, + ) + # In order for this to apply immediately, clear the cache for this user. + tokens = self.db_pool.simple_select_onecol_txn( + txn, + table="access_tokens", + keyvalues={"user_id": user.to_string()}, + retcol="token", + ) + for token in tokens: + self._invalidate_cache_and_stream( + txn, self.get_user_by_access_token, (token,) + ) + + await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn) + def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]: sql = """ SELECT users.name as user_id, diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index e48f8c1d7b..ee05ee60bc 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2380,3 +2380,67 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(self.other_user, channel.json_body["user_id"]) self.assertIn("devices", channel.json_body) + + +class ShadowBanRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + + self.url = "/_synapse/admin/v1/users/%s/shadow_ban" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to get information of an user without authentication. + """ + channel = self.make_request("POST", self.url) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_not_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + channel = self.make_request("POST", self.url, access_token=other_user_token) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that shadow-banning for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" + + channel = self.make_request("POST", url, access_token=self.admin_user_tok) + self.assertEqual(400, channel.code, msg=channel.json_body) + + def test_success(self): + """ + Shadow-banning should succeed for an admin. + """ + # The user starts off as not shadow-banned. + other_user_token = self.login("user", "pass") + result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + self.assertFalse(result.shadow_banned) + + channel = self.make_request("POST", self.url, access_token=self.admin_user_tok) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual({}, channel.json_body) + + # Ensure the user is shadow-banned (and the cache was cleared). + result = self.get_success(self.store.get_user_by_access_token(other_user_token)) + self.assertTrue(result.shadow_banned) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index e689c3fbea..0ebdf1415b 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -18,6 +18,7 @@ import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client.v1 import directory, login, profile, room from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet +from synapse.types import UserID from tests import unittest @@ -31,12 +32,7 @@ class _ShadowBannedBase(unittest.HomeserverTestCase): self.store = self.hs.get_datastore() self.get_success( - self.store.db_pool.simple_update( - table="users", - keyvalues={"name": self.banned_user_id}, - updatevalues={"shadow_banned": True}, - desc="shadow_ban", - ) + self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True) ) self.other_user_id = self.register_user("otheruser", "pass") -- cgit 1.5.1 From 5b857b77f7de62bb9be0aa88a3fffcf7cb11efe6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 25 Jan 2021 14:52:30 -0500 Subject: Don't error if deleting a non-existent pusher. (#9121) --- changelog.d/9121.bugfix | 1 + synapse/storage/databases/main/pusher.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9121.bugfix (limited to 'synapse/storage/databases') diff --git a/changelog.d/9121.bugfix b/changelog.d/9121.bugfix new file mode 100644 index 0000000000..a566878ec0 --- /dev/null +++ b/changelog.d/9121.bugfix @@ -0,0 +1 @@ +Fix spurious errors in logs when deleting a non-existant pusher. diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index bc7621b8d6..2687ef3e43 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -344,7 +344,9 @@ class PusherStore(PusherWorkerStore): txn, self.get_if_user_has_pusher, (user_id,) ) - self.db_pool.simple_delete_one_txn( + # It is expected that there is exactly one pusher to delete, but + # if it isn't there (or there are multiple) delete them all. + self.db_pool.simple_delete_txn( txn, "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, -- cgit 1.5.1 From e20f18a76680bc16fd8299a61dd81dc07f1a3ffd Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Tue, 26 Jan 2021 10:13:35 +0200 Subject: Make natural join inner join Co-authored-by: Erik Johnston --- synapse/storage/databases/main/events_forward_extremities.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'synapse/storage/databases') diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py index 84aaa919fb..68b64838bb 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -88,8 +88,8 @@ class EventForwardExtremitiesStore(SQLBaseStore): sql = """ SELECT event_id, state_group, depth, received_ts FROM event_forward_extremities - NATURAL JOIN event_to_state_groups - NATURAL JOIN events + INNER JOIN event_to_state_groups USING (event_id) + INNER JOIN events INNER JOIN USING (event_id) WHERE room_id = ? """ -- cgit 1.5.1 From 4936fc59fcf23582c940cb1cbf4286039b3504de Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Tue, 26 Jan 2021 10:21:02 +0200 Subject: Fix get forward extremities query Signed-off-by: Jason Robinson --- synapse/storage/databases/main/events_forward_extremities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage/databases') diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py index 68b64838bb..0ac1da9c35 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py @@ -89,7 +89,7 @@ class EventForwardExtremitiesStore(SQLBaseStore): SELECT event_id, state_group, depth, received_ts FROM event_forward_extremities INNER JOIN event_to_state_groups USING (event_id) - INNER JOIN events INNER JOIN USING (event_id) + INNER JOIN events USING (room_id, event_id) WHERE room_id = ? """ -- cgit 1.5.1 From 1baab2035265cf2543fe3c0ef5412c1ac0740c7e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 26 Jan 2021 10:50:21 -0500 Subject: Add type hints to various handlers. (#9223) With this change all handlers except the e2e_* ones have type hints enabled. --- changelog.d/9223.misc | 1 + mypy.ini | 14 ++++ synapse/handlers/acme.py | 12 ++-- synapse/handlers/acme_issuing_service.py | 27 +++++--- synapse/handlers/groups_local.py | 83 ++++++++++++------------ synapse/handlers/search.py | 38 ++++++----- synapse/handlers/set_password.py | 10 +-- synapse/handlers/state_deltas.py | 14 +++- synapse/handlers/stats.py | 39 ++++++----- synapse/handlers/typing.py | 69 +++++++++++--------- synapse/handlers/user_directory.py | 9 +-- synapse/storage/databases/main/search.py | 3 +- synapse/storage/databases/main/stats.py | 22 ++++--- synapse/storage/databases/main/user_directory.py | 2 +- 14 files changed, 205 insertions(+), 138 deletions(-) create mode 100644 changelog.d/9223.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/9223.misc b/changelog.d/9223.misc new file mode 100644 index 0000000000..9d44b621c9 --- /dev/null +++ b/changelog.d/9223.misc @@ -0,0 +1 @@ +Add type hints to handlers code. diff --git a/mypy.ini b/mypy.ini index bd99069c81..f3700d323c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -26,6 +26,8 @@ files = synapse/handlers/_base.py, synapse/handlers/account_data.py, synapse/handlers/account_validity.py, + synapse/handlers/acme.py, + synapse/handlers/acme_issuing_service.py, synapse/handlers/admin.py, synapse/handlers/appservice.py, synapse/handlers/auth.py, @@ -36,6 +38,7 @@ files = synapse/handlers/directory.py, synapse/handlers/events.py, synapse/handlers/federation.py, + synapse/handlers/groups_local.py, synapse/handlers/identity.py, synapse/handlers/initial_sync.py, synapse/handlers/message.py, @@ -52,8 +55,13 @@ files = synapse/handlers/room_member.py, synapse/handlers/room_member_worker.py, synapse/handlers/saml_handler.py, + synapse/handlers/search.py, + synapse/handlers/set_password.py, synapse/handlers/sso.py, + synapse/handlers/state_deltas.py, + synapse/handlers/stats.py, synapse/handlers/sync.py, + synapse/handlers/typing.py, synapse/handlers/user_directory.py, synapse/handlers/ui_auth, synapse/http/client.py, @@ -194,3 +202,9 @@ ignore_missing_imports = True [mypy-hiredis] ignore_missing_imports = True + +[mypy-josepy.*] +ignore_missing_imports = True + +[mypy-txacme.*] +ignore_missing_imports = True diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py index 8476256a59..5ecb2da1ac 100644 --- a/synapse/handlers/acme.py +++ b/synapse/handlers/acme.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING import twisted import twisted.internet.error @@ -22,6 +23,9 @@ from twisted.web.resource import Resource from synapse.app import check_bind_error +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) ACME_REGISTER_FAIL_ERROR = """ @@ -35,12 +39,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC class AcmeHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.reactor = hs.get_reactor() self._acme_domain = hs.config.acme_domain - async def start_listening(self): + async def start_listening(self) -> None: from synapse.handlers import acme_issuing_service # Configure logging for txacme, if you need to debug @@ -85,7 +89,7 @@ class AcmeHandler: logger.error(ACME_REGISTER_FAIL_ERROR) raise - async def provision_certificate(self): + async def provision_certificate(self) -> None: logger.warning("Reprovisioning %s", self._acme_domain) @@ -110,5 +114,3 @@ class AcmeHandler: except Exception: logger.exception("Failed saving!") raise - - return True diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py index 7294649d71..ae2a9dd9c2 100644 --- a/synapse/handlers/acme_issuing_service.py +++ b/synapse/handlers/acme_issuing_service.py @@ -22,8 +22,10 @@ only need (and may only have available) if we are doing ACME, so is designed to imported conditionally. """ import logging +from typing import Dict, Iterable, List import attr +import pem from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from josepy import JWKRSA @@ -36,20 +38,27 @@ from txacme.util import generate_private_key from zope.interface import implementer from twisted.internet import defer +from twisted.internet.interfaces import IReactorTCP from twisted.python.filepath import FilePath from twisted.python.url import URL +from twisted.web.resource import IResource logger = logging.getLogger(__name__) -def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource): +def create_issuing_service( + reactor: IReactorTCP, + acme_url: str, + account_key_file: str, + well_known_resource: IResource, +) -> AcmeIssuingService: """Create an ACME issuing service, and attach it to a web Resource Args: reactor: twisted reactor - acme_url (str): URL to use to request certificates - account_key_file (str): where to store the account key - well_known_resource (twisted.web.IResource): web resource for .well-known. + acme_url: URL to use to request certificates + account_key_file: where to store the account key + well_known_resource: web resource for .well-known. we will attach a child resource for "acme-challenge". Returns: @@ -83,18 +92,20 @@ class ErsatzStore: A store that only stores in memory. """ - certs = attr.ib(default=attr.Factory(dict)) + certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict)) - def store(self, server_name, pem_objects): + def store( + self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject] + ) -> defer.Deferred: self.certs[server_name] = [o.as_bytes() for o in pem_objects] return defer.succeed(None) -def load_or_create_client_key(key_file): +def load_or_create_client_key(key_file: str) -> JWKRSA: """Load the ACME account key from a file, creating it if it does not exist. Args: - key_file (str): name of the file to use as the account key + key_file: name of the file to use as the account key """ # this is based on txacme.endpoint.load_or_create_client_key, but doesn't # hardcode the 'client.key' filename diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index df29edeb83..71f11ef94a 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -15,9 +15,13 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Dict, Iterable, List, Set from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError -from synapse.types import GroupID, get_domain_from_id +from synapse.types import GroupID, JsonDict, get_domain_from_id + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -56,7 +60,7 @@ def _create_rerouter(func_name): class GroupsLocalWorkerHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.room_list_handler = hs.get_room_list_handler() @@ -84,7 +88,9 @@ class GroupsLocalWorkerHandler: get_group_role = _create_rerouter("get_group_role") get_group_roles = _create_rerouter("get_group_roles") - async def get_group_summary(self, group_id, requester_user_id): + async def get_group_summary( + self, group_id: str, requester_user_id: str + ) -> JsonDict: """Get the group summary for a group. If the group is remote we check that the users have valid attestations. @@ -137,14 +143,15 @@ class GroupsLocalWorkerHandler: return res - async def get_users_in_group(self, group_id, requester_user_id): + async def get_users_in_group( + self, group_id: str, requester_user_id: str + ) -> JsonDict: """Get users in a group """ if self.is_mine_id(group_id): - res = await self.groups_server_handler.get_users_in_group( + return await self.groups_server_handler.get_users_in_group( group_id, requester_user_id ) - return res group_server_name = get_domain_from_id(group_id) @@ -178,11 +185,11 @@ class GroupsLocalWorkerHandler: return res - async def get_joined_groups(self, user_id): + async def get_joined_groups(self, user_id: str) -> JsonDict: group_ids = await self.store.get_joined_groups(user_id) return {"groups": group_ids} - async def get_publicised_groups_for_user(self, user_id): + async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict: if self.hs.is_mine_id(user_id): result = await self.store.get_publicised_groups_for_user(user_id) @@ -206,8 +213,10 @@ class GroupsLocalWorkerHandler: # TODO: Verify attestations return {"groups": result} - async def bulk_get_publicised_groups(self, user_ids, proxy=True): - destinations = {} + async def bulk_get_publicised_groups( + self, user_ids: Iterable[str], proxy: bool = True + ) -> JsonDict: + destinations = {} # type: Dict[str, Set[str]] local_users = set() for user_id in user_ids: @@ -220,7 +229,7 @@ class GroupsLocalWorkerHandler: raise SynapseError(400, "Some user_ids are not local") results = {} - failed_results = [] + failed_results = [] # type: List[str] for destination, dest_user_ids in destinations.items(): try: r = await self.transport_client.bulk_get_publicised_groups( @@ -242,7 +251,7 @@ class GroupsLocalWorkerHandler: class GroupsLocalHandler(GroupsLocalWorkerHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) # Ensure attestations get renewed @@ -271,7 +280,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): set_group_join_policy = _create_rerouter("set_group_join_policy") - async def create_group(self, group_id, user_id, content): + async def create_group( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Create a group """ @@ -284,27 +295,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): local_attestation = None remote_attestation = None else: - local_attestation = self.attestations.create_attestation(group_id, user_id) - content["attestation"] = local_attestation - - content["user_profile"] = await self.profile_handler.get_profile(user_id) - - try: - res = await self.transport_client.create_group( - get_domain_from_id(group_id), group_id, user_id, content - ) - except HttpResponseException as e: - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - remote_attestation = res["attestation"] - await self.attestations.verify_attestation( - remote_attestation, - group_id=group_id, - user_id=user_id, - server_name=get_domain_from_id(group_id), - ) + raise SynapseError(400, "Unable to create remote groups") is_publicised = content.get("publicise", False) token = await self.store.register_user_group_membership( @@ -320,7 +311,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return res - async def join_group(self, group_id, user_id, content): + async def join_group( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Request to join a group """ if self.is_mine_id(group_id): @@ -365,7 +358,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {} - async def accept_invite(self, group_id, user_id, content): + async def accept_invite( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Accept an invite to a group """ if self.is_mine_id(group_id): @@ -410,7 +405,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {} - async def invite(self, group_id, user_id, requester_user_id, config): + async def invite( + self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict + ) -> JsonDict: """Invite a user to a group """ content = {"requester_user_id": requester_user_id, "config": config} @@ -434,7 +431,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return res - async def on_invite(self, group_id, user_id, content): + async def on_invite( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """One of our users were invited to a group """ # TODO: Support auto join and rejection @@ -465,8 +464,8 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {"state": "invite", "user_profile": user_profile} async def remove_user_from_group( - self, group_id, user_id, requester_user_id, content - ): + self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: """Remove a user from a group """ if user_id == requester_user_id: @@ -499,7 +498,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return res - async def user_removed_from_group(self, group_id, user_id, content): + async def user_removed_from_group( + self, group_id: str, user_id: str, content: JsonDict + ) -> None: """One of our users was removed/kicked from a group """ # TODO: Check if user in group diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 66f1bbcfc4..94062e79cb 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -15,23 +15,28 @@ import itertools import logging -from typing import Iterable +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional from unpaddedbase64 import decode_base64, encode_base64 from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter +from synapse.events import EventBase from synapse.storage.state import StateFilter +from synapse.types import JsonDict, UserID from synapse.visibility import filter_events_for_client from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class SearchHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() @@ -87,13 +92,15 @@ class SearchHandler(BaseHandler): return historical_room_ids - async def search(self, user, content, batch=None): + async def search( + self, user: UserID, content: JsonDict, batch: Optional[str] = None + ) -> JsonDict: """Performs a full text search for a user. Args: - user (UserID) - content (dict): Search parameters - batch (str): The next_batch parameter. Used for pagination. + user + content: Search parameters + batch: The next_batch parameter. Used for pagination. Returns: dict to be returned to the client with results of search @@ -186,7 +193,7 @@ class SearchHandler(BaseHandler): # If doing a subset of all rooms seearch, check if any of the rooms # are from an upgraded room, and search their contents as well if search_filter.rooms: - historical_room_ids = [] + historical_room_ids = [] # type: List[str] for room_id in search_filter.rooms: # Add any previous rooms to the search if they exist ids = await self.get_old_rooms_from_upgraded_room(room_id) @@ -209,8 +216,10 @@ class SearchHandler(BaseHandler): rank_map = {} # event_id -> rank of event allowed_events = [] - room_groups = {} # Holds result of grouping by room, if applicable - sender_group = {} # Holds result of grouping by sender, if applicable + # Holds result of grouping by room, if applicable + room_groups = {} # type: Dict[str, JsonDict] + # Holds result of grouping by sender, if applicable + sender_group = {} # type: Dict[str, JsonDict] # Holds the next_batch for the entire result set if one of those exists global_next_batch = None @@ -254,7 +263,7 @@ class SearchHandler(BaseHandler): s["results"].append(e.event_id) elif order_by == "recent": - room_events = [] + room_events = [] # type: List[EventBase] i = 0 pagination_token = batch_token @@ -418,13 +427,10 @@ class SearchHandler(BaseHandler): state_results = {} if include_state: - rooms = {e.room_id for e in allowed_events} - for room_id in rooms: + for room_id in {e.room_id for e in allowed_events}: state = await self.state_handler.get_current_state(room_id) state_results[room_id] = list(state.values()) - state_results.values() - # We're now about to serialize the events. We should not make any # blocking calls after this. Otherwise the 'age' will be wrong @@ -448,9 +454,9 @@ class SearchHandler(BaseHandler): if state_results: s = {} - for room_id, state in state_results.items(): + for room_id, state_events in state_results.items(): s[room_id] = await self._event_serializer.serialize_events( - state, time_now + state_events, time_now ) rooms_cat_res["state"] = s diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index a5d67f828f..84af2dde7e 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -13,24 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional from synapse.api.errors import Codes, StoreError, SynapseError from synapse.types import Requester from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class SetPasswordHandler(BaseHandler): """Handler which deals with changing user account passwords""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() - self._password_policy_handler = hs.get_password_policy_handler() async def set_password( self, @@ -38,7 +40,7 @@ class SetPasswordHandler(BaseHandler): password_hash: str, logout_devices: bool, requester: Optional[Requester] = None, - ): + ) -> None: if not self.hs.config.password_localdb_enabled: raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index fb4f70e8e2..b3f9875358 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -14,15 +14,25 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) class StateDeltasHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - async def _get_key_change(self, prev_event_id, event_id, key_name, public_value): + async def _get_key_change( + self, + prev_event_id: Optional[str], + event_id: Optional[str], + key_name: str, + public_value: str, + ) -> Optional[bool]: """Given two events check if the `key_name` field in content changed from not matching `public_value` to doing so. diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index dc62b21c06..d261d7cd4e 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -12,13 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from collections import Counter +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple + +from typing_extensions import Counter as CounterType from synapse.api.constants import EventTypes, Membership from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -31,7 +37,7 @@ class StatsHandler: Heavily derived from UserDirectoryHandler """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.state = hs.get_state_handler() @@ -44,7 +50,7 @@ class StatsHandler: self.stats_enabled = hs.config.stats_enabled # The current position in the current_state_delta stream - self.pos = None + self.pos = None # type: Optional[int] # Guard to ensure we only process deltas one at a time self._is_processing = False @@ -56,7 +62,7 @@ class StatsHandler: # we start populating stats self.clock.call_later(0, self.notify_new_event) - def notify_new_event(self): + def notify_new_event(self) -> None: """Called when there may be more deltas to process """ if not self.stats_enabled or self._is_processing: @@ -72,7 +78,7 @@ class StatsHandler: run_as_background_process("stats.notify_new_event", process) - async def _unsafe_process(self): + async def _unsafe_process(self) -> None: # If self.pos is None then means we haven't fetched it from DB if self.pos is None: self.pos = await self.store.get_stats_positions() @@ -110,10 +116,10 @@ class StatsHandler: ) for room_id, fields in room_count.items(): - room_deltas.setdefault(room_id, {}).update(fields) + room_deltas.setdefault(room_id, Counter()).update(fields) for user_id, fields in user_count.items(): - user_deltas.setdefault(user_id, {}).update(fields) + user_deltas.setdefault(user_id, Counter()).update(fields) logger.debug("room_deltas: %s", room_deltas) logger.debug("user_deltas: %s", user_deltas) @@ -131,19 +137,20 @@ class StatsHandler: self.pos = max_pos - async def _handle_deltas(self, deltas): + async def _handle_deltas( + self, deltas: Iterable[JsonDict] + ) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]: """Called with the state deltas to process Returns: - tuple[dict[str, Counter], dict[str, counter]] Two dicts: the room deltas and the user deltas, mapping from room/user ID to changes in the various fields. """ - room_to_stats_deltas = {} - user_to_stats_deltas = {} + room_to_stats_deltas = {} # type: Dict[str, CounterType[str]] + user_to_stats_deltas = {} # type: Dict[str, CounterType[str]] - room_to_state_updates = {} + room_to_state_updates = {} # type: Dict[str, Dict[str, Any]] for delta in deltas: typ = delta["type"] @@ -173,7 +180,7 @@ class StatsHandler: ) continue - event_content = {} + event_content = {} # type: JsonDict sender = None if event_id is not None: @@ -257,13 +264,13 @@ class StatsHandler: ) if has_changed_joinedness: - delta = +1 if membership == Membership.JOIN else -1 + membership_delta = +1 if membership == Membership.JOIN else -1 user_to_stats_deltas.setdefault(user_id, Counter())[ "joined_rooms" - ] += delta + ] += membership_delta - room_stats_delta["local_users_in_room"] += delta + room_stats_delta["local_users_in_room"] += membership_delta elif typ == EventTypes.Create: room_state["is_federatable"] = ( diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index e919a8f9ed..3f0dfc7a74 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -15,13 +15,13 @@ import logging import random from collections import namedtuple -from typing import TYPE_CHECKING, List, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.appservice import ApplicationService from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.streams import TypingStream -from synapse.types import JsonDict, UserID, get_domain_from_id +from synapse.types import JsonDict, Requester, UserID, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -65,17 +65,17 @@ class FollowerTypingHandler: ) # map room IDs to serial numbers - self._room_serials = {} + self._room_serials = {} # type: Dict[str, int] # map room IDs to sets of users currently typing - self._room_typing = {} + self._room_typing = {} # type: Dict[str, Set[str]] - self._member_last_federation_poke = {} + self._member_last_federation_poke = {} # type: Dict[RoomMember, int] self.wheel_timer = WheelTimer(bucket_size=5000) self._latest_room_serial = 0 self.clock.looping_call(self._handle_timeouts, 5000) - def _reset(self): + def _reset(self) -> None: """Reset the typing handler's data caches. """ # map room IDs to serial numbers @@ -86,7 +86,7 @@ class FollowerTypingHandler: self._member_last_federation_poke = {} self.wheel_timer = WheelTimer(bucket_size=5000) - def _handle_timeouts(self): + def _handle_timeouts(self) -> None: logger.debug("Checking for typing timeouts") now = self.clock.time_msec() @@ -96,7 +96,7 @@ class FollowerTypingHandler: for member in members: self._handle_timeout_for_member(now, member) - def _handle_timeout_for_member(self, now: int, member: RoomMember): + def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None: if not self.is_typing(member): # Nothing to do if they're no longer typing return @@ -114,10 +114,10 @@ class FollowerTypingHandler: # each person typing. self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) - def is_typing(self, member): + def is_typing(self, member: RoomMember) -> bool: return member.user_id in self._room_typing.get(member.room_id, []) - async def _push_remote(self, member, typing): + async def _push_remote(self, member: RoomMember, typing: bool) -> None: if not self.federation: return @@ -148,7 +148,7 @@ class FollowerTypingHandler: def process_replication_rows( self, token: int, rows: List[TypingStream.TypingStreamRow] - ): + ) -> None: """Should be called whenever we receive updates for typing stream. """ @@ -178,7 +178,7 @@ class FollowerTypingHandler: async def _send_changes_in_typing_to_remotes( self, room_id: str, prev_typing: Set[str], now_typing: Set[str] - ): + ) -> None: """Process a change in typing of a room from replication, sending EDUs for any local users. """ @@ -194,12 +194,12 @@ class FollowerTypingHandler: if self.is_mine_id(user_id): await self._push_remote(RoomMember(room_id, user_id), False) - def get_current_token(self): + def get_current_token(self) -> int: return self._latest_room_serial class TypingWriterHandler(FollowerTypingHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) assert hs.config.worker.writers.typing == hs.get_instance_name() @@ -213,14 +213,15 @@ class TypingWriterHandler(FollowerTypingHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) - self._member_typing_until = {} # clock time we expect to stop + # clock time we expect to stop + self._member_typing_until = {} # type: Dict[RoomMember, int] # caches which room_ids changed at which serials self._typing_stream_change_cache = StreamChangeCache( "TypingStreamChangeCache", self._latest_room_serial ) - def _handle_timeout_for_member(self, now: int, member: RoomMember): + def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None: super()._handle_timeout_for_member(now, member) if not self.is_typing(member): @@ -233,7 +234,9 @@ class TypingWriterHandler(FollowerTypingHandler): self._stopped_typing(member) return - async def started_typing(self, target_user, requester, room_id, timeout): + async def started_typing( + self, target_user: UserID, requester: Requester, room_id: str, timeout: int + ) -> None: target_user_id = target_user.to_string() auth_user_id = requester.user.to_string() @@ -263,11 +266,13 @@ class TypingWriterHandler(FollowerTypingHandler): if was_present: # No point sending another notification - return None + return self._push_update(member=member, typing=True) - async def stopped_typing(self, target_user, requester, room_id): + async def stopped_typing( + self, target_user: UserID, requester: Requester, room_id: str + ) -> None: target_user_id = target_user.to_string() auth_user_id = requester.user.to_string() @@ -290,23 +295,23 @@ class TypingWriterHandler(FollowerTypingHandler): self._stopped_typing(member) - def user_left_room(self, user, room_id): + def user_left_room(self, user: UserID, room_id: str) -> None: user_id = user.to_string() if self.is_mine_id(user_id): member = RoomMember(room_id=room_id, user_id=user_id) self._stopped_typing(member) - def _stopped_typing(self, member): + def _stopped_typing(self, member: RoomMember) -> None: if member.user_id not in self._room_typing.get(member.room_id, set()): # No point - return None + return self._member_typing_until.pop(member, None) self._member_last_federation_poke.pop(member, None) self._push_update(member=member, typing=False) - def _push_update(self, member, typing): + def _push_update(self, member: RoomMember, typing: bool) -> None: if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. run_as_background_process( @@ -315,7 +320,7 @@ class TypingWriterHandler(FollowerTypingHandler): self._push_update_local(member=member, typing=typing) - async def _recv_edu(self, origin, content): + async def _recv_edu(self, origin: str, content: JsonDict) -> None: room_id = content["room_id"] user_id = content["user_id"] @@ -340,7 +345,7 @@ class TypingWriterHandler(FollowerTypingHandler): self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT) self._push_update_local(member=member, typing=content["typing"]) - def _push_update_local(self, member, typing): + def _push_update_local(self, member: RoomMember, typing: bool) -> None: room_set = self._room_typing.setdefault(member.room_id, set()) if typing: room_set.add(member.user_id) @@ -386,7 +391,7 @@ class TypingWriterHandler(FollowerTypingHandler): changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( last_id - ) + ) # type: Optional[Iterable[str]] if changed_rooms is None: changed_rooms = self._room_serials @@ -412,13 +417,13 @@ class TypingWriterHandler(FollowerTypingHandler): def process_replication_rows( self, token: int, rows: List[TypingStream.TypingStreamRow] - ): + ) -> None: # The writing process should never get updates from replication. raise Exception("Typing writer instance got typing info over replication") class TypingNotificationEventSource: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.clock = hs.get_clock() # We can't call get_typing_handler here because there's a cycle: @@ -427,7 +432,7 @@ class TypingNotificationEventSource: # self.get_typing_handler = hs.get_typing_handler - def _make_event_for(self, room_id): + def _make_event_for(self, room_id: str) -> JsonDict: typing = self.get_typing_handler()._room_typing[room_id] return { "type": "m.typing", @@ -462,7 +467,9 @@ class TypingNotificationEventSource: return (events, handler._latest_room_serial) - async def get_new_events(self, from_key, room_ids, **kwargs): + async def get_new_events( + self, from_key: int, room_ids: Iterable[str], **kwargs + ) -> Tuple[List[JsonDict], int]: with Measure(self.clock, "typing.get_new_events"): from_key = int(from_key) handler = self.get_typing_handler() @@ -478,5 +485,5 @@ class TypingNotificationEventSource: return (events, handler._latest_room_serial) - def get_current_key(self): + def get_current_key(self) -> int: return self.get_typing_handler()._latest_room_serial diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index d4651c8348..8aedf5072e 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -145,10 +145,6 @@ class UserDirectoryHandler(StateDeltasHandler): if self.pos is None: self.pos = await self.store.get_user_directory_stream_pos() - # If still None then the initial background update hasn't happened yet - if self.pos is None: - return None - # Loop round handling deltas until we're up to date while True: with Measure(self.clock, "user_dir_delta"): @@ -233,6 +229,11 @@ class UserDirectoryHandler(StateDeltasHandler): if change: # The user joined event = await self.store.get_event(event_id, allow_none=True) + # It isn't expected for this event to not exist, but we + # don't want the entire background process to break. + if event is None: + continue + profile = ProfileInfo( avatar_url=event.content.get("avatar_url"), display_name=event.content.get("displayname"), diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 871af64b11..f5e7d9ef98 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -24,6 +24,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import Collection logger = logging.getLogger(__name__) @@ -460,7 +461,7 @@ class SearchStore(SearchBackgroundUpdateStore): async def search_rooms( self, - room_ids: List[str], + room_ids: Collection[str], search_term: str, keys: List[str], limit, diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 0cdb3ec1f7..d421d18f8d 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -15,11 +15,12 @@ # limitations under the License. import logging -from collections import Counter from enum import Enum from itertools import chain from typing import Any, Dict, List, Optional, Tuple +from typing_extensions import Counter + from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership @@ -319,7 +320,9 @@ class StatsStore(StateDeltasStore): return slice_list @cached() - async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int: + async def get_earliest_token_for_stats( + self, stats_type: str, id: str + ) -> Optional[int]: """ Fetch the "earliest token". This is used by the room stats delta processor to ignore deltas that have been processed between the @@ -339,7 +342,7 @@ class StatsStore(StateDeltasStore): ) async def bulk_update_stats_delta( - self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int + self, ts: int, updates: Dict[str, Dict[str, Counter[str]]], stream_id: int ) -> None: """Bulk update stats tables for a given stream_id and updates the stats incremental position. @@ -665,7 +668,7 @@ class StatsStore(StateDeltasStore): async def get_changes_room_total_events_and_bytes( self, min_pos: int, max_pos: int - ) -> Dict[str, Dict[str, int]]: + ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]: """Fetches the counts of events in the given range of stream IDs. Args: @@ -683,18 +686,19 @@ class StatsStore(StateDeltasStore): max_pos, ) - def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos): + def get_changes_room_total_events_and_bytes_txn( + self, txn, low_pos: int, high_pos: int + ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]: """Gets the total_events and total_event_bytes counts for rooms and senders, in a range of stream_orderings (including backfilled events). Args: txn - low_pos (int): Low stream ordering - high_pos (int): High stream ordering + low_pos: Low stream ordering + high_pos: High stream ordering Returns: - tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The - room and user deltas for total_events/total_event_bytes in the + The room and user deltas for total_events/total_event_bytes in the format of `stats_id` -> fields """ diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index ef11f1c3b3..7b9729da09 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -540,7 +540,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): desc="get_user_in_directory", ) - async def update_user_directory_stream_pos(self, stream_id: str) -> None: + async def update_user_directory_stream_pos(self, stream_id: int) -> None: await self.db_pool.simple_update_one( table="user_directory_stream_pos", keyvalues={}, -- cgit 1.5.1 From a78016dadfb1680f5f77daae9948086b37cbeef8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 28 Jan 2021 08:34:19 -0500 Subject: Add type hints to E2E handler. (#9232) This finishes adding type hints to the `synapse.handlers` module. --- changelog.d/9232.misc | 1 + mypy.ini | 42 +--- synapse/handlers/device.py | 12 +- synapse/handlers/e2e_keys.py | 223 +++++++++++++--------- synapse/handlers/e2e_room_keys.py | 91 +++++---- synapse/logging/opentracing.py | 2 +- synapse/storage/databases/main/end_to_end_keys.py | 4 +- 7 files changed, 198 insertions(+), 177 deletions(-) create mode 100644 changelog.d/9232.misc (limited to 'synapse/storage/databases') diff --git a/changelog.d/9232.misc b/changelog.d/9232.misc new file mode 100644 index 0000000000..9d44b621c9 --- /dev/null +++ b/changelog.d/9232.misc @@ -0,0 +1 @@ +Add type hints to handlers code. diff --git a/mypy.ini b/mypy.ini index f3700d323c..68a4533973 100644 --- a/mypy.ini +++ b/mypy.ini @@ -23,47 +23,7 @@ files = synapse/events/validator.py, synapse/events/spamcheck.py, synapse/federation, - synapse/handlers/_base.py, - synapse/handlers/account_data.py, - synapse/handlers/account_validity.py, - synapse/handlers/acme.py, - synapse/handlers/acme_issuing_service.py, - synapse/handlers/admin.py, - synapse/handlers/appservice.py, - synapse/handlers/auth.py, - synapse/handlers/cas_handler.py, - synapse/handlers/deactivate_account.py, - synapse/handlers/device.py, - synapse/handlers/devicemessage.py, - synapse/handlers/directory.py, - synapse/handlers/events.py, - synapse/handlers/federation.py, - synapse/handlers/groups_local.py, - synapse/handlers/identity.py, - synapse/handlers/initial_sync.py, - synapse/handlers/message.py, - synapse/handlers/oidc_handler.py, - synapse/handlers/pagination.py, - synapse/handlers/password_policy.py, - synapse/handlers/presence.py, - synapse/handlers/profile.py, - synapse/handlers/read_marker.py, - synapse/handlers/receipts.py, - synapse/handlers/register.py, - synapse/handlers/room.py, - synapse/handlers/room_list.py, - synapse/handlers/room_member.py, - synapse/handlers/room_member_worker.py, - synapse/handlers/saml_handler.py, - synapse/handlers/search.py, - synapse/handlers/set_password.py, - synapse/handlers/sso.py, - synapse/handlers/state_deltas.py, - synapse/handlers/stats.py, - synapse/handlers/sync.py, - synapse/handlers/typing.py, - synapse/handlers/user_directory.py, - synapse/handlers/ui_auth, + synapse/handlers, synapse/http/client.py, synapse/http/federation/matrix_federation_agent.py, synapse/http/federation/well_known_resolver.py, diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index debb1b4f29..0863154f7a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from synapse.api import errors from synapse.api.constants import EventTypes @@ -62,7 +62,7 @@ class DeviceWorkerHandler(BaseHandler): self._auth_handler = hs.get_auth_handler() @trace - async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]: + async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: """ Retrieve the given user's devices @@ -85,7 +85,7 @@ class DeviceWorkerHandler(BaseHandler): return devices @trace - async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: + async def get_device(self, user_id: str, device_id: str) -> JsonDict: """ Retrieve the given device Args: @@ -598,7 +598,7 @@ class DeviceHandler(DeviceWorkerHandler): def _update_device_from_client_ips( - device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]] + device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict] ) -> None: ip = client_ips.get((device["user_id"], device["device_id"]), {}) device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) @@ -946,8 +946,8 @@ class DeviceListUpdater: async def process_cross_signing_key_update( self, user_id: str, - master_key: Optional[Dict[str, Any]], - self_signing_key: Optional[Dict[str, Any]], + master_key: Optional[JsonDict], + self_signing_key: Optional[JsonDict], ) -> List[str]: """Process the given new master and self-signing key for the given remote user. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 929752150d..8f3a6b35a4 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -16,7 +16,7 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.types import ( + JsonDict, UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, @@ -40,11 +41,14 @@ from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class E2eKeysHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() @@ -78,7 +82,9 @@ class E2eKeysHandler: ) @trace - async def query_devices(self, query_body, timeout, from_user_id): + async def query_devices( + self, query_body: JsonDict, timeout: int, from_user_id: str + ) -> JsonDict: """ Handle a device key query from a client { @@ -98,12 +104,14 @@ class E2eKeysHandler: } Args: - from_user_id (str): the user making the query. This is used when + from_user_id: the user making the query. This is used when adding cross-signing signatures to limit what signatures users can see. """ - device_keys_query = query_body.get("device_keys", {}) + device_keys_query = query_body.get( + "device_keys", {} + ) # type: Dict[str, Iterable[str]] # separate users by domain. # make a map from domain to user_id to device_ids @@ -121,7 +129,8 @@ class E2eKeysHandler: set_tag("remote_key_query", remote_queries) # First get local devices. - failures = {} + # A map of destination -> failure response. + failures = {} # type: Dict[str, JsonDict] results = {} if local_query: local_result = await self.query_local_devices(local_query) @@ -135,9 +144,10 @@ class E2eKeysHandler: ) # Now attempt to get any remote devices from our local cache. - remote_queries_not_in_cache = {} + # A map of destination -> user ID -> device IDs. + remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]] if remote_queries: - query_list = [] + query_list = [] # type: List[Tuple[str, Optional[str]]] for user_id, device_ids in remote_queries.items(): if device_ids: query_list.extend((user_id, device_id) for device_id in device_ids) @@ -284,15 +294,15 @@ class E2eKeysHandler: return ret async def get_cross_signing_keys_from_cache( - self, query, from_user_id + self, query: Iterable[str], from_user_id: Optional[str] ) -> Dict[str, Dict[str, dict]]: """Get cross-signing keys for users from the database Args: - query (Iterable[string]) an iterable of user IDs. A dict whose keys + query: an iterable of user IDs. A dict whose keys are user IDs satisfies this, so the query format used for query_devices can be used here. - from_user_id (str): the user making the query. This is used when + from_user_id: the user making the query. This is used when adding cross-signing signatures to limit what signatures users can see. @@ -315,14 +325,12 @@ class E2eKeysHandler: if "self_signing" in user_info: self_signing_keys[user_id] = user_info["self_signing"] - if ( - from_user_id in keys - and keys[from_user_id] is not None - and "user_signing" in keys[from_user_id] - ): - # users can see other users' master and self-signing keys, but can - # only see their own user-signing keys - user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"] + # users can see other users' master and self-signing keys, but can + # only see their own user-signing keys + if from_user_id: + from_user_key = keys.get(from_user_id) + if from_user_key and "user_signing" in from_user_key: + user_signing_keys[from_user_id] = from_user_key["user_signing"] return { "master_keys": master_keys, @@ -344,9 +352,9 @@ class E2eKeysHandler: A map from user_id -> device_id -> device details """ set_tag("local_query", query) - local_query = [] + local_query = [] # type: List[Tuple[str, Optional[str]]] - result_dict = {} + result_dict = {} # type: Dict[str, Dict[str, dict]] for user_id, device_ids in query.items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): @@ -380,10 +388,14 @@ class E2eKeysHandler: log_kv(results) return result_dict - async def on_federation_query_client_keys(self, query_body): + async def on_federation_query_client_keys( + self, query_body: Dict[str, Dict[str, Optional[List[str]]]] + ) -> JsonDict: """ Handle a device key query from a federated server """ - device_keys_query = query_body.get("device_keys", {}) + device_keys_query = query_body.get( + "device_keys", {} + ) # type: Dict[str, Optional[List[str]]] res = await self.query_local_devices(device_keys_query) ret = {"device_keys": res} @@ -397,31 +409,34 @@ class E2eKeysHandler: return ret @trace - async def claim_one_time_keys(self, query, timeout): - local_query = [] - remote_queries = {} + async def claim_one_time_keys( + self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int + ) -> JsonDict: + local_query = [] # type: List[Tuple[str, str, str]] + remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]] - for user_id, device_keys in query.get("one_time_keys", {}).items(): + for user_id, one_time_keys in query.get("one_time_keys", {}).items(): # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): - for device_id, algorithm in device_keys.items(): + for device_id, algorithm in one_time_keys.items(): local_query.append((user_id, device_id, algorithm)) else: domain = get_domain_from_id(user_id) - remote_queries.setdefault(domain, {})[user_id] = device_keys + remote_queries.setdefault(domain, {})[user_id] = one_time_keys set_tag("local_key_query", local_query) set_tag("remote_key_query", remote_queries) results = await self.store.claim_e2e_one_time_keys(local_query) - json_result = {} - failures = {} + # A map of user ID -> device ID -> key ID -> key. + json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] + failures = {} # type: Dict[str, JsonDict] for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): - for key_id, json_bytes in keys.items(): + for key_id, json_str in keys.items(): json_result.setdefault(user_id, {})[device_id] = { - key_id: json_decoder.decode(json_bytes) + key_id: json_decoder.decode(json_str) } @trace @@ -468,7 +483,9 @@ class E2eKeysHandler: return {"one_time_keys": json_result, "failures": failures} @tag_args - async def upload_keys_for_user(self, user_id, device_id, keys): + async def upload_keys_for_user( + self, user_id: str, device_id: str, keys: JsonDict + ) -> JsonDict: time_now = self.clock.time_msec() @@ -543,8 +560,8 @@ class E2eKeysHandler: return {"one_time_key_counts": result} async def _upload_one_time_keys_for_user( - self, user_id, device_id, time_now, one_time_keys - ): + self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict + ) -> None: logger.info( "Adding one_time_keys %r for device %r for user %r at %d", one_time_keys.keys(), @@ -585,12 +602,14 @@ class E2eKeysHandler: log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys}) await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) - async def upload_signing_keys_for_user(self, user_id, keys): + async def upload_signing_keys_for_user( + self, user_id: str, keys: JsonDict + ) -> JsonDict: """Upload signing keys for cross-signing Args: - user_id (string): the user uploading the keys - keys (dict[string, dict]): the signing keys + user_id: the user uploading the keys + keys: the signing keys """ # if a master key is uploaded, then check it. Otherwise, load the @@ -667,16 +686,17 @@ class E2eKeysHandler: return {} - async def upload_signatures_for_device_keys(self, user_id, signatures): + async def upload_signatures_for_device_keys( + self, user_id: str, signatures: JsonDict + ) -> JsonDict: """Upload device signatures for cross-signing Args: - user_id (string): the user uploading the signatures - signatures (dict[string, dict[string, dict]]): map of users to - devices to signed keys. This is the submission from the user; an - exception will be raised if it is malformed. + user_id: the user uploading the signatures + signatures: map of users to devices to signed keys. This is the submission + from the user; an exception will be raised if it is malformed. Returns: - dict: response to be sent back to the client. The response will have + The response to be sent back to the client. The response will have a "failures" key, which will be a dict mapping users to devices to errors for the signatures that failed. Raises: @@ -719,7 +739,9 @@ class E2eKeysHandler: return {"failures": failures} - async def _process_self_signatures(self, user_id, signatures): + async def _process_self_signatures( + self, user_id: str, signatures: JsonDict + ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]: """Process uploaded signatures of the user's own keys. Signatures of the user's own keys from this API come in two forms: @@ -731,15 +753,14 @@ class E2eKeysHandler: signatures (dict[string, dict]): map of devices to signed keys Returns: - (list[SignatureListItem], dict[string, dict[string, dict]]): - a list of signatures to store, and a map of users to devices to failure - reasons + A tuple of a list of signatures to store, and a map of users to + devices to failure reasons Raises: SynapseError: if the input is malformed """ - signature_list = [] - failures = {} + signature_list = [] # type: List[SignatureListItem] + failures = {} # type: Dict[str, Dict[str, JsonDict]] if not signatures: return signature_list, failures @@ -834,19 +855,24 @@ class E2eKeysHandler: return signature_list, failures def _check_master_key_signature( - self, user_id, master_key_id, signed_master_key, stored_master_key, devices - ): + self, + user_id: str, + master_key_id: str, + signed_master_key: JsonDict, + stored_master_key: JsonDict, + devices: Dict[str, Dict[str, JsonDict]], + ) -> List["SignatureListItem"]: """Check signatures of a user's master key made by their devices. Args: - user_id (string): the user whose master key is being checked - master_key_id (string): the ID of the user's master key - signed_master_key (dict): the user's signed master key that was uploaded - stored_master_key (dict): our previously-stored copy of the user's master key - devices (iterable(dict)): the user's devices + user_id: the user whose master key is being checked + master_key_id: the ID of the user's master key + signed_master_key: the user's signed master key that was uploaded + stored_master_key: our previously-stored copy of the user's master key + devices: the user's devices Returns: - list[SignatureListItem]: a list of signatures to store + A list of signatures to store Raises: SynapseError: if a signature is invalid @@ -877,25 +903,26 @@ class E2eKeysHandler: return master_key_signature_list - async def _process_other_signatures(self, user_id, signatures): + async def _process_other_signatures( + self, user_id: str, signatures: Dict[str, dict] + ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]: """Process uploaded signatures of other users' keys. These will be the target user's master keys, signed by the uploading user's user-signing key. Args: - user_id (string): the user uploading the keys - signatures (dict[string, dict]): map of users to devices to signed keys + user_id: the user uploading the keys + signatures: map of users to devices to signed keys Returns: - (list[SignatureListItem], dict[string, dict[string, dict]]): - a list of signatures to store, and a map of users to devices to failure + A list of signatures to store, and a map of users to devices to failure reasons Raises: SynapseError: if the input is malformed """ - signature_list = [] - failures = {} + signature_list = [] # type: List[SignatureListItem] + failures = {} # type: Dict[str, Dict[str, JsonDict]] if not signatures: return signature_list, failures @@ -983,7 +1010,7 @@ class E2eKeysHandler: async def _get_e2e_cross_signing_verify_key( self, user_id: str, key_type: str, from_user_id: str = None - ): + ) -> Tuple[JsonDict, str, VerifyKey]: """Fetch locally or remotely query for a cross-signing public key. First, attempt to fetch the cross-signing public key from storage. @@ -997,8 +1024,7 @@ class E2eKeysHandler: This affects what signatures are fetched. Returns: - dict, str, VerifyKey: the raw key data, the key ID, and the - signedjson verify key + The raw key data, the key ID, and the signedjson verify key Raises: NotFoundError: if the key is not found @@ -1135,16 +1161,18 @@ class E2eKeysHandler: return desired_key, desired_key_id, desired_verify_key -def _check_cross_signing_key(key, user_id, key_type, signing_key=None): +def _check_cross_signing_key( + key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None +) -> None: """Check a cross-signing key uploaded by a user. Performs some basic sanity checking, and ensures that it is signed, if a signature is required. Args: - key (dict): the key data to verify - user_id (str): the user whose key is being checked - key_type (str): the type of key that the key should be - signing_key (VerifyKey): (optional) the signing key that the key should - be signed with. If omitted, signatures will not be checked. + key: the key data to verify + user_id: the user whose key is being checked + key_type: the type of key that the key should be + signing_key: the signing key that the key should be signed with. If + omitted, signatures will not be checked. """ if ( key.get("user_id") != user_id @@ -1162,16 +1190,21 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None): ) -def _check_device_signature(user_id, verify_key, signed_device, stored_device): +def _check_device_signature( + user_id: str, + verify_key: VerifyKey, + signed_device: JsonDict, + stored_device: JsonDict, +) -> None: """Check that a signature on a device or cross-signing key is correct and matches the copy of the device/key that we have stored. Throws an exception if an error is detected. Args: - user_id (str): the user ID whose signature is being checked - verify_key (VerifyKey): the key to verify the device with - signed_device (dict): the uploaded signed device data - stored_device (dict): our previously stored copy of the device + user_id: the user ID whose signature is being checked + verify_key: the key to verify the device with + signed_device: the uploaded signed device data + stored_device: our previously stored copy of the device Raises: SynapseError: if the signature was invalid or the sent device is not the @@ -1201,7 +1234,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device): raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE) -def _exception_to_failure(e): +def _exception_to_failure(e: Exception) -> JsonDict: if isinstance(e, SynapseError): return {"status": e.code, "errcode": e.errcode, "message": str(e)} @@ -1218,7 +1251,7 @@ def _exception_to_failure(e): return {"status": 503, "message": str(e)} -def _one_time_keys_match(old_key_json, new_key): +def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool: old_key = json_decoder.decode(old_key_json) # if either is a string rather than an object, they must match exactly @@ -1239,16 +1272,16 @@ class SignatureListItem: """An item in the signature list as used by upload_signatures_for_device_keys. """ - signing_key_id = attr.ib() - target_user_id = attr.ib() - target_device_id = attr.ib() - signature = attr.ib() + signing_key_id = attr.ib(type=str) + target_user_id = attr.ib(type=str) + target_device_id = attr.ib(type=str) + signature = attr.ib(type=JsonDict) class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" - def __init__(self, hs, e2e_keys_handler): + def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): self.store = hs.get_datastore() self.federation = hs.get_federation_client() self.clock = hs.get_clock() @@ -1257,7 +1290,7 @@ class SigningKeyEduUpdater: self._remote_edu_linearizer = Linearizer(name="remote_signing_key") # user_id -> list of updates waiting to be handled. - self._pending_updates = {} + self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]] # Recently seen stream ids. We don't bother keeping these in the DB, # but they're useful to have them about to reduce the number of spurious @@ -1270,13 +1303,15 @@ class SigningKeyEduUpdater: iterable=True, ) - async def incoming_signing_key_update(self, origin, edu_content): + async def incoming_signing_key_update( + self, origin: str, edu_content: JsonDict + ) -> None: """Called on incoming signing key update from federation. Responsible for parsing the EDU and adding to pending updates list. Args: - origin (string): the server that sent the EDU - edu_content (dict): the contents of the EDU + origin: the server that sent the EDU + edu_content: the contents of the EDU """ user_id = edu_content.pop("user_id") @@ -1299,11 +1334,11 @@ class SigningKeyEduUpdater: await self._handle_signing_key_updates(user_id) - async def _handle_signing_key_updates(self, user_id): + async def _handle_signing_key_updates(self, user_id: str) -> None: """Actually handle pending updates. Args: - user_id (string): the user whose updates we are processing + user_id: the user whose updates we are processing """ device_handler = self.e2e_keys_handler.device_handler @@ -1315,7 +1350,7 @@ class SigningKeyEduUpdater: # This can happen since we batch updates return - device_ids = [] + device_ids = [] # type: List[str] logger.info("pending updates: %r", pending_updates) diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index f01b090772..622cae23be 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, List, Optional from synapse.api.errors import ( Codes, @@ -24,8 +25,12 @@ from synapse.api.errors import ( SynapseError, ) from synapse.logging.opentracing import log_kv, trace +from synapse.types import JsonDict from synapse.util.async_helpers import Linearizer +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -37,7 +42,7 @@ class E2eRoomKeysHandler: The actual payload of the encrypted keys is completely opaque to the handler. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() # Used to lock whenever a client is uploading key data. This prevents collisions @@ -48,21 +53,27 @@ class E2eRoomKeysHandler: self._upload_linearizer = Linearizer("upload_room_keys_lock") @trace - async def get_room_keys(self, user_id, version, room_id=None, session_id=None): + async def get_room_keys( + self, + user_id: str, + version: str, + room_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> List[JsonDict]: """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. See EndToEndRoomKeyStore.get_e2e_room_keys for full details. Args: - user_id(str): the user whose keys we're getting - version(str): the version ID of the backup we're getting keys from - room_id(string): room ID to get keys for, for None to get keys for all rooms - session_id(string): session ID to get keys for, for None to get keys for all + user_id: the user whose keys we're getting + version: the version ID of the backup we're getting keys from + room_id: room ID to get keys for, for None to get keys for all rooms + session_id: session ID to get keys for, for None to get keys for all sessions Raises: NotFoundError: if the backup version does not exist Returns: - A deferred list of dicts giving the session_data and message metadata for + A list of dicts giving the session_data and message metadata for these room keys. """ @@ -86,17 +97,23 @@ class E2eRoomKeysHandler: return results @trace - async def delete_room_keys(self, user_id, version, room_id=None, session_id=None): + async def delete_room_keys( + self, + user_id: str, + version: str, + room_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> JsonDict: """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. See EndToEndRoomKeyStore.delete_e2e_room_keys for full details. Args: - user_id(str): the user whose backup we're deleting - version(str): the version ID of the backup we're deleting - room_id(string): room ID to delete keys for, for None to delete keys for all + user_id: the user whose backup we're deleting + version: the version ID of the backup we're deleting + room_id: room ID to delete keys for, for None to delete keys for all rooms - session_id(string): session ID to delete keys for, for None to delete keys + session_id: session ID to delete keys for, for None to delete keys for all sessions Raises: NotFoundError: if the backup version does not exist @@ -128,15 +145,17 @@ class E2eRoomKeysHandler: return {"etag": str(version_etag), "count": count} @trace - async def upload_room_keys(self, user_id, version, room_keys): + async def upload_room_keys( + self, user_id: str, version: str, room_keys: JsonDict + ) -> JsonDict: """Bulk upload a list of room keys into a given backup version, asserting that the given version is the current backup version. room_keys are merged into the current backup as described in RoomKeysServlet.on_PUT(). Args: - user_id(str): the user whose backup we're setting - version(str): the version ID of the backup we're updating - room_keys(dict): a nested dict describing the room_keys we're setting: + user_id: the user whose backup we're setting + version: the version ID of the backup we're updating + room_keys: a nested dict describing the room_keys we're setting: { "rooms": { @@ -254,14 +273,16 @@ class E2eRoomKeysHandler: return {"etag": str(version_etag), "count": count} @staticmethod - def _should_replace_room_key(current_room_key, room_key): + def _should_replace_room_key( + current_room_key: Optional[JsonDict], room_key: JsonDict + ) -> bool: """ Determine whether to replace a given current_room_key (if any) with a newly uploaded room_key backup Args: - current_room_key (dict): Optional, the current room_key dict if any - room_key (dict): The new room_key dict which may or may not be fit to + current_room_key: Optional, the current room_key dict if any + room_key : The new room_key dict which may or may not be fit to replace the current_room_key Returns: @@ -286,14 +307,14 @@ class E2eRoomKeysHandler: return True @trace - async def create_version(self, user_id, version_info): + async def create_version(self, user_id: str, version_info: JsonDict) -> str: """Create a new backup version. This automatically becomes the new backup version for the user's keys; previous backups will no longer be writeable to. Args: - user_id(str): the user whose backup version we're creating - version_info(dict): metadata about the new version being created + user_id: the user whose backup version we're creating + version_info: metadata about the new version being created { "algorithm": "m.megolm_backup.v1", @@ -301,7 +322,7 @@ class E2eRoomKeysHandler: } Returns: - A deferred of a string that gives the new version number. + The new version number. """ # TODO: Validate the JSON to make sure it has the right keys. @@ -313,17 +334,19 @@ class E2eRoomKeysHandler: ) return new_version - async def get_version_info(self, user_id, version=None): + async def get_version_info( + self, user_id: str, version: Optional[str] = None + ) -> JsonDict: """Get the info about a given version of the user's backup Args: - user_id(str): the user whose current backup version we're querying - version(str): Optional; if None gives the most recent version + user_id: the user whose current backup version we're querying + version: Optional; if None gives the most recent version otherwise a historical one. Raises: NotFoundError: if the requested backup version doesn't exist Returns: - A deferred of a info dict that gives the info about the new version. + A info dict that gives the info about the new version. { "version": "1234", @@ -346,7 +369,7 @@ class E2eRoomKeysHandler: return res @trace - async def delete_version(self, user_id, version=None): + async def delete_version(self, user_id: str, version: Optional[str] = None) -> None: """Deletes a given version of the user's e2e_room_keys backup Args: @@ -366,17 +389,19 @@ class E2eRoomKeysHandler: raise @trace - async def update_version(self, user_id, version, version_info): + async def update_version( + self, user_id: str, version: str, version_info: JsonDict + ) -> JsonDict: """Update the info about a given version of the user's backup Args: - user_id(str): the user whose current backup version we're updating - version(str): the backup version we're updating - version_info(dict): the new information about the backup + user_id: the user whose current backup version we're updating + version: the backup version we're updating + version_info: the new information about the backup Raises: NotFoundError: if the requested backup version doesn't exist Returns: - A deferred of an empty dict. + An empty dict. """ if "version" not in version_info: version_info["version"] = version diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index ab586c318c..0538350f38 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -791,7 +791,7 @@ def tag_args(func): @wraps(func) def _tag_args_inner(*args, **kwargs): - argspec = inspect.getargspec(func) + argspec = inspect.getfullargspec(func) for i, arg in enumerate(argspec.args[1:]): set_tag("ARG_" + arg, args[i]) set_tag("args", args[len(argspec.args) :]) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index c128889bf9..309f1e865b 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -634,7 +634,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): async def get_e2e_cross_signing_keys_bulk( self, user_ids: List[str], from_user_id: Optional[str] = None - ) -> Dict[str, Dict[str, dict]]: + ) -> Dict[str, Optional[Dict[str, dict]]]: """Returns the cross-signing keys for a set of users. Args: @@ -724,7 +724,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): async def claim_e2e_one_time_keys( self, query_list: Iterable[Tuple[str, str, str]] - ) -> Dict[str, Dict[str, Dict[str, bytes]]]: + ) -> Dict[str, Dict[str, Dict[str, str]]]: """Take a list of one time keys out of the database. Args: -- cgit 1.5.1 From 9c715a5f1981891815c124353ba15cf4d17bf9bb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 15:47:59 +0000 Subject: Fix SSO on workers (#9271) Fixes #8966. * Factor out build_synapse_client_resource_tree Start a function which will mount resources common to all workers. * Move sso init into build_synapse_client_resource_tree ... so that we don't have to do it for each worker * Fix SSO-login-via-a-worker Expose the SSO login endpoints on workers, like the documentation says. * Update workers config for new endpoints Add documentation for endpoints recently added (#8942, #9017, #9262) * remove submit_token from workers endpoints list this *doesn't* work on workers (yet). * changelog * Add a comment about the odd path for SAML2Resource --- changelog.d/9271.bugfix | 1 + docs/workers.md | 18 +++++----- synapse/app/generic_worker.py | 11 +++--- synapse/app/homeserver.py | 18 ++-------- synapse/rest/synapse/client/__init__.py | 49 +++++++++++++++++++++++++- synapse/storage/databases/main/registration.py | 40 ++++++++++----------- tests/rest/client/v1/test_login.py | 15 ++------ tests/rest/client/v2_alpha/test_auth.py | 6 ++-- 8 files changed, 93 insertions(+), 65 deletions(-) create mode 100644 changelog.d/9271.bugfix (limited to 'synapse/storage/databases') diff --git a/changelog.d/9271.bugfix b/changelog.d/9271.bugfix new file mode 100644 index 0000000000..ef30c6570f --- /dev/null +++ b/changelog.d/9271.bugfix @@ -0,0 +1 @@ +Fix single-sign-on when the endpoints are routed to synapse workers. diff --git a/docs/workers.md b/docs/workers.md index d01683681f..6b8887de36 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -225,7 +225,6 @@ expressions: ^/_matrix/client/(api/v1|r0|unstable)/joined_groups$ ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$ ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/ - ^/_synapse/client/password_reset/email/submit_token$ # Registration/login requests ^/_matrix/client/(api/v1|r0|unstable)/login$ @@ -256,25 +255,28 @@ Additionally, the following endpoints should be included if Synapse is configure to use SSO (you only need to include the ones for whichever SSO provider you're using): + # for all SSO providers + ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect + ^/_synapse/client/pick_idp$ + ^/_synapse/client/pick_username + ^/_synapse/client/sso_register$ + # OpenID Connect requests. - ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$ ^/_synapse/oidc/callback$ # SAML requests. - ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$ ^/_matrix/saml2/authn_response$ # CAS requests. - ^/_matrix/client/(api/v1|r0|unstable)/login/(cas|sso)/redirect$ ^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$ -Note that a HTTP listener with `client` and `federation` resources must be -configured in the `worker_listeners` option in the worker config. - -Ensure that all SSO logins go to a single process (usually the main process). +Ensure that all SSO logins go to a single process. For multiple workers not handling the SSO endpoints properly, see [#7530](https://github.com/matrix-org/synapse/issues/7530). +Note that a HTTP listener with `client` and `federation` resources must be +configured in the `worker_listeners` option in the worker config. + #### Load balancing It is possible to run multiple instances of this worker app, with incoming requests diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index e60988fa4a..516f2464b4 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -22,6 +22,7 @@ from typing import Dict, Iterable, Optional, Set from typing_extensions import ContextManager from twisted.internet import address +from twisted.web.resource import IResource import synapse import synapse.events @@ -90,9 +91,8 @@ from synapse.replication.tcp.streams import ( ToDeviceStream, ) from synapse.rest.admin import register_servlets_for_media_repo -from synapse.rest.client.v1 import events, room +from synapse.rest.client.v1 import events, login, room from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet -from synapse.rest.client.v1.login import LoginRestServlet from synapse.rest.client.v1.profile import ( ProfileAvatarURLRestServlet, ProfileDisplaynameRestServlet, @@ -127,6 +127,7 @@ from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.server import HomeServer, cache_in_self from synapse.storage.databases.main.censor_events import CensorEventsStore from synapse.storage.databases.main.client_ips import ClientIpWorkerStore @@ -507,7 +508,7 @@ class GenericWorkerServer(HomeServer): site_tag = port # We always include a health resource. - resources = {"/health": HealthResource()} + resources = {"/health": HealthResource()} # type: Dict[str, IResource] for res in listener_config.http_options.resources: for name in res.names: @@ -517,7 +518,7 @@ class GenericWorkerServer(HomeServer): resource = JsonResource(self, canonical_json=False) RegisterRestServlet(self).register(resource) - LoginRestServlet(self).register(resource) + login.register_servlets(self, resource) ThreepidRestServlet(self).register(resource) DevicesRestServlet(self).register(resource) KeyQueryServlet(self).register(resource) @@ -557,6 +558,8 @@ class GenericWorkerServer(HomeServer): groups.register_servlets(self, resource) resources.update({CLIENT_API_PREFIX: resource}) + + resources.update(build_synapse_client_resource_tree(self)) elif name == "federation": resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) elif name == "media": diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 86d6f73674..244657cb88 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -60,9 +60,7 @@ from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource -from synapse.rest.synapse.client.pick_idp import PickIdpResource -from synapse.rest.synapse.client.pick_username import pick_username_resource -from synapse.rest.synapse.client.sso_register import SsoRegisterResource +from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer from synapse.storage import DataStore @@ -191,22 +189,10 @@ class SynapseHomeServer(HomeServer): "/_matrix/client/versions": client_resource, "/.well-known/matrix/client": WellKnownResource(self), "/_synapse/admin": AdminRestResource(self), - "/_synapse/client/pick_username": pick_username_resource(self), - "/_synapse/client/pick_idp": PickIdpResource(self), - "/_synapse/client/sso_register": SsoRegisterResource(self), + **build_synapse_client_resource_tree(self), } ) - if self.get_config().oidc_enabled: - from synapse.rest.oidc import OIDCResource - - resources["/_synapse/oidc"] = OIDCResource(self) - - if self.get_config().saml2_enabled: - from synapse.rest.saml2 import SAML2Resource - - resources["/_matrix/saml2"] = SAML2Resource(self) - if self.get_config().threepid_behaviour_email == ThreepidBehaviour.LOCAL: from synapse.rest.synapse.client.password_reset import ( PasswordResetSubmitTokenResource, diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index c0b733488b..6acbc03d73 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2020 The Matrix.org Foundation C.I.C. +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,3 +12,50 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from typing import TYPE_CHECKING, Mapping + +from twisted.web.resource import Resource + +from synapse.rest.synapse.client.pick_idp import PickIdpResource +from synapse.rest.synapse.client.pick_username import pick_username_resource +from synapse.rest.synapse.client.sso_register import SsoRegisterResource + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resource]: + """Builds a resource tree to include synapse-specific client resources + + These are resources which should be loaded on all workers which expose a C-S API: + ie, the main process, and any generic workers so configured. + + Returns: + map from path to Resource. + """ + resources = { + # SSO bits. These are always loaded, whether or not SSO login is actually + # enabled (they just won't work very well if it's not) + "/_synapse/client/pick_idp": PickIdpResource(hs), + "/_synapse/client/pick_username": pick_username_resource(hs), + "/_synapse/client/sso_register": SsoRegisterResource(hs), + } + + # provider-specific SSO bits. Only load these if they are enabled, since they + # rely on optional dependencies. + if hs.config.oidc_enabled: + from synapse.rest.oidc import OIDCResource + + resources["/_synapse/oidc"] = OIDCResource(hs) + + if hs.config.saml2_enabled: + from synapse.rest.saml2 import SAML2Resource + + # This is mounted under '/_matrix' for backwards-compatibility. + resources["/_matrix/saml2"] = SAML2Resource(hs) + + return resources + + +__all__ = ["build_synapse_client_resource_tree"] diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 8d05288ed4..14c0878d81 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -443,6 +443,26 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) + async def record_user_external_id( + self, auth_provider: str, external_id: str, user_id: str + ) -> None: + """Record a mapping from an external user id to a mxid + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ + await self.db_pool.simple_insert( + table="user_external_ids", + values={ + "auth_provider": auth_provider, + "external_id": external_id, + "user_id": user_id, + }, + desc="record_user_external_id", + ) + async def get_user_by_external_id( self, auth_provider: str, external_id: str ) -> Optional[str]: @@ -1371,26 +1391,6 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - async def record_user_external_id( - self, auth_provider: str, external_id: str, user_id: str - ) -> None: - """Record a mapping from an external user id to a mxid - - Args: - auth_provider: identifier for the remote auth provider - external_id: id on that system - user_id: complete mxid that it is mapped to - """ - await self.db_pool.simple_insert( - table="user_external_ids", - values={ - "auth_provider": auth_provider, - "external_id": external_id, - "user_id": user_id, - }, - desc="record_user_external_id", - ) - async def user_set_password_hash( self, user_id: str, password_hash: Optional[str] ) -> None: diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index f01215ed1c..ded22a9767 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -29,9 +29,7 @@ from synapse.appservice import ApplicationService from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet -from synapse.rest.synapse.client.pick_idp import PickIdpResource -from synapse.rest.synapse.client.pick_username import pick_username_resource -from synapse.rest.synapse.client.sso_register import SsoRegisterResource +from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import create_requester from tests import unittest @@ -424,11 +422,8 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): return config def create_resource_dict(self) -> Dict[str, Resource]: - from synapse.rest.oidc import OIDCResource - d = super().create_resource_dict() - d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs) - d["/_synapse/oidc"] = OIDCResource(self.hs) + d.update(build_synapse_client_resource_tree(self.hs)) return d def test_get_login_flows(self): @@ -1212,12 +1207,8 @@ class UsernamePickerTestCase(HomeserverTestCase): return config def create_resource_dict(self) -> Dict[str, Resource]: - from synapse.rest.oidc import OIDCResource - d = super().create_resource_dict() - d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) - d["/_synapse/client/sso_register"] = SsoRegisterResource(self.hs) - d["/_synapse/oidc"] = OIDCResource(self.hs) + d.update(build_synapse_client_resource_tree(self.hs)) return d def test_username_picker(self): diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index a6488a3d29..3f50c56745 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -22,7 +22,7 @@ from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import auth, devices, register -from synapse.rest.oidc import OIDCResource +from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import JsonDict, UserID from tests import unittest @@ -173,9 +173,7 @@ class UIAuthTests(unittest.HomeserverTestCase): def create_resource_dict(self): resource_dict = super().create_resource_dict() - if HAS_OIDC: - # mount the OIDC resource at /_synapse/oidc - resource_dict["/_synapse/oidc"] = OIDCResource(self.hs) + resource_dict.update(build_synapse_client_resource_tree(self.hs)) return resource_dict def prepare(self, reactor, clock, hs): -- cgit 1.5.1 From 43dd93bb262c8fa7b6c201013891ef540c331682 Mon Sep 17 00:00:00 2001 From: Jan Christian Grünhage Date: Mon, 1 Feb 2021 18:06:22 +0100 Subject: Add phone home stats for encrypted messages. (#9283) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jan Christian Grünhage --- changelog.d/9283.feature | 1 + synapse/app/phone_stats_home.py | 9 +++-- synapse/storage/databases/main/metrics.py | 56 +++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 2 deletions(-) create mode 100644 changelog.d/9283.feature (limited to 'synapse/storage/databases') diff --git a/changelog.d/9283.feature b/changelog.d/9283.feature new file mode 100644 index 0000000000..54f133a064 --- /dev/null +++ b/changelog.d/9283.feature @@ -0,0 +1 @@ +Add phone home stats for encrypted messages. diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index c38cf8231f..8f86cecb76 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -93,15 +93,20 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): stats["daily_active_users"] = await hs.get_datastore().count_daily_users() stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users() + daily_active_e2ee_rooms = await hs.get_datastore().count_daily_active_e2ee_rooms() + stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms + stats["daily_e2ee_messages"] = await hs.get_datastore().count_daily_e2ee_messages() + daily_sent_e2ee_messages = await hs.get_datastore().count_daily_sent_e2ee_messages() + stats["daily_sent_e2ee_messages"] = daily_sent_e2ee_messages stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms() stats["daily_messages"] = await hs.get_datastore().count_daily_messages() + daily_sent_messages = await hs.get_datastore().count_daily_sent_messages() + stats["daily_sent_messages"] = daily_sent_messages r30_results = await hs.get_datastore().count_r30_users() for name, count in r30_results.items(): stats["r30_users_" + name] = count - daily_sent_messages = await hs.get_datastore().count_daily_sent_messages() - stats["daily_sent_messages"] = daily_sent_messages stats["cache_factor"] = hs.config.caches.global_factor stats["event_cache_size"] = hs.config.caches.event_cache_size diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index ab18cc4d79..92e65aa640 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -88,6 +88,62 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): (x[0] - 1) * x[1] for x in res if x[1] ) + async def count_daily_e2ee_messages(self): + """ + Returns an estimate of the number of messages sent in the last day. + + If it has been significantly less or more than one day since the last + call to this function, it will return None. + """ + + def _count_messages(txn): + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.encrypted' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + (count,) = txn.fetchone() + return count + + return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages) + + async def count_daily_sent_e2ee_messages(self): + def _count_messages(txn): + # This is good enough as if you have silly characters in your own + # hostname then thats your own fault. + like_clause = "%:" + self.hs.hostname + + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.encrypted' + AND sender LIKE ? + AND stream_ordering > ? + """ + + txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) + (count,) = txn.fetchone() + return count + + return await self.db_pool.runInteraction( + "count_daily_sent_e2ee_messages", _count_messages + ) + + async def count_daily_active_e2ee_rooms(self): + def _count(txn): + sql = """ + SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events + WHERE type = 'm.room.encrypted' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + (count,) = txn.fetchone() + return count + + return await self.db_pool.runInteraction( + "count_daily_active_e2ee_rooms", _count + ) + async def count_daily_messages(self): """ Returns an estimate of the number of messages sent in the last day. -- cgit 1.5.1