summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/errors.py7
-rw-r--r--synapse/app/generic_worker.py3
-rw-r--r--synapse/app/homeserver.py4
-rw-r--r--synapse/config/server.py19
-rw-r--r--synapse/config/workers.py18
-rw-r--r--synapse/events/__init__.py227
-rw-r--r--synapse/events/third_party_rules.py9
-rw-r--r--synapse/events/validator.py2
-rw-r--r--synapse/federation/federation_server.py9
-rw-r--r--synapse/federation/transport/client.py3
-rw-r--r--synapse/handlers/appservice.py91
-rw-r--r--synapse/handlers/auth.py4
-rw-r--r--synapse/handlers/federation_event.py2
-rw-r--r--synapse/handlers/message.py14
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/room_batch.py2
-rw-r--r--synapse/handlers/room_member.py4
-rw-r--r--synapse/handlers/typing.py6
-rw-r--r--synapse/notifier.py38
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py4
-rw-r--r--synapse/push/push_rule_evaluator.py10
-rw-r--r--synapse/replication/tcp/handler.py2
-rw-r--r--synapse/replication/tcp/streams/_base.py3
-rw-r--r--synapse/rest/admin/__init__.py2
-rw-r--r--synapse/rest/admin/rooms.py141
-rw-r--r--synapse/rest/client/room.py2
-rw-r--r--synapse/rest/client/room_batch.py31
-rw-r--r--synapse/rest/media/v1/media_repository.py12
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py9
-rw-r--r--synapse/rest/media/v1/upload_resource.py2
-rw-r--r--synapse/rest/well_known.py47
-rw-r--r--synapse/server.py4
-rw-r--r--synapse/state/__init__.py2
-rw-r--r--synapse/storage/databases/main/deviceinbox.py89
-rw-r--r--synapse/storage/databases/main/devices.py4
-rw-r--r--synapse/storage/databases/main/events.py7
-rw-r--r--synapse/storage/databases/main/events_worker.py56
-rw-r--r--synapse/storage/databases/main/group_server.py17
-rw-r--r--synapse/storage/databases/main/lock.py31
-rw-r--r--synapse/storage/databases/main/presence.py2
-rw-r--r--synapse/storage/databases/main/room.py29
-rw-r--r--synapse/storage/databases/main/roommember.py8
-rw-r--r--synapse/storage/prepare_database.py23
-rw-r--r--synapse/storage/schema/main/delta/65/03remove_hidden_devices_from_device_inbox.sql22
-rw-r--r--synapse/storage/schema/main/delta/65/04_local_group_updates.sql18
-rw-r--r--synapse/util/async_helpers.py34
47 files changed, 715 insertions, 362 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py

index 355b36fc63..5ef34bce40 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.46.0rc1" +__version__ = "1.46.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 685d1c25cf..85302163da 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py
@@ -596,3 +596,10 @@ class ShadowBanError(Exception): This should be caught and a proper "fake" success response sent to the user. """ + + +class ModuleFailedException(Exception): + """ + Raised when a module API callback fails, for example because it raised an + exception. + """ diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 51eadf122d..218826741e 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py
@@ -100,6 +100,7 @@ from synapse.rest.client.register import ( from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.synapse.client import build_synapse_client_resource_tree +from synapse.rest.well_known import well_known_resource from synapse.server import HomeServer from synapse.storage.databases.main.censor_events import CensorEventsStore from synapse.storage.databases.main.client_ips import ClientIpWorkerStore @@ -318,6 +319,8 @@ class GenericWorkerServer(HomeServer): resources.update({CLIENT_API_PREFIX: resource}) resources.update(build_synapse_client_resource_tree(self)) + resources.update({"/.well-known": well_known_resource(self)}) + elif name == "federation": resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) elif name == "media": diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 93e2299266..336c279a44 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py
@@ -66,7 +66,7 @@ from synapse.rest.admin import AdminRestResource from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.synapse.client import build_synapse_client_resource_tree -from synapse.rest.well_known import WellKnownResource +from synapse.rest.well_known import well_known_resource from synapse.server import HomeServer from synapse.storage import DataStore from synapse.util.httpresourcetree import create_resource_tree @@ -189,7 +189,7 @@ class SynapseHomeServer(HomeServer): "/_matrix/client/unstable": client_resource, "/_matrix/client/v2_alpha": client_resource, "/_matrix/client/versions": client_resource, - "/.well-known/matrix/client": WellKnownResource(self), + "/.well-known": well_known_resource(self), "/_synapse/admin": AdminRestResource(self), **build_synapse_client_resource_tree(self), } diff --git a/synapse/config/server.py b/synapse/config/server.py
index ed094bdc44..a387fd9310 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py
@@ -262,6 +262,7 @@ class ServerConfig(Config): self.print_pidfile = config.get("print_pidfile") self.user_agent_suffix = config.get("user_agent_suffix") self.use_frozen_dicts = config.get("use_frozen_dicts", False) + self.serve_server_wellknown = config.get("serve_server_wellknown", False) self.public_baseurl = config.get("public_baseurl") if self.public_baseurl is not None: @@ -774,6 +775,24 @@ class ServerConfig(Config): # #public_baseurl: https://example.com/ + # Uncomment the following to tell other servers to send federation traffic on + # port 443. + # + # By default, other servers will try to reach our server on port 8448, which can + # be inconvenient in some environments. + # + # Provided 'https://<server_name>/' on port 443 is routed to Synapse, this + # option configures Synapse to serve a file at + # 'https://<server_name>/.well-known/matrix/server'. This will tell other + # servers to send traffic to port 443 instead. + # + # See https://matrix-org.github.io/synapse/latest/delegate.html for more + # information. + # + # Defaults to 'false'. + # + #serve_server_wellknown: true + # Set the soft limit on the number of file descriptors synapse can use # Zero is used to indicate synapse should set the soft limit to the # hard limit. diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 462630201d..4507992031 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py
@@ -63,7 +63,8 @@ class WriterLocations: Attributes: events: The instances that write to the event and backfill streams. - typing: The instance that writes to the typing stream. + typing: The instances that write to the typing stream. Currently + can only be a single instance. to_device: The instances that write to the to_device stream. Currently can only be a single instance. account_data: The instances that write to the account data streams. Currently @@ -75,9 +76,15 @@ class WriterLocations: """ events = attr.ib( - default=["master"], type=List[str], converter=_instance_to_list_converter + default=["master"], + type=List[str], + converter=_instance_to_list_converter, + ) + typing = attr.ib( + default=["master"], + type=List[str], + converter=_instance_to_list_converter, ) - typing = attr.ib(default="master", type=str) to_device = attr.ib( default=["master"], type=List[str], @@ -217,6 +224,11 @@ class WorkerConfig(Config): % (instance, stream) ) + if len(self.writers.typing) != 1: + raise ConfigError( + "Must only specify one instance to handle `typing` messages." + ) + if len(self.writers.to_device) != 1: raise ConfigError( "Must only specify one instance to handle `to_device` messages." diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 157669ea88..38f3cf4d33 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py
@@ -16,8 +16,23 @@ import abc import os -from typing import Dict, Optional, Tuple, Type - +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + Iterable, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +from typing_extensions import Literal from unpaddedbase64 import encode_base64 from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions @@ -26,6 +41,9 @@ from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze from synapse.util.stringutils import strtobool +if TYPE_CHECKING: + from synapse.events.builder import EventBuilder + # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents # bugs where we accidentally share e.g. signature dicts. However, converting a # dict to frozen_dicts is expensive. @@ -37,7 +55,23 @@ from synapse.util.stringutils import strtobool USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0")) -class DictProperty: +T = TypeVar("T") + + +# DictProperty (and DefaultDictProperty) require the classes they're used with to +# have a _dict property to pull properties from. +# +# TODO _DictPropertyInstance should not include EventBuilder but due to +# https://github.com/python/mypy/issues/5570 it thinks the DictProperty and +# DefaultDictProperty get applied to EventBuilder when it is in a Union with +# EventBase. This is the least invasive hack to get mypy to comply. +# +# Note that DictProperty/DefaultDictProperty cannot actually be used with +# EventBuilder as it lacks a _dict property. +_DictPropertyInstance = Union["_EventInternalMetadata", "EventBase", "EventBuilder"] + + +class DictProperty(Generic[T]): """An object property which delegates to the `_dict` within its parent object.""" __slots__ = ["key"] @@ -45,12 +79,33 @@ class DictProperty: def __init__(self, key: str): self.key = key - def __get__(self, instance, owner=None): + @overload + def __get__( + self, + instance: Literal[None], + owner: Optional[Type[_DictPropertyInstance]] = None, + ) -> "DictProperty": + ... + + @overload + def __get__( + self, + instance: _DictPropertyInstance, + owner: Optional[Type[_DictPropertyInstance]] = None, + ) -> T: + ... + + def __get__( + self, + instance: Optional[_DictPropertyInstance], + owner: Optional[Type[_DictPropertyInstance]] = None, + ) -> Union[T, "DictProperty"]: # if the property is accessed as a class property rather than an instance # property, return the property itself rather than the value if instance is None: return self try: + assert isinstance(instance, (EventBase, _EventInternalMetadata)) return instance._dict[self.key] except KeyError as e1: # We want this to look like a regular attribute error (mostly so that @@ -65,10 +120,12 @@ class DictProperty: "'%s' has no '%s' property" % (type(instance), self.key) ) from e1.__context__ - def __set__(self, instance, v): + def __set__(self, instance: _DictPropertyInstance, v: T) -> None: + assert isinstance(instance, (EventBase, _EventInternalMetadata)) instance._dict[self.key] = v - def __delete__(self, instance): + def __delete__(self, instance: _DictPropertyInstance) -> None: + assert isinstance(instance, (EventBase, _EventInternalMetadata)) try: del instance._dict[self.key] except KeyError as e1: @@ -77,7 +134,7 @@ class DictProperty: ) from e1.__context__ -class DefaultDictProperty(DictProperty): +class DefaultDictProperty(DictProperty, Generic[T]): """An extension of DictProperty which provides a default if the property is not present in the parent's _dict. @@ -86,13 +143,34 @@ class DefaultDictProperty(DictProperty): __slots__ = ["default"] - def __init__(self, key, default): + def __init__(self, key: str, default: T): super().__init__(key) self.default = default - def __get__(self, instance, owner=None): + @overload + def __get__( + self, + instance: Literal[None], + owner: Optional[Type[_DictPropertyInstance]] = None, + ) -> "DefaultDictProperty": + ... + + @overload + def __get__( + self, + instance: _DictPropertyInstance, + owner: Optional[Type[_DictPropertyInstance]] = None, + ) -> T: + ... + + def __get__( + self, + instance: Optional[_DictPropertyInstance], + owner: Optional[Type[_DictPropertyInstance]] = None, + ) -> Union[T, "DefaultDictProperty"]: if instance is None: return self + assert isinstance(instance, (EventBase, _EventInternalMetadata)) return instance._dict.get(self.key, self.default) @@ -111,22 +189,22 @@ class _EventInternalMetadata: # in the DAG) self.outlier = False - out_of_band_membership: bool = DictProperty("out_of_band_membership") - send_on_behalf_of: str = DictProperty("send_on_behalf_of") - recheck_redaction: bool = DictProperty("recheck_redaction") - soft_failed: bool = DictProperty("soft_failed") - proactively_send: bool = DictProperty("proactively_send") - redacted: bool = DictProperty("redacted") - txn_id: str = DictProperty("txn_id") - token_id: int = DictProperty("token_id") - historical: bool = DictProperty("historical") + out_of_band_membership: DictProperty[bool] = DictProperty("out_of_band_membership") + send_on_behalf_of: DictProperty[str] = DictProperty("send_on_behalf_of") + recheck_redaction: DictProperty[bool] = DictProperty("recheck_redaction") + soft_failed: DictProperty[bool] = DictProperty("soft_failed") + proactively_send: DictProperty[bool] = DictProperty("proactively_send") + redacted: DictProperty[bool] = DictProperty("redacted") + txn_id: DictProperty[str] = DictProperty("txn_id") + token_id: DictProperty[int] = DictProperty("token_id") + historical: DictProperty[bool] = DictProperty("historical") # XXX: These are set by StreamWorkerStore._set_before_and_after. # I'm pretty sure that these are never persisted to the database, so shouldn't # be here - before: RoomStreamToken = DictProperty("before") - after: RoomStreamToken = DictProperty("after") - order: Tuple[int, int] = DictProperty("order") + before: DictProperty[RoomStreamToken] = DictProperty("before") + after: DictProperty[RoomStreamToken] = DictProperty("after") + order: DictProperty[Tuple[int, int]] = DictProperty("order") def get_dict(self) -> JsonDict: return dict(self._dict) @@ -162,9 +240,6 @@ class _EventInternalMetadata: If the sender of the redaction event is allowed to redact any event due to auth rules, then this will always return false. - - Returns: - bool """ return self._dict.get("recheck_redaction", False) @@ -176,32 +251,23 @@ class _EventInternalMetadata: sent to clients. 2. They should not be added to the forward extremities (and therefore not to current state). - - Returns: - bool """ return self._dict.get("soft_failed", False) - def should_proactively_send(self): + def should_proactively_send(self) -> bool: """Whether the event, if ours, should be sent to other clients and servers. This is used for sending dummy events internally. Servers and clients can still explicitly fetch the event. - - Returns: - bool """ return self._dict.get("proactively_send", True) - def is_redacted(self): + def is_redacted(self) -> bool: """Whether the event has been redacted. This is used for efficiently checking whether an event has been marked as redacted without needing to make another database call. - - Returns: - bool """ return self._dict.get("redacted", False) @@ -241,29 +307,31 @@ class EventBase(metaclass=abc.ABCMeta): self.internal_metadata = _EventInternalMetadata(internal_metadata_dict) - auth_events = DictProperty("auth_events") - depth = DictProperty("depth") - content = DictProperty("content") - hashes = DictProperty("hashes") - origin = DictProperty("origin") - origin_server_ts = DictProperty("origin_server_ts") - prev_events = DictProperty("prev_events") - redacts = DefaultDictProperty("redacts", None) - room_id = DictProperty("room_id") - sender = DictProperty("sender") - state_key = DictProperty("state_key") - type = DictProperty("type") - user_id = DictProperty("sender") + depth: DictProperty[int] = DictProperty("depth") + content: DictProperty[JsonDict] = DictProperty("content") + hashes: DictProperty[Dict[str, str]] = DictProperty("hashes") + origin: DictProperty[str] = DictProperty("origin") + origin_server_ts: DictProperty[int] = DictProperty("origin_server_ts") + redacts: DefaultDictProperty[Optional[str]] = DefaultDictProperty("redacts", None) + room_id: DictProperty[str] = DictProperty("room_id") + sender: DictProperty[str] = DictProperty("sender") + # TODO state_key should be Optional[str], this is generally asserted in Synapse + # by calling is_state() first (which ensures this), but it is hard (not possible?) + # to properly annotate that calling is_state() asserts that state_key exists + # and is non-None. + state_key: DictProperty[str] = DictProperty("state_key") + type: DictProperty[str] = DictProperty("type") + user_id: DictProperty[str] = DictProperty("sender") @property def event_id(self) -> str: raise NotImplementedError() @property - def membership(self): + def membership(self) -> str: return self.content["membership"] - def is_state(self): + def is_state(self) -> bool: return hasattr(self, "state_key") and self.state_key is not None def get_dict(self) -> JsonDict: @@ -272,13 +340,13 @@ class EventBase(metaclass=abc.ABCMeta): return d - def get(self, key, default=None): + def get(self, key: str, default: Optional[Any] = None) -> Any: return self._dict.get(key, default) - def get_internal_metadata_dict(self): + def get_internal_metadata_dict(self) -> JsonDict: return self.internal_metadata.get_dict() - def get_pdu_json(self, time_now=None) -> JsonDict: + def get_pdu_json(self, time_now: Optional[int] = None) -> JsonDict: pdu_json = self.get_dict() if time_now is not None and "age_ts" in pdu_json["unsigned"]: @@ -305,49 +373,46 @@ class EventBase(metaclass=abc.ABCMeta): return template_json - def __set__(self, instance, value): - raise AttributeError("Unrecognized attribute %s" % (instance,)) - - def __getitem__(self, field): + def __getitem__(self, field: str) -> Optional[Any]: return self._dict[field] - def __contains__(self, field): + def __contains__(self, field: str) -> bool: return field in self._dict - def items(self): + def items(self) -> List[Tuple[str, Optional[Any]]]: return list(self._dict.items()) - def keys(self): + def keys(self) -> Iterable[str]: return self._dict.keys() - def prev_event_ids(self): + def prev_event_ids(self) -> Sequence[str]: """Returns the list of prev event IDs. The order matches the order specified in the event, though there is no meaning to it. Returns: - list[str]: The list of event IDs of this event's prev_events + The list of event IDs of this event's prev_events """ - return [e for e, _ in self.prev_events] + return [e for e, _ in self._dict["prev_events"]] - def auth_event_ids(self): + def auth_event_ids(self) -> Sequence[str]: """Returns the list of auth event IDs. The order matches the order specified in the event, though there is no meaning to it. Returns: - list[str]: The list of event IDs of this event's auth_events + The list of event IDs of this event's auth_events """ - return [e for e, _ in self.auth_events] + return [e for e, _ in self._dict["auth_events"]] - def freeze(self): + def freeze(self) -> None: """'Freeze' the event dict, so it cannot be modified by accident""" # this will be a no-op if the event dict is already frozen. self._dict = freeze(self._dict) - def __str__(self): + def __str__(self) -> str: return self.__repr__() - def __repr__(self): + def __repr__(self) -> str: rejection = f"REJECTED={self.rejected_reason}, " if self.rejected_reason else "" return ( @@ -443,7 +508,7 @@ class FrozenEventV2(EventBase): else: frozen_dict = event_dict - self._event_id = None + self._event_id: Optional[str] = None super().__init__( frozen_dict, @@ -455,7 +520,7 @@ class FrozenEventV2(EventBase): ) @property - def event_id(self): + def event_id(self) -> str: # We have to import this here as otherwise we get an import loop which # is hard to break. from synapse.crypto.event_signing import compute_event_reference_hash @@ -465,23 +530,23 @@ class FrozenEventV2(EventBase): self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1]) return self._event_id - def prev_event_ids(self): + def prev_event_ids(self) -> Sequence[str]: """Returns the list of prev event IDs. The order matches the order specified in the event, though there is no meaning to it. Returns: - list[str]: The list of event IDs of this event's prev_events + The list of event IDs of this event's prev_events """ - return self.prev_events + return self._dict["prev_events"] - def auth_event_ids(self): + def auth_event_ids(self) -> Sequence[str]: """Returns the list of auth event IDs. The order matches the order specified in the event, though there is no meaning to it. Returns: - list[str]: The list of event IDs of this event's auth_events + The list of event IDs of this event's auth_events """ - return self.auth_events + return self._dict["auth_events"] class FrozenEventV3(FrozenEventV2): @@ -490,7 +555,7 @@ class FrozenEventV3(FrozenEventV2): format_version = EventFormatVersions.V3 # All events of this type are V3 @property - def event_id(self): + def event_id(self) -> str: # We have to import this here as otherwise we get an import loop which # is hard to break. from synapse.crypto.event_signing import compute_event_reference_hash @@ -503,12 +568,14 @@ class FrozenEventV3(FrozenEventV2): return self._event_id -def _event_type_from_format_version(format_version: int) -> Type[EventBase]: +def _event_type_from_format_version( + format_version: int, +) -> Type[Union[FrozenEvent, FrozenEventV2, FrozenEventV3]]: """Returns the python type to use to construct an Event object for the given event format version. Args: - format_version (int): The event format version + format_version: The event format version Returns: type: A type that can be initialized as per the initializer of diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 8816ef4b76..1bb8ca7145 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py
@@ -14,7 +14,7 @@ import logging from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple -from synapse.api.errors import SynapseError +from synapse.api.errors import ModuleFailedException, SynapseError from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.types import Requester, StateMap @@ -233,9 +233,10 @@ class ThirdPartyEventRules: # This module callback needs a rework so that hacks such as # this one are not necessary. raise e - except Exception as e: - logger.warning("Failed to run module API callback %s: %s", callback, e) - continue + except Exception: + raise ModuleFailedException( + "Failed to run `check_event_allowed` module API callback" + ) # Return if the event shouldn't be allowed or if the module came up with a # replacement dict for the event. diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 4d459c17f1..cf86934968 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py
@@ -55,7 +55,7 @@ class EventValidator: ] for k in required: - if not hasattr(event, k): + if k not in event: raise SynapseError(400, "Event does not have key %s" % (k,)) # Check that the following keys have string values diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 32a75993d9..9a8758e9a6 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py
@@ -213,6 +213,11 @@ class FederationServer(FederationBase): self._started_handling_of_staged_events = True self._handle_old_staged_events() + # Start a periodic check for old staged events. This is to handle + # the case where locks time out, e.g. if another process gets killed + # without dropping its locks. + self._clock.looping_call(self._handle_old_staged_events, 60 * 1000) + # keep this as early as possible to make the calculated origin ts as # accurate as possible. request_time = self._clock.time_msec() @@ -1232,10 +1237,6 @@ class FederationHandlerRegistry: self.query_handlers[query_type] = handler - def register_instance_for_edu(self, edu_type: str, instance_name: str) -> None: - """Register that the EDU handler is on a different instance than master.""" - self._edu_type_to_instance[edu_type] = [instance_name] - def register_instances_for_edu( self, edu_type: str, instance_names: List[str] ) -> None: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index d963178838..10b5aa5af8 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py
@@ -1310,14 +1310,17 @@ class SendJoinParser(ByteParser[SendJoinResponse]): self._coro_state = ijson.items_coro( _event_list_parser(room_version, self._response.state), prefix + "state.item", + use_float=True, ) self._coro_auth = ijson.items_coro( _event_list_parser(room_version, self._response.auth_events), prefix + "auth_chain.item", + use_float=True, ) self._coro_event = ijson.kvitems_coro( _event_parser(self._response.event_dict), prefix + "org.matrix.msc3083.v2.event", + use_float=True, ) def write(self, data: bytes) -> int: diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 36c206dae6..ddc9105ee9 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py
@@ -34,6 +34,7 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.storage.databases.main.directory import RoomAliasMapping from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID +from synapse.util.async_helpers import Linearizer from synapse.util.metrics import Measure if TYPE_CHECKING: @@ -58,6 +59,10 @@ class ApplicationServicesHandler: self.current_max = 0 self.is_processing = False + self._ephemeral_events_linearizer = Linearizer( + name="appservice_ephemeral_events" + ) + def notify_interested_services(self, max_token: RoomStreamToken) -> None: """Notifies (pushes) all application services interested in this event. @@ -182,7 +187,7 @@ class ApplicationServicesHandler: def notify_interested_services_ephemeral( self, stream_key: str, - new_token: Optional[int], + new_token: Union[int, RoomStreamToken], users: Optional[Collection[Union[str, UserID]]] = None, ) -> None: """ @@ -203,7 +208,7 @@ class ApplicationServicesHandler: Appservices will only receive ephemeral events that fall within their registered user and room namespaces. - new_token: The latest stream token. + new_token: The stream token of the event. users: The users that should be informed of the new event, if any. """ if not self.notify_appservices: @@ -212,6 +217,19 @@ class ApplicationServicesHandler: if stream_key not in ("typing_key", "receipt_key", "presence_key"): return + # Assert that new_token is an integer (and not a RoomStreamToken). + # All of the supported streams that this function handles use an + # integer to track progress (rather than a RoomStreamToken - a + # vector clock implementation) as they don't support multiple + # stream writers. + # + # As a result, we simply assert that new_token is an integer. + # If we do end up needing to pass a RoomStreamToken down here + # in the future, using RoomStreamToken.stream (the minimum stream + # position) to convert to an ascending integer value should work. + # Additional context: https://github.com/matrix-org/synapse/pull/11137 + assert isinstance(new_token, int) + services = [ service for service in self.store.get_app_services() @@ -231,14 +249,13 @@ class ApplicationServicesHandler: self, services: List[ApplicationService], stream_key: str, - new_token: Optional[int], + new_token: int, users: Collection[Union[str, UserID]], ) -> None: logger.debug("Checking interested services for %s" % (stream_key)) with Measure(self.clock, "notify_interested_services_ephemeral"): for service in services: - # Only handle typing if we have the latest token - if stream_key == "typing_key" and new_token is not None: + if stream_key == "typing_key": # Note that we don't persist the token (via set_type_stream_id_for_appservice) # for typing_key due to performance reasons and due to their highly # ephemeral nature. @@ -248,26 +265,37 @@ class ApplicationServicesHandler: events = await self._handle_typing(service, new_token) if events: self.scheduler.submit_ephemeral_events_for_as(service, events) + continue - elif stream_key == "receipt_key": - events = await self._handle_receipts(service) - if events: - self.scheduler.submit_ephemeral_events_for_as(service, events) - - # Persist the latest handled stream token for this appservice - await self.store.set_type_stream_id_for_appservice( - service, "read_receipt", new_token + # Since we read/update the stream position for this AS/stream + with ( + await self._ephemeral_events_linearizer.queue( + (service.id, stream_key) ) + ): + if stream_key == "receipt_key": + events = await self._handle_receipts(service, new_token) + if events: + self.scheduler.submit_ephemeral_events_for_as( + service, events + ) + + # Persist the latest handled stream token for this appservice + await self.store.set_type_stream_id_for_appservice( + service, "read_receipt", new_token + ) - elif stream_key == "presence_key": - events = await self._handle_presence(service, users) - if events: - self.scheduler.submit_ephemeral_events_for_as(service, events) + elif stream_key == "presence_key": + events = await self._handle_presence(service, users, new_token) + if events: + self.scheduler.submit_ephemeral_events_for_as( + service, events + ) - # Persist the latest handled stream token for this appservice - await self.store.set_type_stream_id_for_appservice( - service, "presence", new_token - ) + # Persist the latest handled stream token for this appservice + await self.store.set_type_stream_id_for_appservice( + service, "presence", new_token + ) async def _handle_typing( self, service: ApplicationService, new_token: int @@ -304,7 +332,9 @@ class ApplicationServicesHandler: ) return typing - async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]: + async def _handle_receipts( + self, service: ApplicationService, new_token: Optional[int] + ) -> List[JsonDict]: """ Return the latest read receipts that the given application service should receive. @@ -323,6 +353,12 @@ class ApplicationServicesHandler: from_key = await self.store.get_type_stream_id_for_appservice( service, "read_receipt" ) + if new_token is not None and new_token <= from_key: + logger.debug( + "Rejecting token lower than or equal to stored: %s" % (new_token,) + ) + return [] + receipts_source = self.event_sources.sources.receipt receipts, _ = await receipts_source.get_new_events_as( service=service, from_key=from_key @@ -330,7 +366,10 @@ class ApplicationServicesHandler: return receipts async def _handle_presence( - self, service: ApplicationService, users: Collection[Union[str, UserID]] + self, + service: ApplicationService, + users: Collection[Union[str, UserID]], + new_token: Optional[int], ) -> List[JsonDict]: """ Return the latest presence updates that the given application service should receive. @@ -353,6 +392,12 @@ class ApplicationServicesHandler: from_key = await self.store.get_type_stream_id_for_appservice( service, "presence" ) + if new_token is not None and new_token <= from_key: + logger.debug( + "Rejecting token lower than or equal to stored: %s" % (new_token,) + ) + return [] + for user in users: if isinstance(user, str): user = UserID.from_string(user) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index d508d7d32a..60e59d11a0 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -1989,7 +1989,9 @@ class PasswordAuthProvider: self, check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None, on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None, - auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None, + auth_checkers: Optional[ + Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK] + ] = None, ) -> None: # Register check_3pid_auth callback if check_3pid_auth is not None: diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index e617db4c0d..1a1cd93b1a 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py
@@ -1643,7 +1643,7 @@ class FederationEventHandler: event: the event whose auth_events we want Returns: - all of the events in `event.auth_events`, after deduplication + all of the events listed in `event.auth_events_ids`, after deduplication Raises: AuthError if we were unable to fetch the auth_events for any reason. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 4a0fccfcc6..b7bc187169 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py
@@ -1318,6 +1318,8 @@ class EventCreationHandler: # user is actually admin or not). is_admin_redaction = False if event.type == EventTypes.Redaction: + assert event.redacts is not None + original_event = await self.store.get_event( event.redacts, redact_behaviour=EventRedactBehaviour.AS_IS, @@ -1413,6 +1415,8 @@ class EventCreationHandler: ) if event.type == EventTypes.Redaction: + assert event.redacts is not None + original_event = await self.store.get_event( event.redacts, redact_behaviour=EventRedactBehaviour.AS_IS, @@ -1500,11 +1504,13 @@ class EventCreationHandler: next_batch_id = event.content.get( EventContentFields.MSC2716_NEXT_BATCH_ID ) - conflicting_insertion_event_id = ( - await self.store.get_insertion_event_by_batch_id( - event.room_id, next_batch_id + conflicting_insertion_event_id = None + if next_batch_id: + conflicting_insertion_event_id = ( + await self.store.get_insertion_event_by_batch_id( + event.room_id, next_batch_id + ) ) - ) if conflicting_insertion_event_id is not None: # The current insertion event that we're processing is invalid # because an insertion event already exists in the room with the diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 99e9b37344..969eb3b9b0 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py
@@ -525,7 +525,7 @@ class RoomCreationHandler: ): await self.room_member_handler.update_membership( requester, - UserID.from_string(old_event["state_key"]), + UserID.from_string(old_event.state_key), new_room_id, "ban", ratelimit=False, diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 2f5a3e4d19..0723286383 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py
@@ -355,7 +355,7 @@ class RoomBatchHandler: for (event, context) in reversed(events_to_persist): await self.event_creation_handler.handle_new_client_event( await self.create_requester_for_user_id_from_app_service( - event["sender"], app_service_requester.app_service + event.sender, app_service_requester.app_service ), event=event, context=context, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 74e6c7eca6..08244b690d 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -1669,7 +1669,9 @@ class RoomMemberMasterHandler(RoomMemberHandler): # # the prev_events consist solely of the previous membership event. prev_event_ids = [previous_membership_event.event_id] - auth_event_ids = previous_membership_event.auth_event_ids() + prev_event_ids + auth_event_ids = ( + list(previous_membership_event.auth_event_ids()) + prev_event_ids + ) event, context = await self.event_creation_handler.create_event( requester, diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index c411d69924..22c6174821 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py
@@ -62,8 +62,8 @@ class FollowerTypingHandler: if hs.should_send_federation(): self.federation = hs.get_federation_sender() - if hs.config.worker.writers.typing != hs.get_instance_name(): - hs.get_federation_registry().register_instance_for_edu( + if hs.get_instance_name() not in hs.config.worker.writers.typing: + hs.get_federation_registry().register_instances_for_edu( "m.typing", hs.config.worker.writers.typing, ) @@ -205,7 +205,7 @@ class TypingWriterHandler(FollowerTypingHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - assert hs.config.worker.writers.typing == hs.get_instance_name() + assert hs.get_instance_name() in hs.config.worker.writers.typing self.auth = hs.get_auth() self.notifier = hs.get_notifier() diff --git a/synapse/notifier.py b/synapse/notifier.py
index 1882fffd2a..60e5409895 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py
@@ -383,29 +383,6 @@ class Notifier: except Exception: logger.exception("Error notifying application services of event") - def _notify_app_services_ephemeral( - self, - stream_key: str, - new_token: Union[int, RoomStreamToken], - users: Optional[Collection[Union[str, UserID]]] = None, - ) -> None: - """Notify application services of ephemeral event activity. - - Args: - stream_key: The stream the event came from. - new_token: The value of the new stream token. - users: The users that should be informed of the new event, if any. - """ - try: - stream_token = None - if isinstance(new_token, int): - stream_token = new_token - self.appservice_handler.notify_interested_services_ephemeral( - stream_key, stream_token, users or [] - ) - except Exception: - logger.exception("Error notifying application services of event") - def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken): try: self._pusher_pool.on_new_notifications(max_room_stream_token) @@ -467,12 +444,15 @@ class Notifier: self.notify_replication() - # Notify appservices - self._notify_app_services_ephemeral( - stream_key, - new_token, - users, - ) + # Notify appservices. + try: + self.appservice_handler.notify_interested_services_ephemeral( + stream_key, + new_token, + users, + ) + except Exception: + logger.exception("Error notifying application services of event") def on_new_replication_data(self) -> None: """Used to inform replication listeners that something has happened diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 0622a37ae8..009d8e77b0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -232,6 +232,8 @@ class BulkPushRuleEvaluator: # that user, as they might not be already joined. if event.type == EventTypes.Member and event.state_key == uid: display_name = event.content.get("displayname", None) + if not isinstance(display_name, str): + display_name = None if count_as_unread: # Add an element for the current user if the event needs to be marked as @@ -268,7 +270,7 @@ def _condition_checker( evaluator: PushRuleEvaluatorForEvent, conditions: List[dict], uid: str, - display_name: str, + display_name: Optional[str], cache: Dict[str, bool], ) -> bool: for cond in conditions: diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 7a8dc63976..7f68092ec5 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py
@@ -18,7 +18,7 @@ import re from typing import Any, Dict, List, Optional, Pattern, Tuple, Union from synapse.events import EventBase -from synapse.types import UserID +from synapse.types import JsonDict, UserID from synapse.util import glob_to_regex, re_word_boundary from synapse.util.caches.lrucache import LruCache @@ -129,7 +129,7 @@ class PushRuleEvaluatorForEvent: self._value_cache = _flatten_dict(event) def matches( - self, condition: Dict[str, Any], user_id: str, display_name: str + self, condition: Dict[str, Any], user_id: str, display_name: Optional[str] ) -> bool: if condition["kind"] == "event_match": return self._event_match(condition, user_id) @@ -172,7 +172,7 @@ class PushRuleEvaluatorForEvent: return _glob_matches(pattern, haystack) - def _contains_display_name(self, display_name: str) -> bool: + def _contains_display_name(self, display_name: Optional[str]) -> bool: if not display_name: return False @@ -222,7 +222,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: def _flatten_dict( - d: Union[EventBase, dict], + d: Union[EventBase, JsonDict], prefix: Optional[List[str]] = None, result: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: @@ -233,7 +233,7 @@ def _flatten_dict( for key, value in d.items(): if isinstance(value, str): result[".".join(prefix + [key])] = value.lower() - elif hasattr(value, "items"): + elif isinstance(value, dict): _flatten_dict(value, prefix=(prefix + [key]), result=result) return result diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 06fd06fdf3..21293038ef 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py
@@ -138,7 +138,7 @@ class ReplicationCommandHandler: if isinstance(stream, TypingStream): # Only add TypingStream as a source on the instance in charge of # typing. - if hs.config.worker.writers.typing == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.typing: self._streams_to_replicate.append(stream) continue diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index c8b188ae4e..743a01da08 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py
@@ -328,8 +328,7 @@ class TypingStream(Stream): ROW_TYPE = TypingStreamRow def __init__(self, hs: "HomeServer"): - writer_instance = hs.config.worker.writers.typing - if writer_instance == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.typing: # On the writer, query the typing handler typing_writer_handler = hs.get_typing_writer_handler() update_function: Callable[ diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index e1506deb2b..70514e814f 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py
@@ -42,7 +42,6 @@ from synapse.rest.admin.registration_tokens import ( RegistrationTokenRestServlet, ) from synapse.rest.admin.rooms import ( - DeleteRoomRestServlet, ForwardExtremitiesRestServlet, JoinRoomAliasServlet, ListRoomRestServlet, @@ -221,7 +220,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomStateRestServlet(hs).register(http_server) RoomRestServlet(hs).register(http_server) RoomMembersRestServlet(hs).register(http_server) - DeleteRoomRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index a4823ca6e7..05c5b4bf0c 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py
@@ -46,41 +46,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class DeleteRoomRestServlet(RestServlet): - """Delete a room from server. - - It is a combination and improvement of shutdown and purge room. - - Shuts down a room by removing all local users from the room. - Blocking all future invites and joins to the room is optional. - - If desired any local aliases will be repointed to a new room - created by `new_room_user_id` and kicked users will be auto- - joined to the new room. - - If 'purge' is true, it will remove all traces of a room from the database. - """ - - PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$") - - def __init__(self, hs: "HomeServer"): - self.hs = hs - self.auth = hs.get_auth() - self.room_shutdown_handler = hs.get_room_shutdown_handler() - self.pagination_handler = hs.get_pagination_handler() - - async def on_POST( - self, request: SynapseRequest, room_id: str - ) -> Tuple[int, JsonDict]: - return await _delete_room( - request, - room_id, - self.auth, - self.room_shutdown_handler, - self.pagination_handler, - ) - - class ListRoomRestServlet(RestServlet): """ List all rooms that are known to the homeserver. Results are returned @@ -218,7 +183,7 @@ class RoomRestServlet(RestServlet): async def on_DELETE( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: - return await _delete_room( + return await self._delete_room( request, room_id, self.auth, @@ -226,6 +191,58 @@ class RoomRestServlet(RestServlet): self.pagination_handler, ) + async def _delete_room( + self, + request: SynapseRequest, + room_id: str, + auth: "Auth", + room_shutdown_handler: "RoomShutdownHandler", + pagination_handler: "PaginationHandler", + ) -> Tuple[int, JsonDict]: + requester = await auth.get_user_by_req(request) + await assert_user_is_admin(auth, requester.user) + + content = parse_json_object_from_request(request) + + block = content.get("block", False) + if not isinstance(block, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'block' must be a boolean, if given", + Codes.BAD_JSON, + ) + + purge = content.get("purge", True) + if not isinstance(purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + + force_purge = content.get("force_purge", False) + if not isinstance(force_purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'force_purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + + ret = await room_shutdown_handler.shutdown_room( + room_id=room_id, + new_room_user_id=content.get("new_room_user_id"), + new_room_name=content.get("room_name"), + message=content.get("message"), + requester_user_id=requester.user.to_string(), + block=block, + ) + + # Purge room + if purge: + await pagination_handler.purge_room(room_id, force=force_purge) + + return 200, ret + class RoomMembersRestServlet(RestServlet): """ @@ -617,55 +634,3 @@ class RoomEventContextServlet(RestServlet): ) return 200, results - - -async def _delete_room( - request: SynapseRequest, - room_id: str, - auth: "Auth", - room_shutdown_handler: "RoomShutdownHandler", - pagination_handler: "PaginationHandler", -) -> Tuple[int, JsonDict]: - requester = await auth.get_user_by_req(request) - await assert_user_is_admin(auth, requester.user) - - content = parse_json_object_from_request(request) - - block = content.get("block", False) - if not isinstance(block, bool): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Param 'block' must be a boolean, if given", - Codes.BAD_JSON, - ) - - purge = content.get("purge", True) - if not isinstance(purge, bool): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Param 'purge' must be a boolean, if given", - Codes.BAD_JSON, - ) - - force_purge = content.get("force_purge", False) - if not isinstance(force_purge, bool): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Param 'force_purge' must be a boolean, if given", - Codes.BAD_JSON, - ) - - ret = await room_shutdown_handler.shutdown_room( - room_id=room_id, - new_room_user_id=content.get("new_room_user_id"), - new_room_name=content.get("room_name"), - message=content.get("message"), - requester_user_id=requester.user.to_string(), - block=block, - ) - - # Purge room - if purge: - await pagination_handler.purge_room(room_id, force=force_purge) - - return 200, ret diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index ed95189b6d..6a876cfa2f 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py
@@ -914,7 +914,7 @@ class RoomTypingRestServlet(RestServlet): # If we're not on the typing writer instance we should scream if we get # requests. self._is_typing_writer = ( - hs.config.worker.writers.typing == hs.get_instance_name() + hs.get_instance_name() in hs.config.worker.writers.typing ) async def on_PUT( diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index 99f8156ad0..46f033eee2 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py
@@ -131,20 +131,22 @@ class RoomBatchSendEventRestServlet(RestServlet): prev_event_ids_from_query ) + state_event_ids_at_start = [] # Create and persist all of the state events that float off on their own # before the batch. These will most likely be all of the invite/member # state events used to auth the upcoming historical messages. - state_event_ids_at_start = ( - await self.room_batch_handler.persist_state_events_at_start( - state_events_at_start=body["state_events_at_start"], - room_id=room_id, - initial_auth_event_ids=auth_event_ids, - app_service_requester=requester, + if body["state_events_at_start"]: + state_event_ids_at_start = ( + await self.room_batch_handler.persist_state_events_at_start( + state_events_at_start=body["state_events_at_start"], + room_id=room_id, + initial_auth_event_ids=auth_event_ids, + app_service_requester=requester, + ) ) - ) - # Update our ongoing auth event ID list with all of the new state we - # just created - auth_event_ids.extend(state_event_ids_at_start) + # Update our ongoing auth event ID list with all of the new state we + # just created + auth_event_ids.extend(state_event_ids_at_start) inherited_depth = await self.room_batch_handler.inherit_depth_from_prev_ids( prev_event_ids_from_query @@ -191,14 +193,17 @@ class RoomBatchSendEventRestServlet(RestServlet): depth=inherited_depth, ) - batch_id_to_connect_to = base_insertion_event["content"][ + batch_id_to_connect_to = base_insertion_event.content[ EventContentFields.MSC2716_NEXT_BATCH_ID ] # Also connect the historical event chain to the end of the floating # state chain, which causes the HS to ask for the state at the start of - # the batch later. - prev_event_ids = [state_event_ids_at_start[-1]] + # the batch later. If there is no state chain to connect to, just make + # the insertion event float itself. + prev_event_ids = [] + if len(state_event_ids_at_start): + prev_event_ids = [state_event_ids_at_start[-1]] # Create and persist all of the historical events as well as insertion # and batch meta events to make the batch navigable in the DAG. diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index abd88a2d4f..244ba261bb 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py
@@ -215,6 +215,8 @@ class MediaRepository: self.mark_recently_accessed(None, media_id) media_type = media_info["media_type"] + if not media_type: + media_type = "application/octet-stream" media_length = media_info["media_length"] upload_name = name if name else media_info["upload_name"] url_cache = media_info["url_cache"] @@ -333,6 +335,9 @@ class MediaRepository: logger.info("Media is quarantined") raise NotFoundError() + if not media_info["media_type"]: + media_info["media_type"] = "application/octet-stream" + responder = await self.media_storage.fetch_media(file_info) if responder: return responder, media_info @@ -354,6 +359,8 @@ class MediaRepository: raise e file_id = media_info["filesystem_id"] + if not media_info["media_type"]: + media_info["media_type"] = "application/octet-stream" file_info = FileInfo(server_name, file_id) # We generate thumbnails even if another process downloaded the media @@ -445,7 +452,10 @@ class MediaRepository: await finish() - media_type = headers[b"Content-Type"][0].decode("ascii") + if b"Content-Type" in headers: + media_type = headers[b"Content-Type"][0].decode("ascii") + else: + media_type = "application/octet-stream" upload_name = get_filename_from_headers(headers) time_now_ms = self.clock.time_msec() diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 278fd901e2..8ca97b5b18 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -718,9 +718,12 @@ def decode_body( if not body: return None + # The idea here is that multiple encodings are tried until one works. + # Unfortunately the result is never used and then LXML will decode the string + # again with the found encoding. for encoding in get_html_media_encodings(body, content_type): try: - body_str = body.decode(encoding) + body.decode(encoding) except Exception: pass else: @@ -732,11 +735,11 @@ def decode_body( from lxml import etree # Create an HTML parser. - parser = etree.HTMLParser(recover=True, encoding="utf-8") + parser = etree.HTMLParser(recover=True, encoding=encoding) # Attempt to parse the body. Returns None if the body was successfully # parsed, but no tree was found. - return etree.fromstring(body_str, parser) + return etree.fromstring(body, parser) def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]: diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 7dcb1428e4..8162094cf6 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py
@@ -80,7 +80,7 @@ class UploadResource(DirectServeJsonResource): assert content_type_headers # for mypy media_type = content_type_headers[0].decode("ascii") else: - raise SynapseError(msg="Upload request missing 'Content-Type'", code=400) + media_type = "application/octet-stream" # if headers.hasHeader(b"Content-Disposition"): # disposition = headers.getRawHeaders(b"Content-Disposition")[0] diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 7ac01faab4..edbf5ce5d0 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py
@@ -21,6 +21,7 @@ from twisted.web.server import Request from synapse.http.server import set_cors_headers from synapse.types import JsonDict from synapse.util import json_encoder +from synapse.util.stringutils import parse_server_name if TYPE_CHECKING: from synapse.server import HomeServer @@ -47,8 +48,8 @@ class WellKnownBuilder: return result -class WellKnownResource(Resource): - """A Twisted web resource which renders the .well-known file""" +class ClientWellKnownResource(Resource): + """A Twisted web resource which renders the .well-known/matrix/client file""" isLeaf = 1 @@ -67,3 +68,45 @@ class WellKnownResource(Resource): logger.debug("returning: %s", r) request.setHeader(b"Content-Type", b"application/json") return json_encoder.encode(r).encode("utf-8") + + +class ServerWellKnownResource(Resource): + """Resource for .well-known/matrix/server, redirecting to port 443""" + + isLeaf = 1 + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._serve_server_wellknown = hs.config.server.serve_server_wellknown + + host, port = parse_server_name(hs.config.server.server_name) + + # If we've got this far, then https://<server_name>/ must route to us, so + # we just redirect the traffic to port 443 instead of 8448. + if port is None: + port = 443 + + self._response = json_encoder.encode({"m.server": f"{host}:{port}"}).encode( + "utf-8" + ) + + def render_GET(self, request: Request) -> bytes: + if not self._serve_server_wellknown: + request.setResponseCode(404) + request.setHeader(b"Content-Type", b"text/plain") + return b"404. Is anything ever truly *well* known?\n" + + request.setHeader(b"Content-Type", b"application/json") + return self._response + + +def well_known_resource(hs: "HomeServer") -> Resource: + """Returns a Twisted web resource which handles '.well-known' requests""" + res = Resource() + matrix_resource = Resource() + res.putChild(b"matrix", matrix_resource) + + matrix_resource.putChild(b"server", ServerWellKnownResource(hs)) + matrix_resource.putChild(b"client", ClientWellKnownResource(hs)) + + return res diff --git a/synapse/server.py b/synapse/server.py
index 0fbf36ba99..013a7bacaa 100644 --- a/synapse/server.py +++ b/synapse/server.py
@@ -463,7 +463,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_typing_writer_handler(self) -> TypingWriterHandler: - if self.config.worker.writers.typing == self.get_instance_name(): + if self.get_instance_name() in self.config.worker.writers.typing: return TypingWriterHandler(self) else: raise Exception("Workers cannot write typing") @@ -474,7 +474,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_typing_handler(self) -> FollowerTypingHandler: - if self.config.worker.writers.typing == self.get_instance_name(): + if self.get_instance_name() in self.config.worker.writers.typing: # Use get_typing_writer_handler to ensure that we use the same # cached version. return self.get_typing_writer_handler() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 98a0239759..1605411b00 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py
@@ -247,7 +247,7 @@ class StateHandler: return await self.get_hosts_in_room_at_events(room_id, event_ids) async def get_hosts_in_room_at_events( - self, room_id: str, event_ids: List[str] + self, room_id: str, event_ids: Iterable[str] ) -> Set[str]: """Get the hosts that were in a room at the given event ids diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 25e9c1efe1..264e625bd7 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -561,6 +561,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox" + REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox" def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) @@ -581,6 +582,11 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): self._remove_deleted_devices_from_device_inbox, ) + self.db_pool.updates.register_background_update_handler( + self.REMOVE_HIDDEN_DEVICES, + self._remove_hidden_devices_from_device_inbox, + ) + async def _background_drop_index_device_inbox(self, progress, batch_size): def reindex_txn(conn): txn = conn.cursor() @@ -676,6 +682,89 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore): return number_deleted + async def _remove_hidden_devices_from_device_inbox( + self, progress: JsonDict, batch_size: int + ) -> int: + """A background update that deletes all device_inboxes for hidden devices. + + This should only need to be run once (when users upgrade to v1.47.0) + + Args: + progress: JsonDict used to store progress of this background update + batch_size: the maximum number of rows to retrieve in a single select query + + Returns: + The number of deleted rows + """ + + def _remove_hidden_devices_from_device_inbox_txn( + txn: LoggingTransaction, + ) -> int: + """stream_id is not unique + we need to use an inclusive `stream_id >= ?` clause, + since we might not have deleted all hidden device messages for the stream_id + returned from the previous query + + Then delete only rows matching the `(user_id, device_id, stream_id)` tuple, + to avoid problems of deleting a large number of rows all at once + due to a single device having lots of device messages. + """ + + last_stream_id = progress.get("stream_id", 0) + + sql = """ + SELECT device_id, user_id, stream_id + FROM device_inbox + WHERE + stream_id >= ? + AND (device_id, user_id) IN ( + SELECT device_id, user_id FROM devices WHERE hidden = ? + ) + ORDER BY stream_id + LIMIT ? + """ + + txn.execute(sql, (last_stream_id, True, batch_size)) + rows = txn.fetchall() + + num_deleted = 0 + for row in rows: + num_deleted += self.db_pool.simple_delete_txn( + txn, + "device_inbox", + {"device_id": row[0], "user_id": row[1], "stream_id": row[2]}, + ) + + if rows: + # We don't just save the `stream_id` in progress as + # otherwise it can happen in large deployments that + # no change of status is visible in the log file, as + # it may be that the stream_id does not change in several runs + self.db_pool.updates._background_update_progress_txn( + txn, + self.REMOVE_HIDDEN_DEVICES, + { + "device_id": rows[-1][0], + "user_id": rows[-1][1], + "stream_id": rows[-1][2], + }, + ) + + return num_deleted + + number_deleted = await self.db_pool.runInteraction( + "_remove_hidden_devices_from_device_inbox", + _remove_hidden_devices_from_device_inbox_txn, + ) + + # The task is finished when no more lines are deleted. + if not number_deleted: + await self.db_pool.updates._end_background_update( + self.REMOVE_HIDDEN_DEVICES + ) + + return number_deleted + class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): pass diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index b15cd030e0..9ccc66e589 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -427,7 +427,7 @@ class DeviceWorkerStore(SQLBaseStore): user_ids: the users who were signed Returns: - THe new stream ID. + The new stream ID. """ async with self._device_list_id_gen.get_next() as stream_id: @@ -1322,7 +1322,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): async def add_device_change_to_streams( self, user_id: str, device_ids: Collection[str], hosts: List[str] - ): + ) -> int: """Persist that a user's devices have been updated, and which hosts (if any) should be poked. """ diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 8d9086ecf0..596275c23c 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -24,6 +24,7 @@ from typing import ( Iterable, List, Optional, + Sequence, Set, Tuple, ) @@ -494,7 +495,7 @@ class PersistEventsStore: event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], event_to_types: Dict[str, Tuple[str, str]], - event_to_auth_chain: Dict[str, List[str]], + event_to_auth_chain: Dict[str, Sequence[str]], ) -> None: """Calculate the chain cover index for the given events. @@ -786,7 +787,7 @@ class PersistEventsStore: event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], event_to_types: Dict[str, Tuple[str, str]], - event_to_auth_chain: Dict[str, List[str]], + event_to_auth_chain: Dict[str, Sequence[str]], events_to_calc_chain_id_for: Set[str], chain_map: Dict[str, Tuple[int, int]], ) -> Dict[str, Tuple[int, int]]: @@ -1794,7 +1795,7 @@ class PersistEventsStore: ) # Insert an edge for every prev_event connection - for prev_event_id in event.prev_events: + for prev_event_id in event.prev_event_ids(): self.db_pool.simple_insert_txn( txn, table="insertion_event_edges", diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index ae37901be9..c6bf316d5b 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -28,6 +28,7 @@ from typing import ( import attr from constantly import NamedConstant, Names +from prometheus_client import Gauge from typing_extensions import Literal from twisted.internet import defer @@ -81,6 +82,12 @@ EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events +event_fetch_ongoing_gauge = Gauge( + "synapse_event_fetch_ongoing", + "The number of event fetchers that are running", +) + + @attr.s(slots=True, auto_attribs=True) class _EventCacheEntry: event: EventBase @@ -222,6 +229,7 @@ class EventsWorkerStore(SQLBaseStore): self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 + event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) # We define this sequence here so that it can be referenced from both # the DataStore and PersistEventStore. @@ -732,28 +740,31 @@ class EventsWorkerStore(SQLBaseStore): """Takes a database connection and waits for requests for events from the _event_fetch_list queue. """ - i = 0 - while True: - with self._event_fetch_lock: - event_list = self._event_fetch_list - self._event_fetch_list = [] - - if not event_list: - single_threaded = self.database_engine.single_threaded - if ( - not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING - or single_threaded - or i > EVENT_QUEUE_ITERATIONS - ): - self._event_fetch_ongoing -= 1 - return - else: - self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) - i += 1 - continue - i = 0 - - self._fetch_event_list(conn, event_list) + try: + i = 0 + while True: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + single_threaded = self.database_engine.single_threaded + if ( + not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING + or single_threaded + or i > EVENT_QUEUE_ITERATIONS + ): + break + else: + self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) + i += 1 + continue + i = 0 + + self._fetch_event_list(conn, event_list) + finally: + self._event_fetch_ongoing -= 1 + event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) def _fetch_event_list( self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]] @@ -977,6 +988,7 @@ class EventsWorkerStore(SQLBaseStore): if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: self._event_fetch_ongoing += 1 + event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) should_start = True else: should_start = False diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index e70d3649ff..bb621df0dd 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py
@@ -13,15 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing_extensions import TypedDict from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import DatabasePool +from synapse.storage.types import Connection from synapse.types import JsonDict from synapse.util import json_encoder +if TYPE_CHECKING: + from synapse.server import HomeServer + # The category ID for the "default" category. We don't store as null in the # database to avoid the fun of null != null _DEFAULT_CATEGORY_ID = "" @@ -35,6 +40,16 @@ class _RoomInGroup(TypedDict): class GroupServerWorkerStore(SQLBaseStore): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + database.updates.register_background_index_update( + update_name="local_group_updates_index", + index_name="local_group_updates_stream_id_index", + table="local_group_updates", + columns=("stream_id",), + unique=True, + ) + super().__init__(database, db_conn, hs) + async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]: return await self.db_pool.simple_select_one( table="groups", diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index 3d1dff660b..3d0df0cbd4 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py
@@ -14,6 +14,7 @@ import logging from types import TracebackType from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type +from weakref import WeakValueDictionary from twisted.internet.interfaces import IReactorCore @@ -61,7 +62,7 @@ class LockStore(SQLBaseStore): # A map from `(lock_name, lock_key)` to the token of any locks that we # think we currently hold. - self._live_tokens: Dict[Tuple[str, str], str] = {} + self._live_tokens: Dict[Tuple[str, str], Lock] = WeakValueDictionary() # When we shut down we want to remove the locks. Technically this can # lead to a race, as we may drop the lock while we are still processing. @@ -80,10 +81,10 @@ class LockStore(SQLBaseStore): # We need to take a copy of the tokens dict as dropping the locks will # cause the dictionary to change. - tokens = dict(self._live_tokens) + locks = dict(self._live_tokens) - for (lock_name, lock_key), token in tokens.items(): - await self._drop_lock(lock_name, lock_key, token) + for lock in locks.values(): + await lock.release() logger.info("Dropped locks due to shutdown") @@ -93,6 +94,11 @@ class LockStore(SQLBaseStore): used (otherwise the lock will leak). """ + # Check if this process has taken out a lock and if it's still valid. + lock = self._live_tokens.get((lock_name, lock_key)) + if lock and await lock.is_still_valid(): + return None + now = self._clock.time_msec() token = random_string(6) @@ -100,7 +106,9 @@ class LockStore(SQLBaseStore): def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool: # We take out the lock if either a) there is no row for the lock - # already or b) the existing row has timed out. + # already, b) the existing row has timed out, or c) the row is + # for this instance (which means the process got killed and + # restarted) sql = """ INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts) VALUES (?, ?, ?, ?, ?) @@ -112,6 +120,7 @@ class LockStore(SQLBaseStore): last_renewed_ts = EXCLUDED.last_renewed_ts WHERE worker_locks.last_renewed_ts < ? + OR worker_locks.instance_name = EXCLUDED.instance_name """ txn.execute( sql, @@ -148,11 +157,11 @@ class LockStore(SQLBaseStore): WHERE lock_name = ? AND lock_key = ? - AND last_renewed_ts < ? + AND (last_renewed_ts < ? OR instance_name = ?) """ txn.execute( sql, - (lock_name, lock_key, now - _LOCK_TIMEOUT_MS), + (lock_name, lock_key, now - _LOCK_TIMEOUT_MS, self._instance_name), ) inserted = self.db_pool.simple_upsert_txn_emulated( @@ -179,9 +188,7 @@ class LockStore(SQLBaseStore): if not did_lock: return None - self._live_tokens[(lock_name, lock_key)] = token - - return Lock( + lock = Lock( self._reactor, self._clock, self, @@ -190,6 +197,10 @@ class LockStore(SQLBaseStore): token=token, ) + self._live_tokens[(lock_name, lock_key)] = lock + + return lock + async def _is_lock_still_valid( self, lock_name: str, lock_key: str, token: str ) -> bool: diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 12cf6995eb..cc0eebdb46 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py
@@ -92,7 +92,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): prefilled_cache=presence_cache_prefill, ) - async def update_presence(self, presence_states): + async def update_presence(self, presence_states) -> Tuple[int, int]: assert self._can_persist_presence stream_ordering_manager = self._presence_id_gen.get_next_mult( diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index f879bbe7c7..cefc77fa0f 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -412,22 +412,33 @@ class RoomWorkerStore(SQLBaseStore): limit: maximum amount of rooms to retrieve order_by: the sort order of the returned list reverse_order: whether to reverse the room list - search_term: a string to filter room names by + search_term: a string to filter room names, + canonical alias and room ids by. + Room ID must match exactly. Canonical alias must match a substring of the local part. Returns: A list of room dicts and an integer representing the total number of rooms that exist given this query """ # Filter room names by a string where_statement = "" + search_pattern = [] if search_term: - where_statement = "WHERE LOWER(state.name) LIKE ?" + where_statement = """ + WHERE LOWER(state.name) LIKE ? + OR LOWER(state.canonical_alias) LIKE ? + OR state.room_id = ? + """ # Our postgres db driver converts ? -> %s in SQL strings as that's the # placeholder for postgres. # HOWEVER, if you put a % into your SQL then everything goes wibbly. # To get around this, we're going to surround search_term with %'s # before giving it to the database in python instead - search_term = "%" + search_term.lower() + "%" + search_pattern = [ + "%" + search_term.lower() + "%", + "#%" + search_term.lower() + "%:%", + search_term, + ] # Set ordering if RoomSortOrder(order_by) == RoomSortOrder.SIZE: @@ -519,12 +530,9 @@ class RoomWorkerStore(SQLBaseStore): ) def _get_rooms_paginate_txn(txn): - # Execute the data query - sql_values = (limit, start) - if search_term: - # Add the search term into the WHERE clause - sql_values = (search_term,) + sql_values - txn.execute(info_sql, sql_values) + # Add the search term into the WHERE clause + # and execute the data query + txn.execute(info_sql, search_pattern + [limit, start]) # Refactor room query data into a structured dictionary rooms = [] @@ -551,8 +559,7 @@ class RoomWorkerStore(SQLBaseStore): # Execute the count query # Add the search term into the WHERE clause if present - sql_values = (search_term,) if search_term else () - txn.execute(count_sql, sql_values) + txn.execute(count_sql, search_pattern) room_count = txn.fetchone() return rooms, room_count[0] diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 4b288bb2e7..033a9831d6 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -570,7 +570,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): async def get_joined_users_from_context( self, event: EventBase, context: EventContext - ): + ) -> Dict[str, ProfileInfo]: state_group = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -584,7 +584,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): event.room_id, state_group, current_state_ids, event=event, context=context ) - async def get_joined_users_from_state(self, room_id, state_entry): + async def get_joined_users_from_state( + self, room_id, state_entry + ) -> Dict[str, ProfileInfo]: state_group = state_entry.state_group if not state_group: # If state_group is None it means it has yet to be assigned a @@ -607,7 +609,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): cache_context, event=None, context=None, - ): + ) -> Dict[str, ProfileInfo]: # We don't use `state_group`, it's there so that we can cache based # on it. However, it's important that it's never None, since two current_states # with a state_group of None are likely to be different. diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 1629d2a53c..b5c1c14ee3 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py
@@ -133,22 +133,23 @@ def prepare_database( # if it's a worker app, refuse to upgrade the database, to avoid multiple # workers doing it at once. - if ( - config.worker.worker_app is not None - and version_info.current_version != SCHEMA_VERSION - ): + if config.worker.worker_app is None: + _upgrade_existing_database( + cur, + version_info, + database_engine, + config, + databases=databases, + ) + elif version_info.current_version < SCHEMA_VERSION: + # If the DB is on an older version than we expect the we refuse + # to start the worker (as the main process needs to run first to + # update the schema). raise UpgradeDatabaseException( OUTDATED_SCHEMA_ON_WORKER_ERROR % (SCHEMA_VERSION, version_info.current_version) ) - _upgrade_existing_database( - cur, - version_info, - database_engine, - config, - databases=databases, - ) else: logger.info("%r: Initialising new database", databases) diff --git a/synapse/storage/schema/main/delta/65/03remove_hidden_devices_from_device_inbox.sql b/synapse/storage/schema/main/delta/65/03remove_hidden_devices_from_device_inbox.sql new file mode 100644
index 0000000000..7b3592dcf0 --- /dev/null +++ b/synapse/storage/schema/main/delta/65/03remove_hidden_devices_from_device_inbox.sql
@@ -0,0 +1,22 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Remove messages from the device_inbox table which were orphaned +-- because a device was hidden using Synapse earlier than 1.47.0. +-- This runs as background task, but may take a bit to finish. + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6503, 'remove_hidden_devices_from_device_inbox', '{}'); diff --git a/synapse/storage/schema/main/delta/65/04_local_group_updates.sql b/synapse/storage/schema/main/delta/65/04_local_group_updates.sql new file mode 100644
index 0000000000..a178abfe12 --- /dev/null +++ b/synapse/storage/schema/main/delta/65/04_local_group_updates.sql
@@ -0,0 +1,18 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Check index on `local_group_updates.stream_id`. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6504, 'local_group_updates_index', '{}'); diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 5df80ea8e7..96efc5f3e3 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py
@@ -22,11 +22,11 @@ from typing import ( Any, Awaitable, Callable, + Collection, Dict, Generic, Hashable, Iterable, - List, Optional, Set, TypeVar, @@ -76,12 +76,17 @@ class ObservableDeferred(Generic[_T]): def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False): object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_result", None) - object.__setattr__(self, "_observers", set()) + object.__setattr__(self, "_observers", []) def callback(r): object.__setattr__(self, "_result", (True, r)) - while self._observers: - observer = self._observers.pop() + + # once we have set _result, no more entries will be added to _observers, + # so it's safe to replace it with the empty tuple. + observers = self._observers + object.__setattr__(self, "_observers", ()) + + for observer in observers: try: observer.callback(r) except Exception as e: @@ -95,12 +100,16 @@ class ObservableDeferred(Generic[_T]): def errback(f): object.__setattr__(self, "_result", (False, f)) - while self._observers: + + # once we have set _result, no more entries will be added to _observers, + # so it's safe to replace it with the empty tuple. + observers = self._observers + object.__setattr__(self, "_observers", ()) + + for observer in observers: # This is a little bit of magic to correctly propagate stack # traces when we `await` on one of the observer deferreds. f.value.__failure__ = f - - observer = self._observers.pop() try: observer.errback(f) except Exception as e: @@ -127,20 +136,13 @@ class ObservableDeferred(Generic[_T]): """ if not self._result: d: "defer.Deferred[_T]" = defer.Deferred() - - def remove(r): - self._observers.discard(d) - return r - - d.addBoth(remove) - - self._observers.add(d) + self._observers.append(d) return d else: success, res = self._result return defer.succeed(res) if success else defer.fail(res) - def observers(self) -> "List[defer.Deferred[_T]]": + def observers(self) -> "Collection[defer.Deferred[_T]]": return self._observers def has_called(self) -> bool: