From 0147b3de20f313975226a9a3f319c77b90aa2793 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 14 Dec 2021 17:35:28 +0000 Subject: Add missing type hints to `synapse.logging.context` (#11556) --- synapse/federation/federation_server.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'synapse/federation') diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8e37e76206..cf067b56c6 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -30,7 +30,6 @@ from typing import ( from prometheus_client import Counter, Gauge, Histogram -from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure @@ -67,7 +66,7 @@ from synapse.replication.http.federation import ( from synapse.storage.databases.main.lock import Lock from synapse.types import JsonDict, get_domain_from_id from synapse.util import glob_to_regex, json_decoder, unwrapFirstError -from synapse.util.async_helpers import Linearizer, concurrently_execute +from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_server_name @@ -360,13 +359,13 @@ class FederationServer(FederationBase): # want to block things like to device messages from reaching clients # behind the potentially expensive handling of PDUs. pdu_results, _ = await make_deferred_yieldable( - defer.gatherResults( - [ + gather_results( + ( run_in_background( self._handle_pdus_in_txn, origin, transaction, request_time ), run_in_background(self._handle_edus_in_txn, origin, transaction), - ], + ), consumeErrors=True, ).addErrback(unwrapFirstError) ) -- cgit 1.5.1 From 60fa4935b5d3ee26f9ebb4b25ec74bed26d3c98d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 20 Dec 2021 17:45:03 +0000 Subject: Improve opentracing for incoming HTTP requests (#11618) * remove `start_active_span_from_request` Instead, pull out a separate function, `span_context_from_request`, to extract the parent span, which we can then pass into `start_active_span` as normal. This seems to be clearer all round. * Remove redundant tags from `incoming-federation-request` These are all wrapped up inside a parent span generated in AsyncResource, so there's no point duplicating all the tags that are set there. * Leave request spans open until the request completes It may take some time for the response to be encoded into JSON, and that JSON to be streamed back to the client, and really we want that inside the top-level span, so let's hand responsibility for closure to the SynapseRequest. * opentracing logs for HTTP request events * changelog --- changelog.d/11618.misc | 1 + synapse/federation/transport/server/_base.py | 39 ++++++---------- synapse/http/site.py | 30 +++++++++++- synapse/logging/opentracing.py | 68 +++++++++------------------- 4 files changed, 65 insertions(+), 73 deletions(-) create mode 100644 changelog.d/11618.misc (limited to 'synapse/federation') diff --git a/changelog.d/11618.misc b/changelog.d/11618.misc new file mode 100644 index 0000000000..4076b30bf7 --- /dev/null +++ b/changelog.d/11618.misc @@ -0,0 +1 @@ +Improve opentracing support for incoming HTTP requests. diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index dc39e3537b..da1fbf8b63 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -22,13 +22,11 @@ from synapse.api.urls import FEDERATION_V1_PREFIX from synapse.http.server import HttpServer, ServletCallback from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.logging import opentracing from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( - SynapseTags, - start_active_span, - start_active_span_from_request, - tags, + set_tag, + span_context_from_request, + start_active_span_follows_from, whitelisted_homeserver, ) from synapse.server import HomeServer @@ -279,30 +277,19 @@ class BaseFederationServlet: logger.warning("authenticate_request failed: %s", e) raise - request_tags = { - SynapseTags.REQUEST_ID: request.get_request_id(), - tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, - tags.HTTP_METHOD: request.get_method(), - tags.HTTP_URL: request.get_redacted_uri(), - tags.PEER_HOST_IPV6: request.getClientIP(), - "authenticated_entity": origin, - "servlet_name": request.request_metrics.name, - } - - # Only accept the span context if the origin is authenticated - # and whitelisted + # update the active opentracing span with the authenticated entity + set_tag("authenticated_entity", origin) + + # if the origin is authenticated and whitelisted, link to its span context + context = None if origin and whitelisted_homeserver(origin): - scope = start_active_span_from_request( - request, "incoming-federation-request", tags=request_tags - ) - else: - scope = start_active_span( - "incoming-federation-request", tags=request_tags - ) + context = span_context_from_request(request) - with scope: - opentracing.inject_response_headers(request.responseHeaders) + scope = start_active_span_follows_from( + "incoming-federation-request", contexts=(context,) if context else () + ) + with scope: if origin and self.RATELIMIT: with ratelimiter.ratelimit(origin) as d: await d diff --git a/synapse/http/site.py b/synapse/http/site.py index 9f68d7e191..80f7a2ff58 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -14,7 +14,7 @@ import contextlib import logging import time -from typing import Any, Generator, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union import attr from zope.interface import implementer @@ -35,6 +35,9 @@ from synapse.logging.context import ( ) from synapse.types import Requester +if TYPE_CHECKING: + import opentracing + logger = logging.getLogger(__name__) _next_request_seq = 0 @@ -81,6 +84,10 @@ class SynapseRequest(Request): # server name, for client requests this is the Requester object. self._requester: Optional[Union[Requester, str]] = None + # An opentracing span for this request. Will be closed when the request is + # completely processed. + self._opentracing_span: "Optional[opentracing.Span]" = None + # we can't yet create the logcontext, as we don't know the method. self.logcontext: Optional[LoggingContext] = None @@ -148,6 +155,13 @@ class SynapseRequest(Request): # If there's no authenticated entity, it was the requester. self.logcontext.request.authenticated_entity = authenticated_entity or requester + def set_opentracing_span(self, span: "opentracing.Span") -> None: + """attach an opentracing span to this request + + Doing so will cause the span to be closed when we finish processing the request + """ + self._opentracing_span = span + def get_request_id(self) -> str: return "%s-%i" % (self.get_method(), self.request_seq) @@ -286,6 +300,9 @@ class SynapseRequest(Request): self._processing_finished_time = time.time() self._is_processing = False + if self._opentracing_span: + self._opentracing_span.log_kv({"event": "finished processing"}) + # if we've already sent the response, log it now; otherwise, we wait for the # response to be sent. if self.finish_time is not None: @@ -299,6 +316,8 @@ class SynapseRequest(Request): """ self.finish_time = time.time() Request.finish(self) + if self._opentracing_span: + self._opentracing_span.log_kv({"event": "response sent"}) if not self._is_processing: assert self.logcontext is not None with PreserveLoggingContext(self.logcontext): @@ -333,6 +352,11 @@ class SynapseRequest(Request): with PreserveLoggingContext(self.logcontext): logger.info("Connection from client lost before response was sent") + if self._opentracing_span: + self._opentracing_span.log_kv( + {"event": "client connection lost", "reason": str(reason.value)} + ) + if not self._is_processing: self._finished_processing() @@ -421,6 +445,10 @@ class SynapseRequest(Request): usage.evt_db_fetch_count, ) + # complete the opentracing span, if any. + if self._opentracing_span: + self._opentracing_span.finish() + try: self.request_metrics.stop(self.finish_time, self.code, self.sentLength) except Exception as e: diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 5d93ab07f1..6364290615 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -173,6 +173,7 @@ from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Typ import attr from twisted.internet import defer +from twisted.web.http import Request from twisted.web.http_headers import Headers from synapse.config import ConfigError @@ -490,48 +491,6 @@ def start_active_span_follows_from( return scope -def start_active_span_from_request( - request, - operation_name, - references=None, - tags=None, - start_time=None, - ignore_active_span=False, - finish_on_close=True, -): - """ - Extracts a span context from a Twisted Request. - args: - headers (twisted.web.http.Request) - - For the other args see opentracing.tracer - - returns: - span_context (opentracing.span.SpanContext) - """ - # Twisted encodes the values as lists whereas opentracing doesn't. - # So, we take the first item in the list. - # Also, twisted uses byte arrays while opentracing expects strings. - - if opentracing is None: - return noop_context_manager() # type: ignore[unreachable] - - header_dict = { - k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders() - } - context = opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict) - - return opentracing.tracer.start_active_span( - operation_name, - child_of=context, - references=references, - tags=tags, - start_time=start_time, - ignore_active_span=ignore_active_span, - finish_on_close=finish_on_close, - ) - - def start_active_span_from_edu( edu_content, operation_name, @@ -743,6 +702,20 @@ def active_span_context_as_string(): return json_encoder.encode(carrier) +def span_context_from_request(request: Request) -> "Optional[opentracing.SpanContext]": + """Extract an opentracing context from the headers on an HTTP request + + This is useful when we have received an HTTP request from another part of our + system, and want to link our spans to those of the remote system. + """ + if not opentracing: + return None + header_dict = { + k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders() + } + return opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict) + + @only_if_tracing def span_context_from_string(carrier): """ @@ -882,10 +855,13 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False): } request_name = request.request_metrics.name - if extract_context: - scope = start_active_span_from_request(request, request_name) - else: - scope = start_active_span(request_name) + context = span_context_from_request(request) if extract_context else None + + # we configure the scope not to finish the span immediately on exit, and instead + # pass the span into the SynapseRequest, which will finish it once we've finished + # sending the response to the client. + scope = start_active_span(request_name, child_of=context, finish_on_close=False) + request.set_opentracing_span(scope.span) with scope: inject_response_headers(request.responseHeaders) -- cgit 1.5.1 From cbd82d0b2db069400b5d43373838817d8a0209e7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Dec 2021 13:47:12 -0500 Subject: Convert all namedtuples to attrs. (#11665) To improve type hints throughout the code. --- changelog.d/11665.misc | 1 + synapse/api/filtering.py | 3 +- synapse/config/repository.py | 34 +++---- synapse/federation/federation_base.py | 5 - synapse/federation/send_queue.py | 47 +++++----- synapse/handlers/appservice.py | 4 +- synapse/handlers/directory.py | 10 +- synapse/handlers/room_list.py | 22 ++--- synapse/handlers/typing.py | 14 ++- synapse/http/server.py | 10 +- synapse/replication/tcp/streams/_base.py | 129 +++++++++++++------------- synapse/replication/tcp/streams/federation.py | 15 ++- synapse/rest/media/v1/media_repository.py | 19 ++-- synapse/state/__init__.py | 5 +- synapse/storage/databases/main/directory.py | 10 +- synapse/storage/databases/main/events.py | 13 ++- synapse/storage/databases/main/room.py | 26 ++++-- synapse/storage/databases/main/search.py | 16 +++- synapse/storage/databases/main/state.py | 14 --- synapse/storage/databases/main/stream.py | 12 ++- synapse/types.py | 22 ++--- tests/replication/test_federation_ack.py | 6 +- 22 files changed, 231 insertions(+), 206 deletions(-) create mode 100644 changelog.d/11665.misc (limited to 'synapse/federation') diff --git a/changelog.d/11665.misc b/changelog.d/11665.misc new file mode 100644 index 0000000000..e7cc8ff23f --- /dev/null +++ b/changelog.d/11665.misc @@ -0,0 +1 @@ +Convert `namedtuples` to `attrs`. diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 13dd6ce248..d087c816db 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -351,8 +351,7 @@ class Filter: True if the event matches the filter. """ # We usually get the full "events" as dictionaries coming through, - # except for presence which actually gets passed around as its own - # namedtuple type. + # except for presence which actually gets passed around as its own type. if isinstance(event, UserPresenceState): user_id = event.user_id field_matchers = { diff --git a/synapse/config/repository.py b/synapse/config/repository.py index b129b9dd68..1980351e77 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -14,10 +14,11 @@ import logging import os -from collections import namedtuple from typing import Dict, List, Tuple from urllib.request import getproxies_environment # type: ignore +import attr + from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import JsonDict @@ -44,18 +45,20 @@ THUMBNAIL_SIZE_YAML = """\ HTTP_PROXY_SET_WARNING = """\ The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured.""" -ThumbnailRequirement = namedtuple( - "ThumbnailRequirement", ["width", "height", "method", "media_type"] -) -MediaStorageProviderConfig = namedtuple( - "MediaStorageProviderConfig", - ( - "store_local", # Whether to store newly uploaded local files - "store_remote", # Whether to store newly downloaded remote files - "store_synchronous", # Whether to wait for successful storage for local uploads - ), -) +@attr.s(frozen=True, slots=True, auto_attribs=True) +class ThumbnailRequirement: + width: int + height: int + method: str + media_type: str + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class MediaStorageProviderConfig: + store_local: bool # Whether to store newly uploaded local files + store_remote: bool # Whether to store newly downloaded remote files + store_synchronous: bool # Whether to wait for successful storage for local uploads def parse_thumbnail_requirements( @@ -66,11 +69,10 @@ def parse_thumbnail_requirements( method, and thumbnail media type to precalculate Args: - thumbnail_sizes(list): List of dicts with "width", "height", and - "method" keys + thumbnail_sizes: List of dicts with "width", "height", and "method" keys + Returns: - Dictionary mapping from media type string to list of - ThumbnailRequirement tuples. + Dictionary mapping from media type string to list of ThumbnailRequirement. """ requirements: Dict[str, List[ThumbnailRequirement]] = {} for size in thumbnail_sizes: diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index f56344a3b9..4df90e02d7 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from collections import namedtuple from typing import TYPE_CHECKING from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership @@ -104,10 +103,6 @@ class FederationBase: return pdu -class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])): - pass - - async def _check_sigs_on_pdu( keyring: Keyring, room_version: RoomVersion, pdu: EventBase ) -> None: diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 63289a5a33..0d7c4f5067 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -30,7 +30,6 @@ Events are replicated via a separate events stream. """ import logging -from collections import namedtuple from typing import ( TYPE_CHECKING, Dict, @@ -43,6 +42,7 @@ from typing import ( Type, ) +import attr from sortedcontainers import SortedDict from synapse.api.presence import UserPresenceState @@ -382,13 +382,11 @@ class BaseFederationRow: raise NotImplementedError() -class PresenceDestinationsRow( - BaseFederationRow, - namedtuple( - "PresenceDestinationsRow", - ("state", "destinations"), # UserPresenceState # list[str] - ), -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class PresenceDestinationsRow(BaseFederationRow): + state: UserPresenceState + destinations: List[str] + TypeId = "pd" @staticmethod @@ -404,17 +402,15 @@ class PresenceDestinationsRow( buff.presence_destinations.append((self.state, self.destinations)) -class KeyedEduRow( - BaseFederationRow, - namedtuple( - "KeyedEduRow", - ("key", "edu"), # tuple(str) - the edu key passed to send_edu # Edu - ), -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class KeyedEduRow(BaseFederationRow): """Streams EDUs that have an associated key that is ued to clobber. For example, typing EDUs clobber based on room_id. """ + key: Tuple[str, ...] # the edu key passed to send_edu + edu: Edu + TypeId = "k" @staticmethod @@ -428,9 +424,12 @@ class KeyedEduRow( buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu -class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EduRow(BaseFederationRow): """Streams EDUs that don't have keys. See KeyedEduRow""" + edu: Edu + TypeId = "e" @staticmethod @@ -453,14 +452,14 @@ _rowtypes: Tuple[Type[BaseFederationRow], ...] = ( TypeToRow = {Row.TypeId: Row for Row in _rowtypes} -ParsedFederationStreamData = namedtuple( - "ParsedFederationStreamData", - ( - "presence_destinations", # list of tuples of UserPresenceState and destinations - "keyed_edus", # dict of destination -> { key -> Edu } - "edus", # dict of destination -> [Edu] - ), -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ParsedFederationStreamData: + # list of tuples of UserPresenceState and destinations + presence_destinations: List[Tuple[UserPresenceState, List[str]]] + # dict of destination -> { key -> Edu } + keyed_edus: Dict[str, Dict[Tuple[str, ...], Edu]] + # dict of destination -> [Edu] + edus: Dict[str, List[Edu]] def process_rows_for_federation( diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 9abdad262b..7833e77e2b 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -462,9 +462,9 @@ class ApplicationServicesHandler: Args: room_alias: The room alias to query. + Returns: - namedtuple: with keys "room_id" and "servers" or None if no - association can be found. + RoomAliasMapping or None if no association can be found. """ room_alias_str = room_alias.to_string() services = self.store.get_app_services() diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 7ee5c47fd9..082f521791 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -278,13 +278,15 @@ class DirectoryHandler: users = await self.store.get_users_in_room(room_id) extra_servers = {get_domain_from_id(u) for u in users} - servers = set(extra_servers) | set(servers) + servers_set = set(extra_servers) | set(servers) # If this server is in the list of servers, return it first. - if self.server_name in servers: - servers = [self.server_name] + [s for s in servers if s != self.server_name] + if self.server_name in servers_set: + servers = [self.server_name] + [ + s for s in servers_set if s != self.server_name + ] else: - servers = list(servers) + servers = list(servers_set) return {"room_id": room_id, "servers": servers} diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index ba7a14d651..1a33211a1f 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -13,9 +13,9 @@ # limitations under the License. import logging -from collections import namedtuple from typing import TYPE_CHECKING, Any, Optional, Tuple +import attr import msgpack from unpaddedbase64 import decode_base64, encode_base64 @@ -474,16 +474,12 @@ class RoomListHandler: ) -class RoomListNextBatch( - namedtuple( - "RoomListNextBatch", - ( - "last_joined_members", # The count to get rooms after/before - "last_room_id", # The room_id to get rooms after/before - "direction_is_forward", # Bool if this is a next_batch, false if prev_batch - ), - ) -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomListNextBatch: + last_joined_members: int # The count to get rooms after/before + last_room_id: str # The room_id to get rooms after/before + direction_is_forward: bool # True if this is a next_batch, false if prev_batch + KEY_DICT = { "last_joined_members": "m", "last_room_id": "r", @@ -502,12 +498,12 @@ class RoomListNextBatch( def to_token(self) -> str: return encode_base64( msgpack.dumps( - {self.KEY_DICT[key]: val for key, val in self._asdict().items()} + {self.KEY_DICT[key]: val for key, val in attr.asdict(self).items()} ) ) def copy_and_replace(self, **kwds: Any) -> "RoomListNextBatch": - return self._replace(**kwds) + return attr.evolve(self, **kwds) def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 1676ebd057..e43c22832d 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -13,9 +13,10 @@ # limitations under the License. import logging import random -from collections import namedtuple from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple +import attr + from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.appservice import ApplicationService from synapse.metrics.background_process_metrics import ( @@ -37,7 +38,10 @@ logger = logging.getLogger(__name__) # A tiny object useful for storing a user's membership in a room, as a mapping # key -RoomMember = namedtuple("RoomMember", ("room_id", "user_id")) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomMember: + room_id: str + user_id: str # How often we expect remote servers to resend us presence. @@ -119,7 +123,7 @@ class FollowerTypingHandler: self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) def is_typing(self, member: RoomMember) -> bool: - return member.user_id in self._room_typing.get(member.room_id, []) + return member.user_id in self._room_typing.get(member.room_id, set()) async def _push_remote(self, member: RoomMember, typing: bool) -> None: if not self.federation: @@ -166,9 +170,9 @@ class FollowerTypingHandler: for row in rows: self._room_serials[row.room_id] = token - prev_typing = set(self._room_typing.get(row.room_id, [])) + prev_typing = self._room_typing.get(row.room_id, set()) now_typing = set(row.user_ids) - self._room_typing[row.room_id] = row.user_ids + self._room_typing[row.room_id] = now_typing if self.federation: run_as_background_process( diff --git a/synapse/http/server.py b/synapse/http/server.py index e302946591..09b4125489 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -14,7 +14,6 @@ # limitations under the License. import abc -import collections import html import logging import types @@ -37,6 +36,7 @@ from typing import ( Union, ) +import attr import jinja2 from canonicaljson import encode_canonical_json from typing_extensions import Protocol @@ -354,9 +354,11 @@ class DirectServeJsonResource(_AsyncResource): return_json_error(f, request) -_PathEntry = collections.namedtuple( - "_PathEntry", ["pattern", "callback", "servlet_classname"] -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _PathEntry: + pattern: Pattern + callback: ServletCallback + servlet_classname: str class JsonResource(DirectServeJsonResource): diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 743a01da08..5a2d90c530 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -15,7 +15,6 @@ import heapq import logging -from collections import namedtuple from typing import ( TYPE_CHECKING, Any, @@ -30,6 +29,7 @@ from typing import ( import attr from synapse.replication.http.streams import ReplicationGetStreamUpdates +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -226,17 +226,14 @@ class BackfillStream(Stream): or it went from being an outlier to not. """ - BackfillStreamRow = namedtuple( - "BackfillStreamRow", - ( - "event_id", # str - "room_id", # str - "type", # str - "state_key", # str, optional - "redacts", # str, optional - "relates_to", # str, optional - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class BackfillStreamRow: + event_id: str + room_id: str + type: str + state_key: Optional[str] + redacts: Optional[str] + relates_to: Optional[str] NAME = "backfill" ROW_TYPE = BackfillStreamRow @@ -256,18 +253,15 @@ class BackfillStream(Stream): class PresenceStream(Stream): - PresenceStreamRow = namedtuple( - "PresenceStreamRow", - ( - "user_id", # str - "state", # str - "last_active_ts", # int - "last_federation_update_ts", # int - "last_user_sync_ts", # int - "status_msg", # str - "currently_active", # bool - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class PresenceStreamRow: + user_id: str + state: str + last_active_ts: int + last_federation_update_ts: int + last_user_sync_ts: int + status_msg: str + currently_active: bool NAME = "presence" ROW_TYPE = PresenceStreamRow @@ -302,7 +296,7 @@ class PresenceFederationStream(Stream): send. """ - @attr.s(slots=True, auto_attribs=True) + @attr.s(slots=True, frozen=True, auto_attribs=True) class PresenceFederationStreamRow: destination: str user_id: str @@ -320,9 +314,10 @@ class PresenceFederationStream(Stream): class TypingStream(Stream): - TypingStreamRow = namedtuple( - "TypingStreamRow", ("room_id", "user_ids") # str # list(str) - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class TypingStreamRow: + room_id: str + user_ids: List[str] NAME = "typing" ROW_TYPE = TypingStreamRow @@ -348,16 +343,13 @@ class TypingStream(Stream): class ReceiptsStream(Stream): - ReceiptsStreamRow = namedtuple( - "ReceiptsStreamRow", - ( - "room_id", # str - "receipt_type", # str - "user_id", # str - "event_id", # str - "data", # dict - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ReceiptsStreamRow: + room_id: str + receipt_type: str + user_id: str + event_id: str + data: dict NAME = "receipts" ROW_TYPE = ReceiptsStreamRow @@ -374,7 +366,9 @@ class ReceiptsStream(Stream): class PushRulesStream(Stream): """A user has changed their push rules""" - PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str + @attr.s(slots=True, frozen=True, auto_attribs=True) + class PushRulesStreamRow: + user_id: str NAME = "push_rules" ROW_TYPE = PushRulesStreamRow @@ -396,10 +390,12 @@ class PushRulesStream(Stream): class PushersStream(Stream): """A user has added/changed/removed a pusher""" - PushersStreamRow = namedtuple( - "PushersStreamRow", - ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class PushersStreamRow: + user_id: str + app_id: str + pushkey: str + deleted: bool NAME = "pushers" ROW_TYPE = PushersStreamRow @@ -419,7 +415,7 @@ class CachesStream(Stream): the cache on the workers """ - @attr.s(slots=True) + @attr.s(slots=True, frozen=True, auto_attribs=True) class CachesStreamRow: """Stream to inform workers they should invalidate their cache. @@ -430,9 +426,9 @@ class CachesStream(Stream): invalidation_ts: Timestamp of when the invalidation took place. """ - cache_func = attr.ib(type=str) - keys = attr.ib(type=Optional[List[Any]]) - invalidation_ts = attr.ib(type=int) + cache_func: str + keys: Optional[List[Any]] + invalidation_ts: int NAME = "caches" ROW_TYPE = CachesStreamRow @@ -451,9 +447,9 @@ class DeviceListsStream(Stream): told about a device update. """ - @attr.s(slots=True) + @attr.s(slots=True, frozen=True, auto_attribs=True) class DeviceListsStreamRow: - entity = attr.ib(type=str) + entity: str NAME = "device_lists" ROW_TYPE = DeviceListsStreamRow @@ -470,7 +466,9 @@ class DeviceListsStream(Stream): class ToDeviceStream(Stream): """New to_device messages for a client""" - ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ToDeviceStreamRow: + entity: str NAME = "to_device" ROW_TYPE = ToDeviceStreamRow @@ -487,9 +485,11 @@ class ToDeviceStream(Stream): class TagAccountDataStream(Stream): """Someone added/removed a tag for a room""" - TagAccountDataStreamRow = namedtuple( - "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class TagAccountDataStreamRow: + user_id: str + room_id: str + data: JsonDict NAME = "tag_account_data" ROW_TYPE = TagAccountDataStreamRow @@ -506,10 +506,11 @@ class TagAccountDataStream(Stream): class AccountDataStream(Stream): """Global or per room account data was changed""" - AccountDataStreamRow = namedtuple( - "AccountDataStreamRow", - ("user_id", "room_id", "data_type"), # str # Optional[str] # str - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class AccountDataStreamRow: + user_id: str + room_id: Optional[str] + data_type: str NAME = "account_data" ROW_TYPE = AccountDataStreamRow @@ -573,10 +574,12 @@ class AccountDataStream(Stream): class GroupServerStream(Stream): - GroupsStreamRow = namedtuple( - "GroupsStreamRow", - ("group_id", "user_id", "type", "content"), # str # str # str # dict - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class GroupsStreamRow: + group_id: str + user_id: str + type: str + content: JsonDict NAME = "groups" ROW_TYPE = GroupsStreamRow @@ -593,7 +596,9 @@ class GroupServerStream(Stream): class UserSignatureStream(Stream): """A user has signed their own device with their user-signing key""" - UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str + @attr.s(slots=True, frozen=True, auto_attribs=True) + class UserSignatureStreamRow: + user_id: str NAME = "user_signature" ROW_TYPE = UserSignatureStreamRow diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index 0600cdbf36..4046bdec69 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -12,14 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Tuple +import attr + from synapse.replication.tcp.streams._base import ( Stream, current_token_without_instance, make_http_update_function, ) +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -30,13 +32,10 @@ class FederationStream(Stream): sending disabled. """ - FederationStreamRow = namedtuple( - "FederationStreamRow", - ( - "type", # str, the type of data as defined in the BaseFederationRows - "data", # dict, serialization of a federation.send_queue.BaseFederationRow - ), - ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class FederationStreamRow: + type: str # the type of data as defined in the BaseFederationRows + data: JsonDict # serialization of a federation.send_queue.BaseFederationRow NAME = "federation" ROW_TYPE = FederationStreamRow diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 244ba261bb..71b9a34b14 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -739,14 +739,21 @@ class MediaRepository: # We deduplicate the thumbnail sizes by ignoring the cropped versions if # they have the same dimensions of a scaled one. thumbnails: Dict[Tuple[int, int, str], str] = {} - for r_width, r_height, r_method, r_type in requirements: - if r_method == "crop": - thumbnails.setdefault((r_width, r_height, r_type), r_method) - elif r_method == "scale": - t_width, t_height = thumbnailer.aspect(r_width, r_height) + for requirement in requirements: + if requirement.method == "crop": + thumbnails.setdefault( + (requirement.width, requirement.height, requirement.media_type), + requirement.method, + ) + elif requirement.method == "scale": + t_width, t_height = thumbnailer.aspect( + requirement.width, requirement.height + ) t_width = min(m_width, t_width) t_height = min(m_height, t_height) - thumbnails[(t_width, t_height, r_type)] = r_method + thumbnails[ + (t_width, t_height, requirement.media_type) + ] = requirement.method # Now we generate the thumbnails for each dimension, store it for (t_width, t_height, t_type), t_method in thumbnails.items(): diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 446204dbe5..69ac8c3423 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. import heapq import logging -from collections import defaultdict, namedtuple +from collections import defaultdict from typing import ( TYPE_CHECKING, Any, @@ -69,9 +69,6 @@ state_groups_histogram = Histogram( ) -KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) - - EVICTION_TIMEOUT_SECONDS = 60 * 60 diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index a3442814d7..f76c6121e8 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -12,16 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from typing import Iterable, List, Optional, Tuple +import attr + from synapse.api.errors import SynapseError from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.types import RoomAlias from synapse.util.caches.descriptors import cached -RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomAliasMapping: + room_id: str + room_alias: str + servers: List[str] class DirectoryWorkerStore(CacheInvalidationWorkerStore): diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 81e67ece55..dd255aefb9 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1976,14 +1976,17 @@ class PersistEventsStore: txn, self.store.get_retention_policy_for_room, (event.room_id,) ) - def store_event_search_txn(self, txn, event, key, value): + def store_event_search_txn( + self, txn: LoggingTransaction, event: EventBase, key: str, value: str + ) -> None: """Add event to the search table Args: - txn (cursor): - event (EventBase): - key (str): - value (str): + txn: The database transaction. + event: The event being added to the search table. + key: A key describing the search value (one of "content.name", + "content.topic", or "content.body") + value: The value from the event's content. """ self.store.store_search_entries_txn( txn, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 4472335af9..c0e837854a 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -13,11 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import logging from abc import abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Dict, + List, + Optional, + Tuple, + Union, + cast, +) + +import attr from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.errors import StoreError @@ -43,9 +54,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -RatelimitOverride = collections.namedtuple( - "RatelimitOverride", ("messages_per_second", "burst_count") -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RatelimitOverride: + messages_per_second: int + burst_count: int class RoomSortOrder(Enum): @@ -207,6 +219,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): WHERE appservice_id = ? AND network_id = ? """ query_args.append(network_tuple.appservice_id) + assert network_tuple.network_id is not None query_args.append(network_tuple.network_id) else: published_sql = """ @@ -284,7 +297,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): """ where_clauses = [] - query_args = [] + query_args: List[Union[str, int]] = [] if network_tuple: if network_tuple.appservice_id: @@ -293,6 +306,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): WHERE appservice_id = ? AND network_id = ? """ query_args.append(network_tuple.appservice_id) + assert network_tuple.network_id is not None query_args.append(network_tuple.network_id) else: published_sql = """ diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index f87acfb866..2d085a5764 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -14,9 +14,10 @@ import logging import re -from collections import namedtuple from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set +import attr + from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -33,10 +34,15 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -SearchEntry = namedtuple( - "SearchEntry", - ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"], -) + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class SearchEntry: + key: str + value: str + event_id: str + room_id: str + stream_ordering: Optional[int] + origin_server_ts: int def _clean_value_for_search(value: str) -> str: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 4bc044fb16..7e5a6aae18 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -14,7 +14,6 @@ # limitations under the License. import collections.abc import logging -from collections import namedtuple from typing import TYPE_CHECKING, Iterable, Optional, Set from synapse.api.constants import EventTypes, Membership @@ -43,19 +42,6 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 -class _GetStateGroupDelta( - namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) -): - """Return type of get_state_group_delta that implements __len__, which lets - us use the itrable flag when caching - """ - - __slots__ = [] - - def __len__(self): - return len(self.delta_ids) if self.delta_ids else 0 - - # this inherits from EventsWorkerStore because it calls self.get_events class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers.""" diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 9488fd5094..b0642ca69f 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -36,9 +36,9 @@ what sort order was used: """ import abc import logging -from collections import namedtuple from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple +import attr from frozendict import frozendict from twisted.internet import defer @@ -74,9 +74,11 @@ _TOPOLOGICAL_TOKEN = "topological" # Used as return values for pagination APIs -_EventDictReturn = namedtuple( - "_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering") -) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _EventDictReturn: + event_id: str + topological_ordering: Optional[int] + stream_ordering: int def generate_pagination_where_clause( @@ -825,7 +827,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): for event, row in zip(events, rows): stream = row.stream_ordering if topo_order and row.topological_ordering: - topo = row.topological_ordering + topo: Optional[int] = row.topological_ordering else: topo = None internal = event.internal_metadata diff --git a/synapse/types.py b/synapse/types.py index b06979e8e8..42aeaf6270 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -15,7 +15,6 @@ import abc import re import string -from collections import namedtuple from typing import ( TYPE_CHECKING, Any, @@ -227,8 +226,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta): localpart = attr.ib(type=str) domain = attr.ib(type=str) - # Because this class is a namedtuple of strings and booleans, it is deeply - # immutable. + # Because this is a frozen class, it is deeply immutable. def __copy__(self): return self @@ -708,16 +706,18 @@ class PersistedEventPosition: return RoomStreamToken(None, self.stream) -class ThirdPartyInstanceID( - namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id")) -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThirdPartyInstanceID: + appservice_id: Optional[str] + network_id: Optional[str] + # Deny iteration because it will bite you if you try to create a singleton # set by: # users = set(user) def __iter__(self): raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) - # Because this class is a namedtuple of strings, it is deeply immutable. + # Because this class is a frozen class, it is deeply immutable. def __copy__(self): return self @@ -725,22 +725,18 @@ class ThirdPartyInstanceID( return self @classmethod - def from_string(cls, s): + def from_string(cls, s: str) -> "ThirdPartyInstanceID": bits = s.split("|", 2) if len(bits) != 2: raise SynapseError(400, "Invalid ID %r" % (s,)) return cls(appservice_id=bits[0], network_id=bits[1]) - def to_string(self): + def to_string(self) -> str: return "%s|%s" % (self.appservice_id, self.network_id) __str__ = to_string - @classmethod - def create(cls, appservice_id, network_id): - return cls(appservice_id=appservice_id, network_id=network_id) - @attr.s(slots=True) class ReadReceipt: diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 04a869e295..1b6a4bf4b0 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -62,7 +62,11 @@ class FederationAckTestCase(HomeserverTestCase): "federation", "master", token=10, - rows=[FederationStream.FederationStreamRow(type="x", data=[1, 2, 3])], + rows=[ + FederationStream.FederationStreamRow( + type="x", data={"test": [1, 2, 3]} + ) + ], ) ) -- cgit 1.5.1 From 878aa5529399ece8fbb68674a6de86d6958bacae Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 4 Jan 2022 16:31:32 +0000 Subject: `FederationClient.backfill`: stop flagging events as outliers (#11632) Events returned by `backfill` should not be flagged as outliers. Fixes: ``` AssertionError: null File "synapse/handlers/federation.py", line 313, in try_backfill dom, room_id, limit=100, extremities=extremities File "synapse/handlers/federation_event.py", line 517, in backfill await self._process_pulled_events(dest, events, backfilled=True) File "synapse/handlers/federation_event.py", line 642, in _process_pulled_events await self._process_pulled_event(origin, ev, backfilled=backfilled) File "synapse/handlers/federation_event.py", line 669, in _process_pulled_event assert not event.internal_metadata.is_outlier() ``` See https://sentry.matrix.org/sentry/synapse-matrixorg/issues/231992 Fixes #8894. --- changelog.d/11632.bugfix | 1 + synapse/federation/federation_client.py | 2 +- synapse/handlers/federation_event.py | 4 +++- 3 files changed, 5 insertions(+), 2 deletions(-) create mode 100644 changelog.d/11632.bugfix (limited to 'synapse/federation') diff --git a/changelog.d/11632.bugfix b/changelog.d/11632.bugfix new file mode 100644 index 0000000000..c73d41652a --- /dev/null +++ b/changelog.d/11632.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.19.3 which could sometimes cause `AssertionError`s when backfilling rooms over federation. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index fee1477ab6..7353c2b6b1 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -272,7 +272,7 @@ class FederationClient(FederationBase): # Check signatures and hash of pdus, removing any from the list that fail checks pdus[:] = await self._check_sigs_and_hash_and_fetch( - dest, pdus, outlier=True, room_version=room_version + dest, pdus, room_version=room_version ) return pdus diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index d08e48da58..7c81e3651f 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -666,7 +666,9 @@ class FederationEventHandler: logger.info("Processing pulled event %s", event) # these should not be outliers. - assert not event.internal_metadata.is_outlier() + assert ( + not event.internal_metadata.is_outlier() + ), "pulled event unexpectedly flagged as outlier" event_id = event.event_id -- cgit 1.5.1 From 84bfe47b01402cf037417f0026ed4077223ed259 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 5 Jan 2022 11:41:49 +0000 Subject: Re-apply: Move glob_to_regex and re_word_boundary to matrix-python-common #11505 (#11687) Co-authored-by: Sean Quah --- changelog.d/11505.misc | 1 + changelog.d/11687.misc | 1 + synapse/config/room_directory.py | 3 +- synapse/config/tls.py | 3 +- synapse/federation/federation_server.py | 3 +- synapse/push/push_rule_evaluator.py | 7 ++-- synapse/python_dependencies.py | 1 + synapse/util/__init__.py | 59 +-------------------------------- tests/util/test_glob_to_regex.py | 59 --------------------------------- 9 files changed, 14 insertions(+), 123 deletions(-) create mode 100644 changelog.d/11505.misc create mode 100644 changelog.d/11687.misc delete mode 100644 tests/util/test_glob_to_regex.py (limited to 'synapse/federation') diff --git a/changelog.d/11505.misc b/changelog.d/11505.misc new file mode 100644 index 0000000000..926b562fad --- /dev/null +++ b/changelog.d/11505.misc @@ -0,0 +1 @@ +Move `glob_to_regex` and `re_word_boundary` to `matrix-python-common`. diff --git a/changelog.d/11687.misc b/changelog.d/11687.misc new file mode 100644 index 0000000000..926b562fad --- /dev/null +++ b/changelog.d/11687.misc @@ -0,0 +1 @@ +Move `glob_to_regex` and `re_word_boundary` to `matrix-python-common`. diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 57316c59b6..3c5e0f7ce7 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -15,8 +15,9 @@ from typing import List +from matrix_common.regex import glob_to_regex + from synapse.types import JsonDict -from synapse.util import glob_to_regex from ._base import Config, ConfigError diff --git a/synapse/config/tls.py b/synapse/config/tls.py index ffb316e4c0..6e673d65a7 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -16,11 +16,12 @@ import logging import os from typing import List, Optional, Pattern +from matrix_common.regex import glob_to_regex + from OpenSSL import SSL, crypto from twisted.internet._sslverify import Certificate, trustRootFromCertificates from synapse.config._base import Config, ConfigError -from synapse.util import glob_to_regex logger = logging.getLogger(__name__) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index cf067b56c6..ee71f289c8 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -28,6 +28,7 @@ from typing import ( Union, ) +from matrix_common.regex import glob_to_regex from prometheus_client import Counter, Gauge, Histogram from twisted.internet.abstract import isIPAddress @@ -65,7 +66,7 @@ from synapse.replication.http.federation import ( ) from synapse.storage.databases.main.lock import Lock from synapse.types import JsonDict, get_domain_from_id -from synapse.util import glob_to_regex, json_decoder, unwrapFirstError +from synapse.util import json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_server_name diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 7f68092ec5..659a53805d 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -17,9 +17,10 @@ import logging import re from typing import Any, Dict, List, Optional, Pattern, Tuple, Union +from matrix_common.regex import glob_to_regex, to_word_pattern + from synapse.events import EventBase from synapse.types import JsonDict, UserID -from synapse.util import glob_to_regex, re_word_boundary from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -184,7 +185,7 @@ class PushRuleEvaluatorForEvent: r = regex_cache.get((display_name, False, True), None) if not r: r1 = re.escape(display_name) - r1 = re_word_boundary(r1) + r1 = to_word_pattern(r1) r = re.compile(r1, flags=re.IGNORECASE) regex_cache[(display_name, False, True)] = r @@ -213,7 +214,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: try: r = regex_cache.get((glob, True, word_boundary), None) if not r: - r = glob_to_regex(glob, word_boundary) + r = glob_to_regex(glob, word_boundary=word_boundary) regex_cache[(glob, True, word_boundary)] = r return bool(r.search(value)) except re.error: diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 13fb69460e..d844fbb3b3 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -88,6 +88,7 @@ REQUIREMENTS = [ # with the latest security patches. "cryptography>=3.4.7", "ijson>=3.1", + "matrix-common==1.0.0", ] CONDITIONAL_REQUIREMENTS = { diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 95f23e27b6..f157132210 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -14,9 +14,8 @@ import json import logging -import re import typing -from typing import Any, Callable, Dict, Generator, Optional, Pattern +from typing import Any, Callable, Dict, Generator, Optional import attr from frozendict import frozendict @@ -35,9 +34,6 @@ if typing.TYPE_CHECKING: logger = logging.getLogger(__name__) -_WILDCARD_RUN = re.compile(r"([\?\*]+)") - - def _reject_invalid_json(val: Any) -> None: """Do not allow Infinity, -Infinity, or NaN values in JSON.""" raise ValueError("Invalid JSON value: '%s'" % val) @@ -185,56 +181,3 @@ def log_failure( if not consumeErrors: return failure return None - - -def glob_to_regex(glob: str, word_boundary: bool = False) -> Pattern: - """Converts a glob to a compiled regex object. - - Args: - glob: pattern to match - word_boundary: If True, the pattern will be allowed to match at word boundaries - anywhere in the string. Otherwise, the pattern is anchored at the start and - end of the string. - - Returns: - compiled regex pattern - """ - - # Patterns with wildcards must be simplified to avoid performance cliffs - # - The glob `?**?**?` is equivalent to the glob `???*` - # - The glob `???*` is equivalent to the regex `.{3,}` - chunks = [] - for chunk in _WILDCARD_RUN.split(glob): - # No wildcards? re.escape() - if not _WILDCARD_RUN.match(chunk): - chunks.append(re.escape(chunk)) - continue - - # Wildcards? Simplify. - qmarks = chunk.count("?") - if "*" in chunk: - chunks.append(".{%d,}" % qmarks) - else: - chunks.append(".{%d}" % qmarks) - - res = "".join(chunks) - - if word_boundary: - res = re_word_boundary(res) - else: - # \A anchors at start of string, \Z at end of string - res = r"\A" + res + r"\Z" - - return re.compile(res, re.IGNORECASE) - - -def re_word_boundary(r: str) -> str: - """ - Adds word boundary characters to the start and end of an - expression to require that the match occur as a whole word, - but do so respecting the fact that strings starting or ending - with non-word characters will change word boundaries. - """ - # we can't use \b as it chokes on unicode. however \W seems to be okay - # as shorthand for [^0-9A-Za-z_]. - return r"(^|\W)%s(\W|$)" % (r,) diff --git a/tests/util/test_glob_to_regex.py b/tests/util/test_glob_to_regex.py deleted file mode 100644 index 220accb92b..0000000000 --- a/tests/util/test_glob_to_regex.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from synapse.util import glob_to_regex - -from tests.unittest import TestCase - - -class GlobToRegexTestCase(TestCase): - def test_literal_match(self): - """patterns without wildcards should match""" - pat = glob_to_regex("foobaz") - self.assertTrue( - pat.match("FoobaZ"), "patterns should match and be case-insensitive" - ) - self.assertFalse( - pat.match("x foobaz"), "pattern should not match at word boundaries" - ) - - def test_wildcard_match(self): - pat = glob_to_regex("f?o*baz") - - self.assertTrue( - pat.match("FoobarbaZ"), - "* should match string and pattern should be case-insensitive", - ) - self.assertTrue(pat.match("foobaz"), "* should match 0 characters") - self.assertFalse(pat.match("fooxaz"), "the character after * must match") - self.assertFalse(pat.match("fobbaz"), "? should not match 0 characters") - self.assertFalse(pat.match("fiiobaz"), "? should not match 2 characters") - - def test_multi_wildcard(self): - """patterns with multiple wildcards in a row should match""" - pat = glob_to_regex("**baz") - self.assertTrue(pat.match("agsgsbaz"), "** should match any string") - self.assertTrue(pat.match("baz"), "** should match the empty string") - self.assertEqual(pat.pattern, r"\A.{0,}baz\Z") - - pat = glob_to_regex("*?baz") - self.assertTrue(pat.match("agsgsbaz"), "*? should match any string") - self.assertTrue(pat.match("abaz"), "*? should match a single char") - self.assertFalse(pat.match("baz"), "*? should not match the empty string") - self.assertEqual(pat.pattern, r"\A.{1,}baz\Z") - - pat = glob_to_regex("a?*?*?baz") - self.assertTrue(pat.match("a g baz"), "?*?*? should match 3 chars") - self.assertFalse(pat.match("a..baz"), "?*?*? should not match 2 chars") - self.assertTrue(pat.match("a.gg.baz"), "?*?*? should match 4 chars") - self.assertEqual(pat.pattern, r"\Aa.{3,}baz\Z") -- cgit 1.5.1 From 0fb3dd0830e476c0e0b89c3bf6c7855a4129ff11 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 5 Jan 2022 12:26:11 +0000 Subject: Refactor the way we set `outlier` (#11634) * `_auth_and_persist_outliers`: mark persisted events as outliers Mark any events that get persisted via `_auth_and_persist_outliers` as, well, outliers. Currently this will be a no-op as everything will already be flagged as an outlier, but I'm going to change that. * `process_remote_join`: stop flagging as outlier The events are now flagged as outliers later on, by `_auth_and_persist_outliers`. * `send_join`: remove `outlier=True` The events created here are returned in the result of `send_join` to `FederationHandler.do_invite_join`. From there they are passed into `FederationEventHandler.process_remote_join`, which passes them to `_auth_and_persist_outliers`... which sets the `outlier` flag. * `get_event_auth`: remove `outlier=True` stop flagging the events returned by `get_event_auth` as outliers. This method is only called by `_get_remote_auth_chain_for_event`, which passes the results into `_auth_and_persist_outliers`, which will flag them as outliers. * `_get_remote_auth_chain_for_event`: remove `outlier=True` we pass all the events into `_auth_and_persist_outliers`, which will now flag the events as outliers. * `_check_sigs_and_hash_and_fetch`: remove unused `outlier` parameter This param is now never set to True, so we can remove it. * `_check_sigs_and_hash_and_fetch_one`: remove unused `outlier` param This is no longer set anywhere, so we can remove it. * `get_pdu`: remove unused `outlier` parameter ... and chase it down into `get_pdu_from_destination_raw`. * `event_from_pdu_json`: remove redundant `outlier` param This is never set to `True`, so can be removed. * changelog * update docstring --- changelog.d/11634.misc | 1 + synapse/federation/federation_base.py | 7 +------ synapse/federation/federation_client.py | 36 +++++---------------------------- synapse/handlers/federation_event.py | 15 +++++++------- tests/handlers/test_federation.py | 4 +--- 5 files changed, 16 insertions(+), 47 deletions(-) create mode 100644 changelog.d/11634.misc (limited to 'synapse/federation') diff --git a/changelog.d/11634.misc b/changelog.d/11634.misc new file mode 100644 index 0000000000..4069cbc2f4 --- /dev/null +++ b/changelog.d/11634.misc @@ -0,0 +1 @@ +Refactor the way that the `outlier` flag is set on events received over federation. diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 4df90e02d7..addc0bf000 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -215,15 +215,12 @@ def _is_invite_via_3pid(event: EventBase) -> bool: ) -def event_from_pdu_json( - pdu_json: JsonDict, room_version: RoomVersion, outlier: bool = False -) -> EventBase: +def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventBase: """Construct an EventBase from an event json received over federation Args: pdu_json: pdu as received over federation room_version: The version of the room this event belongs to - outlier: True to mark this event as an outlier Raises: SynapseError: if the pdu is missing required fields or is otherwise @@ -247,6 +244,4 @@ def event_from_pdu_json( validate_canonicaljson(pdu_json) event = make_event_from_dict(pdu_json, room_version) - event.internal_metadata.outlier = outlier - return event diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 7353c2b6b1..6ea4edfc71 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -265,10 +265,7 @@ class FederationClient(FederationBase): room_version = await self.store.get_room_version(room_id) - pdus = [ - event_from_pdu_json(p, room_version, outlier=False) - for p in transaction_data_pdus - ] + pdus = [event_from_pdu_json(p, room_version) for p in transaction_data_pdus] # Check signatures and hash of pdus, removing any from the list that fail checks pdus[:] = await self._check_sigs_and_hash_and_fetch( @@ -282,7 +279,6 @@ class FederationClient(FederationBase): destination: str, event_id: str, room_version: RoomVersion, - outlier: bool = False, timeout: Optional[int] = None, ) -> Optional[EventBase]: """Requests the PDU with given origin and ID from the remote home @@ -292,9 +288,6 @@ class FederationClient(FederationBase): destination: Which homeserver to query event_id: event to fetch room_version: version of the room - outlier: Indicates whether the PDU is an `outlier`, i.e. if - it's from an arbitrary point in the context as opposed to part - of the current block of PDUs. Defaults to `False` timeout: How long to try (in ms) each destination for before moving to the next destination. None indicates no timeout. @@ -316,8 +309,7 @@ class FederationClient(FederationBase): ) pdu_list: List[EventBase] = [ - event_from_pdu_json(p, room_version, outlier=outlier) - for p in transaction_data["pdus"] + event_from_pdu_json(p, room_version) for p in transaction_data["pdus"] ] if pdu_list and pdu_list[0]: @@ -334,7 +326,6 @@ class FederationClient(FederationBase): destinations: Iterable[str], event_id: str, room_version: RoomVersion, - outlier: bool = False, timeout: Optional[int] = None, ) -> Optional[EventBase]: """Requests the PDU with given origin and ID from the remote home @@ -347,9 +338,6 @@ class FederationClient(FederationBase): destinations: Which homeservers to query event_id: event to fetch room_version: version of the room - outlier: Indicates whether the PDU is an `outlier`, i.e. if - it's from an arbitrary point in the context as opposed to part - of the current block of PDUs. Defaults to `False` timeout: How long to try (in ms) each destination for before moving to the next destination. None indicates no timeout. @@ -377,7 +365,6 @@ class FederationClient(FederationBase): destination=destination, event_id=event_id, room_version=room_version, - outlier=outlier, timeout=timeout, ) @@ -435,7 +422,6 @@ class FederationClient(FederationBase): origin: str, pdus: Collection[EventBase], room_version: RoomVersion, - outlier: bool = False, ) -> List[EventBase]: """Takes a list of PDUs and checks the signatures and hashes of each one. If a PDU fails its signature check then we check if we have it in @@ -451,7 +437,6 @@ class FederationClient(FederationBase): origin pdu room_version - outlier: Whether the events are outliers or not Returns: A list of PDUs that have valid signatures and hashes. @@ -466,7 +451,6 @@ class FederationClient(FederationBase): valid_pdu = await self._check_sigs_and_hash_and_fetch_one( pdu=pdu, origin=origin, - outlier=outlier, room_version=room_version, ) @@ -482,7 +466,6 @@ class FederationClient(FederationBase): pdu: EventBase, origin: str, room_version: RoomVersion, - outlier: bool = False, ) -> Optional[EventBase]: """Takes a PDU and checks its signatures and hashes. If the PDU fails its signature check then we check if we have it in the database and if @@ -494,9 +477,6 @@ class FederationClient(FederationBase): origin pdu room_version - outlier: Whether the events are outliers or not - include_none: Whether to include None in the returned list - for events that have failed their checks Returns: The PDU (possibly redacted) if it has valid signatures and hashes. @@ -521,7 +501,6 @@ class FederationClient(FederationBase): destinations=[pdu_origin], event_id=pdu.event_id, room_version=room_version, - outlier=outlier, timeout=10000, ) except SynapseError: @@ -541,13 +520,10 @@ class FederationClient(FederationBase): room_version = await self.store.get_room_version(room_id) - auth_chain = [ - event_from_pdu_json(p, room_version, outlier=True) - for p in res["auth_chain"] - ] + auth_chain = [event_from_pdu_json(p, room_version) for p in res["auth_chain"]] signed_auth = await self._check_sigs_and_hash_and_fetch( - destination, auth_chain, outlier=True, room_version=room_version + destination, auth_chain, room_version=room_version ) return signed_auth @@ -816,7 +792,6 @@ class FederationClient(FederationBase): valid_pdu = await self._check_sigs_and_hash_and_fetch_one( pdu=event, origin=destination, - outlier=True, room_version=room_version, ) @@ -864,7 +839,6 @@ class FederationClient(FederationBase): valid_pdu = await self._check_sigs_and_hash_and_fetch_one( pdu=pdu, origin=destination, - outlier=True, room_version=room_version, ) @@ -1235,7 +1209,7 @@ class FederationClient(FederationBase): ] signed_events = await self._check_sigs_and_hash_and_fetch( - destination, events, outlier=False, room_version=room_version + destination, events, room_version=room_version ) except HttpResponseException as e: if not e.code == 400: diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 7c81e3651f..11771f3c9c 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -421,9 +421,6 @@ class FederationEventHandler: Raises: SynapseError if the response is in some way invalid. """ - for e in itertools.chain(auth_events, state): - e.internal_metadata.outlier = True - event_map = {e.event_id: e for e in itertools.chain(auth_events, state)} create_event = None @@ -1194,7 +1191,6 @@ class FederationEventHandler: [destination], event_id, room_version, - outlier=True, ) if event is None: logger.warning( @@ -1223,9 +1219,10 @@ class FederationEventHandler: """Persist a batch of outlier events fetched from remote servers. We first sort the events to make sure that we process each event's auth_events - before the event itself, and then auth and persist them. + before the event itself. - Notifies about the events where appropriate. + We then mark the events as outliers, persist them to the database, and, where + appropriate (eg, an invite), awake the notifier. Params: room_id: the room that the events are meant to be in (though this has @@ -1276,7 +1273,8 @@ class FederationEventHandler: Persists a batch of events where we have (theoretically) already persisted all of their auth events. - Notifies about the events where appropriate. + Marks the events as outliers, auths them, persists them to the database, and, + where appropriate (eg, an invite), awakes the notifier. Params: origin: where the events came from @@ -1314,6 +1312,9 @@ class FederationEventHandler: return None auth.append(ae) + # we're not bothering about room state, so flag the event as an outlier. + event.internal_metadata.outlier = True + context = EventContext.for_outlier() try: validate_event_for_room_version(room_version_obj, event) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index e1557566e4..496b581726 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -373,9 +373,7 @@ class FederationTestCase(unittest.HomeserverTestCase): destination: str, room_id: str, event_id: str ) -> List[EventBase]: return [ - event_from_pdu_json( - ae.get_pdu_json(), room_version=room_version, outlier=True - ) + event_from_pdu_json(ae.get_pdu_json(), room_version=room_version) for ae in auth_events ] -- cgit 1.5.1