diff --git a/changelog.d/16288.bugfix b/changelog.d/16288.bugfix
new file mode 100644
index 0000000000..f08d10d1f3
--- /dev/null
+++ b/changelog.d/16288.bugfix
@@ -0,0 +1 @@
+Fix bug introduced in Synapse 1.49.0 when using dehydrated devices ([MSC2697](https://github.com/matrix-org/matrix-spec-proposals/pull/2697)) and refresh tokens. Contributed by Hanadi.
diff --git a/changelog.d/16301.misc b/changelog.d/16301.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/16301.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/changelog.d/16304.doc b/changelog.d/16304.doc
new file mode 100644
index 0000000000..53660ec9a4
--- /dev/null
+++ b/changelog.d/16304.doc
@@ -0,0 +1 @@
+Link to the Alpine Linux community package for Synapse.
diff --git a/changelog.d/16313.misc b/changelog.d/16313.misc
new file mode 100644
index 0000000000..4f266c1fb0
--- /dev/null
+++ b/changelog.d/16313.misc
@@ -0,0 +1 @@
+Delete device messages asynchronously and in staged batches using the task scheduler.
diff --git a/changelog.d/16314.misc b/changelog.d/16314.misc
new file mode 100644
index 0000000000..a32b07112a
--- /dev/null
+++ b/changelog.d/16314.misc
@@ -0,0 +1 @@
+Remove a reference cycle for in background processes.
diff --git a/changelog.d/16316.misc b/changelog.d/16316.misc
new file mode 100644
index 0000000000..aa0644f278
--- /dev/null
+++ b/changelog.d/16316.misc
@@ -0,0 +1 @@
+Refactor `get_user_by_id`.
diff --git a/changelog.d/16318.misc b/changelog.d/16318.misc
new file mode 100644
index 0000000000..1433a2f246
--- /dev/null
+++ b/changelog.d/16318.misc
@@ -0,0 +1 @@
+Speed up task to delete to-device messages.
diff --git a/docs/setup/installation.md b/docs/setup/installation.md
index 0357d2a0fb..1f13864a8f 100644
--- a/docs/setup/installation.md
+++ b/docs/setup/installation.md
@@ -155,6 +155,14 @@ sudo pip uninstall py-bcrypt
sudo pip install py-bcrypt
```
+#### Alpine Linux
+
+6543 maintains [Synapse packages for Alpine Linux](https://pkgs.alpinelinux.org/packages?name=synapse&branch=edge) in the community repository. Install with:
+
+```sh
+sudo apk add synapse
+```
+
#### Void Linux
Synapse can be found in the void repositories as
diff --git a/synapse/api/auth/internal.py b/synapse/api/auth/internal.py
index 6a5fd44ec0..a75f6f2cc4 100644
--- a/synapse/api/auth/internal.py
+++ b/synapse/api/auth/internal.py
@@ -268,7 +268,7 @@ class InternalAuth(BaseAuth):
stored_user = await self.store.get_user_by_id(user_id)
if not stored_user:
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
- if not stored_user["is_guest"]:
+ if not stored_user.is_guest:
raise InvalidClientTokenError(
"Guest access token used for regular user"
)
diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py
index ef5d3f9b81..31bb035cc8 100644
--- a/synapse/api/auth/msc3861_delegated.py
+++ b/synapse/api/auth/msc3861_delegated.py
@@ -300,7 +300,7 @@ class MSC3861DelegatedAuth(BaseAuth):
user_id = UserID(username, self._hostname)
# First try to find a user from the username claim
- user_info = await self.store.get_userinfo_by_id(user_id=user_id.to_string())
+ user_info = await self.store.get_user_by_id(user_id=user_id.to_string())
if user_info is None:
# If the user does not exist, we should create it on the fly
# TODO: we could use SCIM to provision users ahead of time and listen
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index a94b57a671..9ac7e4313e 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -27,9 +27,7 @@ from typing import (
Any,
Awaitable,
Callable,
- Collection,
Dict,
- Iterable,
List,
NoReturn,
Optional,
@@ -76,7 +74,7 @@ from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_
from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
load_legacy_third_party_event_rules,
)
-from synapse.types import ISynapseReactor
+from synapse.types import ISynapseReactor, StrCollection
from synapse.util import SYNAPSE_VERSION
from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
from synapse.util.daemonize import daemonize_process
@@ -278,7 +276,7 @@ def register_start(
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
-def listen_metrics(bind_addresses: Iterable[str], port: int) -> None:
+def listen_metrics(bind_addresses: StrCollection, port: int) -> None:
"""
Start Prometheus metrics server.
"""
@@ -315,7 +313,7 @@ def _set_prometheus_client_use_created_metrics(new_value: bool) -> None:
def listen_manhole(
- bind_addresses: Collection[str],
+ bind_addresses: StrCollection,
port: int,
manhole_settings: ManholeConfig,
manhole_globals: dict,
@@ -339,7 +337,7 @@ def listen_manhole(
def listen_tcp(
- bind_addresses: Collection[str],
+ bind_addresses: StrCollection,
port: int,
factory: ServerFactory,
reactor: IReactorTCP = reactor,
@@ -448,7 +446,7 @@ def listen_http(
def listen_ssl(
- bind_addresses: Collection[str],
+ bind_addresses: StrCollection,
port: int,
factory: ServerFactory,
context_factory: IOpenSSLContextFactory,
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 58856839e1..c5816105f4 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -26,7 +26,6 @@ from textwrap import dedent
from typing import (
Any,
ClassVar,
- Collection,
Dict,
Iterable,
Iterator,
@@ -384,7 +383,7 @@ class RootConfig:
config_classes: List[Type[Config]] = []
- def __init__(self, config_files: Collection[str] = ()):
+ def __init__(self, config_files: StrSequence = ()):
# Capture absolute paths here, so we can reload config after we daemonize.
self.config_files = [os.path.abspath(path) for path in config_files]
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 35257a3b1b..3c1777b7ec 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -25,7 +25,6 @@ from typing import (
Iterable,
List,
Optional,
- Sequence,
Tuple,
Type,
TypeVar,
@@ -408,7 +407,7 @@ class EventBase(metaclass=abc.ABCMeta):
def keys(self) -> Iterable[str]:
return self._dict.keys()
- def prev_event_ids(self) -> Sequence[str]:
+ def prev_event_ids(self) -> List[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
@@ -553,7 +552,7 @@ class FrozenEventV2(EventBase):
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
return self._event_id
- def prev_event_ids(self) -> Sequence[str]:
+ def prev_event_ids(self) -> List[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 14ea0e6640..1165c017ba 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import attr
from signedjson.types import SigningKey
@@ -28,7 +28,7 @@ from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
-from synapse.types import EventID, JsonDict
+from synapse.types import EventID, JsonDict, StrCollection
from synapse.types.state import StateFilter
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -103,7 +103,7 @@ class EventBuilder:
async def build(
self,
- prev_event_ids: Collection[str],
+ prev_event_ids: StrCollection,
auth_event_ids: Optional[List[str]],
depth: Optional[int] = None,
) -> EventBase:
@@ -136,7 +136,7 @@ class EventBuilder:
format_version = self.room_version.event_format
# The types of auth/prev events changes between event versions.
- prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]]
+ prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]]
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
if format_version == EventFormatVersions.ROOM_V1_V2:
auth_events = await self._store.add_event_hashes(auth_event_ids)
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 34625dd7a1..5da50cb0d2 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
-from typing import Iterable, List, Type, Union, cast
+from typing import List, Type, Union, cast
import jsonschema
from pydantic import Field, StrictBool, StrictStr
@@ -36,7 +36,7 @@ from synapse.events.utils import (
from synapse.federation.federation_server import server_matches_acl_event
from synapse.http.servlet import validate_json_object
from synapse.rest.models import RequestBodyModel
-from synapse.types import EventID, JsonDict, RoomID, UserID
+from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID
class EventValidator:
@@ -225,7 +225,7 @@ class EventValidator:
self._ensure_state_event(event)
- def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None:
+ def _ensure_strings(self, d: JsonDict, keys: StrCollection) -> None:
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py
index c05a14304c..fa043cca86 100644
--- a/synapse/handlers/account.py
+++ b/synapse/handlers/account.py
@@ -102,7 +102,7 @@ class AccountHandler:
"""
status = {"exists": False}
- userinfo = await self._main_store.get_userinfo_by_id(user_id.to_string())
+ userinfo = await self._main_store.get_user_by_id(user_id.to_string())
if userinfo is not None:
status = {
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 2f0e5f3b0a..7092ff3449 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set
from synapse.api.constants import Direction, Membership
from synapse.events import EventBase
-from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
+from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID, UserInfo
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
@@ -57,38 +57,30 @@ class AdminHandler:
async def get_user(self, user: UserID) -> Optional[JsonDict]:
"""Function to get user details"""
- user_info_dict = await self._store.get_user_by_id(user.to_string())
- if user_info_dict is None:
+ user_info: Optional[UserInfo] = await self._store.get_user_by_id(
+ user.to_string()
+ )
+ if user_info is None:
return None
- # Restrict returned information to a known set of fields. This prevents additional
- # fields added to get_user_by_id from modifying Synapse's external API surface.
- user_info_to_return = {
- "name",
- "admin",
- "deactivated",
- "locked",
- "shadow_banned",
- "creation_ts",
- "appservice_id",
- "consent_server_notice_sent",
- "consent_version",
- "consent_ts",
- "user_type",
- "is_guest",
- "last_seen_ts",
+ user_info_dict = {
+ "name": user.to_string(),
+ "admin": user_info.is_admin,
+ "deactivated": user_info.is_deactivated,
+ "locked": user_info.locked,
+ "shadow_banned": user_info.is_shadow_banned,
+ "creation_ts": user_info.creation_ts,
+ "appservice_id": user_info.appservice_id,
+ "consent_server_notice_sent": user_info.consent_server_notice_sent,
+ "consent_version": user_info.consent_version,
+ "consent_ts": user_info.consent_ts,
+ "user_type": user_info.user_type,
+ "is_guest": user_info.is_guest,
}
if self._msc3866_enabled:
# Only include the approved flag if support for MSC3866 is enabled.
- user_info_to_return.add("approved")
-
- # Restrict returned keys to a known set.
- user_info_dict = {
- key: value
- for key, value in user_info_dict.items()
- if key in user_info_to_return
- }
+ user_info_dict["approved"] = user_info.approved
# Add additional user metadata
profile = await self._store.get_profileinfo(user)
@@ -105,6 +97,9 @@ class AdminHandler:
user_info_dict["external_ids"] = external_ids
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
+ last_seen_ts = await self._store.get_last_seen_for_user_id(user.to_string())
+ user_info_dict["last_seen_ts"] = last_seen_ts
+
return user_info_dict
async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any:
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index e2ae3da67e..86ad96d030 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -388,7 +388,8 @@ class DeviceWorkerHandler:
"Trying handling device list state for partial join: not supported on workers."
)
- DEVICE_MSGS_DELETE_BATCH_LIMIT = 100
+ DEVICE_MSGS_DELETE_BATCH_LIMIT = 1000
+ DEVICE_MSGS_DELETE_SLEEP_MS = 1000
async def _delete_device_messages(
self,
@@ -400,19 +401,19 @@ class DeviceWorkerHandler:
device_id = task.params["device_id"]
up_to_stream_id = task.params["up_to_stream_id"]
- res = await self.store.delete_messages_for_device(
- user_id=user_id,
- device_id=device_id,
- up_to_stream_id=up_to_stream_id,
- limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT,
- )
+ # Delete the messages in batches to avoid too much DB load.
+ while True:
+ res = await self.store.delete_messages_for_device(
+ user_id=user_id,
+ device_id=device_id,
+ up_to_stream_id=up_to_stream_id,
+ limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT,
+ )
- if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT:
- return TaskStatus.COMPLETE, None, None
- else:
- # There is probably still device messages to be deleted, let's keep the task active and it will be run
- # again in a subsequent scheduler loop run (probably the next one, if not too many tasks are running).
- return TaskStatus.ACTIVE, None, None
+ if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT:
+ return TaskStatus.COMPLETE, None, None
+
+ await self.clock.sleep(DeviceHandler.DEVICE_MSGS_DELETE_SLEEP_MS / 1000.0)
class DeviceHandler(DeviceWorkerHandler):
@@ -758,12 +759,13 @@ class DeviceHandler(DeviceWorkerHandler):
# If the dehydrated device was successfully deleted (the device ID
# matched the stored dehydrated device), then modify the access
- # token to use the dehydrated device's ID and copy the old device
- # display name to the dehydrated device, and destroy the old device
- # ID
+ # token and refresh token to use the dehydrated device's ID and
+ # copy the old device display name to the dehydrated device,
+ # and destroy the old device ID
old_device_id = await self.store.set_device_for_access_token(
access_token, device_id
)
+ await self.store.set_device_for_refresh_token(user_id, old_device_id, device_id)
old_device = await self.store.get_device(user_id, old_device_id)
if old_device is None:
raise errors.NotFoundError()
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index d6be18cdef..c036578a3d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -828,13 +828,13 @@ class EventCreationHandler:
u = await self.store.get_user_by_id(user_id)
assert u is not None
- if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT):
+ if u.user_type in (UserTypes.SUPPORT, UserTypes.BOT):
# support and bot users are not required to consent
return
- if u["appservice_id"] is not None:
+ if u.appservice_id is not None:
# users registered by an appservice are exempt
return
- if u["consent_version"] == self.config.consent.user_consent_version:
+ if u.consent_version == self.config.consent.user_consent_version:
return
consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index ca2cdbc6e2..c750e03b36 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -78,7 +78,7 @@ from synapse.http.replicationagent import ReplicationAgent
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
-from synapse.types import ISynapseReactor
+from synapse.types import ISynapseReactor, StrSequence
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
@@ -108,10 +108,9 @@ RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValu
# the value actually has to be a List, but List is invariant so we can't specify that
# the entries can either be Lists or bytes.
RawHeaderValue = Union[
- List[str],
+ StrSequence,
List[bytes],
List[Union[str, bytes]],
- Tuple[str, ...],
Tuple[bytes, ...],
Tuple[Union[str, bytes], ...],
]
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index fc62793628..5d79d31579 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -18,7 +18,6 @@ import logging
from http import HTTPStatus
from typing import (
TYPE_CHECKING,
- Iterable,
List,
Mapping,
Optional,
@@ -38,7 +37,7 @@ from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http import redact_uri
from synapse.http.server import HttpServer
-from synapse.types import JsonDict, RoomAlias, RoomID
+from synapse.types import JsonDict, RoomAlias, RoomID, StrCollection
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -340,7 +339,7 @@ def parse_string(
name: str,
default: str,
*,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> str:
...
@@ -352,7 +351,7 @@ def parse_string(
name: str,
*,
required: Literal[True],
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> str:
...
@@ -365,7 +364,7 @@ def parse_string(
*,
default: Optional[str] = None,
required: bool = False,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
@@ -376,7 +375,7 @@ def parse_string(
name: str,
default: Optional[str] = None,
required: bool = False,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
"""
@@ -485,7 +484,7 @@ def parse_enum(
def _parse_string_value(
value: bytes,
- allowed_values: Optional[Iterable[str]],
+ allowed_values: Optional[StrCollection],
name: str,
encoding: str,
) -> str:
@@ -511,7 +510,7 @@ def parse_strings_from_args(
args: Mapping[bytes, Sequence[bytes]],
name: str,
*,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
...
@@ -523,7 +522,7 @@ def parse_strings_from_args(
name: str,
default: List[str],
*,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> List[str]:
...
@@ -535,7 +534,7 @@ def parse_strings_from_args(
name: str,
*,
required: Literal[True],
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> List[str]:
...
@@ -548,7 +547,7 @@ def parse_strings_from_args(
default: Optional[List[str]] = None,
*,
required: bool = False,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
...
@@ -559,7 +558,7 @@ def parse_strings_from_args(
name: str,
default: Optional[List[str]] = None,
required: bool = False,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
"""
@@ -610,7 +609,7 @@ def parse_string_from_args(
name: str,
default: Optional[str] = None,
*,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
@@ -623,7 +622,7 @@ def parse_string_from_args(
default: Optional[str] = None,
*,
required: Literal[True],
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> str:
...
@@ -635,7 +634,7 @@ def parse_string_from_args(
name: str,
default: Optional[str] = None,
required: bool = False,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
@@ -646,7 +645,7 @@ def parse_string_from_args(
name: str,
default: Optional[str] = None,
required: bool = False,
- allowed_values: Optional[Iterable[str]] = None,
+ allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
"""
@@ -821,7 +820,7 @@ def parse_and_validate_json_object_from_request(
return validate_json_object(content, model_type)
-def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
+def assert_params_in_dict(body: JsonDict, required: StrCollection) -> None:
absent = []
for k in required:
if k not in body:
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 39fc629937..3cf2fbc3e2 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -25,7 +25,6 @@ from typing import (
Iterable,
Mapping,
Optional,
- Sequence,
Set,
Tuple,
Type,
@@ -49,6 +48,7 @@ import synapse.metrics._reactor_metrics # noqa: F401
from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
from synapse.metrics._twisted_exposition import MetricsResource, generate_latest
from synapse.metrics._types import Collector
+from synapse.types import StrSequence
from synapse.util import SYNAPSE_VERSION
logger = logging.getLogger(__name__)
@@ -81,7 +81,7 @@ class LaterGauge(Collector):
name: str
desc: str
- labels: Optional[Sequence[str]] = attr.ib(hash=False)
+ labels: Optional[StrSequence] = attr.ib(hash=False)
# callback: should either return a value (if there are no labels for this metric),
# or dict mapping from a label tuple to a value
caller: Callable[
@@ -143,8 +143,8 @@ class InFlightGauge(Generic[MetricsEntry], Collector):
self,
name: str,
desc: str,
- labels: Sequence[str],
- sub_metrics: Sequence[str],
+ labels: StrSequence,
+ sub_metrics: StrSequence,
):
self.name = name
self.desc = desc
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 9ea4e23b31..f1f1f0cdf9 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -322,13 +322,21 @@ class BackgroundProcessLoggingContext(LoggingContext):
if instance_id is None:
instance_id = id(self)
super().__init__("%s-%s" % (name, instance_id))
- self._proc = _BackgroundProcess(name, self)
+ self._proc: Optional[_BackgroundProcess] = _BackgroundProcess(name, self)
def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
"""Log context has started running (again)."""
super().start(rusage)
+ if self._proc is None:
+ logger.error(
+ "Background process re-entered without a proc: %s",
+ self.name,
+ stack_info=True,
+ )
+ return
+
# We've become active again so we make sure we're in the list of active
# procs. (Note that "start" here means we've become active, as opposed
# to starting for the first time.)
@@ -345,6 +353,14 @@ class BackgroundProcessLoggingContext(LoggingContext):
super().__exit__(type, value, traceback)
+ if self._proc is None:
+ logger.error(
+ "Background process exited without a proc: %s",
+ self.name,
+ stack_info=True,
+ )
+ return
+
# The background process has finished. We explicitly remove and manually
# update the metrics here so that if nothing is scraping metrics the set
# doesn't infinitely grow.
@@ -352,3 +368,6 @@ class BackgroundProcessLoggingContext(LoggingContext):
_background_processes_active_since_last_scrape.discard(self._proc)
self._proc.update_metrics()
+
+ # Set proc to None to break the reference cycle.
+ self._proc = None
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index d6efe10a28..7ec202be23 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -572,7 +572,7 @@ class ModuleApi:
Returns:
UserInfo object if a user was found, otherwise None
"""
- return await self._store.get_userinfo_by_id(user_id)
+ return await self._store.get_user_by_id(user_id)
async def get_user_by_req(
self,
@@ -1878,7 +1878,7 @@ class AccountDataManager:
raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}")
# Ensure the user exists, so we don't just write to users that aren't there.
- if await self._store.get_userinfo_by_id(user_id) is None:
+ if await self._store.get_user_by_id(user_id) is None:
raise ValueError(f"User {user_id} does not exist on this server.")
await self._handler.add_account_data_for_user(user_id, data_type, new_data)
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 68115bca70..fc39e5c963 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -104,7 +104,7 @@ class _NotifierUserStream:
def __init__(
self,
user_id: str,
- rooms: Collection[str],
+ rooms: StrCollection,
current_token: StreamToken,
time_now_ms: int,
):
@@ -457,7 +457,7 @@ class Notifier:
stream_key: str,
new_token: Union[int, RoomStreamToken],
users: Optional[Collection[Union[str, UserID]]] = None,
- rooms: Optional[Collection[str]] = None,
+ rooms: Optional[StrCollection] = None,
) -> None:
"""Used to inform listeners that something has happened event wise.
@@ -529,7 +529,7 @@ class Notifier:
user_id: str,
timeout: int,
callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
- room_ids: Optional[Collection[str]] = None,
+ room_ids: Optional[StrCollection] = None,
from_token: StreamToken = StreamToken.START,
) -> T:
"""Wait until the callback returns a non empty response or the
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 5642666411..b668bb5da1 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -672,14 +672,12 @@ class ReplicationCommandHandler:
cmd.instance_name, cmd.lock_name, cmd.lock_key
)
- async def on_NEW_ACTIVE_TASK(
+ def on_NEW_ACTIVE_TASK(
self, conn: IReplicationConnection, cmd: NewActiveTaskCommand
) -> None:
"""Called when get a new NEW_ACTIVE_TASK command."""
if self._task_scheduler:
- task = await self._task_scheduler.get_task(cmd.data)
- if task:
- await self._task_scheduler._launch_task(task)
+ self._task_scheduler.launch_task_by_id(cmd.data)
def new_connection(self, connection: IReplicationConnection) -> None:
"""Called when we have a new connection."""
diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py
index 5c1c19e1f3..73c568ef75 100644
--- a/synapse/rest/client/_base.py
+++ b/synapse/rest/client/_base.py
@@ -20,14 +20,14 @@ from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar,
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX
-from synapse.types import JsonDict
+from synapse.types import JsonDict, StrCollection
logger = logging.getLogger(__name__)
def client_patterns(
path_regex: str,
- releases: Iterable[str] = ("r0", "v3"),
+ releases: StrCollection = ("r0", "v3"),
unstable: bool = True,
v1: bool = False,
) -> Iterable[Pattern]:
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 25f9ea285b..88d3ec1baf 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -129,7 +129,7 @@ class ConsentResource(DirectServeHtmlResource):
if u is None:
raise NotFoundError("Unknown user")
- has_consented = u["consent_version"] == version
+ has_consented = u.consent_version == version
userhmac = userhmac_bytes.decode("ascii")
try:
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
index 94025ba41f..a879b6505e 100644
--- a/synapse/server_notices/consent_server_notices.py
+++ b/synapse/server_notices/consent_server_notices.py
@@ -79,15 +79,15 @@ class ConsentServerNotices:
if u is None:
return
- if u["is_guest"] and not self._send_to_guests:
+ if u.is_guest and not self._send_to_guests:
# don't send to guests
return
- if u["consent_version"] == self._current_consent_version:
+ if u.consent_version == self._current_consent_version:
# user has already consented
return
- if u["consent_server_notice_sent"] == self._current_consent_version:
+ if u.consent_server_notice_sent == self._current_consent_version:
# we've already sent a notice to the user
return
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 1b91cf5eaa..e977ed1044 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -20,7 +20,6 @@ from typing import (
Any,
Awaitable,
Callable,
- Collection,
DefaultDict,
Dict,
FrozenSet,
@@ -49,7 +48,7 @@ from synapse.logging.opentracing import tag_args, trace
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import StateMap
+from synapse.types import StateMap, StrCollection
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
@@ -197,7 +196,7 @@ class StateHandler:
async def compute_state_after_events(
self,
room_id: str,
- event_ids: Collection[str],
+ event_ids: StrCollection,
state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> StateMap[str]:
@@ -231,7 +230,7 @@ class StateHandler:
return await ret.get_state(self._state_storage_controller, state_filter)
async def get_current_user_ids_in_room(
- self, room_id: str, latest_event_ids: Collection[str]
+ self, room_id: str, latest_event_ids: StrCollection
) -> Set[str]:
"""
Get the users IDs who are currently in a room.
@@ -256,7 +255,7 @@ class StateHandler:
return await self.store.get_joined_user_ids_from_state(room_id, state)
async def get_hosts_in_room_at_events(
- self, room_id: str, event_ids: Collection[str]
+ self, room_id: str, event_ids: StrCollection
) -> FrozenSet[str]:
"""Get the hosts that were in a room at the given event ids
@@ -470,7 +469,7 @@ class StateHandler:
@trace
@measure_func()
async def resolve_state_groups_for_events(
- self, room_id: str, event_ids: Collection[str], await_full_state: bool = True
+ self, room_id: str, event_ids: StrCollection, await_full_state: bool = True
) -> _StateCacheEntry:
"""Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
@@ -882,7 +881,7 @@ class StateResolutionStore:
store: "DataStore"
def get_events(
- self, event_ids: Collection[str], allow_rejected: bool = False
+ self, event_ids: StrCollection, allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
"""Get events from the database
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 500e384695..c76a2f082e 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -17,7 +17,6 @@ import logging
from typing import (
Awaitable,
Callable,
- Collection,
Dict,
Iterable,
List,
@@ -32,7 +31,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateMap, StrCollection
logger = logging.getLogger(__name__)
@@ -45,7 +44,7 @@ async def resolve_events_with_store(
room_version: RoomVersion,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
- state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
+ state_map_factory: Callable[[StrCollection], Awaitable[Dict[str, EventBase]]],
) -> StateMap[str]:
"""
Args:
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 44c49274a9..1752f95db8 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -19,7 +19,6 @@ from typing import (
Any,
Awaitable,
Callable,
- Collection,
Dict,
Generator,
Iterable,
@@ -39,7 +38,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateMap, StrCollection
logger = logging.getLogger(__name__)
@@ -56,7 +55,7 @@ class StateResolutionStore(Protocol):
# This is usually synapse.state.StateResolutionStore, but it's replaced with a
# TestStateResolutionStore in tests.
def get_events(
- self, event_ids: Collection[str], allow_rejected: bool = False
+ self, event_ids: StrCollection, allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
...
@@ -366,7 +365,7 @@ async def _get_auth_chain_difference(
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
- auth_difference_unpersisted_part: Collection[str] = union - intersection
+ auth_difference_unpersisted_part: StrCollection = union - intersection
else:
auth_difference_unpersisted_part = ()
state_sets_ids = [set(state_set.values()) for state_set in state_sets]
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index e97b844dfa..16170e0436 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -764,3 +764,14 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
}
return list(results.values())
+
+ async def get_last_seen_for_user_id(self, user_id: str) -> Optional[int]:
+ """Get the last seen timestamp for a user, if we have it."""
+
+ return await self.db_pool.simple_select_one_onecol(
+ table="user_ips",
+ keyvalues={"user_id": user_id},
+ retcol="MAX(last_seen)",
+ allow_none=True,
+ desc="get_last_seen_for_user_id",
+ )
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index fab7008a8f..09de8f55e2 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -47,7 +47,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.types import JsonDict, StrCollection
+from synapse.types import JsonDict, StrCollection, StrSequence
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
@@ -1179,7 +1179,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
@cached(max_entries=5000, iterable=True)
- async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]:
+ async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence:
return await self.db_pool.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 7e85b73e8e..cc964604e2 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@
import logging
import random
import re
-from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import attr
@@ -192,8 +192,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
@cached()
- async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
- """Deprecated: use get_userinfo_by_id instead"""
+ async def get_user_by_id(self, user_id: str) -> Optional[UserInfo]:
+ """Returns info about the user account, if it exists."""
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
# We could technically use simple_select_one here, but it would not perform
@@ -202,16 +202,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
txn.execute(
"""
SELECT
- name, password_hash, is_guest, admin, consent_version, consent_ts,
+ name, is_guest, admin, consent_version, consent_ts,
consent_server_notice_sent, appservice_id, creation_ts, user_type,
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
COALESCE(approved, TRUE) AS approved,
- COALESCE(locked, FALSE) AS locked, last_seen_ts
+ COALESCE(locked, FALSE) AS locked
FROM users
- LEFT JOIN (
- SELECT user_id, MAX(last_seen) AS last_seen_ts
- FROM user_ips GROUP BY user_id
- ) ls ON users.name = ls.user_id
WHERE name = ?
""",
(user_id,),
@@ -228,51 +224,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_user_by_id",
func=get_user_by_id_txn,
)
-
- if row is not None:
- # If we're using SQLite our boolean values will be integers. Because we
- # present some of this data as is to e.g. server admins via REST APIs, we
- # want to make sure we're returning the right type of data.
- # Note: when adding a column name to this list, be wary of NULLable columns,
- # since NULL values will be turned into False.
- boolean_columns = [
- "admin",
- "deactivated",
- "shadow_banned",
- "approved",
- "locked",
- ]
- for column in boolean_columns:
- row[column] = bool(row[column])
-
- return row
-
- async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
- """Get a UserInfo object for a user by user ID.
-
- Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed,
- this method should be cached.
-
- Args:
- user_id: The user to fetch user info for.
- Returns:
- `UserInfo` object if user found, otherwise `None`.
- """
- user_data = await self.get_user_by_id(user_id)
- if not user_data:
+ if row is None:
return None
+
return UserInfo(
- appservice_id=user_data["appservice_id"],
- consent_server_notice_sent=user_data["consent_server_notice_sent"],
- consent_version=user_data["consent_version"],
- creation_ts=user_data["creation_ts"],
- is_admin=bool(user_data["admin"]),
- is_deactivated=bool(user_data["deactivated"]),
- is_guest=bool(user_data["is_guest"]),
- is_shadow_banned=bool(user_data["shadow_banned"]),
- user_id=UserID.from_string(user_data["name"]),
- user_type=user_data["user_type"],
- last_seen_ts=user_data["last_seen_ts"],
+ appservice_id=row["appservice_id"],
+ consent_server_notice_sent=row["consent_server_notice_sent"],
+ consent_version=row["consent_version"],
+ consent_ts=row["consent_ts"],
+ creation_ts=row["creation_ts"],
+ is_admin=bool(row["admin"]),
+ is_deactivated=bool(row["deactivated"]),
+ is_guest=bool(row["is_guest"]),
+ is_shadow_banned=bool(row["shadow_banned"]),
+ user_id=UserID.from_string(row["name"]),
+ user_type=row["user_type"],
+ approved=bool(row["approved"]),
+ locked=bool(row["locked"]),
)
async def is_trial_user(self, user_id: str) -> bool:
@@ -290,10 +258,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
now = self._clock.time_msec()
days = self.config.server.mau_appservice_trial_days.get(
- info["appservice_id"], self.config.server.mau_trial_days
+ info.appservice_id, self.config.server.mau_trial_days
)
trial_duration_ms = days * 24 * 60 * 60 * 1000
- is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
+ is_trial = (now - info.creation_ts * 1000) < trial_duration_ms
return is_trial
@cached()
@@ -2312,6 +2280,26 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
return next_id
+ async def set_device_for_refresh_token(
+ self, user_id: str, old_device_id: str, device_id: str
+ ) -> None:
+ """Moves refresh tokens from old device to current device
+
+ Args:
+ user_id: The user of the devices.
+ old_device_id: The old device.
+ device_id: The new device ID.
+ Returns:
+ None
+ """
+
+ await self.db_pool.simple_update(
+ "refresh_tokens",
+ keyvalues={"user_id": user_id, "device_id": old_device_id},
+ updatevalues={"device_id": device_id},
+ desc="set_device_for_refresh_token",
+ )
+
def _set_device_for_access_token_txn(
self, txn: LoggingTransaction, token: str, device_id: str
) -> str:
diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py
index 9ab120eea9..5c5372a825 100644
--- a/synapse/storage/databases/main/task_scheduler.py
+++ b/synapse/storage/databases/main/task_scheduler.py
@@ -53,6 +53,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
resource_id: Optional[str] = None,
statuses: Optional[List[TaskStatus]] = None,
max_timestamp: Optional[int] = None,
+ limit: Optional[int] = None,
) -> List[ScheduledTask]:
"""Get a list of scheduled tasks from the DB.
@@ -62,6 +63,7 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
statuses: Limit the returned tasks to the specific statuses
max_timestamp: Limit the returned tasks to the ones that have
a timestamp inferior to the specified one
+ limit: Only return `limit` number of rows if set.
Returns: a list of `ScheduledTask`, ordered by increasing timestamps
"""
@@ -94,6 +96,10 @@ class TaskSchedulerWorkerStore(SQLBaseStore):
sql = sql + " ORDER BY timestamp"
+ if limit is not None:
+ sql += " LIMIT ?"
+ args.append(limit)
+
txn.execute(sql, args)
return self.db_pool.cursor_to_dict(txn)
diff --git a/synapse/storage/schema/main/delta/82/02_scheduled_tasks_index.sql b/synapse/storage/schema/main/delta/82/02_scheduled_tasks_index.sql
new file mode 100644
index 0000000000..6b90275139
--- /dev/null
+++ b/synapse/storage/schema/main/delta/82/02_scheduled_tasks_index.sql
@@ -0,0 +1,16 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE INDEX IF NOT EXISTS scheduled_tasks_timestamp ON scheduled_tasks(timestamp);
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 488714f60c..76b0e3e694 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -933,33 +933,37 @@ def get_verify_key_from_cross_signing_key(
@attr.s(auto_attribs=True, frozen=True, slots=True)
class UserInfo:
- """Holds information about a user. Result of get_userinfo_by_id.
+ """Holds information about a user. Result of get_user_by_id.
Attributes:
user_id: ID of the user.
appservice_id: Application service ID that created this user.
consent_server_notice_sent: Version of policy documents the user has been sent.
consent_version: Version of policy documents the user has consented to.
+ consent_ts: Time the user consented
creation_ts: Creation timestamp of the user.
is_admin: True if the user is an admin.
is_deactivated: True if the user has been deactivated.
is_guest: True if the user is a guest user.
is_shadow_banned: True if the user has been shadow-banned.
user_type: User type (None for normal user, 'support' and 'bot' other options).
- last_seen_ts: Last activity timestamp of the user.
+ approved: If the user has been "approved" to register on the server.
+ locked: Whether the user's account has been locked
"""
user_id: UserID
appservice_id: Optional[int]
consent_server_notice_sent: Optional[str]
consent_version: Optional[str]
+ consent_ts: Optional[int]
user_type: Optional[str]
creation_ts: int
is_admin: bool
is_deactivated: bool
is_guest: bool
is_shadow_banned: bool
- last_seen_ts: Optional[int]
+ approved: bool
+ locked: bool
class UserProfile(TypedDict):
diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py
index b7de201bde..caf13b3474 100644
--- a/synapse/util/task_scheduler.py
+++ b/synapse/util/task_scheduler.py
@@ -15,12 +15,14 @@
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set, Tuple
-from prometheus_client import Gauge
-
from twisted.python.failure import Failure
from synapse.logging.context import nested_logging_context
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
+)
from synapse.types import JsonMapping, ScheduledTask, TaskStatus
from synapse.util.stringutils import random_string
@@ -30,12 +32,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-running_tasks_gauge = Gauge(
- "synapse_scheduler_running_tasks",
- "The number of concurrent running tasks handled by the TaskScheduler",
-)
-
-
class TaskScheduler:
"""
This is a simple task sheduler aimed at resumable tasks: usually we use `run_in_background`
@@ -70,6 +66,8 @@ class TaskScheduler:
# Precision of the scheduler, evaluation of tasks to run will only happen
# every `SCHEDULE_INTERVAL_MS` ms
SCHEDULE_INTERVAL_MS = 1 * 60 * 1000 # 1mn
+ # How often to clean up old tasks.
+ CLEANUP_INTERVAL_MS = 30 * 60 * 1000
# Time before a complete or failed task is deleted from the DB
KEEP_TASKS_FOR_MS = 7 * 24 * 60 * 60 * 1000 # 1 week
# Maximum number of tasks that can run at the same time
@@ -92,14 +90,26 @@ class TaskScheduler:
] = {}
self._run_background_tasks = hs.config.worker.run_background_tasks
+ # Flag to make sure we only try and launch new tasks once at a time.
+ self._launching_new_tasks = False
+
if self._run_background_tasks:
self._clock.looping_call(
- run_as_background_process,
+ self._launch_scheduled_tasks,
+ TaskScheduler.SCHEDULE_INTERVAL_MS,
+ )
+ self._clock.looping_call(
+ self._clean_scheduled_tasks,
TaskScheduler.SCHEDULE_INTERVAL_MS,
- "handle_scheduled_tasks",
- self._handle_scheduled_tasks,
)
+ LaterGauge(
+ "synapse_scheduler_running_tasks",
+ "The number of concurrent running tasks handled by the TaskScheduler",
+ labels=None,
+ caller=lambda: len(self._running_tasks),
+ )
+
def register_action(
self,
function: Callable[
@@ -234,6 +244,7 @@ class TaskScheduler:
resource_id: Optional[str] = None,
statuses: Optional[List[TaskStatus]] = None,
max_timestamp: Optional[int] = None,
+ limit: Optional[int] = None,
) -> List[ScheduledTask]:
"""Get a list of tasks. Returns all the tasks if no args is provided.
@@ -247,6 +258,7 @@ class TaskScheduler:
statuses: Limit the returned tasks to the specific statuses
max_timestamp: Limit the returned tasks to the ones that have
a timestamp inferior to the specified one
+ limit: Only return `limit` number of rows if set.
Returns
A list of `ScheduledTask`, ordered by increasing timestamps
@@ -256,6 +268,7 @@ class TaskScheduler:
resource_id=resource_id,
statuses=statuses,
max_timestamp=max_timestamp,
+ limit=limit,
)
async def delete_task(self, id: str) -> None:
@@ -273,34 +286,58 @@ class TaskScheduler:
raise Exception(f"Task {id} is currently ACTIVE and can't be deleted")
await self._store.delete_scheduled_task(id)
- async def _handle_scheduled_tasks(self) -> None:
- """Main loop taking care of launching tasks and cleaning up old ones."""
- await self._launch_scheduled_tasks()
- await self._clean_scheduled_tasks()
+ def launch_task_by_id(self, id: str) -> None:
+ """Try launching the task with the given ID."""
+ # Don't bother trying to launch new tasks if we're already at capacity.
+ if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
+ return
+
+ run_as_background_process("launch_task_by_id", self._launch_task_by_id, id)
+
+ async def _launch_task_by_id(self, id: str) -> None:
+ """Helper async function for `launch_task_by_id`."""
+ task = await self.get_task(id)
+ if task:
+ await self._launch_task(task)
+ @wrap_as_background_process("launch_scheduled_tasks")
async def _launch_scheduled_tasks(self) -> None:
"""Retrieve and launch scheduled tasks that should be running at that time."""
- for task in await self.get_tasks(statuses=[TaskStatus.ACTIVE]):
- await self._launch_task(task)
- for task in await self.get_tasks(
- statuses=[TaskStatus.SCHEDULED], max_timestamp=self._clock.time_msec()
- ):
- await self._launch_task(task)
+ # Don't bother trying to launch new tasks if we're already at capacity.
+ if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
+ return
+
+ if self._launching_new_tasks:
+ return
- running_tasks_gauge.set(len(self._running_tasks))
+ self._launching_new_tasks = True
+ try:
+ for task in await self.get_tasks(
+ statuses=[TaskStatus.ACTIVE], limit=self.MAX_CONCURRENT_RUNNING_TASKS
+ ):
+ await self._launch_task(task)
+ for task in await self.get_tasks(
+ statuses=[TaskStatus.SCHEDULED],
+ max_timestamp=self._clock.time_msec(),
+ limit=self.MAX_CONCURRENT_RUNNING_TASKS,
+ ):
+ await self._launch_task(task)
+
+ finally:
+ self._launching_new_tasks = False
+
+ @wrap_as_background_process("clean_scheduled_tasks")
async def _clean_scheduled_tasks(self) -> None:
"""Clean old complete or failed jobs to avoid clutter the DB."""
+ now = self._clock.time_msec()
for task in await self._store.get_scheduled_tasks(
- statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE]
+ statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE],
+ max_timestamp=now - TaskScheduler.KEEP_TASKS_FOR_MS,
):
# FAILED and COMPLETE tasks should never be running
assert task.id not in self._running_tasks
- if (
- self._clock.time_msec()
- > task.timestamp + TaskScheduler.KEEP_TASKS_FOR_MS
- ):
- await self._store.delete_scheduled_task(task.id)
+ await self._store.delete_scheduled_task(task.id)
async def _launch_task(self, task: ScheduledTask) -> None:
"""Launch a scheduled task now.
@@ -339,6 +376,9 @@ class TaskScheduler:
)
self._running_tasks.remove(task.id)
+ # Try launch a new task since we've finished with this one.
+ self._clock.call_later(1, self._launch_scheduled_tasks)
+
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
return
@@ -355,4 +395,4 @@ class TaskScheduler:
self._running_tasks.add(task.id)
await self.update_task(task.id, status=TaskStatus.ACTIVE)
- run_as_background_process(task.action, wrapper)
+ run_as_background_process(f"task-{task.action}", wrapper)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index eac10f6438..f15fdd8314 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -36,7 +36,7 @@ from synapse.events.utils import prune_event
from synapse.logging.opentracing import trace
from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
-from synapse.types import RetentionPolicy, StateMap, get_domain_from_id
+from synapse.types import RetentionPolicy, StateMap, StrCollection, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util import Clock
@@ -150,12 +150,12 @@ async def filter_events_for_client(
async def filter_event_for_clients_with_state(
store: DataStore,
- user_ids: Collection[str],
+ user_ids: StrCollection,
event: EventBase,
context: EventContext,
is_peeking: bool = False,
filter_send_to_client: bool = True,
-) -> Collection[str]:
+) -> StrCollection:
"""
Checks to see if an event is visible to the users in the list at the time of
the event.
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index dcd01d5688..e00d7215df 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -188,8 +188,11 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- # This just needs to return a truth-y value.
- self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
+
+ class FakeUserInfo:
+ is_guest = False
+
+ self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={})
@@ -341,7 +344,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
def test_get_guest_user_from_macaroon(self) -> None:
- self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True})
+ class FakeUserInfo:
+ is_guest = True
+
+ self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo())
self.store.get_user_by_access_token = AsyncMock(return_value=None)
user_id = "@baldrick:matrix.org"
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 79d327499b..d4ed068357 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -461,6 +461,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.message_handler = hs.get_device_message_handler()
self.registration = hs.get_registration_handler()
self.auth = hs.get_auth()
+ self.auth_handler = hs.get_auth_handler()
self.store = hs.get_datastores().main
return hs
@@ -487,11 +488,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
# Create a new login for the user and dehydrated the device
- device_id, access_token, _expiration_time, _refresh_token = self.get_success(
+ device_id, access_token, _expiration_time, refresh_token = self.get_success(
self.registration.register_device(
user_id=user_id,
device_id=None,
initial_display_name="new device",
+ should_issue_refresh_token=True,
)
)
@@ -522,6 +524,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.assertEqual(user_info.device_id, retrieved_device_id)
+ # make sure the user device has the refresh token
+ assert refresh_token is not None
+ self.get_success(
+ self.auth_handler.refresh_token(refresh_token, 5 * 60 * 1000, 5 * 60 * 1000)
+ )
+
# make sure the device has the display name that was set from the login
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 95c9792d54..0cca34d355 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -16,7 +16,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError
from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, UserID, UserInfo
from synapse.util import Clock
from tests.unittest import HomeserverTestCase, override_config
@@ -35,24 +35,22 @@ class RegistrationStoreTestCase(HomeserverTestCase):
self.get_success(self.store.register_user(self.user_id, self.pwhash))
self.assertEqual(
- {
+ UserInfo(
# TODO(paul): Surely this field should be 'user_id', not 'name'
- "name": self.user_id,
- "password_hash": self.pwhash,
- "admin": 0,
- "is_guest": 0,
- "consent_version": None,
- "consent_ts": None,
- "consent_server_notice_sent": None,
- "appservice_id": None,
- "creation_ts": 0,
- "user_type": None,
- "deactivated": 0,
- "locked": 0,
- "shadow_banned": 0,
- "approved": 1,
- "last_seen_ts": None,
- },
+ user_id=UserID.from_string(self.user_id),
+ is_admin=False,
+ is_guest=False,
+ consent_server_notice_sent=None,
+ consent_ts=None,
+ consent_version=None,
+ appservice_id=None,
+ creation_ts=0,
+ user_type=None,
+ is_deactivated=False,
+ locked=False,
+ is_shadow_banned=False,
+ approved=True,
+ ),
(self.get_success(self.store.get_user_by_id(self.user_id))),
)
@@ -65,9 +63,11 @@ class RegistrationStoreTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user
- self.assertEqual(user["consent_version"], "1")
- self.assertGreater(user["consent_ts"], before_consent)
- self.assertLess(user["consent_ts"], self.clock.time_msec())
+ self.assertEqual(user.consent_version, "1")
+ self.assertIsNotNone(user.consent_ts)
+ assert user.consent_ts is not None
+ self.assertGreater(user.consent_ts, before_consent)
+ self.assertLess(user.consent_ts, self.clock.time_msec())
def test_add_tokens(self) -> None:
self.get_success(self.store.register_user(self.user_id, self.pwhash))
@@ -215,7 +215,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user is not None
- self.assertTrue(user["approved"])
+ self.assertTrue(user.approved)
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved)
@@ -228,7 +228,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user is not None
- self.assertFalse(user["approved"])
+ self.assertFalse(user.approved)
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertFalse(approved)
@@ -248,7 +248,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
user = self.get_success(self.store.get_user_by_id(self.user_id))
self.assertIsNotNone(user)
assert user is not None
- self.assertEqual(user["approved"], 1)
+ self.assertEqual(user.approved, 1)
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved)
|