diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/_base.py | 58 | ||||
-rw-r--r-- | synapse/storage/events.py | 12 | ||||
-rw-r--r-- | synapse/storage/registration.py | 16 | ||||
-rw-r--r-- | synapse/storage/relations.py | 71 | ||||
-rw-r--r-- | synapse/storage/stream.py | 21 |
5 files changed, 155 insertions, 23 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 983ce026e1..fa6839ceca 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -227,6 +229,8 @@ class SQLBaseStore(object): # A set of tables that are not safe to use native upserts in. self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) + self._account_validity = self.hs.config.account_validity + # We add the user_directory_search table to the blacklist on SQLite # because the existing search table does not have an index, making it # unsafe to use native upserts. @@ -243,6 +247,14 @@ class SQLBaseStore(object): self._check_safe_to_upsert, ) + if self._account_validity.enabled: + self._clock.call_later( + 0.0, + run_as_background_process, + "account_validity_set_expiration_dates", + self._set_expiration_date_when_missing, + ) + @defer.inlineCallbacks def _check_safe_to_upsert(self): """ @@ -275,6 +287,52 @@ class SQLBaseStore(object): self._check_safe_to_upsert, ) + @defer.inlineCallbacks + def _set_expiration_date_when_missing(self): + """ + Retrieves the list of registered users that don't have an expiration date, and + adds an expiration date for each of them. + """ + + def select_users_with_no_expiration_date_txn(txn): + """Retrieves the list of registered users with no expiration date from the + database. + """ + sql = ( + "SELECT users.name FROM users" + " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" + " WHERE account_validity.user_id is NULL;" + ) + txn.execute(sql, []) + + res = self.cursor_to_dict(txn) + if res: + for user in res: + self.set_expiration_date_for_user_txn(txn, user["name"]) + + yield self.runInteraction( + "get_users_with_no_expiration_date", + select_users_with_no_expiration_date_txn, + ) + + def set_expiration_date_for_user_txn(self, txn, user_id): + """Sets an expiration date to the account with the given user ID. + + Args: + user_id (str): User ID to set an expiration date for. + """ + now_ms = self._clock.time_msec() + expiration_ts = now_ms + self._account_validity.period + self._simple_insert_txn( + txn, + "account_validity", + values={ + "user_id": user_id, + "expiration_ts_ms": expiration_ts, + "email_sent": False, + }, + ) + def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() diff --git a/synapse/storage/events.py b/synapse/storage/events.py index b025ebc926..2ffc27ff41 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -575,10 +575,11 @@ class EventsStore( def _get_events(txn, batch): sql = """ - SELECT prev_event_id + SELECT prev_event_id, internal_metadata FROM event_edges INNER JOIN events USING (event_id) LEFT JOIN rejections USING (event_id) + LEFT JOIN event_json USING (event_id) WHERE prev_event_id IN (%s) AND NOT events.outlier @@ -588,7 +589,11 @@ class EventsStore( ) txn.execute(sql, batch) - results.extend(r[0] for r in txn) + results.extend( + r[0] + for r in txn + if not json.loads(r[1]).get("soft_failed") + ) for chunk in batch_iter(event_ids, 100): yield self.runInteraction("_get_events_which_are_prevs", _get_events, chunk) @@ -1325,6 +1330,9 @@ class EventsStore( txn, event.room_id, event.redacts ) + # Remove from relations table. + self._handle_redaction(txn, event.redacts) + # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. self._handle_mult_prev_events( diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 03a06a83d6..4cf159ba81 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- -# Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -725,17 +727,7 @@ class RegistrationStore( raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) if self._account_validity.enabled: - now_ms = self.clock.time_msec() - expiration_ts = now_ms + self._account_validity.period - self._simple_insert_txn( - txn, - "account_validity", - values={ - "user_id": user_id, - "expiration_ts_ms": expiration_ts, - "email_sent": False, - } - ) + self.set_expiration_date_for_user_txn(txn, user_id) if token: # it's possible for this to get a conflict, but only for a single user diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 42f587a7d8..4c83800cca 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -350,9 +350,7 @@ class RelationsWorkerStore(SQLBaseStore): """ def _get_applicable_edit_txn(txn): - txn.execute( - sql, (event_id, RelationTypes.REPLACES,) - ) + txn.execute(sql, (event_id, RelationTypes.REPLACE)) row = txn.fetchone() if row: return row[0] @@ -367,6 +365,50 @@ class RelationsWorkerStore(SQLBaseStore): edit_event = yield self.get_event(edit_id, allow_none=True) defer.returnValue(edit_event) + def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): + """Check if a user has already annotated an event with the same key + (e.g. already liked an event). + + Args: + parent_id (str): The event being annotated + event_type (str): The event type of the annotation + aggregation_key (str): The aggregation key of the annotation + sender (str): The sender of the annotation + + Returns: + Deferred[bool] + """ + + sql = """ + SELECT 1 FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id = ? + AND relation_type = ? + AND type = ? + AND sender = ? + AND aggregation_key = ? + LIMIT 1; + """ + + def _get_if_user_has_annotated_event(txn): + txn.execute( + sql, + ( + parent_id, + RelationTypes.ANNOTATION, + event_type, + sender, + aggregation_key, + ), + ) + + return bool(txn.fetchone()) + + return self.runInteraction( + "get_if_user_has_annotated_event", _get_if_user_has_annotated_event + ) + class RelationsStore(RelationsWorkerStore): def _handle_event_relations(self, txn, event): @@ -384,8 +426,8 @@ class RelationsStore(RelationsWorkerStore): rel_type = relation.get("rel_type") if rel_type not in ( RelationTypes.ANNOTATION, - RelationTypes.REFERENCES, - RelationTypes.REPLACES, + RelationTypes.REFERENCE, + RelationTypes.REPLACE, ): # Unknown relation type return @@ -413,5 +455,22 @@ class RelationsStore(RelationsWorkerStore): self.get_aggregation_groups_for_event.invalidate_many, (parent_id,) ) - if rel_type == RelationTypes.REPLACES: + if rel_type == RelationTypes.REPLACE: txn.call_after(self.get_applicable_edit.invalidate, (parent_id,)) + + def _handle_redaction(self, txn, redacted_event_id): + """Handles receiving a redaction and checking whether we need to remove + any redacted relations from the database. + + Args: + txn + redacted_event_id (str): The event that was redacted. + """ + + self._simple_delete_txn( + txn, + table="event_relations", + keyvalues={ + "event_id": redacted_event_id, + } + ) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 163363c0c2..529ad4ea79 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -77,14 +77,27 @@ def generate_pagination_where_clause( would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3). + Note that tokens are considered to be after the row they are in, e.g. if + a row A has a token T, then we consider A to be before T. This convention + is important when figuring out inequalities for the generated SQL, and + produces the following result: + - If paginating forwards then we exclude any rows matching the from + token, but include those that match the to token. + - If paginating backwards then we include any rows matching the from + token, but include those that match the to token. + Args: direction (str): Whether we're paginating backwards("b") or forwards ("f"). column_names (tuple[str, str]): The column names to bound. Must *not* be user defined as these get inserted directly into the SQL statement without escapes. - from_token (tuple[int, int]|None) - to_token (tuple[int, int]|None) + from_token (tuple[int, int]|None): The start point for the pagination. + This is an exclusive minimum bound if direction is "f", and an + inclusive maximum bound if direction is "b". + to_token (tuple[int, int]|None): The endpoint point for the pagination. + This is an inclusive maximum bound if direction is "f", and an + exclusive minimum bound if direction is "b". engine: The database engine to generate the clauses for Returns: @@ -131,7 +144,9 @@ def _make_generic_sql_bound(bound, column_names, values, engine): names (tuple[str, str]): The column names. Must *not* be user defined as these get inserted directly into the SQL statement without escapes. - values (tuple[int, int]): The values to bound the columns by. + values (tuple[int|None, int]): The values to bound the columns by. If + the first value is None then only creates a bound on the second + column. engine: The database engine to generate the SQL for Returns: |