summary refs log tree commit diff
path: root/synapse/storage/receipts.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/receipts.py')
-rw-r--r--synapse/storage/receipts.py191
1 files changed, 89 insertions, 102 deletions
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 89a1f7e3d7..a1647e50a1 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -64,10 +64,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     def get_receipts_for_room(self, room_id, receipt_type):
         return self._simple_select_list(
             table="receipts_linearized",
-            keyvalues={
-                "room_id": room_id,
-                "receipt_type": receipt_type,
-            },
+            keyvalues={"room_id": room_id, "receipt_type": receipt_type},
             retcols=("user_id", "event_id"),
             desc="get_receipts_for_room",
         )
@@ -79,7 +76,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             keyvalues={
                 "room_id": room_id,
                 "receipt_type": receipt_type,
-                "user_id": user_id
+                "user_id": user_id,
             },
             retcol="event_id",
             desc="get_own_receipt_for_user",
@@ -90,10 +87,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     def get_receipts_for_user(self, user_id, receipt_type):
         rows = yield self._simple_select_list(
             table="receipts_linearized",
-            keyvalues={
-                "user_id": user_id,
-                "receipt_type": receipt_type,
-            },
+            keyvalues={"user_id": user_id, "receipt_type": receipt_type},
             retcols=("room_id", "event_id"),
             desc="get_receipts_for_user",
         )
@@ -114,16 +108,18 @@ class ReceiptsWorkerStore(SQLBaseStore):
             )
             txn.execute(sql, (user_id,))
             return txn.fetchall()
-        rows = yield self.runInteraction(
-            "get_receipts_for_user_with_orderings", f
+
+        rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f)
+        defer.returnValue(
+            {
+                row[0]: {
+                    "event_id": row[1],
+                    "topological_ordering": row[2],
+                    "stream_ordering": row[3],
+                }
+                for row in rows
+            }
         )
-        defer.returnValue({
-            row[0]: {
-                "event_id": row[1],
-                "topological_ordering": row[2],
-                "stream_ordering": row[3],
-            } for row in rows
-        })
 
     @defer.inlineCallbacks
     def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
@@ -177,6 +173,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
         """See get_linearized_receipts_for_room
         """
+
         def f(txn):
             if from_key:
                 sql = (
@@ -184,48 +181,40 @@ class ReceiptsWorkerStore(SQLBaseStore):
                     " room_id = ? AND stream_id > ? AND stream_id <= ?"
                 )
 
-                txn.execute(
-                    sql,
-                    (room_id, from_key, to_key)
-                )
+                txn.execute(sql, (room_id, from_key, to_key))
             else:
                 sql = (
                     "SELECT * FROM receipts_linearized WHERE"
                     " room_id = ? AND stream_id <= ?"
                 )
 
-                txn.execute(
-                    sql,
-                    (room_id, to_key)
-                )
+                txn.execute(sql, (room_id, to_key))
 
             rows = self.cursor_to_dict(txn)
 
             return rows
 
-        rows = yield self.runInteraction(
-            "get_linearized_receipts_for_room", f
-        )
+        rows = yield self.runInteraction("get_linearized_receipts_for_room", f)
 
         if not rows:
             defer.returnValue([])
 
         content = {}
         for row in rows:
-            content.setdefault(
-                row["event_id"], {}
-            ).setdefault(
-                row["receipt_type"], {}
-            )[row["user_id"]] = json.loads(row["data"])
-
-        defer.returnValue([{
-            "type": "m.receipt",
-            "room_id": room_id,
-            "content": content,
-        }])
-
-    @cachedList(cached_method_name="_get_linearized_receipts_for_room",
-                list_name="room_ids", num_args=3, inlineCallbacks=True)
+            content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
+                row["user_id"]
+            ] = json.loads(row["data"])
+
+        defer.returnValue(
+            [{"type": "m.receipt", "room_id": room_id, "content": content}]
+        )
+
+    @cachedList(
+        cached_method_name="_get_linearized_receipts_for_room",
+        list_name="room_ids",
+        num_args=3,
+        inlineCallbacks=True,
+    )
     def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
         if not room_ids:
             defer.returnValue({})
@@ -235,9 +224,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 sql = (
                     "SELECT * FROM receipts_linearized WHERE"
                     " room_id IN (%s) AND stream_id > ? AND stream_id <= ?"
-                ) % (
-                    ",".join(["?"] * len(room_ids))
-                )
+                ) % (",".join(["?"] * len(room_ids)))
                 args = list(room_ids)
                 args.extend([from_key, to_key])
 
@@ -246,9 +233,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 sql = (
                     "SELECT * FROM receipts_linearized WHERE"
                     " room_id IN (%s) AND stream_id <= ?"
-                ) % (
-                    ",".join(["?"] * len(room_ids))
-                )
+                ) % (",".join(["?"] * len(room_ids)))
 
                 args = list(room_ids)
                 args.append(to_key)
@@ -257,19 +242,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
             return self.cursor_to_dict(txn)
 
-        txn_results = yield self.runInteraction(
-            "_get_linearized_receipts_for_rooms", f
-        )
+        txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f)
 
         results = {}
         for row in txn_results:
             # We want a single event per room, since we want to batch the
             # receipts by room, event and type.
-            room_event = results.setdefault(row["room_id"], {
-                "type": "m.receipt",
-                "room_id": row["room_id"],
-                "content": {},
-            })
+            room_event = results.setdefault(
+                row["room_id"],
+                {"type": "m.receipt", "room_id": row["room_id"], "content": {}},
+            )
 
             # The content is of the form:
             # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
@@ -301,21 +283,21 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 args.append(limit)
             txn.execute(sql, args)
 
-            return (
-                r[0:5] + (json.loads(r[5]), ) for r in txn
-            )
+            return (r[0:5] + (json.loads(r[5]),) for r in txn)
+
         return self.runInteraction(
             "get_all_updated_receipts", get_all_updated_receipts_txn
         )
 
-    def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
-                                                    user_id):
+    def _invalidate_get_users_with_receipts_in_room(
+        self, room_id, receipt_type, user_id
+    ):
         if receipt_type != "m.read":
             return
 
         # Returns either an ObservableDeferred or the raw result
         res = self.get_users_with_read_receipts_in_room.cache.get(
-            room_id, None, update_metrics=False,
+            room_id, None, update_metrics=False
         )
 
         # first handle the Deferred case
@@ -346,8 +328,9 @@ class ReceiptsStore(ReceiptsWorkerStore):
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
 
-    def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
-                                      user_id, event_id, data, stream_id):
+    def insert_linearized_receipt_txn(
+        self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
+    ):
         """Inserts a read-receipt into the database if it's newer than the current RR
 
         Returns: int|None
@@ -360,7 +343,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
             table="events",
             retcols=["stream_ordering", "received_ts"],
             keyvalues={"event_id": event_id},
-            allow_none=True
+            allow_none=True,
         )
 
         stream_ordering = int(res["stream_ordering"]) if res else None
@@ -381,31 +364,31 @@ class ReceiptsStore(ReceiptsWorkerStore):
                     logger.debug(
                         "Ignoring new receipt for %s in favour of existing "
                         "one for later event %s",
-                        event_id, eid,
+                        event_id,
+                        eid,
                     )
                     return None
 
-        txn.call_after(
-            self.get_receipts_for_room.invalidate, (room_id, receipt_type)
-        )
+        txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
         txn.call_after(
             self._invalidate_get_users_with_receipts_in_room,
-            room_id, receipt_type, user_id,
+            room_id,
+            receipt_type,
+            user_id,
         )
+        txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
+        # FIXME: This shouldn't invalidate the whole cache
         txn.call_after(
-            self.get_receipts_for_user.invalidate, (user_id, receipt_type)
+            self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
         )
-        # FIXME: This shouldn't invalidate the whole cache
-        txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
 
         txn.call_after(
-            self._receipts_stream_cache.entity_has_changed,
-            room_id, stream_id
+            self._receipts_stream_cache.entity_has_changed, room_id, stream_id
         )
 
         txn.call_after(
             self.get_last_receipt_event_id_for_user.invalidate,
-            (user_id, room_id, receipt_type)
+            (user_id, room_id, receipt_type),
         )
 
         self._simple_delete_txn(
@@ -415,7 +398,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 "room_id": room_id,
                 "receipt_type": receipt_type,
                 "user_id": user_id,
-            }
+            },
         )
 
         self._simple_insert_txn(
@@ -428,15 +411,12 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 "user_id": user_id,
                 "event_id": event_id,
                 "data": json.dumps(data),
-            }
+            },
         )
 
         if receipt_type == "m.read" and stream_ordering is not None:
             self._remove_old_push_actions_before_txn(
-                txn,
-                room_id=room_id,
-                user_id=user_id,
-                stream_ordering=stream_ordering,
+                txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
             )
 
         return rx_ts
@@ -479,7 +459,10 @@ class ReceiptsStore(ReceiptsWorkerStore):
             event_ts = yield self.runInteraction(
                 "insert_linearized_receipt",
                 self.insert_linearized_receipt_txn,
-                room_id, receipt_type, user_id, linearized_event_id,
+                room_id,
+                receipt_type,
+                user_id,
+                linearized_event_id,
                 data,
                 stream_id=stream_id,
             )
@@ -490,39 +473,43 @@ class ReceiptsStore(ReceiptsWorkerStore):
         now = self._clock.time_msec()
         logger.debug(
             "RR for event %s in %s (%i ms old)",
-            linearized_event_id, room_id, now - event_ts,
+            linearized_event_id,
+            room_id,
+            now - event_ts,
         )
 
-        yield self.insert_graph_receipt(
-            room_id, receipt_type, user_id, event_ids, data
-        )
+        yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
 
         max_persisted_id = self._receipts_id_gen.get_current_token()
 
         defer.returnValue((stream_id, max_persisted_id))
 
-    def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
-                             data):
+    def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
         return self.runInteraction(
             "insert_graph_receipt",
             self.insert_graph_receipt_txn,
-            room_id, receipt_type, user_id, event_ids, data
+            room_id,
+            receipt_type,
+            user_id,
+            event_ids,
+            data,
         )
 
-    def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
-                                 user_id, event_ids, data):
-        txn.call_after(
-            self.get_receipts_for_room.invalidate, (room_id, receipt_type)
-        )
+    def insert_graph_receipt_txn(
+        self, txn, room_id, receipt_type, user_id, event_ids, data
+    ):
+        txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
         txn.call_after(
             self._invalidate_get_users_with_receipts_in_room,
-            room_id, receipt_type, user_id,
+            room_id,
+            receipt_type,
+            user_id,
         )
+        txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
+        # FIXME: This shouldn't invalidate the whole cache
         txn.call_after(
-            self.get_receipts_for_user.invalidate, (user_id, receipt_type)
+            self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
         )
-        # FIXME: This shouldn't invalidate the whole cache
-        txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
 
         self._simple_delete_txn(
             txn,
@@ -531,7 +518,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 "room_id": room_id,
                 "receipt_type": receipt_type,
                 "user_id": user_id,
-            }
+            },
         )
         self._simple_insert_txn(
             txn,
@@ -542,5 +529,5 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 "user_id": user_id,
                 "event_ids": json.dumps(event_ids),
                 "data": json.dumps(data),
-            }
+            },
         )