summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py12
-rw-r--r--synapse/api/auth.py2
-rw-r--r--synapse/app/generic_worker.py2
-rw-r--r--synapse/config/registration.py18
-rw-r--r--synapse/config/server.py7
-rw-r--r--synapse/events/builder.py19
-rw-r--r--synapse/events/snapshot.py46
-rw-r--r--synapse/events/third_party_rules.py37
-rw-r--r--synapse/events/utils.py15
-rw-r--r--synapse/groups/attestations.py25
-rw-r--r--synapse/handlers/federation.py2
-rw-r--r--synapse/handlers/room_member.py8
-rw-r--r--synapse/handlers/saml_handler.py3
-rw-r--r--synapse/handlers/sync.py6
-rw-r--r--synapse/http/client.py16
-rw-r--r--synapse/http/matrixfederationclient.py7
-rw-r--r--synapse/push/action_generator.py7
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py62
-rw-r--r--synapse/push/httppusher.py58
-rw-r--r--synapse/push/presentable_names.py15
-rw-r--r--synapse/push/push_tools.py35
-rw-r--r--synapse/push/pusherpool.py72
-rw-r--r--synapse/python_dependencies.py2
-rw-r--r--synapse/replication/http/federation.py4
-rw-r--r--synapse/replication/http/send_event.py2
-rw-r--r--synapse/rest/admin/rooms.py11
-rw-r--r--synapse/rest/client/v2_alpha/sync.py1
-rw-r--r--synapse/rest/media/v1/_base.py8
-rw-r--r--synapse/rest/media/v1/media_repository.py105
-rw-r--r--synapse/rest/media/v1/media_storage.py52
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py10
-rw-r--r--synapse/rest/media/v1/storage_provider.py62
-rw-r--r--synapse/storage/data_stores/main/cache.py1
-rw-r--r--synapse/storage/data_stores/main/event_push_actions.py4
-rw-r--r--synapse/storage/data_stores/main/events.py48
-rw-r--r--synapse/storage/data_stores/main/events_worker.py86
-rw-r--r--synapse/storage/data_stores/main/purge_events.py2
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql18
-rw-r--r--synapse/storage/database.py18
-rw-r--r--synapse/storage/persist_events.py40
-rw-r--r--synapse/storage/purge_events.py38
-rw-r--r--synapse/storage/state.py207
-rw-r--r--synapse/visibility.py30
43 files changed, 690 insertions, 533 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py

index 5155e719a1..f70381bc71 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py
@@ -17,6 +17,7 @@ """ This is a reference implementation of a Matrix homeserver. """ +import json import os import sys @@ -25,6 +26,9 @@ if sys.version_info < (3, 5): print("Synapse requires Python 3.5 or above.") sys.exit(1) +# Twisted and canonicaljson will fail to import when this file is executed to +# get the __version__ during a fresh install. That's OK and subsequent calls to +# actually start Synapse will import these libraries fine. try: from twisted.internet import protocol from twisted.internet.protocol import Factory @@ -36,6 +40,14 @@ try: except ImportError: pass +# Use the standard library json implementation instead of simplejson. +try: + from canonicaljson import set_json_library + + set_json_library(json) +except ImportError: + pass + __version__ = "1.18.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 71859c9bc0..afb5702439 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py
@@ -82,7 +82,7 @@ class Auth(object): @defer.inlineCallbacks def check_from_context(self, room_version: str, event, context, do_sig_check=True): - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) auth_events_ids = yield self.compute_auth_events( event, prev_state_ids, for_verification=True ) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index c1ab8378e1..e77467491e 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py
@@ -628,7 +628,7 @@ class GenericWorkerServer(HomeServer): self.get_tcp_replication().start_replication(self) - def remove_pusher(self, app_id, push_key, user_id): + async def remove_pusher(self, app_id, push_key, user_id): self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) def build_replication_data_handler(self): diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 080b9bc445..3c2e951a71 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py
@@ -412,24 +412,6 @@ class RegistrationConfig(Config): # #default_identity_server: https://matrix.org - # The list of identity servers trusted to verify third party - # identifiers by this server. - # - # Also defines the ID server which will be called when an account is - # deactivated (one will be picked arbitrarily). - # - # Note: This option is deprecated. Since v0.99.4, Synapse has tracked which identity - # server a 3PID has been bound to. For 3PIDs bound before then, Synapse runs a - # background migration script, informing itself that the identity server all of its - # 3PIDs have been bound to is likely one of the below. - # - # As of Synapse v1.4.0, all other functionality of this option has been deprecated, and - # it is now solely used for the purposes of the background migration script, and can be - # removed once it has run. - #trusted_third_party_id_servers: - # - matrix.org - # - vector.im - # If enabled, user IDs, display names and avatar URLs will be replicated # to this server whenever they change. # This is an experimental API currently implemented by sydent to support diff --git a/synapse/config/server.py b/synapse/config/server.py
index 57b0b97164..c73f7077b5 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py
@@ -445,6 +445,9 @@ class ServerConfig(Config): validator=attr.validators.instance_of(str), default=ROOM_COMPLEXITY_TOO_GREAT, ) + admins_can_join = attr.ib( + validator=attr.validators.instance_of(bool), default=False + ) self.limit_remote_rooms = LimitRemoteRoomsConfig( **(config.get("limit_remote_rooms") or {}) @@ -927,6 +930,10 @@ class ServerConfig(Config): # #complexity_error: "This room is too complex." + # allow server admins to join complex rooms. Default is false. + # + #admins_can_join: true + # Whether to require a user to be in the room to add an alias to it. # Defaults to 'true'. # diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 0bb216419a..69b53ca2bc 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py
@@ -17,8 +17,6 @@ from typing import Optional import attr from nacl.signing import SigningKey -from twisted.internet import defer - from synapse.api.constants import MAX_DEPTH from synapse.api.errors import UnsupportedRoomVersionError from synapse.api.room_versions import ( @@ -95,31 +93,30 @@ class EventBuilder(object): def is_state(self): return self._state_key is not None - @defer.inlineCallbacks - def build(self, prev_event_ids): + async def build(self, prev_event_ids): """Transform into a fully signed and hashed event Args: prev_event_ids (list[str]): The event IDs to use as the prev events Returns: - Deferred[FrozenEvent] + FrozenEvent """ - state_ids = yield defer.ensureDeferred( - self._state.get_current_state_ids(self.room_id, prev_event_ids) + state_ids = await self._state.get_current_state_ids( + self.room_id, prev_event_ids ) - auth_ids = yield self._auth.compute_auth_events(self, state_ids) + auth_ids = await self._auth.compute_auth_events(self, state_ids) format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: - auth_events = yield self._store.add_event_hashes(auth_ids) - prev_events = yield self._store.add_event_hashes(prev_event_ids) + auth_events = await self._store.add_event_hashes(auth_ids) + prev_events = await self._store.add_event_hashes(prev_event_ids) else: auth_events = auth_ids prev_events = prev_event_ids - old_depth = yield self._store.get_max_depth_of(prev_event_ids) + old_depth = await self._store.get_max_depth_of(prev_event_ids) depth = old_depth + 1 # we cap depth of generated events, to ensure that they are not diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index f94cdcbaba..cca93e3a46 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py
@@ -12,17 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union import attr from frozendict import frozendict -from twisted.internet import defer - from synapse.appservice import ApplicationService +from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.types import StateMap +if TYPE_CHECKING: + from synapse.storage.data_stores.main import DataStore + @attr.s(slots=True) class EventContext: @@ -129,8 +131,7 @@ class EventContext: delta_ids=delta_ids, ) - @defer.inlineCallbacks - def serialize(self, event, store): + async def serialize(self, event: EventBase, store: "DataStore") -> dict: """Converts self to a type that can be serialized as JSON, and then deserialized by `deserialize` @@ -146,7 +147,7 @@ class EventContext: # the prev_state_ids, so if we're a state event we include the event # id that we replaced in the state. if event.is_state(): - prev_state_ids = yield self.get_prev_state_ids() + prev_state_ids = await self.get_prev_state_ids() prev_state_id = prev_state_ids.get((event.type, event.state_key)) else: prev_state_id = None @@ -214,8 +215,7 @@ class EventContext: return self._state_group - @defer.inlineCallbacks - def get_current_state_ids(self): + async def get_current_state_ids(self) -> Optional[StateMap[str]]: """ Gets the room state map, including this event - ie, the state in ``state_group`` @@ -224,32 +224,31 @@ class EventContext: ``rejected`` is set. Returns: - Deferred[dict[(str, str), str]|None]: Returns None if state_group - is None, which happens when the associated event is an outlier. + Returns None if state_group is None, which happens when the associated + event is an outlier. - Maps a (type, state_key) to the event ID of the state event matching - this tuple. + Maps a (type, state_key) to the event ID of the state event matching + this tuple. """ if self.rejected: raise RuntimeError("Attempt to access state_ids of rejected event") - yield self._ensure_fetched() + await self._ensure_fetched() return self._current_state_ids - @defer.inlineCallbacks - def get_prev_state_ids(self): + async def get_prev_state_ids(self): """ Gets the room state map, excluding this event. For a non-state event, this will be the same as get_current_state_ids(). Returns: - Deferred[dict[(str, str), str]|None]: Returns None if state_group + dict[(str, str), str]|None: Returns None if state_group is None, which happens when the associated event is an outlier. Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - yield self._ensure_fetched() + await self._ensure_fetched() return self._prev_state_ids def get_cached_current_state_ids(self): @@ -269,8 +268,8 @@ class EventContext: return self._current_state_ids - def _ensure_fetched(self): - return defer.succeed(None) + async def _ensure_fetched(self): + return None @attr.s(slots=True) @@ -303,21 +302,20 @@ class _AsyncEventContextImpl(EventContext): _event_state_key = attr.ib(default=None) _fetching_state_deferred = attr.ib(default=None) - def _ensure_fetched(self): + async def _ensure_fetched(self): if not self._fetching_state_deferred: self._fetching_state_deferred = run_in_background(self._fill_out_state) - return make_deferred_yieldable(self._fetching_state_deferred) + return await make_deferred_yieldable(self._fetching_state_deferred) - @defer.inlineCallbacks - def _fill_out_state(self): + async def _fill_out_state(self): """Called to populate the _current_state_ids and _prev_state_ids attributes by loading from the database. """ if self.state_group is None: return - self._current_state_ids = yield self._storage.state.get_state_ids_for_group( + self._current_state_ids = await self._storage.state.get_state_ids_for_group( self.state_group ) if self._event_state_key is not None: diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 317a6eb638..7fe1525e0f 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py
@@ -15,8 +15,9 @@ from typing import Callable from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.module_api import ModuleApi -from synapse.types import StateMap +from synapse.types import Requester, StateMap class ThirdPartyEventRules(object): @@ -42,15 +43,17 @@ class ThirdPartyEventRules(object): config=config, module_api=ModuleApi(hs, hs.get_auth_handler()), ) - async def check_event_allowed(self, event, context): + async def check_event_allowed( + self, event: EventBase, context: EventContext + ) -> bool: """Check if a provided event should be allowed in the given context. Args: - event (synapse.events.EventBase): The event to be checked. - context (synapse.events.snapshot.EventContext): The context of the event. + event: The event to be checked. + context: The context of the event. Returns: - defer.Deferred[bool]: True if the event should be allowed, False if not. + True if the event should be allowed, False if not. """ if self.third_party_rules is None: return True @@ -65,17 +68,19 @@ class ThirdPartyEventRules(object): ret = await self.third_party_rules.check_event_allowed(event, state_events) return ret - async def on_create_room(self, requester, config, is_requester_admin): + async def on_create_room( + self, requester: Requester, config: dict, is_requester_admin: bool + ) -> bool: """Intercept requests to create room to allow, deny or update the request config. Args: - requester (Requester) - config (dict): The creation config from the client. - is_requester_admin (bool): If the requester is an admin + requester + config: The creation config from the client. + is_requester_admin: If the requester is an admin Returns: - defer.Deferred[bool]: Whether room creation is allowed or denied. + Whether room creation is allowed or denied. """ if self.third_party_rules is None: @@ -86,16 +91,18 @@ class ThirdPartyEventRules(object): ) return ret - async def check_threepid_can_be_invited(self, medium, address, room_id): + async def check_threepid_can_be_invited( + self, medium: str, address: str, room_id: str + ) -> bool: """Check if a provided 3PID can be invited in the given room. Args: - medium (str): The 3PID's medium. - address (str): The 3PID's address. - room_id (str): The room we want to invite the threepid to. + medium: The 3PID's medium. + address: The 3PID's address. + room_id: The room we want to invite the threepid to. Returns: - defer.Deferred[bool], True if the 3PID can be invited, False if not. + True if the 3PID can be invited, False if not. """ if self.third_party_rules is None: diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 11f0d34ec8..2d42e268c6 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py
@@ -18,8 +18,6 @@ from typing import Any, Mapping, Union from frozendict import frozendict -from twisted.internet import defer - from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion @@ -337,8 +335,9 @@ class EventClientSerializer(object): hs.config.experimental_msc1849_support_enabled ) - @defer.inlineCallbacks - def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs): + async def serialize_event( + self, event, time_now, bundle_aggregations=True, **kwargs + ): """Serializes a single event. Args: @@ -348,7 +347,7 @@ class EventClientSerializer(object): **kwargs: Arguments to pass to `serialize_event` Returns: - Deferred[dict]: The serialized event + dict: The serialized event """ # To handle the case of presence events and the like if not isinstance(event, EventBase): @@ -363,8 +362,8 @@ class EventClientSerializer(object): if not event.internal_metadata.is_redacted() and ( self.experimental_msc1849_support_enabled and bundle_aggregations ): - annotations = yield self.store.get_aggregation_groups_for_event(event_id) - references = yield self.store.get_relations_for_event( + annotations = await self.store.get_aggregation_groups_for_event(event_id) + references = await self.store.get_relations_for_event( event_id, RelationTypes.REFERENCE, direction="f" ) @@ -378,7 +377,7 @@ class EventClientSerializer(object): edit = None if event.type == EventTypes.Message: - edit = yield self.store.get_applicable_edit(event_id) + edit = await self.store.get_applicable_edit(event_id) if edit: # If there is an edit replace the content, preserving existing diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index dab13c243f..e674bf44a2 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py
@@ -41,8 +41,6 @@ from typing import Tuple from signedjson.sign import sign_json -from twisted.internet import defer - from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import get_domain_from_id @@ -72,8 +70,9 @@ class GroupAttestationSigning(object): self.server_name = hs.hostname self.signing_key = hs.signing_key - @defer.inlineCallbacks - def verify_attestation(self, attestation, group_id, user_id, server_name=None): + async def verify_attestation( + self, attestation, group_id, user_id, server_name=None + ): """Verifies that the given attestation matches the given parameters. An optional server_name can be supplied to explicitly set which server's @@ -102,7 +101,7 @@ class GroupAttestationSigning(object): if valid_until_ms < now: raise SynapseError(400, "Attestation expired") - yield self.keyring.verify_json_for_server( + await self.keyring.verify_json_for_server( server_name, attestation, now, "Group attestation" ) @@ -142,8 +141,7 @@ class GroupAttestionRenewer(object): self._start_renew_attestations, 30 * 60 * 1000 ) - @defer.inlineCallbacks - def on_renew_attestation(self, group_id, user_id, content): + async def on_renew_attestation(self, group_id, user_id, content): """When a remote updates an attestation """ attestation = content["attestation"] @@ -151,11 +149,11 @@ class GroupAttestionRenewer(object): if not self.is_mine_id(group_id) and not self.is_mine_id(user_id): raise SynapseError(400, "Neither user not group are on this server") - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( attestation, user_id=user_id, group_id=group_id ) - yield self.store.update_remote_attestion(group_id, user_id, attestation) + await self.store.update_remote_attestion(group_id, user_id, attestation) return {} @@ -172,8 +170,7 @@ class GroupAttestionRenewer(object): now + UPDATE_ATTESTATION_TIME_MS ) - @defer.inlineCallbacks - def _renew_attestation(group_user: Tuple[str, str]): + async def _renew_attestation(group_user: Tuple[str, str]): group_id, user_id = group_user try: if not self.is_mine_id(group_id): @@ -186,16 +183,16 @@ class GroupAttestionRenewer(object): user_id, group_id, ) - yield self.store.remove_attestation_renewal(group_id, user_id) + await self.store.remove_attestation_renewal(group_id, user_id) return attestation = self.attestations.create_attestation(group_id, user_id) - yield self.transport_client.renew_group_attestation( + await self.transport_client.renew_group_attestation( destination, group_id, user_id, content={"attestation": attestation} ) - yield self.store.update_attestation_renewal( + await self.store.update_attestation_renewal( group_id, user_id, attestation ) except (RequestSendFailed, HttpResponseException) as e: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 8a7ceb0df2..8b8e989bdc 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -2480,7 +2480,7 @@ class FederationHandler(BaseHandler): } current_state_ids = await context.get_current_state_ids() - current_state_ids = dict(current_state_ids) + current_state_ids = dict(current_state_ids) # type: ignore current_state_ids.update(state_updates) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index c3fee116c2..a11ffae753 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -1007,7 +1007,11 @@ class RoomMemberMasterHandler(RoomMemberHandler): if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") - if self.hs.config.limit_remote_rooms.enabled: + check_complexity = self.hs.config.limit_remote_rooms.enabled + if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join: + check_complexity = not await self.hs.auth.is_server_admin(user) + + if check_complexity: # Fetch the room complexity too_complex = await self._is_remote_room_too_complex( room_id, remote_room_hosts @@ -1030,7 +1034,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # Check the room we just joined wasn't too large, if we didn't fetch the # complexity of it before. - if self.hs.config.limit_remote_rooms.enabled: + if check_complexity: if too_complex is False: # We checked, and we're under the limit. return event_id, stream_id diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index abecaa8313..2d506dc1f2 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py
@@ -96,6 +96,9 @@ class SamlHandler: relay_state=client_redirect_url ) + # Since SAML sessions timeout it is useful to log when they were created. + logger.info("Initiating a new SAML session: %s" % (reqid,)) + now = self._clock.time_msec() self._outstanding_requests_dict[reqid] = Saml2SessionData( creation_time=now, ui_auth_session_id=ui_auth_session_id, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index ebd3e98105..eaa4eeadf7 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -103,6 +103,7 @@ class JoinedSyncResult: account_data = attr.ib(type=List[JsonDict]) unread_notifications = attr.ib(type=JsonDict) summary = attr.ib(type=Optional[JsonDict]) + unread_count = attr.ib(type=int) def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -1886,6 +1887,10 @@ class SyncHandler(object): if room_builder.rtype == "joined": unread_notifications = {} # type: Dict[str, str] + + unread_count = await self.store.get_unread_message_count_for_user( + room_id, sync_config.user.to_string(), + ) room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, @@ -1894,6 +1899,7 @@ class SyncHandler(object): account_data=account_data_events, unread_notifications=unread_notifications, summary=summary, + unread_count=unread_count, ) if room_sync or always_include: diff --git a/synapse/http/client.py b/synapse/http/client.py
index 6bc51202cd..155b7460d4 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py
@@ -395,7 +395,9 @@ class SimpleHttpClient(object): if 200 <= response.code < 300: return json.loads(body.decode("utf-8")) else: - raise HttpResponseException(response.code, response.phrase, body) + raise HttpResponseException( + response.code, response.phrase.decode("ascii", errors="replace"), body + ) @defer.inlineCallbacks def post_json_get_json(self, uri, post_json, headers=None): @@ -436,7 +438,9 @@ class SimpleHttpClient(object): if 200 <= response.code < 300: return json.loads(body.decode("utf-8")) else: - raise HttpResponseException(response.code, response.phrase, body) + raise HttpResponseException( + response.code, response.phrase.decode("ascii", errors="replace"), body + ) @defer.inlineCallbacks def get_json(self, uri, args={}, headers=None): @@ -509,7 +513,9 @@ class SimpleHttpClient(object): if 200 <= response.code < 300: return json.loads(body.decode("utf-8")) else: - raise HttpResponseException(response.code, response.phrase, body) + raise HttpResponseException( + response.code, response.phrase.decode("ascii", errors="replace"), body + ) @defer.inlineCallbacks def get_raw(self, uri, args={}, headers=None): @@ -544,7 +550,9 @@ class SimpleHttpClient(object): if 200 <= response.code < 300: return body else: - raise HttpResponseException(response.code, response.phrase, body) + raise HttpResponseException( + response.code, response.phrase.decode("ascii", errors="replace"), body + ) # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # The two should be factored out. diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 148eeb19dc..ea026ed9f4 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py
@@ -447,6 +447,7 @@ class MatrixFederationHttpClient(object): ).inc() set_tag(tags.HTTP_STATUS_CODE, response.code) + response_phrase = response.phrase.decode("ascii", errors="replace") if 200 <= response.code < 300: logger.debug( @@ -454,7 +455,7 @@ class MatrixFederationHttpClient(object): request.txn_id, request.destination, response.code, - response.phrase.decode("ascii", errors="replace"), + response_phrase, ) pass else: @@ -463,7 +464,7 @@ class MatrixFederationHttpClient(object): request.txn_id, request.destination, response.code, - response.phrase.decode("ascii", errors="replace"), + response_phrase, ) # :'( # Update transactions table? @@ -487,7 +488,7 @@ class MatrixFederationHttpClient(object): ) body = None - e = HttpResponseException(response.code, response.phrase, body) + e = HttpResponseException(response.code, response_phrase, body) # Retry if the error is a 429 (Too Many Requests), # otherwise just raise a standard HttpResponseException diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index 1ffd5e2df3..0d23142653 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py
@@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.util.metrics import Measure from .bulk_push_rule_evaluator import BulkPushRuleEvaluator @@ -37,7 +35,6 @@ class ActionGenerator(object): # event stream, so we just run the rules for a client with no profile # tag (ie. we just need all the users). - @defer.inlineCallbacks - def handle_push_actions_for_event(self, event, context): + async def handle_push_actions_for_event(self, event, context): with Measure(self.clock, "action_for_event_by_user"): - yield self.bulk_evaluator.action_for_event_by_user(event, context) + await self.bulk_evaluator.action_for_event_by_user(event, context) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 472ddf9f7d..04b9d8ac82 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -19,8 +19,6 @@ from collections import namedtuple from prometheus_client import Counter -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.event_auth import get_user_power_level from synapse.state import POWER_KEY @@ -70,8 +68,7 @@ class BulkPushRuleEvaluator(object): resizable=False, ) - @defer.inlineCallbacks - def _get_rules_for_event(self, event, context): + async def _get_rules_for_event(self, event, context): """This gets the rules for all users in the room at the time of the event, as well as the push rules for the invitee if the event is an invite. @@ -79,19 +76,19 @@ class BulkPushRuleEvaluator(object): dict of user_id -> push_rules """ room_id = event.room_id - rules_for_room = yield self._get_rules_for_room(room_id) + rules_for_room = await self._get_rules_for_room(room_id) - rules_by_user = yield rules_for_room.get_rules(event, context) + rules_by_user = await rules_for_room.get_rules(event, context) # if this event is an invite event, we may need to run rules for the user # who's been invited, otherwise they won't get told they've been invited if event.type == "m.room.member" and event.content["membership"] == "invite": invited = event.state_key if invited and self.hs.is_mine_id(invited): - has_pusher = yield self.store.user_has_pusher(invited) + has_pusher = await self.store.user_has_pusher(invited) if has_pusher: rules_by_user = dict(rules_by_user) - rules_by_user[invited] = yield self.store.get_push_rules_for_user( + rules_by_user[invited] = await self.store.get_push_rules_for_user( invited ) @@ -114,20 +111,19 @@ class BulkPushRuleEvaluator(object): self.room_push_rule_cache_metrics, ) - @defer.inlineCallbacks - def _get_power_levels_and_sender_level(self, event, context): - prev_state_ids = yield context.get_prev_state_ids() + async def _get_power_levels_and_sender_level(self, event, context): + prev_state_ids = await context.get_prev_state_ids() pl_event_id = prev_state_ids.get(POWER_KEY) if pl_event_id: # fastpath: if there's a power level event, that's all we need, and # not having a power level event is an extreme edge case - pl_event = yield self.store.get_event(pl_event_id) + pl_event = await self.store.get_event(pl_event_id) auth_events = {POWER_KEY: pl_event} else: - auth_events_ids = yield self.auth.compute_auth_events( + auth_events_ids = await self.auth.compute_auth_events( event, prev_state_ids, for_verification=False ) - auth_events = yield self.store.get_events(auth_events_ids) + auth_events = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} sender_level = get_user_power_level(event.sender, auth_events) @@ -136,23 +132,19 @@ class BulkPushRuleEvaluator(object): return pl_event.content if pl_event else {}, sender_level - @defer.inlineCallbacks - def action_for_event_by_user(self, event, context): + async def action_for_event_by_user(self, event, context) -> None: """Given an event and context, evaluate the push rules and insert the results into the event_push_actions_staging table. - - Returns: - Deferred """ - rules_by_user = yield self._get_rules_for_event(event, context) + rules_by_user = await self._get_rules_for_event(event, context) actions_by_user = {} - room_members = yield self.store.get_joined_users_from_context(event, context) + room_members = await self.store.get_joined_users_from_context(event, context) ( power_levels, sender_power_level, - ) = yield self._get_power_levels_and_sender_level(event, context) + ) = await self._get_power_levels_and_sender_level(event, context) evaluator = PushRuleEvaluatorForEvent( event, len(room_members), sender_power_level, power_levels @@ -165,7 +157,7 @@ class BulkPushRuleEvaluator(object): continue if not event.is_state(): - is_ignored = yield self.store.is_ignored_by(event.sender, uid) + is_ignored = await self.store.is_ignored_by(event.sender, uid) if is_ignored: continue @@ -197,7 +189,7 @@ class BulkPushRuleEvaluator(object): # Mark in the DB staging area the push actions for users who should be # notified for this event. (This will then get handled when we persist # the event) - yield self.store.add_push_actions_to_staging(event.event_id, actions_by_user) + await self.store.add_push_actions_to_staging(event.event_id, actions_by_user) def _condition_checker(evaluator, conditions, uid, display_name, cache): @@ -274,8 +266,7 @@ class RulesForRoom(object): # to self around in the callback. self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id) - @defer.inlineCallbacks - def get_rules(self, event, context): + async def get_rules(self, event, context): """Given an event context return the rules for all users who are currently in the room. """ @@ -286,7 +277,7 @@ class RulesForRoom(object): self.room_push_rule_cache_metrics.inc_hits() return self.rules_by_user - with (yield self.linearizer.queue(())): + with (await self.linearizer.queue(())): if state_group and self.state_group == state_group: logger.debug("Using cached rules for %r", self.room_id) self.room_push_rule_cache_metrics.inc_hits() @@ -304,9 +295,7 @@ class RulesForRoom(object): push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = yield defer.ensureDeferred( - context.get_current_state_ids() - ) + current_state_ids = await context.get_current_state_ids() push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc(len(current_state_ids)) @@ -353,7 +342,7 @@ class RulesForRoom(object): # If we have some memebr events we haven't seen, look them up # and fetch push rules for them if appropriate. logger.debug("Found new member events %r", missing_member_event_ids) - yield self._update_rules_with_member_event_ids( + await self._update_rules_with_member_event_ids( ret_rules_by_user, missing_member_event_ids, state_group, event ) else: @@ -371,8 +360,7 @@ class RulesForRoom(object): ) return ret_rules_by_user - @defer.inlineCallbacks - def _update_rules_with_member_event_ids( + async def _update_rules_with_member_event_ids( self, ret_rules_by_user, member_event_ids, state_group, event ): """Update the partially filled rules_by_user dict by fetching rules for @@ -388,7 +376,7 @@ class RulesForRoom(object): """ sequence = self.sequence - rows = yield self.store.get_membership_from_event_ids(member_event_ids.values()) + rows = await self.store.get_membership_from_event_ids(member_event_ids.values()) members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} @@ -410,7 +398,7 @@ class RulesForRoom(object): logger.debug("Joined: %r", interested_in_user_ids) - if_users_with_pushers = yield self.store.get_if_users_have_pushers( + if_users_with_pushers = await self.store.get_if_users_have_pushers( interested_in_user_ids, on_invalidate=self.invalidate_all_cb ) @@ -420,7 +408,7 @@ class RulesForRoom(object): logger.debug("With pushers: %r", user_ids) - users_with_receipts = yield self.store.get_users_with_read_receipts_in_room( + users_with_receipts = await self.store.get_users_with_read_receipts_in_room( self.room_id, on_invalidate=self.invalidate_all_cb ) @@ -431,7 +419,7 @@ class RulesForRoom(object): if uid in interested_in_user_ids: user_ids.add(uid) - rules_by_user = yield self.store.bulk_get_push_rules( + rules_by_user = await self.store.bulk_get_push_rules( user_ids, on_invalidate=self.invalidate_all_cb ) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 2fac07593b..4c469efb20 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py
@@ -17,7 +17,6 @@ import logging from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.error import AlreadyCalled, AlreadyCancelled from synapse.api.constants import EventTypes @@ -128,12 +127,11 @@ class HttpPusher(object): # but currently that's the only type of receipt anyway... run_as_background_process("http_pusher.on_new_receipts", self._update_badge) - @defer.inlineCallbacks - def _update_badge(self): + async def _update_badge(self): # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems # to be largely redundant. perhaps we can remove it. - badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) - yield self._send_badge(badge) + badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + await self._send_badge(badge) def on_timer(self): self._start_processing() @@ -152,8 +150,7 @@ class HttpPusher(object): run_as_background_process("httppush.process", self._process) - @defer.inlineCallbacks - def _process(self): + async def _process(self): # we should never get here if we are already processing assert not self._is_processing @@ -164,7 +161,7 @@ class HttpPusher(object): while True: starting_max_ordering = self.max_stream_ordering try: - yield self._unsafe_process() + await self._unsafe_process() except Exception: logger.exception("Exception processing notifs") if self.max_stream_ordering == starting_max_ordering: @@ -172,8 +169,7 @@ class HttpPusher(object): finally: self._is_processing = False - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): """ Looks for unset notifications and dispatch them, in order Never call this directly: use _process which will only allow this to @@ -181,7 +177,7 @@ class HttpPusher(object): """ fn = self.store.get_unread_push_actions_for_user_in_range_for_http - unprocessed = yield fn( + unprocessed = await fn( self.user_id, self.last_stream_ordering, self.max_stream_ordering ) @@ -203,13 +199,13 @@ class HttpPusher(object): "app_display_name": self.app_display_name, }, ): - processed = yield self._process_one(push_action) + processed = await self._process_one(push_action) if processed: http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success( + pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( self.app_id, self.pushkey, self.user_id, @@ -224,14 +220,14 @@ class HttpPusher(object): if self.failing_since: self.failing_since = None - yield self.store.update_pusher_failing_since( + await self.store.update_pusher_failing_since( self.app_id, self.pushkey, self.user_id, self.failing_since ) else: http_push_failed_counter.inc() if not self.failing_since: self.failing_since = self.clock.time_msec() - yield self.store.update_pusher_failing_since( + await self.store.update_pusher_failing_since( self.app_id, self.pushkey, self.user_id, self.failing_since ) @@ -250,7 +246,7 @@ class HttpPusher(object): ) self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = yield self.store.update_pusher_last_stream_ordering( + pusher_still_exists = await self.store.update_pusher_last_stream_ordering( self.app_id, self.pushkey, self.user_id, @@ -263,7 +259,7 @@ class HttpPusher(object): return self.failing_since = None - yield self.store.update_pusher_failing_since( + await self.store.update_pusher_failing_since( self.app_id, self.pushkey, self.user_id, self.failing_since ) else: @@ -276,18 +272,17 @@ class HttpPusher(object): ) break - @defer.inlineCallbacks - def _process_one(self, push_action): + async def _process_one(self, push_action): if "notify" not in push_action["actions"]: return True tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"]) - badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) - event = yield self.store.get_event(push_action["event_id"], allow_none=True) + event = await self.store.get_event(push_action["event_id"], allow_none=True) if event is None: return True # It's been redacted - rejected = yield self.dispatch_push(event, tweaks, badge) + rejected = await self.dispatch_push(event, tweaks, badge) if rejected is False: return False @@ -301,11 +296,10 @@ class HttpPusher(object): ) else: logger.info("Pushkey %s was rejected: removing", pk) - yield self.hs.remove_pusher(self.app_id, pk, self.user_id) + await self.hs.remove_pusher(self.app_id, pk, self.user_id) return True - @defer.inlineCallbacks - def _build_notification_dict(self, event, tweaks, badge): + async def _build_notification_dict(self, event, tweaks, badge): priority = "low" if ( event.type == EventTypes.Encrypted @@ -335,7 +329,7 @@ class HttpPusher(object): } return d - ctx = yield push_tools.get_context_for_event( + ctx = await push_tools.get_context_for_event( self.storage, self.state_handler, event, self.user_id ) @@ -377,13 +371,12 @@ class HttpPusher(object): return d - @defer.inlineCallbacks - def dispatch_push(self, event, tweaks, badge): - notification_dict = yield self._build_notification_dict(event, tweaks, badge) + async def dispatch_push(self, event, tweaks, badge): + notification_dict = await self._build_notification_dict(event, tweaks, badge) if not notification_dict: return [] try: - resp = yield self.http_client.post_json_get_json( + resp = await self.http_client.post_json_get_json( self.url, notification_dict ) except Exception as e: @@ -400,8 +393,7 @@ class HttpPusher(object): rejected = resp["rejected"] return rejected - @defer.inlineCallbacks - def _send_badge(self, badge): + async def _send_badge(self, badge): """ Args: badge (int): number of unread messages @@ -424,7 +416,7 @@ class HttpPusher(object): } } try: - yield self.http_client.post_json_get_json(self.url, d) + await self.http_client.post_json_get_json(self.url, d) http_badges_processed_counter.inc() except Exception as e: logger.warning( diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index 0644a13cfc..d8f4a453cd 100644 --- a/synapse/push/presentable_names.py +++ b/synapse/push/presentable_names.py
@@ -16,8 +16,6 @@ import logging import re -from twisted.internet import defer - from synapse.api.constants import EventTypes logger = logging.getLogger(__name__) @@ -29,8 +27,7 @@ ALIAS_RE = re.compile(r"^#.*:.+$") ALL_ALONE = "Empty Room" -@defer.inlineCallbacks -def calculate_room_name( +async def calculate_room_name( store, room_state_ids, user_id, @@ -53,7 +50,7 @@ def calculate_room_name( """ # does it have a name? if (EventTypes.Name, "") in room_state_ids: - m_room_name = yield store.get_event( + m_room_name = await store.get_event( room_state_ids[(EventTypes.Name, "")], allow_none=True ) if m_room_name and m_room_name.content and m_room_name.content["name"]: @@ -61,7 +58,7 @@ def calculate_room_name( # does it have a canonical alias? if (EventTypes.CanonicalAlias, "") in room_state_ids: - canon_alias = yield store.get_event( + canon_alias = await store.get_event( room_state_ids[(EventTypes.CanonicalAlias, "")], allow_none=True ) if ( @@ -81,7 +78,7 @@ def calculate_room_name( my_member_event = None if (EventTypes.Member, user_id) in room_state_ids: - my_member_event = yield store.get_event( + my_member_event = await store.get_event( room_state_ids[(EventTypes.Member, user_id)], allow_none=True ) @@ -90,7 +87,7 @@ def calculate_room_name( and my_member_event.content["membership"] == "invite" ): if (EventTypes.Member, my_member_event.sender) in room_state_ids: - inviter_member_event = yield store.get_event( + inviter_member_event = await store.get_event( room_state_ids[(EventTypes.Member, my_member_event.sender)], allow_none=True, ) @@ -107,7 +104,7 @@ def calculate_room_name( # we're going to have to generate a name based on who's in the room, # so find out who is in the room that isn't the user. if EventTypes.Member in room_state_bytype_ids: - member_events = yield store.get_events( + member_events = await store.get_events( list(room_state_bytype_ids[EventTypes.Member].values()) ) all_members = [ diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 5dae4648c0..bc8f71916b 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py
@@ -13,53 +13,40 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage import Storage -@defer.inlineCallbacks -def get_badge_count(store, user_id): - invites = yield store.get_invited_rooms_for_local_user(user_id) - joins = yield store.get_rooms_for_user(user_id) - - my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read") +async def get_badge_count(store, user_id): + invites = await store.get_invited_rooms_for_local_user(user_id) + joins = await store.get_rooms_for_user(user_id) badge = len(invites) for room_id in joins: - if room_id in my_receipts_by_room: - last_unread_event_id = my_receipts_by_room[room_id] - - notifs = yield ( - store.get_unread_event_push_actions_by_room_for_user( - room_id, user_id, last_unread_event_id - ) - ) - # return one badge count per conversation, as count per - # message is so noisy as to be almost useless - badge += 1 if notifs["notify_count"] else 0 + unread_count = await store.get_unread_message_count_for_user(room_id, user_id) + # return one badge count per conversation, as count per + # message is so noisy as to be almost useless + badge += 1 if unread_count else 0 return badge -@defer.inlineCallbacks -def get_context_for_event(storage: Storage, state_handler, ev, user_id): +async def get_context_for_event(storage: Storage, state_handler, ev, user_id): ctx = {} - room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id) + room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id) # we no longer bother setting room_alias, and make room_name the # human-readable name instead, be that m.room.name, an alias or # a list of people in the room - name = yield calculate_room_name( + name = await calculate_room_name( storage.main, room_state_ids, user_id, fallback_to_single_member=False ) if name: ctx["name"] = name sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] - sender_state_event = yield storage.main.get_event(sender_state_event_id) + sender_state_event = await storage.main.get_event(sender_state_event_id) ctx["sender_display_name"] = name_from_member_event(sender_state_event) return ctx diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 7c5e47bc81..305fd00fc0 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py
@@ -19,8 +19,6 @@ from typing import TYPE_CHECKING, Dict, Union from prometheus_client import Gauge -from twisted.internet import defer - from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import PusherConfigException from synapse.push.emailpusher import EmailPusher @@ -52,7 +50,7 @@ class PusherPool: Note that it is expected that each pusher will have its own 'processing' loop which will send out the notifications in the background, rather than blocking until the notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and - Pusher.on_new_receipts are not expected to return deferreds. + Pusher.on_new_receipts are not expected to return awaitables. """ def __init__(self, hs: "HomeServer"): @@ -79,8 +77,7 @@ class PusherPool: return run_as_background_process("start_pushers", self._start_pushers) - @defer.inlineCallbacks - def add_pusher( + async def add_pusher( self, user_id, access_token, @@ -96,7 +93,7 @@ class PusherPool: """Creates a new pusher and adds it to the pool Returns: - Deferred[EmailPusher|HttpPusher] + EmailPusher|HttpPusher """ time_now_msec = self.clock.time_msec() @@ -126,9 +123,9 @@ class PusherPool: # create the pusher setting last_stream_ordering to the current maximum # stream ordering in event_push_actions, so it will process # pushes from this point onwards. - last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering() + last_stream_ordering = await self.store.get_latest_push_action_stream_ordering() - yield self.store.add_pusher( + await self.store.add_pusher( user_id=user_id, access_token=access_token, kind=kind, @@ -142,15 +139,14 @@ class PusherPool: last_stream_ordering=last_stream_ordering, profile_tag=profile_tag, ) - pusher = yield self.start_pusher_by_id(app_id, pushkey, user_id) + pusher = await self.start_pusher_by_id(app_id, pushkey, user_id) return pusher - @defer.inlineCallbacks - def remove_pushers_by_app_id_and_pushkey_not_user( + async def remove_pushers_by_app_id_and_pushkey_not_user( self, app_id, pushkey, not_user_id ): - to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) for p in to_remove: if p["user_name"] != not_user_id: logger.info( @@ -159,10 +155,9 @@ class PusherPool: pushkey, p["user_name"], ) - yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) + await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - @defer.inlineCallbacks - def remove_pushers_by_access_token(self, user_id, access_tokens): + async def remove_pushers_by_access_token(self, user_id, access_tokens): """Remove the pushers for a given user corresponding to a set of access_tokens. @@ -175,7 +170,7 @@ class PusherPool: return tokens = set(access_tokens) - for p in (yield self.store.get_pushers_by_user_id(user_id)): + for p in await self.store.get_pushers_by_user_id(user_id): if p["access_token"] in tokens: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", @@ -183,16 +178,15 @@ class PusherPool: p["pushkey"], p["user_name"], ) - yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) + await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - @defer.inlineCallbacks - def on_new_notifications(self, min_stream_id, max_stream_id): + async def on_new_notifications(self, min_stream_id, max_stream_id): if not self.pushers: # nothing to do here. return try: - users_affected = yield self.store.get_push_action_users_in_range( + users_affected = await self.store.get_push_action_users_in_range( min_stream_id, max_stream_id ) @@ -200,7 +194,7 @@ class PusherPool: if u in self.pushers: # Don't push if the user account has expired if self._account_validity.enabled: - expired = yield self.store.is_account_expired( + expired = await self.store.is_account_expired( u, self.clock.time_msec() ) if expired: @@ -212,8 +206,7 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_notifications") - @defer.inlineCallbacks - def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): + async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): if not self.pushers: # nothing to do here. return @@ -221,7 +214,7 @@ class PusherPool: try: # Need to subtract 1 from the minimum because the lower bound here # is not inclusive - users_affected = yield self.store.get_users_sent_receipts_between( + users_affected = await self.store.get_users_sent_receipts_between( min_stream_id - 1, max_stream_id ) @@ -241,12 +234,11 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_receipts") - @defer.inlineCallbacks - def start_pusher_by_id(self, app_id, pushkey, user_id): + async def start_pusher_by_id(self, app_id, pushkey, user_id): """Look up the details for the given pusher, and start it Returns: - Deferred[EmailPusher|HttpPusher|None]: The pusher started, if any + EmailPusher|HttpPusher|None: The pusher started, if any """ if not self._should_start_pushers: return @@ -254,7 +246,7 @@ class PusherPool: if not self._pusher_shard_config.should_handle(self._instance_name, user_id): return - resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) pusher_dict = None for r in resultlist: @@ -263,34 +255,29 @@ class PusherPool: pusher = None if pusher_dict: - pusher = yield self._start_pusher(pusher_dict) + pusher = await self._start_pusher(pusher_dict) return pusher - @defer.inlineCallbacks - def _start_pushers(self): + async def _start_pushers(self) -> None: """Start all the pushers - - Returns: - Deferred """ - pushers = yield self.store.get_all_pushers() + pushers = await self.store.get_all_pushers() # Stagger starting up the pushers so we don't completely drown the # process on start up. - yield concurrently_execute(self._start_pusher, pushers, 10) + await concurrently_execute(self._start_pusher, pushers, 10) logger.info("Started pushers") - @defer.inlineCallbacks - def _start_pusher(self, pusherdict): + async def _start_pusher(self, pusherdict): """Start the given pusher Args: pusherdict (dict): dict with the values pulled from the db table Returns: - Deferred[EmailPusher|HttpPusher] + EmailPusher|HttpPusher """ if not self._pusher_shard_config.should_handle( self._instance_name, pusherdict["user_name"] @@ -333,7 +320,7 @@ class PusherPool: user_id = pusherdict["user_name"] last_stream_ordering = pusherdict["last_stream_ordering"] if last_stream_ordering: - have_notifs = yield self.store.get_if_maybe_push_in_range_for_user( + have_notifs = await self.store.get_if_maybe_push_in_range_for_user( user_id, last_stream_ordering ) else: @@ -345,8 +332,7 @@ class PusherPool: return p - @defer.inlineCallbacks - def remove_pusher(self, app_id, pushkey, user_id): + async def remove_pusher(self, app_id, pushkey, user_id): appid_pushkey = "%s:%s" % (app_id, pushkey) byuser = self.pushers.get(user_id, {}) @@ -358,6 +344,6 @@ class PusherPool: synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() - yield self.store.delete_pusher_by_app_id_pushkey_user_id( + await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id, pushkey, user_id ) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 8cfcdb0573..abea2be4ef 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py
@@ -43,7 +43,7 @@ REQUIREMENTS = [ "jsonschema>=2.5.1", "frozendict>=1", "unpaddedbase64>=1.1.0", - "canonicaljson>=1.1.3", + "canonicaljson>=1.2.0", # we use the type definitions added in signedjson 1.1. "signedjson>=1.1.0", "pynacl>=1.2.1", diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index c287c4e269..ca065e819e 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py
@@ -78,7 +78,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): """ event_payloads = [] for event, context in event_and_contexts: - serialized_context = yield context.serialize(event, store) + serialized_context = yield defer.ensureDeferred( + context.serialize(event, store) + ) event_payloads.append( { diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index c981723c1a..b30e4d5039 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py
@@ -77,7 +77,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): extra_users (list(UserID)): Any extra users to notify about event """ - serialized_context = yield context.serialize(event, store) + serialized_context = yield defer.ensureDeferred(context.serialize(event, store)) payload = { "event": event.get_pdu_json(), diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index b8c95d045a..a8364d9793 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py
@@ -103,6 +103,14 @@ class DeleteRoomRestServlet(RestServlet): Codes.BAD_JSON, ) + purge = content.get("purge", True) + if not isinstance(purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + ret = await self.room_shutdown_handler.shutdown_room( room_id=room_id, new_room_user_id=content.get("new_room_user_id"), @@ -113,7 +121,8 @@ class DeleteRoomRestServlet(RestServlet): ) # Purge room - await self.pagination_handler.purge_room(room_id) + if purge: + await self.pagination_handler.purge_room(room_id) return (200, ret) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index a5c24fbd63..3f5bf75e59 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py
@@ -426,6 +426,7 @@ class SyncRestServlet(RestServlet): result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications result["summary"] = room.summary + result["org.matrix.msc2654.unread_count"] = room.unread_count return result diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 9a847130c0..20ddb9550b 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py
@@ -17,7 +17,9 @@ import logging import os import urllib +from typing import Awaitable +from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender from synapse.api.errors import Codes, SynapseError, cs_error @@ -240,14 +242,14 @@ class Responder(object): held can be cleaned up. """ - def write_to_consumer(self, consumer): + def write_to_consumer(self, consumer: IConsumer) -> Awaitable: """Stream response into consumer Args: - consumer (IConsumer) + consumer: The consumer to stream into. Returns: - Deferred: Resolves once the response has finished being written + Resolves once the response has finished being written """ pass diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 45628c07b4..6fb4039e98 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py
@@ -18,10 +18,11 @@ import errno import logging import os import shutil -from typing import Dict, Tuple +from typing import IO, Dict, Optional, Tuple import twisted.internet.error import twisted.web.http +from twisted.web.http import Request from twisted.web.resource import Resource from synapse.api.errors import ( @@ -40,6 +41,7 @@ from synapse.util.stringutils import random_string from ._base import ( FileInfo, + Responder, get_filename_from_headers, respond_404, respond_with_responder, @@ -135,19 +137,24 @@ class MediaRepository(object): self.recently_accessed_locals.add(media_id) async def create_content( - self, media_type, upload_name, content, content_length, auth_user - ): + self, + media_type: str, + upload_name: str, + content: IO, + content_length: int, + auth_user: str, + ) -> str: """Store uploaded content for a local user and return the mxc URL Args: - media_type(str): The content type of the file - upload_name(str): The name of the file + media_type: The content type of the file + upload_name: The name of the file content: A file like object that is the content to store - content_length(int): The length of the content - auth_user(str): The user_id of the uploader + content_length: The length of the content + auth_user: The user_id of the uploader Returns: - Deferred[str]: The mxc url of the stored content + The mxc url of the stored content """ media_id = random_string(24) @@ -170,19 +177,20 @@ class MediaRepository(object): return "mxc://%s/%s" % (self.server_name, media_id) - async def get_local_media(self, request, media_id, name): + async def get_local_media( + self, request: Request, media_id: str, name: Optional[str] + ) -> None: """Responds to reqests for local media, if exists, or returns 404. Args: - request(twisted.web.http.Request) - media_id (str): The media ID of the content. (This is the same as + request: The incoming request. + media_id: The media ID of the content. (This is the same as the file_id for local content.) - name (str|None): Optional name that, if specified, will be used as + name: Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: - Deferred: Resolves once a response has successfully been written - to request + Resolves once a response has successfully been written to request """ media_info = await self.store.get_local_media(media_id) if not media_info or media_info["quarantined_by"]: @@ -203,20 +211,20 @@ class MediaRepository(object): request, responder, media_type, media_length, upload_name ) - async def get_remote_media(self, request, server_name, media_id, name): + async def get_remote_media( + self, request: Request, server_name: str, media_id: str, name: Optional[str] + ) -> None: """Respond to requests for remote media. Args: - request(twisted.web.http.Request) - server_name (str): Remote server_name where the media originated. - media_id (str): The media ID of the content (as defined by the - remote server). - name (str|None): Optional name that, if specified, will be used as + request: The incoming request. + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). + name: Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: - Deferred: Resolves once a response has successfully been written - to request + Resolves once a response has successfully been written to request """ if ( self.federation_domain_whitelist is not None @@ -245,17 +253,16 @@ class MediaRepository(object): else: respond_404(request) - async def get_remote_media_info(self, server_name, media_id): + async def get_remote_media_info(self, server_name: str, media_id: str) -> dict: """Gets the media info associated with the remote file, downloading if necessary. Args: - server_name (str): Remote server_name where the media originated. - media_id (str): The media ID of the content (as defined by the - remote server). + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). Returns: - Deferred[dict]: The media_info of the file + The media info of the file """ if ( self.federation_domain_whitelist is not None @@ -278,7 +285,9 @@ class MediaRepository(object): return media_info - async def _get_remote_media_impl(self, server_name, media_id): + async def _get_remote_media_impl( + self, server_name: str, media_id: str + ) -> Tuple[Optional[Responder], dict]: """Looks for media in local cache, if not there then attempt to download from remote server. @@ -288,7 +297,7 @@ class MediaRepository(object): remote server). Returns: - Deferred[(Responder, media_info)] + A tuple of responder and the media info of the file. """ media_info = await self.store.get_cached_remote_media(server_name, media_id) @@ -319,19 +328,21 @@ class MediaRepository(object): responder = await self.media_storage.fetch_media(file_info) return responder, media_info - async def _download_remote_file(self, server_name, media_id, file_id): + async def _download_remote_file( + self, server_name: str, media_id: str, file_id: str + ) -> dict: """Attempt to download the remote file from the given server name, using the given file_id as the local id. Args: - server_name (str): Originating server - media_id (str): The media ID of the content (as defined by the + server_name: Originating server + media_id: The media ID of the content (as defined by the remote server). This is different than the file_id, which is locally generated. - file_id (str): Local file ID + file_id: Local file ID Returns: - Deferred[MediaInfo] + The media info of the file. """ file_info = FileInfo(server_name=server_name, file_id=file_id) @@ -549,25 +560,31 @@ class MediaRepository(object): return output_path async def _generate_thumbnails( - self, server_name, media_id, file_id, media_type, url_cache=False - ): + self, + server_name: Optional[str], + media_id: str, + file_id: str, + media_type: str, + url_cache: bool = False, + ) -> Optional[dict]: """Generate and store thumbnails for an image. Args: - server_name (str|None): The server name if remote media, else None if local - media_id (str): The media ID of the content. (This is the same as + server_name: The server name if remote media, else None if local + media_id: The media ID of the content. (This is the same as the file_id for local content) - file_id (str): Local file ID - media_type (str): The content type of the file - url_cache (bool): If we are thumbnailing images downloaded for the URL cache, + file_id: Local file ID + media_type: The content type of the file + url_cache: If we are thumbnailing images downloaded for the URL cache, used exclusively by the url previewer Returns: - Deferred[dict]: Dict with "width" and "height" keys of original image + Dict with "width" and "height" keys of original image or None if the + media cannot be thumbnailed. """ requirements = self._get_thumbnail_requirements(media_type) if not requirements: - return + return None input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=url_cache) @@ -584,7 +601,7 @@ class MediaRepository(object): m_height, self.max_image_pixels, ) - return + return None if thumbnailer.transpose_method is not None: m_width, m_height = await defer_to_thread( diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 66bc1c3360..858b6d3005 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py
@@ -12,13 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import contextlib import inspect import logging import os import shutil -from typing import Optional +from typing import IO, TYPE_CHECKING, Any, Optional, Sequence from twisted.protocols.basic import FileSender @@ -26,6 +25,12 @@ from synapse.logging.context import defer_to_thread, make_deferred_yieldable from synapse.util.file_consumer import BackgroundFileConsumer from ._base import FileInfo, Responder +from .filepath import MediaFilePaths + +if TYPE_CHECKING: + from synapse.server import HomeServer + + from .storage_provider import StorageProvider logger = logging.getLogger(__name__) @@ -34,20 +39,25 @@ class MediaStorage(object): """Responsible for storing/fetching files from local sources. Args: - hs (synapse.server.Homeserver) - local_media_directory (str): Base path where we store media on disk - filepaths (MediaFilePaths) - storage_providers ([StorageProvider]): List of StorageProvider that are - used to fetch and store files. + hs + local_media_directory: Base path where we store media on disk + filepaths + storage_providers: List of StorageProvider that are used to fetch and store files. """ - def __init__(self, hs, local_media_directory, filepaths, storage_providers): + def __init__( + self, + hs: "HomeServer", + local_media_directory: str, + filepaths: MediaFilePaths, + storage_providers: Sequence["StorageProvider"], + ): self.hs = hs self.local_media_directory = local_media_directory self.filepaths = filepaths self.storage_providers = storage_providers - async def store_file(self, source, file_info: FileInfo) -> str: + async def store_file(self, source: IO, file_info: FileInfo) -> str: """Write `source` to the on disk media store, and also any other configured storage providers @@ -69,7 +79,7 @@ class MediaStorage(object): return fname @contextlib.contextmanager - def store_into_file(self, file_info): + def store_into_file(self, file_info: FileInfo): """Context manager used to get a file like object to write into, as described by file_info. @@ -85,7 +95,7 @@ class MediaStorage(object): error. Args: - file_info (FileInfo): Info about the file to store + file_info: Info about the file to store Example: @@ -143,9 +153,9 @@ class MediaStorage(object): return FileResponder(open(local_path, "rb")) for provider in self.storage_providers: - res = provider.fetch(path, file_info) - # Fetch is supposed to return an Awaitable, but guard against - # improper implementations. + res = provider.fetch(path, file_info) # type: Any + # Fetch is supposed to return an Awaitable[Responder], but guard + # against improper implementations. if inspect.isawaitable(res): res = await res if res: @@ -174,9 +184,9 @@ class MediaStorage(object): os.makedirs(dirname) for provider in self.storage_providers: - res = provider.fetch(path, file_info) - # Fetch is supposed to return an Awaitable, but guard against - # improper implementations. + res = provider.fetch(path, file_info) # type: Any + # Fetch is supposed to return an Awaitable[Responder], but guard + # against improper implementations. if inspect.isawaitable(res): res = await res if res: @@ -190,17 +200,11 @@ class MediaStorage(object): raise Exception("file could not be found") - def _file_info_to_path(self, file_info): + def _file_info_to_path(self, file_info: FileInfo) -> str: """Converts file_info into a relative path. The path is suitable for storing files under a directory, e.g. used to store files on local FS under the base media repository directory. - - Args: - file_info (FileInfo) - - Returns: - str """ if file_info.url_cache: if file_info.thumbnail: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 13d1a6d2ed..e12f65a206 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -231,16 +231,16 @@ class PreviewUrlResource(DirectServeJsonResource): og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe)) respond_with_json_bytes(request, 200, og, send_cors=True) - async def _do_preview(self, url, user, ts): + async def _do_preview(self, url: str, user: str, ts: int) -> bytes: """Check the db, and download the URL and build a preview Args: - url (str): - user (str): - ts (int): + url: The URL to preview. + user: The user requesting the preview. + ts: The timestamp requested for the preview. Returns: - Deferred[bytes]: json-encoded og data + json-encoded og data """ # check the URL cache in the DB (which will also provide us with # historical previews, if we have any) diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 858680be26..a33f56e806 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py
@@ -16,62 +16,62 @@ import logging import os import shutil - -from twisted.internet import defer +from typing import Optional from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background +from ._base import FileInfo, Responder from .media_storage import FileResponder logger = logging.getLogger(__name__) -class StorageProvider(object): +class StorageProvider: """A storage provider is a service that can store uploaded media and retrieve them. """ - def store_file(self, path, file_info): + async def store_file(self, path: str, file_info: FileInfo): """Store the file described by file_info. The actual contents can be retrieved by reading the file in file_info.upload_path. Args: - path (str): Relative path of file in local cache - file_info (FileInfo) - - Returns: - Deferred + path: Relative path of file in local cache + file_info: The metadata of the file. """ - pass - def fetch(self, path, file_info): + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: """Attempt to fetch the file described by file_info and stream it into writer. Args: - path (str): Relative path of file in local cache - file_info (FileInfo) + path: Relative path of file in local cache + file_info: The metadata of the file. Returns: - Deferred(Responder): Returns a Responder if the provider has the file, - otherwise returns None. + Returns a Responder if the provider has the file, otherwise returns None. """ - pass class StorageProviderWrapper(StorageProvider): """Wraps a storage provider and provides various config options Args: - backend (StorageProvider) - store_local (bool): Whether to store new local files or not. - store_synchronous (bool): Whether to wait for file to be successfully + backend: The storage provider to wrap. + store_local: Whether to store new local files or not. + store_synchronous: Whether to wait for file to be successfully uploaded, or todo the upload in the background. - store_remote (bool): Whether remote media should be uploaded + store_remote: Whether remote media should be uploaded """ - def __init__(self, backend, store_local, store_synchronous, store_remote): + def __init__( + self, + backend: StorageProvider, + store_local: bool, + store_synchronous: bool, + store_remote: bool, + ): self.backend = backend self.store_local = store_local self.store_synchronous = store_synchronous @@ -80,15 +80,15 @@ class StorageProviderWrapper(StorageProvider): def __str__(self): return "StorageProviderWrapper[%s]" % (self.backend,) - def store_file(self, path, file_info): + async def store_file(self, path, file_info): if not file_info.server_name and not self.store_local: - return defer.succeed(None) + return None if file_info.server_name and not self.store_remote: - return defer.succeed(None) + return None if self.store_synchronous: - return self.backend.store_file(path, file_info) + return await self.backend.store_file(path, file_info) else: # TODO: Handle errors. def store(): @@ -98,10 +98,10 @@ class StorageProviderWrapper(StorageProvider): logger.exception("Error storing file") run_in_background(store) - return defer.succeed(None) + return None - def fetch(self, path, file_info): - return self.backend.fetch(path, file_info) + async def fetch(self, path, file_info): + return await self.backend.fetch(path, file_info) class FileStorageProviderBackend(StorageProvider): @@ -120,7 +120,7 @@ class FileStorageProviderBackend(StorageProvider): def __str__(self): return "FileStorageProviderBackend[%s]" % (self.base_directory,) - def store_file(self, path, file_info): + async def store_file(self, path, file_info): """See StorageProvider.store_file""" primary_fname = os.path.join(self.cache_directory, path) @@ -130,11 +130,11 @@ class FileStorageProviderBackend(StorageProvider): if not os.path.exists(dirname): os.makedirs(dirname) - return defer_to_thread( + return await defer_to_thread( self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname ) - def fetch(self, path, file_info): + async def fetch(self, path, file_info): """See StorageProvider.fetch""" backup_fname = os.path.join(self.base_directory, path) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index f39f556c20..edc3624fed 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py
@@ -172,6 +172,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_latest_event_ids_in_room.invalidate((room_id,)) + self.get_unread_message_count_for_user.invalidate_many((room_id,)) self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) if not backfilled: diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 504babaa7e..18297cf3b8 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -411,7 +411,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): _get_if_maybe_push_in_range_for_user_txn, ) - def add_push_actions_to_staging(self, event_id, user_id_actions): + async def add_push_actions_to_staging(self, event_id, user_id_actions): """Add the push actions for the event to the push action staging area. Args: @@ -457,7 +457,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ), ) - return self.db.runInteraction( + return await self.db.runInteraction( "add_push_actions_to_staging", _add_push_actions_to_staging_txn ) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 6f2e0d15cc..0c9c02afa1 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py
@@ -53,6 +53,47 @@ event_counter = Counter( ["type", "origin_type", "origin_entity"], ) +STATE_EVENT_TYPES_TO_MARK_UNREAD = { + EventTypes.Topic, + EventTypes.Name, + EventTypes.RoomAvatar, + EventTypes.Tombstone, +} + + +def should_count_as_unread(event: EventBase, context: EventContext) -> bool: + # Exclude rejected and soft-failed events. + if context.rejected or event.internal_metadata.is_soft_failed(): + return False + + # Exclude notices. + if ( + not event.is_state() + and event.type == EventTypes.Message + and event.content.get("msgtype") == "m.notice" + ): + return False + + # Exclude edits. + relates_to = event.content.get("m.relates_to", {}) + if relates_to.get("rel_type") == RelationTypes.REPLACE: + return False + + # Mark events that have a non-empty string body as unread. + body = event.content.get("body") + if isinstance(body, str) and body: + return True + + # Mark some state events as unread. + if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD: + return True + + # Mark encrypted events as unread. + if not event.is_state() and event.type == EventTypes.Encrypted: + return True + + return False + def encode_json(json_object): """ @@ -196,6 +237,10 @@ class PersistEventsStore: event_counter.labels(event.type, origin_type, origin_entity).inc() + self.store.get_unread_message_count_for_user.invalidate_many( + (event.room_id,), + ) + for room_id, new_state in current_state_for_room.items(): self.store.get_current_state_ids.prefill((room_id,), new_state) @@ -817,8 +862,9 @@ class PersistEventsStore: "contains_url": ( "url" in event.content and isinstance(event.content["url"], str) ), + "count_as_unread": should_count_as_unread(event, context), } - for event, _ in events_and_contexts + for event, context in events_and_contexts ], ) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index e812c67078..b03b259636 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py
@@ -41,9 +41,15 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import Database +from synapse.storage.types import Cursor from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import ( + Cache, + _CacheContext, + cached, + cachedInlineCallbacks, +) from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -1358,6 +1364,84 @@ class EventsWorkerStore(SQLBaseStore): desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) + @cached(tree=True, cache_context=True) + async def get_unread_message_count_for_user( + self, room_id: str, user_id: str, cache_context: _CacheContext, + ) -> int: + """Retrieve the count of unread messages for the given room and user. + + Args: + room_id: The ID of the room to count unread messages in. + user_id: The ID of the user to count unread messages for. + + Returns: + The number of unread messages for the given user in the given room. + """ + with Measure(self._clock, "get_unread_message_count_for_user"): + last_read_event_id = await self.get_last_receipt_event_id_for_user( + user_id=user_id, + room_id=room_id, + receipt_type="m.read", + on_invalidate=cache_context.invalidate, + ) + + return await self.db.runInteraction( + "get_unread_message_count_for_user", + self._get_unread_message_count_for_user_txn, + user_id, + room_id, + last_read_event_id, + ) + + def _get_unread_message_count_for_user_txn( + self, + txn: Cursor, + user_id: str, + room_id: str, + last_read_event_id: Optional[str], + ) -> int: + if last_read_event_id: + # Get the stream ordering for the last read event. + stream_ordering = self.db.simple_select_one_onecol_txn( + txn=txn, + table="events", + keyvalues={"room_id": room_id, "event_id": last_read_event_id}, + retcol="stream_ordering", + ) + else: + # If there's no read receipt for that room, it probably means the user hasn't + # opened it yet, in which case use the stream ID of their join event. + # We can't just set it to 0 otherwise messages from other local users from + # before this user joined will be counted as well. + txn.execute( + """ + SELECT stream_ordering FROM local_current_membership + LEFT JOIN events USING (event_id, room_id) + WHERE membership = 'join' + AND user_id = ? + AND room_id = ? + """, + (user_id, room_id), + ) + row = txn.fetchone() + + if row is None: + return 0 + + stream_ordering = row[0] + + # Count the messages that qualify as unread after the stream ordering we've just + # retrieved. + sql = """ + SELECT COUNT(*) FROM events + WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread + """ + + txn.execute(sql, (user_id, room_id, stream_ordering)) + row = txn.fetchone() + + return row[0] if row else 0 + AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/data_stores/main/purge_events.py
index 6546569139..b53fe35c33 100644 --- a/synapse/storage/data_stores/main/purge_events.py +++ b/synapse/storage/data_stores/main/purge_events.py
@@ -62,6 +62,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): # event_json # event_push_actions # event_reference_hashes + # event_relations # event_search # event_to_state_groups # events @@ -209,6 +210,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): "event_edges", "event_forward_extremities", "event_reference_hashes", + "event_relations", "event_search", "rejections", ): diff --git a/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql new file mode 100644
index 0000000000..531b532c73 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql
@@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Store a boolean value in the events table for whether the event should be counted in +-- the unread_count property of sync responses. +ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN; diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 3be20c866a..ce8757a400 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -49,11 +49,11 @@ from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3E from synapse.storage.types import Connection, Cursor from synapse.types import Collection -logger = logging.getLogger(__name__) - # python 3 does not have a maximum int value MAX_TXN_ID = 2 ** 63 - 1 +logger = logging.getLogger(__name__) + sql_logger = logging.getLogger("synapse.storage.SQL") transaction_logger = logging.getLogger("synapse.storage.txn") perf_logger = logging.getLogger("synapse.storage.TIME") @@ -233,7 +233,7 @@ class LoggingTransaction: try: return func(sql, *args) except Exception as e: - logger.debug("[SQL FAIL] {%s} %s", self.name, e) + sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e) raise finally: secs = time.time() - start @@ -419,7 +419,7 @@ class Database(object): except self.engine.module.OperationalError as e: # This can happen if the database disappears mid # transaction. - logger.warning( + transaction_logger.warning( "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N, ) if i < N: @@ -427,18 +427,20 @@ class Database(object): try: conn.rollback() except self.engine.module.Error as e1: - logger.warning("[TXN EROLL] {%s} %s", name, e1) + transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1) continue raise except self.engine.module.DatabaseError as e: if self.engine.is_deadlock(e): - logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N) + transaction_logger.warning( + "[TXN DEADLOCK] {%s} %d/%d", name, i, N + ) if i < N: i += 1 try: conn.rollback() except self.engine.module.Error as e1: - logger.warning( + transaction_logger.warning( "[TXN EROLL] {%s} %s", name, e1, ) continue @@ -478,7 +480,7 @@ class Database(object): # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236 cursor.close() except Exception as e: - logger.debug("[TXN FAIL] {%s} %s", name, e) + transaction_logger.debug("[TXN FAIL] {%s} %s", name, e) raise finally: end = monotonic_time() diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 78fbdcdee8..4a164834d9 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py
@@ -25,7 +25,7 @@ from prometheus_client import Counter, Histogram from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.events import FrozenEvent +from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process @@ -192,12 +192,11 @@ class EventsPersistenceStorage(object): self._event_persist_queue = _EventPeristenceQueue() self._state_resolution_handler = hs.get_state_resolution_handler() - @defer.inlineCallbacks - def persist_events( + async def persist_events( self, - events_and_contexts: List[Tuple[FrozenEvent, EventContext]], + events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool = False, - ): + ) -> int: """ Write events to the database Args: @@ -207,7 +206,7 @@ class EventsPersistenceStorage(object): which might update the current state etc. Returns: - Deferred[int]: the stream ordering of the latest persisted event + the stream ordering of the latest persisted event """ partitioned = {} for event, ctx in events_and_contexts: @@ -223,22 +222,19 @@ class EventsPersistenceStorage(object): for room_id in partitioned: self._maybe_start_persisting(room_id) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) - max_persisted_id = yield self.main_store.get_current_events_token() - - return max_persisted_id + return self.main_store.get_current_events_token() - @defer.inlineCallbacks - def persist_event( - self, event: FrozenEvent, context: EventContext, backfilled: bool = False - ): + async def persist_event( + self, event: EventBase, context: EventContext, backfilled: bool = False + ) -> Tuple[int, int]: """ Returns: - Deferred[Tuple[int, int]]: the stream ordering of ``event``, - and the stream ordering of the latest persisted event + The stream ordering of `event`, and the stream ordering of the + latest persisted event """ deferred = self._event_persist_queue.add_to_queue( event.room_id, [(event, context)], backfilled=backfilled @@ -246,9 +242,9 @@ class EventsPersistenceStorage(object): self._maybe_start_persisting(event.room_id) - yield make_deferred_yieldable(deferred) + await make_deferred_yieldable(deferred) - max_persisted_id = yield self.main_store.get_current_events_token() + max_persisted_id = self.main_store.get_current_events_token() return (event.internal_metadata.stream_ordering, max_persisted_id) def _maybe_start_persisting(self, room_id: str): @@ -262,7 +258,7 @@ class EventsPersistenceStorage(object): async def _persist_events( self, - events_and_contexts: List[Tuple[FrozenEvent, EventContext]], + events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool = False, ): """Calculates the change to current state and forward extremities, and @@ -439,7 +435,7 @@ class EventsPersistenceStorage(object): async def _calculate_new_extremities( self, room_id: str, - event_contexts: List[Tuple[FrozenEvent, EventContext]], + event_contexts: List[Tuple[EventBase, EventContext]], latest_event_ids: List[str], ): """Calculates the new forward extremities for a room given events to @@ -497,7 +493,7 @@ class EventsPersistenceStorage(object): async def _get_new_state_after_events( self, room_id: str, - events_context: List[Tuple[FrozenEvent, EventContext]], + events_context: List[Tuple[EventBase, EventContext]], old_latest_event_ids: Iterable[str], new_latest_event_ids: Iterable[str], ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]: @@ -683,7 +679,7 @@ class EventsPersistenceStorage(object): async def _is_server_still_joined( self, room_id: str, - ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]], + ev_ctx_rm: List[Tuple[EventBase, EventContext]], delta: DeltaState, current_state: Optional[StateMap[str]], potentially_left_users: Set[str], diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index fdc0abf5cf..79d9f06e2e 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py
@@ -15,8 +15,7 @@ import itertools import logging - -from twisted.internet import defer +from typing import Set logger = logging.getLogger(__name__) @@ -28,49 +27,48 @@ class PurgeEventsStorage(object): def __init__(self, hs, stores): self.stores = stores - @defer.inlineCallbacks - def purge_room(self, room_id: str): + async def purge_room(self, room_id: str): """Deletes all record of a room """ - state_groups_to_delete = yield self.stores.main.purge_room(room_id) - yield self.stores.state.purge_room_state(room_id, state_groups_to_delete) + state_groups_to_delete = await self.stores.main.purge_room(room_id) + await self.stores.state.purge_room_state(room_id, state_groups_to_delete) - @defer.inlineCallbacks - def purge_history(self, room_id, token, delete_local_events): + async def purge_history( + self, room_id: str, token: str, delete_local_events: bool + ) -> None: """Deletes room history before a certain point Args: - room_id (str): + room_id: The room ID - token (str): A topological token to delete events before + token: A topological token to delete events before - delete_local_events (bool): + delete_local_events: if True, we will delete local events as well as remote ones (instead of just marking them as outliers and deleting their state groups). """ - state_groups = yield self.stores.main.purge_history( + state_groups = await self.stores.main.purge_history( room_id, token, delete_local_events ) logger.info("[purge] finding state groups that can be deleted") - sg_to_delete = yield self._find_unreferenced_groups(state_groups) + sg_to_delete = await self._find_unreferenced_groups(state_groups) - yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) + await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) - @defer.inlineCallbacks - def _find_unreferenced_groups(self, state_groups): + async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]: """Used when purging history to figure out which state groups can be deleted. Args: - state_groups (set[int]): Set of state groups referenced by events + state_groups: Set of state groups referenced by events that are going to be deleted. Returns: - Deferred[set[int]] The set of state groups that can be deleted. + The set of state groups that can be deleted. """ # Graph of state group -> previous group graph = {} @@ -93,7 +91,7 @@ class PurgeEventsStorage(object): current_search = set(itertools.islice(next_to_search, 100)) next_to_search -= current_search - referenced = yield self.stores.main.get_referenced_state_groups( + referenced = await self.stores.main.get_referenced_state_groups( current_search ) referenced_groups |= referenced @@ -102,7 +100,7 @@ class PurgeEventsStorage(object): # groups that are referenced. current_search -= referenced - edges = yield self.stores.state.get_previous_state_groups(current_search) + edges = await self.stores.state.get_previous_state_groups(current_search) prevs = set(edges.values()) # We don't bother re-handling groups we've already seen diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index dc568476f4..49ee9c9a74 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py
@@ -14,13 +14,12 @@ # limitations under the License. import logging -from typing import Iterable, List, TypeVar +from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar import attr -from twisted.internet import defer - from synapse.api.constants import EventTypes +from synapse.events import EventBase from synapse.types import StateMap logger = logging.getLogger(__name__) @@ -34,16 +33,16 @@ class StateFilter(object): """A filter used when querying for state. Attributes: - types (dict[str, set[str]|None]): Map from type to set of state keys (or - None). This specifies which state_keys for the given type to fetch - from the DB. If None then all events with that type are fetched. If - the set is empty then no events with that type are fetched. - include_others (bool): Whether to fetch events with types that do not + types: Map from type to set of state keys (or None). This specifies + which state_keys for the given type to fetch from the DB. If None + then all events with that type are fetched. If the set is empty + then no events with that type are fetched. + include_others: Whether to fetch events with types that do not appear in `types`. """ - types = attr.ib() - include_others = attr.ib(default=False) + types = attr.ib(type=Dict[str, Optional[Set[str]]]) + include_others = attr.ib(default=False, type=bool) def __attrs_post_init__(self): # If `include_others` is set we canonicalise the filter by removing @@ -52,36 +51,35 @@ class StateFilter(object): self.types = {k: v for k, v in self.types.items() if v is not None} @staticmethod - def all(): + def all() -> "StateFilter": """Creates a filter that fetches everything. Returns: - StateFilter + The new state filter. """ return StateFilter(types={}, include_others=True) @staticmethod - def none(): + def none() -> "StateFilter": """Creates a filter that fetches nothing. Returns: - StateFilter + The new state filter. """ return StateFilter(types={}, include_others=False) @staticmethod - def from_types(types): + def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": """Creates a filter that only fetches the given types Args: - types (Iterable[tuple[str, str|None]]): A list of type and state - keys to fetch. A state_key of None fetches everything for - that type + types: A list of type and state keys to fetch. A state_key of None + fetches everything for that type Returns: - StateFilter + The new state filter. """ - type_dict = {} + type_dict = {} # type: Dict[str, Optional[Set[str]]] for typ, s in types: if typ in type_dict: if type_dict[typ] is None: @@ -91,24 +89,24 @@ class StateFilter(object): type_dict[typ] = None continue - type_dict.setdefault(typ, set()).add(s) + type_dict.setdefault(typ, set()).add(s) # type: ignore return StateFilter(types=type_dict) @staticmethod - def from_lazy_load_member_list(members): + def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": """Creates a filter that returns all non-member events, plus the member events for the given users Args: - members (iterable[str]): Set of user IDs + members: Set of user IDs Returns: - StateFilter + The new state filter """ return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) - def return_expanded(self): + def return_expanded(self) -> "StateFilter": """Creates a new StateFilter where type wild cards have been removed (except for memberships). The returned filter is a superset of the current one, i.e. anything that passes the current filter will pass @@ -130,7 +128,7 @@ class StateFilter(object): return all non-member events Returns: - StateFilter + The new state filter. """ if self.is_full(): @@ -167,7 +165,7 @@ class StateFilter(object): include_others=True, ) - def make_sql_filter_clause(self): + def make_sql_filter_clause(self) -> Tuple[str, List[str]]: """Converts the filter to an SQL clause. For example: @@ -179,13 +177,12 @@ class StateFilter(object): Returns: - tuple[str, list]: The SQL string (may be empty) and arguments. An - empty SQL string is returned when the filter matches everything - (i.e. is "full"). + The SQL string (may be empty) and arguments. An empty SQL string is + returned when the filter matches everything (i.e. is "full"). """ where_clause = "" - where_args = [] + where_args = [] # type: List[str] if self.is_full(): return where_clause, where_args @@ -221,7 +218,7 @@ class StateFilter(object): return where_clause, where_args - def max_entries_returned(self): + def max_entries_returned(self) -> Optional[int]: """Returns the maximum number of entries this filter will return if known, otherwise returns None. @@ -260,33 +257,33 @@ class StateFilter(object): return filtered_state - def is_full(self): + def is_full(self) -> bool: """Whether this filter fetches everything or not Returns: - bool + True if the filter fetches everything. """ return self.include_others and not self.types - def has_wildcards(self): + def has_wildcards(self) -> bool: """Whether the filter includes wildcards or is attempting to fetch specific state. Returns: - bool + True if the filter includes wildcards. """ return self.include_others or any( state_keys is None for state_keys in self.types.values() ) - def concrete_types(self): + def concrete_types(self) -> List[Tuple[str, str]]: """Returns a list of concrete type/state_keys (i.e. not None) that will be fetched. This will be a complete list if `has_wildcards` returns False, but otherwise will be a subset (or even empty). Returns: - list[tuple[str,str]] + A list of type/state_keys tuples. """ return [ (t, s) @@ -295,7 +292,7 @@ class StateFilter(object): for s in state_keys ] - def get_member_split(self): + def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: """Return the filter split into two: one which assumes it's exclusively matching against member state, and one which assumes it's matching against non member state. @@ -307,7 +304,7 @@ class StateFilter(object): state caches). Returns: - tuple[StateFilter, StateFilter]: The member and non member filters + The member and non member filters """ if EventTypes.Member in self.types: @@ -340,6 +337,9 @@ class StateGroupStorage(object): """Given a state group try to return a previous group and a delta between the old and the new. + Args: + state_group: The state group used to retrieve state deltas. + Returns: Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: (prev_group, delta_ids) @@ -347,55 +347,59 @@ class StateGroupStorage(object): return self.stores.state.get_state_group_delta(state_group) - @defer.inlineCallbacks - def get_state_groups_ids(self, _room_id, event_ids): + async def get_state_groups_ids( + self, _room_id: str, event_ids: Iterable[str] + ) -> Dict[int, StateMap[str]]: """Get the event IDs of all the state for the state groups for the given events Args: - _room_id (str): id of the room for these events - event_ids (iterable[str]): ids of the events + _room_id: id of the room for these events + event_ids: ids of the events Returns: - Deferred[dict[int, StateMap[str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + dict of state_group_id -> (dict of (type, state_key) -> event id) """ if not event_ids: return {} - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) - group_to_state = yield self.stores.state._get_state_for_groups(groups) + group_to_state = await self.stores.state._get_state_for_groups(groups) return group_to_state - @defer.inlineCallbacks - def get_state_ids_for_group(self, state_group): + async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]: """Get the event IDs of all the state in the given state group Args: - state_group (int) + state_group: A state group for which we want to get the state IDs. Returns: - Deferred[dict]: Resolves to a map of (type, state_key) -> event_id + Resolves to a map of (type, state_key) -> event_id """ - group_to_state = yield self._get_state_for_groups((state_group,)) + group_to_state = await self._get_state_for_groups((state_group,)) return group_to_state[state_group] - @defer.inlineCallbacks - def get_state_groups(self, room_id, event_ids): + async def get_state_groups( + self, room_id: str, event_ids: Iterable[str] + ) -> Dict[int, List[EventBase]]: """ Get the state groups for the given list of event_ids + + Args: + room_id: ID of the room for these events. + event_ids: The event IDs to retrieve state for. + Returns: - Deferred[dict[int, list[EventBase]]]: - dict of state_group_id -> list of state events. + dict of state_group_id -> list of state events. """ if not event_ids: return {} - group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) + group_to_ids = await self.get_state_groups_ids(room_id, event_ids) - state_event_map = yield self.stores.main.get_events( + state_event_map = await self.stores.main.get_events( [ ev_id for group_ids in group_to_ids.values() @@ -423,31 +427,34 @@ class StateGroupStorage(object): groups: list of state group IDs to query state_filter: The state filter used to fetch state from the database. + Returns: Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. """ return self.stores.state._get_state_groups_from_groups(groups, state_filter) - @defer.inlineCallbacks - def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): + async def get_state_for_events( + self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() + ): """Given a list of event_ids and type tuples, return a list of state dicts for each event. + Args: - event_ids (list[string]) - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_ids: The events to fetch the state of. + state_filter: The state filter used to fetch state. + Returns: - deferred: A dict of (event_id) -> (type, state_key) -> [state_events] + A dict of (event_id) -> (type, state_key) -> [state_events] """ - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) - group_to_state = yield self.stores.state._get_state_for_groups( + group_to_state = await self.stores.state._get_state_for_groups( groups, state_filter ) - state_event_map = yield self.stores.main.get_events( + state_event_map = await self.stores.main.get_events( [ev_id for sd in group_to_state.values() for ev_id in sd.values()], get_prev_content=False, ) @@ -463,24 +470,24 @@ class StateGroupStorage(object): return {event: event_to_state[event] for event in event_ids} - @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): + async def get_state_ids_for_events( + self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() + ): """ Get the state dicts corresponding to a list of events, containing the event_ids of the state events (as opposed to the events themselves) Args: - event_ids(list(str)): events whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_ids: events whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: - A deferred dict from event_id -> (type, state_key) -> event_id + A dict from event_id -> (type, state_key) -> event_id """ - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) - group_to_state = yield self.stores.state._get_state_for_groups( + group_to_state = await self.stores.state._get_state_for_groups( groups, state_filter ) @@ -491,36 +498,36 @@ class StateGroupStorage(object): return {event: event_to_state[event] for event in event_ids} - @defer.inlineCallbacks - def get_state_for_event(self, event_id, state_filter=StateFilter.all()): + async def get_state_for_event( + self, event_id: str, state_filter: StateFilter = StateFilter.all() + ): """ Get the state dict corresponding to a particular event Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: - A deferred dict from (type, state_key) -> state_event + A dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_for_events([event_id], state_filter) + state_map = await self.get_state_for_events([event_id], state_filter) return state_map[event_id] - @defer.inlineCallbacks - def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): + async def get_state_ids_for_event( + self, event_id: str, state_filter: StateFilter = StateFilter.all() + ): """ Get the state dict corresponding to a particular event Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: A deferred dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_ids_for_events([event_id], state_filter) + state_map = await self.get_state_ids_for_events([event_id], state_filter) return state_map[event_id] def _get_state_for_groups( @@ -530,9 +537,8 @@ class StateGroupStorage(object): filtering by type/state_key Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - state_filter (StateFilter): The state filter used to fetch state + groups: list of state groups for which we want to get the state. + state_filter: The state filter used to fetch state. from the database. Returns: Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. @@ -540,18 +546,23 @@ class StateGroupStorage(object): return self.stores.state._get_state_for_groups(groups, state_filter) def store_state_group( - self, event_id, room_id, prev_group, delta_ids, current_state_ids + self, + event_id: str, + room_id: str, + prev_group: Optional[int], + delta_ids: Optional[dict], + current_state_ids: dict, ): """Store a new set of state, returning a newly assigned state group. Args: - event_id (str): The event ID for which the state was calculated - room_id (str) - prev_group (int|None): A previous state group for the room, optional. - delta_ids (dict|None): The delta between state at `prev_group` and + event_id: The event ID for which the state was calculated. + room_id: ID of the room for which the state was calculated. + prev_group: A previous state group for the room, optional. + delta_ids: The delta between state at `prev_group` and `current_state_ids`, if `prev_group` was given. Same format as `current_state_ids`. - current_state_ids (dict): The state to store. Map of (type, state_key) + current_state_ids: The state to store. Map of (type, state_key) to event_id. Returns: diff --git a/synapse/visibility.py b/synapse/visibility.py
index 0f042c5696..e3da7744d2 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py
@@ -16,8 +16,6 @@ import logging import operator -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.events.utils import prune_event from synapse.storage import Storage @@ -39,8 +37,7 @@ MEMBERSHIP_PRIORITY = ( ) -@defer.inlineCallbacks -def filter_events_for_client( +async def filter_events_for_client( storage: Storage, user_id, events, @@ -67,19 +64,19 @@ def filter_events_for_client( also be called to check whether a user can see the state at a given point. Returns: - Deferred[list[synapse.events.EventBase]] + list[synapse.events.EventBase] """ # Filter out events that have been soft failed so that we don't relay them # to clients. events = [e for e in events if not e.internal_metadata.is_soft_failed()] types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) - event_id_to_state = yield storage.state.get_state_for_events( + event_id_to_state = await storage.state.get_state_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types(types), ) - ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user( + ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id ) @@ -90,7 +87,7 @@ def filter_events_for_client( else [] ) - erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) + erased_senders = await storage.main.are_users_erased((e.sender for e in events)) if filter_send_to_client: room_ids = {e.room_id for e in events} @@ -99,7 +96,7 @@ def filter_events_for_client( for room_id in room_ids: retention_policies[ room_id - ] = yield storage.main.get_retention_policy_for_room(room_id) + ] = await storage.main.get_retention_policy_for_room(room_id) def allowed(event): """ @@ -254,8 +251,7 @@ def filter_events_for_client( return list(filtered_events) -@defer.inlineCallbacks -def filter_events_for_server( +async def filter_events_for_server( storage: Storage, server_name, events, @@ -277,7 +273,7 @@ def filter_events_for_server( backfill or not. Returns - Deferred[list[FrozenEvent]] + list[FrozenEvent] """ def is_sender_erased(event, erased_senders): @@ -321,7 +317,7 @@ def filter_events_for_server( # Lets check to see if all the events have a history visibility # of "shared" or "world_readable". If that's the case then we don't # need to check membership (as we know the server is in the room). - event_to_state_ids = yield storage.state.get_state_ids_for_events( + event_to_state_ids = await storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""),) @@ -339,14 +335,14 @@ def filter_events_for_server( if not visibility_ids: all_open = True else: - event_map = yield storage.main.get_events(visibility_ids) + event_map = await storage.main.get_events(visibility_ids) all_open = all( e.content.get("history_visibility") in (None, "shared", "world_readable") for e in event_map.values() ) if not check_history_visibility_only: - erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) + erased_senders = await storage.main.are_users_erased((e.sender for e in events)) else: # We don't want to check whether users are erased, which is equivalent # to no users having been erased. @@ -375,7 +371,7 @@ def filter_events_for_server( # first, for each event we're wanting to return, get the event_ids # of the history vis and membership state at those events. - event_to_state_ids = yield storage.state.get_state_ids_for_events( + event_to_state_ids = await storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) @@ -405,7 +401,7 @@ def filter_events_for_server( return False return state_key[idx + 1 :] == server_name - event_map = yield storage.main.get_events( + event_map = await storage.main.get_events( [e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])] )