summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py23
-rw-r--r--synapse/storage/data_stores/main/events.py126
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql21
3 files changed, 160 insertions, 10 deletions
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index d8ad59ad93..643327b57b 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -145,13 +145,28 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         txn.execute(signature_sql, signature_query_params)
         rows = self.cursor_to_dict(txn)
 
+        # add each cross-signing signature to the correct device in the result dict.
         for row in rows:
+            signing_user_id = row["user_id"]
+            signing_key_id = row["key_id"]
             target_user_id = row["target_user_id"]
             target_device_id = row["target_device_id"]
-            if target_user_id in result and target_device_id in result[target_user_id]:
-                result[target_user_id][target_device_id].setdefault(
-                    "signatures", {}
-                ).setdefault(row["user_id"], {})[row["key_id"]] = row["signature"]
+            signature = row["signature"]
+
+            target_user_result = result.get(target_user_id)
+            if not target_user_result:
+                continue
+
+            target_device_result = target_user_result.get(target_device_id)
+            if not target_device_result:
+                # note that target_device_result will be None for deleted devices.
+                continue
+
+            target_device_signatures = target_device_result.setdefault("signatures", {})
+            signing_user_signatures = target_device_signatures.setdefault(
+                signing_user_id, {}
+            )
+            signing_user_signatures[signing_key_id] = signature
 
         log_kv(result)
         return result
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 2737a1d3ae..79c91fe284 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -130,6 +130,8 @@ class EventsStore(
         if self.hs.config.redaction_retention_period is not None:
             hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
 
+        self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+
     @defer.inlineCallbacks
     def _read_forward_extremities(self):
         def fetch(txn):
@@ -940,6 +942,12 @@ class EventsStore(
                     txn, event.event_id, labels, event.room_id, event.depth
                 )
 
+            if self._ephemeral_messages_enabled:
+                # If there's an expiry timestamp on the event, store it.
+                expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
+                if isinstance(expiry_ts, int) and not event.is_state():
+                    self._insert_event_expiry_txn(txn, event.event_id, expiry_ts)
+
         # Insert into the room_memberships table.
         self._store_room_members_txn(
             txn,
@@ -1101,12 +1109,7 @@ class EventsStore(
         def _update_censor_txn(txn):
             for redaction_id, event_id, pruned_json in updates:
                 if pruned_json:
-                    self._simple_update_one_txn(
-                        txn,
-                        table="event_json",
-                        keyvalues={"event_id": event_id},
-                        updatevalues={"json": pruned_json},
-                    )
+                    self._censor_event_txn(txn, event_id, pruned_json)
 
                 self._simple_update_one_txn(
                     txn,
@@ -1117,6 +1120,22 @@ class EventsStore(
 
         yield self.runInteraction("_update_censor_txn", _update_censor_txn)
 
+    def _censor_event_txn(self, txn, event_id, pruned_json):
+        """Censor an event by replacing its JSON in the event_json table with the
+        provided pruned JSON.
+
+        Args:
+            txn (LoggingTransaction): The database transaction.
+            event_id (str): The ID of the event to censor.
+            pruned_json (str): The pruned JSON
+        """
+        self._simple_update_one_txn(
+            txn,
+            table="event_json",
+            keyvalues={"event_id": event_id},
+            updatevalues={"json": pruned_json},
+        )
+
     @defer.inlineCallbacks
     def count_daily_messages(self):
         """
@@ -1957,6 +1976,101 @@ class EventsStore(
             ],
         )
 
+    def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+        """Save the expiry timestamp associated with a given event ID.
+
+        Args:
+            txn (LoggingTransaction): The database transaction to use.
+            event_id (str): The event ID the expiry timestamp is associated with.
+            expiry_ts (int): The timestamp at which to expire (delete) the event.
+        """
+        return self._simple_insert_txn(
+            txn=txn,
+            table="event_expiry",
+            values={"event_id": event_id, "expiry_ts": expiry_ts},
+        )
+
+    @defer.inlineCallbacks
+    def expire_event(self, event_id):
+        """Retrieve and expire an event that has expired, and delete its associated
+        expiry timestamp. If the event can't be retrieved, delete its associated
+        timestamp so we don't try to expire it again in the future.
+
+        Args:
+             event_id (str): The ID of the event to delete.
+        """
+        # Try to retrieve the event's content from the database or the event cache.
+        event = yield self.get_event(event_id)
+
+        def delete_expired_event_txn(txn):
+            # Delete the expiry timestamp associated with this event from the database.
+            self._delete_event_expiry_txn(txn, event_id)
+
+            if not event:
+                # If we can't find the event, log a warning and delete the expiry date
+                # from the database so that we don't try to expire it again in the
+                # future.
+                logger.warning(
+                    "Can't expire event %s because we don't have it.", event_id
+                )
+                return
+
+            # Prune the event's dict then convert it to JSON.
+            pruned_json = encode_json(prune_event_dict(event.get_dict()))
+
+            # Update the event_json table to replace the event's JSON with the pruned
+            # JSON.
+            self._censor_event_txn(txn, event.event_id, pruned_json)
+
+            # We need to invalidate the event cache entry for this event because we
+            # changed its content in the database. We can't call
+            # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
+            # right type.
+            txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+            # Send that invalidation to replication so that other workers also invalidate
+            # the event cache.
+            self._send_invalidation_to_replication(
+                txn, "_get_event_cache", (event.event_id,)
+            )
+
+        yield self.runInteraction("delete_expired_event", delete_expired_event_txn)
+
+    def _delete_event_expiry_txn(self, txn, event_id):
+        """Delete the expiry timestamp associated with an event ID without deleting the
+        actual event.
+
+        Args:
+            txn (LoggingTransaction): The transaction to use to perform the deletion.
+            event_id (str): The event ID to delete the associated expiry timestamp of.
+        """
+        return self._simple_delete_txn(
+            txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
+        )
+
+    def get_next_event_to_expire(self):
+        """Retrieve the entry with the lowest expiry timestamp in the event_expiry
+        table, or None if there's no more event to expire.
+
+        Returns: Deferred[Optional[Tuple[str, int]]]
+            A tuple containing the event ID as its first element and an expiry timestamp
+            as its second one, if there's at least one row in the event_expiry table.
+            None otherwise.
+        """
+
+        def get_next_event_to_expire_txn(txn):
+            txn.execute(
+                """
+                SELECT event_id, expiry_ts FROM event_expiry
+                ORDER BY expiry_ts ASC LIMIT 1
+                """
+            )
+
+            return txn.fetchone()
+
+        return self.runInteraction(
+            desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
+        )
+
 
 AllNewEventsResult = namedtuple(
     "AllNewEventsResult",
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql
new file mode 100644
index 0000000000..81a36a8b1d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql
@@ -0,0 +1,21 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS event_expiry (
+    event_id TEXT PRIMARY KEY,
+    expiry_ts BIGINT NOT NULL
+);
+
+CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts);