From d93362d87fbbf4941da06ade65eaf5df1672bccb Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Dec 2021 12:26:29 -0500 Subject: Add a constant for receipt types (m.read). (#11531) And expand some type hints in the receipts storage module. --- synapse/api/constants.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'synapse/api') diff --git a/synapse/api/constants.py b/synapse/api/constants.py index f7d29b4319..52c083a20b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -253,5 +253,9 @@ class GuestAccess: FORBIDDEN: Final = "forbidden" +class ReceiptTypes: + READ: Final = "m.read" + + class ReadReceiptEventFields: MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden" -- cgit 1.5.1 From 17886d2603112531d4eda459d312f84d0d677652 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 15 Dec 2021 10:40:52 +0000 Subject: Add experimental support for MSC3202: allowing application services to masquerade as specific devices. (#11538) --- changelog.d/11538.feature | 1 + synapse/api/auth.py | 86 ++++++++++++++++++++++++++----- synapse/config/experimental.py | 5 ++ synapse/storage/databases/main/devices.py | 20 +++++++ tests/api/test_auth.py | 64 +++++++++++++++++++++++ 5 files changed, 162 insertions(+), 14 deletions(-) create mode 100644 changelog.d/11538.feature (limited to 'synapse/api') diff --git a/changelog.d/11538.feature b/changelog.d/11538.feature new file mode 100644 index 0000000000..b6229e2b45 --- /dev/null +++ b/changelog.d/11538.feature @@ -0,0 +1 @@ +Add experimental support for MSC3202: allowing application services to masquerade as specific devices. \ No newline at end of file diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 44883c6663..0bf58dff08 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -155,7 +155,11 @@ class Auth: access_token = self.get_access_token_from_request(request) - user_id, app_service = await self._get_appservice_user_id(request) + ( + user_id, + device_id, + app_service, + ) = await self._get_appservice_user_id_and_device_id(request) if user_id and app_service: if ip_addr and self._track_appservice_user_ips: await self.store.insert_client_ip( @@ -163,16 +167,22 @@ class Auth: access_token=access_token, ip=ip_addr, user_agent=user_agent, - device_id="dummy-device", # stubbed + device_id="dummy-device" + if device_id is None + else device_id, # stubbed ) - requester = create_requester(user_id, app_service=app_service) + requester = create_requester( + user_id, app_service=app_service, device_id=device_id + ) request.requester = user_id if user_id in self._force_tracing_for_users: opentracing.force_tracing() opentracing.set_tag("authenticated_entity", user_id) opentracing.set_tag("user_id", user_id) + if device_id is not None: + opentracing.set_tag("device_id", device_id) opentracing.set_tag("appservice_id", app_service.id) return requester @@ -274,33 +284,81 @@ class Auth: 403, "Application service has not registered this user (%s)" % user_id ) - async def _get_appservice_user_id( + async def _get_appservice_user_id_and_device_id( self, request: Request - ) -> Tuple[Optional[str], Optional[ApplicationService]]: + ) -> Tuple[Optional[str], Optional[str], Optional[ApplicationService]]: + """ + Given a request, reads the request parameters to determine: + - whether it's an application service that's making this request + - what user the application service should be treated as controlling + (the user_id URI parameter allows an application service to masquerade + any applicable user in its namespace) + - what device the application service should be treated as controlling + (the device_id[^1] URI parameter allows an application service to masquerade + as any device that exists for the relevant user) + + [^1] Unstable and provided by MSC3202. + Must use `org.matrix.msc3202.device_id` in place of `device_id` for now. + + Returns: + 3-tuple of + (user ID?, device ID?, application service?) + + Postconditions: + - If an application service is returned, so is a user ID + - A user ID is never returned without an application service + - A device ID is never returned without a user ID or an application service + - The returned application service, if present, is permitted to control the + returned user ID. + - The returned device ID, if present, has been checked to be a valid device ID + for the returned user ID. + """ + DEVICE_ID_ARG_NAME = b"org.matrix.msc3202.device_id" + app_service = self.store.get_app_service_by_token( self.get_access_token_from_request(request) ) if app_service is None: - return None, None + return None, None, None if app_service.ip_range_whitelist: ip_address = IPAddress(request.getClientIP()) if ip_address not in app_service.ip_range_whitelist: - return None, None + return None, None, None # This will always be set by the time Twisted calls us. assert request.args is not None - if b"user_id" not in request.args: - return app_service.sender, app_service + if b"user_id" in request.args: + effective_user_id = request.args[b"user_id"][0].decode("utf8") + await self.validate_appservice_can_control_user_id( + app_service, effective_user_id + ) + else: + effective_user_id = app_service.sender - user_id = request.args[b"user_id"][0].decode("utf8") - await self.validate_appservice_can_control_user_id(app_service, user_id) + effective_device_id: Optional[str] = None - if app_service.sender == user_id: - return app_service.sender, app_service + if ( + self.hs.config.experimental.msc3202_device_masquerading_enabled + and DEVICE_ID_ARG_NAME in request.args + ): + effective_device_id = request.args[DEVICE_ID_ARG_NAME][0].decode("utf8") + # We only just set this so it can't be None! + assert effective_device_id is not None + device_opt = await self.store.get_device( + effective_user_id, effective_device_id + ) + if device_opt is None: + # For now, use 400 M_EXCLUSIVE if the device doesn't exist. + # This is an open thread of discussion on MSC3202 as of 2021-12-09. + raise AuthError( + 400, + f"Application service trying to use a device that doesn't exist ('{effective_device_id}' for {effective_user_id})", + Codes.EXCLUSIVE, + ) - return user_id, app_service + return effective_user_id, effective_device_id, app_service async def get_user_by_access_token( self, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d78a15097c..678c78d565 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -49,3 +49,8 @@ class ExperimentalConfig(Config): # MSC3030 (Jump to date API endpoint) self.msc3030_enabled: bool = experimental.get("msc3030_enabled", False) + + # The portion of MSC3202 which is related to device masquerading. + self.msc3202_device_masquerading_enabled: bool = experimental.get( + "msc3202_device_masquerading", False + ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 3932599988..273adb61fd 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -128,6 +128,26 @@ class DeviceWorkerStore(SQLBaseStore): allow_none=True, ) + async def get_device_opt( + self, user_id: str, device_id: str + ) -> Optional[Dict[str, Any]]: + """Retrieve a device. Only returns devices that are not marked as + hidden. + + Args: + user_id: The ID of the user which owns the device + device_id: The ID of the device to retrieve + Returns: + A dict containing the device information, or None if the device does not exist. + """ + return await self.db_pool.simple_select_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, + retcols=("user_id", "device_id", "display_name"), + desc="get_device", + allow_none=True, + ) + async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]: """Retrieve all of a user's registered devices. Only returns devices that are not marked as hidden. diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 3aa9ba3c43..a2dfa1ed05 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -31,6 +31,7 @@ from synapse.types import Requester from tests import unittest from tests.test_utils import simple_async_mock +from tests.unittest import override_config from tests.utils import mock_getRawHeaders @@ -210,6 +211,69 @@ class AuthTestCase(unittest.HomeserverTestCase): request.requestHeaders.getRawHeaders = mock_getRawHeaders() self.get_failure(self.auth.get_user_by_req(request), AuthError) + @override_config({"experimental_features": {"msc3202_device_masquerading": True}}) + def test_get_user_by_req_appservice_valid_token_valid_device_id(self): + """ + Tests that when an application service passes the device_id URL parameter + with the ID of a valid device for the user in question, + the requester instance tracks that device ID. + """ + masquerading_user_id = b"@doppelganger:matrix.org" + masquerading_device_id = b"DOPPELDEVICE" + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None + ) + app_service.is_interested_in_user = Mock(return_value=True) + self.store.get_app_service_by_token = Mock(return_value=app_service) + # This just needs to return a truth-y value. + self.store.get_user_by_id = simple_async_mock({"is_guest": False}) + self.store.get_user_by_access_token = simple_async_mock(None) + # This also needs to just return a truth-y value + self.store.get_device = simple_async_mock({"hidden": False}) + + request = Mock(args={}) + request.getClientIP.return_value = "127.0.0.1" + request.args[b"access_token"] = [self.test_token] + request.args[b"user_id"] = [masquerading_user_id] + request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + requester = self.get_success(self.auth.get_user_by_req(request)) + self.assertEquals( + requester.user.to_string(), masquerading_user_id.decode("utf8") + ) + self.assertEquals(requester.device_id, masquerading_device_id.decode("utf8")) + + @override_config({"experimental_features": {"msc3202_device_masquerading": True}}) + def test_get_user_by_req_appservice_valid_token_invalid_device_id(self): + """ + Tests that when an application service passes the device_id URL parameter + with an ID that is not a valid device ID for the user in question, + the request fails with the appropriate error code. + """ + masquerading_user_id = b"@doppelganger:matrix.org" + masquerading_device_id = b"NOT_A_REAL_DEVICE_ID" + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None + ) + app_service.is_interested_in_user = Mock(return_value=True) + self.store.get_app_service_by_token = Mock(return_value=app_service) + # This just needs to return a truth-y value. + self.store.get_user_by_id = simple_async_mock({"is_guest": False}) + self.store.get_user_by_access_token = simple_async_mock(None) + # This also needs to just return a falsey value + self.store.get_device = simple_async_mock(None) + + request = Mock(args={}) + request.getClientIP.return_value = "127.0.0.1" + request.args[b"access_token"] = [self.test_token] + request.args[b"user_id"] = [masquerading_user_id] + request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + + failure = self.get_failure(self.auth.get_user_by_req(request), AuthError) + self.assertEquals(failure.value.code, 400) + self.assertEquals(failure.value.errcode, Codes.EXCLUSIVE) + def test_get_user_from_macaroon(self): self.store.get_user_by_access_token = simple_async_mock( TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device") -- cgit 1.5.1 From 221595414751f7b8fd0c79772c5ac4ffefeca10a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 21 Dec 2021 11:10:36 +0000 Subject: Various opentracing enhancements (#11619) * Wrap `auth.get_user_by_req` in an opentracing span give `get_user_by_req` its own opentracing span, since it can result in a non-trivial number of sub-spans which it is useful to group together. This requires a bit of reorganisation because it also sets some tags (and may force tracing) on the servlet span. * Emit opentracing span for encoding json responses This can be a significant time sink. * Rename all sync spans with a prefix * Write an opentracing span for encoding sync response * opentracing span to group generate_room_entries * opentracing spans within sync.encode_response * changelog * Use the `trace` decorator instead of context managers --- changelog.d/11619.misc | 1 + synapse/api/auth.py | 53 +++++++++++++++++++++++++++++++-------------- synapse/handlers/sync.py | 7 +++--- synapse/http/server.py | 19 ++++++++++++++-- synapse/rest/client/sync.py | 6 +++++ 5 files changed, 65 insertions(+), 21 deletions(-) create mode 100644 changelog.d/11619.misc (limited to 'synapse/api') diff --git a/changelog.d/11619.misc b/changelog.d/11619.misc new file mode 100644 index 0000000000..2125cbddd2 --- /dev/null +++ b/changelog.d/11619.misc @@ -0,0 +1 @@ +A number of improvements to opentracing support. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 0bf58dff08..4a32d430bd 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -32,7 +32,7 @@ from synapse.appservice import ApplicationService from synapse.events import EventBase from synapse.http import get_request_user_agent from synapse.http.site import SynapseRequest -from synapse.logging import opentracing as opentracing +from synapse.logging.opentracing import active_span, force_tracing, start_active_span from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import Requester, StateMap, UserID, create_requester from synapse.util.caches.lrucache import LruCache @@ -149,6 +149,42 @@ class Auth: is invalid. AuthError if access is denied for the user in the access token """ + parent_span = active_span() + with start_active_span("get_user_by_req"): + requester = await self._wrapped_get_user_by_req( + request, allow_guest, rights, allow_expired + ) + + if parent_span: + if requester.authenticated_entity in self._force_tracing_for_users: + # request tracing is enabled for this user, so we need to force it + # tracing on for the parent span (which will be the servlet span). + # + # It's too late for the get_user_by_req span to inherit the setting, + # so we also force it on for that. + force_tracing() + force_tracing(parent_span) + parent_span.set_tag( + "authenticated_entity", requester.authenticated_entity + ) + parent_span.set_tag("user_id", requester.user.to_string()) + if requester.device_id is not None: + parent_span.set_tag("device_id", requester.device_id) + if requester.app_service is not None: + parent_span.set_tag("appservice_id", requester.app_service.id) + return requester + + async def _wrapped_get_user_by_req( + self, + request: SynapseRequest, + allow_guest: bool, + rights: str, + allow_expired: bool, + ) -> Requester: + """Helper for get_user_by_req + + Once get_user_by_req has set up the opentracing span, this does the actual work. + """ try: ip_addr = request.getClientIP() user_agent = get_request_user_agent(request) @@ -177,14 +213,6 @@ class Auth: ) request.requester = user_id - if user_id in self._force_tracing_for_users: - opentracing.force_tracing() - opentracing.set_tag("authenticated_entity", user_id) - opentracing.set_tag("user_id", user_id) - if device_id is not None: - opentracing.set_tag("device_id", device_id) - opentracing.set_tag("appservice_id", app_service.id) - return requester user_info = await self.get_user_by_access_token( @@ -242,13 +270,6 @@ class Auth: ) request.requester = requester - if user_info.token_owner in self._force_tracing_for_users: - opentracing.force_tracing() - opentracing.set_tag("authenticated_entity", user_info.token_owner) - opentracing.set_tag("user_id", user_info.user_id) - if device_id: - opentracing.set_tag("device_id", device_id) - return requester except KeyError: raise MissingClientTokenError() diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index bcd10cbb30..d24124d6ac 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -421,7 +421,7 @@ class SyncHandler: span to track the sync. See `generate_sync_result` for the next part of your indoctrination. """ - with start_active_span("current_sync_for_user"): + with start_active_span("sync.current_sync_for_user"): log_kv({"since_token": since_token}) sync_result = await self.generate_sync_result( sync_config, since_token, full_state @@ -1585,7 +1585,8 @@ class SyncHandler: ) logger.debug("Generated room entry for %s", room_entry.room_id) - await concurrently_execute(handle_room_entries, room_entries, 10) + with start_active_span("sync.generate_room_entries"): + await concurrently_execute(handle_room_entries, room_entries, 10) sync_result_builder.invited.extend(invited) sync_result_builder.knocked.extend(knocked) @@ -2045,7 +2046,7 @@ class SyncHandler: since_token = room_builder.since_token upto_token = room_builder.upto_token - with start_active_span("generate_room_entry"): + with start_active_span("sync.generate_room_entry"): set_tag("room_id", room_id) log_kv({"events": len(events or ())}) diff --git a/synapse/http/server.py b/synapse/http/server.py index 7bbbe7648b..e302946591 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -58,12 +58,14 @@ from synapse.api.errors import ( ) from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background -from synapse.logging.opentracing import trace_servlet +from synapse.logging.opentracing import active_span, start_active_span, trace_servlet from synapse.util import json_encoder from synapse.util.caches import intern_dict from synapse.util.iterutils import chunk_seq if TYPE_CHECKING: + import opentracing + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -759,7 +761,20 @@ async def _async_write_json_to_request_in_thread( expensive. """ - json_str = await defer_to_thread(request.reactor, json_encoder, json_object) + def encode(opentracing_span: "Optional[opentracing.Span]") -> bytes: + # it might take a while for the threadpool to schedule us, so we write + # opentracing logs once we actually get scheduled, so that we can see how + # much that contributed. + if opentracing_span: + opentracing_span.log_kv({"event": "scheduled"}) + res = json_encoder(json_object) + if opentracing_span: + opentracing_span.log_kv({"event": "encoded"}) + return res + + with start_active_span("encode_json_response"): + span = active_span() + json_str = await defer_to_thread(request.reactor, encode, span) _write_bytes_to_request(request, json_str) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 8c4b0f6e5d..e99a943d0d 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -48,6 +48,7 @@ from synapse.handlers.sync import ( from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest +from synapse.logging.opentracing import trace from synapse.types import JsonDict, StreamToken from synapse.util import json_decoder @@ -222,6 +223,7 @@ class SyncRestServlet(RestServlet): logger.debug("Event formatting complete") return 200, response_content + @trace(opname="sync.encode_response") async def encode_response( self, time_now: int, @@ -332,6 +334,7 @@ class SyncRestServlet(RestServlet): ] } + @trace(opname="sync.encode_joined") async def encode_joined( self, rooms: List[JoinedSyncResult], @@ -368,6 +371,7 @@ class SyncRestServlet(RestServlet): return joined + @trace(opname="sync.encode_invited") async def encode_invited( self, rooms: List[InvitedSyncResult], @@ -406,6 +410,7 @@ class SyncRestServlet(RestServlet): return invited + @trace(opname="sync.encode_knocked") async def encode_knocked( self, rooms: List[KnockedSyncResult], @@ -460,6 +465,7 @@ class SyncRestServlet(RestServlet): return knocked + @trace(opname="sync.encode_archived") async def encode_archived( self, rooms: List[ArchivedSyncResult], -- cgit 1.5.1 From cbd82d0b2db069400b5d43373838817d8a0209e7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Dec 2021 13:47:12 -0500 Subject: Convert all namedtuples to attrs. (#11665) To improve type hints throughout the code. --- changelog.d/11665.misc | 1 + synapse/api/filtering.py | 3 +- synapse/config/repository.py | 34 +++---- synapse/federation/federation_base.py | 5 - synapse/federation/send_queue.py | 47 +++++----- synapse/handlers/appservice.py | 4 +- synapse/handlers/directory.py | 10 +- synapse/handlers/room_list.py | 22 ++--- synapse/handlers/typing.py | 14 ++- synapse/http/server.py | 10 +- synapse/replication/tcp/streams/_base.py | 129 +++++++++++++------------- synapse/replication/tcp/streams/federation.py | 15 ++- synapse/rest/media/v1/media_repository.py | 19 ++-- synapse/state/__init__.py | 5 +- synapse/storage/databases/main/directory.py | 10 +- synapse/storage/databases/main/events.py | 13 ++- synapse/storage/databases/main/room.py | 26 ++++-- synapse/storage/databases/main/search.py | 16 +++- synapse/storage/databases/main/state.py | 14 --- synapse/storage/databases/main/stream.py | 12 ++- synapse/types.py | 22 ++--- tests/replication/test_federation_ack.py | 6 +- 22 files changed, 231 insertions(+), 206 deletions(-) create mode 100644 changelog.d/11665.misc (limited to 'synapse/api') diff --git a/changelog.d/11665.misc b/changelog.d/11665.misc new file mode 100644 index 0000000000..e7cc8ff23f --- /dev/null +++ b/changelog.d/11665.misc @@ -0,0 +1 @@ +Convert `namedtuples` to `attrs`. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 13dd6ce248..d087c816db 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -351,8 +351,7 @@ class Filter: True if the event matches the filter. """ # We usually get the full "events" as dictionaries coming through, - # except for presence which actually gets passed around as its own - # namedtuple type. + # except for presence which actually gets passed around as its own type. if isinstance(event, UserPresenceState): user_id = event.user_id field_matchers = { diff --git a/synapse/config/repository.py b/synapse/config/repository.py index b129b9dd68..1980351e77 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -14,10 +14,11 @@ import logging import os -from collections import namedtuple from typing import Dict, List, Tuple from urllib.request import getproxies_environment # type: ignore +import attr + from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import JsonDict @@ -44,18 +45,20 @@ THUMBNAIL_SIZE_YAML = """\ HTTP_PROXY_SET_WARNING = """\ The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured.""" -ThumbnailRequirement = namedtuple( - "ThumbnailRequirement", ["width", "height", "method", "media_type"] -) -MediaStorageProviderConfig = namedtuple( - "MediaStorageProviderConfig", - ( - "store_local", # Whether to store newly uploaded local files - "store_remote", # Whether to store newly downloaded remote files - "store_synchronous", # Whether to wait for successful storage for local uploads - ), -) +@attr.s(frozen=True, slots=True, auto_attribs=True) +class ThumbnailRequirement: + width: int + height: int + method: str + media_type: str + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class MediaStorageProviderConfig: + store_local: bool # Whether to store newly uploaded local files + store_remote: bool # Whether to store newly downloaded remote files + store_synchronous: bool # Whether to wait for successful storage for local uploads def parse_thumbnail_requirements( @@ -66,11 +69,10 @@ def parse_thumbnail_requirements( method, and thumbnail media type to precalculate Args: - thumbnail_sizes(list): List of dicts with "width", "height", and - "method" keys + thumbnail_sizes: List of dicts with "width", "height", and "method" keys + Returns: - Dictionary mapping from media type string to list of - ThumbnailRequirement tuples. + Dictionary mapping from media type string to list of ThumbnailRequirement. """ requirements: Dict[str, List[ThumbnailRequirement]] = {} for size in thumbnail_sizes: diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index f56344a3b9..4df90e02d7 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from collections import namedtuple from typing import TYPE_CHECKING from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership @@ -104,10 +103,6 @@ class FederationBase: return pdu -class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])): - pass - - async def _check_sigs_on_pdu( keyring: Keyring, room_version: RoomVersion, pdu: EventBase ) -> None: diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 63289a5a33..0d7c4f5067 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -30,7 +30,6 @@ Events are replicated via a separate events stream. """ import logging -from collections import namedtuple from typing import ( TYPE_CHECKING, Dict, @@ -43,6 +42,7 @@ from typing import ( Type, ) +import attr from sortedcontainers import SortedDict from synapse.api.presence import UserPresenceState @@ -382,13 +382,11 @@ class BaseFederationRow: raise NotImplementedError() -class PresenceDestinationsRow( - BaseFederationRow, - namedtuple( - "PresenceDestinationsRow", - ("state", "destinations"), # UserPresenceState # list[str] - ), -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class PresenceDestinationsRow(BaseFederationRow): + state: UserPresenceState + destinations: List[str] + TypeId = "pd" @staticmethod @@ -404,17 +402,15 @@ class PresenceDestinationsRow( buff.presence_destinations.append((self.state, self.destinations)) -class KeyedEduRow( - BaseFederationRow, - namedtuple( - "KeyedEduRow", - ("key", "edu"), # tuple(str) - the edu key passed to send_edu # Edu - ), -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class KeyedEduRow(BaseFederationRow): """Streams EDUs that have an associated key that is ued to clobber. For example, typing EDUs clobber based on room_id. """ + key: Tuple[str, ...] # the edu key passed to send_edu + edu: Edu + TypeId = "k" @staticmethod @@ -428,9 +424,12 @@ class KeyedEduRow( buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu -class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EduRow(BaseFederationRow): """Streams EDUs that don't have keys. See KeyedEduRow""" + edu: Edu + TypeId = "e" @staticmethod @@ -453,14 +452,14 @@ _rowtypes: Tuple[Type[BaseFederationRow], ...] = ( TypeToRow = {Row.TypeId: Row for Row in _rowtypes} -ParsedFederationStreamData = namedtuple( - "ParsedFederationStreamData", - ( - "presence_destinations", # list of tuples of UserPresenceState and destinations - "keyed_edus", # dict of destination -> { key -> Edu } - "edus", # dict of destination -> [Edu] - ), -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ParsedFederationStreamData: + # list of tuples of UserPresenceState and destinations + presence_destinations: List[Tuple[UserPresenceState, List[str]]] + # dict of destination -> { key -> Edu } + keyed_edus: Dict[str, Dict[Tuple[str, ...], Edu]] + # dict of destination -> [Edu] + edus: Dict[str, List[Edu]] def process_rows_for_federation( diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 9abdad262b..7833e77e2b 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -462,9 +462,9 @@ class ApplicationServicesHandler: Args: room_alias: The room alias to query. + Returns: - namedtuple: with keys "room_id" and "servers" or None if no - association can be found. + RoomAliasMapping or None if no association can be found. """ room_alias_str = room_alias.to_string() services = self.store.get_app_services() diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 7ee5c47fd9..082f521791 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -278,13 +278,15 @@ class DirectoryHandler: users = await self.store.get_users_in_room(room_id) extra_servers = {get_domain_from_id(u) for u in users} - servers = set(extra_servers) | set(servers) + servers_set = set(extra_servers) | set(servers) # If this server is in the list of servers, return it first. - if self.server_name in servers: - servers = [self.server_name] + [s for s in servers if s != self.server_name] + if self.server_name in servers_set: + servers = [self.server_name] + [ + s for s in servers_set if s != self.server_name + ] else: - servers = list(servers) + servers = list(servers_set) return {"room_id": room_id, "servers": servers} diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index ba7a14d651..1a33211a1f 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -13,9 +13,9 @@ # limitations under the License. import logging -from collections import namedtuple from typing import TYPE_CHECKING, Any, Optional, Tuple +import attr import msgpack from unpaddedbase64 import decode_base64, encode_base64 @@ -474,16 +474,12 @@ class RoomListHandler: ) -class RoomListNextBatch( - namedtuple( - "RoomListNextBatch", - ( - "last_joined_members", # The count to get rooms after/before - "last_room_id", # The room_id to get rooms after/before - "direction_is_forward", # Bool if this is a next_batch, false if prev_batch - ), - ) -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomListNextBatch: + last_joined_members: int # The count to get rooms after/before + last_room_id: str # The room_id to get rooms after/before + direction_is_forward: bool # True if this is a next_batch, false if prev_batch + KEY_DICT = { "last_joined_members": "m", "last_room_id": "r", @@ -502,12 +498,12 @@ class RoomListNextBatch( def to_token(self) -> str: return encode_base64( msgpack.dumps( - {self.KEY_DICT[key]: val for key, val in self._asdict().items()} + {self.KEY_DICT[key]: val for key, val in attr.asdict(self).items()} ) ) def copy_and_replace(self, **kwds: Any) -> "RoomListNextBatch": - return self._replace(**kwds) + return attr.evolve(self, **kwds) def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 1676ebd057..e43c22832d 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -13,9 +13,10 @@ # limitations under the License. import logging import random -from collections import namedtuple from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple +import attr + from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.appservice import ApplicationService from synapse.metrics.background_process_metrics import ( @@ -37,7 +38,10 @@ logger = logging.getLogger(__name__) # A tiny object useful for storing a user's membership in a room, as a mapping # key -RoomMember = namedtuple("RoomMember", ("room_id", "user_id")) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomMember: + room_id: str + user_id: str # How often we expect remote servers to resend us presence. @@ -119,7 +123,7 @@ class FollowerTypingHandler: self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) def is_typing(self, member: RoomMember) -> bool: - return member.user_id in self._room_typing.get(member.room_id, []) + return member.user_id in self._room_typing.get(member.room_id, set()) async def _push_remote(self, member: RoomMember, typing: bool) -> None: if not self.federation: @@ -166,9 +170,9 @@ class FollowerTypingHandler: for row in rows: self._room_serials[row.room_id] = token - prev_typing = set(self._room_typing.get(row.room_id, [])) + prev_typing = self._room_typing.get(row.room_id, set()) now_typing = set(row.user_ids) - self._room_typing[row.room_id] = row.user_ids + self._room_typing[row.room_id] = now_typing if self.federation: run_as_background_process( diff --git a/synapse/http/server.py b/synapse/http/server.py index e302946591..09b4125489 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -14,7 +14,6 @@ # limitations under the License. import abc -import collections import html import logging import types @@ -37,6 +36,7 @@ from typing import ( Union, ) +import attr import jinja2 from canonicaljson import encode_canonical_json from typing_extensions import Protocol @@ -354,9 +354,11 @@ class DirectServeJsonResource(_AsyncResource): return_json_error(f, request) -_PathEntry = collections.namedtuple( - "_PathEntry", ["pattern", "callback", "servlet_classname"] -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _PathEntry: + pattern: Pattern + callback: ServletCallback + servlet_classname: str class JsonResource(DirectServeJsonResource): diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 743a01da08..5a2d90c530 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -15,7 +15,6 @@ import heapq import logging -from collections import namedtuple from typing import ( TYPE_CHECKING, Any, @@ -30,6 +29,7 @@ from typing import ( import attr from synapse.replication.http.streams import ReplicationGetStreamUpdates +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -226,17 +226,14 @@ class BackfillStream(Stream): or it went from being an outlier to not. """ - BackfillStreamRow = namedtuple( - "BackfillStreamRow", - ( - "event_id", # str - "room_id", # str - "type", # str - "state_key", # str, optional - "redacts", # str, optional - "relates_to", # str, optional - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class BackfillStreamRow: + event_id: str + room_id: str + type: str + state_key: Optional[str] + redacts: Optional[str] + relates_to: Optional[str] NAME = "backfill" ROW_TYPE = BackfillStreamRow @@ -256,18 +253,15 @@ class BackfillStream(Stream): class PresenceStream(Stream): - PresenceStreamRow = namedtuple( - "PresenceStreamRow", - ( - "user_id", # str - "state", # str - "last_active_ts", # int - "last_federation_update_ts", # int - "last_user_sync_ts", # int - "status_msg", # str - "currently_active", # bool - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class PresenceStreamRow: + user_id: str + state: str + last_active_ts: int + last_federation_update_ts: int + last_user_sync_ts: int + status_msg: str + currently_active: bool NAME = "presence" ROW_TYPE = PresenceStreamRow @@ -302,7 +296,7 @@ class PresenceFederationStream(Stream): send. """ - @attr.s(slots=True, auto_attribs=True) + @attr.s(slots=True, frozen=True, auto_attribs=True) class PresenceFederationStreamRow: destination: str user_id: str @@ -320,9 +314,10 @@ class PresenceFederationStream(Stream): class TypingStream(Stream): - TypingStreamRow = namedtuple( - "TypingStreamRow", ("room_id", "user_ids") # str # list(str) - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class TypingStreamRow: + room_id: str + user_ids: List[str] NAME = "typing" ROW_TYPE = TypingStreamRow @@ -348,16 +343,13 @@ class TypingStream(Stream): class ReceiptsStream(Stream): - ReceiptsStreamRow = namedtuple( - "ReceiptsStreamRow", - ( - "room_id", # str - "receipt_type", # str - "user_id", # str - "event_id", # str - "data", # dict - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ReceiptsStreamRow: + room_id: str + receipt_type: str + user_id: str + event_id: str + data: dict NAME = "receipts" ROW_TYPE = ReceiptsStreamRow @@ -374,7 +366,9 @@ class ReceiptsStream(Stream): class PushRulesStream(Stream): """A user has changed their push rules""" - PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str + @attr.s(slots=True, frozen=True, auto_attribs=True) + class PushRulesStreamRow: + user_id: str NAME = "push_rules" ROW_TYPE = PushRulesStreamRow @@ -396,10 +390,12 @@ class PushRulesStream(Stream): class PushersStream(Stream): """A user has added/changed/removed a pusher""" - PushersStreamRow = namedtuple( - "PushersStreamRow", - ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class PushersStreamRow: + user_id: str + app_id: str + pushkey: str + deleted: bool NAME = "pushers" ROW_TYPE = PushersStreamRow @@ -419,7 +415,7 @@ class CachesStream(Stream): the cache on the workers """ - @attr.s(slots=True) + @attr.s(slots=True, frozen=True, auto_attribs=True) class CachesStreamRow: """Stream to inform workers they should invalidate their cache. @@ -430,9 +426,9 @@ class CachesStream(Stream): invalidation_ts: Timestamp of when the invalidation took place. """ - cache_func = attr.ib(type=str) - keys = attr.ib(type=Optional[List[Any]]) - invalidation_ts = attr.ib(type=int) + cache_func: str + keys: Optional[List[Any]] + invalidation_ts: int NAME = "caches" ROW_TYPE = CachesStreamRow @@ -451,9 +447,9 @@ class DeviceListsStream(Stream): told about a device update. """ - @attr.s(slots=True) + @attr.s(slots=True, frozen=True, auto_attribs=True) class DeviceListsStreamRow: - entity = attr.ib(type=str) + entity: str NAME = "device_lists" ROW_TYPE = DeviceListsStreamRow @@ -470,7 +466,9 @@ class DeviceListsStream(Stream): class ToDeviceStream(Stream): """New to_device messages for a client""" - ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ToDeviceStreamRow: + entity: str NAME = "to_device" ROW_TYPE = ToDeviceStreamRow @@ -487,9 +485,11 @@ class ToDeviceStream(Stream): class TagAccountDataStream(Stream): """Someone added/removed a tag for a room""" - TagAccountDataStreamRow = namedtuple( - "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class TagAccountDataStreamRow: + user_id: str + room_id: str + data: JsonDict NAME = "tag_account_data" ROW_TYPE = TagAccountDataStreamRow @@ -506,10 +506,11 @@ class TagAccountDataStream(Stream): class AccountDataStream(Stream): """Global or per room account data was changed""" - AccountDataStreamRow = namedtuple( - "AccountDataStreamRow", - ("user_id", "room_id", "data_type"), # str # Optional[str] # str - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class AccountDataStreamRow: + user_id: str + room_id: Optional[str] + data_type: str NAME = "account_data" ROW_TYPE = AccountDataStreamRow @@ -573,10 +574,12 @@ class AccountDataStream(Stream): class GroupServerStream(Stream): - GroupsStreamRow = namedtuple( - "GroupsStreamRow", - ("group_id", "user_id", "type", "content"), # str # str # str # dict - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class GroupsStreamRow: + group_id: str + user_id: str + type: str + content: JsonDict NAME = "groups" ROW_TYPE = GroupsStreamRow @@ -593,7 +596,9 @@ class GroupServerStream(Stream): class UserSignatureStream(Stream): """A user has signed their own device with their user-signing key""" - UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str + @attr.s(slots=True, frozen=True, auto_attribs=True) + class UserSignatureStreamRow: + user_id: str NAME = "user_signature" ROW_TYPE = UserSignatureStreamRow diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 0600cdbf36..4046bdec69 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -12,14 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Tuple +import attr + from synapse.replication.tcp.streams._base import ( Stream, current_token_without_instance, make_http_update_function, ) +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -30,13 +32,10 @@ class FederationStream(Stream): sending disabled. """ - FederationStreamRow = namedtuple( - "FederationStreamRow", - ( - "type", # str, the type of data as defined in the BaseFederationRows - "data", # dict, serialization of a federation.send_queue.BaseFederationRow - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class FederationStreamRow: + type: str # the type of data as defined in the BaseFederationRows + data: JsonDict # serialization of a federation.send_queue.BaseFederationRow NAME = "federation" ROW_TYPE = FederationStreamRow diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 244ba261bb..71b9a34b14 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -739,14 +739,21 @@ class MediaRepository: # We deduplicate the thumbnail sizes by ignoring the cropped versions if # they have the same dimensions of a scaled one. thumbnails: Dict[Tuple[int, int, str], str] = {} - for r_width, r_height, r_method, r_type in requirements: - if r_method == "crop": - thumbnails.setdefault((r_width, r_height, r_type), r_method) - elif r_method == "scale": - t_width, t_height = thumbnailer.aspect(r_width, r_height) + for requirement in requirements: + if requirement.method == "crop": + thumbnails.setdefault( + (requirement.width, requirement.height, requirement.media_type), + requirement.method, + ) + elif requirement.method == "scale": + t_width, t_height = thumbnailer.aspect( + requirement.width, requirement.height + ) t_width = min(m_width, t_width) t_height = min(m_height, t_height) - thumbnails[(t_width, t_height, r_type)] = r_method + thumbnails[ + (t_width, t_height, requirement.media_type) + ] = requirement.method # Now we generate the thumbnails for each dimension, store it for (t_width, t_height, t_type), t_method in thumbnails.items(): diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 446204dbe5..69ac8c3423 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. import heapq import logging -from collections import defaultdict, namedtuple +from collections import defaultdict from typing import ( TYPE_CHECKING, Any, @@ -69,9 +69,6 @@ state_groups_histogram = Histogram( ) -KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) - - EVICTION_TIMEOUT_SECONDS = 60 * 60 diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index a3442814d7..f76c6121e8 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -12,16 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from typing import Iterable, List, Optional, Tuple +import attr + from synapse.api.errors import SynapseError from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.types import RoomAlias from synapse.util.caches.descriptors import cached -RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomAliasMapping: + room_id: str + room_alias: str + servers: List[str] class DirectoryWorkerStore(CacheInvalidationWorkerStore): diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 81e67ece55..dd255aefb9 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1976,14 +1976,17 @@ class PersistEventsStore: txn, self.store.get_retention_policy_for_room, (event.room_id,) ) - def store_event_search_txn(self, txn, event, key, value): + def store_event_search_txn( + self, txn: LoggingTransaction, event: EventBase, key: str, value: str + ) -> None: """Add event to the search table Args: - txn (cursor): - event (EventBase): - key (str): - value (str): + txn: The database transaction. + event: The event being added to the search table. + key: A key describing the search value (one of "content.name", + "content.topic", or "content.body") + value: The value from the event's content. """ self.store.store_search_entries_txn( txn, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 4472335af9..c0e837854a 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -13,11 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import logging from abc import abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Dict, + List, + Optional, + Tuple, + Union, + cast, +) + +import attr from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.errors import StoreError @@ -43,9 +54,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -RatelimitOverride = collections.namedtuple( - "RatelimitOverride", ("messages_per_second", "burst_count") -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RatelimitOverride: + messages_per_second: int + burst_count: int class RoomSortOrder(Enum): @@ -207,6 +219,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): WHERE appservice_id = ? AND network_id = ? """ query_args.append(network_tuple.appservice_id) + assert network_tuple.network_id is not None query_args.append(network_tuple.network_id) else: published_sql = """ @@ -284,7 +297,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): """ where_clauses = [] - query_args = [] + query_args: List[Union[str, int]] = [] if network_tuple: if network_tuple.appservice_id: @@ -293,6 +306,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): WHERE appservice_id = ? AND network_id = ? """ query_args.append(network_tuple.appservice_id) + assert network_tuple.network_id is not None query_args.append(network_tuple.network_id) else: published_sql = """ diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index f87acfb866..2d085a5764 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -14,9 +14,10 @@ import logging import re -from collections import namedtuple from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set +import attr + from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -33,10 +34,15 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -SearchEntry = namedtuple( - "SearchEntry", - ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"], -) + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class SearchEntry: + key: str + value: str + event_id: str + room_id: str + stream_ordering: Optional[int] + origin_server_ts: int def _clean_value_for_search(value: str) -> str: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 4bc044fb16..7e5a6aae18 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -14,7 +14,6 @@ # limitations under the License. import collections.abc import logging -from collections import namedtuple from typing import TYPE_CHECKING, Iterable, Optional, Set from synapse.api.constants import EventTypes, Membership @@ -43,19 +42,6 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 -class _GetStateGroupDelta( - namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) -): - """Return type of get_state_group_delta that implements __len__, which lets - us use the itrable flag when caching - """ - - __slots__ = [] - - def __len__(self): - return len(self.delta_ids) if self.delta_ids else 0 - - # this inherits from EventsWorkerStore because it calls self.get_events class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers.""" diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 9488fd5094..b0642ca69f 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -36,9 +36,9 @@ what sort order was used: """ import abc import logging -from collections import namedtuple from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple +import attr from frozendict import frozendict from twisted.internet import defer @@ -74,9 +74,11 @@ _TOPOLOGICAL_TOKEN = "topological" # Used as return values for pagination APIs -_EventDictReturn = namedtuple( - "_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering") -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _EventDictReturn: + event_id: str + topological_ordering: Optional[int] + stream_ordering: int def generate_pagination_where_clause( @@ -825,7 +827,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): for event, row in zip(events, rows): stream = row.stream_ordering if topo_order and row.topological_ordering: - topo = row.topological_ordering + topo: Optional[int] = row.topological_ordering else: topo = None internal = event.internal_metadata diff --git a/synapse/types.py b/synapse/types.py index b06979e8e8..42aeaf6270 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -15,7 +15,6 @@ import abc import re import string -from collections import namedtuple from typing import ( TYPE_CHECKING, Any, @@ -227,8 +226,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta): localpart = attr.ib(type=str) domain = attr.ib(type=str) - # Because this class is a namedtuple of strings and booleans, it is deeply - # immutable. + # Because this is a frozen class, it is deeply immutable. def __copy__(self): return self @@ -708,16 +706,18 @@ class PersistedEventPosition: return RoomStreamToken(None, self.stream) -class ThirdPartyInstanceID( - namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id")) -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThirdPartyInstanceID: + appservice_id: Optional[str] + network_id: Optional[str] + # Deny iteration because it will bite you if you try to create a singleton # set by: # users = set(user) def __iter__(self): raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) - # Because this class is a namedtuple of strings, it is deeply immutable. + # Because this class is a frozen class, it is deeply immutable. def __copy__(self): return self @@ -725,22 +725,18 @@ class ThirdPartyInstanceID( return self @classmethod - def from_string(cls, s): + def from_string(cls, s: str) -> "ThirdPartyInstanceID": bits = s.split("|", 2) if len(bits) != 2: raise SynapseError(400, "Invalid ID %r" % (s,)) return cls(appservice_id=bits[0], network_id=bits[1]) - def to_string(self): + def to_string(self) -> str: return "%s|%s" % (self.appservice_id, self.network_id) __str__ = to_string - @classmethod - def create(cls, appservice_id, network_id): - return cls(appservice_id=appservice_id, network_id=network_id) - @attr.s(slots=True) class ReadReceipt: diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 04a869e295..1b6a4bf4b0 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -62,7 +62,11 @@ class FederationAckTestCase(HomeserverTestCase): "federation", "master", token=10, - rows=[FederationStream.FederationStreamRow(type="x", data=[1, 2, 3])], + rows=[ + FederationStream.FederationStreamRow( + type="x", data={"test": [1, 2, 3]} + ) + ], ) ) -- cgit 1.5.1