diff --git a/synapse/__init__.py b/synapse/__init__.py
index 6c1801862b..2f9c22a833 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -25,7 +25,11 @@ from synapse.util.rust import check_rust_lib_up_to_date
from synapse.util.stringutils import strtobool
# Check that we're not running on an unsupported Python version.
-if sys.version_info < (3, 8):
+#
+# Note that we use an (unneeded) variable here so that pyupgrade doesn't nuke the
+# if-statement completely.
+py_version = sys.version_info
+if py_version < (3, 8):
print("Synapse requires Python 3.8 or above.")
sys.exit(1)
@@ -78,7 +82,7 @@ try:
except ImportError:
pass
-import synapse.util
+import synapse.util # noqa: E402
__version__ = synapse.util.SYNAPSE_VERSION
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 22c84fbd5b..49242800b8 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -123,7 +123,7 @@ BOOLEAN_COLUMNS = {
"redactions": ["have_censored"],
"room_stats_state": ["is_federatable"],
"rooms": ["is_public", "has_auth_chain_index"],
- "users": ["shadow_banned", "approved"],
+ "users": ["shadow_banned", "approved", "locked"],
"un_partial_stated_event_stream": ["rejection_status_changed"],
"users_who_share_rooms": ["share_private"],
"per_user_experimental_features": ["enabled"],
@@ -1205,10 +1205,10 @@ class CursesProgress(Progress):
self.total_processed = 0
self.total_remaining = 0
- super(CursesProgress, self).__init__()
+ super().__init__()
def update(self, table: str, num_done: int) -> None:
- super(CursesProgress, self).update(table, num_done)
+ super().update(table, num_done)
self.total_processed = 0
self.total_remaining = 0
@@ -1304,7 +1304,7 @@ class TerminalProgress(Progress):
"""Just prints progress to the terminal"""
def update(self, table: str, num_done: int) -> None:
- super(TerminalProgress, self).update(table, num_done)
+ super().update(table, num_done)
data = self.tables[table]
diff --git a/synapse/_scripts/update_synapse_database.py b/synapse/_scripts/update_synapse_database.py
index 0adf94bba6..f97aecf8d5 100644
--- a/synapse/_scripts/update_synapse_database.py
+++ b/synapse/_scripts/update_synapse_database.py
@@ -38,7 +38,7 @@ class MockHomeserver(HomeServer):
DATASTORE_CLASS = DataStore # type: ignore [assignment]
def __init__(self, config: HomeServerConfig):
- super(MockHomeserver, self).__init__(
+ super().__init__(
hostname=config.server.server_name,
config=config,
reactor=reactor,
diff --git a/synapse/api/auth/__init__.py b/synapse/api/auth/__init__.py
index 90cfe39d76..bb3f50f2dd 100644
--- a/synapse/api/auth/__init__.py
+++ b/synapse/api/auth/__init__.py
@@ -60,6 +60,7 @@ class Auth(Protocol):
request: SynapseRequest,
allow_guest: bool = False,
allow_expired: bool = False,
+ allow_locked: bool = False,
) -> Requester:
"""Get a registered user's ID.
diff --git a/synapse/api/auth/internal.py b/synapse/api/auth/internal.py
index e2ae198b19..6a5fd44ec0 100644
--- a/synapse/api/auth/internal.py
+++ b/synapse/api/auth/internal.py
@@ -58,6 +58,7 @@ class InternalAuth(BaseAuth):
request: SynapseRequest,
allow_guest: bool = False,
allow_expired: bool = False,
+ allow_locked: bool = False,
) -> Requester:
"""Get a registered user's ID.
@@ -79,7 +80,7 @@ class InternalAuth(BaseAuth):
parent_span = active_span()
with start_active_span("get_user_by_req"):
requester = await self._wrapped_get_user_by_req(
- request, allow_guest, allow_expired
+ request, allow_guest, allow_expired, allow_locked
)
if parent_span:
@@ -107,6 +108,7 @@ class InternalAuth(BaseAuth):
request: SynapseRequest,
allow_guest: bool,
allow_expired: bool,
+ allow_locked: bool,
) -> Requester:
"""Helper for get_user_by_req
@@ -126,6 +128,17 @@ class InternalAuth(BaseAuth):
access_token, allow_expired=allow_expired
)
+ # Deny the request if the user account is locked.
+ if not allow_locked and await self.store.get_user_locked_status(
+ requester.user.to_string()
+ ):
+ raise AuthError(
+ 401,
+ "User account has been locked",
+ errcode=Codes.USER_LOCKED,
+ additional_fields={"soft_logout": True},
+ )
+
# Deny the request if the user account has expired.
# This check is only done for regular users, not appservice ones.
if not allow_expired:
diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py
index bd4fc9c0ee..3a516093f5 100644
--- a/synapse/api/auth/msc3861_delegated.py
+++ b/synapse/api/auth/msc3861_delegated.py
@@ -27,6 +27,7 @@ from twisted.web.http_headers import Headers
from synapse.api.auth.base import BaseAuth
from synapse.api.errors import (
AuthError,
+ Codes,
HttpResponseException,
InvalidClientTokenError,
OAuthInsufficientScopeError,
@@ -38,6 +39,7 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.types import Requester, UserID, create_requester
from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
+from synapse.util.caches.expiringcache import ExpiringCache
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -105,6 +107,14 @@ class MSC3861DelegatedAuth(BaseAuth):
self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata)
+ self._clock = hs.get_clock()
+ self._token_cache: ExpiringCache[str, IntrospectionToken] = ExpiringCache(
+ cache_name="introspection_token_cache",
+ clock=self._clock,
+ max_len=10000,
+ expiry_ms=5 * 60 * 1000,
+ )
+
if isinstance(auth_method, PrivateKeyJWTWithKid):
# Use the JWK as the client secret when using the private_key_jwt method
assert self._config.jwk, "No JWK provided"
@@ -143,6 +153,20 @@ class MSC3861DelegatedAuth(BaseAuth):
Returns:
The introspection response
"""
+ # check the cache before doing a request
+ introspection_token = self._token_cache.get(token, None)
+
+ if introspection_token:
+ # check the expiration field of the token (if it exists)
+ exp = introspection_token.get("exp", None)
+ if exp:
+ time_now = self._clock.time()
+ expired = time_now > exp
+ if not expired:
+ return introspection_token
+ else:
+ return introspection_token
+
metadata = await self._issuer_metadata.get()
introspection_endpoint = metadata.get("introspection_endpoint")
raw_headers: Dict[str, str] = {
@@ -156,7 +180,10 @@ class MSC3861DelegatedAuth(BaseAuth):
# Fill the body/headers with credentials
uri, raw_headers, body = self._client_auth.prepare(
- method="POST", uri=introspection_endpoint, headers=raw_headers, body=body
+ method="POST",
+ uri=introspection_endpoint,
+ headers=raw_headers,
+ body=body,
)
headers = Headers({k: [v] for (k, v) in raw_headers.items()})
@@ -186,7 +213,17 @@ class MSC3861DelegatedAuth(BaseAuth):
"The introspection endpoint returned an invalid JSON response."
)
- return IntrospectionToken(**resp)
+ expiration = resp.get("exp", None)
+ if expiration:
+ if self._clock.time() > expiration:
+ raise InvalidClientTokenError("Token is expired.")
+
+ introspection_token = IntrospectionToken(**resp)
+
+ # add token to cache
+ self._token_cache[token] = introspection_token
+
+ return introspection_token
async def is_server_admin(self, requester: Requester) -> bool:
return "urn:synapse:admin:*" in requester.scope
@@ -196,6 +233,7 @@ class MSC3861DelegatedAuth(BaseAuth):
request: SynapseRequest,
allow_guest: bool = False,
allow_expired: bool = False,
+ allow_locked: bool = False,
) -> Requester:
access_token = self.get_access_token_from_request(request)
@@ -205,6 +243,17 @@ class MSC3861DelegatedAuth(BaseAuth):
# so that we don't provision the user if they don't have enough permission:
requester = await self.get_user_by_access_token(access_token, allow_expired)
+ # Deny the request if the user account is locked.
+ if not allow_locked and await self.store.get_user_locked_status(
+ requester.user.to_string()
+ ):
+ raise AuthError(
+ 401,
+ "User account has been locked",
+ errcode=Codes.USER_LOCKED,
+ additional_fields={"soft_logout": True},
+ )
+
if not allow_guest and requester.is_guest:
raise OAuthInsufficientScopeError([SCOPE_MATRIX_API])
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index dc32553d0c..bf311b636d 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -18,8 +18,7 @@
"""Contains constants from the specification."""
import enum
-
-from typing_extensions import Final
+from typing import Final
# the max size of a (canonical-json-encoded) event
MAX_PDU_SIZE = 65536
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 3546aaf7c3..7ffd72c42c 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -80,6 +80,8 @@ class Codes(str, Enum):
WEAK_PASSWORD = "M_WEAK_PASSWORD"
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED"
+ # USER_LOCKED = "M_USER_LOCKED"
+ USER_LOCKED = "ORG_MATRIX_MSC3939_USER_LOCKED"
# Part of MSC3848
# https://github.com/matrix-org/matrix-spec-proposals/pull/3848
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 9152c06bd6..c4e63e7411 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -47,6 +47,10 @@ class CasConfig(Config):
required_attributes
)
+ self.idp_name = cas_config.get("idp_name", "CAS")
+ self.idp_icon = cas_config.get("idp_icon")
+ self.idp_brand = cas_config.get("idp_brand")
+
else:
self.cas_server_url = None
self.cas_service_url = None
diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py
index 49ca663dde..c69e24cf26 100644
--- a/synapse/config/saml2.py
+++ b/synapse/config/saml2.py
@@ -89,8 +89,14 @@ class SAML2Config(Config):
"grandfathered_mxid_source_attribute", "uid"
)
+ # refers to a SAML IdP entity ID
self.saml2_idp_entityid = saml2_config.get("idp_entityid", None)
+ # IdP properties for Matrix clients
+ self.idp_name = saml2_config.get("idp_name", "SAML")
+ self.idp_icon = saml2_config.get("idp_icon")
+ self.idp_brand = saml2_config.get("idp_brand")
+
# user_mapping_provider may be None if the key is present but has no value
ump_dict = saml2_config.get("user_mapping_provider") or {}
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index c9e18b91e9..f60ec2ea66 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -35,3 +35,4 @@ class UserDirectoryConfig(Config):
self.user_directory_search_prefer_local_users = user_directory_config.get(
"prefer_local_users", False
)
+ self.show_locked_users = user_directory_config.get("show_locked_users", False)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index a90d99c4d6..f9915e5a3f 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -63,7 +63,7 @@ from synapse.federation.federation_base import (
)
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
-from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
+from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import (
make_deferred_yieldable,
@@ -1245,7 +1245,7 @@ class FederationServer(FederationBase):
# while holding the `_INBOUND_EVENT_HANDLING_LOCK_NAME`
# lock.
async with self._worker_lock_handler.acquire_read_write_lock(
- DELETE_ROOM_LOCK_NAME, room_id, write=False
+ NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
await self._federation_event_handler.on_receive_pdu(
origin, event
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 119c7f8384..0e812a6d8b 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -67,6 +67,7 @@ class AdminHandler:
"name",
"admin",
"deactivated",
+ "locked",
"shadow_banned",
"creation_ts",
"appservice_id",
diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py
index fc467bc7c1..5c71637038 100644
--- a/synapse/handlers/cas.py
+++ b/synapse/handlers/cas.py
@@ -76,12 +76,13 @@ class CasHandler:
self.idp_id = "cas"
# user-facing name of this auth provider
- self.idp_name = "CAS"
+ self.idp_name = hs.config.cas.idp_name
- # we do not currently support brands/icons for CAS auth, but this is required by
- # the SsoIdentityProvider protocol type.
- self.idp_icon = None
- self.idp_brand = None
+ # MXC URI for icon for this auth provider
+ self.idp_icon = hs.config.cas.idp_icon
+
+ # optional brand identifier for this auth provider
+ self.idp_brand = hs.config.cas.idp_brand
self._sso_handler = hs.get_sso_handler()
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index b7bf70a72d..5ae427d52c 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -385,6 +385,7 @@ class DeviceHandler(DeviceWorkerHandler):
self.federation_sender = hs.get_federation_sender()
self._account_data_handler = hs.get_account_data_handler()
self._storage_controllers = hs.get_storage_controllers()
+ self.db_pool = hs.get_datastores().main.db_pool
self.device_list_updater = DeviceListUpdater(hs, self)
@@ -656,15 +657,17 @@ class DeviceHandler(DeviceWorkerHandler):
device_id: Optional[str],
device_data: JsonDict,
initial_device_display_name: Optional[str] = None,
+ keys_for_device: Optional[JsonDict] = None,
) -> str:
- """Store a dehydrated device for a user. If the user had a previous
- dehydrated device, it is removed.
+ """Store a dehydrated device for a user, optionally storing the keys associated with
+ it as well. If the user had a previous dehydrated device, it is removed.
Args:
user_id: the user that we are storing the device for
device_id: device id supplied by client
device_data: the dehydrated device information
initial_device_display_name: The display name to use for the device
+ keys_for_device: keys for the dehydrated device
Returns:
device id of the dehydrated device
"""
@@ -673,11 +676,16 @@ class DeviceHandler(DeviceWorkerHandler):
device_id,
initial_device_display_name,
)
+
+ time_now = self.clock.time_msec()
+
old_device_id = await self.store.store_dehydrated_device(
- user_id, device_id, device_data
+ user_id, device_id, device_data, time_now, keys_for_device
)
+
if old_device_id is not None:
await self.delete_devices(user_id, [old_device_id])
+
return device_id
async def rehydrate_device(
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 15e94a03cb..17ff8821d9 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -367,19 +367,6 @@ class DeviceMessageHandler:
errcode=Codes.INVALID_PARAM,
)
- # if we have a since token, delete any to-device messages before that token
- # (since we now know that the device has received them)
- deleted = await self.store.delete_messages_for_device(
- user_id, device_id, since_stream_id
- )
- logger.debug(
- "Deleted %d to-device messages up to %d for user_id %s device_id %s",
- deleted,
- since_stream_id,
- user_id,
- device_id,
- )
-
to_token = self.event_sources.get_current_token().to_device_key
messages, stream_id = await self.store.get_messages_for_device(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index d485f21e49..a74db1dccf 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -53,7 +53,7 @@ from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
-from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
+from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -1034,7 +1034,7 @@ class EventCreationHandler:
)
async with self._worker_lock_handler.acquire_read_write_lock(
- DELETE_ROOM_LOCK_NAME, room_id, write=False
+ NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
return await self._create_and_send_nonmember_event_locked(
requester=requester,
@@ -1978,7 +1978,7 @@ class EventCreationHandler:
for room_id in room_ids:
async with self._worker_lock_handler.acquire_read_write_lock(
- DELETE_ROOM_LOCK_NAME, room_id, write=False
+ NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
dummy_event_sent = await self._send_dummy_event_for_room(room_id)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index da34658470..1be6ebc6d9 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -24,6 +24,7 @@ from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import SerializeEventConfig
from synapse.handlers.room import ShutdownRoomResponse
+from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging.opentracing import trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.admin._base import assert_user_is_admin
@@ -46,9 +47,10 @@ logger = logging.getLogger(__name__)
BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3
-PURGE_HISTORY_LOCK_NAME = "purge_history_lock"
-
-DELETE_ROOM_LOCK_NAME = "delete_room_lock"
+# This is used to avoid purging a room several time at the same moment,
+# and also paginating during a purge. Pagination can trigger backfill,
+# which would create old events locally, and would potentially clash with the room delete.
+PURGE_PAGINATION_LOCK_NAME = "purge_pagination_lock"
@attr.s(slots=True, auto_attribs=True)
@@ -363,7 +365,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.add(room_id)
try:
async with self._worker_locks.acquire_read_write_lock(
- PURGE_HISTORY_LOCK_NAME, room_id, write=True
+ PURGE_PAGINATION_LOCK_NAME, room_id, write=True
):
await self._storage_controllers.purge_events.purge_history(
room_id, token, delete_local_events
@@ -421,7 +423,10 @@ class PaginationHandler:
force: set true to skip checking for joined users.
"""
async with self._worker_locks.acquire_multi_read_write_lock(
- [(PURGE_HISTORY_LOCK_NAME, room_id), (DELETE_ROOM_LOCK_NAME, room_id)],
+ [
+ (PURGE_PAGINATION_LOCK_NAME, room_id),
+ (NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id),
+ ],
write=True,
):
# first check that we have no users in this room
@@ -483,7 +488,7 @@ class PaginationHandler:
room_token = from_token.room_key
async with self._worker_locks.acquire_read_write_lock(
- PURGE_HISTORY_LOCK_NAME, room_id, write=False
+ PURGE_PAGINATION_LOCK_NAME, room_id, write=False
):
(membership, member_event_id) = (None, None)
if not use_admin_priviledge:
@@ -761,7 +766,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.add(room_id)
try:
async with self._worker_locks.acquire_read_write_lock(
- PURGE_HISTORY_LOCK_NAME, room_id, write=True
+ PURGE_PAGINATION_LOCK_NAME, room_id, write=True
):
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN
self._delete_by_id[
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index cd7df0525f..e8e9db4b91 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -30,9 +30,9 @@ from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
- Awaitable,
Callable,
Collection,
+ ContextManager,
Dict,
Generator,
Iterable,
@@ -44,7 +44,6 @@ from typing import (
)
from prometheus_client import Counter
-from typing_extensions import ContextManager
import synapse.metrics
from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState
@@ -54,7 +53,10 @@ from synapse.appservice import ApplicationService
from synapse.events.presence_router import PresenceRouter
from synapse.logging.context import run_in_background
from synapse.metrics import LaterGauge
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+ run_as_background_process,
+ wrap_as_background_process,
+)
from synapse.replication.http.presence import (
ReplicationBumpPresenceActiveTime,
ReplicationPresenceSetState,
@@ -141,6 +143,8 @@ class BasePresenceHandler(abc.ABC):
self.state = hs.get_state_handler()
self.is_mine_id = hs.is_mine_id
+ self._presence_enabled = hs.config.server.use_presence
+
self._federation = None
if hs.should_send_federation():
self._federation = hs.get_federation_sender()
@@ -149,6 +153,15 @@ class BasePresenceHandler(abc.ABC):
self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
+ self.VALID_PRESENCE: Tuple[str, ...] = (
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.OFFLINE,
+ )
+
+ if self._busy_presence_enabled:
+ self.VALID_PRESENCE += (PresenceState.BUSY,)
+
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {state.user_id: state for state in active_presence}
@@ -395,8 +408,6 @@ class WorkerPresenceHandler(BasePresenceHandler):
self._presence_writer_instance = hs.config.worker.writers.presence[0]
- self._presence_enabled = hs.config.server.use_presence
-
# Route presence EDUs to the right worker
hs.get_federation_registry().register_instances_for_edu(
EduTypes.PRESENCE,
@@ -421,8 +432,6 @@ class WorkerPresenceHandler(BasePresenceHandler):
self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
)
- self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
-
hs.get_reactor().addSystemEventTrigger(
"before",
"shutdown",
@@ -490,7 +499,9 @@ class WorkerPresenceHandler(BasePresenceHandler):
# what the spec wants: see comment in the BasePresenceHandler version
# of this function.
await self.set_state(
- UserID.from_string(user_id), {"presence": presence_state}, True
+ UserID.from_string(user_id),
+ {"presence": presence_state},
+ ignore_status_msg=True,
)
curr_sync = self._user_to_num_current_syncs.get(user_id, 0)
@@ -601,22 +612,13 @@ class WorkerPresenceHandler(BasePresenceHandler):
"""
presence = state["presence"]
- valid_presence = (
- PresenceState.ONLINE,
- PresenceState.UNAVAILABLE,
- PresenceState.OFFLINE,
- PresenceState.BUSY,
- )
-
- if presence not in valid_presence or (
- presence == PresenceState.BUSY and not self._busy_presence_enabled
- ):
+ if presence not in self.VALID_PRESENCE:
raise SynapseError(400, "Invalid presence state")
user_id = target_user.to_string()
# If presence is disabled, no-op
- if not self.hs.config.server.use_presence:
+ if not self._presence_enabled:
return
# Proxy request to instance that writes presence
@@ -633,7 +635,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
with the app.
"""
# If presence is disabled, no-op
- if not self.hs.config.server.use_presence:
+ if not self._presence_enabled:
return
# Proxy request to instance that writes presence
@@ -649,7 +651,6 @@ class PresenceHandler(BasePresenceHandler):
self.hs = hs
self.wheel_timer: WheelTimer[str] = WheelTimer()
self.notifier = hs.get_notifier()
- self._presence_enabled = hs.config.server.use_presence
federation_registry = hs.get_federation_registry()
@@ -700,8 +701,6 @@ class PresenceHandler(BasePresenceHandler):
self._on_shutdown,
)
- self._next_serial = 1
-
# Keeps track of the number of *ongoing* syncs on this process. While
# this is non zero a user will never go offline.
self.user_to_num_current_syncs: Dict[str, int] = {}
@@ -723,21 +722,16 @@ class PresenceHandler(BasePresenceHandler):
# Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline.
- def run_timeout_handler() -> Awaitable[None]:
- return run_as_background_process(
- "handle_presence_timeouts", self._handle_timeouts
- )
-
self.clock.call_later(
- 30, self.clock.looping_call, run_timeout_handler, 5000
+ 30, self.clock.looping_call, self._handle_timeouts, 5000
)
- def run_persister() -> Awaitable[None]:
- return run_as_background_process(
- "persist_presence_changes", self._persist_unpersisted_changes
- )
-
- self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000)
+ self.clock.call_later(
+ 60,
+ self.clock.looping_call,
+ self._persist_unpersisted_changes,
+ 60 * 1000,
+ )
LaterGauge(
"synapse_handlers_presence_wheel_timer_size",
@@ -783,6 +777,7 @@ class PresenceHandler(BasePresenceHandler):
)
logger.info("Finished _on_shutdown")
+ @wrap_as_background_process("persist_presence_changes")
async def _persist_unpersisted_changes(self) -> None:
"""We periodically persist the unpersisted changes, as otherwise they
may stack up and slow down shutdown times.
@@ -898,6 +893,7 @@ class PresenceHandler(BasePresenceHandler):
states, [destination]
)
+ @wrap_as_background_process("handle_presence_timeouts")
async def _handle_timeouts(self) -> None:
"""Checks the presence of users that have timed out and updates as
appropriate.
@@ -955,7 +951,7 @@ class PresenceHandler(BasePresenceHandler):
with the app.
"""
# If presence is disabled, no-op
- if not self.hs.config.server.use_presence:
+ if not self._presence_enabled:
return
user_id = user.to_string()
@@ -990,56 +986,51 @@ class PresenceHandler(BasePresenceHandler):
client that is being used by a user.
presence_state: The presence state indicated in the sync request
"""
- # Override if it should affect the user's presence, if presence is
- # disabled.
- if not self.hs.config.server.use_presence:
- affect_presence = False
+ if not affect_presence or not self._presence_enabled:
+ return _NullContextManager()
- if affect_presence:
- curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
- self.user_to_num_current_syncs[user_id] = curr_sync + 1
+ curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
+ self.user_to_num_current_syncs[user_id] = curr_sync + 1
- prev_state = await self.current_state_for_user(user_id)
+ prev_state = await self.current_state_for_user(user_id)
- # If they're busy then they don't stop being busy just by syncing,
- # so just update the last sync time.
- if prev_state.state != PresenceState.BUSY:
- # XXX: We set_state separately here and just update the last_active_ts above
- # This keeps the logic as similar as possible between the worker and single
- # process modes. Using set_state will actually cause last_active_ts to be
- # updated always, which is not what the spec calls for, but synapse has done
- # this for... forever, I think.
- await self.set_state(
- UserID.from_string(user_id), {"presence": presence_state}, True
- )
- # Retrieve the new state for the logic below. This should come from the
- # in-memory cache.
- prev_state = await self.current_state_for_user(user_id)
+ # If they're busy then they don't stop being busy just by syncing,
+ # so just update the last sync time.
+ if prev_state.state != PresenceState.BUSY:
+ # XXX: We set_state separately here and just update the last_active_ts above
+ # This keeps the logic as similar as possible between the worker and single
+ # process modes. Using set_state will actually cause last_active_ts to be
+ # updated always, which is not what the spec calls for, but synapse has done
+ # this for... forever, I think.
+ await self.set_state(
+ UserID.from_string(user_id),
+ {"presence": presence_state},
+ ignore_status_msg=True,
+ )
+ # Retrieve the new state for the logic below. This should come from the
+ # in-memory cache.
+ prev_state = await self.current_state_for_user(user_id)
- # To keep the single process behaviour consistent with worker mode, run the
- # same logic as `update_external_syncs_row`, even though it looks weird.
- if prev_state.state == PresenceState.OFFLINE:
- await self._update_states(
- [
- prev_state.copy_and_replace(
- state=PresenceState.ONLINE,
- last_active_ts=self.clock.time_msec(),
- last_user_sync_ts=self.clock.time_msec(),
- )
- ]
- )
- # otherwise, set the new presence state & update the last sync time,
- # but don't update last_active_ts as this isn't an indication that
- # they've been active (even though it's probably been updated by
- # set_state above)
- else:
- await self._update_states(
- [
- prev_state.copy_and_replace(
- last_user_sync_ts=self.clock.time_msec()
- )
- ]
- )
+ # To keep the single process behaviour consistent with worker mode, run the
+ # same logic as `update_external_syncs_row`, even though it looks weird.
+ if prev_state.state == PresenceState.OFFLINE:
+ await self._update_states(
+ [
+ prev_state.copy_and_replace(
+ state=PresenceState.ONLINE,
+ last_active_ts=self.clock.time_msec(),
+ last_user_sync_ts=self.clock.time_msec(),
+ )
+ ]
+ )
+ # otherwise, set the new presence state & update the last sync time,
+ # but don't update last_active_ts as this isn't an indication that
+ # they've been active (even though it's probably been updated by
+ # set_state above)
+ else:
+ await self._update_states(
+ [prev_state.copy_and_replace(last_user_sync_ts=self.clock.time_msec())]
+ )
async def _end() -> None:
try:
@@ -1061,8 +1052,7 @@ class PresenceHandler(BasePresenceHandler):
try:
yield
finally:
- if affect_presence:
- run_in_background(_end)
+ run_in_background(_end)
return _user_syncing()
@@ -1229,20 +1219,11 @@ class PresenceHandler(BasePresenceHandler):
status_msg = state.get("status_msg", None)
presence = state["presence"]
- valid_presence = (
- PresenceState.ONLINE,
- PresenceState.UNAVAILABLE,
- PresenceState.OFFLINE,
- PresenceState.BUSY,
- )
-
- if presence not in valid_presence or (
- presence == PresenceState.BUSY and not self._busy_presence_enabled
- ):
+ if presence not in self.VALID_PRESENCE:
raise SynapseError(400, "Invalid presence state")
# If presence is disabled, no-op
- if not self.hs.config.server.use_presence:
+ if not self._presence_enabled:
return
user_id = target_user.to_string()
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index bd8277e736..bb409f97b7 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -39,7 +39,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
-from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
+from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging import opentracing
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -629,7 +629,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
async with self.member_linearizer.queue(key):
async with self._worker_lock_handler.acquire_read_write_lock(
- DELETE_ROOM_LOCK_NAME, room_id, write=False
+ NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
diff = self.clock.time_msec() - then
diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py
index 6083c9f4b5..d00035c332 100644
--- a/synapse/handlers/saml.py
+++ b/synapse/handlers/saml.py
@@ -74,12 +74,13 @@ class SamlHandler:
self.idp_id = "saml"
# user-facing name of this auth provider
- self.idp_name = "SAML"
+ self.idp_name = hs.config.saml2.idp_name
- # we do not currently support icons/brands for SAML auth, but this is required by
- # the SsoIdentityProvider protocol type.
- self.idp_icon = None
- self.idp_brand = None
+ # MXC URI for icon for this auth provider
+ self.idp_icon = hs.config.saml2.idp_icon
+
+ # optional brand identifier for this auth provider
+ self.idp_brand = hs.config.saml2.idp_brand
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict: Dict[str, Saml2SessionData] = {}
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 4d29328a74..e9a544e754 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -24,13 +24,14 @@ from typing import (
Iterable,
List,
Mapping,
+ NoReturn,
Optional,
Set,
)
from urllib.parse import urlencode
import attr
-from typing_extensions import NoReturn, Protocol
+from typing_extensions import Protocol
from twisted.web.iweb import IRequest
from twisted.web.server import Request
@@ -791,7 +792,7 @@ class SsoHandler:
if code != 200:
raise Exception(
- "GET request to download sso avatar image returned {}".format(code)
+ f"GET request to download sso avatar image returned {code}"
)
# upload name includes hash of the image file's content so that we can
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 7cabf7980a..3dde19fc81 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -14,9 +14,15 @@
# limitations under the License.
import logging
from collections import Counter
-from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
-
-from typing_extensions import Counter as CounterType
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Counter as CounterType,
+ Dict,
+ Iterable,
+ Optional,
+ Tuple,
+)
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.metrics import event_processing_positions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c010405be6..60a9f341b5 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -387,16 +387,16 @@ class SyncHandler:
from_token=since_token,
)
- # if nothing has happened in any of the users' rooms since /sync was called,
- # the resultant next_batch will be the same as since_token (since the result
- # is generated when wait_for_events is first called, and not regenerated
- # when wait_for_events times out).
- #
- # If that happens, we mustn't cache it, so that when the client comes back
- # with the same cache token, we don't immediately return the same empty
- # result, causing a tightloop. (#8518)
- if result.next_batch == since_token:
- cache_context.should_cache = False
+ # if nothing has happened in any of the users' rooms since /sync was called,
+ # the resultant next_batch will be the same as since_token (since the result
+ # is generated when wait_for_events is first called, and not regenerated
+ # when wait_for_events times out).
+ #
+ # If that happens, we mustn't cache it, so that when the client comes back
+ # with the same cache token, we don't immediately return the same empty
+ # result, causing a tightloop. (#8518)
+ if result.next_batch == since_token:
+ cache_context.should_cache = False
if result:
if sync_config.filter_collection.lazy_load_members():
@@ -1442,11 +1442,9 @@ class SyncHandler:
# Now we have our list of joined room IDs, exclude as configured and freeze
joined_room_ids = frozenset(
- (
- room_id
- for room_id in mutable_joined_room_ids
- if room_id not in mutable_rooms_to_exclude
- )
+ room_id
+ for room_id in mutable_joined_room_ids
+ if room_id not in mutable_rooms_to_exclude
)
logger.debug(
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 05197edc95..a0f5568000 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -94,6 +94,7 @@ class UserDirectoryHandler(StateDeltasHandler):
self.is_mine_id = hs.is_mine_id
self.update_user_directory = hs.config.worker.should_update_user_directory
self.search_all_users = hs.config.userdirectory.user_directory_search_all_users
+ self.show_locked_users = hs.config.userdirectory.show_locked_users
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
self._hs = hs
@@ -144,7 +145,9 @@ class UserDirectoryHandler(StateDeltasHandler):
]
}
"""
- results = await self.store.search_user_dir(user_id, search_term, limit)
+ results = await self.store.search_user_dir(
+ user_id, search_term, limit, self.show_locked_users
+ )
# Remove any spammy users from the results.
non_spammy_users = []
diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py
index 72df773a86..58efe7116b 100644
--- a/synapse/handlers/worker_lock.py
+++ b/synapse/handlers/worker_lock.py
@@ -42,7 +42,11 @@ if TYPE_CHECKING:
from synapse.server import HomeServer
-DELETE_ROOM_LOCK_NAME = "delete_room_lock"
+# This lock is used to avoid creating an event while we are purging the room.
+# We take a read lock when creating an event, and a write one when purging a room.
+# This is because it is fine to create several events concurrently, since referenced events
+# will not disappear under our feet as long as we don't delete the room.
+NEW_EVENT_DURING_PURGE_LOCK_NAME = "new_event_during_purge_lock"
class WorkerLocksHandler:
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index 5a61b21eaf..284fbac524 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -18,10 +18,9 @@ import traceback
from collections import deque
from ipaddress import IPv4Address, IPv6Address, ip_address
from math import floor
-from typing import Callable, Optional
+from typing import Callable, Deque, Optional
import attr
-from typing_extensions import Deque
from zope.interface import implementer
from twisted.application.internet import ClientService
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index acee1dafd3..9ad8e038ae 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -31,7 +31,7 @@ from typing import (
import attr
import jinja2
-from typing_extensions import ParamSpec
+from typing_extensions import Concatenate, ParamSpec
from twisted.internet import defer
from twisted.internet.interfaces import IDelayedCall
@@ -885,7 +885,7 @@ class ModuleApi:
def run_db_interaction(
self,
desc: str,
- func: Callable[P, T],
+ func: Callable[Concatenate[LoggingTransaction, P], T],
*args: P.args,
**kwargs: P.kwargs,
) -> "defer.Deferred[T]":
diff --git a/synapse/module_api/callbacks/spamchecker_callbacks.py b/synapse/module_api/callbacks/spamchecker_callbacks.py
index e191450323..32db7cce8d 100644
--- a/synapse/module_api/callbacks/spamchecker_callbacks.py
+++ b/synapse/module_api/callbacks/spamchecker_callbacks.py
@@ -426,9 +426,7 @@ class SpamCheckerModuleApiCallbacks:
generally discouraged as it doesn't support internationalization.
"""
for callback in self._check_event_for_spam_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(event))
if res is False or res == self.NOT_SPAM:
# This spam-checker accepts the event.
@@ -481,9 +479,7 @@ class SpamCheckerModuleApiCallbacks:
True if the event should be silently dropped
"""
for callback in self._should_drop_federated_event_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res: Union[bool, str] = await delay_cancellation(callback(event))
if res:
return res
@@ -505,9 +501,7 @@ class SpamCheckerModuleApiCallbacks:
NOT_SPAM if the operation is permitted, [Codes, Dict] otherwise.
"""
for callback in self._user_may_join_room_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(user_id, room_id, is_invited))
# Normalize return values to `Codes` or `"NOT_SPAM"`.
if res is True or res is self.NOT_SPAM:
@@ -546,9 +540,7 @@ class SpamCheckerModuleApiCallbacks:
NOT_SPAM if the operation is permitted, Codes otherwise.
"""
for callback in self._user_may_invite_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(
callback(inviter_userid, invitee_userid, room_id)
)
@@ -593,9 +585,7 @@ class SpamCheckerModuleApiCallbacks:
NOT_SPAM if the operation is permitted, Codes otherwise.
"""
for callback in self._user_may_send_3pid_invite_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(
callback(inviter_userid, medium, address, room_id)
)
@@ -630,9 +620,7 @@ class SpamCheckerModuleApiCallbacks:
userid: The ID of the user attempting to create a room
"""
for callback in self._user_may_create_room_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(userid))
if res is True or res is self.NOT_SPAM:
continue
@@ -666,9 +654,7 @@ class SpamCheckerModuleApiCallbacks:
"""
for callback in self._user_may_create_room_alias_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(userid, room_alias))
if res is True or res is self.NOT_SPAM:
continue
@@ -701,9 +687,7 @@ class SpamCheckerModuleApiCallbacks:
room_id: The ID of the room that would be published
"""
for callback in self._user_may_publish_room_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(userid, room_id))
if res is True or res is self.NOT_SPAM:
continue
@@ -742,9 +726,7 @@ class SpamCheckerModuleApiCallbacks:
True if the user is spammy.
"""
for callback in self._check_username_for_spam_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
# Make a copy of the user profile object to ensure the spam checker cannot
# modify it.
res = await delay_cancellation(callback(user_profile.copy()))
@@ -776,9 +758,7 @@ class SpamCheckerModuleApiCallbacks:
"""
for callback in self._check_registration_for_spam_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
behaviour = await delay_cancellation(
callback(email_threepid, username, request_info, auth_provider_id)
)
@@ -820,9 +800,7 @@ class SpamCheckerModuleApiCallbacks:
"""
for callback in self._check_media_file_for_spam_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(file_wrapper, file_info))
# Normalize return values to `Codes` or `"NOT_SPAM"`.
if res is False or res is self.NOT_SPAM:
@@ -869,9 +847,7 @@ class SpamCheckerModuleApiCallbacks:
"""
for callback in self._check_login_for_spam_callbacks:
- with Measure(
- self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
- ):
+ with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(
callback(
user_id,
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index a2cabba7b1..38adcbe1d0 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -17,6 +17,7 @@ from typing import (
TYPE_CHECKING,
Any,
Awaitable,
+ Deque,
Dict,
Iterable,
Iterator,
@@ -29,7 +30,6 @@ from typing import (
)
from prometheus_client import Counter
-from typing_extensions import Deque
from twisted.internet.protocol import ReconnectingClientFactory
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index e0257daa75..04d9ef25b7 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -280,6 +280,17 @@ class UserRestServletV2(RestServlet):
HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean"
)
+ lock = body.get("locked", False)
+ if not isinstance(lock, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "'locked' parameter is not of type boolean"
+ )
+
+ if deactivate and lock:
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "An user can't be deactivated and locked"
+ )
+
approved: Optional[bool] = None
if "approved" in body and self._msc3866_enabled:
approved = body["approved"]
@@ -397,6 +408,12 @@ class UserRestServletV2(RestServlet):
target_user.to_string()
)
+ if "locked" in body:
+ if lock and not user["locked"]:
+ await self.store.set_user_locked_status(user_id, True)
+ elif not lock and user["locked"]:
+ await self.store.set_user_locked_status(user_id, False)
+
if "user_type" in body:
await self.store.set_user_type(target_user, user_type)
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 51f17f80da..925f037743 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -29,7 +29,6 @@ from synapse.http.servlet import (
parse_integer,
)
from synapse.http.site import SynapseRequest
-from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.rest.client.models import AuthenticationData
from synapse.rest.models import RequestBodyModel
@@ -480,13 +479,6 @@ class DehydratedDeviceV2Servlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = handler
- if hs.config.worker.worker_app is None:
- # if main process
- self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
- else:
- # then a worker
- self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)
-
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
@@ -549,18 +541,12 @@ class DehydratedDeviceV2Servlet(RestServlet):
"Device key(s) not found, these must be provided.",
)
- # TODO: Those two operations, creating a device and storing the
- # device's keys should be atomic.
device_id = await self.device_handler.store_dehydrated_device(
requester.user.to_string(),
submission.device_id,
submission.device_data.dict(),
submission.initial_device_display_name,
- )
-
- # TODO: Do we need to do something with the result here?
- await self.key_uploader(
- user_id=user_id, device_id=submission.device_id, keys=submission.dict()
+ device_info,
)
return 200, {"device_id": device_id}
diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py
index 94ad90942f..2e104d4888 100644
--- a/synapse/rest/client/logout.py
+++ b/synapse/rest/client/logout.py
@@ -40,7 +40,9 @@ class LogoutRestServlet(RestServlet):
self._device_handler = handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_expired=True)
+ requester = await self.auth.get_user_by_req(
+ request, allow_expired=True, allow_locked=True
+ )
if requester.device_id is None:
# The access token wasn't associated with a device.
@@ -67,7 +69,9 @@ class LogoutAllRestServlet(RestServlet):
self._device_handler = handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_expired=True)
+ requester = await self.auth.get_user_by_req(
+ request, allow_expired=True, allow_locked=True
+ )
user_id = requester.user.to_string()
# first delete all of the user's devices
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index 5c9fece3ba..5ed3b83a03 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -32,6 +32,7 @@ from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest.client._base import client_patterns
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.types import JsonDict
+from synapse.util.async_helpers import Linearizer
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -53,26 +54,32 @@ class PushRuleRestServlet(RestServlet):
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker.worker_app is not None
self._push_rules_handler = hs.get_push_rules_handler()
+ self._push_rule_linearizer = Linearizer(name="push_rules")
async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
if self._is_worker:
raise Exception("Cannot handle PUT /push_rules on worker")
+ requester = await self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ async with self._push_rule_linearizer.queue(user_id):
+ return await self.handle_put(request, path, user_id)
+
+ async def handle_put(
+ self, request: SynapseRequest, path: str, user_id: str
+ ) -> Tuple[int, JsonDict]:
spec = _rule_spec_from_path(path.split("/"))
try:
priority_class = _priority_class_from_spec(spec)
except InvalidRuleException as e:
raise SynapseError(400, str(e))
- requester = await self.auth.get_user_by_req(request)
-
if "/" in spec.rule_id or "\\" in spec.rule_id:
raise SynapseError(400, "rule_id may not contain slashes")
content = parse_json_value_from_request(request)
- user_id = requester.user.to_string()
-
if spec.attr:
try:
await self._push_rules_handler.set_rule_attr(user_id, spec, content)
@@ -126,11 +133,20 @@ class PushRuleRestServlet(RestServlet):
if self._is_worker:
raise Exception("Cannot handle DELETE /push_rules on worker")
- spec = _rule_spec_from_path(path.split("/"))
-
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
+ async with self._push_rule_linearizer.queue(user_id):
+ return await self.handle_delete(request, path, user_id)
+
+ async def handle_delete(
+ self,
+ request: SynapseRequest,
+ path: str,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
+ spec = _rule_spec_from_path(path.split("/"))
+
namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}"
try:
diff --git a/synapse/rest/client/room_upgrade_rest_servlet.py b/synapse/rest/client/room_upgrade_rest_servlet.py
index 4a5d9e13e7..b1f6b5d1b7 100644
--- a/synapse/rest/client/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/room_upgrade_rest_servlet.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, ShadowBanError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
-from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
+from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@@ -81,7 +81,7 @@ class RoomUpgradeRestServlet(RestServlet):
try:
async with self._worker_lock_handler.acquire_read_write_lock(
- DELETE_ROOM_LOCK_NAME, room_id, write=False
+ NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
new_room_id = await self._room_creation_handler.upgrade_room(
requester, room_id, new_version
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 8f3865d412..981fd1f58a 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -14,7 +14,7 @@
import logging
import re
-from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
from signedjson.sign import sign_json
@@ -27,6 +27,7 @@ from synapse.http.servlet import (
parse_integer,
parse_json_object_from_request,
)
+from synapse.storage.keys import FetchKeyResultForRemote
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
@@ -157,14 +158,22 @@ class RemoteKey(RestServlet):
) -> JsonDict:
logger.info("Handling query for keys %r", query)
- store_queries = []
+ server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
for server_name, key_ids in query.items():
- if not key_ids:
- key_ids = (None,)
- for key_id in key_ids:
- store_queries.append((server_name, key_id, None))
+ if key_ids:
+ results: Mapping[
+ str, Optional[FetchKeyResultForRemote]
+ ] = await self.store.get_server_keys_json_for_remote(
+ server_name, key_ids
+ )
+ else:
+ results = await self.store.get_all_server_keys_json_for_remote(
+ server_name
+ )
- cached = await self.store.get_server_keys_json_for_remote(store_queries)
+ server_keys.update(
+ ((server_name, key_id), res) for key_id, res in results.items()
+ )
json_results: Set[bytes] = set()
@@ -173,23 +182,20 @@ class RemoteKey(RestServlet):
# Map server_name->key_id->int. Note that the value of the int is unused.
# XXX: why don't we just use a set?
cache_misses: Dict[str, Dict[str, int]] = {}
- for (server_name, key_id, _), key_results in cached.items():
- results = [(result["ts_added_ms"], result) for result in key_results]
-
- if key_id is None:
+ for (server_name, key_id), key_result in server_keys.items():
+ if not query[server_name]:
# all keys were requested. Just return what we have without worrying
# about validity
- for _, result in results:
- # Cast to bytes since postgresql returns a memoryview.
- json_results.add(bytes(result["key_json"]))
+ if key_result:
+ json_results.add(key_result.key_json)
continue
miss = False
- if not results:
+ if key_result is None:
miss = True
else:
- ts_added_ms, most_recent_result = max(results)
- ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
+ ts_added_ms = key_result.added_ts
+ ts_valid_until_ms = key_result.valid_until_ts
req_key = query.get(server_name, {}).get(key_id, {})
req_valid_until = req_key.get("minimum_valid_until_ts")
if req_valid_until is not None:
@@ -235,8 +241,8 @@ class RemoteKey(RestServlet):
ts_valid_until_ms,
time_now_ms,
)
- # Cast to bytes since postgresql returns a memoryview.
- json_results.add(bytes(most_recent_result["key_json"]))
+
+ json_results.add(key_result.key_json)
if miss and query_remote_on_cache_miss:
# only bother attempting to fetch keys from servers on our whitelist
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 2d5ddc3e7b..ddca0af1da 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -238,6 +238,7 @@ class BackgroundUpdater:
def __init__(self, hs: "HomeServer", database: "DatabasePool"):
self._clock = hs.get_clock()
self.db_pool = database
+ self.hs = hs
self._database_name = database.name()
@@ -758,6 +759,11 @@ class BackgroundUpdater:
logger.debug("[SQL] %s", sql)
c.execute(sql)
+ # override the global statement timeout to avoid accidentally squashing
+ # a long-running index creation process
+ timeout_sql = "SET SESSION statement_timeout = 0"
+ c.execute(timeout_sql)
+
sql = (
"CREATE %(unique)s INDEX CONCURRENTLY %(name)s"
" ON %(table)s"
@@ -778,6 +784,12 @@ class BackgroundUpdater:
logger.debug("[SQL] %s", sql)
c.execute(sql)
finally:
+ # mypy ignore - `statement_timeout` is defined on PostgresEngine
+ # reset the global timeout to the default
+ default_timeout = self.db_pool.engine.statement_timeout # type: ignore[attr-defined]
+ undo_timeout_sql = f"SET statement_timeout = {default_timeout}"
+ conn.cursor().execute(undo_timeout_sql)
+
conn.set_session(autocommit=False) # type: ignore
def create_index_sqlite(conn: Connection) -> None:
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index 35cd1089d6..abd1d149db 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -45,7 +45,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME
+from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.logging.opentracing import (
SynapseTags,
@@ -357,7 +357,7 @@ class EventsPersistenceStorageController:
# it. We might already have taken out the lock, but since this is just a
# "read" lock its inherently reentrant.
async with self.hs.get_worker_locks_handler().acquire_read_write_lock(
- DELETE_ROOM_LOCK_NAME, room_id, write=False
+ NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
if isinstance(task, _PersistEventsTask):
return await self._persist_event_batch(room_id, task)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d9df437e51..e4162f846b 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -28,6 +28,7 @@ from typing import (
cast,
)
+from canonicaljson import encode_canonical_json
from typing_extensions import Literal
from synapse.api.constants import EduTypes
@@ -1188,8 +1189,42 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
def _store_dehydrated_device_txn(
- self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ device_data: str,
+ time: int,
+ keys: Optional[JsonDict] = None,
) -> Optional[str]:
+ # TODO: make keys non-optional once support for msc2697 is dropped
+ if keys:
+ device_keys = keys.get("device_keys", None)
+ if device_keys:
+ # Type ignore - this function is defined on EndToEndKeyStore which we do
+ # have access to due to hs.get_datastore() "magic"
+ self._set_e2e_device_keys_txn( # type: ignore[attr-defined]
+ txn, user_id, device_id, time, device_keys
+ )
+
+ one_time_keys = keys.get("one_time_keys", None)
+ if one_time_keys:
+ key_list = []
+ for key_id, key_obj in one_time_keys.items():
+ algorithm, key_id = key_id.split(":")
+ key_list.append(
+ (
+ algorithm,
+ key_id,
+ encode_canonical_json(key_obj).decode("ascii"),
+ )
+ )
+ self._add_e2e_one_time_keys_txn(txn, user_id, device_id, time, key_list)
+
+ fallback_keys = keys.get("fallback_keys", None)
+ if fallback_keys:
+ self._set_e2e_fallback_keys_txn(txn, user_id, device_id, fallback_keys)
+
old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn,
table="dehydrated_devices",
@@ -1203,10 +1238,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
keyvalues={"user_id": user_id},
values={"device_id": device_id, "device_data": device_data},
)
+
return old_device_id
async def store_dehydrated_device(
- self, user_id: str, device_id: str, device_data: JsonDict
+ self,
+ user_id: str,
+ device_id: str,
+ device_data: JsonDict,
+ time_now: int,
+ keys: Optional[dict] = None,
) -> Optional[str]:
"""Store a dehydrated device for a user.
@@ -1214,15 +1255,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_id: the user that we are storing the device for
device_id: the ID of the dehydrated device
device_data: the dehydrated device information
+ time_now: current time at the request in milliseconds
+ keys: keys for the dehydrated device
+
Returns:
device id of the user's previous dehydrated device, if any
"""
+
return await self.db_pool.runInteraction(
"store_dehydrated_device_txn",
self._store_dehydrated_device_txn,
user_id,
device_id,
json_encoder.encode(device_data),
+ time_now,
+ keys,
)
async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 91ae9c457d..b49dea577c 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -522,36 +522,57 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
"""
- def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
- set_tag("user_id", user_id)
- set_tag("device_id", device_id)
- set_tag("new_keys", str(new_keys))
- # We are protected from race between lookup and insertion due to
- # a unique constraint. If there is a race of two calls to
- # `add_e2e_one_time_keys` then they'll conflict and we will only
- # insert one set.
- self.db_pool.simple_insert_many_txn(
- txn,
- table="e2e_one_time_keys_json",
- keys=(
- "user_id",
- "device_id",
- "algorithm",
- "key_id",
- "ts_added_ms",
- "key_json",
- ),
- values=[
- (user_id, device_id, algorithm, key_id, time_now, json_bytes)
- for algorithm, key_id, json_bytes in new_keys
- ],
- )
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
-
await self.db_pool.runInteraction(
- "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
+ "add_e2e_one_time_keys_insert",
+ self._add_e2e_one_time_keys_txn,
+ user_id,
+ device_id,
+ time_now,
+ new_keys,
+ )
+
+ def _add_e2e_one_time_keys_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ time_now: int,
+ new_keys: Iterable[Tuple[str, str, str]],
+ ) -> None:
+ """Insert some new one time keys for a device. Errors if any of the keys already exist.
+
+ Args:
+ user_id: id of user to get keys for
+ device_id: id of device to get keys for
+ time_now: insertion time to record (ms since epoch)
+ new_keys: keys to add - each a tuple of (algorithm, key_id, key json) - note
+ that the key JSON must be in canonical JSON form
+ """
+ set_tag("user_id", user_id)
+ set_tag("device_id", device_id)
+ set_tag("new_keys", str(new_keys))
+ # We are protected from race between lookup and insertion due to
+ # a unique constraint. If there is a race of two calls to
+ # `add_e2e_one_time_keys` then they'll conflict and we will only
+ # insert one set.
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="e2e_one_time_keys_json",
+ keys=(
+ "user_id",
+ "device_id",
+ "algorithm",
+ "key_id",
+ "ts_added_ms",
+ "key_json",
+ ),
+ values=[
+ (user_id, device_id, algorithm, key_id, time_now, json_bytes)
+ for algorithm, key_id, json_bytes in new_keys
+ ],
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
@cached(max_entries=10000)
@@ -723,6 +744,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
device_id: str,
fallback_keys: JsonDict,
) -> None:
+ """Set the user's e2e fallback keys.
+
+ Args:
+ user_id: the user whose keys are being set
+ device_id: the device whose keys are being set
+ fallback_keys: the keys to set. This is a map from key ID (which is
+ of the form "algorithm:id") to key data.
+ """
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
# FIXME: make sure that only one key per algorithm is uploaded
@@ -1304,42 +1333,69 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
+
+ Args:
+ user_id: user_id of the user to store keys for
+ device_id: device_id of the device to store keys for
+ time_now: time at the request to store the keys
+ device_keys: the keys to store
"""
- def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
- set_tag("user_id", user_id)
- set_tag("device_id", device_id)
- set_tag("time_now", time_now)
- set_tag("device_keys", str(device_keys))
+ return await self.db_pool.runInteraction(
+ "set_e2e_device_keys",
+ self._set_e2e_device_keys_txn,
+ user_id,
+ device_id,
+ time_now,
+ device_keys,
+ )
- old_key_json = self.db_pool.simple_select_one_onecol_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- retcol="key_json",
- allow_none=True,
- )
+ def _set_e2e_device_keys_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ time_now: int,
+ device_keys: JsonDict,
+ ) -> bool:
+ """Stores device keys for a device. Returns whether there was a change
+ or the keys were already in the database.
- # In py3 we need old_key_json to match new_key_json type. The DB
- # returns unicode while encode_canonical_json returns bytes.
- new_key_json = encode_canonical_json(device_keys).decode("utf-8")
+ Args:
+ user_id: user_id of the user to store keys for
+ device_id: device_id of the device to store keys for
+ time_now: time at the request to store the keys
+ device_keys: the keys to store
+ """
+ set_tag("user_id", user_id)
+ set_tag("device_id", device_id)
+ set_tag("time_now", time_now)
+ set_tag("device_keys", str(device_keys))
+
+ old_key_json = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ retcol="key_json",
+ allow_none=True,
+ )
- if old_key_json == new_key_json:
- log_kv({"Message": "Device key already stored."})
- return False
+ # In py3 we need old_key_json to match new_key_json type. The DB
+ # returns unicode while encode_canonical_json returns bytes.
+ new_key_json = encode_canonical_json(device_keys).decode("utf-8")
- self.db_pool.simple_upsert_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- values={"ts_added_ms": time_now, "key_json": new_key_json},
- )
- log_kv({"message": "Device keys stored."})
- return True
+ if old_key_json == new_key_json:
+ log_kv({"Message": "Device key already stored."})
+ return False
- return await self.db_pool.runInteraction(
- "set_e2e_device_keys", _set_e2e_device_keys_txn
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ values={"ts_added_ms": time_now, "key_json": new_key_json},
)
+ log_kv({"message": "Device keys stored."})
+ return True
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index fff417f9e3..047de6283a 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional, Tuple, Union, cast
+from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
from canonicaljson import encode_canonical_json
-from typing_extensions import TYPE_CHECKING
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 1666e3c43b..a3b4744855 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,14 +16,13 @@
import itertools
import json
import logging
-from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
+from typing import Dict, Iterable, Mapping, Optional, Tuple
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction
-from synapse.storage.keys import FetchKeyResult
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@@ -34,7 +33,7 @@ logger = logging.getLogger(__name__)
db_binary_type = memoryview
-class KeyStore(SQLBaseStore):
+class KeyStore(CacheInvalidationWorkerStore):
"""Persistence for signature verification keys"""
@cached()
@@ -188,7 +187,12 @@ class KeyStore(SQLBaseStore):
# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
- self._get_server_keys_json.invalidate((((server_name, key_id),)))
+ await self.invalidate_cache_and_stream(
+ "_get_server_keys_json", ((server_name, key_id),)
+ )
+ await self.invalidate_cache_and_stream(
+ "get_server_key_json_for_remote", (server_name, key_id)
+ )
@cached()
def _get_server_keys_json(
@@ -253,47 +257,87 @@ class KeyStore(SQLBaseStore):
return await self.db_pool.runInteraction("get_server_keys_json", _txn)
- async def get_server_keys_json_for_remote(
- self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
- ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
- """Retrieve the key json for a list of server_keys and key ids.
- If no keys are found for a given server, key_id and source then
- that server, key_id, and source triplet entry will be an empty list.
- The JSON is returned as a byte array so that it can be efficiently
- used in an HTTP response.
+ @cached()
+ def get_server_key_json_for_remote(
+ self,
+ server_name: str,
+ key_id: str,
+ ) -> Optional[FetchKeyResultForRemote]:
+ raise NotImplementedError()
- Args:
- server_keys: List of (server_name, key_id, source) triplets.
+ @cachedList(
+ cached_method_name="get_server_key_json_for_remote", list_name="key_ids"
+ )
+ async def get_server_keys_json_for_remote(
+ self, server_name: str, key_ids: Iterable[str]
+ ) -> Dict[str, Optional[FetchKeyResultForRemote]]:
+ """Fetch the cached keys for the given server/key IDs.
- Returns:
- A mapping from (server_name, key_id, source) triplets to a list of dicts
+ If we have multiple entries for a given key ID, returns the most recent.
"""
+ rows = await self.db_pool.simple_select_many_batch(
+ table="server_keys_json",
+ column="key_id",
+ iterable=key_ids,
+ keyvalues={"server_name": server_name},
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ desc="get_server_keys_json_for_remote",
+ )
- def _get_server_keys_json_txn(
- txn: LoggingTransaction,
- ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
- results = {}
- for server_name, key_id, from_server in server_keys:
- keyvalues = {"server_name": server_name}
- if key_id is not None:
- keyvalues["key_id"] = key_id
- if from_server is not None:
- keyvalues["from_server"] = from_server
- rows = self.db_pool.simple_select_list_txn(
- txn,
- "server_keys_json",
- keyvalues=keyvalues,
- retcols=(
- "key_id",
- "from_server",
- "ts_added_ms",
- "ts_valid_until_ms",
- "key_json",
- ),
- )
- results[(server_name, key_id, from_server)] = rows
- return results
+ if not rows:
+ return {}
+
+ # We sort the rows so that the most recently added entry is picked up.
+ rows.sort(key=lambda r: r["ts_added_ms"])
+
+ return {
+ row["key_id"]: FetchKeyResultForRemote(
+ # Cast to bytes since postgresql returns a memoryview.
+ key_json=bytes(row["key_json"]),
+ valid_until_ts=row["ts_valid_until_ms"],
+ added_ts=row["ts_added_ms"],
+ )
+ for row in rows
+ }
- return await self.db_pool.runInteraction(
- "get_server_keys_json", _get_server_keys_json_txn
+ async def get_all_server_keys_json_for_remote(
+ self,
+ server_name: str,
+ ) -> Dict[str, FetchKeyResultForRemote]:
+ """Fetch the cached keys for the given server.
+
+ If we have multiple entries for a given key ID, returns the most recent.
+ """
+ rows = await self.db_pool.simple_select_list(
+ table="server_keys_json",
+ keyvalues={"server_name": server_name},
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ desc="get_server_keys_json_for_remote",
)
+
+ if not rows:
+ return {}
+
+ rows.sort(key=lambda r: r["ts_added_ms"])
+
+ return {
+ row["key_id"]: FetchKeyResultForRemote(
+ # Cast to bytes since postgresql returns a memoryview.
+ key_json=bytes(row["key_json"]),
+ valid_until_ts=row["ts_valid_until_ms"],
+ added_ts=row["ts_added_ms"],
+ )
+ for row in rows
+ }
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index 1680bf6168..54d40e7a3a 100644
--- a/synapse/storage/databases/main/lock.py
+++ b/synapse/storage/databases/main/lock.py
@@ -26,7 +26,6 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
-from synapse.storage.engines import PostgresEngine
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -96,6 +95,10 @@ class LockStore(SQLBaseStore):
self._acquiring_locks: Set[Tuple[str, str]] = set()
+ self._clock.looping_call(
+ self._reap_stale_read_write_locks, _LOCK_TIMEOUT_MS / 10.0
+ )
+
@wrap_as_background_process("LockStore._on_shutdown")
async def _on_shutdown(self) -> None:
"""Called when the server is shutting down"""
@@ -216,6 +219,7 @@ class LockStore(SQLBaseStore):
lock_name,
lock_key,
write,
+ db_autocommit=True,
)
except self.database_engine.module.IntegrityError:
return None
@@ -233,61 +237,22 @@ class LockStore(SQLBaseStore):
# `worker_read_write_locks` and seeing if that fails any
# constraints. If it doesn't then we have acquired the lock,
# otherwise we haven't.
- #
- # Before that though we clear the table of any stale locks.
now = self._clock.time_msec()
token = random_string(6)
- delete_sql = """
- DELETE FROM worker_read_write_locks
- WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
- """
-
- insert_sql = """
- INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
- VALUES (?, ?, ?, ?, ?, ?)
- """
-
- if isinstance(self.database_engine, PostgresEngine):
- # For Postgres we can send these queries at the same time.
- txn.execute(
- delete_sql + ";" + insert_sql,
- (
- # DELETE args
- now - _LOCK_TIMEOUT_MS,
- lock_name,
- lock_key,
- # UPSERT args
- lock_name,
- lock_key,
- write,
- self._instance_name,
- token,
- now,
- ),
- )
- else:
- # For SQLite these need to be two queries.
- txn.execute(
- delete_sql,
- (
- now - _LOCK_TIMEOUT_MS,
- lock_name,
- lock_key,
- ),
- )
- txn.execute(
- insert_sql,
- (
- lock_name,
- lock_key,
- write,
- self._instance_name,
- token,
- now,
- ),
- )
+ self.db_pool.simple_insert_txn(
+ txn,
+ table="worker_read_write_locks",
+ values={
+ "lock_name": lock_name,
+ "lock_key": lock_key,
+ "write_lock": write,
+ "instance_name": self._instance_name,
+ "token": token,
+ "last_renewed_ts": now,
+ },
+ )
lock = Lock(
self._reactor,
@@ -351,6 +316,24 @@ class LockStore(SQLBaseStore):
return locks
+ @wrap_as_background_process("_reap_stale_read_write_locks")
+ async def _reap_stale_read_write_locks(self) -> None:
+ delete_sql = """
+ DELETE FROM worker_read_write_locks
+ WHERE last_renewed_ts < ?
+ """
+
+ def reap_stale_read_write_locks_txn(txn: LoggingTransaction) -> None:
+ txn.execute(delete_sql, (self._clock.time_msec() - _LOCK_TIMEOUT_MS,))
+ if txn.rowcount:
+ logger.info("Reaped %d stale locks", txn.rowcount)
+
+ await self.db_pool.runInteraction(
+ "_reap_stale_read_write_locks",
+ reap_stale_read_write_locks_txn,
+ db_autocommit=True,
+ )
+
class Lock:
"""An async context manager that manages an acquired lock, ensuring it is
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index c582cf0573..d3a01d526f 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -205,7 +205,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
name, password_hash, is_guest, admin, consent_version, consent_ts,
consent_server_notice_sent, appservice_id, creation_ts, user_type,
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
- COALESCE(approved, TRUE) AS approved
+ COALESCE(approved, TRUE) AS approved,
+ COALESCE(locked, FALSE) AS locked
FROM users
WHERE name = ?
""",
@@ -230,10 +231,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# want to make sure we're returning the right type of data.
# Note: when adding a column name to this list, be wary of NULLable columns,
# since NULL values will be turned into False.
- boolean_columns = ["admin", "deactivated", "shadow_banned", "approved"]
+ boolean_columns = [
+ "admin",
+ "deactivated",
+ "shadow_banned",
+ "approved",
+ "locked",
+ ]
for column in boolean_columns:
- if not isinstance(row[column], bool):
- row[column] = bool(row[column])
+ row[column] = bool(row[column])
return row
@@ -1116,6 +1122,27 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# Convert the integer into a boolean.
return res == 1
+ @cached()
+ async def get_user_locked_status(self, user_id: str) -> bool:
+ """Retrieve the value for the `locked` property for the provided user.
+
+ Args:
+ user_id: The ID of the user to retrieve the status for.
+
+ Returns:
+ True if the user was locked, false if the user is still active.
+ """
+
+ res = await self.db_pool.simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="locked",
+ desc="get_user_locked_status",
+ )
+
+ # Convert the potential integer into a boolean.
+ return bool(res)
+
async def get_threepid_validation_session(
self,
medium: Optional[str],
@@ -2111,6 +2138,33 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
+ async def set_user_locked_status(self, user_id: str, locked: bool) -> None:
+ """Set the `locked` property for the provided user to the provided value.
+
+ Args:
+ user_id: The ID of the user to set the status for.
+ locked: The value to set for `locked`.
+ """
+
+ await self.db_pool.runInteraction(
+ "set_user_locked_status",
+ self.set_user_locked_status_txn,
+ user_id,
+ locked,
+ )
+
+ def set_user_locked_status_txn(
+ self, txn: LoggingTransaction, user_id: str, locked: bool
+ ) -> None:
+ self.db_pool.simple_update_one_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"locked": locked},
+ )
+ self._invalidate_cache_and_stream(txn, self.get_user_locked_status, (user_id,))
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
+
def update_user_approval_status_txn(
self, txn: LoggingTransaction, user_id: str, approved: bool
) -> None:
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index f34b7ce8f4..6298f0984d 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -19,6 +19,7 @@ from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
+ Counter,
Dict,
Iterable,
List,
@@ -28,8 +29,6 @@ from typing import (
cast,
)
-from typing_extensions import Counter
-
from twisted.internet.defer import DeferredLock
from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 2a136f2ff6..f0dc31fee6 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -995,7 +995,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
)
async def search_user_dir(
- self, user_id: str, search_term: str, limit: int
+ self,
+ user_id: str,
+ search_term: str,
+ limit: int,
+ show_locked_users: bool = False,
) -> SearchResult:
"""Searches for users in directory
@@ -1029,6 +1033,9 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
)
"""
+ if not show_locked_users:
+ where_clause += " AND (u.locked IS NULL OR u.locked = FALSE)"
+
# We allow manipulating the ranking algorithm by injecting statements
# based on config options.
additional_ordering_statements = []
@@ -1060,6 +1067,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
SELECT d.user_id AS user_id, display_name, avatar_url
FROM matching_users as t
INNER JOIN user_directory AS d USING (user_id)
+ LEFT JOIN users AS u ON t.user_id = u.name
WHERE
%(where_clause)s
ORDER BY
@@ -1115,6 +1123,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
SELECT d.user_id AS user_id, display_name, avatar_url
FROM user_directory_search as t
INNER JOIN user_directory AS d USING (user_id)
+ LEFT JOIN users AS u ON t.user_id = u.name
WHERE
%(where_clause)s
AND value MATCH ?
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index 0363cdc038..0b5b3bf03e 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -145,5 +145,5 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM
This is not provided by DBAPI2, and so needs engine-specific support.
"""
- with open(filepath, "rt") as f:
+ with open(filepath) as f:
cls.executescript(cursor, f.read())
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 71584f3f74..e74b2269d2 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -25,3 +25,10 @@ logger = logging.getLogger(__name__)
class FetchKeyResult:
verify_key: VerifyKey # the key itself
valid_until_ts: int # how long we can use this key for
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class FetchKeyResultForRemote:
+ key_json: bytes # the full key JSON
+ valid_until_ts: int # how long we can use this key for, in milliseconds.
+ added_ts: int # When we added this key, in milliseconds.
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 38b7abd801..31501fd573 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -16,10 +16,18 @@ import logging
import os
import re
from collections import Counter
-from typing import Collection, Generator, Iterable, List, Optional, TextIO, Tuple
+from typing import (
+ Collection,
+ Counter as CounterType,
+ Generator,
+ Iterable,
+ List,
+ Optional,
+ TextIO,
+ Tuple,
+)
import attr
-from typing_extensions import Counter as CounterType
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import LoggingDatabaseConnection, LoggingTransaction
diff --git a/synapse/storage/schema/main/delta/80/01_users_alter_locked.sql b/synapse/storage/schema/main/delta/80/01_users_alter_locked.sql
new file mode 100644
index 0000000000..21c7971441
--- /dev/null
+++ b/synapse/storage/schema/main/delta/80/01_users_alter_locked.sql
@@ -0,0 +1,16 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE users ADD locked BOOLEAN DEFAULT FALSE NOT NULL;
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 39a1ae4ac3..073f682aca 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -21,6 +21,7 @@ from typing import (
Any,
ClassVar,
Dict,
+ Final,
List,
Mapping,
Match,
@@ -38,7 +39,7 @@ import attr
from immutabledict import immutabledict
from signedjson.key import decode_verify_key_bytes
from signedjson.types import VerifyKey
-from typing_extensions import Final, TypedDict
+from typing_extensions import TypedDict
from unpaddedbase64 import decode_base64
from zope.interface import Interface
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 4041e49e71..943ad54456 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -22,6 +22,7 @@ import logging
from contextlib import asynccontextmanager
from typing import (
Any,
+ AsyncContextManager,
AsyncIterator,
Awaitable,
Callable,
@@ -42,7 +43,7 @@ from typing import (
)
import attr
-from typing_extensions import AsyncContextManager, Concatenate, Literal, ParamSpec
+from typing_extensions import Concatenate, Literal, ParamSpec
from twisted.internet import defer
from twisted.internet.defer import CancelledError
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
index 644c341e8c..db6c40a3e1 100644
--- a/synapse/util/macaroons.py
+++ b/synapse/util/macaroons.py
@@ -218,7 +218,7 @@ class MacaroonGenerator:
# to avoid validating those as guest tokens, we explicitely verify if
# the macaroon includes the "guest = true" caveat.
is_guest = any(
- (caveat.caveat_id == "guest = true" for caveat in macaroon.caveats)
+ caveat.caveat_id == "guest = true" for caveat in macaroon.caveats
)
if not is_guest:
diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py
index 48b8195ca1..8cb766860e 100644
--- a/synapse/util/manhole.py
+++ b/synapse/util/manhole.py
@@ -98,7 +98,9 @@ def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> ServerFactory:
SynapseManhole, dict(globals, __name__="__console__")
)
- factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
+ # type-ignore: This is an error in Twisted's annotations. See
+ # https://github.com/twisted/twisted/issues/11812 and /11813 .
+ factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) # type: ignore[arg-type]
# conch has the wrong type on these dicts (says bytes to bytes,
# should be bytes to Keys judging by how it's used).
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 2ad55ac13e..cde4a0780f 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -20,6 +20,7 @@ import typing
from typing import (
Any,
Callable,
+ ContextManager,
DefaultDict,
Dict,
Iterator,
@@ -33,7 +34,6 @@ from typing import (
from weakref import WeakSet
from prometheus_client.core import Counter
-from typing_extensions import ContextManager
from twisted.internet import defer
diff --git a/synapse/visibility.py b/synapse/visibility.py
index fc71dc92a4..eac10f6438 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -17,6 +17,7 @@ from enum import Enum, auto
from typing import (
Collection,
Dict,
+ Final,
FrozenSet,
List,
Mapping,
@@ -27,7 +28,6 @@ from typing import (
)
import attr
-from typing_extensions import Final
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.events import EventBase
|