diff --git a/synapse/__init__.py b/synapse/__init__.py
index 01d6bf17f0..5da6c924fc 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -47,7 +47,7 @@ try:
except ImportError:
pass
-__version__ = "1.39.0rc2"
+__version__ = "1.39.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/_scripts/review_recent_signups.py b/synapse/_scripts/review_recent_signups.py
index 01dc0c4237..9de913db88 100644
--- a/synapse/_scripts/review_recent_signups.py
+++ b/synapse/_scripts/review_recent_signups.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 8c7ad2a407..a986fdb47a 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -120,6 +120,7 @@ class EventTypes:
SpaceParent = "m.space.parent"
MSC2716_INSERTION = "org.matrix.msc2716.insertion"
+ MSC2716_CHUNK = "org.matrix.msc2716.chunk"
MSC2716_MARKER = "org.matrix.msc2716.marker"
@@ -198,15 +199,13 @@ class EventContentFields:
# Used on normal messages to indicate they were historically imported after the fact
MSC2716_HISTORICAL = "org.matrix.msc2716.historical"
- # For "insertion" events
+ # For "insertion" events to indicate what the next chunk ID should be in
+ # order to connect to it
MSC2716_NEXT_CHUNK_ID = "org.matrix.msc2716.next_chunk_id"
- # Used on normal message events to indicate where the chunk connects to
+ # Used on "chunk" events to indicate which insertion event it connects to
MSC2716_CHUNK_ID = "org.matrix.msc2716.chunk_id"
# For "marker" events
MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion"
- MSC2716_MARKER_INSERTION_PREV_EVENTS = (
- "org.matrix.msc2716.marker.insertion_prev_events"
- )
class RoomTypes:
@@ -230,3 +229,7 @@ class HistoryVisibility:
JOINED = "joined"
SHARED = "shared"
WORLD_READABLE = "world_readable"
+
+
+class ReadReceiptEventFields:
+ MSC2285_HIDDEN = "org.matrix.msc2285.hidden"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 054ab14ab6..dc662bca83 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -75,6 +75,9 @@ class Codes:
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
USER_DEACTIVATED = "M_USER_DEACTIVATED"
BAD_ALIAS = "M_BAD_ALIAS"
+ # For restricted join rules.
+ UNABLE_AUTHORISE_JOIN = "M_UNABLE_TO_AUTHORISE_JOIN"
+ UNABLE_TO_GRANT_JOIN = "M_UNABLE_TO_GRANT_JOIN"
class CodeMessageException(RuntimeError):
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index a20abc5a65..bc678efe49 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict
+from typing import Callable, Dict, Optional
import attr
@@ -73,6 +73,9 @@ class RoomVersion:
# MSC2403: Allows join_rules to be set to 'knock', changes auth rules to allow sending
# m.room.membership event with membership 'knock'.
msc2403_knocking = attr.ib(type=bool)
+ # MSC2716: Adds m.room.power_levels -> content.historical field to control
+ # whether "insertion", "chunk", "marker" events can be sent
+ msc2716_historical = attr.ib(type=bool)
class RoomVersions:
@@ -88,6 +91,7 @@ class RoomVersions:
msc2176_redaction_rules=False,
msc3083_join_rules=False,
msc2403_knocking=False,
+ msc2716_historical=False,
)
V2 = RoomVersion(
"2",
@@ -101,6 +105,7 @@ class RoomVersions:
msc2176_redaction_rules=False,
msc3083_join_rules=False,
msc2403_knocking=False,
+ msc2716_historical=False,
)
V3 = RoomVersion(
"3",
@@ -114,6 +119,7 @@ class RoomVersions:
msc2176_redaction_rules=False,
msc3083_join_rules=False,
msc2403_knocking=False,
+ msc2716_historical=False,
)
V4 = RoomVersion(
"4",
@@ -127,6 +133,7 @@ class RoomVersions:
msc2176_redaction_rules=False,
msc3083_join_rules=False,
msc2403_knocking=False,
+ msc2716_historical=False,
)
V5 = RoomVersion(
"5",
@@ -140,6 +147,7 @@ class RoomVersions:
msc2176_redaction_rules=False,
msc3083_join_rules=False,
msc2403_knocking=False,
+ msc2716_historical=False,
)
V6 = RoomVersion(
"6",
@@ -153,6 +161,7 @@ class RoomVersions:
msc2176_redaction_rules=False,
msc3083_join_rules=False,
msc2403_knocking=False,
+ msc2716_historical=False,
)
MSC2176 = RoomVersion(
"org.matrix.msc2176",
@@ -166,9 +175,10 @@ class RoomVersions:
msc2176_redaction_rules=True,
msc3083_join_rules=False,
msc2403_knocking=False,
+ msc2716_historical=False,
)
MSC3083 = RoomVersion(
- "org.matrix.msc3083",
+ "org.matrix.msc3083.v2",
RoomDisposition.UNSTABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
@@ -179,6 +189,7 @@ class RoomVersions:
msc2176_redaction_rules=False,
msc3083_join_rules=True,
msc2403_knocking=False,
+ msc2716_historical=False,
)
V7 = RoomVersion(
"7",
@@ -192,6 +203,21 @@ class RoomVersions:
msc2176_redaction_rules=False,
msc3083_join_rules=False,
msc2403_knocking=True,
+ msc2716_historical=False,
+ )
+ MSC2716 = RoomVersion(
+ "org.matrix.msc2716",
+ RoomDisposition.STABLE,
+ EventFormatVersions.V3,
+ StateResolutionVersions.V2,
+ enforce_key_validity=True,
+ special_case_aliases_auth=False,
+ strict_canonicaljson=True,
+ limit_notifications_power_levels=True,
+ msc2176_redaction_rules=False,
+ msc3083_join_rules=False,
+ msc2403_knocking=True,
+ msc2716_historical=True,
)
@@ -207,6 +233,41 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
RoomVersions.MSC2176,
RoomVersions.MSC3083,
RoomVersions.V7,
+ RoomVersions.MSC2716,
+ )
+}
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RoomVersionCapability:
+ """An object which describes the unique attributes of a room version."""
+
+ identifier: str # the identifier for this capability
+ preferred_version: Optional[RoomVersion]
+ support_check_lambda: Callable[[RoomVersion], bool]
+
+
+MSC3244_CAPABILITIES = {
+ cap.identifier: {
+ "preferred": cap.preferred_version.identifier
+ if cap.preferred_version is not None
+ else None,
+ "support": [
+ v.identifier
+ for v in KNOWN_ROOM_VERSIONS.values()
+ if cap.support_check_lambda(v)
+ ],
+ }
+ for cap in (
+ RoomVersionCapability(
+ "knock",
+ RoomVersions.V7,
+ lambda room_version: room_version.msc2403_knocking,
+ ),
+ RoomVersionCapability(
+ "restricted",
+ None,
+ lambda room_version: room_version.msc3083_join_rules,
+ ),
)
- # Note that we do not include MSC2043 here unless it is enabled in the config.
}
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 2878d2c140..3234d9ebba 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
# Copyright 2019 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 2d50060ffb..de1bcee0a7 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 2d50060ffb..de1bcee0a7 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index 57af28f10a..885454ed44 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 2d50060ffb..de1bcee0a7 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 2d50060ffb..de1bcee0a7 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index 2d50060ffb..de1bcee0a7 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index c3d4992518..3b7131af8f 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
# Copyright 2016 OpenMarket Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 920b34d97b..7dae163c1a 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
#
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index 2d50060ffb..de1bcee0a7 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 2d50060ffb..de1bcee0a7 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 2d50060ffb..de1bcee0a7 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index a368efb354..14bde27179 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
# Copyright 2017 Vector Creations Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 3d7d92f615..651e31b576 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -33,6 +33,9 @@ DEFAULT_CONFIG = """\
# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or
# 'psycopg2' (for PostgreSQL).
#
+# 'txn_limit' gives the maximum number of transactions to run per connection
+# before reconnecting. Defaults to 0, which means no limit.
+#
# 'args' gives options which are passed through to the database engine,
# except for options starting 'cp_', which are used to configure the Twisted
# connection pool. For a reference to valid arguments, see:
@@ -53,6 +56,7 @@ DEFAULT_CONFIG = """\
#
#database:
# name: psycopg2
+# txn_limit: 10000
# args:
# user: synapse_user
# password: secretpassword
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index bcecbfec03..8d8f166e9b 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -39,12 +39,13 @@ DEFAULT_SUBJECTS = {
"messages_from_person_and_others": "[%(app)s] You have messages on %(app)s from %(person)s and others...",
"invite_from_person": "[%(app)s] %(person)s has invited you to chat on %(app)s...",
"invite_from_person_to_room": "[%(app)s] %(person)s has invited you to join the %(room)s room on %(app)s...",
+ "invite_from_person_to_space": "[%(app)s] %(person)s has invited you to join the %(space)s space on %(app)s...",
"password_reset": "[%(server_name)s] Password reset",
"email_validation": "[%(server_name)s] Validate your email",
}
-@attr.s
+@attr.s(slots=True, frozen=True)
class EmailSubjectConfig:
message_from_person_in_room = attr.ib(type=str)
message_from_person = attr.ib(type=str)
@@ -54,6 +55,7 @@ class EmailSubjectConfig:
messages_from_person_and_others = attr.ib(type=str)
invite_from_person = attr.ib(type=str)
invite_from_person_to_room = attr.ib(type=str)
+ invite_from_person_to_space = attr.ib(type=str)
password_reset = attr.ib(type=str)
email_validation = attr.ib(type=str)
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index e25ccba9ac..4c60ee8c28 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -32,3 +32,9 @@ class ExperimentalConfig(Config):
# MSC2716 (backfill existing history)
self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False)
+
+ # MSC2285 (hidden read receipts)
+ self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False)
+
+ # MSC3244 (room version capabilities)
+ self.msc3244_enabled: bool = experimental.get("msc3244_enabled", False)
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index ad4e6e61c3..dcd3ed1dac 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -71,7 +71,7 @@ handlers:
# will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR
# logs will still be flushed immediately.
buffer:
- class: logging.handlers.MemoryHandler
+ class: synapse.logging.handlers.PeriodicallyFlushingMemoryHandler
target: file
# The capacity is the number of log lines that are buffered before
# being written to disk. Increasing this will lead to better
@@ -79,6 +79,9 @@ handlers:
# be written to disk.
capacity: 10
flushLevel: 30 # Flush for WARNING logs as well
+ # The period of time, in seconds, between forced flushes.
+ # Messages will not be delayed for longer than this time.
+ period: 5
# A handler that writes logs to stderr. Unused by default, but can be used
# instead of "buffer" and "file" in the logger handlers.
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 137dff2513..4c92e9a2d4 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -106,6 +106,18 @@ def check(
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
+ is_invite_via_allow_rule = (
+ event.type == EventTypes.Member
+ and event.membership == Membership.JOIN
+ and "join_authorised_via_users_server" in event.content
+ )
+ if is_invite_via_allow_rule:
+ authoriser_domain = get_domain_from_id(
+ event.content["join_authorised_via_users_server"]
+ )
+ if not event.signatures.get(authoriser_domain):
+ raise AuthError(403, "Event not signed by authorising server")
+
# Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules
#
# 1. If type is m.room.create:
@@ -177,7 +189,7 @@ def check(
# https://github.com/vector-im/vector-web/issues/1208 hopefully
if event.type == EventTypes.ThirdPartyInvite:
user_level = get_user_power_level(event.user_id, auth_events)
- invite_level = _get_named_level(auth_events, "invite", 0)
+ invite_level = get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(403, "You don't have permission to invite users")
@@ -193,6 +205,13 @@ def check(
if event.type == EventTypes.Redaction:
check_redaction(room_version_obj, event, auth_events)
+ if (
+ event.type == EventTypes.MSC2716_INSERTION
+ or event.type == EventTypes.MSC2716_CHUNK
+ or event.type == EventTypes.MSC2716_MARKER
+ ):
+ check_historical(room_version_obj, event, auth_events)
+
logger.debug("Allowing! %s", event)
@@ -285,8 +304,8 @@ def _is_membership_change_allowed(
user_level = get_user_power_level(event.user_id, auth_events)
target_level = get_user_power_level(target_user_id, auth_events)
- # FIXME (erikj): What should we do here as the default?
- ban_level = _get_named_level(auth_events, "ban", 50)
+ invite_level = get_named_level(auth_events, "invite", 0)
+ ban_level = get_named_level(auth_events, "ban", 50)
logger.debug(
"_is_membership_change_allowed: %s",
@@ -336,8 +355,6 @@ def _is_membership_change_allowed(
elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." % target_user_id)
else:
- invite_level = _get_named_level(auth_events, "invite", 0)
-
if user_level < invite_level:
raise AuthError(403, "You don't have permission to invite users")
elif Membership.JOIN == membership:
@@ -345,16 +362,41 @@ def _is_membership_change_allowed(
# * They are not banned.
# * They are accepting a previously sent invitation.
# * They are already joined (it's a NOOP).
- # * The room is public or restricted.
+ # * The room is public.
+ # * The room is restricted and the user meets the allows rules.
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
elif target_banned:
raise AuthError(403, "You are banned from this room")
- elif join_rule == JoinRules.PUBLIC or (
+ elif join_rule == JoinRules.PUBLIC:
+ pass
+ elif (
room_version.msc3083_join_rules
and join_rule == JoinRules.MSC3083_RESTRICTED
):
- pass
+ # This is the same as public, but the event must contain a reference
+ # to the server who authorised the join. If the event does not contain
+ # the proper content it is rejected.
+ #
+ # Note that if the caller is in the room or invited, then they do
+ # not need to meet the allow rules.
+ if not caller_in_room and not caller_invited:
+ authorising_user = event.content.get("join_authorised_via_users_server")
+
+ if authorising_user is None:
+ raise AuthError(403, "Join event is missing authorising user.")
+
+ # The authorising user must be in the room.
+ key = (EventTypes.Member, authorising_user)
+ member_event = auth_events.get(key)
+ _check_joined_room(member_event, authorising_user, event.room_id)
+
+ authorising_user_level = get_user_power_level(
+ authorising_user, auth_events
+ )
+ if authorising_user_level < invite_level:
+ raise AuthError(403, "Join event authorised by invalid server.")
+
elif join_rule == JoinRules.INVITE or (
room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
):
@@ -369,7 +411,7 @@ def _is_membership_change_allowed(
if target_banned and user_level < ban_level:
raise AuthError(403, "You cannot unban user %s." % (target_user_id,))
elif target_user_id != event.user_id:
- kick_level = _get_named_level(auth_events, "kick", 50)
+ kick_level = get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level:
raise AuthError(403, "You cannot kick user %s." % target_user_id)
@@ -445,7 +487,7 @@ def get_send_level(
def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
- power_levels_event = _get_power_level_event(auth_events)
+ power_levels_event = get_power_level_event(auth_events)
send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
user_level = get_user_power_level(event.user_id, auth_events)
@@ -485,7 +527,7 @@ def check_redaction(
"""
user_level = get_user_power_level(event.user_id, auth_events)
- redact_level = _get_named_level(auth_events, "redact", 50)
+ redact_level = get_named_level(auth_events, "redact", 50)
if user_level >= redact_level:
return False
@@ -504,6 +546,37 @@ def check_redaction(
raise AuthError(403, "You don't have permission to redact events")
+def check_historical(
+ room_version_obj: RoomVersion,
+ event: EventBase,
+ auth_events: StateMap[EventBase],
+) -> None:
+ """Check whether the event sender is allowed to send historical related
+ events like "insertion", "chunk", and "marker".
+
+ Returns:
+ None
+
+ Raises:
+ AuthError if the event sender is not allowed to send historical related events
+ ("insertion", "chunk", and "marker").
+ """
+ # Ignore the auth checks in room versions that do not support historical
+ # events
+ if not room_version_obj.msc2716_historical:
+ return
+
+ user_level = get_user_power_level(event.user_id, auth_events)
+
+ historical_level = get_named_level(auth_events, "historical", 100)
+
+ if user_level < historical_level:
+ raise AuthError(
+ 403,
+ 'You don\'t have permission to send send historical related events ("insertion", "chunk", and "marker")',
+ )
+
+
def _check_power_levels(
room_version_obj: RoomVersion,
event: EventBase,
@@ -600,7 +673,7 @@ def _check_power_levels(
)
-def _get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
+def get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
return auth_events.get((EventTypes.PowerLevels, ""))
@@ -616,10 +689,10 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
Returns:
the user's power level in this room.
"""
- power_level_event = _get_power_level_event(auth_events)
+ power_level_event = get_power_level_event(auth_events)
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
- if not level:
+ if level is None:
level = power_level_event.content.get("users_default", 0)
if level is None:
@@ -640,8 +713,8 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
return 0
-def _get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
- power_level_event = _get_power_level_event(auth_events)
+def get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
+ power_level_event = get_power_level_event(auth_events)
if not power_level_event:
return default
@@ -728,7 +801,9 @@ def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
return public_keys
-def auth_types_for_event(event: Union[EventBase, EventBuilder]) -> Set[Tuple[str, str]]:
+def auth_types_for_event(
+ room_version: RoomVersion, event: Union[EventBase, EventBuilder]
+) -> Set[Tuple[str, str]]:
"""Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what
would actually be required depending on the full state of the room.
@@ -760,4 +835,12 @@ def auth_types_for_event(event: Union[EventBase, EventBuilder]) -> Set[Tuple[str
)
auth_types.add(key)
+ if room_version.msc3083_join_rules and membership == Membership.JOIN:
+ if "join_authorised_via_users_server" in event.content:
+ key = (
+ EventTypes.Member,
+ event.content["join_authorised_via_users_server"],
+ )
+ auth_types.add(key)
+
return auth_types
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index ec96999e4e..a0c07f62f4 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -109,6 +109,8 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
add_fields("creator")
elif event_type == EventTypes.JoinRules:
add_fields("join_rule")
+ if room_version.msc3083_join_rules:
+ add_fields("allow")
elif event_type == EventTypes.PowerLevels:
add_fields(
"users",
@@ -124,6 +126,9 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
if room_version.msc2176_redaction_rules:
add_fields("invite")
+ if room_version.msc2716_historical:
+ add_fields("historical")
+
elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth:
add_fields("aliases")
elif event_type == EventTypes.RoomHistoryVisibility:
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 2bfe6a3d37..024e440ff4 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -178,6 +178,34 @@ async def _check_sigs_on_pdu(
)
raise SynapseError(403, errmsg, Codes.FORBIDDEN)
+ # If this is a join event for a restricted room it may have been authorised
+ # via a different server from the sending server. Check those signatures.
+ if (
+ room_version.msc3083_join_rules
+ and pdu.type == EventTypes.Member
+ and pdu.membership == Membership.JOIN
+ and "join_authorised_via_users_server" in pdu.content
+ ):
+ authorising_server = get_domain_from_id(
+ pdu.content["join_authorised_via_users_server"]
+ )
+ try:
+ await keyring.verify_event_for_server(
+ authorising_server,
+ pdu,
+ pdu.origin_server_ts if room_version.enforce_key_validity else 0,
+ )
+ except Exception as e:
+ errmsg = (
+ "event id %s: unable to verify signature for authorising server %s: %s"
+ % (
+ pdu.event_id,
+ authorising_server,
+ e,
+ )
+ )
+ raise SynapseError(403, errmsg, Codes.FORBIDDEN)
+
def _is_invite_via_3pid(event: EventBase) -> bool:
return (
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c767d30627..b7a10da15a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -19,10 +19,10 @@ import itertools
import logging
from typing import (
TYPE_CHECKING,
- Any,
Awaitable,
Callable,
Collection,
+ Container,
Dict,
Iterable,
List,
@@ -79,7 +79,15 @@ class InvalidResponseError(RuntimeError):
we couldn't parse
"""
- pass
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class SendJoinResult:
+ # The event to persist.
+ event: EventBase
+ # A string giving the server the event was sent to.
+ origin: str
+ state: List[EventBase]
+ auth_chain: List[EventBase]
class FederationClient(FederationBase):
@@ -506,6 +514,7 @@ class FederationClient(FederationBase):
description: str,
destinations: Iterable[str],
callback: Callable[[str], Awaitable[T]],
+ failover_errcodes: Optional[Container[str]] = None,
failover_on_unknown_endpoint: bool = False,
) -> T:
"""Try an operation on a series of servers, until it succeeds
@@ -526,6 +535,9 @@ class FederationClient(FederationBase):
next server tried. Normally the stacktrace is logged but this is
suppressed if the exception is an InvalidResponseError.
+ failover_errcodes: Error codes (specific to this endpoint) which should
+ cause a failover when received as part of an HTTP 400 error.
+
failover_on_unknown_endpoint: if True, we will try other servers if it looks
like a server doesn't support the endpoint. This is typically useful
if the endpoint in question is new or experimental.
@@ -537,6 +549,9 @@ class FederationClient(FederationBase):
SynapseError if the chosen remote server returns a 300/400 code, or
no servers were reachable.
"""
+ if failover_errcodes is None:
+ failover_errcodes = ()
+
for destination in destinations:
if destination == self.server_name:
continue
@@ -551,11 +566,17 @@ class FederationClient(FederationBase):
synapse_error = e.to_synapse_error()
failover = False
- # Failover on an internal server error, or if the destination
- # doesn't implemented the endpoint for some reason.
+ # Failover should occur:
+ #
+ # * On internal server errors.
+ # * If the destination responds that it cannot complete the request.
+ # * If the destination doesn't implemented the endpoint for some reason.
if 500 <= e.code < 600:
failover = True
+ elif e.code == 400 and synapse_error.errcode in failover_errcodes:
+ failover = True
+
elif failover_on_unknown_endpoint and self._is_unknown_endpoint(
e, synapse_error
):
@@ -671,13 +692,25 @@ class FederationClient(FederationBase):
return destination, ev, room_version
+ # MSC3083 defines additional error codes for room joins. Unfortunately
+ # we do not yet know the room version, assume these will only be returned
+ # by valid room versions.
+ failover_errcodes = (
+ (Codes.UNABLE_AUTHORISE_JOIN, Codes.UNABLE_TO_GRANT_JOIN)
+ if membership == Membership.JOIN
+ else None
+ )
+
return await self._try_destination_list(
- "make_" + membership, destinations, send_request
+ "make_" + membership,
+ destinations,
+ send_request,
+ failover_errcodes=failover_errcodes,
)
async def send_join(
self, destinations: Iterable[str], pdu: EventBase, room_version: RoomVersion
- ) -> Dict[str, Any]:
+ ) -> SendJoinResult:
"""Sends a join event to one of a list of homeservers.
Doing so will cause the remote server to add the event to the graph,
@@ -691,18 +724,38 @@ class FederationClient(FederationBase):
did the make_join)
Returns:
- a dict with members ``origin`` (a string
- giving the server the event was sent to, ``state`` (?) and
- ``auth_chain``.
+ The result of the send join request.
Raises:
SynapseError: if the chosen remote server returns a 300/400 code, or
no servers successfully handle the request.
"""
- async def send_request(destination) -> Dict[str, Any]:
+ async def send_request(destination) -> SendJoinResult:
response = await self._do_send_join(room_version, destination, pdu)
+ # If an event was returned (and expected to be returned):
+ #
+ # * Ensure it has the same event ID (note that the event ID is a hash
+ # of the event fields for versions which support MSC3083).
+ # * Ensure the signatures are good.
+ #
+ # Otherwise, fallback to the provided event.
+ if room_version.msc3083_join_rules and response.event:
+ event = response.event
+
+ valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
+ pdu=event,
+ origin=destination,
+ outlier=True,
+ room_version=room_version,
+ )
+
+ if valid_pdu is None or event.event_id != pdu.event_id:
+ raise InvalidResponseError("Returned an invalid join event")
+ else:
+ event = pdu
+
state = response.state
auth_chain = response.auth_events
@@ -784,13 +837,32 @@ class FederationClient(FederationBase):
% (auth_chain_create_events,)
)
- return {
- "state": signed_state,
- "auth_chain": signed_auth,
- "origin": destination,
- }
+ return SendJoinResult(
+ event=event,
+ state=signed_state,
+ auth_chain=signed_auth,
+ origin=destination,
+ )
- return await self._try_destination_list("send_join", destinations, send_request)
+ # MSC3083 defines additional error codes for room joins.
+ failover_errcodes = None
+ if room_version.msc3083_join_rules:
+ failover_errcodes = (
+ Codes.UNABLE_AUTHORISE_JOIN,
+ Codes.UNABLE_TO_GRANT_JOIN,
+ )
+
+ # If the join is being authorised via allow rules, we need to send
+ # the /send_join back to the same server that was originally used
+ # with /make_join.
+ if "join_authorised_via_users_server" in pdu.content:
+ destinations = [
+ get_domain_from_id(pdu.content["join_authorised_via_users_server"])
+ ]
+
+ return await self._try_destination_list(
+ "send_join", destinations, send_request, failover_errcodes=failover_errcodes
+ )
async def _do_send_join(
self, room_version: RoomVersion, destination: str, pdu: EventBase
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 29619aeeb8..145b9161d9 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -45,6 +45,7 @@ from synapse.api.errors import (
UnsupportedRoomVersionError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.crypto.event_signing import compute_event_signature
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
@@ -64,7 +65,7 @@ from synapse.replication.http.federation import (
ReplicationGetQueryRestServlet,
)
from synapse.storage.databases.main.lock import Lock
-from synapse.types import JsonDict
+from synapse.types import JsonDict, get_domain_from_id
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
@@ -586,7 +587,7 @@ class FederationServer(FederationBase):
async def on_send_join_request(
self, origin: str, content: JsonDict, room_id: str
) -> Dict[str, Any]:
- context = await self._on_send_membership_event(
+ event, context = await self._on_send_membership_event(
origin, content, Membership.JOIN, room_id
)
@@ -597,6 +598,7 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec()
return {
+ "org.matrix.msc3083.v2.event": event.get_pdu_json(),
"state": [p.get_pdu_json(time_now) for p in state.values()],
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
}
@@ -681,7 +683,7 @@ class FederationServer(FederationBase):
Returns:
The stripped room state.
"""
- event_context = await self._on_send_membership_event(
+ _, context = await self._on_send_membership_event(
origin, content, Membership.KNOCK, room_id
)
@@ -690,14 +692,14 @@ class FederationServer(FederationBase):
# related to the room while the knock request is pending.
stripped_room_state = (
await self.store.get_stripped_room_state_from_event_context(
- event_context, self._room_prejoin_state_types
+ context, self._room_prejoin_state_types
)
)
return {"knock_state_events": stripped_room_state}
async def _on_send_membership_event(
self, origin: str, content: JsonDict, membership_type: str, room_id: str
- ) -> EventContext:
+ ) -> Tuple[EventBase, EventContext]:
"""Handle an on_send_{join,leave,knock} request
Does some preliminary validation before passing the request on to the
@@ -712,7 +714,7 @@ class FederationServer(FederationBase):
in the event
Returns:
- The context of the event after inserting it into the room graph.
+ The event and context of the event after inserting it into the room graph.
Raises:
SynapseError if there is a problem with the request, including things like
@@ -748,6 +750,33 @@ class FederationServer(FederationBase):
logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures)
+ # Sign the event since we're vouching on behalf of the remote server that
+ # the event is valid to be sent into the room. Currently this is only done
+ # if the user is being joined via restricted join rules.
+ if (
+ room_version.msc3083_join_rules
+ and event.membership == Membership.JOIN
+ and "join_authorised_via_users_server" in event.content
+ ):
+ # We can only authorise our own users.
+ authorising_server = get_domain_from_id(
+ event.content["join_authorised_via_users_server"]
+ )
+ if authorising_server != self.server_name:
+ raise SynapseError(
+ 400,
+ f"Cannot authorise request from resident server: {authorising_server}",
+ )
+
+ event.signatures.update(
+ compute_event_signature(
+ room_version,
+ event.get_pdu_json(),
+ self.hs.hostname,
+ self.hs.signing_key,
+ )
+ )
+
event = await self._check_sigs_and_hash(room_version, event)
return await self.handler.on_send_membership_event(origin, event)
@@ -995,6 +1024,23 @@ class FederationServer(FederationBase):
origin, event = next
+ # Prune the event queue if it's getting large.
+ #
+ # We do this *after* handling the first event as the common case is
+ # that the queue is empty (/has the single event in), and so there's
+ # no need to do this check.
+ pruned = await self.store.prune_staged_events_in_room(room_id, room_version)
+ if pruned:
+ # If we have pruned the queue check we need to refetch the next
+ # event to handle.
+ next = await self.store.get_next_staged_event_for_room(
+ room_id, room_version
+ )
+ if not next:
+ break
+
+ origin, event = next
+
lock = await self.store.try_acquire_lock(
_INBOUND_EVENT_HANDLING_LOCK_NAME, room_id
)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 98b1bf77fd..6a8d3ad4fe 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -15,7 +15,7 @@
import logging
import urllib
-from typing import Any, Dict, List, Optional
+from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
import attr
import ijson
@@ -29,6 +29,7 @@ from synapse.api.urls import (
FEDERATION_V2_PREFIX,
)
from synapse.events import EventBase, make_event_from_dict
+from synapse.federation.units import Transaction
from synapse.http.matrixfederationclient import ByteParser
from synapse.logging.utils import log_function
from synapse.types import JsonDict
@@ -49,23 +50,25 @@ class TransportLayerClient:
self.client = hs.get_federation_http_client()
@log_function
- def get_room_state_ids(self, destination, room_id, event_id):
+ async def get_room_state_ids(
+ self, destination: str, room_id: str, event_id: str
+ ) -> JsonDict:
"""Requests all state for a given room from the given server at the
given event. Returns the state's event_id's
Args:
- destination (str): The host name of the remote homeserver we want
+ destination: The host name of the remote homeserver we want
to get the state from.
- context (str): The name of the context we want the state of
- event_id (str): The event we want the context at.
+ context: The name of the context we want the state of
+ event_id: The event we want the context at.
Returns:
- Awaitable: Results in a dict received from the remote homeserver.
+ Results in a dict received from the remote homeserver.
"""
logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id)
path = _create_v1_path("/state_ids/%s", room_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination,
path=path,
args={"event_id": event_id},
@@ -73,39 +76,43 @@ class TransportLayerClient:
)
@log_function
- def get_event(self, destination, event_id, timeout=None):
+ async def get_event(
+ self, destination: str, event_id: str, timeout: Optional[int] = None
+ ) -> JsonDict:
"""Requests the pdu with give id and origin from the given server.
Args:
- destination (str): The host name of the remote homeserver we want
+ destination: The host name of the remote homeserver we want
to get the state from.
- event_id (str): The id of the event being requested.
- timeout (int): How long to try (in ms) the destination for before
+ event_id: The id of the event being requested.
+ timeout: How long to try (in ms) the destination for before
giving up. None indicates no timeout.
Returns:
- Awaitable: Results in a dict received from the remote homeserver.
+ Results in a dict received from the remote homeserver.
"""
logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id)
path = _create_v1_path("/event/%s", event_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination, path=path, timeout=timeout, try_trailing_slash_on_400=True
)
@log_function
- def backfill(self, destination, room_id, event_tuples, limit):
+ async def backfill(
+ self, destination: str, room_id: str, event_tuples: Iterable[str], limit: int
+ ) -> Optional[JsonDict]:
"""Requests `limit` previous PDUs in a given context before list of
PDUs.
Args:
- dest (str)
- room_id (str)
- event_tuples (list)
- limit (int)
+ destination
+ room_id
+ event_tuples
+ limit
Returns:
- Awaitable: Results in a dict received from the remote homeserver.
+ Results in a dict received from the remote homeserver.
"""
logger.debug(
"backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s",
@@ -117,18 +124,22 @@ class TransportLayerClient:
if not event_tuples:
# TODO: raise?
- return
+ return None
path = _create_v1_path("/backfill/%s", room_id)
args = {"v": event_tuples, "limit": [str(limit)]}
- return self.client.get_json(
+ return await self.client.get_json(
destination, path=path, args=args, try_trailing_slash_on_400=True
)
@log_function
- async def send_transaction(self, transaction, json_data_callback=None):
+ async def send_transaction(
+ self,
+ transaction: Transaction,
+ json_data_callback: Optional[Callable[[], JsonDict]] = None,
+ ) -> JsonDict:
"""Sends the given Transaction to its destination
Args:
@@ -149,21 +160,21 @@ class TransportLayerClient:
"""
logger.debug(
"send_data dest=%s, txid=%s",
- transaction.destination,
- transaction.transaction_id,
+ transaction.destination, # type: ignore
+ transaction.transaction_id, # type: ignore
)
- if transaction.destination == self.server_name:
+ if transaction.destination == self.server_name: # type: ignore
raise RuntimeError("Transport layer cannot send to itself!")
# FIXME: This is only used by the tests. The actual json sent is
# generated by the json_data_callback.
json_data = transaction.get_dict()
- path = _create_v1_path("/send/%s", transaction.transaction_id)
+ path = _create_v1_path("/send/%s", transaction.transaction_id) # type: ignore
- response = await self.client.put_json(
- transaction.destination,
+ return await self.client.put_json(
+ transaction.destination, # type: ignore
path=path,
data=json_data,
json_data_callback=json_data_callback,
@@ -172,8 +183,6 @@ class TransportLayerClient:
try_trailing_slash_on_400=True,
)
- return response
-
@log_function
async def make_query(
self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False
@@ -193,8 +202,13 @@ class TransportLayerClient:
@log_function
async def make_membership_event(
- self, destination, room_id, user_id, membership, params
- ):
+ self,
+ destination: str,
+ room_id: str,
+ user_id: str,
+ membership: str,
+ params: Optional[Mapping[str, Union[str, Iterable[str]]]],
+ ) -> JsonDict:
"""Asks a remote server to build and sign us a membership event
Note that this does not append any events to any graphs.
@@ -240,7 +254,7 @@ class TransportLayerClient:
ignore_backoff = True
retry_on_dns_fail = True
- content = await self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args=params,
@@ -249,20 +263,18 @@ class TransportLayerClient:
ignore_backoff=ignore_backoff,
)
- return content
-
@log_function
async def send_join_v1(
self,
- room_version,
- destination,
- room_id,
- event_id,
- content,
+ room_version: RoomVersion,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ content: JsonDict,
) -> "SendJoinResponse":
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
- response = await self.client.put_json(
+ return await self.client.put_json(
destination=destination,
path=path,
data=content,
@@ -270,15 +282,18 @@ class TransportLayerClient:
max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
)
- return response
-
@log_function
async def send_join_v2(
- self, room_version, destination, room_id, event_id, content
+ self,
+ room_version: RoomVersion,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ content: JsonDict,
) -> "SendJoinResponse":
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
- response = await self.client.put_json(
+ return await self.client.put_json(
destination=destination,
path=path,
data=content,
@@ -286,13 +301,13 @@ class TransportLayerClient:
max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
)
- return response
-
@log_function
- async def send_leave_v1(self, destination, room_id, event_id, content):
+ async def send_leave_v1(
+ self, destination: str, room_id: str, event_id: str, content: JsonDict
+ ) -> Tuple[int, JsonDict]:
path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
- response = await self.client.put_json(
+ return await self.client.put_json(
destination=destination,
path=path,
data=content,
@@ -303,13 +318,13 @@ class TransportLayerClient:
ignore_backoff=True,
)
- return response
-
@log_function
- async def send_leave_v2(self, destination, room_id, event_id, content):
+ async def send_leave_v2(
+ self, destination: str, room_id: str, event_id: str, content: JsonDict
+ ) -> JsonDict:
path = _create_v2_path("/send_leave/%s/%s", room_id, event_id)
- response = await self.client.put_json(
+ return await self.client.put_json(
destination=destination,
path=path,
data=content,
@@ -320,8 +335,6 @@ class TransportLayerClient:
ignore_backoff=True,
)
- return response
-
@log_function
async def send_knock_v1(
self,
@@ -357,25 +370,25 @@ class TransportLayerClient:
)
@log_function
- async def send_invite_v1(self, destination, room_id, event_id, content):
+ async def send_invite_v1(
+ self, destination: str, room_id: str, event_id: str, content: JsonDict
+ ) -> Tuple[int, JsonDict]:
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
- response = await self.client.put_json(
+ return await self.client.put_json(
destination=destination, path=path, data=content, ignore_backoff=True
)
- return response
-
@log_function
- async def send_invite_v2(self, destination, room_id, event_id, content):
+ async def send_invite_v2(
+ self, destination: str, room_id: str, event_id: str, content: JsonDict
+ ) -> JsonDict:
path = _create_v2_path("/invite/%s/%s", room_id, event_id)
- response = await self.client.put_json(
+ return await self.client.put_json(
destination=destination, path=path, data=content, ignore_backoff=True
)
- return response
-
@log_function
async def get_public_rooms(
self,
@@ -385,7 +398,7 @@ class TransportLayerClient:
search_filter: Optional[Dict] = None,
include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None,
- ):
+ ) -> JsonDict:
"""Get the list of public rooms from a remote homeserver
See synapse.federation.federation_client.FederationClient.get_public_rooms for
@@ -450,25 +463,27 @@ class TransportLayerClient:
return response
@log_function
- async def exchange_third_party_invite(self, destination, room_id, event_dict):
+ async def exchange_third_party_invite(
+ self, destination: str, room_id: str, event_dict: JsonDict
+ ) -> JsonDict:
path = _create_v1_path("/exchange_third_party_invite/%s", room_id)
- response = await self.client.put_json(
+ return await self.client.put_json(
destination=destination, path=path, data=event_dict
)
- return response
-
@log_function
- async def get_event_auth(self, destination, room_id, event_id):
+ async def get_event_auth(
+ self, destination: str, room_id: str, event_id: str
+ ) -> JsonDict:
path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)
- content = await self.client.get_json(destination=destination, path=path)
-
- return content
+ return await self.client.get_json(destination=destination, path=path)
@log_function
- async def query_client_keys(self, destination, query_content, timeout):
+ async def query_client_keys(
+ self, destination: str, query_content: JsonDict, timeout: int
+ ) -> JsonDict:
"""Query the device keys for a list of user ids hosted on a remote
server.
@@ -496,20 +511,21 @@ class TransportLayerClient:
}
Args:
- destination(str): The server to query.
- query_content(dict): The user ids to query.
+ destination: The server to query.
+ query_content: The user ids to query.
Returns:
A dict containing device and cross-signing keys.
"""
path = _create_v1_path("/user/keys/query")
- content = await self.client.post_json(
+ return await self.client.post_json(
destination=destination, path=path, data=query_content, timeout=timeout
)
- return content
@log_function
- async def query_user_devices(self, destination, user_id, timeout):
+ async def query_user_devices(
+ self, destination: str, user_id: str, timeout: int
+ ) -> JsonDict:
"""Query the devices for a user id hosted on a remote server.
Response:
@@ -535,20 +551,21 @@ class TransportLayerClient:
}
Args:
- destination(str): The server to query.
- query_content(dict): The user ids to query.
+ destination: The server to query.
+ query_content: The user ids to query.
Returns:
A dict containing device and cross-signing keys.
"""
path = _create_v1_path("/user/devices/%s", user_id)
- content = await self.client.get_json(
+ return await self.client.get_json(
destination=destination, path=path, timeout=timeout
)
- return content
@log_function
- async def claim_client_keys(self, destination, query_content, timeout):
+ async def claim_client_keys(
+ self, destination: str, query_content: JsonDict, timeout: int
+ ) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.
Request:
@@ -572,33 +589,32 @@ class TransportLayerClient:
}
Args:
- destination(str): The server to query.
- query_content(dict): The user ids to query.
+ destination: The server to query.
+ query_content: The user ids to query.
Returns:
A dict containing the one-time keys.
"""
path = _create_v1_path("/user/keys/claim")
- content = await self.client.post_json(
+ return await self.client.post_json(
destination=destination, path=path, data=query_content, timeout=timeout
)
- return content
@log_function
async def get_missing_events(
self,
- destination,
- room_id,
- earliest_events,
- latest_events,
- limit,
- min_depth,
- timeout,
- ):
+ destination: str,
+ room_id: str,
+ earliest_events: Iterable[str],
+ latest_events: Iterable[str],
+ limit: int,
+ min_depth: int,
+ timeout: int,
+ ) -> JsonDict:
path = _create_v1_path("/get_missing_events/%s", room_id)
- content = await self.client.post_json(
+ return await self.client.post_json(
destination=destination,
path=path,
data={
@@ -610,14 +626,14 @@ class TransportLayerClient:
timeout=timeout,
)
- return content
-
@log_function
- def get_group_profile(self, destination, group_id, requester_user_id):
+ async def get_group_profile(
+ self, destination: str, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get a group profile"""
path = _create_v1_path("/groups/%s/profile", group_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -625,14 +641,16 @@ class TransportLayerClient:
)
@log_function
- def update_group_profile(self, destination, group_id, requester_user_id, content):
+ async def update_group_profile(
+ self, destination: str, group_id: str, requester_user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Update a remote group profile
Args:
- destination (str)
- group_id (str)
- requester_user_id (str)
- content (dict): The new profile of the group
+ destination
+ group_id
+ requester_user_id
+ content: The new profile of the group
"""
path = _create_v1_path("/groups/%s/profile", group_id)
@@ -645,11 +663,13 @@ class TransportLayerClient:
)
@log_function
- def get_group_summary(self, destination, group_id, requester_user_id):
+ async def get_group_summary(
+ self, destination: str, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get a group summary"""
path = _create_v1_path("/groups/%s/summary", group_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -657,24 +677,31 @@ class TransportLayerClient:
)
@log_function
- def get_rooms_in_group(self, destination, group_id, requester_user_id):
+ async def get_rooms_in_group(
+ self, destination: str, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get all rooms in a group"""
path = _create_v1_path("/groups/%s/rooms", group_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
ignore_backoff=True,
)
- def add_room_to_group(
- self, destination, group_id, requester_user_id, room_id, content
- ):
+ async def add_room_to_group(
+ self,
+ destination: str,
+ group_id: str,
+ requester_user_id: str,
+ room_id: str,
+ content: JsonDict,
+ ) -> JsonDict:
"""Add a room to a group"""
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -682,15 +709,21 @@ class TransportLayerClient:
ignore_backoff=True,
)
- def update_room_in_group(
- self, destination, group_id, requester_user_id, room_id, config_key, content
- ):
+ async def update_room_in_group(
+ self,
+ destination: str,
+ group_id: str,
+ requester_user_id: str,
+ room_id: str,
+ config_key: str,
+ content: JsonDict,
+ ) -> JsonDict:
"""Update room in group"""
path = _create_v1_path(
"/groups/%s/room/%s/config/%s", group_id, room_id, config_key
)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -698,11 +731,13 @@ class TransportLayerClient:
ignore_backoff=True,
)
- def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
+ async def remove_room_from_group(
+ self, destination: str, group_id: str, requester_user_id: str, room_id: str
+ ) -> JsonDict:
"""Remove a room from a group"""
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
- return self.client.delete_json(
+ return await self.client.delete_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -710,11 +745,13 @@ class TransportLayerClient:
)
@log_function
- def get_users_in_group(self, destination, group_id, requester_user_id):
+ async def get_users_in_group(
+ self, destination: str, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get users in a group"""
path = _create_v1_path("/groups/%s/users", group_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -722,11 +759,13 @@ class TransportLayerClient:
)
@log_function
- def get_invited_users_in_group(self, destination, group_id, requester_user_id):
+ async def get_invited_users_in_group(
+ self, destination: str, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get users that have been invited to a group"""
path = _create_v1_path("/groups/%s/invited_users", group_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -734,16 +773,20 @@ class TransportLayerClient:
)
@log_function
- def accept_group_invite(self, destination, group_id, user_id, content):
+ async def accept_group_invite(
+ self, destination: str, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Accept a group invite"""
path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
- def join_group(self, destination, group_id, user_id, content):
+ def join_group(
+ self, destination: str, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Attempts to join a group"""
path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
@@ -752,13 +795,18 @@ class TransportLayerClient:
)
@log_function
- def invite_to_group(
- self, destination, group_id, user_id, requester_user_id, content
- ):
+ async def invite_to_group(
+ self,
+ destination: str,
+ group_id: str,
+ user_id: str,
+ requester_user_id: str,
+ content: JsonDict,
+ ) -> JsonDict:
"""Invite a user to a group"""
path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -767,25 +815,32 @@ class TransportLayerClient:
)
@log_function
- def invite_to_group_notification(self, destination, group_id, user_id, content):
+ async def invite_to_group_notification(
+ self, destination: str, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Sent by group server to inform a user's server that they have been
invited.
"""
path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
- def remove_user_from_group(
- self, destination, group_id, requester_user_id, user_id, content
- ):
+ async def remove_user_from_group(
+ self,
+ destination: str,
+ group_id: str,
+ requester_user_id: str,
+ user_id: str,
+ content: JsonDict,
+ ) -> JsonDict:
"""Remove a user from a group"""
path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -794,35 +849,43 @@ class TransportLayerClient:
)
@log_function
- def remove_user_from_group_notification(
- self, destination, group_id, user_id, content
- ):
+ async def remove_user_from_group_notification(
+ self, destination: str, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Sent by group server to inform a user's server that they have been
kicked from the group.
"""
path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
- def renew_group_attestation(self, destination, group_id, user_id, content):
+ async def renew_group_attestation(
+ self, destination: str, group_id: str, user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Sent by either a group server or a user's server to periodically update
the attestations
"""
path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
- def update_group_summary_room(
- self, destination, group_id, user_id, room_id, category_id, content
- ):
+ async def update_group_summary_room(
+ self,
+ destination: str,
+ group_id: str,
+ user_id: str,
+ room_id: str,
+ category_id: str,
+ content: JsonDict,
+ ) -> JsonDict:
"""Update a room entry in a group summary"""
if category_id:
path = _create_v1_path(
@@ -834,7 +897,7 @@ class TransportLayerClient:
else:
path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": user_id},
@@ -843,9 +906,14 @@ class TransportLayerClient:
)
@log_function
- def delete_group_summary_room(
- self, destination, group_id, user_id, room_id, category_id
- ):
+ async def delete_group_summary_room(
+ self,
+ destination: str,
+ group_id: str,
+ user_id: str,
+ room_id: str,
+ category_id: str,
+ ) -> JsonDict:
"""Delete a room entry in a group summary"""
if category_id:
path = _create_v1_path(
@@ -857,7 +925,7 @@ class TransportLayerClient:
else:
path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id)
- return self.client.delete_json(
+ return await self.client.delete_json(
destination=destination,
path=path,
args={"requester_user_id": user_id},
@@ -865,11 +933,13 @@ class TransportLayerClient:
)
@log_function
- def get_group_categories(self, destination, group_id, requester_user_id):
+ async def get_group_categories(
+ self, destination: str, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get all categories in a group"""
path = _create_v1_path("/groups/%s/categories", group_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -877,11 +947,13 @@ class TransportLayerClient:
)
@log_function
- def get_group_category(self, destination, group_id, requester_user_id, category_id):
+ async def get_group_category(
+ self, destination: str, group_id: str, requester_user_id: str, category_id: str
+ ) -> JsonDict:
"""Get category info in a group"""
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -889,13 +961,18 @@ class TransportLayerClient:
)
@log_function
- def update_group_category(
- self, destination, group_id, requester_user_id, category_id, content
- ):
+ async def update_group_category(
+ self,
+ destination: str,
+ group_id: str,
+ requester_user_id: str,
+ category_id: str,
+ content: JsonDict,
+ ) -> JsonDict:
"""Update a category in a group"""
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -904,13 +981,13 @@ class TransportLayerClient:
)
@log_function
- def delete_group_category(
- self, destination, group_id, requester_user_id, category_id
- ):
+ async def delete_group_category(
+ self, destination: str, group_id: str, requester_user_id: str, category_id: str
+ ) -> JsonDict:
"""Delete a category in a group"""
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
- return self.client.delete_json(
+ return await self.client.delete_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -918,11 +995,13 @@ class TransportLayerClient:
)
@log_function
- def get_group_roles(self, destination, group_id, requester_user_id):
+ async def get_group_roles(
+ self, destination: str, group_id: str, requester_user_id: str
+ ) -> JsonDict:
"""Get all roles in a group"""
path = _create_v1_path("/groups/%s/roles", group_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -930,11 +1009,13 @@ class TransportLayerClient:
)
@log_function
- def get_group_role(self, destination, group_id, requester_user_id, role_id):
+ async def get_group_role(
+ self, destination: str, group_id: str, requester_user_id: str, role_id: str
+ ) -> JsonDict:
"""Get a roles info"""
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
- return self.client.get_json(
+ return await self.client.get_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -942,13 +1023,18 @@ class TransportLayerClient:
)
@log_function
- def update_group_role(
- self, destination, group_id, requester_user_id, role_id, content
- ):
+ async def update_group_role(
+ self,
+ destination: str,
+ group_id: str,
+ requester_user_id: str,
+ role_id: str,
+ content: JsonDict,
+ ) -> JsonDict:
"""Update a role in a group"""
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -957,11 +1043,13 @@ class TransportLayerClient:
)
@log_function
- def delete_group_role(self, destination, group_id, requester_user_id, role_id):
+ async def delete_group_role(
+ self, destination: str, group_id: str, requester_user_id: str, role_id: str
+ ) -> JsonDict:
"""Delete a role in a group"""
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
- return self.client.delete_json(
+ return await self.client.delete_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -969,9 +1057,15 @@ class TransportLayerClient:
)
@log_function
- def update_group_summary_user(
- self, destination, group_id, requester_user_id, user_id, role_id, content
- ):
+ async def update_group_summary_user(
+ self,
+ destination: str,
+ group_id: str,
+ requester_user_id: str,
+ user_id: str,
+ role_id: str,
+ content: JsonDict,
+ ) -> JsonDict:
"""Update a users entry in a group"""
if role_id:
path = _create_v1_path(
@@ -980,7 +1074,7 @@ class TransportLayerClient:
else:
path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id)
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -989,11 +1083,13 @@ class TransportLayerClient:
)
@log_function
- def set_group_join_policy(self, destination, group_id, requester_user_id, content):
+ async def set_group_join_policy(
+ self, destination: str, group_id: str, requester_user_id: str, content: JsonDict
+ ) -> JsonDict:
"""Sets the join policy for a group"""
path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id)
- return self.client.put_json(
+ return await self.client.put_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
@@ -1002,9 +1098,14 @@ class TransportLayerClient:
)
@log_function
- def delete_group_summary_user(
- self, destination, group_id, requester_user_id, user_id, role_id
- ):
+ async def delete_group_summary_user(
+ self,
+ destination: str,
+ group_id: str,
+ requester_user_id: str,
+ user_id: str,
+ role_id: str,
+ ) -> JsonDict:
"""Delete a users entry in a group"""
if role_id:
path = _create_v1_path(
@@ -1013,33 +1114,35 @@ class TransportLayerClient:
else:
path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id)
- return self.client.delete_json(
+ return await self.client.delete_json(
destination=destination,
path=path,
args={"requester_user_id": requester_user_id},
ignore_backoff=True,
)
- def bulk_get_publicised_groups(self, destination, user_ids):
+ async def bulk_get_publicised_groups(
+ self, destination: str, user_ids: Iterable[str]
+ ) -> JsonDict:
"""Get the groups a list of users are publicising"""
path = _create_v1_path("/get_groups_publicised")
content = {"user_ids": user_ids}
- return self.client.post_json(
+ return await self.client.post_json(
destination=destination, path=path, data=content, ignore_backoff=True
)
- def get_room_complexity(self, destination, room_id):
+ async def get_room_complexity(self, destination: str, room_id: str) -> JsonDict:
"""
Args:
- destination (str): The remote server
- room_id (str): The room ID to ask about.
+ destination: The remote server
+ room_id: The room ID to ask about.
"""
path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/rooms/%s/complexity", room_id)
- return self.client.get_json(destination=destination, path=path)
+ return await self.client.get_json(destination=destination, path=path)
async def get_space_summary(
self,
@@ -1075,14 +1178,14 @@ class TransportLayerClient:
)
-def _create_path(federation_prefix, path, *args):
+def _create_path(federation_prefix: str, path: str, *args: str) -> str:
"""
Ensures that all args are url encoded.
"""
return federation_prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args)
-def _create_v1_path(path, *args):
+def _create_v1_path(path: str, *args: str) -> str:
"""Creates a path against V1 federation API from the path template and
args. Ensures that all args are url encoded.
@@ -1091,16 +1194,13 @@ def _create_v1_path(path, *args):
_create_v1_path("/event/%s", event_id)
Args:
- path (str): String template for the path
- args: ([str]): Args to insert into path. Each arg will be url encoded
-
- Returns:
- str
+ path: String template for the path
+ args: Args to insert into path. Each arg will be url encoded
"""
return _create_path(FEDERATION_V1_PREFIX, path, *args)
-def _create_v2_path(path, *args):
+def _create_v2_path(path: str, *args: str) -> str:
"""Creates a path against V2 federation API from the path template and
args. Ensures that all args are url encoded.
@@ -1109,11 +1209,8 @@ def _create_v2_path(path, *args):
_create_v2_path("/event/%s", event_id)
Args:
- path (str): String template for the path
- args: ([str]): Args to insert into path. Each arg will be url encoded
-
- Returns:
- str
+ path: String template for the path
+ args: Args to insert into path. Each arg will be url encoded
"""
return _create_path(FEDERATION_V2_PREFIX, path, *args)
@@ -1122,8 +1219,26 @@ def _create_v2_path(path, *args):
class SendJoinResponse:
"""The parsed response of a `/send_join` request."""
+ # The list of auth events from the /send_join response.
auth_events: List[EventBase]
+ # The list of state from the /send_join response.
state: List[EventBase]
+ # The raw join event from the /send_join response.
+ event_dict: JsonDict
+ # The parsed join event from the /send_join response. This will be None if
+ # "event" is not included in the response.
+ event: Optional[EventBase] = None
+
+
+@ijson.coroutine
+def _event_parser(event_dict: JsonDict):
+ """Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
+ to add them to a given dictionary.
+ """
+
+ while True:
+ key, value = yield
+ event_dict[key] = value
@ijson.coroutine
@@ -1149,7 +1264,8 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
CONTENT_TYPE = "application/json"
def __init__(self, room_version: RoomVersion, v1_api: bool):
- self._response = SendJoinResponse([], [])
+ self._response = SendJoinResponse([], [], {})
+ self._room_version = room_version
# The V1 API has the shape of `[200, {...}]`, which we handle by
# prefixing with `item.*`.
@@ -1163,12 +1279,21 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
_event_list_parser(room_version, self._response.auth_events),
prefix + "auth_chain.item",
)
+ self._coro_event = ijson.kvitems_coro(
+ _event_parser(self._response.event_dict),
+ prefix + "org.matrix.msc3083.v2.event",
+ )
def write(self, data: bytes) -> int:
self._coro_state.send(data)
self._coro_auth.send(data)
+ self._coro_event.send(data)
return len(data)
def finish(self) -> SendJoinResponse:
+ if self._response.event_dict:
+ self._response.event = make_event_from_dict(
+ self._response.event_dict, self._room_version
+ )
return self._response
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 2974d4d0cc..5e059d6e09 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -984,7 +984,7 @@ class PublicRoomList(BaseFederationServlet):
limit = parse_integer_from_args(query, "limit", 0)
since_token = parse_string_from_args(query, "since", None)
include_all_networks = parse_boolean_from_args(
- query, "include_all_networks", False
+ query, "include_all_networks", default=False
)
third_party_instance_id = parse_string_from_args(
query, "third_party_instance_id", None
@@ -1908,16 +1908,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet):
suggested_only = parse_boolean_from_args(query, "suggested_only", default=False)
max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space")
- exclude_rooms = []
- if b"exclude_rooms" in query:
- try:
- exclude_rooms = [
- room_id.decode("ascii") for room_id in query[b"exclude_rooms"]
- ]
- except Exception:
- raise SynapseError(
- 400, "Bad query parameter for exclude_rooms", Codes.INVALID_PARAM
- )
+ exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[])
return 200, await self.handler.federation_space_summary(
origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 41dbdfd0a1..53fac1f8a3 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
from typing import TYPE_CHECKING, Collection, List, Optional, Union
from synapse import event_auth
@@ -20,16 +21,18 @@ from synapse.api.constants import (
Membership,
RestrictedJoinRuleTypes,
)
-from synapse.api.errors import AuthError
+from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
-from synapse.types import StateMap
+from synapse.types import StateMap, get_domain_from_id
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
+logger = logging.getLogger(__name__)
+
class EventAuthHandler:
"""
@@ -39,6 +42,7 @@ class EventAuthHandler:
def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._store = hs.get_datastore()
+ self._server_name = hs.hostname
async def check_from_context(
self, room_version: str, event, context, do_sig_check=True
@@ -81,15 +85,76 @@ class EventAuthHandler:
# introduce undesirable "state reset" behaviour.
#
# All of which sounds a bit tricky so we don't bother for now.
-
auth_ids = []
- for etype, state_key in event_auth.auth_types_for_event(event):
+ for etype, state_key in event_auth.auth_types_for_event(
+ event.room_version, event
+ ):
auth_ev_id = current_state_ids.get((etype, state_key))
if auth_ev_id:
auth_ids.append(auth_ev_id)
return auth_ids
+ async def get_user_which_could_invite(
+ self, room_id: str, current_state_ids: StateMap[str]
+ ) -> str:
+ """
+ Searches the room state for a local user who has the power level necessary
+ to invite other users.
+
+ Args:
+ room_id: The room ID under search.
+ current_state_ids: The current state of the room.
+
+ Returns:
+ The MXID of the user which could issue an invite.
+
+ Raises:
+ SynapseError if no appropriate user is found.
+ """
+ power_level_event_id = current_state_ids.get((EventTypes.PowerLevels, ""))
+ invite_level = 0
+ users_default_level = 0
+ if power_level_event_id:
+ power_level_event = await self._store.get_event(power_level_event_id)
+ invite_level = power_level_event.content.get("invite", invite_level)
+ users_default_level = power_level_event.content.get(
+ "users_default", users_default_level
+ )
+ users = power_level_event.content.get("users", {})
+ else:
+ users = {}
+
+ # Find the user with the highest power level.
+ users_in_room = await self._store.get_users_in_room(room_id)
+ # Only interested in local users.
+ local_users_in_room = [
+ u for u in users_in_room if get_domain_from_id(u) == self._server_name
+ ]
+ chosen_user = max(
+ local_users_in_room,
+ key=lambda user: users.get(user, users_default_level),
+ default=None,
+ )
+
+ # Return the chosen if they can issue invites.
+ user_power_level = users.get(chosen_user, users_default_level)
+ if chosen_user and user_power_level >= invite_level:
+ logger.debug(
+ "Found a user who can issue invites %s with power level %d >= invite level %d",
+ chosen_user,
+ user_power_level,
+ invite_level,
+ )
+ return chosen_user
+
+ # No user was found.
+ raise SynapseError(
+ 400,
+ "Unable to find a user which could issue an invite",
+ Codes.UNABLE_TO_GRANT_JOIN,
+ )
+
async def check_host_in_room(self, room_id: str, host: str) -> bool:
with Measure(self._clock, "check_host_in_room"):
return await self._store.is_host_joined(room_id, host)
@@ -134,6 +199,18 @@ class EventAuthHandler:
# in any of them.
allowed_rooms = await self.get_rooms_that_allow_join(state_ids)
if not await self.is_user_in_rooms(allowed_rooms, user_id):
+
+ # If this is a remote request, the user might be in an allowed room
+ # that we do not know about.
+ if get_domain_from_id(user_id) != self._server_name:
+ for room_id in allowed_rooms:
+ if not await self._store.is_host_joined(room_id, self._server_name):
+ raise SynapseError(
+ 400,
+ f"Unable to check if {user_id} is in allowed rooms.",
+ Codes.UNABLE_AUTHORISE_JOIN,
+ )
+
raise AuthError(
403,
"You do not belong to any of the required rooms to join this room.",
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 5728719909..8197b60b76 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1494,9 +1494,10 @@ class FederationHandler(BaseHandler):
host_list, event, room_version_obj
)
- origin = ret["origin"]
- state = ret["state"]
- auth_chain = ret["auth_chain"]
+ event = ret.event
+ origin = ret.origin
+ state = ret.state
+ auth_chain = ret.auth_chain
auth_chain.sort(key=lambda e: e.depth)
logger.debug("do_invite_join auth_chain: %s", auth_chain)
@@ -1676,7 +1677,7 @@ class FederationHandler(BaseHandler):
# checking the room version will check that we've actually heard of the room
# (and return a 404 otherwise)
- room_version = await self.store.get_room_version_id(room_id)
+ room_version = await self.store.get_room_version(room_id)
# now check that we are *still* in the room
is_in_room = await self._event_auth_handler.check_host_in_room(
@@ -1691,8 +1692,38 @@ class FederationHandler(BaseHandler):
event_content = {"membership": Membership.JOIN}
+ # If the current room is using restricted join rules, additional information
+ # may need to be included in the event content in order to efficiently
+ # validate the event.
+ #
+ # Note that this requires the /send_join request to come back to the
+ # same server.
+ if room_version.msc3083_join_rules:
+ state_ids = await self.store.get_current_state_ids(room_id)
+ if await self._event_auth_handler.has_restricted_join_rules(
+ state_ids, room_version
+ ):
+ prev_member_event_id = state_ids.get((EventTypes.Member, user_id), None)
+ # If the user is invited or joined to the room already, then
+ # no additional info is needed.
+ include_auth_user_id = True
+ if prev_member_event_id:
+ prev_member_event = await self.store.get_event(prev_member_event_id)
+ include_auth_user_id = prev_member_event.membership not in (
+ Membership.JOIN,
+ Membership.INVITE,
+ )
+
+ if include_auth_user_id:
+ event_content[
+ "join_authorised_via_users_server"
+ ] = await self._event_auth_handler.get_user_which_could_invite(
+ room_id,
+ state_ids,
+ )
+
builder = self.event_builder_factory.new(
- room_version,
+ room_version.identifier,
{
"type": EventTypes.Member,
"content": event_content,
@@ -1710,10 +1741,13 @@ class FederationHandler(BaseHandler):
logger.warning("Failed to create join to %s because %s", room_id, e)
raise
+ # Ensure the user can even join the room.
+ await self._check_join_restrictions(context, event)
+
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
await self._event_auth_handler.check_from_context(
- room_version, event, context, do_sig_check=False
+ room_version.identifier, event, context, do_sig_check=False
)
return event
@@ -1958,7 +1992,7 @@ class FederationHandler(BaseHandler):
@log_function
async def on_send_membership_event(
self, origin: str, event: EventBase
- ) -> EventContext:
+ ) -> Tuple[EventBase, EventContext]:
"""
We have received a join/leave/knock event for a room via send_join/leave/knock.
@@ -1981,7 +2015,7 @@ class FederationHandler(BaseHandler):
event: The member event that has been signed by the remote homeserver.
Returns:
- The context of the event after inserting it into the room graph.
+ The event and context of the event after inserting it into the room graph.
Raises:
SynapseError if the event is not accepted into the room
@@ -2037,7 +2071,7 @@ class FederationHandler(BaseHandler):
# all looks good, we can persist the event.
await self._run_push_actions_and_persist_event(event, context)
- return context
+ return event, context
async def _check_join_restrictions(
self, context: EventContext, event: EventBase
@@ -2473,7 +2507,7 @@ class FederationHandler(BaseHandler):
)
# Now check if event pass auth against said current state
- auth_types = auth_types_for_event(event)
+ auth_types = auth_types_for_event(room_version_obj, event)
current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types
]
@@ -2714,9 +2748,11 @@ class FederationHandler(BaseHandler):
event.event_id,
e.event_id,
)
- context = await self.state_handler.compute_event_context(e)
+ missing_auth_event_context = (
+ await self.state_handler.compute_event_context(e)
+ )
await self._auth_and_persist_event(
- origin, e, context, auth_events=auth
+ origin, e, missing_auth_event_context, auth_events=auth
)
if e.event_id in event_auth_events:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 5d49640760..e1c544a3c9 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -21,6 +21,7 @@ from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
+from synapse.handlers.receipts import ReceiptEventSource
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage.roommember import RoomsForUser
from synapse.streams.config import PaginationConfig
@@ -134,6 +135,8 @@ class InitialSyncHandler(BaseHandler):
joined_rooms,
to_key=int(now_token.receipt_key),
)
+ if self.hs.config.experimental.msc2285_enabled:
+ receipt = ReceiptEventSource.filter_out_hidden(receipt, user_id)
tags_by_room = await self.store.get_tags_for_user(user_id)
@@ -430,7 +433,9 @@ class InitialSyncHandler(BaseHandler):
room_id, to_key=now_token.receipt_key
)
if not receipts:
- receipts = []
+ return []
+ if self.hs.config.experimental.msc2285_enabled:
+ receipts = ReceiptEventSource.filter_out_hidden(receipts, user_id)
return receipts
presence, receipts, (messages, token) = await make_deferred_yieldable(
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 283483fc2c..b9085bbccb 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -14,9 +14,10 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
+from synapse.api.constants import ReadReceiptEventFields
from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
-from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
+from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -137,7 +138,7 @@ class ReceiptsHandler(BaseHandler):
return True
async def received_client_receipt(
- self, room_id: str, receipt_type: str, user_id: str, event_id: str
+ self, room_id: str, receipt_type: str, user_id: str, event_id: str, hidden: bool
) -> None:
"""Called when a client tells us a local user has read up to the given
event_id in the room.
@@ -147,23 +148,67 @@ class ReceiptsHandler(BaseHandler):
receipt_type=receipt_type,
user_id=user_id,
event_ids=[event_id],
- data={"ts": int(self.clock.time_msec())},
+ data={"ts": int(self.clock.time_msec()), "hidden": hidden},
)
is_new = await self._handle_new_receipts([receipt])
if not is_new:
return
- if self.federation_sender:
+ if self.federation_sender and not (
+ self.hs.config.experimental.msc2285_enabled and hidden
+ ):
await self.federation_sender.send_read_receipt(receipt)
class ReceiptEventSource:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
+ self.config = hs.config
+
+ @staticmethod
+ def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]:
+ visible_events = []
+
+ # filter out hidden receipts the user shouldn't see
+ for event in events:
+ content = event.get("content", {})
+ new_event = event.copy()
+ new_event["content"] = {}
+
+ for event_id in content.keys():
+ event_content = content.get(event_id, {})
+ m_read = event_content.get("m.read", {})
+
+ # If m_read is missing copy over the original event_content as there is nothing to process here
+ if not m_read:
+ new_event["content"][event_id] = event_content.copy()
+ continue
+
+ new_users = {}
+ for rr_user_id, user_rr in m_read.items():
+ hidden = user_rr.get("hidden", None)
+ if hidden is not True or rr_user_id == user_id:
+ new_users[rr_user_id] = user_rr.copy()
+ # If hidden has a value replace hidden with the correct prefixed key
+ if hidden is not None:
+ new_users[rr_user_id].pop("hidden")
+ new_users[rr_user_id][
+ ReadReceiptEventFields.MSC2285_HIDDEN
+ ] = hidden
+
+ # Set new users unless empty
+ if len(new_users.keys()) > 0:
+ new_event["content"][event_id] = {"m.read": new_users}
+
+ # Append new_event to visible_events unless empty
+ if len(new_event["content"].keys()) > 0:
+ visible_events.append(new_event)
+
+ return visible_events
async def get_new_events(
- self, from_key: int, room_ids: List[str], **kwargs
+ self, from_key: int, room_ids: List[str], user: UserID, **kwargs
) -> Tuple[List[JsonDict], int]:
from_key = int(from_key)
to_key = self.get_current_key()
@@ -175,6 +220,9 @@ class ReceiptEventSource:
room_ids, from_key=from_key, to_key=to_key
)
+ if self.config.experimental.msc2285_enabled:
+ events = ReceiptEventSource.filter_out_hidden(events, user.to_string())
+
return (events, to_key)
async def get_new_events_as(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 370561e549..b33fe09f77 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -951,6 +951,7 @@ class RoomCreationHandler(BaseHandler):
"kick": 50,
"redact": 50,
"invite": 50,
+ "historical": 100,
}
if config["original_invitees_have_ops"]:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 3bff2fc489..45467c7e42 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -16,7 +16,7 @@ import abc
import logging
import random
from http import HTTPStatus
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple
from synapse import types
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
@@ -28,6 +28,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter
+from synapse.event_auth import get_named_level, get_power_level_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import (
@@ -341,16 +342,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if event.membership == Membership.JOIN:
newly_joined = True
- prev_member_event = None
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
- # Check if the member should be allowed access via membership in a space.
- await self.event_auth_handler.check_restricted_join_rules(
- prev_state_ids, event.room_version, user_id, prev_member_event
- )
-
# Only rate-limit if the user actually joined the room, otherwise we'll end
# up blocking profile updates.
if newly_joined and ratelimit:
@@ -721,7 +716,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
- if not is_host_in_room:
+ # Check if a remote join should be performed.
+ remote_join, remote_room_hosts = await self._should_perform_remote_join(
+ target.to_string(), room_id, remote_room_hosts, content, is_host_in_room
+ )
+ if remote_join:
if ratelimit:
time_now_s = self.clock.time()
(
@@ -846,6 +845,106 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
outlier=outlier,
)
+ async def _should_perform_remote_join(
+ self,
+ user_id: str,
+ room_id: str,
+ remote_room_hosts: List[str],
+ content: JsonDict,
+ is_host_in_room: bool,
+ ) -> Tuple[bool, List[str]]:
+ """
+ Check whether the server should do a remote join (as opposed to a local
+ join) for a user.
+
+ Generally a remote join is used if:
+
+ * The server is not yet in the room.
+ * The server is in the room, the room has restricted join rules, the user
+ is not joined or invited to the room, and the server does not have
+ another user who is capable of issuing invites.
+
+ Args:
+ user_id: The user joining the room.
+ room_id: The room being joined.
+ remote_room_hosts: A list of remote room hosts.
+ content: The content to use as the event body of the join. This may
+ be modified.
+ is_host_in_room: True if the host is in the room.
+
+ Returns:
+ A tuple of:
+ True if a remote join should be performed. False if the join can be
+ done locally.
+
+ A list of remote room hosts to use. This is an empty list if a
+ local join is to be done.
+ """
+ # If the host isn't in the room, pass through the prospective hosts.
+ if not is_host_in_room:
+ return True, remote_room_hosts
+
+ # If the host is in the room, but not one of the authorised hosts
+ # for restricted join rules, a remote join must be used.
+ room_version = await self.store.get_room_version(room_id)
+ current_state_ids = await self.store.get_current_state_ids(room_id)
+
+ # If restricted join rules are not being used, a local join can always
+ # be used.
+ if not await self.event_auth_handler.has_restricted_join_rules(
+ current_state_ids, room_version
+ ):
+ return False, []
+
+ # If the user is invited to the room or already joined, the join
+ # event can always be issued locally.
+ prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None)
+ prev_member_event = None
+ if prev_member_event_id:
+ prev_member_event = await self.store.get_event(prev_member_event_id)
+ if prev_member_event.membership in (
+ Membership.JOIN,
+ Membership.INVITE,
+ ):
+ return False, []
+
+ # If the local host has a user who can issue invites, then a local
+ # join can be done.
+ #
+ # If not, generate a new list of remote hosts based on which
+ # can issue invites.
+ event_map = await self.store.get_events(current_state_ids.values())
+ current_state = {
+ state_key: event_map[event_id]
+ for state_key, event_id in current_state_ids.items()
+ }
+ allowed_servers = get_servers_from_users(
+ get_users_which_can_issue_invite(current_state)
+ )
+
+ # If the local server is not one of allowed servers, then a remote
+ # join must be done. Return the list of prospective servers based on
+ # which can issue invites.
+ if self.hs.hostname not in allowed_servers:
+ return True, list(allowed_servers)
+
+ # Ensure the member should be allowed access via membership in a room.
+ await self.event_auth_handler.check_restricted_join_rules(
+ current_state_ids, room_version, user_id, prev_member_event
+ )
+
+ # If this is going to be a local join, additional information must
+ # be included in the event content in order to efficiently validate
+ # the event.
+ content[
+ "join_authorised_via_users_server"
+ ] = await self.event_auth_handler.get_user_which_could_invite(
+ room_id,
+ current_state_ids,
+ )
+
+ return False, []
+
async def transfer_room_state_on_room_upgrade(
self, old_room_id: str, room_id: str
) -> None:
@@ -1534,3 +1633,63 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if membership:
await self.store.forget(user_id, room_id)
+
+
+def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[str]:
+ """
+ Return the list of users which can issue invites.
+
+ This is done by exploring the joined users and comparing their power levels
+ to the necessyar power level to issue an invite.
+
+ Args:
+ auth_events: state in force at this point in the room
+
+ Returns:
+ The users which can issue invites.
+ """
+ invite_level = get_named_level(auth_events, "invite", 0)
+ users_default_level = get_named_level(auth_events, "users_default", 0)
+ power_level_event = get_power_level_event(auth_events)
+
+ # Custom power-levels for users.
+ if power_level_event:
+ users = power_level_event.content.get("users", {})
+ else:
+ users = {}
+
+ result = []
+
+ # Check which members are able to invite by ensuring they're joined and have
+ # the necessary power level.
+ for (event_type, state_key), event in auth_events.items():
+ if event_type != EventTypes.Member:
+ continue
+
+ if event.membership != Membership.JOIN:
+ continue
+
+ # Check if the user has a custom power level.
+ if users.get(state_key, users_default_level) >= invite_level:
+ result.append(state_key)
+
+ return result
+
+
+def get_servers_from_users(users: List[str]) -> Set[str]:
+ """
+ Resolve a list of users into their servers.
+
+ Args:
+ users: A list of users.
+
+ Returns:
+ A set of servers.
+ """
+ servers = set()
+ for user in users:
+ try:
+ servers.add(get_domain_from_id(user))
+ except SynapseError:
+ pass
+ return servers
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 2ac76b15c2..c2ea51ee16 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -847,7 +847,7 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
def read_body_with_max_size(
response: IResponse, stream: ByteWriteable, max_size: Optional[int]
-) -> defer.Deferred:
+) -> "defer.Deferred[int]":
"""
Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
@@ -862,7 +862,7 @@ def read_body_with_max_size(
Returns:
A Deferred which resolves to the length of the read body.
"""
- d = defer.Deferred()
+ d: "defer.Deferred[int]" = defer.Deferred()
# If the Content-Length header gives a size larger than the maximum allowed
# size, do not bother downloading the body.
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 950770201a..c16b7f10e6 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -27,7 +27,7 @@ from twisted.internet.interfaces import (
)
from twisted.web.client import URI, Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer
+from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer, IResponse
from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http.client import BlacklistingAgentWrapper
@@ -116,7 +116,7 @@ class MatrixFederationAgent:
uri: bytes,
headers: Optional[Headers] = None,
bodyProducer: Optional[IBodyProducer] = None,
- ) -> Generator[defer.Deferred, Any, defer.Deferred]:
+ ) -> Generator[defer.Deferred, Any, IResponse]:
"""
Args:
method: HTTP method: GET/POST/etc
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index f7193e60bd..19e987f118 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -14,21 +14,32 @@
import base64
import logging
import re
-from typing import Optional, Tuple
-from urllib.request import getproxies_environment, proxy_bypass_environment
+from typing import Any, Dict, Optional, Tuple
+from urllib.parse import urlparse
+from urllib.request import ( # type: ignore[attr-defined]
+ getproxies_environment,
+ proxy_bypass_environment,
+)
import attr
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
from twisted.python.failure import Failure
-from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
+from twisted.web.client import (
+ URI,
+ BrowserLikePolicyForHTTPS,
+ HTTPConnectionPool,
+ _AgentBase,
+)
from twisted.web.error import SchemeNotSupported
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IPolicyForHTTPS
+from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
+from synapse.types import ISynapseReactor
logger = logging.getLogger(__name__)
@@ -63,35 +74,38 @@ class ProxyAgent(_AgentBase):
reactor might have some blacklisting applied (i.e. for DNS queries),
but we need unblocked access to the proxy.
- contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the
+ contextFactory: A factory for TLS contexts, to control the
verification parameters of OpenSSL. The default is to use a
`BrowserLikePolicyForHTTPS`, so unless you have special
requirements you can leave this as-is.
- connectTimeout (Optional[float]): The amount of time that this Agent will wait
+ connectTimeout: The amount of time that this Agent will wait
for the peer to accept a connection, in seconds. If 'None',
HostnameEndpoint's default (30s) will be used.
-
This is used for connections to both proxies and destination servers.
- bindAddress (bytes): The local address for client sockets to bind to.
+ bindAddress: The local address for client sockets to bind to.
- pool (HTTPConnectionPool|None): connection pool to be used. If None, a
+ pool: connection pool to be used. If None, a
non-persistent pool instance will be created.
- use_proxy (bool): Whether proxy settings should be discovered and used
+ use_proxy: Whether proxy settings should be discovered and used
from conventional environment variables.
+
+ Raises:
+ ValueError if use_proxy is set and the environment variables
+ contain an invalid proxy specification.
"""
def __init__(
self,
- reactor,
- proxy_reactor=None,
+ reactor: IReactorCore,
+ proxy_reactor: Optional[ISynapseReactor] = None,
contextFactory: Optional[IPolicyForHTTPS] = None,
- connectTimeout=None,
- bindAddress=None,
- pool=None,
- use_proxy=False,
+ connectTimeout: Optional[float] = None,
+ bindAddress: Optional[bytes] = None,
+ pool: Optional[HTTPConnectionPool] = None,
+ use_proxy: bool = False,
):
contextFactory = contextFactory or BrowserLikePolicyForHTTPS()
@@ -102,7 +116,7 @@ class ProxyAgent(_AgentBase):
else:
self.proxy_reactor = proxy_reactor
- self._endpoint_kwargs = {}
+ self._endpoint_kwargs: Dict[str, Any] = {}
if connectTimeout is not None:
self._endpoint_kwargs["timeout"] = connectTimeout
if bindAddress is not None:
@@ -117,16 +131,12 @@ class ProxyAgent(_AgentBase):
https_proxy = proxies["https"].encode() if "https" in proxies else None
no_proxy = proxies["no"] if "no" in proxies else None
- # Parse credentials from http and https proxy connection string if present
- self.http_proxy_creds, http_proxy = parse_username_password(http_proxy)
- self.https_proxy_creds, https_proxy = parse_username_password(https_proxy)
-
- self.http_proxy_endpoint = _http_proxy_endpoint(
- http_proxy, self.proxy_reactor, **self._endpoint_kwargs
+ self.http_proxy_endpoint, self.http_proxy_creds = _http_proxy_endpoint(
+ http_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs
)
- self.https_proxy_endpoint = _http_proxy_endpoint(
- https_proxy, self.proxy_reactor, **self._endpoint_kwargs
+ self.https_proxy_endpoint, self.https_proxy_creds = _http_proxy_endpoint(
+ https_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs
)
self.no_proxy = no_proxy
@@ -134,7 +144,13 @@ class ProxyAgent(_AgentBase):
self._policy_for_https = contextFactory
self._reactor = reactor
- def request(self, method, uri, headers=None, bodyProducer=None):
+ def request(
+ self,
+ method: bytes,
+ uri: bytes,
+ headers: Optional[Headers] = None,
+ bodyProducer: Optional[IBodyProducer] = None,
+ ) -> defer.Deferred:
"""
Issue a request to the server indicated by the given uri.
@@ -146,16 +162,15 @@ class ProxyAgent(_AgentBase):
See also: twisted.web.iweb.IAgent.request
Args:
- method (bytes): The request method to use, such as `GET`, `POST`, etc
+ method: The request method to use, such as `GET`, `POST`, etc
- uri (bytes): The location of the resource to request.
+ uri: The location of the resource to request.
- headers (Headers|None): Extra headers to send with the request
+ headers: Extra headers to send with the request
- bodyProducer (IBodyProducer|None): An object which can generate bytes to
- make up the body of this request (for example, the properly encoded
- contents of a file for a file upload). Or, None if the request is to
- have no body.
+ bodyProducer: An object which can generate bytes to make up the body of
+ this request (for example, the properly encoded contents of a file for
+ a file upload). Or, None if the request is to have no body.
Returns:
Deferred[IResponse]: completes when the header of the response has
@@ -253,70 +268,89 @@ class ProxyAgent(_AgentBase):
)
-def _http_proxy_endpoint(proxy: Optional[bytes], reactor, **kwargs):
+def _http_proxy_endpoint(
+ proxy: Optional[bytes],
+ reactor: IReactorCore,
+ tls_options_factory: IPolicyForHTTPS,
+ **kwargs,
+) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]:
"""Parses an http proxy setting and returns an endpoint for the proxy
Args:
- proxy: the proxy setting in the form: [<username>:<password>@]<host>[:<port>]
- Note that compared to other apps, this function currently lacks support
- for specifying a protocol schema (i.e. protocol://...).
+ proxy: the proxy setting in the form: [scheme://][<username>:<password>@]<host>[:<port>]
+ This currently supports http:// and https:// proxies.
+ A hostname without scheme is assumed to be http.
reactor: reactor to be used to connect to the proxy
+ tls_options_factory: the TLS options to use when connecting through a https proxy
+
kwargs: other args to be passed to HostnameEndpoint
Returns:
- interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy,
- or None
+ a tuple of
+ endpoint to use to connect to the proxy, or None
+ ProxyCredentials or if no credentials were found, or None
+
+ Raise:
+ ValueError if proxy has no hostname or unsupported scheme.
"""
if proxy is None:
- return None
+ return None, None
- # Parse the connection string
- host, port = parse_host_port(proxy, default_port=1080)
- return HostnameEndpoint(reactor, host, port, **kwargs)
+ # Note: urlsplit/urlparse cannot be used here as that does not work (for Python
+ # 3.9+) on scheme-less proxies, e.g. host:port.
+ scheme, host, port, credentials = parse_proxy(proxy)
+ proxy_endpoint = HostnameEndpoint(reactor, host, port, **kwargs)
-def parse_username_password(proxy: bytes) -> Tuple[Optional[ProxyCredentials], bytes]:
- """
- Parses the username and password from a proxy declaration e.g
- username:password@hostname:port.
+ if scheme == b"https":
+ tls_options = tls_options_factory.creatorForNetloc(host, port)
+ proxy_endpoint = wrapClientTLS(tls_options, proxy_endpoint)
- Args:
- proxy: The proxy connection string.
+ return proxy_endpoint, credentials
- Returns
- An instance of ProxyCredentials and the proxy connection string with any credentials
- stripped, i.e u:p@host:port -> host:port. If no credentials were found, the
- ProxyCredentials instance is replaced with None.
- """
- if proxy and b"@" in proxy:
- # We use rsplit here as the password could contain an @ character
- credentials, proxy_without_credentials = proxy.rsplit(b"@", 1)
- return ProxyCredentials(credentials), proxy_without_credentials
- return None, proxy
+def parse_proxy(
+ proxy: bytes, default_scheme: bytes = b"http", default_port: int = 1080
+) -> Tuple[bytes, bytes, int, Optional[ProxyCredentials]]:
+ """
+ Parse a proxy connection string.
+ Given a HTTP proxy URL, breaks it down into components and checks that it
+ has a hostname (otherwise it is not useful to us when trying to find a
+ proxy) and asserts that the URL has a scheme we support.
-def parse_host_port(hostport: bytes, default_port: int = None) -> Tuple[bytes, int]:
- """
- Parse the hostname and port from a proxy connection byte string.
Args:
- hostport: The proxy connection string. Must be in the form 'host[:port]'.
- default_port: The default port to return if one is not found in `hostport`.
+ proxy: The proxy connection string. Must be in the form '[scheme://][<username>:<password>@]host[:port]'.
+ default_scheme: The default scheme to return if one is not found in `proxy`. Defaults to http
+ default_port: The default port to return if one is not found in `proxy`. Defaults to 1080
Returns:
- A tuple containing the hostname and port. Uses `default_port` if one was not found.
+ A tuple containing the scheme, hostname, port and ProxyCredentials.
+ If no credentials were found, the ProxyCredentials instance is replaced with None.
+
+ Raise:
+ ValueError if proxy has no hostname or unsupported scheme.
"""
- if b":" in hostport:
- host, port = hostport.rsplit(b":", 1)
- try:
- port = int(port)
- return host, port
- except ValueError:
- # the thing after the : wasn't a valid port; presumably this is an
- # IPv6 address.
- pass
+ # First check if we have a scheme present
+ # Note: urlsplit/urlparse cannot be used (for Python # 3.9+) on scheme-less proxies, e.g. host:port.
+ if b"://" not in proxy:
+ proxy = b"".join([default_scheme, b"://", proxy])
+
+ url = urlparse(proxy)
+
+ if not url.hostname:
+ raise ValueError("Proxy URL did not contain a hostname! Please specify one.")
+
+ if url.scheme not in (b"http", b"https"):
+ raise ValueError(
+ f"Unknown proxy scheme {url.scheme!s}; only 'http' and 'https' is supported."
+ )
+
+ credentials = None
+ if url.username and url.password:
+ credentials = ProxyCredentials(b"".join([url.username, b":", url.password]))
- return hostport, default_port
+ return url.scheme, url.hostname, url.port or default_port, credentials
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 04560fb589..732a1e6aeb 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -14,47 +14,86 @@
""" This module contains base REST classes for constructing REST servlets. """
import logging
-from typing import Dict, Iterable, List, Optional, overload
+from typing import Iterable, List, Mapping, Optional, Sequence, overload
from typing_extensions import Literal
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
+from synapse.types import JsonDict
from synapse.util import json_decoder
logger = logging.getLogger(__name__)
-def parse_integer(request, name, default=None, required=False):
+@overload
+def parse_integer(request: Request, name: str, default: int) -> int:
+ ...
+
+
+@overload
+def parse_integer(request: Request, name: str, *, required: Literal[True]) -> int:
+ ...
+
+
+@overload
+def parse_integer(
+ request: Request, name: str, default: Optional[int] = None, required: bool = False
+) -> Optional[int]:
+ ...
+
+
+def parse_integer(
+ request: Request, name: str, default: Optional[int] = None, required: bool = False
+) -> Optional[int]:
"""Parse an integer parameter from the request string
Args:
request: the twisted HTTP request.
- name (bytes/unicode): the name of the query parameter.
- default (int|None): value to use if the parameter is absent, defaults
- to None.
- required (bool): whether to raise a 400 SynapseError if the
- parameter is absent, defaults to False.
+ name: the name of the query parameter.
+ default: value to use if the parameter is absent, defaults to None.
+ required: whether to raise a 400 SynapseError if the parameter is absent,
+ defaults to False.
Returns:
- int|None: An int value or the default.
+ An int value or the default.
Raises:
SynapseError: if the parameter is absent and required, or if the
parameter is present and not an integer.
"""
- return parse_integer_from_args(request.args, name, default, required)
+ args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore
+ return parse_integer_from_args(args, name, default, required)
+
+def parse_integer_from_args(
+ args: Mapping[bytes, Sequence[bytes]],
+ name: str,
+ default: Optional[int] = None,
+ required: bool = False,
+) -> Optional[int]:
+ """Parse an integer parameter from the request string
+
+ Args:
+ args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
+ name: the name of the query parameter.
+ default: value to use if the parameter is absent, defaults to None.
+ required: whether to raise a 400 SynapseError if the parameter is absent,
+ defaults to False.
-def parse_integer_from_args(args, name, default=None, required=False):
+ Returns:
+ An int value or the default.
- if not isinstance(name, bytes):
- name = name.encode("ascii")
+ Raises:
+ SynapseError: if the parameter is absent and required, or if the
+ parameter is present and not an integer.
+ """
+ name_bytes = name.encode("ascii")
- if name in args:
+ if name_bytes in args:
try:
- return int(args[name][0])
+ return int(args[name_bytes][0])
except Exception:
message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
@@ -66,36 +105,102 @@ def parse_integer_from_args(args, name, default=None, required=False):
return default
-def parse_boolean(request, name, default=None, required=False):
+@overload
+def parse_boolean(request: Request, name: str, default: bool) -> bool:
+ ...
+
+
+@overload
+def parse_boolean(request: Request, name: str, *, required: Literal[True]) -> bool:
+ ...
+
+
+@overload
+def parse_boolean(
+ request: Request, name: str, default: Optional[bool] = None, required: bool = False
+) -> Optional[bool]:
+ ...
+
+
+def parse_boolean(
+ request: Request, name: str, default: Optional[bool] = None, required: bool = False
+) -> Optional[bool]:
"""Parse a boolean parameter from the request query string
Args:
request: the twisted HTTP request.
- name (bytes/unicode): the name of the query parameter.
- default (bool|None): value to use if the parameter is absent, defaults
- to None.
- required (bool): whether to raise a 400 SynapseError if the
- parameter is absent, defaults to False.
+ name: the name of the query parameter.
+ default: value to use if the parameter is absent, defaults to None.
+ required: whether to raise a 400 SynapseError if the parameter is absent,
+ defaults to False.
Returns:
- bool|None: A bool value or the default.
+ A bool value or the default.
Raises:
SynapseError: if the parameter is absent and required, or if the
parameter is present and not one of "true" or "false".
"""
+ args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore
+ return parse_boolean_from_args(args, name, default, required)
+
+
+@overload
+def parse_boolean_from_args(
+ args: Mapping[bytes, Sequence[bytes]],
+ name: str,
+ default: bool,
+) -> bool:
+ ...
+
- return parse_boolean_from_args(request.args, name, default, required)
+@overload
+def parse_boolean_from_args(
+ args: Mapping[bytes, Sequence[bytes]],
+ name: str,
+ *,
+ required: Literal[True],
+) -> bool:
+ ...
-def parse_boolean_from_args(args, name, default=None, required=False):
+@overload
+def parse_boolean_from_args(
+ args: Mapping[bytes, Sequence[bytes]],
+ name: str,
+ default: Optional[bool] = None,
+ required: bool = False,
+) -> Optional[bool]:
+ ...
- if not isinstance(name, bytes):
- name = name.encode("ascii")
- if name in args:
+def parse_boolean_from_args(
+ args: Mapping[bytes, Sequence[bytes]],
+ name: str,
+ default: Optional[bool] = None,
+ required: bool = False,
+) -> Optional[bool]:
+ """Parse a boolean parameter from the request query string
+
+ Args:
+ args: A mapping of request args as bytes to a list of bytes (e.g. request.args).
+ name: the name of the query parameter.
+ default: value to use if the parameter is absent, defaults to None.
+ required: whether to raise a 400 SynapseError if the parameter is absent,
+ defaults to False.
+
+ Returns:
+ A bool value or the default.
+
+ Raises:
+ SynapseError: if the parameter is absent and required, or if the
+ parameter is present and not one of "true" or "false".
+ """
+ name_bytes = name.encode("ascii")
+
+ if name_bytes in args:
try:
- return {b"true": True, b"false": False}[args[name][0]]
+ return {b"true": True, b"false": False}[args[name_bytes][0]]
except Exception:
message = (
"Boolean query parameter %r must be one of ['true', 'false']"
@@ -111,7 +216,7 @@ def parse_boolean_from_args(args, name, default=None, required=False):
@overload
def parse_bytes_from_args(
- args: Dict[bytes, List[bytes]],
+ args: Mapping[bytes, Sequence[bytes]],
name: str,
default: Optional[bytes] = None,
) -> Optional[bytes]:
@@ -120,7 +225,7 @@ def parse_bytes_from_args(
@overload
def parse_bytes_from_args(
- args: Dict[bytes, List[bytes]],
+ args: Mapping[bytes, Sequence[bytes]],
name: str,
default: Literal[None] = None,
*,
@@ -131,7 +236,7 @@ def parse_bytes_from_args(
@overload
def parse_bytes_from_args(
- args: Dict[bytes, List[bytes]],
+ args: Mapping[bytes, Sequence[bytes]],
name: str,
default: Optional[bytes] = None,
required: bool = False,
@@ -140,7 +245,7 @@ def parse_bytes_from_args(
def parse_bytes_from_args(
- args: Dict[bytes, List[bytes]],
+ args: Mapping[bytes, Sequence[bytes]],
name: str,
default: Optional[bytes] = None,
required: bool = False,
@@ -172,6 +277,42 @@ def parse_bytes_from_args(
return default
+@overload
+def parse_string(
+ request: Request,
+ name: str,
+ default: str,
+ *,
+ allowed_values: Optional[Iterable[str]] = None,
+ encoding: str = "ascii",
+) -> str:
+ ...
+
+
+@overload
+def parse_string(
+ request: Request,
+ name: str,
+ *,
+ required: Literal[True],
+ allowed_values: Optional[Iterable[str]] = None,
+ encoding: str = "ascii",
+) -> str:
+ ...
+
+
+@overload
+def parse_string(
+ request: Request,
+ name: str,
+ *,
+ required: bool = False,
+ allowed_values: Optional[Iterable[str]] = None,
+ encoding: str = "ascii",
+) -> Optional[str]:
+ ...
+
+
def parse_string(
request: Request,
name: str,
@@ -179,7 +320,7 @@ def parse_string(
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
-):
+) -> Optional[str]:
"""
Parse a string parameter from the request query string.
@@ -205,7 +346,7 @@ def parse_string(
parameter is present, must be one of a list of allowed values and
is not one of those allowed values.
"""
- args: Dict[bytes, List[bytes]] = request.args # type: ignore
+ args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore
return parse_string_from_args(
args,
name,
@@ -239,9 +380,8 @@ def _parse_string_value(
@overload
def parse_strings_from_args(
- args: Dict[bytes, List[bytes]],
+ args: Mapping[bytes, Sequence[bytes]],
name: str,
- default: Optional[List[str]] = None,
*,
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
@@ -251,9 +391,20 @@ def parse_strings_from_args(
@overload
def parse_strings_from_args(
- args: Dict[bytes, List[bytes]],
+ args: Mapping[bytes, Sequence[bytes]],
+ name: str,
+ default: List[str],
+ *,
+ allowed_values: Optional[Iterable[str]] = None,
+ encoding: str = "ascii",
+) -> List[str]:
+ ...
+
+
+@overload
+def parse_strings_from_args(
+ args: Mapping[bytes, Sequence[bytes]],
name: str,
- default: Optional[List[str]] = None,
*,
required: Literal[True],
allowed_values: Optional[Iterable[str]] = None,
@@ -264,7 +415,7 @@ def parse_strings_from_args(
@overload
def parse_strings_from_args(
- args: Dict[bytes, List[bytes]],
+ args: Mapping[bytes, Sequence[bytes]],
name: str,
default: Optional[List[str]] = None,
*,
@@ -276,7 +427,7 @@ def parse_strings_from_args(
def parse_strings_from_args(
- args: Dict[bytes, List[bytes]],
+ args: Mapping[bytes, Sequence[bytes]],
name: str,
default: Optional[List[str]] = None,
required: bool = False,
@@ -325,7 +476,7 @@ def parse_strings_from_args(
@overload
def parse_string_from_args(
- args: Dict[bytes, List[bytes]],
+ args: Mapping[bytes, Sequence[bytes]],
name: str,
default: Optional[str] = None,
*,
@@ -337,7 +488,7 @@ def parse_string_from_args(
@overload
def parse_string_from_args(
- args: Dict[bytes, List[bytes]],
+ args: Mapping[bytes, Sequence[bytes]],
name: str,
default: Optional[str] = None,
*,
@@ -350,7 +501,7 @@ def parse_string_from_args(
@overload
def parse_string_from_args(
- args: Dict[bytes, List[bytes]],
+ args: Mapping[bytes, Sequence[bytes]],
name: str,
default: Optional[str] = None,
required: bool = False,
@@ -361,7 +512,7 @@ def parse_string_from_args(
def parse_string_from_args(
- args: Dict[bytes, List[bytes]],
+ args: Mapping[bytes, Sequence[bytes]],
name: str,
default: Optional[str] = None,
required: bool = False,
@@ -409,13 +560,14 @@ def parse_string_from_args(
return strings[0]
-def parse_json_value_from_request(request, allow_empty_body=False):
+def parse_json_value_from_request(
+ request: Request, allow_empty_body: bool = False
+) -> Optional[JsonDict]:
"""Parse a JSON value from the body of a twisted HTTP request.
Args:
request: the twisted HTTP request.
- allow_empty_body (bool): if True, an empty body will be accepted and
- turned into None
+ allow_empty_body: if True, an empty body will be accepted and turned into None
Returns:
The JSON value.
@@ -424,7 +576,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
SynapseError if the request body couldn't be decoded as JSON.
"""
try:
- content_bytes = request.content.read()
+ content_bytes = request.content.read() # type: ignore
except Exception:
raise SynapseError(400, "Error reading JSON content.")
@@ -440,13 +592,15 @@ def parse_json_value_from_request(request, allow_empty_body=False):
return content
-def parse_json_object_from_request(request, allow_empty_body=False):
+def parse_json_object_from_request(
+ request: Request, allow_empty_body: bool = False
+) -> JsonDict:
"""Parse a JSON object from the body of a twisted HTTP request.
Args:
request: the twisted HTTP request.
- allow_empty_body (bool): if True, an empty body will be accepted and
- turned into an empty dict.
+ allow_empty_body: if True, an empty body will be accepted and turned into
+ an empty dict.
Raises:
SynapseError if the request body couldn't be decoded as JSON or
@@ -457,14 +611,14 @@ def parse_json_object_from_request(request, allow_empty_body=False):
if allow_empty_body and content is None:
return {}
- if type(content) != dict:
+ if not isinstance(content, dict):
message = "Content must be a JSON object."
raise SynapseError(400, message, errcode=Codes.BAD_JSON)
return content
-def assert_params_in_dict(body, required):
+def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
absent = []
for k in required:
if k not in body:
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 18ac507802..02e5ddd2ef 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -25,7 +25,7 @@ See doc/log_contexts.rst for details on how this works.
import inspect
import logging
import threading
-import types
+import typing
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
@@ -745,7 +745,7 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
# by synchronous exceptions, so let's turn them into Failures.
return defer.fail()
- if isinstance(res, types.CoroutineType):
+ if isinstance(res, typing.Coroutine):
res = defer.ensureDeferred(res)
# At this point we should have a Deferred, if not then f was a synchronous
diff --git a/synapse/logging/handlers.py b/synapse/logging/handlers.py
new file mode 100644
index 0000000000..a6c212f300
--- /dev/null
+++ b/synapse/logging/handlers.py
@@ -0,0 +1,88 @@
+import logging
+import time
+from logging import Handler, LogRecord
+from logging.handlers import MemoryHandler
+from threading import Thread
+from typing import Optional
+
+from twisted.internet.interfaces import IReactorCore
+
+
+class PeriodicallyFlushingMemoryHandler(MemoryHandler):
+ """
+ This is a subclass of MemoryHandler that additionally spawns a background
+ thread to periodically flush the buffer.
+
+ This prevents messages from being buffered for too long.
+
+ Additionally, all messages will be immediately flushed if the reactor has
+ not yet been started.
+ """
+
+ def __init__(
+ self,
+ capacity: int,
+ flushLevel: int = logging.ERROR,
+ target: Optional[Handler] = None,
+ flushOnClose: bool = True,
+ period: float = 5.0,
+ reactor: Optional[IReactorCore] = None,
+ ) -> None:
+ """
+ period: the period between automatic flushes
+
+ reactor: if specified, a custom reactor to use. If not specifies,
+ defaults to the globally-installed reactor.
+ Log entries will be flushed immediately until this reactor has
+ started.
+ """
+ super().__init__(capacity, flushLevel, target, flushOnClose)
+
+ self._flush_period: float = period
+ self._active: bool = True
+ self._reactor_started = False
+
+ self._flushing_thread: Thread = Thread(
+ name="PeriodicallyFlushingMemoryHandler flushing thread",
+ target=self._flush_periodically,
+ )
+ self._flushing_thread.start()
+
+ def on_reactor_running():
+ self._reactor_started = True
+
+ reactor_to_use: IReactorCore
+ if reactor is None:
+ from twisted.internet import reactor as global_reactor
+
+ reactor_to_use = global_reactor # type: ignore[assignment]
+ else:
+ reactor_to_use = reactor
+
+ # call our hook when the reactor start up
+ reactor_to_use.callWhenRunning(on_reactor_running)
+
+ def shouldFlush(self, record: LogRecord) -> bool:
+ """
+ Before reactor start-up, log everything immediately.
+ Otherwise, fall back to original behaviour of waiting for the buffer to fill.
+ """
+
+ if self._reactor_started:
+ return super().shouldFlush(record)
+ else:
+ return True
+
+ def _flush_periodically(self):
+ """
+ Whilst this handler is active, flush the handler periodically.
+ """
+
+ while self._active:
+ # flush is thread-safe; it acquires and releases the lock internally
+ self.flush()
+ time.sleep(self._flush_period)
+
+ def close(self) -> None:
+ self._active = False
+ super().close()
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 1259fc2d90..473812b8e2 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -484,7 +484,7 @@ class ModuleApi:
@defer.inlineCallbacks
def get_state_events_in_room(
self, room_id: str, types: Iterable[Tuple[str, Optional[str]]]
- ) -> Generator[defer.Deferred, Any, defer.Deferred]:
+ ) -> Generator[defer.Deferred, Any, Iterable[EventBase]]:
"""Gets current state events for the given room.
(This is exposed for compatibility with the old SpamCheckerApi. We should
diff --git a/synapse/notifier.py b/synapse/notifier.py
index c5fbebc17d..bbe337949a 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -111,8 +111,9 @@ class _NotifierUserStream:
self.last_notified_token = current_token
self.last_notified_ms = time_now_ms
- with PreserveLoggingContext():
- self.notify_deferred = ObservableDeferred(defer.Deferred())
+ self.notify_deferred: ObservableDeferred[StreamToken] = ObservableDeferred(
+ defer.Deferred()
+ )
def notify(
self,
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 7be5fe1e9b..941fb238b7 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
import bleach
import jinja2
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RoomTypes
from synapse.api.errors import StoreError
from synapse.config.emailconfig import EmailSubjectConfig
from synapse.events import EventBase
@@ -600,6 +600,22 @@ class Mailer:
"app": self.app_name,
}
+ # If the room is a space, it gets a slightly different topic.
+ create_event_id = room_state_ids.get(("m.room.create", ""))
+ if create_event_id:
+ create_event = await self.store.get_event(
+ create_event_id, allow_none=True
+ )
+ if (
+ create_event
+ and create_event.content.get("room_type") == RoomTypes.SPACE
+ ):
+ return self.email_subjects.invite_from_person_to_space % {
+ "person": inviter_name,
+ "space": room_name,
+ "app": self.app_name,
+ }
+
return self.email_subjects.invite_from_person_to_room % {
"person": inviter_name,
"room": room_name,
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 85621f33ef..a1436f3930 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 9d4859798b..3fd2811713 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -285,7 +285,7 @@ class ReplicationDataHandler:
# Create a new deferred that times out after N seconds, as we don't want
# to wedge here forever.
- deferred = Deferred()
+ deferred: "Deferred[None]" = Deferred()
deferred = timeout_deferred(
deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor
)
@@ -393,6 +393,11 @@ class FederationSenderHandler:
# we only want to send on receipts for our own users
if not self._is_mine_id(receipt.user_id):
continue
+ if (
+ receipt.data.get("hidden", False)
+ and self._hs.config.experimental.msc2285_enabled
+ ):
+ continue
receipt_info = ReadReceipt(
receipt.room_id,
receipt.receipt_type,
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 589e47fa47..eef76ab18a 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -62,6 +62,7 @@ class UsersRestServletV2(RestServlet):
The parameter `name` can be used to filter by user id or display name.
The parameter `guests` can be used to exclude guest users.
The parameter `deactivated` can be used to include deactivated users.
+ The parameter `order_by` can be used to order the result.
"""
def __init__(self, hs: "HomeServer"):
@@ -90,8 +91,8 @@ class UsersRestServletV2(RestServlet):
errcode=Codes.INVALID_PARAM,
)
- user_id = parse_string(request, "user_id", default=None)
- name = parse_string(request, "name", default=None)
+ user_id = parse_string(request, "user_id")
+ name = parse_string(request, "name")
guests = parse_boolean(request, "guests", default=True)
deactivated = parse_boolean(request, "deactivated", default=False)
@@ -108,6 +109,7 @@ class UsersRestServletV2(RestServlet):
UserSortOrder.USER_TYPE.value,
UserSortOrder.AVATAR_URL.value,
UserSortOrder.SHADOW_BANNED.value,
+ UserSortOrder.CREATION_TS.value,
),
)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 31a1193cd3..502a917588 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -413,7 +413,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
assert_params_in_dict(body, ["state_events_at_start", "events"])
prev_events_from_query = parse_strings_from_args(request.args, "prev_event")
- chunk_id_from_query = parse_string(request, "chunk_id", default=None)
+ chunk_id_from_query = parse_string(request, "chunk_id")
if prev_events_from_query is None:
raise SynapseError(
@@ -504,7 +504,6 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
events_to_create = body["events"]
- prev_event_ids = prev_events_from_query
inherited_depth = await self._inherit_depth_from_prev_ids(
prev_events_from_query
)
@@ -516,6 +515,10 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
chunk_id_to_connect_to = chunk_id_from_query
base_insertion_event = None
if chunk_id_from_query:
+ # All but the first base insertion event should point at a fake
+ # event, which causes the HS to ask for the state at the start of
+ # the chunk later.
+ prev_event_ids = [fake_prev_event_id]
# TODO: Verify the chunk_id_from_query corresponds to an insertion event
pass
# Otherwise, create an insertion event to act as a starting point.
@@ -526,6 +529,8 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
# an insertion event), in which case we just create a new insertion event
# that can then get pointed to by a "marker" event later.
else:
+ prev_event_ids = prev_events_from_query
+
base_insertion_event_dict = self._create_insertion_event_dict(
sender=requester.user.to_string(),
room_id=room_id,
@@ -553,9 +558,18 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
]
# Connect this current chunk to the insertion event from the previous chunk
- last_event_in_chunk["content"][
- EventContentFields.MSC2716_CHUNK_ID
- ] = chunk_id_to_connect_to
+ chunk_event = {
+ "type": EventTypes.MSC2716_CHUNK,
+ "sender": requester.user.to_string(),
+ "room_id": room_id,
+ "content": {EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to},
+ # Since the chunk event is put at the end of the chunk,
+ # where the newest-in-time event is, copy the origin_server_ts from
+ # the last event we're inserting
+ "origin_server_ts": last_event_in_chunk["origin_server_ts"],
+ }
+ # Add the chunk event to the end of the chunk (newest-in-time)
+ events_to_create.append(chunk_event)
# Add an "insertion" event to the start of each chunk (next to the oldest-in-time
# event in the chunk) so the next chunk can be connected to this one.
@@ -567,7 +581,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
# the first event we're inserting
origin_server_ts=events_to_create[0]["origin_server_ts"],
)
- # Prepend the insertion event to the start of the chunk
+ # Prepend the insertion event to the start of the chunk (oldest-in-time)
events_to_create = [insertion_event] + events_to_create
event_ids = []
@@ -726,7 +740,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
self.auth = hs.get_auth()
async def on_GET(self, request):
- server = parse_string(request, "server", default=None)
+ server = parse_string(request, "server")
try:
await self.auth.get_user_by_req(request, allow_guest=True)
@@ -745,8 +759,8 @@ class PublicRoomListRestServlet(TransactionRestServlet):
if server:
raise e
- limit = parse_integer(request, "limit", 0)
- since_token = parse_string(request, "since", None)
+ limit: Optional[int] = parse_integer(request, "limit", 0)
+ since_token = parse_string(request, "since")
if limit == 0:
# zero is a special value which corresponds to no limit.
@@ -780,7 +794,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
async def on_POST(self, request):
await self.auth.get_user_by_req(request, allow_guest=True)
- server = parse_string(request, "server", default=None)
+ server = parse_string(request, "server")
content = parse_json_object_from_request(request)
limit: Optional[int] = int(content.get("limit", 100))
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 085561d3e9..fb5ad2906e 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -884,7 +884,14 @@ class WhoamiRestServlet(RestServlet):
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request)
- return 200, {"user_id": requester.user.to_string()}
+ response = {"user_id": requester.user.to_string()}
+
+ # Appservices and similar accounts do not have device IDs
+ # that we can report on, so exclude them for compliance.
+ if requester.device_id is not None:
+ response["device_id"] = requester.device_id
+
+ return 200, response
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index 6a24021484..88e3aac797 100644
--- a/synapse/rest/client/v2_alpha/capabilities.py
+++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -14,7 +14,7 @@
import logging
from typing import TYPE_CHECKING, Tuple
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
@@ -55,6 +55,12 @@ class CapabilitiesRestServlet(RestServlet):
"m.change_password": {"enabled": change_password},
}
}
+
+ if self.config.experimental.msc3244_enabled:
+ response["capabilities"]["m.room_versions"][
+ "org.matrix.msc3244.room_capabilities"
+ ] = MSC3244_CAPABILITIES
+
return 200, response
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 33cf8de186..d0d9d30d40 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -194,7 +194,7 @@ class KeyChangesServlet(RestServlet):
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- from_token_string = parse_string(request, "from")
+ from_token_string = parse_string(request, "from", required=True)
set_tag("from", from_token_string)
# We want to enforce they do pass us one, but we ignore it and return
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
index 5988fa47e5..027f8b81fa 100644
--- a/synapse/rest/client/v2_alpha/read_marker.py
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -14,6 +14,8 @@
import logging
+from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_patterns
@@ -37,14 +39,24 @@ class ReadMarkerRestServlet(RestServlet):
await self.presence_handler.bump_presence_active_time(requester.user)
body = parse_json_object_from_request(request)
-
read_event_id = body.get("m.read", None)
+ hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False)
+
+ if not isinstance(hidden, bool):
+ raise SynapseError(
+ 400,
+ "Param %s must be a boolean, if given"
+ % ReadReceiptEventFields.MSC2285_HIDDEN,
+ Codes.BAD_JSON,
+ )
+
if read_event_id:
await self.receipts_handler.received_client_receipt(
room_id,
"m.read",
user_id=requester.user.to_string(),
event_id=read_event_id,
+ hidden=hidden,
)
read_marker_event_id = body.get("m.fully_read", None)
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index 8cf4aebdbe..4b98979b47 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -14,8 +14,9 @@
import logging
-from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet
+from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_patterns
@@ -42,10 +43,25 @@ class ReceiptRestServlet(RestServlet):
if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'")
+ body = parse_json_object_from_request(request)
+ hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False)
+
+ if not isinstance(hidden, bool):
+ raise SynapseError(
+ 400,
+ "Param %s must be a boolean, if given"
+ % ReadReceiptEventFields.MSC2285_HIDDEN,
+ Codes.BAD_JSON,
+ )
+
await self.presence_handler.bump_presence_active_time(requester.user)
await self.receipts_handler.received_client_receipt(
- room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id
+ room_id,
+ receipt_type,
+ user_id=requester.user.to_string(),
+ event_id=event_id,
+ hidden=hidden,
)
return 200, {}
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index c7da6759db..0821cd285f 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -158,19 +158,21 @@ class RelationPaginationServlet(RestServlet):
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
limit = parse_integer(request, "limit", default=5)
- from_token = parse_string(request, "from")
- to_token = parse_string(request, "to")
+ from_token_str = parse_string(request, "from")
+ to_token_str = parse_string(request, "to")
if event.internal_metadata.is_redacted():
# If the event is redacted, return an empty list of relations
pagination_chunk = PaginationChunk(chunk=[])
else:
# Return the relations
- if from_token:
- from_token = RelationPaginationToken.from_string(from_token)
+ from_token = None
+ if from_token_str:
+ from_token = RelationPaginationToken.from_string(from_token_str)
- if to_token:
- to_token = RelationPaginationToken.from_string(to_token)
+ to_token = None
+ if to_token_str:
+ to_token = RelationPaginationToken.from_string(to_token_str)
pagination_chunk = await self.store.get_relations_for_event(
event_id=parent_id,
@@ -256,19 +258,21 @@ class RelationAggregationPaginationServlet(RestServlet):
raise SynapseError(400, "Relation type must be 'annotation'")
limit = parse_integer(request, "limit", default=5)
- from_token = parse_string(request, "from")
- to_token = parse_string(request, "to")
+ from_token_str = parse_string(request, "from")
+ to_token_str = parse_string(request, "to")
if event.internal_metadata.is_redacted():
# If the event is redacted, return an empty list of relations
pagination_chunk = PaginationChunk(chunk=[])
else:
# Return the relations
- if from_token:
- from_token = AggregationPaginationToken.from_string(from_token)
+ from_token = None
+ if from_token_str:
+ from_token = AggregationPaginationToken.from_string(from_token_str)
- if to_token:
- to_token = AggregationPaginationToken.from_string(to_token)
+ to_token = None
+ if to_token_str:
+ to_token = AggregationPaginationToken.from_string(to_token_str)
pagination_chunk = await self.store.get_aggregation_groups_for_event(
event_id=parent_id,
@@ -336,14 +340,16 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
raise SynapseError(400, "Relation type must be 'annotation'")
limit = parse_integer(request, "limit", default=5)
- from_token = parse_string(request, "from")
- to_token = parse_string(request, "to")
+ from_token_str = parse_string(request, "from")
+ to_token_str = parse_string(request, "to")
- if from_token:
- from_token = RelationPaginationToken.from_string(from_token)
+ from_token = None
+ if from_token_str:
+ from_token = RelationPaginationToken.from_string(from_token_str)
- if to_token:
- to_token = RelationPaginationToken.from_string(to_token)
+ to_token = None
+ if to_token_str:
+ to_token = RelationPaginationToken.from_string(to_token_str)
result = await self.store.get_relations_for_event(
event_id=parent_id,
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 32e8500795..e321668698 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -112,7 +112,7 @@ class SyncRestServlet(RestServlet):
default="online",
allowed_values=self.ALLOWED_PRESENCE,
)
- filter_id = parse_string(request, "filter", default=None)
+ filter_id = parse_string(request, "filter")
full_state = parse_boolean(request, "full_state", default=False)
logger.debug(
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 4582c274c7..fa2e4e9cba 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -82,6 +82,8 @@ class VersionsRestServlet(RestServlet):
"io.element.e2ee_forced.trusted_private": self.e2ee_forced_trusted_private,
# Supports the busy presence state described in MSC3026.
"org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled,
+ # Supports receiving hidden read receipts as per MSC2285
+ "org.matrix.msc2285": self.config.experimental.msc2285_enabled,
},
},
)
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 4282e2b228..11f7320832 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -112,7 +112,7 @@ class ConsentResource(DirectServeHtmlResource):
request (twisted.web.http.Request):
"""
version = parse_string(request, "v", default=self._default_consent_version)
- username = parse_string(request, "u", required=False, default="")
+ username = parse_string(request, "u", default="")
userhmac = None
has_consented = False
public_version = username == ""
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index cd2468f9c5..d6d938953e 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -49,6 +49,8 @@ class DownloadResource(DirectServeJsonResource):
b" media-src 'self';"
b" object-src 'self';",
)
+ # Limited non-standard form of CSP for IE11
+ request.setHeader(b"X-Content-Security-Policy", b"sandbox;")
request.setHeader(
b"Referrer-Policy",
b"no-referrer",
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 8e7fead3a2..0f051d4041 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -58,9 +58,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-_charset_match = re.compile(br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9-]+)"?', flags=re.I)
+_charset_match = re.compile(
+ br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I
+)
_xml_encoding_match = re.compile(
- br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9-]+)"', flags=re.I
+ br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I
)
_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
@@ -186,15 +188,11 @@ class PreviewUrlResource(DirectServeJsonResource):
respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_GET(self, request: SynapseRequest) -> None:
- # This will always be set by the time Twisted calls us.
- assert request.args is not None
-
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request)
- url = parse_string(request, "url")
- if b"ts" in request.args:
- ts = parse_integer(request, "ts")
- else:
+ url = parse_string(request, "url", required=True)
+ ts = parse_integer(request, "ts")
+ if ts is None:
ts = self.clock.time_msec()
# XXX: we could move this into _do_preview if we wanted.
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 6223daf522..463ce58dae 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,6 +16,7 @@ import heapq
import logging
from collections import defaultdict, namedtuple
from typing import (
+ TYPE_CHECKING,
Any,
Awaitable,
Callable,
@@ -52,6 +53,10 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure, measure_func
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+ from synapse.storage.databases.main import DataStore
+
logger = logging.getLogger(__name__)
metrics_logger = logging.getLogger("synapse.state.metrics")
@@ -74,7 +79,7 @@ _NEXT_STATE_ID = 1
POWER_KEY = (EventTypes.PowerLevels, "")
-def _gen_state_id():
+def _gen_state_id() -> str:
global _NEXT_STATE_ID
s = "X%d" % (_NEXT_STATE_ID,)
_NEXT_STATE_ID += 1
@@ -109,7 +114,7 @@ class _StateCacheEntry:
# `state_id` is either a state_group (and so an int) or a string. This
# ensures we don't accidentally persist a state_id as a stateg_group
if state_group:
- self.state_id = state_group
+ self.state_id: Union[str, int] = state_group
else:
self.state_id = _gen_state_id()
@@ -122,7 +127,7 @@ class StateHandler:
where necessary
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.state_store = hs.get_storage().state
@@ -507,7 +512,7 @@ class StateResolutionHandler:
be storage-independent.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
@@ -636,16 +641,20 @@ class StateResolutionHandler:
"""
try:
with Measure(self.clock, "state._resolve_events") as m:
- v = KNOWN_ROOM_VERSIONS[room_version]
- if v.state_res == StateResolutionVersions.V1:
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+ if room_version_obj.state_res == StateResolutionVersions.V1:
return await v1.resolve_events_with_store(
- room_id, state_sets, event_map, state_res_store.get_events
+ room_id,
+ room_version_obj,
+ state_sets,
+ event_map,
+ state_res_store.get_events,
)
else:
return await v2.resolve_events_with_store(
self.clock,
room_id,
- room_version,
+ room_version_obj,
state_sets,
event_map,
state_res_store,
@@ -653,13 +662,15 @@ class StateResolutionHandler:
finally:
self._record_state_res_metrics(room_id, m.get_resource_usage())
- def _record_state_res_metrics(self, room_id: str, rusage: ContextResourceUsage):
+ def _record_state_res_metrics(
+ self, room_id: str, rusage: ContextResourceUsage
+ ) -> None:
room_metrics = self._state_res_metrics[room_id]
room_metrics.cpu_time += rusage.ru_utime + rusage.ru_stime
room_metrics.db_time += rusage.db_txn_duration_sec
room_metrics.db_events += rusage.evt_db_fetch_count
- def _report_metrics(self):
+ def _report_metrics(self) -> None:
if not self._state_res_metrics:
# no state res has happened since the last iteration: don't bother logging.
return
@@ -769,16 +780,13 @@ def _make_state_cache_entry(
)
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
class StateResolutionStore:
"""Interface that allows state resolution algorithms to access the database
in well defined way.
-
- Args:
- store (DataStore)
"""
- store = attr.ib()
+ store: "DataStore"
def get_events(
self, event_ids: Iterable[str], allow_rejected: bool = False
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 267193cedf..92336d7cc8 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -29,7 +29,7 @@ from typing import (
from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
-from synapse.api.room_versions import RoomVersions
+from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
@@ -41,6 +41,7 @@ POWER_KEY = (EventTypes.PowerLevels, "")
async def resolve_events_with_store(
room_id: str,
+ room_version: RoomVersion,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
@@ -104,7 +105,7 @@ async def resolve_events_with_store(
# get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state.
auth_events = _create_auth_events_from_maps(
- unconflicted_state, conflicted_state, state_map
+ room_version, unconflicted_state, conflicted_state, state_map
)
new_needed_events = set(auth_events.values())
@@ -132,7 +133,7 @@ async def resolve_events_with_store(
state_map.update(state_map_new)
return _resolve_with_state(
- unconflicted_state, conflicted_state, auth_events, state_map
+ room_version, unconflicted_state, conflicted_state, auth_events, state_map
)
@@ -187,6 +188,7 @@ def _seperate(
def _create_auth_events_from_maps(
+ room_version: RoomVersion,
unconflicted_state: StateMap[str],
conflicted_state: StateMap[Set[str]],
state_map: Dict[str, EventBase],
@@ -194,6 +196,7 @@ def _create_auth_events_from_maps(
"""
Args:
+ room_version: The room version.
unconflicted_state: The unconflicted state map.
conflicted_state: The conflicted state map.
state_map:
@@ -205,7 +208,9 @@ def _create_auth_events_from_maps(
for event_ids in conflicted_state.values():
for event_id in event_ids:
if event_id in state_map:
- keys = event_auth.auth_types_for_event(state_map[event_id])
+ keys = event_auth.auth_types_for_event(
+ room_version, state_map[event_id]
+ )
for key in keys:
if key not in auth_events:
auth_event_id = unconflicted_state.get(key, None)
@@ -215,6 +220,7 @@ def _create_auth_events_from_maps(
def _resolve_with_state(
+ room_version: RoomVersion,
unconflicted_state_ids: MutableStateMap[str],
conflicted_state_ids: StateMap[Set[str]],
auth_event_ids: StateMap[str],
@@ -235,7 +241,9 @@ def _resolve_with_state(
}
try:
- resolved_state = _resolve_state_events(conflicted_state, auth_events)
+ resolved_state = _resolve_state_events(
+ room_version, conflicted_state, auth_events
+ )
except Exception:
logger.exception("Failed to resolve state")
raise
@@ -248,7 +256,9 @@ def _resolve_with_state(
def _resolve_state_events(
- conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
+ room_version: RoomVersion,
+ conflicted_state: StateMap[List[EventBase]],
+ auth_events: MutableStateMap[EventBase],
) -> StateMap[EventBase]:
"""This is where we actually decide which of the conflicted state to
use.
@@ -263,21 +273,27 @@ def _resolve_state_events(
if POWER_KEY in conflicted_state:
events = conflicted_state[POWER_KEY]
logger.debug("Resolving conflicted power levels %r", events)
- resolved_state[POWER_KEY] = _resolve_auth_events(events, auth_events)
+ resolved_state[POWER_KEY] = _resolve_auth_events(
+ room_version, events, auth_events
+ )
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
- resolved_state[key] = _resolve_auth_events(events, auth_events)
+ resolved_state[key] = _resolve_auth_events(
+ room_version, events, auth_events
+ )
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
- resolved_state[key] = _resolve_auth_events(events, auth_events)
+ resolved_state[key] = _resolve_auth_events(
+ room_version, events, auth_events
+ )
auth_events.update(resolved_state)
@@ -290,12 +306,14 @@ def _resolve_state_events(
def _resolve_auth_events(
- events: List[EventBase], auth_events: StateMap[EventBase]
+ room_version: RoomVersion, events: List[EventBase], auth_events: StateMap[EventBase]
) -> EventBase:
reverse = list(reversed(_ordered_events(events)))
auth_keys = {
- key for event in events for key in event_auth.auth_types_for_event(event)
+ key
+ for event in events
+ for key in event_auth.auth_types_for_event(room_version, event)
}
new_auth_events = {}
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index e66e6571c8..7b1e8361de 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -36,7 +36,7 @@ import synapse.state
from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
from synapse.util import Clock
@@ -53,7 +53,7 @@ _AWAIT_AFTER_ITERATIONS = 100
async def resolve_events_with_store(
clock: Clock,
room_id: str,
- room_version: str,
+ room_version: RoomVersion,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "synapse.state.StateResolutionStore",
@@ -497,7 +497,7 @@ async def _reverse_topological_power_sort(
async def _iterative_auth_checks(
clock: Clock,
room_id: str,
- room_version: str,
+ room_version: RoomVersion,
event_ids: List[str],
base_state: StateMap[str],
event_map: Dict[str, EventBase],
@@ -519,7 +519,6 @@ async def _iterative_auth_checks(
Returns the final updated state
"""
resolved_state = dict(base_state)
- room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
for idx, event_id in enumerate(event_ids, start=1):
event = event_map[event_id]
@@ -538,7 +537,7 @@ async def _iterative_auth_checks(
if ev.rejected_reason is None:
auth_events[(ev.type, ev.state_key)] = ev
- for key in event_auth.auth_types_for_event(event):
+ for key in event_auth.auth_types_for_event(room_version, event):
if key in resolved_state:
ev_id = resolved_state[key]
ev = await _get_event(room_id, ev_id, event_map, state_res_store)
@@ -548,7 +547,7 @@ async def _iterative_auth_checks(
try:
event_auth.check(
- room_version_obj,
+ room_version,
event,
auth_events,
do_sig_check=False,
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ccf9ac51ef..c8015a3848 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
import time
+from collections import defaultdict
from sys import intern
from time import monotonic as monotonic_time
from typing import (
@@ -397,6 +398,7 @@ class DatabasePool:
):
self.hs = hs
self._clock = hs.get_clock()
+ self._txn_limit = database_config.config.get("txn_limit", 0)
self._database_config = database_config
self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
@@ -406,6 +408,9 @@ class DatabasePool:
self._current_txn_total_time = 0.0
self._previous_loop_ts = 0.0
+ # Transaction counter: key is the twisted thread id, value is the current count
+ self._txn_counters: Dict[int, int] = defaultdict(int)
+
# TODO(paul): These can eventually be removed once the metrics code
# is running in mainline, and we have some nice monitoring frontends
# to watch it
@@ -750,10 +755,26 @@ class DatabasePool:
sql_scheduling_timer.observe(sched_duration_sec)
context.add_database_scheduled(sched_duration_sec)
+ if self._txn_limit > 0:
+ tid = self._db_pool.threadID()
+ self._txn_counters[tid] += 1
+
+ if self._txn_counters[tid] > self._txn_limit:
+ logger.debug(
+ "Reconnecting database connection over transaction limit"
+ )
+ conn.reconnect()
+ opentracing.log_kv(
+ {"message": "reconnected due to txn limit"}
+ )
+ self._txn_counters[tid] = 1
+
if self.engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
opentracing.log_kv({"message": "reconnected"})
+ if self._txn_limit > 0:
+ self._txn_counters[tid] = 1
try:
if db_autocommit:
@@ -832,31 +853,16 @@ class DatabasePool:
self,
table: str,
values: Dict[str, Any],
- or_ignore: bool = False,
desc: str = "simple_insert",
- ) -> bool:
+ ) -> None:
"""Executes an INSERT query on the named table.
Args:
table: string giving the table name
values: dict of new column names and values for them
- or_ignore: bool stating whether an exception should be raised
- when a conflicting row already exists. If True, False will be
- returned by the function instead
desc: description of the transaction, for logging and metrics
-
- Returns:
- Whether the row was inserted or not. Only useful when `or_ignore` is True
"""
- try:
- await self.runInteraction(desc, self.simple_insert_txn, table, values)
- except self.engine.module.IntegrityError:
- # We have to do or_ignore flag at this layer, since we can't reuse
- # a cursor after we receive an error from the db.
- if not or_ignore:
- raise
- return False
- return True
+ await self.runInteraction(desc, self.simple_insert_txn, table, values)
@staticmethod
def simple_insert_txn(
@@ -930,7 +936,7 @@ class DatabasePool:
insertion_values: Optional[Dict[str, Any]] = None,
desc: str = "simple_upsert",
lock: bool = True,
- ) -> Optional[bool]:
+ ) -> bool:
"""
`lock` should generally be set to True (the default), but can be set
@@ -951,8 +957,8 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
lock: True to lock the table when doing the upsert.
Returns:
- Native upserts always return None. Emulated upserts return True if a
- new entry was created, False if an existing one was updated.
+ Returns True if a row was inserted or updated (i.e. if `values` is
+ not empty then this always returns True)
"""
insertion_values = insertion_values or {}
@@ -995,7 +1001,7 @@ class DatabasePool:
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
lock: bool = True,
- ) -> Optional[bool]:
+ ) -> bool:
"""
Pick the UPSERT method which works best on the platform. Either the
native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
@@ -1008,16 +1014,15 @@ class DatabasePool:
insertion_values: additional key/values to use only when inserting
lock: True to lock the table when doing the upsert.
Returns:
- Native upserts always return None. Emulated upserts return True if a
- new entry was created, False if an existing one was updated.
+ Returns True if a row was inserted or updated (i.e. if `values` is
+ not empty then this always returns True)
"""
insertion_values = insertion_values or {}
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
- self.simple_upsert_txn_native_upsert(
+ return self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values
)
- return None
else:
return self.simple_upsert_txn_emulated(
txn,
@@ -1045,8 +1050,8 @@ class DatabasePool:
insertion_values: additional key/values to use only when inserting
lock: True to lock the table when doing the upsert.
Returns:
- Returns True if a new entry was created, False if an existing
- one was updated.
+ Returns True if a row was inserted or updated (i.e. if `values` is
+ not empty then this always returns True)
"""
insertion_values = insertion_values or {}
@@ -1086,8 +1091,7 @@ class DatabasePool:
txn.execute(sql, sqlargs)
if txn.rowcount > 0:
- # successfully updated at least one row.
- return False
+ return True
# We didn't find any existing rows, so insert a new one
allvalues: Dict[str, Any] = {}
@@ -1111,15 +1115,19 @@ class DatabasePool:
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
- ) -> None:
+ ) -> bool:
"""
- Use the native UPSERT functionality in recent PostgreSQL versions.
+ Use the native UPSERT functionality in PostgreSQL.
Args:
table: The table to upsert into
keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
+
+ Returns:
+ Returns True if a row was inserted or updated (i.e. if `values` is
+ not empty then this always returns True)
"""
allvalues: Dict[str, Any] = {}
allvalues.update(keyvalues)
@@ -1140,6 +1148,8 @@ class DatabasePool:
)
txn.execute(sql, list(allvalues.values()))
+ return bool(txn.rowcount)
+
async def simple_upsert_many(
self,
table: str,
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index a3fddea042..8d9f07111d 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -249,7 +249,7 @@ class DataStore(
name: Optional[str] = None,
guests: bool = True,
deactivated: bool = False,
- order_by: UserSortOrder = UserSortOrder.USER_ID.value,
+ order_by: str = UserSortOrder.USER_ID.value,
direction: str = "f",
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users from
@@ -297,27 +297,22 @@ class DataStore(
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
- sql_base = """
+ sql_base = f"""
FROM users as u
LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
- {}
- """.format(
- where_clause
- )
+ {where_clause}
+ """
sql = "SELECT COUNT(*) as total_users " + sql_base
txn.execute(sql, args)
count = txn.fetchone()[0]
- sql = """
- SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url
+ sql = f"""
+ SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
+ displayname, avatar_url, creation_ts * 1000 as creation_ts
{sql_base}
ORDER BY {order_by_column} {order}, u.name ASC
LIMIT ? OFFSET ?
- """.format(
- sql_base=sql_base,
- order_by_column=order_by_column,
- order=order,
- )
+ """
args += [limit, start]
txn.execute(sql, args)
users = self.db_pool.cursor_to_dict(txn)
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 18f07d96dc..3816a0ca53 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1078,16 +1078,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return False
try:
- inserted = await self.db_pool.simple_insert(
+ inserted = await self.db_pool.simple_upsert(
"devices",
- values={
+ keyvalues={
"user_id": user_id,
"device_id": device_id,
+ },
+ values={},
+ insertion_values={
"display_name": initial_device_display_name,
"hidden": False,
},
desc="store_device",
- or_ignore=True,
)
if not inserted:
# if the device already exists, check if it's a real device, or
@@ -1099,6 +1101,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
if hidden:
raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
+
self.device_id_exists_cache.set(key, True)
return inserted
except StoreError:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index d39368c20e..44018c1c31 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -16,11 +16,11 @@ import logging
from queue import Empty, PriorityQueue
from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple
-from prometheus_client import Gauge
+from prometheus_client import Counter, Gauge
from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import StoreError
-from synapse.api.room_versions import RoomVersion
+from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.events import EventBase, make_event_from_dict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@@ -44,6 +44,12 @@ number_pdus_in_federation_queue = Gauge(
"The total number of events in the inbound federation staging",
)
+pdus_pruned_from_federation_queue = Counter(
+ "synapse_federation_server_number_inbound_pdu_pruned",
+ "The number of events in the inbound federation staging that have been "
+ "pruned due to the queue getting too long",
+)
+
logger = logging.getLogger(__name__)
@@ -936,15 +942,46 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# We want to make sure that we do a breadth-first, "depth" ordered
# search.
- query = (
- "SELECT depth, prev_event_id FROM event_edges"
- " INNER JOIN events"
- " ON prev_event_id = events.event_id"
- " WHERE event_edges.event_id = ?"
- " AND event_edges.is_state = ?"
- " LIMIT ?"
- )
+ # Look for the prev_event_id connected to the given event_id
+ query = """
+ SELECT depth, prev_event_id FROM event_edges
+ /* Get the depth of the prev_event_id from the events table */
+ INNER JOIN events
+ ON prev_event_id = events.event_id
+ /* Find an event which matches the given event_id */
+ WHERE event_edges.event_id = ?
+ AND event_edges.is_state = ?
+ LIMIT ?
+ """
+
+ # Look for the "insertion" events connected to the given event_id
+ connected_insertion_event_query = """
+ SELECT e.depth, i.event_id FROM insertion_event_edges AS i
+ /* Get the depth of the insertion event from the events table */
+ INNER JOIN events AS e USING (event_id)
+ /* Find an insertion event which points via prev_events to the given event_id */
+ WHERE i.insertion_prev_event_id = ?
+ LIMIT ?
+ """
+
+ # Find any chunk connections of a given insertion event
+ chunk_connection_query = """
+ SELECT e.depth, c.event_id FROM insertion_events AS i
+ /* Find the chunk that connects to the given insertion event */
+ INNER JOIN chunk_events AS c
+ ON i.next_chunk_id = c.chunk_id
+ /* Get the depth of the chunk start event from the events table */
+ INNER JOIN events AS e USING (event_id)
+ /* Find an insertion event which matches the given event_id */
+ WHERE i.event_id = ?
+ LIMIT ?
+ """
+ # In a PriorityQueue, the lowest valued entries are retrieved first.
+ # We're using depth as the priority in the queue.
+ # Depth is lowest at the oldest-in-time message and highest and
+ # newest-in-time message. We add events to the queue with a negative depth so that
+ # we process the newest-in-time messages first going backwards in time.
queue = PriorityQueue()
for event_id in event_list:
@@ -970,9 +1007,48 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_results.add(event_id)
+ # Try and find any potential historical chunks of message history.
+ #
+ # First we look for an insertion event connected to the current
+ # event (by prev_event). If we find any, we need to go and try to
+ # find any chunk events connected to the insertion event (by
+ # chunk_id). If we find any, we'll add them to the queue and
+ # navigate up the DAG like normal in the next iteration of the loop.
+ txn.execute(
+ connected_insertion_event_query, (event_id, limit - len(event_results))
+ )
+ connected_insertion_event_id_results = txn.fetchall()
+ logger.debug(
+ "_get_backfill_events: connected_insertion_event_query %s",
+ connected_insertion_event_id_results,
+ )
+ for row in connected_insertion_event_id_results:
+ connected_insertion_event_depth = row[0]
+ connected_insertion_event = row[1]
+ queue.put((-connected_insertion_event_depth, connected_insertion_event))
+
+ # Find any chunk connections for the given insertion event
+ txn.execute(
+ chunk_connection_query,
+ (connected_insertion_event, limit - len(event_results)),
+ )
+ chunk_start_event_id_results = txn.fetchall()
+ logger.debug(
+ "_get_backfill_events: chunk_start_event_id_results %s",
+ chunk_start_event_id_results,
+ )
+ for row in chunk_start_event_id_results:
+ if row[1] not in event_results:
+ queue.put((-row[0], row[1]))
+
+ # Navigate up the DAG by prev_event
txn.execute(query, (event_id, False, limit - len(event_results)))
+ prev_event_id_results = txn.fetchall()
+ logger.debug(
+ "_get_backfill_events: prev_event_ids %s", prev_event_id_results
+ )
- for row in txn:
+ for row in prev_event_id_results:
if row[1] not in event_results:
queue.put((-row[0], row[1]))
@@ -1207,6 +1283,100 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return origin, event
+ async def prune_staged_events_in_room(
+ self,
+ room_id: str,
+ room_version: RoomVersion,
+ ) -> bool:
+ """Checks if there are lots of staged events for the room, and if so
+ prune them down.
+
+ Returns:
+ Whether any events were pruned
+ """
+
+ # First check the size of the queue.
+ count = await self.db_pool.simple_select_one_onecol(
+ table="federation_inbound_events_staging",
+ keyvalues={"room_id": room_id},
+ retcol="COALESCE(COUNT(*), 0)",
+ desc="prune_staged_events_in_room_count",
+ )
+
+ if count < 100:
+ return False
+
+ # If the queue is too large, then we want clear the entire queue,
+ # keeping only the forward extremities (i.e. the events not referenced
+ # by other events in the queue). We do this so that we can always
+ # backpaginate in all the events we have dropped.
+ rows = await self.db_pool.simple_select_list(
+ table="federation_inbound_events_staging",
+ keyvalues={"room_id": room_id},
+ retcols=("event_id", "event_json"),
+ desc="prune_staged_events_in_room_fetch",
+ )
+
+ # Find the set of events referenced by those in the queue, as well as
+ # collecting all the event IDs in the queue.
+ referenced_events: Set[str] = set()
+ seen_events: Set[str] = set()
+ for row in rows:
+ event_id = row["event_id"]
+ seen_events.add(event_id)
+ event_d = db_to_json(row["event_json"])
+
+ # We don't bother parsing the dicts into full blown event objects,
+ # as that is needlessly expensive.
+
+ # We haven't checked that the `prev_events` have the right format
+ # yet, so we check as we go.
+ prev_events = event_d.get("prev_events", [])
+ if not isinstance(prev_events, list):
+ logger.info("Invalid prev_events for %s", event_id)
+ continue
+
+ if room_version.event_format == EventFormatVersions.V1:
+ for prev_event_tuple in prev_events:
+ if not isinstance(prev_event_tuple, list) or len(prev_events) != 2:
+ logger.info("Invalid prev_events for %s", event_id)
+ break
+
+ prev_event_id = prev_event_tuple[0]
+ if not isinstance(prev_event_id, str):
+ logger.info("Invalid prev_events for %s", event_id)
+ break
+
+ referenced_events.add(prev_event_id)
+ else:
+ for prev_event_id in prev_events:
+ if not isinstance(prev_event_id, str):
+ logger.info("Invalid prev_events for %s", event_id)
+ break
+
+ referenced_events.add(prev_event_id)
+
+ to_delete = referenced_events & seen_events
+ if not to_delete:
+ return False
+
+ pdus_pruned_from_federation_queue.inc(len(to_delete))
+ logger.info(
+ "Pruning %d events in room %s from federation queue",
+ len(to_delete),
+ room_id,
+ )
+
+ await self.db_pool.simple_delete_many(
+ table="federation_inbound_events_staging",
+ keyvalues={"room_id": room_id},
+ iterable=to_delete,
+ column="event_id",
+ desc="prune_staged_events_in_room_delete",
+ )
+
+ return True
+
async def get_all_rooms_with_staged_incoming_events(self) -> List[str]:
"""Get the room IDs of all events currently staged."""
return await self.db_pool.simple_select_onecol(
@@ -1227,12 +1397,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
(count,) = txn.fetchone()
txn.execute(
- "SELECT coalesce(min(received_ts), 0) FROM federation_inbound_events_staging"
+ "SELECT min(received_ts) FROM federation_inbound_events_staging"
)
(received_ts,) = txn.fetchone()
- age = self._clock.time_msec() - received_ts
+ # If there is nothing in the staging area default it to 0.
+ age = 0
+ if received_ts is not None:
+ age = self._clock.time_msec() - received_ts
return count, age
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index a396a201d4..86baf397fb 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1502,6 +1502,9 @@ class PersistEventsStore:
self._handle_event_relations(txn, event)
+ self._handle_insertion_event(txn, event)
+ self._handle_chunk_event(txn, event)
+
# Store the labels for this event.
labels = event.content.get(EventContentFields.LABELS)
if labels:
@@ -1754,6 +1757,94 @@ class PersistEventsStore:
if rel_type == RelationTypes.REPLACE:
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
+ def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
+ """Handles keeping track of insertion events and edges/connections.
+ Part of MSC2716.
+
+ Args:
+ txn: The database transaction object
+ event: The event to process
+ """
+
+ if event.type != EventTypes.MSC2716_INSERTION:
+ # Not a insertion event
+ return
+
+ # Skip processing a insertion event if the room version doesn't
+ # support it.
+ room_version = self.store.get_room_version_txn(txn, event.room_id)
+ if not room_version.msc2716_historical:
+ return
+
+ next_chunk_id = event.content.get(EventContentFields.MSC2716_NEXT_CHUNK_ID)
+ if next_chunk_id is None:
+ # Invalid insertion event without next chunk ID
+ return
+
+ logger.debug(
+ "_handle_insertion_event (next_chunk_id=%s) %s", next_chunk_id, event
+ )
+
+ # Keep track of the insertion event and the chunk ID
+ self.db_pool.simple_insert_txn(
+ txn,
+ table="insertion_events",
+ values={
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "next_chunk_id": next_chunk_id,
+ },
+ )
+
+ # Insert an edge for every prev_event connection
+ for prev_event_id in event.prev_events:
+ self.db_pool.simple_insert_txn(
+ txn,
+ table="insertion_event_edges",
+ values={
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "insertion_prev_event_id": prev_event_id,
+ },
+ )
+
+ def _handle_chunk_event(self, txn: LoggingTransaction, event: EventBase):
+ """Handles inserting the chunk edges/connections between the chunk event
+ and an insertion event. Part of MSC2716.
+
+ Args:
+ txn: The database transaction object
+ event: The event to process
+ """
+
+ if event.type != EventTypes.MSC2716_CHUNK:
+ # Not a chunk event
+ return
+
+ # Skip processing a chunk event if the room version doesn't
+ # support it.
+ room_version = self.store.get_room_version_txn(txn, event.room_id)
+ if not room_version.msc2716_historical:
+ return
+
+ chunk_id = event.content.get(EventContentFields.MSC2716_CHUNK_ID)
+ if chunk_id is None:
+ # Invalid chunk event without a chunk ID
+ return
+
+ logger.debug("_handle_chunk_event chunk_id=%s %s", chunk_id, event)
+
+ # Keep track of the insertion event and the chunk ID
+ self.db_pool.simple_insert_txn(
+ txn,
+ table="chunk_events",
+ values={
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "chunk_id": chunk_id,
+ },
+ )
+
def _handle_redaction(self, txn, redacted_event_id):
"""Handles receiving a redaction and checking whether we need to remove
any redacted relations from the database.
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index fe25638289..d213b26703 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -297,17 +297,13 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
Args:
txn (cursor):
user_id (str): user to add/update
-
- Returns:
- bool: True if a new entry was created, False if an
- existing one was updated.
"""
# Am consciously deciding to lock the table on the basis that is ought
# never be a big table and alternative approaches (batching multiple
# upserts into a single txn) introduced a lot of extra complexity.
# See https://github.com/matrix-org/synapse/issues/3854 for more
- is_insert = self.db_pool.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="monthly_active_users",
keyvalues={"user_id": user_id},
@@ -322,8 +318,6 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
txn, self.user_last_seen_monthly_active, (user_id,)
)
- return is_insert
-
async def populate_monthly_active_users(self, user_id):
"""Checks on the state of monthly active user limits and optionally
add the user to the monthly active tables
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 6ddafe5434..443e5f3315 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -363,7 +363,7 @@ class RoomWorkerStore(SQLBaseStore):
self,
start: int,
limit: int,
- order_by: RoomSortOrder,
+ order_by: str,
reverse_order: bool,
search_term: Optional[str],
) -> Tuple[List[Dict[str, Any]], int]:
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 1757064a68..8e22da99ae 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -22,7 +22,7 @@ from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
@@ -58,15 +58,32 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_room_version(self, room_id: str) -> RoomVersion:
"""Get the room_version of a given room
-
Raises:
NotFoundError: if the room is unknown
+ UnsupportedRoomVersionError: if the room uses an unknown room version.
+ Typically this happens if support for the room's version has been
+ removed from Synapse.
+ """
+ return await self.db_pool.runInteraction(
+ "get_room_version_txn",
+ self.get_room_version_txn,
+ room_id,
+ )
+ def get_room_version_txn(
+ self, txn: LoggingTransaction, room_id: str
+ ) -> RoomVersion:
+ """Get the room_version of a given room
+ Args:
+ txn: Transaction object
+ room_id: The room_id of the room you are trying to get the version for
+ Raises:
+ NotFoundError: if the room is unknown
UnsupportedRoomVersionError: if the room uses an unknown room version.
Typically this happens if support for the room's version has been
removed from Synapse.
"""
- room_version_id = await self.get_room_version_id(room_id)
+ room_version_id = self.get_room_version_id_txn(txn, room_id)
v = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not v:
@@ -80,7 +97,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
@cached(max_entries=10000)
async def get_room_version_id(self, room_id: str) -> str:
"""Get the room_version of a given room
+ Raises:
+ NotFoundError: if the room is unknown
+ """
+ return await self.db_pool.runInteraction(
+ "get_room_version_id_txn",
+ self.get_room_version_id_txn,
+ room_id,
+ )
+ def get_room_version_id_txn(self, txn: LoggingTransaction, room_id: str) -> str:
+ """Get the room_version of a given room
+ Args:
+ txn: Transaction object
+ room_id: The room_id of the room you are trying to get the version for
Raises:
NotFoundError: if the room is unknown
"""
@@ -88,24 +118,22 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# First we try looking up room version from the database, but for old
# rooms we might not have added the room version to it yet so we fall
# back to previous behaviour and look in current state events.
-
+ #
# We really should have an entry in the rooms table for every room we
# care about, but let's be a bit paranoid (at least while the background
# update is happening) to avoid breaking existing rooms.
- version = await self.db_pool.simple_select_one_onecol(
+ room_version = self.db_pool.simple_select_one_onecol_txn(
+ txn,
table="rooms",
keyvalues={"room_id": room_id},
retcol="room_version",
- desc="get_room_version",
allow_none=True,
)
- if version is not None:
- return version
+ if room_version is None:
+ raise NotFoundError("Could not room_version for %s" % (room_id,))
- # Retrieve the room's create event
- create_event = await self.get_create_event_for_room(room_id)
- return create_event.content.get("room_version", "1")
+ return room_version
async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
"""Get the predecessor of an upgraded room if it exists.
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 59d67c255b..42edbcc057 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -75,6 +75,7 @@ class UserSortOrder(Enum):
USER_TYPE = ordered alphabetically by `user_type`
AVATAR_URL = ordered alphabetically by `avatar_url`
SHADOW_BANNED = ordered by `shadow_banned`
+ CREATION_TS = ordered by `creation_ts`
"""
MEDIA_LENGTH = "media_length"
@@ -88,6 +89,7 @@ class UserSortOrder(Enum):
USER_TYPE = "user_type"
AVATAR_URL = "avatar_url"
SHADOW_BANNED = "shadow_banned"
+ CREATION_TS = "creation_ts"
class StatsStore(StateDeltasStore):
@@ -647,10 +649,10 @@ class StatsStore(StateDeltasStore):
limit: int,
from_ts: Optional[int] = None,
until_ts: Optional[int] = None,
- order_by: Optional[UserSortOrder] = UserSortOrder.USER_ID.value,
+ order_by: Optional[str] = UserSortOrder.USER_ID.value,
direction: Optional[str] = "f",
search_term: Optional[str] = None,
- ) -> Tuple[List[JsonDict], Dict[str, int]]:
+ ) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users and their uploaded local media
(size and number). This will return a json list of users and the
total number of users matching the filter criteria.
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index d211c423b2..7728d5f102 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -134,16 +134,18 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
response_dict: The response, to be encoded into JSON.
"""
- await self.db_pool.simple_insert(
+ await self.db_pool.simple_upsert(
table="received_transactions",
- values={
+ keyvalues={
"transaction_id": transaction_id,
"origin": origin,
+ },
+ values={},
+ insertion_values={
"response_code": code,
"response_json": db_binary_type(encode_canonical_json(response_dict)),
"ts": self._clock.time_msec(),
},
- or_ignore=True,
desc="set_received_txn_response",
)
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index a6bfb4902a..9d28d69ac7 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -377,7 +377,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
avatar_url = None
def _update_profile_in_user_dir_txn(txn):
- new_entry = self.db_pool.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="user_directory",
keyvalues={"user_id": user_id},
@@ -388,8 +388,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
if isinstance(self.database_engine, PostgresEngine):
# We weight the localpart most highly, then display name and finally
# server name
- if self.database_engine.can_native_upsert:
- sql = """
+ sql = """
INSERT INTO user_directory_search(user_id, vector)
VALUES (?,
setweight(to_tsvector('simple', ?), 'A')
@@ -397,58 +396,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|| setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
"""
- txn.execute(
- sql,
- (
- user_id,
- get_localpart_from_id(user_id),
- get_domain_from_id(user_id),
- display_name,
- ),
- )
- else:
- # TODO: Remove this code after we've bumped the minimum version
- # of postgres to always support upserts, so we can get rid of
- # `new_entry` usage
- if new_entry is True:
- sql = """
- INSERT INTO user_directory_search(user_id, vector)
- VALUES (?,
- setweight(to_tsvector('simple', ?), 'A')
- || setweight(to_tsvector('simple', ?), 'D')
- || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
- )
- """
- txn.execute(
- sql,
- (
- user_id,
- get_localpart_from_id(user_id),
- get_domain_from_id(user_id),
- display_name,
- ),
- )
- elif new_entry is False:
- sql = """
- UPDATE user_directory_search
- SET vector = setweight(to_tsvector('simple', ?), 'A')
- || setweight(to_tsvector('simple', ?), 'D')
- || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
- WHERE user_id = ?
- """
- txn.execute(
- sql,
- (
- get_localpart_from_id(user_id),
- get_domain_from_id(user_id),
- display_name,
- user_id,
- ),
- )
- else:
- raise RuntimeError(
- "upsert returned None when 'can_native_upsert' is False"
- )
+ txn.execute(
+ sql,
+ (
+ user_id,
+ get_localpart_from_id(user_id),
+ get_domain_from_id(user_id),
+ display_name,
+ ),
+ )
elif isinstance(self.database_engine, Sqlite3Engine):
value = "%s %s" % (user_id, display_name) if display_name else user_id
self.db_pool.simple_upsert_txn(
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index e38461adbc..f839c0c24f 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -372,18 +372,23 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
async def store_state_group(
- self, event_id, room_id, prev_group, delta_ids, current_state_ids
+ self,
+ event_id: str,
+ room_id: str,
+ prev_group: Optional[int],
+ delta_ids: Optional[StateMap[str]],
+ current_state_ids: StateMap[str],
) -> int:
"""Store a new set of state, returning a newly assigned state group.
Args:
- event_id (str): The event ID for which the state was calculated
- room_id (str)
- prev_group (int|None): A previous state group for the room, optional.
- delta_ids (dict|None): The delta between state at `prev_group` and
+ event_id: The event ID for which the state was calculated
+ room_id
+ prev_group: A previous state group for the room, optional.
+ delta_ids: The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
- current_state_ids (dict): The state to store. Map of (type, state_key)
+ current_state_ids: The state to store. Map of (type, state_key)
to event_id.
Returns:
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index a39877f0d5..0e8270746d 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -170,7 +170,9 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
end_item = queue[-1]
else:
# need to make a new queue item
- deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+ deferred: ObservableDeferred[_PersistResult] = ObservableDeferred(
+ defer.Deferred(), consumeErrors=True
+ )
end_item = _EventPersistQueueItem(
events_and_contexts=[],
diff --git a/synapse/storage/schema/main/delta/61/01insertion_event_lookups.sql b/synapse/storage/schema/main/delta/61/01insertion_event_lookups.sql
new file mode 100644
index 0000000000..7d7bafc631
--- /dev/null
+++ b/synapse/storage/schema/main/delta/61/01insertion_event_lookups.sql
@@ -0,0 +1,49 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add a table that keeps track of "insertion" events and
+-- their next_chunk_id's so we can navigate to the next chunk of history.
+CREATE TABLE IF NOT EXISTS insertion_events(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ next_chunk_id TEXT NOT NULL
+);
+CREATE UNIQUE INDEX IF NOT EXISTS insertion_events_event_id ON insertion_events(event_id);
+CREATE INDEX IF NOT EXISTS insertion_events_next_chunk_id ON insertion_events(next_chunk_id);
+
+-- Add a table that keeps track of all of the events we are inserting between.
+-- We use this when navigating the DAG and when we hit an event which matches
+-- `insertion_prev_event_id`, it should backfill from the "insertion" event and
+-- navigate the historical messages from there.
+CREATE TABLE IF NOT EXISTS insertion_event_edges(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ insertion_prev_event_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS insertion_event_edges_event_id ON insertion_event_edges(event_id);
+CREATE INDEX IF NOT EXISTS insertion_event_edges_insertion_room_id ON insertion_event_edges(room_id);
+CREATE INDEX IF NOT EXISTS insertion_event_edges_insertion_prev_event_id ON insertion_event_edges(insertion_prev_event_id);
+
+-- Add a table that keeps track of how each chunk is labeled. The chunks are
+-- connected together based on an insertion events `next_chunk_id`.
+CREATE TABLE IF NOT EXISTS chunk_events(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ chunk_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS chunk_events_event_id ON chunk_events(event_id);
+CREATE INDEX IF NOT EXISTS chunk_events_chunk_id ON chunk_events(chunk_id);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index f8fbba9d38..e5400d681a 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -570,8 +570,8 @@ class StateGroupStorage:
event_id: str,
room_id: str,
prev_group: Optional[int],
- delta_ids: Optional[dict],
- current_state_ids: dict,
+ delta_ids: Optional[StateMap[str]],
+ current_state_ids: StateMap[str],
) -> int:
"""Store a new set of state, returning a newly assigned state group.
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 13d300588b..cf4005984b 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -47,20 +47,22 @@ class PaginationConfig:
) -> "PaginationConfig":
direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
- from_tok = parse_string(request, "from")
- to_tok = parse_string(request, "to")
+ from_tok_str = parse_string(request, "from")
+ to_tok_str = parse_string(request, "to")
try:
- if from_tok == "END":
+ from_tok = None
+ if from_tok_str == "END":
from_tok = None # For backwards compat.
- elif from_tok:
- from_tok = await StreamToken.from_string(store, from_tok)
+ elif from_tok_str:
+ from_tok = await StreamToken.from_string(store, from_tok_str)
except Exception:
raise SynapseError(400, "'from' parameter is invalid")
try:
- if to_tok:
- to_tok = await StreamToken.from_string(store, to_tok)
+ to_tok = None
+ if to_tok_str:
+ to_tok = await StreamToken.from_string(store, to_tok_str)
except Exception:
raise SynapseError(400, "'to' parameter is invalid")
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 014db1355b..a3b65aee27 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -23,6 +23,7 @@ from typing import (
Awaitable,
Callable,
Dict,
+ Generic,
Hashable,
Iterable,
List,
@@ -39,6 +40,7 @@ from twisted.internet import defer
from twisted.internet.defer import CancelledError
from twisted.internet.interfaces import IReactorTime
from twisted.python import failure
+from twisted.python.failure import Failure
from synapse.logging.context import (
PreserveLoggingContext,
@@ -49,8 +51,10 @@ from synapse.util import Clock, unwrapFirstError
logger = logging.getLogger(__name__)
+_T = TypeVar("_T")
-class ObservableDeferred:
+
+class ObservableDeferred(Generic[_T]):
"""Wraps a deferred object so that we can add observer deferreds. These
observer deferreds do not affect the callback chain of the original
deferred.
@@ -68,7 +72,7 @@ class ObservableDeferred:
__slots__ = ["_deferred", "_observers", "_result"]
- def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False):
+ def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", set())
@@ -113,7 +117,7 @@ class ObservableDeferred:
deferred.addCallbacks(callback, errback)
- def observe(self) -> defer.Deferred:
+ def observe(self) -> "defer.Deferred[_T]":
"""Observe the underlying deferred.
This returns a brand new deferred that is resolved when the underlying
@@ -121,7 +125,7 @@ class ObservableDeferred:
effect the underlying deferred.
"""
if not self._result:
- d = defer.Deferred()
+ d: "defer.Deferred[_T]" = defer.Deferred()
def remove(r):
self._observers.discard(d)
@@ -135,7 +139,7 @@ class ObservableDeferred:
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)
- def observers(self) -> List[defer.Deferred]:
+ def observers(self) -> "List[defer.Deferred[_T]]":
return self._observers
def has_called(self) -> bool:
@@ -144,7 +148,7 @@ class ObservableDeferred:
def has_succeeded(self) -> bool:
return self._result is not None and self._result[0] is True
- def get_result(self) -> Any:
+ def get_result(self) -> Union[_T, Failure]:
return self._result[1]
def __getattr__(self, name: str) -> Any:
@@ -415,7 +419,7 @@ class ReadWriteLock:
self.key_to_current_writer: Dict[str, defer.Deferred] = {}
async def read(self, key: str) -> ContextManager:
- new_defer = defer.Deferred()
+ new_defer: "defer.Deferred[None]" = defer.Deferred()
curr_readers = self.key_to_current_readers.setdefault(key, set())
curr_writer = self.key_to_current_writer.get(key, None)
@@ -438,7 +442,7 @@ class ReadWriteLock:
return _ctx_manager()
async def write(self, key: str) -> ContextManager:
- new_defer = defer.Deferred()
+ new_defer: "defer.Deferred[None]" = defer.Deferred()
curr_readers = self.key_to_current_readers.get(key, set())
curr_writer = self.key_to_current_writer.get(key, None)
@@ -471,10 +475,8 @@ R = TypeVar("R")
def timeout_deferred(
- deferred: defer.Deferred,
- timeout: float,
- reactor: IReactorTime,
-) -> defer.Deferred:
+ deferred: "defer.Deferred[_T]", timeout: float, reactor: IReactorTime
+) -> "defer.Deferred[_T]":
"""The in built twisted `Deferred.addTimeout` fails to time out deferreds
that have a canceller that throws exceptions. This method creates a new
deferred that wraps and times out the given deferred, correctly handling
@@ -497,7 +499,7 @@ def timeout_deferred(
Returns:
A new Deferred, which will errback with defer.TimeoutError on timeout.
"""
- new_d = defer.Deferred()
+ new_d: "defer.Deferred[_T]" = defer.Deferred()
timed_out = [False]
diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py
index 891bee0b33..e58dd91eda 100644
--- a/synapse/util/caches/cached_call.py
+++ b/synapse/util/caches/cached_call.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import enum
from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union
from twisted.internet.defer import Deferred
@@ -22,6 +22,10 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
TV = TypeVar("TV")
+class _Sentinel(enum.Enum):
+ sentinel = object()
+
+
class CachedCall(Generic[TV]):
"""A wrapper for asynchronous calls whose results should be shared
@@ -65,7 +69,7 @@ class CachedCall(Generic[TV]):
"""
self._callable: Optional[Callable[[], Awaitable[TV]]] = f
self._deferred: Optional[Deferred] = None
- self._result: Union[None, Failure, TV] = None
+ self._result: Union[_Sentinel, TV, Failure] = _Sentinel.sentinel
async def get(self) -> TV:
"""Kick off the call if necessary, and return the result"""
@@ -78,8 +82,9 @@ class CachedCall(Generic[TV]):
self._callable = None
# once the deferred completes, store the result. We cannot simply leave the
- # result in the deferred, since if it's a Failure, GCing the deferred
- # would then log a critical error about unhandled Failures.
+ # result in the deferred, since `awaiting` a deferred destroys its result.
+ # (Also, if it's a Failure, GCing the deferred would log a critical error
+ # about unhandled Failures)
def got_result(r):
self._result = r
@@ -92,13 +97,15 @@ class CachedCall(Generic[TV]):
# and any eventual exception may not be reported.
# we can now await the deferred, and once it completes, return the result.
- await make_deferred_yieldable(self._deferred)
+ if isinstance(self._result, _Sentinel):
+ await make_deferred_yieldable(self._deferred)
+ assert not isinstance(self._result, _Sentinel)
+
+ if isinstance(self._result, Failure):
+ self._result.raiseException()
+ raise AssertionError("unexpected return from Failure.raiseException")
- # I *think* this is the easiest way to correctly raise a Failure without having
- # to gut-wrench into the implementation of Deferred.
- d = Deferred()
- d.callback(self._result)
- return await d
+ return self._result
class RetryOnExceptionCachedCall(Generic[TV]):
diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index 8c6fafc677..b6456392cd 100644
--- a/synapse/util/caches/deferred_cache.py
+++ b/synapse/util/caches/deferred_cache.py
@@ -16,7 +16,16 @@
import enum
import threading
-from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, Union
+from typing import (
+ Callable,
+ Generic,
+ Iterable,
+ MutableMapping,
+ Optional,
+ TypeVar,
+ Union,
+ cast,
+)
from prometheus_client import Gauge
@@ -166,7 +175,7 @@ class DeferredCache(Generic[KT, VT]):
def set(
self,
key: KT,
- value: defer.Deferred,
+ value: "defer.Deferred[VT]",
callback: Optional[Callable[[], None]] = None,
) -> defer.Deferred:
"""Adds a new entry to the cache (or updates an existing one).
@@ -214,7 +223,7 @@ class DeferredCache(Generic[KT, VT]):
if value.called:
result = value.result
if not isinstance(result, failure.Failure):
- self.cache.set(key, result, callbacks)
+ self.cache.set(key, cast(VT, result), callbacks)
return value
# otherwise, we'll add an entry to the _pending_deferred_cache for now,
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 1e8e6b1d01..1ca31e41ac 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -413,7 +413,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
# relevant result for that key.
deferreds_map = {}
for arg in missing:
- deferred = defer.Deferred()
+ deferred: "defer.Deferred[Any]" = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
cache.set(key, deferred, callback=invalidate_callback)
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index dfa30a6229..cb08af7385 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
|