summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorAndrew Morgan <andrew@amorgan.xyz>2020-10-20 16:45:58 +0100
committerAndrew Morgan <andrew@amorgan.xyz>2020-10-20 16:45:58 +0100
commitb7bb088b8456b0f7e1007a3720cdff6611260eb3 (patch)
treec9e9b020da28fe3b131eac5ce51d6734700b091b /synapse/storage/databases
parentMerge commit 'b79d69796' into anoa/dinsic_release_1_21_x (diff)
parentFix rate limiting unit tests. (#8167) (diff)
downloadsynapse-b7bb088b8456b0f7e1007a3720cdff6611260eb3.tar.xz
Merge commit '56efa9ec7' into anoa/dinsic_release_1_21_x
* commit '56efa9ec7': (22 commits)
  Fix rate limiting unit tests. (#8167)
  Add functions to `MultiWriterIdGen` used by events stream (#8164)
  Do not allow send_nonmember_event to be called with shadow-banned users. (#8158)
  Changelog fixes
  Make StreamIdGen `get_next` and `get_next_mult` async  (#8161)
  Wording fixes to 'name' user admin api filter (#8163)
  Fix missing double-backtick in RST document
  Search in columns 'name' and 'displayname' in the admin users endpoint (#7377)
  Add type hints for state. (#8140)
  Stop shadow-banned users from sending non-member events. (#8142)
  Allow capping a room's retention policy (#8104)
  Add healthcheck for default localhost 8008 port on /health endpoint. (#8147)
  Fix flaky shadow-ban tests. (#8152)
  Don't fail /submit_token requests on incorrect session ID if request_token_inhibit_3pid_errors is turned on (#7991)
  Do not apply ratelimiting on joins to appservices (#8139)
  Micro-optimisations to get_auth_chain_ids (#8132)
  Allow denying or shadow banning registrations via the spam checker (#8034)
  Stop shadow-banned users from sending invites. (#8095)
  Be more tolerant of membership events in unknown rooms (#8110)
  Improve the error code when trying to register using a name reserved for guests. (#8135)
  ...
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/__init__.py31
-rw-r--r--synapse/storage/databases/main/account_data.py4
-rw-r--r--synapse/storage/databases/main/appservice.py5
-rw-r--r--synapse/storage/databases/main/deviceinbox.py4
-rw-r--r--synapse/storage/databases/main/devices.py8
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py43
-rw-r--r--synapse/storage/databases/main/event_federation.py40
-rw-r--r--synapse/storage/databases/main/events.py4
-rw-r--r--synapse/storage/databases/main/events_worker.py31
-rw-r--r--synapse/storage/databases/main/group_server.py2
-rw-r--r--synapse/storage/databases/main/presence.py2
-rw-r--r--synapse/storage/databases/main/push_rule.py8
-rw-r--r--synapse/storage/databases/main/pusher.py4
-rw-r--r--synapse/storage/databases/main/receipts.py3
-rw-r--r--synapse/storage/databases/main/registration.py25
-rw-r--r--synapse/storage/databases/main/room.py28
-rw-r--r--synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql25
-rw-r--r--synapse/storage/databases/main/tags.py11
-rw-r--r--synapse/storage/databases/main/ui_auth.py61
19 files changed, 218 insertions, 121 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py

index 17fa470919..0934ae276c 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -498,7 +498,7 @@ class DataStore( ) def get_users_paginate( - self, start, limit, name=None, guests=True, deactivated=False + self, start, limit, user_id=None, name=None, guests=True, deactivated=False ): """Function to retrieve a paginated list of users from users list. This will return a json list of users and the @@ -507,7 +507,8 @@ class DataStore( Args: start (int): start number to begin the query from limit (int): number of rows to retrieve - name (string): filter for user names + user_id (string): search for user_id. ignored if name is not None + name (string): search for local part of user_id or display name guests (bool): whether to in include guest users deactivated (bool): whether to include deactivated users Returns: @@ -516,11 +517,14 @@ class DataStore( def get_users_paginate_txn(txn): filters = [] - args = [] + args = [self.hs.config.server_name] if name: + filters.append("(name LIKE ? OR displayname LIKE ?)") + args.extend(["@%" + name + "%:%", "%" + name + "%"]) + elif user_id: filters.append("name LIKE ?") - args.append("%" + name + "%") + args.extend(["%" + user_id + "%"]) if not guests: filters.append("is_guest = 0") @@ -530,20 +534,23 @@ class DataStore( where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" - sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause) - txn.execute(sql, args) - count = txn.fetchone()[0] - - args = [self.hs.config.server_name] + args + [limit, start] - sql = """ - SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url + sql_base = """ FROM users as u LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? {} - ORDER BY u.name LIMIT ? OFFSET ? """.format( where_clause ) + sql = "SELECT COUNT(*) as total_users " + sql_base + txn.execute(sql, args) + count = txn.fetchone()[0] + + sql = ( + "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url " + + sql_base + + " ORDER BY u.name LIMIT ? OFFSET ?" + ) + args += [limit, start] txn.execute(sql, args) users = self.db_pool.cursor_to_dict(txn) return users, count diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 82aac2bbf3..04042a2c98 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py
@@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. @@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore): """ content_json = json_encoder.encode(content) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 02568a2391..77723f7d4d 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py
@@ -16,13 +16,12 @@ import logging import re -from canonicaljson import json - from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util import json_encoder logger = logging.getLogger(__name__) @@ -204,7 +203,7 @@ class ApplicationServiceTransactionWorkerStore( new_txn_id = max(highest_txn_id, last_txn_id) + 1 # Insert new txn into txn table - event_ids = json.dumps([e.event_id for e in events]) + event_ids = json_encoder.encode([e.event_id for e in events]) txn.execute( "INSERT INTO application_services_txns(as_id, txn_id, event_ids) " "VALUES(?,?,?)", diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 1f6e995c4f..bb85637a95 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) - with self._device_inbox_id_gen.get_next() as stream_id: + with await self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id @@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) txn, stream_id, local_messages_by_user_then_device ) - with self._device_inbox_id_gen.get_next() as stream_id: + with await self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9a786e2929..03b45dbc4d 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -380,7 +380,7 @@ class DeviceWorkerStore(SQLBaseStore): THe new stream ID. """ - with self._device_list_id_gen.get_next() as stream_id: + with await self._device_list_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, @@ -1146,7 +1146,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not device_ids: return - with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: + with await self._device_list_id_gen.get_next_mult( + len(device_ids) + ) as stream_ids: await self.db_pool.runInteraction( "add_device_change_to_stream", self._add_device_change_to_stream_txn, @@ -1159,7 +1161,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return stream_ids[-1] context = get_active_span_text_map() - with self._device_list_id_gen.get_next_mult( + with await self._device_list_id_gen.get_next_mult( len(hosts) * len(device_ids) ) as stream_ids: await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index f93e0d320d..385868bdab 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -648,7 +648,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) - def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): + def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id): """Set a user's cross-signing key. Args: @@ -658,6 +658,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key key (dict): the key data + stream_id (int) """ # the 'key' dict will look something like: # { @@ -695,23 +696,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ) # and finally, store the key itself - with self._cross_signing_id_gen.get_next() as stream_id: - self.db_pool.simple_insert_txn( - txn, - "e2e_cross_signing_keys", - values={ - "user_id": user_id, - "keytype": key_type, - "keydata": json_encoder.encode(key), - "stream_id": stream_id, - }, - ) + self.db_pool.simple_insert_txn( + txn, + "e2e_cross_signing_keys", + values={ + "user_id": user_id, + "keytype": key_type, + "keydata": json_encoder.encode(key), + "stream_id": stream_id, + }, + ) self._invalidate_cache_and_stream( txn, self._get_bare_e2e_cross_signing_keys, (user_id,) ) - def set_e2e_cross_signing_key(self, user_id, key_type, key): + async def set_e2e_cross_signing_key(self, user_id, key_type, key): """Set a user's cross-signing key. Args: @@ -719,13 +719,16 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key_type (str): the type of cross-signing key to set key (dict): the key data """ - return self.db_pool.runInteraction( - "add_e2e_cross_signing_key", - self._set_e2e_cross_signing_key_txn, - user_id, - key_type, - key, - ) + + with await self._cross_signing_id_gen.get_next() as stream_id: + return await self.db_pool.runInteraction( + "add_e2e_cross_signing_key", + self._set_e2e_cross_signing_key_txn, + user_id, + key_type, + key, + stream_id, + ) def store_e2e_cross_signing_signatures(self, user_id, signatures): """Stores cross-signing signatures. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 4826be630c..e6a97b018c 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -15,14 +15,16 @@ import itertools import logging from queue import Empty, PriorityQueue -from typing import Dict, Iterable, List, Optional, Set, Tuple +from typing import Dict, Iterable, List, Set, Tuple from synapse.api.errors import StoreError +from synapse.events import EventBase from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore +from synapse.types import Collection from synapse.util.caches.descriptors import cached from synapse.util.iterutils import batch_iter @@ -30,12 +32,14 @@ logger = logging.getLogger(__name__) class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - async def get_auth_chain(self, event_ids, include_given=False): + async def get_auth_chain( + self, event_ids: Collection[str], include_given: bool = False + ) -> List[EventBase]: """Get auth events for given event_ids. The events *must* be state events. Args: - event_ids (list): state events - include_given (bool): include the given events in result + event_ids: state events + include_given: include the given events in result Returns: list of events @@ -45,43 +49,34 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) return await self.get_events_as_list(event_ids) - def get_auth_chain_ids( - self, - event_ids: List[str], - include_given: bool = False, - ignore_events: Optional[Set[str]] = None, - ): + async def get_auth_chain_ids( + self, event_ids: Collection[str], include_given: bool = False, + ) -> List[str]: """Get auth events for given event_ids. The events *must* be state events. Args: event_ids: state events include_given: include the given events in result - ignore_events: Set of events to exclude from the returned auth - chain. This is useful if the caller will just discard the - given events anyway, and saves us from figuring out their auth - chains if not required. Returns: list of event_ids """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given, - ignore_events, ) - def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events): - if ignore_events is None: - ignore_events = set() - + def _get_auth_chain_ids_txn( + self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool + ) -> List[str]: if include_given: results = set(event_ids) else: results = set() - base_sql = "SELECT auth_id FROM event_auth WHERE " + base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE " front = set(event_ids) while front: @@ -93,7 +88,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(base_sql + clause, args) new_front.update(r[0] for r in txn) - new_front -= ignore_events new_front -= results front = new_front diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b90e6de2d5..6313b41eef 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -153,11 +153,11 @@ class PersistEventsStore: # Note: Multiple instances of this function cannot be in flight at # the same time for the same room. if backfilled: - stream_ordering_manager = self._backfill_id_gen.get_next_mult( + stream_ordering_manager = await self._backfill_id_gen.get_next_mult( len(events_and_contexts) ) else: - stream_ordering_manager = self._stream_id_gen.get_next_mult( + stream_ordering_manager = await self._stream_id_gen.get_next_mult( len(events_and_contexts) ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4a3333c0db..e1241a724b 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -620,19 +620,38 @@ class EventsWorkerStore(SQLBaseStore): room_version_id = row["room_version_id"] if not room_version_id: - # this should only happen for out-of-band membership events - if not internal_metadata.get("out_of_band_membership"): - logger.warning( - "Room %s for event %s is unknown", d["room_id"], event_id + # this should only happen for out-of-band membership events which + # arrived before #6983 landed. For all other events, we should have + # an entry in the 'rooms' table. + # + # However, the 'out_of_band_membership' flag is unreliable for older + # invites, so just accept it for all membership events. + # + if d["type"] != EventTypes.Member: + raise Exception( + "Room %s for event %s is unknown" % (d["room_id"], event_id) ) - continue - # take a wild stab at the room version based on the event format + # so, assuming this is an out-of-band-invite that arrived before #6983 + # landed, we know that the room version must be v5 or earlier (because + # v6 hadn't been invented at that point, so invites from such rooms + # would have been rejected.) + # + # The main reason we need to know the room version here (other than + # choosing the right python Event class) is in case the event later has + # to be redacted - and all the room versions up to v5 used the same + # redaction algorithm. + # + # So, the following approximations should be adequate. + if format_version == EventFormatVersions.V1: + # if it's event format v1 then it must be room v1 or v2 room_version = RoomVersions.V1 elif format_version == EventFormatVersions.V2: + # if it's event format v2 then it must be room v3 room_version = RoomVersions.V3 else: + # if it's event format v3 then it must be room v4 or v5 room_version = RoomVersions.V5 else: room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 0e3b8739c6..a488e0924b 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py
@@ -1182,7 +1182,7 @@ class GroupServerStore(GroupServerWorkerStore): return next_id - with self._group_updates_id_gen.get_next() as next_id: + with await self._group_updates_id_gen.get_next() as next_id: res = await self.db_pool.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 4e3ec02d14..c9f655dfb7 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py
@@ -23,7 +23,7 @@ from synapse.util.iterutils import batch_iter class PresenceStore(SQLBaseStore): async def update_presence(self, presence_states): - stream_ordering_manager = self._presence_id_gen.get_next_mult( + stream_ordering_manager = await self._presence_id_gen.get_next_mult( len(presence_states) ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index a585e54812..2fb5b02d7d 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore): ) -> None: conditions_json = json_encoder.encode(conditions) actions_json = json_encoder.encode(actions) - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() if before or after: @@ -560,7 +560,7 @@ class PushRuleStore(PushRulesWorkerStore): txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" ) - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -571,7 +571,7 @@ class PushRuleStore(PushRulesWorkerStore): ) async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None: - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( @@ -646,7 +646,7 @@ class PushRuleStore(PushRulesWorkerStore): data={"actions": actions_json}, ) - with self._push_rules_stream_id_gen.get_next() as stream_id: + with await self._push_rules_stream_id_gen.get_next() as stream_id: event_stream_ordering = self._stream_id_gen.get_current_token() await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 1126fd0751..c388468273 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py
@@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore): last_stream_ordering, profile_tag="", ) -> None: - with self._pushers_id_gen.get_next() as stream_id: + with await self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on # (app_id, pushkey, user_name) so simple_upsert will retry await self.db_pool.simple_upsert( @@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore): }, ) - with self._pushers_id_gen.get_next() as stream_id: + with await self._pushers_id_gen.get_next() as stream_id: await self.db_pool.runInteraction( "delete_pusher", delete_pusher_txn, stream_id ) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 19ad1c056f..6821476ee0 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py
@@ -520,8 +520,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "insert_receipt_conv", graph_to_linear ) - stream_id_manager = self._receipts_id_gen.get_next() - with stream_id_manager as stream_id: + with await self._receipts_id_gen.get_next() as stream_id: event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 5986d32b18..336b578e23 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -968,6 +968,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): super(RegistrationStore, self).__init__(database, db_conn, hs) self._account_validity = hs.config.account_validity + self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors if self._account_validity.enabled: self._clock.call_later( @@ -1381,15 +1382,22 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) if not row: - raise ThreepidValidationError(400, "Unknown session_id") + if self._ignore_unknown_session_error: + # If we need to inhibit the error caused by an incorrect session ID, + # use None as placeholder values for the client secret and the + # validation timestamp. + # It shouldn't be an issue because they're both only checked after + # the token check, which should fail. And if it doesn't for some + # reason, the next check is on the client secret, which is NOT NULL, + # so we don't have to worry about the client secret matching by + # accident. + row = {"client_secret": None, "validated_at": None} + else: + raise ThreepidValidationError(400, "Unknown session_id") + retrieved_client_secret = row["client_secret"] validated_at = row["validated_at"] - if retrieved_client_secret != client_secret: - raise ThreepidValidationError( - 400, "This client_secret does not match the provided session_id" - ) - row = self.db_pool.simple_select_one_txn( txn, table="threepid_validation_token", @@ -1405,6 +1413,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): expires = row["expires"] next_link = row["next_link"] + if retrieved_client_secret != client_secret: + raise ThreepidValidationError( + 400, "This client_secret does not match the provided session_id" + ) + # If the session is already validated, no need to revalidate if validated_at: return next_link diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 0142a856d5..99a8a9fab0 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -21,10 +21,6 @@ from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple -from canonicaljson import json - -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions @@ -32,6 +28,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchStore from synapse.types import ThirdPartyInstanceID +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -342,23 +339,22 @@ class RoomWorkerStore(SQLBaseStore): desc="is_room_blocked", ) - @defer.inlineCallbacks - def is_room_published(self, room_id): + async def is_room_published(self, room_id: str) -> bool: """Check whether a room has been published in the local public room directory. Args: - room_id (str) + room_id Returns: - bool: Whether the room is currently published in the room directory + Whether the room is currently published in the room directory """ # Get room information - room_info = yield self.get_room(room_id) + room_info = await self.get_room(room_id) if not room_info: - defer.returnValue(False) + return False # Check the is_public value - defer.returnValue(room_info.get("is_public", False)) + return room_info.get("is_public", False) async def get_rooms_paginate( self, @@ -572,7 +568,7 @@ class RoomWorkerStore(SQLBaseStore): # maximum, in order not to filter out events we should filter out when sending to # the client. if not self.config.retention_enabled: - defer.returnValue({"min_lifetime": None, "max_lifetime": None}) + return {"min_lifetime": None, "max_lifetime": None} def get_retention_policy_for_room_txn(txn): txn.execute( @@ -1155,7 +1151,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: + with await self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "store_room_txn", store_room_txn, next_id ) @@ -1222,7 +1218,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: + with await self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) @@ -1302,7 +1298,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): }, ) - with self._public_room_id_gen.get_next() as next_id: + with await self._public_room_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, @@ -1335,7 +1331,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): "event_id": event_id, "user_id": user_id, "reason": reason, - "content": json.dumps(content), + "content": json_encoder.encode(content), }, desc="add_event_report", ) diff --git a/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql new file mode 100644
index 0000000000..4cc96a5341 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
@@ -0,0 +1,25 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- A table of the IP address and user-agent used to complete each step of a +-- user-interactive authentication session. +CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips( + session_id TEXT NOT NULL, + ip TEXT NOT NULL, + user_agent TEXT NOT NULL, + UNIQUE (session_id, ip, user_agent), + FOREIGN KEY (session_id) + REFERENCES ui_auth_sessions (session_id) +); diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index e4e0a0c433..0c34bbf21a 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py
@@ -17,11 +17,10 @@ import logging from typing import Dict, List, Tuple -from canonicaljson import json - from synapse.storage._base import db_to_json from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.types import JsonDict +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -98,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore): txn.execute(sql, (user_id, room_id)) tags = [] for tag, content in txn: - tags.append(json.dumps(tag) + ":" + content) + tags.append(json_encoder.encode(tag) + ":" + content) tag_json = "{" + ",".join(tags) + "}" results.append((stream_id, (user_id, room_id, tag_json))) @@ -200,7 +199,7 @@ class TagsStore(TagsWorkerStore): Returns: The next account data ID. """ - content_json = json.dumps(content) + content_json = json_encoder.encode(content) def add_tag_txn(txn, next_id): self.db_pool.simple_upsert_txn( @@ -211,7 +210,7 @@ class TagsStore(TagsWorkerStore): ) self._update_revision_txn(txn, user_id, room_id, next_id) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) @@ -233,7 +232,7 @@ class TagsStore(TagsWorkerStore): txn.execute(sql, (user_id, room_id, tag)) self._update_revision_txn(txn, user_id, room_id, next_id) - with self._account_data_id_gen.get_next() as next_id: + with await self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 37276f73f8..9eef8e57c5 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py
@@ -12,15 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import attr -from canonicaljson import json from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict -from synapse.util import stringutils as stringutils +from synapse.util import json_encoder, stringutils @attr.s @@ -72,7 +72,7 @@ class UIAuthWorkerStore(SQLBaseStore): StoreError if a unique session ID cannot be generated. """ # The clientdict gets stored as JSON. - clientdict_json = json.dumps(clientdict) + clientdict_json = json_encoder.encode(clientdict) # autogen a session ID and try to create it. We may clash, so just # try a few times till one goes through, giving up eventually. @@ -143,7 +143,7 @@ class UIAuthWorkerStore(SQLBaseStore): await self.db_pool.simple_upsert( table="ui_auth_sessions_credentials", keyvalues={"session_id": session_id, "stage_type": stage_type}, - values={"result": json.dumps(result)}, + values={"result": json_encoder.encode(result)}, desc="mark_ui_auth_stage_complete", ) except self.db_pool.engine.module.IntegrityError: @@ -184,7 +184,7 @@ class UIAuthWorkerStore(SQLBaseStore): The dictionary from the client root level, not the 'auth' key. """ # The clientdict gets stored as JSON. - clientdict_json = json.dumps(clientdict) + clientdict_json = json_encoder.encode(clientdict) await self.db_pool.simple_update_one( table="ui_auth_sessions", @@ -214,14 +214,16 @@ class UIAuthWorkerStore(SQLBaseStore): value, ) - def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any): + def _set_ui_auth_session_data_txn( + self, txn: LoggingTransaction, session_id: str, key: str, value: Any + ): # Get the current value. result = self.db_pool.simple_select_one_txn( txn, table="ui_auth_sessions", keyvalues={"session_id": session_id}, retcols=("serverdict",), - ) + ) # type: Dict[str, Any] # type: ignore # Update it and add it back to the database. serverdict = db_to_json(result["serverdict"]) @@ -231,7 +233,7 @@ class UIAuthWorkerStore(SQLBaseStore): txn, table="ui_auth_sessions", keyvalues={"session_id": session_id}, - updatevalues={"serverdict": json.dumps(serverdict)}, + updatevalues={"serverdict": json_encoder.encode(serverdict)}, ) async def get_ui_auth_session_data( @@ -258,6 +260,34 @@ class UIAuthWorkerStore(SQLBaseStore): return serverdict.get(key, default) + async def add_user_agent_ip_to_ui_auth_session( + self, session_id: str, user_agent: str, ip: str, + ): + """Add the given user agent / IP to the tracking table + """ + await self.db_pool.simple_upsert( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip}, + values={}, + desc="add_user_agent_ip_to_ui_auth_session", + ) + + async def get_user_agents_ips_to_ui_auth_session( + self, session_id: str, + ) -> List[Tuple[str, str]]: + """Get the given user agents / IPs used during the ui auth process + + Returns: + List of user_agent/ip pairs + """ + rows = await self.db_pool.simple_select_list( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id}, + retcols=("user_agent", "ip"), + desc="get_user_agents_ips_to_ui_auth_session", + ) + return [(row["user_agent"], row["ip"]) for row in rows] + class UIAuthStore(UIAuthWorkerStore): def delete_old_ui_auth_sessions(self, expiration_time: int): @@ -275,12 +305,23 @@ class UIAuthStore(UIAuthWorkerStore): expiration_time, ) - def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int): + def _delete_old_ui_auth_sessions_txn( + self, txn: LoggingTransaction, expiration_time: int + ): # Get the expired sessions. sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" txn.execute(sql, [expiration_time]) session_ids = [r[0] for r in txn.fetchall()] + # Delete the corresponding IP/user agents. + self.db_pool.simple_delete_many_txn( + txn, + table="ui_auth_sessions_ips", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) + # Delete the corresponding completed credentials. self.db_pool.simple_delete_many_txn( txn,