summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py58
-rw-r--r--synapse/storage/events.py12
-rw-r--r--synapse/storage/registration.py16
-rw-r--r--synapse/storage/relations.py71
-rw-r--r--synapse/storage/stream.py21
5 files changed, 155 insertions, 23 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 983ce026e1..fa6839ceca 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -1,5 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -227,6 +229,8 @@ class SQLBaseStore(object):
         # A set of tables that are not safe to use native upserts in.
         self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
 
+        self._account_validity = self.hs.config.account_validity
+
         # We add the user_directory_search table to the blacklist on SQLite
         # because the existing search table does not have an index, making it
         # unsafe to use native upserts.
@@ -243,6 +247,14 @@ class SQLBaseStore(object):
                 self._check_safe_to_upsert,
             )
 
+        if self._account_validity.enabled:
+            self._clock.call_later(
+                0.0,
+                run_as_background_process,
+                "account_validity_set_expiration_dates",
+                self._set_expiration_date_when_missing,
+            )
+
     @defer.inlineCallbacks
     def _check_safe_to_upsert(self):
         """
@@ -275,6 +287,52 @@ class SQLBaseStore(object):
                 self._check_safe_to_upsert,
             )
 
+    @defer.inlineCallbacks
+    def _set_expiration_date_when_missing(self):
+        """
+        Retrieves the list of registered users that don't have an expiration date, and
+        adds an expiration date for each of them.
+        """
+
+        def select_users_with_no_expiration_date_txn(txn):
+            """Retrieves the list of registered users with no expiration date from the
+            database.
+            """
+            sql = (
+                "SELECT users.name FROM users"
+                " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+                " WHERE account_validity.user_id is NULL;"
+            )
+            txn.execute(sql, [])
+
+            res = self.cursor_to_dict(txn)
+            if res:
+                for user in res:
+                    self.set_expiration_date_for_user_txn(txn, user["name"])
+
+        yield self.runInteraction(
+            "get_users_with_no_expiration_date",
+            select_users_with_no_expiration_date_txn,
+        )
+
+    def set_expiration_date_for_user_txn(self, txn, user_id):
+        """Sets an expiration date to the account with the given user ID.
+
+        Args:
+             user_id (str): User ID to set an expiration date for.
+        """
+        now_ms = self._clock.time_msec()
+        expiration_ts = now_ms + self._account_validity.period
+        self._simple_insert_txn(
+            txn,
+            "account_validity",
+            values={
+                "user_id": user_id,
+                "expiration_ts_ms": expiration_ts,
+                "email_sent": False,
+            },
+        )
+
     def start_profiling(self):
         self._previous_loop_ts = self._clock.time_msec()
 
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index b025ebc926..2ffc27ff41 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -575,10 +575,11 @@ class EventsStore(
 
         def _get_events(txn, batch):
             sql = """
-            SELECT prev_event_id
+            SELECT prev_event_id, internal_metadata
             FROM event_edges
                 INNER JOIN events USING (event_id)
                 LEFT JOIN rejections USING (event_id)
+                LEFT JOIN event_json USING (event_id)
             WHERE
                 prev_event_id IN (%s)
                 AND NOT events.outlier
@@ -588,7 +589,11 @@ class EventsStore(
             )
 
             txn.execute(sql, batch)
-            results.extend(r[0] for r in txn)
+            results.extend(
+                r[0]
+                for r in txn
+                if not json.loads(r[1]).get("soft_failed")
+            )
 
         for chunk in batch_iter(event_ids, 100):
             yield self.runInteraction("_get_events_which_are_prevs", _get_events, chunk)
@@ -1325,6 +1330,9 @@ class EventsStore(
                     txn, event.room_id, event.redacts
                 )
 
+                # Remove from relations table.
+                self._handle_redaction(txn, event.redacts)
+
         # Update the event_forward_extremities, event_backward_extremities and
         # event_edges tables.
         self._handle_mult_prev_events(
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 03a06a83d6..4cf159ba81 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -1,5 +1,7 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -725,17 +727,7 @@ class RegistrationStore(
             raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
 
         if self._account_validity.enabled:
-            now_ms = self.clock.time_msec()
-            expiration_ts = now_ms + self._account_validity.period
-            self._simple_insert_txn(
-                txn,
-                "account_validity",
-                values={
-                    "user_id": user_id,
-                    "expiration_ts_ms": expiration_ts,
-                    "email_sent": False,
-                }
-            )
+            self.set_expiration_date_for_user_txn(txn, user_id)
 
         if token:
             # it's possible for this to get a conflict, but only for a single user
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index 42f587a7d8..4c83800cca 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -350,9 +350,7 @@ class RelationsWorkerStore(SQLBaseStore):
         """
 
         def _get_applicable_edit_txn(txn):
-            txn.execute(
-                sql, (event_id, RelationTypes.REPLACES,)
-            )
+            txn.execute(sql, (event_id, RelationTypes.REPLACE))
             row = txn.fetchone()
             if row:
                 return row[0]
@@ -367,6 +365,50 @@ class RelationsWorkerStore(SQLBaseStore):
         edit_event = yield self.get_event(edit_id, allow_none=True)
         defer.returnValue(edit_event)
 
+    def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
+        """Check if a user has already annotated an event with the same key
+        (e.g. already liked an event).
+
+        Args:
+            parent_id (str): The event being annotated
+            event_type (str): The event type of the annotation
+            aggregation_key (str): The aggregation key of the annotation
+            sender (str): The sender of the annotation
+
+        Returns:
+            Deferred[bool]
+        """
+
+        sql = """
+            SELECT 1 FROM event_relations
+            INNER JOIN events USING (event_id)
+            WHERE
+                relates_to_id = ?
+                AND relation_type = ?
+                AND type = ?
+                AND sender = ?
+                AND aggregation_key = ?
+            LIMIT 1;
+        """
+
+        def _get_if_user_has_annotated_event(txn):
+            txn.execute(
+                sql,
+                (
+                    parent_id,
+                    RelationTypes.ANNOTATION,
+                    event_type,
+                    sender,
+                    aggregation_key,
+                ),
+            )
+
+            return bool(txn.fetchone())
+
+        return self.runInteraction(
+            "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
+        )
+
 
 class RelationsStore(RelationsWorkerStore):
     def _handle_event_relations(self, txn, event):
@@ -384,8 +426,8 @@ class RelationsStore(RelationsWorkerStore):
         rel_type = relation.get("rel_type")
         if rel_type not in (
             RelationTypes.ANNOTATION,
-            RelationTypes.REFERENCES,
-            RelationTypes.REPLACES,
+            RelationTypes.REFERENCE,
+            RelationTypes.REPLACE,
         ):
             # Unknown relation type
             return
@@ -413,5 +455,22 @@ class RelationsStore(RelationsWorkerStore):
             self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
         )
 
-        if rel_type == RelationTypes.REPLACES:
+        if rel_type == RelationTypes.REPLACE:
             txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
+
+    def _handle_redaction(self, txn, redacted_event_id):
+        """Handles receiving a redaction and checking whether we need to remove
+        any redacted relations from the database.
+
+        Args:
+            txn
+            redacted_event_id (str): The event that was redacted.
+        """
+
+        self._simple_delete_txn(
+            txn,
+            table="event_relations",
+            keyvalues={
+                "event_id": redacted_event_id,
+            }
+        )
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 163363c0c2..529ad4ea79 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -77,14 +77,27 @@ def generate_pagination_where_clause(
 
     would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3).
 
+    Note that tokens are considered to be after the row they are in, e.g. if
+    a row A has a token T, then we consider A to be before T. This convention
+    is important when figuring out inequalities for the generated SQL, and
+    produces the following result:
+        - If paginating forwards then we exclude any rows matching the from
+          token, but include those that match the to token.
+        - If paginating backwards then we include any rows matching the from
+          token, but include those that match the to token.
+
     Args:
         direction (str): Whether we're paginating backwards("b") or
             forwards ("f").
         column_names (tuple[str, str]): The column names to bound. Must *not*
             be user defined as these get inserted directly into the SQL
             statement without escapes.
-        from_token (tuple[int, int]|None)
-        to_token (tuple[int, int]|None)
+        from_token (tuple[int, int]|None): The start point for the pagination.
+            This is an exclusive minimum bound if direction is "f", and an
+            inclusive maximum bound if direction is "b".
+        to_token (tuple[int, int]|None): The endpoint point for the pagination.
+            This is an inclusive maximum bound if direction is "f", and an
+            exclusive minimum bound if direction is "b".
         engine: The database engine to generate the clauses for
 
     Returns:
@@ -131,7 +144,9 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
         names (tuple[str, str]): The column names. Must *not* be user defined
             as these get inserted directly into the SQL statement without
             escapes.
-        values (tuple[int, int]): The values to bound the columns by.
+        values (tuple[int|None, int]): The values to bound the columns by. If
+            the first value is None then only creates a bound on the second
+            column.
         engine: The database engine to generate the SQL for
 
     Returns: