From 7862f821de30b17ed035a3355e552d027429dc6b Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Thu, 25 Nov 2021 16:14:23 +0000 Subject: Annotate string constants in `synapse.api.constants` with `Final` (#11356) This change makes mypy complain if the constants are ever reassigned, and, more usefully, makes mypy type them as `Literal`s instead of `str`s, allowing code of the following form to pass mypy: ```py def do_something(membership: Literal["join", "leave"], ...): ... do_something(Membership.JOIN, ...) ``` --- synapse/api/constants.py | 198 ++++++++++++++++++++++++----------------------- 1 file changed, 100 insertions(+), 98 deletions(-) (limited to 'synapse/api/constants.py') diff --git a/synapse/api/constants.py b/synapse/api/constants.py index a33ac34161..f7d29b4319 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -17,6 +17,8 @@ """Contains constants from the specification.""" +from typing_extensions import Final + # the max size of a (canonical-json-encoded) event MAX_PDU_SIZE = 65536 @@ -39,125 +41,125 @@ class Membership: """Represents the membership states of a user in a room.""" - INVITE = "invite" - JOIN = "join" - KNOCK = "knock" - LEAVE = "leave" - BAN = "ban" - LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN) + INVITE: Final = "invite" + JOIN: Final = "join" + KNOCK: Final = "knock" + LEAVE: Final = "leave" + BAN: Final = "ban" + LIST: Final = (INVITE, JOIN, KNOCK, LEAVE, BAN) class PresenceState: """Represents the presence state of a user.""" - OFFLINE = "offline" - UNAVAILABLE = "unavailable" - ONLINE = "online" - BUSY = "org.matrix.msc3026.busy" + OFFLINE: Final = "offline" + UNAVAILABLE: Final = "unavailable" + ONLINE: Final = "online" + BUSY: Final = "org.matrix.msc3026.busy" class JoinRules: - PUBLIC = "public" - KNOCK = "knock" - INVITE = "invite" - PRIVATE = "private" + PUBLIC: Final = "public" + KNOCK: Final = "knock" + INVITE: Final = "invite" + PRIVATE: Final = "private" # As defined for MSC3083. - RESTRICTED = "restricted" + RESTRICTED: Final = "restricted" class RestrictedJoinRuleTypes: """Understood types for the allow rules in restricted join rules.""" - ROOM_MEMBERSHIP = "m.room_membership" + ROOM_MEMBERSHIP: Final = "m.room_membership" class LoginType: - PASSWORD = "m.login.password" - EMAIL_IDENTITY = "m.login.email.identity" - MSISDN = "m.login.msisdn" - RECAPTCHA = "m.login.recaptcha" - TERMS = "m.login.terms" - SSO = "m.login.sso" - DUMMY = "m.login.dummy" - REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token" + PASSWORD: Final = "m.login.password" + EMAIL_IDENTITY: Final = "m.login.email.identity" + MSISDN: Final = "m.login.msisdn" + RECAPTCHA: Final = "m.login.recaptcha" + TERMS: Final = "m.login.terms" + SSO: Final = "m.login.sso" + DUMMY: Final = "m.login.dummy" + REGISTRATION_TOKEN: Final = "org.matrix.msc3231.login.registration_token" # This is used in the `type` parameter for /register when called by # an appservice to register a new user. -APP_SERVICE_REGISTRATION_TYPE = "m.login.application_service" +APP_SERVICE_REGISTRATION_TYPE: Final = "m.login.application_service" class EventTypes: - Member = "m.room.member" - Create = "m.room.create" - Tombstone = "m.room.tombstone" - JoinRules = "m.room.join_rules" - PowerLevels = "m.room.power_levels" - Aliases = "m.room.aliases" - Redaction = "m.room.redaction" - ThirdPartyInvite = "m.room.third_party_invite" - RelatedGroups = "m.room.related_groups" - - RoomHistoryVisibility = "m.room.history_visibility" - CanonicalAlias = "m.room.canonical_alias" - Encrypted = "m.room.encrypted" - RoomAvatar = "m.room.avatar" - RoomEncryption = "m.room.encryption" - GuestAccess = "m.room.guest_access" + Member: Final = "m.room.member" + Create: Final = "m.room.create" + Tombstone: Final = "m.room.tombstone" + JoinRules: Final = "m.room.join_rules" + PowerLevels: Final = "m.room.power_levels" + Aliases: Final = "m.room.aliases" + Redaction: Final = "m.room.redaction" + ThirdPartyInvite: Final = "m.room.third_party_invite" + RelatedGroups: Final = "m.room.related_groups" + + RoomHistoryVisibility: Final = "m.room.history_visibility" + CanonicalAlias: Final = "m.room.canonical_alias" + Encrypted: Final = "m.room.encrypted" + RoomAvatar: Final = "m.room.avatar" + RoomEncryption: Final = "m.room.encryption" + GuestAccess: Final = "m.room.guest_access" # These are used for validation - Message = "m.room.message" - Topic = "m.room.topic" - Name = "m.room.name" + Message: Final = "m.room.message" + Topic: Final = "m.room.topic" + Name: Final = "m.room.name" - ServerACL = "m.room.server_acl" - Pinned = "m.room.pinned_events" + ServerACL: Final = "m.room.server_acl" + Pinned: Final = "m.room.pinned_events" - Retention = "m.room.retention" + Retention: Final = "m.room.retention" - Dummy = "org.matrix.dummy_event" + Dummy: Final = "org.matrix.dummy_event" - SpaceChild = "m.space.child" - SpaceParent = "m.space.parent" + SpaceChild: Final = "m.space.child" + SpaceParent: Final = "m.space.parent" - MSC2716_INSERTION = "org.matrix.msc2716.insertion" - MSC2716_BATCH = "org.matrix.msc2716.batch" - MSC2716_MARKER = "org.matrix.msc2716.marker" + MSC2716_INSERTION: Final = "org.matrix.msc2716.insertion" + MSC2716_BATCH: Final = "org.matrix.msc2716.batch" + MSC2716_MARKER: Final = "org.matrix.msc2716.marker" class ToDeviceEventTypes: - RoomKeyRequest = "m.room_key_request" + RoomKeyRequest: Final = "m.room_key_request" class DeviceKeyAlgorithms: """Spec'd algorithms for the generation of per-device keys""" - ED25519 = "ed25519" - CURVE25519 = "curve25519" - SIGNED_CURVE25519 = "signed_curve25519" + ED25519: Final = "ed25519" + CURVE25519: Final = "curve25519" + SIGNED_CURVE25519: Final = "signed_curve25519" class EduTypes: - Presence = "m.presence" + Presence: Final = "m.presence" class RejectedReason: - AUTH_ERROR = "auth_error" + AUTH_ERROR: Final = "auth_error" class RoomCreationPreset: - PRIVATE_CHAT = "private_chat" - PUBLIC_CHAT = "public_chat" - TRUSTED_PRIVATE_CHAT = "trusted_private_chat" + PRIVATE_CHAT: Final = "private_chat" + PUBLIC_CHAT: Final = "public_chat" + TRUSTED_PRIVATE_CHAT: Final = "trusted_private_chat" class ThirdPartyEntityKind: - USER = "user" - LOCATION = "location" + USER: Final = "user" + LOCATION: Final = "location" -ServerNoticeMsgType = "m.server_notice" -ServerNoticeLimitReached = "m.server_notice.usage_limit_reached" +ServerNoticeMsgType: Final = "m.server_notice" +ServerNoticeLimitReached: Final = "m.server_notice.usage_limit_reached" class UserTypes: @@ -165,91 +167,91 @@ class UserTypes: 'admin' and 'guest' users should also be UserTypes. Normal users are type None """ - SUPPORT = "support" - BOT = "bot" - ALL_USER_TYPES = (SUPPORT, BOT) + SUPPORT: Final = "support" + BOT: Final = "bot" + ALL_USER_TYPES: Final = (SUPPORT, BOT) class RelationTypes: """The types of relations known to this server.""" - ANNOTATION = "m.annotation" - REPLACE = "m.replace" - REFERENCE = "m.reference" - THREAD = "io.element.thread" + ANNOTATION: Final = "m.annotation" + REPLACE: Final = "m.replace" + REFERENCE: Final = "m.reference" + THREAD: Final = "io.element.thread" class LimitBlockingTypes: """Reasons that a server may be blocked""" - MONTHLY_ACTIVE_USER = "monthly_active_user" - HS_DISABLED = "hs_disabled" + MONTHLY_ACTIVE_USER: Final = "monthly_active_user" + HS_DISABLED: Final = "hs_disabled" class EventContentFields: """Fields found in events' content, regardless of type.""" # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 - LABELS = "org.matrix.labels" + LABELS: Final = "org.matrix.labels" # Timestamp to delete the event after # cf https://github.com/matrix-org/matrix-doc/pull/2228 - SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after" + SELF_DESTRUCT_AFTER: Final = "org.matrix.self_destruct_after" # cf https://github.com/matrix-org/matrix-doc/pull/1772 - ROOM_TYPE = "type" + ROOM_TYPE: Final = "type" # Whether a room can federate. - FEDERATE = "m.federate" + FEDERATE: Final = "m.federate" # The creator of the room, as used in `m.room.create` events. - ROOM_CREATOR = "creator" + ROOM_CREATOR: Final = "creator" # Used in m.room.guest_access events. - GUEST_ACCESS = "guest_access" + GUEST_ACCESS: Final = "guest_access" # Used on normal messages to indicate they were historically imported after the fact - MSC2716_HISTORICAL = "org.matrix.msc2716.historical" + MSC2716_HISTORICAL: Final = "org.matrix.msc2716.historical" # For "insertion" events to indicate what the next batch ID should be in # order to connect to it - MSC2716_NEXT_BATCH_ID = "org.matrix.msc2716.next_batch_id" + MSC2716_NEXT_BATCH_ID: Final = "org.matrix.msc2716.next_batch_id" # Used on "batch" events to indicate which insertion event it connects to - MSC2716_BATCH_ID = "org.matrix.msc2716.batch_id" + MSC2716_BATCH_ID: Final = "org.matrix.msc2716.batch_id" # For "marker" events - MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion" + MSC2716_MARKER_INSERTION: Final = "org.matrix.msc2716.marker.insertion" # The authorising user for joining a restricted room. - AUTHORISING_USER = "join_authorised_via_users_server" + AUTHORISING_USER: Final = "join_authorised_via_users_server" class RoomTypes: """Understood values of the room_type field of m.room.create events.""" - SPACE = "m.space" + SPACE: Final = "m.space" class RoomEncryptionAlgorithms: - MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2" - DEFAULT = MEGOLM_V1_AES_SHA2 + MEGOLM_V1_AES_SHA2: Final = "m.megolm.v1.aes-sha2" + DEFAULT: Final = MEGOLM_V1_AES_SHA2 class AccountDataTypes: - DIRECT = "m.direct" - IGNORED_USER_LIST = "m.ignored_user_list" + DIRECT: Final = "m.direct" + IGNORED_USER_LIST: Final = "m.ignored_user_list" class HistoryVisibility: - INVITED = "invited" - JOINED = "joined" - SHARED = "shared" - WORLD_READABLE = "world_readable" + INVITED: Final = "invited" + JOINED: Final = "joined" + SHARED: Final = "shared" + WORLD_READABLE: Final = "world_readable" class GuestAccess: - CAN_JOIN = "can_join" + CAN_JOIN: Final = "can_join" # anything that is not "can_join" is considered "forbidden", but for completeness: - FORBIDDEN = "forbidden" + FORBIDDEN: Final = "forbidden" class ReadReceiptEventFields: - MSC2285_HIDDEN = "org.matrix.msc2285.hidden" + MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden" -- cgit 1.5.1 From d93362d87fbbf4941da06ade65eaf5df1672bccb Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Dec 2021 12:26:29 -0500 Subject: Add a constant for receipt types (m.read). (#11531) And expand some type hints in the receipts storage module. --- changelog.d/11531.misc | 1 + synapse/api/constants.py | 4 ++ synapse/handlers/receipts.py | 6 +- synapse/handlers/sync.py | 4 +- synapse/push/push_tools.py | 3 +- synapse/rest/client/notifications.py | 3 +- synapse/rest/client/read_marker.py | 6 +- synapse/rest/client/receipts.py | 4 +- synapse/storage/databases/main/receipts.py | 101 +++++++++++++++++++---------- 9 files changed, 87 insertions(+), 45 deletions(-) create mode 100644 changelog.d/11531.misc (limited to 'synapse/api/constants.py') diff --git a/changelog.d/11531.misc b/changelog.d/11531.misc new file mode 100644 index 0000000000..ed6ef3bb3e --- /dev/null +++ b/changelog.d/11531.misc @@ -0,0 +1 @@ +Add a receipt types constant for `m.read`. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index f7d29b4319..52c083a20b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -253,5 +253,9 @@ class GuestAccess: FORBIDDEN: Final = "forbidden" +class ReceiptTypes: + READ: Final = "m.read" + + class ReadReceiptEventFields: MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden" diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 4911a11535..5cb1ff749d 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -14,7 +14,7 @@ import logging from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.appservice import ApplicationService from synapse.streams import EventSource from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id @@ -178,7 +178,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): for event_id in content.keys(): event_content = content.get(event_id, {}) - m_read = event_content.get("m.read", {}) + m_read = event_content.get(ReceiptTypes.READ, {}) # If m_read is missing copy over the original event_content as there is nothing to process here if not m_read: @@ -206,7 +206,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): # Set new users unless empty if len(new_users.keys()) > 0: - new_event["content"][event_id] = {"m.read": new_users} + new_event["content"][event_id] = {ReceiptTypes.READ: new_users} # Append new_event to visible_events unless empty if len(new_event["content"].keys()) > 0: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index f3039c3c3f..96f37e9f42 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -28,7 +28,7 @@ from typing import ( import attr from prometheus_client import Counter -from synapse.api.constants import AccountDataTypes, EventTypes, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS @@ -1046,7 +1046,7 @@ class SyncHandler: last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), room_id=room_id, - receipt_type="m.read", + receipt_type=ReceiptTypes.READ, ) notifs = await self.store.get_unread_event_push_actions_by_room_for_user( diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 9c85200c0f..da641aca47 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Dict +from synapse.api.constants import ReceiptTypes from synapse.events import EventBase from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage import Storage @@ -23,7 +24,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) - my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") + my_receipts_by_room = await store.get_receipts_for_user(user_id, ReceiptTypes.READ) badge = len(invites) diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index d1d8a984c6..b12a332776 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple +from synapse.api.constants import ReceiptTypes from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string @@ -54,7 +55,7 @@ class NotificationsServlet(RestServlet): ) receipts_by_room = await self.store.get_receipts_for_user_with_orderings( - user_id, "m.read" + user_id, ReceiptTypes.READ ) notif_event_ids = [pa["event_id"] for pa in push_actions] diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 43c04fac6f..f51be511d1 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -48,7 +48,7 @@ class ReadMarkerRestServlet(RestServlet): await self.presence_handler.bump_presence_active_time(requester.user) body = parse_json_object_from_request(request) - read_event_id = body.get("m.read", None) + read_event_id = body.get(ReceiptTypes.READ, None) hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) if not isinstance(hidden, bool): @@ -62,7 +62,7 @@ class ReadMarkerRestServlet(RestServlet): if read_event_id: await self.receipts_handler.received_client_receipt( room_id, - "m.read", + ReceiptTypes.READ, user_id=requester.user.to_string(), event_id=read_event_id, hidden=hidden, diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 2b25b9aad6..b24ad2d1be 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -16,7 +16,7 @@ import logging import re from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http import get_request_user_agent from synapse.http.server import HttpServer @@ -53,7 +53,7 @@ class ReceiptRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - if receipt_type != "m.read": + if receipt_type != ReceiptTypes.READ: raise SynapseError(400, "Receipt type must be 'm.read'") # Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body. diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index c99f8aebdb..9c5625c8bb 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -14,14 +14,25 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, +) from twisted.internet import defer +from synapse.api.constants import ReceiptTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict @@ -78,17 +89,13 @@ class ReceiptsWorkerStore(SQLBaseStore): "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) - def get_max_receipt_stream_id(self): - """Get the current max stream ID for receipts stream - - Returns: - int - """ + def get_max_receipt_stream_id(self) -> int: + """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() @cached() - async def get_users_with_read_receipts_in_room(self, room_id): - receipts = await self.get_receipts_for_room(room_id, "m.read") + async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]: + receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ) return {r["user_id"] for r in receipts} @cached(num_args=2) @@ -119,7 +126,9 @@ class ReceiptsWorkerStore(SQLBaseStore): ) @cached(num_args=2) - async def get_receipts_for_user(self, user_id, receipt_type): + async def get_receipts_for_user( + self, user_id: str, receipt_type: str + ) -> Dict[str, str]: rows = await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, @@ -129,8 +138,10 @@ class ReceiptsWorkerStore(SQLBaseStore): return {row["room_id"]: row["event_id"] for row in rows} - async def get_receipts_for_user_with_orderings(self, user_id, receipt_type): - def f(txn): + async def get_receipts_for_user_with_orderings( + self, user_id: str, receipt_type: str + ) -> JsonDict: + def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]: sql = ( "SELECT rl.room_id, rl.event_id," " e.topological_ordering, e.stream_ordering" @@ -209,10 +220,10 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(num_args=3, tree=True) async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> List[dict]: + ) -> List[JsonDict]: """See get_linearized_receipts_for_room""" - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = ( "SELECT * FROM receipts_linearized WHERE" @@ -250,11 +261,13 @@ class ReceiptsWorkerStore(SQLBaseStore): list_name="room_ids", num_args=3, ) - async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + async def _get_linearized_receipts_for_rooms( + self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None + ) -> Dict[str, List[JsonDict]]: if not room_ids: return {} - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = """ SELECT * FROM receipts_linearized WHERE @@ -323,7 +336,7 @@ class ReceiptsWorkerStore(SQLBaseStore): A dictionary of roomids to a list of receipts. """ - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = """ SELECT * FROM receipts_linearized WHERE @@ -379,7 +392,7 @@ class ReceiptsWorkerStore(SQLBaseStore): if last_id == current_id: return defer.succeed([]) - def _get_users_sent_receipts_between_txn(txn): + def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]: sql = """ SELECT DISTINCT user_id FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? @@ -419,7 +432,9 @@ class ReceiptsWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_updated_receipts_txn(txn): + def get_all_updated_receipts_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, list]], int, bool]: sql = """ SELECT stream_id, room_id, receipt_type, user_id, event_id, data FROM receipts_linearized @@ -446,8 +461,8 @@ class ReceiptsWorkerStore(SQLBaseStore): def _invalidate_get_users_with_receipts_in_room( self, room_id: str, receipt_type: str, user_id: str - ): - if receipt_type != "m.read": + ) -> None: + if receipt_type != ReceiptTypes.READ: return res = self.get_users_with_read_receipts_in_room.cache.get_immediate( @@ -461,7 +476,9 @@ class ReceiptsWorkerStore(SQLBaseStore): self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): + def invalidate_caches_for_receipt( + self, room_id: str, receipt_type: str, user_id: str + ) -> None: self.get_receipts_for_user.invalidate((user_id, receipt_type)) self._get_linearized_receipts_for_room.invalidate((room_id,)) self.get_last_receipt_event_id_for_user.invalidate( @@ -482,11 +499,18 @@ class ReceiptsWorkerStore(SQLBaseStore): return super().process_replication_rows(stream_name, instance_name, token, rows) def insert_linearized_receipt_txn( - self, txn, room_id, receipt_type, user_id, event_id, data, stream_id - ): + self, + txn: LoggingTransaction, + room_id: str, + receipt_type: str, + user_id: str, + event_id: str, + data: JsonDict, + stream_id: int, + ) -> Optional[int]: """Inserts a read-receipt into the database if it's newer than the current RR - Returns: int|None + Returns: None if the RR is older than the current RR otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) @@ -550,7 +574,7 @@ class ReceiptsWorkerStore(SQLBaseStore): lock=False, ) - if receipt_type == "m.read" and stream_ordering is not None: + if receipt_type == ReceiptTypes.READ and stream_ordering is not None: self._remove_old_push_actions_before_txn( txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering ) @@ -580,7 +604,7 @@ class ReceiptsWorkerStore(SQLBaseStore): else: # we need to points in graph -> linearized form. # TODO: Make this better. - def graph_to_linear(txn): + def graph_to_linear(txn: LoggingTransaction) -> str: clause, args = make_in_list_sql_clause( self.database_engine, "event_id", event_ids ) @@ -634,11 +658,16 @@ class ReceiptsWorkerStore(SQLBaseStore): return stream_id, max_persisted_id async def insert_graph_receipt( - self, room_id, receipt_type, user_id, event_ids, data - ): + self, + room_id: str, + receipt_type: str, + user_id: str, + event_ids: List[str], + data: JsonDict, + ) -> None: assert self._can_write_to_receipts - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, @@ -649,8 +678,14 @@ class ReceiptsWorkerStore(SQLBaseStore): ) def insert_graph_receipt_txn( - self, txn, room_id, receipt_type, user_id, event_ids, data - ): + self, + txn: LoggingTransaction, + room_id: str, + receipt_type: str, + user_id: str, + event_ids: List[str], + data: JsonDict, + ) -> None: assert self._can_write_to_receipts txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) -- cgit 1.5.1