From 3ab55d43bd66b377c1ed94a40931eba98dd07b01 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 18 Oct 2021 15:01:10 -0400 Subject: Add missing type hints to synapse.api. (#11109) * Convert UserPresenceState to attrs. * Remove args/kwargs from error classes and explicitly pass msg/errorcode. --- synapse/api/auth.py | 14 +++++++-- synapse/api/errors.py | 69 +++++++++++++++------------------------------ synapse/api/filtering.py | 18 ++++++------ synapse/api/presence.py | 51 ++++++++++++++++----------------- synapse/api/ratelimiting.py | 4 +-- synapse/api/urls.py | 13 ++++----- 6 files changed, 75 insertions(+), 94 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e6ca9232ee..44883c6663 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -245,7 +245,7 @@ class Auth: async def validate_appservice_can_control_user_id( self, app_service: ApplicationService, user_id: str - ): + ) -> None: """Validates that the app service is allowed to control the given user. @@ -618,5 +618,13 @@ class Auth: % (user_id, room_id), ) - async def check_auth_blocking(self, *args, **kwargs) -> None: - await self._auth_blocking.check_auth_blocking(*args, **kwargs) + async def check_auth_blocking( + self, + user_id: Optional[str] = None, + threepid: Optional[dict] = None, + user_type: Optional[str] = None, + requester: Optional[Requester] = None, + ) -> None: + await self._auth_blocking.check_auth_blocking( + user_id=user_id, threepid=threepid, user_type=user_type, requester=requester + ) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 9480f448d7..685d1c25cf 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -18,7 +18,7 @@ import logging import typing from http import HTTPStatus -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from twisted.web import http @@ -143,7 +143,7 @@ class SynapseError(CodeMessageException): super().__init__(code, msg) self.errcode = errcode - def error_dict(self): + def error_dict(self) -> "JsonDict": return cs_error(self.msg, self.errcode) @@ -175,7 +175,7 @@ class ProxiedRequestError(SynapseError): else: self._additional_fields = dict(additional_fields) - def error_dict(self): + def error_dict(self) -> "JsonDict": return cs_error(self.msg, self.errcode, **self._additional_fields) @@ -196,7 +196,7 @@ class ConsentNotGivenError(SynapseError): ) self._consent_uri = consent_uri - def error_dict(self): + def error_dict(self) -> "JsonDict": return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri) @@ -262,14 +262,10 @@ class InteractiveAuthIncompleteError(Exception): class UnrecognizedRequestError(SynapseError): """An error indicating we don't understand the request you're trying to make""" - def __init__(self, *args, **kwargs): - if "errcode" not in kwargs: - kwargs["errcode"] = Codes.UNRECOGNIZED - if len(args) == 0: - message = "Unrecognized request" - else: - message = args[0] - super().__init__(400, message, **kwargs) + def __init__( + self, msg: str = "Unrecognized request", errcode: str = Codes.UNRECOGNIZED + ): + super().__init__(400, msg, errcode) class NotFoundError(SynapseError): @@ -284,10 +280,8 @@ class AuthError(SynapseError): other poorly-defined times. """ - def __init__(self, *args, **kwargs): - if "errcode" not in kwargs: - kwargs["errcode"] = Codes.FORBIDDEN - super().__init__(*args, **kwargs) + def __init__(self, code: int, msg: str, errcode: str = Codes.FORBIDDEN): + super().__init__(code, msg, errcode) class InvalidClientCredentialsError(SynapseError): @@ -321,7 +315,7 @@ class InvalidClientTokenError(InvalidClientCredentialsError): super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN") self._soft_logout = soft_logout - def error_dict(self): + def error_dict(self) -> "JsonDict": d = super().error_dict() d["soft_logout"] = self._soft_logout return d @@ -345,7 +339,7 @@ class ResourceLimitError(SynapseError): self.limit_type = limit_type super().__init__(code, msg, errcode=errcode) - def error_dict(self): + def error_dict(self) -> "JsonDict": return cs_error( self.msg, self.errcode, @@ -357,32 +351,17 @@ class ResourceLimitError(SynapseError): class EventSizeError(SynapseError): """An error raised when an event is too big.""" - def __init__(self, *args, **kwargs): - if "errcode" not in kwargs: - kwargs["errcode"] = Codes.TOO_LARGE - super().__init__(413, *args, **kwargs) - - -class EventStreamError(SynapseError): - """An error raised when there a problem with the event stream.""" - - def __init__(self, *args, **kwargs): - if "errcode" not in kwargs: - kwargs["errcode"] = Codes.BAD_PAGINATION - super().__init__(*args, **kwargs) + def __init__(self, msg: str): + super().__init__(413, msg, Codes.TOO_LARGE) class LoginError(SynapseError): """An error raised when there was a problem logging in.""" - pass - class StoreError(SynapseError): """An error raised when there was a problem storing some data.""" - pass - class InvalidCaptchaError(SynapseError): def __init__( @@ -395,7 +374,7 @@ class InvalidCaptchaError(SynapseError): super().__init__(code, msg, errcode) self.error_url = error_url - def error_dict(self): + def error_dict(self) -> "JsonDict": return cs_error(self.msg, self.errcode, error_url=self.error_url) @@ -412,7 +391,7 @@ class LimitExceededError(SynapseError): super().__init__(code, msg, errcode) self.retry_after_ms = retry_after_ms - def error_dict(self): + def error_dict(self) -> "JsonDict": return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) @@ -443,10 +422,8 @@ class UnsupportedRoomVersionError(SynapseError): class ThreepidValidationError(SynapseError): """An error raised when there was a problem authorising an event.""" - def __init__(self, *args, **kwargs): - if "errcode" not in kwargs: - kwargs["errcode"] = Codes.FORBIDDEN - super().__init__(*args, **kwargs) + def __init__(self, msg: str, errcode: str = Codes.FORBIDDEN): + super().__init__(400, msg, errcode) class IncompatibleRoomVersionError(SynapseError): @@ -466,7 +443,7 @@ class IncompatibleRoomVersionError(SynapseError): self._room_version = room_version - def error_dict(self): + def error_dict(self) -> "JsonDict": return cs_error(self.msg, self.errcode, room_version=self._room_version) @@ -494,7 +471,7 @@ class RequestSendFailed(RuntimeError): errors (like programming errors). """ - def __init__(self, inner_exception, can_retry): + def __init__(self, inner_exception: BaseException, can_retry: bool): super().__init__( "Failed to send request: %s: %s" % (type(inner_exception).__name__, inner_exception) @@ -503,7 +480,7 @@ class RequestSendFailed(RuntimeError): self.can_retry = can_retry -def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs): +def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs: Any) -> "JsonDict": """Utility method for constructing an error response for client-server interactions. @@ -551,7 +528,7 @@ class FederationError(RuntimeError): msg = "%s %s: %s" % (level, code, reason) super().__init__(msg) - def get_dict(self): + def get_dict(self) -> "JsonDict": return { "level": self.level, "code": self.code, @@ -580,7 +557,7 @@ class HttpResponseException(CodeMessageException): super().__init__(code, msg) self.response = response - def to_synapse_error(self): + def to_synapse_error(self) -> SynapseError: """Make a SynapseError based on an HTTPResponseException This is useful when a proxied request has failed, and we need to diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 20e91a115d..bc550ae646 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -231,24 +231,24 @@ class FilterCollection: def include_redundant_members(self) -> bool: return self._room_state_filter.include_redundant_members() - def filter_presence(self, events): + def filter_presence( + self, events: Iterable[UserPresenceState] + ) -> List[UserPresenceState]: return self._presence_filter.filter(events) - def filter_account_data(self, events): + def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]: return self._account_data.filter(events) - def filter_room_state(self, events): + def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]: return self._room_state_filter.filter(self._room_filter.filter(events)) - def filter_room_timeline(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: + def filter_room_timeline(self, events: Iterable[EventBase]) -> List[EventBase]: return self._room_timeline_filter.filter(self._room_filter.filter(events)) - def filter_room_ephemeral(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: + def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]: return self._room_ephemeral_filter.filter(self._room_filter.filter(events)) - def filter_room_account_data( - self, events: Iterable[FilterEvent] - ) -> List[FilterEvent]: + def filter_room_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]: return self._room_account_data.filter(self._room_filter.filter(events)) def blocks_all_presence(self) -> bool: @@ -309,7 +309,7 @@ class Filter: # except for presence which actually gets passed around as its own # namedtuple type. if isinstance(event, UserPresenceState): - sender = event.user_id + sender: Optional[str] = event.user_id room_id = None ev_type = "m.presence" contains_url = False diff --git a/synapse/api/presence.py b/synapse/api/presence.py index a3bf0348d1..b80aa83cb3 100644 --- a/synapse/api/presence.py +++ b/synapse/api/presence.py @@ -12,49 +12,48 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple +from typing import Any, Optional + +import attr from synapse.api.constants import PresenceState +from synapse.types import JsonDict -class UserPresenceState( - namedtuple( - "UserPresenceState", - ( - "user_id", - "state", - "last_active_ts", - "last_federation_update_ts", - "last_user_sync_ts", - "status_msg", - "currently_active", - ), - ) -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class UserPresenceState: """Represents the current presence state of the user. - user_id (str) - last_active (int): Time in msec that the user last interacted with server. - last_federation_update (int): Time in msec since either a) we sent a presence + user_id + last_active: Time in msec that the user last interacted with server. + last_federation_update: Time in msec since either a) we sent a presence update to other servers or b) we received a presence update, depending on if is a local user or not. - last_user_sync (int): Time in msec that the user last *completed* a sync + last_user_sync: Time in msec that the user last *completed* a sync (or event stream). - status_msg (str): User set status message. + status_msg: User set status message. """ - def as_dict(self): - return dict(self._asdict()) + user_id: str + state: str + last_active_ts: int + last_federation_update_ts: int + last_user_sync_ts: int + status_msg: Optional[str] + currently_active: bool + + def as_dict(self) -> JsonDict: + return attr.asdict(self) @staticmethod - def from_dict(d): + def from_dict(d: JsonDict) -> "UserPresenceState": return UserPresenceState(**d) - def copy_and_replace(self, **kwargs): - return self._replace(**kwargs) + def copy_and_replace(self, **kwargs: Any) -> "UserPresenceState": + return attr.evolve(self, **kwargs) @classmethod - def default(cls, user_id): + def default(cls, user_id: str) -> "UserPresenceState": """Returns a default presence state.""" return cls( user_id=user_id, diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index e8964097d3..849c18ceda 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -161,7 +161,7 @@ class Ratelimiter: return allowed, time_allowed - def _prune_message_counts(self, time_now_s: float): + def _prune_message_counts(self, time_now_s: float) -> None: """Remove message count entries that have not exceeded their defined rate_hz limit @@ -190,7 +190,7 @@ class Ratelimiter: update: bool = True, n_actions: int = 1, _time_now_s: Optional[float] = None, - ): + ) -> None: """Checks if an action can be performed. If not, raises a LimitExceededError Checks if the user has ratelimiting disabled in the database by looking diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 032c69b210..6e84b1524f 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -19,6 +19,7 @@ from hashlib import sha256 from urllib.parse import urlencode from synapse.config import ConfigError +from synapse.config.homeserver import HomeServerConfig SYNAPSE_CLIENT_API_PREFIX = "/_synapse/client" CLIENT_API_PREFIX = "/_matrix/client" @@ -34,11 +35,7 @@ LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" class ConsentURIBuilder: - def __init__(self, hs_config): - """ - Args: - hs_config (synapse.config.homeserver.HomeServerConfig): - """ + def __init__(self, hs_config: HomeServerConfig): if hs_config.key.form_secret is None: raise ConfigError("form_secret not set in config") if hs_config.server.public_baseurl is None: @@ -47,15 +44,15 @@ class ConsentURIBuilder: self._hmac_secret = hs_config.key.form_secret.encode("utf-8") self._public_baseurl = hs_config.server.public_baseurl - def build_user_consent_uri(self, user_id): + def build_user_consent_uri(self, user_id: str) -> str: """Build a URI which we can give to the user to do their privacy policy consent Args: - user_id (str): mxid or username of user + user_id: mxid or username of user Returns - (str) the URI where the user can do consent + The URI where the user can do consent """ mac = hmac.new( key=self._hmac_secret, msg=user_id.encode("ascii"), digestmod=sha256 -- cgit 1.5.1 From ba00e20234eadae66f105f8bda64e39beed9a92d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 21 Oct 2021 14:39:16 -0400 Subject: Add a thread relation type per MSC3440. (#11088) Adds experimental support for MSC3440's `io.element.thread` relation type (and the aggregation for it). --- changelog.d/11088.feature | 1 + synapse/api/constants.py | 1 + synapse/config/experimental.py | 2 + synapse/events/utils.py | 17 +++++++++ synapse/rest/client/relations.py | 3 +- synapse/storage/databases/main/events.py | 4 ++ synapse/storage/databases/main/relations.py | 59 ++++++++++++++++++++++++++++- tests/rest/client/test_relations.py | 40 ++++++++++++++++--- 8 files changed, 119 insertions(+), 8 deletions(-) create mode 100644 changelog.d/11088.feature (limited to 'synapse/api') diff --git a/changelog.d/11088.feature b/changelog.d/11088.feature new file mode 100644 index 0000000000..76b0d28084 --- /dev/null +++ b/changelog.d/11088.feature @@ -0,0 +1 @@ +Experimental support for the thread relation defined in [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index a31f037748..a33ac34161 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -176,6 +176,7 @@ class RelationTypes: ANNOTATION = "m.annotation" REPLACE = "m.replace" REFERENCE = "m.reference" + THREAD = "io.element.thread" class LimitBlockingTypes: diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index b013a3918c..8b098ad48d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -26,6 +26,8 @@ class ExperimentalConfig(Config): # Whether to enable experimental MSC1849 (aka relations) support self.msc1849_enabled = config.get("experimental_msc1849_support_enabled", True) + # MSC3440 (thread relation) + self.msc3440_enabled: bool = experimental.get("msc3440_enabled", False) # MSC3026 (busy presence state) self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 3f3eba86a8..6fa631aa1d 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -386,6 +386,7 @@ class EventClientSerializer: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self._msc1849_enabled = hs.config.experimental.msc1849_enabled + self._msc3440_enabled = hs.config.experimental.msc3440_enabled async def serialize_event( self, @@ -462,6 +463,22 @@ class EventClientSerializer: "sender": edit.sender, } + # If this event is the start of a thread, include a summary of the replies. + if self._msc3440_enabled: + ( + thread_count, + latest_thread_event, + ) = await self.store.get_thread_summary(event_id) + if latest_thread_event: + r = serialized_event["unsigned"].setdefault("m.relations", {}) + r[RelationTypes.THREAD] = { + # Don't bundle aggregations as this could recurse forever. + "latest_event": await self.serialize_event( + latest_thread_event, time_now, bundle_aggregations=False + ), + "count": thread_count, + } + return serialized_event async def serialize_events( diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index d695c18be2..58f6699073 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -128,9 +128,10 @@ class RelationSendServlet(RestServlet): content["m.relates_to"] = { "event_id": parent_id, - "key": aggregation_key, "rel_type": relation_type, } + if aggregation_key is not None: + content["m.relates_to"]["key"] = aggregation_key event_dict = { "type": event_type, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 37439f8562..8d9086ecf0 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1710,6 +1710,7 @@ class PersistEventsStore: RelationTypes.ANNOTATION, RelationTypes.REFERENCE, RelationTypes.REPLACE, + RelationTypes.THREAD, ): # Unknown relation type return @@ -1740,6 +1741,9 @@ class PersistEventsStore: if rel_type == RelationTypes.REPLACE: txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) + if rel_type == RelationTypes.THREAD: + txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) + def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): """Handles keeping track of insertion events and edges/connections. Part of MSC2716. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 2bbf6d6a95..40760fbd1b 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Optional +from typing import Optional, Tuple import attr @@ -269,6 +269,63 @@ class RelationsWorkerStore(SQLBaseStore): return await self.get_event(edit_id, allow_none=True) + @cached() + async def get_thread_summary( + self, event_id: str + ) -> Tuple[int, Optional[EventBase]]: + """Get the number of threaded replies, the senders of those replies, and + the latest reply (if any) for the given event. + + Args: + event_id: The original event ID + + Returns: + The number of items in the thread and the most recent response, if any. + """ + + def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]: + # Fetch the count of threaded events and the latest event ID. + # TODO Should this only allow m.room.message events. + sql = """ + SELECT event_id + FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id = ? + AND relation_type = ? + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT 1 + """ + + txn.execute(sql, (event_id, RelationTypes.THREAD)) + row = txn.fetchone() + if row is None: + return 0, None + + latest_event_id = row[0] + + sql = """ + SELECT COALESCE(COUNT(event_id), 0) + FROM event_relations + WHERE + relates_to_id = ? + AND relation_type = ? + """ + txn.execute(sql, (event_id, RelationTypes.THREAD)) + count = txn.fetchone()[0] + + return count, latest_event_id + + count, latest_event_id = await self.db_pool.runInteraction( + "get_thread_summary", _get_thread_summary_txn + ) + + latest_event = None + if latest_event_id: + latest_event = await self.get_event(latest_event_id, allow_none=True) + + return count, latest_event + async def has_user_annotated_event( self, parent_id: str, event_type: str, aggregation_key: str, sender: str ) -> bool: diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 3c7d49f0b4..78c2fb86b9 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -101,10 +101,10 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_basic_paginate_relations(self): """Tests that calling pagination API correctly the latest relations.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") self.assertEquals(200, channel.code, channel.json_body) annotation_id = channel.json_body["event_id"] @@ -141,8 +141,10 @@ class RelationsTestCase(unittest.HomeserverTestCase): """ expected_event_ids = [] - for _ in range(10): - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + for idx in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx) + ) self.assertEquals(200, channel.code, channel.json_body) expected_event_ids.append(channel.json_body["event_id"]) @@ -386,8 +388,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) self.assertEquals(400, channel.code, channel.json_body) + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_aggregation_get_event(self): - """Test that annotations and references get correctly bundled when + """Test that annotations, references, and threads get correctly bundled when getting the parent event. """ @@ -410,6 +413,13 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) reply_2 = channel.json_body["event_id"] + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + thread_2 = channel.json_body["event_id"] + channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id), @@ -429,6 +439,25 @@ class RelationsTestCase(unittest.HomeserverTestCase): RelationTypes.REFERENCE: { "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] }, + RelationTypes.THREAD: { + "count": 2, + "latest_event": { + "age": 100, + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "origin_server_ts": 1600, + "room_id": self.room, + "sender": self.user_id, + "type": "m.room.test", + "unsigned": {"age": 100}, + "user_id": self.user_id, + }, + }, }, ) @@ -559,7 +588,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): { "m.relates_to": { "event_id": self.parent_id, - "key": None, "rel_type": "m.reference", } }, -- cgit 1.5.1