From 4f3096d866a9810b1c982669d9567fe47b2db73f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Mon, 21 Sep 2020 12:34:06 +0100 Subject: Add a comment re #1691 --- synapse/crypto/context_factory.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 2b03f5ac76..79668a402e 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -45,7 +45,11 @@ _TLS_VERSION_MAP = { class ServerContextFactory(ContextFactory): """Factory for PyOpenSSL SSL contexts that are used to handle incoming - connections.""" + connections. + + TODO: replace this with an implementation of IOpenSSLServerConnectionCreator, + per https://github.com/matrix-org/synapse/issues/1691 + """ def __init__(self, config): # TODO: once pyOpenSSL exposes TLS_METHOD and SSL_CTX_set_min_proto_version, -- cgit 1.5.1 From 37ca5924bddccc37521798236339b539677d101f Mon Sep 17 00:00:00 2001 From: Dionysis Grigoropoulos Date: Tue, 22 Sep 2020 13:42:55 +0300 Subject: Create function to check for long names in devices (#8364) * Create a new function to verify that the length of a device name is under a certain threshold. * Refactor old code and tests to use said function. * Verify device name length during registration of device * Add a test for the above Signed-off-by: Dionysis Grigoropoulos --- changelog.d/8364.bugfix | 2 ++ synapse/handlers/device.py | 30 ++++++++++++++++++++++++------ tests/handlers/test_device.py | 11 +++++++++++ tests/rest/admin/test_device.py | 2 +- 4 files changed, 38 insertions(+), 7 deletions(-) create mode 100644 changelog.d/8364.bugfix diff --git a/changelog.d/8364.bugfix b/changelog.d/8364.bugfix new file mode 100644 index 0000000000..7b82cbc388 --- /dev/null +++ b/changelog.d/8364.bugfix @@ -0,0 +1,2 @@ +Fix a bug where during device registration the length of the device name wasn't +limited. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 55a9787439..4149520d6c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional from synapse.api import errors from synapse.api.constants import EventTypes from synapse.api.errors import ( + Codes, FederationDeniedError, HttpResponseException, RequestSendFailed, @@ -265,6 +266,24 @@ class DeviceHandler(DeviceWorkerHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) + def _check_device_name_length(self, name: str): + """ + Checks whether a device name is longer than the maximum allowed length. + + Args: + name: The name of the device. + + Raises: + SynapseError: if the device name is too long. + """ + if name and len(name) > MAX_DEVICE_DISPLAY_NAME_LEN: + raise SynapseError( + 400, + "Device display name is too long (max %i)" + % (MAX_DEVICE_DISPLAY_NAME_LEN,), + errcode=Codes.TOO_LARGE, + ) + async def check_device_registered( self, user_id, device_id, initial_device_display_name=None ): @@ -282,6 +301,9 @@ class DeviceHandler(DeviceWorkerHandler): Returns: str: device id (generated if none was supplied) """ + + self._check_device_name_length(initial_device_display_name) + if device_id is not None: new_device = await self.store.store_device( user_id=user_id, @@ -397,12 +419,8 @@ class DeviceHandler(DeviceWorkerHandler): # Reject a new displayname which is too long. new_display_name = content.get("display_name") - if new_display_name and len(new_display_name) > MAX_DEVICE_DISPLAY_NAME_LEN: - raise SynapseError( - 400, - "Device display name is too long (max %i)" - % (MAX_DEVICE_DISPLAY_NAME_LEN,), - ) + + self._check_device_name_length(new_display_name) try: await self.store.update_device( diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 6aa322bf3a..969d44c787 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -35,6 +35,17 @@ class DeviceTestCase(unittest.HomeserverTestCase): # These tests assume that it starts 1000 seconds in. self.reactor.advance(1000) + def test_device_is_created_with_invalid_name(self): + self.get_failure( + self.handler.check_device_registered( + user_id="@boris:foo", + device_id="foo", + initial_device_display_name="a" + * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1), + ), + synapse.api.errors.SynapseError, + ) + def test_device_is_created_if_doesnt_exist(self): res = self.get_success( self.handler.check_device_registered( diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index faa7f381a9..92c9058887 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -221,7 +221,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual(Codes.TOO_LARGE, channel.json_body["errcode"]) # Ensure the display name was not updated. request, channel = self.make_request( -- cgit 1.5.1 From 55bb5fda339f8ec232e8b2a65df01f1597e594ee Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 22 Sep 2020 15:18:31 +0100 Subject: 1.20.0 --- CHANGES.md | 6 ++++++ debian/changelog | 8 ++++++-- synapse/__init__.py | 2 +- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 84976ab2bd..5a846daa4d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,9 @@ +Synapse 1.20.0 (2020-09-22) +=========================== + +No significant changes. + + Synapse 1.20.0rc5 (2020-09-18) ============================== diff --git a/debian/changelog b/debian/changelog index dbf01d6b1e..ae548f9f33 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,8 +1,12 @@ -matrix-synapse-py3 (1.20.0ubuntu1) UNRELEASED; urgency=medium +matrix-synapse-py3 (1.20.0) stable; urgency=medium + [ Synapse Packaging team ] + * New synapse release 1.20.0. + + [ Dexter Chua ] * Use Type=notify in systemd service - -- Dexter Chua Wed, 26 Aug 2020 12:41:36 +0000 + -- Synapse Packaging team Tue, 22 Sep 2020 15:19:32 +0100 matrix-synapse-py3 (1.19.3) stable; urgency=medium diff --git a/synapse/__init__.py b/synapse/__init__.py index a95753dcc7..8242d05f60 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -48,7 +48,7 @@ try: except ImportError: pass -__version__ = "1.20.0rc5" +__version__ = "1.20.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.5.1 From 012736ff070573fa15611a178e83421f4998abb7 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 22 Sep 2020 15:30:44 +0100 Subject: Deprecation warning for synapse admin api being accessible under /_matrix --- CHANGES.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 5a846daa4d..bac099e9b5 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,7 +1,19 @@ Synapse 1.20.0 (2020-09-22) =========================== -No significant changes. +No significant changes since v1.20.0rc5. + +Removal warning +--------------- + +Historically, the [Synapse Admin +API](https://github.com/matrix-org/synapse/tree/master/docs) has been +accessible under the `/_matrix/client/api/v1/admin`, +`/_matrix/client/unstable/admin`, `/_matrix/client/r0/admin` and +`/_synapse/admin` prefixes. In a future release, we will soon be dropping +support for accessing Synapse's Admin API using the `/_matrix/client/*` +prefixes. This is to help make locking down external access to the Admin API +endpoints easier for homeserver admins. Synapse 1.20.0rc5 (2020-09-18) -- cgit 1.5.1 From d191dbdaa6b0a00c1148f0464f542c15db973efa Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 22 Sep 2020 15:42:53 +0100 Subject: Fix wording of deprecation notice in changelog --- CHANGES.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index bac099e9b5..84711de448 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -10,11 +10,10 @@ Historically, the [Synapse Admin API](https://github.com/matrix-org/synapse/tree/master/docs) has been accessible under the `/_matrix/client/api/v1/admin`, `/_matrix/client/unstable/admin`, `/_matrix/client/r0/admin` and -`/_synapse/admin` prefixes. In a future release, we will soon be dropping -support for accessing Synapse's Admin API using the `/_matrix/client/*` -prefixes. This is to help make locking down external access to the Admin API -endpoints easier for homeserver admins. - +`/_synapse/admin` prefixes. In a future release, we will be dropping support +for accessing Synapse's Admin API using the `/_matrix/client/*` prefixes. This +makes it easier for homeserver admins to lock down external access to the Admin +API endpoints. Synapse 1.20.0rc5 (2020-09-18) ============================== -- cgit 1.5.1 From 4da01f9c614f36a293235d6a1fd3602d550f2001 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 22 Sep 2020 19:15:04 +0200 Subject: Admin API for reported events (#8217) Add an admin API to read entries of table `event_reports`. API: `GET /_synapse/admin/v1/event_reports` --- changelog.d/8217.feature | 1 + docs/admin_api/event_reports.rst | 129 +++++++++++ synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/event_reports.py | 88 ++++++++ synapse/storage/databases/main/room.py | 95 ++++++++ tests/rest/admin/test_event_reports.py | 382 +++++++++++++++++++++++++++++++++ 6 files changed, 697 insertions(+) create mode 100644 changelog.d/8217.feature create mode 100644 docs/admin_api/event_reports.rst create mode 100644 synapse/rest/admin/event_reports.py create mode 100644 tests/rest/admin/test_event_reports.py diff --git a/changelog.d/8217.feature b/changelog.d/8217.feature new file mode 100644 index 0000000000..899cbf14ef --- /dev/null +++ b/changelog.d/8217.feature @@ -0,0 +1 @@ +Add an admin API `GET /_synapse/admin/v1/event_reports` to read entries of table `event_reports`. Contributed by @dklimpel. \ No newline at end of file diff --git a/docs/admin_api/event_reports.rst b/docs/admin_api/event_reports.rst new file mode 100644 index 0000000000..461be01230 --- /dev/null +++ b/docs/admin_api/event_reports.rst @@ -0,0 +1,129 @@ +Show reported events +==================== + +This API returns information about reported events. + +The api is:: + + GET /_synapse/admin/v1/event_reports?from=0&limit=10 + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst `_. + +It returns a JSON body like the following: + +.. code:: jsonc + + { + "event_reports": [ + { + "content": { + "reason": "foo", + "score": -100 + }, + "event_id": "$bNUFCwGzWca1meCGkjp-zwslF-GfVcXukvRLI1_FaVY", + "event_json": { + "auth_events": [ + "$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M", + "$oggsNXxzPFRE3y53SUNd7nsj69-QzKv03a1RucHu-ws" + ], + "content": { + "body": "matrix.org: This Week in Matrix", + "format": "org.matrix.custom.html", + "formatted_body": "matrix.org:
This Week in Matrix", + "msgtype": "m.notice" + }, + "depth": 546, + "hashes": { + "sha256": "xK1//xnmvHJIOvbgXlkI8eEqdvoMmihVDJ9J4SNlsAw" + }, + "origin": "matrix.org", + "origin_server_ts": 1592291711430, + "prev_events": [ + "$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M" + ], + "prev_state": [], + "room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org", + "sender": "@foobar:matrix.org", + "signatures": { + "matrix.org": { + "ed25519:a_JaEG": "cs+OUKW/iHx5pEidbWxh0UiNNHwe46Ai9LwNz+Ah16aWDNszVIe2gaAcVZfvNsBhakQTew51tlKmL2kspXk/Dg" + } + }, + "type": "m.room.message", + "unsigned": { + "age_ts": 1592291711430, + } + }, + "id": 2, + "reason": "foo", + "received_ts": 1570897107409, + "room_alias": "#alias1:matrix.org", + "room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org", + "sender": "@foobar:matrix.org", + "user_id": "@foo:matrix.org" + }, + { + "content": { + "reason": "bar", + "score": -100 + }, + "event_id": "$3IcdZsDaN_En-S1DF4EMCy3v4gNRKeOJs8W5qTOKj4I", + "event_json": { + // hidden items + // see above + }, + "id": 3, + "reason": "bar", + "received_ts": 1598889612059, + "room_alias": "#alias2:matrix.org", + "room_id": "!eGvUQuTCkHGVwNMOjv:matrix.org", + "sender": "@foobar:matrix.org", + "user_id": "@bar:matrix.org" + } + ], + "next_token": 2, + "total": 4 + } + +To paginate, check for ``next_token`` and if present, call the endpoint again +with ``from`` set to the value of ``next_token``. This will return a new page. + +If the endpoint does not return a ``next_token`` then there are no more +reports to paginate through. + +**URL parameters:** + +- ``limit``: integer - Is optional but is used for pagination, + denoting the maximum number of items to return in this call. Defaults to ``100``. +- ``from``: integer - Is optional but used for pagination, + denoting the offset in the returned results. This should be treated as an opaque value and + not explicitly set to anything other than the return value of ``next_token`` from a previous call. + Defaults to ``0``. +- ``dir``: string - Direction of event report order. Whether to fetch the most recent first (``b``) or the + oldest first (``f``). Defaults to ``b``. +- ``user_id``: string - Is optional and filters to only return users with user IDs that contain this value. + This is the user who reported the event and wrote the reason. +- ``room_id``: string - Is optional and filters to only return rooms with room IDs that contain this value. + +**Response** + +The following fields are returned in the JSON response body: + +- ``id``: integer - ID of event report. +- ``received_ts``: integer - The timestamp (in milliseconds since the unix epoch) when this report was sent. +- ``room_id``: string - The ID of the room in which the event being reported is located. +- ``event_id``: string - The ID of the reported event. +- ``user_id``: string - This is the user who reported the event and wrote the reason. +- ``reason``: string - Comment made by the ``user_id`` in this report. May be blank. +- ``content``: object - Content of reported event. + + - ``reason``: string - Comment made by the ``user_id`` in this report. May be blank. + - ``score``: integer - Content is reported based upon a negative score, where -100 is "most offensive" and 0 is "inoffensive". + +- ``sender``: string - This is the ID of the user who sent the original message/event that was reported. +- ``room_alias``: string - The alias of the room. ``null`` if the room does not have a canonical alias set. +- ``event_json``: object - Details of the original event that was reported. +- ``next_token``: integer - Indication for pagination. See above. +- ``total``: integer - Total number of event reports related to the query (``user_id`` and ``room_id``). + diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 4a75c06480..5c5f00b213 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -31,6 +31,7 @@ from synapse.rest.admin.devices import ( DeviceRestServlet, DevicesRestServlet, ) +from synapse.rest.admin.event_reports import EventReportsRestServlet from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet @@ -216,6 +217,7 @@ def register_servlets(hs, http_server): DeviceRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) DeleteDevicesRestServlet(hs).register(http_server) + EventReportsRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py new file mode 100644 index 0000000000..5b8d0594cd --- /dev/null +++ b/synapse/rest/admin/event_reports.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin + +logger = logging.getLogger(__name__) + + +class EventReportsRestServlet(RestServlet): + """ + List all reported events that are known to the homeserver. Results are returned + in a dictionary containing report information. Supports pagination. + The requester must have administrator access in Synapse. + + GET /_synapse/admin/v1/event_reports + returns: + 200 OK with list of reports if success otherwise an error. + + Args: + The parameters `from` and `limit` are required only for pagination. + By default, a `limit` of 100 is used. + The parameter `dir` can be used to define the order of results. + The parameter `user_id` can be used to filter by user id. + The parameter `room_id` can be used to filter by room id. + Returns: + A list of reported events and an integer representing the total number of + reported events that exist given this query + """ + + PATTERNS = admin_patterns("/event_reports$") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request): + await assert_requester_is_admin(self.auth, request) + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + direction = parse_string(request, "dir", default="b") + user_id = parse_string(request, "user_id") + room_id = parse_string(request, "room_id") + + if start < 0: + raise SynapseError( + 400, + "The start parameter must be a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "The limit parameter must be a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if direction not in ("f", "b"): + raise SynapseError( + 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + ) + + event_reports, total = await self.store.get_event_reports_paginate( + start, limit, direction, user_id, room_id + ) + ret = {"event_reports": event_reports, "total": total} + if (start + limit) < total: + ret["next_token"] = start + len(event_reports) + + return 200, ret diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index bd6f9553c6..3ee097abf7 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1328,6 +1328,101 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): desc="add_event_report", ) + async def get_event_reports_paginate( + self, + start: int, + limit: int, + direction: str = "b", + user_id: Optional[str] = None, + room_id: Optional[str] = None, + ) -> Tuple[List[Dict[str, Any]], int]: + """Retrieve a paginated list of event reports + + Args: + start: event offset to begin the query from + limit: number of rows to retrieve + direction: Whether to fetch the most recent first (`"b"`) or the + oldest first (`"f"`) + user_id: search for user_id. Ignored if user_id is None + room_id: search for room_id. Ignored if room_id is None + Returns: + event_reports: json list of event reports + count: total number of event reports matching the filter criteria + """ + + def _get_event_reports_paginate_txn(txn): + filters = [] + args = [] + + if user_id: + filters.append("er.user_id LIKE ?") + args.extend(["%" + user_id + "%"]) + if room_id: + filters.append("er.room_id LIKE ?") + args.extend(["%" + room_id + "%"]) + + if direction == "b": + order = "DESC" + else: + order = "ASC" + + where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" + + sql = """ + SELECT COUNT(*) as total_event_reports + FROM event_reports AS er + {} + """.format( + where_clause + ) + txn.execute(sql, args) + count = txn.fetchone()[0] + + sql = """ + SELECT + er.id, + er.received_ts, + er.room_id, + er.event_id, + er.user_id, + er.reason, + er.content, + events.sender, + room_aliases.room_alias, + event_json.json AS event_json + FROM event_reports AS er + LEFT JOIN room_aliases + ON room_aliases.room_id = er.room_id + JOIN events + ON events.event_id = er.event_id + JOIN event_json + ON event_json.event_id = er.event_id + {where_clause} + ORDER BY er.received_ts {order} + LIMIT ? + OFFSET ? + """.format( + where_clause=where_clause, order=order, + ) + + args += [limit, start] + txn.execute(sql, args) + event_reports = self.db_pool.cursor_to_dict(txn) + + if count > 0: + for row in event_reports: + try: + row["content"] = db_to_json(row["content"]) + row["event_json"] = db_to_json(row["event_json"]) + except Exception: + continue + + return event_reports, count + + return await self.db_pool.runInteraction( + "get_event_reports_paginate", _get_event_reports_paginate_txn + ) + def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py new file mode 100644 index 0000000000..bf79086f78 --- /dev/null +++ b/tests/rest/admin/test_event_reports.py @@ -0,0 +1,382 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import report_event + +from tests import unittest + + +class EventReportsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + report_event.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id1 = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok, is_public=True + ) + self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok) + + self.room_id2 = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok, is_public=True + ) + self.helper.join(self.room_id2, user=self.admin_user, tok=self.admin_user_tok) + + # Two rooms and two users. Every user sends and reports every room event + for i in range(5): + self._create_event_and_report( + room_id=self.room_id1, user_tok=self.other_user_tok, + ) + for i in range(5): + self._create_event_and_report( + room_id=self.room_id2, user_tok=self.other_user_tok, + ) + for i in range(5): + self._create_event_and_report( + room_id=self.room_id1, user_tok=self.admin_user_tok, + ) + for i in range(5): + self._create_event_and_report( + room_id=self.room_id2, user_tok=self.admin_user_tok, + ) + + self.url = "/_synapse/admin/v1/event_reports" + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + + request, channel = self.make_request( + "GET", self.url, access_token=self.other_user_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_default_success(self): + """ + Testing list of reported events + """ + + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 20) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["event_reports"]) + + def test_limit(self): + """ + Testing list of reported events with limit + """ + + request, channel = self.make_request( + "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 5) + self.assertEqual(channel.json_body["next_token"], 5) + self._check_fields(channel.json_body["event_reports"]) + + def test_from(self): + """ + Testing list of reported events with a defined starting point (from) + """ + + request, channel = self.make_request( + "GET", self.url + "?from=5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 15) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["event_reports"]) + + def test_limit_and_from(self): + """ + Testing list of reported events with a defined starting point and limit + """ + + request, channel = self.make_request( + "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(channel.json_body["next_token"], 15) + self.assertEqual(len(channel.json_body["event_reports"]), 10) + self._check_fields(channel.json_body["event_reports"]) + + def test_filter_room(self): + """ + Testing list of reported events with a filter of room + """ + + request, channel = self.make_request( + "GET", + self.url + "?room_id=%s" % self.room_id1, + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 10) + self.assertEqual(len(channel.json_body["event_reports"]), 10) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["event_reports"]) + + for report in channel.json_body["event_reports"]: + self.assertEqual(report["room_id"], self.room_id1) + + def test_filter_user(self): + """ + Testing list of reported events with a filter of user + """ + + request, channel = self.make_request( + "GET", + self.url + "?user_id=%s" % self.other_user, + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 10) + self.assertEqual(len(channel.json_body["event_reports"]), 10) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["event_reports"]) + + for report in channel.json_body["event_reports"]: + self.assertEqual(report["user_id"], self.other_user) + + def test_filter_user_and_room(self): + """ + Testing list of reported events with a filter of user and room + """ + + request, channel = self.make_request( + "GET", + self.url + "?user_id=%s&room_id=%s" % (self.other_user, self.room_id1), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 5) + self.assertEqual(len(channel.json_body["event_reports"]), 5) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["event_reports"]) + + for report in channel.json_body["event_reports"]: + self.assertEqual(report["user_id"], self.other_user) + self.assertEqual(report["room_id"], self.room_id1) + + def test_valid_search_order(self): + """ + Testing search order. Order by timestamps. + """ + + # fetch the most recent first, largest timestamp + request, channel = self.make_request( + "GET", self.url + "?dir=b", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 20) + report = 1 + while report < len(channel.json_body["event_reports"]): + self.assertGreaterEqual( + channel.json_body["event_reports"][report - 1]["received_ts"], + channel.json_body["event_reports"][report]["received_ts"], + ) + report += 1 + + # fetch the oldest first, smallest timestamp + request, channel = self.make_request( + "GET", self.url + "?dir=f", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 20) + report = 1 + while report < len(channel.json_body["event_reports"]): + self.assertLessEqual( + channel.json_body["event_reports"][report - 1]["received_ts"], + channel.json_body["event_reports"][report]["received_ts"], + ) + report += 1 + + def test_invalid_search_order(self): + """ + Testing that a invalid search order returns a 400 + """ + + request, channel = self.make_request( + "GET", self.url + "?dir=bar", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual("Unknown direction: bar", channel.json_body["error"]) + + def test_limit_is_negative(self): + """ + Testing that a negative list parameter returns a 400 + """ + + request, channel = self.make_request( + "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_from_is_negative(self): + """ + Testing that a negative from parameter returns a 400 + """ + + request, channel = self.make_request( + "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_next_token(self): + """ + Testing that `next_token` appears at the right place + """ + + # `next_token` does not appear + # Number of results is the number of entries + request, channel = self.make_request( + "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 20) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + request, channel = self.make_request( + "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 20) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + request, channel = self.make_request( + "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 19) + self.assertEqual(channel.json_body["next_token"], 19) + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + request, channel = self.make_request( + "GET", self.url + "?from=19", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], 20) + self.assertEqual(len(channel.json_body["event_reports"]), 1) + self.assertNotIn("next_token", channel.json_body) + + def _create_event_and_report(self, room_id, user_tok): + """Create and report events + """ + resp = self.helper.send(room_id, tok=user_tok) + event_id = resp["event_id"] + + request, channel = self.make_request( + "POST", + "rooms/%s/report/%s" % (room_id, event_id), + json.dumps({"score": -100, "reason": "this makes me sad"}), + access_token=user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + def _check_fields(self, content): + """Checks that all attributes are present in a event report + """ + for c in content: + self.assertIn("id", c) + self.assertIn("received_ts", c) + self.assertIn("room_id", c) + self.assertIn("event_id", c) + self.assertIn("user_id", c) + self.assertIn("reason", c) + self.assertIn("content", c) + self.assertIn("sender", c) + self.assertIn("room_alias", c) + self.assertIn("event_json", c) + self.assertIn("score", c["content"]) + self.assertIn("reason", c["content"]) + self.assertIn("auth_events", c["event_json"]) + self.assertIn("type", c["event_json"]) + self.assertIn("room_id", c["event_json"]) + self.assertIn("sender", c["event_json"]) + self.assertIn("content", c["event_json"]) -- cgit 1.5.1 From 8998217540bc41975e64e44c507632361ca95698 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 22 Sep 2020 19:19:01 +0200 Subject: Fixed a bug with reactivating users with the admin API (#8362) Fixes: #8359 Trying to reactivate a user with the admin API (`PUT /_synapse/admin/v2/users/`) causes an internal server error. Seems to be a regression in #8033. --- changelog.d/8362.bugfix | 1 + synapse/storage/databases/main/user_erasure_store.py | 2 +- tests/rest/admin/test_user.py | 14 ++++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8362.bugfix diff --git a/changelog.d/8362.bugfix b/changelog.d/8362.bugfix new file mode 100644 index 0000000000..4e50067c87 --- /dev/null +++ b/changelog.d/8362.bugfix @@ -0,0 +1 @@ +Fixed a regression in v1.19.0 with reactivating users through the admin API. diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index 2f7c95fc74..f9575b1f1f 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -100,7 +100,7 @@ class UserErasureStore(UserErasureWorkerStore): return # They are there, delete them. - self.simple_delete_one_txn( + self.db_pool.simple_delete_one_txn( txn, "erased_users", keyvalues={"user_id": user_id} ) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index f96011fc1c..98d0623734 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -874,6 +874,10 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self._is_erased("@user:test", False) + d = self.store.mark_user_erased("@user:test") + self.assertIsNone(self.get_success(d)) + self._is_erased("@user:test", True) # Attempt to reactivate the user (without a password). request, channel = self.make_request( @@ -906,6 +910,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(False, channel.json_body["deactivated"]) + self._is_erased("@user:test", False) def test_set_user_as_admin(self): """ @@ -996,6 +1001,15 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Ensure they're still alive self.assertEqual(0, channel.json_body["deactivated"]) + def _is_erased(self, user_id, expect): + """Assert that the user is erased or not + """ + d = self.store.is_user_erased(user_id) + if expect: + self.assertTrue(self.get_success(d)) + else: + self.assertFalse(self.get_success(d)) + class UserMembershipRestTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From 4325be1a52b9054a2c1096dcdb29ee79d9ad4ead Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 22 Sep 2020 19:39:29 +0100 Subject: Fix missing null character check on guest_access room state When updating room_stats_state, we try to check for null bytes slipping in to the content for state events. It turns out we had added guest_access as a field to room_stats_state without including it in the null byte check. Lo and behold, a null byte in a m.room.guest_access event then breaks room_stats_state updates. This PR adds the check for guest_access. A further PR will improve this function so that this hopefully does not happen again in future. --- synapse/storage/databases/main/stats.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index d7816a8606..5beb302be3 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -210,6 +210,7 @@ class StatsStore(StateDeltasStore): * topic * avatar * canonical_alias + * guest_access A is_federatable key can also be included with a boolean value. @@ -234,6 +235,7 @@ class StatsStore(StateDeltasStore): "topic", "avatar", "canonical_alias", + "guest_access", ): field = fields.get(col, sentinel) if field is not sentinel and (not isinstance(field, str) or "\0" in field): -- cgit 1.5.1 From 48336eeb85457e356a7a23619776dc598ebd2189 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 22 Sep 2020 14:54:23 +0100 Subject: Changelog --- changelog.d/8373.bugfix | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/8373.bugfix diff --git a/changelog.d/8373.bugfix b/changelog.d/8373.bugfix new file mode 100644 index 0000000000..e9d66a2088 --- /dev/null +++ b/changelog.d/8373.bugfix @@ -0,0 +1 @@ +Include `guest_access` in the fields that are checked for null bytes when updating `room_stats_state`. Broke in v1.7.2. \ No newline at end of file -- cgit 1.5.1 From a4e63e5a47a855884ae3aea41dfbfa464bddb744 Mon Sep 17 00:00:00 2001 From: Julian Fietkau <1278511+jfietkau@users.noreply.github.com> Date: Wed, 23 Sep 2020 12:14:08 +0200 Subject: Add note to reverse_proxy.md about disabling Apache's mod_security2 (#8375) This change adds a note and a few lines of configuration settings for Apache users to disable ModSecurity for Synapse's virtual hosts. With ModSecurity enabled and running with its default settings, Matrix clients are unable to send chat messages through the Synapse installation. With this change, ModSecurity can be disabled only for the Synapse virtual hosts. --- changelog.d/8375.doc | 1 + docs/reverse_proxy.md | 8 ++++++++ 2 files changed, 9 insertions(+) create mode 100644 changelog.d/8375.doc diff --git a/changelog.d/8375.doc b/changelog.d/8375.doc new file mode 100644 index 0000000000..d291fb92fa --- /dev/null +++ b/changelog.d/8375.doc @@ -0,0 +1 @@ +Add note to the reverse proxy settings documentation about disabling Apache's mod_security2. Contributed by Julian Fietkau (@jfietkau). diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md index edd109fa7b..46d8f35771 100644 --- a/docs/reverse_proxy.md +++ b/docs/reverse_proxy.md @@ -121,6 +121,14 @@ example.com:8448 { **NOTE**: ensure the `nocanon` options are included. +**NOTE 2**: It appears that Synapse is currently incompatible with the ModSecurity module for Apache (`mod_security2`). If you need it enabled for other services on your web server, you can disable it for Synapse's two VirtualHosts by including the following lines before each of the two `` above: + +``` + + SecRuleEngine off + +``` + ### HAProxy ``` -- cgit 1.5.1 From bbde4038dff379fdf48b914782a73a6889135a56 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 23 Sep 2020 06:45:37 -0400 Subject: Do not check lint/test dependencies at runtime. (#8377) moves non-runtime dependencies out of synapse.python_dependencies (test and lint) --- changelog.d/8330.misc | 2 +- changelog.d/8377.misc | 1 + setup.py | 16 ++++++++++++++++ synapse/python_dependencies.py | 13 ++++--------- tox.ini | 8 +++----- 5 files changed, 25 insertions(+), 15 deletions(-) create mode 100644 changelog.d/8377.misc diff --git a/changelog.d/8330.misc b/changelog.d/8330.misc index c51370f215..fbfdd52473 100644 --- a/changelog.d/8330.misc +++ b/changelog.d/8330.misc @@ -1 +1 @@ -Move lint-related dependencies to package-extra field, update CONTRIBUTING.md to utilise this. \ No newline at end of file +Move lint-related dependencies to package-extra field, update CONTRIBUTING.md to utilise this. diff --git a/changelog.d/8377.misc b/changelog.d/8377.misc new file mode 100644 index 0000000000..fbfdd52473 --- /dev/null +++ b/changelog.d/8377.misc @@ -0,0 +1 @@ +Move lint-related dependencies to package-extra field, update CONTRIBUTING.md to utilise this. diff --git a/setup.py b/setup.py index 54ddec8f9f..926b1bc86f 100755 --- a/setup.py +++ b/setup.py @@ -94,6 +94,22 @@ ALL_OPTIONAL_REQUIREMENTS = dependencies["ALL_OPTIONAL_REQUIREMENTS"] # Make `pip install matrix-synapse[all]` install all the optional dependencies. CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS) +# Developer dependencies should not get included in "all". +# +# We pin black so that our tests don't start failing on new releases. +CONDITIONAL_REQUIREMENTS["lint"] = [ + "isort==5.0.3", + "black==19.10b0", + "flake8-comprehensions", + "flake8", +] + +# Dependencies which are exclusively required by unit test code. This is +# NOT a list of all modules that are necessary to run the unit tests. +# Tests assume that all optional dependencies are installed. +# +# parameterized_class decorator was introduced in parameterized 0.7.0 +CONDITIONAL_REQUIREMENTS["test"] = ["mock>=2.0", "parameterized>=0.7.0"] setup( name="matrix-synapse", diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 67f019fd22..288631477e 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -37,6 +37,9 @@ logger = logging.getLogger(__name__) # installed when that optional dependency requirement is specified. It is passed # to setup() as extras_require in setup.py # +# Note that these both represent runtime dependencies (and the versions +# installed are checked at runtime). +# # [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers. REQUIREMENTS = [ @@ -92,20 +95,12 @@ CONDITIONAL_REQUIREMENTS = { "oidc": ["authlib>=0.14.0"], "systemd": ["systemd-python>=231"], "url_preview": ["lxml>=3.5.0"], - # Dependencies which are exclusively required by unit test code. This is - # NOT a list of all modules that are necessary to run the unit tests. - # Tests assume that all optional dependencies are installed. - # - # parameterized_class decorator was introduced in parameterized 0.7.0 - "test": ["mock>=2.0", "parameterized>=0.7.0"], "sentry": ["sentry-sdk>=0.7.2"], "opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"], "jwt": ["pyjwt>=1.6.4"], # hiredis is not a *strict* dependency, but it makes things much faster. # (if it is not installed, we fall back to slow code.) "redis": ["txredisapi>=1.4.7", "hiredis"], - # We pin black so that our tests don't start failing on new releases. - "lint": ["isort==5.0.3", "black==19.10b0", "flake8-comprehensions", "flake8"], } ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str] @@ -113,7 +108,7 @@ ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str] for name, optional_deps in CONDITIONAL_REQUIREMENTS.items(): # Exclude systemd as it's a system-based requirement. # Exclude lint as it's a dev-based requirement. - if name not in ["systemd", "lint"]: + if name not in ["systemd"]: ALL_OPTIONAL_REQUIREMENTS = set(optional_deps) | ALL_OPTIONAL_REQUIREMENTS diff --git a/tox.ini b/tox.ini index ddcab0198f..4d132eff4c 100644 --- a/tox.ini +++ b/tox.ini @@ -2,13 +2,12 @@ envlist = packaging, py35, py36, py37, py38, check_codestyle, check_isort [base] +extras = test deps = - mock python-subunit junitxml coverage coverage-enable-subprocess - parameterized # cyptography 2.2 requires setuptools >= 18.5 # @@ -36,7 +35,7 @@ setenv = [testenv] deps = {[base]deps} -extras = all +extras = all, test whitelist_externals = sh @@ -84,7 +83,6 @@ deps = # Old automat version for Twisted Automat == 0.3.0 - mock lxml coverage coverage-enable-subprocess @@ -97,7 +95,7 @@ commands = /bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" -e "s/psycopg2==2.6//" -e "s/pyopenssl==16.0.0/pyopenssl==17.0.0/" | xargs -d"\n" pip install' # Install Synapse itself. This won't update any libraries. - pip install -e . + pip install -e ".[test]" {envbindir}/coverage run "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:} -- cgit 1.5.1 From 916bb9d0d15cf941e73b2e808c553a1edd1c2eb9 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Wed, 23 Sep 2020 17:06:28 +0200 Subject: Don't push if an user account has expired (#8353) --- changelog.d/8353.bugfix | 1 + synapse/api/auth.py | 6 +----- synapse/push/pusherpool.py | 18 ++++++++++++++++++ synapse/storage/databases/main/registration.py | 14 ++++++++++++++ 4 files changed, 34 insertions(+), 5 deletions(-) create mode 100644 changelog.d/8353.bugfix diff --git a/changelog.d/8353.bugfix b/changelog.d/8353.bugfix new file mode 100644 index 0000000000..45fc0adb8d --- /dev/null +++ b/changelog.d/8353.bugfix @@ -0,0 +1 @@ +Don't send push notifications to expired user accounts. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 75388643ee..1071a0576e 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -218,11 +218,7 @@ class Auth: # Deny the request if the user account has expired. if self._account_validity.enabled and not allow_expired: user_id = user.to_string() - expiration_ts = await self.store.get_expiration_ts_for_user(user_id) - if ( - expiration_ts is not None - and self.clock.time_msec() >= expiration_ts - ): + if await self.store.is_account_expired(user_id, self.clock.time_msec()): raise AuthError( 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT ) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index cc839ffce4..76150e117b 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -60,6 +60,8 @@ class PusherPool: self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() + self._account_validity = hs.config.account_validity + # We shard the handling of push notifications by user ID. self._pusher_shard_config = hs.config.push.pusher_shard_config self._instance_name = hs.get_instance_name() @@ -202,6 +204,14 @@ class PusherPool: ) for u in users_affected: + # Don't push if the user account has expired + if self._account_validity.enabled: + expired = await self.store.is_account_expired( + u, self.clock.time_msec() + ) + if expired: + continue + if u in self.pushers: for p in self.pushers[u].values(): p.on_new_notifications(max_stream_id) @@ -222,6 +232,14 @@ class PusherPool: ) for u in users_affected: + # Don't push if the user account has expired + if self._account_validity.enabled: + expired = await self.store.is_account_expired( + u, self.clock.time_msec() + ) + if expired: + continue + if u in self.pushers: for p in self.pushers[u].values(): p.on_new_receipts(min_stream_id, max_stream_id) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 675e81fe34..33825e8949 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -116,6 +116,20 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_expiration_ts_for_user", ) + async def is_account_expired(self, user_id: str, current_ts: int) -> bool: + """ + Returns whether an user account is expired. + + Args: + user_id: The user's ID + current_ts: The current timestamp + + Returns: + Whether the user account has expired + """ + expiration_ts = await self.get_expiration_ts_for_user(user_id) + return expiration_ts is not None and current_ts >= expiration_ts + async def set_account_validity_for_user( self, user_id: str, -- cgit 1.5.1 From cbabb312e0b59090e5a8cf9e7e016a8618e62867 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 23 Sep 2020 16:11:18 +0100 Subject: Use `async with` for ID gens (#8383) This will allow us to hit the DB after we've finished using the generated stream ID. --- changelog.d/8383.misc | 1 + synapse/storage/databases/main/account_data.py | 4 +- synapse/storage/databases/main/deviceinbox.py | 4 +- synapse/storage/databases/main/devices.py | 6 +- synapse/storage/databases/main/end_to_end_keys.py | 2 +- synapse/storage/databases/main/events.py | 6 +- synapse/storage/databases/main/group_server.py | 2 +- synapse/storage/databases/main/presence.py | 4 +- synapse/storage/databases/main/push_rule.py | 8 +- synapse/storage/databases/main/pusher.py | 4 +- synapse/storage/databases/main/receipts.py | 2 +- synapse/storage/databases/main/room.py | 6 +- synapse/storage/databases/main/tags.py | 4 +- synapse/storage/util/id_generators.py | 130 +++++++++++++--------- tests/storage/test_id_generators.py | 66 ++++++----- 15 files changed, 144 insertions(+), 105 deletions(-) create mode 100644 changelog.d/8383.misc diff --git a/changelog.d/8383.misc b/changelog.d/8383.misc new file mode 100644 index 0000000000..cb8318bf57 --- /dev/null +++ b/changelog.d/8383.misc @@ -0,0 +1 @@ +Refactor ID generators to use `async with` syntax. diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index c5a36990e4..ef81d73573 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -339,7 +339,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. @@ -387,7 +387,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index e71217a41f..d42faa3f1f 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) - with await self._device_inbox_id_gen.get_next() as stream_id: + async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id @@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) txn, stream_id, local_messages_by_user_then_device ) - with await self._device_inbox_id_gen.get_next() as stream_id: + async with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index c04374e43d..fdf394c612 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore): THe new stream ID. """ - with await self._device_list_id_gen.get_next() as stream_id: + async with self._device_list_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, @@ -1093,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not device_ids: return - with await self._device_list_id_gen.get_next_mult( + async with self._device_list_id_gen.get_next_mult( len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( @@ -1108,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return stream_ids[-1] context = get_active_span_text_map() - with await self._device_list_id_gen.get_next_mult( + async with self._device_list_id_gen.get_next_mult( len(hosts) * len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index c8df0bcb3f..22e1ed15d0 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -831,7 +831,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key (dict): the key data """ - with await self._cross_signing_id_gen.get_next() as stream_id: + async with self._cross_signing_id_gen.get_next() as stream_id: return await self.db_pool.runInteraction( "add_e2e_cross_signing_key", self._set_e2e_cross_signing_key_txn, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 9a80f419e3..7723d82496 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -156,15 +156,15 @@ class PersistEventsStore: # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. if backfilled: - stream_ordering_manager = await self._backfill_id_gen.get_next_mult( + stream_ordering_manager = self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) else: - stream_ordering_manager = await self._stream_id_gen.get_next_mult( + stream_ordering_manager = self._stream_id_gen.get_next_mult( len(events_and_contexts) ) - with stream_ordering_manager as stream_orderings: + async with stream_ordering_manager as stream_orderings: for (event, context), stream in zip(events_and_contexts, stream_orderings): event.internal_metadata.stream_ordering = stream diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index ccfbb2135e..7218191965 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore): return next_id - with await self._group_updates_id_gen.get_next() as next_id: + async with self._group_updates_id_gen.get_next() as next_id: res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index c9f655dfb7..dbbb99cb95 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): async def update_presence(self, presence_states): - stream_ordering_manager = await self._presence_id_gen.get_next_mult( + stream_ordering_manager = self._presence_id_gen.get_next_mult( len(presence_states) ) - with stream_ordering_manager as stream_orderings: + async with stream_ordering_manager as stream_orderings: await self.db_pool.runInteraction( "update_presence", self._update_presence_txn, diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index e20a16f907..711d5aa23d 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore): ) -> None: conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() if before or after: @@ -585,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore): txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" ) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -616,7 +616,7 @@ class PushRuleStore(PushRulesWorkerStore): Raises: NotFoundError if the rule does not exist. """ - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( "_set_push_rule_enabled_txn", @@ -754,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore): data={"actions": actions_json}, ) - with await self._push_rules_stream_id_gen.get_next() as stream_id: + async with self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index c388468273..df8609b97b 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore): last_stream_ordering, profile_tag="", ) -> None: - with await self._pushers_id_gen.get_next() as stream_id: + async with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so simple_upsert will retry await self.db_pool.simple_upsert( @@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore): }, ) - with await self._pushers_id_gen.get_next() as stream_id: + async with self._pushers_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "delete_pusher", delete_pusher_txn, stream_id ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index f880b5e562..c79ddff680 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -524,7 +524,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "insert_receipt_conv", graph_to_linear ) - with await self._receipts_id_gen.get_next() as stream_id: + async with self._receipts_id_gen.get_next() as stream_id: event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 3ee097abf7..3c7630857f 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1137,7 +1137,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "store_room_txn", store_room_txn, next_id ) @@ -1204,7 +1204,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) @@ -1284,7 +1284,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with await self._public_room_id_gen.get_next() as next_id: + async with self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 96ffe26cc9..9f120d3cb6 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) @@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with await self._account_data_id_gen.get_next() as next_id: + async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 1de2b91587..b0353ac2dc 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -12,14 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import contextlib import heapq import logging import threading from collections import deque -from typing import Dict, List, Set +from contextlib import contextmanager +from typing import Dict, List, Optional, Set, Union +import attr from typing_extensions import Deque from synapse.storage.database import DatabasePool, LoggingTransaction @@ -86,7 +86,7 @@ class StreamIdGenerator: upwards, -1 to grow downwards. Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ @@ -101,10 +101,10 @@ class StreamIdGenerator: ) self._unfinished_ids = deque() # type: Deque[int] - async def get_next(self): + def get_next(self): """ Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ with self._lock: @@ -113,7 +113,7 @@ class StreamIdGenerator: self._unfinished_ids.append(next_id) - @contextlib.contextmanager + @contextmanager def manager(): try: yield next_id @@ -121,12 +121,12 @@ class StreamIdGenerator: with self._lock: self._unfinished_ids.remove(next_id) - return manager() + return _AsyncCtxManagerWrapper(manager()) - async def get_next_mult(self, n): + def get_next_mult(self, n): """ Usage: - with await stream_id_gen.get_next(n) as stream_ids: + async with stream_id_gen.get_next(n) as stream_ids: # ... persist events ... """ with self._lock: @@ -140,7 +140,7 @@ class StreamIdGenerator: for next_id in next_ids: self._unfinished_ids.append(next_id) - @contextlib.contextmanager + @contextmanager def manager(): try: yield next_ids @@ -149,7 +149,7 @@ class StreamIdGenerator: for next_id in next_ids: self._unfinished_ids.remove(next_id) - return manager() + return _AsyncCtxManagerWrapper(manager()) def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or @@ -282,59 +282,23 @@ class MultiWriterIdGenerator: def _load_next_mult_id_txn(self, txn, n: int) -> List[int]: return self._sequence_gen.get_next_mult_txn(txn, n) - async def get_next(self): + def get_next(self): """ Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn) - - # Assert the fetched ID is actually greater than what we currently - # believe the ID to be. If not, then the sequence and table have got - # out of sync somehow. - with self._lock: - assert self._current_positions.get(self._instance_name, 0) < next_id - - self._unfinished_ids.add(next_id) - - @contextlib.contextmanager - def manager(): - try: - # Multiply by the return factor so that the ID has correct sign. - yield self._return_factor * next_id - finally: - self._mark_id_as_finished(next_id) - return manager() + return _MultiWriterCtxManager(self) - async def get_next_mult(self, n: int): + def get_next_mult(self, n: int): """ Usage: - with await stream_id_gen.get_next_mult(5) as stream_ids: + async with stream_id_gen.get_next_mult(5) as stream_ids: # ... persist events ... """ - next_ids = await self._db.runInteraction( - "_load_next_mult_id", self._load_next_mult_id_txn, n - ) - # Assert the fetched ID is actually greater than any ID we've already - # seen. If not, then the sequence and table have got out of sync - # somehow. - with self._lock: - assert max(self._current_positions.values(), default=0) < min(next_ids) - - self._unfinished_ids.update(next_ids) - - @contextlib.contextmanager - def manager(): - try: - yield [self._return_factor * i for i in next_ids] - finally: - for i in next_ids: - self._mark_id_as_finished(i) - - return manager() + return _MultiWriterCtxManager(self, n) def get_next_txn(self, txn: LoggingTransaction): """ @@ -482,3 +446,61 @@ class MultiWriterIdGenerator: # There was a gap in seen positions, so there is nothing more to # do. break + + +@attr.s(slots=True) +class _AsyncCtxManagerWrapper: + """Helper class to convert a plain context manager to an async one. + + This is mainly useful if you have a plain context manager but the interface + requires an async one. + """ + + inner = attr.ib() + + async def __aenter__(self): + return self.inner.__enter__() + + async def __aexit__(self, exc_type, exc, tb): + return self.inner.__exit__(exc_type, exc, tb) + + +@attr.s(slots=True) +class _MultiWriterCtxManager: + """Async context manager returned by MultiWriterIdGenerator + """ + + id_gen = attr.ib(type=MultiWriterIdGenerator) + multiple_ids = attr.ib(type=Optional[int], default=None) + stream_ids = attr.ib(type=List[int], factory=list) + + async def __aenter__(self) -> Union[int, List[int]]: + self.stream_ids = await self.id_gen._db.runInteraction( + "_load_next_mult_id", + self.id_gen._load_next_mult_id_txn, + self.multiple_ids or 1, + ) + + # Assert the fetched ID is actually greater than any ID we've already + # seen. If not, then the sequence and table have got out of sync + # somehow. + with self.id_gen._lock: + assert max(self.id_gen._current_positions.values(), default=0) < min( + self.stream_ids + ) + + self.id_gen._unfinished_ids.update(self.stream_ids) + + if self.multiple_ids is None: + return self.stream_ids[0] * self.id_gen._return_factor + else: + return [i * self.id_gen._return_factor for i in self.stream_ids] + + async def __aexit__(self, exc_type, exc, tb): + for i in self.stream_ids: + self.id_gen._mark_id_as_finished(i) + + if exc_type is not None: + return False + + return False diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 20636fc400..fb8f5bc255 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -111,7 +111,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # advanced after we leave the context manager. async def _get_next_async(): - with await id_gen.get_next() as stream_id: + async with id_gen.get_next() as stream_id: self.assertEqual(stream_id, 8) self.assertEqual(id_gen.get_positions(), {"master": 7}) @@ -139,10 +139,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): ctx3 = self.get_success(id_gen.get_next()) ctx4 = self.get_success(id_gen.get_next()) - s1 = ctx1.__enter__() - s2 = ctx2.__enter__() - s3 = ctx3.__enter__() - s4 = ctx4.__enter__() + s1 = self.get_success(ctx1.__aenter__()) + s2 = self.get_success(ctx2.__aenter__()) + s3 = self.get_success(ctx3.__aenter__()) + s4 = self.get_success(ctx4.__aenter__()) self.assertEqual(s1, 8) self.assertEqual(s2, 9) @@ -152,22 +152,22 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) - ctx2.__exit__(None, None, None) + self.get_success(ctx2.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) - ctx1.__exit__(None, None, None) + self.get_success(ctx1.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 9}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) - ctx4.__exit__(None, None, None) + self.get_success(ctx4.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 9}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) - ctx3.__exit__(None, None, None) + self.get_success(ctx3.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 11}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 11) @@ -190,7 +190,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # advanced after we leave the context manager. async def _get_next_async(): - with await first_id_gen.get_next() as stream_id: + async with first_id_gen.get_next() as stream_id: self.assertEqual(stream_id, 8) self.assertEqual( @@ -208,7 +208,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # stream ID async def _get_next_async(): - with await second_id_gen.get_next() as stream_id: + async with second_id_gen.get_next() as stream_id: self.assertEqual(stream_id, 9) self.assertEqual( @@ -305,9 +305,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) self.assertEqual(id_gen.get_persisted_upto_position(), 3) - with self.get_success(id_gen.get_next()) as stream_id: - self.assertEqual(stream_id, 6) - self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + async def _get_next_async(): + async with id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 6) + self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + self.get_success(_get_next_async()) self.assertEqual(id_gen.get_persisted_upto_position(), 6) @@ -373,16 +377,22 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """ id_gen = self._create_id_generator() - with self.get_success(id_gen.get_next()) as stream_id: - self._insert_row("master", stream_id) + async def _get_next_async(): + async with id_gen.get_next() as stream_id: + self._insert_row("master", stream_id) + + self.get_success(_get_next_async()) self.assertEqual(id_gen.get_positions(), {"master": -1}) self.assertEqual(id_gen.get_current_token_for_writer("master"), -1) self.assertEqual(id_gen.get_persisted_upto_position(), -1) - with self.get_success(id_gen.get_next_mult(3)) as stream_ids: - for stream_id in stream_ids: - self._insert_row("master", stream_id) + async def _get_next_async2(): + async with id_gen.get_next_mult(3) as stream_ids: + for stream_id in stream_ids: + self._insert_row("master", stream_id) + + self.get_success(_get_next_async2()) self.assertEqual(id_gen.get_positions(), {"master": -4}) self.assertEqual(id_gen.get_current_token_for_writer("master"), -4) @@ -402,18 +412,24 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen_1 = self._create_id_generator("first") id_gen_2 = self._create_id_generator("second") - with self.get_success(id_gen_1.get_next()) as stream_id: - self._insert_row("first", stream_id) - id_gen_2.advance("first", stream_id) + async def _get_next_async(): + async with id_gen_1.get_next() as stream_id: + self._insert_row("first", stream_id) + id_gen_2.advance("first", stream_id) + + self.get_success(_get_next_async()) self.assertEqual(id_gen_1.get_positions(), {"first": -1}) self.assertEqual(id_gen_2.get_positions(), {"first": -1}) self.assertEqual(id_gen_1.get_persisted_upto_position(), -1) self.assertEqual(id_gen_2.get_persisted_upto_position(), -1) - with self.get_success(id_gen_2.get_next()) as stream_id: - self._insert_row("second", stream_id) - id_gen_1.advance("second", stream_id) + async def _get_next_async2(): + async with id_gen_2.get_next() as stream_id: + self._insert_row("second", stream_id) + id_gen_1.advance("second", stream_id) + + self.get_success(_get_next_async2()) self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) -- cgit 1.5.1 From 302dc89f6a16f69e076943cb0a9b94f1e41741f9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 23 Sep 2020 16:42:14 +0100 Subject: Fix bug which caused failure on join with malformed membership events (#8385) --- changelog.d/8385.bugfix | 1 + synapse/storage/databases/main/events.py | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) create mode 100644 changelog.d/8385.bugfix diff --git a/changelog.d/8385.bugfix b/changelog.d/8385.bugfix new file mode 100644 index 0000000000..c42502a8e0 --- /dev/null +++ b/changelog.d/8385.bugfix @@ -0,0 +1 @@ +Fix a bug which could cause errors in rooms with malformed membership events, on servers using sqlite. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 7723d82496..18def01f50 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -17,7 +17,7 @@ import itertools import logging from collections import OrderedDict, namedtuple -from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple import attr from prometheus_client import Counter @@ -1108,6 +1108,10 @@ class PersistEventsStore: def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. """ + + def str_or_none(val: Any) -> Optional[str]: + return val if isinstance(val, str) else None + self.db_pool.simple_insert_many_txn( txn, table="room_memberships", @@ -1118,8 +1122,8 @@ class PersistEventsStore: "sender": event.user_id, "room_id": event.room_id, "membership": event.membership, - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), + "display_name": str_or_none(event.content.get("displayname")), + "avatar_url": str_or_none(event.content.get("avatar_url")), } for event in events ], -- cgit 1.5.1 From 91c60f304256c08e8aff53ed13d5b282057277d6 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 23 Sep 2020 16:42:44 +0100 Subject: Improve logging of state resolution (#8371) I'd like to get a better insight into what we are doing with respect to state res. The list of state groups we are resolving across should be short (if it isn't, that's a massive problem in itself), so it should be fine to log it in ite entiretly. I've done some grepping and found approximately zero cases in which the "shortcut" code delivered the result, so I've ripped that out too. --- changelog.d/8371.misc | 1 + synapse/state/__init__.py | 64 ++++++++++++----------------------------------- 2 files changed, 17 insertions(+), 48 deletions(-) create mode 100644 changelog.d/8371.misc diff --git a/changelog.d/8371.misc b/changelog.d/8371.misc new file mode 100644 index 0000000000..6a54a9496a --- /dev/null +++ b/changelog.d/8371.misc @@ -0,0 +1 @@ +Improve logging of state resolution. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 56d6afb863..5a5ea39e01 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -25,7 +25,6 @@ from typing import ( Sequence, Set, Union, - cast, overload, ) @@ -42,7 +41,7 @@ from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.roommember import ProfileInfo -from synapse.types import Collection, MutableStateMap, StateMap +from synapse.types import Collection, StateMap from synapse.util import Clock from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -472,10 +471,9 @@ class StateResolutionHandler: def __init__(self, hs): self.clock = hs.get_clock() - # dict of set of event_ids -> _StateCacheEntry. - self._state_cache = None self.resolve_linearizer = Linearizer(name="state_resolve_lock") + # dict of set of event_ids -> _StateCacheEntry. self._state_cache = ExpiringCache( cache_name="state_cache", clock=self.clock, @@ -519,57 +517,28 @@ class StateResolutionHandler: Returns: The resolved state """ - logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys()) - group_names = frozenset(state_groups_ids.keys()) with (await self.resolve_linearizer.queue(group_names)): - if self._state_cache is not None: - cache = self._state_cache.get(group_names, None) - if cache: - return cache + cache = self._state_cache.get(group_names, None) + if cache: + return cache logger.info( - "Resolving state for %s with %d groups", room_id, len(state_groups_ids) + "Resolving state for %s with groups %s", room_id, list(group_names), ) state_groups_histogram.observe(len(state_groups_ids)) - # start by assuming we won't have any conflicted state, and build up the new - # state map by iterating through the state groups. If we discover a conflict, - # we give up and instead use `resolve_events_with_store`. - # - # XXX: is this actually worthwhile, or should we just let - # resolve_events_with_store do it? - new_state = {} # type: MutableStateMap[str] - conflicted_state = False - for st in state_groups_ids.values(): - for key, e_id in st.items(): - if key in new_state: - conflicted_state = True - break - new_state[key] = e_id - if conflicted_state: - break - - if conflicted_state: - logger.info("Resolving conflicted state for %r", room_id) - with Measure(self.clock, "state._resolve_events"): - # resolve_events_with_store returns a StateMap, but we can - # treat it as a MutableStateMap as it is above. It isn't - # actually mutated anymore (and is frozen in - # _make_state_cache_entry below). - new_state = cast( - MutableStateMap, - await resolve_events_with_store( - self.clock, - room_id, - room_version, - list(state_groups_ids.values()), - event_map=event_map, - state_res_store=state_res_store, - ), - ) + with Measure(self.clock, "state._resolve_events"): + new_state = await resolve_events_with_store( + self.clock, + room_id, + room_version, + list(state_groups_ids.values()), + event_map=event_map, + state_res_store=state_res_store, + ) # if the new state matches any of the input state groups, we can # use that state group again. Otherwise we will generate a state_id @@ -579,8 +548,7 @@ class StateResolutionHandler: with Measure(self.clock, "state.create_group_ids"): cache = _make_state_cache_entry(new_state, state_groups_ids) - if self._state_cache is not None: - self._state_cache[group_names] = cache + self._state_cache[group_names] = cache return cache -- cgit 1.5.1 From 2983049a77557512519f3856fc88e3bc5f1915ed Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 23 Sep 2020 18:18:43 +0100 Subject: Factor out `_send_dummy_event_for_room` (#8370) this makes it possible to use from the manhole, and seems cleaner anyway. --- changelog.d/8370.misc | 1 + synapse/handlers/message.py | 102 +++++++++++++++++++++++--------------------- 2 files changed, 55 insertions(+), 48 deletions(-) create mode 100644 changelog.d/8370.misc diff --git a/changelog.d/8370.misc b/changelog.d/8370.misc new file mode 100644 index 0000000000..1aaac1e0bf --- /dev/null +++ b/changelog.d/8370.misc @@ -0,0 +1 @@ +Factor out a `_send_dummy_event_for_room` method. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index a8fe5cf4e2..6ee559fd1d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1182,54 +1182,7 @@ class EventCreationHandler: ) for room_id in room_ids: - # For each room we need to find a joined member we can use to send - # the dummy event with. - - latest_event_ids = await self.store.get_prev_events_for_room(room_id) - - members = await self.state.get_current_users_in_room( - room_id, latest_event_ids=latest_event_ids - ) - dummy_event_sent = False - for user_id in members: - if not self.hs.is_mine_id(user_id): - continue - requester = create_requester(user_id) - try: - event, context = await self.create_event( - requester, - { - "type": "org.matrix.dummy_event", - "content": {}, - "room_id": room_id, - "sender": user_id, - }, - prev_event_ids=latest_event_ids, - ) - - event.internal_metadata.proactively_send = False - - # Since this is a dummy-event it is OK if it is sent by a - # shadow-banned user. - await self.send_nonmember_event( - requester, - event, - context, - ratelimit=False, - ignore_shadow_ban=True, - ) - dummy_event_sent = True - break - except ConsentNotGivenError: - logger.info( - "Failed to send dummy event into room %s for user %s due to " - "lack of consent. Will try another user" % (room_id, user_id) - ) - except AuthError: - logger.info( - "Failed to send dummy event into room %s for user %s due to " - "lack of power. Will try another user" % (room_id, user_id) - ) + dummy_event_sent = await self._send_dummy_event_for_room(room_id) if not dummy_event_sent: # Did not find a valid user in the room, so remove from future attempts @@ -1242,6 +1195,59 @@ class EventCreationHandler: now = self.clock.time_msec() self._rooms_to_exclude_from_dummy_event_insertion[room_id] = now + async def _send_dummy_event_for_room(self, room_id: str) -> bool: + """Attempt to send a dummy event for the given room. + + Args: + room_id: room to try to send an event from + + Returns: + True if a dummy event was successfully sent. False if no user was able + to send an event. + """ + + # For each room we need to find a joined member we can use to send + # the dummy event with. + latest_event_ids = await self.store.get_prev_events_for_room(room_id) + members = await self.state.get_current_users_in_room( + room_id, latest_event_ids=latest_event_ids + ) + for user_id in members: + if not self.hs.is_mine_id(user_id): + continue + requester = create_requester(user_id) + try: + event, context = await self.create_event( + requester, + { + "type": "org.matrix.dummy_event", + "content": {}, + "room_id": room_id, + "sender": user_id, + }, + prev_event_ids=latest_event_ids, + ) + + event.internal_metadata.proactively_send = False + + # Since this is a dummy-event it is OK if it is sent by a + # shadow-banned user. + await self.send_nonmember_event( + requester, event, context, ratelimit=False, ignore_shadow_ban=True, + ) + return True + except ConsentNotGivenError: + logger.info( + "Failed to send dummy event into room %s for user %s due to " + "lack of consent. Will try another user" % (room_id, user_id) + ) + except AuthError: + logger.info( + "Failed to send dummy event into room %s for user %s due to " + "lack of power. Will try another user" % (room_id, user_id) + ) + return False + def _expire_rooms_to_exclude_from_dummy_event_insertion(self): expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY to_expire = set() -- cgit 1.5.1 From 13099ae4311436b82ae47ca252cac1fa7fa58cc6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 24 Sep 2020 08:13:55 -0400 Subject: Mark the shadow_banned column as boolean in synapse_port_db. (#8386) --- .buildkite/test_db.db | Bin 18825216 -> 19279872 bytes changelog.d/8386.bugfix | 1 + scripts/synapse_port_db | 1 + 3 files changed, 2 insertions(+) create mode 100644 changelog.d/8386.bugfix diff --git a/.buildkite/test_db.db b/.buildkite/test_db.db index f20567ba73..361369a581 100644 Binary files a/.buildkite/test_db.db and b/.buildkite/test_db.db differ diff --git a/changelog.d/8386.bugfix b/changelog.d/8386.bugfix new file mode 100644 index 0000000000..24983a1e95 --- /dev/null +++ b/changelog.d/8386.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.20.0 which caused the `synapse_port_db` script to fail. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index a34bdf1830..ecca8b6e8f 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -89,6 +89,7 @@ BOOLEAN_COLUMNS = { "redactions": ["have_censored"], "room_stats_state": ["is_federatable"], "local_media_repository": ["safe_from_quarantine"], + "users": ["shadow_banned"], } -- cgit 1.5.1 From ac11fcbbb8ccfeb4c72b5aae9faef28469109277 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 24 Sep 2020 13:24:17 +0100 Subject: Add EventStreamPosition type (#8388) The idea is to remove some of the places we pass around `int`, where it can represent one of two things: 1. the position of an event in the stream; or 2. a token that partitions the stream, used as part of the stream tokens. The valid operations are then: 1. did a position happen before or after a token; 2. get all events that happened before or after a token; and 3. get all events between two tokens. (Note that we don't want to allow other operations as we want to change the tokens to be vector clocks rather than simple ints) --- changelog.d/8388.misc | 1 + synapse/handlers/federation.py | 16 +++++--- synapse/handlers/message.py | 6 +-- synapse/handlers/sync.py | 10 ++--- synapse/notifier.py | 55 ++++++++++++++------------ synapse/replication/tcp/client.py | 12 ++++-- synapse/storage/databases/main/roommember.py | 14 ++++--- synapse/storage/persist_events.py | 14 ++++--- synapse/storage/roommember.py | 2 +- synapse/types.py | 15 +++++++ tests/replication/slave/storage/test_events.py | 12 ++++-- 11 files changed, 100 insertions(+), 57 deletions(-) create mode 100644 changelog.d/8388.misc diff --git a/changelog.d/8388.misc b/changelog.d/8388.misc new file mode 100644 index 0000000000..aaaef88b66 --- /dev/null +++ b/changelog.d/8388.misc @@ -0,0 +1 @@ +Add `EventStreamPosition` type. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index ea9264e751..9f773aefa7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -74,6 +74,8 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( JsonDict, MutableStateMap, + PersistedEventPosition, + RoomStreamToken, StateMap, UserID, get_domain_from_id, @@ -2956,7 +2958,7 @@ class FederationHandler(BaseHandler): ) return result["max_stream_id"] else: - max_stream_id = await self.storage.persistence.persist_events( + max_stream_token = await self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled ) @@ -2967,12 +2969,12 @@ class FederationHandler(BaseHandler): if not backfilled: # Never notify for backfilled events for event, _ in event_and_contexts: - await self._notify_persisted_event(event, max_stream_id) + await self._notify_persisted_event(event, max_stream_token) - return max_stream_id + return max_stream_token.stream async def _notify_persisted_event( - self, event: EventBase, max_stream_id: int + self, event: EventBase, max_stream_token: RoomStreamToken ) -> None: """Checks to see if notifier/pushers should be notified about the event or not. @@ -2998,9 +3000,11 @@ class FederationHandler(BaseHandler): elif event.internal_metadata.is_outlier(): return - event_stream_id = event.internal_metadata.stream_ordering + event_pos = PersistedEventPosition( + self._instance_name, event.internal_metadata.stream_ordering + ) self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users + event, event_pos, max_stream_token, extra_users=extra_users ) async def _clean_room_for_join(self, room_id: str) -> None: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 6ee559fd1d..ee271e85e5 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1138,7 +1138,7 @@ class EventCreationHandler: if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") - event_stream_id, max_stream_id = await self.storage.persistence.persist_event( + event_pos, max_stream_token = await self.storage.persistence.persist_event( event, context=context ) @@ -1149,7 +1149,7 @@ class EventCreationHandler: def _notify(): try: self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users + event, event_pos, max_stream_token, extra_users=extra_users ) except Exception: logger.exception("Error notifying about new room event") @@ -1161,7 +1161,7 @@ class EventCreationHandler: # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) - return event_stream_id + return event_pos.stream async def _bump_active_time(self, user: UserID) -> None: try: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 9b3a4f638b..e948efef2e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -967,7 +967,7 @@ class SyncHandler: raise NotImplementedError() else: joined_room_ids = await self.get_rooms_for_user_at( - user_id, now_token.room_stream_id + user_id, now_token.room_key ) sync_result_builder = SyncResultBuilder( sync_config, @@ -1916,7 +1916,7 @@ class SyncHandler: raise Exception("Unrecognized rtype: %r", room_builder.rtype) async def get_rooms_for_user_at( - self, user_id: str, stream_ordering: int + self, user_id: str, room_key: RoomStreamToken ) -> FrozenSet[str]: """Get set of joined rooms for a user at the given stream ordering. @@ -1942,15 +1942,15 @@ class SyncHandler: # If the membership's stream ordering is after the given stream # ordering, we need to go and work out if the user was in the room # before. - for room_id, membership_stream_ordering in joined_rooms: - if membership_stream_ordering <= stream_ordering: + for room_id, event_pos in joined_rooms: + if not event_pos.persisted_after(room_key): joined_room_ids.add(room_id) continue logger.info("User joined room after current token: %s", room_id) extrems = await self.store.get_forward_extremeties_for_room( - room_id, stream_ordering + room_id, event_pos.stream ) users_in_room = await self.state.get_current_users_in_room(room_id, extrems) if user_id in users_in_room: diff --git a/synapse/notifier.py b/synapse/notifier.py index a8fd3ef886..441b3d15e2 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -42,7 +42,13 @@ from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.streams.config import PaginationConfig -from synapse.types import Collection, RoomStreamToken, StreamToken, UserID +from synapse.types import ( + Collection, + PersistedEventPosition, + RoomStreamToken, + StreamToken, + UserID, +) from synapse.util.async_helpers import ObservableDeferred, timeout_deferred from synapse.util.metrics import Measure from synapse.visibility import filter_events_for_client @@ -187,7 +193,7 @@ class Notifier: self.store = hs.get_datastore() self.pending_new_room_events = ( [] - ) # type: List[Tuple[int, EventBase, Collection[UserID]]] + ) # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]] # Called when there are new things to stream over replication self.replication_callbacks = [] # type: List[Callable[[], None]] @@ -246,8 +252,8 @@ class Notifier: def on_new_room_event( self, event: EventBase, - room_stream_id: int, - max_room_stream_id: int, + event_pos: PersistedEventPosition, + max_room_stream_token: RoomStreamToken, extra_users: Collection[UserID] = [], ): """ Used by handlers to inform the notifier something has happened @@ -261,16 +267,16 @@ class Notifier: until all previous events have been persisted before notifying the client streams. """ - self.pending_new_room_events.append((room_stream_id, event, extra_users)) - self._notify_pending_new_room_events(max_room_stream_id) + self.pending_new_room_events.append((event_pos, event, extra_users)) + self._notify_pending_new_room_events(max_room_stream_token) self.notify_replication() - def _notify_pending_new_room_events(self, max_room_stream_id: int): + def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken): """Notify for the room events that were queued waiting for a previous event to be persisted. Args: - max_room_stream_id: The highest stream_id below which all + max_room_stream_token: The highest stream_id below which all events have been persisted. """ pending = self.pending_new_room_events @@ -279,11 +285,9 @@ class Notifier: users = set() # type: Set[UserID] rooms = set() # type: Set[str] - for room_stream_id, event, extra_users in pending: - if room_stream_id > max_room_stream_id: - self.pending_new_room_events.append( - (room_stream_id, event, extra_users) - ) + for event_pos, event, extra_users in pending: + if event_pos.persisted_after(max_room_stream_token): + self.pending_new_room_events.append((event_pos, event, extra_users)) else: if ( event.type == EventTypes.Member @@ -296,39 +300,38 @@ class Notifier: if users or rooms: self.on_new_event( - "room_key", - RoomStreamToken(None, max_room_stream_id), - users=users, - rooms=rooms, + "room_key", max_room_stream_token, users=users, rooms=rooms, ) - self._on_updated_room_token(max_room_stream_id) + self._on_updated_room_token(max_room_stream_token) - def _on_updated_room_token(self, max_room_stream_id: int): + def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken): """Poke services that might care that the room position has been updated. """ # poke any interested application service. run_as_background_process( - "_notify_app_services", self._notify_app_services, max_room_stream_id + "_notify_app_services", self._notify_app_services, max_room_stream_token ) run_as_background_process( - "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_id + "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_token ) if self.federation_sender: - self.federation_sender.notify_new_events(max_room_stream_id) + self.federation_sender.notify_new_events(max_room_stream_token.stream) - async def _notify_app_services(self, max_room_stream_id: int): + async def _notify_app_services(self, max_room_stream_token: RoomStreamToken): try: - await self.appservice_handler.notify_interested_services(max_room_stream_id) + await self.appservice_handler.notify_interested_services( + max_room_stream_token.stream + ) except Exception: logger.exception("Error notifying application services of event") - async def _notify_pusher_pool(self, max_room_stream_id: int): + async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken): try: - await self._pusher_pool.on_new_notifications(max_room_stream_id) + await self._pusher_pool.on_new_notifications(max_room_stream_token.stream) except Exception: logger.exception("Error pusher pool of event") diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e82b9e386f..55af3d41ea 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -29,7 +29,7 @@ from synapse.replication.tcp.streams.events import ( EventsStreamEventRow, EventsStreamRow, ) -from synapse.types import UserID +from synapse.types import PersistedEventPosition, RoomStreamToken, UserID from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure @@ -151,8 +151,14 @@ class ReplicationDataHandler: extra_users = () # type: Tuple[UserID, ...] if event.type == EventTypes.Member: extra_users = (UserID.from_string(event.state_key),) - max_token = self.store.get_room_max_stream_ordering() - self.notifier.on_new_room_event(event, token, max_token, extra_users) + + max_token = RoomStreamToken( + None, self.store.get_room_max_stream_ordering() + ) + event_pos = PersistedEventPosition(instance_name, token) + self.notifier.on_new_room_event( + event, event_pos, max_token, extra_users + ) # Notify any waiting deferreds. The list is ordered by position so we # just iterate through the list until we reach a position that is diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 4fa8767b01..86ffe2479e 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set @@ -37,7 +36,7 @@ from synapse.storage.roommember import ( ProfileInfo, RoomsForUser, ) -from synapse.types import Collection, get_domain_from_id +from synapse.types import Collection, PersistedEventPosition, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -387,7 +386,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # for rooms the server is participating in. if self._current_state_events_membership_up_to_date: sql = """ - SELECT room_id, e.stream_ordering + SELECT room_id, e.instance_name, e.stream_ordering FROM current_state_events AS c INNER JOIN events AS e USING (room_id, event_id) WHERE @@ -397,7 +396,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ else: sql = """ - SELECT room_id, e.stream_ordering + SELECT room_id, e.instance_name, e.stream_ordering FROM current_state_events AS c INNER JOIN room_memberships AS m USING (room_id, event_id) INNER JOIN events AS e USING (room_id, event_id) @@ -408,7 +407,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ txn.execute(sql, (user_id, Membership.JOIN)) - return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) + return frozenset( + GetRoomsForUserWithStreamOrdering( + room_id, PersistedEventPosition(instance, stream_id) + ) + for room_id, instance, stream_id in txn + ) async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index d89f6ed128..603cd7d825 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases import Databases from synapse.storage.databases.main.events import DeltaState -from synapse.types import Collection, StateMap +from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap from synapse.util.async_helpers import ObservableDeferred from synapse.util.metrics import Measure @@ -190,6 +190,7 @@ class EventsPersistenceStorage: self.persist_events_store = stores.persist_events self._clock = hs.get_clock() + self._instance_name = hs.get_instance_name() self.is_mine_id = hs.is_mine_id self._event_persist_queue = _EventPeristenceQueue() self._state_resolution_handler = hs.get_state_resolution_handler() @@ -198,7 +199,7 @@ class EventsPersistenceStorage: self, events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool = False, - ) -> int: + ) -> RoomStreamToken: """ Write events to the database Args: @@ -228,11 +229,11 @@ class EventsPersistenceStorage: defer.gatherResults(deferreds, consumeErrors=True) ) - return self.main_store.get_current_events_token() + return RoomStreamToken(None, self.main_store.get_current_events_token()) async def persist_event( self, event: EventBase, context: EventContext, backfilled: bool = False - ) -> Tuple[int, int]: + ) -> Tuple[PersistedEventPosition, RoomStreamToken]: """ Returns: The stream ordering of `event`, and the stream ordering of the @@ -247,7 +248,10 @@ class EventsPersistenceStorage: await make_deferred_yieldable(deferred) max_persisted_id = self.main_store.get_current_events_token() - return (event.internal_metadata.stream_ordering, max_persisted_id) + event_stream_id = event.internal_metadata.stream_ordering + + pos = PersistedEventPosition(self._instance_name, event_stream_id) + return pos, RoomStreamToken(None, max_persisted_id) def _maybe_start_persisting(self, room_id: str): async def persisting_queue(item): diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 8c4a83a840..f152f63321 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -25,7 +25,7 @@ RoomsForUser = namedtuple( ) GetRoomsForUserWithStreamOrdering = namedtuple( - "_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering") + "_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos") ) diff --git a/synapse/types.py b/synapse/types.py index a6fc7df22c..ec39f9e1e8 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -495,6 +495,21 @@ class StreamToken: StreamToken.START = StreamToken.from_string("s0_0") +@attr.s(slots=True, frozen=True) +class PersistedEventPosition: + """Position of a newly persisted event with instance that persisted it. + + This can be used to test whether the event is persisted before or after a + RoomStreamToken. + """ + + instance_name = attr.ib(type=str) + stream = attr.ib(type=int) + + def persisted_after(self, token: RoomStreamToken) -> bool: + return token.stream < self.stream + + class ThirdPartyInstanceID( namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id")) ): diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index bc578411d6..c0ee1cfbd6 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -20,6 +20,7 @@ from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_ from synapse.handlers.room import RoomEventSource from synapse.replication.slave.storage.events import SlavedEventStore from synapse.storage.roommember import RoomsForUser +from synapse.types import PersistedEventPosition from tests.server import FakeTransport @@ -204,10 +205,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" ) self.replicate() + + expected_pos = PersistedEventPosition( + "master", j2.internal_metadata.stream_ordering + ) self.check( "get_rooms_for_user_with_stream_ordering", (USER_ID_2,), - {(ROOM_ID, j2.internal_metadata.stream_ordering)}, + {(ROOM_ID, expected_pos)}, ) def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self): @@ -293,9 +298,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): # the membership change is only any use to us if the room is in the # joined_rooms list. if membership_changes: - self.assertEqual( - joined_rooms, {(ROOM_ID, j2.internal_metadata.stream_ordering)} + expected_pos = PersistedEventPosition( + "master", j2.internal_metadata.stream_ordering ) + self.assertEqual(joined_rooms, {(ROOM_ID, expected_pos)}) event_id = 0 -- cgit 1.5.1 From 6fdf5775939100121ad9e6e3a8cb21192a5444d6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 24 Sep 2020 13:43:49 +0100 Subject: Add new sequences to port DB script (#8387) --- changelog.d/8387.feature | 1 + scripts/synapse_port_db | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+) create mode 100644 changelog.d/8387.feature diff --git a/changelog.d/8387.feature b/changelog.d/8387.feature new file mode 100644 index 0000000000..b363e929ea --- /dev/null +++ b/changelog.d/8387.feature @@ -0,0 +1 @@ +Add experimental support for sharding event persister. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index ecca8b6e8f..684a518b8e 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -628,6 +628,7 @@ class Porter(object): self.progress.set_state("Setting up sequence generators") await self._setup_state_group_id_seq() await self._setup_user_id_seq() + await self._setup_events_stream_seqs() self.progress.done() except Exception as e: @@ -804,6 +805,29 @@ class Porter(object): return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r) + def _setup_events_stream_seqs(self): + def r(txn): + txn.execute("SELECT MAX(stream_ordering) FROM events") + curr_id = txn.fetchone()[0] + if curr_id: + next_id = curr_id + 1 + txn.execute( + "ALTER SEQUENCE events_stream_seq RESTART WITH %s", (next_id,) + ) + + txn.execute("SELECT -MIN(stream_ordering) FROM events") + curr_id = txn.fetchone()[0] + if curr_id: + next_id = curr_id + 1 + txn.execute( + "ALTER SEQUENCE events_backfill_stream_seq RESTART WITH %s", + (next_id,), + ) + + return self.postgres_store.db_pool.runInteraction( + "_setup_events_stream_seqs", r + ) + ############################################## # The following is simply UI stuff -- cgit 1.5.1 From 11c9e17738277958f66d18015bf0e68f2c03bb8b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 24 Sep 2020 15:47:20 +0100 Subject: Add type annotations to SimpleHttpClient (#8372) --- changelog.d/8372.misc | 1 + synapse/appservice/api.py | 2 +- synapse/http/client.py | 187 ++++++++++++++++++-------- synapse/rest/media/v1/preview_url_resource.py | 14 +- 4 files changed, 143 insertions(+), 61 deletions(-) create mode 100644 changelog.d/8372.misc diff --git a/changelog.d/8372.misc b/changelog.d/8372.misc new file mode 100644 index 0000000000..a56e36de4b --- /dev/null +++ b/changelog.d/8372.misc @@ -0,0 +1 @@ +Add type annotations to `SimpleHttpClient`. diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 1514c0f691..c526c28b93 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -178,7 +178,7 @@ class ApplicationServiceApi(SimpleHttpClient): urllib.parse.quote(protocol), ) try: - info = await self.get_json(uri, {}) + info = await self.get_json(uri) if not _is_valid_3pe_metadata(info): logger.warning( diff --git a/synapse/http/client.py b/synapse/http/client.py index 13fcab3378..4694adc400 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -17,6 +17,18 @@ import logging import urllib from io import BytesIO +from typing import ( + Any, + BinaryIO, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) import treq from canonicaljson import encode_canonical_json @@ -37,6 +49,7 @@ from twisted.web._newclient import ResponseDone from twisted.web.client import Agent, HTTPConnectionPool, readBody from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import ( @@ -57,6 +70,19 @@ incoming_responses_counter = Counter( "synapse_http_client_responses", "", ["method", "code"] ) +# the type of the headers list, to be passed to the t.w.h.Headers. +# Actually we can mix str and bytes keys, but Mapping treats 'key' as invariant so +# we simplify. +RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]] + +# the value actually has to be a List, but List is invariant so we can't specify that +# the entries can either be Lists or bytes. +RawHeaderValue = Sequence[Union[str, bytes]] + +# the type of the query params, to be passed into `urlencode` +QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]] +QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]] + def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist): """ @@ -285,13 +311,26 @@ class SimpleHttpClient: ip_blacklist=self._ip_blacklist, ) - async def request(self, method, uri, data=None, headers=None): + async def request( + self, + method: str, + uri: str, + data: Optional[bytes] = None, + headers: Optional[Headers] = None, + ) -> IResponse: """ Args: - method (str): HTTP method to use. - uri (str): URI to query. - data (bytes): Data to send in the request body, if applicable. - headers (t.w.http_headers.Headers): Request headers. + method: HTTP method to use. + uri: URI to query. + data: Data to send in the request body, if applicable. + headers: Request headers. + + Returns: + Response object, once the headers have been read. + + Raises: + RequestTimedOutError if the request times out before the headers are read + """ # A small wrapper around self.agent.request() so we can easily attach # counters to it @@ -324,6 +363,8 @@ class SimpleHttpClient: headers=headers, **self._extra_treq_args ) + # we use our own timeout mechanism rather than treq's as a workaround + # for https://twistedmatrix.com/trac/ticket/9534. request_deferred = timeout_deferred( request_deferred, 60, @@ -353,18 +394,26 @@ class SimpleHttpClient: set_tag("error_reason", e.args[0]) raise - async def post_urlencoded_get_json(self, uri, args={}, headers=None): + async def post_urlencoded_get_json( + self, + uri: str, + args: Mapping[str, Union[str, List[str]]] = {}, + headers: Optional[RawHeaders] = None, + ) -> Any: """ Args: - uri (str): - args (dict[str, str|List[str]]): query params - headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from - header name to a list of values for that header + uri: uri to query + args: parameters to be url-encoded in the body + headers: a map from header name to a list of values for that header Returns: - object: parsed json + parsed json Raises: + RequestTimedOutException: if there is a timeout before the response headers + are received. Note there is currently no timeout on reading the response + body. + HttpResponseException: On a non-2xx HTTP response. ValueError: if the response was not JSON @@ -398,19 +447,24 @@ class SimpleHttpClient: response.code, response.phrase.decode("ascii", errors="replace"), body ) - async def post_json_get_json(self, uri, post_json, headers=None): + async def post_json_get_json( + self, uri: str, post_json: Any, headers: Optional[RawHeaders] = None + ) -> Any: """ Args: - uri (str): - post_json (object): - headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from - header name to a list of values for that header + uri: URI to query. + post_json: request body, to be encoded as json + headers: a map from header name to a list of values for that header Returns: - object: parsed json + parsed json Raises: + RequestTimedOutException: if there is a timeout before the response headers + are received. Note there is currently no timeout on reading the response + body. + HttpResponseException: On a non-2xx HTTP response. ValueError: if the response was not JSON @@ -440,21 +494,22 @@ class SimpleHttpClient: response.code, response.phrase.decode("ascii", errors="replace"), body ) - async def get_json(self, uri, args={}, headers=None): - """ Gets some json from the given URI. + async def get_json( + self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None, + ) -> Any: + """Gets some json from the given URI. Args: - uri (str): The URI to request, not including query parameters - args (dict): A dictionary used to create query strings, defaults to - None. - **Note**: The value of each key is assumed to be an iterable - and *not* a string. - headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from - header name to a list of values for that header + uri: The URI to request, not including query parameters + args: A dictionary used to create query string + headers: a map from header name to a list of values for that header Returns: - Succeeds when we get *any* 2xx HTTP response, with the - HTTP body as JSON. + Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON. Raises: + RequestTimedOutException: if there is a timeout before the response headers + are received. Note there is currently no timeout on reading the response + body. + HttpResponseException On a non-2xx HTTP response. ValueError: if the response was not JSON @@ -466,22 +521,27 @@ class SimpleHttpClient: body = await self.get_raw(uri, args, headers=headers) return json_decoder.decode(body.decode("utf-8")) - async def put_json(self, uri, json_body, args={}, headers=None): - """ Puts some json to the given URI. + async def put_json( + self, + uri: str, + json_body: Any, + args: QueryParams = {}, + headers: RawHeaders = None, + ) -> Any: + """Puts some json to the given URI. Args: - uri (str): The URI to request, not including query parameters - json_body (dict): The JSON to put in the HTTP body, - args (dict): A dictionary used to create query strings, defaults to - None. - **Note**: The value of each key is assumed to be an iterable - and *not* a string. - headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from - header name to a list of values for that header + uri: The URI to request, not including query parameters + json_body: The JSON to put in the HTTP body, + args: A dictionary used to create query strings + headers: a map from header name to a list of values for that header Returns: - Succeeds when we get *any* 2xx HTTP response, with the - HTTP body as JSON. + Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON. Raises: + RequestTimedOutException: if there is a timeout before the response headers + are received. Note there is currently no timeout on reading the response + body. + HttpResponseException On a non-2xx HTTP response. ValueError: if the response was not JSON @@ -513,21 +573,23 @@ class SimpleHttpClient: response.code, response.phrase.decode("ascii", errors="replace"), body ) - async def get_raw(self, uri, args={}, headers=None): - """ Gets raw text from the given URI. + async def get_raw( + self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None + ) -> bytes: + """Gets raw text from the given URI. Args: - uri (str): The URI to request, not including query parameters - args (dict): A dictionary used to create query strings, defaults to - None. - **Note**: The value of each key is assumed to be an iterable - and *not* a string. - headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from - header name to a list of values for that header + uri: The URI to request, not including query parameters + args: A dictionary used to create query strings + headers: a map from header name to a list of values for that header Returns: - Succeeds when we get *any* 2xx HTTP response, with the + Succeeds when we get a 2xx HTTP response, with the HTTP body as bytes. Raises: + RequestTimedOutException: if there is a timeout before the response headers + are received. Note there is currently no timeout on reading the response + body. + HttpResponseException on a non-2xx HTTP response. """ if len(args): @@ -552,16 +614,29 @@ class SimpleHttpClient: # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # The two should be factored out. - async def get_file(self, url, output_stream, max_size=None, headers=None): + async def get_file( + self, + url: str, + output_stream: BinaryIO, + max_size: Optional[int] = None, + headers: Optional[RawHeaders] = None, + ) -> Tuple[int, Dict[bytes, List[bytes]], str, int]: """GETs a file from a given URL Args: - url (str): The URL to GET - output_stream (file): File to write the response body to. - headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from - header name to a list of values for that header + url: The URL to GET + output_stream: File to write the response body to. + headers: A map from header name to a list of values for that header Returns: - A (int,dict,string,int) tuple of the file length, dict of the response + A tuple of the file length, dict of the response headers, absolute URI of the response and HTTP response code. + + Raises: + RequestTimedOutException: if there is a timeout before the response headers + are received. Note there is currently no timeout on reading the response + body. + + SynapseError: if the response is not a 2xx, the remote file is too large, or + another exception happens during the download. """ actual_headers = {b"User-Agent": [self.user_agent]} diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 987765e877..dce6c4d168 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource): logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) raise OEmbedError() from e - async def _download_url(self, url, user): + async def _download_url(self, url: str, user): # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource): file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) # If this URL can be accessed via oEmbed, use that instead. - url_to_download = url + url_to_download = url # type: Optional[str] oembed_url = self._get_oembed_url(url) if oembed_url: # The result might be a new URL to download, or it might be HTML content. @@ -520,9 +520,15 @@ class PreviewUrlResource(DirectServeJsonResource): # FIXME: we should calculate a proper expiration based on the # Cache-Control and Expire headers. But for now, assume 1 hour. expires = ONE_HOUR - etag = headers["ETag"][0] if "ETag" in headers else None + etag = ( + headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None + ) else: - html_bytes = oembed_result.html.encode("utf-8") # type: ignore + # we can only get here if we did an oembed request and have an oembed_result.html + assert oembed_result.html is not None + assert oembed_url is not None + + html_bytes = oembed_result.html.encode("utf-8") with self.media_storage.store_into_file(file_info) as (f, fname, finish): f.write(html_bytes) await finish() -- cgit 1.5.1 From 3f4a2a7064f79e77deaed8be96668020abef3c9d Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 24 Sep 2020 16:24:08 +0100 Subject: Hotfix: disable autoescape by default when rendering Jinja2 templates (#8394) #8037 changed the default `autoescape` option when rendering Jinja2 templates from `False` to `True`. This caused some bugs, noticeably around redirect URLs being escaped in SAML2 auth confirmation templates, causing those URLs to break for users. This change returns the previous behaviour as it stood. We may want to look at each template individually and see whether autoescaping is a good idea at some point, but for now lets just fix the breakage. --- changelog.d/8394.bugfix | 1 + synapse/config/_base.py | 10 ++++++++-- synapse/config/saml2_config.py | 4 +++- 3 files changed, 12 insertions(+), 3 deletions(-) create mode 100644 changelog.d/8394.bugfix diff --git a/changelog.d/8394.bugfix b/changelog.d/8394.bugfix new file mode 100644 index 0000000000..0ac1eeca0a --- /dev/null +++ b/changelog.d/8394.bugfix @@ -0,0 +1 @@ +Fix URLs being accidentally escaped in Jinja2 templates. Broke in v1.20.0. \ No newline at end of file diff --git a/synapse/config/_base.py b/synapse/config/_base.py index ad5ab6ad62..f8ab8e38df 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -194,7 +194,10 @@ class Config: return file_stream.read() def read_templates( - self, filenames: List[str], custom_template_directory: Optional[str] = None, + self, + filenames: List[str], + custom_template_directory: Optional[str] = None, + autoescape: bool = False, ) -> List[jinja2.Template]: """Load a list of template files from disk using the given variables. @@ -210,6 +213,9 @@ class Config: custom_template_directory: A directory to try to look for the templates before using the default Synapse template directory instead. + autoescape: Whether to autoescape variables before inserting them into the + template. + Raises: ConfigError: if the file's path is incorrect or otherwise cannot be read. @@ -233,7 +239,7 @@ class Config: search_directories.insert(0, custom_template_directory) loader = jinja2.FileSystemLoader(search_directories) - env = jinja2.Environment(loader=loader, autoescape=True) + env = jinja2.Environment(loader=loader, autoescape=autoescape) # Update the environment with our custom filters env.filters.update( diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index cc7401888b..755478e2ff 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -169,8 +169,10 @@ class SAML2Config(Config): saml2_config.get("saml_session_lifetime", "15m") ) + # We enable autoescape here as the message may potentially come from a + # remote resource self.saml2_error_html_template = self.read_templates( - ["saml_error.html"], saml2_config.get("template_dir") + ["saml_error.html"], saml2_config.get("template_dir"), autoescape=True )[0] def _default_saml_config_dict( -- cgit 1.5.1 From f3e5c2e702fb2bb5c59d354b92dec3a46f4dc962 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 24 Sep 2020 08:13:55 -0400 Subject: Mark the shadow_banned column as boolean in synapse_port_db. (#8386) --- .buildkite/test_db.db | Bin 18825216 -> 19279872 bytes changelog.d/8386.bugfix | 1 + scripts/synapse_port_db | 1 + 3 files changed, 2 insertions(+) create mode 100644 changelog.d/8386.bugfix diff --git a/.buildkite/test_db.db b/.buildkite/test_db.db index f20567ba73..361369a581 100644 Binary files a/.buildkite/test_db.db and b/.buildkite/test_db.db differ diff --git a/changelog.d/8386.bugfix b/changelog.d/8386.bugfix new file mode 100644 index 0000000000..24983a1e95 --- /dev/null +++ b/changelog.d/8386.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.20.0 which caused the `synapse_port_db` script to fail. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index a34bdf1830..ecca8b6e8f 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -89,6 +89,7 @@ BOOLEAN_COLUMNS = { "redactions": ["have_censored"], "room_stats_state": ["is_federatable"], "local_media_repository": ["safe_from_quarantine"], + "users": ["shadow_banned"], } -- cgit 1.5.1 From 920dd1083efb7e38b8b85b4b32f090277d5b69db Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 24 Sep 2020 16:25:33 +0100 Subject: 1.20.1 --- CHANGES.md | 10 ++++++++++ changelog.d/8386.bugfix | 1 - changelog.d/8394.bugfix | 1 - debian/changelog | 6 ++++++ synapse/__init__.py | 2 +- 5 files changed, 17 insertions(+), 3 deletions(-) delete mode 100644 changelog.d/8386.bugfix delete mode 100644 changelog.d/8394.bugfix diff --git a/CHANGES.md b/CHANGES.md index 84711de448..650dc8487d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,13 @@ +Synapse 1.20.1 (2020-09-24) +=========================== + +Bugfixes +-------- + +- Fix a bug introduced in v1.20.0 which caused the `synapse_port_db` script to fail. ([\#8386](https://github.com/matrix-org/synapse/issues/8386)) +- Fix URLs being accidentally escaped in Jinja2 templates. Broke in v1.20.0. ([\#8394](https://github.com/matrix-org/synapse/issues/8394)) + + Synapse 1.20.0 (2020-09-22) =========================== diff --git a/changelog.d/8386.bugfix b/changelog.d/8386.bugfix deleted file mode 100644 index 24983a1e95..0000000000 --- a/changelog.d/8386.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in v1.20.0 which caused the `synapse_port_db` script to fail. diff --git a/changelog.d/8394.bugfix b/changelog.d/8394.bugfix deleted file mode 100644 index 0ac1eeca0a..0000000000 --- a/changelog.d/8394.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix URLs being accidentally escaped in Jinja2 templates. Broke in v1.20.0. \ No newline at end of file diff --git a/debian/changelog b/debian/changelog index ae548f9f33..264ef9ce7c 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.20.1) stable; urgency=medium + + * New synapse release 1.20.1. + + -- Synapse Packaging team Thu, 24 Sep 2020 16:25:22 +0100 + matrix-synapse-py3 (1.20.0) stable; urgency=medium [ Synapse Packaging team ] diff --git a/synapse/__init__.py b/synapse/__init__.py index 8242d05f60..e40b582bd5 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -48,7 +48,7 @@ try: except ImportError: pass -__version__ = "1.20.0" +__version__ = "1.20.1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.5.1 From 5ce5a9f1447088bc29cac49f4a0ebfab6c0198d6 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 24 Sep 2020 16:26:57 +0100 Subject: Update changelog wording --- CHANGES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 650dc8487d..16e83f6f10 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,7 +5,7 @@ Bugfixes -------- - Fix a bug introduced in v1.20.0 which caused the `synapse_port_db` script to fail. ([\#8386](https://github.com/matrix-org/synapse/issues/8386)) -- Fix URLs being accidentally escaped in Jinja2 templates. Broke in v1.20.0. ([\#8394](https://github.com/matrix-org/synapse/issues/8394)) +- Fix a bug introduced in v1.20.0 which caused URLs to be accidentally escaped in Jinja2 templates. ([\#8394](https://github.com/matrix-org/synapse/issues/8394)) Synapse 1.20.0 (2020-09-22) -- cgit 1.5.1 From 271086ebda55f9ef0a0bdee69c96d79c5005e21d Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 24 Sep 2020 16:33:49 +0100 Subject: s/accidentally/incorrectly in changelog --- CHANGES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 16e83f6f10..7ea08fa117 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,7 +5,7 @@ Bugfixes -------- - Fix a bug introduced in v1.20.0 which caused the `synapse_port_db` script to fail. ([\#8386](https://github.com/matrix-org/synapse/issues/8386)) -- Fix a bug introduced in v1.20.0 which caused URLs to be accidentally escaped in Jinja2 templates. ([\#8394](https://github.com/matrix-org/synapse/issues/8394)) +- Fix a bug introduced in v1.20.0 which caused URLs to be incorrectly escaped in Jinja2 templates. ([\#8394](https://github.com/matrix-org/synapse/issues/8394)) Synapse 1.20.0 (2020-09-22) -- cgit 1.5.1 From ab903e7337f6c2c7cfcdac69b13dedf67e56d801 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 24 Sep 2020 16:35:31 +0100 Subject: s/URLs/variables in changelog --- CHANGES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 7ea08fa117..5de819ea1e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,7 +5,7 @@ Bugfixes -------- - Fix a bug introduced in v1.20.0 which caused the `synapse_port_db` script to fail. ([\#8386](https://github.com/matrix-org/synapse/issues/8386)) -- Fix a bug introduced in v1.20.0 which caused URLs to be incorrectly escaped in Jinja2 templates. ([\#8394](https://github.com/matrix-org/synapse/issues/8394)) +- Fix a bug introduced in v1.20.0 which caused variables to be incorrectly escaped in Jinja2 templates. ([\#8394](https://github.com/matrix-org/synapse/issues/8394)) Synapse 1.20.0 (2020-09-22) -- cgit 1.5.1 From f112cfe5bb2c918c9e942941686a05664d8bd7da Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 24 Sep 2020 16:53:51 +0100 Subject: Fix MultiWriteIdGenerator's handling of restarts. (#8374) On startup `MultiWriteIdGenerator` fetches the maximum stream ID for each instance from the table and uses that as its initial "current position" for each writer. This is problematic as a) it involves either a scan of events table or an index (neither of which is ideal), and b) if rows are being persisted out of order elsewhere while the process restarts then using the maximum stream ID is not correct. This could theoretically lead to race conditions where e.g. events that are persisted out of order are not sent down sync streams. We fix this by creating a new table that tracks the current positions of each writer to the stream, and update it each time we finish persisting a new entry. This is a relatively small overhead when persisting events. However for the cache invalidation stream this is a much bigger relative overhead, so instead we note that for invalidation we don't actually care about reliability over restarts (as there's no caches to invalidate) and simply don't bother reading and writing to the new table in that particular case. --- changelog.d/8374.bugfix | 1 + synapse/replication/slave/storage/_base.py | 2 + synapse/storage/databases/main/__init__.py | 8 +- synapse/storage/databases/main/events_worker.py | 4 + .../main/schema/delta/58/18stream_positions.sql | 22 +++ synapse/storage/util/id_generators.py | 148 ++++++++++++++++++--- tests/storage/test_id_generators.py | 119 +++++++++++++++-- 7 files changed, 274 insertions(+), 30 deletions(-) create mode 100644 changelog.d/8374.bugfix create mode 100644 synapse/storage/databases/main/schema/delta/58/18stream_positions.sql diff --git a/changelog.d/8374.bugfix b/changelog.d/8374.bugfix new file mode 100644 index 0000000000..155bc3404f --- /dev/null +++ b/changelog.d/8374.bugfix @@ -0,0 +1 @@ +Fix theoretical race condition where events are not sent down `/sync` if the synchrotron worker is restarted without restarting other workers. diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index d25fa49e1a..d0089fe06c 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -31,11 +31,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): self._cache_id_gen = MultiWriterIdGenerator( db_conn, database, + stream_name="caches", instance_name=hs.get_instance_name(), table="cache_invalidation_stream_by_instance", instance_column="instance_name", id_column="stream_id", sequence_name="cache_invalidation_stream_seq", + writers=[], ) # type: Optional[MultiWriterIdGenerator] else: self._cache_id_gen = None diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index ccb3384db9..0cb12f4c61 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -160,14 +160,20 @@ class DataStore( ) if isinstance(self.database_engine, PostgresEngine): + # We set the `writers` to an empty list here as we don't care about + # missing updates over restarts, as we'll not have anything in our + # caches to invalidate. (This reduces the amount of writes to the DB + # that happen). self._cache_id_gen = MultiWriterIdGenerator( db_conn, database, - instance_name="master", + stream_name="caches", + instance_name=hs.get_instance_name(), table="cache_invalidation_stream_by_instance", instance_column="instance_name", id_column="stream_id", sequence_name="cache_invalidation_stream_seq", + writers=[], ) else: self._cache_id_gen = None diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index de9e8d1dc6..f95679ebc4 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -83,21 +83,25 @@ class EventsWorkerStore(SQLBaseStore): self._stream_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + stream_name="events", instance_name=hs.get_instance_name(), table="events", instance_column="instance_name", id_column="stream_ordering", sequence_name="events_stream_seq", + writers=hs.config.worker.writers.events, ) self._backfill_id_gen = MultiWriterIdGenerator( db_conn=db_conn, db=database, + stream_name="backfill", instance_name=hs.get_instance_name(), table="events", instance_column="instance_name", id_column="stream_ordering", sequence_name="events_backfill_stream_seq", positive=False, + writers=hs.config.worker.writers.events, ) else: # We shouldn't be running in worker mode with SQLite, but its useful diff --git a/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql new file mode 100644 index 0000000000..985fd949a2 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE stream_positions ( + stream_name TEXT NOT NULL, + instance_name TEXT NOT NULL, + stream_id BIGINT NOT NULL +); + +CREATE UNIQUE INDEX stream_positions_idx ON stream_positions(stream_name, instance_name); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index b0353ac2dc..727fcc521c 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -22,6 +22,7 @@ from typing import Dict, List, Optional, Set, Union import attr from typing_extensions import Deque +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.util.sequence import PostgresSequenceGenerator @@ -184,12 +185,16 @@ class MultiWriterIdGenerator: Args: db_conn db + stream_name: A name for the stream. instance_name: The name of this instance. table: Database table associated with stream. instance_column: Column that stores the row's writer's instance name id_column: Column that stores the stream ID. sequence_name: The name of the postgres sequence used to generate new IDs. + writers: A list of known writers to use to populate current positions + on startup. Can be empty if nothing uses `get_current_token` or + `get_positions` (e.g. caches stream). positive: Whether the IDs are positive (true) or negative (false). When using negative IDs we go backwards from -1 to -2, -3, etc. """ @@ -198,16 +203,20 @@ class MultiWriterIdGenerator: self, db_conn, db: DatabasePool, + stream_name: str, instance_name: str, table: str, instance_column: str, id_column: str, sequence_name: str, + writers: List[str], positive: bool = True, ): self._db = db + self._stream_name = stream_name self._instance_name = instance_name self._positive = positive + self._writers = writers self._return_factor = 1 if positive else -1 # We lock as some functions may be called from DB threads. @@ -216,9 +225,7 @@ class MultiWriterIdGenerator: # Note: If we are a negative stream then we still store all the IDs as # positive to make life easier for us, and simply negate the IDs when we # return them. - self._current_positions = self._load_current_ids( - db_conn, table, instance_column, id_column - ) + self._current_positions = {} # type: Dict[str, int] # Set of local IDs that we're still processing. The current position # should be less than the minimum of this set (if not empty). @@ -251,30 +258,80 @@ class MultiWriterIdGenerator: self._sequence_gen = PostgresSequenceGenerator(sequence_name) + # This goes and fills out the above state from the database. + self._load_current_ids(db_conn, table, instance_column, id_column) + def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str - ) -> Dict[str, int]: - # If positive stream aggregate via MAX. For negative stream use MIN - # *and* negate the result to get a positive number. - sql = """ - SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s - GROUP BY %(instance)s - """ % { - "instance": instance_column, - "id": id_column, - "table": table, - "agg": "MAX" if self._positive else "-MIN", - } - + ): cur = db_conn.cursor() - cur.execute(sql) - # `cur` is an iterable over returned rows, which are 2-tuples. - current_positions = dict(cur) + # Load the current positions of all writers for the stream. + if self._writers: + sql = """ + SELECT instance_name, stream_id FROM stream_positions + WHERE stream_name = ? + """ + sql = self._db.engine.convert_param_style(sql) - cur.close() + cur.execute(sql, (self._stream_name,)) + + self._current_positions = { + instance: stream_id * self._return_factor + for instance, stream_id in cur + if instance in self._writers + } + + # We set the `_persisted_upto_position` to be the minimum of all current + # positions. If empty we use the max stream ID from the DB table. + min_stream_id = min(self._current_positions.values(), default=None) + + if min_stream_id is None: + sql = """ + SELECT COALESCE(%(agg)s(%(id)s), 1) FROM %(table)s + """ % { + "id": id_column, + "table": table, + "agg": "MAX" if self._positive else "-MIN", + } + cur.execute(sql) + (stream_id,) = cur.fetchone() + self._persisted_upto_position = stream_id + else: + # If we have a min_stream_id then we pull out everything greater + # than it from the DB so that we can prefill + # `_known_persisted_positions` and get a more accurate + # `_persisted_upto_position`. + # + # We also check if any of the later rows are from this instance, in + # which case we use that for this instance's current position. This + # is to handle the case where we didn't finish persisting to the + # stream positions table before restart (or the stream position + # table otherwise got out of date). + + sql = """ + SELECT %(instance)s, %(id)s FROM %(table)s + WHERE ? %(cmp)s %(id)s + """ % { + "id": id_column, + "table": table, + "instance": instance_column, + "cmp": "<=" if self._positive else ">=", + } + sql = self._db.engine.convert_param_style(sql) + cur.execute(sql, (min_stream_id,)) + + self._persisted_upto_position = min_stream_id + + with self._lock: + for (instance, stream_id,) in cur: + stream_id = self._return_factor * stream_id + self._add_persisted_position(stream_id) - return current_positions + if instance == self._instance_name: + self._current_positions[instance] = stream_id + + cur.close() def _load_next_id_txn(self, txn) -> int: return self._sequence_gen.get_next_id_txn(txn) @@ -316,6 +373,21 @@ class MultiWriterIdGenerator: txn.call_after(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_id_as_finished, next_id) + # Update the `stream_positions` table with newly updated stream + # ID (unless self._writers is not set in which case we don't + # bother, as nothing will read it). + # + # We only do this on the success path so that the persisted current + # position points to a persited row with the correct instance name. + if self._writers: + txn.call_after( + run_as_background_process, + "MultiWriterIdGenerator._update_table", + self._db.runInteraction, + "MultiWriterIdGenerator._update_table", + self._update_stream_positions_table_txn, + ) + return self._return_factor * next_id def _mark_id_as_finished(self, next_id: int): @@ -447,6 +519,28 @@ class MultiWriterIdGenerator: # do. break + def _update_stream_positions_table_txn(self, txn): + """Update the `stream_positions` table with newly persisted position. + """ + + if not self._writers: + return + + # We upsert the value, ensuring on conflict that we always increase the + # value (or decrease if stream goes backwards). + sql = """ + INSERT INTO stream_positions (stream_name, instance_name, stream_id) + VALUES (?, ?, ?) + ON CONFLICT (stream_name, instance_name) + DO UPDATE SET + stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id) + """ % { + "agg": "GREATEST" if self._positive else "LEAST", + } + + pos = (self.get_current_token_for_writer(self._instance_name),) + txn.execute(sql, (self._stream_name, self._instance_name, pos)) + @attr.s(slots=True) class _AsyncCtxManagerWrapper: @@ -503,4 +597,16 @@ class _MultiWriterCtxManager: if exc_type is not None: return False + # Update the `stream_positions` table with newly updated stream + # ID (unless self._writers is not set in which case we don't + # bother, as nothing will read it). + # + # We only do this on the success path so that the persisted current + # position points to a persisted row with the correct instance name. + if self.id_gen._writers: + await self.id_gen._db.runInteraction( + "MultiWriterIdGenerator._update_table", + self.id_gen._update_stream_positions_table_txn, + ) + return False diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index fb8f5bc255..d4ff55fbff 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -43,16 +43,20 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): """ ) - def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator: + def _create_id_generator( + self, instance_name="master", writers=["master"] + ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( conn, self.db_pool, + stream_name="test_stream", instance_name=instance_name, table="foobar", instance_column="instance_name", id_column="stream_id", sequence_name="foobar_seq", + writers=writers, ) return self.get_success(self.db_pool.runWithConnection(_create)) @@ -68,6 +72,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)", (instance_name,), ) + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, lastval()) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval() + """, + (instance_name,), + ) self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) @@ -81,6 +92,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,), ) txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,)) + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, ?) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? + """, + (instance_name, stream_id, stream_id), + ) self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert)) @@ -179,8 +197,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_rows("first", 3) self._insert_rows("second", 4) - first_id_gen = self._create_id_generator("first") - second_id_gen = self._create_id_generator("second") + first_id_gen = self._create_id_generator("first", writers=["first", "second"]) + second_id_gen = self._create_id_generator("second", writers=["first", "second"]) self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) @@ -262,7 +280,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_row_with_id("first", 3) self._insert_row_with_id("second", 5) - id_gen = self._create_id_generator("first") + id_gen = self._create_id_generator("first", writers=["first", "second"]) self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) @@ -300,7 +318,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_row_with_id("first", 3) self._insert_row_with_id("second", 5) - id_gen = self._create_id_generator("first") + id_gen = self._create_id_generator("first", writers=["first", "second"]) self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) @@ -319,6 +337,80 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # `persisted_upto_position` in this case, then it will be correct in the # other cases that are tested above (since they'll hit the same code). + def test_restart_during_out_of_order_persistence(self): + """Test that restarting a process while another process is writing out + of order updates are handled correctly. + """ + + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) + + id_gen = self._create_id_generator() + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + # Persist two rows at once + ctx1 = self.get_success(id_gen.get_next()) + ctx2 = self.get_success(id_gen.get_next()) + + s1 = self.get_success(ctx1.__aenter__()) + s2 = self.get_success(ctx2.__aenter__()) + + self.assertEqual(s1, 8) + self.assertEqual(s2, 9) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + # We finish persisting the second row before restart + self.get_success(ctx2.__aexit__(None, None, None)) + + # We simulate a restart of another worker by just creating a new ID gen. + id_gen_worker = self._create_id_generator("worker") + + # Restarted worker should not see the second persisted row + self.assertEqual(id_gen_worker.get_positions(), {"master": 7}) + self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7) + + # Now if we persist the first row then both instances should jump ahead + # correctly. + self.get_success(ctx1.__aexit__(None, None, None)) + + self.assertEqual(id_gen.get_positions(), {"master": 9}) + id_gen_worker.advance("master", 9) + self.assertEqual(id_gen_worker.get_positions(), {"master": 9}) + + def test_writer_config_change(self): + """Test that changing the writer config correctly works. + """ + + self._insert_row_with_id("first", 3) + self._insert_row_with_id("second", 5) + + # Initial config has two writers + id_gen = self._create_id_generator("first", writers=["first", "second"]) + self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + # New config removes one of the configs. Note that if the writer is + # removed from config we assume that it has been shut down and has + # finished persisting, hence why the persisted upto position is 5. + id_gen_2 = self._create_id_generator("second", writers=["second"]) + self.assertEqual(id_gen_2.get_persisted_upto_position(), 5) + + # This config points to a single, previously unused writer. + id_gen_3 = self._create_id_generator("third", writers=["third"]) + self.assertEqual(id_gen_3.get_persisted_upto_position(), 5) + + # Check that we get a sane next stream ID with this new config. + + async def _get_next_async(): + async with id_gen_3.get_next() as stream_id: + self.assertEqual(stream_id, 6) + + self.get_success(_get_next_async()) + self.assertEqual(id_gen_3.get_persisted_upto_position(), 6) + class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """Tests MultiWriterIdGenerator that produce *negative* stream IDs. @@ -345,16 +437,20 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """ ) - def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator: + def _create_id_generator( + self, instance_name="master", writers=["master"] + ) -> MultiWriterIdGenerator: def _create(conn): return MultiWriterIdGenerator( conn, self.db_pool, + stream_name="test_stream", instance_name=instance_name, table="foobar", instance_column="instance_name", id_column="stream_id", sequence_name="foobar_seq", + writers=writers, positive=False, ) @@ -368,6 +464,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): txn.execute( "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,), ) + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, ?) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ? + """, + (instance_name, -stream_id, -stream_id), + ) self.get_success(self.db_pool.runInteraction("_insert_row", _insert)) @@ -409,8 +512,8 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """Tests that having multiple instances that get advanced over federation works corretly. """ - id_gen_1 = self._create_id_generator("first") - id_gen_2 = self._create_id_generator("second") + id_gen_1 = self._create_id_generator("first", writers=["first", "second"]) + id_gen_2 = self._create_id_generator("second", writers=["first", "second"]) async def _get_next_async(): async with id_gen_1.get_next() as stream_id: -- cgit 1.5.1 From 3e87d79e1c6ef894387ee2f24e008dfb8f5f853f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 25 Sep 2020 09:58:32 +0100 Subject: Fix schema delta for servers that have not backfilled (#8396) Fixes #8395. --- changelog.d/8396.feature | 1 + .../main/schema/delta/58/14events_instance_name.sql.postgres | 4 +++- synapse/storage/util/id_generators.py | 6 +++++- 3 files changed, 9 insertions(+), 2 deletions(-) create mode 100644 changelog.d/8396.feature diff --git a/changelog.d/8396.feature b/changelog.d/8396.feature new file mode 100644 index 0000000000..b363e929ea --- /dev/null +++ b/changelog.d/8396.feature @@ -0,0 +1 @@ +Add experimental support for sharding event persister. diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres index 97c1e6a0c5..c31f9af82a 100644 --- a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres @@ -21,6 +21,8 @@ SELECT setval('events_stream_seq', ( CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq; +-- If the server has never backfilled a room then doing `-MIN(...)` will give +-- a negative result, hence why we do `GREATEST(...)` SELECT setval('events_backfill_stream_seq', ( - SELECT COALESCE(-MIN(stream_ordering), 1) FROM events + SELECT GREATEST(COALESCE(-MIN(stream_ordering), 1), 1) FROM events )); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 727fcc521c..4269eaf918 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -287,8 +287,12 @@ class MultiWriterIdGenerator: min_stream_id = min(self._current_positions.values(), default=None) if min_stream_id is None: + # We add a GREATEST here to ensure that the result is always + # positive. (This can be a problem for e.g. backfill streams where + # the server has never backfilled). sql = """ - SELECT COALESCE(%(agg)s(%(id)s), 1) FROM %(table)s + SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) + FROM %(table)s """ % { "id": id_column, "table": table, -- cgit 1.5.1 From abd04b6af0671517a01781c8bd10fef2a6c32cc4 Mon Sep 17 00:00:00 2001 From: Tdxdxoz Date: Fri, 25 Sep 2020 19:01:45 +0800 Subject: Allow existing users to login via OpenID Connect. (#8345) Co-authored-by: Benjamin Koch This adds configuration flags that will match a user to pre-existing users when logging in via OpenID Connect. This is useful when switching to an existing SSO system. --- changelog.d/8345.feature | 1 + docs/sample_config.yaml | 5 +++ synapse/config/oidc_config.py | 6 ++++ synapse/handlers/oidc_handler.py | 42 +++++++++++++++++--------- synapse/storage/databases/main/registration.py | 4 +-- tests/handlers/test_oidc.py | 35 +++++++++++++++++++++ 6 files changed, 76 insertions(+), 17 deletions(-) create mode 100644 changelog.d/8345.feature diff --git a/changelog.d/8345.feature b/changelog.d/8345.feature new file mode 100644 index 0000000000..4ee5b6a56e --- /dev/null +++ b/changelog.d/8345.feature @@ -0,0 +1 @@ +Add a configuration option that allows existing users to log in with OpenID Connect. Contributed by @BBBSnowball and @OmmyZhang. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index fb04ff283d..845f537795 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1689,6 +1689,11 @@ oidc_config: # #skip_verification: true + # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead + # of failing. This could be used if switching from password logins to OIDC. Defaults to false. + # + #allow_existing_users: true + # An external module can be provided here as a custom solution to mapping # attributes returned from a OIDC provider onto a matrix user. # diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index e0939bce84..70fc8a2f62 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -56,6 +56,7 @@ class OIDCConfig(Config): self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint") self.oidc_jwks_uri = oidc_config.get("jwks_uri") self.oidc_skip_verification = oidc_config.get("skip_verification", False) + self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False) ump_config = oidc_config.get("user_mapping_provider", {}) ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) @@ -158,6 +159,11 @@ class OIDCConfig(Config): # #skip_verification: true + # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead + # of failing. This could be used if switching from password logins to OIDC. Defaults to false. + # + #allow_existing_users: true + # An external module can be provided here as a custom solution to mapping # attributes returned from a OIDC provider onto a matrix user. # diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 4230dbaf99..0e06e4408d 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -114,6 +114,7 @@ class OidcHandler: hs.config.oidc_user_mapping_provider_config ) # type: OidcMappingProvider self._skip_verification = hs.config.oidc_skip_verification # type: bool + self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool self._http_client = hs.get_proxied_http_client() self._auth_handler = hs.get_auth_handler() @@ -849,7 +850,8 @@ class OidcHandler: If we don't find the user that way, we should register the user, mapping the localpart and the display name from the UserInfo. - If a user already exists with the mxid we've mapped, raise an exception. + If a user already exists with the mxid we've mapped and allow_existing_users + is disabled, raise an exception. Args: userinfo: an object representing the user @@ -905,21 +907,31 @@ class OidcHandler: localpart = map_username_to_mxid_localpart(attributes["localpart"]) - user_id = UserID(localpart, self._hostname) - if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()): - # This mxid is taken - raise MappingException( - "mxid '{}' is already taken".format(user_id.to_string()) + user_id = UserID(localpart, self._hostname).to_string() + users = await self._datastore.get_users_by_id_case_insensitive(user_id) + if users: + if self._allow_existing_users: + if len(users) == 1: + registered_user_id = next(iter(users)) + elif user_id in users: + registered_user_id = user_id + else: + raise MappingException( + "Attempted to login as '{}' but it matches more than one user inexactly: {}".format( + user_id, list(users.keys()) + ) + ) + else: + # This mxid is taken + raise MappingException("mxid '{}' is already taken".format(user_id)) + else: + # It's the first time this user is logging in and the mapped mxid was + # not taken, register the user + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, + default_display_name=attributes["display_name"], + user_agent_ips=(user_agent, ip_address), ) - - # It's the first time this user is logging in and the mapped mxid was - # not taken, register the user - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=attributes["display_name"], - user_agent_ips=(user_agent, ip_address), - ) - await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id, ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 33825e8949..48ce7ecd16 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -393,7 +393,7 @@ class RegistrationWorkerStore(SQLBaseStore): async def get_user_by_external_id( self, auth_provider: str, external_id: str - ) -> str: + ) -> Optional[str]: """Look up a user by their external auth id Args: @@ -401,7 +401,7 @@ class RegistrationWorkerStore(SQLBaseStore): external_id: id on that system Returns: - str|None: the mxid of the user, or None if they are not known + the mxid of the user, or None if they are not known """ return await self.db_pool.simple_select_one_onecol( table="user_external_ids", diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 89ec5fcb31..5910772aa8 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -617,3 +617,38 @@ class OidcHandlerTestCase(HomeserverTestCase): ) ) self.assertEqual(mxid, "@test_user_2:test") + + # Test if the mxid is already taken + store = self.hs.get_datastore() + user3 = UserID.from_string("@test_user_3:test") + self.get_success( + store.register_user(user_id=user3.to_string(), password_hash=None) + ) + userinfo = {"sub": "test3", "username": "test_user_3"} + e = self.get_failure( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ), + MappingException, + ) + self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken") + + @override_config({"oidc_config": {"allow_existing_users": True}}) + def test_map_userinfo_to_existing_user(self): + """Existing users can log in with OpenID Connect when allow_existing_users is True.""" + store = self.hs.get_datastore() + user4 = UserID.from_string("@test_user_4:test") + self.get_success( + store.register_user(user_id=user4.to_string(), password_hash=None) + ) + userinfo = { + "sub": "test4", + "username": "test_user_4", + } + token = {} + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user_4:test") -- cgit 1.5.1 From fec6f9ac178867a8e7c5410e0d25898f29bab35c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 25 Sep 2020 12:29:54 +0100 Subject: Fix occasional "Re-starting finished log context" from keyring (#8398) * Fix test_verify_json_objects_for_server_awaits_previous_requests It turns out that this wasn't really testing what it thought it was testing (in particular, `check_context` was turning failures into success, which was making the tests pass even though it wasn't clear they should have been. It was also somewhat overcomplex - we can test what it was trying to test without mocking out perspectives servers. * Fix warnings about finished logcontexts in the keyring We need to make sure that we finish the key fetching magic before we run the verifying code, to ensure that we don't mess up our logcontexts. --- changelog.d/8398.bugfix | 1 + synapse/crypto/keyring.py | 70 +++++++++++++++---------- tests/crypto/test_keyring.py | 120 ++++++++++++++++++++----------------------- 3 files changed, 101 insertions(+), 90 deletions(-) create mode 100644 changelog.d/8398.bugfix diff --git a/changelog.d/8398.bugfix b/changelog.d/8398.bugfix new file mode 100644 index 0000000000..e432aeebf1 --- /dev/null +++ b/changelog.d/8398.bugfix @@ -0,0 +1 @@ +Fix "Re-starting finished log context" warning when receiving an event we already had over federation. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 42e4087a92..c04ad77cf9 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -42,7 +42,6 @@ from synapse.api.errors import ( ) from synapse.logging.context import ( PreserveLoggingContext, - current_context, make_deferred_yieldable, preserve_fn, run_in_background, @@ -233,8 +232,6 @@ class Keyring: """ try: - ctx = current_context() - # map from server name to a set of outstanding request ids server_to_request_ids = {} @@ -265,12 +262,8 @@ class Keyring: # if there are no more requests for this server, we can drop the lock. if not server_requests: - with PreserveLoggingContext(ctx): - logger.debug("Releasing key lookup lock on %s", server_name) - - # ... but not immediately, as that can cause stack explosions if - # we get a long queue of lookups. - self.clock.call_later(0, drop_server_lock, server_name) + logger.debug("Releasing key lookup lock on %s", server_name) + drop_server_lock(server_name) return res @@ -335,20 +328,32 @@ class Keyring: ) # look for any requests which weren't satisfied - with PreserveLoggingContext(): - for verify_request in remaining_requests: - verify_request.key_ready.errback( - SynapseError( - 401, - "No key for %s with ids in %s (min_validity %i)" - % ( - verify_request.server_name, - verify_request.key_ids, - verify_request.minimum_valid_until_ts, - ), - Codes.UNAUTHORIZED, - ) + while remaining_requests: + verify_request = remaining_requests.pop() + rq_str = ( + "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)" + % ( + verify_request.server_name, + verify_request.key_ids, + verify_request.minimum_valid_until_ts, ) + ) + + # If we run the errback immediately, it may cancel our + # loggingcontext while we are still in it, so instead we + # schedule it for the next time round the reactor. + # + # (this also ensures that we don't get a stack overflow if we + # has a massive queue of lookups waiting for this server). + self.clock.call_later( + 0, + verify_request.key_ready.errback, + SynapseError( + 401, + "Failed to find any key to satisfy %s" % (rq_str,), + Codes.UNAUTHORIZED, + ), + ) except Exception as err: # we don't really expect to get here, because any errors should already # have been caught and logged. But if we do, let's log the error and make @@ -410,10 +415,23 @@ class Keyring: # key was not valid at this point continue - with PreserveLoggingContext(): - verify_request.key_ready.callback( - (server_name, key_id, fetch_key_result.verify_key) - ) + # we have a valid key for this request. If we run the callback + # immediately, it may cancel our loggingcontext while we are still in + # it, so instead we schedule it for the next time round the reactor. + # + # (this also ensures that we don't get a stack overflow if we had + # a massive queue of lookups waiting for this server). + logger.debug( + "Found key %s:%s for %s", + server_name, + key_id, + verify_request.request_name, + ) + self.clock.call_later( + 0, + verify_request.key_ready.callback, + (server_name, key_id, fetch_key_result.verify_key), + ) completed.append(verify_request) break diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 2e6e7abf1f..5cf408f21f 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -23,6 +23,7 @@ from nacl.signing import SigningKey from signedjson.key import encode_verify_key_base64, get_verify_key from twisted.internet import defer +from twisted.internet.defer import Deferred, ensureDeferred from synapse.api.errors import SynapseError from synapse.crypto import keyring @@ -33,7 +34,6 @@ from synapse.crypto.keyring import ( ) from synapse.logging.context import ( LoggingContext, - PreserveLoggingContext, current_context, make_deferred_yieldable, ) @@ -68,54 +68,40 @@ class MockPerspectiveServer: class KeyringTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): - self.mock_perspective_server = MockPerspectiveServer() - self.http_client = Mock() - - config = self.default_config() - config["trusted_key_servers"] = [ - { - "server_name": self.mock_perspective_server.server_name, - "verify_keys": self.mock_perspective_server.get_verify_keys(), - } - ] - - return self.setup_test_homeserver( - handlers=None, http_client=self.http_client, config=config - ) - - def check_context(self, _, expected): + def check_context(self, val, expected): self.assertEquals(getattr(current_context(), "request", None), expected) + return val def test_verify_json_objects_for_server_awaits_previous_requests(self): - key1 = signedjson.key.generate_signing_key(1) + mock_fetcher = keyring.KeyFetcher() + mock_fetcher.get_keys = Mock() + kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,)) - kr = keyring.Keyring(self.hs) + # a signed object that we are going to try to validate + key1 = signedjson.key.generate_signing_key(1) json1 = {} signedjson.sign.sign_json(json1, "server10", key1) - persp_resp = { - "server_keys": [ - self.mock_perspective_server.get_signed_key( - "server10", signedjson.key.get_verify_key(key1) - ) - ] - } - persp_deferred = defer.Deferred() + # start off a first set of lookups. We make the mock fetcher block until this + # deferred completes. + first_lookup_deferred = Deferred() + + async def first_lookup_fetch(keys_to_fetch): + self.assertEquals(current_context().request, "context_11") + self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}}) - async def get_perspectives(**kwargs): - self.assertEquals(current_context().request, "11") - with PreserveLoggingContext(): - await persp_deferred - return persp_resp + await make_deferred_yieldable(first_lookup_deferred) + return { + "server10": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100) + } + } - self.http_client.post_json.side_effect = get_perspectives + mock_fetcher.get_keys.side_effect = first_lookup_fetch - # start off a first set of lookups - @defer.inlineCallbacks - def first_lookup(): - with LoggingContext("11") as context_11: - context_11.request = "11" + async def first_lookup(): + with LoggingContext("context_11") as context_11: + context_11.request = "context_11" res_deferreds = kr.verify_json_objects_for_server( [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")] @@ -124,7 +110,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): # the unsigned json should be rejected pretty quickly self.assertTrue(res_deferreds[1].called) try: - yield res_deferreds[1] + await res_deferreds[1] self.assertFalse("unsigned json didn't cause a failure") except SynapseError: pass @@ -132,45 +118,51 @@ class KeyringTestCase(unittest.HomeserverTestCase): self.assertFalse(res_deferreds[0].called) res_deferreds[0].addBoth(self.check_context, None) - yield make_deferred_yieldable(res_deferreds[0]) + await make_deferred_yieldable(res_deferreds[0]) - # let verify_json_objects_for_server finish its work before we kill the - # logcontext - yield self.clock.sleep(0) + d0 = ensureDeferred(first_lookup()) - d0 = first_lookup() - - # wait a tick for it to send the request to the perspectives server - # (it first tries the datastore) - self.pump() - self.http_client.post_json.assert_called_once() + mock_fetcher.get_keys.assert_called_once() # a second request for a server with outstanding requests # should block rather than start a second call - @defer.inlineCallbacks - def second_lookup(): - with LoggingContext("12") as context_12: - context_12.request = "12" - self.http_client.post_json.reset_mock() - self.http_client.post_json.return_value = defer.Deferred() + + async def second_lookup_fetch(keys_to_fetch): + self.assertEquals(current_context().request, "context_12") + return { + "server10": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100) + } + } + + mock_fetcher.get_keys.reset_mock() + mock_fetcher.get_keys.side_effect = second_lookup_fetch + second_lookup_state = [0] + + async def second_lookup(): + with LoggingContext("context_12") as context_12: + context_12.request = "context_12" res_deferreds_2 = kr.verify_json_objects_for_server( [("server10", json1, 0, "test")] ) res_deferreds_2[0].addBoth(self.check_context, None) - yield make_deferred_yieldable(res_deferreds_2[0]) + second_lookup_state[0] = 1 + await make_deferred_yieldable(res_deferreds_2[0]) + second_lookup_state[0] = 2 - # let verify_json_objects_for_server finish its work before we kill the - # logcontext - yield self.clock.sleep(0) - - d2 = second_lookup() + d2 = ensureDeferred(second_lookup()) self.pump() - self.http_client.post_json.assert_not_called() + # the second request should be pending, but the fetcher should not yet have been + # called + self.assertEqual(second_lookup_state[0], 1) + mock_fetcher.get_keys.assert_not_called() # complete the first request - persp_deferred.callback(persp_resp) + first_lookup_deferred.callback(None) + + # and now both verifications should succeed. self.get_success(d0) self.get_success(d2) -- cgit 1.5.1 From 31acc5c30938bd532670d45304f6750de6e6e759 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 25 Sep 2020 11:05:54 -0400 Subject: Escape the error description on the sso_error template. (#8405) --- changelog.d/8405.feature | 1 + synapse/res/templates/sso_error.html | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8405.feature diff --git a/changelog.d/8405.feature b/changelog.d/8405.feature new file mode 100644 index 0000000000..f3c4a74bc7 --- /dev/null +++ b/changelog.d/8405.feature @@ -0,0 +1 @@ +Consolidate the SSO error template across all configuration. diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html index af8459719a..944bc9c9ca 100644 --- a/synapse/res/templates/sso_error.html +++ b/synapse/res/templates/sso_error.html @@ -12,7 +12,7 @@

There was an error during authentication:

-
{{ error_description }}
+
{{ error_description | e }}

If you are seeing this page after clicking a link sent to you via email, make sure you only click the confirmation link once, and that you open the -- cgit 1.5.1 From 4b3a1faa08f5ad16e0e00dc629fb25be520575d7 Mon Sep 17 00:00:00 2001 From: Matthew Hodgson Date: Mon, 28 Sep 2020 00:23:35 +0100 Subject: typo --- synapse/storage/databases/main/schema/delta/56/event_labels.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/schema/delta/56/event_labels.sql b/synapse/storage/databases/main/schema/delta/56/event_labels.sql index 5e29c1da19..ccf287971c 100644 --- a/synapse/storage/databases/main/schema/delta/56/event_labels.sql +++ b/synapse/storage/databases/main/schema/delta/56/event_labels.sql @@ -13,7 +13,7 @@ * limitations under the License. */ --- room_id and topoligical_ordering are denormalised from the events table in order to +-- room_id and topological_ordering are denormalised from the events table in order to -- make the index work. CREATE TABLE IF NOT EXISTS event_labels ( event_id TEXT, -- cgit 1.5.1