summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-12-08 12:26:29 -0500
committerGitHub <noreply@github.com>2021-12-08 12:26:29 -0500
commitd93362d87fbbf4941da06ade65eaf5df1672bccb (patch)
treec357e1655faa4e6c286b0ed3493ffc69ae9a529c
parentClean up `synapse.rest.admin` (#11535) (diff)
downloadsynapse-d93362d87fbbf4941da06ade65eaf5df1672bccb.tar.xz
Add a constant for receipt types (m.read). (#11531)
And expand some type hints in the receipts storage module.
-rw-r--r--changelog.d/11531.misc1
-rw-r--r--synapse/api/constants.py4
-rw-r--r--synapse/handlers/receipts.py6
-rw-r--r--synapse/handlers/sync.py4
-rw-r--r--synapse/push/push_tools.py3
-rw-r--r--synapse/rest/client/notifications.py3
-rw-r--r--synapse/rest/client/read_marker.py6
-rw-r--r--synapse/rest/client/receipts.py4
-rw-r--r--synapse/storage/databases/main/receipts.py101
9 files changed, 87 insertions, 45 deletions
diff --git a/changelog.d/11531.misc b/changelog.d/11531.misc
new file mode 100644
index 0000000000..ed6ef3bb3e
--- /dev/null
+++ b/changelog.d/11531.misc
@@ -0,0 +1 @@
+Add a receipt types constant for `m.read`.
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index f7d29b4319..52c083a20b 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -253,5 +253,9 @@ class GuestAccess:
     FORBIDDEN: Final = "forbidden"
 
 
+class ReceiptTypes:
+    READ: Final = "m.read"
+
+
 class ReadReceiptEventFields:
     MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden"
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 4911a11535..5cb1ff749d 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -14,7 +14,7 @@
 import logging
 from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
-from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
 from synapse.appservice import ApplicationService
 from synapse.streams import EventSource
 from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
@@ -178,7 +178,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
 
             for event_id in content.keys():
                 event_content = content.get(event_id, {})
-                m_read = event_content.get("m.read", {})
+                m_read = event_content.get(ReceiptTypes.READ, {})
 
                 # If m_read is missing copy over the original event_content as there is nothing to process here
                 if not m_read:
@@ -206,7 +206,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
 
                 # Set new users unless empty
                 if len(new_users.keys()) > 0:
-                    new_event["content"][event_id] = {"m.read": new_users}
+                    new_event["content"][event_id] = {ReceiptTypes.READ: new_users}
 
             # Append new_event to visible_events unless empty
             if len(new_event["content"].keys()) > 0:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index f3039c3c3f..96f37e9f42 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -28,7 +28,7 @@ from typing import (
 import attr
 from prometheus_client import Counter
 
-from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes
 from synapse.api.filtering import FilterCollection
 from synapse.api.presence import UserPresenceState
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -1046,7 +1046,7 @@ class SyncHandler:
             last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
                 user_id=sync_config.user.to_string(),
                 room_id=room_id,
-                receipt_type="m.read",
+                receipt_type=ReceiptTypes.READ,
             )
 
             notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 9c85200c0f..da641aca47 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 from typing import Dict
 
+from synapse.api.constants import ReceiptTypes
 from synapse.events import EventBase
 from synapse.push.presentable_names import calculate_room_name, name_from_member_event
 from synapse.storage import Storage
@@ -23,7 +24,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
     invites = await store.get_invited_rooms_for_local_user(user_id)
     joins = await store.get_rooms_for_user(user_id)
 
-    my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read")
+    my_receipts_by_room = await store.get_receipts_for_user(user_id, ReceiptTypes.READ)
 
     badge = len(invites)
 
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index d1d8a984c6..b12a332776 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -15,6 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Tuple
 
+from synapse.api.constants import ReceiptTypes
 from synapse.events.utils import format_event_for_client_v2_without_room_id
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_integer, parse_string
@@ -54,7 +55,7 @@ class NotificationsServlet(RestServlet):
         )
 
         receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
-            user_id, "m.read"
+            user_id, ReceiptTypes.READ
         )
 
         notif_event_ids = [pa["event_id"] for pa in push_actions]
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 43c04fac6f..f51be511d1 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -15,7 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Tuple
 
-from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -48,7 +48,7 @@ class ReadMarkerRestServlet(RestServlet):
         await self.presence_handler.bump_presence_active_time(requester.user)
 
         body = parse_json_object_from_request(request)
-        read_event_id = body.get("m.read", None)
+        read_event_id = body.get(ReceiptTypes.READ, None)
         hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False)
 
         if not isinstance(hidden, bool):
@@ -62,7 +62,7 @@ class ReadMarkerRestServlet(RestServlet):
         if read_event_id:
             await self.receipts_handler.received_client_receipt(
                 room_id,
-                "m.read",
+                ReceiptTypes.READ,
                 user_id=requester.user.to_string(),
                 event_id=read_event_id,
                 hidden=hidden,
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 2b25b9aad6..b24ad2d1be 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -16,7 +16,7 @@ import logging
 import re
 from typing import TYPE_CHECKING, Tuple
 
-from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
 from synapse.api.errors import Codes, SynapseError
 from synapse.http import get_request_user_agent
 from synapse.http.server import HttpServer
@@ -53,7 +53,7 @@ class ReceiptRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
-        if receipt_type != "m.read":
+        if receipt_type != ReceiptTypes.READ:
             raise SynapseError(400, "Receipt type must be 'm.read'")
 
         # Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body.
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index c99f8aebdb..9c5625c8bb 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,14 +14,25 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+)
 
 from twisted.internet import defer
 
+from synapse.api.constants import ReceiptTypes
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import ReceiptsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.types import JsonDict
@@ -78,17 +89,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
             "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
         )
 
-    def get_max_receipt_stream_id(self):
-        """Get the current max stream ID for receipts stream
-
-        Returns:
-            int
-        """
+    def get_max_receipt_stream_id(self) -> int:
+        """Get the current max stream ID for receipts stream"""
         return self._receipts_id_gen.get_current_token()
 
     @cached()
-    async def get_users_with_read_receipts_in_room(self, room_id):
-        receipts = await self.get_receipts_for_room(room_id, "m.read")
+    async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]:
+        receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ)
         return {r["user_id"] for r in receipts}
 
     @cached(num_args=2)
@@ -119,7 +126,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
         )
 
     @cached(num_args=2)
-    async def get_receipts_for_user(self, user_id, receipt_type):
+    async def get_receipts_for_user(
+        self, user_id: str, receipt_type: str
+    ) -> Dict[str, str]:
         rows = await self.db_pool.simple_select_list(
             table="receipts_linearized",
             keyvalues={"user_id": user_id, "receipt_type": receipt_type},
@@ -129,8 +138,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         return {row["room_id"]: row["event_id"] for row in rows}
 
-    async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
-        def f(txn):
+    async def get_receipts_for_user_with_orderings(
+        self, user_id: str, receipt_type: str
+    ) -> JsonDict:
+        def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
             sql = (
                 "SELECT rl.room_id, rl.event_id,"
                 " e.topological_ordering, e.stream_ordering"
@@ -209,10 +220,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
     @cached(num_args=3, tree=True)
     async def _get_linearized_receipts_for_room(
         self, room_id: str, to_key: int, from_key: Optional[int] = None
-    ) -> List[dict]:
+    ) -> List[JsonDict]:
         """See get_linearized_receipts_for_room"""
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
             if from_key:
                 sql = (
                     "SELECT * FROM receipts_linearized WHERE"
@@ -250,11 +261,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
         list_name="room_ids",
         num_args=3,
     )
-    async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+    async def _get_linearized_receipts_for_rooms(
+        self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
+    ) -> Dict[str, List[JsonDict]]:
         if not room_ids:
             return {}
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
             if from_key:
                 sql = """
                     SELECT * FROM receipts_linearized WHERE
@@ -323,7 +336,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             A dictionary of roomids to a list of receipts.
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
             if from_key:
                 sql = """
                     SELECT * FROM receipts_linearized WHERE
@@ -379,7 +392,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return defer.succeed([])
 
-        def _get_users_sent_receipts_between_txn(txn):
+        def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
             sql = """
                 SELECT DISTINCT user_id FROM receipts_linearized
                 WHERE ? < stream_id AND stream_id <= ?
@@ -419,7 +432,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_updated_receipts_txn(txn):
+        def get_all_updated_receipts_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, list]], int, bool]:
             sql = """
                 SELECT stream_id, room_id, receipt_type, user_id, event_id, data
                 FROM receipts_linearized
@@ -446,8 +461,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     def _invalidate_get_users_with_receipts_in_room(
         self, room_id: str, receipt_type: str, user_id: str
-    ):
-        if receipt_type != "m.read":
+    ) -> None:
+        if receipt_type != ReceiptTypes.READ:
             return
 
         res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
@@ -461,7 +476,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         self.get_users_with_read_receipts_in_room.invalidate((room_id,))
 
-    def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
+    def invalidate_caches_for_receipt(
+        self, room_id: str, receipt_type: str, user_id: str
+    ) -> None:
         self.get_receipts_for_user.invalidate((user_id, receipt_type))
         self._get_linearized_receipts_for_room.invalidate((room_id,))
         self.get_last_receipt_event_id_for_user.invalidate(
@@ -482,11 +499,18 @@ class ReceiptsWorkerStore(SQLBaseStore):
         return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     def insert_linearized_receipt_txn(
-        self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
-    ):
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        receipt_type: str,
+        user_id: str,
+        event_id: str,
+        data: JsonDict,
+        stream_id: int,
+    ) -> Optional[int]:
         """Inserts a read-receipt into the database if it's newer than the current RR
 
-        Returns: int|None
+        Returns:
             None if the RR is older than the current RR
             otherwise, the rx timestamp of the event that the RR corresponds to
                 (or 0 if the event is unknown)
@@ -550,7 +574,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             lock=False,
         )
 
-        if receipt_type == "m.read" and stream_ordering is not None:
+        if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
             self._remove_old_push_actions_before_txn(
                 txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
             )
@@ -580,7 +604,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         else:
             # we need to points in graph -> linearized form.
             # TODO: Make this better.
-            def graph_to_linear(txn):
+            def graph_to_linear(txn: LoggingTransaction) -> str:
                 clause, args = make_in_list_sql_clause(
                     self.database_engine, "event_id", event_ids
                 )
@@ -634,11 +658,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
         return stream_id, max_persisted_id
 
     async def insert_graph_receipt(
-        self, room_id, receipt_type, user_id, event_ids, data
-    ):
+        self,
+        room_id: str,
+        receipt_type: str,
+        user_id: str,
+        event_ids: List[str],
+        data: JsonDict,
+    ) -> None:
         assert self._can_write_to_receipts
 
-        return await self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "insert_graph_receipt",
             self.insert_graph_receipt_txn,
             room_id,
@@ -649,8 +678,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
         )
 
     def insert_graph_receipt_txn(
-        self, txn, room_id, receipt_type, user_id, event_ids, data
-    ):
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        receipt_type: str,
+        user_id: str,
+        event_ids: List[str],
+        data: JsonDict,
+    ) -> None:
         assert self._can_write_to_receipts
 
         txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))