summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/experimental.py2
-rw-r--r--synapse/handlers/receipts.py23
-rw-r--r--synapse/replication/tcp/client.py3
-rw-r--r--synapse/replication/tcp/streams/_base.py1
-rw-r--r--synapse/rest/client/read_marker.py2
-rw-r--r--synapse/rest/client/receipts.py14
-rw-r--r--synapse/rest/client/versions.py2
-rw-r--r--synapse/storage/database.py2
-rw-r--r--synapse/storage/databases/main/receipts.py87
-rw-r--r--synapse/types.py1
10 files changed, 110 insertions, 27 deletions
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 595eb007a5..933779c23a 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -83,6 +83,8 @@ class ExperimentalConfig(Config):
         # MSC3786 (Add a default push rule to ignore m.room.server_acl events)
         self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False)
 
+        # MSC3771: Thread read receipts
+        self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False)
         # MSC3772: A push rule for mutual relations.
         self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False)
 
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index afaf3261df..4768a34c07 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -63,6 +63,8 @@ class ReceiptsHandler:
         self.clock = self.hs.get_clock()
         self.state = hs.get_state_handler()
 
+        self._msc3771_enabled = hs.config.experimental.msc3771_enabled
+
     async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
         """Called when we receive an EDU of type m.receipt from a remote HS."""
         receipts = []
@@ -91,13 +93,23 @@ class ReceiptsHandler:
                         )
                         continue
 
+                    # Check if these receipts apply to a thread.
+                    thread_id = None
+                    data = user_values.get("data", {})
+                    if self._msc3771_enabled and isinstance(data, dict):
+                        thread_id = data.get("thread_id")
+                        # If the thread ID is invalid, consider it missing.
+                        if not isinstance(thread_id, str):
+                            thread_id = None
+
                     receipts.append(
                         ReadReceipt(
                             room_id=room_id,
                             receipt_type=receipt_type,
                             user_id=user_id,
                             event_ids=user_values["event_ids"],
-                            data=user_values.get("data", {}),
+                            thread_id=thread_id,
+                            data=data,
                         )
                     )
 
@@ -114,6 +126,7 @@ class ReceiptsHandler:
                 receipt.receipt_type,
                 receipt.user_id,
                 receipt.event_ids,
+                receipt.thread_id,
                 receipt.data,
             )
 
@@ -146,7 +159,12 @@ class ReceiptsHandler:
         return True
 
     async def received_client_receipt(
-        self, room_id: str, receipt_type: str, user_id: str, event_id: str
+        self,
+        room_id: str,
+        receipt_type: str,
+        user_id: str,
+        event_id: str,
+        thread_id: Optional[str],
     ) -> None:
         """Called when a client tells us a local user has read up to the given
         event_id in the room.
@@ -156,6 +174,7 @@ class ReceiptsHandler:
             receipt_type=receipt_type,
             user_id=user_id,
             event_ids=[event_id],
+            thread_id=thread_id,
             data={"ts": int(self.clock.time_msec())},
         )
 
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index cf9cd6833b..b2522f98ca 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -427,7 +427,8 @@ class FederationSenderHandler:
                 receipt.receipt_type,
                 receipt.user_id,
                 [receipt.event_id],
-                receipt.data,
+                thread_id=receipt.thread_id,
+                data=receipt.data,
             )
             await self.federation_sender.send_read_receipt(receipt_info)
 
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 398bebeaa6..e01155ad59 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -361,6 +361,7 @@ class ReceiptsStream(Stream):
         receipt_type: str
         user_id: str
         event_id: str
+        thread_id: Optional[str]
         data: dict
 
     NAME = "receipts"
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 5e53096539..852838515c 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -83,6 +83,8 @@ class ReadMarkerRestServlet(RestServlet):
                     receipt_type,
                     user_id=requester.user.to_string(),
                     event_id=event_id,
+                    # Setting the thread ID is not possible with the /read_markers endpoint.
+                    thread_id=None,
                 )
 
         return 200, {}
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 5b7fad7402..f3ff156abe 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -49,6 +49,7 @@ class ReceiptRestServlet(RestServlet):
             ReceiptTypes.READ_PRIVATE,
             ReceiptTypes.FULLY_READ,
         }
+        self._msc3771_enabled = hs.config.experimental.msc3771_enabled
 
     async def on_POST(
         self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
@@ -61,7 +62,17 @@ class ReceiptRestServlet(RestServlet):
                 f"Receipt type must be {', '.join(self._known_receipt_types)}",
             )
 
-        parse_json_object_from_request(request, allow_empty_body=False)
+        body = parse_json_object_from_request(request)
+
+        # Pull the thread ID, if one exists.
+        thread_id = None
+        if self._msc3771_enabled:
+            if "thread_id" in body:
+                thread_id = body.get("thread_id")
+                if not thread_id or not isinstance(thread_id, str):
+                    raise SynapseError(
+                        400, "thread_id field must be a non-empty string"
+                    )
 
         await self.presence_handler.bump_presence_active_time(requester.user)
 
@@ -77,6 +88,7 @@ class ReceiptRestServlet(RestServlet):
                 receipt_type,
                 user_id=requester.user.to_string(),
                 event_id=event_id,
+                thread_id=thread_id,
             )
 
         return 200, {}
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index b3917a5abc..c95b0d6f19 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -103,6 +103,8 @@ class VersionsRestServlet(RestServlet):
                     "org.matrix.msc3030": self.config.experimental.msc3030_enabled,
                     # Adds support for thread relations, per MSC3440.
                     "org.matrix.msc3440.stable": True,  # TODO: remove when "v1.3" is added above
+                    # Support for thread read receipts.
+                    "org.matrix.msc3771": self.config.experimental.msc3771_enabled,
                     # Allows moderators to fetch redacted event content as described in MSC2815
                     "fi.mau.msc2815": self.config.experimental.msc2815_enabled,
                     # Adds support for login token requests as per MSC3882
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 921cd4dc5e..9d116f6925 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -95,6 +95,8 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
     "local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx",
     "remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx",
     "event_push_summary": "event_push_summary_unique_index",
+    "receipts_linearized": "receipts_linearized_unique_index",
+    "receipts_graph": "receipts_graph_unique_index",
 }
 
 
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index ddb8e80b69..52fe0db924 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -540,7 +540,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     async def get_all_updated_receipts(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> Tuple[List[Tuple[int, list]], int, bool]:
+    ) -> Tuple[
+        List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], int, bool
+    ]:
         """Get updates for receipts replication stream.
 
         Args:
@@ -567,9 +569,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         def get_all_updated_receipts_txn(
             txn: LoggingTransaction,
-        ) -> Tuple[List[Tuple[int, list]], int, bool]:
+        ) -> Tuple[
+            List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]],
+            int,
+            bool,
+        ]:
             sql = """
-                SELECT stream_id, room_id, receipt_type, user_id, event_id, data
+                SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
                 FROM receipts_linearized
                 WHERE ? < stream_id AND stream_id <= ?
                 ORDER BY stream_id ASC
@@ -578,8 +584,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
             txn.execute(sql, (last_id, current_id, limit))
 
             updates = cast(
-                List[Tuple[int, list]],
-                [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
+                List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]],
+                [(r[0], r[1:6] + (db_to_json(r[6]),)) for r in txn],
             )
 
             limited = False
@@ -631,6 +637,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         receipt_type: str,
         user_id: str,
         event_id: str,
+        thread_id: Optional[str],
         data: JsonDict,
         stream_id: int,
     ) -> Optional[int]:
@@ -657,12 +664,27 @@ class ReceiptsWorkerStore(SQLBaseStore):
         # We don't want to clobber receipts for more recent events, so we
         # have to compare orderings of existing receipts
         if stream_ordering is not None:
-            sql = (
-                "SELECT stream_ordering, event_id FROM events"
-                " INNER JOIN receipts_linearized AS r USING (event_id, room_id)"
-                " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
+            if thread_id is None:
+                thread_clause = "r.thread_id IS NULL"
+                thread_args: Tuple[str, ...] = ()
+            else:
+                thread_clause = "r.thread_id = ?"
+                thread_args = (thread_id,)
+
+            sql = f"""
+            SELECT stream_ordering, event_id FROM events
+            INNER JOIN receipts_linearized AS r USING (event_id, room_id)
+            WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ? AND {thread_clause}
+            """
+            txn.execute(
+                sql,
+                (
+                    room_id,
+                    receipt_type,
+                    user_id,
+                )
+                + thread_args,
             )
-            txn.execute(sql, (room_id, receipt_type, user_id))
 
             for so, eid in txn:
                 if int(so) >= stream_ordering:
@@ -682,21 +704,28 @@ class ReceiptsWorkerStore(SQLBaseStore):
             self._receipts_stream_cache.entity_has_changed, room_id, stream_id
         )
 
+        keyvalues = {
+            "room_id": room_id,
+            "receipt_type": receipt_type,
+            "user_id": user_id,
+        }
+        where_clause = ""
+        if thread_id is None:
+            where_clause = "thread_id IS NULL"
+        else:
+            keyvalues["thread_id"] = thread_id
+
         self.db_pool.simple_upsert_txn(
             txn,
             table="receipts_linearized",
-            keyvalues={
-                "room_id": room_id,
-                "receipt_type": receipt_type,
-                "user_id": user_id,
-            },
+            keyvalues=keyvalues,
             values={
                 "stream_id": stream_id,
                 "event_id": event_id,
                 "event_stream_ordering": stream_ordering,
                 "data": json_encoder.encode(data),
-                "thread_id": None,
             },
+            where_clause=where_clause,
             # receipts_linearized has a unique constraint on
             # (user_id, room_id, receipt_type), so no need to lock
             lock=False,
@@ -748,6 +777,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         receipt_type: str,
         user_id: str,
         event_ids: List[str],
+        thread_id: Optional[str],
         data: dict,
     ) -> Optional[Tuple[int, int]]:
         """Insert a receipt, either from local client or remote server.
@@ -780,6 +810,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 receipt_type,
                 user_id,
                 linearized_event_id,
+                thread_id,
                 data,
                 stream_id=stream_id,
                 # Read committed is actually beneficial here because we check for a receipt with
@@ -794,7 +825,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         now = self._clock.time_msec()
         logger.debug(
-            "RR for event %s in %s (%i ms old)",
+            "Receipt %s for event %s in %s (%i ms old)",
+            receipt_type,
             linearized_event_id,
             room_id,
             now - event_ts,
@@ -807,6 +839,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             receipt_type,
             user_id,
             event_ids,
+            thread_id,
             data,
         )
 
@@ -821,6 +854,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         receipt_type: str,
         user_id: str,
         event_ids: List[str],
+        thread_id: Optional[str],
         data: JsonDict,
     ) -> None:
         assert self._can_write_to_receipts
@@ -832,19 +866,26 @@ class ReceiptsWorkerStore(SQLBaseStore):
         # FIXME: This shouldn't invalidate the whole cache
         txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))
 
+        keyvalues = {
+            "room_id": room_id,
+            "receipt_type": receipt_type,
+            "user_id": user_id,
+        }
+        where_clause = ""
+        if thread_id is None:
+            where_clause = "thread_id IS NULL"
+        else:
+            keyvalues["thread_id"] = thread_id
+
         self.db_pool.simple_upsert_txn(
             txn,
             table="receipts_graph",
-            keyvalues={
-                "room_id": room_id,
-                "receipt_type": receipt_type,
-                "user_id": user_id,
-            },
+            keyvalues=keyvalues,
             values={
                 "event_ids": json_encoder.encode(event_ids),
                 "data": json_encoder.encode(data),
-                "thread_id": None,
             },
+            where_clause=where_clause,
             # receipts_graph has a unique constraint on
             # (user_id, room_id, receipt_type), so no need to lock
             lock=False,
diff --git a/synapse/types.py b/synapse/types.py
index ec44601f54..773f0438d5 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -835,6 +835,7 @@ class ReadReceipt:
     receipt_type: str
     user_id: str
     event_ids: List[str]
+    thread_id: Optional[str]
     data: JsonDict