diff --git a/synapse/__init__.py b/synapse/__init__.py
index 355b36fc63..5ef34bce40 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
except ImportError:
pass
-__version__ = "1.46.0rc1"
+__version__ = "1.46.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 685d1c25cf..85302163da 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -596,3 +596,10 @@ class ShadowBanError(Exception):
This should be caught and a proper "fake" success response sent to the user.
"""
+
+
+class ModuleFailedException(Exception):
+ """
+ Raised when a module API callback fails, for example because it raised an
+ exception.
+ """
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 51eadf122d..218826741e 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -100,6 +100,7 @@ from synapse.rest.client.register import (
from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.synapse.client import build_synapse_client_resource_tree
+from synapse.rest.well_known import well_known_resource
from synapse.server import HomeServer
from synapse.storage.databases.main.censor_events import CensorEventsStore
from synapse.storage.databases.main.client_ips import ClientIpWorkerStore
@@ -318,6 +319,8 @@ class GenericWorkerServer(HomeServer):
resources.update({CLIENT_API_PREFIX: resource})
resources.update(build_synapse_client_resource_tree(self))
+ resources.update({"/.well-known": well_known_resource(self)})
+
elif name == "federation":
resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
elif name == "media":
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 93e2299266..336c279a44 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -66,7 +66,7 @@ from synapse.rest.admin import AdminRestResource
from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.synapse.client import build_synapse_client_resource_tree
-from synapse.rest.well_known import WellKnownResource
+from synapse.rest.well_known import well_known_resource
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.util.httpresourcetree import create_resource_tree
@@ -189,7 +189,7 @@ class SynapseHomeServer(HomeServer):
"/_matrix/client/unstable": client_resource,
"/_matrix/client/v2_alpha": client_resource,
"/_matrix/client/versions": client_resource,
- "/.well-known/matrix/client": WellKnownResource(self),
+ "/.well-known": well_known_resource(self),
"/_synapse/admin": AdminRestResource(self),
**build_synapse_client_resource_tree(self),
}
diff --git a/synapse/config/server.py b/synapse/config/server.py
index ed094bdc44..a387fd9310 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -262,6 +262,7 @@ class ServerConfig(Config):
self.print_pidfile = config.get("print_pidfile")
self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
+ self.serve_server_wellknown = config.get("serve_server_wellknown", False)
self.public_baseurl = config.get("public_baseurl")
if self.public_baseurl is not None:
@@ -774,6 +775,24 @@ class ServerConfig(Config):
#
#public_baseurl: https://example.com/
+ # Uncomment the following to tell other servers to send federation traffic on
+ # port 443.
+ #
+ # By default, other servers will try to reach our server on port 8448, which can
+ # be inconvenient in some environments.
+ #
+ # Provided 'https://<server_name>/' on port 443 is routed to Synapse, this
+ # option configures Synapse to serve a file at
+ # 'https://<server_name>/.well-known/matrix/server'. This will tell other
+ # servers to send traffic to port 443 instead.
+ #
+ # See https://matrix-org.github.io/synapse/latest/delegate.html for more
+ # information.
+ #
+ # Defaults to 'false'.
+ #
+ #serve_server_wellknown: true
+
# Set the soft limit on the number of file descriptors synapse can use
# Zero is used to indicate synapse should set the soft limit to the
# hard limit.
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 462630201d..4507992031 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -63,7 +63,8 @@ class WriterLocations:
Attributes:
events: The instances that write to the event and backfill streams.
- typing: The instance that writes to the typing stream.
+ typing: The instances that write to the typing stream. Currently
+ can only be a single instance.
to_device: The instances that write to the to_device stream. Currently
can only be a single instance.
account_data: The instances that write to the account data streams. Currently
@@ -75,9 +76,15 @@ class WriterLocations:
"""
events = attr.ib(
- default=["master"], type=List[str], converter=_instance_to_list_converter
+ default=["master"],
+ type=List[str],
+ converter=_instance_to_list_converter,
+ )
+ typing = attr.ib(
+ default=["master"],
+ type=List[str],
+ converter=_instance_to_list_converter,
)
- typing = attr.ib(default="master", type=str)
to_device = attr.ib(
default=["master"],
type=List[str],
@@ -217,6 +224,11 @@ class WorkerConfig(Config):
% (instance, stream)
)
+ if len(self.writers.typing) != 1:
+ raise ConfigError(
+ "Must only specify one instance to handle `typing` messages."
+ )
+
if len(self.writers.to_device) != 1:
raise ConfigError(
"Must only specify one instance to handle `to_device` messages."
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 157669ea88..38f3cf4d33 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -16,8 +16,23 @@
import abc
import os
-from typing import Dict, Optional, Tuple, Type
-
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Generic,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ overload,
+)
+
+from typing_extensions import Literal
from unpaddedbase64 import encode_base64
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
@@ -26,6 +41,9 @@ from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
from synapse.util.stringutils import strtobool
+if TYPE_CHECKING:
+ from synapse.events.builder import EventBuilder
+
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
# bugs where we accidentally share e.g. signature dicts. However, converting a
# dict to frozen_dicts is expensive.
@@ -37,7 +55,23 @@ from synapse.util.stringutils import strtobool
USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
-class DictProperty:
+T = TypeVar("T")
+
+
+# DictProperty (and DefaultDictProperty) require the classes they're used with to
+# have a _dict property to pull properties from.
+#
+# TODO _DictPropertyInstance should not include EventBuilder but due to
+# https://github.com/python/mypy/issues/5570 it thinks the DictProperty and
+# DefaultDictProperty get applied to EventBuilder when it is in a Union with
+# EventBase. This is the least invasive hack to get mypy to comply.
+#
+# Note that DictProperty/DefaultDictProperty cannot actually be used with
+# EventBuilder as it lacks a _dict property.
+_DictPropertyInstance = Union["_EventInternalMetadata", "EventBase", "EventBuilder"]
+
+
+class DictProperty(Generic[T]):
"""An object property which delegates to the `_dict` within its parent object."""
__slots__ = ["key"]
@@ -45,12 +79,33 @@ class DictProperty:
def __init__(self, key: str):
self.key = key
- def __get__(self, instance, owner=None):
+ @overload
+ def __get__(
+ self,
+ instance: Literal[None],
+ owner: Optional[Type[_DictPropertyInstance]] = None,
+ ) -> "DictProperty":
+ ...
+
+ @overload
+ def __get__(
+ self,
+ instance: _DictPropertyInstance,
+ owner: Optional[Type[_DictPropertyInstance]] = None,
+ ) -> T:
+ ...
+
+ def __get__(
+ self,
+ instance: Optional[_DictPropertyInstance],
+ owner: Optional[Type[_DictPropertyInstance]] = None,
+ ) -> Union[T, "DictProperty"]:
# if the property is accessed as a class property rather than an instance
# property, return the property itself rather than the value
if instance is None:
return self
try:
+ assert isinstance(instance, (EventBase, _EventInternalMetadata))
return instance._dict[self.key]
except KeyError as e1:
# We want this to look like a regular attribute error (mostly so that
@@ -65,10 +120,12 @@ class DictProperty:
"'%s' has no '%s' property" % (type(instance), self.key)
) from e1.__context__
- def __set__(self, instance, v):
+ def __set__(self, instance: _DictPropertyInstance, v: T) -> None:
+ assert isinstance(instance, (EventBase, _EventInternalMetadata))
instance._dict[self.key] = v
- def __delete__(self, instance):
+ def __delete__(self, instance: _DictPropertyInstance) -> None:
+ assert isinstance(instance, (EventBase, _EventInternalMetadata))
try:
del instance._dict[self.key]
except KeyError as e1:
@@ -77,7 +134,7 @@ class DictProperty:
) from e1.__context__
-class DefaultDictProperty(DictProperty):
+class DefaultDictProperty(DictProperty, Generic[T]):
"""An extension of DictProperty which provides a default if the property is
not present in the parent's _dict.
@@ -86,13 +143,34 @@ class DefaultDictProperty(DictProperty):
__slots__ = ["default"]
- def __init__(self, key, default):
+ def __init__(self, key: str, default: T):
super().__init__(key)
self.default = default
- def __get__(self, instance, owner=None):
+ @overload
+ def __get__(
+ self,
+ instance: Literal[None],
+ owner: Optional[Type[_DictPropertyInstance]] = None,
+ ) -> "DefaultDictProperty":
+ ...
+
+ @overload
+ def __get__(
+ self,
+ instance: _DictPropertyInstance,
+ owner: Optional[Type[_DictPropertyInstance]] = None,
+ ) -> T:
+ ...
+
+ def __get__(
+ self,
+ instance: Optional[_DictPropertyInstance],
+ owner: Optional[Type[_DictPropertyInstance]] = None,
+ ) -> Union[T, "DefaultDictProperty"]:
if instance is None:
return self
+ assert isinstance(instance, (EventBase, _EventInternalMetadata))
return instance._dict.get(self.key, self.default)
@@ -111,22 +189,22 @@ class _EventInternalMetadata:
# in the DAG)
self.outlier = False
- out_of_band_membership: bool = DictProperty("out_of_band_membership")
- send_on_behalf_of: str = DictProperty("send_on_behalf_of")
- recheck_redaction: bool = DictProperty("recheck_redaction")
- soft_failed: bool = DictProperty("soft_failed")
- proactively_send: bool = DictProperty("proactively_send")
- redacted: bool = DictProperty("redacted")
- txn_id: str = DictProperty("txn_id")
- token_id: int = DictProperty("token_id")
- historical: bool = DictProperty("historical")
+ out_of_band_membership: DictProperty[bool] = DictProperty("out_of_band_membership")
+ send_on_behalf_of: DictProperty[str] = DictProperty("send_on_behalf_of")
+ recheck_redaction: DictProperty[bool] = DictProperty("recheck_redaction")
+ soft_failed: DictProperty[bool] = DictProperty("soft_failed")
+ proactively_send: DictProperty[bool] = DictProperty("proactively_send")
+ redacted: DictProperty[bool] = DictProperty("redacted")
+ txn_id: DictProperty[str] = DictProperty("txn_id")
+ token_id: DictProperty[int] = DictProperty("token_id")
+ historical: DictProperty[bool] = DictProperty("historical")
# XXX: These are set by StreamWorkerStore._set_before_and_after.
# I'm pretty sure that these are never persisted to the database, so shouldn't
# be here
- before: RoomStreamToken = DictProperty("before")
- after: RoomStreamToken = DictProperty("after")
- order: Tuple[int, int] = DictProperty("order")
+ before: DictProperty[RoomStreamToken] = DictProperty("before")
+ after: DictProperty[RoomStreamToken] = DictProperty("after")
+ order: DictProperty[Tuple[int, int]] = DictProperty("order")
def get_dict(self) -> JsonDict:
return dict(self._dict)
@@ -162,9 +240,6 @@ class _EventInternalMetadata:
If the sender of the redaction event is allowed to redact any event
due to auth rules, then this will always return false.
-
- Returns:
- bool
"""
return self._dict.get("recheck_redaction", False)
@@ -176,32 +251,23 @@ class _EventInternalMetadata:
sent to clients.
2. They should not be added to the forward extremities (and
therefore not to current state).
-
- Returns:
- bool
"""
return self._dict.get("soft_failed", False)
- def should_proactively_send(self):
+ def should_proactively_send(self) -> bool:
"""Whether the event, if ours, should be sent to other clients and
servers.
This is used for sending dummy events internally. Servers and clients
can still explicitly fetch the event.
-
- Returns:
- bool
"""
return self._dict.get("proactively_send", True)
- def is_redacted(self):
+ def is_redacted(self) -> bool:
"""Whether the event has been redacted.
This is used for efficiently checking whether an event has been
marked as redacted without needing to make another database call.
-
- Returns:
- bool
"""
return self._dict.get("redacted", False)
@@ -241,29 +307,31 @@ class EventBase(metaclass=abc.ABCMeta):
self.internal_metadata = _EventInternalMetadata(internal_metadata_dict)
- auth_events = DictProperty("auth_events")
- depth = DictProperty("depth")
- content = DictProperty("content")
- hashes = DictProperty("hashes")
- origin = DictProperty("origin")
- origin_server_ts = DictProperty("origin_server_ts")
- prev_events = DictProperty("prev_events")
- redacts = DefaultDictProperty("redacts", None)
- room_id = DictProperty("room_id")
- sender = DictProperty("sender")
- state_key = DictProperty("state_key")
- type = DictProperty("type")
- user_id = DictProperty("sender")
+ depth: DictProperty[int] = DictProperty("depth")
+ content: DictProperty[JsonDict] = DictProperty("content")
+ hashes: DictProperty[Dict[str, str]] = DictProperty("hashes")
+ origin: DictProperty[str] = DictProperty("origin")
+ origin_server_ts: DictProperty[int] = DictProperty("origin_server_ts")
+ redacts: DefaultDictProperty[Optional[str]] = DefaultDictProperty("redacts", None)
+ room_id: DictProperty[str] = DictProperty("room_id")
+ sender: DictProperty[str] = DictProperty("sender")
+ # TODO state_key should be Optional[str], this is generally asserted in Synapse
+ # by calling is_state() first (which ensures this), but it is hard (not possible?)
+ # to properly annotate that calling is_state() asserts that state_key exists
+ # and is non-None.
+ state_key: DictProperty[str] = DictProperty("state_key")
+ type: DictProperty[str] = DictProperty("type")
+ user_id: DictProperty[str] = DictProperty("sender")
@property
def event_id(self) -> str:
raise NotImplementedError()
@property
- def membership(self):
+ def membership(self) -> str:
return self.content["membership"]
- def is_state(self):
+ def is_state(self) -> bool:
return hasattr(self, "state_key") and self.state_key is not None
def get_dict(self) -> JsonDict:
@@ -272,13 +340,13 @@ class EventBase(metaclass=abc.ABCMeta):
return d
- def get(self, key, default=None):
+ def get(self, key: str, default: Optional[Any] = None) -> Any:
return self._dict.get(key, default)
- def get_internal_metadata_dict(self):
+ def get_internal_metadata_dict(self) -> JsonDict:
return self.internal_metadata.get_dict()
- def get_pdu_json(self, time_now=None) -> JsonDict:
+ def get_pdu_json(self, time_now: Optional[int] = None) -> JsonDict:
pdu_json = self.get_dict()
if time_now is not None and "age_ts" in pdu_json["unsigned"]:
@@ -305,49 +373,46 @@ class EventBase(metaclass=abc.ABCMeta):
return template_json
- def __set__(self, instance, value):
- raise AttributeError("Unrecognized attribute %s" % (instance,))
-
- def __getitem__(self, field):
+ def __getitem__(self, field: str) -> Optional[Any]:
return self._dict[field]
- def __contains__(self, field):
+ def __contains__(self, field: str) -> bool:
return field in self._dict
- def items(self):
+ def items(self) -> List[Tuple[str, Optional[Any]]]:
return list(self._dict.items())
- def keys(self):
+ def keys(self) -> Iterable[str]:
return self._dict.keys()
- def prev_event_ids(self):
+ def prev_event_ids(self) -> Sequence[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
- list[str]: The list of event IDs of this event's prev_events
+ The list of event IDs of this event's prev_events
"""
- return [e for e, _ in self.prev_events]
+ return [e for e, _ in self._dict["prev_events"]]
- def auth_event_ids(self):
+ def auth_event_ids(self) -> Sequence[str]:
"""Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
- list[str]: The list of event IDs of this event's auth_events
+ The list of event IDs of this event's auth_events
"""
- return [e for e, _ in self.auth_events]
+ return [e for e, _ in self._dict["auth_events"]]
- def freeze(self):
+ def freeze(self) -> None:
"""'Freeze' the event dict, so it cannot be modified by accident"""
# this will be a no-op if the event dict is already frozen.
self._dict = freeze(self._dict)
- def __str__(self):
+ def __str__(self) -> str:
return self.__repr__()
- def __repr__(self):
+ def __repr__(self) -> str:
rejection = f"REJECTED={self.rejected_reason}, " if self.rejected_reason else ""
return (
@@ -443,7 +508,7 @@ class FrozenEventV2(EventBase):
else:
frozen_dict = event_dict
- self._event_id = None
+ self._event_id: Optional[str] = None
super().__init__(
frozen_dict,
@@ -455,7 +520,7 @@ class FrozenEventV2(EventBase):
)
@property
- def event_id(self):
+ def event_id(self) -> str:
# We have to import this here as otherwise we get an import loop which
# is hard to break.
from synapse.crypto.event_signing import compute_event_reference_hash
@@ -465,23 +530,23 @@ class FrozenEventV2(EventBase):
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
return self._event_id
- def prev_event_ids(self):
+ def prev_event_ids(self) -> Sequence[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
- list[str]: The list of event IDs of this event's prev_events
+ The list of event IDs of this event's prev_events
"""
- return self.prev_events
+ return self._dict["prev_events"]
- def auth_event_ids(self):
+ def auth_event_ids(self) -> Sequence[str]:
"""Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
- list[str]: The list of event IDs of this event's auth_events
+ The list of event IDs of this event's auth_events
"""
- return self.auth_events
+ return self._dict["auth_events"]
class FrozenEventV3(FrozenEventV2):
@@ -490,7 +555,7 @@ class FrozenEventV3(FrozenEventV2):
format_version = EventFormatVersions.V3 # All events of this type are V3
@property
- def event_id(self):
+ def event_id(self) -> str:
# We have to import this here as otherwise we get an import loop which
# is hard to break.
from synapse.crypto.event_signing import compute_event_reference_hash
@@ -503,12 +568,14 @@ class FrozenEventV3(FrozenEventV2):
return self._event_id
-def _event_type_from_format_version(format_version: int) -> Type[EventBase]:
+def _event_type_from_format_version(
+ format_version: int,
+) -> Type[Union[FrozenEvent, FrozenEventV2, FrozenEventV3]]:
"""Returns the python type to use to construct an Event object for the
given event format version.
Args:
- format_version (int): The event format version
+ format_version: The event format version
Returns:
type: A type that can be initialized as per the initializer of
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 8816ef4b76..1bb8ca7145 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -14,7 +14,7 @@
import logging
from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple
-from synapse.api.errors import SynapseError
+from synapse.api.errors import ModuleFailedException, SynapseError
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Requester, StateMap
@@ -233,9 +233,10 @@ class ThirdPartyEventRules:
# This module callback needs a rework so that hacks such as
# this one are not necessary.
raise e
- except Exception as e:
- logger.warning("Failed to run module API callback %s: %s", callback, e)
- continue
+ except Exception:
+ raise ModuleFailedException(
+ "Failed to run `check_event_allowed` module API callback"
+ )
# Return if the event shouldn't be allowed or if the module came up with a
# replacement dict for the event.
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 4d459c17f1..cf86934968 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -55,7 +55,7 @@ class EventValidator:
]
for k in required:
- if not hasattr(event, k):
+ if k not in event:
raise SynapseError(400, "Event does not have key %s" % (k,))
# Check that the following keys have string values
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 32a75993d9..9a8758e9a6 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -213,6 +213,11 @@ class FederationServer(FederationBase):
self._started_handling_of_staged_events = True
self._handle_old_staged_events()
+ # Start a periodic check for old staged events. This is to handle
+ # the case where locks time out, e.g. if another process gets killed
+ # without dropping its locks.
+ self._clock.looping_call(self._handle_old_staged_events, 60 * 1000)
+
# keep this as early as possible to make the calculated origin ts as
# accurate as possible.
request_time = self._clock.time_msec()
@@ -1232,10 +1237,6 @@ class FederationHandlerRegistry:
self.query_handlers[query_type] = handler
- def register_instance_for_edu(self, edu_type: str, instance_name: str) -> None:
- """Register that the EDU handler is on a different instance than master."""
- self._edu_type_to_instance[edu_type] = [instance_name]
-
def register_instances_for_edu(
self, edu_type: str, instance_names: List[str]
) -> None:
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index d963178838..10b5aa5af8 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1310,14 +1310,17 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
self._coro_state = ijson.items_coro(
_event_list_parser(room_version, self._response.state),
prefix + "state.item",
+ use_float=True,
)
self._coro_auth = ijson.items_coro(
_event_list_parser(room_version, self._response.auth_events),
prefix + "auth_chain.item",
+ use_float=True,
)
self._coro_event = ijson.kvitems_coro(
_event_parser(self._response.event_dict),
prefix + "org.matrix.msc3083.v2.event",
+ use_float=True,
)
def write(self, data: bytes) -> int:
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 36c206dae6..ddc9105ee9 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -34,6 +34,7 @@ from synapse.metrics.background_process_metrics import (
)
from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID
+from synapse.util.async_helpers import Linearizer
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@@ -58,6 +59,10 @@ class ApplicationServicesHandler:
self.current_max = 0
self.is_processing = False
+ self._ephemeral_events_linearizer = Linearizer(
+ name="appservice_ephemeral_events"
+ )
+
def notify_interested_services(self, max_token: RoomStreamToken) -> None:
"""Notifies (pushes) all application services interested in this event.
@@ -182,7 +187,7 @@ class ApplicationServicesHandler:
def notify_interested_services_ephemeral(
self,
stream_key: str,
- new_token: Optional[int],
+ new_token: Union[int, RoomStreamToken],
users: Optional[Collection[Union[str, UserID]]] = None,
) -> None:
"""
@@ -203,7 +208,7 @@ class ApplicationServicesHandler:
Appservices will only receive ephemeral events that fall within their
registered user and room namespaces.
- new_token: The latest stream token.
+ new_token: The stream token of the event.
users: The users that should be informed of the new event, if any.
"""
if not self.notify_appservices:
@@ -212,6 +217,19 @@ class ApplicationServicesHandler:
if stream_key not in ("typing_key", "receipt_key", "presence_key"):
return
+ # Assert that new_token is an integer (and not a RoomStreamToken).
+ # All of the supported streams that this function handles use an
+ # integer to track progress (rather than a RoomStreamToken - a
+ # vector clock implementation) as they don't support multiple
+ # stream writers.
+ #
+ # As a result, we simply assert that new_token is an integer.
+ # If we do end up needing to pass a RoomStreamToken down here
+ # in the future, using RoomStreamToken.stream (the minimum stream
+ # position) to convert to an ascending integer value should work.
+ # Additional context: https://github.com/matrix-org/synapse/pull/11137
+ assert isinstance(new_token, int)
+
services = [
service
for service in self.store.get_app_services()
@@ -231,14 +249,13 @@ class ApplicationServicesHandler:
self,
services: List[ApplicationService],
stream_key: str,
- new_token: Optional[int],
+ new_token: int,
users: Collection[Union[str, UserID]],
) -> None:
logger.debug("Checking interested services for %s" % (stream_key))
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
- # Only handle typing if we have the latest token
- if stream_key == "typing_key" and new_token is not None:
+ if stream_key == "typing_key":
# Note that we don't persist the token (via set_type_stream_id_for_appservice)
# for typing_key due to performance reasons and due to their highly
# ephemeral nature.
@@ -248,26 +265,37 @@ class ApplicationServicesHandler:
events = await self._handle_typing(service, new_token)
if events:
self.scheduler.submit_ephemeral_events_for_as(service, events)
+ continue
- elif stream_key == "receipt_key":
- events = await self._handle_receipts(service)
- if events:
- self.scheduler.submit_ephemeral_events_for_as(service, events)
-
- # Persist the latest handled stream token for this appservice
- await self.store.set_type_stream_id_for_appservice(
- service, "read_receipt", new_token
+ # Since we read/update the stream position for this AS/stream
+ with (
+ await self._ephemeral_events_linearizer.queue(
+ (service.id, stream_key)
)
+ ):
+ if stream_key == "receipt_key":
+ events = await self._handle_receipts(service, new_token)
+ if events:
+ self.scheduler.submit_ephemeral_events_for_as(
+ service, events
+ )
+
+ # Persist the latest handled stream token for this appservice
+ await self.store.set_type_stream_id_for_appservice(
+ service, "read_receipt", new_token
+ )
- elif stream_key == "presence_key":
- events = await self._handle_presence(service, users)
- if events:
- self.scheduler.submit_ephemeral_events_for_as(service, events)
+ elif stream_key == "presence_key":
+ events = await self._handle_presence(service, users, new_token)
+ if events:
+ self.scheduler.submit_ephemeral_events_for_as(
+ service, events
+ )
- # Persist the latest handled stream token for this appservice
- await self.store.set_type_stream_id_for_appservice(
- service, "presence", new_token
- )
+ # Persist the latest handled stream token for this appservice
+ await self.store.set_type_stream_id_for_appservice(
+ service, "presence", new_token
+ )
async def _handle_typing(
self, service: ApplicationService, new_token: int
@@ -304,7 +332,9 @@ class ApplicationServicesHandler:
)
return typing
- async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
+ async def _handle_receipts(
+ self, service: ApplicationService, new_token: Optional[int]
+ ) -> List[JsonDict]:
"""
Return the latest read receipts that the given application service should receive.
@@ -323,6 +353,12 @@ class ApplicationServicesHandler:
from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt"
)
+ if new_token is not None and new_token <= from_key:
+ logger.debug(
+ "Rejecting token lower than or equal to stored: %s" % (new_token,)
+ )
+ return []
+
receipts_source = self.event_sources.sources.receipt
receipts, _ = await receipts_source.get_new_events_as(
service=service, from_key=from_key
@@ -330,7 +366,10 @@ class ApplicationServicesHandler:
return receipts
async def _handle_presence(
- self, service: ApplicationService, users: Collection[Union[str, UserID]]
+ self,
+ service: ApplicationService,
+ users: Collection[Union[str, UserID]],
+ new_token: Optional[int],
) -> List[JsonDict]:
"""
Return the latest presence updates that the given application service should receive.
@@ -353,6 +392,12 @@ class ApplicationServicesHandler:
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
)
+ if new_token is not None and new_token <= from_key:
+ logger.debug(
+ "Rejecting token lower than or equal to stored: %s" % (new_token,)
+ )
+ return []
+
for user in users:
if isinstance(user, str):
user = UserID.from_string(user)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index d508d7d32a..60e59d11a0 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1989,7 +1989,9 @@ class PasswordAuthProvider:
self,
check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
- auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None,
+ auth_checkers: Optional[
+ Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
+ ] = None,
) -> None:
# Register check_3pid_auth callback
if check_3pid_auth is not None:
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index e617db4c0d..1a1cd93b1a 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -1643,7 +1643,7 @@ class FederationEventHandler:
event: the event whose auth_events we want
Returns:
- all of the events in `event.auth_events`, after deduplication
+ all of the events listed in `event.auth_events_ids`, after deduplication
Raises:
AuthError if we were unable to fetch the auth_events for any reason.
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 4a0fccfcc6..b7bc187169 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1318,6 +1318,8 @@ class EventCreationHandler:
# user is actually admin or not).
is_admin_redaction = False
if event.type == EventTypes.Redaction:
+ assert event.redacts is not None
+
original_event = await self.store.get_event(
event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS,
@@ -1413,6 +1415,8 @@ class EventCreationHandler:
)
if event.type == EventTypes.Redaction:
+ assert event.redacts is not None
+
original_event = await self.store.get_event(
event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS,
@@ -1500,11 +1504,13 @@ class EventCreationHandler:
next_batch_id = event.content.get(
EventContentFields.MSC2716_NEXT_BATCH_ID
)
- conflicting_insertion_event_id = (
- await self.store.get_insertion_event_by_batch_id(
- event.room_id, next_batch_id
+ conflicting_insertion_event_id = None
+ if next_batch_id:
+ conflicting_insertion_event_id = (
+ await self.store.get_insertion_event_by_batch_id(
+ event.room_id, next_batch_id
+ )
)
- )
if conflicting_insertion_event_id is not None:
# The current insertion event that we're processing is invalid
# because an insertion event already exists in the room with the
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 99e9b37344..969eb3b9b0 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -525,7 +525,7 @@ class RoomCreationHandler:
):
await self.room_member_handler.update_membership(
requester,
- UserID.from_string(old_event["state_key"]),
+ UserID.from_string(old_event.state_key),
new_room_id,
"ban",
ratelimit=False,
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 2f5a3e4d19..0723286383 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -355,7 +355,7 @@ class RoomBatchHandler:
for (event, context) in reversed(events_to_persist):
await self.event_creation_handler.handle_new_client_event(
await self.create_requester_for_user_id_from_app_service(
- event["sender"], app_service_requester.app_service
+ event.sender, app_service_requester.app_service
),
event=event,
context=context,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 74e6c7eca6..08244b690d 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1669,7 +1669,9 @@ class RoomMemberMasterHandler(RoomMemberHandler):
#
# the prev_events consist solely of the previous membership event.
prev_event_ids = [previous_membership_event.event_id]
- auth_event_ids = previous_membership_event.auth_event_ids() + prev_event_ids
+ auth_event_ids = (
+ list(previous_membership_event.auth_event_ids()) + prev_event_ids
+ )
event, context = await self.event_creation_handler.create_event(
requester,
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index c411d69924..22c6174821 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -62,8 +62,8 @@ class FollowerTypingHandler:
if hs.should_send_federation():
self.federation = hs.get_federation_sender()
- if hs.config.worker.writers.typing != hs.get_instance_name():
- hs.get_federation_registry().register_instance_for_edu(
+ if hs.get_instance_name() not in hs.config.worker.writers.typing:
+ hs.get_federation_registry().register_instances_for_edu(
"m.typing",
hs.config.worker.writers.typing,
)
@@ -205,7 +205,7 @@ class TypingWriterHandler(FollowerTypingHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- assert hs.config.worker.writers.typing == hs.get_instance_name()
+ assert hs.get_instance_name() in hs.config.worker.writers.typing
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 1882fffd2a..60e5409895 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -383,29 +383,6 @@ class Notifier:
except Exception:
logger.exception("Error notifying application services of event")
- def _notify_app_services_ephemeral(
- self,
- stream_key: str,
- new_token: Union[int, RoomStreamToken],
- users: Optional[Collection[Union[str, UserID]]] = None,
- ) -> None:
- """Notify application services of ephemeral event activity.
-
- Args:
- stream_key: The stream the event came from.
- new_token: The value of the new stream token.
- users: The users that should be informed of the new event, if any.
- """
- try:
- stream_token = None
- if isinstance(new_token, int):
- stream_token = new_token
- self.appservice_handler.notify_interested_services_ephemeral(
- stream_key, stream_token, users or []
- )
- except Exception:
- logger.exception("Error notifying application services of event")
-
def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken):
try:
self._pusher_pool.on_new_notifications(max_room_stream_token)
@@ -467,12 +444,15 @@ class Notifier:
self.notify_replication()
- # Notify appservices
- self._notify_app_services_ephemeral(
- stream_key,
- new_token,
- users,
- )
+ # Notify appservices.
+ try:
+ self.appservice_handler.notify_interested_services_ephemeral(
+ stream_key,
+ new_token,
+ users,
+ )
+ except Exception:
+ logger.exception("Error notifying application services of event")
def on_new_replication_data(self) -> None:
"""Used to inform replication listeners that something has happened
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 0622a37ae8..009d8e77b0 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -232,6 +232,8 @@ class BulkPushRuleEvaluator:
# that user, as they might not be already joined.
if event.type == EventTypes.Member and event.state_key == uid:
display_name = event.content.get("displayname", None)
+ if not isinstance(display_name, str):
+ display_name = None
if count_as_unread:
# Add an element for the current user if the event needs to be marked as
@@ -268,7 +270,7 @@ def _condition_checker(
evaluator: PushRuleEvaluatorForEvent,
conditions: List[dict],
uid: str,
- display_name: str,
+ display_name: Optional[str],
cache: Dict[str, bool],
) -> bool:
for cond in conditions:
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 7a8dc63976..7f68092ec5 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -18,7 +18,7 @@ import re
from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
from synapse.events import EventBase
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
from synapse.util import glob_to_regex, re_word_boundary
from synapse.util.caches.lrucache import LruCache
@@ -129,7 +129,7 @@ class PushRuleEvaluatorForEvent:
self._value_cache = _flatten_dict(event)
def matches(
- self, condition: Dict[str, Any], user_id: str, display_name: str
+ self, condition: Dict[str, Any], user_id: str, display_name: Optional[str]
) -> bool:
if condition["kind"] == "event_match":
return self._event_match(condition, user_id)
@@ -172,7 +172,7 @@ class PushRuleEvaluatorForEvent:
return _glob_matches(pattern, haystack)
- def _contains_display_name(self, display_name: str) -> bool:
+ def _contains_display_name(self, display_name: Optional[str]) -> bool:
if not display_name:
return False
@@ -222,7 +222,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
def _flatten_dict(
- d: Union[EventBase, dict],
+ d: Union[EventBase, JsonDict],
prefix: Optional[List[str]] = None,
result: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
@@ -233,7 +233,7 @@ def _flatten_dict(
for key, value in d.items():
if isinstance(value, str):
result[".".join(prefix + [key])] = value.lower()
- elif hasattr(value, "items"):
+ elif isinstance(value, dict):
_flatten_dict(value, prefix=(prefix + [key]), result=result)
return result
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 06fd06fdf3..21293038ef 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -138,7 +138,7 @@ class ReplicationCommandHandler:
if isinstance(stream, TypingStream):
# Only add TypingStream as a source on the instance in charge of
# typing.
- if hs.config.worker.writers.typing == hs.get_instance_name():
+ if hs.get_instance_name() in hs.config.worker.writers.typing:
self._streams_to_replicate.append(stream)
continue
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index c8b188ae4e..743a01da08 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -328,8 +328,7 @@ class TypingStream(Stream):
ROW_TYPE = TypingStreamRow
def __init__(self, hs: "HomeServer"):
- writer_instance = hs.config.worker.writers.typing
- if writer_instance == hs.get_instance_name():
+ if hs.get_instance_name() in hs.config.worker.writers.typing:
# On the writer, query the typing handler
typing_writer_handler = hs.get_typing_writer_handler()
update_function: Callable[
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index e1506deb2b..70514e814f 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -42,7 +42,6 @@ from synapse.rest.admin.registration_tokens import (
RegistrationTokenRestServlet,
)
from synapse.rest.admin.rooms import (
- DeleteRoomRestServlet,
ForwardExtremitiesRestServlet,
JoinRoomAliasServlet,
ListRoomRestServlet,
@@ -221,7 +220,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomStateRestServlet(hs).register(http_server)
RoomRestServlet(hs).register(http_server)
RoomMembersRestServlet(hs).register(http_server)
- DeleteRoomRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
UserAdminServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index a4823ca6e7..05c5b4bf0c 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -46,41 +46,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class DeleteRoomRestServlet(RestServlet):
- """Delete a room from server.
-
- It is a combination and improvement of shutdown and purge room.
-
- Shuts down a room by removing all local users from the room.
- Blocking all future invites and joins to the room is optional.
-
- If desired any local aliases will be repointed to a new room
- created by `new_room_user_id` and kicked users will be auto-
- joined to the new room.
-
- If 'purge' is true, it will remove all traces of a room from the database.
- """
-
- PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
-
- def __init__(self, hs: "HomeServer"):
- self.hs = hs
- self.auth = hs.get_auth()
- self.room_shutdown_handler = hs.get_room_shutdown_handler()
- self.pagination_handler = hs.get_pagination_handler()
-
- async def on_POST(
- self, request: SynapseRequest, room_id: str
- ) -> Tuple[int, JsonDict]:
- return await _delete_room(
- request,
- room_id,
- self.auth,
- self.room_shutdown_handler,
- self.pagination_handler,
- )
-
-
class ListRoomRestServlet(RestServlet):
"""
List all rooms that are known to the homeserver. Results are returned
@@ -218,7 +183,7 @@ class RoomRestServlet(RestServlet):
async def on_DELETE(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
- return await _delete_room(
+ return await self._delete_room(
request,
room_id,
self.auth,
@@ -226,6 +191,58 @@ class RoomRestServlet(RestServlet):
self.pagination_handler,
)
+ async def _delete_room(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ auth: "Auth",
+ room_shutdown_handler: "RoomShutdownHandler",
+ pagination_handler: "PaginationHandler",
+ ) -> Tuple[int, JsonDict]:
+ requester = await auth.get_user_by_req(request)
+ await assert_user_is_admin(auth, requester.user)
+
+ content = parse_json_object_from_request(request)
+
+ block = content.get("block", False)
+ if not isinstance(block, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Param 'block' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
+ purge = content.get("purge", True)
+ if not isinstance(purge, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Param 'purge' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
+ force_purge = content.get("force_purge", False)
+ if not isinstance(force_purge, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Param 'force_purge' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
+ ret = await room_shutdown_handler.shutdown_room(
+ room_id=room_id,
+ new_room_user_id=content.get("new_room_user_id"),
+ new_room_name=content.get("room_name"),
+ message=content.get("message"),
+ requester_user_id=requester.user.to_string(),
+ block=block,
+ )
+
+ # Purge room
+ if purge:
+ await pagination_handler.purge_room(room_id, force=force_purge)
+
+ return 200, ret
+
class RoomMembersRestServlet(RestServlet):
"""
@@ -617,55 +634,3 @@ class RoomEventContextServlet(RestServlet):
)
return 200, results
-
-
-async def _delete_room(
- request: SynapseRequest,
- room_id: str,
- auth: "Auth",
- room_shutdown_handler: "RoomShutdownHandler",
- pagination_handler: "PaginationHandler",
-) -> Tuple[int, JsonDict]:
- requester = await auth.get_user_by_req(request)
- await assert_user_is_admin(auth, requester.user)
-
- content = parse_json_object_from_request(request)
-
- block = content.get("block", False)
- if not isinstance(block, bool):
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "Param 'block' must be a boolean, if given",
- Codes.BAD_JSON,
- )
-
- purge = content.get("purge", True)
- if not isinstance(purge, bool):
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "Param 'purge' must be a boolean, if given",
- Codes.BAD_JSON,
- )
-
- force_purge = content.get("force_purge", False)
- if not isinstance(force_purge, bool):
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "Param 'force_purge' must be a boolean, if given",
- Codes.BAD_JSON,
- )
-
- ret = await room_shutdown_handler.shutdown_room(
- room_id=room_id,
- new_room_user_id=content.get("new_room_user_id"),
- new_room_name=content.get("room_name"),
- message=content.get("message"),
- requester_user_id=requester.user.to_string(),
- block=block,
- )
-
- # Purge room
- if purge:
- await pagination_handler.purge_room(room_id, force=force_purge)
-
- return 200, ret
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index ed95189b6d..6a876cfa2f 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -914,7 +914,7 @@ class RoomTypingRestServlet(RestServlet):
# If we're not on the typing writer instance we should scream if we get
# requests.
self._is_typing_writer = (
- hs.config.worker.writers.typing == hs.get_instance_name()
+ hs.get_instance_name() in hs.config.worker.writers.typing
)
async def on_PUT(
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index 99f8156ad0..46f033eee2 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -131,20 +131,22 @@ class RoomBatchSendEventRestServlet(RestServlet):
prev_event_ids_from_query
)
+ state_event_ids_at_start = []
# Create and persist all of the state events that float off on their own
# before the batch. These will most likely be all of the invite/member
# state events used to auth the upcoming historical messages.
- state_event_ids_at_start = (
- await self.room_batch_handler.persist_state_events_at_start(
- state_events_at_start=body["state_events_at_start"],
- room_id=room_id,
- initial_auth_event_ids=auth_event_ids,
- app_service_requester=requester,
+ if body["state_events_at_start"]:
+ state_event_ids_at_start = (
+ await self.room_batch_handler.persist_state_events_at_start(
+ state_events_at_start=body["state_events_at_start"],
+ room_id=room_id,
+ initial_auth_event_ids=auth_event_ids,
+ app_service_requester=requester,
+ )
)
- )
- # Update our ongoing auth event ID list with all of the new state we
- # just created
- auth_event_ids.extend(state_event_ids_at_start)
+ # Update our ongoing auth event ID list with all of the new state we
+ # just created
+ auth_event_ids.extend(state_event_ids_at_start)
inherited_depth = await self.room_batch_handler.inherit_depth_from_prev_ids(
prev_event_ids_from_query
@@ -191,14 +193,17 @@ class RoomBatchSendEventRestServlet(RestServlet):
depth=inherited_depth,
)
- batch_id_to_connect_to = base_insertion_event["content"][
+ batch_id_to_connect_to = base_insertion_event.content[
EventContentFields.MSC2716_NEXT_BATCH_ID
]
# Also connect the historical event chain to the end of the floating
# state chain, which causes the HS to ask for the state at the start of
- # the batch later.
- prev_event_ids = [state_event_ids_at_start[-1]]
+ # the batch later. If there is no state chain to connect to, just make
+ # the insertion event float itself.
+ prev_event_ids = []
+ if len(state_event_ids_at_start):
+ prev_event_ids = [state_event_ids_at_start[-1]]
# Create and persist all of the historical events as well as insertion
# and batch meta events to make the batch navigable in the DAG.
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index abd88a2d4f..244ba261bb 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -215,6 +215,8 @@ class MediaRepository:
self.mark_recently_accessed(None, media_id)
media_type = media_info["media_type"]
+ if not media_type:
+ media_type = "application/octet-stream"
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
url_cache = media_info["url_cache"]
@@ -333,6 +335,9 @@ class MediaRepository:
logger.info("Media is quarantined")
raise NotFoundError()
+ if not media_info["media_type"]:
+ media_info["media_type"] = "application/octet-stream"
+
responder = await self.media_storage.fetch_media(file_info)
if responder:
return responder, media_info
@@ -354,6 +359,8 @@ class MediaRepository:
raise e
file_id = media_info["filesystem_id"]
+ if not media_info["media_type"]:
+ media_info["media_type"] = "application/octet-stream"
file_info = FileInfo(server_name, file_id)
# We generate thumbnails even if another process downloaded the media
@@ -445,7 +452,10 @@ class MediaRepository:
await finish()
- media_type = headers[b"Content-Type"][0].decode("ascii")
+ if b"Content-Type" in headers:
+ media_type = headers[b"Content-Type"][0].decode("ascii")
+ else:
+ media_type = "application/octet-stream"
upload_name = get_filename_from_headers(headers)
time_now_ms = self.clock.time_msec()
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 278fd901e2..8ca97b5b18 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -718,9 +718,12 @@ def decode_body(
if not body:
return None
+ # The idea here is that multiple encodings are tried until one works.
+ # Unfortunately the result is never used and then LXML will decode the string
+ # again with the found encoding.
for encoding in get_html_media_encodings(body, content_type):
try:
- body_str = body.decode(encoding)
+ body.decode(encoding)
except Exception:
pass
else:
@@ -732,11 +735,11 @@ def decode_body(
from lxml import etree
# Create an HTML parser.
- parser = etree.HTMLParser(recover=True, encoding="utf-8")
+ parser = etree.HTMLParser(recover=True, encoding=encoding)
# Attempt to parse the body. Returns None if the body was successfully
# parsed, but no tree was found.
- return etree.fromstring(body_str, parser)
+ return etree.fromstring(body, parser)
def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 7dcb1428e4..8162094cf6 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -80,7 +80,7 @@ class UploadResource(DirectServeJsonResource):
assert content_type_headers # for mypy
media_type = content_type_headers[0].decode("ascii")
else:
- raise SynapseError(msg="Upload request missing 'Content-Type'", code=400)
+ media_type = "application/octet-stream"
# if headers.hasHeader(b"Content-Disposition"):
# disposition = headers.getRawHeaders(b"Content-Disposition")[0]
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 7ac01faab4..edbf5ce5d0 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -21,6 +21,7 @@ from twisted.web.server import Request
from synapse.http.server import set_cors_headers
from synapse.types import JsonDict
from synapse.util import json_encoder
+from synapse.util.stringutils import parse_server_name
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -47,8 +48,8 @@ class WellKnownBuilder:
return result
-class WellKnownResource(Resource):
- """A Twisted web resource which renders the .well-known file"""
+class ClientWellKnownResource(Resource):
+ """A Twisted web resource which renders the .well-known/matrix/client file"""
isLeaf = 1
@@ -67,3 +68,45 @@ class WellKnownResource(Resource):
logger.debug("returning: %s", r)
request.setHeader(b"Content-Type", b"application/json")
return json_encoder.encode(r).encode("utf-8")
+
+
+class ServerWellKnownResource(Resource):
+ """Resource for .well-known/matrix/server, redirecting to port 443"""
+
+ isLeaf = 1
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._serve_server_wellknown = hs.config.server.serve_server_wellknown
+
+ host, port = parse_server_name(hs.config.server.server_name)
+
+ # If we've got this far, then https://<server_name>/ must route to us, so
+ # we just redirect the traffic to port 443 instead of 8448.
+ if port is None:
+ port = 443
+
+ self._response = json_encoder.encode({"m.server": f"{host}:{port}"}).encode(
+ "utf-8"
+ )
+
+ def render_GET(self, request: Request) -> bytes:
+ if not self._serve_server_wellknown:
+ request.setResponseCode(404)
+ request.setHeader(b"Content-Type", b"text/plain")
+ return b"404. Is anything ever truly *well* known?\n"
+
+ request.setHeader(b"Content-Type", b"application/json")
+ return self._response
+
+
+def well_known_resource(hs: "HomeServer") -> Resource:
+ """Returns a Twisted web resource which handles '.well-known' requests"""
+ res = Resource()
+ matrix_resource = Resource()
+ res.putChild(b"matrix", matrix_resource)
+
+ matrix_resource.putChild(b"server", ServerWellKnownResource(hs))
+ matrix_resource.putChild(b"client", ClientWellKnownResource(hs))
+
+ return res
diff --git a/synapse/server.py b/synapse/server.py
index 0fbf36ba99..013a7bacaa 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -463,7 +463,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_typing_writer_handler(self) -> TypingWriterHandler:
- if self.config.worker.writers.typing == self.get_instance_name():
+ if self.get_instance_name() in self.config.worker.writers.typing:
return TypingWriterHandler(self)
else:
raise Exception("Workers cannot write typing")
@@ -474,7 +474,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_typing_handler(self) -> FollowerTypingHandler:
- if self.config.worker.writers.typing == self.get_instance_name():
+ if self.get_instance_name() in self.config.worker.writers.typing:
# Use get_typing_writer_handler to ensure that we use the same
# cached version.
return self.get_typing_writer_handler()
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 98a0239759..1605411b00 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -247,7 +247,7 @@ class StateHandler:
return await self.get_hosts_in_room_at_events(room_id, event_ids)
async def get_hosts_in_room_at_events(
- self, room_id: str, event_ids: List[str]
+ self, room_id: str, event_ids: Iterable[str]
) -> Set[str]:
"""Get the hosts that were in a room at the given event ids
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 25e9c1efe1..264e625bd7 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -561,6 +561,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox"
+ REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
@@ -581,6 +582,11 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
self._remove_deleted_devices_from_device_inbox,
)
+ self.db_pool.updates.register_background_update_handler(
+ self.REMOVE_HIDDEN_DEVICES,
+ self._remove_hidden_devices_from_device_inbox,
+ )
+
async def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
@@ -676,6 +682,89 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
return number_deleted
+ async def _remove_hidden_devices_from_device_inbox(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """A background update that deletes all device_inboxes for hidden devices.
+
+ This should only need to be run once (when users upgrade to v1.47.0)
+
+ Args:
+ progress: JsonDict used to store progress of this background update
+ batch_size: the maximum number of rows to retrieve in a single select query
+
+ Returns:
+ The number of deleted rows
+ """
+
+ def _remove_hidden_devices_from_device_inbox_txn(
+ txn: LoggingTransaction,
+ ) -> int:
+ """stream_id is not unique
+ we need to use an inclusive `stream_id >= ?` clause,
+ since we might not have deleted all hidden device messages for the stream_id
+ returned from the previous query
+
+ Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
+ to avoid problems of deleting a large number of rows all at once
+ due to a single device having lots of device messages.
+ """
+
+ last_stream_id = progress.get("stream_id", 0)
+
+ sql = """
+ SELECT device_id, user_id, stream_id
+ FROM device_inbox
+ WHERE
+ stream_id >= ?
+ AND (device_id, user_id) IN (
+ SELECT device_id, user_id FROM devices WHERE hidden = ?
+ )
+ ORDER BY stream_id
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_stream_id, True, batch_size))
+ rows = txn.fetchall()
+
+ num_deleted = 0
+ for row in rows:
+ num_deleted += self.db_pool.simple_delete_txn(
+ txn,
+ "device_inbox",
+ {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
+ )
+
+ if rows:
+ # We don't just save the `stream_id` in progress as
+ # otherwise it can happen in large deployments that
+ # no change of status is visible in the log file, as
+ # it may be that the stream_id does not change in several runs
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ self.REMOVE_HIDDEN_DEVICES,
+ {
+ "device_id": rows[-1][0],
+ "user_id": rows[-1][1],
+ "stream_id": rows[-1][2],
+ },
+ )
+
+ return num_deleted
+
+ number_deleted = await self.db_pool.runInteraction(
+ "_remove_hidden_devices_from_device_inbox",
+ _remove_hidden_devices_from_device_inbox_txn,
+ )
+
+ # The task is finished when no more lines are deleted.
+ if not number_deleted:
+ await self.db_pool.updates._end_background_update(
+ self.REMOVE_HIDDEN_DEVICES
+ )
+
+ return number_deleted
+
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
pass
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index b15cd030e0..9ccc66e589 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -427,7 +427,7 @@ class DeviceWorkerStore(SQLBaseStore):
user_ids: the users who were signed
Returns:
- THe new stream ID.
+ The new stream ID.
"""
async with self._device_list_id_gen.get_next() as stream_id:
@@ -1322,7 +1322,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
async def add_device_change_to_streams(
self, user_id: str, device_ids: Collection[str], hosts: List[str]
- ):
+ ) -> int:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 8d9086ecf0..596275c23c 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -24,6 +24,7 @@ from typing import (
Iterable,
List,
Optional,
+ Sequence,
Set,
Tuple,
)
@@ -494,7 +495,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
- event_to_auth_chain: Dict[str, List[str]],
+ event_to_auth_chain: Dict[str, Sequence[str]],
) -> None:
"""Calculate the chain cover index for the given events.
@@ -786,7 +787,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
- event_to_auth_chain: Dict[str, List[str]],
+ event_to_auth_chain: Dict[str, Sequence[str]],
events_to_calc_chain_id_for: Set[str],
chain_map: Dict[str, Tuple[int, int]],
) -> Dict[str, Tuple[int, int]]:
@@ -1794,7 +1795,7 @@ class PersistEventsStore:
)
# Insert an edge for every prev_event connection
- for prev_event_id in event.prev_events:
+ for prev_event_id in event.prev_event_ids():
self.db_pool.simple_insert_txn(
txn,
table="insertion_event_edges",
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index ae37901be9..c6bf316d5b 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -28,6 +28,7 @@ from typing import (
import attr
from constantly import NamedConstant, Names
+from prometheus_client import Gauge
from typing_extensions import Literal
from twisted.internet import defer
@@ -81,6 +82,12 @@ EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
+event_fetch_ongoing_gauge = Gauge(
+ "synapse_event_fetch_ongoing",
+ "The number of event fetchers that are running",
+)
+
+
@attr.s(slots=True, auto_attribs=True)
class _EventCacheEntry:
event: EventBase
@@ -222,6 +229,7 @@ class EventsWorkerStore(SQLBaseStore):
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
@@ -732,28 +740,31 @@ class EventsWorkerStore(SQLBaseStore):
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
- i = 0
- while True:
- with self._event_fetch_lock:
- event_list = self._event_fetch_list
- self._event_fetch_list = []
-
- if not event_list:
- single_threaded = self.database_engine.single_threaded
- if (
- not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
- or single_threaded
- or i > EVENT_QUEUE_ITERATIONS
- ):
- self._event_fetch_ongoing -= 1
- return
- else:
- self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
- i += 1
- continue
- i = 0
-
- self._fetch_event_list(conn, event_list)
+ try:
+ i = 0
+ while True:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ single_threaded = self.database_engine.single_threaded
+ if (
+ not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+ or single_threaded
+ or i > EVENT_QUEUE_ITERATIONS
+ ):
+ break
+ else:
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
+
+ self._fetch_event_list(conn, event_list)
+ finally:
+ self._event_fetch_ongoing -= 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
def _fetch_event_list(
self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
@@ -977,6 +988,7 @@ class EventsWorkerStore(SQLBaseStore):
if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
self._event_fetch_ongoing += 1
+ event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
should_start = True
else:
should_start = False
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index e70d3649ff..bb621df0dd 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -13,15 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing_extensions import TypedDict
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool
+from synapse.storage.types import Connection
from synapse.types import JsonDict
from synapse.util import json_encoder
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
# The category ID for the "default" category. We don't store as null in the
# database to avoid the fun of null != null
_DEFAULT_CATEGORY_ID = ""
@@ -35,6 +40,16 @@ class _RoomInGroup(TypedDict):
class GroupServerWorkerStore(SQLBaseStore):
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ database.updates.register_background_index_update(
+ update_name="local_group_updates_index",
+ index_name="local_group_updates_stream_id_index",
+ table="local_group_updates",
+ columns=("stream_id",),
+ unique=True,
+ )
+ super().__init__(database, db_conn, hs)
+
async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="groups",
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index 3d1dff660b..3d0df0cbd4 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -14,6 +14,7 @@
import logging
from types import TracebackType
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type
+from weakref import WeakValueDictionary
from twisted.internet.interfaces import IReactorCore
@@ -61,7 +62,7 @@ class LockStore(SQLBaseStore):
# A map from `(lock_name, lock_key)` to the token of any locks that we
# think we currently hold.
- self._live_tokens: Dict[Tuple[str, str], str] = {}
+ self._live_tokens: Dict[Tuple[str, str], Lock] = WeakValueDictionary()
# When we shut down we want to remove the locks. Technically this can
# lead to a race, as we may drop the lock while we are still processing.
@@ -80,10 +81,10 @@ class LockStore(SQLBaseStore):
# We need to take a copy of the tokens dict as dropping the locks will
# cause the dictionary to change.
- tokens = dict(self._live_tokens)
+ locks = dict(self._live_tokens)
- for (lock_name, lock_key), token in tokens.items():
- await self._drop_lock(lock_name, lock_key, token)
+ for lock in locks.values():
+ await lock.release()
logger.info("Dropped locks due to shutdown")
@@ -93,6 +94,11 @@ class LockStore(SQLBaseStore):
used (otherwise the lock will leak).
"""
+ # Check if this process has taken out a lock and if it's still valid.
+ lock = self._live_tokens.get((lock_name, lock_key))
+ if lock and await lock.is_still_valid():
+ return None
+
now = self._clock.time_msec()
token = random_string(6)
@@ -100,7 +106,9 @@ class LockStore(SQLBaseStore):
def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool:
# We take out the lock if either a) there is no row for the lock
- # already or b) the existing row has timed out.
+ # already, b) the existing row has timed out, or c) the row is
+ # for this instance (which means the process got killed and
+ # restarted)
sql = """
INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts)
VALUES (?, ?, ?, ?, ?)
@@ -112,6 +120,7 @@ class LockStore(SQLBaseStore):
last_renewed_ts = EXCLUDED.last_renewed_ts
WHERE
worker_locks.last_renewed_ts < ?
+ OR worker_locks.instance_name = EXCLUDED.instance_name
"""
txn.execute(
sql,
@@ -148,11 +157,11 @@ class LockStore(SQLBaseStore):
WHERE
lock_name = ?
AND lock_key = ?
- AND last_renewed_ts < ?
+ AND (last_renewed_ts < ? OR instance_name = ?)
"""
txn.execute(
sql,
- (lock_name, lock_key, now - _LOCK_TIMEOUT_MS),
+ (lock_name, lock_key, now - _LOCK_TIMEOUT_MS, self._instance_name),
)
inserted = self.db_pool.simple_upsert_txn_emulated(
@@ -179,9 +188,7 @@ class LockStore(SQLBaseStore):
if not did_lock:
return None
- self._live_tokens[(lock_name, lock_key)] = token
-
- return Lock(
+ lock = Lock(
self._reactor,
self._clock,
self,
@@ -190,6 +197,10 @@ class LockStore(SQLBaseStore):
token=token,
)
+ self._live_tokens[(lock_name, lock_key)] = lock
+
+ return lock
+
async def _is_lock_still_valid(
self, lock_name: str, lock_key: str, token: str
) -> bool:
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 12cf6995eb..cc0eebdb46 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -92,7 +92,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
prefilled_cache=presence_cache_prefill,
)
- async def update_presence(self, presence_states):
+ async def update_presence(self, presence_states) -> Tuple[int, int]:
assert self._can_persist_presence
stream_ordering_manager = self._presence_id_gen.get_next_mult(
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index f879bbe7c7..cefc77fa0f 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -412,22 +412,33 @@ class RoomWorkerStore(SQLBaseStore):
limit: maximum amount of rooms to retrieve
order_by: the sort order of the returned list
reverse_order: whether to reverse the room list
- search_term: a string to filter room names by
+ search_term: a string to filter room names,
+ canonical alias and room ids by.
+ Room ID must match exactly. Canonical alias must match a substring of the local part.
Returns:
A list of room dicts and an integer representing the total number of
rooms that exist given this query
"""
# Filter room names by a string
where_statement = ""
+ search_pattern = []
if search_term:
- where_statement = "WHERE LOWER(state.name) LIKE ?"
+ where_statement = """
+ WHERE LOWER(state.name) LIKE ?
+ OR LOWER(state.canonical_alias) LIKE ?
+ OR state.room_id = ?
+ """
# Our postgres db driver converts ? -> %s in SQL strings as that's the
# placeholder for postgres.
# HOWEVER, if you put a % into your SQL then everything goes wibbly.
# To get around this, we're going to surround search_term with %'s
# before giving it to the database in python instead
- search_term = "%" + search_term.lower() + "%"
+ search_pattern = [
+ "%" + search_term.lower() + "%",
+ "#%" + search_term.lower() + "%:%",
+ search_term,
+ ]
# Set ordering
if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
@@ -519,12 +530,9 @@ class RoomWorkerStore(SQLBaseStore):
)
def _get_rooms_paginate_txn(txn):
- # Execute the data query
- sql_values = (limit, start)
- if search_term:
- # Add the search term into the WHERE clause
- sql_values = (search_term,) + sql_values
- txn.execute(info_sql, sql_values)
+ # Add the search term into the WHERE clause
+ # and execute the data query
+ txn.execute(info_sql, search_pattern + [limit, start])
# Refactor room query data into a structured dictionary
rooms = []
@@ -551,8 +559,7 @@ class RoomWorkerStore(SQLBaseStore):
# Execute the count query
# Add the search term into the WHERE clause if present
- sql_values = (search_term,) if search_term else ()
- txn.execute(count_sql, sql_values)
+ txn.execute(count_sql, search_pattern)
room_count = txn.fetchone()
return rooms, room_count[0]
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 4b288bb2e7..033a9831d6 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -570,7 +570,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
async def get_joined_users_from_context(
self, event: EventBase, context: EventContext
- ):
+ ) -> Dict[str, ProfileInfo]:
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -584,7 +584,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event.room_id, state_group, current_state_ids, event=event, context=context
)
- async def get_joined_users_from_state(self, room_id, state_entry):
+ async def get_joined_users_from_state(
+ self, room_id, state_entry
+ ) -> Dict[str, ProfileInfo]:
state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -607,7 +609,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
cache_context,
event=None,
context=None,
- ):
+ ) -> Dict[str, ProfileInfo]:
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
# with a state_group of None are likely to be different.
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 1629d2a53c..b5c1c14ee3 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -133,22 +133,23 @@ def prepare_database(
# if it's a worker app, refuse to upgrade the database, to avoid multiple
# workers doing it at once.
- if (
- config.worker.worker_app is not None
- and version_info.current_version != SCHEMA_VERSION
- ):
+ if config.worker.worker_app is None:
+ _upgrade_existing_database(
+ cur,
+ version_info,
+ database_engine,
+ config,
+ databases=databases,
+ )
+ elif version_info.current_version < SCHEMA_VERSION:
+ # If the DB is on an older version than we expect the we refuse
+ # to start the worker (as the main process needs to run first to
+ # update the schema).
raise UpgradeDatabaseException(
OUTDATED_SCHEMA_ON_WORKER_ERROR
% (SCHEMA_VERSION, version_info.current_version)
)
- _upgrade_existing_database(
- cur,
- version_info,
- database_engine,
- config,
- databases=databases,
- )
else:
logger.info("%r: Initialising new database", databases)
diff --git a/synapse/storage/schema/main/delta/65/03remove_hidden_devices_from_device_inbox.sql b/synapse/storage/schema/main/delta/65/03remove_hidden_devices_from_device_inbox.sql
new file mode 100644
index 0000000000..7b3592dcf0
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/03remove_hidden_devices_from_device_inbox.sql
@@ -0,0 +1,22 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- Remove messages from the device_inbox table which were orphaned
+-- because a device was hidden using Synapse earlier than 1.47.0.
+-- This runs as background task, but may take a bit to finish.
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (6503, 'remove_hidden_devices_from_device_inbox', '{}');
diff --git a/synapse/storage/schema/main/delta/65/04_local_group_updates.sql b/synapse/storage/schema/main/delta/65/04_local_group_updates.sql
new file mode 100644
index 0000000000..a178abfe12
--- /dev/null
+++ b/synapse/storage/schema/main/delta/65/04_local_group_updates.sql
@@ -0,0 +1,18 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Check index on `local_group_updates.stream_id`.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (6504, 'local_group_updates_index', '{}');
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 5df80ea8e7..96efc5f3e3 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -22,11 +22,11 @@ from typing import (
Any,
Awaitable,
Callable,
+ Collection,
Dict,
Generic,
Hashable,
Iterable,
- List,
Optional,
Set,
TypeVar,
@@ -76,12 +76,17 @@ class ObservableDeferred(Generic[_T]):
def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None)
- object.__setattr__(self, "_observers", set())
+ object.__setattr__(self, "_observers", [])
def callback(r):
object.__setattr__(self, "_result", (True, r))
- while self._observers:
- observer = self._observers.pop()
+
+ # once we have set _result, no more entries will be added to _observers,
+ # so it's safe to replace it with the empty tuple.
+ observers = self._observers
+ object.__setattr__(self, "_observers", ())
+
+ for observer in observers:
try:
observer.callback(r)
except Exception as e:
@@ -95,12 +100,16 @@ class ObservableDeferred(Generic[_T]):
def errback(f):
object.__setattr__(self, "_result", (False, f))
- while self._observers:
+
+ # once we have set _result, no more entries will be added to _observers,
+ # so it's safe to replace it with the empty tuple.
+ observers = self._observers
+ object.__setattr__(self, "_observers", ())
+
+ for observer in observers:
# This is a little bit of magic to correctly propagate stack
# traces when we `await` on one of the observer deferreds.
f.value.__failure__ = f
-
- observer = self._observers.pop()
try:
observer.errback(f)
except Exception as e:
@@ -127,20 +136,13 @@ class ObservableDeferred(Generic[_T]):
"""
if not self._result:
d: "defer.Deferred[_T]" = defer.Deferred()
-
- def remove(r):
- self._observers.discard(d)
- return r
-
- d.addBoth(remove)
-
- self._observers.add(d)
+ self._observers.append(d)
return d
else:
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)
- def observers(self) -> "List[defer.Deferred[_T]]":
+ def observers(self) -> "Collection[defer.Deferred[_T]]":
return self._observers
def has_called(self) -> bool:
|