diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/_base.py | 6 | ||||
-rw-r--r-- | synapse/storage/events_worker.py | 37 | ||||
-rw-r--r-- | synapse/storage/prepare_database.py | 2 | ||||
-rw-r--r-- | synapse/storage/registration.py | 290 | ||||
-rw-r--r-- | synapse/storage/schema/delta/55/track_threepid_validations.sql | 31 | ||||
-rw-r--r-- | synapse/storage/stats.py | 16 | ||||
-rw-r--r-- | synapse/storage/stream.py | 12 |
7 files changed, 388 insertions, 6 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 52891bb9eb..ae891aa332 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -588,6 +588,10 @@ class SQLBaseStore(object): Args: table : string giving the table name values : dict of new column names and values for them + or_ignore : bool stating whether an exception should be raised + when a conflicting row already exists. If True, False will be + returned by the function instead + desc : string giving a description of the transaction Returns: bool: Whether the row was inserted or not. Only useful when @@ -1228,8 +1232,8 @@ class SQLBaseStore(object): ) txn.execute(select_sql, list(keyvalues.values())) - row = txn.fetchone() + if not row: if allow_none: return None diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 1782428048..cc7df5cf14 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -78,6 +78,43 @@ class EventsWorkerStore(SQLBaseStore): desc="get_received_ts", ) + def get_received_ts_by_stream_pos(self, stream_ordering): + """Given a stream ordering get an approximate timestamp of when it + happened. + + This is done by simply taking the received ts of the first event that + has a stream ordering greater than or equal to the given stream pos. + If none exists returns the current time, on the assumption that it must + have happened recently. + + Args: + stream_ordering (int) + + Returns: + Deferred[int] + """ + + def _get_approximate_received_ts_txn(txn): + sql = """ + SELECT received_ts FROM events + WHERE stream_ordering >= ? + LIMIT 1 + """ + + txn.execute(sql, (stream_ordering,)) + row = txn.fetchone() + if row and row[0]: + ts = row[0] + else: + ts = self.clock.time_msec() + + return ts + + return self.runInteraction( + "get_approximate_received_ts", + _get_approximate_received_ts_txn, + ) + @defer.inlineCallbacks def get_event( self, diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index b81c05369f..f2c1bed487 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 54 +SCHEMA_VERSION = 55 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 4cf159ba81..9b41cbd757 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -17,17 +17,20 @@ import re +from six import iterkeys from six.moves import range from twisted.internet import defer from synapse.api.constants import UserTypes -from synapse.api.errors import Codes, StoreError +from synapse.api.errors import Codes, StoreError, ThreepidValidationError from synapse.storage import background_updates from synapse.storage._base import SQLBaseStore from synapse.types import UserID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 + class RegistrationWorkerStore(SQLBaseStore): def __init__(self, db_conn, hs): @@ -422,7 +425,7 @@ class RegistrationWorkerStore(SQLBaseStore): defer.returnValue(None) @defer.inlineCallbacks - def get_user_id_by_threepid(self, medium, address): + def get_user_id_by_threepid(self, medium, address, require_verified=False): """Returns user id from threepid Args: @@ -595,6 +598,11 @@ class RegistrationStore( "user_threepids_grandfather", self._bg_user_threepids_grandfather, ) + # Create a background job for culling expired 3PID validity tokens + hs.get_clock().looping_call( + self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS, + ) + @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id=None): """Adds an access token for the given user. @@ -963,7 +971,6 @@ class RegistrationStore( We do this by grandfathering in existing user threepids assuming that they used one of the server configured trusted identity servers. """ - id_servers = set(self.config.trusted_third_party_id_servers) def _bg_user_threepids_grandfather_txn(txn): @@ -984,3 +991,280 @@ class RegistrationStore( yield self._end_background_update("user_threepids_grandfather") defer.returnValue(1) + + def get_threepid_validation_session( + self, + medium, + client_secret, + address=None, + sid=None, + validated=None, + ): + """Gets a session_id and last_send_attempt (if available) for a + client_secret/medium/(address|session_id) combo + + Args: + medium (str|None): The medium of the 3PID + address (str|None): The address of the 3PID + sid (str|None): The ID of the validation session + client_secret (str|None): A unique string provided by the client to + help identify this validation attempt + validated (bool|None): Whether sessions should be filtered by + whether they have been validated already or not. None to + perform no filtering + + Returns: + deferred {str, int}|None: A dict containing the + latest session_id and send_attempt count for this 3PID. + Otherwise None if there hasn't been a previous attempt + """ + keyvalues = { + "medium": medium, + "client_secret": client_secret, + } + if address: + keyvalues["address"] = address + if sid: + keyvalues["session_id"] = sid + + assert(address or sid) + + def get_threepid_validation_session_txn(txn): + sql = """ + SELECT address, session_id, medium, client_secret, + last_send_attempt, validated_at + FROM threepid_validation_session WHERE %s + """ % (" AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),) + + if validated is not None: + sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL") + + sql += " LIMIT 1" + + txn.execute(sql, list(keyvalues.values())) + rows = self.cursor_to_dict(txn) + if not rows: + return None + + return rows[0] + + return self.runInteraction( + "get_threepid_validation_session", + get_threepid_validation_session_txn, + ) + + def validate_threepid_session( + self, + session_id, + client_secret, + token, + current_ts, + ): + """Attempt to validate a threepid session using a token + + Args: + session_id (str): The id of a validation session + client_secret (str): A unique string provided by the client to + help identify this validation attempt + token (str): A validation token + current_ts (int): The current unix time in milliseconds. Used for + checking token expiry status + + Returns: + deferred str|None: A str representing a link to redirect the user + to if there is one. + """ + # Insert everything into a transaction in order to run atomically + def validate_threepid_session_txn(txn): + row = self._simple_select_one_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + retcols=["client_secret", "validated_at"], + allow_none=True, + ) + + if not row: + raise ThreepidValidationError(400, "Unknown session_id") + retrieved_client_secret = row["client_secret"] + validated_at = row["validated_at"] + + if retrieved_client_secret != client_secret: + raise ThreepidValidationError( + 400, "This client_secret does not match the provided session_id", + ) + + row = self._simple_select_one_txn( + txn, + table="threepid_validation_token", + keyvalues={"session_id": session_id, "token": token}, + retcols=["expires", "next_link"], + allow_none=True, + ) + + if not row: + raise ThreepidValidationError( + 400, "Validation token not found or has expired", + ) + expires = row["expires"] + next_link = row["next_link"] + + # If the session is already validated, no need to revalidate + if validated_at: + return next_link + + if expires <= current_ts: + raise ThreepidValidationError( + 400, "This token has expired. Please request a new one", + ) + + # Looks good. Validate the session + self._simple_update_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + updatevalues={"validated_at": self.clock.time_msec()}, + ) + + return next_link + + # Return next_link if it exists + return self.runInteraction( + "validate_threepid_session_txn", + validate_threepid_session_txn, + ) + + def upsert_threepid_validation_session( + self, + medium, + address, + client_secret, + send_attempt, + session_id, + validated_at=None, + ): + """Upsert a threepid validation session + Args: + medium (str): The medium of the 3PID + address (str): The address of the 3PID + client_secret (str): A unique string provided by the client to + help identify this validation attempt + send_attempt (int): The latest send_attempt on this session + session_id (str): The id of this validation session + validated_at (int|None): The unix timestamp in milliseconds of + when the session was marked as valid + """ + insertion_values = { + "medium": medium, + "address": address, + "client_secret": client_secret, + } + + if validated_at: + insertion_values["validated_at"] = validated_at + + return self._simple_upsert( + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + values={"last_send_attempt": send_attempt}, + insertion_values=insertion_values, + desc="upsert_threepid_validation_session", + ) + + def start_or_continue_validation_session( + self, + medium, + address, + session_id, + client_secret, + send_attempt, + next_link, + token, + token_expires, + ): + """Creates a new threepid validation session if it does not already + exist and associates a new validation token with it + + Args: + medium (str): The medium of the 3PID + address (str): The address of the 3PID + session_id (str): The id of this validation session + client_secret (str): A unique string provided by the client to + help identify this validation attempt + send_attempt (int): The latest send_attempt on this session + next_link (str|None): The link to redirect the user to upon + successful validation + token (str): The validation token + token_expires (int): The timestamp for which after the token + will no longer be valid + """ + def start_or_continue_validation_session_txn(txn): + # Create or update a validation session + self._simple_upsert_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + values={"last_send_attempt": send_attempt}, + insertion_values={ + "medium": medium, + "address": address, + "client_secret": client_secret, + }, + ) + + # Create a new validation token with this session ID + self._simple_insert_txn( + txn, + table="threepid_validation_token", + values={ + "session_id": session_id, + "token": token, + "next_link": next_link, + "expires": token_expires, + }, + ) + + return self.runInteraction( + "start_or_continue_validation_session", + start_or_continue_validation_session_txn, + ) + + def cull_expired_threepid_validation_tokens(self): + """Remove threepid validation tokens with expiry dates that have passed""" + def cull_expired_threepid_validation_tokens_txn(txn, ts): + sql = """ + DELETE FROM threepid_validation_token WHERE + expires < ? + """ + return txn.execute(sql, (ts,)) + + return self.runInteraction( + "cull_expired_threepid_validation_tokens", + cull_expired_threepid_validation_tokens_txn, + self.clock.time_msec(), + ) + + def delete_threepid_session(self, session_id): + """Removes a threepid validation session from the database. This can + be done after validation has been performed and whatever action was + waiting on it has been carried out + + Args: + session_id (str): The ID of the session to delete + """ + def delete_threepid_session_txn(txn): + self._simple_delete_txn( + txn, + table="threepid_validation_token", + keyvalues={"session_id": session_id}, + ) + self._simple_delete_txn( + txn, + table="threepid_validation_session", + keyvalues={"session_id": session_id}, + ) + + return self.runInteraction( + "delete_threepid_session", + delete_threepid_session_txn, + ) diff --git a/synapse/storage/schema/delta/55/track_threepid_validations.sql b/synapse/storage/schema/delta/55/track_threepid_validations.sql new file mode 100644 index 0000000000..a8eced2e0a --- /dev/null +++ b/synapse/storage/schema/delta/55/track_threepid_validations.sql @@ -0,0 +1,31 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS threepid_validation_session ( + session_id TEXT PRIMARY KEY, + medium TEXT NOT NULL, + address TEXT NOT NULL, + client_secret TEXT NOT NULL, + last_send_attempt BIGINT NOT NULL, + validated_at BIGINT +); + +CREATE TABLE IF NOT EXISTS threepid_validation_token ( + token TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + next_link TEXT, + expires BIGINT NOT NULL +); + +CREATE INDEX threepid_validation_token_session_id ON threepid_validation_token(session_id); diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py index 1c0b183a56..ff266b09b0 100644 --- a/synapse/storage/stats.py +++ b/synapse/storage/stats.py @@ -328,6 +328,22 @@ class StatsStore(StateDeltasStore): room_id (str) fields (dict[str:Any]) """ + + # For whatever reason some of the fields may contain null bytes, which + # postgres isn't a fan of, so we replace those fields with null. + for col in ( + "join_rules", + "history_visibility", + "encryption", + "name", + "topic", + "avatar", + "canonical_alias" + ): + field = fields.get(col) + if field and "\0" in field: + fields[col] = None + return self._simple_upsert( table="room_state", keyvalues={"room_id": room_id}, diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 529ad4ea79..6f7f65d96b 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -592,8 +592,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) def get_max_topological_token(self, room_id, stream_key): + """Get the max topological token in a room before the given stream + ordering. + + Args: + room_id (str) + stream_key (int) + + Returns: + Deferred[int] + """ sql = ( - "SELECT max(topological_ordering) FROM events" + "SELECT coalesce(max(topological_ordering), 0) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) return self._execute( |