diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 67e789eef7..7e01c18c6c 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -155,9 +155,6 @@ class AccountDataHandler:
max_stream_id = await self._store.remove_account_data_for_room(
user_id, room_id, account_data_type
)
- if max_stream_id is None:
- # The referenced account data did not exist, so no delete occurred.
- return None
self._notifier.on_new_event(
StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
@@ -230,9 +227,6 @@ class AccountDataHandler:
max_stream_id = await self._store.remove_account_data_for_user(
user_id, account_data_type
)
- if max_stream_id is None:
- # The referenced account data did not exist, so no delete occurred.
- return None
self._notifier.on_new_event(
StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
@@ -248,7 +242,6 @@ class AccountDataHandler:
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
account_data_type=account_data_type,
- content={},
)
return response["max_stream_id"]
@@ -343,10 +336,12 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
}
)
- (
- account_data,
- room_account_data,
- ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
+ account_data = await self.store.get_updated_global_account_data_for_user(
+ user_id, last_stream_id
+ )
+ room_account_data = await self.store.get_updated_room_account_data_for_user(
+ user_id, last_stream_id
+ )
for account_data_type, content in account_data.items():
results.append({"type": account_data_type, "content": content})
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index b03c214b14..b06f25b03c 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -14,7 +14,7 @@
import abc
import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
+from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set
from synapse.api.constants import Direction, Membership
from synapse.events import EventBase
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class AdminHandler:
def __init__(self, hs: "HomeServer"):
- self.store = hs.get_datastores().main
+ self._store = hs.get_datastores().main
self._device_handler = hs.get_device_handler()
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
@@ -38,7 +38,7 @@ class AdminHandler:
async def get_whois(self, user: UserID) -> JsonDict:
connections = []
- sessions = await self.store.get_user_ip_and_agents(user)
+ sessions = await self._store.get_user_ip_and_agents(user)
for session in sessions:
connections.append(
{
@@ -57,7 +57,7 @@ class AdminHandler:
async def get_user(self, user: UserID) -> Optional[JsonDict]:
"""Function to get user details"""
- user_info_dict = await self.store.get_user_by_id(user.to_string())
+ user_info_dict = await self._store.get_user_by_id(user.to_string())
if user_info_dict is None:
return None
@@ -89,11 +89,11 @@ class AdminHandler:
}
# Add additional user metadata
- profile = await self.store.get_profileinfo(user.localpart)
- threepids = await self.store.user_get_threepids(user.to_string())
+ profile = await self._store.get_profileinfo(user.localpart)
+ threepids = await self._store.user_get_threepids(user.to_string())
external_ids = [
({"auth_provider": auth_provider, "external_id": external_id})
- for auth_provider, external_id in await self.store.get_external_ids_by_user(
+ for auth_provider, external_id in await self._store.get_external_ids_by_user(
user.to_string()
)
]
@@ -101,7 +101,7 @@ class AdminHandler:
user_info_dict["avatar_url"] = profile.avatar_url
user_info_dict["threepids"] = threepids
user_info_dict["external_ids"] = external_ids
- user_info_dict["erased"] = await self.store.is_user_erased(user.to_string())
+ user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
return user_info_dict
@@ -117,7 +117,7 @@ class AdminHandler:
The returned value is that returned by `writer.finished()`.
"""
# Get all rooms the user is in or has been in
- rooms = await self.store.get_rooms_for_local_user_where_membership_is(
+ rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id,
membership_list=(
Membership.JOIN,
@@ -131,7 +131,7 @@ class AdminHandler:
# We only try and fetch events for rooms the user has been in. If
# they've been e.g. invited to a room without joining then we handle
# those separately.
- rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id)
+ rooms_user_has_been_in = await self._store.get_rooms_user_has_been_in(user_id)
for index, room in enumerate(rooms):
room_id = room.room_id
@@ -140,7 +140,7 @@ class AdminHandler:
"[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms)
)
- forgotten = await self.store.did_forget(user_id, room_id)
+ forgotten = await self._store.did_forget(user_id, room_id)
if forgotten:
logger.info("[%s] User forgot room %d, ignoring", user_id, room_id)
continue
@@ -152,14 +152,14 @@ class AdminHandler:
if room.membership == Membership.INVITE:
event_id = room.event_id
- invite = await self.store.get_event(event_id, allow_none=True)
+ invite = await self._store.get_event(event_id, allow_none=True)
if invite:
invited_state = invite.unsigned["invite_room_state"]
writer.write_invite(room_id, invite, invited_state)
if room.membership == Membership.KNOCK:
event_id = room.event_id
- knock = await self.store.get_event(event_id, allow_none=True)
+ knock = await self._store.get_event(event_id, allow_none=True)
if knock:
knock_state = knock.unsigned["knock_room_state"]
writer.write_knock(room_id, knock, knock_state)
@@ -170,7 +170,7 @@ class AdminHandler:
# were joined. We estimate that point by looking at the
# stream_ordering of the last membership if it wasn't a join.
if room.membership == Membership.JOIN:
- stream_ordering = self.store.get_room_max_stream_ordering()
+ stream_ordering = self._store.get_room_max_stream_ordering()
else:
stream_ordering = room.stream_ordering
@@ -197,7 +197,7 @@ class AdminHandler:
# events that we have and then filtering, this isn't the most
# efficient method perhaps but it does guarantee we get everything.
while True:
- events, _ = await self.store.paginate_room_events(
+ events, _ = await self._store.paginate_room_events(
room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS
)
if not events:
@@ -252,16 +252,49 @@ class AdminHandler:
profile = await self.get_user(UserID.from_string(user_id))
if profile is not None:
writer.write_profile(profile)
+ logger.info("[%s] Written profile", user_id)
# Get all devices the user has
devices = await self._device_handler.get_devices_by_user(user_id)
writer.write_devices(devices)
+ logger.info("[%s] Written %s devices", user_id, len(devices))
# Get all connections the user has
connections = await self.get_whois(UserID.from_string(user_id))
writer.write_connections(
connections["devices"][""]["sessions"][0]["connections"]
)
+ logger.info("[%s] Written %s connections", user_id, len(connections))
+
+ # Get all account data the user has global and in rooms
+ global_data = await self._store.get_global_account_data_for_user(user_id)
+ by_room_data = await self._store.get_room_account_data_for_user(user_id)
+ writer.write_account_data("global", global_data)
+ for room_id in by_room_data:
+ writer.write_account_data(room_id, by_room_data[room_id])
+ logger.info(
+ "[%s] Written account data for %s rooms", user_id, len(by_room_data)
+ )
+
+ # Get all media ids the user has
+ limit = 100
+ start = 0
+ while True:
+ media_ids, total = await self._store.get_local_media_by_user_paginate(
+ start, limit, user_id
+ )
+ for media in media_ids:
+ writer.write_media_id(media["media_id"], media)
+
+ logger.info(
+ "[%s] Written %d media_ids of %s",
+ user_id,
+ (start + len(media_ids)),
+ total,
+ )
+ if (start + limit) >= total:
+ break
+ start += limit
return writer.finished()
@@ -341,6 +374,30 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
+ def write_account_data(
+ self, file_name: str, account_data: Mapping[str, JsonDict]
+ ) -> None:
+ """Write the account data of a user.
+
+ Args:
+ file_name: file name to write data
+ account_data: mapping of global or room account_data
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None:
+ """Write the media's metadata of a user.
+ Exports only the metadata, as this can be fetched from the database via
+ read only. In order to access the files, a connection to the correct
+ media repository would be required.
+
+ Args:
+ media_id: ID of the media.
+ media_metadata: Metadata of one media file.
+ """
+
+ @abc.abstractmethod
def finished(self) -> Any:
"""Called when all data has successfully been exported and written.
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index b4a3ad217a..455d4005df 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -737,7 +737,7 @@ class ApplicationServicesHandler:
)
ret = []
- for (success, result) in results:
+ for success, result in results:
if success:
ret.extend(result)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 30f2d46c3c..308e38edea 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -201,7 +201,7 @@ class AuthHandler:
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs)
if inst.is_enabled():
- self.checkers[inst.AUTH_TYPE] = inst # type: ignore
+ self.checkers[inst.AUTH_TYPE] = inst
self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
@@ -815,7 +815,6 @@ class AuthHandler:
now_ms = self._clock.time_msec()
if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms:
-
raise SynapseError(
HTTPStatus.FORBIDDEN,
"The supplied refresh token has expired",
@@ -1543,6 +1542,17 @@ class AuthHandler:
async def add_threepid(
self, user_id: str, medium: str, address: str, validated_at: int
) -> None:
+ """
+ Adds an association between a user's Matrix ID and a third-party ID (email,
+ phone number).
+
+ Args:
+ user_id: The ID of the user to associate.
+ medium: The medium of the third-party ID (email, msisdn).
+ address: The address of the third-party ID (i.e. an email address).
+ validated_at: The timestamp in ms of when the validation that the user owns
+ this third-party ID occurred.
+ """
# check if medium has a valid value
if medium not in ["email", "msisdn"]:
raise SynapseError(
@@ -1567,43 +1577,44 @@ class AuthHandler:
user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
)
+ # Inform Synapse modules that a 3PID association has been created.
+ await self._third_party_rules.on_add_user_third_party_identifier(
+ user_id, medium, address
+ )
+
+ # Deprecated method for informing Synapse modules that a 3PID association
+ # has successfully been created.
await self._third_party_rules.on_threepid_bind(user_id, medium, address)
- async def delete_threepid(
- self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
- ) -> bool:
- """Attempts to unbind the 3pid on the identity servers and deletes it
- from the local database.
+ async def delete_local_threepid(
+ self, user_id: str, medium: str, address: str
+ ) -> None:
+ """Deletes an association between a third-party ID and a user ID from the local
+ database. This method does not unbind the association from any identity servers.
+
+ If `medium` is 'email' and a pusher is associated with this third-party ID, the
+ pusher will also be deleted.
Args:
user_id: ID of user to remove the 3pid from.
medium: The medium of the 3pid being removed: "email" or "msisdn".
address: The 3pid address to remove.
- id_server: Use the given identity server when unbinding
- any threepids. If None then will attempt to unbind using the
- identity server specified when binding (if known).
-
- Returns:
- Returns True if successfully unbound the 3pid on
- the identity server, False if identity server doesn't support the
- unbind API.
"""
-
# 'Canonicalise' email addresses as per above
if medium == "email":
address = canonicalise_email(address)
- identity_handler = self.hs.get_identity_handler()
- result = await identity_handler.try_unbind_threepid(
- user_id, {"medium": medium, "address": address, "id_server": id_server}
+ await self.store.user_delete_threepid(user_id, medium, address)
+
+ # Inform Synapse modules that a 3PID association has been deleted.
+ await self._third_party_rules.on_remove_user_third_party_identifier(
+ user_id, medium, address
)
- await self.store.user_delete_threepid(user_id, medium, address)
if medium == "email":
await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id="m.email", pushkey=address, user_id=user_id
)
- return result
async def hash(self, password: str) -> str:
"""Computes a secure hash of password.
@@ -2260,7 +2271,6 @@ class PasswordAuthProvider:
async def on_logged_out(
self, user_id: str, device_id: Optional[str], access_token: str
) -> None:
-
# call all of the on_logged_out callbacks
for callback in self.on_logged_out_callbacks:
try:
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index ba58f150d1..c1850521e4 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -100,31 +100,28 @@ class DeactivateAccountHandler:
# unbinding
identity_server_supports_unbinding = True
- # Retrieve the 3PIDs this user has bound to an identity server
- threepids = await self.store.user_get_bound_threepids(user_id)
-
- for threepid in threepids:
+ # Attempt to unbind any known bound threepids to this account from identity
+ # server(s).
+ bound_threepids = await self.store.user_get_bound_threepids(user_id)
+ for threepid in bound_threepids:
try:
result = await self._identity_handler.try_unbind_threepid(
- user_id,
- {
- "medium": threepid["medium"],
- "address": threepid["address"],
- "id_server": id_server,
- },
+ user_id, threepid["medium"], threepid["address"], id_server
)
- identity_server_supports_unbinding &= result
except Exception:
# Do we want this to be a fatal error or should we carry on?
logger.exception("Failed to remove threepid from ID server")
raise SynapseError(400, "Failed to remove threepid from ID server")
- await self.store.user_delete_threepid(
+
+ identity_server_supports_unbinding &= result
+
+ # Remove any local threepid associations for this account.
+ local_threepids = await self.store.user_get_threepids(user_id)
+ for threepid in local_threepids:
+ await self._auth_handler.delete_local_threepid(
user_id, threepid["medium"], threepid["address"]
)
- # Remove all 3PIDs this user has bound to the homeserver
- await self.store.user_delete_threepids(user_id)
-
# delete any devices belonging to the user, which will also
# delete corresponding access tokens.
await self._device_handler.delete_all_devices_for_user(user_id)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 2ea52257cb..1fb23cc9bf 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -14,7 +14,7 @@
import logging
import string
-from typing import TYPE_CHECKING, Iterable, List, Optional
+from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence
from typing_extensions import Literal
@@ -485,7 +485,8 @@ class DirectoryHandler:
)
)
if canonical_alias:
- room_aliases.append(canonical_alias)
+ # Ensure we do not mutate room_aliases.
+ room_aliases = list(room_aliases) + [canonical_alias]
if not self.config.roomdirectory.is_publishing_room_allowed(
user_id, room_id, room_aliases
@@ -496,9 +497,11 @@ class DirectoryHandler:
raise SynapseError(403, "Not allowed to publish room")
# Check if publishing is blocked by a third party module
- allowed_by_third_party_rules = await (
- self.third_party_event_rules.check_visibility_can_be_modified(
- room_id, visibility
+ allowed_by_third_party_rules = (
+ await (
+ self.third_party_event_rules.check_visibility_can_be_modified(
+ room_id, visibility
+ )
)
)
if not allowed_by_third_party_rules:
@@ -528,7 +531,7 @@ class DirectoryHandler:
async def get_aliases_for_room(
self, requester: Requester, room_id: str
- ) -> List[str]:
+ ) -> Sequence[str]:
"""
Get a list of the aliases that currently point to this room on this server
"""
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d2188ca08f..4e9c8d8db0 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -159,19 +159,22 @@ class E2eKeysHandler:
# A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
if remote_queries:
- query_list: List[Tuple[str, Optional[str]]] = []
+ user_ids = set()
+ user_and_device_ids: List[Tuple[str, str]] = []
for user_id, device_ids in remote_queries.items():
if device_ids:
- query_list.extend(
+ user_and_device_ids.extend(
(user_id, device_id) for device_id in device_ids
)
else:
- query_list.append((user_id, None))
+ user_ids.add(user_id)
(
user_ids_not_in_cache,
remote_results,
- ) = await self.store.get_user_devices_from_cache(query_list)
+ ) = await self.store.get_user_devices_from_cache(
+ user_ids, user_and_device_ids
+ )
# Check that the homeserver still shares a room with all cached users.
# Note that this check may be slightly racy when a remote user leaves a
@@ -1298,6 +1301,20 @@ class E2eKeysHandler:
return desired_key_data
+ async def is_cross_signing_set_up_for_user(self, user_id: str) -> bool:
+ """Checks if the user has cross-signing set up
+
+ Args:
+ user_id: The user to check
+
+ Returns:
+ True if the user has cross-signing set up, False otherwise
+ """
+ existing_master_key = await self.store.get_e2e_cross_signing_key(
+ user_id, "master"
+ )
+ return existing_master_key is not None
+
def _check_cross_signing_key(
key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 83f53ceb88..50317ec753 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -188,7 +188,6 @@ class E2eRoomKeysHandler:
# XXX: perhaps we should use a finer grained lock here?
async with self._upload_linearizer.queue(user_id):
-
# Check that the version we're trying to upload is the current version
try:
version_info = await self.store.get_e2e_room_keys_version_info(user_id)
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index a23a8ce2a1..0db0bd7304 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -63,9 +63,18 @@ class EventAuthHandler:
self._store, event, batched_auth_events
)
auth_event_ids = event.auth_event_ids()
- auth_events_by_id = await self._store.get_events(auth_event_ids)
+
if batched_auth_events:
- auth_events_by_id.update(batched_auth_events)
+ # Copy the batched auth events to avoid mutating them.
+ auth_events_by_id = dict(batched_auth_events)
+ needed_auth_event_ids = set(auth_event_ids) - set(batched_auth_events)
+ if needed_auth_event_ids:
+ auth_events_by_id.update(
+ await self._store.get_events(needed_auth_event_ids)
+ )
+ else:
+ auth_events_by_id = await self._store.get_events(auth_event_ids)
+
check_state_dependent_auth_rules(event, auth_events_by_id.values())
def compute_auth_events(
@@ -202,7 +211,7 @@ class EventAuthHandler:
state_ids: StateMap[str],
room_version: RoomVersion,
user_id: str,
- prev_member_event: Optional[EventBase],
+ prev_membership: Optional[str],
) -> None:
"""
Check whether a user can join a room without an invite due to restricted join rules.
@@ -214,15 +223,14 @@ class EventAuthHandler:
state_ids: The state of the room as it currently is.
room_version: The room version of the room being joined.
user_id: The user joining the room.
- prev_member_event: The current membership event for this user.
+ prev_membership: The current membership state for this user. `None` if the
+ user has never joined the room (equivalent to "leave").
Raises:
AuthError if the user cannot join the room.
"""
# If the member is invited or currently joined, then nothing to do.
- if prev_member_event and (
- prev_member_event.membership in (Membership.JOIN, Membership.INVITE)
- ):
+ if prev_membership in (Membership.JOIN, Membership.INVITE):
return
# This is not a room with a restricted join rule, so we don't need to do the
@@ -237,7 +245,6 @@ class EventAuthHandler:
# in any of them.
allowed_rooms = await self.get_rooms_that_allow_join(state_ids)
if not await self.is_user_in_rooms(allowed_rooms, user_id):
-
# If this is a remote request, the user might be in an allowed room
# that we do not know about.
if get_domain_from_id(user_id) != self._server_name:
@@ -255,13 +262,14 @@ class EventAuthHandler:
)
async def has_restricted_join_rules(
- self, state_ids: StateMap[str], room_version: RoomVersion
+ self, partial_state_ids: StateMap[str], room_version: RoomVersion
) -> bool:
"""
Return if the room has the proper join rules set for access via rooms.
Args:
- state_ids: The state of the room as it currently is.
+ state_ids: The state of the room as it currently is. May be full or partial
+ state.
room_version: The room version of the room to query.
Returns:
@@ -272,7 +280,7 @@ class EventAuthHandler:
return False
# If there's no join rule, then it defaults to invite (so this doesn't apply).
- join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
+ join_rules_event_id = partial_state_ids.get((EventTypes.JoinRules, ""), None)
if not join_rules_event_id:
return False
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 949b69cb41..68c07f0265 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -23,7 +23,7 @@ from synapse.events.utils import SerializeEventConfig
from synapse.handlers.presence import format_user_presence_state
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, Requester, UserID
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
@@ -46,13 +46,12 @@ class EventStreamHandler:
async def get_stream(
self,
- auth_user_id: str,
+ requester: Requester,
pagin_config: PaginationConfig,
timeout: int = 0,
as_client_event: bool = True,
affect_presence: bool = True,
room_id: Optional[str] = None,
- is_guest: bool = False,
) -> JsonDict:
"""Fetches the events stream for a given user."""
@@ -62,13 +61,12 @@ class EventStreamHandler:
raise SynapseError(403, "This room has been blocked on this server")
# send any outstanding server notices to the user.
- await self._server_notices_sender.on_user_syncing(auth_user_id)
+ await self._server_notices_sender.on_user_syncing(requester.user.to_string())
- auth_user = UserID.from_string(auth_user_id)
presence_handler = self.hs.get_presence_handler()
context = await presence_handler.user_syncing(
- auth_user_id,
+ requester.user.to_string(),
affect_presence=affect_presence,
presence_state=PresenceState.ONLINE,
)
@@ -82,10 +80,10 @@ class EventStreamHandler:
timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
stream_result = await self.notifier.get_events_for(
- auth_user,
+ requester.user,
pagin_config,
timeout,
- is_guest=is_guest,
+ is_guest=requester.is_guest,
explicit_room_id=room_id,
)
events = stream_result.events
@@ -102,7 +100,7 @@ class EventStreamHandler:
if event.membership != Membership.JOIN:
continue
# Send down presence.
- if event.state_key == auth_user_id:
+ if event.state_key == requester.user.to_string():
# Send down presence for everyone in the room.
users: Iterable[str] = await self.store.get_users_in_room(
event.room_id
@@ -124,7 +122,9 @@ class EventStreamHandler:
chunks = self._event_serializer.serialize_events(
events,
time_now,
- config=SerializeEventConfig(as_client_event=as_client_event),
+ config=SerializeEventConfig(
+ as_client_event=as_client_event, requester=requester
+ ),
)
chunk = {
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4e77bfa55e..2944deaa00 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -49,6 +49,7 @@ from synapse.api.errors import (
FederationPullAttemptBackoffError,
HttpResponseException,
NotFoundError,
+ PartialStateConflictError,
RequestSendFailed,
SynapseError,
)
@@ -56,7 +57,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.crypto.event_signing import compute_event_signature
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.events.validator import EventValidator
from synapse.federation.federation_client import InvalidResponseError
from synapse.http.servlet import assert_params_in_dict
@@ -68,7 +69,6 @@ from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
ReplicationStoreRoomOnOutlierMembershipRestServlet,
)
-from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import JsonDict, StrCollection, get_domain_from_id
from synapse.types.state import StateFilter
@@ -952,7 +952,20 @@ class FederationHandler:
#
# Note that this requires the /send_join request to come back to the
# same server.
+ prev_event_ids = None
if room_version.msc3083_join_rules:
+ # Note that the room's state can change out from under us and render our
+ # nice join rules-conformant event non-conformant by the time we build the
+ # event. When this happens, our validation at the end fails and we respond
+ # to the requesting server with a 403, which is misleading — it indicates
+ # that the user is not allowed to join the room and the joining server
+ # should not bother retrying via this homeserver or any others, when
+ # in fact we've just messed up with building the event.
+ #
+ # To reduce the likelihood of this race, we capture the forward extremities
+ # of the room (prev_event_ids) just before fetching the current state, and
+ # hope that the state we fetch corresponds to the prev events we chose.
+ prev_event_ids = await self.store.get_prev_events_for_room(room_id)
state_ids = await self._state_storage_controller.get_current_state_ids(
room_id
)
@@ -990,15 +1003,21 @@ class FederationHandler:
)
try:
- event, context = await self.event_creation_handler.create_new_client_event(
- builder=builder
+ (
+ event,
+ unpersisted_context,
+ ) = await self.event_creation_handler.create_new_client_event(
+ builder=builder,
+ prev_event_ids=prev_event_ids,
)
except SynapseError as e:
logger.warning("Failed to create join to %s because %s", room_id, e)
raise
# Ensure the user can even join the room.
- await self._federation_event_handler.check_join_restrictions(context, event)
+ await self._federation_event_handler.check_join_restrictions(
+ unpersisted_context, event
+ )
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
@@ -1178,7 +1197,7 @@ class FederationHandler:
},
)
- event, context = await self.event_creation_handler.create_new_client_event(
+ event, _ = await self.event_creation_handler.create_new_client_event(
builder=builder
)
@@ -1228,12 +1247,13 @@ class FederationHandler:
},
)
- event, context = await self.event_creation_handler.create_new_client_event(
- builder=builder
- )
+ (
+ event,
+ unpersisted_context,
+ ) = await self.event_creation_handler.create_new_client_event(builder=builder)
event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
- event, context
+ event, unpersisted_context
)
if not event_allowed:
logger.warning("Creation of knock %s forbidden by third-party rules", event)
@@ -1406,15 +1426,20 @@ class FederationHandler:
try:
(
event,
- context,
+ unpersisted_context,
) = await self.event_creation_handler.create_new_client_event(
builder=builder
)
- event, context = await self.add_display_name_to_third_party_invite(
- room_version_obj, event_dict, event, context
+ (
+ event,
+ unpersisted_context,
+ ) = await self.add_display_name_to_third_party_invite(
+ room_version_obj, event_dict, event, unpersisted_context
)
+ context = await unpersisted_context.persist(event)
+
EventValidator().validate_new(event, self.config)
# We need to tell the transaction queue to send this out, even
@@ -1483,14 +1508,19 @@ class FederationHandler:
try:
(
event,
- context,
+ unpersisted_context,
) = await self.event_creation_handler.create_new_client_event(
builder=builder
)
- event, context = await self.add_display_name_to_third_party_invite(
- room_version_obj, event_dict, event, context
+ (
+ event,
+ unpersisted_context,
+ ) = await self.add_display_name_to_third_party_invite(
+ room_version_obj, event_dict, event, unpersisted_context
)
+ context = await unpersisted_context.persist(event)
+
try:
validate_event_for_room_version(event)
await self._event_auth_handler.check_auth_rules_from_context(event)
@@ -1522,8 +1552,8 @@ class FederationHandler:
room_version_obj: RoomVersion,
event_dict: JsonDict,
event: EventBase,
- context: EventContext,
- ) -> Tuple[EventBase, EventContext]:
+ context: UnpersistedEventContextBase,
+ ) -> Tuple[EventBase, UnpersistedEventContextBase]:
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"],
@@ -1557,11 +1587,14 @@ class FederationHandler:
room_version_obj, event_dict
)
EventValidator().validate_builder(builder)
- event, context = await self.event_creation_handler.create_new_client_event(
- builder=builder
- )
+
+ (
+ event,
+ unpersisted_context,
+ ) = await self.event_creation_handler.create_new_client_event(builder=builder)
+
EventValidator().validate_new(event, self.config)
- return event, context
+ return event, unpersisted_context
async def _check_signature(self, event: EventBase, context: EventContext) -> None:
"""
@@ -1861,6 +1894,11 @@ class FederationHandler:
logger.info("Updating current state for %s", room_id)
# TODO(faster_joins): notify workers in notify_room_un_partial_stated
# https://github.com/matrix-org/synapse/issues/12994
+ #
+ # NB: there's a potential race here. If room is purged just before we
+ # call this, we _might_ end up inserting rows into current_state_events.
+ # (The logic is hard to chase through.) We think this is fine, but if
+ # not the HS admin should purge the room again.
await self.state_handler.update_current_state(room_id)
logger.info("Handling any pending device list updates")
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 2e19df0976..3a65ccbb55 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -47,6 +47,7 @@ from synapse.api.errors import (
FederationError,
FederationPullAttemptBackoffError,
HttpResponseException,
+ PartialStateConflictError,
RequestSendFailed,
SynapseError,
)
@@ -58,7 +59,7 @@ from synapse.event_auth import (
validate_event_for_room_version,
)
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo
from synapse.logging.context import nested_logging_context
from synapse.logging.opentracing import (
@@ -74,7 +75,6 @@ from synapse.replication.http.federation import (
ReplicationFederationSendEventsRestServlet,
)
from synapse.state import StateResolutionStore
-from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import (
PersistedEventPosition,
@@ -426,7 +426,9 @@ class FederationEventHandler:
return event, context
async def check_join_restrictions(
- self, context: EventContext, event: EventBase
+ self,
+ context: UnpersistedEventContextBase,
+ event: EventBase,
) -> None:
"""Check that restrictions in restricted join rules are matched
@@ -439,16 +441,17 @@ class FederationEventHandler:
# Check if the user is already in the room or invited to the room.
user_id = event.state_key
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
- prev_member_event = None
+ prev_membership = None
if prev_member_event_id:
prev_member_event = await self._store.get_event(prev_member_event_id)
+ prev_membership = prev_member_event.membership
# Check if the member should be allowed access via membership in a space.
await self._event_auth_handler.check_restricted_join_rules(
prev_state_ids,
event.room_version,
user_id,
- prev_member_event,
+ prev_membership,
)
@trace
@@ -524,11 +527,57 @@ class FederationEventHandler:
"Peristing join-via-remote %s (partial_state: %s)", event, partial_state
)
with nested_logging_context(suffix=event.event_id):
+ if partial_state:
+ # When handling a second partial state join into a partial state room,
+ # the returned state will exclude the membership from the first join. To
+ # preserve prior memberships, we try to compute the partial state before
+ # the event ourselves if we know about any of the prev events.
+ #
+ # When we don't know about any of the prev events, it's fine to just use
+ # the returned state, since the new join will create a new forward
+ # extremity, and leave the forward extremity containing our prior
+ # memberships alone.
+ prev_event_ids = set(event.prev_event_ids())
+ seen_event_ids = await self._store.have_events_in_timeline(
+ prev_event_ids
+ )
+ missing_event_ids = prev_event_ids - seen_event_ids
+
+ state_maps_to_resolve: List[StateMap[str]] = []
+
+ # Fetch the state after the prev events that we know about.
+ state_maps_to_resolve.extend(
+ (
+ await self._state_storage_controller.get_state_groups_ids(
+ room_id, seen_event_ids, await_full_state=False
+ )
+ ).values()
+ )
+
+ # When there are prev events we do not have the state for, we state
+ # resolve with the state returned by the remote homeserver.
+ if missing_event_ids or len(state_maps_to_resolve) == 0:
+ state_maps_to_resolve.append(
+ {(e.type, e.state_key): e.event_id for e in state}
+ )
+
+ state_ids_before_event = (
+ await self._state_resolution_handler.resolve_events_with_store(
+ event.room_id,
+ room_version.identifier,
+ state_maps_to_resolve,
+ event_map=None,
+ state_res_store=StateResolutionStore(self._store),
+ )
+ )
+ else:
+ state_ids_before_event = {
+ (e.type, e.state_key): e.event_id for e in state
+ }
+
context = await self._state_handler.compute_event_context(
event,
- state_ids_before_event={
- (e.type, e.state_key): e.event_id for e in state
- },
+ state_ids_before_event=state_ids_before_event,
partial_state=partial_state,
)
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 848e46eb9b..bf0f7acf80 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -219,28 +219,31 @@ class IdentityHandler:
data = json_decoder.decode(e.msg) # XXX WAT?
return data
- async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool:
- """Attempt to remove a 3PID from an identity server, or if one is not provided, all
- identity servers we're aware the binding is present on
+ async def try_unbind_threepid(
+ self, mxid: str, medium: str, address: str, id_server: Optional[str]
+ ) -> bool:
+ """Attempt to remove a 3PID from one or more identity servers.
Args:
mxid: Matrix user ID of binding to be removed
- threepid: Dict with medium & address of binding to be
- removed, and an optional id_server.
+ medium: The medium of the third-party ID.
+ address: The address of the third-party ID.
+ id_server: An identity server to attempt to unbind from. If None,
+ attempt to remove the association from all identity servers
+ known to potentially have it.
Raises:
- SynapseError: If we failed to contact the identity server
+ SynapseError: If we failed to contact one or more identity servers.
Returns:
- True on success, otherwise False if the identity
- server doesn't support unbinding (or no identity server found to
- contact).
+ True on success, otherwise False if the identity server doesn't
+ support unbinding (or no identity server to contact was found).
"""
- if threepid.get("id_server"):
- id_servers = [threepid["id_server"]]
+ if id_server:
+ id_servers = [id_server]
else:
id_servers = await self.store.get_id_servers_user_bound(
- user_id=mxid, medium=threepid["medium"], address=threepid["address"]
+ mxid, medium, address
)
# We don't know where to unbind, so we don't have a choice but to return
@@ -249,20 +252,21 @@ class IdentityHandler:
changed = True
for id_server in id_servers:
- changed &= await self.try_unbind_threepid_with_id_server(
- mxid, threepid, id_server
+ changed &= await self._try_unbind_threepid_with_id_server(
+ mxid, medium, address, id_server
)
return changed
- async def try_unbind_threepid_with_id_server(
- self, mxid: str, threepid: dict, id_server: str
+ async def _try_unbind_threepid_with_id_server(
+ self, mxid: str, medium: str, address: str, id_server: str
) -> bool:
"""Removes a binding from an identity server
Args:
mxid: Matrix user ID of binding to be removed
- threepid: Dict with medium & address of binding to be removed
+ medium: The medium of the third-party ID
+ address: The address of the third-party ID
id_server: Identity server to unbind from
Raises:
@@ -286,7 +290,7 @@ class IdentityHandler:
content = {
"mxid": mxid,
- "threepid": {"medium": threepid["medium"], "address": threepid["address"]},
+ "threepid": {"medium": medium, "address": address},
}
# we abuse the federation http client to sign the request, but we have to send it
@@ -319,12 +323,7 @@ class IdentityHandler:
except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
- await self.store.remove_user_bound_threepid(
- user_id=mxid,
- medium=threepid["medium"],
- address=threepid["address"],
- id_server=id_server,
- )
+ await self.store.remove_user_bound_threepid(mxid, medium, address, id_server)
return changed
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 191529bd8e..b3be7a86f0 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -124,7 +124,6 @@ class InitialSyncHandler:
as_client_event: bool = True,
include_archived: bool = False,
) -> JsonDict:
-
memberships = [Membership.INVITE, Membership.JOIN]
if include_archived:
memberships.append(Membership.LEAVE)
@@ -154,9 +153,8 @@ class InitialSyncHandler:
tags_by_room = await self.store.get_tags_for_user(user_id)
- account_data, account_data_by_room = await self.store.get_account_data_for_user(
- user_id
- )
+ account_data = await self.store.get_global_account_data_for_user(user_id)
+ account_data_by_room = await self.store.get_room_account_data_for_user(user_id)
public_room_ids = await self.store.get_public_room_ids()
@@ -320,11 +318,9 @@ class InitialSyncHandler:
)
is_peeking = member_event_id is None
- user_id = requester.user.to_string()
-
if membership == Membership.JOIN:
result = await self._room_initial_sync_joined(
- user_id, room_id, pagin_config, membership, is_peeking
+ requester, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
# The member_event_id will always be available if membership is set
@@ -332,10 +328,16 @@ class InitialSyncHandler:
assert member_event_id
result = await self._room_initial_sync_parted(
- user_id, room_id, pagin_config, membership, member_event_id, is_peeking
+ requester,
+ room_id,
+ pagin_config,
+ membership,
+ member_event_id,
+ is_peeking,
)
account_data_events = []
+ user_id = requester.user.to_string()
tags = await self.store.get_tags_for_room(user_id, room_id)
if tags:
account_data_events.append(
@@ -352,7 +354,7 @@ class InitialSyncHandler:
async def _room_initial_sync_parted(
self,
- user_id: str,
+ requester: Requester,
room_id: str,
pagin_config: PaginationConfig,
membership: str,
@@ -371,13 +373,17 @@ class InitialSyncHandler:
)
messages = await filter_events_for_client(
- self._storage_controllers, user_id, messages, is_peeking=is_peeking
+ self._storage_controllers,
+ requester.user.to_string(),
+ messages,
+ is_peeking=is_peeking,
)
start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token)
end_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, stream_token)
time_now = self.clock.time_msec()
+ serialize_options = SerializeEventConfig(requester=requester)
return {
"membership": membership,
@@ -385,14 +391,18 @@ class InitialSyncHandler:
"messages": {
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
- self._event_serializer.serialize_events(messages, time_now)
+ self._event_serializer.serialize_events(
+ messages, time_now, config=serialize_options
+ )
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),
},
"state": (
# Don't bundle aggregations as this is a deprecated API.
- self._event_serializer.serialize_events(room_state.values(), time_now)
+ self._event_serializer.serialize_events(
+ room_state.values(), time_now, config=serialize_options
+ )
),
"presence": [],
"receipts": [],
@@ -400,7 +410,7 @@ class InitialSyncHandler:
async def _room_initial_sync_joined(
self,
- user_id: str,
+ requester: Requester,
room_id: str,
pagin_config: PaginationConfig,
membership: str,
@@ -412,9 +422,12 @@ class InitialSyncHandler:
# TODO: These concurrently
time_now = self.clock.time_msec()
+ serialize_options = SerializeEventConfig(requester=requester)
# Don't bundle aggregations as this is a deprecated API.
state = self._event_serializer.serialize_events(
- current_state.values(), time_now
+ current_state.values(),
+ time_now,
+ config=serialize_options,
)
now_token = self.hs.get_event_sources().get_current_token()
@@ -452,7 +465,10 @@ class InitialSyncHandler:
if not receipts:
return []
- return ReceiptEventSource.filter_out_private_receipts(receipts, user_id)
+ return ReceiptEventSource.filter_out_private_receipts(
+ receipts,
+ requester.user.to_string(),
+ )
presence, receipts, (messages, token) = await make_deferred_yieldable(
gather_results(
@@ -471,20 +487,23 @@ class InitialSyncHandler:
)
messages = await filter_events_for_client(
- self._storage_controllers, user_id, messages, is_peeking=is_peeking
+ self._storage_controllers,
+ requester.user.to_string(),
+ messages,
+ is_peeking=is_peeking,
)
start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
end_token = now_token
- time_now = self.clock.time_msec()
-
ret = {
"room_id": room_id,
"messages": {
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
- self._event_serializer.serialize_events(messages, time_now)
+ self._event_serializer.serialize_events(
+ messages, time_now, config=serialize_options
+ )
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 1c5fdca12a..29ec7e3544 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -38,6 +38,7 @@ from synapse.api.errors import (
Codes,
ConsentNotGivenError,
NotFoundError,
+ PartialStateConflictError,
ShadowBanError,
SynapseError,
UnstableSpecAuthError,
@@ -48,8 +49,8 @@ from synapse.api.urls import ConsentURIBuilder
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase, relation_from_event
from synapse.events.builder import EventBuilder
-from synapse.events.snapshot import EventContext
-from synapse.events.utils import maybe_upsert_event_field
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
+from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
from synapse.logging import opentracing
@@ -57,7 +58,6 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.replication.http.send_events import ReplicationSendEventsRestServlet
-from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import (
MutableStateMap,
@@ -245,8 +245,11 @@ class MessageHandler:
)
room_state = room_state_events[membership_event_id]
- now = self.clock.time_msec()
- events = self._event_serializer.serialize_events(room_state.values(), now)
+ events = self._event_serializer.serialize_events(
+ room_state.values(),
+ self.clock.time_msec(),
+ config=SerializeEventConfig(requester=requester),
+ )
return events
async def _user_can_see_state_at_event(
@@ -499,9 +502,9 @@ class EventCreationHandler:
self.request_ratelimiter = hs.get_request_ratelimiter()
- # We arbitrarily limit concurrent event creation for a room to 5.
- # This is to stop us from diverging history *too* much.
- self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
+ # We limit concurrent event creation for a room to 1. This prevents state resolution
+ # from occurring when sending bursts of events to a local room
+ self.limiter = Linearizer(max_count=1, name="room_event_creation_limit")
self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
@@ -574,7 +577,7 @@ class EventCreationHandler:
state_map: Optional[StateMap[str]] = None,
for_batch: bool = False,
current_state_group: Optional[int] = None,
- ) -> Tuple[EventBase, EventContext]:
+ ) -> Tuple[EventBase, UnpersistedEventContextBase]:
"""
Given a dict from a client, create a new event. If bool for_batch is true, will
create an event using the prev_event_ids, and will create an event context for
@@ -708,7 +711,7 @@ class EventCreationHandler:
builder.internal_metadata.historical = historical
- event, context = await self.create_new_client_event(
+ event, unpersisted_context = await self.create_new_client_event(
builder=builder,
requester=requester,
allow_no_prev_events=allow_no_prev_events,
@@ -737,7 +740,7 @@ class EventCreationHandler:
assert state_map is not None
prev_event_id = state_map.get((EventTypes.Member, event.sender))
else:
- prev_state_ids = await context.get_prev_state_ids(
+ prev_state_ids = await unpersisted_context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.Member, None)])
)
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
@@ -762,8 +765,7 @@ class EventCreationHandler:
)
self.validator.validate_new(event, self.config)
-
- return event, context
+ return event, unpersisted_context
async def _is_exempt_from_privacy_policy(
self, builder: EventBuilder, requester: Requester
@@ -1003,7 +1005,7 @@ class EventCreationHandler:
max_retries = 5
for i in range(max_retries):
try:
- event, context = await self.create_event(
+ event, unpersisted_context = await self.create_event(
requester,
event_dict,
txn_id=txn_id,
@@ -1014,6 +1016,7 @@ class EventCreationHandler:
historical=historical,
depth=depth,
)
+ context = await unpersisted_context.persist(event)
assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
event.sender,
@@ -1083,13 +1086,14 @@ class EventCreationHandler:
state_map: Optional[StateMap[str]] = None,
for_batch: bool = False,
current_state_group: Optional[int] = None,
- ) -> Tuple[EventBase, EventContext]:
+ ) -> Tuple[EventBase, UnpersistedEventContextBase]:
"""Create a new event for a local client. If bool for_batch is true, will
create an event using the prev_event_ids, and will create an event context for
the event using the parameters state_map and current_state_group, thus these parameters
must be provided in this case if for_batch is True. The subsequently created event
and context are suitable for being batched up and bulk persisted to the database
- with other similarly created events.
+ with other similarly created events. Note that this returns an UnpersistedEventContext,
+ which must be converted to an EventContext before it can be sent to the DB.
Args:
builder:
@@ -1131,7 +1135,7 @@ class EventCreationHandler:
batch persisting
Returns:
- Tuple of created event, context
+ Tuple of created event, UnpersistedEventContext
"""
# Strip down the state_event_ids to only what we need to auth the event.
# For example, we don't need extra m.room.member that don't match event.sender
@@ -1187,14 +1191,20 @@ class EventCreationHandler:
if for_batch:
assert prev_event_ids is not None
assert state_map is not None
- assert current_state_group is not None
auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map)
event = await builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth
)
- context = await self.state.compute_event_context_for_batched(
- event, state_map, current_state_group
+
+ context: UnpersistedEventContextBase = (
+ await self.state.calculate_context_info(
+ event,
+ state_ids_before_event=state_map,
+ partial_state=False,
+ state_group_before_event=current_state_group,
+ )
)
+
else:
event = await builder.build(
prev_event_ids=prev_event_ids,
@@ -1244,16 +1254,17 @@ class EventCreationHandler:
state_map_for_event[(data.event_type, data.state_key)] = state_id
- context = await self.state.compute_event_context(
+ # TODO(faster_joins): check how MSC2716 works and whether we can have
+ # partial state here
+ # https://github.com/matrix-org/synapse/issues/13003
+ context = await self.state.calculate_context_info(
event,
state_ids_before_event=state_map_for_event,
- # TODO(faster_joins): check how MSC2716 works and whether we can have
- # partial state here
- # https://github.com/matrix-org/synapse/issues/13003
partial_state=False,
)
+
else:
- context = await self.state.compute_event_context(event)
+ context = await self.state.calculate_context_info(event)
if requester:
context.app_service = requester.app_service
@@ -1326,7 +1337,11 @@ class EventCreationHandler:
relation.parent_id, event.type, aggregation_key, event.sender
)
if already_exists:
- raise SynapseError(400, "Can't send same reaction twice")
+ raise SynapseError(
+ 400,
+ "Can't send same reaction twice",
+ errcode=Codes.DUPLICATE_ANNOTATION,
+ )
# Don't attempt to start a thread if the parent event is a relation.
elif relation.rel_type == RelationTypes.THREAD:
@@ -2031,7 +2046,7 @@ class EventCreationHandler:
max_retries = 5
for i in range(max_retries):
try:
- event, context = await self.create_event(
+ event, unpersisted_context = await self.create_event(
requester,
{
"type": EventTypes.Dummy,
@@ -2040,6 +2055,7 @@ class EventCreationHandler:
"sender": user_id,
},
)
+ context = await unpersisted_context.persist(event)
event.internal_metadata.proactively_send = False
@@ -2082,9 +2098,9 @@ class EventCreationHandler:
async def _rebuild_event_after_third_party_rules(
self, third_party_result: dict, original_event: EventBase
- ) -> Tuple[EventBase, EventContext]:
+ ) -> Tuple[EventBase, UnpersistedEventContextBase]:
# the third_party_event_rules want to replace the event.
- # we do some basic checks, and then return the replacement event and context.
+ # we do some basic checks, and then return the replacement event.
# Construct a new EventBuilder and validate it, which helps with the
# rest of these checks.
@@ -2138,5 +2154,6 @@ class EventCreationHandler:
# we rebuild the event context, to be on the safe side. If nothing else,
# delta_ids might need an update.
- context = await self.state.compute_event_context(event)
+ context = await self.state.calculate_context_info(event)
+
return event, context
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index f2095ce164..d7085c001d 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -579,7 +579,9 @@ class PaginationHandler:
time_now = self.clock.time_msec()
- serialize_options = SerializeEventConfig(as_client_event=as_client_event)
+ serialize_options = SerializeEventConfig(
+ as_client_event=as_client_event, requester=requester
+ )
chunk = {
"chunk": (
@@ -681,7 +683,7 @@ class PaginationHandler:
await self._storage_controllers.purge_events.purge_room(room_id)
- logger.info("complete")
+ logger.info("purge complete for room_id %s", room_id)
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE
except Exception:
f = Failure()
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index b4c0577e4d..b289b6cb23 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -777,7 +777,6 @@ class PresenceHandler(BasePresenceHandler):
)
if self.unpersisted_users_changes:
-
await self.store.update_presence(
[
self.user_to_current_state[user_id]
@@ -823,7 +822,6 @@ class PresenceHandler(BasePresenceHandler):
now = self.clock.time_msec()
with Measure(self.clock, "presence_update_states"):
-
# NOTE: We purposefully don't await between now and when we've
# calculated what we want to do with the new states, to avoid races.
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 04c61ae3dd..2bacdebfb5 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple
from synapse.api.constants import EduTypes, ReceiptTypes
from synapse.appservice import ApplicationService
@@ -189,7 +189,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
@staticmethod
def filter_out_private_receipts(
- rooms: List[JsonDict], user_id: str
+ rooms: Sequence[JsonDict], user_id: str
) -> List[JsonDict]:
"""
Filters a list of serialized receipts (as returned by /sync and /initialSync)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c611efb760..e4e506e62c 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -476,7 +476,7 @@ class RegistrationHandler:
# create room expects the localpart of the room alias
config["room_alias_name"] = room_alias.localpart
- info, _ = await room_creation_handler.create_room(
+ room_id, _, _ = await room_creation_handler.create_room(
fake_requester,
config=config,
ratelimit=False,
@@ -490,7 +490,7 @@ class RegistrationHandler:
user_id, authenticated_entity=self._server_name
),
target=UserID.from_string(user_id),
- room_id=info["room_id"],
+ room_id=room_id,
# Since it was just created, there are no remote hosts.
remote_room_hosts=[],
action="join",
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 0fb15391e0..1d09fdf135 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -20,6 +20,7 @@ import attr
from synapse.api.constants import Direction, EventTypes, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event
+from synapse.events.utils import SerializeEventConfig
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import trace
from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
@@ -60,13 +61,12 @@ class BundledAggregations:
Some values require additional processing during serialization.
"""
- annotations: Optional[JsonDict] = None
references: Optional[JsonDict] = None
replace: Optional[EventBase] = None
thread: Optional[_ThreadAggregation] = None
def __bool__(self) -> bool:
- return bool(self.annotations or self.references or self.replace or self.thread)
+ return bool(self.references or self.replace or self.thread)
class RelationsHandler:
@@ -152,16 +152,23 @@ class RelationsHandler:
)
now = self._clock.time_msec()
+ serialize_options = SerializeEventConfig(requester=requester)
return_value: JsonDict = {
"chunk": self._event_serializer.serialize_events(
- events, now, bundle_aggregations=aggregations
+ events,
+ now,
+ bundle_aggregations=aggregations,
+ config=serialize_options,
),
}
if include_original_event:
# Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it.
return_value["original_event"] = self._event_serializer.serialize_event(
- event, now, bundle_aggregations=None
+ event,
+ now,
+ bundle_aggregations=None,
+ config=serialize_options,
)
if next_token:
@@ -227,67 +234,6 @@ class RelationsHandler:
e.msg,
)
- async def get_annotations_for_events(
- self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
- ) -> Dict[str, List[JsonDict]]:
- """Get a list of annotations to the given events, grouped by event type and
- aggregation key, sorted by count.
-
- This is used e.g. to get the what and how many reactions have happened
- on an event.
-
- Args:
- event_ids: Fetch events that relate to these event IDs.
- ignored_users: The users ignored by the requesting user.
-
- Returns:
- A map of event IDs to a list of groups of annotations that match.
- Each entry is a dict with `type`, `key` and `count` fields.
- """
- # Get the base results for all users.
- full_results = await self._main_store.get_aggregation_groups_for_events(
- event_ids
- )
-
- # Avoid additional logic if there are no ignored users.
- if not ignored_users:
- return {
- event_id: results
- for event_id, results in full_results.items()
- if results
- }
-
- # Then subtract off the results for any ignored users.
- ignored_results = await self._main_store.get_aggregation_groups_for_users(
- [event_id for event_id, results in full_results.items() if results],
- ignored_users,
- )
-
- filtered_results = {}
- for event_id, results in full_results.items():
- # If no annotations, skip.
- if not results:
- continue
-
- # If there are not ignored results for this event, copy verbatim.
- if event_id not in ignored_results:
- filtered_results[event_id] = results
- continue
-
- # Otherwise, subtract out the ignored results.
- event_ignored_results = ignored_results[event_id]
- for result in results:
- key = (result["type"], result["key"])
- if key in event_ignored_results:
- # Ensure to not modify the cache.
- result = result.copy()
- result["count"] -= event_ignored_results[key]
- if result["count"] <= 0:
- continue
- filtered_results.setdefault(event_id, []).append(result)
-
- return filtered_results
-
async def get_references_for_events(
self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
) -> Dict[str, List[_RelatedEvent]]:
@@ -531,17 +477,6 @@ class RelationsHandler:
# (as that is what makes it part of the thread).
relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD
- async def _fetch_annotations() -> None:
- """Fetch any annotations (ie, reactions) to bundle with this event."""
- annotations_by_event_id = await self.get_annotations_for_events(
- events_by_id.keys(), ignored_users=ignored_users
- )
- for event_id, annotations in annotations_by_event_id.items():
- if annotations:
- results.setdefault(event_id, BundledAggregations()).annotations = {
- "chunk": annotations
- }
-
async def _fetch_references() -> None:
"""Fetch any references to bundle with this event."""
references_by_event_id = await self.get_references_for_events(
@@ -575,7 +510,6 @@ class RelationsHandler:
await make_deferred_yieldable(
gather_results(
(
- run_in_background(_fetch_annotations),
run_in_background(_fetch_references),
run_in_background(_fetch_edits),
)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 7ba7c4ff07..be120cb12f 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -43,6 +43,7 @@ from synapse.api.errors import (
Codes,
LimitExceededError,
NotFoundError,
+ PartialStateConflictError,
StoreError,
SynapseError,
)
@@ -50,11 +51,11 @@ from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
+from synapse.events.snapshot import UnpersistedEventContext
from synapse.events.utils import copy_and_fixup_power_levels_contents
from synapse.handlers.relations import BundledAggregations
from synapse.module_api import NOT_SPAM
from synapse.rest.admin._base import assert_user_is_admin
-from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.streams import EventSource
from synapse.types import (
JsonDict,
@@ -211,7 +212,7 @@ class RoomCreationHandler:
# the required power level to send the tombstone event.
(
tombstone_event,
- tombstone_context,
+ tombstone_unpersisted_context,
) = await self.event_creation_handler.create_event(
requester,
{
@@ -225,6 +226,9 @@ class RoomCreationHandler:
},
},
)
+ tombstone_context = await tombstone_unpersisted_context.persist(
+ tombstone_event
+ )
validate_event_for_room_version(tombstone_event)
await self._event_auth_handler.check_auth_rules_from_context(
tombstone_event
@@ -565,7 +569,7 @@ class RoomCreationHandler:
new_room_id,
# we expect to override all the presets with initial_state, so this is
# somewhat arbitrary.
- preset_config=RoomCreationPreset.PRIVATE_CHAT,
+ room_config={"preset": RoomCreationPreset.PRIVATE_CHAT},
invite_list=[],
initial_state=initial_state,
creation_content=creation_content,
@@ -690,13 +694,14 @@ class RoomCreationHandler:
config: JsonDict,
ratelimit: bool = True,
creator_join_profile: Optional[JsonDict] = None,
- ) -> Tuple[dict, int]:
+ ) -> Tuple[str, Optional[RoomAlias], int]:
"""Creates a new room.
Args:
- requester:
- The user who requested the room creation.
- config : A dict of configuration options.
+ requester: The user who requested the room creation.
+ config: A dict of configuration options. This will be the body of
+ a /createRoom request; see
+ https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3createroom
ratelimit: set to False to disable the rate limiter
creator_join_profile:
@@ -707,14 +712,17 @@ class RoomCreationHandler:
`avatar_url` and/or `displayname`.
Returns:
- First, a dict containing the keys `room_id` and, if an alias
- was, requested, `room_alias`. Secondly, the stream_id of the
- last persisted event.
+ A 3-tuple containing:
+ - the room ID;
+ - if requested, the room alias, otherwise None; and
+ - the `stream_id` of the last persisted event.
Raises:
- SynapseError if the room ID couldn't be stored, 3pid invitation config
- validation failed, or something went horribly wrong.
- ResourceLimitError if server is blocked to some resource being
- exceeded
+ SynapseError:
+ if the room ID couldn't be stored, 3pid invitation config
+ validation failed, or something went horribly wrong.
+ ResourceLimitError:
+ if server is blocked to some resource being
+ exceeded
"""
user_id = requester.user.to_string()
@@ -864,9 +872,11 @@ class RoomCreationHandler:
)
# Check whether this visibility value is blocked by a third party module
- allowed_by_third_party_rules = await (
- self.third_party_event_rules.check_visibility_can_be_modified(
- room_id, visibility
+ allowed_by_third_party_rules = (
+ await (
+ self.third_party_event_rules.check_visibility_can_be_modified(
+ room_id, visibility
+ )
)
)
if not allowed_by_third_party_rules:
@@ -894,13 +904,6 @@ class RoomCreationHandler:
check_membership=False,
)
- preset_config = config.get(
- "preset",
- RoomCreationPreset.PRIVATE_CHAT
- if visibility == "private"
- else RoomCreationPreset.PUBLIC_CHAT,
- )
-
raw_initial_state = config.get("initial_state", [])
initial_state = OrderedDict()
@@ -919,7 +922,7 @@ class RoomCreationHandler:
) = await self._send_events_for_new_room(
requester,
room_id,
- preset_config=preset_config,
+ room_config=config,
invite_list=invite_list,
initial_state=initial_state,
creation_content=creation_content,
@@ -928,48 +931,6 @@ class RoomCreationHandler:
creator_join_profile=creator_join_profile,
)
- if "name" in config:
- name = config["name"]
- (
- name_event,
- last_stream_id,
- ) = await self.event_creation_handler.create_and_send_nonmember_event(
- requester,
- {
- "type": EventTypes.Name,
- "room_id": room_id,
- "sender": user_id,
- "state_key": "",
- "content": {"name": name},
- },
- ratelimit=False,
- prev_event_ids=[last_sent_event_id],
- depth=depth,
- )
- last_sent_event_id = name_event.event_id
- depth += 1
-
- if "topic" in config:
- topic = config["topic"]
- (
- topic_event,
- last_stream_id,
- ) = await self.event_creation_handler.create_and_send_nonmember_event(
- requester,
- {
- "type": EventTypes.Topic,
- "room_id": room_id,
- "sender": user_id,
- "state_key": "",
- "content": {"topic": topic},
- },
- ratelimit=False,
- prev_event_ids=[last_sent_event_id],
- depth=depth,
- )
- last_sent_event_id = topic_event.event_id
- depth += 1
-
# we avoid dropping the lock between invites, as otherwise joins can
# start coming in and making the createRoom slow.
#
@@ -1024,11 +985,6 @@ class RoomCreationHandler:
last_sent_event_id = member_event_id
depth += 1
- result = {"room_id": room_id}
-
- if room_alias:
- result["room_alias"] = room_alias.to_string()
-
# Always wait for room creation to propagate before returning
await self._replication.wait_for_stream_position(
self.hs.config.worker.events_shard_config.get_instance(room_id),
@@ -1036,13 +992,13 @@ class RoomCreationHandler:
last_stream_id,
)
- return result, last_stream_id
+ return room_id, room_alias, last_stream_id
async def _send_events_for_new_room(
self,
creator: Requester,
room_id: str,
- preset_config: str,
+ room_config: JsonDict,
invite_list: List[str],
initial_state: MutableStateMap,
creation_content: JsonDict,
@@ -1059,11 +1015,33 @@ class RoomCreationHandler:
Rate limiting should already have been applied by this point.
+ Args:
+ creator:
+ the user requesting the room creation
+ room_id:
+ room id for the room being created
+ room_config:
+ A dict of configuration options. This will be the body of
+ a /createRoom request; see
+ https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3createroom
+ invite_list:
+ a list of user ids to invite to the room
+ initial_state:
+ A list of state events to set in the new room.
+ creation_content:
+ Extra keys, such as m.federate, to be added to the content of the m.room.create event.
+ room_alias:
+ alias for the room
+ power_level_content_override:
+ The power level content to override in the default power level event.
+ creator_join_profile:
+ Set to override the displayname and avatar for the creating
+ user in this room.
+
Returns:
A tuple containing the stream ID, event ID and depth of the last
event sent to the room.
"""
-
creator_id = creator.user.to_string()
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
depth = 1
@@ -1074,9 +1052,6 @@ class RoomCreationHandler:
# created (but not persisted to the db) to determine state for future created events
# (as this info can't be pulled from the db)
state_map: MutableStateMap[str] = {}
- # current_state_group of last event created. Used for computing event context of
- # events to be batched
- current_state_group = None
def create_event_dict(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
e = {"type": etype, "content": content}
@@ -1091,7 +1066,7 @@ class RoomCreationHandler:
content: JsonDict,
for_batch: bool,
**kwargs: Any,
- ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]:
+ ) -> Tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]:
"""
Creates an event and associated event context.
Args:
@@ -1110,20 +1085,33 @@ class RoomCreationHandler:
event_dict = create_event_dict(etype, content, **kwargs)
- new_event, new_context = await self.event_creation_handler.create_event(
+ (
+ new_event,
+ new_unpersisted_context,
+ ) = await self.event_creation_handler.create_event(
creator,
event_dict,
prev_event_ids=prev_event,
depth=depth,
- state_map=state_map,
+ # Take a copy to ensure each event gets a unique copy of
+ # state_map since it is modified below.
+ state_map=dict(state_map),
for_batch=for_batch,
- current_state_group=current_state_group,
)
+
depth += 1
prev_event = [new_event.event_id]
state_map[(new_event.type, new_event.state_key)] = new_event.event_id
- return new_event, new_context
+ return new_event, new_unpersisted_context
+
+ visibility = room_config.get("visibility", "private")
+ preset_config = room_config.get(
+ "preset",
+ RoomCreationPreset.PRIVATE_CHAT
+ if visibility == "private"
+ else RoomCreationPreset.PUBLIC_CHAT,
+ )
try:
config = self._presets_dict[preset_config]
@@ -1133,10 +1121,10 @@ class RoomCreationHandler:
)
creation_content.update({"creator": creator_id})
- creation_event, creation_context = await create_event(
+ creation_event, unpersisted_creation_context = await create_event(
EventTypes.Create, creation_content, False
)
-
+ creation_context = await unpersisted_creation_context.persist(creation_event)
logger.debug("Sending %s in new room", EventTypes.Member)
ev = await self.event_creation_handler.handle_new_client_event(
requester=creator,
@@ -1180,7 +1168,6 @@ class RoomCreationHandler:
power_event, power_context = await create_event(
EventTypes.PowerLevels, pl_content, True
)
- current_state_group = power_context._state_group
events_to_send.append((power_event, power_context))
else:
power_level_content: JsonDict = {
@@ -1229,14 +1216,12 @@ class RoomCreationHandler:
power_level_content,
True,
)
- current_state_group = pl_context._state_group
events_to_send.append((pl_event, pl_context))
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
room_alias_event, room_alias_context = await create_event(
EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True
)
- current_state_group = room_alias_context._state_group
events_to_send.append((room_alias_event, room_alias_context))
if (EventTypes.JoinRules, "") not in initial_state:
@@ -1245,7 +1230,6 @@ class RoomCreationHandler:
{"join_rule": config["join_rules"]},
True,
)
- current_state_group = join_rules_context._state_group
events_to_send.append((join_rules_event, join_rules_context))
if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
@@ -1254,7 +1238,6 @@ class RoomCreationHandler:
{"history_visibility": config["history_visibility"]},
True,
)
- current_state_group = visibility_context._state_group
events_to_send.append((visibility_event, visibility_context))
if config["guest_can_join"]:
@@ -1264,14 +1247,12 @@ class RoomCreationHandler:
{EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
True,
)
- current_state_group = guest_access_context._state_group
events_to_send.append((guest_access_event, guest_access_context))
for (etype, state_key), content in initial_state.items():
event, context = await create_event(
etype, content, True, state_key=state_key
)
- current_state_group = context._state_group
events_to_send.append((event, context))
if config["encrypted"]:
@@ -1283,9 +1264,34 @@ class RoomCreationHandler:
)
events_to_send.append((encryption_event, encryption_context))
+ if "name" in room_config:
+ name = room_config["name"]
+ name_event, name_context = await create_event(
+ EventTypes.Name,
+ {"name": name},
+ True,
+ )
+ events_to_send.append((name_event, name_context))
+
+ if "topic" in room_config:
+ topic = room_config["topic"]
+ topic_event, topic_context = await create_event(
+ EventTypes.Topic,
+ {"topic": topic},
+ True,
+ )
+ events_to_send.append((topic_event, topic_context))
+
+ datastore = self.hs.get_datastores().state
+ events_and_context = (
+ await UnpersistedEventContext.batch_persist_unpersisted_contexts(
+ events_to_send, room_id, current_state_group, datastore
+ )
+ )
+
last_event = await self.event_creation_handler.handle_new_client_event(
creator,
- events_to_send,
+ events_and_context,
ignore_shadow_ban=True,
ratelimit=False,
)
@@ -1825,7 +1831,7 @@ class RoomShutdownHandler:
new_room_user_id, authenticated_entity=requester_user_id
)
- info, stream_id = await self._room_creation_handler.create_room(
+ new_room_id, _, stream_id = await self._room_creation_handler.create_room(
room_creator_requester,
config={
"preset": RoomCreationPreset.PUBLIC_CHAT,
@@ -1834,7 +1840,6 @@ class RoomShutdownHandler:
},
ratelimit=False,
)
- new_room_id = info["room_id"]
logger.info(
"Shutting down room %r, joining to new room: %r", room_id, new_room_id
@@ -1887,6 +1892,7 @@ class RoomShutdownHandler:
# Join users to new room
if new_room_user_id:
+ assert new_room_id is not None
await self.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,
@@ -1919,6 +1925,7 @@ class RoomShutdownHandler:
aliases_for_room = await self.store.get_aliases_for_room(room_id)
+ assert new_room_id is not None
await self.store.update_aliases_for_room(
room_id, new_room_id, requester_user_id
)
@@ -1928,6 +1935,6 @@ class RoomShutdownHandler:
return {
"kicked_users": kicked_users,
"failed_to_kick_users": failed_to_kick_users,
- "local_aliases": aliases_for_room,
+ "local_aliases": list(aliases_for_room),
"new_room_id": new_room_id,
}
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index c73d2adaad..bf9df60218 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -327,7 +327,7 @@ class RoomBatchHandler:
# Mark all events as historical
event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
- event, context = await self.event_creation_handler.create_event(
+ event, unpersisted_context = await self.event_creation_handler.create_event(
await self.create_requester_for_user_id_from_app_service(
ev["sender"], app_service_requester.app_service
),
@@ -345,7 +345,7 @@ class RoomBatchHandler:
historical=True,
depth=inherited_depth,
)
-
+ context = await unpersisted_context.persist(event)
assert context._state_group
# Normally this is done when persisting the event but we have to
@@ -374,7 +374,7 @@ class RoomBatchHandler:
# correct stream_ordering as they are backfilled (which decrements).
# Events are sorted by (topological_ordering, stream_ordering)
# where topological_ordering is just depth.
- for (event, context) in reversed(events_to_persist):
+ for event, context in reversed(events_to_persist):
# This call can't raise `PartialStateConflictError` since we forbid
# use of the historical batch API during partial state
await self.event_creation_handler.handle_new_client_event(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index d236cc09b5..509c557889 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -26,7 +26,13 @@ from synapse.api.constants import (
GuestAccess,
Membership,
)
-from synapse.api.errors import AuthError, Codes, ShadowBanError, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ PartialStateConflictError,
+ ShadowBanError,
+ SynapseError,
+)
from synapse.api.ratelimiting import Ratelimiter
from synapse.event_auth import get_named_level, get_power_level_event
from synapse.events import EventBase
@@ -34,7 +40,6 @@ from synapse.events.snapshot import EventContext
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.logging import opentracing
from synapse.module_api import NOT_SPAM
-from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.types import (
JsonDict,
Requester,
@@ -56,6 +61,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class NoKnownServersError(SynapseError):
+ """No server already resident to the room was provided to the join/knock operation."""
+
+ def __init__(self, msg: str = "No known servers"):
+ super().__init__(404, msg)
+
+
class RoomMemberHandler(metaclass=abc.ABCMeta):
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
@@ -185,12 +197,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
room_id: Room that we are trying to join
user: User who is trying to join
content: A dict that should be used as the content of the join event.
+
+ Raises:
+ NoKnownServersError: if remote_room_hosts does not contain a server joined to
+ the room.
"""
raise NotImplementedError()
@abc.abstractmethod
async def remote_knock(
self,
+ requester: Requester,
remote_room_hosts: List[str],
room_id: str,
user: UserID,
@@ -398,7 +415,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
max_retries = 5
for i in range(max_retries):
try:
- event, context = await self.event_creation_handler.create_event(
+ (
+ event,
+ unpersisted_context,
+ ) = await self.event_creation_handler.create_event(
requester,
{
"type": EventTypes.Member,
@@ -419,7 +439,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
outlier=outlier,
historical=historical,
)
-
+ context = await unpersisted_context.persist(event)
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.Member, None)])
)
@@ -484,7 +504,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
user_id: The user's ID.
"""
# Retrieve user account data for predecessor room
- user_account_data, _ = await self.store.get_account_data_for_user(user_id)
+ user_account_data = await self.store.get_global_account_data_for_user(user_id)
# Copy direct message state if applicable
direct_rooms = user_account_data.get(AccountDataTypes.DIRECT, {})
@@ -823,14 +843,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
- state_before_join = await self.state_handler.compute_state_after_events(
- room_id, latest_event_ids
+ is_partial_state_room = await self.store.is_partial_state_room(room_id)
+ partial_state_before_join = await self.state_handler.compute_state_after_events(
+ room_id, latest_event_ids, await_full_state=False
)
+ # `is_partial_state_room` also indicates whether `partial_state_before_join` is
+ # partial.
# TODO: Refactor into dictionary of explicitly allowed transitions
# between old and new state, with specific error messages for some
# transitions and generic otherwise
- old_state_id = state_before_join.get((EventTypes.Member, target.to_string()))
+ old_state_id = partial_state_before_join.get(
+ (EventTypes.Member, target.to_string())
+ )
if old_state_id:
old_state = await self.store.get_event(old_state_id, allow_none=True)
old_membership = old_state.content.get("membership") if old_state else None
@@ -881,11 +906,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if action == "kick":
raise AuthError(403, "The target user is not in the room")
- is_host_in_room = await self._is_host_in_room(state_before_join)
+ is_host_in_room = await self._is_host_in_room(partial_state_before_join)
if effective_membership_state == Membership.JOIN:
if requester.is_guest:
- guest_can_join = await self._can_guest_join(state_before_join)
+ guest_can_join = await self._can_guest_join(partial_state_before_join)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
@@ -927,8 +952,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
room_id,
remote_room_hosts,
content,
+ is_partial_state_room,
is_host_in_room,
- state_before_join,
+ partial_state_before_join,
)
if remote_join:
if ratelimit:
@@ -1048,7 +1074,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
return await self.remote_knock(
- remote_room_hosts, room_id, target, content
+ requester, remote_room_hosts, room_id, target, content
)
return await self._local_membership_update(
@@ -1073,8 +1099,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
room_id: str,
remote_room_hosts: List[str],
content: JsonDict,
+ is_partial_state_room: bool,
is_host_in_room: bool,
- state_before_join: StateMap[str],
+ partial_state_before_join: StateMap[str],
) -> Tuple[bool, List[str]]:
"""
Check whether the server should do a remote join (as opposed to a local
@@ -1093,9 +1120,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
remote_room_hosts: A list of remote room hosts.
content: The content to use as the event body of the join. This may
be modified.
- is_host_in_room: True if the host is in the room.
- state_before_join: The state before the join event (i.e. the resolution of
- the states after its parent events).
+ is_partial_state_room: `True` if the server currently doesn't hold the full
+ state of the room.
+ is_host_in_room: `True` if the host is in the room.
+ partial_state_before_join: The state before the join event (i.e. the
+ resolution of the states after its parent events). May be full or
+ partial state, depending on `is_partial_state_room`.
Returns:
A tuple of:
@@ -1109,6 +1139,23 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if not is_host_in_room:
return True, remote_room_hosts
+ prev_member_event_id = partial_state_before_join.get(
+ (EventTypes.Member, user_id), None
+ )
+ previous_membership = None
+ if prev_member_event_id:
+ prev_member_event = await self.store.get_event(prev_member_event_id)
+ previous_membership = prev_member_event.membership
+
+ # If we are not fully joined yet, and the target is not already in the room,
+ # let's do a remote join so another server with the full state can validate
+ # that the user has not been banned for example.
+ # We could just accept the join and wait for state res to resolve that later on
+ # but we would then leak room history to this person until then, which is pretty
+ # bad.
+ if is_partial_state_room and previous_membership != Membership.JOIN:
+ return True, remote_room_hosts
+
# If the host is in the room, but not one of the authorised hosts
# for restricted join rules, a remote join must be used.
room_version = await self.store.get_room_version(room_id)
@@ -1116,21 +1163,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# If restricted join rules are not being used, a local join can always
# be used.
if not await self.event_auth_handler.has_restricted_join_rules(
- state_before_join, room_version
+ partial_state_before_join, room_version
):
return False, []
# If the user is invited to the room or already joined, the join
# event can always be issued locally.
- prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None)
- prev_member_event = None
- if prev_member_event_id:
- prev_member_event = await self.store.get_event(prev_member_event_id)
- if prev_member_event.membership in (
- Membership.JOIN,
- Membership.INVITE,
- ):
- return False, []
+ if previous_membership in (Membership.JOIN, Membership.INVITE):
+ return False, []
+
+ # All the partial state cases are covered above. We have been given the full
+ # state of the room.
+ assert not is_partial_state_room
+ state_before_join = partial_state_before_join
# If the local host has a user who can issue invites, then a local
# join can be done.
@@ -1154,7 +1199,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Ensure the member should be allowed access via membership in a room.
await self.event_auth_handler.check_restricted_join_rules(
- state_before_join, room_version, user_id, prev_member_event
+ state_before_join, room_version, user_id, previous_membership
)
# If this is going to be a local join, additional information must
@@ -1304,11 +1349,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target_user, room_id)
- async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
+ async def _can_guest_join(self, partial_current_state_ids: StateMap[str]) -> bool:
"""
Returns whether a guest can join a room based on its current state.
+
+ Args:
+ partial_current_state_ids: The current state of the room. May be full or
+ partial state.
"""
- guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
+ guest_access_id = partial_current_state_ids.get(
+ (EventTypes.GuestAccess, ""), None
+ )
if not guest_access_id:
return False
@@ -1634,19 +1685,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
return event, stream_id
- async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
+ async def _is_host_in_room(self, partial_current_state_ids: StateMap[str]) -> bool:
+ """Returns whether the homeserver is in the room based on its current state.
+
+ Args:
+ partial_current_state_ids: The current state of the room. May be full or
+ partial state.
+ """
# Have we just created the room, and is this about to be the very
# first member event?
- create_event_id = current_state_ids.get(("m.room.create", ""))
- if len(current_state_ids) == 1 and create_event_id:
+ create_event_id = partial_current_state_ids.get(("m.room.create", ""))
+ if len(partial_current_state_ids) == 1 and create_event_id:
# We can only get here if we're in the process of creating the room
return True
- for etype, state_key in current_state_ids:
+ for etype, state_key in partial_current_state_ids:
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
continue
- event_id = current_state_ids[(etype, state_key)]
+ event_id = partial_current_state_ids[(etype, state_key)]
event = await self.store.get_event(event_id, allow_none=True)
if not event:
continue
@@ -1715,8 +1772,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
]
if len(remote_room_hosts) == 0:
- raise SynapseError(
- 404,
+ raise NoKnownServersError(
"Can't join remote room because no servers "
"that are in the room have been provided.",
)
@@ -1892,7 +1948,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
max_retries = 5
for i in range(max_retries):
try:
- event, context = await self.event_creation_handler.create_event(
+ (
+ event,
+ unpersisted_context,
+ ) = await self.event_creation_handler.create_event(
requester,
event_dict,
txn_id=txn_id,
@@ -1900,6 +1959,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
auth_event_ids=auth_event_ids,
outlier=True,
)
+ context = await unpersisted_context.persist(event)
event.internal_metadata.out_of_band_membership = True
result_event = (
@@ -1925,6 +1985,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
async def remote_knock(
self,
+ requester: Requester,
remote_room_hosts: List[str],
room_id: str,
user: UserID,
@@ -1947,7 +2008,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
]
if len(remote_room_hosts) == 0:
- raise SynapseError(404, "No known servers")
+ raise NoKnownServersError()
return await self.federation_handler.do_knock(
remote_room_hosts, room_id, user.to_string(), content=content
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 221552a2a6..76e36b8a6d 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -15,8 +15,7 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
-from synapse.api.errors import SynapseError
-from synapse.handlers.room_member import RoomMemberHandler
+from synapse.handlers.room_member import NoKnownServersError, RoomMemberHandler
from synapse.replication.http.membership import (
ReplicationRemoteJoinRestServlet as ReplRemoteJoin,
ReplicationRemoteKnockRestServlet as ReplRemoteKnock,
@@ -52,7 +51,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
) -> Tuple[str, int]:
"""Implements RoomMemberHandler._remote_join"""
if len(remote_room_hosts) == 0:
- raise SynapseError(404, "No known servers")
+ raise NoKnownServersError()
ret = await self._remote_join_client(
requester=requester,
@@ -114,6 +113,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
async def remote_knock(
self,
+ requester: Requester,
remote_room_hosts: List[str],
room_id: str,
user: UserID,
@@ -124,9 +124,10 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
Implements RoomMemberHandler.remote_knock
"""
ret = await self._remote_knock_client(
+ requester=requester,
remote_room_hosts=remote_room_hosts,
room_id=room_id,
- user=user,
+ user_id=user.to_string(),
content=content,
)
return ret["event_id"], ret["stream_id"]
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 4472019fbc..807245160d 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -521,8 +521,8 @@ class RoomSummaryHandler:
It should return true if:
- * The requester is joined or can join the room (per MSC3173).
- * The origin server has any user that is joined or can join the room.
+ * The requesting user is joined or can join the room (per MSC3173); or
+ * The origin server has any user that is joined or can join the room; or
* The history visibility is set to world readable.
Args:
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 9bbf83047d..aad4706f14 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -23,7 +23,8 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.events import EventBase
-from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID
+from synapse.events.utils import SerializeEventConfig
+from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType, UserID
from synapse.types.state import StateFilter
from synapse.visibility import filter_events_for_client
@@ -109,12 +110,12 @@ class SearchHandler:
return historical_room_ids
async def search(
- self, user: UserID, content: JsonDict, batch: Optional[str] = None
+ self, requester: Requester, content: JsonDict, batch: Optional[str] = None
) -> JsonDict:
"""Performs a full text search for a user.
Args:
- user: The user performing the search.
+ requester: The user performing the search.
content: Search parameters
batch: The next_batch parameter. Used for pagination.
@@ -199,7 +200,7 @@ class SearchHandler:
)
return await self._search(
- user,
+ requester,
batch_group,
batch_group_key,
batch_token,
@@ -217,7 +218,7 @@ class SearchHandler:
async def _search(
self,
- user: UserID,
+ requester: Requester,
batch_group: Optional[str],
batch_group_key: Optional[str],
batch_token: Optional[str],
@@ -235,7 +236,7 @@ class SearchHandler:
"""Performs a full text search for a user.
Args:
- user: The user performing the search.
+ requester: The user performing the search.
batch_group: Pagination information.
batch_group_key: Pagination information.
batch_token: Pagination information.
@@ -269,7 +270,7 @@ class SearchHandler:
# TODO: Search through left rooms too
rooms = await self.store.get_rooms_for_local_user_where_membership_is(
- user.to_string(),
+ requester.user.to_string(),
membership_list=[Membership.JOIN],
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
)
@@ -303,13 +304,13 @@ class SearchHandler:
if order_by == "rank":
search_result, sender_group = await self._search_by_rank(
- user, room_ids, search_term, keys, search_filter
+ requester.user, room_ids, search_term, keys, search_filter
)
# Unused return values for rank search.
global_next_batch = None
elif order_by == "recent":
search_result, global_next_batch = await self._search_by_recent(
- user,
+ requester.user,
room_ids,
search_term,
keys,
@@ -334,7 +335,7 @@ class SearchHandler:
assert after_limit is not None
contexts = await self._calculate_event_contexts(
- user,
+ requester.user,
search_result.allowed_events,
before_limit,
after_limit,
@@ -363,27 +364,37 @@ class SearchHandler:
# The returned events.
search_result.allowed_events,
),
- user.to_string(),
+ requester.user.to_string(),
)
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise, the 'age' will be wrong.
time_now = self.clock.time_msec()
+ serialize_options = SerializeEventConfig(requester=requester)
for context in contexts.values():
context["events_before"] = self._event_serializer.serialize_events(
- context["events_before"], time_now, bundle_aggregations=aggregations
+ context["events_before"],
+ time_now,
+ bundle_aggregations=aggregations,
+ config=serialize_options,
)
context["events_after"] = self._event_serializer.serialize_events(
- context["events_after"], time_now, bundle_aggregations=aggregations
+ context["events_after"],
+ time_now,
+ bundle_aggregations=aggregations,
+ config=serialize_options,
)
results = [
{
"rank": search_result.rank_map[e.event_id],
"result": self._event_serializer.serialize_event(
- e, time_now, bundle_aggregations=aggregations
+ e,
+ time_now,
+ bundle_aggregations=aggregations,
+ config=serialize_options,
),
"context": contexts.get(e.event_id, {}),
}
@@ -398,7 +409,9 @@ class SearchHandler:
if state_results:
rooms_cat_res["state"] = {
- room_id: self._event_serializer.serialize_events(state_events, time_now)
+ room_id: self._event_serializer.serialize_events(
+ state_events, time_now, config=serialize_options
+ )
for room_id, state_events in state_results.items()
}
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 3566537894..9f5b83ed54 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -269,6 +269,8 @@ class SyncHandler:
self._state_storage_controller = self._storage_controllers.state
self._device_handler = hs.get_device_handler()
+ self.should_calculate_push_rules = hs.config.push.enable_push
+
# TODO: flush cache entries on subsequent sync request.
# Once we get the next /sync request (ie, one with the same access token
# that sets 'since' to 'next_batch'), we know that device won't need a
@@ -1224,6 +1226,10 @@ class SyncHandler:
continue
event_with_membership_auth = events_with_membership_auth[member]
+ is_create = (
+ event_with_membership_auth.is_state()
+ and event_with_membership_auth.type == EventTypes.Create
+ )
is_join = (
event_with_membership_auth.is_state()
and event_with_membership_auth.type == EventTypes.Member
@@ -1231,9 +1237,10 @@ class SyncHandler:
and event_with_membership_auth.content.get("membership")
== Membership.JOIN
)
- if not is_join:
+ if not is_create and not is_join:
# The event must include the desired membership as an auth event, unless
- # it's the first join event for a given user.
+ # it's the `m.room.create` event for a room or the first join event for
+ # a given user.
missing_members.add(member)
auth_event_ids.update(event_with_membership_auth.auth_event_ids())
@@ -1288,8 +1295,13 @@ class SyncHandler:
async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig
) -> RoomNotifCounts:
- with Measure(self.clock, "unread_notifs_for_room_id"):
+ if not self.should_calculate_push_rules:
+ # If push rules have been universally disabled then we know we won't
+ # have any unread counts in the DB, so we may as well skip asking
+ # the DB.
+ return RoomNotifCounts.empty()
+ with Measure(self.clock, "unread_notifs_for_room_id"):
return await self.store.get_unread_event_push_actions_by_room_for_user(
room_id,
sync_config.user.to_string(),
@@ -1391,6 +1403,11 @@ class SyncHandler:
for room_id, is_partial_state in results.items()
if is_partial_state
)
+ membership_change_events = [
+ event
+ for event in membership_change_events
+ if not results.get(event.room_id, False)
+ ]
# Incremental eager syncs should additionally include rooms that
# - we are joined to
@@ -1444,9 +1461,9 @@ class SyncHandler:
logger.debug("Fetching account data")
- account_data_by_room = await self._generate_sync_entry_for_account_data(
- sync_result_builder
- )
+ # Global account data is included if it is not filtered out.
+ if not sync_config.filter_collection.blocks_all_global_account_data():
+ await self._generate_sync_entry_for_account_data(sync_result_builder)
# Presence data is included if the server has it enabled and not filtered out.
include_presence_data = bool(
@@ -1472,9 +1489,7 @@ class SyncHandler:
(
newly_joined_rooms,
newly_left_rooms,
- ) = await self._generate_sync_entry_for_rooms(
- sync_result_builder, account_data_by_room
- )
+ ) = await self._generate_sync_entry_for_rooms(sync_result_builder)
# Work out which users have joined or left rooms we're in. We use this
# to build the presence and device_list parts of the sync response in
@@ -1521,7 +1536,7 @@ class SyncHandler:
one_time_keys_count = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
- unused_fallback_key_types = (
+ unused_fallback_key_types = list(
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)
@@ -1717,35 +1732,29 @@ class SyncHandler:
async def _generate_sync_entry_for_account_data(
self, sync_result_builder: "SyncResultBuilder"
- ) -> Dict[str, Dict[str, JsonDict]]:
- """Generates the account data portion of the sync response.
+ ) -> None:
+ """Generates the global account data portion of the sync response.
Account data (called "Client Config" in the spec) can be set either globally
or for a specific room. Account data consists of a list of events which
accumulate state, much like a room.
- This function retrieves global and per-room account data. The former is written
- to the given `sync_result_builder`. The latter is returned directly, to be
- later written to the `sync_result_builder` on a room-by-room basis.
+ This function retrieves global account data and writes it to the given
+ `sync_result_builder`. See `_generate_sync_entry_for_rooms` for handling
+ of per-room account data.
Args:
sync_result_builder
-
- Returns:
- A dictionary whose keys (room ids) map to the per room account data for that
- room.
"""
sync_config = sync_result_builder.sync_config
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
if since_token and not sync_result_builder.full_state:
- # TODO Do not fetch room account data if it will be unused.
- (
- global_account_data,
- account_data_by_room,
- ) = await self.store.get_updated_account_data_for_user(
- user_id, since_token.account_data_key
+ global_account_data = (
+ await self.store.get_updated_global_account_data_for_user(
+ user_id, since_token.account_data_key
+ )
)
push_rules_changed = await self.store.have_push_rules_changed_for_user(
@@ -1753,31 +1762,31 @@ class SyncHandler:
)
if push_rules_changed:
+ global_account_data = dict(global_account_data)
global_account_data["m.push_rules"] = await self.push_rules_for_user(
sync_config.user
)
else:
- # TODO Do not fetch room account data if it will be unused.
- (
- global_account_data,
- account_data_by_room,
- ) = await self.store.get_account_data_for_user(sync_config.user.to_string())
+ all_global_account_data = await self.store.get_global_account_data_for_user(
+ user_id
+ )
+ global_account_data = dict(all_global_account_data)
global_account_data["m.push_rules"] = await self.push_rules_for_user(
sync_config.user
)
- account_data_for_user = await sync_config.filter_collection.filter_account_data(
- [
- {"type": account_data_type, "content": content}
- for account_data_type, content in global_account_data.items()
- ]
+ account_data_for_user = (
+ await sync_config.filter_collection.filter_global_account_data(
+ [
+ {"type": account_data_type, "content": content}
+ for account_data_type, content in global_account_data.items()
+ ]
+ )
)
sync_result_builder.account_data = account_data_for_user
- return account_data_by_room
-
async def _generate_sync_entry_for_presence(
self,
sync_result_builder: "SyncResultBuilder",
@@ -1837,9 +1846,7 @@ class SyncHandler:
sync_result_builder.presence = presence
async def _generate_sync_entry_for_rooms(
- self,
- sync_result_builder: "SyncResultBuilder",
- account_data_by_room: Dict[str, Dict[str, JsonDict]],
+ self, sync_result_builder: "SyncResultBuilder"
) -> Tuple[AbstractSet[str], AbstractSet[str]]:
"""Generates the rooms portion of the sync response. Populates the
`sync_result_builder` with the result.
@@ -1850,7 +1857,6 @@ class SyncHandler:
Args:
sync_result_builder
- account_data_by_room: Dictionary of per room account data
Returns:
Returns a 2-tuple describing rooms the user has joined or left.
@@ -1863,9 +1869,30 @@ class SyncHandler:
since_token = sync_result_builder.since_token
user_id = sync_result_builder.sync_config.user.to_string()
+ blocks_all_rooms = (
+ sync_result_builder.sync_config.filter_collection.blocks_all_rooms()
+ )
+
+ # 0. Start by fetching room account data (if required).
+ if (
+ blocks_all_rooms
+ or sync_result_builder.sync_config.filter_collection.blocks_all_room_account_data()
+ ):
+ account_data_by_room: Mapping[str, Mapping[str, JsonDict]] = {}
+ elif since_token and not sync_result_builder.full_state:
+ account_data_by_room = (
+ await self.store.get_updated_room_account_data_for_user(
+ user_id, since_token.account_data_key
+ )
+ )
+ else:
+ account_data_by_room = await self.store.get_room_account_data_for_user(
+ user_id
+ )
+
# 1. Start by fetching all ephemeral events in rooms we've joined (if required).
block_all_room_ephemeral = (
- sync_result_builder.sync_config.filter_collection.blocks_all_rooms()
+ blocks_all_rooms
or sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
)
if block_all_room_ephemeral:
@@ -2291,8 +2318,8 @@ class SyncHandler:
sync_result_builder: "SyncResultBuilder",
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
- tags: Optional[Dict[str, Dict[str, Any]]],
- account_data: Dict[str, JsonDict],
+ tags: Optional[Mapping[str, Mapping[str, Any]]],
+ account_data: Mapping[str, JsonDict],
always_include: bool = False,
) -> None:
"""Populates the `joined` and `archived` section of `sync_result_builder`
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 332edcca24..78a75bfed6 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -13,7 +13,8 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, ClassVar, Sequence, Type
from twisted.web.client import PartialDownloadError
@@ -27,19 +28,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class UserInteractiveAuthChecker:
+class UserInteractiveAuthChecker(ABC):
"""Abstract base class for an interactive auth checker"""
- def __init__(self, hs: "HomeServer"):
+ # This should really be an "abstract class property", i.e. it should
+ # be an error to instantiate a subclass that doesn't specify an AUTH_TYPE.
+ # But calling this a `ClassVar` is simpler than a decorator stack of
+ # @property @abstractmethod and @classmethod (if that's even the right order).
+ AUTH_TYPE: ClassVar[str]
+
+ def __init__(self, hs: "HomeServer"): # noqa: B027
pass
+ @abstractmethod
def is_enabled(self) -> bool:
"""Check if the configuration of the homeserver allows this checker to work
Returns:
True if this login type is enabled.
"""
+ raise NotImplementedError()
+ @abstractmethod
async def check_auth(self, authdict: dict, clientip: str) -> Any:
"""Given the authentication dict from the client, attempt to check this step
@@ -304,7 +314,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
)
-INTERACTIVE_AUTH_CHECKERS = [
+INTERACTIVE_AUTH_CHECKERS: Sequence[Type[UserInteractiveAuthChecker]] = [
DummyAuthChecker,
TermsAuthChecker,
RecaptchaAuthChecker,
|