From dc46f12725001dde99c536a9189045709cf7e06c Mon Sep 17 00:00:00 2001 From: Dagfinn Ilmari Mannsåker Date: Tue, 3 Aug 2021 14:35:49 +0100 Subject: Include room ID in ignored EDU log messages (#10507) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dagfinn Ilmari Mannsåker --- synapse/handlers/receipts.py | 3 ++- synapse/handlers/typing.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index b9085bbccb..5fd4525700 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -70,7 +70,8 @@ class ReceiptsHandler(BaseHandler): ) if not is_in_room: logger.info( - "Ignoring receipt from %s as we're not in the room", + "Ignoring receipt for room %r from server %s as we're not in the room", + room_id, origin, ) continue diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 0cb651a400..a97c448595 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -335,7 +335,8 @@ class TypingWriterHandler(FollowerTypingHandler): ) if not is_in_room: logger.info( - "Ignoring typing update from %s as we're not in the room", + "Ignoring typing update for room %r from server %s as we're not in the room", + room_id, origin, ) return -- cgit 1.5.1 From 4b10880da363efed5d066191190237f1c64fddfd Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 3 Aug 2021 14:45:04 +0100 Subject: Make sync response cache time configurable. (#10513) --- changelog.d/10513.feature | 1 + docs/sample_config.yaml | 9 +++++++++ synapse/config/cache.py | 13 +++++++++++++ synapse/handlers/sync.py | 14 +++++++++++--- 4 files changed, 34 insertions(+), 3 deletions(-) create mode 100644 changelog.d/10513.feature (limited to 'synapse') diff --git a/changelog.d/10513.feature b/changelog.d/10513.feature new file mode 100644 index 0000000000..153b2df7b2 --- /dev/null +++ b/changelog.d/10513.feature @@ -0,0 +1 @@ +Add a configuration setting for the time a `/sync` response is cached for. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 1a217f35db..a2efc14100 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -711,6 +711,15 @@ caches: # #expiry_time: 30m + # Controls how long the results of a /sync request are cached for after + # a successful response is returned. A higher duration can help clients with + # intermittent connections, at the cost of higher memory usage. + # + # By default, this is zero, which means that sync responses are not cached + # at all. + # + #sync_response_cache_duration: 2m + ## Database ## diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 8d5f38b5d9..d119427ad8 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -151,6 +151,15 @@ class CacheConfig(Config): # entries are never evicted based on time. # #expiry_time: 30m + + # Controls how long the results of a /sync request are cached for after + # a successful response is returned. A higher duration can help clients with + # intermittent connections, at the cost of higher memory usage. + # + # By default, this is zero, which means that sync responses are not cached + # at all. + # + #sync_response_cache_duration: 2m """ def read_config(self, config, **kwargs): @@ -212,6 +221,10 @@ class CacheConfig(Config): else: self.expiry_time_msec = None + self.sync_response_cache_duration = self.parse_duration( + cache_config.get("sync_response_cache_duration", 0) + ) + # Resize all caches (if necessary) with the new factors we've loaded self.resize_all_caches() diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index f30bfcc93c..590642f510 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -269,14 +269,22 @@ class SyncHandler: self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() - self.response_cache: ResponseCache[SyncRequestKey] = ResponseCache( - hs.get_clock(), "sync" - ) self.state = hs.get_state_handler() self.auth = hs.get_auth() self.storage = hs.get_storage() self.state_store = self.storage.state + # TODO: flush cache entries on subsequent sync request. + # Once we get the next /sync request (ie, one with the same access token + # that sets 'since' to 'next_batch'), we know that device won't need a + # cached result any more, and we could flush the entry from the cache to save + # memory. + self.response_cache: ResponseCache[SyncRequestKey] = ResponseCache( + hs.get_clock(), + "sync", + timeout_ms=hs.config.caches.sync_response_cache_duration, + ) + # ExpiringCache((User, Device)) -> LruCache(user_id => event_id) self.lazy_loaded_members_cache: ExpiringCache[ Tuple[str, Optional[str]], LruCache[str, str] -- cgit 1.5.1 From 72935b7c5053af122c8cb5767a3e85a3a0f3a20c Mon Sep 17 00:00:00 2001 From: Kento Okamoto Date: Tue, 3 Aug 2021 11:13:34 -0700 Subject: Add warnings to ip_range_blacklist usage with proxies (#10129) Per issue #9812 using `url_preview_ip_range_blacklist` with a proxy via `HTTPS_PROXY` or `HTTP_PROXY` environment variables has some inconsistent bahavior than mentioned. This PR changes the following: - Changes the Sample Config file to include a note mentioning that `url_preview_ip_range_blacklist` and `ip_range_blacklist` is ignored when using a proxy - Changes some logic in synapse/config/repository.py to send a warning when both `*ip_range_blacklist` configs and a proxy environment variable are set and but no longer throws an error. Signed-off-by: Kento Okamoto --- changelog.d/10129.bugfix | 1 + docs/sample_config.yaml | 4 ++++ synapse/config/repository.py | 24 +++++++++++++++++++----- synapse/config/server.py | 2 ++ 4 files changed, 26 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10129.bugfix (limited to 'synapse') diff --git a/changelog.d/10129.bugfix b/changelog.d/10129.bugfix new file mode 100644 index 0000000000..292676ec8d --- /dev/null +++ b/changelog.d/10129.bugfix @@ -0,0 +1 @@ +Add some clarification to the sample config file. Contributed by @Kentokamoto. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index a2efc14100..16843dd8c9 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -210,6 +210,8 @@ presence: # # This option replaces federation_ip_range_blacklist in Synapse v1.25.0. # +# Note: The value is ignored when an HTTP proxy is in use +# #ip_range_blacklist: # - '127.0.0.0/8' # - '10.0.0.0/8' @@ -972,6 +974,8 @@ media_store_path: "DATADIR/media_store" # This must be specified if url_preview_enabled is set. It is recommended that # you uncomment the following list as a starting point. # +# Note: The value is ignored when an HTTP proxy is in use +# #url_preview_ip_range_blacklist: # - '127.0.0.0/8' # - '10.0.0.0/8' diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 0dfb3a227a..7481f3bf5f 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from collections import namedtuple from typing import Dict, List +from urllib.request import getproxies_environment # type: ignore from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set from synapse.python_dependencies import DependencyException, check_requirements @@ -22,6 +24,8 @@ from synapse.util.module_loader import load_module from ._base import Config, ConfigError +logger = logging.getLogger(__name__) + DEFAULT_THUMBNAIL_SIZES = [ {"width": 32, "height": 32, "method": "crop"}, {"width": 96, "height": 96, "method": "crop"}, @@ -36,6 +40,9 @@ THUMBNAIL_SIZE_YAML = """\ # method: %(method)s """ +HTTP_PROXY_SET_WARNING = """\ +The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured.""" + ThumbnailRequirement = namedtuple( "ThumbnailRequirement", ["width", "height", "method", "media_type"] ) @@ -180,12 +187,17 @@ class ContentRepositoryConfig(Config): e.message # noqa: B306, DependencyException.message is a property ) + proxy_env = getproxies_environment() if "url_preview_ip_range_blacklist" not in config: - raise ConfigError( - "For security, you must specify an explicit target IP address " - "blacklist in url_preview_ip_range_blacklist for url previewing " - "to work" - ) + if "http" not in proxy_env or "https" not in proxy_env: + raise ConfigError( + "For security, you must specify an explicit target IP address " + "blacklist in url_preview_ip_range_blacklist for url previewing " + "to work" + ) + else: + if "http" in proxy_env or "https" in proxy_env: + logger.warning("".join(HTTP_PROXY_SET_WARNING)) # we always blacklist '0.0.0.0' and '::', which are supposed to be # unroutable addresses. @@ -292,6 +304,8 @@ class ContentRepositoryConfig(Config): # This must be specified if url_preview_enabled is set. It is recommended that # you uncomment the following list as a starting point. # + # Note: The value is ignored when an HTTP proxy is in use + # #url_preview_ip_range_blacklist: %(ip_range_blacklist)s diff --git a/synapse/config/server.py b/synapse/config/server.py index b9e0c0b300..187b4301a0 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -960,6 +960,8 @@ class ServerConfig(Config): # # This option replaces federation_ip_range_blacklist in Synapse v1.25.0. # + # Note: The value is ignored when an HTTP proxy is in use + # #ip_range_blacklist: %(ip_range_blacklist)s -- cgit 1.5.1 From c2000ab35b76288a625f598d2382d4e3f29f65f6 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Wed, 4 Aug 2021 13:40:25 +0300 Subject: Add `get_userinfo_by_id` method to `ModuleApi` (#9581) Makes it easier to fetch user details in for example spam checker modules, without needing to use api._store or figure out database interactions. Signed-off-by: Jason Robinson --- changelog.d/9581.feature | 1 + synapse/module_api/__init__.py | 12 ++++++++++- synapse/storage/databases/main/registration.py | 30 +++++++++++++++++++++++++- synapse/types.py | 29 +++++++++++++++++++++++++ tests/module_api/test_api.py | 10 +++++++++ 5 files changed, 80 insertions(+), 2 deletions(-) create mode 100644 changelog.d/9581.feature (limited to 'synapse') diff --git a/changelog.d/9581.feature b/changelog.d/9581.feature new file mode 100644 index 0000000000..fa1949cd4b --- /dev/null +++ b/changelog.d/9581.feature @@ -0,0 +1 @@ +Add `get_userinfo_by_id` method to ModuleApi. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 473812b8e2..1cc13fc97b 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -45,7 +45,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter -from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.types import JsonDict, Requester, UserID, UserInfo, create_requester from synapse.util import Clock from synapse.util.caches.descriptors import cached @@ -174,6 +174,16 @@ class ModuleApi: """The application name configured in the homeserver's configuration.""" return self._hs.config.email.email_app_name + async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: + """Get user info by user_id + + Args: + user_id: Fully qualified user id. + Returns: + UserInfo object if a user was found, otherwise None + """ + return await self._store.get_userinfo_by_id(user_id) + async def get_user_by_req( self, req: SynapseRequest, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 6ad1a0cf7f..14670c2881 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -29,7 +29,7 @@ from synapse.storage.databases.main.stats import StatsStore from synapse.storage.types import Connection, Cursor from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import UserID +from synapse.types import UserID, UserInfo from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -146,6 +146,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): @cached() async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + """Deprecated: use get_userinfo_by_id instead""" return await self.db_pool.simple_select_one( table="users", keyvalues={"name": user_id}, @@ -166,6 +167,33 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="get_user_by_id", ) + async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: + """Get a UserInfo object for a user by user ID. + + Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed, + this method should be cached. + + Args: + user_id: The user to fetch user info for. + Returns: + `UserInfo` object if user found, otherwise `None`. + """ + user_data = await self.get_user_by_id(user_id) + if not user_data: + return None + return UserInfo( + appservice_id=user_data["appservice_id"], + consent_server_notice_sent=user_data["consent_server_notice_sent"], + consent_version=user_data["consent_version"], + creation_ts=user_data["creation_ts"], + is_admin=bool(user_data["admin"]), + is_deactivated=bool(user_data["deactivated"]), + is_guest=bool(user_data["is_guest"]), + is_shadow_banned=bool(user_data["shadow_banned"]), + user_id=UserID.from_string(user_data["name"]), + user_type=user_data["user_type"], + ) + async def is_trial_user(self, user_id: str) -> bool: """Checks if user is in the "trial" period, i.e. within the first N days of registration defined by `mau_trial_days` config diff --git a/synapse/types.py b/synapse/types.py index 429bb013d2..80fa903c4b 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -751,3 +751,32 @@ def get_verify_key_from_cross_signing_key(key_info): # and return that one key for key_id, key_data in keys.items(): return (key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))) + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class UserInfo: + """Holds information about a user. Result of get_userinfo_by_id. + + Attributes: + user_id: ID of the user. + appservice_id: Application service ID that created this user. + consent_server_notice_sent: Version of policy documents the user has been sent. + consent_version: Version of policy documents the user has consented to. + creation_ts: Creation timestamp of the user. + is_admin: True if the user is an admin. + is_deactivated: True if the user has been deactivated. + is_guest: True if the user is a guest user. + is_shadow_banned: True if the user has been shadow-banned. + user_type: User type (None for normal user, 'support' and 'bot' other options). + """ + + user_id: UserID + appservice_id: Optional[int] + consent_server_notice_sent: Optional[str] + consent_version: Optional[str] + user_type: Optional[str] + creation_ts: int + is_admin: bool + is_deactivated: bool + is_guest: bool + is_shadow_banned: bool diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 81d9e2f484..0b817cc701 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -79,6 +79,16 @@ class ModuleApiTestCase(HomeserverTestCase): displayname = self.get_success(self.store.get_profile_displayname("bob")) self.assertEqual(displayname, "Bobberino") + def test_get_userinfo_by_id(self): + user_id = self.register_user("alice", "1234") + found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) + self.assertEqual(found_user.user_id.to_string(), user_id) + self.assertIdentical(found_user.is_admin, False) + + def test_get_userinfo_by_id__no_user_found(self): + found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test")) + self.assertIsNone(found_user) + def test_sending_events_into_room(self): """Tests that a module can send events into a room""" # Mock out create_and_send_nonmember_event to check whether events are being sent -- cgit 1.5.1 From 11540be55ed15da920fa6f3ea805315517c02c76 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Aug 2021 13:09:04 +0100 Subject: Fix `could not serialize access` errors for `claim_e2e_one_time_keys` (#10504) --- changelog.d/10504.misc | 1 + synapse/storage/databases/main/end_to_end_keys.py | 188 +++++++++++++++------- 2 files changed, 127 insertions(+), 62 deletions(-) create mode 100644 changelog.d/10504.misc (limited to 'synapse') diff --git a/changelog.d/10504.misc b/changelog.d/10504.misc new file mode 100644 index 0000000000..1479a5022d --- /dev/null +++ b/changelog.d/10504.misc @@ -0,0 +1 @@ +Reduce errors in PostgreSQL logs due to concurrent serialization errors. diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 1edc96042b..1f0a39eac4 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -755,81 +755,145 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): """ @trace - def _claim_e2e_one_time_keys(txn): - sql = ( - "SELECT key_id, key_json FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " LIMIT 1" + def _claim_e2e_one_time_key_simple( + txn, user_id: str, device_id: str, algorithm: str + ) -> Optional[Tuple[str, str]]: + """Claim OTK for device for DBs that don't support RETURNING. + + Returns: + A tuple of key name (algorithm + key ID) and key JSON, if an + OTK was found. + """ + + sql = """ + SELECT key_id, key_json FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + LIMIT 1 + """ + + txn.execute(sql, (user_id, device_id, algorithm)) + otk_row = txn.fetchone() + if otk_row is None: + return None + + key_id, key_json = otk_row + + self.db_pool.simple_delete_one_txn( + txn, + table="e2e_one_time_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id, + }, ) - fallback_sql = ( - "SELECT key_id, key_json, used FROM e2e_fallback_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " LIMIT 1" + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - result = {} - delete = [] - used_fallbacks = [] - for user_id, device_id, algorithm in query_list: - user_result = result.setdefault(user_id, {}) - device_result = user_result.setdefault(device_id, {}) - txn.execute(sql, (user_id, device_id, algorithm)) - otk_row = txn.fetchone() - if otk_row is not None: - key_id, key_json = otk_row - device_result[algorithm + ":" + key_id] = key_json - delete.append((user_id, device_id, algorithm, key_id)) - else: - # no one-time key available, so see if there's a fallback - # key - txn.execute(fallback_sql, (user_id, device_id, algorithm)) - fallback_row = txn.fetchone() - if fallback_row is not None: - key_id, key_json, used = fallback_row - device_result[algorithm + ":" + key_id] = key_json - if not used: - used_fallbacks.append( - (user_id, device_id, algorithm, key_id) - ) - - # drop any one-time keys that were claimed - sql = ( - "DELETE FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " AND key_id = ?" + + return f"{algorithm}:{key_id}", key_json + + @trace + def _claim_e2e_one_time_key_returning( + txn, user_id: str, device_id: str, algorithm: str + ) -> Optional[Tuple[str, str]]: + """Claim OTK for device for DBs that support RETURNING. + + Returns: + A tuple of key name (algorithm + key ID) and key JSON, if an + OTK was found. + """ + + # We can use RETURNING to do the fetch and DELETE in once step. + sql = """ + DELETE FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + AND key_id IN ( + SELECT key_id FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + LIMIT 1 + ) + RETURNING key_id, key_json + """ + + txn.execute( + sql, (user_id, device_id, algorithm, user_id, device_id, algorithm) ) - for user_id, device_id, algorithm, key_id in delete: - log_kv( - { - "message": "Executing claim e2e_one_time_keys transaction on database." - } - ) - txn.execute(sql, (user_id, device_id, algorithm, key_id)) - log_kv({"message": "finished executing and invalidating cache"}) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) + otk_row = txn.fetchone() + if otk_row is None: + return None + + key_id, key_json = otk_row + return f"{algorithm}:{key_id}", key_json + + results = {} + for user_id, device_id, algorithm in query_list: + if self.database_engine.supports_returning: + # If we support RETURNING clause we can use a single query that + # allows us to use autocommit mode. + _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning + db_autocommit = True + else: + _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple + db_autocommit = False + + row = await self.db_pool.runInteraction( + "claim_e2e_one_time_keys", + _claim_e2e_one_time_key, + user_id, + device_id, + algorithm, + db_autocommit=db_autocommit, + ) + if row: + device_results = results.setdefault(user_id, {}).setdefault( + device_id, {} ) - # mark fallback keys as used - for user_id, device_id, algorithm, key_id in used_fallbacks: - self.db_pool.simple_update_txn( - txn, - "e2e_fallback_keys_json", - { + device_results[row[0]] = row[1] + continue + + # No one-time key available, so see if there's a fallback + # key + row = await self.db_pool.simple_select_one( + table="e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + }, + retcols=("key_id", "key_json", "used"), + desc="_get_fallback_key", + allow_none=True, + ) + if row is None: + continue + + key_id = row["key_id"] + key_json = row["key_json"] + used = row["used"] + + # Mark fallback key as used if not already. + if not used: + await self.db_pool.simple_update_one( + table="e2e_fallback_keys_json", + keyvalues={ "user_id": user_id, "device_id": device_id, "algorithm": algorithm, "key_id": key_id, }, - {"used": True}, + updatevalues={"used": True}, + desc="_get_fallback_key_set_used", ) - self._invalidate_cache_and_stream( - txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) + await self.invalidate_cache_and_stream( + "get_e2e_unused_fallback_key_types", (user_id, device_id) ) - return result + device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) + device_results[f"{algorithm}:{key_id}"] = key_json - return await self.db_pool.runInteraction( - "claim_e2e_one_time_keys", _claim_e2e_one_time_keys - ) + return results class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): -- cgit 1.5.1 From c37dad67ab04980ac934554399f52a27e54292ab Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Aug 2021 13:54:51 +0100 Subject: Improve event caching code (#10119) Ensure we only load an event from the DB once when the same event is requested multiple times at once. --- changelog.d/10119.misc | 1 + synapse/storage/databases/main/events_worker.py | 144 +++++++++++++++------ synapse/storage/databases/main/roommember.py | 6 +- tests/storage/databases/main/test_events_worker.py | 50 +++++++ 4 files changed, 158 insertions(+), 43 deletions(-) create mode 100644 changelog.d/10119.misc (limited to 'synapse') diff --git a/changelog.d/10119.misc b/changelog.d/10119.misc new file mode 100644 index 0000000000..f70dc6496f --- /dev/null +++ b/changelog.d/10119.misc @@ -0,0 +1 @@ +Improve event caching mechanism to avoid having multiple copies of an event in memory at a time. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 3c86adab56..375463e4e9 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -14,7 +14,6 @@ import logging import threading -from collections import namedtuple from typing import ( Collection, Container, @@ -27,6 +26,7 @@ from typing import ( overload, ) +import attr from constantly import NamedConstant, Names from typing_extensions import Literal @@ -42,7 +42,11 @@ from synapse.api.room_versions import ( from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.events.utils import prune_event -from synapse.logging.context import PreserveLoggingContext, current_context +from synapse.logging.context import ( + PreserveLoggingContext, + current_context, + make_deferred_yieldable, +) from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -56,6 +60,8 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id +from synapse.util import unwrapFirstError +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter @@ -74,7 +80,10 @@ EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events -_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) +@attr.s(slots=True, auto_attribs=True) +class _EventCacheEntry: + event: EventBase + redacted_event: Optional[EventBase] class EventRedactBehaviour(Names): @@ -161,6 +170,13 @@ class EventsWorkerStore(SQLBaseStore): max_size=hs.config.caches.event_cache_size, ) + # Map from event ID to a deferred that will result in a map from event + # ID to cache entry. Note that the returned dict may not have the + # requested event in it if the event isn't in the DB. + self._current_event_fetches: Dict[ + str, ObservableDeferred[Dict[str, _EventCacheEntry]] + ] = {} + self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 @@ -476,7 +492,9 @@ class EventsWorkerStore(SQLBaseStore): return events - async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + async def _get_events_from_cache_or_db( + self, event_ids: Iterable[str], allow_rejected: bool = False + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -485,53 +503,107 @@ class EventsWorkerStore(SQLBaseStore): Args: - event_ids (Iterable[str]): The event_ids of the events to fetch + event_ids: The event_ids of the events to fetch - allow_rejected (bool): Whether to include rejected events. If False, + allow_rejected: Whether to include rejected events. If False, rejected events are omitted from the response. Returns: - Dict[str, _EventCacheEntry]: - map from event id to result + map from event id to result """ event_entry_map = self._get_events_from_cache( - event_ids, allow_rejected=allow_rejected + event_ids, ) - missing_events_ids = [e for e in event_ids if e not in event_entry_map] + missing_events_ids = {e for e in event_ids if e not in event_entry_map} + + # We now look up if we're already fetching some of the events in the DB, + # if so we wait for those lookups to finish instead of pulling the same + # events out of the DB multiple times. + already_fetching: Dict[str, defer.Deferred] = {} + + for event_id in missing_events_ids: + deferred = self._current_event_fetches.get(event_id) + if deferred is not None: + # We're already pulling the event out of the DB. Add the deferred + # to the collection of deferreds to wait on. + already_fetching[event_id] = deferred.observe() + + missing_events_ids.difference_update(already_fetching) if missing_events_ids: log_ctx = current_context() log_ctx.record_event_fetch(len(missing_events_ids)) + # Add entries to `self._current_event_fetches` for each event we're + # going to pull from the DB. We use a single deferred that resolves + # to all the events we pulled from the DB (this will result in this + # function returning more events than requested, but that can happen + # already due to `_get_events_from_db`). + fetching_deferred: ObservableDeferred[ + Dict[str, _EventCacheEntry] + ] = ObservableDeferred(defer.Deferred()) + for event_id in missing_events_ids: + self._current_event_fetches[event_id] = fetching_deferred + # Note that _get_events_from_db is also responsible for turning db rows # into FrozenEvents (via _get_event_from_row), which involves seeing if # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - missing_events = await self._get_events_from_db( - missing_events_ids, allow_rejected=allow_rejected - ) + try: + missing_events = await self._get_events_from_db( + missing_events_ids, + ) - event_entry_map.update(missing_events) + event_entry_map.update(missing_events) + except Exception as e: + with PreserveLoggingContext(): + fetching_deferred.errback(e) + raise e + finally: + # Ensure that we mark these events as no longer being fetched. + for event_id in missing_events_ids: + self._current_event_fetches.pop(event_id, None) + + with PreserveLoggingContext(): + fetching_deferred.callback(missing_events) + + if already_fetching: + # Wait for the other event requests to finish and add their results + # to ours. + results = await make_deferred_yieldable( + defer.gatherResults( + already_fetching.values(), + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) + + for result in results: + event_entry_map.update(result) + + if not allow_rejected: + event_entry_map = { + event_id: entry + for event_id, entry in event_entry_map.items() + if not entry.event.rejected_reason + } return event_entry_map def _invalidate_get_event_cache(self, event_id): self._get_event_cache.invalidate((event_id,)) - def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): - """Fetch events from the caches + def _get_events_from_cache( + self, events: Iterable[str], update_metrics: bool = True + ) -> Dict[str, _EventCacheEntry]: + """Fetch events from the caches. - Args: - events (Iterable[str]): list of event_ids to fetch - allow_rejected (bool): Whether to return events that were rejected - update_metrics (bool): Whether to update the cache hit ratio metrics + May return rejected events. - Returns: - dict of event_id -> _EventCacheEntry for each event_id in cache. If - allow_rejected is `False` then there will still be an entry but it - will be `None` + Args: + events: list of event_ids to fetch + update_metrics: Whether to update the cache hit ratio metrics """ event_map = {} @@ -542,10 +614,7 @@ class EventsWorkerStore(SQLBaseStore): if not ret: continue - if allow_rejected or not ret.event.rejected_reason: - event_map[event_id] = ret - else: - event_map[event_id] = None + event_map[event_id] = ret return event_map @@ -672,23 +741,23 @@ class EventsWorkerStore(SQLBaseStore): with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire, event_list, e) - async def _get_events_from_db(self, event_ids, allow_rejected=False): + async def _get_events_from_db( + self, event_ids: Iterable[str] + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the database. + May return rejected events. + Returned events will be added to the cache for future lookups. Unknown events are omitted from the response. Args: - event_ids (Iterable[str]): The event_ids of the events to fetch - - allow_rejected (bool): Whether to include rejected events. If False, - rejected events are omitted from the response. + event_ids: The event_ids of the events to fetch Returns: - Dict[str, _EventCacheEntry]: - map from event id to result. May return extra events which - weren't asked for. + map from event id to result. May return extra events which + weren't asked for. """ fetched_events = {} events_to_fetch = event_ids @@ -717,9 +786,6 @@ class EventsWorkerStore(SQLBaseStore): rejected_reason = row["rejected_reason"] - if not allow_rejected and rejected_reason: - continue - # If the event or metadata cannot be parsed, log the error and act # as if the event is unknown. try: diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 68f1b40ea6..e8157ba3d4 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -629,14 +629,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): # We don't update the event cache hit ratio as it completely throws off # the hit ratio counts. After all, we don't populate the cache if we # miss it here - event_map = self._get_events_from_cache( - member_event_ids, allow_rejected=False, update_metrics=False - ) + event_map = self._get_events_from_cache(member_event_ids, update_metrics=False) missing_member_event_ids = [] for event_id in member_event_ids: ev_entry = event_map.get(event_id) - if ev_entry: + if ev_entry and not ev_entry.event.rejected_reason: if ev_entry.event.membership == Membership.JOIN: users_in_room[ev_entry.event.state_key] = ProfileInfo( display_name=ev_entry.event.content.get("displayname", None), diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 932970fd9a..d05d367685 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -14,7 +14,10 @@ import json from synapse.logging.context import LoggingContext +from synapse.rest import admin +from synapse.rest.client.v1 import login, room from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util.async_helpers import yieldable_gather_results from tests import unittest @@ -94,3 +97,50 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): res = self.get_success(self.store.have_seen_events("room1", ["event10"])) self.assertEquals(res, {"event10"}) self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) + + +class EventCacheTestCase(unittest.HomeserverTestCase): + """Test that the various layers of event cache works.""" + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store: EventsWorkerStore = hs.get_datastore() + + self.user = self.register_user("user", "pass") + self.token = self.login(self.user, "pass") + + self.room = self.helper.create_room_as(self.user, tok=self.token) + + res = self.helper.send(self.room, tok=self.token) + self.event_id = res["event_id"] + + # Reset the event cache so the tests start with it empty + self.store._get_event_cache.clear() + + def test_simple(self): + """Test that we cache events that we pull from the DB.""" + + with LoggingContext("test") as ctx: + self.get_success(self.store.get_event(self.event_id)) + + # We should have fetched the event from the DB + self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) + + def test_dedupe(self): + """Test that if we request the same event multiple times we only pull it + out once. + """ + + with LoggingContext("test") as ctx: + d = yieldable_gather_results( + self.store.get_event, [self.event_id, self.event_id] + ) + self.get_success(d) + + # We should have fetched the event from the DB + self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) -- cgit 1.5.1 From 684d19a11c3b93c9dd5fb90f43d38aa7e8c6005f Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 4 Aug 2021 12:07:57 -0500 Subject: Add support for MSC2716 marker events (#10498) * Make historical messages available to federated servers Part of MSC2716: https://github.com/matrix-org/matrix-doc/pull/2716 Follow-up to https://github.com/matrix-org/synapse/pull/9247 * Debug message not available on federation * Add base starting insertion point when no chunk ID is provided * Fix messages from multiple senders in historical chunk Follow-up to https://github.com/matrix-org/synapse/pull/9247 Part of MSC2716: https://github.com/matrix-org/matrix-doc/pull/2716 --- Previously, Synapse would throw a 403, `Cannot force another user to join.`, because we were trying to use `?user_id` from a single virtual user which did not match with messages from other users in the chunk. * Remove debug lines * Messing with selecting insertion event extremeties * Move db schema change to new version * Add more better comments * Make a fake requester with just what we need See https://github.com/matrix-org/synapse/pull/10276#discussion_r660999080 * Store insertion events in table * Make base insertion event float off on its own See https://github.com/matrix-org/synapse/pull/10250#issuecomment-875711889 Conflicts: synapse/rest/client/v1/room.py * Validate that the app service can actually control the given user See https://github.com/matrix-org/synapse/pull/10276#issuecomment-876316455 Conflicts: synapse/rest/client/v1/room.py * Add some better comments on what we're trying to check for * Continue debugging * Share validation logic * Add inserted historical messages to /backfill response * Remove debug sql queries * Some marker event implemntation trials * Clean up PR * Rename insertion_event_id to just event_id * Add some better sql comments * More accurate description * Add changelog * Make it clear what MSC the change is part of * Add more detail on which insertion event came through * Address review and improve sql queries * Only use event_id as unique constraint * Fix test case where insertion event is already in the normal DAG * Remove debug changes * Add support for MSC2716 marker events * Process markers when we receive it over federation * WIP: make hs2 backfill historical messages after marker event * hs2 to better ask for insertion event extremity But running into the `sqlite3.IntegrityError: NOT NULL constraint failed: event_to_state_groups.state_group` error * Add insertion_event_extremities table * Switch to chunk events so we can auth via power_levels Previously, we were using `content.chunk_id` to connect one chunk to another. But these events can be from any `sender` and we can't tell who should be able to send historical events. We know we only want the application service to do it but these events have the sender of a real historical message, not the application service user ID as the sender. Other federated homeservers also have no indicator which senders are an application service on the originating homeserver. So we want to auth all of the MSC2716 events via power_levels and have them be sent by the application service with proper PL levels in the room. * Switch to chunk events for federation * Add unstable room version to support new historical PL * Messy: Fix undefined state_group for federated historical events ``` 2021-07-13 02:27:57,810 - synapse.handlers.federation - 1248 - ERROR - GET-4 - Failed to backfill from hs1 because NOT NULL constraint failed: event_to_state_groups.state_group Traceback (most recent call last): File "/usr/local/lib/python3.8/site-packages/synapse/handlers/federation.py", line 1216, in try_backfill await self.backfill( File "/usr/local/lib/python3.8/site-packages/synapse/handlers/federation.py", line 1035, in backfill await self._auth_and_persist_event(dest, event, context, backfilled=True) File "/usr/local/lib/python3.8/site-packages/synapse/handlers/federation.py", line 2222, in _auth_and_persist_event await self._run_push_actions_and_persist_event(event, context, backfilled) File "/usr/local/lib/python3.8/site-packages/synapse/handlers/federation.py", line 2244, in _run_push_actions_and_persist_event await self.persist_events_and_notify( File "/usr/local/lib/python3.8/site-packages/synapse/handlers/federation.py", line 3290, in persist_events_and_notify events, max_stream_token = await self.storage.persistence.persist_events( File "/usr/local/lib/python3.8/site-packages/synapse/logging/opentracing.py", line 774, in _trace_inner return await func(*args, **kwargs) File "/usr/local/lib/python3.8/site-packages/synapse/storage/persist_events.py", line 320, in persist_events ret_vals = await yieldable_gather_results(enqueue, partitioned.items()) File "/usr/local/lib/python3.8/site-packages/synapse/storage/persist_events.py", line 237, in handle_queue_loop ret = await self._per_item_callback( File "/usr/local/lib/python3.8/site-packages/synapse/storage/persist_events.py", line 577, in _persist_event_batch await self.persist_events_store._persist_events_and_state_updates( File "/usr/local/lib/python3.8/site-packages/synapse/storage/databases/main/events.py", line 176, in _persist_events_and_state_updates await self.db_pool.runInteraction( File "/usr/local/lib/python3.8/site-packages/synapse/storage/database.py", line 681, in runInteraction result = await self.runWithConnection( File "/usr/local/lib/python3.8/site-packages/synapse/storage/database.py", line 770, in runWithConnection return await make_deferred_yieldable( File "/usr/local/lib/python3.8/site-packages/twisted/python/threadpool.py", line 238, in inContext result = inContext.theWork() # type: ignore[attr-defined] File "/usr/local/lib/python3.8/site-packages/twisted/python/threadpool.py", line 254, in inContext.theWork = lambda: context.call( # type: ignore[attr-defined] File "/usr/local/lib/python3.8/site-packages/twisted/python/context.py", line 118, in callWithContext return self.currentContext().callWithContext(ctx, func, *args, **kw) File "/usr/local/lib/python3.8/site-packages/twisted/python/context.py", line 83, in callWithContext return func(*args, **kw) File "/usr/local/lib/python3.8/site-packages/twisted/enterprise/adbapi.py", line 293, in _runWithConnection compat.reraise(excValue, excTraceback) File "/usr/local/lib/python3.8/site-packages/twisted/python/deprecate.py", line 298, in deprecatedFunction return function(*args, **kwargs) File "/usr/local/lib/python3.8/site-packages/twisted/python/compat.py", line 403, in reraise raise exception.with_traceback(traceback) File "/usr/local/lib/python3.8/site-packages/twisted/enterprise/adbapi.py", line 284, in _runWithConnection result = func(conn, *args, **kw) File "/usr/local/lib/python3.8/site-packages/synapse/storage/database.py", line 765, in inner_func return func(db_conn, *args, **kwargs) File "/usr/local/lib/python3.8/site-packages/synapse/storage/database.py", line 549, in new_transaction r = func(cursor, *args, **kwargs) File "/usr/local/lib/python3.8/site-packages/synapse/logging/utils.py", line 69, in wrapped return f(*args, **kwargs) File "/usr/local/lib/python3.8/site-packages/synapse/storage/databases/main/events.py", line 385, in _persist_events_txn self._store_event_state_mappings_txn(txn, events_and_contexts) File "/usr/local/lib/python3.8/site-packages/synapse/storage/databases/main/events.py", line 2065, in _store_event_state_mappings_txn self.db_pool.simple_insert_many_txn( File "/usr/local/lib/python3.8/site-packages/synapse/storage/database.py", line 923, in simple_insert_many_txn txn.execute_batch(sql, vals) File "/usr/local/lib/python3.8/site-packages/synapse/storage/database.py", line 280, in execute_batch self.executemany(sql, args) File "/usr/local/lib/python3.8/site-packages/synapse/storage/database.py", line 300, in executemany self._do_execute(self.txn.executemany, sql, *args) File "/usr/local/lib/python3.8/site-packages/synapse/storage/database.py", line 330, in _do_execute return func(sql, *args) sqlite3.IntegrityError: NOT NULL constraint failed: event_to_state_groups.state_group ``` * Revert "Messy: Fix undefined state_group for federated historical events" This reverts commit 187ab28611546321e02770944c86f30ee2bc742a. * Fix federated events being rejected for no state_groups Add fix from https://github.com/matrix-org/synapse/pull/10439 until it merges. * Adapting to experimental room version * Some log cleanup * Add better comments around extremity fetching code and why * Rename to be more accurate to what the function returns * Add changelog * Ignore rejected events * Use simplified upsert * Add Erik's explanation of extra event checks See https://github.com/matrix-org/synapse/pull/10498#discussion_r680880332 * Clarify that the depth is not directly correlated to the backwards extremity that we return See https://github.com/matrix-org/synapse/pull/10498#discussion_r681725404 * lock only matters for sqlite See https://github.com/matrix-org/synapse/pull/10498#discussion_r681728061 * Move new SQL changes to its own delta file * Clean up upsert docstring * Bump database schema version (62) --- changelog.d/10498.feature | 1 + scripts-dev/complement.sh | 2 +- synapse/handlers/federation.py | 119 +++++++++++++++++++-- synapse/storage/database.py | 14 +-- synapse/storage/databases/main/event_federation.py | 114 +++++++++++++++++--- synapse/storage/databases/main/events.py | 24 ++++- synapse/storage/schema/__init__.py | 2 +- .../delta/62/01insertion_event_extremities.sql | 24 +++++ 8 files changed, 265 insertions(+), 35 deletions(-) create mode 100644 changelog.d/10498.feature create mode 100644 synapse/storage/schema/main/delta/62/01insertion_event_extremities.sql (limited to 'synapse') diff --git a/changelog.d/10498.feature b/changelog.d/10498.feature new file mode 100644 index 0000000000..5df896572d --- /dev/null +++ b/changelog.d/10498.feature @@ -0,0 +1 @@ +Add support for "marker" events which makes historical events discoverable for servers that already have all of the scrollback history (part of MSC2716). diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index cba015d942..5d0ef8dd3a 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -65,4 +65,4 @@ if [[ -n "$1" ]]; then fi # Run the tests! -go test -v -tags synapse_blacklist,msc2946,msc3083,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... +go test -v -tags synapse_blacklist,msc2946,msc3083,msc2403,msc2716 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 8197b60b76..8b602e3813 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -42,6 +42,7 @@ from twisted.internet import defer from synapse import event_auth from synapse.api.constants import ( + EventContentFields, EventTypes, Membership, RejectedReason, @@ -262,7 +263,12 @@ class FederationHandler(BaseHandler): state = None - # Get missing pdus if necessary. + # Check that the event passes auth based on the state at the event. This is + # done for events that are to be added to the timeline (non-outliers). + # + # Get missing pdus if necessary: + # - Fetching any missing prev events to fill in gaps in the graph + # - Fetching state if we have a hole in the graph if not pdu.internal_metadata.is_outlier(): # We only backfill backwards to the min depth. min_depth = await self.get_min_depth_for_context(pdu.room_id) @@ -432,6 +438,13 @@ class FederationHandler(BaseHandler): affected=event_id, ) + # A second round of checks for all events. Check that the event passes auth + # based on `auth_events`, this allows us to assert that the event would + # have been allowed at some point. If an event passes this check its OK + # for it to be used as part of a returned `/state` request, as either + # a) we received the event as part of the original join and so trust it, or + # b) we'll do a state resolution with existing state before it becomes + # part of the "current state", which adds more protection. await self._process_received_pdu(origin, pdu, state=state) async def _get_missing_events_for_pdu( @@ -889,6 +902,79 @@ class FederationHandler(BaseHandler): "resync_device_due_to_pdu", self._resync_device, event.sender ) + await self._handle_marker_event(origin, event) + + async def _handle_marker_event(self, origin: str, marker_event: EventBase): + """Handles backfilling the insertion event when we receive a marker + event that points to one. + + Args: + origin: Origin of the event. Will be called to get the insertion event + marker_event: The event to process + """ + + if marker_event.type != EventTypes.MSC2716_MARKER: + # Not a marker event + return + + if marker_event.rejected_reason is not None: + # Rejected event + return + + # Skip processing a marker event if the room version doesn't + # support it. + room_version = await self.store.get_room_version(marker_event.room_id) + if not room_version.msc2716_historical: + return + + logger.debug("_handle_marker_event: received %s", marker_event) + + insertion_event_id = marker_event.content.get( + EventContentFields.MSC2716_MARKER_INSERTION + ) + + if insertion_event_id is None: + # Nothing to retrieve then (invalid marker) + return + + logger.debug( + "_handle_marker_event: backfilling insertion event %s", insertion_event_id + ) + + await self._get_events_and_persist( + origin, + marker_event.room_id, + [insertion_event_id], + ) + + insertion_event = await self.store.get_event( + insertion_event_id, allow_none=True + ) + if insertion_event is None: + logger.warning( + "_handle_marker_event: server %s didn't return insertion event %s for marker %s", + origin, + insertion_event_id, + marker_event.event_id, + ) + return + + logger.debug( + "_handle_marker_event: succesfully backfilled insertion event %s from marker event %s", + insertion_event, + marker_event, + ) + + await self.store.insert_insertion_extremity( + insertion_event_id, marker_event.room_id + ) + + logger.debug( + "_handle_marker_event: insertion extremity added for %s from marker event %s", + insertion_event, + marker_event, + ) + async def _resync_device(self, sender: str) -> None: """We have detected that the device list for the given user may be out of sync, so we try and resync them. @@ -1057,9 +1143,19 @@ class FederationHandler(BaseHandler): async def _maybe_backfill_inner( self, room_id: str, current_depth: int, limit: int ) -> bool: - extremities = await self.store.get_oldest_events_with_depth_in_room(room_id) + oldest_events_with_depth = ( + await self.store.get_oldest_event_ids_with_depth_in_room(room_id) + ) + insertion_events_to_be_backfilled = ( + await self.store.get_insertion_event_backwards_extremities_in_room(room_id) + ) + logger.debug( + "_maybe_backfill_inner: extremities oldest_events_with_depth=%s insertion_events_to_be_backfilled=%s", + oldest_events_with_depth, + insertion_events_to_be_backfilled, + ) - if not extremities: + if not oldest_events_with_depth and not insertion_events_to_be_backfilled: logger.debug("Not backfilling as no extremeties found.") return False @@ -1089,10 +1185,12 @@ class FederationHandler(BaseHandler): # state *before* the event, ignoring the special casing certain event # types have. - forward_events = await self.store.get_successor_events(list(extremities)) + forward_event_ids = await self.store.get_successor_events( + list(oldest_events_with_depth) + ) extremities_events = await self.store.get_events( - forward_events, + forward_event_ids, redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, ) @@ -1106,10 +1204,19 @@ class FederationHandler(BaseHandler): redact=False, check_history_visibility_only=True, ) + logger.debug( + "_maybe_backfill_inner: filtered_extremities %s", filtered_extremities + ) - if not filtered_extremities: + if not filtered_extremities and not insertion_events_to_be_backfilled: return False + extremities = { + **oldest_events_with_depth, + # TODO: insertion_events_to_be_backfilled is currently skipping the filtered_extremities checks + **insertion_events_to_be_backfilled, + } + # Check if we reached a point where we should start backfilling. sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1])) max_depth = sorted_extremeties_tuple[0][1] diff --git a/synapse/storage/database.py b/synapse/storage/database.py index c8015a3848..95d2caff62 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -941,13 +941,13 @@ class DatabasePool: `lock` should generally be set to True (the default), but can be set to False if either of the following are true: - - * there is a UNIQUE INDEX on the key columns. In this case a conflict - will cause an IntegrityError in which case this function will retry - the update. - - * we somehow know that we are the only thread which will be updating - this table. + 1. there is a UNIQUE INDEX on the key columns. In this case a conflict + will cause an IntegrityError in which case this function will retry + the update. + 2. we somehow know that we are the only thread which will be updating + this table. + As an additional note, this parameter only matters for old SQLite versions + because we will use native upserts otherwise. Args: table: The table to upsert into diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 44018c1c31..bddf5ef192 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -671,27 +671,97 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} - async def get_oldest_events_with_depth_in_room(self, room_id): + async def get_oldest_event_ids_with_depth_in_room(self, room_id) -> Dict[str, int]: + """Gets the oldest events(backwards extremities) in the room along with the + aproximate depth. + + We use this function so that we can compare and see if someones current + depth at their current scrollback is within pagination range of the + event extremeties. If the current depth is close to the depth of given + oldest event, we can trigger a backfill. + + Args: + room_id: Room where we want to find the oldest events + + Returns: + Map from event_id to depth + """ + + def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id): + # Assemble a dictionary with event_id -> depth for the oldest events + # we know of in the room. Backwards extremeties are the oldest + # events we know of in the room but we only know of them because + # some other event referenced them by prev_event and aren't peristed + # in our database yet (meaning we don't know their depth + # specifically). So we need to look for the aproximate depth from + # the events connected to the current backwards extremeties. + sql = """ + SELECT b.event_id, MAX(e.depth) FROM events as e + /** + * Get the edge connections from the event_edges table + * so we can see whether this event's prev_events points + * to a backward extremity in the next join. + */ + INNER JOIN event_edges as g + ON g.event_id = e.event_id + /** + * We find the "oldest" events in the room by looking for + * events connected to backwards extremeties (oldest events + * in the room that we know of so far). + */ + INNER JOIN event_backward_extremities as b + ON g.prev_event_id = b.event_id + WHERE b.room_id = ? AND g.is_state is ? + GROUP BY b.event_id + """ + + txn.execute(sql, (room_id, False)) + + return dict(txn) + return await self.db_pool.runInteraction( - "get_oldest_events_with_depth_in_room", - self.get_oldest_events_with_depth_in_room_txn, + "get_oldest_event_ids_with_depth_in_room", + get_oldest_event_ids_with_depth_in_room_txn, room_id, ) - def get_oldest_events_with_depth_in_room_txn(self, txn, room_id): - sql = ( - "SELECT b.event_id, MAX(e.depth) FROM events as e" - " INNER JOIN event_edges as g" - " ON g.event_id = e.event_id" - " INNER JOIN event_backward_extremities as b" - " ON g.prev_event_id = b.event_id" - " WHERE b.room_id = ? AND g.is_state is ?" - " GROUP BY b.event_id" - ) + async def get_insertion_event_backwards_extremities_in_room( + self, room_id + ) -> Dict[str, int]: + """Get the insertion events we know about that we haven't backfilled yet. - txn.execute(sql, (room_id, False)) + We use this function so that we can compare and see if someones current + depth at their current scrollback is within pagination range of the + insertion event. If the current depth is close to the depth of given + insertion event, we can trigger a backfill. - return dict(txn) + Args: + room_id: Room where we want to find the oldest events + + Returns: + Map from event_id to depth + """ + + def get_insertion_event_backwards_extremities_in_room_txn(txn, room_id): + sql = """ + SELECT b.event_id, MAX(e.depth) FROM insertion_events as i + /* We only want insertion events that are also marked as backwards extremities */ + INNER JOIN insertion_event_extremities as b USING (event_id) + /* Get the depth of the insertion event from the events table */ + INNER JOIN events AS e USING (event_id) + WHERE b.room_id = ? + GROUP BY b.event_id + """ + + txn.execute(sql, (room_id,)) + + return dict(txn) + + return await self.db_pool.runInteraction( + "get_insertion_event_backwards_extremities_in_room", + get_insertion_event_backwards_extremities_in_room_txn, + room_id, + ) async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]: """Returns the event ID and depth for the event that has the max depth from a set of event IDs @@ -1041,7 +1111,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas if row[1] not in event_results: queue.put((-row[0], row[1])) - # Navigate up the DAG by prev_event txn.execute(query, (event_id, False, limit - len(event_results))) prev_event_id_results = txn.fetchall() logger.debug( @@ -1136,6 +1205,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas _delete_old_forward_extrem_cache_txn, ) + async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None: + await self.db_pool.simple_upsert( + table="insertion_event_extremities", + keyvalues={"event_id": event_id}, + values={ + "event_id": event_id, + "room_id": room_id, + }, + insertion_values={}, + desc="insert_insertion_extremity", + lock=False, + ) + async def insert_received_event_to_staging( self, origin: str, event: EventBase ) -> None: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 86baf397fb..40b53274fb 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1845,6 +1845,18 @@ class PersistEventsStore: }, ) + # When we receive an event with a `chunk_id` referencing the + # `next_chunk_id` of the insertion event, we can remove it from the + # `insertion_event_extremities` table. + sql = """ + DELETE FROM insertion_event_extremities WHERE event_id IN ( + SELECT event_id FROM insertion_events + WHERE next_chunk_id = ? + ) + """ + + txn.execute(sql, (chunk_id,)) + def _handle_redaction(self, txn, redacted_event_id): """Handles receiving a redaction and checking whether we need to remove any redacted relations from the database. @@ -2101,15 +2113,17 @@ class PersistEventsStore: Forward extremities are handled when we first start persisting the events. """ + # From the events passed in, add all of the prev events as backwards extremities. + # Ignore any events that are already backwards extrems or outliers. query = ( "INSERT INTO event_backward_extremities (event_id, room_id)" " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" " )" " AND NOT EXISTS (" - " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " - " AND outlier = ?" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " + " AND outlier = ?" " )" ) @@ -2123,6 +2137,8 @@ class PersistEventsStore: ], ) + # Delete all these events that we've already fetched and now know that their + # prev events are the new backwards extremeties. query = ( "DELETE FROM event_backward_extremities" " WHERE event_id = ? AND room_id = ?" diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 36340a652a..fd4dd67d91 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 61 +SCHEMA_VERSION = 62 """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the diff --git a/synapse/storage/schema/main/delta/62/01insertion_event_extremities.sql b/synapse/storage/schema/main/delta/62/01insertion_event_extremities.sql new file mode 100644 index 0000000000..b731ef284a --- /dev/null +++ b/synapse/storage/schema/main/delta/62/01insertion_event_extremities.sql @@ -0,0 +1,24 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Add a table that keeps track of which "insertion" events need to be backfilled +CREATE TABLE IF NOT EXISTS insertion_event_extremities( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS insertion_event_extremities_event_id ON insertion_event_extremities(event_id); +CREATE INDEX IF NOT EXISTS insertion_event_extremities_room_id ON insertion_event_extremities(room_id); -- cgit 1.5.1 From 9db24cc50d252b1685a4ac69a736b49ed225dcb6 Mon Sep 17 00:00:00 2001 From: Michael Telatynski <7t3chguy@gmail.com> Date: Wed, 4 Aug 2021 18:39:57 +0100 Subject: Send unstable-prefixed room_type in store-invite IS API requests (#10435) The room type is per MSC3288 to allow the identity-server to change invitation wording based on whether the invitation is to a room or a space. The prefixed key will be replaced once MSC3288 is accepted into the spec. --- changelog.d/10435.feature | 1 + synapse/handlers/identity.py | 6 ++++++ synapse/handlers/room_member.py | 13 ++++++++++++- 3 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10435.feature (limited to 'synapse') diff --git a/changelog.d/10435.feature b/changelog.d/10435.feature new file mode 100644 index 0000000000..f93ef5b415 --- /dev/null +++ b/changelog.d/10435.feature @@ -0,0 +1 @@ +Experimental support for [MSC3288](https://github.com/matrix-org/matrix-doc/pull/3288), sending `room_type` to the identity server for 3pid invites over the `/store-invite` API. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 0961dec5ab..8ffeabacf9 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -824,6 +824,7 @@ class IdentityHandler(BaseHandler): room_avatar_url: str, room_join_rules: str, room_name: str, + room_type: Optional[str], inviter_display_name: str, inviter_avatar_url: str, id_access_token: Optional[str] = None, @@ -843,6 +844,7 @@ class IdentityHandler(BaseHandler): notifications. room_join_rules: The join rules of the email (e.g. "public"). room_name: The m.room.name of the room. + room_type: The type of the room from its m.room.create event (e.g "m.space"). inviter_display_name: The current display name of the inviter. inviter_avatar_url: The URL of the inviter's avatar. @@ -869,6 +871,10 @@ class IdentityHandler(BaseHandler): "sender_display_name": inviter_display_name, "sender_avatar_url": inviter_avatar_url, } + + if room_type is not None: + invite_config["org.matrix.msc3288.room_type"] = room_type + # If a custom web client location is available, include it in the request. if self._web_client_location: invite_config["org.matrix.web_client_location"] = self._web_client_location diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 65ad3efa6a..ba13196218 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -19,7 +19,12 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple from synapse import types -from synapse.api.constants import AccountDataTypes, EventTypes, Membership +from synapse.api.constants import ( + AccountDataTypes, + EventContentFields, + EventTypes, + Membership, +) from synapse.api.errors import ( AuthError, Codes, @@ -1237,6 +1242,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if room_name_event: room_name = room_name_event.content.get("name", "") + room_type = None + room_create_event = room_state.get((EventTypes.Create, "")) + if room_create_event: + room_type = room_create_event.content.get(EventContentFields.ROOM_TYPE) + room_join_rules = "" join_rules_event = room_state.get((EventTypes.JoinRules, "")) if join_rules_event: @@ -1263,6 +1273,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): room_avatar_url=room_avatar_url, room_join_rules=room_join_rules, room_name=room_name, + room_type=room_type, inviter_display_name=inviter_display_name, inviter_avatar_url=inviter_avatar_url, id_access_token=id_access_token, -- cgit 1.5.1 From a8a27b2b8bac2995c3edd20518680366eb543ac9 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Thu, 5 Aug 2021 13:22:14 +0100 Subject: Only return an appservice protocol if it has a service providing it. (#10532) If there are no services providing a protocol, omit it completely instead of returning an empty dictionary. This fixes a long-standing spec compliance bug. --- changelog.d/10532.bugfix | 1 + synapse/handlers/appservice.py | 7 +-- tests/handlers/test_appservice.py | 122 +++++++++++++++++++++++++++++++++++++- 3 files changed, 125 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10532.bugfix (limited to 'synapse') diff --git a/changelog.d/10532.bugfix b/changelog.d/10532.bugfix new file mode 100644 index 0000000000..d95e3d9b59 --- /dev/null +++ b/changelog.d/10532.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where protocols which are not implemented by any appservices were incorrectly returned via `GET /_matrix/client/r0/thirdparty/protocols`. diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 21a17cd2e8..4ab4046650 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -392,9 +392,6 @@ class ApplicationServicesHandler: protocols[p].append(info) def _merge_instances(infos: List[JsonDict]) -> JsonDict: - if not infos: - return {} - # Merge the 'instances' lists of multiple results, but just take # the other fields from the first as they ought to be identical # copy the result so as not to corrupt the cached one @@ -406,7 +403,9 @@ class ApplicationServicesHandler: return combined - return {p: _merge_instances(protocols[p]) for p in protocols.keys()} + return { + p: _merge_instances(protocols[p]) for p in protocols.keys() if protocols[p] + } async def _get_services_for_event( self, event: EventBase diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 024c5e963c..43998020b2 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -133,11 +133,131 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.assertEquals(result.room_id, room_id) self.assertEquals(result.servers, servers) - def _mkservice(self, is_interested): + def test_get_3pe_protocols_no_appservices(self): + self.mock_store.get_app_services.return_value = [] + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) + ) + self.mock_as_api.get_3pe_protocol.assert_not_called() + self.assertEquals(response, {}) + + def test_get_3pe_protocols_no_protocols(self): + service = self._mkservice(False, []) + self.mock_store.get_app_services.return_value = [service] + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_not_called() + self.assertEquals(response, {}) + + def test_get_3pe_protocols_protocol_no_response(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals(response, {}) + + def test_get_3pe_protocols_select_one_protocol(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals( + response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} + ) + + def test_get_3pe_protocols_one_protocol(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals( + response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} + ) + + def test_get_3pe_protocols_multiple_protocol(self): + service_one = self._mkservice(False, ["my-protocol"]) + service_two = self._mkservice(False, ["other-protocol"]) + self.mock_store.get_app_services.return_value = [service_one, service_two] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called() + self.assertEquals( + response, + { + "my-protocol": {"x-protocol-data": 42, "instances": []}, + "other-protocol": {"x-protocol-data": 42, "instances": []}, + }, + ) + + def test_get_3pe_protocols_multiple_info(self): + service_one = self._mkservice(False, ["my-protocol"]) + service_two = self._mkservice(False, ["my-protocol"]) + + async def get_3pe_protocol(service, unusedProtocol): + if service == service_one: + return { + "x-protocol-data": 42, + "instances": [{"desc": "Alice's service"}], + } + if service == service_two: + return { + "x-protocol-data": 36, + "x-not-used": 45, + "instances": [{"desc": "Bob's service"}], + } + raise Exception("Unexpected service") + + self.mock_store.get_app_services.return_value = [service_one, service_two] + self.mock_as_api.get_3pe_protocol = get_3pe_protocol + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + # It's expected that the second service's data doesn't appear in the response + self.assertEquals( + response, + { + "my-protocol": { + "x-protocol-data": 42, + "instances": [ + { + "desc": "Alice's service", + }, + {"desc": "Bob's service"}, + ], + }, + }, + ) + + def _mkservice(self, is_interested, protocols=None): service = Mock() service.is_interested.return_value = make_awaitable(is_interested) service.token = "mock_service_token" service.url = "mock_service_url" + service.protocols = protocols return service def _mkservice_alias(self, is_interested_in_alias): -- cgit 1.5.1 From 3b354faad0e6b1f41ed5dd0269a1785d3f505465 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 5 Aug 2021 08:39:17 -0400 Subject: Refactoring before implementing the updated spaces summary. (#10527) This should have no user-visible changes, but refactors some pieces of the SpaceSummaryHandler before adding support for the updated MSC2946. --- changelog.d/10527.misc | 1 + synapse/federation/federation_client.py | 23 ++-- synapse/handlers/space_summary.py | 125 ++++++++++++--------- tests/handlers/test_space_summary.py | 185 ++++++++++++++++++-------------- 4 files changed, 198 insertions(+), 136 deletions(-) create mode 100644 changelog.d/10527.misc (limited to 'synapse') diff --git a/changelog.d/10527.misc b/changelog.d/10527.misc new file mode 100644 index 0000000000..3cf22f9daf --- /dev/null +++ b/changelog.d/10527.misc @@ -0,0 +1 @@ +Prepare for the new spaces summary endpoint (updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946)). diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b7a10da15a..007d1a27dc 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1290,7 +1290,7 @@ class FederationClient(FederationBase): ) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class FederationSpaceSummaryEventResult: """Represents a single event in the result of a successful get_space_summary call. @@ -1299,12 +1299,13 @@ class FederationSpaceSummaryEventResult: object attributes. """ - event_type = attr.ib(type=str) - state_key = attr.ib(type=str) - via = attr.ib(type=Sequence[str]) + event_type: str + room_id: str + state_key: str + via: Sequence[str] # the raw data, including the above keys - data = attr.ib(type=JsonDict) + data: JsonDict @classmethod def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryEventResult": @@ -1321,6 +1322,10 @@ class FederationSpaceSummaryEventResult: if not isinstance(event_type, str): raise ValueError("Invalid event: 'event_type' must be a str") + room_id = d.get("room_id") + if not isinstance(room_id, str): + raise ValueError("Invalid event: 'room_id' must be a str") + state_key = d.get("state_key") if not isinstance(state_key, str): raise ValueError("Invalid event: 'state_key' must be a str") @@ -1335,15 +1340,15 @@ class FederationSpaceSummaryEventResult: if any(not isinstance(v, str) for v in via): raise ValueError("Invalid event: 'via' must be a list of strings") - return cls(event_type, state_key, via, d) + return cls(event_type, room_id, state_key, via, d) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class FederationSpaceSummaryResult: """Represents the data returned by a successful get_space_summary call.""" - rooms = attr.ib(type=Sequence[JsonDict]) - events = attr.ib(type=Sequence[FederationSpaceSummaryEventResult]) + rooms: Sequence[JsonDict] + events: Sequence[FederationSpaceSummaryEventResult] @classmethod def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryResult": diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 5f7d4602bd..3eb232c83e 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -16,7 +16,17 @@ import itertools import logging import re from collections import deque -from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, +) import attr @@ -116,20 +126,22 @@ class SpaceSummaryHandler: max_children = max_rooms_per_space if processed_rooms else None if is_in_room: - room, events = await self._summarize_local_room( + room_entry = await self._summarize_local_room( requester, None, room_id, suggested_only, max_children ) + events: Collection[JsonDict] = [] + if room_entry: + rooms_result.append(room_entry.room) + events = room_entry.children + logger.debug( "Query of local room %s returned events %s", room_id, ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], ) - - if room: - rooms_result.append(room) else: - fed_rooms, fed_events = await self._summarize_remote_room( + fed_rooms = await self._summarize_remote_room( queue_entry, suggested_only, max_children, @@ -141,12 +153,10 @@ class SpaceSummaryHandler: # user is not permitted see. # # Filter the returned results to only what is accessible to the user. - room_ids = set() events = [] - for room in fed_rooms: - fed_room_id = room.get("room_id") - if not fed_room_id or not isinstance(fed_room_id, str): - continue + for room_entry in fed_rooms: + room = room_entry.room + fed_room_id = room_entry.room_id # The room should only be included in the summary if: # a. the user is in the room; @@ -189,21 +199,17 @@ class SpaceSummaryHandler: # The user can see the room, include it! if include_room: rooms_result.append(room) - room_ids.add(fed_room_id) + events.extend(room_entry.children) # All rooms returned don't need visiting again (even if the user # didn't have access to them). processed_rooms.add(fed_room_id) - for event in fed_events: - if event.get("room_id") in room_ids: - events.append(event) - logger.debug( "Query of %s returned rooms %s, events %s", room_id, - [room.get("room_id") for room in fed_rooms], - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in fed_events], + [room_entry.room.get("room_id") for room_entry in fed_rooms], + ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], ) # the room we queried may or may not have been returned, but don't process @@ -283,20 +289,20 @@ class SpaceSummaryHandler: # already done this room continue - logger.debug("Processing room %s", room_id) - - room, events = await self._summarize_local_room( + room_entry = await self._summarize_local_room( None, origin, room_id, suggested_only, max_rooms_per_space ) processed_rooms.add(room_id) - if room: - rooms_result.append(room) - events_result.extend(events) + if room_entry: + rooms_result.append(room_entry.room) + events_result.extend(room_entry.children) - # add any children to the queue - room_queue.extend(edge_event["state_key"] for edge_event in events) + # add any children to the queue + room_queue.extend( + edge_event["state_key"] for edge_event in room_entry.children + ) return {"rooms": rooms_result, "events": events_result} @@ -307,7 +313,7 @@ class SpaceSummaryHandler: room_id: str, suggested_only: bool, max_children: Optional[int], - ) -> Tuple[Optional[JsonDict], Sequence[JsonDict]]: + ) -> Optional["_RoomEntry"]: """ Generate a room entry and a list of event entries for a given room. @@ -326,21 +332,16 @@ class SpaceSummaryHandler: to a server-set limit. Returns: - A tuple of: - The room information, if the room should be returned to the - user. None, otherwise. - - An iterable of the sorted children events. This may be limited - to a maximum size or may include all children. + A room entry if the room should be returned. None, otherwise. """ if not await self._is_room_accessible(room_id, requester, origin): - return None, () + return None room_entry = await self._build_room_entry(room_id) # If the room is not a space, return just the room information. if room_entry.get("room_type") != RoomTypes.SPACE: - return room_entry, () + return _RoomEntry(room_id, room_entry) # Otherwise, look for child rooms/spaces. child_events = await self._get_child_events(room_id) @@ -363,7 +364,7 @@ class SpaceSummaryHandler: ) ) - return room_entry, events_result + return _RoomEntry(room_id, room_entry, events_result) async def _summarize_remote_room( self, @@ -371,7 +372,7 @@ class SpaceSummaryHandler: suggested_only: bool, max_children: Optional[int], exclude_rooms: Iterable[str], - ) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]: + ) -> Iterable["_RoomEntry"]: """ Request room entries and a list of event entries for a given room by querying a remote server. @@ -386,11 +387,7 @@ class SpaceSummaryHandler: Rooms IDs which do not need to be summarized. Returns: - A tuple of: - An iterable of rooms. - - An iterable of the sorted children events. This may be limited - to a maximum size or may include all children. + An iterable of room entries. """ room_id = room.room_id logger.info("Requesting summary for %s via %s", room_id, room.via) @@ -414,11 +411,30 @@ class SpaceSummaryHandler: e, exc_info=logger.isEnabledFor(logging.DEBUG), ) - return (), () + return () + + # Group the events by their room. + children_by_room: Dict[str, List[JsonDict]] = {} + for ev in res.events: + if ev.event_type == EventTypes.SpaceChild: + children_by_room.setdefault(ev.room_id, []).append(ev.data) + + # Generate the final results. + results = [] + for fed_room in res.rooms: + fed_room_id = fed_room.get("room_id") + if not fed_room_id or not isinstance(fed_room_id, str): + continue - return res.rooms, tuple( - ev.data for ev in res.events if ev.event_type == EventTypes.SpaceChild - ) + results.append( + _RoomEntry( + fed_room_id, + fed_room, + children_by_room.get(fed_room_id, []), + ) + ) + + return results async def _is_room_accessible( self, room_id: str, requester: Optional[str], origin: Optional[str] @@ -606,10 +622,21 @@ class SpaceSummaryHandler: return sorted(filter(_has_valid_via, events), key=_child_events_comparison_key) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class _RoomQueueEntry: - room_id = attr.ib(type=str) - via = attr.ib(type=Sequence[str]) + room_id: str + via: Sequence[str] + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class _RoomEntry: + room_id: str + # The room summary for this room. + room: JsonDict + # An iterable of the sorted, stripped children events for children of this room. + # + # This may not include all children. + children: Collection[JsonDict] = () def _has_valid_via(e: EventBase) -> bool: diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index 3f73ad7f94..f982a8c8b4 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -26,7 +26,7 @@ from synapse.api.constants import ( from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict -from synapse.handlers.space_summary import _child_events_comparison_key +from synapse.handlers.space_summary import _child_events_comparison_key, _RoomEntry from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.server import HomeServer @@ -351,26 +351,30 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # events before child events). # Note that these entries are brief, but should contain enough info. - rooms = [ - { - "room_id": subspace, - "world_readable": True, - "room_type": RoomTypes.SPACE, - }, - { - "room_id": subroom, - "world_readable": True, - }, - ] - event_content = {"via": [fed_hostname]} - events = [ - { - "room_id": subspace, - "state_key": subroom, - "content": event_content, - }, + return [ + _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + "room_type": RoomTypes.SPACE, + }, + [ + { + "room_id": subspace, + "state_key": subroom, + "content": {"via": [fed_hostname]}, + } + ], + ), + _RoomEntry( + subroom, + { + "room_id": subroom, + "world_readable": True, + }, + ), ] - return rooms, events # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) @@ -436,70 +440,95 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ): # Note that these entries are brief, but should contain enough info. rooms = [ - { - "room_id": public_room, - "world_readable": False, - "join_rules": JoinRules.PUBLIC, - }, - { - "room_id": knock_room, - "world_readable": False, - "join_rules": JoinRules.KNOCK, - }, - { - "room_id": not_invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": restricted_room, - "world_readable": False, - "join_rules": JoinRules.MSC3083_RESTRICTED, - "allowed_spaces": [], - }, - { - "room_id": restricted_accessible_room, - "world_readable": False, - "join_rules": JoinRules.MSC3083_RESTRICTED, - "allowed_spaces": [self.room], - }, - { - "room_id": world_readable_room, - "world_readable": True, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": joined_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ] - - # Place each room in the sub-space. - event_content = {"via": [fed_hostname]} - events = [ - { - "room_id": subspace, - "state_key": room["room_id"], - "content": event_content, - } - for room in rooms + _RoomEntry( + public_room, + { + "room_id": public_room, + "world_readable": False, + "join_rules": JoinRules.PUBLIC, + }, + ), + _RoomEntry( + knock_room, + { + "room_id": knock_room, + "world_readable": False, + "join_rules": JoinRules.KNOCK, + }, + ), + _RoomEntry( + not_invited_room, + { + "room_id": not_invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + _RoomEntry( + invited_room, + { + "room_id": invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + _RoomEntry( + restricted_room, + { + "room_id": restricted_room, + "world_readable": False, + "join_rules": JoinRules.MSC3083_RESTRICTED, + "allowed_spaces": [], + }, + ), + _RoomEntry( + restricted_accessible_room, + { + "room_id": restricted_accessible_room, + "world_readable": False, + "join_rules": JoinRules.MSC3083_RESTRICTED, + "allowed_spaces": [self.room], + }, + ), + _RoomEntry( + world_readable_room, + { + "room_id": world_readable_room, + "world_readable": True, + "join_rules": JoinRules.INVITE, + }, + ), + _RoomEntry( + joined_room, + { + "room_id": joined_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), ] # Also include the subspace. rooms.insert( 0, - { - "room_id": subspace, - "world_readable": True, - }, + _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + }, + # Place each room in the sub-space. + [ + { + "room_id": subspace, + "state_key": room.room_id, + "content": {"via": [fed_hostname]}, + } + for room in rooms + ], + ), ) - return rooms, events + return rooms # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) -- cgit 1.5.1 From f5a368bb48df85dd488afdead01a39f77f50de99 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 5 Aug 2021 20:35:53 -0500 Subject: Mark all MSC2716 events as historical (#10537) * Mark all MSC2716 events as historical --- changelog.d/10537.misc | 1 + synapse/rest/client/v1/room.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10537.misc (limited to 'synapse') diff --git a/changelog.d/10537.misc b/changelog.d/10537.misc new file mode 100644 index 0000000000..c9e045300c --- /dev/null +++ b/changelog.d/10537.misc @@ -0,0 +1 @@ +Mark all events stemming from the MSC2716 `/batch_send` endpoint as historical. diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 502a917588..982f134148 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -458,6 +458,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): "state_key": state_event["state_key"], } + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + # Make the state events float off on their own fake_prev_event_id = "$" + random_string(43) @@ -562,7 +565,10 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): "type": EventTypes.MSC2716_CHUNK, "sender": requester.user.to_string(), "room_id": room_id, - "content": {EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to}, + "content": { + EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to, + EventContentFields.MSC2716_HISTORICAL: True, + }, # Since the chunk event is put at the end of the chunk, # where the newest-in-time event is, copy the origin_server_ts from # the last event we're inserting @@ -589,10 +595,6 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): for ev in events_to_create: assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) - # Mark all events as historical - # This has important semantics within the Synapse internals to backfill properly - ev["content"][EventContentFields.MSC2716_HISTORICAL] = True - event_dict = { "type": ev["type"], "origin_server_ts": ev["origin_server_ts"], @@ -602,6 +604,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): "prev_events": prev_event_ids.copy(), } + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + event, context = await self.event_creation_handler.create_event( await self._create_requester_for_user_id_from_app_service( ev["sender"], requester.app_service -- cgit 1.5.1 From 74d7336686e7de1d0923d67af61b510ec801fa84 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 6 Aug 2021 11:13:34 +0100 Subject: Add a setting to disable TLS for sending email (#10546) This is mostly useful in case the server offers TLS, but doesn't present a valid certificate. --- changelog.d/10546.feature | 1 + docs/sample_config.yaml | 8 +++ synapse/config/emailconfig.py | 14 +++++ synapse/handlers/send_email.py | 94 +++++++++++++++++++++++------ synapse/server.py | 6 -- tests/push/test_email.py | 20 +++--- tests/rest/client/v2_alpha/test_account.py | 33 ++++++---- tests/rest/client/v2_alpha/test_register.py | 12 ++-- 8 files changed, 138 insertions(+), 50 deletions(-) create mode 100644 changelog.d/10546.feature (limited to 'synapse') diff --git a/changelog.d/10546.feature b/changelog.d/10546.feature new file mode 100644 index 0000000000..7709d010b3 --- /dev/null +++ b/changelog.d/10546.feature @@ -0,0 +1 @@ +Add a setting to disable TLS when sending email. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 16843dd8c9..aeebcaf45f 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2242,6 +2242,14 @@ email: # #require_transport_security: true + # Uncomment the following to disable TLS for SMTP. + # + # By default, if the server supports TLS, it will be used, and the server + # must present a certificate that is valid for 'smtp_host'. If this option + # is set to false, TLS will not be used. + # + #enable_tls: false + # notif_from defines the "From" address to use when sending emails. # It must be set if email sending is enabled. # diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 8d8f166e9b..42526502f0 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -80,6 +80,12 @@ class EmailConfig(Config): self.require_transport_security = email_config.get( "require_transport_security", False ) + self.enable_smtp_tls = email_config.get("enable_tls", True) + if self.require_transport_security and not self.enable_smtp_tls: + raise ConfigError( + "email.require_transport_security requires email.enable_tls to be true" + ) + if "app_name" in email_config: self.email_app_name = email_config["app_name"] else: @@ -368,6 +374,14 @@ class EmailConfig(Config): # #require_transport_security: true + # Uncomment the following to disable TLS for SMTP. + # + # By default, if the server supports TLS, it will be used, and the server + # must present a certificate that is valid for 'smtp_host'. If this option + # is set to false, TLS will not be used. + # + #enable_tls: false + # notif_from defines the "From" address to use when sending emails. # It must be set if email sending is enabled. # diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py index e9f6aef06f..dda9659c11 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py @@ -16,7 +16,12 @@ import email.utils import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from typing import TYPE_CHECKING +from io import BytesIO +from typing import TYPE_CHECKING, Optional + +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IReactorTCP +from twisted.mail.smtp import ESMTPSenderFactory from synapse.logging.context import make_deferred_yieldable @@ -26,19 +31,75 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +async def _sendmail( + reactor: IReactorTCP, + smtphost: str, + smtpport: int, + from_addr: str, + to_addr: str, + msg_bytes: bytes, + username: Optional[bytes] = None, + password: Optional[bytes] = None, + require_auth: bool = False, + require_tls: bool = False, + tls_hostname: Optional[str] = None, +) -> None: + """A simple wrapper around ESMTPSenderFactory, to allow substitution in tests + + Params: + reactor: reactor to use to make the outbound connection + smtphost: hostname to connect to + smtpport: port to connect to + from_addr: "From" address for email + to_addr: "To" address for email + msg_bytes: Message content + username: username to authenticate with, if auth is enabled + password: password to give when authenticating + require_auth: if auth is not offered, fail the request + require_tls: if TLS is not offered, fail the reqest + tls_hostname: TLS hostname to check for. None to disable TLS. + """ + msg = BytesIO(msg_bytes) + + d: "Deferred[object]" = Deferred() + + factory = ESMTPSenderFactory( + username, + password, + from_addr, + to_addr, + msg, + d, + heloFallback=True, + requireAuthentication=require_auth, + requireTransportSecurity=require_tls, + hostname=tls_hostname, + ) + + # the IReactorTCP interface claims host has to be a bytes, which seems to be wrong + reactor.connectTCP(smtphost, smtpport, factory, timeout=30, bindAddress=None) # type: ignore[arg-type] + + await make_deferred_yieldable(d) + + class SendEmailHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self._sendmail = hs.get_sendmail() self._reactor = hs.get_reactor() self._from = hs.config.email.email_notif_from self._smtp_host = hs.config.email.email_smtp_host self._smtp_port = hs.config.email.email_smtp_port - self._smtp_user = hs.config.email.email_smtp_user - self._smtp_pass = hs.config.email.email_smtp_pass + + user = hs.config.email.email_smtp_user + self._smtp_user = user.encode("utf-8") if user is not None else None + passwd = hs.config.email.email_smtp_pass + self._smtp_pass = passwd.encode("utf-8") if passwd is not None else None self._require_transport_security = hs.config.email.require_transport_security + self._enable_tls = hs.config.email.enable_smtp_tls + + self._sendmail = _sendmail async def send_email( self, @@ -82,17 +143,16 @@ class SendEmailHandler: logger.info("Sending email to %s" % email_address) - await make_deferred_yieldable( - self._sendmail( - self._smtp_host, - raw_from, - raw_to, - multipart_msg.as_string().encode("utf8"), - reactor=self._reactor, - port=self._smtp_port, - requireAuthentication=self._smtp_user is not None, - username=self._smtp_user, - password=self._smtp_pass, - requireTransportSecurity=self._require_transport_security, - ) + await self._sendmail( + self._reactor, + self._smtp_host, + self._smtp_port, + raw_from, + raw_to, + multipart_msg.as_string().encode("utf8"), + username=self._smtp_user, + password=self._smtp_pass, + require_auth=self._smtp_user is not None, + require_tls=self._require_transport_security, + tls_hostname=self._smtp_host if self._enable_tls else None, ) diff --git a/synapse/server.py b/synapse/server.py index 095dba9ad0..6c867f0f47 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -34,8 +34,6 @@ from typing import ( ) import twisted.internet.tcp -from twisted.internet import defer -from twisted.mail.smtp import sendmail from twisted.web.iweb import IPolicyForHTTPS from twisted.web.resource import IResource @@ -442,10 +440,6 @@ class HomeServer(metaclass=abc.ABCMeta): def get_room_shutdown_handler(self) -> RoomShutdownHandler: return RoomShutdownHandler(self) - @cache_in_self - def get_sendmail(self) -> Callable[..., defer.Deferred]: - return sendmail - @cache_in_self def get_state_handler(self) -> StateHandler: return StateHandler(self) diff --git a/tests/push/test_email.py b/tests/push/test_email.py index e04bc5c9a6..a487706758 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -45,14 +45,6 @@ class EmailPusherTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): - # List[Tuple[Deferred, args, kwargs]] - self.email_attempts = [] - - def sendmail(*args, **kwargs): - d = Deferred() - self.email_attempts.append((d, args, kwargs)) - return d - config = self.default_config() config["email"] = { "enable_notifs": True, @@ -75,7 +67,17 @@ class EmailPusherTests(HomeserverTestCase): config["public_baseurl"] = "aaa" config["start_pushers"] = True - hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + hs = self.setup_test_homeserver(config=config) + + # List[Tuple[Deferred, args, kwargs]] + self.email_attempts = [] + + def sendmail(*args, **kwargs): + d = Deferred() + self.email_attempts.append((d, args, kwargs)) + return d + + hs.get_send_email_handler()._sendmail = sendmail return hs diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 317a2287e3..e7e617e9df 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -47,12 +47,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): config = self.default_config() # Email config. - self.email_attempts = [] - - async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): - self.email_attempts.append(msg) - return - config["email"] = { "enable_notifs": False, "template_dir": os.path.abspath( @@ -67,7 +61,16 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): } config["public_baseurl"] = "https://example.com" - hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + hs = self.setup_test_homeserver(config=config) + + async def sendmail( + reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs + ): + self.email_attempts.append(msg) + + self.email_attempts = [] + hs.get_send_email_handler()._sendmail = sendmail + return hs def prepare(self, reactor, clock, hs): @@ -511,11 +514,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): config = self.default_config() # Email config. - self.email_attempts = [] - - async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): - self.email_attempts.append(msg) - config["email"] = { "enable_notifs": False, "template_dir": os.path.abspath( @@ -530,7 +528,16 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): } config["public_baseurl"] = "https://example.com" - self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + self.hs = self.setup_test_homeserver(config=config) + + async def sendmail( + reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs + ): + self.email_attempts.append(msg) + + self.email_attempts = [] + self.hs.get_send_email_handler()._sendmail = sendmail + return self.hs def prepare(self, reactor, clock, hs): diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 1cad5f00eb..a52e5e608a 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -509,10 +509,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): } # Email config. - self.email_attempts = [] - - async def sendmail(*args, **kwargs): - self.email_attempts.append((args, kwargs)) config["email"] = { "enable_notifs": True, @@ -532,7 +528,13 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): } config["public_baseurl"] = "aaa" - self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + self.hs = self.setup_test_homeserver(config=config) + + async def sendmail(*args, **kwargs): + self.email_attempts.append((args, kwargs)) + + self.email_attempts = [] + self.hs.get_send_email_handler()._sendmail = sendmail self.store = self.hs.get_datastore() -- cgit 1.5.1 From f4ade972ada6d61ca9370d26784ac9f3ed8e5282 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 6 Aug 2021 07:40:29 -0400 Subject: Update the API response for spaces summary over federation. (#10530) This adds 'allowed_room_ids' (in addition to 'allowed_spaces', for backwards compatibility) to the federation response of the spaces summary. A future PR will remove the 'allowed_spaces' flag. --- changelog.d/10530.misc | 1 + synapse/handlers/space_summary.py | 57 ++++++++++++++++++++++++++------------- 2 files changed, 39 insertions(+), 19 deletions(-) create mode 100644 changelog.d/10530.misc (limited to 'synapse') diff --git a/changelog.d/10530.misc b/changelog.d/10530.misc new file mode 100644 index 0000000000..3cf22f9daf --- /dev/null +++ b/changelog.d/10530.misc @@ -0,0 +1 @@ +Prepare for the new spaces summary endpoint (updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946)). diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 3eb232c83e..2517f278b6 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -179,7 +179,9 @@ class SpaceSummaryHandler: # Check if the user is a member of any of the allowed spaces # from the response. - allowed_rooms = room.get("allowed_spaces") + allowed_rooms = room.get("allowed_room_ids") or room.get( + "allowed_spaces" + ) if ( not include_room and allowed_rooms @@ -198,6 +200,11 @@ class SpaceSummaryHandler: # The user can see the room, include it! if include_room: + # Before returning to the client, remove the allowed_room_ids + # and allowed_spaces keys. + room.pop("allowed_room_ids", None) + room.pop("allowed_spaces", None) + rooms_result.append(room) events.extend(room_entry.children) @@ -236,11 +243,6 @@ class SpaceSummaryHandler: ) processed_events.add(ev_key) - # Before returning to the client, remove the allowed_spaces key for any - # rooms. - for room in rooms_result: - room.pop("allowed_spaces", None) - return {"rooms": rooms_result, "events": events_result} async def federation_space_summary( @@ -337,7 +339,7 @@ class SpaceSummaryHandler: if not await self._is_room_accessible(room_id, requester, origin): return None - room_entry = await self._build_room_entry(room_id) + room_entry = await self._build_room_entry(room_id, for_federation=bool(origin)) # If the room is not a space, return just the room information. if room_entry.get("room_type") != RoomTypes.SPACE: @@ -548,8 +550,18 @@ class SpaceSummaryHandler: ) return False - async def _build_room_entry(self, room_id: str) -> JsonDict: - """Generate en entry suitable for the 'rooms' list in the summary response""" + async def _build_room_entry(self, room_id: str, for_federation: bool) -> JsonDict: + """ + Generate en entry suitable for the 'rooms' list in the summary response. + + Args: + room_id: The room ID to summarize. + for_federation: True if this is a summary requested over federation + (which includes additional fields). + + Returns: + The JSON dictionary for the room. + """ stats = await self._store.get_room_with_stats(room_id) # currently this should be impossible because we call @@ -562,15 +574,6 @@ class SpaceSummaryHandler: current_state_ids[(EventTypes.Create, "")] ) - room_version = await self._store.get_room_version(room_id) - allowed_rooms = None - if await self._event_auth_handler.has_restricted_join_rules( - current_state_ids, room_version - ): - allowed_rooms = await self._event_auth_handler.get_rooms_that_allow_join( - current_state_ids - ) - entry = { "room_id": stats["room_id"], "name": stats["name"], @@ -585,9 +588,25 @@ class SpaceSummaryHandler: "guest_can_join": stats["guest_access"] == "can_join", "creation_ts": create_event.origin_server_ts, "room_type": create_event.content.get(EventContentFields.ROOM_TYPE), - "allowed_spaces": allowed_rooms, } + # Federation requests need to provide additional information so the + # requested server is able to filter the response appropriately. + if for_federation: + room_version = await self._store.get_room_version(room_id) + if await self._event_auth_handler.has_restricted_join_rules( + current_state_ids, room_version + ): + allowed_rooms = ( + await self._event_auth_handler.get_rooms_that_allow_join( + current_state_ids + ) + ) + if allowed_rooms: + entry["allowed_room_ids"] = allowed_rooms + # TODO Remove this key once the API is stable. + entry["allowed_spaces"] = allowed_rooms + # Filter out Nones – rather omit the field altogether room_entry = {k: v for k, v in entry.items() if v is not None} -- cgit 1.5.1 From 1bebc0b78cbedffb6b69fd76327f0eb7663c3c96 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 6 Aug 2021 13:54:23 +0100 Subject: Clean up federation event auth code (#10539) * drop old-room hack pretty sure we don't need this any more. * Remove incorrect comment about modifying `context` It doesn't look like the supplied context is ever modified. * Stop `_auth_and_persist_event` modifying its parameters This is only called in three places. Two of them don't pass `auth_events`, and the third doesn't use the dict after passing it in, so this should be non-functional. * Stop `_check_event_auth` modifying its parameters `_check_event_auth` is only called in three places. `on_send_membership_event` doesn't pass an `auth_events`, and `prep` and `_auth_and_persist_event` do not use the map after passing it in. * Stop `_update_auth_events_and_context_for_auth` modifying its parameters Return the updated auth event dict, rather than modifying the parameter. This is only called from `_check_event_auth`. * Improve documentation on `_auth_and_persist_event` Rename `auth_events` parameter to better reflect what it contains. * Improve documentation on `_NewEventInfo` * Improve documentation on `_check_event_auth` rename `auth_events` parameter to better describe what it contains * changelog --- changelog.d/10539.misc | 1 + synapse/handlers/federation.py | 118 +++++++++++++++++++++++------------------ tests/test_federation.py | 6 +-- 3 files changed, 69 insertions(+), 56 deletions(-) create mode 100644 changelog.d/10539.misc (limited to 'synapse') diff --git a/changelog.d/10539.misc b/changelog.d/10539.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10539.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 8b602e3813..9a5e726533 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -109,21 +109,33 @@ soft_failed_event_counter = Counter( ) -@attr.s(slots=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class _NewEventInfo: """Holds information about a received event, ready for passing to _auth_and_persist_events Attributes: event: the received event - state: the state at that event + state: the state at that event, according to /state_ids from a remote + homeserver. Only populated for backfilled events which are going to be a + new backwards extremity. + + claimed_auth_event_map: a map of (type, state_key) => event for the event's + claimed auth_events. + + This can include events which have not yet been persisted, in the case that + we are backfilling a batch of events. + + Note: May be incomplete: if we were unable to find all of the claimed auth + events. Also, treat the contents with caution: the events might also have + been rejected, might not yet have been authorized themselves, or they might + be in the wrong room. - auth_events: the auth_event map for that event """ - event = attr.ib(type=EventBase) - state = attr.ib(type=Optional[Sequence[EventBase]], default=None) - auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None) + event: EventBase + state: Optional[Sequence[EventBase]] + claimed_auth_event_map: StateMap[EventBase] class FederationHandler(BaseHandler): @@ -1086,7 +1098,7 @@ class FederationHandler(BaseHandler): _NewEventInfo( event=ev, state=events_to_state[e_id], - auth_events={ + claimed_auth_event_map={ ( auth_events[a_id].type, auth_events[a_id].state_key, @@ -2315,7 +2327,7 @@ class FederationHandler(BaseHandler): event: EventBase, context: EventContext, state: Optional[Iterable[EventBase]] = None, - auth_events: Optional[MutableStateMap[EventBase]] = None, + claimed_auth_event_map: Optional[StateMap[EventBase]] = None, backfilled: bool = False, ) -> None: """ @@ -2327,17 +2339,18 @@ class FederationHandler(BaseHandler): context: The event context. - NB that this function potentially modifies it. state: The state events used to check the event for soft-fail. If this is not provided the current state events will be used. - auth_events: - Map from (event_type, state_key) to event - Normally, our calculated auth_events based on the state of the room - at the event's position in the DAG, though occasionally (eg if the - event is an outlier), may be the auth events claimed by the remote - server. + claimed_auth_event_map: + A map of (type, state_key) => event for the event's claimed auth_events. + Possibly incomplete, and possibly including events that are not yet + persisted, or authed, or in the right room. + + Only populated where we may not already have persisted these events - + for example, when populating outliers. + backfilled: True if the event was backfilled. """ context = await self._check_event_auth( @@ -2345,7 +2358,7 @@ class FederationHandler(BaseHandler): event, context, state=state, - auth_events=auth_events, + claimed_auth_event_map=claimed_auth_event_map, backfilled=backfilled, ) @@ -2409,7 +2422,7 @@ class FederationHandler(BaseHandler): event, res, state=ev_info.state, - auth_events=ev_info.auth_events, + claimed_auth_event_map=ev_info.claimed_auth_event_map, backfilled=backfilled, ) return res @@ -2675,7 +2688,7 @@ class FederationHandler(BaseHandler): event: EventBase, context: EventContext, state: Optional[Iterable[EventBase]] = None, - auth_events: Optional[MutableStateMap[EventBase]] = None, + claimed_auth_event_map: Optional[StateMap[EventBase]] = None, backfilled: bool = False, ) -> EventContext: """ @@ -2687,21 +2700,19 @@ class FederationHandler(BaseHandler): context: The event context. - NB that this function potentially modifies it. state: The state events used to check the event for soft-fail. If this is not provided the current state events will be used. - auth_events: - Map from (event_type, state_key) to event - Normally, our calculated auth_events based on the state of the room - at the event's position in the DAG, though occasionally (eg if the - event is an outlier), may be the auth events claimed by the remote - server. + claimed_auth_event_map: + A map of (type, state_key) => event for the event's claimed auth_events. + Possibly incomplete, and possibly including events that are not yet + persisted, or authed, or in the right room. - Also NB that this function adds entries to it. + Only populated where we may not already have persisted these events - + for example, when populating outliers, or the state for a backwards + extremity. - If this is not provided, it is calculated from the previous state IDs. backfilled: True if the event was backfilled. Returns: @@ -2710,7 +2721,12 @@ class FederationHandler(BaseHandler): room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - if not auth_events: + if claimed_auth_event_map: + # if we have a copy of the auth events from the event, use that as the + # basis for auth. + auth_events = claimed_auth_event_map + else: + # otherwise, we calculate what the auth events *should* be, and use that prev_state_ids = await context.get_prev_state_ids() auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=True @@ -2718,18 +2734,11 @@ class FederationHandler(BaseHandler): auth_events_x = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} - # This is a hack to fix some old rooms where the initial join event - # didn't reference the create event in its auth events. - if event.type == EventTypes.Member and not event.auth_event_ids(): - if len(event.prev_event_ids()) == 1 and event.depth < 5: - c = await self.store.get_event( - event.prev_event_ids()[0], allow_none=True - ) - if c and c.type == EventTypes.Create: - auth_events[(c.type, c.state_key)] = c - try: - context = await self._update_auth_events_and_context_for_auth( + ( + context, + auth_events_for_auth, + ) = await self._update_auth_events_and_context_for_auth( origin, event, context, auth_events ) except Exception: @@ -2742,9 +2751,10 @@ class FederationHandler(BaseHandler): "Ignoring failure and continuing processing of event.", event.event_id, ) + auth_events_for_auth = auth_events try: - event_auth.check(room_version_obj, event, auth_events=auth_events) + event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth) except AuthError as e: logger.warning("Failed auth resolution for %r because %s", event, e) context.rejected = RejectedReason.AUTH_ERROR @@ -2769,8 +2779,8 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, context: EventContext, - auth_events: MutableStateMap[EventBase], - ) -> EventContext: + input_auth_events: StateMap[EventBase], + ) -> Tuple[EventContext, StateMap[EventBase]]: """Helper for _check_event_auth. See there for docs. Checks whether a given event has the expected auth events. If it @@ -2787,7 +2797,7 @@ class FederationHandler(BaseHandler): event: context: - auth_events: + input_auth_events: Map from (event_type, state_key) to event Normally, our calculated auth_events based on the state of the room @@ -2795,11 +2805,12 @@ class FederationHandler(BaseHandler): event is an outlier), may be the auth events claimed by the remote server. - Also NB that this function adds entries to it. - Returns: - updated context + updated context, updated auth event map """ + # take a copy of input_auth_events before we modify it. + auth_events: MutableStateMap[EventBase] = dict(input_auth_events) + event_auth_events = set(event.auth_event_ids()) # missing_auth is the set of the event's auth_events which we don't yet have @@ -2828,7 +2839,7 @@ class FederationHandler(BaseHandler): # The other side isn't around or doesn't implement the # endpoint, so lets just bail out. logger.info("Failed to get event auth from remote: %s", e1) - return context + return context, auth_events seen_remotes = await self.store.have_seen_events( event.room_id, [e.event_id for e in remote_auth_chain] @@ -2859,7 +2870,10 @@ class FederationHandler(BaseHandler): await self.state_handler.compute_event_context(e) ) await self._auth_and_persist_event( - origin, e, missing_auth_event_context, auth_events=auth + origin, + e, + missing_auth_event_context, + claimed_auth_event_map=auth, ) if e.event_id in event_auth_events: @@ -2877,14 +2891,14 @@ class FederationHandler(BaseHandler): # obviously be empty # (b) alternatively, why don't we do it earlier? logger.info("Skipping auth_event fetch for outlier") - return context + return context, auth_events different_auth = event_auth_events.difference( e.event_id for e in auth_events.values() ) if not different_auth: - return context + return context, auth_events logger.info( "auth_events refers to events which are not in our calculated auth " @@ -2910,7 +2924,7 @@ class FederationHandler(BaseHandler): # XXX: should we reject the event in this case? It feels like we should, # but then shouldn't we also do so if we've failed to fetch any of the # auth events? - return context + return context, auth_events # now we state-resolve between our own idea of the auth events, and the remote's # idea of them. @@ -2940,7 +2954,7 @@ class FederationHandler(BaseHandler): event, context, auth_events ) - return context + return context, auth_events async def _update_context_for_auth_events( self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] diff --git a/tests/test_federation.py b/tests/test_federation.py index 0ed8326f55..3785799f46 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -75,10 +75,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) self.handler = self.homeserver.get_federation_handler() - self.handler._check_event_auth = ( - lambda origin, event, context, state, auth_events, backfilled: succeed( - context - ) + self.handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed( + context ) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( -- cgit 1.5.1 From 60f0534b6e910a497800da2454638bcf4aae006e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 6 Aug 2021 14:05:41 +0100 Subject: Fix exceptions in logs when failing to get remote room list (#10541) --- changelog.d/10541.bugfix | 1 + synapse/federation/federation_client.py | 3 +- synapse/handlers/room_list.py | 46 ++++++++++------- synapse/rest/client/v1/room.py | 30 +++++------ tests/rest/client/v1/test_rooms.py | 92 ++++++++++++++++++++++++++++++++- 5 files changed, 134 insertions(+), 38 deletions(-) create mode 100644 changelog.d/10541.bugfix (limited to 'synapse') diff --git a/changelog.d/10541.bugfix b/changelog.d/10541.bugfix new file mode 100644 index 0000000000..bb946e0920 --- /dev/null +++ b/changelog.d/10541.bugfix @@ -0,0 +1 @@ +Fix exceptions in logs when failing to get remote room list. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 007d1a27dc..2eefac04fd 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1108,7 +1108,8 @@ class FederationClient(FederationBase): The response from the remote server. Raises: - HttpResponseException: There was an exception returned from the remote server + HttpResponseException / RequestSendFailed: There was an exception + returned from the remote server SynapseException: M_FORBIDDEN when the remote server has disallowed publicRoom requests over federation diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index fae2c098e3..6d433fad41 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -356,6 +356,12 @@ class RoomListHandler(BaseHandler): include_all_networks: bool = False, third_party_instance_id: Optional[str] = None, ) -> JsonDict: + """Get the public room list from remote server + + Raises: + SynapseError + """ + if not self.enable_room_list_search: return {"chunk": [], "total_room_count_estimate": 0} @@ -395,13 +401,16 @@ class RoomListHandler(BaseHandler): limit = None since_token = None - res = await self._get_remote_list_cached( - server_name, - limit=limit, - since_token=since_token, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) + try: + res = await self._get_remote_list_cached( + server_name, + limit=limit, + since_token=since_token, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) + except (RequestSendFailed, HttpResponseException): + raise SynapseError(502, "Failed to fetch room list") if search_filter: res = { @@ -423,20 +432,21 @@ class RoomListHandler(BaseHandler): include_all_networks: bool = False, third_party_instance_id: Optional[str] = None, ) -> JsonDict: + """Wrapper around FederationClient.get_public_rooms that caches the + result. + """ + repl_layer = self.hs.get_federation_client() if search_filter: # We can't cache when asking for search - try: - return await repl_layer.get_public_rooms( - server_name, - limit=limit, - since_token=since_token, - search_filter=search_filter, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) - except (RequestSendFailed, HttpResponseException): - raise SynapseError(502, "Failed to fetch room list") + return await repl_layer.get_public_rooms( + server_name, + limit=limit, + since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) key = ( server_name, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 982f134148..f887970b76 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -23,7 +23,6 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, - HttpResponseException, InvalidClientCredentialsError, ShadowBanError, SynapseError, @@ -783,12 +782,9 @@ class PublicRoomListRestServlet(TransactionRestServlet): Codes.INVALID_PARAM, ) - try: - data = await handler.get_remote_public_room_list( - server, limit=limit, since_token=since_token - ) - except HttpResponseException as e: - raise e.to_synapse_error() + data = await handler.get_remote_public_room_list( + server, limit=limit, since_token=since_token + ) else: data = await handler.get_local_public_room_list( limit=limit, since_token=since_token @@ -836,17 +832,15 @@ class PublicRoomListRestServlet(TransactionRestServlet): Codes.INVALID_PARAM, ) - try: - data = await handler.get_remote_public_room_list( - server, - limit=limit, - since_token=since_token, - search_filter=search_filter, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) - except HttpResponseException as e: - raise e.to_synapse_error() + data = await handler.get_remote_public_room_list( + server, + limit=limit, + since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) + else: data = await handler.get_local_public_room_list( limit=limit, diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 3df070c936..1a9528ec20 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -19,11 +19,14 @@ import json from typing import Iterable -from unittest.mock import Mock +from unittest.mock import Mock, call from urllib import parse as urlparse +from twisted.internet import defer + import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.errors import HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client.v1 import directory, login, profile, room @@ -1124,6 +1127,93 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) +class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): + """Test that we correctly fallback to local filtering if a remote server + doesn't support search. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver(federation_client=Mock()) + + def prepare(self, reactor, clock, hs): + self.register_user("user", "pass") + self.token = self.login("user", "pass") + + self.federation_client = hs.get_federation_client() + + def test_simple(self): + "Simple test for searching rooms over federation" + self.federation_client.get_public_rooms.side_effect = ( + lambda *a, **k: defer.succeed({}) + ) + + search_filter = {"generic_search_term": "foobar"} + + channel = self.make_request( + "POST", + b"/_matrix/client/r0/publicRooms?server=testserv", + content={"filter": search_filter}, + access_token=self.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.federation_client.get_public_rooms.assert_called_once_with( + "testserv", + limit=100, + since_token=None, + search_filter=search_filter, + include_all_networks=False, + third_party_instance_id=None, + ) + + def test_fallback(self): + "Test that searching public rooms over federation falls back if it gets a 404" + + # The `get_public_rooms` should be called again if the first call fails + # with a 404, when using search filters. + self.federation_client.get_public_rooms.side_effect = ( + HttpResponseException(404, "Not Found", b""), + defer.succeed({}), + ) + + search_filter = {"generic_search_term": "foobar"} + + channel = self.make_request( + "POST", + b"/_matrix/client/r0/publicRooms?server=testserv", + content={"filter": search_filter}, + access_token=self.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.federation_client.get_public_rooms.assert_has_calls( + [ + call( + "testserv", + limit=100, + since_token=None, + search_filter=search_filter, + include_all_networks=False, + third_party_instance_id=None, + ), + call( + "testserv", + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ), + ] + ) + + class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): servlets = [ -- cgit 1.5.1 From 1de26b346796ec8d6b51b4395017f8107f640c47 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 6 Aug 2021 09:39:59 -0400 Subject: Convert Transaction and Edu object to attrs (#10542) Instead of wrapping the JSON into an object, this creates concrete instances for Transaction and Edu. This allows for improved type hints and simplified code. --- changelog.d/10542.misc | 1 + synapse/federation/federation_server.py | 50 ++++++----- synapse/federation/persistence.py | 4 +- synapse/federation/sender/transaction_manager.py | 9 +- synapse/federation/transport/client.py | 2 +- synapse/federation/transport/server.py | 11 +-- synapse/federation/units.py | 90 ++++++++------------ synapse/util/jsonobject.py | 102 ----------------------- 8 files changed, 75 insertions(+), 194 deletions(-) create mode 100644 changelog.d/10542.misc delete mode 100644 synapse/util/jsonobject.py (limited to 'synapse') diff --git a/changelog.d/10542.misc b/changelog.d/10542.misc new file mode 100644 index 0000000000..44b70b4730 --- /dev/null +++ b/changelog.d/10542.misc @@ -0,0 +1 @@ +Convert `Transaction` and `Edu` objects to attrs. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 145b9161d9..0385aadefa 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -195,13 +195,17 @@ class FederationServer(FederationBase): origin, room_id, versions, limit ) - res = self._transaction_from_pdus(pdus).get_dict() + res = self._transaction_dict_from_pdus(pdus) return 200, res async def on_incoming_transaction( - self, origin: str, transaction_data: JsonDict - ) -> Tuple[int, Dict[str, Any]]: + self, + origin: str, + transaction_id: str, + destination: str, + transaction_data: JsonDict, + ) -> Tuple[int, JsonDict]: # If we receive a transaction we should make sure that kick off handling # any old events in the staging area. if not self._started_handling_of_staged_events: @@ -212,8 +216,14 @@ class FederationServer(FederationBase): # accurate as possible. request_time = self._clock.time_msec() - transaction = Transaction(**transaction_data) - transaction_id = transaction.transaction_id # type: ignore + transaction = Transaction( + transaction_id=transaction_id, + destination=destination, + origin=origin, + origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore + pdus=transaction_data.get("pdus"), # type: ignore + edus=transaction_data.get("edus"), + ) if not transaction_id: raise Exception("Transaction missing transaction_id") @@ -221,9 +231,7 @@ class FederationServer(FederationBase): logger.debug("[%s] Got transaction", transaction_id) # Reject malformed transactions early: reject if too many PDUs/EDUs - if len(transaction.pdus) > 50 or ( # type: ignore - hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore - ): + if len(transaction.pdus) > 50 or len(transaction.edus) > 100: logger.info("Transaction PDU or EDU count too large. Returning 400") return 400, {} @@ -263,7 +271,7 @@ class FederationServer(FederationBase): # CRITICAL SECTION: the first thing we must do (before awaiting) is # add an entry to _active_transactions. assert origin not in self._active_transactions - self._active_transactions[origin] = transaction.transaction_id # type: ignore + self._active_transactions[origin] = transaction.transaction_id try: result = await self._handle_incoming_transaction( @@ -291,11 +299,11 @@ class FederationServer(FederationBase): if response: logger.debug( "[%s] We've already responded to this request", - transaction.transaction_id, # type: ignore + transaction.transaction_id, ) return response - logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore + logger.debug("[%s] Transaction is new", transaction.transaction_id) # We process PDUs and EDUs in parallel. This is important as we don't # want to block things like to device messages from reaching clients @@ -334,7 +342,7 @@ class FederationServer(FederationBase): report back to the sending server. """ - received_pdus_counter.inc(len(transaction.pdus)) # type: ignore + received_pdus_counter.inc(len(transaction.pdus)) origin_host, _ = parse_server_name(origin) @@ -342,7 +350,7 @@ class FederationServer(FederationBase): newest_pdu_ts = 0 - for p in transaction.pdus: # type: ignore + for p in transaction.pdus: # FIXME (richardv): I don't think this works: # https://github.com/matrix-org/synapse/issues/8429 if "unsigned" in p: @@ -436,10 +444,10 @@ class FederationServer(FederationBase): return pdu_results - async def _handle_edus_in_txn(self, origin: str, transaction: Transaction): + async def _handle_edus_in_txn(self, origin: str, transaction: Transaction) -> None: """Process the EDUs in a received transaction.""" - async def _process_edu(edu_dict): + async def _process_edu(edu_dict: JsonDict) -> None: received_edus_counter.inc() edu = Edu( @@ -452,7 +460,7 @@ class FederationServer(FederationBase): await concurrently_execute( _process_edu, - getattr(transaction, "edus", []), + transaction.edus, TRANSACTION_CONCURRENCY_LIMIT, ) @@ -538,7 +546,7 @@ class FederationServer(FederationBase): pdu = await self.handler.get_persisted_pdu(origin, event_id) if pdu: - return 200, self._transaction_from_pdus([pdu]).get_dict() + return 200, self._transaction_dict_from_pdus([pdu]) else: return 404, "" @@ -879,18 +887,20 @@ class FederationServer(FederationBase): ts_now_ms = self._clock.time_msec() return await self.store.get_user_id_for_open_id_token(token, ts_now_ms) - def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction: + def _transaction_dict_from_pdus(self, pdu_list: List[EventBase]) -> JsonDict: """Returns a new Transaction containing the given PDUs suitable for transmission. """ time_now = self._clock.time_msec() pdus = [p.get_pdu_json(time_now) for p in pdu_list] return Transaction( + # Just need a dummy transaction ID and destination since it won't be used. + transaction_id="", origin=self.server_name, pdus=pdus, origin_server_ts=int(time_now), - destination=None, - ) + destination="", + ).get_dict() async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None: """Process a PDU received in a federation /send/ transaction. diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 2f9c9bc2cd..4fead6ca29 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -45,7 +45,7 @@ class TransactionActions: `None` if we have not previously responded to this transaction or a 2-tuple of `(int, dict)` representing the response code and response body. """ - transaction_id = transaction.transaction_id # type: ignore + transaction_id = transaction.transaction_id if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") @@ -56,7 +56,7 @@ class TransactionActions: self, origin: str, transaction: Transaction, code: int, response: JsonDict ) -> None: """Persist how we responded to a transaction.""" - transaction_id = transaction.transaction_id # type: ignore + transaction_id = transaction.transaction_id if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 72a635830b..dc555cca0b 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -27,6 +27,7 @@ from synapse.logging.opentracing import ( tags, whitelisted_homeserver, ) +from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.metrics import measure_func @@ -104,13 +105,13 @@ class TransactionManager: len(edus), ) - transaction = Transaction.create_new( + transaction = Transaction( origin_server_ts=int(self.clock.time_msec()), transaction_id=txn_id, origin=self._server_name, destination=destination, - pdus=pdus, - edus=edus, + pdus=[p.get_pdu_json() for p in pdus], + edus=[edu.get_dict() for edu in edus], ) self._next_txn_id += 1 @@ -131,7 +132,7 @@ class TransactionManager: # FIXME (richardv): I also believe it no longer works. We (now?) store # "age_ts" in "unsigned" rather than at the top level. See # https://github.com/matrix-org/synapse/issues/8429. - def json_data_cb(): + def json_data_cb() -> JsonDict: data = transaction.get_dict() now = int(self.clock.time_msec()) if "pdus" in data: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 6a8d3ad4fe..90a7c16b62 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -143,7 +143,7 @@ class TransportLayerClient: """Sends the given Transaction to its destination Args: - transaction (Transaction) + transaction Returns: Succeeds when we get a 2xx HTTP response. The result diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 5e059d6e09..640f46fff6 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -450,21 +450,12 @@ class FederationSendServlet(BaseFederationServerServlet): len(transaction_data.get("edus", [])), ) - # We should ideally be getting this from the security layer. - # origin = body["origin"] - - # Add some extra data to the transaction dict that isn't included - # in the request body. - transaction_data.update( - transaction_id=transaction_id, destination=self.server_name - ) - except Exception as e: logger.exception(e) return 400, {"error": "Invalid transaction"} code, response = await self.handler.on_incoming_transaction( - origin, transaction_data + origin, transaction_id, self.server_name, transaction_data ) return code, response diff --git a/synapse/federation/units.py b/synapse/federation/units.py index c83a261918..b9b12fbea5 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -17,18 +17,17 @@ server protocol. """ import logging -from typing import Optional +from typing import List, Optional import attr from synapse.types import JsonDict -from synapse.util.jsonobject import JsonEncodedObject logger = logging.getLogger(__name__) -@attr.s(slots=True) -class Edu(JsonEncodedObject): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class Edu: """An Edu represents a piece of data sent from one homeserver to another. In comparison to Pdus, Edus are not persisted for a long time on disk, are @@ -36,10 +35,10 @@ class Edu(JsonEncodedObject): internal ID or previous references graph. """ - edu_type = attr.ib(type=str) - content = attr.ib(type=dict) - origin = attr.ib(type=str) - destination = attr.ib(type=str) + edu_type: str + content: dict + origin: str + destination: str def get_dict(self) -> JsonDict: return { @@ -55,14 +54,21 @@ class Edu(JsonEncodedObject): "destination": self.destination, } - def get_context(self): + def get_context(self) -> str: return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}") - def strip_context(self): + def strip_context(self) -> None: getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}" -class Transaction(JsonEncodedObject): +def _none_to_list(edus: Optional[List[JsonDict]]) -> List[JsonDict]: + if edus is None: + return [] + return edus + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class Transaction: """A transaction is a list of Pdus and Edus to be sent to a remote home server with some extra metadata. @@ -78,47 +84,21 @@ class Transaction(JsonEncodedObject): """ - valid_keys = [ - "transaction_id", - "origin", - "destination", - "origin_server_ts", - "previous_ids", - "pdus", - "edus", - ] - - internal_keys = ["transaction_id", "destination"] - - required_keys = [ - "transaction_id", - "origin", - "destination", - "origin_server_ts", - "pdus", - ] - - def __init__(self, transaction_id=None, pdus: Optional[list] = None, **kwargs): - """If we include a list of pdus then we decode then as PDU's - automatically. - """ - - # If there's no EDUs then remove the arg - if "edus" in kwargs and not kwargs["edus"]: - del kwargs["edus"] - - super().__init__(transaction_id=transaction_id, pdus=pdus or [], **kwargs) - - @staticmethod - def create_new(pdus, **kwargs): - """Used to create a new transaction. Will auto fill out - transaction_id and origin_server_ts keys. - """ - if "origin_server_ts" not in kwargs: - raise KeyError("Require 'origin_server_ts' to construct a Transaction") - if "transaction_id" not in kwargs: - raise KeyError("Require 'transaction_id' to construct a Transaction") - - kwargs["pdus"] = [p.get_pdu_json() for p in pdus] - - return Transaction(**kwargs) + # Required keys. + transaction_id: str + origin: str + destination: str + origin_server_ts: int + pdus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list) + edus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list) + + def get_dict(self) -> JsonDict: + """A JSON-ready dictionary of valid keys which aren't internal.""" + result = { + "origin": self.origin, + "origin_server_ts": self.origin_server_ts, + "pdus": self.pdus, + } + if self.edus: + result["edus"] = self.edus + return result diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py deleted file mode 100644 index abc12f0837..0000000000 --- a/synapse/util/jsonobject.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class JsonEncodedObject: - """A common base class for defining protocol units that are represented - as JSON. - - Attributes: - unrecognized_keys (dict): A dict containing all the key/value pairs we - don't recognize. - """ - - valid_keys = [] # keys we will store - """A list of strings that represent keys we know about - and can handle. If we have values for these keys they will be - included in the `dictionary` instance variable. - """ - - internal_keys = [] # keys to ignore while building dict - """A list of strings that should *not* be encoded into JSON. - """ - - required_keys = [] - """A list of strings that we require to exist. If they are not given upon - construction it raises an exception. - """ - - def __init__(self, **kwargs): - """Takes the dict of `kwargs` and loads all keys that are *valid* - (i.e., are included in the `valid_keys` list) into the dictionary` - instance variable. - - Any keys that aren't recognized are added to the `unrecognized_keys` - attribute. - - Args: - **kwargs: Attributes associated with this protocol unit. - """ - for required_key in self.required_keys: - if required_key not in kwargs: - raise RuntimeError("Key %s is required" % required_key) - - self.unrecognized_keys = {} # Keys we were given not listed as valid - for k, v in kwargs.items(): - if k in self.valid_keys or k in self.internal_keys: - self.__dict__[k] = v - else: - self.unrecognized_keys[k] = v - - def get_dict(self): - """Converts this protocol unit into a :py:class:`dict`, ready to be - encoded as JSON. - - The keys it encodes are: `valid_keys` - `internal_keys` - - Returns - dict - """ - d = { - k: _encode(v) - for (k, v) in self.__dict__.items() - if k in self.valid_keys and k not in self.internal_keys - } - d.update(self.unrecognized_keys) - return d - - def get_internal_dict(self): - d = { - k: _encode(v, internal=True) - for (k, v) in self.__dict__.items() - if k in self.valid_keys - } - d.update(self.unrecognized_keys) - return d - - def __str__(self): - return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__)) - - -def _encode(obj, internal=False): - if type(obj) is list: - return [_encode(o, internal=internal) for o in obj] - - if isinstance(obj, JsonEncodedObject): - if internal: - return obj.get_internal_dict() - else: - return obj.get_dict() - - return obj -- cgit 1.5.1 From ad35b7739e72fe198fa78fa4279f58cacfc9fa37 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 9 Aug 2021 13:41:29 +0100 Subject: 1.40.0rc3 --- CHANGES.md | 21 +++++++++++++++++++++ changelog.d/10449.bugfix | 1 - changelog.d/10449.feature | 1 - changelog.d/10543.doc | 1 - debian/changelog | 6 ++++++ synapse/__init__.py | 2 +- 6 files changed, 28 insertions(+), 4 deletions(-) delete mode 100644 changelog.d/10449.bugfix delete mode 100644 changelog.d/10449.feature delete mode 100644 changelog.d/10543.doc (limited to 'synapse') diff --git a/CHANGES.md b/CHANGES.md index 62ea684e58..b04abbeb4d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,24 @@ +Synapse 1.40.0rc3 (2021-08-09) +============================== + +Features +-------- + +- Support [MSC3289: room version 8](https://github.com/matrix-org/matrix-doc/pull/3289). ([\#10449](https://github.com/matrix-org/synapse/issues/10449)) + + +Bugfixes +-------- + +- Mark the experimental room version from [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) as unstable. ([\#10449](https://github.com/matrix-org/synapse/issues/10449)) + + +Improved Documentation +---------------------- + +- Fix broken links in `upgrade.md`. Contributed by @dklimpel. ([\#10543](https://github.com/matrix-org/synapse/issues/10543)) + + Synapse 1.40.0rc2 (2021-08-04) ============================== diff --git a/changelog.d/10449.bugfix b/changelog.d/10449.bugfix deleted file mode 100644 index c5e23ba019..0000000000 --- a/changelog.d/10449.bugfix +++ /dev/null @@ -1 +0,0 @@ -Mark the experimental room version from [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) as unstable. diff --git a/changelog.d/10449.feature b/changelog.d/10449.feature deleted file mode 100644 index a45a17cb28..0000000000 --- a/changelog.d/10449.feature +++ /dev/null @@ -1 +0,0 @@ -Support [MSC3289: room version 8](https://github.com/matrix-org/matrix-doc/pull/3289). diff --git a/changelog.d/10543.doc b/changelog.d/10543.doc deleted file mode 100644 index 6c06722eb4..0000000000 --- a/changelog.d/10543.doc +++ /dev/null @@ -1 +0,0 @@ -Fix broken links in `upgrade.md`. Contributed by @dklimpel. diff --git a/debian/changelog b/debian/changelog index c523101f9a..7b44341bc6 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.40.0~rc3) stable; urgency=medium + + * New synapse release 1.40.0~rc3. + + -- Synapse Packaging team Mon, 09 Aug 2021 13:41:08 +0100 + matrix-synapse-py3 (1.40.0~rc2) stable; urgency=medium * New synapse release 1.40.0~rc2. diff --git a/synapse/__init__.py b/synapse/__init__.py index da52463531..5cca899f7d 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.40.0rc2" +__version__ = "1.40.0rc3" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.5.1 From 6b61debf5cf571ae9e230b102c758865eee2a788 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 9 Aug 2021 18:21:04 +0200 Subject: Do not remove `status_msg` when user going offline (#10550) Signed-off-by: Dirk Klimpel dirk@klimpel.org --- changelog.d/10550.bugfix | 1 + synapse/handlers/presence.py | 11 +-- tests/handlers/test_presence.py | 163 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 166 insertions(+), 9 deletions(-) create mode 100644 changelog.d/10550.bugfix (limited to 'synapse') diff --git a/changelog.d/10550.bugfix b/changelog.d/10550.bugfix new file mode 100644 index 0000000000..2e1b7c8bbb --- /dev/null +++ b/changelog.d/10550.bugfix @@ -0,0 +1 @@ +Fix longstanding bug which caused the user "status" to be reset when the user went offline. Contributed by @dklimpel. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 016c5df2ca..7ca14e1d84 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1184,8 +1184,7 @@ class PresenceHandler(BasePresenceHandler): new_fields = {"state": presence} if not ignore_status_msg: - msg = status_msg if presence != PresenceState.OFFLINE else None - new_fields["status_msg"] = msg + new_fields["status_msg"] = status_msg if presence == PresenceState.ONLINE or ( presence == PresenceState.BUSY and self._busy_presence_enabled @@ -1478,7 +1477,7 @@ def format_user_presence_state( content["user_id"] = state.user_id if state.last_active_ts: content["last_active_ago"] = now - state.last_active_ts - if state.status_msg and state.state != PresenceState.OFFLINE: + if state.status_msg: content["status_msg"] = state.status_msg if state.state == PresenceState.ONLINE: content["currently_active"] = state.currently_active @@ -1840,9 +1839,7 @@ def handle_timeout( # don't set them as offline. sync_or_active = max(state.last_user_sync_ts, state.last_active_ts) if now - sync_or_active > SYNC_ONLINE_TIMEOUT: - state = state.copy_and_replace( - state=PresenceState.OFFLINE, status_msg=None - ) + state = state.copy_and_replace(state=PresenceState.OFFLINE) changed = True else: # We expect to be poked occasionally by the other side. @@ -1850,7 +1847,7 @@ def handle_timeout( # no one gets stuck online forever. if now - state.last_federation_update_ts > FEDERATION_TIMEOUT: # The other side seems to have disappeared. - state = state.copy_and_replace(state=PresenceState.OFFLINE, status_msg=None) + state = state.copy_and_replace(state=PresenceState.OFFLINE) changed = True return state if changed else None diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 18e92e90d7..29845a80da 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Optional from unittest.mock import Mock, call from signedjson.key import generate_signing_key @@ -339,8 +339,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): class PresenceTimeoutTestCase(unittest.TestCase): + """Tests different timers and that the timer does not change `status_msg` of user.""" + def test_idle_timer(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -348,12 +351,14 @@ class PresenceTimeoutTestCase(unittest.TestCase): state=PresenceState.ONLINE, last_active_ts=now - IDLE_TIMER - 1, last_user_sync_ts=now, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.UNAVAILABLE) + self.assertEquals(new_state.status_msg, status_msg) def test_busy_no_idle(self): """ @@ -361,6 +366,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): presence state into unavailable. """ user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -368,15 +374,18 @@ class PresenceTimeoutTestCase(unittest.TestCase): state=PresenceState.BUSY, last_active_ts=now - IDLE_TIMER - 1, last_user_sync_ts=now, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.BUSY) + self.assertEquals(new_state.status_msg, status_msg) def test_sync_timeout(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -384,15 +393,18 @@ class PresenceTimeoutTestCase(unittest.TestCase): state=PresenceState.ONLINE, last_active_ts=0, last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.OFFLINE) + self.assertEquals(new_state.status_msg, status_msg) def test_sync_online(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -400,6 +412,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): state=PresenceState.ONLINE, last_active_ts=now - SYNC_ONLINE_TIMEOUT - 1, last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, + status_msg=status_msg, ) new_state = handle_timeout( @@ -408,9 +421,11 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.ONLINE) + self.assertEquals(new_state.status_msg, status_msg) def test_federation_ping(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -419,12 +434,13 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_active_ts=now, last_user_sync_ts=now, last_federation_update_ts=now - FEDERATION_PING_INTERVAL - 1, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(new_state, new_state) + self.assertEquals(state, new_state) def test_no_timeout(self): user_id = "@foo:bar" @@ -444,6 +460,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): def test_federation_timeout(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -452,6 +469,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_active_ts=now, last_user_sync_ts=now, last_federation_update_ts=now - FEDERATION_TIMEOUT - 1, + status_msg=status_msg, ) new_state = handle_timeout( @@ -460,9 +478,11 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.OFFLINE) + self.assertEquals(new_state.status_msg, status_msg) def test_last_active(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -471,6 +491,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_active_ts=now - LAST_ACTIVE_GRANULARITY - 1, last_user_sync_ts=now, last_federation_update_ts=now, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) @@ -516,6 +537,144 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase): ) self.assertEqual(state.state, PresenceState.OFFLINE) + def test_user_goes_offline_by_timeout_status_msg_remain(self): + """Test that if a user doesn't update the records for a while + users presence goes `OFFLINE` because of timeout and `status_msg` remains. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Check that if we wait a while without telling the handler the user has + # stopped syncing that their presence state doesn't get timed out. + self.reactor.advance(SYNC_ONLINE_TIMEOUT / 2) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.ONLINE) + self.assertEqual(state.status_msg, status_msg) + + # Check that if the timeout fires, then the syncing user gets timed out + self.reactor.advance(SYNC_ONLINE_TIMEOUT) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + # status_msg should remain even after going offline + self.assertEqual(state.state, PresenceState.OFFLINE) + self.assertEqual(state.status_msg, status_msg) + + def test_user_goes_offline_manually_with_no_status_msg(self): + """Test that if a user change presence manually to `OFFLINE` + and no status is set, that `status_msg` is `None`. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as offline + self.get_success( + self.presence_handler.set_state( + UserID.from_string(user_id), {"presence": PresenceState.OFFLINE} + ) + ) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.OFFLINE) + self.assertEqual(state.status_msg, None) + + def test_user_goes_offline_manually_with_status_msg(self): + """Test that if a user change presence manually to `OFFLINE` + and a status is set, that `status_msg` appears. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as offline + self._set_presencestate_with_status_msg( + user_id, PresenceState.OFFLINE, "And now here." + ) + + def test_user_reset_online_with_no_status(self): + """Test that if a user set again the presence manually + and no status is set, that `status_msg` is `None`. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as online again + self.get_success( + self.presence_handler.set_state( + UserID.from_string(user_id), {"presence": PresenceState.ONLINE} + ) + ) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + # status_msg should remain even after going offline + self.assertEqual(state.state, PresenceState.ONLINE) + self.assertEqual(state.status_msg, None) + + def test_set_presence_with_status_msg_none(self): + """Test that if a user set again the presence manually + and status is `None`, that `status_msg` is `None`. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as online and `status_msg = None` + self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None) + + def _set_presencestate_with_status_msg( + self, user_id: str, state: PresenceState, status_msg: Optional[str] + ): + """Set a PresenceState and status_msg and check the result. + + Args: + user_id: User for that the status is to be set. + PresenceState: The new PresenceState. + status_msg: Status message that is to be set. + """ + self.get_success( + self.presence_handler.set_state( + UserID.from_string(user_id), + {"presence": state, "status_msg": status_msg}, + ) + ) + + new_state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(new_state.state, state) + self.assertEqual(new_state.status_msg, status_msg) + class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): -- cgit 1.5.1 From 7afb615839a2df05d39f87718016d278ebdadf5c Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 9 Aug 2021 20:23:31 -0500 Subject: When redacting, keep event fields around that maintain the historical event structure intact (MSC2716) (#10538) * Keep event fields that maintain the historical event structure intact Fix https://github.com/matrix-org/synapse/issues/10521 * Add changelog * Bump room version * Better changelog text * Fix up room version after develop merge --- changelog.d/10538.feature | 1 + synapse/api/room_versions.py | 37 ++++++++++++++++++++++++++++++++----- synapse/events/utils.py | 8 +++++++- 3 files changed, 40 insertions(+), 6 deletions(-) create mode 100644 changelog.d/10538.feature (limited to 'synapse') diff --git a/changelog.d/10538.feature b/changelog.d/10538.feature new file mode 100644 index 0000000000..120c8e8ca0 --- /dev/null +++ b/changelog.d/10538.feature @@ -0,0 +1 @@ +Add support for new redaction rules for historical events specified in [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716). diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index f32a40ba4a..11280c4462 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -76,6 +76,8 @@ class RoomVersion: # MSC2716: Adds m.room.power_levels -> content.historical field to control # whether "insertion", "chunk", "marker" events can be sent msc2716_historical = attr.ib(type=bool) + # MSC2716: Adds support for redacting "insertion", "chunk", and "marker" events + msc2716_redactions = attr.ib(type=bool) class RoomVersions: @@ -92,6 +94,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) V2 = RoomVersion( "2", @@ -106,6 +109,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) V3 = RoomVersion( "3", @@ -120,6 +124,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) V4 = RoomVersion( "4", @@ -134,6 +139,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) V5 = RoomVersion( "5", @@ -148,6 +154,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) V6 = RoomVersion( "6", @@ -162,6 +169,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) MSC2176 = RoomVersion( "org.matrix.msc2176", @@ -176,6 +184,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) V7 = RoomVersion( "7", @@ -190,6 +199,22 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=True, msc2716_historical=False, + msc2716_redactions=False, + ) + V8 = RoomVersion( + "8", + RoomDisposition.STABLE, + EventFormatVersions.V3, + StateResolutionVersions.V2, + enforce_key_validity=True, + special_case_aliases_auth=False, + strict_canonicaljson=True, + limit_notifications_power_levels=True, + msc2176_redaction_rules=False, + msc3083_join_rules=True, + msc2403_knocking=True, + msc2716_historical=False, + msc2716_redactions=False, ) MSC2716 = RoomVersion( "org.matrix.msc2716", @@ -204,10 +229,11 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=True, msc2716_historical=True, + msc2716_redactions=False, ) - V8 = RoomVersion( - "8", - RoomDisposition.STABLE, + MSC2716v2 = RoomVersion( + "org.matrix.msc2716v2", + RoomDisposition.UNSTABLE, EventFormatVersions.V3, StateResolutionVersions.V2, enforce_key_validity=True, @@ -215,9 +241,10 @@ class RoomVersions: strict_canonicaljson=True, limit_notifications_power_levels=True, msc2176_redaction_rules=False, - msc3083_join_rules=True, + msc3083_join_rules=False, msc2403_knocking=True, - msc2716_historical=False, + msc2716_historical=True, + msc2716_redactions=True, ) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index a0c07f62f4..b6da2f60af 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -17,7 +17,7 @@ from typing import Any, Mapping, Union from frozendict import frozendict -from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion from synapse.util.async_helpers import yieldable_gather_results @@ -135,6 +135,12 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: add_fields("history_visibility") elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules: add_fields("redacts") + elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_INSERTION: + add_fields(EventContentFields.MSC2716_NEXT_CHUNK_ID) + elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_CHUNK: + add_fields(EventContentFields.MSC2716_CHUNK_ID) + elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_MARKER: + add_fields(EventContentFields.MSC2716_MARKER_INSERTION) allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys} -- cgit 1.5.1 From 9f7c038272318bab09535e85e6bb4345ed2f1368 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 10 Aug 2021 13:50:58 +0100 Subject: 1.40.0 --- CHANGES.md | 6 ++++++ debian/changelog | 6 ++++++ synapse/__init__.py | 2 +- 3 files changed, 13 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/CHANGES.md b/CHANGES.md index b04abbeb4d..0e5e052951 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,9 @@ +Synapse 1.40.0 (2021-08-10) +=========================== + +No significant changes. + + Synapse 1.40.0rc3 (2021-08-09) ============================== diff --git a/debian/changelog b/debian/changelog index 7b44341bc6..d3da448b0f 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.40.0) stable; urgency=medium + + * New synapse release 1.40.0. + + -- Synapse Packaging team Tue, 10 Aug 2021 13:50:48 +0100 + matrix-synapse-py3 (1.40.0~rc3) stable; urgency=medium * New synapse release 1.40.0~rc3. diff --git a/synapse/__init__.py b/synapse/__init__.py index 5cca899f7d..919293cd80 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.40.0rc3" +__version__ = "1.40.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.5.1 From 691593bf719edb4c8b0d7a6bee95fcb41d0c56ae Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 10 Aug 2021 10:56:54 -0400 Subject: Fix an edge-case with invited rooms over federation in the spaces summary. (#10560) If a room which the requesting user was invited to was queried over federation it will now properly appear in the spaces summary (instead of being stripped out by the requesting server). --- changelog.d/10560.feature | 1 + synapse/handlers/space_summary.py | 93 ++++++++++++++++-------------- tests/handlers/test_space_summary.py | 106 ++++++++++++++++++++++++++++------- 3 files changed, 138 insertions(+), 62 deletions(-) create mode 100644 changelog.d/10560.feature (limited to 'synapse') diff --git a/changelog.d/10560.feature b/changelog.d/10560.feature new file mode 100644 index 0000000000..ffc4e4289c --- /dev/null +++ b/changelog.d/10560.feature @@ -0,0 +1 @@ +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 2517f278b6..d04afe6c31 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -158,48 +158,10 @@ class SpaceSummaryHandler: room = room_entry.room fed_room_id = room_entry.room_id - # The room should only be included in the summary if: - # a. the user is in the room; - # b. the room is world readable; or - # c. the user could join the room, e.g. the join rules - # are set to public or the user is in a space that - # has been granted access to the room. - # - # Note that we know the user is not in the root room (which is - # why the remote call was made in the first place), but the user - # could be in one of the children rooms and we just didn't know - # about the link. - - # The API doesn't return the room version so assume that a - # join rule of knock is valid. - include_room = ( - room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK) - or room.get("world_readable") is True - ) - - # Check if the user is a member of any of the allowed spaces - # from the response. - allowed_rooms = room.get("allowed_room_ids") or room.get( - "allowed_spaces" - ) - if ( - not include_room - and allowed_rooms - and isinstance(allowed_rooms, list) - ): - include_room = await self._event_auth_handler.is_user_in_rooms( - allowed_rooms, requester - ) - - # Finally, if this isn't the requested room, check ourselves - # if we can access the room. - if not include_room and fed_room_id != queue_entry.room_id: - include_room = await self._is_room_accessible( - fed_room_id, requester, None - ) - # The user can see the room, include it! - if include_room: + if await self._is_remote_room_accessible( + requester, fed_room_id, room + ): # Before returning to the client, remove the allowed_room_ids # and allowed_spaces keys. room.pop("allowed_room_ids", None) @@ -336,7 +298,7 @@ class SpaceSummaryHandler: Returns: A room entry if the room should be returned. None, otherwise. """ - if not await self._is_room_accessible(room_id, requester, origin): + if not await self._is_local_room_accessible(room_id, requester, origin): return None room_entry = await self._build_room_entry(room_id, for_federation=bool(origin)) @@ -438,7 +400,7 @@ class SpaceSummaryHandler: return results - async def _is_room_accessible( + async def _is_local_room_accessible( self, room_id: str, requester: Optional[str], origin: Optional[str] ) -> bool: """ @@ -550,6 +512,51 @@ class SpaceSummaryHandler: ) return False + async def _is_remote_room_accessible( + self, requester: str, room_id: str, room: JsonDict + ) -> bool: + """ + Calculate whether the room received over federation should be shown in the spaces summary. + + It should be included if: + + * The requester is joined or can join the room (per MSC3173). + * The history visibility is set to world readable. + + Note that the local server is not in the requested room (which is why the + remote call was made in the first place), but the user could have access + due to an invite, etc. + + Args: + requester: The user requesting the summary. + room_id: The room ID returned over federation. + room: The summary of the child room returned over federation. + + Returns: + True if the room should be included in the spaces summary. + """ + # The API doesn't return the room version so assume that a + # join rule of knock is valid. + if ( + room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK) + or room.get("world_readable") is True + ): + return True + + # Check if the user is a member of any of the allowed spaces + # from the response. + allowed_rooms = room.get("allowed_room_ids") or room.get("allowed_spaces") + if allowed_rooms and isinstance(allowed_rooms, list): + if await self._event_auth_handler.is_user_in_rooms( + allowed_rooms, requester + ): + return True + + # Finally, check locally if we can access the room. The user might + # already be in the room (if it was a child room), or there might be a + # pending invite, etc. + return await self._is_local_room_accessible(room_id, requester, None) + async def _build_room_entry(self, room_id: str, for_federation: bool) -> JsonDict: """ Generate en entry suitable for the 'rooms' list in the summary response. diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index 6cc1a02e12..f470c81ea2 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -30,7 +30,7 @@ from synapse.handlers.space_summary import _child_events_comparison_key, _RoomEn from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.server import HomeServer -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from tests import unittest @@ -149,6 +149,36 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): events, ) + def _poke_fed_invite(self, room_id: str, from_user: str) -> None: + """ + Creates a invite (as if received over federation) for the room from the + given hostname. + + Args: + room_id: The room ID to issue an invite for. + fed_hostname: The user to invite from. + """ + # Poke an invite over federation into the database. + fed_handler = self.hs.get_federation_handler() + fed_hostname = UserID.from_string(from_user).domain + event = make_event_from_dict( + { + "room_id": room_id, + "event_id": "!abcd:" + fed_hostname, + "type": EventTypes.Member, + "sender": from_user, + "state_key": self.user, + "content": {"membership": Membership.INVITE}, + "prev_events": [], + "auth_events": [], + "depth": 1, + "origin_server_ts": 1234, + } + ) + self.get_success( + fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6) + ) + def test_simple_space(self): """Test a simple space with a single room.""" result = self.get_success(self.handler.get_space_summary(self.user, self.space)) @@ -416,24 +446,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): joined_room = self.helper.create_room_as(self.user, tok=self.token) # Poke an invite over federation into the database. - fed_handler = self.hs.get_federation_handler() - event = make_event_from_dict( - { - "room_id": invited_room, - "event_id": "!abcd:" + fed_hostname, - "type": EventTypes.Member, - "sender": "@remote:" + fed_hostname, - "state_key": self.user, - "content": {"membership": Membership.INVITE}, - "prev_events": [], - "auth_events": [], - "depth": 1, - "origin_server_ts": 1234, - } - ) - self.get_success( - fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6) - ) + self._poke_fed_invite(invited_room, "@remote:" + fed_hostname) async def summarize_remote_room( _self, room, suggested_only, max_children, exclude_rooms @@ -570,3 +583,58 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): (subspace, joined_room), ], ) + + def test_fed_invited(self): + """ + A room which the user was invited to should be included in the response. + + This differs from test_fed_filtering in that the room itself is being + queried over federation, instead of it being included as a sub-room of + a space in the response. + """ + fed_hostname = self.hs.hostname + "2" + fed_room = "#subroom:" + fed_hostname + + # Poke an invite over federation into the database. + self._poke_fed_invite(fed_room, "@remote:" + fed_hostname) + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + return [ + _RoomEntry( + fed_room, + { + "room_id": fed_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ] + + # Add a room to the space which is on another server. + self._add_child(self.space, fed_room, self.token) + + with mock.patch( + "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + self._assert_rooms( + result, + [ + self.space, + self.room, + fed_room, + ], + ) + self._assert_events( + result, + [ + (self.space, self.room), + (self.space, fed_room), + ], + ) -- cgit 1.5.1 From fe1d0c86180ea025dfb444597c7ad72b036bbb10 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 10 Aug 2021 13:08:17 -0400 Subject: Add local support for the new spaces summary endpoint (MSC2946) (#10549) This adds support for the /hierarchy endpoint, which is an update to MSC2946. Currently this only supports rooms known locally to the homeserver. --- changelog.d/10527.misc | 2 +- changelog.d/10530.misc | 2 +- changelog.d/10549.feature | 1 + synapse/handlers/space_summary.py | 201 +++++++++++++++++- synapse/rest/client/v1/room.py | 41 ++++ tests/handlers/test_space_summary.py | 386 +++++++++++++++++++++++++---------- 6 files changed, 521 insertions(+), 112 deletions(-) create mode 100644 changelog.d/10549.feature (limited to 'synapse') diff --git a/changelog.d/10527.misc b/changelog.d/10527.misc index 3cf22f9daf..ffc4e4289c 100644 --- a/changelog.d/10527.misc +++ b/changelog.d/10527.misc @@ -1 +1 @@ -Prepare for the new spaces summary endpoint (updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946)). +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/changelog.d/10530.misc b/changelog.d/10530.misc index 3cf22f9daf..ffc4e4289c 100644 --- a/changelog.d/10530.misc +++ b/changelog.d/10530.misc @@ -1 +1 @@ -Prepare for the new spaces summary endpoint (updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946)). +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/changelog.d/10549.feature b/changelog.d/10549.feature new file mode 100644 index 0000000000..ffc4e4289c --- /dev/null +++ b/changelog.d/10549.feature @@ -0,0 +1 @@ +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index d04afe6c31..fd76c34695 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -18,7 +18,7 @@ import re from collections import deque from typing import ( TYPE_CHECKING, - Collection, + Deque, Dict, Iterable, List, @@ -38,9 +38,12 @@ from synapse.api.constants import ( Membership, RoomTypes, ) +from synapse.api.errors import Codes, SynapseError from synapse.events import EventBase from synapse.events.utils import format_event_for_client_v2 from synapse.types import JsonDict +from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import random_string if TYPE_CHECKING: from synapse.server import HomeServer @@ -57,6 +60,29 @@ MAX_ROOMS_PER_SPACE = 50 MAX_SERVERS_PER_SPACE = 3 +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _PaginationKey: + """The key used to find unique pagination session.""" + + # The first three entries match the request parameters (and cannot change + # during a pagination session). + room_id: str + suggested_only: bool + max_depth: Optional[int] + # The randomly generated token. + token: str + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _PaginationSession: + """The information that is stored for pagination.""" + + # The queue of rooms which are still to process. + room_queue: Deque["_RoomQueueEntry"] + # A set of rooms which have been processed. + processed_rooms: Set[str] + + class SpaceSummaryHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() @@ -67,6 +93,21 @@ class SpaceSummaryHandler: self._server_name = hs.hostname self._federation_client = hs.get_federation_client() + # A map of query information to the current pagination state. + # + # TODO Allow for multiple workers to share this data. + # TODO Expire pagination tokens. + self._pagination_sessions: Dict[_PaginationKey, _PaginationSession] = {} + + # If a user tries to fetch the same page multiple times in quick succession, + # only process the first attempt and return its result to subsequent requests. + self._pagination_response_cache: ResponseCache[ + Tuple[str, bool, Optional[int], Optional[int], Optional[str]] + ] = ResponseCache( + hs.get_clock(), + "get_room_hierarchy", + ) + async def get_space_summary( self, requester: str, @@ -130,7 +171,7 @@ class SpaceSummaryHandler: requester, None, room_id, suggested_only, max_children ) - events: Collection[JsonDict] = [] + events: Sequence[JsonDict] = [] if room_entry: rooms_result.append(room_entry.room) events = room_entry.children @@ -207,6 +248,154 @@ class SpaceSummaryHandler: return {"rooms": rooms_result, "events": events_result} + async def get_room_hierarchy( + self, + requester: str, + requested_room_id: str, + suggested_only: bool = False, + max_depth: Optional[int] = None, + limit: Optional[int] = None, + from_token: Optional[str] = None, + ) -> JsonDict: + """ + Implementation of the room hierarchy C-S API. + + Args: + requester: The user ID of the user making this request. + requested_room_id: The room ID to start the hierarchy at (the "root" room). + suggested_only: Whether we should only return children with the "suggested" + flag set. + max_depth: The maximum depth in the tree to explore, must be a + non-negative integer. + + 0 would correspond to just the root room, 1 would include just + the root room's children, etc. + limit: An optional limit on the number of rooms to return per + page. Must be a positive integer. + from_token: An optional pagination token. + + Returns: + The JSON hierarchy dictionary. + """ + # If a user tries to fetch the same page multiple times in quick succession, + # only process the first attempt and return its result to subsequent requests. + # + # This is due to the pagination process mutating internal state, attempting + # to process multiple requests for the same page will result in errors. + return await self._pagination_response_cache.wrap( + (requested_room_id, suggested_only, max_depth, limit, from_token), + self._get_room_hierarchy, + requester, + requested_room_id, + suggested_only, + max_depth, + limit, + from_token, + ) + + async def _get_room_hierarchy( + self, + requester: str, + requested_room_id: str, + suggested_only: bool = False, + max_depth: Optional[int] = None, + limit: Optional[int] = None, + from_token: Optional[str] = None, + ) -> JsonDict: + """See docstring for SpaceSummaryHandler.get_room_hierarchy.""" + + # first of all, check that the user is in the room in question (or it's + # world-readable) + await self._auth.check_user_in_room_or_world_readable( + requested_room_id, requester + ) + + # If this is continuing a previous session, pull the persisted data. + if from_token: + pagination_key = _PaginationKey( + requested_room_id, suggested_only, max_depth, from_token + ) + if pagination_key not in self._pagination_sessions: + raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM) + + # Load the previous state. + pagination_session = self._pagination_sessions[pagination_key] + room_queue = pagination_session.room_queue + processed_rooms = pagination_session.processed_rooms + else: + # the queue of rooms to process + room_queue = deque((_RoomQueueEntry(requested_room_id, ()),)) + + # Rooms we have already processed. + processed_rooms = set() + + rooms_result: List[JsonDict] = [] + + # Cap the limit to a server-side maximum. + if limit is None: + limit = MAX_ROOMS + else: + limit = min(limit, MAX_ROOMS) + + # Iterate through the queue until we reach the limit or run out of + # rooms to include. + while room_queue and len(rooms_result) < limit: + queue_entry = room_queue.popleft() + room_id = queue_entry.room_id + current_depth = queue_entry.depth + if room_id in processed_rooms: + # already done this room + continue + + logger.debug("Processing room %s", room_id) + + is_in_room = await self._store.is_host_joined(room_id, self._server_name) + if is_in_room: + room_entry = await self._summarize_local_room( + requester, + None, + room_id, + suggested_only, + # TODO Handle max children. + max_children=None, + ) + + if room_entry: + rooms_result.append(room_entry.as_json()) + + # Add the child to the queue. We have already validated + # that the vias are a list of server names. + # + # If the current depth is the maximum depth, do not queue + # more entries. + if max_depth is None or current_depth < max_depth: + room_queue.extendleft( + _RoomQueueEntry( + ev["state_key"], ev["content"]["via"], current_depth + 1 + ) + for ev in reversed(room_entry.children) + ) + + processed_rooms.add(room_id) + else: + # TODO Federation. + pass + + result: JsonDict = {"rooms": rooms_result} + + # If there's additional data, generate a pagination token (and persist state). + if room_queue: + next_token = random_string(24) + result["next_token"] = next_token + pagination_key = _PaginationKey( + requested_room_id, suggested_only, max_depth, next_token + ) + self._pagination_sessions[pagination_key] = _PaginationSession( + room_queue, processed_rooms + ) + + return result + async def federation_space_summary( self, origin: str, @@ -652,6 +841,7 @@ class SpaceSummaryHandler: class _RoomQueueEntry: room_id: str via: Sequence[str] + depth: int = 0 @attr.s(frozen=True, slots=True, auto_attribs=True) @@ -662,7 +852,12 @@ class _RoomEntry: # An iterable of the sorted, stripped children events for children of this room. # # This may not include all children. - children: Collection[JsonDict] = () + children: Sequence[JsonDict] = () + + def as_json(self) -> JsonDict: + result = dict(self.room) + result["children_state"] = self.children + return result def _has_valid_via(e: EventBase) -> bool: diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index f887970b76..b28b72bfbd 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -1445,6 +1445,46 @@ class RoomSpaceSummaryRestServlet(RestServlet): ) +class RoomHierarchyRestServlet(RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc2946" + "/rooms/(?P[^/]*)/hierarchy$" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._auth = hs.get_auth() + self._space_summary_handler = hs.get_space_summary_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + + max_depth = parse_integer(request, "max_depth") + if max_depth is not None and max_depth < 0: + raise SynapseError( + 400, "'max_depth' must be a non-negative integer", Codes.BAD_JSON + ) + + limit = parse_integer(request, "limit") + if limit is not None and limit <= 0: + raise SynapseError( + 400, "'limit' must be a positive integer", Codes.BAD_JSON + ) + + return 200, await self._space_summary_handler.get_room_hierarchy( + requester.user.to_string(), + room_id, + suggested_only=parse_boolean(request, "suggested_only", default=False), + max_depth=max_depth, + limit=limit, + from_token=parse_string(request, "from"), + ) + + def register_servlets(hs: "HomeServer", http_server, is_worker=False): msc2716_enabled = hs.config.experimental.msc2716_enabled @@ -1463,6 +1503,7 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False): RoomTypingRestServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) RoomSpaceSummaryRestServlet(hs).register(http_server) + RoomHierarchyRestServlet(hs).register(http_server) RoomEventServlet(hs).register(http_server) JoinedRoomsRestServlet(hs).register(http_server) RoomAliasListServlet(hs).register(http_server) diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index f470c81ea2..255dd17f86 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -23,7 +23,7 @@ from synapse.api.constants import ( RestrictedJoinRuleTypes, RoomTypes, ) -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.handlers.space_summary import _child_events_comparison_key, _RoomEntry @@ -123,32 +123,83 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self.room = self.helper.create_room_as(self.user, tok=self.token) self._add_child(self.space, self.room, self.token) - def _add_child(self, space_id: str, room_id: str, token: str) -> None: + def _add_child( + self, space_id: str, room_id: str, token: str, order: Optional[str] = None + ) -> None: """Add a child room to a space.""" + content = {"via": [self.hs.hostname]} + if order is not None: + content["order"] = order self.helper.send_state( space_id, event_type=EventTypes.SpaceChild, - body={"via": [self.hs.hostname]}, + body=content, tok=token, state_key=room_id, ) - def _assert_rooms(self, result: JsonDict, rooms: Iterable[str]) -> None: - """Assert that the expected room IDs are in the response.""" - self.assertCountEqual([room.get("room_id") for room in result["rooms"]], rooms) - - def _assert_events( - self, result: JsonDict, events: Iterable[Tuple[str, str]] + def _assert_rooms( + self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] ) -> None: - """Assert that the expected parent / child room IDs are in the response.""" + """ + Assert that the expected room IDs and events are in the response. + + Args: + result: The result from the API call. + rooms_and_children: An iterable of tuples where each tuple is: + The expected room ID. + The expected IDs of any children rooms. + """ + room_ids = [] + children_ids = [] + for room_id, children in rooms_and_children: + room_ids.append(room_id) + if children: + children_ids.extend([(room_id, child_id) for child_id in children]) + self.assertCountEqual( + [room.get("room_id") for room in result["rooms"]], room_ids + ) self.assertCountEqual( [ (event.get("room_id"), event.get("state_key")) for event in result["events"] ], - events, + children_ids, ) + def _assert_hierarchy( + self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] + ) -> None: + """ + Assert that the expected room IDs are in the response. + + Args: + result: The result from the API call. + rooms_and_children: An iterable of tuples where each tuple is: + The expected room ID. + The expected IDs of any children rooms. + """ + result_room_ids = [] + result_children_ids = [] + for result_room in result["rooms"]: + result_room_ids.append(result_room["room_id"]) + result_children_ids.append( + [ + (cs["room_id"], cs["state_key"]) + for cs in result_room.get("children_state") + ] + ) + + room_ids = [] + children_ids = [] + for room_id, children in rooms_and_children: + room_ids.append(room_id) + children_ids.append([(room_id, child_id) for child_id in children]) + + # Note that order matters. + self.assertEqual(result_room_ids, room_ids) + self.assertEqual(result_children_ids, children_ids) + def _poke_fed_invite(self, room_id: str, from_user: str) -> None: """ Creates a invite (as if received over federation) for the room from the @@ -184,8 +235,13 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_space_summary(self.user, self.space)) # The result should have the space and the room in it, along with a link # from space -> room. - self._assert_rooms(result, [self.space, self.room]) - self._assert_events(result, [(self.space, self.room)]) + expected = [(self.space, [self.room]), (self.room, ())] + self._assert_rooms(result, expected) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) def test_visibility(self): """A user not in a space cannot inspect it.""" @@ -194,6 +250,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # The user cannot see the space. self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) + self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) # If the space is made world-readable it should return a result. self.helper.send_state( @@ -203,8 +260,11 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): tok=self.token, ) result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, [self.space, self.room]) - self._assert_events(result, [(self.space, self.room)]) + expected = [(self.space, [self.room]), (self.room, ())] + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) # Make it not world-readable again and confirm it results in an error. self.helper.send_state( @@ -214,12 +274,15 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): tok=self.token, ) self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) + self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) # Join the space and results should be returned. self.helper.join(self.space, user2, tok=token2) result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, [self.space, self.room]) - self._assert_events(result, [(self.space, self.room)]) + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) def _create_room_with_join_rule( self, join_rule: str, room_version: Optional[str] = None, **extra_content @@ -290,34 +353,33 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Join the space. self.helper.join(self.space, user2, tok=token2) result = self.get_success(self.handler.get_space_summary(user2, self.space)) - - self._assert_rooms( - result, - [ + expected = [ + ( self.space, - self.room, - public_room, - knock_room, - invited_room, - restricted_accessible_room, - world_readable_room, - joined_room, - ], - ) - self._assert_events( - result, - [ - (self.space, self.room), - (self.space, public_room), - (self.space, knock_room), - (self.space, not_invited_room), - (self.space, invited_room), - (self.space, restricted_room), - (self.space, restricted_accessible_room), - (self.space, world_readable_room), - (self.space, joined_room), - ], - ) + [ + self.room, + public_room, + knock_room, + not_invited_room, + invited_room, + restricted_room, + restricted_accessible_room, + world_readable_room, + joined_room, + ], + ), + (self.room, ()), + (public_room, ()), + (knock_room, ()), + (invited_room, ()), + (restricted_accessible_room, ()), + (world_readable_room, ()), + (joined_room, ()), + ] + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) def test_complex_space(self): """ @@ -349,19 +411,145 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_space_summary(self.user, self.space)) # The result should include each room a single time and each link. - self._assert_rooms(result, [self.space, self.room, subspace, subroom]) - self._assert_events( - result, - [ - (self.space, self.room), - (self.space, room2), - (self.space, subspace), - (subspace, subroom), - (subspace, self.room), - (subspace, room2), - ], + expected = [ + (self.space, [self.room, room2, subspace]), + (self.room, ()), + (subspace, [subroom, self.room, room2]), + (subroom, ()), + ] + self._assert_rooms(result, expected) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + def test_pagination(self): + """Test simple pagination works.""" + room_ids = [] + for i in range(1, 10): + room = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, room, self.token, order=str(i)) + room_ids.append(room) + # The room created initially doesn't have an order, so comes last. + room_ids.append(self.room) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, limit=7) + ) + # The result should have the space and all of the links, plus some of the + # rooms and a pagination token. + expected = [(self.space, room_ids)] + [ + (room_id, ()) for room_id in room_ids[:6] + ] + self._assert_hierarchy(result, expected) + self.assertIn("next_token", result) + + # Check the next page. + result = self.get_success( + self.handler.get_room_hierarchy( + self.user, self.space, limit=5, from_token=result["next_token"] + ) + ) + # The result should have the space and the room in it, along with a link + # from space -> room. + expected = [(room_id, ()) for room_id in room_ids[6:]] + self._assert_hierarchy(result, expected) + self.assertNotIn("next_token", result) + + def test_invalid_pagination_token(self): + """""" + room_ids = [] + for i in range(1, 10): + room = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, room, self.token, order=str(i)) + room_ids.append(room) + # The room created initially doesn't have an order, so comes last. + room_ids.append(self.room) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, limit=7) + ) + self.assertIn("next_token", result) + + # Changing the room ID, suggested-only, or max-depth causes an error. + self.get_failure( + self.handler.get_room_hierarchy( + self.user, self.room, from_token=result["next_token"] + ), + SynapseError, + ) + self.get_failure( + self.handler.get_room_hierarchy( + self.user, + self.space, + suggested_only=True, + from_token=result["next_token"], + ), + SynapseError, + ) + self.get_failure( + self.handler.get_room_hierarchy( + self.user, self.space, max_depth=0, from_token=result["next_token"] + ), + SynapseError, ) + # An invalid token is ignored. + self.get_failure( + self.handler.get_room_hierarchy(self.user, self.space, from_token="foo"), + SynapseError, + ) + + def test_max_depth(self): + """Create a deep tree to test the max depth against.""" + spaces = [self.space] + rooms = [self.room] + for _ in range(5): + spaces.append( + self.helper.create_room_as( + self.user, + tok=self.token, + extra_content={ + "creation_content": { + EventContentFields.ROOM_TYPE: RoomTypes.SPACE + } + }, + ) + ) + self._add_child(spaces[-2], spaces[-1], self.token) + rooms.append(self.helper.create_room_as(self.user, tok=self.token)) + self._add_child(spaces[-1], rooms[-1], self.token) + + # Test just the space itself. + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, max_depth=0) + ) + expected = [(spaces[0], [rooms[0], spaces[1]])] + self._assert_hierarchy(result, expected) + + # A single additional layer. + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, max_depth=1) + ) + expected += [ + (rooms[0], ()), + (spaces[1], [rooms[1], spaces[2]]), + ] + self._assert_hierarchy(result, expected) + + # A few layers. + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, max_depth=3) + ) + expected += [ + (rooms[1], ()), + (spaces[2], [rooms[2], spaces[3]]), + (rooms[2], ()), + (spaces[3], [rooms[3], spaces[4]]), + ] + self._assert_hierarchy(result, expected) + def test_fed_complex(self): """ Return data over federation and ensure that it is handled properly. @@ -417,15 +605,13 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self.handler.get_space_summary(self.user, self.space) ) - self._assert_rooms(result, [self.space, self.room, subspace, subroom]) - self._assert_events( - result, - [ - (self.space, self.room), - (self.space, subspace), - (subspace, subroom), - ], - ) + expected = [ + (self.space, [self.room, subspace]), + (self.room, ()), + (subspace, [subroom]), + (subroom, ()), + ] + self._assert_rooms(result, expected) def test_fed_filtering(self): """ @@ -554,35 +740,30 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self.handler.get_space_summary(self.user, self.space) ) - self._assert_rooms( - result, - [ - self.space, - self.room, + expected = [ + (self.space, [self.room, subspace]), + (self.room, ()), + ( subspace, - public_room, - knock_room, - invited_room, - restricted_accessible_room, - world_readable_room, - joined_room, - ], - ) - self._assert_events( - result, - [ - (self.space, self.room), - (self.space, subspace), - (subspace, public_room), - (subspace, knock_room), - (subspace, not_invited_room), - (subspace, invited_room), - (subspace, restricted_room), - (subspace, restricted_accessible_room), - (subspace, world_readable_room), - (subspace, joined_room), - ], - ) + [ + public_room, + knock_room, + not_invited_room, + invited_room, + restricted_room, + restricted_accessible_room, + world_readable_room, + joined_room, + ], + ), + (public_room, ()), + (knock_room, ()), + (invited_room, ()), + (restricted_accessible_room, ()), + (world_readable_room, ()), + (joined_room, ()), + ] + self._assert_rooms(result, expected) def test_fed_invited(self): """ @@ -623,18 +804,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self.handler.get_space_summary(self.user, self.space) ) - self._assert_rooms( - result, - [ - self.space, - self.room, - fed_room, - ], - ) - self._assert_events( - result, - [ - (self.space, self.room), - (self.space, fed_room), - ], - ) + expected = [ + (self.space, [self.room, fed_room]), + (self.room, ()), + (fed_room, ()), + ] + self._assert_rooms(result, expected) -- cgit 1.5.1 From 8c654b73095a594b36101aa81cf91a8e1bebc29f Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 10 Aug 2021 18:10:40 -0500 Subject: Only return state events that the AS passed in via `state_events_at_start` (MSC2716) (#10552) * Only return state events that the AS passed in via state_events_at_start As discovered by @Half-Shot in https://github.com/matrix-org/matrix-doc/pull/2716#discussion_r684158448 Part of MSC2716 * Add changelog * Fix changelog extension --- changelog.d/10552.misc | 1 + synapse/rest/client/v1/room.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10552.misc (limited to 'synapse') diff --git a/changelog.d/10552.misc b/changelog.d/10552.misc new file mode 100644 index 0000000000..fc5f6aea5f --- /dev/null +++ b/changelog.d/10552.misc @@ -0,0 +1 @@ +Update `/batch_send` endpoint to only return `state_events` created by the `state_events_from_before` passed in. diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index b28b72bfbd..f1bc43be2d 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -437,6 +437,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): prev_state_ids = list(prev_state_map.values()) auth_event_ids = prev_state_ids + state_events_at_start = [] for state_event in body["state_events_at_start"]: assert_params_in_dict( state_event, ["type", "origin_server_ts", "content", "sender"] @@ -502,6 +503,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): ) event_id = event.event_id + state_events_at_start.append(event_id) auth_event_ids.append(event_id) events_to_create = body["events"] @@ -651,7 +653,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): event_ids.append(base_insertion_event.event_id) return 200, { - "state_events": auth_event_ids, + "state_events": state_events_at_start, "events": event_ids, "next_chunk_id": insertion_event["content"][ EventContentFields.MSC2716_NEXT_CHUNK_ID -- cgit 1.5.1 From 339c3918e1301d53b998c98282137b12d9d16c45 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 11 Aug 2021 16:34:59 +0200 Subject: support federation queries through http connect proxy (#10475) Signed-off-by: Marcus Hoffmann Signed-off-by: Dirk Klimpel dirk@klimpel.org --- changelog.d/10475.feature | 1 + docs/setup/forward_proxy.md | 6 +- docs/upgrade.md | 27 ++ synapse/http/connectproxyclient.py | 68 ++-- synapse/http/federation/matrix_federation_agent.py | 100 ++++- synapse/http/matrixfederationclient.py | 12 +- synapse/http/proxyagent.py | 51 +-- .../federation/test_matrix_federation_agent.py | 406 +++++++++++++++++---- tests/http/test_proxyagent.py | 75 ++-- 9 files changed, 555 insertions(+), 191 deletions(-) create mode 100644 changelog.d/10475.feature (limited to 'synapse') diff --git a/changelog.d/10475.feature b/changelog.d/10475.feature new file mode 100644 index 0000000000..52eab11b03 --- /dev/null +++ b/changelog.d/10475.feature @@ -0,0 +1 @@ +Add support for sending federation requests through a proxy. Contributed by @Bubu and @dklimpel. \ No newline at end of file diff --git a/docs/setup/forward_proxy.md b/docs/setup/forward_proxy.md index a0720ab342..494c14893b 100644 --- a/docs/setup/forward_proxy.md +++ b/docs/setup/forward_proxy.md @@ -45,18 +45,18 @@ The proxy will be **used** for: - recaptcha validation - CAS auth validation - OpenID Connect +- Outbound federation - Federation (checking public key revocation) +- Fetching public keys of other servers +- Downloading remote media It will **not be used** for: - Application Services - Identity servers -- Outbound federation - In worker configurations - connections between workers - connections from workers to Redis -- Fetching public keys of other servers -- Downloading remote media ## Troubleshooting diff --git a/docs/upgrade.md b/docs/upgrade.md index ce9167e6de..8831c9d6cf 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -86,6 +86,33 @@ process, for example: ``` +# Upgrading to v1.xx.0 + +## Add support for routing outbound HTTP requests via a proxy for federation + +Since Synapse 1.6.0 (2019-11-26) you can set a proxy for outbound HTTP requests via +http_proxy/https_proxy environment variables. This proxy was set for: +- push +- url previews +- phone-home stats +- recaptcha validation +- CAS auth validation +- OpenID Connect +- Federation (checking public key revocation) + +In this version we have added support for outbound requests for: +- Outbound federation +- Downloading remote media +- Fetching public keys of other servers + +These requests use the same proxy configuration. If you have a proxy configuration we +recommend to verify the configuration. It may be necessary to adjust the `no_proxy` +environment variable. + +See [using a forward proxy with Synapse documentation](setup/forward_proxy.md) for +details. + + # Upgrading to v1.39.0 ## Deprecation of the current third-party rules module interface diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py index 17e1c5abb1..c577142268 100644 --- a/synapse/http/connectproxyclient.py +++ b/synapse/http/connectproxyclient.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import logging +from typing import Optional +import attr from zope.interface import implementer from twisted.internet import defer, protocol @@ -21,7 +24,6 @@ from twisted.internet.error import ConnectError from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint from twisted.internet.protocol import ClientFactory, Protocol, connectionDone from twisted.web import http -from twisted.web.http_headers import Headers logger = logging.getLogger(__name__) @@ -30,6 +32,22 @@ class ProxyConnectError(ConnectError): pass +@attr.s +class ProxyCredentials: + username_password = attr.ib(type=bytes) + + def as_proxy_authorization_value(self) -> bytes: + """ + Return the value for a Proxy-Authorization header (i.e. 'Basic abdef=='). + + Returns: + A transformation of the authentication string the encoded value for + a Proxy-Authorization header. + """ + # Encode as base64 and prepend the authorization type + return b"Basic " + base64.encodebytes(self.username_password) + + @implementer(IStreamClientEndpoint) class HTTPConnectProxyEndpoint: """An Endpoint implementation which will send a CONNECT request to an http proxy @@ -46,7 +64,7 @@ class HTTPConnectProxyEndpoint: proxy_endpoint: the endpoint to use to connect to the proxy host: hostname that we want to CONNECT to port: port that we want to connect to - headers: Extra HTTP headers to include in the CONNECT request + proxy_creds: credentials to authenticate at proxy """ def __init__( @@ -55,20 +73,20 @@ class HTTPConnectProxyEndpoint: proxy_endpoint: IStreamClientEndpoint, host: bytes, port: int, - headers: Headers, + proxy_creds: Optional[ProxyCredentials], ): self._reactor = reactor self._proxy_endpoint = proxy_endpoint self._host = host self._port = port - self._headers = headers + self._proxy_creds = proxy_creds def __repr__(self): return "" % (self._proxy_endpoint,) def connect(self, protocolFactory: ClientFactory): f = HTTPProxiedClientFactory( - self._host, self._port, protocolFactory, self._headers + self._host, self._port, protocolFactory, self._proxy_creds ) d = self._proxy_endpoint.connect(f) # once the tcp socket connects successfully, we need to wait for the @@ -87,7 +105,7 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): dst_host: hostname that we want to CONNECT to dst_port: port that we want to connect to wrapped_factory: The original Factory - headers: Extra HTTP headers to include in the CONNECT request + proxy_creds: credentials to authenticate at proxy """ def __init__( @@ -95,12 +113,12 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): dst_host: bytes, dst_port: int, wrapped_factory: ClientFactory, - headers: Headers, + proxy_creds: Optional[ProxyCredentials], ): self.dst_host = dst_host self.dst_port = dst_port self.wrapped_factory = wrapped_factory - self.headers = headers + self.proxy_creds = proxy_creds self.on_connection = defer.Deferred() def startedConnecting(self, connector): @@ -114,7 +132,7 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): self.dst_port, wrapped_protocol, self.on_connection, - self.headers, + self.proxy_creds, ) def clientConnectionFailed(self, connector, reason): @@ -145,7 +163,7 @@ class HTTPConnectProtocol(protocol.Protocol): connected_deferred: a Deferred which will be callbacked with wrapped_protocol when the CONNECT completes - headers: Extra HTTP headers to include in the CONNECT request + proxy_creds: credentials to authenticate at proxy """ def __init__( @@ -154,16 +172,16 @@ class HTTPConnectProtocol(protocol.Protocol): port: int, wrapped_protocol: Protocol, connected_deferred: defer.Deferred, - headers: Headers, + proxy_creds: Optional[ProxyCredentials], ): self.host = host self.port = port self.wrapped_protocol = wrapped_protocol self.connected_deferred = connected_deferred - self.headers = headers + self.proxy_creds = proxy_creds self.http_setup_client = HTTPConnectSetupClient( - self.host, self.port, self.headers + self.host, self.port, self.proxy_creds ) self.http_setup_client.on_connected.addCallback(self.proxyConnected) @@ -205,30 +223,38 @@ class HTTPConnectSetupClient(http.HTTPClient): Args: host: The hostname to send in the CONNECT message port: The port to send in the CONNECT message - headers: Extra headers to send with the CONNECT message + proxy_creds: credentials to authenticate at proxy """ - def __init__(self, host: bytes, port: int, headers: Headers): + def __init__( + self, + host: bytes, + port: int, + proxy_creds: Optional[ProxyCredentials], + ): self.host = host self.port = port - self.headers = headers + self.proxy_creds = proxy_creds self.on_connected = defer.Deferred() def connectionMade(self): logger.debug("Connected to proxy, sending CONNECT") self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port)) - # Send any additional specified headers - for name, values in self.headers.getAllRawHeaders(): - for value in values: - self.sendHeader(name, value) + # Determine whether we need to set Proxy-Authorization headers + if self.proxy_creds: + # Set a Proxy-Authorization header + self.sendHeader( + b"Proxy-Authorization", + self.proxy_creds.as_proxy_authorization_value(), + ) self.endHeaders() def handleStatus(self, version: bytes, status: bytes, message: bytes): logger.debug("Got Status: %s %s %s", status, message, version) if status != b"200": - raise ProxyConnectError("Unexpected status on CONNECT: %s" % status) + raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}") def handleEndHeaders(self): logger.debug("End Headers") diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index c16b7f10e6..1238bfd287 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -14,6 +14,10 @@ import logging import urllib.parse from typing import Any, Generator, List, Optional +from urllib.request import ( # type: ignore[attr-defined] + getproxies_environment, + proxy_bypass_environment, +) from netaddr import AddrFormatError, IPAddress, IPSet from zope.interface import implementer @@ -30,9 +34,12 @@ from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer, IResponse from synapse.crypto.context_factory import FederationPolicyForHTTPS -from synapse.http.client import BlacklistingAgentWrapper +from synapse.http import proxyagent +from synapse.http.client import BlacklistingAgentWrapper, BlacklistingReactorWrapper +from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.http.federation.well_known_resolver import WellKnownResolver +from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.types import ISynapseReactor from synapse.util import Clock @@ -57,6 +64,14 @@ class MatrixFederationAgent: user_agent: The user agent header to use for federation requests. + ip_whitelist: Allowed IP addresses. + + ip_blacklist: Disallowed IP addresses. + + proxy_reactor: twisted reactor to use for connections to the proxy server + reactor might have some blacklisting applied (i.e. for DNS queries), + but we need unblocked access to the proxy. + _srv_resolver: SrvResolver implementation to use for looking up SRV records. None to use a default implementation. @@ -71,11 +86,18 @@ class MatrixFederationAgent: reactor: ISynapseReactor, tls_client_options_factory: Optional[FederationPolicyForHTTPS], user_agent: bytes, + ip_whitelist: IPSet, ip_blacklist: IPSet, _srv_resolver: Optional[SrvResolver] = None, _well_known_resolver: Optional[WellKnownResolver] = None, ): - self._reactor = reactor + # proxy_reactor is not blacklisted + proxy_reactor = reactor + + # We need to use a DNS resolver which filters out blacklisted IP + # addresses, to prevent DNS rebinding. + reactor = BlacklistingReactorWrapper(reactor, ip_whitelist, ip_blacklist) + self._clock = Clock(reactor) self._pool = HTTPConnectionPool(reactor) self._pool.retryAutomatically = False @@ -83,24 +105,27 @@ class MatrixFederationAgent: self._pool.cachedConnectionTimeout = 2 * 60 self._agent = Agent.usingEndpointFactory( - self._reactor, + reactor, MatrixHostnameEndpointFactory( - reactor, tls_client_options_factory, _srv_resolver + reactor, + proxy_reactor, + tls_client_options_factory, + _srv_resolver, ), pool=self._pool, ) self.user_agent = user_agent if _well_known_resolver is None: - # Note that the name resolver has already been wrapped in a - # IPBlacklistingResolver by MatrixFederationHttpClient. _well_known_resolver = WellKnownResolver( - self._reactor, + reactor, agent=BlacklistingAgentWrapper( - Agent( - self._reactor, + ProxyAgent( + reactor, + proxy_reactor, pool=self._pool, contextFactory=tls_client_options_factory, + use_proxy=True, ), ip_blacklist=ip_blacklist, ), @@ -200,10 +225,12 @@ class MatrixHostnameEndpointFactory: def __init__( self, reactor: IReactorCore, + proxy_reactor: IReactorCore, tls_client_options_factory: Optional[FederationPolicyForHTTPS], srv_resolver: Optional[SrvResolver], ): self._reactor = reactor + self._proxy_reactor = proxy_reactor self._tls_client_options_factory = tls_client_options_factory if srv_resolver is None: @@ -211,9 +238,10 @@ class MatrixHostnameEndpointFactory: self._srv_resolver = srv_resolver - def endpointForURI(self, parsed_uri): + def endpointForURI(self, parsed_uri: URI): return MatrixHostnameEndpoint( self._reactor, + self._proxy_reactor, self._tls_client_options_factory, self._srv_resolver, parsed_uri, @@ -227,23 +255,45 @@ class MatrixHostnameEndpoint: Args: reactor: twisted reactor to use for underlying requests + proxy_reactor: twisted reactor to use for connections to the proxy server. + 'reactor' might have some blacklisting applied (i.e. for DNS queries), + but we need unblocked access to the proxy. tls_client_options_factory: factory to use for fetching client tls options, or none to disable TLS. srv_resolver: The SRV resolver to use parsed_uri: The parsed URI that we're wanting to connect to. + + Raises: + ValueError if the environment variables contain an invalid proxy specification. + RuntimeError if no tls_options_factory is given for a https connection """ def __init__( self, reactor: IReactorCore, + proxy_reactor: IReactorCore, tls_client_options_factory: Optional[FederationPolicyForHTTPS], srv_resolver: SrvResolver, parsed_uri: URI, ): self._reactor = reactor - self._parsed_uri = parsed_uri + # http_proxy is not needed because federation is always over TLS + proxies = getproxies_environment() + https_proxy = proxies["https"].encode() if "https" in proxies else None + self.no_proxy = proxies["no"] if "no" in proxies else None + + # endpoint and credentials to use to connect to the outbound https proxy, if any. + ( + self._https_proxy_endpoint, + self._https_proxy_creds, + ) = proxyagent.http_proxy_endpoint( + https_proxy, + proxy_reactor, + tls_client_options_factory, + ) + # set up the TLS connection params # # XXX disabling TLS is really only supported here for the benefit of the @@ -273,9 +323,33 @@ class MatrixHostnameEndpoint: host = server.host port = server.port + should_skip_proxy = False + if self.no_proxy is not None: + should_skip_proxy = proxy_bypass_environment( + host.decode(), + proxies={"no": self.no_proxy}, + ) + + endpoint: IStreamClientEndpoint try: - logger.debug("Connecting to %s:%i", host.decode("ascii"), port) - endpoint = HostnameEndpoint(self._reactor, host, port) + if self._https_proxy_endpoint and not should_skip_proxy: + logger.debug( + "Connecting to %s:%i via %s", + host.decode("ascii"), + port, + self._https_proxy_endpoint, + ) + endpoint = HTTPConnectProxyEndpoint( + self._reactor, + self._https_proxy_endpoint, + host, + port, + proxy_creds=self._https_proxy_creds, + ) + else: + logger.debug("Connecting to %s:%i", host.decode("ascii"), port) + # not using a proxy + endpoint = HostnameEndpoint(self._reactor, host, port) if self._tls_options: endpoint = wrapClientTLS(self._tls_options, endpoint) result = await make_deferred_yieldable( diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 2efa15bf04..2e9898997c 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -59,7 +59,6 @@ from synapse.api.errors import ( from synapse.http import QuieterFileBodyProducer from synapse.http.client import ( BlacklistingAgentWrapper, - BlacklistingReactorWrapper, BodyExceededMaxSize, ByteWriteable, encode_query_args, @@ -69,7 +68,7 @@ from synapse.http.federation.matrix_federation_agent import MatrixFederationAgen from synapse.logging import opentracing from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags -from synapse.types import ISynapseReactor, JsonDict +from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure @@ -325,13 +324,7 @@ class MatrixFederationHttpClient: self.signing_key = hs.signing_key self.server_name = hs.hostname - # We need to use a DNS resolver which filters out blacklisted IP - # addresses, to prevent DNS rebinding. - self.reactor: ISynapseReactor = BlacklistingReactorWrapper( - hs.get_reactor(), - hs.config.federation_ip_range_whitelist, - hs.config.federation_ip_range_blacklist, - ) + self.reactor = hs.get_reactor() user_agent = hs.version_string if hs.config.user_agent_suffix: @@ -342,6 +335,7 @@ class MatrixFederationHttpClient: self.reactor, tls_client_options_factory, user_agent, + hs.config.federation_ip_range_whitelist, hs.config.federation_ip_range_blacklist, ) diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 19e987f118..a3f31452d0 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import base64 import logging import re from typing import Any, Dict, Optional, Tuple @@ -21,7 +20,6 @@ from urllib.request import ( # type: ignore[attr-defined] proxy_bypass_environment, ) -import attr from zope.interface import implementer from twisted.internet import defer @@ -38,7 +36,7 @@ from twisted.web.error import SchemeNotSupported from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS -from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint +from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials from synapse.types import ISynapseReactor logger = logging.getLogger(__name__) @@ -46,22 +44,6 @@ logger = logging.getLogger(__name__) _VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z") -@attr.s -class ProxyCredentials: - username_password = attr.ib(type=bytes) - - def as_proxy_authorization_value(self) -> bytes: - """ - Return the value for a Proxy-Authorization header (i.e. 'Basic abdef=='). - - Returns: - A transformation of the authentication string the encoded value for - a Proxy-Authorization header. - """ - # Encode as base64 and prepend the authorization type - return b"Basic " + base64.encodebytes(self.username_password) - - @implementer(IAgent) class ProxyAgent(_AgentBase): """An Agent implementation which will use an HTTP proxy if one was requested @@ -95,6 +77,7 @@ class ProxyAgent(_AgentBase): Raises: ValueError if use_proxy is set and the environment variables contain an invalid proxy specification. + RuntimeError if no tls_options_factory is given for a https connection """ def __init__( @@ -131,11 +114,11 @@ class ProxyAgent(_AgentBase): https_proxy = proxies["https"].encode() if "https" in proxies else None no_proxy = proxies["no"] if "no" in proxies else None - self.http_proxy_endpoint, self.http_proxy_creds = _http_proxy_endpoint( + self.http_proxy_endpoint, self.http_proxy_creds = http_proxy_endpoint( http_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs ) - self.https_proxy_endpoint, self.https_proxy_creds = _http_proxy_endpoint( + self.https_proxy_endpoint, self.https_proxy_creds = http_proxy_endpoint( https_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs ) @@ -224,22 +207,12 @@ class ProxyAgent(_AgentBase): and self.https_proxy_endpoint and not should_skip_proxy ): - connect_headers = Headers() - - # Determine whether we need to set Proxy-Authorization headers - if self.https_proxy_creds: - # Set a Proxy-Authorization header - connect_headers.addRawHeader( - b"Proxy-Authorization", - self.https_proxy_creds.as_proxy_authorization_value(), - ) - endpoint = HTTPConnectProxyEndpoint( self.proxy_reactor, self.https_proxy_endpoint, parsed_uri.host, parsed_uri.port, - headers=connect_headers, + self.https_proxy_creds, ) else: # not using a proxy @@ -268,10 +241,10 @@ class ProxyAgent(_AgentBase): ) -def _http_proxy_endpoint( +def http_proxy_endpoint( proxy: Optional[bytes], reactor: IReactorCore, - tls_options_factory: IPolicyForHTTPS, + tls_options_factory: Optional[IPolicyForHTTPS], **kwargs, ) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]: """Parses an http proxy setting and returns an endpoint for the proxy @@ -294,6 +267,7 @@ def _http_proxy_endpoint( Raise: ValueError if proxy has no hostname or unsupported scheme. + RuntimeError if no tls_options_factory is given for a https connection """ if proxy is None: return None, None @@ -305,8 +279,13 @@ def _http_proxy_endpoint( proxy_endpoint = HostnameEndpoint(reactor, host, port, **kwargs) if scheme == b"https": - tls_options = tls_options_factory.creatorForNetloc(host, port) - proxy_endpoint = wrapClientTLS(tls_options, proxy_endpoint) + if tls_options_factory: + tls_options = tls_options_factory.creatorForNetloc(host, port) + proxy_endpoint = wrapClientTLS(tls_options, proxy_endpoint) + else: + raise RuntimeError( + f"No TLS options for a https connection via proxy {proxy!s}" + ) return proxy_endpoint, credentials diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index a37bce08c3..992d8f94fd 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import base64 import logging -from typing import Optional -from unittest.mock import Mock +import os +from typing import Iterable, Optional +from unittest.mock import Mock, patch import treq from netaddr import IPSet @@ -22,11 +24,12 @@ from zope.interface import implementer from twisted.internet import defer from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions +from twisted.internet.interfaces import IProtocolFactory from twisted.internet.protocol import Factory -from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web._newclient import ResponseNeverReceived from twisted.web.client import Agent -from twisted.web.http import HTTPChannel +from twisted.web.http import HTTPChannel, Request from twisted.web.http_headers import Headers from twisted.web.iweb import IPolicyForHTTPS @@ -49,24 +52,6 @@ from tests.utils import default_config logger = logging.getLogger(__name__) -test_server_connection_factory = None - - -def get_connection_factory(): - # this needs to happen once, but not until we are ready to run the first test - global test_server_connection_factory - if test_server_connection_factory is None: - test_server_connection_factory = TestServerTLSConnectionFactory( - sanlist=[ - b"DNS:testserv", - b"DNS:target-server", - b"DNS:xn--bcher-kva.com", - b"IP:1.2.3.4", - b"IP:::1", - ] - ) - return test_server_connection_factory - # Once Async Mocks or lambdas are supported this can go away. def generate_resolve_service(result): @@ -100,24 +85,38 @@ class MatrixFederationAgentTests(unittest.TestCase): had_well_known_cache=self.had_well_known_cache, ) - self.agent = MatrixFederationAgent( - reactor=self.reactor, - tls_client_options_factory=self.tls_factory, - user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. - ip_blacklist=IPSet(), - _srv_resolver=self.mock_resolver, - _well_known_resolver=self.well_known_resolver, - ) - - def _make_connection(self, client_factory, expected_sni): + def _make_connection( + self, + client_factory: IProtocolFactory, + ssl: bool = True, + expected_sni: bytes = None, + tls_sanlist: Optional[Iterable[bytes]] = None, + ) -> HTTPChannel: """Builds a test server, and completes the outgoing client connection + Args: + client_factory: the the factory that the + application is trying to use to make the outbound connection. We will + invoke it to build the client Protocol + + ssl: If true, we will expect an ssl connection and wrap + server_factory with a TLSMemoryBIOFactory + False is set only for when proxy expect http connection. + Otherwise federation requests use always https. + + expected_sni: the expected SNI value + + tls_sanlist: list of SAN entries for the TLS cert presented by the server. Returns: - HTTPChannel: the test server + the server Protocol returned by server_factory """ # build the test server - server_tls_protocol = _build_test_server(get_connection_factory()) + server_factory = _get_test_protocol_factory() + if ssl: + server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist) + + server_protocol = server_factory.buildProtocol(None) # now, tell the client protocol factory to build the client protocol (it will be a # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an @@ -128,35 +127,39 @@ class MatrixFederationAgentTests(unittest.TestCase): # stubbing that out here. client_protocol = client_factory.buildProtocol(None) client_protocol.makeConnection( - FakeTransport(server_tls_protocol, self.reactor, client_protocol) + FakeTransport(server_protocol, self.reactor, client_protocol) ) - # tell the server tls protocol to send its stuff back to the client, too - server_tls_protocol.makeConnection( - FakeTransport(client_protocol, self.reactor, server_tls_protocol) + # tell the server protocol to send its stuff back to the client, too + server_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, server_protocol) ) - # grab a hold of the TLS connection, in case it gets torn down - server_tls_connection = server_tls_protocol._tlsConnection - - # fish the test server back out of the server-side TLS protocol. - http_protocol = server_tls_protocol.wrappedProtocol + if ssl: + # fish the test server back out of the server-side TLS protocol. + http_protocol = server_protocol.wrappedProtocol + # grab a hold of the TLS connection, in case it gets torn down + tls_connection = server_protocol._tlsConnection + else: + http_protocol = server_protocol + tls_connection = None - # give the reactor a pump to get the TLS juices flowing. - self.reactor.pump((0.1,)) + # give the reactor a pump to get the TLS juices flowing (if needed) + self.reactor.advance(0) # check the SNI - server_name = server_tls_connection.get_servername() - self.assertEqual( - server_name, - expected_sni, - "Expected SNI %s but got %s" % (expected_sni, server_name), - ) + if expected_sni is not None: + server_name = tls_connection.get_servername() + self.assertEqual( + server_name, + expected_sni, + f"Expected SNI {expected_sni!s} but got {server_name!s}", + ) return http_protocol @defer.inlineCallbacks - def _make_get_request(self, uri): + def _make_get_request(self, uri: bytes): """ Sends a simple GET request via the agent, and checks its logcontext management """ @@ -180,20 +183,20 @@ class MatrixFederationAgentTests(unittest.TestCase): def _handle_well_known_connection( self, - client_factory, - expected_sni, - content, + client_factory: IProtocolFactory, + expected_sni: bytes, + content: bytes, response_headers: Optional[dict] = None, - ): + ) -> HTTPChannel: """Handle an outgoing HTTPs connection: wire it up to a server, check that the request is for a .well-known, and send the response. Args: - client_factory (IProtocolFactory): outgoing connection - expected_sni (bytes): SNI that we expect the outgoing connection to send - content (bytes): content to send back as the .well-known + client_factory: outgoing connection + expected_sni: SNI that we expect the outgoing connection to send + content: content to send back as the .well-known Returns: - HTTPChannel: server impl + server impl """ # make the connection for .well-known well_known_server = self._make_connection( @@ -209,7 +212,10 @@ class MatrixFederationAgentTests(unittest.TestCase): return well_known_server def _send_well_known_response( - self, request, content, headers: Optional[dict] = None + self, + request: Request, + content: bytes, + headers: Optional[dict] = None, ): """Check that an incoming request looks like a valid .well-known request, and send back the response. @@ -225,10 +231,37 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) - def test_get(self): + def _make_agent(self) -> MatrixFederationAgent: """ - happy-path test of a GET request with an explicit port + If a proxy server is set, the MatrixFederationAgent must be created again + because it is created too early during setUp """ + return MatrixFederationAgent( + reactor=self.reactor, + tls_client_options_factory=self.tls_factory, + user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. + ip_whitelist=IPSet(), + ip_blacklist=IPSet(), + _srv_resolver=self.mock_resolver, + _well_known_resolver=self.well_known_resolver, + ) + + def test_get(self): + """happy-path test of a GET request with an explicit port""" + self._do_get() + + @patch.dict( + os.environ, + {"https_proxy": "proxy.com", "no_proxy": "testserv"}, + ) + def test_get_bypass_proxy(self): + """test of a GET request with an explicit port and bypass proxy""" + self._do_get() + + def _do_get(self): + """test of a GET request with an explicit port""" + self.agent = self._make_agent() + self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar") @@ -282,10 +315,188 @@ class MatrixFederationAgentTests(unittest.TestCase): json = self.successResultOf(treq.json_content(response)) self.assertEqual(json, {"a": 1}) + @patch.dict( + os.environ, {"https_proxy": "http://proxy.com", "no_proxy": "unused.com"} + ) + def test_get_via_http_proxy(self): + """test for federation request through a http proxy""" + self._do_get_via_proxy(expect_proxy_ssl=False, expected_auth_credentials=None) + + @patch.dict( + os.environ, + {"https_proxy": "http://user:pass@proxy.com", "no_proxy": "unused.com"}, + ) + def test_get_via_http_proxy_with_auth(self): + """test for federation request through a http proxy with authentication""" + self._do_get_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=b"user:pass" + ) + + @patch.dict( + os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"} + ) + def test_get_via_https_proxy(self): + """test for federation request through a https proxy""" + self._do_get_via_proxy(expect_proxy_ssl=True, expected_auth_credentials=None) + + @patch.dict( + os.environ, + {"https_proxy": "https://user:pass@proxy.com", "no_proxy": "unused.com"}, + ) + def test_get_via_https_proxy_with_auth(self): + """test for federation request through a https proxy with authentication""" + self._do_get_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=b"user:pass" + ) + + def _do_get_via_proxy( + self, + expect_proxy_ssl: bool = False, + expected_auth_credentials: Optional[bytes] = None, + ): + """Send a https federation request via an agent and check that it is correctly + received at the proxy and client. The proxy can use either http or https. + Args: + expect_proxy_ssl: True if we expect the request to connect to the proxy via https. + expected_auth_credentials: credentials we expect to be presented to authenticate at the proxy + """ + self.agent = self._make_agent() + + self.reactor.lookups["testserv"] = "1.2.3.4" + self.reactor.lookups["proxy.com"] = "9.9.9.9" + test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + # make sure we are connecting to the proxy + self.assertEqual(host, "9.9.9.9") + self.assertEqual(port, 1080) + + # make a test server to act as the proxy, and wire up the client + proxy_server = self._make_connection( + client_factory, + ssl=expect_proxy_ssl, + tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None, + expected_sni=b"proxy.com" if expect_proxy_ssl else None, + ) + + assert isinstance(proxy_server, HTTPChannel) + + # now there should be a pending CONNECT request + self.assertEqual(len(proxy_server.requests), 1) + + request = proxy_server.requests[0] + self.assertEqual(request.method, b"CONNECT") + self.assertEqual(request.path, b"testserv:8448") + + # Check whether auth credentials have been supplied to the proxy + proxy_auth_header_values = request.requestHeaders.getRawHeaders( + b"Proxy-Authorization" + ) + + if expected_auth_credentials is not None: + # Compute the correct header value for Proxy-Authorization + encoded_credentials = base64.b64encode(expected_auth_credentials) + expected_header_value = b"Basic " + encoded_credentials + + # Validate the header's value + self.assertIn(expected_header_value, proxy_auth_header_values) + else: + # Check that the Proxy-Authorization header has not been supplied to the proxy + self.assertIsNone(proxy_auth_header_values) + + # tell the proxy server not to close the connection + proxy_server.persistent = True + + request.finish() + + # now we make another test server to act as the upstream HTTP server. + server_ssl_protocol = _wrap_server_factory_for_tls( + _get_test_protocol_factory() + ).buildProtocol(None) + + # Tell the HTTP server to send outgoing traffic back via the proxy's transport. + proxy_server_transport = proxy_server.transport + server_ssl_protocol.makeConnection(proxy_server_transport) + + # ... and replace the protocol on the proxy's transport with the + # TLSMemoryBIOProtocol for the test server, so that incoming traffic + # to the proxy gets sent over to the HTTP(s) server. + + # See also comment at `_do_https_request_via_proxy` + # in ../test_proxyagent.py for more details + if expect_proxy_ssl: + assert isinstance(proxy_server_transport, TLSMemoryBIOProtocol) + proxy_server_transport.wrappedProtocol = server_ssl_protocol + else: + assert isinstance(proxy_server_transport, FakeTransport) + client_protocol = proxy_server_transport.other + c2s_transport = client_protocol.transport + c2s_transport.other = server_ssl_protocol + + self.reactor.advance(0) + + server_name = server_ssl_protocol._tlsConnection.get_servername() + expected_sni = b"testserv" + self.assertEqual( + server_name, + expected_sni, + f"Expected SNI {expected_sni!s} but got {server_name!s}", + ) + + # now there should be a pending request + http_server = server_ssl_protocol.wrappedProtocol + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual( + request.requestHeaders.getRawHeaders(b"host"), [b"testserv:8448"] + ) + self.assertEqual( + request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"] + ) + # Check that the destination server DID NOT receive proxy credentials + self.assertIsNone(request.requestHeaders.getRawHeaders(b"Proxy-Authorization")) + content = request.content.read() + self.assertEqual(content, b"") + + # Deferred is still without a result + self.assertNoResult(test_d) + + # send the headers + request.responseHeaders.setRawHeaders(b"Content-Type", [b"application/json"]) + request.write("") + + self.reactor.pump((0.1,)) + + response = self.successResultOf(test_d) + + # that should give us a Response object + self.assertEqual(response.code, 200) + + # Send the body + request.write('{ "a": 1 }'.encode("ascii")) + request.finish() + + self.reactor.pump((0.1,)) + + # check it can be read + json = self.successResultOf(treq.json_content(response)) + self.assertEqual(json, {"a": 1}) + def test_get_ip_address(self): """ Test the behaviour when the server name contains an explicit IP (with no port) """ + self.agent = self._make_agent() + # there will be a getaddrinfo on the IP self.reactor.lookups["1.2.3.4"] = "1.2.3.4" @@ -320,6 +531,7 @@ class MatrixFederationAgentTests(unittest.TestCase): Test the behaviour when the server name contains an explicit IPv6 address (with no port) """ + self.agent = self._make_agent() # there will be a getaddrinfo on the IP self.reactor.lookups["::1"] = "::1" @@ -355,6 +567,7 @@ class MatrixFederationAgentTests(unittest.TestCase): Test the behaviour when the server name contains an explicit IPv6 address (with explicit port) """ + self.agent = self._make_agent() # there will be a getaddrinfo on the IP self.reactor.lookups["::1"] = "::1" @@ -389,6 +602,8 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when the certificate on the server doesn't match the hostname """ + self.agent = self._make_agent() + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv1"] = "1.2.3.4" @@ -441,6 +656,8 @@ class MatrixFederationAgentTests(unittest.TestCase): Test the behaviour when the server name contains an explicit IP, but the server cert doesn't cover it """ + self.agent = self._make_agent() + # there will be a getaddrinfo on the IP self.reactor.lookups["1.2.3.5"] = "1.2.3.5" @@ -471,6 +688,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when the server name has no port, no SRV, and no well-known """ + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" @@ -524,6 +742,7 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_get_well_known(self): """Test the behaviour when the .well-known delegates elsewhere""" + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" @@ -587,6 +806,8 @@ class MatrixFederationAgentTests(unittest.TestCase): """Test the behaviour when the server name has no port and no SRV record, but the .well-known has a 300 redirect """ + self.agent = self._make_agent() + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" @@ -675,6 +896,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when the server name has an *invalid* well-known (and no SRV) """ + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" @@ -743,6 +965,7 @@ class MatrixFederationAgentTests(unittest.TestCase): reactor=self.reactor, tls_client_options_factory=tls_factory, user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below. + ip_whitelist=IPSet(), ip_blacklist=IPSet(), _srv_resolver=self.mock_resolver, _well_known_resolver=WellKnownResolver( @@ -780,6 +1003,8 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when there is a single SRV record """ + self.agent = self._make_agent() + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( [Server(host=b"srvtarget", port=8443)] ) @@ -820,6 +1045,8 @@ class MatrixFederationAgentTests(unittest.TestCase): """Test the behaviour when the .well-known redirects to a place where there is a SRV. """ + self.agent = self._make_agent() + self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["srvtarget"] = "5.6.7.8" @@ -876,6 +1103,7 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_idna_servername(self): """test the behaviour when the server name has idna chars in""" + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) @@ -937,6 +1165,7 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_idna_srv_target(self): """test the behaviour when the target of a SRV record has idna chars""" + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service( [Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com @@ -1140,6 +1369,8 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_srv_fallbacks(self): """Test that other SRV results are tried if the first one fails.""" + self.agent = self._make_agent() + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( [ Server(host=b"target.com", port=8443), @@ -1266,34 +1497,49 @@ def _check_logcontext(context): raise AssertionError("Expected logcontext %s but was %s" % (context, current)) -def _build_test_server(connection_creator): - """Construct a test server - - This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol - +def _wrap_server_factory_for_tls( + factory: IProtocolFactory, sanlist: Iterable[bytes] = None +) -> IProtocolFactory: + """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory + The resultant factory will create a TLS server which presents a certificate + signed by our test CA, valid for the domains in `sanlist` Args: - connection_creator (IOpenSSLServerConnectionCreator): thing to build - SSL connections - sanlist (list[bytes]): list of the SAN entries for the cert returned - by the server + factory: protocol factory to wrap + sanlist: list of domains the cert should be valid for + Returns: + interfaces.IProtocolFactory + """ + if sanlist is None: + sanlist = [ + b"DNS:testserv", + b"DNS:target-server", + b"DNS:xn--bcher-kva.com", + b"IP:1.2.3.4", + b"IP:::1", + ] + + connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist) + return TLSMemoryBIOFactory( + connection_creator, isClient=False, wrappedFactory=factory + ) + +def _get_test_protocol_factory() -> IProtocolFactory: + """Get a protocol Factory which will build an HTTPChannel Returns: - TLSMemoryBIOProtocol + interfaces.IProtocolFactory """ server_factory = Factory.forProtocol(HTTPChannel) + # Request.finish expects the factory to have a 'log' method. server_factory.log = _log_request - server_tls_factory = TLSMemoryBIOFactory( - connection_creator, isClient=False, wrappedFactory=server_factory - ) - - return server_tls_factory.buildProtocol(None) + return server_factory -def _log_request(request): +def _log_request(request: str): """Implements Factory.log, which is expected by Request.finish""" - logger.info("Completed request %s", request) + logger.info(f"Completed request {request}") @implementer(IPolicyForHTTPS) diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index e5865c161d..2db77c6a73 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -29,7 +29,8 @@ from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web.http import HTTPChannel from synapse.http.client import BlacklistingReactorWrapper -from synapse.http.proxyagent import ProxyAgent, ProxyCredentials, parse_proxy +from synapse.http.connectproxyclient import ProxyCredentials +from synapse.http.proxyagent import ProxyAgent, parse_proxy from tests.http import TestServerTLSConnectionFactory, get_test_https_policy from tests.server import FakeTransport, ThreadedMemoryReactorClock @@ -392,7 +393,9 @@ class MatrixFederationAgentTests(TestCase): """ Tests that requests can be made through a proxy. """ - self._do_http_request_via_proxy(ssl=False, auth_credentials=None) + self._do_http_request_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=None + ) @patch.dict( os.environ, @@ -402,13 +405,17 @@ class MatrixFederationAgentTests(TestCase): """ Tests that authenticated requests can be made through a proxy. """ - self._do_http_request_via_proxy(ssl=False, auth_credentials=b"bob:pinkponies") + self._do_http_request_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies" + ) @patch.dict( os.environ, {"http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com"} ) def test_http_request_via_https_proxy(self): - self._do_http_request_via_proxy(ssl=True, auth_credentials=None) + self._do_http_request_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=None + ) @patch.dict( os.environ, @@ -418,12 +425,16 @@ class MatrixFederationAgentTests(TestCase): }, ) def test_http_request_via_https_proxy_with_auth(self): - self._do_http_request_via_proxy(ssl=True, auth_credentials=b"bob:pinkponies") + self._do_http_request_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies" + ) @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) def test_https_request_via_proxy(self): """Tests that TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(ssl=False, auth_credentials=None) + self._do_https_request_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=None + ) @patch.dict( os.environ, @@ -431,14 +442,18 @@ class MatrixFederationAgentTests(TestCase): ) def test_https_request_via_proxy_with_auth(self): """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(ssl=False, auth_credentials=b"bob:pinkponies") + self._do_https_request_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies" + ) @patch.dict( os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"} ) def test_https_request_via_https_proxy(self): """Tests that TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(ssl=True, auth_credentials=None) + self._do_https_request_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=None + ) @patch.dict( os.environ, @@ -446,20 +461,22 @@ class MatrixFederationAgentTests(TestCase): ) def test_https_request_via_https_proxy_with_auth(self): """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(ssl=True, auth_credentials=b"bob:pinkponies") + self._do_https_request_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies" + ) def _do_http_request_via_proxy( self, - ssl: bool = False, - auth_credentials: Optional[bytes] = None, + expect_proxy_ssl: bool = False, + expected_auth_credentials: Optional[bytes] = None, ): """Send a http request via an agent and check that it is correctly received at the proxy. The proxy can use either http or https. Args: - ssl: True if we expect the request to connect via https to proxy - auth_credentials: credentials to authenticate at proxy + expect_proxy_ssl: True if we expect the request to connect via https to proxy + expected_auth_credentials: credentials to authenticate at proxy """ - if ssl: + if expect_proxy_ssl: agent = ProxyAgent( self.reactor, use_proxy=True, contextFactory=get_test_https_policy() ) @@ -480,9 +497,9 @@ class MatrixFederationAgentTests(TestCase): http_server = self._make_connection( client_factory, _get_test_protocol_factory(), - ssl=ssl, - tls_sanlist=[b"DNS:proxy.com"] if ssl else None, - expected_sni=b"proxy.com" if ssl else None, + ssl=expect_proxy_ssl, + tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None, + expected_sni=b"proxy.com" if expect_proxy_ssl else None, ) # the FakeTransport is async, so we need to pump the reactor @@ -498,9 +515,9 @@ class MatrixFederationAgentTests(TestCase): b"Proxy-Authorization" ) - if auth_credentials is not None: + if expected_auth_credentials is not None: # Compute the correct header value for Proxy-Authorization - encoded_credentials = base64.b64encode(auth_credentials) + encoded_credentials = base64.b64encode(expected_auth_credentials) expected_header_value = b"Basic " + encoded_credentials # Validate the header's value @@ -523,14 +540,14 @@ class MatrixFederationAgentTests(TestCase): def _do_https_request_via_proxy( self, - ssl: bool = False, - auth_credentials: Optional[bytes] = None, + expect_proxy_ssl: bool = False, + expected_auth_credentials: Optional[bytes] = None, ): """Send a https request via an agent and check that it is correctly received at the proxy and client. The proxy can use either http or https. Args: - ssl: True if we expect the request to connect via https to proxy - auth_credentials: credentials to authenticate at proxy + expect_proxy_ssl: True if we expect the request to connect via https to proxy + expected_auth_credentials: credentials to authenticate at proxy """ agent = ProxyAgent( self.reactor, @@ -552,9 +569,9 @@ class MatrixFederationAgentTests(TestCase): proxy_server = self._make_connection( client_factory, _get_test_protocol_factory(), - ssl=ssl, - tls_sanlist=[b"DNS:proxy.com"] if ssl else None, - expected_sni=b"proxy.com" if ssl else None, + ssl=expect_proxy_ssl, + tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None, + expected_sni=b"proxy.com" if expect_proxy_ssl else None, ) assert isinstance(proxy_server, HTTPChannel) @@ -570,9 +587,9 @@ class MatrixFederationAgentTests(TestCase): b"Proxy-Authorization" ) - if auth_credentials is not None: + if expected_auth_credentials is not None: # Compute the correct header value for Proxy-Authorization - encoded_credentials = base64.b64encode(auth_credentials) + encoded_credentials = base64.b64encode(expected_auth_credentials) expected_header_value = b"Basic " + encoded_credentials # Validate the header's value @@ -606,7 +623,7 @@ class MatrixFederationAgentTests(TestCase): # Protocol to implement the proxy, which starts out by forwarding to an # HTTPChannel (to implement the CONNECT command) and can then be switched # into a mode where it forwards its traffic to another Protocol.) - if ssl: + if expect_proxy_ssl: assert isinstance(proxy_server_transport, TLSMemoryBIOProtocol) proxy_server_transport.wrappedProtocol = server_ssl_protocol else: -- cgit 1.5.1 From 2ae2a04616a627eabbf3ca69700462a52f344e69 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 11 Aug 2021 14:31:39 -0400 Subject: Clarify error message when joining a restricted room. (#10572) --- changelog.d/10572.misc | 1 + synapse/handlers/event_auth.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10572.misc (limited to 'synapse') diff --git a/changelog.d/10572.misc b/changelog.d/10572.misc new file mode 100644 index 0000000000..008d7be444 --- /dev/null +++ b/changelog.d/10572.misc @@ -0,0 +1 @@ +Clarify error message when failing to join a restricted room. diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index e2410e482f..4288ffff09 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -213,7 +213,7 @@ class EventAuthHandler: raise AuthError( 403, - "You do not belong to any of the required rooms to join this room.", + "You do not belong to any of the required rooms/spaces to join this room.", ) async def has_restricted_join_rules( -- cgit 1.5.1 From 5acd8b5a960b1c53ce0b9efa304010ec5f0f6682 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 11 Aug 2021 14:52:09 -0400 Subject: Expire old spaces summary pagination sessions. (#10574) --- changelog.d/10574.feature | 1 + synapse/handlers/space_summary.py | 24 +++++++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10574.feature (limited to 'synapse') diff --git a/changelog.d/10574.feature b/changelog.d/10574.feature new file mode 100644 index 0000000000..ffc4e4289c --- /dev/null +++ b/changelog.d/10574.feature @@ -0,0 +1 @@ +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index fd76c34695..8c9852bc89 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -77,6 +77,8 @@ class _PaginationKey: class _PaginationSession: """The information that is stored for pagination.""" + # The time the pagination session was created, in milliseconds. + creation_time_ms: int # The queue of rooms which are still to process. room_queue: Deque["_RoomQueueEntry"] # A set of rooms which have been processed. @@ -84,6 +86,9 @@ class _PaginationSession: class SpaceSummaryHandler: + # The time a pagination session remains valid for. + _PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000 + def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() self._auth = hs.get_auth() @@ -108,6 +113,21 @@ class SpaceSummaryHandler: "get_room_hierarchy", ) + def _expire_pagination_sessions(self): + """Expire pagination session which are old.""" + expire_before = ( + self._clock.time_msec() - self._PAGINATION_SESSION_VALIDITY_PERIOD_MS + ) + to_expire = [] + + for key, value in self._pagination_sessions.items(): + if value.creation_time_ms < expire_before: + to_expire.append(key) + + for key in to_expire: + logger.debug("Expiring pagination session id %s", key) + del self._pagination_sessions[key] + async def get_space_summary( self, requester: str, @@ -312,6 +332,8 @@ class SpaceSummaryHandler: # If this is continuing a previous session, pull the persisted data. if from_token: + self._expire_pagination_sessions() + pagination_key = _PaginationKey( requested_room_id, suggested_only, max_depth, from_token ) @@ -391,7 +413,7 @@ class SpaceSummaryHandler: requested_room_id, suggested_only, max_depth, next_token ) self._pagination_sessions[pagination_key] = _PaginationSession( - room_queue, processed_rooms + self._clock.time_msec(), room_queue, processed_rooms ) return result -- cgit 1.5.1 From 3ebb6694f018eedb7d3c4fda829540f07b45a5b1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 11 Aug 2021 15:04:51 -0400 Subject: Allow requesting the summary of a space which is joinable. (#10580) As opposed to only allowing the summary of spaces which the user is already in or has world-readable visibility. This makes the logic consistent with whether a space/room is returned as part of a space and whether a space summary can start at a space. --- changelog.d/10580.bugfix | 1 + synapse/handlers/space_summary.py | 31 ++++++++++++++++++------------- tests/handlers/test_space_summary.py | 28 ++++++++++++++++++++++++++-- 3 files changed, 45 insertions(+), 15 deletions(-) create mode 100644 changelog.d/10580.bugfix (limited to 'synapse') diff --git a/changelog.d/10580.bugfix b/changelog.d/10580.bugfix new file mode 100644 index 0000000000..f8da7382b7 --- /dev/null +++ b/changelog.d/10580.bugfix @@ -0,0 +1 @@ +Allow public rooms to be previewed in the spaces summary APIs from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 8c9852bc89..893546e661 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -38,7 +38,7 @@ from synapse.api.constants import ( Membership, RoomTypes, ) -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import AuthError, Codes, SynapseError from synapse.events import EventBase from synapse.events.utils import format_event_for_client_v2 from synapse.types import JsonDict @@ -91,7 +91,6 @@ class SpaceSummaryHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() - self._auth = hs.get_auth() self._event_auth_handler = hs.get_event_auth_handler() self._store = hs.get_datastore() self._event_serializer = hs.get_event_client_serializer() @@ -153,9 +152,13 @@ class SpaceSummaryHandler: Returns: summary dict to return """ - # first of all, check that the user is in the room in question (or it's - # world-readable) - await self._auth.check_user_in_room_or_world_readable(room_id, requester) + # First of all, check that the room is accessible. + if not await self._is_local_room_accessible(room_id, requester): + raise AuthError( + 403, + "User %s not in room %s, and room previews are disabled" + % (requester, room_id), + ) # the queue of rooms to process room_queue = deque((_RoomQueueEntry(room_id, ()),)) @@ -324,11 +327,13 @@ class SpaceSummaryHandler: ) -> JsonDict: """See docstring for SpaceSummaryHandler.get_room_hierarchy.""" - # first of all, check that the user is in the room in question (or it's - # world-readable) - await self._auth.check_user_in_room_or_world_readable( - requested_room_id, requester - ) + # First of all, check that the room is accessible. + if not await self._is_local_room_accessible(requested_room_id, requester): + raise AuthError( + 403, + "User %s not in room %s, and room previews are disabled" + % (requester, requested_room_id), + ) # If this is continuing a previous session, pull the persisted data. if from_token: @@ -612,7 +617,7 @@ class SpaceSummaryHandler: return results async def _is_local_room_accessible( - self, room_id: str, requester: Optional[str], origin: Optional[str] + self, room_id: str, requester: Optional[str], origin: Optional[str] = None ) -> bool: """ Calculate whether the room should be shown in the spaces summary. @@ -766,7 +771,7 @@ class SpaceSummaryHandler: # Finally, check locally if we can access the room. The user might # already be in the room (if it was a child room), or there might be a # pending invite, etc. - return await self._is_local_room_accessible(room_id, requester, None) + return await self._is_local_room_accessible(room_id, requester) async def _build_room_entry(self, room_id: str, for_federation: bool) -> JsonDict: """ @@ -783,7 +788,7 @@ class SpaceSummaryHandler: stats = await self._store.get_room_with_stats(room_id) # currently this should be impossible because we call - # check_user_in_room_or_world_readable on the room before we get here, so + # _is_local_room_accessible on the room before we get here, so # there should always be an entry assert stats is not None, "unable to retrieve stats for %s" % (room_id,) diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index 04da9bcc25..806b886fe4 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -248,7 +248,21 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): user2 = self.register_user("user2", "pass") token2 = self.login("user2", "pass") - # The user cannot see the space. + # The user can see the space since it is publicly joinable. + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + expected = [(self.space, [self.room]), (self.room, ())] + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) + + # If the space is made invite-only, it should no longer be viewable. + self.helper.send_state( + self.space, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.INVITE}, + tok=self.token, + ) self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) @@ -260,7 +274,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): tok=self.token, ) result = self.get_success(self.handler.get_space_summary(user2, self.space)) - expected = [(self.space, [self.room]), (self.room, ())] self._assert_rooms(result, expected) result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) @@ -277,6 +290,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) # Join the space and results should be returned. + self.helper.invite(self.space, targ=user2, tok=self.token) self.helper.join(self.space, user2, tok=token2) result = self.get_success(self.handler.get_space_summary(user2, self.space)) self._assert_rooms(result, expected) @@ -284,6 +298,16 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) self._assert_hierarchy(result, expected) + # Attempting to view an unknown room returns the same error. + self.get_failure( + self.handler.get_space_summary(user2, "#not-a-space:" + self.hs.hostname), + AuthError, + ) + self.get_failure( + self.handler.get_room_hierarchy(user2, "#not-a-space:" + self.hs.hostname), + AuthError, + ) + def _create_room_with_join_rule( self, join_rule: str, room_version: Optional[str] = None, **extra_content ) -> str: -- cgit 1.5.1 From 915b37e5efd4e0fb9e57ce9895300017b4b3dd43 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 11 Aug 2021 21:29:59 +0200 Subject: Admin API to delete media for a specific user (#10558) --- changelog.d/10558.feature | 1 + docs/admin_api/media_admin_api.md | 9 +- docs/admin_api/user_admin_api.md | 54 ++++- synapse/rest/admin/media.py | 4 +- synapse/rest/admin/users.py | 80 +++++++- synapse/rest/media/v1/media_repository.py | 6 +- tests/rest/admin/test_user.py | 321 +++++++++++++++++++----------- 7 files changed, 347 insertions(+), 128 deletions(-) create mode 100644 changelog.d/10558.feature (limited to 'synapse') diff --git a/changelog.d/10558.feature b/changelog.d/10558.feature new file mode 100644 index 0000000000..1f461bc70a --- /dev/null +++ b/changelog.d/10558.feature @@ -0,0 +1 @@ +Admin API to delete several media for a specific user. Contributed by @dklimpel. diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index 61bed1e0d5..ea05bd6e44 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -12,6 +12,7 @@ - [Delete local media](#delete-local-media) * [Delete a specific local media](#delete-a-specific-local-media) * [Delete local media by date or size](#delete-local-media-by-date-or-size) + * [Delete media uploaded by a user](#delete-media-uploaded-by-a-user) - [Purge Remote Media API](#purge-remote-media-api) # Querying media @@ -47,7 +48,8 @@ The API returns a JSON body like the following: ## List all media uploaded by a user Listing all media that has been uploaded by a local user can be achieved through -the use of the [List media of a user](user_admin_api.md#list-media-of-a-user) +the use of the +[List media uploaded by a user](user_admin_api.md#list-media-uploaded-by-a-user) Admin API. # Quarantine media @@ -281,6 +283,11 @@ The following fields are returned in the JSON response body: * `deleted_media`: an array of strings - List of deleted `media_id` * `total`: integer - Total number of deleted `media_id` +## Delete media uploaded by a user + +You can find details of how to delete multiple media uploaded by a user in +[User Admin API](user_admin_api.md#delete-media-uploaded-by-a-user). + # Purge Remote Media API The purge remote media API allows server admins to purge old cached remote media. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 160899754e..33811f5bbb 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -443,8 +443,9 @@ The following fields are returned in the JSON response body: - `joined_rooms` - An array of `room_id`. - `total` - Number of rooms. +## User media -## List media of a user +### List media uploaded by a user Gets a list of all local media that a specific `user_id` has created. By default, the response is ordered by descending creation date and ascending media ID. The newest media is on top. You can change the order with parameters @@ -543,7 +544,6 @@ The following fields are returned in the JSON response body: - `media` - An array of objects, each containing information about a media. Media objects contain the following fields: - - `created_ts` - integer - Timestamp when the content was uploaded in ms. - `last_access_ts` - integer - Timestamp when the content was last accessed in ms. - `media_id` - string - The id used to refer to the media. @@ -551,13 +551,58 @@ The following fields are returned in the JSON response body: - `media_type` - string - The MIME-type of the media. - `quarantined_by` - string - The user ID that initiated the quarantine request for this media. - - `safe_from_quarantine` - bool - Status if this media is safe from quarantining. - `upload_name` - string - The name the media was uploaded with. - - `next_token`: integer - Indication for pagination. See above. - `total` - integer - Total number of media. +### Delete media uploaded by a user + +This API deletes the *local* media from the disk of your own server +that a specific `user_id` has created. This includes any local thumbnails. + +This API will not affect media that has been uploaded to external +media repositories (e.g https://github.com/turt2live/matrix-media-repo/). + +By default, the API deletes media ordered by descending creation date and ascending media ID. +The newest media is deleted first. You can change the order with parameters +`order_by` and `dir`. If no `limit` is set the API deletes `100` files per request. + +The API is: + +``` +DELETE /_synapse/admin/v1/users//media +``` + +To use it, you will need to authenticate by providing an `access_token` for a +server admin: [Admin API](../usage/administration/admin_api) + +A response body like the following is returned: + +```json +{ + "deleted_media": [ + "abcdefghijklmnopqrstuvwx" + ], + "total": 1 +} +``` + +The following fields are returned in the JSON response body: + +* `deleted_media`: an array of strings - List of deleted `media_id` +* `total`: integer - Total number of deleted `media_id` + +**Note**: There is no `next_token`. This is not useful for deleting media, because +after deleting media the remaining media have a new order. + +**Parameters** + +This API has the same parameters as +[List media uploaded by a user](#list-media-uploaded-by-a-user). +With the parameters you can for example limit the number of files to delete at once or +delete largest/smallest or newest/oldest files first. + ## Login as a user Get an access token that can be used to authenticate as that user. Useful for @@ -1012,4 +1057,3 @@ The following parameters should be set in the URL: - `user_id` - The fully qualified MXID: for example, `@user:server.com`. The user must be local. - diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 0a19a333d7..5f0555039d 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -259,7 +259,9 @@ class DeleteMediaByID(RestServlet): logging.info("Deleting local media by ID: %s", media_id) - deleted_media, total = await self.media_repository.delete_local_media(media_id) + deleted_media, total = await self.media_repository.delete_local_media_ids( + [media_id] + ) return 200, {"deleted_media": deleted_media, "total": total} diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index eef76ab18a..41f21ba118 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -172,7 +172,7 @@ class UserRestServletV2(RestServlet): target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(400, "Can only look up local users") ret = await self.admin_handler.get_user(target_user) @@ -796,7 +796,7 @@ class PushersRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(400, "Can only look up local users") if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") @@ -811,10 +811,10 @@ class PushersRestServlet(RestServlet): class UserMediaRestServlet(RestServlet): """ Gets information about all uploaded local media for a specific `user_id`. + With DELETE request you can delete all this media. Example: - http://localhost:8008/_synapse/admin/v1/users/ - @user:server/media + http://localhost:8008/_synapse/admin/v1/users/@user:server/media Args: The parameters `from` and `limit` are required for pagination. @@ -830,6 +830,7 @@ class UserMediaRestServlet(RestServlet): self.is_mine = hs.is_mine self.auth = hs.get_auth() self.store = hs.get_datastore() + self.media_repository = hs.get_media_repository() async def on_GET( self, request: SynapseRequest, user_id: str @@ -840,7 +841,7 @@ class UserMediaRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(400, "Can only look up local users") user = await self.store.get_user_by_id(user_id) if user is None: @@ -898,6 +899,73 @@ class UserMediaRestServlet(RestServlet): return 200, ret + async def on_DELETE( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + # This will always be set by the time Twisted calls us. + assert request.args is not None + + await assert_requester_is_admin(self.auth, request) + + if not self.is_mine(UserID.from_string(user_id)): + raise SynapseError(400, "Can only look up local users") + + user = await self.store.get_user_by_id(user_id) + if user is None: + raise NotFoundError("Unknown user") + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + 400, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + # If neither `order_by` nor `dir` is set, set the default order + # to newest media is on top for backward compatibility. + if b"order_by" not in request.args and b"dir" not in request.args: + order_by = MediaSortOrder.CREATED_TS.value + direction = "b" + else: + order_by = parse_string( + request, + "order_by", + default=MediaSortOrder.CREATED_TS.value, + allowed_values=( + MediaSortOrder.MEDIA_ID.value, + MediaSortOrder.UPLOAD_NAME.value, + MediaSortOrder.CREATED_TS.value, + MediaSortOrder.LAST_ACCESS_TS.value, + MediaSortOrder.MEDIA_LENGTH.value, + MediaSortOrder.MEDIA_TYPE.value, + MediaSortOrder.QUARANTINED_BY.value, + MediaSortOrder.SAFE_FROM_QUARANTINE.value, + ), + ) + direction = parse_string( + request, "dir", default="f", allowed_values=("f", "b") + ) + + media, _ = await self.store.get_local_media_by_user_paginate( + start, limit, user_id, order_by, direction + ) + + deleted_media, total = await self.media_repository.delete_local_media_ids( + ([row["media_id"] for row in media]) + ) + + return 200, {"deleted_media": deleted_media, "total": total} + class UserTokenRestServlet(RestServlet): """An admin API for logging in as a user. @@ -1017,7 +1085,7 @@ class RateLimitRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(400, "Can only look up local users") if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 4f702f890c..0f5ce41ff8 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -836,7 +836,9 @@ class MediaRepository: return {"deleted": deleted} - async def delete_local_media(self, media_id: str) -> Tuple[List[str], int]: + async def delete_local_media_ids( + self, media_ids: List[str] + ) -> Tuple[List[str], int]: """ Delete the given local or remote media ID from this server @@ -845,7 +847,7 @@ class MediaRepository: Returns: A tuple of (list of deleted media IDs, total deleted media IDs). """ - return await self._remove_local_media_from_disk([media_id]) + return await self._remove_local_media_from_disk(media_ids) async def delete_old_local_media( self, diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 42f50c0921..13fab5579b 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -15,17 +15,21 @@ import hashlib import hmac import json +import os import urllib.parse from binascii import unhexlify from typing import List, Optional from unittest.mock import Mock, patch +from parameterized import parameterized + import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions from synapse.rest.client.v1 import login, logout, profile, room from synapse.rest.client.v2_alpha import devices, sync +from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.types import JsonDict, UserID from tests import unittest @@ -72,7 +76,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "Shared secret registration is not enabled", channel.json_body["error"] ) @@ -104,7 +108,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = json.dumps({"nonce": nonce}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # 61 seconds @@ -112,7 +116,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_register_incorrect_nonce(self): @@ -166,7 +170,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) def test_nonce_reuse(self): @@ -191,13 +195,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_missing_parts(self): @@ -219,7 +223,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = json.dumps({}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("nonce must be specified", channel.json_body["error"]) # @@ -230,28 +234,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = json.dumps({"nonce": nonce()}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": 1234}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # @@ -262,28 +266,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = json.dumps({"nonce": nonce(), "username": "a"}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("password must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Super long body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # @@ -301,7 +305,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid user type", channel.json_body["error"]) def test_displayname(self): @@ -322,11 +326,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob1:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob1:test/displayname") - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("bob1", channel.json_body["displayname"]) # displayname is None @@ -348,11 +352,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob2:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob2:test/displayname") - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("bob2", channel.json_body["displayname"]) # displayname is empty @@ -374,7 +378,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob3:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob3:test/displayname") @@ -399,11 +403,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob4:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob4:test/displayname") - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("Bob's Name", channel.json_body["displayname"]) @override_config( @@ -449,7 +453,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -638,7 +642,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid search order @@ -1085,7 +1089,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content={"erase": False}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -2180,7 +2184,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) + self.assertEqual("Can only look up local users", channel.json_body["error"]) def test_get_pushers(self): """ @@ -2249,6 +2253,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.media_repo = hs.get_media_repository_resource() + self.filepaths = MediaFilePaths(hs.config.media_store_path) self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -2258,37 +2263,34 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.other_user ) - def test_no_auth(self): - """ - Try to list media of an user without authentication. - """ - channel = self.make_request("GET", self.url, b"{}") + @parameterized.expand(["GET", "DELETE"]) + def test_no_auth(self, method: str): + """Try to list media of an user without authentication.""" + channel = self.make_request(method, self.url, {}) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): - """ - If the user is not a server admin, an error is returned. - """ + @parameterized.expand(["GET", "DELETE"]) + def test_requester_is_no_admin(self, method: str): + """If the user is not a server admin, an error is returned.""" other_user_token = self.login("user", "pass") channel = self.make_request( - "GET", + method, self.url, access_token=other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self): - """ - Tests that a lookup for a user that does not exist returns a 404 - """ + @parameterized.expand(["GET", "DELETE"]) + def test_user_does_not_exist(self, method: str): + """Tests that a lookup for a user that does not exist returns a 404""" url = "/_synapse/admin/v1/users/@unknown_person:test/media" channel = self.make_request( - "GET", + method, url, access_token=self.admin_user_tok, ) @@ -2296,25 +2298,22 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_user_is_not_local(self): - """ - Tests that a lookup for a user that is not a local returns a 400 - """ + @parameterized.expand(["GET", "DELETE"]) + def test_user_is_not_local(self, method: str): + """Tests that a lookup for a user that is not a local returns a 400""" url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" channel = self.make_request( - "GET", + method, url, access_token=self.admin_user_tok, ) self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) + self.assertEqual("Can only look up local users", channel.json_body["error"]) - def test_limit(self): - """ - Testing list of media with limit - """ + def test_limit_GET(self): + """Testing list of media with limit""" number_media = 20 other_user_tok = self.login("user", "pass") @@ -2326,16 +2325,31 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 5) self.assertEqual(channel.json_body["next_token"], 5) self._check_fields(channel.json_body["media"]) - def test_from(self): - """ - Testing list of media with a defined starting point (from) - """ + def test_limit_DELETE(self): + """Testing delete of media with limit""" + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media_for_user(other_user_tok, number_media) + + channel = self.make_request( + "DELETE", + self.url + "?limit=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 5) + self.assertEqual(len(channel.json_body["deleted_media"]), 5) + + def test_from_GET(self): + """Testing list of media with a defined starting point (from)""" number_media = 20 other_user_tok = self.login("user", "pass") @@ -2347,16 +2361,31 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 15) self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["media"]) - def test_limit_and_from(self): - """ - Testing list of media with a defined starting point and limit - """ + def test_from_DELETE(self): + """Testing delete of media with a defined starting point (from)""" + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media_for_user(other_user_tok, number_media) + + channel = self.make_request( + "DELETE", + self.url + "?from=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 15) + self.assertEqual(len(channel.json_body["deleted_media"]), 15) + + def test_limit_and_from_GET(self): + """Testing list of media with a defined starting point and limit""" number_media = 20 other_user_tok = self.login("user", "pass") @@ -2368,59 +2397,78 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["media"]), 10) self._check_fields(channel.json_body["media"]) - def test_invalid_parameter(self): - """ - If parameters are invalid, an error is returned. - """ + def test_limit_and_from_DELETE(self): + """Testing delete of media with a defined starting point and limit""" + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media_for_user(other_user_tok, number_media) + + channel = self.make_request( + "DELETE", + self.url + "?from=5&limit=10", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 10) + self.assertEqual(len(channel.json_body["deleted_media"]), 10) + + @parameterized.expand(["GET", "DELETE"]) + def test_invalid_parameter(self, method: str): + """If parameters are invalid, an error is returned.""" # unkown order_by channel = self.make_request( - "GET", + method, self.url + "?order_by=bar", access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid search order channel = self.make_request( - "GET", + method, self.url + "?dir=bar", access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # negative limit channel = self.make_request( - "GET", + method, self.url + "?limit=-5", access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from channel = self.make_request( - "GET", + method, self.url + "?from=-5", access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_next_token(self): """ Testing that `next_token` appears at the right place + + For deleting media `next_token` is not useful, because + after deleting media the media has a new order. """ number_media = 20 @@ -2435,7 +2483,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -2448,7 +2496,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -2461,7 +2509,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -2475,12 +2523,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 1) self.assertNotIn("next_token", channel.json_body) - def test_user_has_no_media(self): + def test_user_has_no_media_GET(self): """ Tests that a normal lookup for media is successfully if user has no media created @@ -2496,11 +2544,24 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["media"])) - def test_get_media(self): + def test_user_has_no_media_DELETE(self): """ - Tests that a normal lookup for media is successfully + Tests that a delete is successful if user has no media """ + channel = self.make_request( + "DELETE", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["deleted_media"])) + + def test_get_media(self): + """Tests that a normal lookup for media is successful""" + number_media = 5 other_user_tok = self.login("user", "pass") self._create_media_for_user(other_user_tok, number_media) @@ -2517,6 +2578,35 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["media"]) + def test_delete_media(self): + """Tests that a normal delete of media is successful""" + + number_media = 5 + other_user_tok = self.login("user", "pass") + media_ids = self._create_media_for_user(other_user_tok, number_media) + + # Test if the file exists + local_paths = [] + for media_id in media_ids: + local_path = self.filepaths.local_media_filepath(media_id) + self.assertTrue(os.path.exists(local_path)) + local_paths.append(local_path) + + channel = self.make_request( + "DELETE", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(number_media, channel.json_body["total"]) + self.assertEqual(number_media, len(channel.json_body["deleted_media"])) + self.assertCountEqual(channel.json_body["deleted_media"], media_ids) + + # Test if the file is deleted + for local_path in local_paths: + self.assertFalse(os.path.exists(local_path)) + def test_order_by(self): """ Testing order list with parameter `order_by` @@ -2622,13 +2712,16 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): [media2] + sorted([media1, media3]), "safe_from_quarantine", "b" ) - def _create_media_for_user(self, user_token: str, number_media: int): + def _create_media_for_user(self, user_token: str, number_media: int) -> List[str]: """ Create a number of media for a specific user Args: user_token: Access token of the user number_media: Number of media to be created for the user + Returns: + List of created media ID """ + media_ids = [] for _ in range(number_media): # file size is 67 Byte image_data = unhexlify( @@ -2637,7 +2730,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): b"0a2db40000000049454e44ae426082" ) - self._create_media_and_access(user_token, image_data) + media_ids.append(self._create_media_and_access(user_token, image_data)) + + return media_ids def _create_media_and_access( self, @@ -2680,7 +2775,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): 200, channel.code, msg=( - "Expected to receive a 200 on accessing media: %s" % server_and_media_id + f"Expected to receive a 200 on accessing media: {server_and_media_id}" ), ) @@ -2718,12 +2813,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): url = self.url + "?" if order_by is not None: - url += "order_by=%s&" % (order_by,) + url += f"order_by={order_by}&" if dir is not None and dir in ("b", "f"): - url += "dir=%s" % (dir,) + url += f"dir={dir}" channel = self.make_request( "GET", - url.encode("ascii"), + url, access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -2762,7 +2857,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, b"{}", access_token=self.admin_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) return channel.json_body["access_token"] def test_no_auth(self): @@ -2803,7 +2898,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # We should only see the one device (from the login in `prepare`) self.assertEqual(len(channel.json_body["devices"]), 1) @@ -2815,11 +2910,11 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout with the puppet token channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) @@ -2829,7 +2924,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) def test_user_logout_all(self): """Tests that the target user calling `/logout/all` does *not* expire @@ -2840,17 +2935,17 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout all with the real user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should still work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # .. but the real user's tokens shouldn't channel = self.make_request( @@ -2867,13 +2962,13 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout all with the admin user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.admin_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) @@ -2883,7 +2978,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) @unittest.override_config( { @@ -3243,7 +3338,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): ) self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) + self.assertEqual("Can only look up local users", channel.json_body["error"]) channel = self.make_request( "POST", @@ -3279,7 +3374,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"messages_per_second": "string"}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # messages_per_second is negative @@ -3290,7 +3385,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"messages_per_second": -1}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is a string @@ -3301,7 +3396,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"burst_count": "string"}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is negative @@ -3312,7 +3407,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"burst_count": -1}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_return_zero_when_null(self): @@ -3337,7 +3432,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["messages_per_second"]) self.assertEqual(0, channel.json_body["burst_count"]) @@ -3351,7 +3446,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3362,7 +3457,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"messages_per_second": 10, "burst_count": 11}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(10, channel.json_body["messages_per_second"]) self.assertEqual(11, channel.json_body["burst_count"]) @@ -3373,7 +3468,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"messages_per_second": 20, "burst_count": 21}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3383,7 +3478,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3393,7 +3488,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3403,6 +3498,6 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) -- cgit 1.5.1 From 98a3355d9a58538cfbc1c88020e6b6d9bccea516 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 11 Aug 2021 15:44:45 -0400 Subject: Update the pagination parameter name based on MSC2946 review. (#10579) --- changelog.d/10579.feature | 1 + synapse/handlers/space_summary.py | 6 +++--- tests/handlers/test_space_summary.py | 14 +++++++------- 3 files changed, 11 insertions(+), 10 deletions(-) create mode 100644 changelog.d/10579.feature (limited to 'synapse') diff --git a/changelog.d/10579.feature b/changelog.d/10579.feature new file mode 100644 index 0000000000..ffc4e4289c --- /dev/null +++ b/changelog.d/10579.feature @@ -0,0 +1 @@ +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 893546e661..d0060f9046 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -412,10 +412,10 @@ class SpaceSummaryHandler: # If there's additional data, generate a pagination token (and persist state). if room_queue: - next_token = random_string(24) - result["next_token"] = next_token + next_batch = random_string(24) + result["next_batch"] = next_batch pagination_key = _PaginationKey( - requested_room_id, suggested_only, max_depth, next_token + requested_room_id, suggested_only, max_depth, next_batch ) self._pagination_sessions[pagination_key] = _PaginationSession( self._clock.time_msec(), room_queue, processed_rooms diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index 806b886fe4..83c2bdd8f9 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -466,19 +466,19 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): expected: List[Tuple[str, Iterable[str]]] = [(self.space, room_ids)] expected += [(room_id, ()) for room_id in room_ids[:6]] self._assert_hierarchy(result, expected) - self.assertIn("next_token", result) + self.assertIn("next_batch", result) # Check the next page. result = self.get_success( self.handler.get_room_hierarchy( - self.user, self.space, limit=5, from_token=result["next_token"] + self.user, self.space, limit=5, from_token=result["next_batch"] ) ) # The result should have the space and the room in it, along with a link # from space -> room. expected = [(room_id, ()) for room_id in room_ids[6:]] self._assert_hierarchy(result, expected) - self.assertNotIn("next_token", result) + self.assertNotIn("next_batch", result) def test_invalid_pagination_token(self): """""" @@ -493,12 +493,12 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result = self.get_success( self.handler.get_room_hierarchy(self.user, self.space, limit=7) ) - self.assertIn("next_token", result) + self.assertIn("next_batch", result) # Changing the room ID, suggested-only, or max-depth causes an error. self.get_failure( self.handler.get_room_hierarchy( - self.user, self.room, from_token=result["next_token"] + self.user, self.room, from_token=result["next_batch"] ), SynapseError, ) @@ -507,13 +507,13 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self.user, self.space, suggested_only=True, - from_token=result["next_token"], + from_token=result["next_batch"], ), SynapseError, ) self.get_failure( self.handler.get_room_hierarchy( - self.user, self.space, max_depth=0, from_token=result["next_token"] + self.user, self.space, max_depth=0, from_token=result["next_batch"] ), SynapseError, ) -- cgit 1.5.1 From c12b5577f22ee587b60ad7b65e88322ce1d86b7b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 13 Aug 2021 07:49:06 -0400 Subject: Fix a harmless exception when the staged events queue is empty. (#10592) --- changelog.d/10592.bugfix | 1 + synapse/federation/federation_server.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10592.bugfix (limited to 'synapse') diff --git a/changelog.d/10592.bugfix b/changelog.d/10592.bugfix new file mode 100644 index 0000000000..efcdab1136 --- /dev/null +++ b/changelog.d/10592.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.37.1 where an error could occur in the asyncronous processing of PDUs when the queue was empty. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 0385aadefa..78d5aac6af 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -972,13 +972,18 @@ class FederationServer(FederationBase): # the room, so instead of pulling the event out of the DB and parsing # the event we just pull out the next event ID and check if that matches. if latest_event is not None and latest_origin is not None: - ( - next_origin, - next_event_id, - ) = await self.store.get_next_staged_event_id_for_room(room_id) - if next_origin != latest_origin or next_event_id != latest_event.event_id: + result = await self.store.get_next_staged_event_id_for_room(room_id) + if result is None: latest_origin = None latest_event = None + else: + next_origin, next_event_id = result + if ( + next_origin != latest_origin + or next_event_id != latest_event.event_id + ): + latest_origin = None + latest_event = None if latest_origin is None or latest_event is None: next = await self.store.get_next_staged_event_for_room( -- cgit 1.5.1 From c8d54be44c1da451f01504664d568dd2f2b37316 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 13 Aug 2021 14:37:24 -0500 Subject: Move /batch_send to /v2_alpha directory (MSC2716) (#10576) * Move /batch_send to /v2_alpha directory As pointed out by @erikjohnston, https://github.com/matrix-org/synapse/pull/10552#discussion_r685836624 --- changelog.d/10576.misc | 1 + synapse/rest/__init__.py | 2 + synapse/rest/client/v1/room.py | 410 +------------------------------- synapse/rest/client/v2_alpha/room.py | 441 +++++++++++++++++++++++++++++++++++ 4 files changed, 445 insertions(+), 409 deletions(-) create mode 100644 changelog.d/10576.misc create mode 100644 synapse/rest/client/v2_alpha/room.py (limited to 'synapse') diff --git a/changelog.d/10576.misc b/changelog.d/10576.misc new file mode 100644 index 0000000000..f9f9c9a6fd --- /dev/null +++ b/changelog.d/10576.misc @@ -0,0 +1 @@ +Move `/batch_send` endpoint defined by [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) to the `/v2_alpha` directory. diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index d29f2fea5e..9cffe59ce5 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -47,6 +47,7 @@ from synapse.rest.client.v2_alpha import ( register, relations, report_event, + room as roomv2, room_keys, room_upgrade_rest_servlet, sendtodevice, @@ -117,6 +118,7 @@ class ClientRestResource(JsonResource): user_directory.register_servlets(hs, client_resource) groups.register_servlets(hs, client_resource) room_upgrade_rest_servlet.register_servlets(hs, client_resource) + roomv2.register_servlets(hs, client_resource) capabilities.register_servlets(hs, client_resource) account_validity.register_servlets(hs, client_resource) relations.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index f1bc43be2d..2c3be23bc8 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -19,7 +19,7 @@ import re from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from urllib import parse as urlparse -from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -28,7 +28,6 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.filtering import Filter -from synapse.appservice import ApplicationService from synapse.events.utils import format_event_for_client_v2 from synapse.http.servlet import ( RestServlet, @@ -47,13 +46,11 @@ from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import ( JsonDict, - Requester, RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID, - create_requester, ) from synapse.util import json_decoder from synapse.util.stringutils import parse_and_validate_server_name, random_string @@ -268,407 +265,6 @@ class RoomSendEventRestServlet(TransactionRestServlet): ) -class RoomBatchSendEventRestServlet(TransactionRestServlet): - """ - API endpoint which can insert a chunk of events historically back in time - next to the given `prev_event`. - - `chunk_id` comes from `next_chunk_id `in the response of the batch send - endpoint and is derived from the "insertion" events added to each chunk. - It's not required for the first batch send. - - `state_events_at_start` is used to define the historical state events - needed to auth the events like join events. These events will float - outside of the normal DAG as outlier's and won't be visible in the chat - history which also allows us to insert multiple chunks without having a bunch - of `@mxid joined the room` noise between each chunk. - - `events` is chronological chunk/list of events you want to insert. - There is a reverse-chronological constraint on chunks so once you insert - some messages, you can only insert older ones after that. - tldr; Insert chunks from your most recent history -> oldest history. - - POST /_matrix/client/unstable/org.matrix.msc2716/rooms//batch_send?prev_event=&chunk_id= - { - "events": [ ... ], - "state_events_at_start": [ ... ] - } - """ - - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc2716" - "/rooms/(?P[^/]*)/batch_send$" - ), - ) - - def __init__(self, hs): - super().__init__(hs) - self.hs = hs - self.store = hs.get_datastore() - self.state_store = hs.get_storage().state - self.event_creation_handler = hs.get_event_creation_handler() - self.room_member_handler = hs.get_room_member_handler() - self.auth = hs.get_auth() - - async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int: - ( - most_recent_prev_event_id, - most_recent_prev_event_depth, - ) = await self.store.get_max_depth_of(prev_event_ids) - - # We want to insert the historical event after the `prev_event` but before the successor event - # - # We inherit depth from the successor event instead of the `prev_event` - # because events returned from `/messages` are first sorted by `topological_ordering` - # which is just the `depth` and then tie-break with `stream_ordering`. - # - # We mark these inserted historical events as "backfilled" which gives them a - # negative `stream_ordering`. If we use the same depth as the `prev_event`, - # then our historical event will tie-break and be sorted before the `prev_event` - # when it should come after. - # - # We want to use the successor event depth so they appear after `prev_event` because - # it has a larger `depth` but before the successor event because the `stream_ordering` - # is negative before the successor event. - successor_event_ids = await self.store.get_successor_events( - [most_recent_prev_event_id] - ) - - # If we can't find any successor events, then it's a forward extremity of - # historical messages and we can just inherit from the previous historical - # event which we can already assume has the correct depth where we want - # to insert into. - if not successor_event_ids: - depth = most_recent_prev_event_depth - else: - ( - _, - oldest_successor_depth, - ) = await self.store.get_min_depth_of(successor_event_ids) - - depth = oldest_successor_depth - - return depth - - def _create_insertion_event_dict( - self, sender: str, room_id: str, origin_server_ts: int - ): - """Creates an event dict for an "insertion" event with the proper fields - and a random chunk ID. - - Args: - sender: The event author MXID - room_id: The room ID that the event belongs to - origin_server_ts: Timestamp when the event was sent - - Returns: - Tuple of event ID and stream ordering position - """ - - next_chunk_id = random_string(8) - insertion_event = { - "type": EventTypes.MSC2716_INSERTION, - "sender": sender, - "room_id": room_id, - "content": { - EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id, - EventContentFields.MSC2716_HISTORICAL: True, - }, - "origin_server_ts": origin_server_ts, - } - - return insertion_event - - async def _create_requester_for_user_id_from_app_service( - self, user_id: str, app_service: ApplicationService - ) -> Requester: - """Creates a new requester for the given user_id - and validates that the app service is allowed to control - the given user. - - Args: - user_id: The author MXID that the app service is controlling - app_service: The app service that controls the user - - Returns: - Requester object - """ - - await self.auth.validate_appservice_can_control_user_id(app_service, user_id) - - return create_requester(user_id, app_service=app_service) - - async def on_POST(self, request, room_id): - requester = await self.auth.get_user_by_req(request, allow_guest=False) - - if not requester.app_service: - raise AuthError( - 403, - "Only application services can use the /batchsend endpoint", - ) - - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["state_events_at_start", "events"]) - - prev_events_from_query = parse_strings_from_args(request.args, "prev_event") - chunk_id_from_query = parse_string(request, "chunk_id") - - if prev_events_from_query is None: - raise SynapseError( - 400, - "prev_event query parameter is required when inserting historical messages back in time", - errcode=Codes.MISSING_PARAM, - ) - - # For the event we are inserting next to (`prev_events_from_query`), - # find the most recent auth events (derived from state events) that - # allowed that message to be sent. We will use that as a base - # to auth our historical messages against. - ( - most_recent_prev_event_id, - _, - ) = await self.store.get_max_depth_of(prev_events_from_query) - # mapping from (type, state_key) -> state_event_id - prev_state_map = await self.state_store.get_state_ids_for_event( - most_recent_prev_event_id - ) - # List of state event ID's - prev_state_ids = list(prev_state_map.values()) - auth_event_ids = prev_state_ids - - state_events_at_start = [] - for state_event in body["state_events_at_start"]: - assert_params_in_dict( - state_event, ["type", "origin_server_ts", "content", "sender"] - ) - - logger.debug( - "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s", - state_event, - auth_event_ids, - ) - - event_dict = { - "type": state_event["type"], - "origin_server_ts": state_event["origin_server_ts"], - "content": state_event["content"], - "room_id": room_id, - "sender": state_event["sender"], - "state_key": state_event["state_key"], - } - - # Mark all events as historical - event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - - # Make the state events float off on their own - fake_prev_event_id = "$" + random_string(43) - - # TODO: This is pretty much the same as some other code to handle inserting state in this file - if event_dict["type"] == EventTypes.Member: - membership = event_dict["content"].get("membership", None) - event_id, _ = await self.room_member_handler.update_membership( - await self._create_requester_for_user_id_from_app_service( - state_event["sender"], requester.app_service - ), - target=UserID.from_string(event_dict["state_key"]), - room_id=room_id, - action=membership, - content=event_dict["content"], - outlier=True, - prev_event_ids=[fake_prev_event_id], - # Make sure to use a copy of this list because we modify it - # later in the loop here. Otherwise it will be the same - # reference and also update in the event when we append later. - auth_event_ids=auth_event_ids.copy(), - ) - else: - # TODO: Add some complement tests that adds state that is not member joins - # and will use this code path. Maybe we only want to support join state events - # and can get rid of this `else`? - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - await self._create_requester_for_user_id_from_app_service( - state_event["sender"], requester.app_service - ), - event_dict, - outlier=True, - prev_event_ids=[fake_prev_event_id], - # Make sure to use a copy of this list because we modify it - # later in the loop here. Otherwise it will be the same - # reference and also update in the event when we append later. - auth_event_ids=auth_event_ids.copy(), - ) - event_id = event.event_id - - state_events_at_start.append(event_id) - auth_event_ids.append(event_id) - - events_to_create = body["events"] - - inherited_depth = await self._inherit_depth_from_prev_ids( - prev_events_from_query - ) - - # Figure out which chunk to connect to. If they passed in - # chunk_id_from_query let's use it. The chunk ID passed in comes - # from the chunk_id in the "insertion" event from the previous chunk. - last_event_in_chunk = events_to_create[-1] - chunk_id_to_connect_to = chunk_id_from_query - base_insertion_event = None - if chunk_id_from_query: - # All but the first base insertion event should point at a fake - # event, which causes the HS to ask for the state at the start of - # the chunk later. - prev_event_ids = [fake_prev_event_id] - # TODO: Verify the chunk_id_from_query corresponds to an insertion event - pass - # Otherwise, create an insertion event to act as a starting point. - # - # We don't always have an insertion event to start hanging more history - # off of (ideally there would be one in the main DAG, but that's not the - # case if we're wanting to add history to e.g. existing rooms without - # an insertion event), in which case we just create a new insertion event - # that can then get pointed to by a "marker" event later. - else: - prev_event_ids = prev_events_from_query - - base_insertion_event_dict = self._create_insertion_event_dict( - sender=requester.user.to_string(), - room_id=room_id, - origin_server_ts=last_event_in_chunk["origin_server_ts"], - ) - base_insertion_event_dict["prev_events"] = prev_event_ids.copy() - - ( - base_insertion_event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - await self._create_requester_for_user_id_from_app_service( - base_insertion_event_dict["sender"], - requester.app_service, - ), - base_insertion_event_dict, - prev_event_ids=base_insertion_event_dict.get("prev_events"), - auth_event_ids=auth_event_ids, - historical=True, - depth=inherited_depth, - ) - - chunk_id_to_connect_to = base_insertion_event["content"][ - EventContentFields.MSC2716_NEXT_CHUNK_ID - ] - - # Connect this current chunk to the insertion event from the previous chunk - chunk_event = { - "type": EventTypes.MSC2716_CHUNK, - "sender": requester.user.to_string(), - "room_id": room_id, - "content": { - EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to, - EventContentFields.MSC2716_HISTORICAL: True, - }, - # Since the chunk event is put at the end of the chunk, - # where the newest-in-time event is, copy the origin_server_ts from - # the last event we're inserting - "origin_server_ts": last_event_in_chunk["origin_server_ts"], - } - # Add the chunk event to the end of the chunk (newest-in-time) - events_to_create.append(chunk_event) - - # Add an "insertion" event to the start of each chunk (next to the oldest-in-time - # event in the chunk) so the next chunk can be connected to this one. - insertion_event = self._create_insertion_event_dict( - sender=requester.user.to_string(), - room_id=room_id, - # Since the insertion event is put at the start of the chunk, - # where the oldest-in-time event is, copy the origin_server_ts from - # the first event we're inserting - origin_server_ts=events_to_create[0]["origin_server_ts"], - ) - # Prepend the insertion event to the start of the chunk (oldest-in-time) - events_to_create = [insertion_event] + events_to_create - - event_ids = [] - events_to_persist = [] - for ev in events_to_create: - assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) - - event_dict = { - "type": ev["type"], - "origin_server_ts": ev["origin_server_ts"], - "content": ev["content"], - "room_id": room_id, - "sender": ev["sender"], # requester.user.to_string(), - "prev_events": prev_event_ids.copy(), - } - - # Mark all events as historical - event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - - event, context = await self.event_creation_handler.create_event( - await self._create_requester_for_user_id_from_app_service( - ev["sender"], requester.app_service - ), - event_dict, - prev_event_ids=event_dict.get("prev_events"), - auth_event_ids=auth_event_ids, - historical=True, - depth=inherited_depth, - ) - logger.debug( - "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s", - event, - prev_event_ids, - auth_event_ids, - ) - - assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( - event.sender, - ) - - events_to_persist.append((event, context)) - event_id = event.event_id - - event_ids.append(event_id) - prev_event_ids = [event_id] - - # Persist events in reverse-chronological order so they have the - # correct stream_ordering as they are backfilled (which decrements). - # Events are sorted by (topological_ordering, stream_ordering) - # where topological_ordering is just depth. - for (event, context) in reversed(events_to_persist): - ev = await self.event_creation_handler.handle_new_client_event( - await self._create_requester_for_user_id_from_app_service( - event["sender"], requester.app_service - ), - event=event, - context=context, - ) - - # Add the base_insertion_event to the bottom of the list we return - if base_insertion_event is not None: - event_ids.append(base_insertion_event.event_id) - - return 200, { - "state_events": state_events_at_start, - "events": event_ids, - "next_chunk_id": insertion_event["content"][ - EventContentFields.MSC2716_NEXT_CHUNK_ID - ], - } - - def on_GET(self, request, room_id): - return 501, "Not implemented" - - def on_PUT(self, request, room_id): - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id - ) - - # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(TransactionRestServlet): def __init__(self, hs): @@ -1488,8 +1084,6 @@ class RoomHierarchyRestServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server, is_worker=False): - msc2716_enabled = hs.config.experimental.msc2716_enabled - RoomStateEventRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server) JoinedRoomMemberListRestServlet(hs).register(http_server) @@ -1497,8 +1091,6 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False): JoinRoomAliasServlet(hs).register(http_server) RoomMembershipRestServlet(hs).register(http_server) RoomSendEventRestServlet(hs).register(http_server) - if msc2716_enabled: - RoomBatchSendEventRestServlet(hs).register(http_server) PublicRoomListRestServlet(hs).register(http_server) RoomStateRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/room.py b/synapse/rest/client/v2_alpha/room.py new file mode 100644 index 0000000000..3172aba605 --- /dev/null +++ b/synapse/rest/client/v2_alpha/room.py @@ -0,0 +1,441 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.appservice import ApplicationService +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, + parse_string, + parse_strings_from_args, +) +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.types import Requester, UserID, create_requester +from synapse.util.stringutils import random_string + +logger = logging.getLogger(__name__) + + +class RoomBatchSendEventRestServlet(RestServlet): + """ + API endpoint which can insert a chunk of events historically back in time + next to the given `prev_event`. + + `chunk_id` comes from `next_chunk_id `in the response of the batch send + endpoint and is derived from the "insertion" events added to each chunk. + It's not required for the first batch send. + + `state_events_at_start` is used to define the historical state events + needed to auth the events like join events. These events will float + outside of the normal DAG as outlier's and won't be visible in the chat + history which also allows us to insert multiple chunks without having a bunch + of `@mxid joined the room` noise between each chunk. + + `events` is chronological chunk/list of events you want to insert. + There is a reverse-chronological constraint on chunks so once you insert + some messages, you can only insert older ones after that. + tldr; Insert chunks from your most recent history -> oldest history. + + POST /_matrix/client/unstable/org.matrix.msc2716/rooms//batch_send?prev_event=&chunk_id= + { + "events": [ ... ], + "state_events_at_start": [ ... ] + } + """ + + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc2716" + "/rooms/(?P[^/]*)/batch_send$" + ), + ) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.store = hs.get_datastore() + self.state_store = hs.get_storage().state + self.event_creation_handler = hs.get_event_creation_handler() + self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() + self.txns = HttpTransactionCache(hs) + + async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int: + ( + most_recent_prev_event_id, + most_recent_prev_event_depth, + ) = await self.store.get_max_depth_of(prev_event_ids) + + # We want to insert the historical event after the `prev_event` but before the successor event + # + # We inherit depth from the successor event instead of the `prev_event` + # because events returned from `/messages` are first sorted by `topological_ordering` + # which is just the `depth` and then tie-break with `stream_ordering`. + # + # We mark these inserted historical events as "backfilled" which gives them a + # negative `stream_ordering`. If we use the same depth as the `prev_event`, + # then our historical event will tie-break and be sorted before the `prev_event` + # when it should come after. + # + # We want to use the successor event depth so they appear after `prev_event` because + # it has a larger `depth` but before the successor event because the `stream_ordering` + # is negative before the successor event. + successor_event_ids = await self.store.get_successor_events( + [most_recent_prev_event_id] + ) + + # If we can't find any successor events, then it's a forward extremity of + # historical messages and we can just inherit from the previous historical + # event which we can already assume has the correct depth where we want + # to insert into. + if not successor_event_ids: + depth = most_recent_prev_event_depth + else: + ( + _, + oldest_successor_depth, + ) = await self.store.get_min_depth_of(successor_event_ids) + + depth = oldest_successor_depth + + return depth + + def _create_insertion_event_dict( + self, sender: str, room_id: str, origin_server_ts: int + ): + """Creates an event dict for an "insertion" event with the proper fields + and a random chunk ID. + + Args: + sender: The event author MXID + room_id: The room ID that the event belongs to + origin_server_ts: Timestamp when the event was sent + + Returns: + Tuple of event ID and stream ordering position + """ + + next_chunk_id = random_string(8) + insertion_event = { + "type": EventTypes.MSC2716_INSERTION, + "sender": sender, + "room_id": room_id, + "content": { + EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id, + EventContentFields.MSC2716_HISTORICAL: True, + }, + "origin_server_ts": origin_server_ts, + } + + return insertion_event + + async def _create_requester_for_user_id_from_app_service( + self, user_id: str, app_service: ApplicationService + ) -> Requester: + """Creates a new requester for the given user_id + and validates that the app service is allowed to control + the given user. + + Args: + user_id: The author MXID that the app service is controlling + app_service: The app service that controls the user + + Returns: + Requester object + """ + + await self.auth.validate_appservice_can_control_user_id(app_service, user_id) + + return create_requester(user_id, app_service=app_service) + + async def on_POST(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=False) + + if not requester.app_service: + raise AuthError( + 403, + "Only application services can use the /batchsend endpoint", + ) + + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["state_events_at_start", "events"]) + + prev_events_from_query = parse_strings_from_args(request.args, "prev_event") + chunk_id_from_query = parse_string(request, "chunk_id") + + if prev_events_from_query is None: + raise SynapseError( + 400, + "prev_event query parameter is required when inserting historical messages back in time", + errcode=Codes.MISSING_PARAM, + ) + + # For the event we are inserting next to (`prev_events_from_query`), + # find the most recent auth events (derived from state events) that + # allowed that message to be sent. We will use that as a base + # to auth our historical messages against. + ( + most_recent_prev_event_id, + _, + ) = await self.store.get_max_depth_of(prev_events_from_query) + # mapping from (type, state_key) -> state_event_id + prev_state_map = await self.state_store.get_state_ids_for_event( + most_recent_prev_event_id + ) + # List of state event ID's + prev_state_ids = list(prev_state_map.values()) + auth_event_ids = prev_state_ids + + state_events_at_start = [] + for state_event in body["state_events_at_start"]: + assert_params_in_dict( + state_event, ["type", "origin_server_ts", "content", "sender"] + ) + + logger.debug( + "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s", + state_event, + auth_event_ids, + ) + + event_dict = { + "type": state_event["type"], + "origin_server_ts": state_event["origin_server_ts"], + "content": state_event["content"], + "room_id": room_id, + "sender": state_event["sender"], + "state_key": state_event["state_key"], + } + + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + + # Make the state events float off on their own + fake_prev_event_id = "$" + random_string(43) + + # TODO: This is pretty much the same as some other code to handle inserting state in this file + if event_dict["type"] == EventTypes.Member: + membership = event_dict["content"].get("membership", None) + event_id, _ = await self.room_member_handler.update_membership( + await self._create_requester_for_user_id_from_app_service( + state_event["sender"], requester.app_service + ), + target=UserID.from_string(event_dict["state_key"]), + room_id=room_id, + action=membership, + content=event_dict["content"], + outlier=True, + prev_event_ids=[fake_prev_event_id], + # Make sure to use a copy of this list because we modify it + # later in the loop here. Otherwise it will be the same + # reference and also update in the event when we append later. + auth_event_ids=auth_event_ids.copy(), + ) + else: + # TODO: Add some complement tests that adds state that is not member joins + # and will use this code path. Maybe we only want to support join state events + # and can get rid of this `else`? + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + await self._create_requester_for_user_id_from_app_service( + state_event["sender"], requester.app_service + ), + event_dict, + outlier=True, + prev_event_ids=[fake_prev_event_id], + # Make sure to use a copy of this list because we modify it + # later in the loop here. Otherwise it will be the same + # reference and also update in the event when we append later. + auth_event_ids=auth_event_ids.copy(), + ) + event_id = event.event_id + + state_events_at_start.append(event_id) + auth_event_ids.append(event_id) + + events_to_create = body["events"] + + inherited_depth = await self._inherit_depth_from_prev_ids( + prev_events_from_query + ) + + # Figure out which chunk to connect to. If they passed in + # chunk_id_from_query let's use it. The chunk ID passed in comes + # from the chunk_id in the "insertion" event from the previous chunk. + last_event_in_chunk = events_to_create[-1] + chunk_id_to_connect_to = chunk_id_from_query + base_insertion_event = None + if chunk_id_from_query: + # All but the first base insertion event should point at a fake + # event, which causes the HS to ask for the state at the start of + # the chunk later. + prev_event_ids = [fake_prev_event_id] + # TODO: Verify the chunk_id_from_query corresponds to an insertion event + pass + # Otherwise, create an insertion event to act as a starting point. + # + # We don't always have an insertion event to start hanging more history + # off of (ideally there would be one in the main DAG, but that's not the + # case if we're wanting to add history to e.g. existing rooms without + # an insertion event), in which case we just create a new insertion event + # that can then get pointed to by a "marker" event later. + else: + prev_event_ids = prev_events_from_query + + base_insertion_event_dict = self._create_insertion_event_dict( + sender=requester.user.to_string(), + room_id=room_id, + origin_server_ts=last_event_in_chunk["origin_server_ts"], + ) + base_insertion_event_dict["prev_events"] = prev_event_ids.copy() + + ( + base_insertion_event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + await self._create_requester_for_user_id_from_app_service( + base_insertion_event_dict["sender"], + requester.app_service, + ), + base_insertion_event_dict, + prev_event_ids=base_insertion_event_dict.get("prev_events"), + auth_event_ids=auth_event_ids, + historical=True, + depth=inherited_depth, + ) + + chunk_id_to_connect_to = base_insertion_event["content"][ + EventContentFields.MSC2716_NEXT_CHUNK_ID + ] + + # Connect this current chunk to the insertion event from the previous chunk + chunk_event = { + "type": EventTypes.MSC2716_CHUNK, + "sender": requester.user.to_string(), + "room_id": room_id, + "content": { + EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to, + EventContentFields.MSC2716_HISTORICAL: True, + }, + # Since the chunk event is put at the end of the chunk, + # where the newest-in-time event is, copy the origin_server_ts from + # the last event we're inserting + "origin_server_ts": last_event_in_chunk["origin_server_ts"], + } + # Add the chunk event to the end of the chunk (newest-in-time) + events_to_create.append(chunk_event) + + # Add an "insertion" event to the start of each chunk (next to the oldest-in-time + # event in the chunk) so the next chunk can be connected to this one. + insertion_event = self._create_insertion_event_dict( + sender=requester.user.to_string(), + room_id=room_id, + # Since the insertion event is put at the start of the chunk, + # where the oldest-in-time event is, copy the origin_server_ts from + # the first event we're inserting + origin_server_ts=events_to_create[0]["origin_server_ts"], + ) + # Prepend the insertion event to the start of the chunk (oldest-in-time) + events_to_create = [insertion_event] + events_to_create + + event_ids = [] + events_to_persist = [] + for ev in events_to_create: + assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) + + event_dict = { + "type": ev["type"], + "origin_server_ts": ev["origin_server_ts"], + "content": ev["content"], + "room_id": room_id, + "sender": ev["sender"], # requester.user.to_string(), + "prev_events": prev_event_ids.copy(), + } + + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + + event, context = await self.event_creation_handler.create_event( + await self._create_requester_for_user_id_from_app_service( + ev["sender"], requester.app_service + ), + event_dict, + prev_event_ids=event_dict.get("prev_events"), + auth_event_ids=auth_event_ids, + historical=True, + depth=inherited_depth, + ) + logger.debug( + "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s", + event, + prev_event_ids, + auth_event_ids, + ) + + assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( + event.sender, + ) + + events_to_persist.append((event, context)) + event_id = event.event_id + + event_ids.append(event_id) + prev_event_ids = [event_id] + + # Persist events in reverse-chronological order so they have the + # correct stream_ordering as they are backfilled (which decrements). + # Events are sorted by (topological_ordering, stream_ordering) + # where topological_ordering is just depth. + for (event, context) in reversed(events_to_persist): + ev = await self.event_creation_handler.handle_new_client_event( + await self._create_requester_for_user_id_from_app_service( + event["sender"], requester.app_service + ), + event=event, + context=context, + ) + + # Add the base_insertion_event to the bottom of the list we return + if base_insertion_event is not None: + event_ids.append(base_insertion_event.event_id) + + return 200, { + "state_events": state_events_at_start, + "events": event_ids, + "next_chunk_id": insertion_event["content"][ + EventContentFields.MSC2716_NEXT_CHUNK_ID + ], + } + + def on_GET(self, request, room_id): + return 501, "Not implemented" + + def on_PUT(self, request, room_id): + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id + ) + + +def register_servlets(hs, http_server): + msc2716_enabled = hs.config.experimental.msc2716_enabled + + if msc2716_enabled: + RoomBatchSendEventRestServlet(hs).register(http_server) -- cgit 1.5.1 From a3a7514570f21dcad6f7ef4c1ee3ed1e30115825 Mon Sep 17 00:00:00 2001 From: Šimon Brandner Date: Mon, 16 Aug 2021 13:22:38 +0200 Subject: Handle string read receipt data (#10606) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Handle string read receipt data Signed-off-by: Šimon Brandner * Test that we handle string read receipt data Signed-off-by: Šimon Brandner * Add changelog for #10606 Signed-off-by: Šimon Brandner * Add docs Signed-off-by: Šimon Brandner * Ignore malformed RRs Signed-off-by: Šimon Brandner * Only surround hidden = ... Signed-off-by: Šimon Brandner * Remove unnecessary argument Signed-off-by: Šimon Brandner * Update changelog.d/10606.bugfix Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/10606.bugfix | 1 + synapse/handlers/receipts.py | 9 ++++++++- tests/handlers/test_receipts.py | 23 +++++++++++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10606.bugfix (limited to 'synapse') diff --git a/changelog.d/10606.bugfix b/changelog.d/10606.bugfix new file mode 100644 index 0000000000..bab9fd2a61 --- /dev/null +++ b/changelog.d/10606.bugfix @@ -0,0 +1 @@ +Fix errors on /sync when read receipt data is a string. Only affects homeservers with the experimental flag for [MSC2285](https://github.com/matrix-org/matrix-doc/pull/2285) enabled. Contributed by @SimonBrandner. diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 5fd4525700..fb495229a7 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -188,7 +188,14 @@ class ReceiptEventSource: new_users = {} for rr_user_id, user_rr in m_read.items(): - hidden = user_rr.get("hidden", None) + try: + hidden = user_rr.get("hidden") + except AttributeError: + # Due to https://github.com/matrix-org/synapse/issues/10376 + # there are cases where user_rr is a string, in those cases + # we just ignore the read receipt + continue + if hidden is not True or rr_user_id == user_id: new_users[rr_user_id] = user_rr.copy() # If hidden has a value replace hidden with the correct prefixed key diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index 93a9a084b2..732a12c9bd 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -286,6 +286,29 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) + def test_handles_string_data(self): + """ + Tests that an invalid shape for read-receipts is handled. + Context: https://github.com/matrix-org/synapse/issues/10603 + """ + + self._test_filters_hidden( + [ + { + "content": { + "$14356419edgd14394fHBLK:matrix.org": { + "m.read": { + "@rikj:jki.re": "string", + } + }, + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + }, + ], + [], + ) + def _test_filters_hidden( self, events: List[JsonDict], expected_output: List[JsonDict] ): -- cgit 1.5.1 From 7de445161f2fec115ce8518cde7a3b333a611f16 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 16 Aug 2021 08:06:17 -0400 Subject: Support federation in the new spaces summary API (MSC2946). (#10569) --- changelog.d/10569.feature | 1 + synapse/federation/federation_client.py | 82 +++++++++ synapse/federation/transport/client.py | 22 +++ synapse/federation/transport/server.py | 28 +++ synapse/handlers/space_summary.py | 258 +++++++++++++++++++++++----- tests/handlers/test_space_summary.py | 292 ++++++++++++++++++-------------- 6 files changed, 518 insertions(+), 165 deletions(-) create mode 100644 changelog.d/10569.feature (limited to 'synapse') diff --git a/changelog.d/10569.feature b/changelog.d/10569.feature new file mode 100644 index 0000000000..ffc4e4289c --- /dev/null +++ b/changelog.d/10569.feature @@ -0,0 +1 @@ +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 2eefac04fd..0af953a5d6 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1290,6 +1290,88 @@ class FederationClient(FederationBase): failover_on_unknown_endpoint=True, ) + async def get_room_hierarchy( + self, + destinations: Iterable[str], + room_id: str, + suggested_only: bool, + ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]: + """ + Call other servers to get a hierarchy of the given room. + + Performs simple data validates and parsing of the response. + + Args: + destinations: The remote servers. We will try them in turn, omitting any + that have been blacklisted. + room_id: ID of the space to be queried + suggested_only: If true, ask the remote server to only return children + with the "suggested" flag set + + Returns: + A tuple of: + The room as a JSON dictionary. + A list of children rooms, as JSON dictionaries. + A list of inaccessible children room IDs. + + Raises: + SynapseError if we were unable to get a valid summary from any of the + remote servers + """ + + async def send_request( + destination: str, + ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]: + res = await self.transport_layer.get_room_hierarchy( + destination=destination, + room_id=room_id, + suggested_only=suggested_only, + ) + + room = res.get("room") + if not isinstance(room, dict): + raise InvalidResponseError("'room' must be a dict") + + # Validate children_state of the room. + children_state = room.get("children_state", []) + if not isinstance(children_state, Sequence): + raise InvalidResponseError("'room.children_state' must be a list") + if any(not isinstance(e, dict) for e in children_state): + raise InvalidResponseError("Invalid event in 'children_state' list") + try: + [ + FederationSpaceSummaryEventResult.from_json_dict(e) + for e in children_state + ] + except ValueError as e: + raise InvalidResponseError(str(e)) + + # Validate the children rooms. + children = res.get("children", []) + if not isinstance(children, Sequence): + raise InvalidResponseError("'children' must be a list") + if any(not isinstance(r, dict) for r in children): + raise InvalidResponseError("Invalid room in 'children' list") + + # Validate the inaccessible children. + inaccessible_children = res.get("inaccessible_children", []) + if not isinstance(inaccessible_children, Sequence): + raise InvalidResponseError("'inaccessible_children' must be a list") + if any(not isinstance(r, str) for r in inaccessible_children): + raise InvalidResponseError( + "Invalid room ID in 'inaccessible_children' list" + ) + + return room, children, inaccessible_children + + # TODO Fallback to the old federation API and translate the results. + return await self._try_destination_list( + "fetch room hierarchy", + destinations, + send_request, + failover_on_unknown_endpoint=True, + ) + @attr.s(frozen=True, slots=True, auto_attribs=True) class FederationSpaceSummaryEventResult: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 90a7c16b62..8b247fe206 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -1177,6 +1177,28 @@ class TransportLayerClient: destination=destination, path=path, data=params ) + async def get_room_hierarchy( + self, + destination: str, + room_id: str, + suggested_only: bool, + ) -> JsonDict: + """ + Args: + destination: The remote server + room_id: The room ID to ask about. + suggested_only: if True, only suggested rooms will be returned + """ + path = _create_path( + FEDERATION_UNSTABLE_PREFIX, "/org.matrix.msc2946/hierarchy/%s", room_id + ) + + return await self.client.get_json( + destination=destination, + path=path, + args={"suggested_only": "true" if suggested_only else "false"}, + ) + def _create_path(federation_prefix: str, path: str, *args: str) -> str: """ diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 640f46fff6..79a2e1afa0 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -1936,6 +1936,33 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): ) +class FederationRoomHierarchyServlet(BaseFederationServlet): + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" + PATH = "/hierarchy/(?P[^/]*)" + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_space_summary_handler() + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Mapping[bytes, Sequence[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) + return 200, await self.handler.get_federation_hierarchy( + origin, room_id, suggested_only + ) + + class RoomComplexityServlet(BaseFederationServlet): """ Indicates to other servers how complex (and therefore likely @@ -1999,6 +2026,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationVersionServlet, RoomComplexityServlet, FederationSpaceSummaryServlet, + FederationRoomHierarchyServlet, FederationV1SendKnockServlet, FederationMakeKnockServlet, ) diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index d0060f9046..c74e90abbc 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -16,17 +16,7 @@ import itertools import logging import re from collections import deque -from typing import ( - TYPE_CHECKING, - Deque, - Dict, - Iterable, - List, - Optional, - Sequence, - Set, - Tuple, -) +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Set, Tuple import attr @@ -80,7 +70,7 @@ class _PaginationSession: # The time the pagination session was created, in milliseconds. creation_time_ms: int # The queue of rooms which are still to process. - room_queue: Deque["_RoomQueueEntry"] + room_queue: List["_RoomQueueEntry"] # A set of rooms which have been processed. processed_rooms: Set[str] @@ -197,7 +187,7 @@ class SpaceSummaryHandler: events: Sequence[JsonDict] = [] if room_entry: rooms_result.append(room_entry.room) - events = room_entry.children + events = room_entry.children_state_events logger.debug( "Query of local room %s returned events %s", @@ -232,7 +222,7 @@ class SpaceSummaryHandler: room.pop("allowed_spaces", None) rooms_result.append(room) - events.extend(room_entry.children) + events.extend(room_entry.children_state_events) # All rooms returned don't need visiting again (even if the user # didn't have access to them). @@ -350,8 +340,8 @@ class SpaceSummaryHandler: room_queue = pagination_session.room_queue processed_rooms = pagination_session.processed_rooms else: - # the queue of rooms to process - room_queue = deque((_RoomQueueEntry(requested_room_id, ()),)) + # The queue of rooms to process, the next room is last on the stack. + room_queue = [_RoomQueueEntry(requested_room_id, ())] # Rooms we have already processed. processed_rooms = set() @@ -367,7 +357,7 @@ class SpaceSummaryHandler: # Iterate through the queue until we reach the limit or run out of # rooms to include. while room_queue and len(rooms_result) < limit: - queue_entry = room_queue.popleft() + queue_entry = room_queue.pop() room_id = queue_entry.room_id current_depth = queue_entry.depth if room_id in processed_rooms: @@ -376,6 +366,18 @@ class SpaceSummaryHandler: logger.debug("Processing room %s", room_id) + # A map of summaries for children rooms that might be returned over + # federation. The rationale for caching these and *maybe* using them + # is to prefer any information local to the homeserver before trusting + # data received over federation. + children_room_entries: Dict[str, JsonDict] = {} + # A set of room IDs which are children that did not have information + # returned over federation and are known to be inaccessible to the + # current server. We should not reach out over federation to try to + # summarise these rooms. + inaccessible_children: Set[str] = set() + + # If the room is known locally, summarise it! is_in_room = await self._store.is_host_joined(room_id, self._server_name) if is_in_room: room_entry = await self._summarize_local_room( @@ -387,26 +389,68 @@ class SpaceSummaryHandler: max_children=None, ) - if room_entry: - rooms_result.append(room_entry.as_json()) - - # Add the child to the queue. We have already validated - # that the vias are a list of server names. - # - # If the current depth is the maximum depth, do not queue - # more entries. - if max_depth is None or current_depth < max_depth: - room_queue.extendleft( - _RoomQueueEntry( - ev["state_key"], ev["content"]["via"], current_depth + 1 - ) - for ev in reversed(room_entry.children) - ) - - processed_rooms.add(room_id) + # Otherwise, attempt to use information for federation. else: - # TODO Federation. - pass + # A previous call might have included information for this room. + # It can be used if either: + # + # 1. The room is not a space. + # 2. The maximum depth has been achieved (since no children + # information is needed). + if queue_entry.remote_room and ( + queue_entry.remote_room.get("room_type") != RoomTypes.SPACE + or (max_depth is not None and current_depth >= max_depth) + ): + room_entry = _RoomEntry( + queue_entry.room_id, queue_entry.remote_room + ) + + # If the above isn't true, attempt to fetch the room + # information over federation. + else: + ( + room_entry, + children_room_entries, + inaccessible_children, + ) = await self._summarize_remote_room_hiearchy( + queue_entry, + suggested_only, + ) + + # Ensure this room is accessible to the requester (and not just + # the homeserver). + if room_entry and not await self._is_remote_room_accessible( + requester, queue_entry.room_id, room_entry.room + ): + room_entry = None + + # This room has been processed and should be ignored if it appears + # elsewhere in the hierarchy. + processed_rooms.add(room_id) + + # There may or may not be a room entry based on whether it is + # inaccessible to the requesting user. + if room_entry: + # Add the room (including the stripped m.space.child events). + rooms_result.append(room_entry.as_json()) + + # If this room is not at the max-depth, check if there are any + # children to process. + if max_depth is None or current_depth < max_depth: + # The children get added in reverse order so that the next + # room to process, according to the ordering, is the last + # item in the list. + room_queue.extend( + _RoomQueueEntry( + ev["state_key"], + ev["content"]["via"], + current_depth + 1, + children_room_entries.get(ev["state_key"]), + ) + for ev in reversed(room_entry.children_state_events) + if ev["type"] == EventTypes.SpaceChild + and ev["state_key"] not in inaccessible_children + ) result: JsonDict = {"rooms": rooms_result} @@ -477,15 +521,78 @@ class SpaceSummaryHandler: if room_entry: rooms_result.append(room_entry.room) - events_result.extend(room_entry.children) + events_result.extend(room_entry.children_state_events) # add any children to the queue room_queue.extend( - edge_event["state_key"] for edge_event in room_entry.children + edge_event["state_key"] + for edge_event in room_entry.children_state_events ) return {"rooms": rooms_result, "events": events_result} + async def get_federation_hierarchy( + self, + origin: str, + requested_room_id: str, + suggested_only: bool, + ): + """ + Implementation of the room hierarchy Federation API. + + This is similar to get_room_hierarchy, but does not recurse into the space. + It also considers whether anyone on the server may be able to access the + room, as opposed to whether a specific user can. + + Args: + origin: The server requesting the spaces summary. + requested_room_id: The room ID to start the hierarchy at (the "root" room). + suggested_only: whether we should only return children with the "suggested" + flag set. + + Returns: + The JSON hierarchy dictionary. + """ + root_room_entry = await self._summarize_local_room( + None, origin, requested_room_id, suggested_only, max_children=None + ) + if root_room_entry is None: + # Room is inaccessible to the requesting server. + raise SynapseError(404, "Unknown room: %s" % (requested_room_id,)) + + children_rooms_result: List[JsonDict] = [] + inaccessible_children: List[str] = [] + + # Iterate through each child and potentially add it, but not its children, + # to the response. + for child_room in root_room_entry.children_state_events: + room_id = child_room.get("state_key") + assert isinstance(room_id, str) + # If the room is unknown, skip it. + if not await self._store.is_host_joined(room_id, self._server_name): + continue + + room_entry = await self._summarize_local_room( + None, origin, room_id, suggested_only, max_children=0 + ) + # If the room is accessible, include it in the results. + # + # Note that only the room summary (without information on children) + # is included in the summary. + if room_entry: + children_rooms_result.append(room_entry.room) + # Otherwise, note that the requesting server shouldn't bother + # trying to summarize this room - they do not have access to it. + else: + inaccessible_children.append(room_id) + + return { + # Include the requested room (including the stripped children events). + "room": root_room_entry.as_json(), + "children": children_rooms_result, + "inaccessible_children": inaccessible_children, + } + async def _summarize_local_room( self, requester: Optional[str], @@ -519,8 +626,9 @@ class SpaceSummaryHandler: room_entry = await self._build_room_entry(room_id, for_federation=bool(origin)) - # If the room is not a space, return just the room information. - if room_entry.get("room_type") != RoomTypes.SPACE: + # If the room is not a space or the children don't matter, return just + # the room information. + if room_entry.get("room_type") != RoomTypes.SPACE or max_children == 0: return _RoomEntry(room_id, room_entry) # Otherwise, look for child rooms/spaces. @@ -616,6 +724,59 @@ class SpaceSummaryHandler: return results + async def _summarize_remote_room_hiearchy( + self, room: "_RoomQueueEntry", suggested_only: bool + ) -> Tuple[Optional["_RoomEntry"], Dict[str, JsonDict], Set[str]]: + """ + Request room entries and a list of event entries for a given room by querying a remote server. + + Args: + room: The room to summarize. + suggested_only: True if only suggested children should be returned. + Otherwise, all children are returned. + + Returns: + A tuple of: + The room entry. + Partial room data return over federation. + A set of inaccessible children room IDs. + """ + room_id = room.room_id + logger.info("Requesting summary for %s via %s", room_id, room.via) + + via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE) + try: + ( + room_response, + children, + inaccessible_children, + ) = await self._federation_client.get_room_hierarchy( + via, + room_id, + suggested_only=suggested_only, + ) + except Exception as e: + logger.warning( + "Unable to get hierarchy of %s via federation: %s", + room_id, + e, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + return None, {}, set() + + # Map the children to their room ID. + children_by_room_id = { + c["room_id"]: c + for c in children + if "room_id" in c and isinstance(c["room_id"], str) + } + + return ( + _RoomEntry(room_id, room_response, room_response.pop("children_state", ())), + children_by_room_id, + set(inaccessible_children), + ) + async def _is_local_room_accessible( self, room_id: str, requester: Optional[str], origin: Optional[str] = None ) -> bool: @@ -866,9 +1027,16 @@ class SpaceSummaryHandler: @attr.s(frozen=True, slots=True, auto_attribs=True) class _RoomQueueEntry: + # The room ID of this entry. room_id: str + # The server to query if the room is not known locally. via: Sequence[str] + # The minimum number of hops necessary to get to this room (compared to the + # originally requested room). depth: int = 0 + # The room summary for this room returned via federation. This will only be + # used if the room is not known locally (and is not a space). + remote_room: Optional[JsonDict] = None @attr.s(frozen=True, slots=True, auto_attribs=True) @@ -879,11 +1047,17 @@ class _RoomEntry: # An iterable of the sorted, stripped children events for children of this room. # # This may not include all children. - children: Sequence[JsonDict] = () + children_state_events: Sequence[JsonDict] = () def as_json(self) -> JsonDict: + """ + Returns a JSON dictionary suitable for the room hierarchy endpoint. + + It returns the room summary including the stripped m.space.child events + as a sub-key. + """ result = dict(self.room) - result["children_state"] = self.children + result["children_state"] = self.children_state_events return result diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index 83c2bdd8f9..bc8e131f4a 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -481,7 +481,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_batch", result) def test_invalid_pagination_token(self): - """""" + """An invalid pagination token, or changing other parameters, shoudl be rejected.""" room_ids = [] for i in range(1, 10): room = self.helper.create_room_as(self.user, tok=self.token) @@ -581,33 +581,40 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): subspace = "#subspace:" + fed_hostname subroom = "#subroom:" + fed_hostname + # Generate some good data, and some bad data: + # + # * Event *back* to the root room. + # * Unrelated events / rooms + # * Multiple levels of events (in a not-useful order, e.g. grandchild + # events before child events). + + # Note that these entries are brief, but should contain enough info. + requested_room_entry = _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + "room_type": RoomTypes.SPACE, + }, + [ + { + "type": EventTypes.SpaceChild, + "room_id": subspace, + "state_key": subroom, + "content": {"via": [fed_hostname]}, + } + ], + ) + child_room = { + "room_id": subroom, + "world_readable": True, + } + async def summarize_remote_room( _self, room, suggested_only, max_children, exclude_rooms ): - # Return some good data, and some bad data: - # - # * Event *back* to the root room. - # * Unrelated events / rooms - # * Multiple levels of events (in a not-useful order, e.g. grandchild - # events before child events). - - # Note that these entries are brief, but should contain enough info. return [ - _RoomEntry( - subspace, - { - "room_id": subspace, - "world_readable": True, - "room_type": RoomTypes.SPACE, - }, - [ - { - "room_id": subspace, - "state_key": subroom, - "content": {"via": [fed_hostname]}, - } - ], - ), + requested_room_entry, _RoomEntry( subroom, { @@ -617,6 +624,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ), ] + async def summarize_remote_room_hiearchy(_self, room, suggested_only): + return requested_room_entry, {subroom: child_room}, set() + # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) @@ -636,6 +646,15 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ] self._assert_rooms(result, expected) + with mock.patch( + "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room_hiearchy", + new=summarize_remote_room_hiearchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + def test_fed_filtering(self): """ Rooms returned over federation should be properly filtered to only include @@ -657,100 +676,106 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Poke an invite over federation into the database. self._poke_fed_invite(invited_room, "@remote:" + fed_hostname) + # Note that these entries are brief, but should contain enough info. + children_rooms = ( + ( + public_room, + { + "room_id": public_room, + "world_readable": False, + "join_rules": JoinRules.PUBLIC, + }, + ), + ( + knock_room, + { + "room_id": knock_room, + "world_readable": False, + "join_rules": JoinRules.KNOCK, + }, + ), + ( + not_invited_room, + { + "room_id": not_invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ( + invited_room, + { + "room_id": invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ( + restricted_room, + { + "room_id": restricted_room, + "world_readable": False, + "join_rules": JoinRules.RESTRICTED, + "allowed_spaces": [], + }, + ), + ( + restricted_accessible_room, + { + "room_id": restricted_accessible_room, + "world_readable": False, + "join_rules": JoinRules.RESTRICTED, + "allowed_spaces": [self.room], + }, + ), + ( + world_readable_room, + { + "room_id": world_readable_room, + "world_readable": True, + "join_rules": JoinRules.INVITE, + }, + ), + ( + joined_room, + { + "room_id": joined_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ) + + subspace_room_entry = _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + }, + # Place each room in the sub-space. + [ + { + "type": EventTypes.SpaceChild, + "room_id": subspace, + "state_key": room_id, + "content": {"via": [fed_hostname]}, + } + for room_id, _ in children_rooms + ], + ) + async def summarize_remote_room( _self, room, suggested_only, max_children, exclude_rooms ): - # Note that these entries are brief, but should contain enough info. - rooms = [ - _RoomEntry( - public_room, - { - "room_id": public_room, - "world_readable": False, - "join_rules": JoinRules.PUBLIC, - }, - ), - _RoomEntry( - knock_room, - { - "room_id": knock_room, - "world_readable": False, - "join_rules": JoinRules.KNOCK, - }, - ), - _RoomEntry( - not_invited_room, - { - "room_id": not_invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ), - _RoomEntry( - invited_room, - { - "room_id": invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ), - _RoomEntry( - restricted_room, - { - "room_id": restricted_room, - "world_readable": False, - "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [], - }, - ), - _RoomEntry( - restricted_accessible_room, - { - "room_id": restricted_accessible_room, - "world_readable": False, - "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [self.room], - }, - ), - _RoomEntry( - world_readable_room, - { - "room_id": world_readable_room, - "world_readable": True, - "join_rules": JoinRules.INVITE, - }, - ), - _RoomEntry( - joined_room, - { - "room_id": joined_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ), + return [subspace_room_entry] + [ + # A copy is made of the room data since the allowed_spaces key + # is removed. + _RoomEntry(child_room[0], dict(child_room[1])) + for child_room in children_rooms ] - # Also include the subspace. - rooms.insert( - 0, - _RoomEntry( - subspace, - { - "room_id": subspace, - "world_readable": True, - }, - # Place each room in the sub-space. - [ - { - "room_id": subspace, - "state_key": room.room_id, - "content": {"via": [fed_hostname]}, - } - for room in rooms - ], - ), - ) - return rooms + async def summarize_remote_room_hiearchy(_self, room, suggested_only): + return subspace_room_entry, dict(children_rooms), set() # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) @@ -788,6 +813,15 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ] self._assert_rooms(result, expected) + with mock.patch( + "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room_hiearchy", + new=summarize_remote_room_hiearchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + def test_fed_invited(self): """ A room which the user was invited to should be included in the response. @@ -802,19 +836,22 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Poke an invite over federation into the database. self._poke_fed_invite(fed_room, "@remote:" + fed_hostname) + fed_room_entry = _RoomEntry( + fed_room, + { + "room_id": fed_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ) + async def summarize_remote_room( _self, room, suggested_only, max_children, exclude_rooms ): - return [ - _RoomEntry( - fed_room, - { - "room_id": fed_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ), - ] + return [fed_room_entry] + + async def summarize_remote_room_hiearchy(_self, room, suggested_only): + return fed_room_entry, {}, set() # Add a room to the space which is on another server. self._add_child(self.space, fed_room, self.token) @@ -833,3 +870,12 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): (fed_room, ()), ] self._assert_rooms(result, expected) + + with mock.patch( + "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room_hiearchy", + new=summarize_remote_room_hiearchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) -- cgit 1.5.1 From 2d9ca4ca77c2cdf98ddb738aee8d5699c7c8749f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 16 Aug 2021 13:19:02 +0100 Subject: Clean up some logging in the federation event handler (#10591) * Include outlier status in `str(event)` In places where we log event objects, knowing whether or not you're dealing with an outlier is super useful. * Remove duplicated logging in get_missing_events When we process events received from get_missing_events, we log them twice (once in `_get_missing_events_for_pdu`, and once in `on_receive_pdu`). Reduce the duplication by removing the logging in `on_receive_pdu`, and ensuring the call sites do sensible logging. * log in `on_receive_pdu` when we already have the event * Log which prev_events we are missing * changelog --- changelog.d/10591.misc | 1 + synapse/events/__init__.py | 3 +- synapse/federation/federation_server.py | 1 + synapse/handlers/federation.py | 52 +++++++++++++++------------------ 4 files changed, 28 insertions(+), 29 deletions(-) create mode 100644 changelog.d/10591.misc (limited to 'synapse') diff --git a/changelog.d/10591.misc b/changelog.d/10591.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10591.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 0298af4c02..a730c1719a 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -396,10 +396,11 @@ class FrozenEvent(EventBase): return self.__repr__() def __repr__(self): - return "" % ( + return "" % ( self.get("event_id", None), self.get("type", None), self.get("state_key", None), + self.internal_metadata.is_outlier(), ) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 78d5aac6af..afd8f8580a 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1003,6 +1003,7 @@ class FederationServer(FederationBase): # has started processing). while True: async with lock: + logger.info("handling received PDU: %s", event) try: await self.handler.on_receive_pdu( origin, event, sent_to_us_directly=True diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 9a5e726533..c0e13bdaac 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -220,8 +220,6 @@ class FederationHandler(BaseHandler): room_id = pdu.room_id event_id = pdu.event_id - logger.info("handling received PDU: %s", pdu) - # We reprocess pdus when we have seen them only as outliers existing = await self.store.get_event( event_id, allow_none=True, allow_rejected=True @@ -229,14 +227,19 @@ class FederationHandler(BaseHandler): # FIXME: Currently we fetch an event again when we already have it # if it has been marked as an outlier. - - already_seen = existing and ( - not existing.internal_metadata.is_outlier() - or pdu.internal_metadata.is_outlier() - ) - if already_seen: - logger.debug("Already seen pdu") - return + if existing: + if not existing.internal_metadata.is_outlier(): + logger.info( + "Ignoring received event %s which we have already seen", event_id + ) + return + if pdu.internal_metadata.is_outlier(): + logger.info( + "Ignoring received outlier %s which we already have as an outlier", + event_id, + ) + return + logger.info("De-outliering event %s", event_id) # do some initial sanity-checking of the event. In particular, make # sure it doesn't have hundreds of prev_events or auth_events, which @@ -331,7 +334,8 @@ class FederationHandler(BaseHandler): "Found all missing prev_events", ) - if prevs - seen: + missing_prevs = prevs - seen + if missing_prevs: # We've still not been able to get all of the prev_events for this event. # # In this case, we need to fall back to asking another server in the @@ -359,8 +363,8 @@ class FederationHandler(BaseHandler): if sent_to_us_directly: logger.warning( "Rejecting: failed to fetch %d prev events: %s", - len(prevs - seen), - shortstr(prevs - seen), + len(missing_prevs), + shortstr(missing_prevs), ) raise FederationError( "ERROR", @@ -373,9 +377,10 @@ class FederationHandler(BaseHandler): ) logger.info( - "Event %s is missing prev_events: calculating state for a " + "Event %s is missing prev_events %s: calculating state for a " "backwards extremity", event_id, + shortstr(missing_prevs), ) # Calculate the state after each of the previous events, and @@ -393,7 +398,7 @@ class FederationHandler(BaseHandler): # Ask the remote server for the states we don't # know about - for p in prevs - seen: + for p in missing_prevs: logger.info("Requesting state after missing prev_event %s", p) with nested_logging_context(p): @@ -556,21 +561,14 @@ class FederationHandler(BaseHandler): logger.warning("Failed to get prev_events: %s", e) return - logger.info( - "Got %d prev_events: %s", - len(missing_events), - shortstr(missing_events), - ) + logger.info("Got %d prev_events", len(missing_events)) # We want to sort these by depth so we process them and # tell clients about them in order. missing_events.sort(key=lambda x: x.depth) for ev in missing_events: - logger.info( - "Handling received prev_event %s", - ev.event_id, - ) + logger.info("Handling received prev_event %s", ev) with nested_logging_context(ev.event_id): try: await self.on_receive_pdu(origin, ev, sent_to_us_directly=False) @@ -1762,10 +1760,8 @@ class FederationHandler(BaseHandler): for p, origin in room_queue: try: logger.info( - "Processing queued PDU %s which was received " - "while we were joining %s", - p.event_id, - p.room_id, + "Processing queued PDU %s which was received while we were joining", + p, ) with nested_logging_context(p.event_id): await self.on_receive_pdu(origin, p, sent_to_us_directly=True) -- cgit 1.5.1 From 87b62f8bb23f99d76bf0ee62c8217fa45a087673 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 16 Aug 2021 10:14:31 -0400 Subject: Split `synapse.federation.transport.server` into multiple files. (#10590) --- changelog.d/10590.misc | 1 + synapse/federation/transport/server.py | 2158 -------------------- synapse/federation/transport/server/__init__.py | 332 +++ synapse/federation/transport/server/_base.py | 328 +++ synapse/federation/transport/server/federation.py | 692 +++++++ .../federation/transport/server/groups_local.py | 113 + .../federation/transport/server/groups_server.py | 753 +++++++ 7 files changed, 2219 insertions(+), 2158 deletions(-) create mode 100644 changelog.d/10590.misc delete mode 100644 synapse/federation/transport/server.py create mode 100644 synapse/federation/transport/server/__init__.py create mode 100644 synapse/federation/transport/server/_base.py create mode 100644 synapse/federation/transport/server/federation.py create mode 100644 synapse/federation/transport/server/groups_local.py create mode 100644 synapse/federation/transport/server/groups_server.py (limited to 'synapse') diff --git a/changelog.d/10590.misc b/changelog.d/10590.misc new file mode 100644 index 0000000000..62fec717da --- /dev/null +++ b/changelog.d/10590.misc @@ -0,0 +1 @@ +Re-organize the `synapse.federation.transport.server` module to create smaller files. diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py deleted file mode 100644 index 79a2e1afa0..0000000000 --- a/synapse/federation/transport/server.py +++ /dev/null @@ -1,2158 +0,0 @@ -# Copyright 2014-2021 The Matrix.org Foundation C.I.C. -# Copyright 2020 Sorunome -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import functools -import logging -import re -from typing import ( - Container, - Dict, - List, - Mapping, - Optional, - Sequence, - Tuple, - Type, - Union, -) - -from typing_extensions import Literal - -import synapse -from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH -from synapse.api.errors import Codes, FederationDeniedError, SynapseError -from synapse.api.room_versions import RoomVersions -from synapse.api.urls import ( - FEDERATION_UNSTABLE_PREFIX, - FEDERATION_V1_PREFIX, - FEDERATION_V2_PREFIX, -) -from synapse.handlers.groups_local import GroupsLocalHandler -from synapse.http.server import HttpServer, JsonResource -from synapse.http.servlet import ( - parse_boolean_from_args, - parse_integer_from_args, - parse_json_object_from_request, - parse_string_from_args, - parse_strings_from_args, -) -from synapse.logging import opentracing -from synapse.logging.context import run_in_background -from synapse.logging.opentracing import ( - SynapseTags, - start_active_span, - start_active_span_from_request, - tags, - whitelisted_homeserver, -) -from synapse.server import HomeServer -from synapse.types import JsonDict, ThirdPartyInstanceID, get_domain_from_id -from synapse.util.ratelimitutils import FederationRateLimiter -from synapse.util.stringutils import parse_and_validate_server_name -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger(__name__) - - -class TransportLayerServer(JsonResource): - """Handles incoming federation HTTP requests""" - - def __init__(self, hs: HomeServer, servlet_groups: Optional[List[str]] = None): - """Initialize the TransportLayerServer - - Will by default register all servlets. For custom behaviour, pass in - a list of servlet_groups to register. - - Args: - hs: homeserver - servlet_groups: List of servlet groups to register. - Defaults to ``DEFAULT_SERVLET_GROUPS``. - """ - self.hs = hs - self.clock = hs.get_clock() - self.servlet_groups = servlet_groups - - super().__init__(hs, canonical_json=False) - - self.authenticator = Authenticator(hs) - self.ratelimiter = hs.get_federation_ratelimiter() - - self.register_servlets() - - def register_servlets(self) -> None: - register_servlets( - self.hs, - resource=self, - ratelimiter=self.ratelimiter, - authenticator=self.authenticator, - servlet_groups=self.servlet_groups, - ) - - -class AuthenticationError(SynapseError): - """There was a problem authenticating the request""" - - -class NoAuthenticationError(AuthenticationError): - """The request had no authentication information""" - - -class Authenticator: - def __init__(self, hs: HomeServer): - self._clock = hs.get_clock() - self.keyring = hs.get_keyring() - self.server_name = hs.hostname - self.store = hs.get_datastore() - self.federation_domain_whitelist = hs.config.federation_domain_whitelist - self.notifier = hs.get_notifier() - - self.replication_client = None - if hs.config.worker.worker_app: - self.replication_client = hs.get_tcp_replication() - - # A method just so we can pass 'self' as the authenticator to the Servlets - async def authenticate_request(self, request, content): - now = self._clock.time_msec() - json_request = { - "method": request.method.decode("ascii"), - "uri": request.uri.decode("ascii"), - "destination": self.server_name, - "signatures": {}, - } - - if content is not None: - json_request["content"] = content - - origin = None - - auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") - - if not auth_headers: - raise NoAuthenticationError( - 401, "Missing Authorization headers", Codes.UNAUTHORIZED - ) - - for auth in auth_headers: - if auth.startswith(b"X-Matrix"): - (origin, key, sig) = _parse_auth_header(auth) - json_request["origin"] = origin - json_request["signatures"].setdefault(origin, {})[key] = sig - - if ( - self.federation_domain_whitelist is not None - and origin not in self.federation_domain_whitelist - ): - raise FederationDeniedError(origin) - - if origin is None or not json_request["signatures"]: - raise NoAuthenticationError( - 401, "Missing Authorization headers", Codes.UNAUTHORIZED - ) - - await self.keyring.verify_json_for_server( - origin, - json_request, - now, - ) - - logger.debug("Request from %s", origin) - request.requester = origin - - # If we get a valid signed request from the other side, its probably - # alive - retry_timings = await self.store.get_destination_retry_timings(origin) - if retry_timings and retry_timings.retry_last_ts: - run_in_background(self._reset_retry_timings, origin) - - return origin - - async def _reset_retry_timings(self, origin): - try: - logger.info("Marking origin %r as up", origin) - await self.store.set_destination_retry_timings(origin, None, 0, 0) - - # Inform the relevant places that the remote server is back up. - self.notifier.notify_remote_server_up(origin) - if self.replication_client: - # If we're on a worker we try and inform master about this. The - # replication client doesn't hook into the notifier to avoid - # infinite loops where we send a `REMOTE_SERVER_UP` command to - # master, which then echoes it back to us which in turn pokes - # the notifier. - self.replication_client.send_remote_server_up(origin) - - except Exception: - logger.exception("Error resetting retry timings on %s", origin) - - -def _parse_auth_header(header_bytes): - """Parse an X-Matrix auth header - - Args: - header_bytes (bytes): header value - - Returns: - Tuple[str, str, str]: origin, key id, signature. - - Raises: - AuthenticationError if the header could not be parsed - """ - try: - header_str = header_bytes.decode("utf-8") - params = header_str.split(" ")[1].split(",") - param_dict = dict(kv.split("=") for kv in params) - - def strip_quotes(value): - if value.startswith('"'): - return value[1:-1] - else: - return value - - origin = strip_quotes(param_dict["origin"]) - - # ensure that the origin is a valid server name - parse_and_validate_server_name(origin) - - key = strip_quotes(param_dict["key"]) - sig = strip_quotes(param_dict["sig"]) - return origin, key, sig - except Exception as e: - logger.warning( - "Error parsing auth header '%s': %s", - header_bytes.decode("ascii", "replace"), - e, - ) - raise AuthenticationError( - 400, "Malformed Authorization header", Codes.UNAUTHORIZED - ) - - -class BaseFederationServlet: - """Abstract base class for federation servlet classes. - - The servlet object should have a PATH attribute which takes the form of a regexp to - match against the request path (excluding the /federation/v1 prefix). - - The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match - the appropriate HTTP method. These methods must be *asynchronous* and have the - signature: - - on_(self, origin, content, query, **kwargs) - - With arguments: - - origin (unicode|None): The authenticated server_name of the calling server, - unless REQUIRE_AUTH is set to False and authentication failed. - - content (unicode|None): decoded json body of the request. None if the - request was a GET. - - query (dict[bytes, list[bytes]]): Query params from the request. url-decoded - (ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded - yet. - - **kwargs (dict[unicode, unicode]): the dict mapping keys to path - components as specified in the path match regexp. - - Returns: - Optional[Tuple[int, object]]: either (response code, response object) to - return a JSON response, or None if the request has already been handled. - - Raises: - SynapseError: to return an error code - - Exception: other exceptions will be caught, logged, and a 500 will be - returned. - """ - - PATH = "" # Overridden in subclasses, the regex to match against the path. - - REQUIRE_AUTH = True - - PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version - - RATELIMIT = True # Whether to rate limit requests or not - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - self.hs = hs - self.authenticator = authenticator - self.ratelimiter = ratelimiter - self.server_name = server_name - - def _wrap(self, func): - authenticator = self.authenticator - ratelimiter = self.ratelimiter - - @functools.wraps(func) - async def new_func(request, *args, **kwargs): - """A callback which can be passed to HttpServer.RegisterPaths - - Args: - request (twisted.web.http.Request): - *args: unused? - **kwargs (dict[unicode, unicode]): the dict mapping keys to path - components as specified in the path match regexp. - - Returns: - Tuple[int, object]|None: (response code, response object) as returned by - the callback method. None if the request has already been handled. - """ - content = None - if request.method in [b"PUT", b"POST"]: - # TODO: Handle other method types? other content types? - content = parse_json_object_from_request(request) - - try: - origin = await authenticator.authenticate_request(request, content) - except NoAuthenticationError: - origin = None - if self.REQUIRE_AUTH: - logger.warning( - "authenticate_request failed: missing authentication" - ) - raise - except Exception as e: - logger.warning("authenticate_request failed: %s", e) - raise - - request_tags = { - SynapseTags.REQUEST_ID: request.get_request_id(), - tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, - tags.HTTP_METHOD: request.get_method(), - tags.HTTP_URL: request.get_redacted_uri(), - tags.PEER_HOST_IPV6: request.getClientIP(), - "authenticated_entity": origin, - "servlet_name": request.request_metrics.name, - } - - # Only accept the span context if the origin is authenticated - # and whitelisted - if origin and whitelisted_homeserver(origin): - scope = start_active_span_from_request( - request, "incoming-federation-request", tags=request_tags - ) - else: - scope = start_active_span( - "incoming-federation-request", tags=request_tags - ) - - with scope: - opentracing.inject_response_headers(request.responseHeaders) - - if origin and self.RATELIMIT: - with ratelimiter.ratelimit(origin) as d: - await d - if request._disconnected: - logger.warning( - "client disconnected before we started processing " - "request" - ) - return -1, None - response = await func( - origin, content, request.args, *args, **kwargs - ) - else: - response = await func( - origin, content, request.args, *args, **kwargs - ) - - return response - - return new_func - - def register(self, server): - pattern = re.compile("^" + self.PREFIX + self.PATH + "$") - - for method in ("GET", "PUT", "POST"): - code = getattr(self, "on_%s" % (method), None) - if code is None: - continue - - server.register_paths( - method, - (pattern,), - self._wrap(code), - self.__class__.__name__, - ) - - -class BaseFederationServerServlet(BaseFederationServlet): - """Abstract base class for federation servlet classes which provides a federation server handler. - - See BaseFederationServlet for more information. - """ - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_federation_server() - - -class FederationSendServlet(BaseFederationServerServlet): - PATH = "/send/(?P[^/]*)/?" - - # We ratelimit manually in the handler as we queue up the requests and we - # don't want to fill up the ratelimiter with blocked requests. - RATELIMIT = False - - # This is when someone is trying to send us a bunch of data. - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - transaction_id: str, - ) -> Tuple[int, JsonDict]: - """Called on PUT /send// - - Args: - transaction_id: The transaction_id associated with this request. This - is *not* None. - - Returns: - Tuple of `(code, response)`, where - `response` is a python dict to be converted into JSON that is - used as the response body. - """ - # Parse the request - try: - transaction_data = content - - logger.debug("Decoded %s: %s", transaction_id, str(transaction_data)) - - logger.info( - "Received txn %s from %s. (PDUs: %d, EDUs: %d)", - transaction_id, - origin, - len(transaction_data.get("pdus", [])), - len(transaction_data.get("edus", [])), - ) - - except Exception as e: - logger.exception(e) - return 400, {"error": "Invalid transaction"} - - code, response = await self.handler.on_incoming_transaction( - origin, transaction_id, self.server_name, transaction_data - ) - - return code, response - - -class FederationEventServlet(BaseFederationServerServlet): - PATH = "/event/(?P[^/]*)/?" - - # This is when someone asks for a data item for a given server data_id pair. - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - event_id: str, - ) -> Tuple[int, Union[JsonDict, str]]: - return await self.handler.on_pdu_request(origin, event_id) - - -class FederationStateV1Servlet(BaseFederationServerServlet): - PATH = "/state/(?P[^/]*)/?" - - # This is when someone asks for all data for a given room. - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - return await self.handler.on_room_state_request( - origin, - room_id, - parse_string_from_args(query, "event_id", None, required=False), - ) - - -class FederationStateIdsServlet(BaseFederationServerServlet): - PATH = "/state_ids/(?P[^/]*)/?" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - return await self.handler.on_state_ids_request( - origin, - room_id, - parse_string_from_args(query, "event_id", None, required=True), - ) - - -class FederationBackfillServlet(BaseFederationServerServlet): - PATH = "/backfill/(?P[^/]*)/?" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - versions = [x.decode("ascii") for x in query[b"v"]] - limit = parse_integer_from_args(query, "limit", None) - - if not limit: - return 400, {"error": "Did not include limit param"} - - return await self.handler.on_backfill_request(origin, room_id, versions, limit) - - -class FederationQueryServlet(BaseFederationServerServlet): - PATH = "/query/(?P[^/]*)" - - # This is when we receive a server-server Query - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - query_type: str, - ) -> Tuple[int, JsonDict]: - args = {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()} - args["origin"] = origin - return await self.handler.on_query_request(query_type, args) - - -class FederationMakeJoinServlet(BaseFederationServerServlet): - PATH = "/make_join/(?P[^/]*)/(?P[^/]*)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - """ - Args: - origin: The authenticated server_name of the calling server - - content: (GETs don't have bodies) - - query: Query params from the request. - - **kwargs: the dict mapping keys to path components as specified in - the path match regexp. - - Returns: - Tuple of (response code, response object) - """ - supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8") - if supported_versions is None: - supported_versions = ["1"] - - result = await self.handler.on_make_join_request( - origin, room_id, user_id, supported_versions=supported_versions - ) - return 200, result - - -class FederationMakeLeaveServlet(BaseFederationServerServlet): - PATH = "/make_leave/(?P[^/]*)/(?P[^/]*)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - result = await self.handler.on_make_leave_request(origin, room_id, user_id) - return 200, result - - -class FederationV1SendLeaveServlet(BaseFederationServerServlet): - PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, Tuple[int, JsonDict]]: - result = await self.handler.on_send_leave_request(origin, content, room_id) - return 200, (200, result) - - -class FederationV2SendLeaveServlet(BaseFederationServerServlet): - PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" - - PREFIX = FEDERATION_V2_PREFIX - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, JsonDict]: - result = await self.handler.on_send_leave_request(origin, content, room_id) - return 200, result - - -class FederationMakeKnockServlet(BaseFederationServerServlet): - PATH = "/make_knock/(?P[^/]*)/(?P[^/]*)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - # Retrieve the room versions the remote homeserver claims to support - supported_versions = parse_strings_from_args( - query, "ver", required=True, encoding="utf-8" - ) - - result = await self.handler.on_make_knock_request( - origin, room_id, user_id, supported_versions=supported_versions - ) - return 200, result - - -class FederationV1SendKnockServlet(BaseFederationServerServlet): - PATH = "/send_knock/(?P[^/]*)/(?P[^/]*)" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, JsonDict]: - result = await self.handler.on_send_knock_request(origin, content, room_id) - return 200, result - - -class FederationEventAuthServlet(BaseFederationServerServlet): - PATH = "/event_auth/(?P[^/]*)/(?P[^/]*)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, JsonDict]: - return await self.handler.on_event_auth(origin, room_id, event_id) - - -class FederationV1SendJoinServlet(BaseFederationServerServlet): - PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, Tuple[int, JsonDict]]: - # TODO(paul): assert that event_id parsed from path actually - # match those given in content - result = await self.handler.on_send_join_request(origin, content, room_id) - return 200, (200, result) - - -class FederationV2SendJoinServlet(BaseFederationServerServlet): - PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" - - PREFIX = FEDERATION_V2_PREFIX - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, JsonDict]: - # TODO(paul): assert that event_id parsed from path actually - # match those given in content - result = await self.handler.on_send_join_request(origin, content, room_id) - return 200, result - - -class FederationV1InviteServlet(BaseFederationServerServlet): - PATH = "/invite/(?P[^/]*)/(?P[^/]*)" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, Tuple[int, JsonDict]]: - # We don't get a room version, so we have to assume its EITHER v1 or - # v2. This is "fine" as the only difference between V1 and V2 is the - # state resolution algorithm, and we don't use that for processing - # invites - result = await self.handler.on_invite_request( - origin, content, room_version_id=RoomVersions.V1.identifier - ) - - # V1 federation API is defined to return a content of `[200, {...}]` - # due to a historical bug. - return 200, (200, result) - - -class FederationV2InviteServlet(BaseFederationServerServlet): - PATH = "/invite/(?P[^/]*)/(?P[^/]*)" - - PREFIX = FEDERATION_V2_PREFIX - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, JsonDict]: - # TODO(paul): assert that room_id/event_id parsed from path actually - # match those given in content - - room_version = content["room_version"] - event = content["event"] - invite_room_state = content["invite_room_state"] - - # Synapse expects invite_room_state to be in unsigned, as it is in v1 - # API - - event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state - - result = await self.handler.on_invite_request( - origin, event, room_version_id=room_version - ) - return 200, result - - -class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet): - PATH = "/exchange_third_party_invite/(?P[^/]*)" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - await self.handler.on_exchange_third_party_invite_request(content) - return 200, {} - - -class FederationClientKeysQueryServlet(BaseFederationServerServlet): - PATH = "/user/keys/query" - - async def on_POST( - self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - return await self.handler.on_query_client_keys(origin, content) - - -class FederationUserDevicesQueryServlet(BaseFederationServerServlet): - PATH = "/user/devices/(?P[^/]*)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - user_id: str, - ) -> Tuple[int, JsonDict]: - return await self.handler.on_query_user_devices(origin, user_id) - - -class FederationClientKeysClaimServlet(BaseFederationServerServlet): - PATH = "/user/keys/claim" - - async def on_POST( - self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - response = await self.handler.on_claim_client_keys(origin, content) - return 200, response - - -class FederationGetMissingEventsServlet(BaseFederationServerServlet): - # TODO(paul): Why does this path alone end with "/?" optional? - PATH = "/get_missing_events/(?P[^/]*)/?" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - limit = int(content.get("limit", 10)) - earliest_events = content.get("earliest_events", []) - latest_events = content.get("latest_events", []) - - result = await self.handler.on_get_missing_events( - origin, - room_id=room_id, - earliest_events=earliest_events, - latest_events=latest_events, - limit=limit, - ) - - return 200, result - - -class On3pidBindServlet(BaseFederationServerServlet): - PATH = "/3pid/onbind" - - REQUIRE_AUTH = False - - async def on_POST( - self, origin: Optional[str], content: JsonDict, query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - if "invites" in content: - last_exception = None - for invite in content["invites"]: - try: - if "signed" not in invite or "token" not in invite["signed"]: - message = ( - "Rejecting received notification of third-" - "party invite without signed: %s" % (invite,) - ) - logger.info(message) - raise SynapseError(400, message) - await self.handler.exchange_third_party_invite( - invite["sender"], - invite["mxid"], - invite["room_id"], - invite["signed"], - ) - except Exception as e: - last_exception = e - if last_exception: - raise last_exception - return 200, {} - - -class OpenIdUserInfo(BaseFederationServerServlet): - """ - Exchange a bearer token for information about a user. - - The response format should be compatible with: - http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse - - GET /openid/userinfo?access_token=ABDEFGH HTTP/1.1 - - HTTP/1.1 200 OK - Content-Type: application/json - - { - "sub": "@userpart:example.org", - } - """ - - PATH = "/openid/userinfo" - - REQUIRE_AUTH = False - - async def on_GET( - self, - origin: Optional[str], - content: Literal[None], - query: Dict[bytes, List[bytes]], - ) -> Tuple[int, JsonDict]: - token = parse_string_from_args(query, "access_token") - if token is None: - return ( - 401, - {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"}, - ) - - user_id = await self.handler.on_openid_userinfo(token) - - if user_id is None: - return ( - 401, - { - "errcode": "M_UNKNOWN_TOKEN", - "error": "Access Token unknown or expired", - }, - ) - - return 200, {"sub": user_id} - - -class PublicRoomList(BaseFederationServlet): - """ - Fetch the public room list for this server. - - This API returns information in the same format as /publicRooms on the - client API, but will only ever include local public rooms and hence is - intended for consumption by other homeservers. - - GET /publicRooms HTTP/1.1 - - HTTP/1.1 200 OK - Content-Type: application/json - - { - "chunk": [ - { - "aliases": [ - "#test:localhost" - ], - "guest_can_join": false, - "name": "test room", - "num_joined_members": 3, - "room_id": "!whkydVegtvatLfXmPN:localhost", - "world_readable": false - } - ], - "end": "END", - "start": "START" - } - """ - - PATH = "/publicRooms" - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - allow_access: bool, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_room_list_handler() - self.allow_access = allow_access - - async def on_GET( - self, origin: str, content: Literal[None], query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - if not self.allow_access: - raise FederationDeniedError(origin) - - limit = parse_integer_from_args(query, "limit", 0) - since_token = parse_string_from_args(query, "since", None) - include_all_networks = parse_boolean_from_args( - query, "include_all_networks", default=False - ) - third_party_instance_id = parse_string_from_args( - query, "third_party_instance_id", None - ) - - if include_all_networks: - network_tuple = None - elif third_party_instance_id: - network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) - else: - network_tuple = ThirdPartyInstanceID(None, None) - - if limit == 0: - # zero is a special value which corresponds to no limit. - limit = None - - data = await self.handler.get_local_public_room_list( - limit, since_token, network_tuple=network_tuple, from_federation=True - ) - return 200, data - - async def on_POST( - self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - # This implements MSC2197 (Search Filtering over Federation) - if not self.allow_access: - raise FederationDeniedError(origin) - - limit: Optional[int] = int(content.get("limit", 100)) - since_token = content.get("since", None) - search_filter = content.get("filter", None) - - include_all_networks = content.get("include_all_networks", False) - third_party_instance_id = content.get("third_party_instance_id", None) - - if include_all_networks: - network_tuple = None - if third_party_instance_id is not None: - raise SynapseError( - 400, "Can't use include_all_networks with an explicit network" - ) - elif third_party_instance_id is None: - network_tuple = ThirdPartyInstanceID(None, None) - else: - network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) - - if search_filter is None: - logger.warning("Nonefilter") - - if limit == 0: - # zero is a special value which corresponds to no limit. - limit = None - - data = await self.handler.get_local_public_room_list( - limit=limit, - since_token=since_token, - search_filter=search_filter, - network_tuple=network_tuple, - from_federation=True, - ) - - return 200, data - - -class FederationVersionServlet(BaseFederationServlet): - PATH = "/version" - - REQUIRE_AUTH = False - - async def on_GET( - self, - origin: Optional[str], - content: Literal[None], - query: Dict[bytes, List[bytes]], - ) -> Tuple[int, JsonDict]: - return ( - 200, - {"server": {"name": "Synapse", "version": get_version_string(synapse)}}, - ) - - -class BaseGroupsServerServlet(BaseFederationServlet): - """Abstract base class for federation servlet classes which provides a groups server handler. - - See BaseFederationServlet for more information. - """ - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_groups_server_handler() - - -class FederationGroupsProfileServlet(BaseGroupsServerServlet): - """Get/set the basic profile of a group on behalf of a user""" - - PATH = "/groups/(?P[^/]*)/profile" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_group_profile(group_id, requester_user_id) - - return 200, new_content - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.update_group_profile( - group_id, requester_user_id, content - ) - - return 200, new_content - - -class FederationGroupsSummaryServlet(BaseGroupsServerServlet): - PATH = "/groups/(?P[^/]*)/summary" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_group_summary(group_id, requester_user_id) - - return 200, new_content - - -class FederationGroupsRoomsServlet(BaseGroupsServerServlet): - """Get the rooms in a group on behalf of a user""" - - PATH = "/groups/(?P[^/]*)/rooms" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_rooms_in_group(group_id, requester_user_id) - - return 200, new_content - - -class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet): - """Add/remove room from group""" - - PATH = "/groups/(?P[^/]*)/room/(?P[^/]*)" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.add_room_to_group( - group_id, requester_user_id, room_id, content - ) - - return 200, new_content - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.remove_room_from_group( - group_id, requester_user_id, room_id - ) - - return 200, new_content - - -class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet): - """Update room config in group""" - - PATH = ( - "/groups/(?P[^/]*)/room/(?P[^/]*)" - "/config/(?P[^/]*)" - ) - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - room_id: str, - config_key: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - result = await self.handler.update_room_in_group( - group_id, requester_user_id, room_id, config_key, content - ) - - return 200, result - - -class FederationGroupsUsersServlet(BaseGroupsServerServlet): - """Get the users in a group on behalf of a user""" - - PATH = "/groups/(?P[^/]*)/users" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_users_in_group(group_id, requester_user_id) - - return 200, new_content - - -class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet): - """Get the users that have been invited to a group""" - - PATH = "/groups/(?P[^/]*)/invited_users" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_invited_users_in_group( - group_id, requester_user_id - ) - - return 200, new_content - - -class FederationGroupsInviteServlet(BaseGroupsServerServlet): - """Ask a group server to invite someone to the group""" - - PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/invite" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.invite_to_group( - group_id, user_id, requester_user_id, content - ) - - return 200, new_content - - -class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet): - """Accept an invitation from the group server""" - - PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/accept_invite" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - if get_domain_from_id(user_id) != origin: - raise SynapseError(403, "user_id doesn't match origin") - - new_content = await self.handler.accept_invite(group_id, user_id, content) - - return 200, new_content - - -class FederationGroupsJoinServlet(BaseGroupsServerServlet): - """Attempt to join a group""" - - PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/join" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - if get_domain_from_id(user_id) != origin: - raise SynapseError(403, "user_id doesn't match origin") - - new_content = await self.handler.join_group(group_id, user_id, content) - - return 200, new_content - - -class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet): - """Leave or kick a user from the group""" - - PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/remove" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.remove_user_from_group( - group_id, user_id, requester_user_id, content - ) - - return 200, new_content - - -class BaseGroupsLocalServlet(BaseFederationServlet): - """Abstract base class for federation servlet classes which provides a groups local handler. - - See BaseFederationServlet for more information. - """ - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_groups_local_handler() - - -class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet): - """A group server has invited a local user""" - - PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/invite" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - if get_domain_from_id(group_id) != origin: - raise SynapseError(403, "group_id doesn't match origin") - - assert isinstance( - self.handler, GroupsLocalHandler - ), "Workers cannot handle group invites." - - new_content = await self.handler.on_invite(group_id, user_id, content) - - return 200, new_content - - -class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet): - """A group server has removed a local user""" - - PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/remove" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, None]: - if get_domain_from_id(group_id) != origin: - raise SynapseError(403, "user_id doesn't match origin") - - assert isinstance( - self.handler, GroupsLocalHandler - ), "Workers cannot handle group removals." - - await self.handler.user_removed_from_group(group_id, user_id, content) - - return 200, None - - -class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): - """A group or user's server renews their attestation""" - - PATH = "/groups/(?P[^/]*)/renew_attestation/(?P[^/]*)" - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_groups_attestation_renewer() - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - # We don't need to check auth here as we check the attestation signatures - - new_content = await self.handler.on_renew_attestation( - group_id, user_id, content - ) - - return 200, new_content - - -class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet): - """Add/remove a room from the group summary, with optional category. - - Matches both: - - /groups/:group/summary/rooms/:room_id - - /groups/:group/summary/categories/:category/rooms/:room_id - """ - - PATH = ( - "/groups/(?P[^/]*)/summary" - "(/categories/(?P[^/]+))?" - "/rooms/(?P[^/]*)" - ) - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError( - 400, "category_id cannot be empty string", Codes.INVALID_PARAM - ) - - if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: - raise SynapseError( - 400, - "category_id may not be longer than %s characters" - % (MAX_GROUP_CATEGORYID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.update_group_summary_room( - group_id, - requester_user_id, - room_id=room_id, - category_id=category_id, - content=content, - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError(400, "category_id cannot be empty string") - - resp = await self.handler.delete_group_summary_room( - group_id, requester_user_id, room_id=room_id, category_id=category_id - ) - - return 200, resp - - -class FederationGroupsCategoriesServlet(BaseGroupsServerServlet): - """Get all categories for a group""" - - PATH = "/groups/(?P[^/]*)/categories/?" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_categories(group_id, requester_user_id) - - return 200, resp - - -class FederationGroupsCategoryServlet(BaseGroupsServerServlet): - """Add/remove/get a category in a group""" - - PATH = "/groups/(?P[^/]*)/categories/(?P[^/]+)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_category( - group_id, requester_user_id, category_id - ) - - return 200, resp - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError(400, "category_id cannot be empty string") - - if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: - raise SynapseError( - 400, - "category_id may not be longer than %s characters" - % (MAX_GROUP_CATEGORYID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.upsert_group_category( - group_id, requester_user_id, category_id, content - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError(400, "category_id cannot be empty string") - - resp = await self.handler.delete_group_category( - group_id, requester_user_id, category_id - ) - - return 200, resp - - -class FederationGroupsRolesServlet(BaseGroupsServerServlet): - """Get roles in a group""" - - PATH = "/groups/(?P[^/]*)/roles/?" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_roles(group_id, requester_user_id) - - return 200, resp - - -class FederationGroupsRoleServlet(BaseGroupsServerServlet): - """Add/remove/get a role in a group""" - - PATH = "/groups/(?P[^/]*)/roles/(?P[^/]+)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_role(group_id, requester_user_id, role_id) - - return 200, resp - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError( - 400, "role_id cannot be empty string", Codes.INVALID_PARAM - ) - - if len(role_id) > MAX_GROUP_ROLEID_LENGTH: - raise SynapseError( - 400, - "role_id may not be longer than %s characters" - % (MAX_GROUP_ROLEID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.update_group_role( - group_id, requester_user_id, role_id, content - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError(400, "role_id cannot be empty string") - - resp = await self.handler.delete_group_role( - group_id, requester_user_id, role_id - ) - - return 200, resp - - -class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet): - """Add/remove a user from the group summary, with optional role. - - Matches both: - - /groups/:group/summary/users/:user_id - - /groups/:group/summary/roles/:role/users/:user_id - """ - - PATH = ( - "/groups/(?P[^/]*)/summary" - "(/roles/(?P[^/]+))?" - "/users/(?P[^/]*)" - ) - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError(400, "role_id cannot be empty string") - - if len(role_id) > MAX_GROUP_ROLEID_LENGTH: - raise SynapseError( - 400, - "role_id may not be longer than %s characters" - % (MAX_GROUP_ROLEID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.update_group_summary_user( - group_id, - requester_user_id, - user_id=user_id, - role_id=role_id, - content=content, - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError(400, "role_id cannot be empty string") - - resp = await self.handler.delete_group_summary_user( - group_id, requester_user_id, user_id=user_id, role_id=role_id - ) - - return 200, resp - - -class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet): - """Get roles in a group""" - - PATH = "/get_groups_publicised" - - async def on_POST( - self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - resp = await self.handler.bulk_get_publicised_groups( - content["user_ids"], proxy=False - ) - - return 200, resp - - -class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet): - """Sets whether a group is joinable without an invite or knock""" - - PATH = "/groups/(?P[^/]*)/settings/m.join_policy" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.set_group_join_policy( - group_id, requester_user_id, content - ) - - return 200, new_content - - -class FederationSpaceSummaryServlet(BaseFederationServlet): - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" - PATH = "/spaces/(?P[^/]*)" - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_space_summary_handler() - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Mapping[bytes, Sequence[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) - max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space") - - exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[]) - - return 200, await self.handler.federation_space_summary( - origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms - ) - - # TODO When switching to the stable endpoint, remove the POST handler. - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Mapping[bytes, Sequence[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - suggested_only = content.get("suggested_only", False) - if not isinstance(suggested_only, bool): - raise SynapseError( - 400, "'suggested_only' must be a boolean", Codes.BAD_JSON - ) - - exclude_rooms = content.get("exclude_rooms", []) - if not isinstance(exclude_rooms, list) or any( - not isinstance(x, str) for x in exclude_rooms - ): - raise SynapseError(400, "bad value for 'exclude_rooms'", Codes.BAD_JSON) - - max_rooms_per_space = content.get("max_rooms_per_space") - if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int): - raise SynapseError( - 400, "bad value for 'max_rooms_per_space'", Codes.BAD_JSON - ) - - return 200, await self.handler.federation_space_summary( - origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms - ) - - -class FederationRoomHierarchyServlet(BaseFederationServlet): - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" - PATH = "/hierarchy/(?P[^/]*)" - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_space_summary_handler() - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Mapping[bytes, Sequence[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) - return 200, await self.handler.get_federation_hierarchy( - origin, room_id, suggested_only - ) - - -class RoomComplexityServlet(BaseFederationServlet): - """ - Indicates to other servers how complex (and therefore likely - resource-intensive) a public room this server knows about is. - """ - - PATH = "/rooms/(?P[^/]*)/complexity" - PREFIX = FEDERATION_UNSTABLE_PREFIX - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self._store = self.hs.get_datastore() - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - is_public = await self._store.is_room_world_readable_or_publicly_joinable( - room_id - ) - - if not is_public: - raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM) - - complexity = await self._store.get_room_complexity(room_id) - return 200, complexity - - -FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( - FederationSendServlet, - FederationEventServlet, - FederationStateV1Servlet, - FederationStateIdsServlet, - FederationBackfillServlet, - FederationQueryServlet, - FederationMakeJoinServlet, - FederationMakeLeaveServlet, - FederationEventServlet, - FederationV1SendJoinServlet, - FederationV2SendJoinServlet, - FederationV1SendLeaveServlet, - FederationV2SendLeaveServlet, - FederationV1InviteServlet, - FederationV2InviteServlet, - FederationGetMissingEventsServlet, - FederationEventAuthServlet, - FederationClientKeysQueryServlet, - FederationUserDevicesQueryServlet, - FederationClientKeysClaimServlet, - FederationThirdPartyInviteExchangeServlet, - On3pidBindServlet, - FederationVersionServlet, - RoomComplexityServlet, - FederationSpaceSummaryServlet, - FederationRoomHierarchyServlet, - FederationV1SendKnockServlet, - FederationMakeKnockServlet, -) - -OPENID_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (OpenIdUserInfo,) - -ROOM_LIST_CLASSES: Tuple[Type[PublicRoomList], ...] = (PublicRoomList,) - -GROUP_SERVER_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( - FederationGroupsProfileServlet, - FederationGroupsSummaryServlet, - FederationGroupsRoomsServlet, - FederationGroupsUsersServlet, - FederationGroupsInvitedUsersServlet, - FederationGroupsInviteServlet, - FederationGroupsAcceptInviteServlet, - FederationGroupsJoinServlet, - FederationGroupsRemoveUserServlet, - FederationGroupsSummaryRoomsServlet, - FederationGroupsCategoriesServlet, - FederationGroupsCategoryServlet, - FederationGroupsRolesServlet, - FederationGroupsRoleServlet, - FederationGroupsSummaryUsersServlet, - FederationGroupsAddRoomsServlet, - FederationGroupsAddRoomsConfigServlet, - FederationGroupsSettingJoinPolicyServlet, -) - - -GROUP_LOCAL_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( - FederationGroupsLocalInviteServlet, - FederationGroupsRemoveLocalUserServlet, - FederationGroupsBulkPublicisedServlet, -) - - -GROUP_ATTESTATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( - FederationGroupsRenewAttestaionServlet, -) - - -DEFAULT_SERVLET_GROUPS = ( - "federation", - "room_list", - "group_server", - "group_local", - "group_attestation", - "openid", -) - - -def register_servlets( - hs: HomeServer, - resource: HttpServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - servlet_groups: Optional[Container[str]] = None, -): - """Initialize and register servlet classes. - - Will by default register all servlets. For custom behaviour, pass in - a list of servlet_groups to register. - - Args: - hs: homeserver - resource: resource class to register to - authenticator: authenticator to use - ratelimiter: ratelimiter to use - servlet_groups: List of servlet groups to register. - Defaults to ``DEFAULT_SERVLET_GROUPS``. - """ - if not servlet_groups: - servlet_groups = DEFAULT_SERVLET_GROUPS - - if "federation" in servlet_groups: - for servletclass in FEDERATION_SERVLET_CLASSES: - servletclass( - hs=hs, - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - - if "openid" in servlet_groups: - for servletclass in OPENID_SERVLET_CLASSES: - servletclass( - hs=hs, - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - - if "room_list" in servlet_groups: - for servletclass in ROOM_LIST_CLASSES: - servletclass( - hs=hs, - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - allow_access=hs.config.allow_public_rooms_over_federation, - ).register(resource) - - if "group_server" in servlet_groups: - for servletclass in GROUP_SERVER_SERVLET_CLASSES: - servletclass( - hs=hs, - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - - if "group_local" in servlet_groups: - for servletclass in GROUP_LOCAL_SERVLET_CLASSES: - servletclass( - hs=hs, - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - - if "group_attestation" in servlet_groups: - for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES: - servletclass( - hs=hs, - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py new file mode 100644 index 0000000000..95176ba6f9 --- /dev/null +++ b/synapse/federation/transport/server/__init__.py @@ -0,0 +1,332 @@ +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. +# Copyright 2020 Sorunome +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Dict, Iterable, List, Optional, Tuple, Type + +from typing_extensions import Literal + +from synapse.api.errors import FederationDeniedError, SynapseError +from synapse.federation.transport.server._base import ( + Authenticator, + BaseFederationServlet, +) +from synapse.federation.transport.server.federation import FEDERATION_SERVLET_CLASSES +from synapse.federation.transport.server.groups_local import GROUP_LOCAL_SERVLET_CLASSES +from synapse.federation.transport.server.groups_server import ( + GROUP_SERVER_SERVLET_CLASSES, +) +from synapse.http.server import HttpServer, JsonResource +from synapse.http.servlet import ( + parse_boolean_from_args, + parse_integer_from_args, + parse_string_from_args, +) +from synapse.server import HomeServer +from synapse.types import JsonDict, ThirdPartyInstanceID +from synapse.util.ratelimitutils import FederationRateLimiter + +logger = logging.getLogger(__name__) + + +class TransportLayerServer(JsonResource): + """Handles incoming federation HTTP requests""" + + def __init__(self, hs: HomeServer, servlet_groups: Optional[List[str]] = None): + """Initialize the TransportLayerServer + + Will by default register all servlets. For custom behaviour, pass in + a list of servlet_groups to register. + + Args: + hs: homeserver + servlet_groups: List of servlet groups to register. + Defaults to ``DEFAULT_SERVLET_GROUPS``. + """ + self.hs = hs + self.clock = hs.get_clock() + self.servlet_groups = servlet_groups + + super().__init__(hs, canonical_json=False) + + self.authenticator = Authenticator(hs) + self.ratelimiter = hs.get_federation_ratelimiter() + + self.register_servlets() + + def register_servlets(self) -> None: + register_servlets( + self.hs, + resource=self, + ratelimiter=self.ratelimiter, + authenticator=self.authenticator, + servlet_groups=self.servlet_groups, + ) + + +class PublicRoomList(BaseFederationServlet): + """ + Fetch the public room list for this server. + + This API returns information in the same format as /publicRooms on the + client API, but will only ever include local public rooms and hence is + intended for consumption by other homeservers. + + GET /publicRooms HTTP/1.1 + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "chunk": [ + { + "aliases": [ + "#test:localhost" + ], + "guest_can_join": false, + "name": "test room", + "num_joined_members": 3, + "room_id": "!whkydVegtvatLfXmPN:localhost", + "world_readable": false + } + ], + "end": "END", + "start": "START" + } + """ + + PATH = "/publicRooms" + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_room_list_handler() + self.allow_access = hs.config.allow_public_rooms_over_federation + + async def on_GET( + self, origin: str, content: Literal[None], query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + if not self.allow_access: + raise FederationDeniedError(origin) + + limit = parse_integer_from_args(query, "limit", 0) + since_token = parse_string_from_args(query, "since", None) + include_all_networks = parse_boolean_from_args( + query, "include_all_networks", default=False + ) + third_party_instance_id = parse_string_from_args( + query, "third_party_instance_id", None + ) + + if include_all_networks: + network_tuple = None + elif third_party_instance_id: + network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) + else: + network_tuple = ThirdPartyInstanceID(None, None) + + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + + data = await self.handler.get_local_public_room_list( + limit, since_token, network_tuple=network_tuple, from_federation=True + ) + return 200, data + + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + # This implements MSC2197 (Search Filtering over Federation) + if not self.allow_access: + raise FederationDeniedError(origin) + + limit: Optional[int] = int(content.get("limit", 100)) + since_token = content.get("since", None) + search_filter = content.get("filter", None) + + include_all_networks = content.get("include_all_networks", False) + third_party_instance_id = content.get("third_party_instance_id", None) + + if include_all_networks: + network_tuple = None + if third_party_instance_id is not None: + raise SynapseError( + 400, "Can't use include_all_networks with an explicit network" + ) + elif third_party_instance_id is None: + network_tuple = ThirdPartyInstanceID(None, None) + else: + network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) + + if search_filter is None: + logger.warning("Nonefilter") + + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + + data = await self.handler.get_local_public_room_list( + limit=limit, + since_token=since_token, + search_filter=search_filter, + network_tuple=network_tuple, + from_federation=True, + ) + + return 200, data + + +class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): + """A group or user's server renews their attestation""" + + PATH = "/groups/(?P[^/]*)/renew_attestation/(?P[^/]*)" + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_attestation_renewer() + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + # We don't need to check auth here as we check the attestation signatures + + new_content = await self.handler.on_renew_attestation( + group_id, user_id, content + ) + + return 200, new_content + + +class OpenIdUserInfo(BaseFederationServlet): + """ + Exchange a bearer token for information about a user. + + The response format should be compatible with: + http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse + + GET /openid/userinfo?access_token=ABDEFGH HTTP/1.1 + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "sub": "@userpart:example.org", + } + """ + + PATH = "/openid/userinfo" + + REQUIRE_AUTH = False + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_federation_server() + + async def on_GET( + self, + origin: Optional[str], + content: Literal[None], + query: Dict[bytes, List[bytes]], + ) -> Tuple[int, JsonDict]: + token = parse_string_from_args(query, "access_token") + if token is None: + return ( + 401, + {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"}, + ) + + user_id = await self.handler.on_openid_userinfo(token) + + if user_id is None: + return ( + 401, + { + "errcode": "M_UNKNOWN_TOKEN", + "error": "Access Token unknown or expired", + }, + ) + + return 200, {"sub": user_id} + + +DEFAULT_SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = { + "federation": FEDERATION_SERVLET_CLASSES, + "room_list": (PublicRoomList,), + "group_server": GROUP_SERVER_SERVLET_CLASSES, + "group_local": GROUP_LOCAL_SERVLET_CLASSES, + "group_attestation": (FederationGroupsRenewAttestaionServlet,), + "openid": (OpenIdUserInfo,), +} + + +def register_servlets( + hs: HomeServer, + resource: HttpServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + servlet_groups: Optional[Iterable[str]] = None, +): + """Initialize and register servlet classes. + + Will by default register all servlets. For custom behaviour, pass in + a list of servlet_groups to register. + + Args: + hs: homeserver + resource: resource class to register to + authenticator: authenticator to use + ratelimiter: ratelimiter to use + servlet_groups: List of servlet groups to register. + Defaults to ``DEFAULT_SERVLET_GROUPS``. + """ + if not servlet_groups: + servlet_groups = DEFAULT_SERVLET_GROUPS.keys() + + for servlet_group in servlet_groups: + # Skip unknown servlet groups. + if servlet_group not in DEFAULT_SERVLET_GROUPS: + raise RuntimeError( + f"Attempting to register unknown federation servlet: '{servlet_group}'" + ) + + for servletclass in DEFAULT_SERVLET_GROUPS[servlet_group]: + servletclass( + hs=hs, + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py new file mode 100644 index 0000000000..624c859f1e --- /dev/null +++ b/synapse/federation/transport/server/_base.py @@ -0,0 +1,328 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import logging +import re + +from synapse.api.errors import Codes, FederationDeniedError, SynapseError +from synapse.api.urls import FEDERATION_V1_PREFIX +from synapse.http.servlet import parse_json_object_from_request +from synapse.logging import opentracing +from synapse.logging.context import run_in_background +from synapse.logging.opentracing import ( + SynapseTags, + start_active_span, + start_active_span_from_request, + tags, + whitelisted_homeserver, +) +from synapse.server import HomeServer +from synapse.util.ratelimitutils import FederationRateLimiter +from synapse.util.stringutils import parse_and_validate_server_name + +logger = logging.getLogger(__name__) + + +class AuthenticationError(SynapseError): + """There was a problem authenticating the request""" + + +class NoAuthenticationError(AuthenticationError): + """The request had no authentication information""" + + +class Authenticator: + def __init__(self, hs: HomeServer): + self._clock = hs.get_clock() + self.keyring = hs.get_keyring() + self.server_name = hs.hostname + self.store = hs.get_datastore() + self.federation_domain_whitelist = hs.config.federation_domain_whitelist + self.notifier = hs.get_notifier() + + self.replication_client = None + if hs.config.worker.worker_app: + self.replication_client = hs.get_tcp_replication() + + # A method just so we can pass 'self' as the authenticator to the Servlets + async def authenticate_request(self, request, content): + now = self._clock.time_msec() + json_request = { + "method": request.method.decode("ascii"), + "uri": request.uri.decode("ascii"), + "destination": self.server_name, + "signatures": {}, + } + + if content is not None: + json_request["content"] = content + + origin = None + + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + + if not auth_headers: + raise NoAuthenticationError( + 401, "Missing Authorization headers", Codes.UNAUTHORIZED + ) + + for auth in auth_headers: + if auth.startswith(b"X-Matrix"): + (origin, key, sig) = _parse_auth_header(auth) + json_request["origin"] = origin + json_request["signatures"].setdefault(origin, {})[key] = sig + + if ( + self.federation_domain_whitelist is not None + and origin not in self.federation_domain_whitelist + ): + raise FederationDeniedError(origin) + + if origin is None or not json_request["signatures"]: + raise NoAuthenticationError( + 401, "Missing Authorization headers", Codes.UNAUTHORIZED + ) + + await self.keyring.verify_json_for_server( + origin, + json_request, + now, + ) + + logger.debug("Request from %s", origin) + request.requester = origin + + # If we get a valid signed request from the other side, its probably + # alive + retry_timings = await self.store.get_destination_retry_timings(origin) + if retry_timings and retry_timings.retry_last_ts: + run_in_background(self._reset_retry_timings, origin) + + return origin + + async def _reset_retry_timings(self, origin): + try: + logger.info("Marking origin %r as up", origin) + await self.store.set_destination_retry_timings(origin, None, 0, 0) + + # Inform the relevant places that the remote server is back up. + self.notifier.notify_remote_server_up(origin) + if self.replication_client: + # If we're on a worker we try and inform master about this. The + # replication client doesn't hook into the notifier to avoid + # infinite loops where we send a `REMOTE_SERVER_UP` command to + # master, which then echoes it back to us which in turn pokes + # the notifier. + self.replication_client.send_remote_server_up(origin) + + except Exception: + logger.exception("Error resetting retry timings on %s", origin) + + +def _parse_auth_header(header_bytes): + """Parse an X-Matrix auth header + + Args: + header_bytes (bytes): header value + + Returns: + Tuple[str, str, str]: origin, key id, signature. + + Raises: + AuthenticationError if the header could not be parsed + """ + try: + header_str = header_bytes.decode("utf-8") + params = header_str.split(" ")[1].split(",") + param_dict = dict(kv.split("=") for kv in params) + + def strip_quotes(value): + if value.startswith('"'): + return value[1:-1] + else: + return value + + origin = strip_quotes(param_dict["origin"]) + + # ensure that the origin is a valid server name + parse_and_validate_server_name(origin) + + key = strip_quotes(param_dict["key"]) + sig = strip_quotes(param_dict["sig"]) + return origin, key, sig + except Exception as e: + logger.warning( + "Error parsing auth header '%s': %s", + header_bytes.decode("ascii", "replace"), + e, + ) + raise AuthenticationError( + 400, "Malformed Authorization header", Codes.UNAUTHORIZED + ) + + +class BaseFederationServlet: + """Abstract base class for federation servlet classes. + + The servlet object should have a PATH attribute which takes the form of a regexp to + match against the request path (excluding the /federation/v1 prefix). + + The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match + the appropriate HTTP method. These methods must be *asynchronous* and have the + signature: + + on_(self, origin, content, query, **kwargs) + + With arguments: + + origin (unicode|None): The authenticated server_name of the calling server, + unless REQUIRE_AUTH is set to False and authentication failed. + + content (unicode|None): decoded json body of the request. None if the + request was a GET. + + query (dict[bytes, list[bytes]]): Query params from the request. url-decoded + (ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded + yet. + + **kwargs (dict[unicode, unicode]): the dict mapping keys to path + components as specified in the path match regexp. + + Returns: + Optional[Tuple[int, object]]: either (response code, response object) to + return a JSON response, or None if the request has already been handled. + + Raises: + SynapseError: to return an error code + + Exception: other exceptions will be caught, logged, and a 500 will be + returned. + """ + + PATH = "" # Overridden in subclasses, the regex to match against the path. + + REQUIRE_AUTH = True + + PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version + + RATELIMIT = True # Whether to rate limit requests or not + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + self.hs = hs + self.authenticator = authenticator + self.ratelimiter = ratelimiter + self.server_name = server_name + + def _wrap(self, func): + authenticator = self.authenticator + ratelimiter = self.ratelimiter + + @functools.wraps(func) + async def new_func(request, *args, **kwargs): + """A callback which can be passed to HttpServer.RegisterPaths + + Args: + request (twisted.web.http.Request): + *args: unused? + **kwargs (dict[unicode, unicode]): the dict mapping keys to path + components as specified in the path match regexp. + + Returns: + Tuple[int, object]|None: (response code, response object) as returned by + the callback method. None if the request has already been handled. + """ + content = None + if request.method in [b"PUT", b"POST"]: + # TODO: Handle other method types? other content types? + content = parse_json_object_from_request(request) + + try: + origin = await authenticator.authenticate_request(request, content) + except NoAuthenticationError: + origin = None + if self.REQUIRE_AUTH: + logger.warning( + "authenticate_request failed: missing authentication" + ) + raise + except Exception as e: + logger.warning("authenticate_request failed: %s", e) + raise + + request_tags = { + SynapseTags.REQUEST_ID: request.get_request_id(), + tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, + tags.HTTP_METHOD: request.get_method(), + tags.HTTP_URL: request.get_redacted_uri(), + tags.PEER_HOST_IPV6: request.getClientIP(), + "authenticated_entity": origin, + "servlet_name": request.request_metrics.name, + } + + # Only accept the span context if the origin is authenticated + # and whitelisted + if origin and whitelisted_homeserver(origin): + scope = start_active_span_from_request( + request, "incoming-federation-request", tags=request_tags + ) + else: + scope = start_active_span( + "incoming-federation-request", tags=request_tags + ) + + with scope: + opentracing.inject_response_headers(request.responseHeaders) + + if origin and self.RATELIMIT: + with ratelimiter.ratelimit(origin) as d: + await d + if request._disconnected: + logger.warning( + "client disconnected before we started processing " + "request" + ) + return -1, None + response = await func( + origin, content, request.args, *args, **kwargs + ) + else: + response = await func( + origin, content, request.args, *args, **kwargs + ) + + return response + + return new_func + + def register(self, server): + pattern = re.compile("^" + self.PREFIX + self.PATH + "$") + + for method in ("GET", "PUT", "POST"): + code = getattr(self, "on_%s" % (method), None) + if code is None: + continue + + server.register_paths( + method, + (pattern,), + self._wrap(code), + self.__class__.__name__, + ) diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py new file mode 100644 index 0000000000..2806337846 --- /dev/null +++ b/synapse/federation/transport/server/federation.py @@ -0,0 +1,692 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union + +from typing_extensions import Literal + +import synapse +from synapse.api.errors import Codes, SynapseError +from synapse.api.room_versions import RoomVersions +from synapse.api.urls import FEDERATION_UNSTABLE_PREFIX, FEDERATION_V2_PREFIX +from synapse.federation.transport.server._base import ( + Authenticator, + BaseFederationServlet, +) +from synapse.http.servlet import ( + parse_boolean_from_args, + parse_integer_from_args, + parse_string_from_args, + parse_strings_from_args, +) +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util.ratelimitutils import FederationRateLimiter +from synapse.util.versionstring import get_version_string + +logger = logging.getLogger(__name__) + + +class BaseFederationServerServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a federation server handler. + + See BaseFederationServlet for more information. + """ + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_federation_server() + + +class FederationSendServlet(BaseFederationServerServlet): + PATH = "/send/(?P[^/]*)/?" + + # We ratelimit manually in the handler as we queue up the requests and we + # don't want to fill up the ratelimiter with blocked requests. + RATELIMIT = False + + # This is when someone is trying to send us a bunch of data. + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + transaction_id: str, + ) -> Tuple[int, JsonDict]: + """Called on PUT /send// + + Args: + transaction_id: The transaction_id associated with this request. This + is *not* None. + + Returns: + Tuple of `(code, response)`, where + `response` is a python dict to be converted into JSON that is + used as the response body. + """ + # Parse the request + try: + transaction_data = content + + logger.debug("Decoded %s: %s", transaction_id, str(transaction_data)) + + logger.info( + "Received txn %s from %s. (PDUs: %d, EDUs: %d)", + transaction_id, + origin, + len(transaction_data.get("pdus", [])), + len(transaction_data.get("edus", [])), + ) + + except Exception as e: + logger.exception(e) + return 400, {"error": "Invalid transaction"} + + code, response = await self.handler.on_incoming_transaction( + origin, transaction_id, self.server_name, transaction_data + ) + + return code, response + + +class FederationEventServlet(BaseFederationServerServlet): + PATH = "/event/(?P[^/]*)/?" + + # This is when someone asks for a data item for a given server data_id pair. + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + event_id: str, + ) -> Tuple[int, Union[JsonDict, str]]: + return await self.handler.on_pdu_request(origin, event_id) + + +class FederationStateV1Servlet(BaseFederationServerServlet): + PATH = "/state/(?P[^/]*)/?" + + # This is when someone asks for all data for a given room. + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + return await self.handler.on_room_state_request( + origin, + room_id, + parse_string_from_args(query, "event_id", None, required=False), + ) + + +class FederationStateIdsServlet(BaseFederationServerServlet): + PATH = "/state_ids/(?P[^/]*)/?" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + return await self.handler.on_state_ids_request( + origin, + room_id, + parse_string_from_args(query, "event_id", None, required=True), + ) + + +class FederationBackfillServlet(BaseFederationServerServlet): + PATH = "/backfill/(?P[^/]*)/?" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + versions = [x.decode("ascii") for x in query[b"v"]] + limit = parse_integer_from_args(query, "limit", None) + + if not limit: + return 400, {"error": "Did not include limit param"} + + return await self.handler.on_backfill_request(origin, room_id, versions, limit) + + +class FederationQueryServlet(BaseFederationServerServlet): + PATH = "/query/(?P[^/]*)" + + # This is when we receive a server-server Query + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + query_type: str, + ) -> Tuple[int, JsonDict]: + args = {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()} + args["origin"] = origin + return await self.handler.on_query_request(query_type, args) + + +class FederationMakeJoinServlet(BaseFederationServerServlet): + PATH = "/make_join/(?P[^/]*)/(?P[^/]*)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + """ + Args: + origin: The authenticated server_name of the calling server + + content: (GETs don't have bodies) + + query: Query params from the request. + + **kwargs: the dict mapping keys to path components as specified in + the path match regexp. + + Returns: + Tuple of (response code, response object) + """ + supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8") + if supported_versions is None: + supported_versions = ["1"] + + result = await self.handler.on_make_join_request( + origin, room_id, user_id, supported_versions=supported_versions + ) + return 200, result + + +class FederationMakeLeaveServlet(BaseFederationServerServlet): + PATH = "/make_leave/(?P[^/]*)/(?P[^/]*)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + result = await self.handler.on_make_leave_request(origin, room_id, user_id) + return 200, result + + +class FederationV1SendLeaveServlet(BaseFederationServerServlet): + PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: + result = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, (200, result) + + +class FederationV2SendLeaveServlet(BaseFederationServerServlet): + PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + result = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, result + + +class FederationMakeKnockServlet(BaseFederationServerServlet): + PATH = "/make_knock/(?P[^/]*)/(?P[^/]*)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + # Retrieve the room versions the remote homeserver claims to support + supported_versions = parse_strings_from_args( + query, "ver", required=True, encoding="utf-8" + ) + + result = await self.handler.on_make_knock_request( + origin, room_id, user_id, supported_versions=supported_versions + ) + return 200, result + + +class FederationV1SendKnockServlet(BaseFederationServerServlet): + PATH = "/send_knock/(?P[^/]*)/(?P[^/]*)" + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + result = await self.handler.on_send_knock_request(origin, content, room_id) + return 200, result + + +class FederationEventAuthServlet(BaseFederationServerServlet): + PATH = "/event_auth/(?P[^/]*)/(?P[^/]*)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + return await self.handler.on_event_auth(origin, room_id, event_id) + + +class FederationV1SendJoinServlet(BaseFederationServerServlet): + PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: + # TODO(paul): assert that event_id parsed from path actually + # match those given in content + result = await self.handler.on_send_join_request(origin, content, room_id) + return 200, (200, result) + + +class FederationV2SendJoinServlet(BaseFederationServerServlet): + PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + # TODO(paul): assert that event_id parsed from path actually + # match those given in content + result = await self.handler.on_send_join_request(origin, content, room_id) + return 200, result + + +class FederationV1InviteServlet(BaseFederationServerServlet): + PATH = "/invite/(?P[^/]*)/(?P[^/]*)" + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: + # We don't get a room version, so we have to assume its EITHER v1 or + # v2. This is "fine" as the only difference between V1 and V2 is the + # state resolution algorithm, and we don't use that for processing + # invites + result = await self.handler.on_invite_request( + origin, content, room_version_id=RoomVersions.V1.identifier + ) + + # V1 federation API is defined to return a content of `[200, {...}]` + # due to a historical bug. + return 200, (200, result) + + +class FederationV2InviteServlet(BaseFederationServerServlet): + PATH = "/invite/(?P[^/]*)/(?P[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + # TODO(paul): assert that room_id/event_id parsed from path actually + # match those given in content + + room_version = content["room_version"] + event = content["event"] + invite_room_state = content["invite_room_state"] + + # Synapse expects invite_room_state to be in unsigned, as it is in v1 + # API + + event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state + + result = await self.handler.on_invite_request( + origin, event, room_version_id=room_version + ) + return 200, result + + +class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet): + PATH = "/exchange_third_party_invite/(?P[^/]*)" + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + await self.handler.on_exchange_third_party_invite_request(content) + return 200, {} + + +class FederationClientKeysQueryServlet(BaseFederationServerServlet): + PATH = "/user/keys/query" + + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + return await self.handler.on_query_client_keys(origin, content) + + +class FederationUserDevicesQueryServlet(BaseFederationServerServlet): + PATH = "/user/devices/(?P[^/]*)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + user_id: str, + ) -> Tuple[int, JsonDict]: + return await self.handler.on_query_user_devices(origin, user_id) + + +class FederationClientKeysClaimServlet(BaseFederationServerServlet): + PATH = "/user/keys/claim" + + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + response = await self.handler.on_claim_client_keys(origin, content) + return 200, response + + +class FederationGetMissingEventsServlet(BaseFederationServerServlet): + # TODO(paul): Why does this path alone end with "/?" optional? + PATH = "/get_missing_events/(?P[^/]*)/?" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + limit = int(content.get("limit", 10)) + earliest_events = content.get("earliest_events", []) + latest_events = content.get("latest_events", []) + + result = await self.handler.on_get_missing_events( + origin, + room_id=room_id, + earliest_events=earliest_events, + latest_events=latest_events, + limit=limit, + ) + + return 200, result + + +class On3pidBindServlet(BaseFederationServerServlet): + PATH = "/3pid/onbind" + + REQUIRE_AUTH = False + + async def on_POST( + self, origin: Optional[str], content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + if "invites" in content: + last_exception = None + for invite in content["invites"]: + try: + if "signed" not in invite or "token" not in invite["signed"]: + message = ( + "Rejecting received notification of third-" + "party invite without signed: %s" % (invite,) + ) + logger.info(message) + raise SynapseError(400, message) + await self.handler.exchange_third_party_invite( + invite["sender"], + invite["mxid"], + invite["room_id"], + invite["signed"], + ) + except Exception as e: + last_exception = e + if last_exception: + raise last_exception + return 200, {} + + +class FederationVersionServlet(BaseFederationServlet): + PATH = "/version" + + REQUIRE_AUTH = False + + async def on_GET( + self, + origin: Optional[str], + content: Literal[None], + query: Dict[bytes, List[bytes]], + ) -> Tuple[int, JsonDict]: + return ( + 200, + {"server": {"name": "Synapse", "version": get_version_string(synapse)}}, + ) + + +class FederationSpaceSummaryServlet(BaseFederationServlet): + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" + PATH = "/spaces/(?P[^/]*)" + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_space_summary_handler() + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Mapping[bytes, Sequence[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) + max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space") + + exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[]) + + return 200, await self.handler.federation_space_summary( + origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms + ) + + # TODO When switching to the stable endpoint, remove the POST handler. + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Mapping[bytes, Sequence[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + suggested_only = content.get("suggested_only", False) + if not isinstance(suggested_only, bool): + raise SynapseError( + 400, "'suggested_only' must be a boolean", Codes.BAD_JSON + ) + + exclude_rooms = content.get("exclude_rooms", []) + if not isinstance(exclude_rooms, list) or any( + not isinstance(x, str) for x in exclude_rooms + ): + raise SynapseError(400, "bad value for 'exclude_rooms'", Codes.BAD_JSON) + + max_rooms_per_space = content.get("max_rooms_per_space") + if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int): + raise SynapseError( + 400, "bad value for 'max_rooms_per_space'", Codes.BAD_JSON + ) + + return 200, await self.handler.federation_space_summary( + origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms + ) + + +class FederationRoomHierarchyServlet(BaseFederationServlet): + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" + PATH = "/hierarchy/(?P[^/]*)" + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_space_summary_handler() + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Mapping[bytes, Sequence[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) + return 200, await self.handler.get_federation_hierarchy( + origin, room_id, suggested_only + ) + + +class RoomComplexityServlet(BaseFederationServlet): + """ + Indicates to other servers how complex (and therefore likely + resource-intensive) a public room this server knows about is. + """ + + PATH = "/rooms/(?P[^/]*)/complexity" + PREFIX = FEDERATION_UNSTABLE_PREFIX + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self._store = self.hs.get_datastore() + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + is_public = await self._store.is_room_world_readable_or_publicly_joinable( + room_id + ) + + if not is_public: + raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM) + + complexity = await self._store.get_room_complexity(room_id) + return 200, complexity + + +FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( + FederationSendServlet, + FederationEventServlet, + FederationStateV1Servlet, + FederationStateIdsServlet, + FederationBackfillServlet, + FederationQueryServlet, + FederationMakeJoinServlet, + FederationMakeLeaveServlet, + FederationEventServlet, + FederationV1SendJoinServlet, + FederationV2SendJoinServlet, + FederationV1SendLeaveServlet, + FederationV2SendLeaveServlet, + FederationV1InviteServlet, + FederationV2InviteServlet, + FederationGetMissingEventsServlet, + FederationEventAuthServlet, + FederationClientKeysQueryServlet, + FederationUserDevicesQueryServlet, + FederationClientKeysClaimServlet, + FederationThirdPartyInviteExchangeServlet, + On3pidBindServlet, + FederationVersionServlet, + RoomComplexityServlet, + FederationSpaceSummaryServlet, + FederationRoomHierarchyServlet, + FederationV1SendKnockServlet, + FederationMakeKnockServlet, +) diff --git a/synapse/federation/transport/server/groups_local.py b/synapse/federation/transport/server/groups_local.py new file mode 100644 index 0000000000..a12cd18d58 --- /dev/null +++ b/synapse/federation/transport/server/groups_local.py @@ -0,0 +1,113 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Tuple, Type + +from synapse.api.errors import SynapseError +from synapse.federation.transport.server._base import ( + Authenticator, + BaseFederationServlet, +) +from synapse.handlers.groups_local import GroupsLocalHandler +from synapse.server import HomeServer +from synapse.types import JsonDict, get_domain_from_id +from synapse.util.ratelimitutils import FederationRateLimiter + + +class BaseGroupsLocalServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a groups local handler. + + See BaseFederationServlet for more information. + """ + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_local_handler() + + +class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet): + """A group server has invited a local user""" + + PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/invite" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + if get_domain_from_id(group_id) != origin: + raise SynapseError(403, "group_id doesn't match origin") + + assert isinstance( + self.handler, GroupsLocalHandler + ), "Workers cannot handle group invites." + + new_content = await self.handler.on_invite(group_id, user_id, content) + + return 200, new_content + + +class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet): + """A group server has removed a local user""" + + PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/remove" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, None]: + if get_domain_from_id(group_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + assert isinstance( + self.handler, GroupsLocalHandler + ), "Workers cannot handle group removals." + + await self.handler.user_removed_from_group(group_id, user_id, content) + + return 200, None + + +class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet): + """Get roles in a group""" + + PATH = "/get_groups_publicised" + + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + resp = await self.handler.bulk_get_publicised_groups( + content["user_ids"], proxy=False + ) + + return 200, resp + + +GROUP_LOCAL_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( + FederationGroupsLocalInviteServlet, + FederationGroupsRemoveLocalUserServlet, + FederationGroupsBulkPublicisedServlet, +) diff --git a/synapse/federation/transport/server/groups_server.py b/synapse/federation/transport/server/groups_server.py new file mode 100644 index 0000000000..b30e92a5eb --- /dev/null +++ b/synapse/federation/transport/server/groups_server.py @@ -0,0 +1,753 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Tuple, Type + +from typing_extensions import Literal + +from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH +from synapse.api.errors import Codes, SynapseError +from synapse.federation.transport.server._base import ( + Authenticator, + BaseFederationServlet, +) +from synapse.http.servlet import parse_string_from_args +from synapse.server import HomeServer +from synapse.types import JsonDict, get_domain_from_id +from synapse.util.ratelimitutils import FederationRateLimiter + + +class BaseGroupsServerServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a groups server handler. + + See BaseFederationServlet for more information. + """ + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_server_handler() + + +class FederationGroupsProfileServlet(BaseGroupsServerServlet): + """Get/set the basic profile of a group on behalf of a user""" + + PATH = "/groups/(?P[^/]*)/profile" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_group_profile(group_id, requester_user_id) + + return 200, new_content + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.update_group_profile( + group_id, requester_user_id, content + ) + + return 200, new_content + + +class FederationGroupsSummaryServlet(BaseGroupsServerServlet): + PATH = "/groups/(?P[^/]*)/summary" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_group_summary(group_id, requester_user_id) + + return 200, new_content + + +class FederationGroupsRoomsServlet(BaseGroupsServerServlet): + """Get the rooms in a group on behalf of a user""" + + PATH = "/groups/(?P[^/]*)/rooms" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_rooms_in_group(group_id, requester_user_id) + + return 200, new_content + + +class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet): + """Add/remove room from group""" + + PATH = "/groups/(?P[^/]*)/room/(?P[^/]*)" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.add_room_to_group( + group_id, requester_user_id, room_id, content + ) + + return 200, new_content + + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.remove_room_from_group( + group_id, requester_user_id, room_id + ) + + return 200, new_content + + +class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet): + """Update room config in group""" + + PATH = ( + "/groups/(?P[^/]*)/room/(?P[^/]*)" + "/config/(?P[^/]*)" + ) + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + config_key: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + result = await self.handler.update_room_in_group( + group_id, requester_user_id, room_id, config_key, content + ) + + return 200, result + + +class FederationGroupsUsersServlet(BaseGroupsServerServlet): + """Get the users in a group on behalf of a user""" + + PATH = "/groups/(?P[^/]*)/users" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_users_in_group(group_id, requester_user_id) + + return 200, new_content + + +class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet): + """Get the users that have been invited to a group""" + + PATH = "/groups/(?P[^/]*)/invited_users" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_invited_users_in_group( + group_id, requester_user_id + ) + + return 200, new_content + + +class FederationGroupsInviteServlet(BaseGroupsServerServlet): + """Ask a group server to invite someone to the group""" + + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/invite" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.invite_to_group( + group_id, user_id, requester_user_id, content + ) + + return 200, new_content + + +class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet): + """Accept an invitation from the group server""" + + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/accept_invite" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + if get_domain_from_id(user_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + new_content = await self.handler.accept_invite(group_id, user_id, content) + + return 200, new_content + + +class FederationGroupsJoinServlet(BaseGroupsServerServlet): + """Attempt to join a group""" + + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/join" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + if get_domain_from_id(user_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + new_content = await self.handler.join_group(group_id, user_id, content) + + return 200, new_content + + +class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet): + """Leave or kick a user from the group""" + + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/remove" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.remove_user_from_group( + group_id, user_id, requester_user_id, content + ) + + return 200, new_content + + +class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet): + """Add/remove a room from the group summary, with optional category. + + Matches both: + - /groups/:group/summary/rooms/:room_id + - /groups/:group/summary/categories/:category/rooms/:room_id + """ + + PATH = ( + "/groups/(?P[^/]*)/summary" + "(/categories/(?P[^/]+))?" + "/rooms/(?P[^/]*)" + ) + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError( + 400, "category_id cannot be empty string", Codes.INVALID_PARAM + ) + + if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: + raise SynapseError( + 400, + "category_id may not be longer than %s characters" + % (MAX_GROUP_CATEGORYID_LENGTH,), + Codes.INVALID_PARAM, + ) + + resp = await self.handler.update_group_summary_room( + group_id, + requester_user_id, + room_id=room_id, + category_id=category_id, + content=content, + ) + + return 200, resp + + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = await self.handler.delete_group_summary_room( + group_id, requester_user_id, room_id=room_id, category_id=category_id + ) + + return 200, resp + + +class FederationGroupsCategoriesServlet(BaseGroupsServerServlet): + """Get all categories for a group""" + + PATH = "/groups/(?P[^/]*)/categories/?" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = await self.handler.get_group_categories(group_id, requester_user_id) + + return 200, resp + + +class FederationGroupsCategoryServlet(BaseGroupsServerServlet): + """Add/remove/get a category in a group""" + + PATH = "/groups/(?P[^/]*)/categories/(?P[^/]+)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = await self.handler.get_group_category( + group_id, requester_user_id, category_id + ) + + return 200, resp + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: + raise SynapseError( + 400, + "category_id may not be longer than %s characters" + % (MAX_GROUP_CATEGORYID_LENGTH,), + Codes.INVALID_PARAM, + ) + + resp = await self.handler.upsert_group_category( + group_id, requester_user_id, category_id, content + ) + + return 200, resp + + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = await self.handler.delete_group_category( + group_id, requester_user_id, category_id + ) + + return 200, resp + + +class FederationGroupsRolesServlet(BaseGroupsServerServlet): + """Get roles in a group""" + + PATH = "/groups/(?P[^/]*)/roles/?" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = await self.handler.get_group_roles(group_id, requester_user_id) + + return 200, resp + + +class FederationGroupsRoleServlet(BaseGroupsServerServlet): + """Add/remove/get a role in a group""" + + PATH = "/groups/(?P[^/]*)/roles/(?P[^/]+)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = await self.handler.get_group_role(group_id, requester_user_id, role_id) + + return 200, resp + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError( + 400, "role_id cannot be empty string", Codes.INVALID_PARAM + ) + + if len(role_id) > MAX_GROUP_ROLEID_LENGTH: + raise SynapseError( + 400, + "role_id may not be longer than %s characters" + % (MAX_GROUP_ROLEID_LENGTH,), + Codes.INVALID_PARAM, + ) + + resp = await self.handler.update_group_role( + group_id, requester_user_id, role_id, content + ) + + return 200, resp + + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = await self.handler.delete_group_role( + group_id, requester_user_id, role_id + ) + + return 200, resp + + +class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet): + """Add/remove a user from the group summary, with optional role. + + Matches both: + - /groups/:group/summary/users/:user_id + - /groups/:group/summary/roles/:role/users/:user_id + """ + + PATH = ( + "/groups/(?P[^/]*)/summary" + "(/roles/(?P[^/]+))?" + "/users/(?P[^/]*)" + ) + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + if len(role_id) > MAX_GROUP_ROLEID_LENGTH: + raise SynapseError( + 400, + "role_id may not be longer than %s characters" + % (MAX_GROUP_ROLEID_LENGTH,), + Codes.INVALID_PARAM, + ) + + resp = await self.handler.update_group_summary_user( + group_id, + requester_user_id, + user_id=user_id, + role_id=role_id, + content=content, + ) + + return 200, resp + + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = await self.handler.delete_group_summary_user( + group_id, requester_user_id, user_id=user_id, role_id=role_id + ) + + return 200, resp + + +class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet): + """Sets whether a group is joinable without an invite or knock""" + + PATH = "/groups/(?P[^/]*)/settings/m.join_policy" + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.set_group_join_policy( + group_id, requester_user_id, content + ) + + return 200, new_content + + +GROUP_SERVER_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( + FederationGroupsProfileServlet, + FederationGroupsSummaryServlet, + FederationGroupsRoomsServlet, + FederationGroupsUsersServlet, + FederationGroupsInvitedUsersServlet, + FederationGroupsInviteServlet, + FederationGroupsAcceptInviteServlet, + FederationGroupsJoinServlet, + FederationGroupsRemoveUserServlet, + FederationGroupsSummaryRoomsServlet, + FederationGroupsCategoriesServlet, + FederationGroupsCategoryServlet, + FederationGroupsRolesServlet, + FederationGroupsRoleServlet, + FederationGroupsSummaryUsersServlet, + FederationGroupsAddRoomsServlet, + FederationGroupsAddRoomsConfigServlet, + FederationGroupsSettingJoinPolicyServlet, +) -- cgit 1.5.1 From 0ace38b7b310fc1b4f88ac93d01ec900f33f7a07 Mon Sep 17 00:00:00 2001 From: Michael Telatynski <7t3chguy@gmail.com> Date: Mon, 16 Aug 2021 15:49:12 +0100 Subject: Experimental support for MSC3266 Room Summary API. (#10394) --- changelog.d/10394.feature | 1 + mypy.ini | 2 +- synapse/config/experimental.py | 3 + synapse/federation/transport/server/federation.py | 4 +- synapse/handlers/room_summary.py | 1171 +++++++++++++++++++++ synapse/handlers/space_summary.py | 1116 -------------------- synapse/http/servlet.py | 58 +- synapse/rest/admin/rooms.py | 45 +- synapse/rest/client/v1/room.py | 90 +- synapse/server.py | 6 +- tests/handlers/test_room_summary.py | 959 +++++++++++++++++ tests/handlers/test_space_summary.py | 881 ---------------- 12 files changed, 2255 insertions(+), 2081 deletions(-) create mode 100644 changelog.d/10394.feature create mode 100644 synapse/handlers/room_summary.py delete mode 100644 synapse/handlers/space_summary.py create mode 100644 tests/handlers/test_room_summary.py delete mode 100644 tests/handlers/test_space_summary.py (limited to 'synapse') diff --git a/changelog.d/10394.feature b/changelog.d/10394.feature new file mode 100644 index 0000000000..c8bbc5a740 --- /dev/null +++ b/changelog.d/10394.feature @@ -0,0 +1 @@ +Initial local support for [MSC3266](https://github.com/matrix-org/synapse/pull/10394), Room Summary over the unstable `/rooms/{roomIdOrAlias}/summary` API. diff --git a/mypy.ini b/mypy.ini index 5d6cd557bc..e1b9405daa 100644 --- a/mypy.ini +++ b/mypy.ini @@ -86,7 +86,7 @@ files = tests/test_event_auth.py, tests/test_utils, tests/handlers/test_password_providers.py, - tests/handlers/test_space_summary.py, + tests/handlers/test_room_summary.py, tests/rest/client/v1/test_login.py, tests/rest/client/v2_alpha/test_auth.py, tests/util/test_itertools.py, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 4c60ee8c28..b918fb15b0 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -38,3 +38,6 @@ class ExperimentalConfig(Config): # MSC3244 (room version capabilities) self.msc3244_enabled: bool = experimental.get("msc3244_enabled", False) + + # MSC3266 (room summary api) + self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 2806337846..7d81cc642c 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -547,7 +547,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): server_name: str, ): super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_space_summary_handler() + self.handler = hs.get_room_summary_handler() async def on_GET( self, @@ -608,7 +608,7 @@ class FederationRoomHierarchyServlet(BaseFederationServlet): server_name: str, ): super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_space_summary_handler() + self.handler = hs.get_room_summary_handler() async def on_GET( self, diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py new file mode 100644 index 0000000000..ac6cfc0da9 --- /dev/null +++ b/synapse/handlers/room_summary.py @@ -0,0 +1,1171 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging +import re +from collections import deque +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Set, Tuple + +import attr + +from synapse.api.constants import ( + EventContentFields, + EventTypes, + HistoryVisibility, + JoinRules, + Membership, + RoomTypes, +) +from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.events import EventBase +from synapse.events.utils import format_event_for_client_v2 +from synapse.types import JsonDict +from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import random_string + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +# number of rooms to return. We'll stop once we hit this limit. +MAX_ROOMS = 50 + +# max number of events to return per room. +MAX_ROOMS_PER_SPACE = 50 + +# max number of federation servers to hit per room +MAX_SERVERS_PER_SPACE = 3 + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _PaginationKey: + """The key used to find unique pagination session.""" + + # The first three entries match the request parameters (and cannot change + # during a pagination session). + room_id: str + suggested_only: bool + max_depth: Optional[int] + # The randomly generated token. + token: str + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _PaginationSession: + """The information that is stored for pagination.""" + + # The time the pagination session was created, in milliseconds. + creation_time_ms: int + # The queue of rooms which are still to process. + room_queue: List["_RoomQueueEntry"] + # A set of rooms which have been processed. + processed_rooms: Set[str] + + +class RoomSummaryHandler: + # The time a pagination session remains valid for. + _PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000 + + def __init__(self, hs: "HomeServer"): + self._clock = hs.get_clock() + self._event_auth_handler = hs.get_event_auth_handler() + self._store = hs.get_datastore() + self._event_serializer = hs.get_event_client_serializer() + self._server_name = hs.hostname + self._federation_client = hs.get_federation_client() + + # A map of query information to the current pagination state. + # + # TODO Allow for multiple workers to share this data. + # TODO Expire pagination tokens. + self._pagination_sessions: Dict[_PaginationKey, _PaginationSession] = {} + + # If a user tries to fetch the same page multiple times in quick succession, + # only process the first attempt and return its result to subsequent requests. + self._pagination_response_cache: ResponseCache[ + Tuple[str, bool, Optional[int], Optional[int], Optional[str]] + ] = ResponseCache( + hs.get_clock(), + "get_room_hierarchy", + ) + + def _expire_pagination_sessions(self): + """Expire pagination session which are old.""" + expire_before = ( + self._clock.time_msec() - self._PAGINATION_SESSION_VALIDITY_PERIOD_MS + ) + to_expire = [] + + for key, value in self._pagination_sessions.items(): + if value.creation_time_ms < expire_before: + to_expire.append(key) + + for key in to_expire: + logger.debug("Expiring pagination session id %s", key) + del self._pagination_sessions[key] + + async def get_space_summary( + self, + requester: str, + room_id: str, + suggested_only: bool = False, + max_rooms_per_space: Optional[int] = None, + ) -> JsonDict: + """ + Implementation of the space summary C-S API + + Args: + requester: user id of the user making this request + + room_id: room id to start the summary at + + suggested_only: whether we should only return children with the "suggested" + flag set. + + max_rooms_per_space: an optional limit on the number of child rooms we will + return. This does not apply to the root room (ie, room_id), and + is overridden by MAX_ROOMS_PER_SPACE. + + Returns: + summary dict to return + """ + # First of all, check that the room is accessible. + if not await self._is_local_room_accessible(room_id, requester): + raise AuthError( + 403, + "User %s not in room %s, and room previews are disabled" + % (requester, room_id), + ) + + # the queue of rooms to process + room_queue = deque((_RoomQueueEntry(room_id, ()),)) + + # rooms we have already processed + processed_rooms: Set[str] = set() + + # events we have already processed. We don't necessarily have their event ids, + # so instead we key on (room id, state key) + processed_events: Set[Tuple[str, str]] = set() + + rooms_result: List[JsonDict] = [] + events_result: List[JsonDict] = [] + + while room_queue and len(rooms_result) < MAX_ROOMS: + queue_entry = room_queue.popleft() + room_id = queue_entry.room_id + if room_id in processed_rooms: + # already done this room + continue + + logger.debug("Processing room %s", room_id) + + is_in_room = await self._store.is_host_joined(room_id, self._server_name) + + # The client-specified max_rooms_per_space limit doesn't apply to the + # room_id specified in the request, so we ignore it if this is the + # first room we are processing. + max_children = max_rooms_per_space if processed_rooms else None + + if is_in_room: + room_entry = await self._summarize_local_room( + requester, None, room_id, suggested_only, max_children + ) + + events: Sequence[JsonDict] = [] + if room_entry: + rooms_result.append(room_entry.room) + events = room_entry.children_state_events + + logger.debug( + "Query of local room %s returned events %s", + room_id, + ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], + ) + else: + fed_rooms = await self._summarize_remote_room( + queue_entry, + suggested_only, + max_children, + exclude_rooms=processed_rooms, + ) + + # The results over federation might include rooms that the we, + # as the requesting server, are allowed to see, but the requesting + # user is not permitted see. + # + # Filter the returned results to only what is accessible to the user. + events = [] + for room_entry in fed_rooms: + room = room_entry.room + fed_room_id = room_entry.room_id + + # The user can see the room, include it! + if await self._is_remote_room_accessible( + requester, fed_room_id, room + ): + # Before returning to the client, remove the allowed_room_ids + # and allowed_spaces keys. + room.pop("allowed_room_ids", None) + room.pop("allowed_spaces", None) + + rooms_result.append(room) + events.extend(room_entry.children_state_events) + + # All rooms returned don't need visiting again (even if the user + # didn't have access to them). + processed_rooms.add(fed_room_id) + + logger.debug( + "Query of %s returned rooms %s, events %s", + room_id, + [room_entry.room.get("room_id") for room_entry in fed_rooms], + ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], + ) + + # the room we queried may or may not have been returned, but don't process + # it again, anyway. + processed_rooms.add(room_id) + + # XXX: is it ok that we blindly iterate through any events returned by + # a remote server, whether or not they actually link to any rooms in our + # tree? + for ev in events: + # remote servers might return events we have already processed + # (eg, Dendrite returns inward pointers as well as outward ones), so + # we need to filter them out, to avoid returning duplicate links to the + # client. + ev_key = (ev["room_id"], ev["state_key"]) + if ev_key in processed_events: + continue + events_result.append(ev) + + # add the child to the queue. we have already validated + # that the vias are a list of server names. + room_queue.append( + _RoomQueueEntry(ev["state_key"], ev["content"]["via"]) + ) + processed_events.add(ev_key) + + return {"rooms": rooms_result, "events": events_result} + + async def get_room_hierarchy( + self, + requester: str, + requested_room_id: str, + suggested_only: bool = False, + max_depth: Optional[int] = None, + limit: Optional[int] = None, + from_token: Optional[str] = None, + ) -> JsonDict: + """ + Implementation of the room hierarchy C-S API. + + Args: + requester: The user ID of the user making this request. + requested_room_id: The room ID to start the hierarchy at (the "root" room). + suggested_only: Whether we should only return children with the "suggested" + flag set. + max_depth: The maximum depth in the tree to explore, must be a + non-negative integer. + + 0 would correspond to just the root room, 1 would include just + the root room's children, etc. + limit: An optional limit on the number of rooms to return per + page. Must be a positive integer. + from_token: An optional pagination token. + + Returns: + The JSON hierarchy dictionary. + """ + # If a user tries to fetch the same page multiple times in quick succession, + # only process the first attempt and return its result to subsequent requests. + # + # This is due to the pagination process mutating internal state, attempting + # to process multiple requests for the same page will result in errors. + return await self._pagination_response_cache.wrap( + (requested_room_id, suggested_only, max_depth, limit, from_token), + self._get_room_hierarchy, + requester, + requested_room_id, + suggested_only, + max_depth, + limit, + from_token, + ) + + async def _get_room_hierarchy( + self, + requester: str, + requested_room_id: str, + suggested_only: bool = False, + max_depth: Optional[int] = None, + limit: Optional[int] = None, + from_token: Optional[str] = None, + ) -> JsonDict: + """See docstring for SpaceSummaryHandler.get_room_hierarchy.""" + + # First of all, check that the room is accessible. + if not await self._is_local_room_accessible(requested_room_id, requester): + raise AuthError( + 403, + "User %s not in room %s, and room previews are disabled" + % (requester, requested_room_id), + ) + + # If this is continuing a previous session, pull the persisted data. + if from_token: + self._expire_pagination_sessions() + + pagination_key = _PaginationKey( + requested_room_id, suggested_only, max_depth, from_token + ) + if pagination_key not in self._pagination_sessions: + raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM) + + # Load the previous state. + pagination_session = self._pagination_sessions[pagination_key] + room_queue = pagination_session.room_queue + processed_rooms = pagination_session.processed_rooms + else: + # The queue of rooms to process, the next room is last on the stack. + room_queue = [_RoomQueueEntry(requested_room_id, ())] + + # Rooms we have already processed. + processed_rooms = set() + + rooms_result: List[JsonDict] = [] + + # Cap the limit to a server-side maximum. + if limit is None: + limit = MAX_ROOMS + else: + limit = min(limit, MAX_ROOMS) + + # Iterate through the queue until we reach the limit or run out of + # rooms to include. + while room_queue and len(rooms_result) < limit: + queue_entry = room_queue.pop() + room_id = queue_entry.room_id + current_depth = queue_entry.depth + if room_id in processed_rooms: + # already done this room + continue + + logger.debug("Processing room %s", room_id) + + # A map of summaries for children rooms that might be returned over + # federation. The rationale for caching these and *maybe* using them + # is to prefer any information local to the homeserver before trusting + # data received over federation. + children_room_entries: Dict[str, JsonDict] = {} + # A set of room IDs which are children that did not have information + # returned over federation and are known to be inaccessible to the + # current server. We should not reach out over federation to try to + # summarise these rooms. + inaccessible_children: Set[str] = set() + + # If the room is known locally, summarise it! + is_in_room = await self._store.is_host_joined(room_id, self._server_name) + if is_in_room: + room_entry = await self._summarize_local_room( + requester, + None, + room_id, + suggested_only, + # TODO Handle max children. + max_children=None, + ) + + # Otherwise, attempt to use information for federation. + else: + # A previous call might have included information for this room. + # It can be used if either: + # + # 1. The room is not a space. + # 2. The maximum depth has been achieved (since no children + # information is needed). + if queue_entry.remote_room and ( + queue_entry.remote_room.get("room_type") != RoomTypes.SPACE + or (max_depth is not None and current_depth >= max_depth) + ): + room_entry = _RoomEntry( + queue_entry.room_id, queue_entry.remote_room + ) + + # If the above isn't true, attempt to fetch the room + # information over federation. + else: + ( + room_entry, + children_room_entries, + inaccessible_children, + ) = await self._summarize_remote_room_hierarchy( + queue_entry, + suggested_only, + ) + + # Ensure this room is accessible to the requester (and not just + # the homeserver). + if room_entry and not await self._is_remote_room_accessible( + requester, queue_entry.room_id, room_entry.room + ): + room_entry = None + + # This room has been processed and should be ignored if it appears + # elsewhere in the hierarchy. + processed_rooms.add(room_id) + + # There may or may not be a room entry based on whether it is + # inaccessible to the requesting user. + if room_entry: + # Add the room (including the stripped m.space.child events). + rooms_result.append(room_entry.as_json()) + + # If this room is not at the max-depth, check if there are any + # children to process. + if max_depth is None or current_depth < max_depth: + # The children get added in reverse order so that the next + # room to process, according to the ordering, is the last + # item in the list. + room_queue.extend( + _RoomQueueEntry( + ev["state_key"], + ev["content"]["via"], + current_depth + 1, + children_room_entries.get(ev["state_key"]), + ) + for ev in reversed(room_entry.children_state_events) + if ev["type"] == EventTypes.SpaceChild + and ev["state_key"] not in inaccessible_children + ) + + result: JsonDict = {"rooms": rooms_result} + + # If there's additional data, generate a pagination token (and persist state). + if room_queue: + next_batch = random_string(24) + result["next_batch"] = next_batch + pagination_key = _PaginationKey( + requested_room_id, suggested_only, max_depth, next_batch + ) + self._pagination_sessions[pagination_key] = _PaginationSession( + self._clock.time_msec(), room_queue, processed_rooms + ) + + return result + + async def federation_space_summary( + self, + origin: str, + room_id: str, + suggested_only: bool, + max_rooms_per_space: Optional[int], + exclude_rooms: Iterable[str], + ) -> JsonDict: + """ + Implementation of the space summary Federation API + + Args: + origin: The server requesting the spaces summary. + + room_id: room id to start the summary at + + suggested_only: whether we should only return children with the "suggested" + flag set. + + max_rooms_per_space: an optional limit on the number of child rooms we will + return. Unlike the C-S API, this applies to the root room (room_id). + It is clipped to MAX_ROOMS_PER_SPACE. + + exclude_rooms: a list of rooms to skip over (presumably because the + calling server has already seen them). + + Returns: + summary dict to return + """ + # the queue of rooms to process + room_queue = deque((room_id,)) + + # the set of rooms that we should not walk further. Initialise it with the + # excluded-rooms list; we will add other rooms as we process them so that + # we do not loop. + processed_rooms: Set[str] = set(exclude_rooms) + + rooms_result: List[JsonDict] = [] + events_result: List[JsonDict] = [] + + while room_queue and len(rooms_result) < MAX_ROOMS: + room_id = room_queue.popleft() + if room_id in processed_rooms: + # already done this room + continue + + room_entry = await self._summarize_local_room( + None, origin, room_id, suggested_only, max_rooms_per_space + ) + + processed_rooms.add(room_id) + + if room_entry: + rooms_result.append(room_entry.room) + events_result.extend(room_entry.children_state_events) + + # add any children to the queue + room_queue.extend( + edge_event["state_key"] + for edge_event in room_entry.children_state_events + ) + + return {"rooms": rooms_result, "events": events_result} + + async def get_federation_hierarchy( + self, + origin: str, + requested_room_id: str, + suggested_only: bool, + ): + """ + Implementation of the room hierarchy Federation API. + + This is similar to get_room_hierarchy, but does not recurse into the space. + It also considers whether anyone on the server may be able to access the + room, as opposed to whether a specific user can. + + Args: + origin: The server requesting the spaces summary. + requested_room_id: The room ID to start the hierarchy at (the "root" room). + suggested_only: whether we should only return children with the "suggested" + flag set. + + Returns: + The JSON hierarchy dictionary. + """ + root_room_entry = await self._summarize_local_room( + None, origin, requested_room_id, suggested_only, max_children=None + ) + if root_room_entry is None: + # Room is inaccessible to the requesting server. + raise SynapseError(404, "Unknown room: %s" % (requested_room_id,)) + + children_rooms_result: List[JsonDict] = [] + inaccessible_children: List[str] = [] + + # Iterate through each child and potentially add it, but not its children, + # to the response. + for child_room in root_room_entry.children_state_events: + room_id = child_room.get("state_key") + assert isinstance(room_id, str) + # If the room is unknown, skip it. + if not await self._store.is_host_joined(room_id, self._server_name): + continue + + room_entry = await self._summarize_local_room( + None, origin, room_id, suggested_only, max_children=0 + ) + # If the room is accessible, include it in the results. + # + # Note that only the room summary (without information on children) + # is included in the summary. + if room_entry: + children_rooms_result.append(room_entry.room) + # Otherwise, note that the requesting server shouldn't bother + # trying to summarize this room - they do not have access to it. + else: + inaccessible_children.append(room_id) + + return { + # Include the requested room (including the stripped children events). + "room": root_room_entry.as_json(), + "children": children_rooms_result, + "inaccessible_children": inaccessible_children, + } + + async def _summarize_local_room( + self, + requester: Optional[str], + origin: Optional[str], + room_id: str, + suggested_only: bool, + max_children: Optional[int], + ) -> Optional["_RoomEntry"]: + """ + Generate a room entry and a list of event entries for a given room. + + Args: + requester: + The user requesting the summary, if it is a local request. None + if this is a federation request. + origin: + The server requesting the summary, if it is a federation request. + None if this is a local request. + room_id: The room ID to summarize. + suggested_only: True if only suggested children should be returned. + Otherwise, all children are returned. + max_children: + The maximum number of children rooms to include. This is capped + to a server-set limit. + + Returns: + A room entry if the room should be returned. None, otherwise. + """ + if not await self._is_local_room_accessible(room_id, requester, origin): + return None + + room_entry = await self._build_room_entry(room_id, for_federation=bool(origin)) + + # If the room is not a space or the children don't matter, return just + # the room information. + if room_entry.get("room_type") != RoomTypes.SPACE or max_children == 0: + return _RoomEntry(room_id, room_entry) + + # Otherwise, look for child rooms/spaces. + child_events = await self._get_child_events(room_id) + + if suggested_only: + # we only care about suggested children + child_events = filter(_is_suggested_child_event, child_events) + + if max_children is None or max_children > MAX_ROOMS_PER_SPACE: + max_children = MAX_ROOMS_PER_SPACE + + now = self._clock.time_msec() + events_result: List[JsonDict] = [] + for edge_event in itertools.islice(child_events, max_children): + events_result.append( + await self._event_serializer.serialize_event( + edge_event, + time_now=now, + event_format=format_event_for_client_v2, + ) + ) + + return _RoomEntry(room_id, room_entry, events_result) + + async def _summarize_remote_room( + self, + room: "_RoomQueueEntry", + suggested_only: bool, + max_children: Optional[int], + exclude_rooms: Iterable[str], + ) -> Iterable["_RoomEntry"]: + """ + Request room entries and a list of event entries for a given room by querying a remote server. + + Args: + room: The room to summarize. + suggested_only: True if only suggested children should be returned. + Otherwise, all children are returned. + max_children: + The maximum number of children rooms to include. This is capped + to a server-set limit. + exclude_rooms: + Rooms IDs which do not need to be summarized. + + Returns: + An iterable of room entries. + """ + room_id = room.room_id + logger.info("Requesting summary for %s via %s", room_id, room.via) + + # we need to make the exclusion list json-serialisable + exclude_rooms = list(exclude_rooms) + + via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE) + try: + res = await self._federation_client.get_space_summary( + via, + room_id, + suggested_only=suggested_only, + max_rooms_per_space=max_children, + exclude_rooms=exclude_rooms, + ) + except Exception as e: + logger.warning( + "Unable to get summary of %s via federation: %s", + room_id, + e, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + return () + + # Group the events by their room. + children_by_room: Dict[str, List[JsonDict]] = {} + for ev in res.events: + if ev.event_type == EventTypes.SpaceChild: + children_by_room.setdefault(ev.room_id, []).append(ev.data) + + # Generate the final results. + results = [] + for fed_room in res.rooms: + fed_room_id = fed_room.get("room_id") + if not fed_room_id or not isinstance(fed_room_id, str): + continue + + results.append( + _RoomEntry( + fed_room_id, + fed_room, + children_by_room.get(fed_room_id, []), + ) + ) + + return results + + async def _summarize_remote_room_hierarchy( + self, room: "_RoomQueueEntry", suggested_only: bool + ) -> Tuple[Optional["_RoomEntry"], Dict[str, JsonDict], Set[str]]: + """ + Request room entries and a list of event entries for a given room by querying a remote server. + + Args: + room: The room to summarize. + suggested_only: True if only suggested children should be returned. + Otherwise, all children are returned. + + Returns: + A tuple of: + The room entry. + Partial room data return over federation. + A set of inaccessible children room IDs. + """ + room_id = room.room_id + logger.info("Requesting summary for %s via %s", room_id, room.via) + + via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE) + try: + ( + room_response, + children, + inaccessible_children, + ) = await self._federation_client.get_room_hierarchy( + via, + room_id, + suggested_only=suggested_only, + ) + except Exception as e: + logger.warning( + "Unable to get hierarchy of %s via federation: %s", + room_id, + e, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + return None, {}, set() + + # Map the children to their room ID. + children_by_room_id = { + c["room_id"]: c + for c in children + if "room_id" in c and isinstance(c["room_id"], str) + } + + return ( + _RoomEntry(room_id, room_response, room_response.pop("children_state", ())), + children_by_room_id, + set(inaccessible_children), + ) + + async def _is_local_room_accessible( + self, room_id: str, requester: Optional[str], origin: Optional[str] = None + ) -> bool: + """ + Calculate whether the room should be shown to the requester. + + It should return true if: + + * The requester is joined or can join the room (per MSC3173). + * The origin server has any user that is joined or can join the room. + * The history visibility is set to world readable. + + Args: + room_id: The room ID to check accessibility of. + requester: + The user making the request, if it is a local request. + None if this is a federation request. + origin: + The server making the request, if it is a federation request. + None if this is a local request. + + Returns: + True if the room is accessible to the requesting user or server. + """ + state_ids = await self._store.get_current_state_ids(room_id) + + # If there's no state for the room, it isn't known. + if not state_ids: + # The user might have a pending invite for the room. + if requester and await self._store.get_invite_for_local_user_in_room( + requester, room_id + ): + return True + + logger.info("room %s is unknown, omitting from summary", room_id) + return False + + room_version = await self._store.get_room_version(room_id) + + # Include the room if it has join rules of public or knock. + join_rules_event_id = state_ids.get((EventTypes.JoinRules, "")) + if join_rules_event_id: + join_rules_event = await self._store.get_event(join_rules_event_id) + join_rule = join_rules_event.content.get("join_rule") + if join_rule == JoinRules.PUBLIC or ( + room_version.msc2403_knocking and join_rule == JoinRules.KNOCK + ): + return True + + # Include the room if it is peekable. + hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, "")) + if hist_vis_event_id: + hist_vis_ev = await self._store.get_event(hist_vis_event_id) + hist_vis = hist_vis_ev.content.get("history_visibility") + if hist_vis == HistoryVisibility.WORLD_READABLE: + return True + + # Otherwise we need to check information specific to the user or server. + + # If we have an authenticated requesting user, check if they are a member + # of the room (or can join the room). + if requester: + member_event_id = state_ids.get((EventTypes.Member, requester), None) + + # If they're in the room they can see info on it. + if member_event_id: + member_event = await self._store.get_event(member_event_id) + if member_event.membership in (Membership.JOIN, Membership.INVITE): + return True + + # Otherwise, check if they should be allowed access via membership in a space. + if await self._event_auth_handler.has_restricted_join_rules( + state_ids, room_version + ): + allowed_rooms = ( + await self._event_auth_handler.get_rooms_that_allow_join(state_ids) + ) + if await self._event_auth_handler.is_user_in_rooms( + allowed_rooms, requester + ): + return True + + # If this is a request over federation, check if the host is in the room or + # has a user who could join the room. + elif origin: + if await self._event_auth_handler.check_host_in_room( + room_id, origin + ) or await self._store.is_host_invited(room_id, origin): + return True + + # Alternately, if the host has a user in any of the spaces specified + # for access, then the host can see this room (and should do filtering + # if the requester cannot see it). + if await self._event_auth_handler.has_restricted_join_rules( + state_ids, room_version + ): + allowed_rooms = ( + await self._event_auth_handler.get_rooms_that_allow_join(state_ids) + ) + for space_id in allowed_rooms: + if await self._event_auth_handler.check_host_in_room( + space_id, origin + ): + return True + + logger.info( + "room %s is unpeekable and requester %s is not a member / not allowed to join, omitting from summary", + room_id, + requester or origin, + ) + return False + + async def _is_remote_room_accessible( + self, requester: str, room_id: str, room: JsonDict + ) -> bool: + """ + Calculate whether the room received over federation should be shown to the requester. + + It should return true if: + + * The requester is joined or can join the room (per MSC3173). + * The history visibility is set to world readable. + + Note that the local server is not in the requested room (which is why the + remote call was made in the first place), but the user could have access + due to an invite, etc. + + Args: + requester: The user requesting the summary. + room_id: The room ID returned over federation. + room: The summary of the room returned over federation. + + Returns: + True if the room is accessible to the requesting user. + """ + # The API doesn't return the room version so assume that a + # join rule of knock is valid. + if ( + room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK) + or room.get("world_readable") is True + ): + return True + + # Check if the user is a member of any of the allowed spaces + # from the response. + allowed_rooms = room.get("allowed_room_ids") or room.get("allowed_spaces") + if allowed_rooms and isinstance(allowed_rooms, list): + if await self._event_auth_handler.is_user_in_rooms( + allowed_rooms, requester + ): + return True + + # Finally, check locally if we can access the room. The user might + # already be in the room (if it was a child room), or there might be a + # pending invite, etc. + return await self._is_local_room_accessible(room_id, requester) + + async def _build_room_entry(self, room_id: str, for_federation: bool) -> JsonDict: + """ + Generate en entry summarising a single room. + + Args: + room_id: The room ID to summarize. + for_federation: True if this is a summary requested over federation + (which includes additional fields). + + Returns: + The JSON dictionary for the room. + """ + stats = await self._store.get_room_with_stats(room_id) + + # currently this should be impossible because we call + # _is_local_room_accessible on the room before we get here, so + # there should always be an entry + assert stats is not None, "unable to retrieve stats for %s" % (room_id,) + + current_state_ids = await self._store.get_current_state_ids(room_id) + create_event = await self._store.get_event( + current_state_ids[(EventTypes.Create, "")] + ) + + entry = { + "room_id": stats["room_id"], + "name": stats["name"], + "topic": stats["topic"], + "canonical_alias": stats["canonical_alias"], + "num_joined_members": stats["joined_members"], + "avatar_url": stats["avatar"], + "join_rules": stats["join_rules"], + "world_readable": ( + stats["history_visibility"] == HistoryVisibility.WORLD_READABLE + ), + "guest_can_join": stats["guest_access"] == "can_join", + "creation_ts": create_event.origin_server_ts, + "room_type": create_event.content.get(EventContentFields.ROOM_TYPE), + } + + # Federation requests need to provide additional information so the + # requested server is able to filter the response appropriately. + if for_federation: + room_version = await self._store.get_room_version(room_id) + if await self._event_auth_handler.has_restricted_join_rules( + current_state_ids, room_version + ): + allowed_rooms = ( + await self._event_auth_handler.get_rooms_that_allow_join( + current_state_ids + ) + ) + if allowed_rooms: + entry["allowed_room_ids"] = allowed_rooms + # TODO Remove this key once the API is stable. + entry["allowed_spaces"] = allowed_rooms + + # Filter out Nones – rather omit the field altogether + room_entry = {k: v for k, v in entry.items() if v is not None} + + return room_entry + + async def _get_child_events(self, room_id: str) -> Iterable[EventBase]: + """ + Get the child events for a given room. + + The returned results are sorted for stability. + + Args: + room_id: The room id to get the children of. + + Returns: + An iterable of sorted child events. + """ + + # look for child rooms/spaces. + current_state_ids = await self._store.get_current_state_ids(room_id) + + events = await self._store.get_events_as_list( + [ + event_id + for key, event_id in current_state_ids.items() + if key[0] == EventTypes.SpaceChild + ] + ) + + # filter out any events without a "via" (which implies it has been redacted), + # and order to ensure we return stable results. + return sorted(filter(_has_valid_via, events), key=_child_events_comparison_key) + + async def get_room_summary( + self, + requester: Optional[str], + room_id: str, + remote_room_hosts: Optional[List[str]] = None, + ) -> JsonDict: + """ + Implementation of the room summary C-S API from MSC3266 + + Args: + requester: user id of the user making this request, will be None + for unauthenticated requests + + room_id: room id to summarise. + + remote_room_hosts: a list of homeservers to try fetching data through + if we don't know it ourselves + + Returns: + summary dict to return + """ + is_in_room = await self._store.is_host_joined(room_id, self._server_name) + + if is_in_room: + room_entry = await self._summarize_local_room( + requester, + None, + room_id, + # Suggested-only doesn't matter since no children are requested. + suggested_only=False, + max_children=0, + ) + + if not room_entry: + raise NotFoundError("Room not found or is not accessible") + + room_summary = room_entry.room + + # If there was a requester, add their membership. + if requester: + ( + membership, + _, + ) = await self._store.get_local_current_membership_for_user_in_room( + requester, room_id + ) + + room_summary["membership"] = membership or "leave" + else: + # TODO federation API, descoped from initial unstable implementation + # as MSC needs more maturing on that side. + raise SynapseError(400, "Federation is not currently supported.") + + return room_summary + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class _RoomQueueEntry: + # The room ID of this entry. + room_id: str + # The server to query if the room is not known locally. + via: Sequence[str] + # The minimum number of hops necessary to get to this room (compared to the + # originally requested room). + depth: int = 0 + # The room summary for this room returned via federation. This will only be + # used if the room is not known locally (and is not a space). + remote_room: Optional[JsonDict] = None + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class _RoomEntry: + room_id: str + # The room summary for this room. + room: JsonDict + # An iterable of the sorted, stripped children events for children of this room. + # + # This may not include all children. + children_state_events: Sequence[JsonDict] = () + + def as_json(self) -> JsonDict: + """ + Returns a JSON dictionary suitable for the room hierarchy endpoint. + + It returns the room summary including the stripped m.space.child events + as a sub-key. + """ + result = dict(self.room) + result["children_state"] = self.children_state_events + return result + + +def _has_valid_via(e: EventBase) -> bool: + via = e.content.get("via") + if not via or not isinstance(via, Sequence): + return False + for v in via: + if not isinstance(v, str): + logger.debug("Ignoring edge event %s with invalid via entry", e.event_id) + return False + return True + + +def _is_suggested_child_event(edge_event: EventBase) -> bool: + suggested = edge_event.content.get("suggested") + if isinstance(suggested, bool) and suggested: + return True + logger.debug("Ignorning not-suggested child %s", edge_event.state_key) + return False + + +# Order may only contain characters in the range of \x20 (space) to \x7E (~) inclusive. +_INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7E]") + + +def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str], str]: + """ + Generate a value for comparing two child events for ordering. + + The rules for ordering are supposed to be: + + 1. The 'order' key, if it is valid. + 2. The 'origin_server_ts' of the 'm.room.create' event. + 3. The 'room_id'. + + But we skip step 2 since we may not have any state from the room. + + Args: + child: The event for generating a comparison key. + + Returns: + The comparison key as a tuple of: + False if the ordering is valid. + The ordering field. + The room ID. + """ + order = child.content.get("order") + # If order is not a string or doesn't meet the requirements, ignore it. + if not isinstance(order, str): + order = None + elif len(order) > 50 or _INVALID_ORDER_CHARS_RE.search(order): + order = None + + # Items without an order come last. + return (order is None, order, child.room_id) diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py deleted file mode 100644 index c74e90abbc..0000000000 --- a/synapse/handlers/space_summary.py +++ /dev/null @@ -1,1116 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import itertools -import logging -import re -from collections import deque -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Set, Tuple - -import attr - -from synapse.api.constants import ( - EventContentFields, - EventTypes, - HistoryVisibility, - JoinRules, - Membership, - RoomTypes, -) -from synapse.api.errors import AuthError, Codes, SynapseError -from synapse.events import EventBase -from synapse.events.utils import format_event_for_client_v2 -from synapse.types import JsonDict -from synapse.util.caches.response_cache import ResponseCache -from synapse.util.stringutils import random_string - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - -# number of rooms to return. We'll stop once we hit this limit. -MAX_ROOMS = 50 - -# max number of events to return per room. -MAX_ROOMS_PER_SPACE = 50 - -# max number of federation servers to hit per room -MAX_SERVERS_PER_SPACE = 3 - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _PaginationKey: - """The key used to find unique pagination session.""" - - # The first three entries match the request parameters (and cannot change - # during a pagination session). - room_id: str - suggested_only: bool - max_depth: Optional[int] - # The randomly generated token. - token: str - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _PaginationSession: - """The information that is stored for pagination.""" - - # The time the pagination session was created, in milliseconds. - creation_time_ms: int - # The queue of rooms which are still to process. - room_queue: List["_RoomQueueEntry"] - # A set of rooms which have been processed. - processed_rooms: Set[str] - - -class SpaceSummaryHandler: - # The time a pagination session remains valid for. - _PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000 - - def __init__(self, hs: "HomeServer"): - self._clock = hs.get_clock() - self._event_auth_handler = hs.get_event_auth_handler() - self._store = hs.get_datastore() - self._event_serializer = hs.get_event_client_serializer() - self._server_name = hs.hostname - self._federation_client = hs.get_federation_client() - - # A map of query information to the current pagination state. - # - # TODO Allow for multiple workers to share this data. - # TODO Expire pagination tokens. - self._pagination_sessions: Dict[_PaginationKey, _PaginationSession] = {} - - # If a user tries to fetch the same page multiple times in quick succession, - # only process the first attempt and return its result to subsequent requests. - self._pagination_response_cache: ResponseCache[ - Tuple[str, bool, Optional[int], Optional[int], Optional[str]] - ] = ResponseCache( - hs.get_clock(), - "get_room_hierarchy", - ) - - def _expire_pagination_sessions(self): - """Expire pagination session which are old.""" - expire_before = ( - self._clock.time_msec() - self._PAGINATION_SESSION_VALIDITY_PERIOD_MS - ) - to_expire = [] - - for key, value in self._pagination_sessions.items(): - if value.creation_time_ms < expire_before: - to_expire.append(key) - - for key in to_expire: - logger.debug("Expiring pagination session id %s", key) - del self._pagination_sessions[key] - - async def get_space_summary( - self, - requester: str, - room_id: str, - suggested_only: bool = False, - max_rooms_per_space: Optional[int] = None, - ) -> JsonDict: - """ - Implementation of the space summary C-S API - - Args: - requester: user id of the user making this request - - room_id: room id to start the summary at - - suggested_only: whether we should only return children with the "suggested" - flag set. - - max_rooms_per_space: an optional limit on the number of child rooms we will - return. This does not apply to the root room (ie, room_id), and - is overridden by MAX_ROOMS_PER_SPACE. - - Returns: - summary dict to return - """ - # First of all, check that the room is accessible. - if not await self._is_local_room_accessible(room_id, requester): - raise AuthError( - 403, - "User %s not in room %s, and room previews are disabled" - % (requester, room_id), - ) - - # the queue of rooms to process - room_queue = deque((_RoomQueueEntry(room_id, ()),)) - - # rooms we have already processed - processed_rooms: Set[str] = set() - - # events we have already processed. We don't necessarily have their event ids, - # so instead we key on (room id, state key) - processed_events: Set[Tuple[str, str]] = set() - - rooms_result: List[JsonDict] = [] - events_result: List[JsonDict] = [] - - while room_queue and len(rooms_result) < MAX_ROOMS: - queue_entry = room_queue.popleft() - room_id = queue_entry.room_id - if room_id in processed_rooms: - # already done this room - continue - - logger.debug("Processing room %s", room_id) - - is_in_room = await self._store.is_host_joined(room_id, self._server_name) - - # The client-specified max_rooms_per_space limit doesn't apply to the - # room_id specified in the request, so we ignore it if this is the - # first room we are processing. - max_children = max_rooms_per_space if processed_rooms else None - - if is_in_room: - room_entry = await self._summarize_local_room( - requester, None, room_id, suggested_only, max_children - ) - - events: Sequence[JsonDict] = [] - if room_entry: - rooms_result.append(room_entry.room) - events = room_entry.children_state_events - - logger.debug( - "Query of local room %s returned events %s", - room_id, - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], - ) - else: - fed_rooms = await self._summarize_remote_room( - queue_entry, - suggested_only, - max_children, - exclude_rooms=processed_rooms, - ) - - # The results over federation might include rooms that the we, - # as the requesting server, are allowed to see, but the requesting - # user is not permitted see. - # - # Filter the returned results to only what is accessible to the user. - events = [] - for room_entry in fed_rooms: - room = room_entry.room - fed_room_id = room_entry.room_id - - # The user can see the room, include it! - if await self._is_remote_room_accessible( - requester, fed_room_id, room - ): - # Before returning to the client, remove the allowed_room_ids - # and allowed_spaces keys. - room.pop("allowed_room_ids", None) - room.pop("allowed_spaces", None) - - rooms_result.append(room) - events.extend(room_entry.children_state_events) - - # All rooms returned don't need visiting again (even if the user - # didn't have access to them). - processed_rooms.add(fed_room_id) - - logger.debug( - "Query of %s returned rooms %s, events %s", - room_id, - [room_entry.room.get("room_id") for room_entry in fed_rooms], - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], - ) - - # the room we queried may or may not have been returned, but don't process - # it again, anyway. - processed_rooms.add(room_id) - - # XXX: is it ok that we blindly iterate through any events returned by - # a remote server, whether or not they actually link to any rooms in our - # tree? - for ev in events: - # remote servers might return events we have already processed - # (eg, Dendrite returns inward pointers as well as outward ones), so - # we need to filter them out, to avoid returning duplicate links to the - # client. - ev_key = (ev["room_id"], ev["state_key"]) - if ev_key in processed_events: - continue - events_result.append(ev) - - # add the child to the queue. we have already validated - # that the vias are a list of server names. - room_queue.append( - _RoomQueueEntry(ev["state_key"], ev["content"]["via"]) - ) - processed_events.add(ev_key) - - return {"rooms": rooms_result, "events": events_result} - - async def get_room_hierarchy( - self, - requester: str, - requested_room_id: str, - suggested_only: bool = False, - max_depth: Optional[int] = None, - limit: Optional[int] = None, - from_token: Optional[str] = None, - ) -> JsonDict: - """ - Implementation of the room hierarchy C-S API. - - Args: - requester: The user ID of the user making this request. - requested_room_id: The room ID to start the hierarchy at (the "root" room). - suggested_only: Whether we should only return children with the "suggested" - flag set. - max_depth: The maximum depth in the tree to explore, must be a - non-negative integer. - - 0 would correspond to just the root room, 1 would include just - the root room's children, etc. - limit: An optional limit on the number of rooms to return per - page. Must be a positive integer. - from_token: An optional pagination token. - - Returns: - The JSON hierarchy dictionary. - """ - # If a user tries to fetch the same page multiple times in quick succession, - # only process the first attempt and return its result to subsequent requests. - # - # This is due to the pagination process mutating internal state, attempting - # to process multiple requests for the same page will result in errors. - return await self._pagination_response_cache.wrap( - (requested_room_id, suggested_only, max_depth, limit, from_token), - self._get_room_hierarchy, - requester, - requested_room_id, - suggested_only, - max_depth, - limit, - from_token, - ) - - async def _get_room_hierarchy( - self, - requester: str, - requested_room_id: str, - suggested_only: bool = False, - max_depth: Optional[int] = None, - limit: Optional[int] = None, - from_token: Optional[str] = None, - ) -> JsonDict: - """See docstring for SpaceSummaryHandler.get_room_hierarchy.""" - - # First of all, check that the room is accessible. - if not await self._is_local_room_accessible(requested_room_id, requester): - raise AuthError( - 403, - "User %s not in room %s, and room previews are disabled" - % (requester, requested_room_id), - ) - - # If this is continuing a previous session, pull the persisted data. - if from_token: - self._expire_pagination_sessions() - - pagination_key = _PaginationKey( - requested_room_id, suggested_only, max_depth, from_token - ) - if pagination_key not in self._pagination_sessions: - raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM) - - # Load the previous state. - pagination_session = self._pagination_sessions[pagination_key] - room_queue = pagination_session.room_queue - processed_rooms = pagination_session.processed_rooms - else: - # The queue of rooms to process, the next room is last on the stack. - room_queue = [_RoomQueueEntry(requested_room_id, ())] - - # Rooms we have already processed. - processed_rooms = set() - - rooms_result: List[JsonDict] = [] - - # Cap the limit to a server-side maximum. - if limit is None: - limit = MAX_ROOMS - else: - limit = min(limit, MAX_ROOMS) - - # Iterate through the queue until we reach the limit or run out of - # rooms to include. - while room_queue and len(rooms_result) < limit: - queue_entry = room_queue.pop() - room_id = queue_entry.room_id - current_depth = queue_entry.depth - if room_id in processed_rooms: - # already done this room - continue - - logger.debug("Processing room %s", room_id) - - # A map of summaries for children rooms that might be returned over - # federation. The rationale for caching these and *maybe* using them - # is to prefer any information local to the homeserver before trusting - # data received over federation. - children_room_entries: Dict[str, JsonDict] = {} - # A set of room IDs which are children that did not have information - # returned over federation and are known to be inaccessible to the - # current server. We should not reach out over federation to try to - # summarise these rooms. - inaccessible_children: Set[str] = set() - - # If the room is known locally, summarise it! - is_in_room = await self._store.is_host_joined(room_id, self._server_name) - if is_in_room: - room_entry = await self._summarize_local_room( - requester, - None, - room_id, - suggested_only, - # TODO Handle max children. - max_children=None, - ) - - # Otherwise, attempt to use information for federation. - else: - # A previous call might have included information for this room. - # It can be used if either: - # - # 1. The room is not a space. - # 2. The maximum depth has been achieved (since no children - # information is needed). - if queue_entry.remote_room and ( - queue_entry.remote_room.get("room_type") != RoomTypes.SPACE - or (max_depth is not None and current_depth >= max_depth) - ): - room_entry = _RoomEntry( - queue_entry.room_id, queue_entry.remote_room - ) - - # If the above isn't true, attempt to fetch the room - # information over federation. - else: - ( - room_entry, - children_room_entries, - inaccessible_children, - ) = await self._summarize_remote_room_hiearchy( - queue_entry, - suggested_only, - ) - - # Ensure this room is accessible to the requester (and not just - # the homeserver). - if room_entry and not await self._is_remote_room_accessible( - requester, queue_entry.room_id, room_entry.room - ): - room_entry = None - - # This room has been processed and should be ignored if it appears - # elsewhere in the hierarchy. - processed_rooms.add(room_id) - - # There may or may not be a room entry based on whether it is - # inaccessible to the requesting user. - if room_entry: - # Add the room (including the stripped m.space.child events). - rooms_result.append(room_entry.as_json()) - - # If this room is not at the max-depth, check if there are any - # children to process. - if max_depth is None or current_depth < max_depth: - # The children get added in reverse order so that the next - # room to process, according to the ordering, is the last - # item in the list. - room_queue.extend( - _RoomQueueEntry( - ev["state_key"], - ev["content"]["via"], - current_depth + 1, - children_room_entries.get(ev["state_key"]), - ) - for ev in reversed(room_entry.children_state_events) - if ev["type"] == EventTypes.SpaceChild - and ev["state_key"] not in inaccessible_children - ) - - result: JsonDict = {"rooms": rooms_result} - - # If there's additional data, generate a pagination token (and persist state). - if room_queue: - next_batch = random_string(24) - result["next_batch"] = next_batch - pagination_key = _PaginationKey( - requested_room_id, suggested_only, max_depth, next_batch - ) - self._pagination_sessions[pagination_key] = _PaginationSession( - self._clock.time_msec(), room_queue, processed_rooms - ) - - return result - - async def federation_space_summary( - self, - origin: str, - room_id: str, - suggested_only: bool, - max_rooms_per_space: Optional[int], - exclude_rooms: Iterable[str], - ) -> JsonDict: - """ - Implementation of the space summary Federation API - - Args: - origin: The server requesting the spaces summary. - - room_id: room id to start the summary at - - suggested_only: whether we should only return children with the "suggested" - flag set. - - max_rooms_per_space: an optional limit on the number of child rooms we will - return. Unlike the C-S API, this applies to the root room (room_id). - It is clipped to MAX_ROOMS_PER_SPACE. - - exclude_rooms: a list of rooms to skip over (presumably because the - calling server has already seen them). - - Returns: - summary dict to return - """ - # the queue of rooms to process - room_queue = deque((room_id,)) - - # the set of rooms that we should not walk further. Initialise it with the - # excluded-rooms list; we will add other rooms as we process them so that - # we do not loop. - processed_rooms: Set[str] = set(exclude_rooms) - - rooms_result: List[JsonDict] = [] - events_result: List[JsonDict] = [] - - while room_queue and len(rooms_result) < MAX_ROOMS: - room_id = room_queue.popleft() - if room_id in processed_rooms: - # already done this room - continue - - room_entry = await self._summarize_local_room( - None, origin, room_id, suggested_only, max_rooms_per_space - ) - - processed_rooms.add(room_id) - - if room_entry: - rooms_result.append(room_entry.room) - events_result.extend(room_entry.children_state_events) - - # add any children to the queue - room_queue.extend( - edge_event["state_key"] - for edge_event in room_entry.children_state_events - ) - - return {"rooms": rooms_result, "events": events_result} - - async def get_federation_hierarchy( - self, - origin: str, - requested_room_id: str, - suggested_only: bool, - ): - """ - Implementation of the room hierarchy Federation API. - - This is similar to get_room_hierarchy, but does not recurse into the space. - It also considers whether anyone on the server may be able to access the - room, as opposed to whether a specific user can. - - Args: - origin: The server requesting the spaces summary. - requested_room_id: The room ID to start the hierarchy at (the "root" room). - suggested_only: whether we should only return children with the "suggested" - flag set. - - Returns: - The JSON hierarchy dictionary. - """ - root_room_entry = await self._summarize_local_room( - None, origin, requested_room_id, suggested_only, max_children=None - ) - if root_room_entry is None: - # Room is inaccessible to the requesting server. - raise SynapseError(404, "Unknown room: %s" % (requested_room_id,)) - - children_rooms_result: List[JsonDict] = [] - inaccessible_children: List[str] = [] - - # Iterate through each child and potentially add it, but not its children, - # to the response. - for child_room in root_room_entry.children_state_events: - room_id = child_room.get("state_key") - assert isinstance(room_id, str) - # If the room is unknown, skip it. - if not await self._store.is_host_joined(room_id, self._server_name): - continue - - room_entry = await self._summarize_local_room( - None, origin, room_id, suggested_only, max_children=0 - ) - # If the room is accessible, include it in the results. - # - # Note that only the room summary (without information on children) - # is included in the summary. - if room_entry: - children_rooms_result.append(room_entry.room) - # Otherwise, note that the requesting server shouldn't bother - # trying to summarize this room - they do not have access to it. - else: - inaccessible_children.append(room_id) - - return { - # Include the requested room (including the stripped children events). - "room": root_room_entry.as_json(), - "children": children_rooms_result, - "inaccessible_children": inaccessible_children, - } - - async def _summarize_local_room( - self, - requester: Optional[str], - origin: Optional[str], - room_id: str, - suggested_only: bool, - max_children: Optional[int], - ) -> Optional["_RoomEntry"]: - """ - Generate a room entry and a list of event entries for a given room. - - Args: - requester: - The user requesting the summary, if it is a local request. None - if this is a federation request. - origin: - The server requesting the summary, if it is a federation request. - None if this is a local request. - room_id: The room ID to summarize. - suggested_only: True if only suggested children should be returned. - Otherwise, all children are returned. - max_children: - The maximum number of children rooms to include. This is capped - to a server-set limit. - - Returns: - A room entry if the room should be returned. None, otherwise. - """ - if not await self._is_local_room_accessible(room_id, requester, origin): - return None - - room_entry = await self._build_room_entry(room_id, for_federation=bool(origin)) - - # If the room is not a space or the children don't matter, return just - # the room information. - if room_entry.get("room_type") != RoomTypes.SPACE or max_children == 0: - return _RoomEntry(room_id, room_entry) - - # Otherwise, look for child rooms/spaces. - child_events = await self._get_child_events(room_id) - - if suggested_only: - # we only care about suggested children - child_events = filter(_is_suggested_child_event, child_events) - - if max_children is None or max_children > MAX_ROOMS_PER_SPACE: - max_children = MAX_ROOMS_PER_SPACE - - now = self._clock.time_msec() - events_result: List[JsonDict] = [] - for edge_event in itertools.islice(child_events, max_children): - events_result.append( - await self._event_serializer.serialize_event( - edge_event, - time_now=now, - event_format=format_event_for_client_v2, - ) - ) - - return _RoomEntry(room_id, room_entry, events_result) - - async def _summarize_remote_room( - self, - room: "_RoomQueueEntry", - suggested_only: bool, - max_children: Optional[int], - exclude_rooms: Iterable[str], - ) -> Iterable["_RoomEntry"]: - """ - Request room entries and a list of event entries for a given room by querying a remote server. - - Args: - room: The room to summarize. - suggested_only: True if only suggested children should be returned. - Otherwise, all children are returned. - max_children: - The maximum number of children rooms to include. This is capped - to a server-set limit. - exclude_rooms: - Rooms IDs which do not need to be summarized. - - Returns: - An iterable of room entries. - """ - room_id = room.room_id - logger.info("Requesting summary for %s via %s", room_id, room.via) - - # we need to make the exclusion list json-serialisable - exclude_rooms = list(exclude_rooms) - - via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE) - try: - res = await self._federation_client.get_space_summary( - via, - room_id, - suggested_only=suggested_only, - max_rooms_per_space=max_children, - exclude_rooms=exclude_rooms, - ) - except Exception as e: - logger.warning( - "Unable to get summary of %s via federation: %s", - room_id, - e, - exc_info=logger.isEnabledFor(logging.DEBUG), - ) - return () - - # Group the events by their room. - children_by_room: Dict[str, List[JsonDict]] = {} - for ev in res.events: - if ev.event_type == EventTypes.SpaceChild: - children_by_room.setdefault(ev.room_id, []).append(ev.data) - - # Generate the final results. - results = [] - for fed_room in res.rooms: - fed_room_id = fed_room.get("room_id") - if not fed_room_id or not isinstance(fed_room_id, str): - continue - - results.append( - _RoomEntry( - fed_room_id, - fed_room, - children_by_room.get(fed_room_id, []), - ) - ) - - return results - - async def _summarize_remote_room_hiearchy( - self, room: "_RoomQueueEntry", suggested_only: bool - ) -> Tuple[Optional["_RoomEntry"], Dict[str, JsonDict], Set[str]]: - """ - Request room entries and a list of event entries for a given room by querying a remote server. - - Args: - room: The room to summarize. - suggested_only: True if only suggested children should be returned. - Otherwise, all children are returned. - - Returns: - A tuple of: - The room entry. - Partial room data return over federation. - A set of inaccessible children room IDs. - """ - room_id = room.room_id - logger.info("Requesting summary for %s via %s", room_id, room.via) - - via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE) - try: - ( - room_response, - children, - inaccessible_children, - ) = await self._federation_client.get_room_hierarchy( - via, - room_id, - suggested_only=suggested_only, - ) - except Exception as e: - logger.warning( - "Unable to get hierarchy of %s via federation: %s", - room_id, - e, - exc_info=logger.isEnabledFor(logging.DEBUG), - ) - return None, {}, set() - - # Map the children to their room ID. - children_by_room_id = { - c["room_id"]: c - for c in children - if "room_id" in c and isinstance(c["room_id"], str) - } - - return ( - _RoomEntry(room_id, room_response, room_response.pop("children_state", ())), - children_by_room_id, - set(inaccessible_children), - ) - - async def _is_local_room_accessible( - self, room_id: str, requester: Optional[str], origin: Optional[str] = None - ) -> bool: - """ - Calculate whether the room should be shown in the spaces summary. - - It should be included if: - - * The requester is joined or can join the room (per MSC3173). - * The origin server has any user that is joined or can join the room. - * The history visibility is set to world readable. - - Args: - room_id: The room ID to summarize. - requester: - The user requesting the summary, if it is a local request. None - if this is a federation request. - origin: - The server requesting the summary, if it is a federation request. - None if this is a local request. - - Returns: - True if the room should be included in the spaces summary. - """ - state_ids = await self._store.get_current_state_ids(room_id) - - # If there's no state for the room, it isn't known. - if not state_ids: - # The user might have a pending invite for the room. - if requester and await self._store.get_invite_for_local_user_in_room( - requester, room_id - ): - return True - - logger.info("room %s is unknown, omitting from summary", room_id) - return False - - room_version = await self._store.get_room_version(room_id) - - # Include the room if it has join rules of public or knock. - join_rules_event_id = state_ids.get((EventTypes.JoinRules, "")) - if join_rules_event_id: - join_rules_event = await self._store.get_event(join_rules_event_id) - join_rule = join_rules_event.content.get("join_rule") - if join_rule == JoinRules.PUBLIC or ( - room_version.msc2403_knocking and join_rule == JoinRules.KNOCK - ): - return True - - # Include the room if it is peekable. - hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, "")) - if hist_vis_event_id: - hist_vis_ev = await self._store.get_event(hist_vis_event_id) - hist_vis = hist_vis_ev.content.get("history_visibility") - if hist_vis == HistoryVisibility.WORLD_READABLE: - return True - - # Otherwise we need to check information specific to the user or server. - - # If we have an authenticated requesting user, check if they are a member - # of the room (or can join the room). - if requester: - member_event_id = state_ids.get((EventTypes.Member, requester), None) - - # If they're in the room they can see info on it. - if member_event_id: - member_event = await self._store.get_event(member_event_id) - if member_event.membership in (Membership.JOIN, Membership.INVITE): - return True - - # Otherwise, check if they should be allowed access via membership in a space. - if await self._event_auth_handler.has_restricted_join_rules( - state_ids, room_version - ): - allowed_rooms = ( - await self._event_auth_handler.get_rooms_that_allow_join(state_ids) - ) - if await self._event_auth_handler.is_user_in_rooms( - allowed_rooms, requester - ): - return True - - # If this is a request over federation, check if the host is in the room or - # has a user who could join the room. - elif origin: - if await self._event_auth_handler.check_host_in_room( - room_id, origin - ) or await self._store.is_host_invited(room_id, origin): - return True - - # Alternately, if the host has a user in any of the spaces specified - # for access, then the host can see this room (and should do filtering - # if the requester cannot see it). - if await self._event_auth_handler.has_restricted_join_rules( - state_ids, room_version - ): - allowed_rooms = ( - await self._event_auth_handler.get_rooms_that_allow_join(state_ids) - ) - for space_id in allowed_rooms: - if await self._event_auth_handler.check_host_in_room( - space_id, origin - ): - return True - - logger.info( - "room %s is unpeekable and requester %s is not a member / not allowed to join, omitting from summary", - room_id, - requester or origin, - ) - return False - - async def _is_remote_room_accessible( - self, requester: str, room_id: str, room: JsonDict - ) -> bool: - """ - Calculate whether the room received over federation should be shown in the spaces summary. - - It should be included if: - - * The requester is joined or can join the room (per MSC3173). - * The history visibility is set to world readable. - - Note that the local server is not in the requested room (which is why the - remote call was made in the first place), but the user could have access - due to an invite, etc. - - Args: - requester: The user requesting the summary. - room_id: The room ID returned over federation. - room: The summary of the child room returned over federation. - - Returns: - True if the room should be included in the spaces summary. - """ - # The API doesn't return the room version so assume that a - # join rule of knock is valid. - if ( - room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK) - or room.get("world_readable") is True - ): - return True - - # Check if the user is a member of any of the allowed spaces - # from the response. - allowed_rooms = room.get("allowed_room_ids") or room.get("allowed_spaces") - if allowed_rooms and isinstance(allowed_rooms, list): - if await self._event_auth_handler.is_user_in_rooms( - allowed_rooms, requester - ): - return True - - # Finally, check locally if we can access the room. The user might - # already be in the room (if it was a child room), or there might be a - # pending invite, etc. - return await self._is_local_room_accessible(room_id, requester) - - async def _build_room_entry(self, room_id: str, for_federation: bool) -> JsonDict: - """ - Generate en entry suitable for the 'rooms' list in the summary response. - - Args: - room_id: The room ID to summarize. - for_federation: True if this is a summary requested over federation - (which includes additional fields). - - Returns: - The JSON dictionary for the room. - """ - stats = await self._store.get_room_with_stats(room_id) - - # currently this should be impossible because we call - # _is_local_room_accessible on the room before we get here, so - # there should always be an entry - assert stats is not None, "unable to retrieve stats for %s" % (room_id,) - - current_state_ids = await self._store.get_current_state_ids(room_id) - create_event = await self._store.get_event( - current_state_ids[(EventTypes.Create, "")] - ) - - entry = { - "room_id": stats["room_id"], - "name": stats["name"], - "topic": stats["topic"], - "canonical_alias": stats["canonical_alias"], - "num_joined_members": stats["joined_members"], - "avatar_url": stats["avatar"], - "join_rules": stats["join_rules"], - "world_readable": ( - stats["history_visibility"] == HistoryVisibility.WORLD_READABLE - ), - "guest_can_join": stats["guest_access"] == "can_join", - "creation_ts": create_event.origin_server_ts, - "room_type": create_event.content.get(EventContentFields.ROOM_TYPE), - } - - # Federation requests need to provide additional information so the - # requested server is able to filter the response appropriately. - if for_federation: - room_version = await self._store.get_room_version(room_id) - if await self._event_auth_handler.has_restricted_join_rules( - current_state_ids, room_version - ): - allowed_rooms = ( - await self._event_auth_handler.get_rooms_that_allow_join( - current_state_ids - ) - ) - if allowed_rooms: - entry["allowed_room_ids"] = allowed_rooms - # TODO Remove this key once the API is stable. - entry["allowed_spaces"] = allowed_rooms - - # Filter out Nones – rather omit the field altogether - room_entry = {k: v for k, v in entry.items() if v is not None} - - return room_entry - - async def _get_child_events(self, room_id: str) -> Iterable[EventBase]: - """ - Get the child events for a given room. - - The returned results are sorted for stability. - - Args: - room_id: The room id to get the children of. - - Returns: - An iterable of sorted child events. - """ - - # look for child rooms/spaces. - current_state_ids = await self._store.get_current_state_ids(room_id) - - events = await self._store.get_events_as_list( - [ - event_id - for key, event_id in current_state_ids.items() - if key[0] == EventTypes.SpaceChild - ] - ) - - # filter out any events without a "via" (which implies it has been redacted), - # and order to ensure we return stable results. - return sorted(filter(_has_valid_via, events), key=_child_events_comparison_key) - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class _RoomQueueEntry: - # The room ID of this entry. - room_id: str - # The server to query if the room is not known locally. - via: Sequence[str] - # The minimum number of hops necessary to get to this room (compared to the - # originally requested room). - depth: int = 0 - # The room summary for this room returned via federation. This will only be - # used if the room is not known locally (and is not a space). - remote_room: Optional[JsonDict] = None - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class _RoomEntry: - room_id: str - # The room summary for this room. - room: JsonDict - # An iterable of the sorted, stripped children events for children of this room. - # - # This may not include all children. - children_state_events: Sequence[JsonDict] = () - - def as_json(self) -> JsonDict: - """ - Returns a JSON dictionary suitable for the room hierarchy endpoint. - - It returns the room summary including the stripped m.space.child events - as a sub-key. - """ - result = dict(self.room) - result["children_state"] = self.children_state_events - return result - - -def _has_valid_via(e: EventBase) -> bool: - via = e.content.get("via") - if not via or not isinstance(via, Sequence): - return False - for v in via: - if not isinstance(v, str): - logger.debug("Ignoring edge event %s with invalid via entry", e.event_id) - return False - return True - - -def _is_suggested_child_event(edge_event: EventBase) -> bool: - suggested = edge_event.content.get("suggested") - if isinstance(suggested, bool) and suggested: - return True - logger.debug("Ignorning not-suggested child %s", edge_event.state_key) - return False - - -# Order may only contain characters in the range of \x20 (space) to \x7E (~) inclusive. -_INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7E]") - - -def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str], str]: - """ - Generate a value for comparing two child events for ordering. - - The rules for ordering are supposed to be: - - 1. The 'order' key, if it is valid. - 2. The 'origin_server_ts' of the 'm.room.create' event. - 3. The 'room_id'. - - But we skip step 2 since we may not have any state from the room. - - Args: - child: The event for generating a comparison key. - - Returns: - The comparison key as a tuple of: - False if the ordering is valid. - The ordering field. - The room ID. - """ - order = child.content.get("order") - # If order is not a string or doesn't meet the requirements, ignore it. - if not isinstance(order, str): - order = None - elif len(order) > 50 or _INVALID_ORDER_CHARS_RE.search(order): - order = None - - # Items without an order come last. - return (order is None, order, child.room_id) diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 732a1e6aeb..a12fa30bfd 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -14,16 +14,28 @@ """ This module contains base REST classes for constructing REST servlets. """ import logging -from typing import Iterable, List, Mapping, Optional, Sequence, overload +from typing import ( + TYPE_CHECKING, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + overload, +) from typing_extensions import Literal from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError -from synapse.types import JsonDict +from synapse.types import JsonDict, RoomAlias, RoomID from synapse.util import json_decoder +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -663,3 +675,45 @@ class RestServlet: else: raise NotImplementedError("RestServlet must register something.") + + +class ResolveRoomIdMixin: + def __init__(self, hs: "HomeServer"): + self.room_member_handler = hs.get_room_member_handler() + + async def resolve_room_id( + self, room_identifier: str, remote_room_hosts: Optional[List[str]] = None + ) -> Tuple[str, Optional[List[str]]]: + """ + Resolve a room identifier to a room ID, if necessary. + + This also performanes checks to ensure the room ID is of the proper form. + + Args: + room_identifier: The room ID or alias. + remote_room_hosts: The potential remote room hosts to use. + + Returns: + The resolved room ID. + + Raises: + SynapseError if the room ID is of the wrong form. + """ + if RoomID.is_valid(room_identifier): + resolved_room_id = room_identifier + elif RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + ( + room_id, + remote_room_hosts, + ) = await self.room_member_handler.lookup_room_alias(room_alias) + resolved_room_id = room_id.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + if not resolved_room_id: + raise SynapseError( + 400, "Unknown room ID or room alias %s" % room_identifier + ) + return resolved_room_id, remote_room_hosts diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 40ee33646c..975c28b225 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -20,6 +20,7 @@ from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.http.servlet import ( + ResolveRoomIdMixin, RestServlet, assert_params_in_dict, parse_integer, @@ -33,7 +34,7 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.storage.databases.main.room import RoomSortOrder -from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester +from synapse.types import JsonDict, UserID, create_requester from synapse.util import json_decoder if TYPE_CHECKING: @@ -45,48 +46,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class ResolveRoomIdMixin: - def __init__(self, hs: "HomeServer"): - self.room_member_handler = hs.get_room_member_handler() - - async def resolve_room_id( - self, room_identifier: str, remote_room_hosts: Optional[List[str]] = None - ) -> Tuple[str, Optional[List[str]]]: - """ - Resolve a room identifier to a room ID, if necessary. - - This also performanes checks to ensure the room ID is of the proper form. - - Args: - room_identifier: The room ID or alias. - remote_room_hosts: The potential remote room hosts to use. - - Returns: - The resolved room ID. - - Raises: - SynapseError if the room ID is of the wrong form. - """ - if RoomID.is_valid(room_identifier): - resolved_room_id = room_identifier - elif RoomAlias.is_valid(room_identifier): - room_alias = RoomAlias.from_string(room_identifier) - ( - room_id, - remote_room_hosts, - ) = await self.room_member_handler.lookup_room_alias(room_alias) - resolved_room_id = room_id.to_string() - else: - raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) - ) - if not resolved_room_id: - raise SynapseError( - 400, "Unknown room ID or room alias %s" % room_identifier - ) - return resolved_room_id, remote_room_hosts - - class ShutdownRoomRestServlet(RestServlet): """Shuts down a room by removing all local users from the room and blocking all future invites and joins to the room. Any local aliases will be repointed diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 2c3be23bc8..d3882a84e2 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -24,12 +24,14 @@ from synapse.api.errors import ( AuthError, Codes, InvalidClientCredentialsError, + MissingClientTokenError, ShadowBanError, SynapseError, ) from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2 from synapse.http.servlet import ( + ResolveRoomIdMixin, RestServlet, assert_params_in_dict, parse_boolean, @@ -44,14 +46,7 @@ from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.v2_alpha._base import client_patterns from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig -from synapse.types import ( - JsonDict, - RoomAlias, - RoomID, - StreamToken, - ThirdPartyInstanceID, - UserID, -) +from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID from synapse.util import json_decoder from synapse.util.stringutils import parse_and_validate_server_name, random_string @@ -266,10 +261,10 @@ class RoomSendEventRestServlet(TransactionRestServlet): # TODO: Needs unit testing for room ID + alias joins -class JoinRoomAliasServlet(TransactionRestServlet): +class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): def __init__(self, hs): super().__init__(hs) - self.room_member_handler = hs.get_room_member_handler() + super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up self.auth = hs.get_auth() def register(self, http_server): @@ -292,24 +287,13 @@ class JoinRoomAliasServlet(TransactionRestServlet): # cheekily send invalid bodies. content = {} - if RoomID.is_valid(room_identifier): - room_id = room_identifier - - # twisted.web.server.Request.args is incorrectly defined as Optional[Any] - args: Dict[bytes, List[bytes]] = request.args # type: ignore - - remote_room_hosts = parse_strings_from_args( - args, "server_name", required=False - ) - elif RoomAlias.is_valid(room_identifier): - handler = self.room_member_handler - room_alias = RoomAlias.from_string(room_identifier) - room_id_obj, remote_room_hosts = await handler.lookup_room_alias(room_alias) - room_id = room_id_obj.to_string() - else: - raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) - ) + # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + args: Dict[bytes, List[bytes]] = request.args # type: ignore + remote_room_hosts = parse_strings_from_args(args, "server_name", required=False) + room_id, remote_room_hosts = await self.resolve_room_id( + room_identifier, + remote_room_hosts, + ) await self.room_member_handler.update_membership( requester=requester, @@ -1002,14 +986,14 @@ class RoomSpaceSummaryRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self._auth = hs.get_auth() - self._space_summary_handler = hs.get_space_summary_handler() + self._room_summary_handler = hs.get_room_summary_handler() async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: requester = await self._auth.get_user_by_req(request, allow_guest=True) - return 200, await self._space_summary_handler.get_space_summary( + return 200, await self._room_summary_handler.get_space_summary( requester.user.to_string(), room_id, suggested_only=parse_boolean(request, "suggested_only", default=False), @@ -1035,7 +1019,7 @@ class RoomSpaceSummaryRestServlet(RestServlet): 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON ) - return 200, await self._space_summary_handler.get_space_summary( + return 200, await self._room_summary_handler.get_space_summary( requester.user.to_string(), room_id, suggested_only=suggested_only, @@ -1054,7 +1038,7 @@ class RoomHierarchyRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self._auth = hs.get_auth() - self._space_summary_handler = hs.get_space_summary_handler() + self._room_summary_handler = hs.get_room_summary_handler() async def on_GET( self, request: SynapseRequest, room_id: str @@ -1073,7 +1057,7 @@ class RoomHierarchyRestServlet(RestServlet): 400, "'limit' must be a positive integer", Codes.BAD_JSON ) - return 200, await self._space_summary_handler.get_room_hierarchy( + return 200, await self._room_summary_handler.get_room_hierarchy( requester.user.to_string(), room_id, suggested_only=parse_boolean(request, "suggested_only", default=False), @@ -1083,6 +1067,44 @@ class RoomHierarchyRestServlet(RestServlet): ) +class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/im.nheko.summary" + "/rooms/(?P[^/]*)/summary$" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + self._auth = hs.get_auth() + self._room_summary_handler = hs.get_room_summary_handler() + + async def on_GET( + self, request: SynapseRequest, room_identifier: str + ) -> Tuple[int, JsonDict]: + try: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + requester_user_id: Optional[str] = requester.user.to_string() + except MissingClientTokenError: + # auth is optional + requester_user_id = None + + # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + args: Dict[bytes, List[bytes]] = request.args # type: ignore + remote_room_hosts = parse_strings_from_args(args, "via", required=False) + room_id, remote_room_hosts = await self.resolve_room_id( + room_identifier, + remote_room_hosts, + ) + + return 200, await self._room_summary_handler.get_room_summary( + requester_user_id, + room_id, + remote_room_hosts, + ) + + def register_servlets(hs: "HomeServer", http_server, is_worker=False): RoomStateEventRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server) @@ -1098,6 +1120,8 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False): RoomEventContextServlet(hs).register(http_server) RoomSpaceSummaryRestServlet(hs).register(http_server) RoomHierarchyRestServlet(hs).register(http_server) + if hs.config.experimental.msc3266_enabled: + RoomSummaryRestServlet(hs).register(http_server) RoomEventServlet(hs).register(http_server) JoinedRoomsRestServlet(hs).register(http_server) RoomAliasListServlet(hs).register(http_server) diff --git a/synapse/server.py b/synapse/server.py index 6c867f0f47..de6517663e 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -99,10 +99,10 @@ from synapse.handlers.room import ( from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler +from synapse.handlers.room_summary import RoomSummaryHandler from synapse.handlers.search import SearchHandler from synapse.handlers.send_email import SendEmailHandler from synapse.handlers.set_password import SetPasswordHandler -from synapse.handlers.space_summary import SpaceSummaryHandler from synapse.handlers.sso import SsoHandler from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler @@ -772,8 +772,8 @@ class HomeServer(metaclass=abc.ABCMeta): return AccountDataHandler(self) @cache_in_self - def get_space_summary_handler(self) -> SpaceSummaryHandler: - return SpaceSummaryHandler(self) + def get_room_summary_handler(self) -> RoomSummaryHandler: + return RoomSummaryHandler(self) @cache_in_self def get_event_auth_handler(self) -> EventAuthHandler: diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py new file mode 100644 index 0000000000..732d746e38 --- /dev/null +++ b/tests/handlers/test_room_summary.py @@ -0,0 +1,959 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Iterable, List, Optional, Tuple +from unittest import mock + +from synapse.api.constants import ( + EventContentFields, + EventTypes, + HistoryVisibility, + JoinRules, + Membership, + RestrictedJoinRuleTypes, + RoomTypes, +) +from synapse.api.errors import AuthError, NotFoundError, SynapseError +from synapse.api.room_versions import RoomVersions +from synapse.events import make_event_from_dict +from synapse.handlers.room_summary import _child_events_comparison_key, _RoomEntry +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.server import HomeServer +from synapse.types import JsonDict, UserID + +from tests import unittest + + +def _create_event(room_id: str, order: Optional[Any] = None): + result = mock.Mock() + result.room_id = room_id + result.content = {} + if order is not None: + result.content["order"] = order + return result + + +def _order(*events): + return sorted(events, key=_child_events_comparison_key) + + +class TestSpaceSummarySort(unittest.TestCase): + def test_no_order_last(self): + """An event with no ordering is placed behind those with an ordering.""" + ev1 = _create_event("!abc:test") + ev2 = _create_event("!xyz:test", "xyz") + + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + def test_order(self): + """The ordering should be used.""" + ev1 = _create_event("!abc:test", "xyz") + ev2 = _create_event("!xyz:test", "abc") + + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + def test_order_room_id(self): + """Room ID is a tie-breaker for ordering.""" + ev1 = _create_event("!abc:test", "abc") + ev2 = _create_event("!xyz:test", "abc") + + self.assertEqual([ev1, ev2], _order(ev1, ev2)) + + def test_invalid_ordering_type(self): + """Invalid orderings are considered the same as missing.""" + ev1 = _create_event("!abc:test", 1) + ev2 = _create_event("!xyz:test", "xyz") + + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + ev1 = _create_event("!abc:test", {}) + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + ev1 = _create_event("!abc:test", []) + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + ev1 = _create_event("!abc:test", True) + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + def test_invalid_ordering_value(self): + """Invalid orderings are considered the same as missing.""" + ev1 = _create_event("!abc:test", "foo\n") + ev2 = _create_event("!xyz:test", "xyz") + + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + ev1 = _create_event("!abc:test", "a" * 51) + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + +class SpaceSummaryTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs: HomeServer): + self.hs = hs + self.handler = self.hs.get_room_summary_handler() + + # Create a user. + self.user = self.register_user("user", "pass") + self.token = self.login("user", "pass") + + # Create a space and a child room. + self.space = self.helper.create_room_as( + self.user, + tok=self.token, + extra_content={ + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} + }, + ) + self.room = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, self.room, self.token) + + def _add_child( + self, space_id: str, room_id: str, token: str, order: Optional[str] = None + ) -> None: + """Add a child room to a space.""" + content: JsonDict = {"via": [self.hs.hostname]} + if order is not None: + content["order"] = order + self.helper.send_state( + space_id, + event_type=EventTypes.SpaceChild, + body=content, + tok=token, + state_key=room_id, + ) + + def _assert_rooms( + self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] + ) -> None: + """ + Assert that the expected room IDs and events are in the response. + + Args: + result: The result from the API call. + rooms_and_children: An iterable of tuples where each tuple is: + The expected room ID. + The expected IDs of any children rooms. + """ + room_ids = [] + children_ids = [] + for room_id, children in rooms_and_children: + room_ids.append(room_id) + if children: + children_ids.extend([(room_id, child_id) for child_id in children]) + self.assertCountEqual( + [room.get("room_id") for room in result["rooms"]], room_ids + ) + self.assertCountEqual( + [ + (event.get("room_id"), event.get("state_key")) + for event in result["events"] + ], + children_ids, + ) + + def _assert_hierarchy( + self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] + ) -> None: + """ + Assert that the expected room IDs are in the response. + + Args: + result: The result from the API call. + rooms_and_children: An iterable of tuples where each tuple is: + The expected room ID. + The expected IDs of any children rooms. + """ + result_room_ids = [] + result_children_ids = [] + for result_room in result["rooms"]: + result_room_ids.append(result_room["room_id"]) + result_children_ids.append( + [ + (cs["room_id"], cs["state_key"]) + for cs in result_room.get("children_state") + ] + ) + + room_ids = [] + children_ids = [] + for room_id, children in rooms_and_children: + room_ids.append(room_id) + children_ids.append([(room_id, child_id) for child_id in children]) + + # Note that order matters. + self.assertEqual(result_room_ids, room_ids) + self.assertEqual(result_children_ids, children_ids) + + def _poke_fed_invite(self, room_id: str, from_user: str) -> None: + """ + Creates a invite (as if received over federation) for the room from the + given hostname. + + Args: + room_id: The room ID to issue an invite for. + fed_hostname: The user to invite from. + """ + # Poke an invite over federation into the database. + fed_handler = self.hs.get_federation_handler() + fed_hostname = UserID.from_string(from_user).domain + event = make_event_from_dict( + { + "room_id": room_id, + "event_id": "!abcd:" + fed_hostname, + "type": EventTypes.Member, + "sender": from_user, + "state_key": self.user, + "content": {"membership": Membership.INVITE}, + "prev_events": [], + "auth_events": [], + "depth": 1, + "origin_server_ts": 1234, + } + ) + self.get_success( + fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6) + ) + + def test_simple_space(self): + """Test a simple space with a single room.""" + result = self.get_success(self.handler.get_space_summary(self.user, self.space)) + # The result should have the space and the room in it, along with a link + # from space -> room. + expected = [(self.space, [self.room]), (self.room, ())] + self._assert_rooms(result, expected) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + def test_visibility(self): + """A user not in a space cannot inspect it.""" + user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") + + # The user can see the space since it is publicly joinable. + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + expected = [(self.space, [self.room]), (self.room, ())] + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) + + # If the space is made invite-only, it should no longer be viewable. + self.helper.send_state( + self.space, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.INVITE}, + tok=self.token, + ) + self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) + self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) + + # If the space is made world-readable it should return a result. + self.helper.send_state( + self.space, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.WORLD_READABLE}, + tok=self.token, + ) + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) + + # Make it not world-readable again and confirm it results in an error. + self.helper.send_state( + self.space, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.JOINED}, + tok=self.token, + ) + self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) + self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) + + # Join the space and results should be returned. + self.helper.invite(self.space, targ=user2, tok=self.token) + self.helper.join(self.space, user2, tok=token2) + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) + + # Attempting to view an unknown room returns the same error. + self.get_failure( + self.handler.get_space_summary(user2, "#not-a-space:" + self.hs.hostname), + AuthError, + ) + self.get_failure( + self.handler.get_room_hierarchy(user2, "#not-a-space:" + self.hs.hostname), + AuthError, + ) + + def _create_room_with_join_rule( + self, join_rule: str, room_version: Optional[str] = None, **extra_content + ) -> str: + """Create a room with the given join rule and add it to the space.""" + room_id = self.helper.create_room_as( + self.user, + room_version=room_version, + tok=self.token, + extra_content={ + "initial_state": [ + { + "type": EventTypes.JoinRules, + "state_key": "", + "content": { + "join_rule": join_rule, + **extra_content, + }, + } + ] + }, + ) + self._add_child(self.space, room_id, self.token) + return room_id + + def test_filtering(self): + """ + Rooms should be properly filtered to only include rooms the user has access to. + """ + user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") + + # Create a few rooms which will have different properties. + public_room = self._create_room_with_join_rule(JoinRules.PUBLIC) + knock_room = self._create_room_with_join_rule( + JoinRules.KNOCK, room_version=RoomVersions.V7.identifier + ) + not_invited_room = self._create_room_with_join_rule(JoinRules.INVITE) + invited_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.invite(invited_room, targ=user2, tok=self.token) + restricted_room = self._create_room_with_join_rule( + JoinRules.RESTRICTED, + room_version=RoomVersions.V8.identifier, + allow=[], + ) + restricted_accessible_room = self._create_room_with_join_rule( + JoinRules.RESTRICTED, + room_version=RoomVersions.V8.identifier, + allow=[ + { + "type": RestrictedJoinRuleTypes.ROOM_MEMBERSHIP, + "room_id": self.space, + "via": [self.hs.hostname], + } + ], + ) + world_readable_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.send_state( + world_readable_room, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.WORLD_READABLE}, + tok=self.token, + ) + joined_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.invite(joined_room, targ=user2, tok=self.token) + self.helper.join(joined_room, user2, tok=token2) + + # Join the space. + self.helper.join(self.space, user2, tok=token2) + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + expected = [ + ( + self.space, + [ + self.room, + public_room, + knock_room, + not_invited_room, + invited_room, + restricted_room, + restricted_accessible_room, + world_readable_room, + joined_room, + ], + ), + (self.room, ()), + (public_room, ()), + (knock_room, ()), + (invited_room, ()), + (restricted_accessible_room, ()), + (world_readable_room, ()), + (joined_room, ()), + ] + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) + + def test_complex_space(self): + """ + Create a "complex" space to see how it handles things like loops and subspaces. + """ + # Create an inaccessible room. + user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") + room2 = self.helper.create_room_as(user2, is_public=False, tok=token2) + # This is a bit odd as "user" is adding a room they don't know about, but + # it works for the tests. + self._add_child(self.space, room2, self.token) + + # Create a subspace under the space with an additional room in it. + subspace = self.helper.create_room_as( + self.user, + tok=self.token, + extra_content={ + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} + }, + ) + subroom = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, subspace, token=self.token) + self._add_child(subspace, subroom, token=self.token) + # Also add the two rooms from the space into this subspace (causing loops). + self._add_child(subspace, self.room, token=self.token) + self._add_child(subspace, room2, self.token) + + result = self.get_success(self.handler.get_space_summary(self.user, self.space)) + + # The result should include each room a single time and each link. + expected = [ + (self.space, [self.room, room2, subspace]), + (self.room, ()), + (subspace, [subroom, self.room, room2]), + (subroom, ()), + ] + self._assert_rooms(result, expected) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + def test_pagination(self): + """Test simple pagination works.""" + room_ids = [] + for i in range(1, 10): + room = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, room, self.token, order=str(i)) + room_ids.append(room) + # The room created initially doesn't have an order, so comes last. + room_ids.append(self.room) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, limit=7) + ) + # The result should have the space and all of the links, plus some of the + # rooms and a pagination token. + expected: List[Tuple[str, Iterable[str]]] = [(self.space, room_ids)] + expected += [(room_id, ()) for room_id in room_ids[:6]] + self._assert_hierarchy(result, expected) + self.assertIn("next_batch", result) + + # Check the next page. + result = self.get_success( + self.handler.get_room_hierarchy( + self.user, self.space, limit=5, from_token=result["next_batch"] + ) + ) + # The result should have the space and the room in it, along with a link + # from space -> room. + expected = [(room_id, ()) for room_id in room_ids[6:]] + self._assert_hierarchy(result, expected) + self.assertNotIn("next_batch", result) + + def test_invalid_pagination_token(self): + """An invalid pagination token, or changing other parameters, shoudl be rejected.""" + room_ids = [] + for i in range(1, 10): + room = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, room, self.token, order=str(i)) + room_ids.append(room) + # The room created initially doesn't have an order, so comes last. + room_ids.append(self.room) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, limit=7) + ) + self.assertIn("next_batch", result) + + # Changing the room ID, suggested-only, or max-depth causes an error. + self.get_failure( + self.handler.get_room_hierarchy( + self.user, self.room, from_token=result["next_batch"] + ), + SynapseError, + ) + self.get_failure( + self.handler.get_room_hierarchy( + self.user, + self.space, + suggested_only=True, + from_token=result["next_batch"], + ), + SynapseError, + ) + self.get_failure( + self.handler.get_room_hierarchy( + self.user, self.space, max_depth=0, from_token=result["next_batch"] + ), + SynapseError, + ) + + # An invalid token is ignored. + self.get_failure( + self.handler.get_room_hierarchy(self.user, self.space, from_token="foo"), + SynapseError, + ) + + def test_max_depth(self): + """Create a deep tree to test the max depth against.""" + spaces = [self.space] + rooms = [self.room] + for _ in range(5): + spaces.append( + self.helper.create_room_as( + self.user, + tok=self.token, + extra_content={ + "creation_content": { + EventContentFields.ROOM_TYPE: RoomTypes.SPACE + } + }, + ) + ) + self._add_child(spaces[-2], spaces[-1], self.token) + rooms.append(self.helper.create_room_as(self.user, tok=self.token)) + self._add_child(spaces[-1], rooms[-1], self.token) + + # Test just the space itself. + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, max_depth=0) + ) + expected: List[Tuple[str, Iterable[str]]] = [(spaces[0], [rooms[0], spaces[1]])] + self._assert_hierarchy(result, expected) + + # A single additional layer. + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, max_depth=1) + ) + expected += [ + (rooms[0], ()), + (spaces[1], [rooms[1], spaces[2]]), + ] + self._assert_hierarchy(result, expected) + + # A few layers. + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, max_depth=3) + ) + expected += [ + (rooms[1], ()), + (spaces[2], [rooms[2], spaces[3]]), + (rooms[2], ()), + (spaces[3], [rooms[3], spaces[4]]), + ] + self._assert_hierarchy(result, expected) + + def test_fed_complex(self): + """ + Return data over federation and ensure that it is handled properly. + """ + fed_hostname = self.hs.hostname + "2" + subspace = "#subspace:" + fed_hostname + subroom = "#subroom:" + fed_hostname + + # Generate some good data, and some bad data: + # + # * Event *back* to the root room. + # * Unrelated events / rooms + # * Multiple levels of events (in a not-useful order, e.g. grandchild + # events before child events). + + # Note that these entries are brief, but should contain enough info. + requested_room_entry = _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + "room_type": RoomTypes.SPACE, + }, + [ + { + "type": EventTypes.SpaceChild, + "room_id": subspace, + "state_key": subroom, + "content": {"via": [fed_hostname]}, + } + ], + ) + child_room = { + "room_id": subroom, + "world_readable": True, + } + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + return [ + requested_room_entry, + _RoomEntry( + subroom, + { + "room_id": subroom, + "world_readable": True, + }, + ), + ] + + async def summarize_remote_room_hierarchy(_self, room, suggested_only): + return requested_room_entry, {subroom: child_room}, set() + + # Add a room to the space which is on another server. + self._add_child(self.space, subspace, self.token) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + expected = [ + (self.space, [self.room, subspace]), + (self.room, ()), + (subspace, [subroom]), + (subroom, ()), + ] + self._assert_rooms(result, expected) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", + new=summarize_remote_room_hierarchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + def test_fed_filtering(self): + """ + Rooms returned over federation should be properly filtered to only include + rooms the user has access to. + """ + fed_hostname = self.hs.hostname + "2" + subspace = "#subspace:" + fed_hostname + + # Create a few rooms which will have different properties. + public_room = "#public:" + fed_hostname + knock_room = "#knock:" + fed_hostname + not_invited_room = "#not_invited:" + fed_hostname + invited_room = "#invited:" + fed_hostname + restricted_room = "#restricted:" + fed_hostname + restricted_accessible_room = "#restricted_accessible:" + fed_hostname + world_readable_room = "#world_readable:" + fed_hostname + joined_room = self.helper.create_room_as(self.user, tok=self.token) + + # Poke an invite over federation into the database. + self._poke_fed_invite(invited_room, "@remote:" + fed_hostname) + + # Note that these entries are brief, but should contain enough info. + children_rooms = ( + ( + public_room, + { + "room_id": public_room, + "world_readable": False, + "join_rules": JoinRules.PUBLIC, + }, + ), + ( + knock_room, + { + "room_id": knock_room, + "world_readable": False, + "join_rules": JoinRules.KNOCK, + }, + ), + ( + not_invited_room, + { + "room_id": not_invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ( + invited_room, + { + "room_id": invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ( + restricted_room, + { + "room_id": restricted_room, + "world_readable": False, + "join_rules": JoinRules.RESTRICTED, + "allowed_spaces": [], + }, + ), + ( + restricted_accessible_room, + { + "room_id": restricted_accessible_room, + "world_readable": False, + "join_rules": JoinRules.RESTRICTED, + "allowed_spaces": [self.room], + }, + ), + ( + world_readable_room, + { + "room_id": world_readable_room, + "world_readable": True, + "join_rules": JoinRules.INVITE, + }, + ), + ( + joined_room, + { + "room_id": joined_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ) + + subspace_room_entry = _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + }, + # Place each room in the sub-space. + [ + { + "type": EventTypes.SpaceChild, + "room_id": subspace, + "state_key": room_id, + "content": {"via": [fed_hostname]}, + } + for room_id, _ in children_rooms + ], + ) + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + return [subspace_room_entry] + [ + # A copy is made of the room data since the allowed_spaces key + # is removed. + _RoomEntry(child_room[0], dict(child_room[1])) + for child_room in children_rooms + ] + + async def summarize_remote_room_hierarchy(_self, room, suggested_only): + return subspace_room_entry, dict(children_rooms), set() + + # Add a room to the space which is on another server. + self._add_child(self.space, subspace, self.token) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + expected = [ + (self.space, [self.room, subspace]), + (self.room, ()), + ( + subspace, + [ + public_room, + knock_room, + not_invited_room, + invited_room, + restricted_room, + restricted_accessible_room, + world_readable_room, + joined_room, + ], + ), + (public_room, ()), + (knock_room, ()), + (invited_room, ()), + (restricted_accessible_room, ()), + (world_readable_room, ()), + (joined_room, ()), + ] + self._assert_rooms(result, expected) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", + new=summarize_remote_room_hierarchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + def test_fed_invited(self): + """ + A room which the user was invited to should be included in the response. + + This differs from test_fed_filtering in that the room itself is being + queried over federation, instead of it being included as a sub-room of + a space in the response. + """ + fed_hostname = self.hs.hostname + "2" + fed_room = "#subroom:" + fed_hostname + + # Poke an invite over federation into the database. + self._poke_fed_invite(fed_room, "@remote:" + fed_hostname) + + fed_room_entry = _RoomEntry( + fed_room, + { + "room_id": fed_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ) + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + return [fed_room_entry] + + async def summarize_remote_room_hierarchy(_self, room, suggested_only): + return fed_room_entry, {}, set() + + # Add a room to the space which is on another server. + self._add_child(self.space, fed_room, self.token) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + expected = [ + (self.space, [self.room, fed_room]), + (self.room, ()), + (fed_room, ()), + ] + self._assert_rooms(result, expected) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", + new=summarize_remote_room_hierarchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + +class RoomSummaryTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs: HomeServer): + self.hs = hs + self.handler = self.hs.get_room_summary_handler() + + # Create a user. + self.user = self.register_user("user", "pass") + self.token = self.login("user", "pass") + + # Create a simple room. + self.room = self.helper.create_room_as(self.user, tok=self.token) + self.helper.send_state( + self.room, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.INVITE}, + tok=self.token, + ) + + def test_own_room(self): + """Test a simple room created by the requester.""" + result = self.get_success(self.handler.get_room_summary(self.user, self.room)) + self.assertEqual(result.get("room_id"), self.room) + + def test_visibility(self): + """A user not in a private room cannot get its summary.""" + user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") + + # The user cannot see the room. + self.get_failure(self.handler.get_room_summary(user2, self.room), NotFoundError) + + # If the room is made world-readable it should return a result. + self.helper.send_state( + self.room, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.WORLD_READABLE}, + tok=self.token, + ) + result = self.get_success(self.handler.get_room_summary(user2, self.room)) + self.assertEqual(result.get("room_id"), self.room) + + # Make it not world-readable again and confirm it results in an error. + self.helper.send_state( + self.room, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.JOINED}, + tok=self.token, + ) + self.get_failure(self.handler.get_room_summary(user2, self.room), NotFoundError) + + # If the room is made public it should return a result. + self.helper.send_state( + self.room, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.PUBLIC}, + tok=self.token, + ) + result = self.get_success(self.handler.get_room_summary(user2, self.room)) + self.assertEqual(result.get("room_id"), self.room) + + # Join the space, make it invite-only again and results should be returned. + self.helper.join(self.room, user2, tok=token2) + self.helper.send_state( + self.room, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.INVITE}, + tok=self.token, + ) + result = self.get_success(self.handler.get_room_summary(user2, self.room)) + self.assertEqual(result.get("room_id"), self.room) diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py deleted file mode 100644 index bc8e131f4a..0000000000 --- a/tests/handlers/test_space_summary.py +++ /dev/null @@ -1,881 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Iterable, List, Optional, Tuple -from unittest import mock - -from synapse.api.constants import ( - EventContentFields, - EventTypes, - HistoryVisibility, - JoinRules, - Membership, - RestrictedJoinRuleTypes, - RoomTypes, -) -from synapse.api.errors import AuthError, SynapseError -from synapse.api.room_versions import RoomVersions -from synapse.events import make_event_from_dict -from synapse.handlers.space_summary import _child_events_comparison_key, _RoomEntry -from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.server import HomeServer -from synapse.types import JsonDict, UserID - -from tests import unittest - - -def _create_event(room_id: str, order: Optional[Any] = None): - result = mock.Mock() - result.room_id = room_id - result.content = {} - if order is not None: - result.content["order"] = order - return result - - -def _order(*events): - return sorted(events, key=_child_events_comparison_key) - - -class TestSpaceSummarySort(unittest.TestCase): - def test_no_order_last(self): - """An event with no ordering is placed behind those with an ordering.""" - ev1 = _create_event("!abc:test") - ev2 = _create_event("!xyz:test", "xyz") - - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - def test_order(self): - """The ordering should be used.""" - ev1 = _create_event("!abc:test", "xyz") - ev2 = _create_event("!xyz:test", "abc") - - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - def test_order_room_id(self): - """Room ID is a tie-breaker for ordering.""" - ev1 = _create_event("!abc:test", "abc") - ev2 = _create_event("!xyz:test", "abc") - - self.assertEqual([ev1, ev2], _order(ev1, ev2)) - - def test_invalid_ordering_type(self): - """Invalid orderings are considered the same as missing.""" - ev1 = _create_event("!abc:test", 1) - ev2 = _create_event("!xyz:test", "xyz") - - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - ev1 = _create_event("!abc:test", {}) - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - ev1 = _create_event("!abc:test", []) - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - ev1 = _create_event("!abc:test", True) - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - def test_invalid_ordering_value(self): - """Invalid orderings are considered the same as missing.""" - ev1 = _create_event("!abc:test", "foo\n") - ev2 = _create_event("!xyz:test", "xyz") - - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - ev1 = _create_event("!abc:test", "a" * 51) - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - -class SpaceSummaryTestCase(unittest.HomeserverTestCase): - servlets = [ - admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def prepare(self, reactor, clock, hs: HomeServer): - self.hs = hs - self.handler = self.hs.get_space_summary_handler() - - # Create a user. - self.user = self.register_user("user", "pass") - self.token = self.login("user", "pass") - - # Create a space and a child room. - self.space = self.helper.create_room_as( - self.user, - tok=self.token, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - self.room = self.helper.create_room_as(self.user, tok=self.token) - self._add_child(self.space, self.room, self.token) - - def _add_child( - self, space_id: str, room_id: str, token: str, order: Optional[str] = None - ) -> None: - """Add a child room to a space.""" - content: JsonDict = {"via": [self.hs.hostname]} - if order is not None: - content["order"] = order - self.helper.send_state( - space_id, - event_type=EventTypes.SpaceChild, - body=content, - tok=token, - state_key=room_id, - ) - - def _assert_rooms( - self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] - ) -> None: - """ - Assert that the expected room IDs and events are in the response. - - Args: - result: The result from the API call. - rooms_and_children: An iterable of tuples where each tuple is: - The expected room ID. - The expected IDs of any children rooms. - """ - room_ids = [] - children_ids = [] - for room_id, children in rooms_and_children: - room_ids.append(room_id) - if children: - children_ids.extend([(room_id, child_id) for child_id in children]) - self.assertCountEqual( - [room.get("room_id") for room in result["rooms"]], room_ids - ) - self.assertCountEqual( - [ - (event.get("room_id"), event.get("state_key")) - for event in result["events"] - ], - children_ids, - ) - - def _assert_hierarchy( - self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] - ) -> None: - """ - Assert that the expected room IDs are in the response. - - Args: - result: The result from the API call. - rooms_and_children: An iterable of tuples where each tuple is: - The expected room ID. - The expected IDs of any children rooms. - """ - result_room_ids = [] - result_children_ids = [] - for result_room in result["rooms"]: - result_room_ids.append(result_room["room_id"]) - result_children_ids.append( - [ - (cs["room_id"], cs["state_key"]) - for cs in result_room.get("children_state") - ] - ) - - room_ids = [] - children_ids = [] - for room_id, children in rooms_and_children: - room_ids.append(room_id) - children_ids.append([(room_id, child_id) for child_id in children]) - - # Note that order matters. - self.assertEqual(result_room_ids, room_ids) - self.assertEqual(result_children_ids, children_ids) - - def _poke_fed_invite(self, room_id: str, from_user: str) -> None: - """ - Creates a invite (as if received over federation) for the room from the - given hostname. - - Args: - room_id: The room ID to issue an invite for. - fed_hostname: The user to invite from. - """ - # Poke an invite over federation into the database. - fed_handler = self.hs.get_federation_handler() - fed_hostname = UserID.from_string(from_user).domain - event = make_event_from_dict( - { - "room_id": room_id, - "event_id": "!abcd:" + fed_hostname, - "type": EventTypes.Member, - "sender": from_user, - "state_key": self.user, - "content": {"membership": Membership.INVITE}, - "prev_events": [], - "auth_events": [], - "depth": 1, - "origin_server_ts": 1234, - } - ) - self.get_success( - fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6) - ) - - def test_simple_space(self): - """Test a simple space with a single room.""" - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) - # The result should have the space and the room in it, along with a link - # from space -> room. - expected = [(self.space, [self.room]), (self.room, ())] - self._assert_rooms(result, expected) - - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) - ) - self._assert_hierarchy(result, expected) - - def test_visibility(self): - """A user not in a space cannot inspect it.""" - user2 = self.register_user("user2", "pass") - token2 = self.login("user2", "pass") - - # The user can see the space since it is publicly joinable. - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - expected = [(self.space, [self.room]), (self.room, ())] - self._assert_rooms(result, expected) - - result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) - self._assert_hierarchy(result, expected) - - # If the space is made invite-only, it should no longer be viewable. - self.helper.send_state( - self.space, - event_type=EventTypes.JoinRules, - body={"join_rule": JoinRules.INVITE}, - tok=self.token, - ) - self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) - self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) - - # If the space is made world-readable it should return a result. - self.helper.send_state( - self.space, - event_type=EventTypes.RoomHistoryVisibility, - body={"history_visibility": HistoryVisibility.WORLD_READABLE}, - tok=self.token, - ) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, expected) - - result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) - self._assert_hierarchy(result, expected) - - # Make it not world-readable again and confirm it results in an error. - self.helper.send_state( - self.space, - event_type=EventTypes.RoomHistoryVisibility, - body={"history_visibility": HistoryVisibility.JOINED}, - tok=self.token, - ) - self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) - self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) - - # Join the space and results should be returned. - self.helper.invite(self.space, targ=user2, tok=self.token) - self.helper.join(self.space, user2, tok=token2) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, expected) - - result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) - self._assert_hierarchy(result, expected) - - # Attempting to view an unknown room returns the same error. - self.get_failure( - self.handler.get_space_summary(user2, "#not-a-space:" + self.hs.hostname), - AuthError, - ) - self.get_failure( - self.handler.get_room_hierarchy(user2, "#not-a-space:" + self.hs.hostname), - AuthError, - ) - - def _create_room_with_join_rule( - self, join_rule: str, room_version: Optional[str] = None, **extra_content - ) -> str: - """Create a room with the given join rule and add it to the space.""" - room_id = self.helper.create_room_as( - self.user, - room_version=room_version, - tok=self.token, - extra_content={ - "initial_state": [ - { - "type": EventTypes.JoinRules, - "state_key": "", - "content": { - "join_rule": join_rule, - **extra_content, - }, - } - ] - }, - ) - self._add_child(self.space, room_id, self.token) - return room_id - - def test_filtering(self): - """ - Rooms should be properly filtered to only include rooms the user has access to. - """ - user2 = self.register_user("user2", "pass") - token2 = self.login("user2", "pass") - - # Create a few rooms which will have different properties. - public_room = self._create_room_with_join_rule(JoinRules.PUBLIC) - knock_room = self._create_room_with_join_rule( - JoinRules.KNOCK, room_version=RoomVersions.V7.identifier - ) - not_invited_room = self._create_room_with_join_rule(JoinRules.INVITE) - invited_room = self._create_room_with_join_rule(JoinRules.INVITE) - self.helper.invite(invited_room, targ=user2, tok=self.token) - restricted_room = self._create_room_with_join_rule( - JoinRules.RESTRICTED, - room_version=RoomVersions.V8.identifier, - allow=[], - ) - restricted_accessible_room = self._create_room_with_join_rule( - JoinRules.RESTRICTED, - room_version=RoomVersions.V8.identifier, - allow=[ - { - "type": RestrictedJoinRuleTypes.ROOM_MEMBERSHIP, - "room_id": self.space, - "via": [self.hs.hostname], - } - ], - ) - world_readable_room = self._create_room_with_join_rule(JoinRules.INVITE) - self.helper.send_state( - world_readable_room, - event_type=EventTypes.RoomHistoryVisibility, - body={"history_visibility": HistoryVisibility.WORLD_READABLE}, - tok=self.token, - ) - joined_room = self._create_room_with_join_rule(JoinRules.INVITE) - self.helper.invite(joined_room, targ=user2, tok=self.token) - self.helper.join(joined_room, user2, tok=token2) - - # Join the space. - self.helper.join(self.space, user2, tok=token2) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - expected = [ - ( - self.space, - [ - self.room, - public_room, - knock_room, - not_invited_room, - invited_room, - restricted_room, - restricted_accessible_room, - world_readable_room, - joined_room, - ], - ), - (self.room, ()), - (public_room, ()), - (knock_room, ()), - (invited_room, ()), - (restricted_accessible_room, ()), - (world_readable_room, ()), - (joined_room, ()), - ] - self._assert_rooms(result, expected) - - result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) - self._assert_hierarchy(result, expected) - - def test_complex_space(self): - """ - Create a "complex" space to see how it handles things like loops and subspaces. - """ - # Create an inaccessible room. - user2 = self.register_user("user2", "pass") - token2 = self.login("user2", "pass") - room2 = self.helper.create_room_as(user2, is_public=False, tok=token2) - # This is a bit odd as "user" is adding a room they don't know about, but - # it works for the tests. - self._add_child(self.space, room2, self.token) - - # Create a subspace under the space with an additional room in it. - subspace = self.helper.create_room_as( - self.user, - tok=self.token, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - subroom = self.helper.create_room_as(self.user, tok=self.token) - self._add_child(self.space, subspace, token=self.token) - self._add_child(subspace, subroom, token=self.token) - # Also add the two rooms from the space into this subspace (causing loops). - self._add_child(subspace, self.room, token=self.token) - self._add_child(subspace, room2, self.token) - - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) - - # The result should include each room a single time and each link. - expected = [ - (self.space, [self.room, room2, subspace]), - (self.room, ()), - (subspace, [subroom, self.room, room2]), - (subroom, ()), - ] - self._assert_rooms(result, expected) - - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) - ) - self._assert_hierarchy(result, expected) - - def test_pagination(self): - """Test simple pagination works.""" - room_ids = [] - for i in range(1, 10): - room = self.helper.create_room_as(self.user, tok=self.token) - self._add_child(self.space, room, self.token, order=str(i)) - room_ids.append(room) - # The room created initially doesn't have an order, so comes last. - room_ids.append(self.room) - - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, limit=7) - ) - # The result should have the space and all of the links, plus some of the - # rooms and a pagination token. - expected: List[Tuple[str, Iterable[str]]] = [(self.space, room_ids)] - expected += [(room_id, ()) for room_id in room_ids[:6]] - self._assert_hierarchy(result, expected) - self.assertIn("next_batch", result) - - # Check the next page. - result = self.get_success( - self.handler.get_room_hierarchy( - self.user, self.space, limit=5, from_token=result["next_batch"] - ) - ) - # The result should have the space and the room in it, along with a link - # from space -> room. - expected = [(room_id, ()) for room_id in room_ids[6:]] - self._assert_hierarchy(result, expected) - self.assertNotIn("next_batch", result) - - def test_invalid_pagination_token(self): - """An invalid pagination token, or changing other parameters, shoudl be rejected.""" - room_ids = [] - for i in range(1, 10): - room = self.helper.create_room_as(self.user, tok=self.token) - self._add_child(self.space, room, self.token, order=str(i)) - room_ids.append(room) - # The room created initially doesn't have an order, so comes last. - room_ids.append(self.room) - - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, limit=7) - ) - self.assertIn("next_batch", result) - - # Changing the room ID, suggested-only, or max-depth causes an error. - self.get_failure( - self.handler.get_room_hierarchy( - self.user, self.room, from_token=result["next_batch"] - ), - SynapseError, - ) - self.get_failure( - self.handler.get_room_hierarchy( - self.user, - self.space, - suggested_only=True, - from_token=result["next_batch"], - ), - SynapseError, - ) - self.get_failure( - self.handler.get_room_hierarchy( - self.user, self.space, max_depth=0, from_token=result["next_batch"] - ), - SynapseError, - ) - - # An invalid token is ignored. - self.get_failure( - self.handler.get_room_hierarchy(self.user, self.space, from_token="foo"), - SynapseError, - ) - - def test_max_depth(self): - """Create a deep tree to test the max depth against.""" - spaces = [self.space] - rooms = [self.room] - for _ in range(5): - spaces.append( - self.helper.create_room_as( - self.user, - tok=self.token, - extra_content={ - "creation_content": { - EventContentFields.ROOM_TYPE: RoomTypes.SPACE - } - }, - ) - ) - self._add_child(spaces[-2], spaces[-1], self.token) - rooms.append(self.helper.create_room_as(self.user, tok=self.token)) - self._add_child(spaces[-1], rooms[-1], self.token) - - # Test just the space itself. - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, max_depth=0) - ) - expected: List[Tuple[str, Iterable[str]]] = [(spaces[0], [rooms[0], spaces[1]])] - self._assert_hierarchy(result, expected) - - # A single additional layer. - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, max_depth=1) - ) - expected += [ - (rooms[0], ()), - (spaces[1], [rooms[1], spaces[2]]), - ] - self._assert_hierarchy(result, expected) - - # A few layers. - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, max_depth=3) - ) - expected += [ - (rooms[1], ()), - (spaces[2], [rooms[2], spaces[3]]), - (rooms[2], ()), - (spaces[3], [rooms[3], spaces[4]]), - ] - self._assert_hierarchy(result, expected) - - def test_fed_complex(self): - """ - Return data over federation and ensure that it is handled properly. - """ - fed_hostname = self.hs.hostname + "2" - subspace = "#subspace:" + fed_hostname - subroom = "#subroom:" + fed_hostname - - # Generate some good data, and some bad data: - # - # * Event *back* to the root room. - # * Unrelated events / rooms - # * Multiple levels of events (in a not-useful order, e.g. grandchild - # events before child events). - - # Note that these entries are brief, but should contain enough info. - requested_room_entry = _RoomEntry( - subspace, - { - "room_id": subspace, - "world_readable": True, - "room_type": RoomTypes.SPACE, - }, - [ - { - "type": EventTypes.SpaceChild, - "room_id": subspace, - "state_key": subroom, - "content": {"via": [fed_hostname]}, - } - ], - ) - child_room = { - "room_id": subroom, - "world_readable": True, - } - - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - return [ - requested_room_entry, - _RoomEntry( - subroom, - { - "room_id": subroom, - "world_readable": True, - }, - ), - ] - - async def summarize_remote_room_hiearchy(_self, room, suggested_only): - return requested_room_entry, {subroom: child_room}, set() - - # Add a room to the space which is on another server. - self._add_child(self.space, subspace, self.token) - - with mock.patch( - "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - - expected = [ - (self.space, [self.room, subspace]), - (self.room, ()), - (subspace, [subroom]), - (subroom, ()), - ] - self._assert_rooms(result, expected) - - with mock.patch( - "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room_hiearchy", - new=summarize_remote_room_hiearchy, - ): - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) - ) - self._assert_hierarchy(result, expected) - - def test_fed_filtering(self): - """ - Rooms returned over federation should be properly filtered to only include - rooms the user has access to. - """ - fed_hostname = self.hs.hostname + "2" - subspace = "#subspace:" + fed_hostname - - # Create a few rooms which will have different properties. - public_room = "#public:" + fed_hostname - knock_room = "#knock:" + fed_hostname - not_invited_room = "#not_invited:" + fed_hostname - invited_room = "#invited:" + fed_hostname - restricted_room = "#restricted:" + fed_hostname - restricted_accessible_room = "#restricted_accessible:" + fed_hostname - world_readable_room = "#world_readable:" + fed_hostname - joined_room = self.helper.create_room_as(self.user, tok=self.token) - - # Poke an invite over federation into the database. - self._poke_fed_invite(invited_room, "@remote:" + fed_hostname) - - # Note that these entries are brief, but should contain enough info. - children_rooms = ( - ( - public_room, - { - "room_id": public_room, - "world_readable": False, - "join_rules": JoinRules.PUBLIC, - }, - ), - ( - knock_room, - { - "room_id": knock_room, - "world_readable": False, - "join_rules": JoinRules.KNOCK, - }, - ), - ( - not_invited_room, - { - "room_id": not_invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ), - ( - invited_room, - { - "room_id": invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ), - ( - restricted_room, - { - "room_id": restricted_room, - "world_readable": False, - "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [], - }, - ), - ( - restricted_accessible_room, - { - "room_id": restricted_accessible_room, - "world_readable": False, - "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [self.room], - }, - ), - ( - world_readable_room, - { - "room_id": world_readable_room, - "world_readable": True, - "join_rules": JoinRules.INVITE, - }, - ), - ( - joined_room, - { - "room_id": joined_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ), - ) - - subspace_room_entry = _RoomEntry( - subspace, - { - "room_id": subspace, - "world_readable": True, - }, - # Place each room in the sub-space. - [ - { - "type": EventTypes.SpaceChild, - "room_id": subspace, - "state_key": room_id, - "content": {"via": [fed_hostname]}, - } - for room_id, _ in children_rooms - ], - ) - - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - return [subspace_room_entry] + [ - # A copy is made of the room data since the allowed_spaces key - # is removed. - _RoomEntry(child_room[0], dict(child_room[1])) - for child_room in children_rooms - ] - - async def summarize_remote_room_hiearchy(_self, room, suggested_only): - return subspace_room_entry, dict(children_rooms), set() - - # Add a room to the space which is on another server. - self._add_child(self.space, subspace, self.token) - - with mock.patch( - "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - - expected = [ - (self.space, [self.room, subspace]), - (self.room, ()), - ( - subspace, - [ - public_room, - knock_room, - not_invited_room, - invited_room, - restricted_room, - restricted_accessible_room, - world_readable_room, - joined_room, - ], - ), - (public_room, ()), - (knock_room, ()), - (invited_room, ()), - (restricted_accessible_room, ()), - (world_readable_room, ()), - (joined_room, ()), - ] - self._assert_rooms(result, expected) - - with mock.patch( - "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room_hiearchy", - new=summarize_remote_room_hiearchy, - ): - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) - ) - self._assert_hierarchy(result, expected) - - def test_fed_invited(self): - """ - A room which the user was invited to should be included in the response. - - This differs from test_fed_filtering in that the room itself is being - queried over federation, instead of it being included as a sub-room of - a space in the response. - """ - fed_hostname = self.hs.hostname + "2" - fed_room = "#subroom:" + fed_hostname - - # Poke an invite over federation into the database. - self._poke_fed_invite(fed_room, "@remote:" + fed_hostname) - - fed_room_entry = _RoomEntry( - fed_room, - { - "room_id": fed_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ) - - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - return [fed_room_entry] - - async def summarize_remote_room_hiearchy(_self, room, suggested_only): - return fed_room_entry, {}, set() - - # Add a room to the space which is on another server. - self._add_child(self.space, fed_room, self.token) - - with mock.patch( - "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - - expected = [ - (self.space, [self.room, fed_room]), - (self.room, ()), - (fed_room, ()), - ] - self._assert_rooms(result, expected) - - with mock.patch( - "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room_hiearchy", - new=summarize_remote_room_hiearchy, - ): - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) - ) - self._assert_hierarchy(result, expected) -- cgit 1.5.1 From 5af83efe8d106ee6fe6568f6758d458159341531 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 16 Aug 2021 12:01:30 -0400 Subject: Validate the max_rooms_per_space parameter to ensure it is non-negative. (#10611) --- changelog.d/10611.bugfix | 1 + synapse/federation/transport/server/federation.py | 22 ++++++++++++++++---- synapse/rest/client/v1/room.py | 25 ++++++++++++++++++----- 3 files changed, 39 insertions(+), 9 deletions(-) create mode 100644 changelog.d/10611.bugfix (limited to 'synapse') diff --git a/changelog.d/10611.bugfix b/changelog.d/10611.bugfix new file mode 100644 index 0000000000..ecbe408b47 --- /dev/null +++ b/changelog.d/10611.bugfix @@ -0,0 +1 @@ +Additional validation for the spaces summary API to avoid errors like `ValueError: Stop argument for islice() must be None or an integer`. The missing validation has existed since v1.31.0. diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 7d81cc642c..2fdf6cc99e 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -557,7 +557,14 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): room_id: str, ) -> Tuple[int, JsonDict]: suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) + max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space") + if max_rooms_per_space is not None and max_rooms_per_space < 0: + raise SynapseError( + 400, + "Value for 'max_rooms_per_space' must be a non-negative integer", + Codes.BAD_JSON, + ) exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[]) @@ -586,10 +593,17 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): raise SynapseError(400, "bad value for 'exclude_rooms'", Codes.BAD_JSON) max_rooms_per_space = content.get("max_rooms_per_space") - if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int): - raise SynapseError( - 400, "bad value for 'max_rooms_per_space'", Codes.BAD_JSON - ) + if max_rooms_per_space is not None: + if not isinstance(max_rooms_per_space, int): + raise SynapseError( + 400, "bad value for 'max_rooms_per_space'", Codes.BAD_JSON + ) + if max_rooms_per_space < 0: + raise SynapseError( + 400, + "Value for 'max_rooms_per_space' must be a non-negative integer", + Codes.BAD_JSON, + ) return 200, await self.handler.federation_space_summary( origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index d3882a84e2..ba7250ad8e 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -993,11 +993,19 @@ class RoomSpaceSummaryRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self._auth.get_user_by_req(request, allow_guest=True) + max_rooms_per_space = parse_integer(request, "max_rooms_per_space") + if max_rooms_per_space is not None and max_rooms_per_space < 0: + raise SynapseError( + 400, + "Value for 'max_rooms_per_space' must be a non-negative integer", + Codes.BAD_JSON, + ) + return 200, await self._room_summary_handler.get_space_summary( requester.user.to_string(), room_id, suggested_only=parse_boolean(request, "suggested_only", default=False), - max_rooms_per_space=parse_integer(request, "max_rooms_per_space"), + max_rooms_per_space=max_rooms_per_space, ) # TODO When switching to the stable endpoint, remove the POST handler. @@ -1014,10 +1022,17 @@ class RoomSpaceSummaryRestServlet(RestServlet): ) max_rooms_per_space = content.get("max_rooms_per_space") - if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int): - raise SynapseError( - 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON - ) + if max_rooms_per_space is not None: + if not isinstance(max_rooms_per_space, int): + raise SynapseError( + 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON + ) + if max_rooms_per_space < 0: + raise SynapseError( + 400, + "Value for 'max_rooms_per_space' must be a non-negative integer", + Codes.BAD_JSON, + ) return 200, await self._room_summary_handler.get_space_summary( requester.user.to_string(), -- cgit 1.5.1 From 19e51b14d23f756883688fd8238da61c6ff29cc3 Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Mon, 16 Aug 2021 18:11:48 +0100 Subject: Manhole: wrap coroutines in `defer.ensureDeferred` automatically (#10602) --- changelog.d/10602.feature | 1 + docs/manhole.md | 2 +- synapse/util/manhole.py | 14 ++++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10602.feature (limited to 'synapse') diff --git a/changelog.d/10602.feature b/changelog.d/10602.feature new file mode 100644 index 0000000000..ab18291a20 --- /dev/null +++ b/changelog.d/10602.feature @@ -0,0 +1 @@ +The Synapse manhole no longer needs coroutines to be wrapped in `defer.ensureDeferred`. diff --git a/docs/manhole.md b/docs/manhole.md index 37d1d7823c..db92df88dc 100644 --- a/docs/manhole.md +++ b/docs/manhole.md @@ -67,7 +67,7 @@ This gives a Python REPL in which `hs` gives access to the `synapse.server.HomeServer` object - which in turn gives access to many other parts of the process. -Note that any call which returns a coroutine will need to be wrapped in `ensureDeferred`. +Note that, prior to Synapse 1.41, any call which returns a coroutine will need to be wrapped in `ensureDeferred`. As a simple example, retrieving an event from the database: diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index da24ba0470..522daa323d 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import sys import traceback @@ -20,6 +21,7 @@ from twisted.conch.insults import insults from twisted.conch.manhole import ColoredManhole, ManholeInterpreter from twisted.conch.ssh.keys import Key from twisted.cred import checkers, portal +from twisted.internet import defer PUBLIC_KEY = ( "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDHhGATaW4KhE23+7nrH4jFx3yLq9OjaEs5" @@ -141,3 +143,15 @@ class SynapseManholeInterpreter(ManholeInterpreter): self.write("".join(lines)) finally: last_tb = ei = None + + def displayhook(self, obj): + """ + We override the displayhook so that we automatically convert coroutines + into Deferreds. (Our superclass' displayhook will take care of the rest, + by displaying the Deferred if it's ready, or registering a callback + if it's not). + """ + if inspect.iscoroutine(obj): + super().displayhook(defer.ensureDeferred(obj)) + else: + super().displayhook(obj) -- cgit 1.5.1 From a933c2c7d8ef49c3c98ef443d959f955600bfb6b Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Tue, 17 Aug 2021 10:52:38 +0100 Subject: Add an admin API to check if a username is available (#10578) This adds a new API GET /_synapse/admin/v1/username_available?username=foo to check if a username is available. It is the counterpart to https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-register-available, except that it works even if registration is disabled. --- changelog.d/10578.feature | 1 + docs/admin_api/user_admin_api.md | 20 ++++++++++ synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/username_available.py | 51 ++++++++++++++++++++++++ tests/rest/admin/test_username_available.py | 62 +++++++++++++++++++++++++++++ 5 files changed, 136 insertions(+) create mode 100644 changelog.d/10578.feature create mode 100644 synapse/rest/admin/username_available.py create mode 100644 tests/rest/admin/test_username_available.py (limited to 'synapse') diff --git a/changelog.d/10578.feature b/changelog.d/10578.feature new file mode 100644 index 0000000000..02397f0009 --- /dev/null +++ b/changelog.d/10578.feature @@ -0,0 +1 @@ +Add an admin API (`GET /_synapse/admin/username_available`) to check if a username is available (regardless of registration settings). \ No newline at end of file diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 33811f5bbb..4b5dd4685a 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -1057,3 +1057,23 @@ The following parameters should be set in the URL: - `user_id` - The fully qualified MXID: for example, `@user:server.com`. The user must be local. + +### Check username availability + +Checks to see if a username is available, and valid, for the server. See [the client-server +API](https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-register-available) +for more information. + +This endpoint will work even if registration is disabled on the server, unlike +`/_matrix/client/r0/register/available`. + +The API is: + +``` +POST /_synapse/admin/v1/username_availabile?username=$localpart +``` + +The request and response format is the same as the [/_matrix/client/r0/register/available](https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-register-available) API. + +To use it, you will need to authenticate by providing an `access_token` for a +server admin: [Admin API](../usage/administration/admin_api) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index abf749b001..8a91068092 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -51,6 +51,7 @@ from synapse.rest.admin.rooms import ( ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet +from synapse.rest.admin.username_available import UsernameAvailableRestServlet from synapse.rest.admin.users import ( AccountValidityRenewServlet, DeactivateAccountRestServlet, @@ -241,6 +242,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ForwardExtremitiesRestServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) RateLimitRestServlet(hs).register(http_server) + UsernameAvailableRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( diff --git a/synapse/rest/admin/username_available.py b/synapse/rest/admin/username_available.py new file mode 100644 index 0000000000..2bf1472967 --- /dev/null +++ b/synapse/rest/admin/username_available.py @@ -0,0 +1,51 @@ +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from http import HTTPStatus +from typing import TYPE_CHECKING, Tuple + +from synapse.http.servlet import RestServlet, parse_string +from synapse.http.site import SynapseRequest +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class UsernameAvailableRestServlet(RestServlet): + """An admin API to check if a given username is available, regardless of whether registration is enabled. + + Example: + GET /_synapse/admin/v1/username_available?username=foo + 200 OK + { + "available": true + } + """ + + PATTERNS = admin_patterns("/username_available") + + def __init__(self, hs: "HomeServer"): + self.auth = hs.get_auth() + self.registration_handler = hs.get_registration_handler() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + + username = parse_string(request, "username", required=True) + await self.registration_handler.check_username(username) + return HTTPStatus.OK, {"available": True} diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py new file mode 100644 index 0000000000..53cbc8ddab --- /dev/null +++ b/tests/rest/admin/test_username_available.py @@ -0,0 +1,62 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse.rest.admin +from synapse.api.errors import Codes, SynapseError +from synapse.rest.client.v1 import login + +from tests import unittest + + +class UsernameAvailableTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + url = "/_synapse/admin/v1/username_available" + + def prepare(self, reactor, clock, hs): + self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + async def check_username(username): + if username == "allowed": + return True + raise SynapseError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) + + handler = self.hs.get_registration_handler() + handler.check_username = check_username + + def test_username_available(self): + """ + The endpoint should return a 200 response if the username does not exist + """ + + url = "%s?username=%s" % (self.url, "allowed") + channel = self.make_request("GET", url, None, self.admin_user_tok) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertTrue(channel.json_body["available"]) + + def test_username_unavailable(self): + """ + The endpoint should return a 200 response if the username does not exist + """ + + url = "%s?username=%s" % (self.url, "disallowed") + channel = self.make_request("GET", url, None, self.admin_user_tok) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE") + self.assertEqual(channel.json_body["error"], "User ID already taken.") -- cgit 1.5.1 From ae2714c1f31f2a843e19dc44501784401181162c Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 17 Aug 2021 12:23:14 +0200 Subject: Allow using several custom template directories (#10587) Allow using several directories in read_templates. --- changelog.d/10587.misc | 1 + synapse/config/_base.py | 43 ++++++++++++++----------- synapse/config/account_validity.py | 2 +- synapse/config/emailconfig.py | 8 +++-- synapse/config/sso.py | 2 +- synapse/module_api/__init__.py | 5 ++- tests/config/test_base.py | 64 ++++++++++++++++++++++++++++++++++++-- 7 files changed, 98 insertions(+), 27 deletions(-) create mode 100644 changelog.d/10587.misc (limited to 'synapse') diff --git a/changelog.d/10587.misc b/changelog.d/10587.misc new file mode 100644 index 0000000000..4c6167977c --- /dev/null +++ b/changelog.d/10587.misc @@ -0,0 +1 @@ +Allow multiple custom directories in `read_templates`. diff --git a/synapse/config/_base.py b/synapse/config/_base.py index d6ec618f8f..2cc242782a 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -237,13 +237,14 @@ class Config: def read_templates( self, filenames: List[str], - custom_template_directory: Optional[str] = None, + custom_template_directories: Optional[Iterable[str]] = None, ) -> List[jinja2.Template]: """Load a list of template files from disk using the given variables. This function will attempt to load the given templates from the default Synapse - template directory. If `custom_template_directory` is supplied, that directory - is tried first. + template directory. If `custom_template_directories` is supplied, any directory + in this list is tried (in the order they appear in the list) before trying + Synapse's default directory. Files read are treated as Jinja templates. The templates are not rendered yet and have autoescape enabled. @@ -251,8 +252,8 @@ class Config: Args: filenames: A list of template filenames to read. - custom_template_directory: A directory to try to look for the templates - before using the default Synapse template directory instead. + custom_template_directories: A list of directory to try to look for the + templates before using the default Synapse template directory instead. Raises: ConfigError: if the file's path is incorrect or otherwise cannot be read. @@ -260,20 +261,26 @@ class Config: Returns: A list of jinja2 templates. """ - search_directories = [self.default_template_dir] - - # The loader will first look in the custom template directory (if specified) for the - # given filename. If it doesn't find it, it will use the default template dir instead - if custom_template_directory: - # Check that the given template directory exists - if not self.path_exists(custom_template_directory): - raise ConfigError( - "Configured template directory does not exist: %s" - % (custom_template_directory,) - ) + search_directories = [] + + # The loader will first look in the custom template directories (if specified) + # for the given filename. If it doesn't find it, it will use the default + # template dir instead. + if custom_template_directories is not None: + for custom_template_directory in custom_template_directories: + # Check that the given template directory exists + if not self.path_exists(custom_template_directory): + raise ConfigError( + "Configured template directory does not exist: %s" + % (custom_template_directory,) + ) + + # Search the custom template directory as well + search_directories.append(custom_template_directory) - # Search the custom template directory as well - search_directories.insert(0, custom_template_directory) + # Append the default directory at the end of the list so Jinja can fallback on it + # if a template is missing from any custom directory. + search_directories.append(self.default_template_dir) # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(search_directories) diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py index 6be4eafe55..9acce5996e 100644 --- a/synapse/config/account_validity.py +++ b/synapse/config/account_validity.py @@ -88,5 +88,5 @@ class AccountValidityConfig(Config): "account_previously_renewed.html", invalid_token_template_filename, ], - account_validity_template_dir, + (td for td in (account_validity_template_dir,) if td), ) diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 42526502f0..fc74b4a8b9 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -257,7 +257,9 @@ class EmailConfig(Config): registration_template_success_html, add_threepid_template_success_html, ], - template_dir, + ( + td for td in (template_dir,) if td + ), # Filter out template_dir if not provided ) # Render templates that do not contain any placeholders @@ -297,7 +299,7 @@ class EmailConfig(Config): self.email_notif_template_text, ) = self.read_templates( [notif_template_html, notif_template_text], - template_dir, + (td for td in (template_dir,) if td), ) self.email_notif_for_new_users = email_config.get( @@ -320,7 +322,7 @@ class EmailConfig(Config): self.account_validity_template_text, ) = self.read_templates( [expiry_template_html, expiry_template_text], - template_dir, + (td for td in (template_dir,) if td), ) subjects_config = email_config.get("subjects", {}) diff --git a/synapse/config/sso.py b/synapse/config/sso.py index d0f04cf8e6..4b590e0535 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -63,7 +63,7 @@ class SSOConfig(Config): "sso_auth_success.html", "sso_auth_bad_user.html", ], - self.sso_template_dir, + (td for td in (self.sso_template_dir,) if td), ) # These templates have no placeholders, so render them here diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 1cc13fc97b..82725853bc 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -677,7 +677,10 @@ class ModuleApi: A list containing the loaded templates, with the orders matching the one of the filenames parameter. """ - return self._hs.config.read_templates(filenames, custom_template_directory) + return self._hs.config.read_templates( + filenames, + (td for td in (custom_template_directory,) if td), + ) class PublicRoomListManager: diff --git a/tests/config/test_base.py b/tests/config/test_base.py index 84ae3b88ae..baa5313fb3 100644 --- a/tests/config/test_base.py +++ b/tests/config/test_base.py @@ -30,7 +30,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # contain template files with tempfile.TemporaryDirectory() as tmp_dir: # Attempt to load an HTML template from our custom template directory - template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0] + template = self.hs.config.read_templates(["sso_error.html"], (tmp_dir,))[0] # If no errors, we should've gotten the default template instead @@ -60,7 +60,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # Attempt to load the template from our custom template directory template = ( - self.hs.config.read_templates([template_filename], tmp_dir) + self.hs.config.read_templates([template_filename], (tmp_dir,)) )[0] # Render the template @@ -74,8 +74,66 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): "Template file did not contain our test string", ) + def test_multiple_custom_template_directories(self): + """Tests that directories are searched in the right order if multiple custom + template directories are provided. + """ + # Create two temporary directories on the filesystem. + tempdirs = [ + tempfile.TemporaryDirectory(), + tempfile.TemporaryDirectory(), + ] + + # Create one template in each directory, whose content is the index of the + # directory in the list. + template_filename = "my_template.html.j2" + for i in range(len(tempdirs)): + tempdir = tempdirs[i] + template_path = os.path.join(tempdir.name, template_filename) + + with open(template_path, "w") as fp: + fp.write(str(i)) + fp.flush() + + # Retrieve the template. + template = ( + self.hs.config.read_templates( + [template_filename], + (td.name for td in tempdirs), + ) + )[0] + + # Test that we got the template we dropped in the first directory in the list. + self.assertEqual(template.render(), "0") + + # Add another template, this one only in the second directory in the list, so we + # can test that the second directory is still searched into when no matching file + # could be found in the first one. + other_template_name = "my_other_template.html.j2" + other_template_path = os.path.join(tempdirs[1].name, other_template_name) + + with open(other_template_path, "w") as fp: + fp.write("hello world") + fp.flush() + + # Retrieve the template. + template = ( + self.hs.config.read_templates( + [other_template_name], + (td.name for td in tempdirs), + ) + )[0] + + # Test that the file has the expected content. + self.assertEqual(template.render(), "hello world") + + # Cleanup the temporary directories manually since we're not using a context + # manager. + for td in tempdirs: + td.cleanup() + def test_loading_template_from_nonexistent_custom_directory(self): with self.assertRaises(ConfigError): self.hs.config.read_templates( - ["some_filename.html"], "a_nonexistent_directory" + ["some_filename.html"], ("a_nonexistent_directory",) ) -- cgit 1.5.1 From 58f0d97275e9ffc134f9aaf59ce01c0e745ec041 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 17 Aug 2021 11:45:35 +0100 Subject: update links to schema doc (#10620) --- changelog.d/10620.misc | 1 + synapse/storage/schema/README.md | 2 +- synapse/storage/schema/__init__.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 changelog.d/10620.misc (limited to 'synapse') diff --git a/changelog.d/10620.misc b/changelog.d/10620.misc new file mode 100644 index 0000000000..8b29668a1f --- /dev/null +++ b/changelog.d/10620.misc @@ -0,0 +1 @@ +Fix up a couple of links to the database schema documentation. diff --git a/synapse/storage/schema/README.md b/synapse/storage/schema/README.md index 729f44ea6c..4fc2061a3d 100644 --- a/synapse/storage/schema/README.md +++ b/synapse/storage/schema/README.md @@ -1,4 +1,4 @@ # Synapse Database Schemas This directory contains the schema files used to build Synapse databases. For more -information, see /docs/development/database_schema.md. +information, see https://matrix-org.github.io/synapse/develop/development/database_schema.html. diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index fd4dd67d91..7e0687e197 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -19,8 +19,8 @@ This should be incremented whenever the codebase changes its requirements on the shape of the database schema (even if those requirements are backwards-compatible with older versions of Synapse). -See `README.md `_ for more information on how this -works. +See https://matrix-org.github.io/synapse/develop/development/database_schema.html +for more information on how this works. Changes in SCHEMA_VERSION = 61: - The `user_stats_historical` and `room_stats_historical` tables are not written and -- cgit 1.5.1 From 3bcd525b46678ff228c4275acad47c12974c9a33 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 17 Aug 2021 12:56:11 +0200 Subject: Allow to edit `external_ids` by Edit User admin API (#10598) Signed-off-by: Dirk Klimpel dirk@klimpel.org --- changelog.d/10598.feature | 1 + docs/admin_api/user_admin_api.md | 40 +++-- synapse/rest/admin/users.py | 139 +++++++++------ synapse/storage/databases/main/registration.py | 22 +++ tests/rest/admin/test_user.py | 227 +++++++++++++++++++++---- 5 files changed, 340 insertions(+), 89 deletions(-) create mode 100644 changelog.d/10598.feature (limited to 'synapse') diff --git a/changelog.d/10598.feature b/changelog.d/10598.feature new file mode 100644 index 0000000000..92c159118b --- /dev/null +++ b/changelog.d/10598.feature @@ -0,0 +1 @@ +Allow editing a user's `external_ids` via the "Edit User" admin API. Contributed by @dklimpel. \ No newline at end of file diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 4b5dd4685a..6a9335d6ec 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -81,6 +81,16 @@ with a body of: "address": "" } ], + "external_ids": [ + { + "auth_provider": "", + "external_id": "" + }, + { + "auth_provider": "", + "external_id": "" + } + ], "avatar_url": "", "admin": false, "deactivated": false @@ -90,26 +100,34 @@ with a body of: To use it, you will need to authenticate by providing an `access_token` for a server admin: [Admin API](../usage/administration/admin_api) +Returns HTTP status code: +- `201` - When a new user object was created. +- `200` - When a user was modified. + URL parameters: - `user_id`: fully-qualified user id: for example, `@user:server.com`. Body parameters: -- `password`, optional. If provided, the user's password is updated and all +- `password` - string, optional. If provided, the user's password is updated and all devices are logged out. - -- `displayname`, optional, defaults to the value of `user_id`. - -- `threepids`, optional, allows setting the third-party IDs (email, msisdn) +- `displayname` - string, optional, defaults to the value of `user_id`. +- `threepids` - array, optional, allows setting the third-party IDs (email, msisdn) + - `medium` - string. Kind of third-party ID, either `email` or `msisdn`. + - `address` - string. Value of third-party ID. belonging to a user. - -- `avatar_url`, optional, must be a +- `external_ids` - array, optional. Allow setting the identifier of the external identity + provider for SSO (Single sign-on). Details in + [Sample Configuration File](../usage/configuration/homeserver_sample_config.html) + section `sso` and `oidc_providers`. + - `auth_provider` - string. ID of the external identity provider. Value of `idp_id` + in homeserver configuration. + - `external_id` - string, user ID in the external identity provider. +- `avatar_url` - string, optional, must be a [MXC URI](https://matrix.org/docs/spec/client_server/r0.6.0#matrix-content-mxc-uris). - -- `admin`, optional, defaults to `false`. - -- `deactivated`, optional. If unspecified, deactivation state will be left +- `admin` - bool, optional, defaults to `false`. +- `deactivated` - bool, optional. If unspecified, deactivation state will be left unchanged on existing accounts and set to `false` for new accounts. A user cannot be erased by deactivating with this API. For details on deactivating users see [Deactivate Account](#deactivate-account). diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 41f21ba118..c885fd77ab 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -196,20 +196,57 @@ class UserRestServletV2(RestServlet): user = await self.admin_handler.get_user(target_user) user_id = target_user.to_string() + # check for required parameters for each threepid + threepids = body.get("threepids") + if threepids is not None: + for threepid in threepids: + assert_params_in_dict(threepid, ["medium", "address"]) + + # check for required parameters for each external_id + external_ids = body.get("external_ids") + if external_ids is not None: + for external_id in external_ids: + assert_params_in_dict(external_id, ["auth_provider", "external_id"]) + + user_type = body.get("user_type", None) + if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: + raise SynapseError(400, "Invalid user type") + + set_admin_to = body.get("admin", False) + if not isinstance(set_admin_to, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'admin' must be a boolean, if given", + Codes.BAD_JSON, + ) + + password = body.get("password", None) + if password is not None: + if not isinstance(password, str) or len(password) > 512: + raise SynapseError(400, "Invalid password") + + deactivate = body.get("deactivated", False) + if not isinstance(deactivate, bool): + raise SynapseError(400, "'deactivated' parameter is not of type boolean") + + # convert into List[Tuple[str, str]] + if external_ids is not None: + new_external_ids = [] + for external_id in external_ids: + new_external_ids.append( + (external_id["auth_provider"], external_id["external_id"]) + ) + if user: # modify user if "displayname" in body: await self.profile_handler.set_displayname( target_user, requester, body["displayname"], True ) - if "threepids" in body: - # check for required parameters for each threepid - for threepid in body["threepids"]: - assert_params_in_dict(threepid, ["medium", "address"]) - + if threepids is not None: # remove old threepids from user - threepids = await self.store.user_get_threepids(user_id) - for threepid in threepids: + old_threepids = await self.store.user_get_threepids(user_id) + for threepid in old_threepids: try: await self.auth_handler.delete_threepid( user_id, threepid["medium"], threepid["address"], None @@ -220,18 +257,39 @@ class UserRestServletV2(RestServlet): # add new threepids to user current_time = self.hs.get_clock().time_msec() - for threepid in body["threepids"]: + for threepid in threepids: await self.auth_handler.add_threepid( user_id, threepid["medium"], threepid["address"], current_time ) - if "avatar_url" in body and type(body["avatar_url"]) == str: + if external_ids is not None: + # get changed external_ids (added and removed) + cur_external_ids = await self.store.get_external_ids_by_user(user_id) + add_external_ids = set(new_external_ids) - set(cur_external_ids) + del_external_ids = set(cur_external_ids) - set(new_external_ids) + + # remove old external_ids + for auth_provider, external_id in del_external_ids: + await self.store.remove_user_external_id( + auth_provider, + external_id, + user_id, + ) + + # add new external_ids + for auth_provider, external_id in add_external_ids: + await self.store.record_user_external_id( + auth_provider, + external_id, + user_id, + ) + + if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( target_user, requester, body["avatar_url"], True ) if "admin" in body: - set_admin_to = bool(body["admin"]) if set_admin_to != user["admin"]: auth_user = requester.user if target_user == auth_user and not set_admin_to: @@ -239,29 +297,18 @@ class UserRestServletV2(RestServlet): await self.store.set_server_admin(target_user, set_admin_to) - if "password" in body: - if not isinstance(body["password"], str) or len(body["password"]) > 512: - raise SynapseError(400, "Invalid password") - else: - new_password = body["password"] - logout_devices = True - - new_password_hash = await self.auth_handler.hash(new_password) - - await self.set_password_handler.set_password( - target_user.to_string(), - new_password_hash, - logout_devices, - requester, - ) + if password is not None: + logout_devices = True + new_password_hash = await self.auth_handler.hash(password) + + await self.set_password_handler.set_password( + target_user.to_string(), + new_password_hash, + logout_devices, + requester, + ) if "deactivated" in body: - deactivate = body["deactivated"] - if not isinstance(deactivate, bool): - raise SynapseError( - 400, "'deactivated' parameter is not of type boolean" - ) - if deactivate and not user["deactivated"]: await self.deactivate_account_handler.deactivate_account( target_user.to_string(), False, requester, by_admin=True @@ -285,36 +332,24 @@ class UserRestServletV2(RestServlet): return 200, user else: # create user - password = body.get("password") + displayname = body.get("displayname", None) + password_hash = None if password is not None: - if not isinstance(password, str) or len(password) > 512: - raise SynapseError(400, "Invalid password") password_hash = await self.auth_handler.hash(password) - admin = body.get("admin", None) - user_type = body.get("user_type", None) - displayname = body.get("displayname", None) - - if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: - raise SynapseError(400, "Invalid user type") - user_id = await self.registration_handler.register_user( localpart=target_user.localpart, password_hash=password_hash, - admin=bool(admin), + admin=set_admin_to, default_display_name=displayname, user_type=user_type, by_admin=True, ) - if "threepids" in body: - # check for required parameters for each threepid - for threepid in body["threepids"]: - assert_params_in_dict(threepid, ["medium", "address"]) - + if threepids is not None: current_time = self.hs.get_clock().time_msec() - for threepid in body["threepids"]: + for threepid in threepids: await self.auth_handler.add_threepid( user_id, threepid["medium"], threepid["address"], current_time ) @@ -334,6 +369,14 @@ class UserRestServletV2(RestServlet): data={}, ) + if external_ids is not None: + for auth_provider, external_id in new_external_ids: + await self.store.record_user_external_id( + auth_provider, + external_id, + user_id, + ) + if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( target_user, requester, body["avatar_url"], True diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 14670c2881..c67bea81c6 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -599,6 +599,28 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="record_user_external_id", ) + async def remove_user_external_id( + self, auth_provider: str, external_id: str, user_id: str + ) -> None: + """Remove a mapping from an external user id to a mxid + + If the mapping is not found, this method does nothing. + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ + await self.db_pool.simple_delete( + table="user_external_ids", + keyvalues={ + "auth_provider": auth_provider, + "external_id": external_id, + "user_id": user_id, + }, + desc="remove_user_external_id", + ) + async def get_user_by_external_id( self, auth_provider: str, external_id: str ) -> Optional[str]: diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 13fab5579b..a736ec4754 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1240,56 +1240,114 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) - def test_get_user(self): + def test_invalid_parameter(self): """ - Test a simple get of a user. + If parameters are invalid, an error is returned. """ + + # admin not bool channel = self.make_request( - "GET", + "PUT", self.url_other_user, access_token=self.admin_user_tok, + content={"admin": "not_bool"}, ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual("User", channel.json_body["displayname"]) - self._check_fields(channel.json_body) + # deactivated not bool + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"deactivated": "not_bool"}, + ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) - def test_get_user_with_sso(self): - """ - Test get a user with SSO details. - """ - self.get_success( - self.store.record_user_external_id( - "auth_provider1", "external_id1", self.other_user - ) + # password not str + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"password": True}, ) - self.get_success( - self.store.record_user_external_id( - "auth_provider2", "external_id2", self.other_user - ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # password not length + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"password": "x" * 513}, ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + # user_type not valid channel = self.make_request( - "GET", + "PUT", self.url_other_user, access_token=self.admin_user_tok, + content={"user_type": "new type"}, ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual( - "external_id1", channel.json_body["external_ids"][0]["external_id"] + # external_ids not valid + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "external_ids": {"auth_provider": "prov", "wrong_external_id": "id"} + }, ) - self.assertEqual( - "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"] + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"external_ids": {"external_id": "id"}}, ) - self.assertEqual( - "external_id2", channel.json_body["external_ids"][1]["external_id"] + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + # threepids not valid + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"threepids": {"medium": "email", "wrong_address": "id"}}, ) - self.assertEqual( - "auth_provider2", channel.json_body["external_ids"][1]["auth_provider"] + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"threepids": {"address": "value"}}, ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + def test_get_user(self): + """ + Test a simple get of a user. + """ + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("User", channel.json_body["displayname"]) self._check_fields(channel.json_body) def test_create_server_admin(self): @@ -1353,6 +1411,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): "admin": False, "displayname": "Bob's name", "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + "external_ids": [ + { + "external_id": "external_id1", + "auth_provider": "auth_provider1", + }, + ], "avatar_url": "mxc://fibble/wibble", } @@ -1368,6 +1432,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual( + "external_id1", channel.json_body["external_ids"][0]["external_id"] + ) + self.assertEqual( + "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"] + ) self.assertFalse(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) self._check_fields(channel.json_body) @@ -1632,6 +1702,103 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + def test_set_external_id(self): + """ + Test setting external id for an other user. + """ + + # Add two external_ids + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "external_ids": [ + { + "external_id": "external_id1", + "auth_provider": "auth_provider1", + }, + { + "external_id": "external_id2", + "auth_provider": "auth_provider2", + }, + ] + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["external_ids"])) + # result does not always have the same sort order, therefore it becomes sorted + self.assertEqual( + sorted(channel.json_body["external_ids"], key=lambda k: k["auth_provider"]), + [ + {"auth_provider": "auth_provider1", "external_id": "external_id1"}, + {"auth_provider": "auth_provider2", "external_id": "external_id2"}, + ], + ) + self._check_fields(channel.json_body) + + # Set a new and remove an external_id + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "external_ids": [ + { + "external_id": "external_id2", + "auth_provider": "auth_provider2", + }, + { + "external_id": "external_id3", + "auth_provider": "auth_provider3", + }, + ] + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["external_ids"])) + self.assertEqual( + channel.json_body["external_ids"], + [ + {"auth_provider": "auth_provider2", "external_id": "external_id2"}, + {"auth_provider": "auth_provider3", "external_id": "external_id3"}, + ], + ) + self._check_fields(channel.json_body) + + # Get user + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual( + channel.json_body["external_ids"], + [ + {"auth_provider": "auth_provider2", "external_id": "external_id2"}, + {"auth_provider": "auth_provider3", "external_id": "external_id3"}, + ], + ) + self._check_fields(channel.json_body) + + # Remove external_ids + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"external_ids": []}, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(0, len(channel.json_body["external_ids"])) + def test_deactivate_user(self): """ Test deactivating another user. -- cgit 1.5.1 From b62eba770522fde7bf1204eb5771ee24d9a5e7bc Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 17 Aug 2021 12:32:25 +0100 Subject: Always list fallback key types in /sync (#10623) --- changelog.d/10623.bugfix | 1 + synapse/rest/client/v2_alpha/sync.py | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) create mode 100644 changelog.d/10623.bugfix (limited to 'synapse') diff --git a/changelog.d/10623.bugfix b/changelog.d/10623.bugfix new file mode 100644 index 0000000000..759fba3513 --- /dev/null +++ b/changelog.d/10623.bugfix @@ -0,0 +1 @@ +Revert behaviour introduced in v1.38.0 that strips `org.matrix.msc2732.device_unused_fallback_key_types` from `/sync` when its value is empty. This field should instead always be present according to [MSC2732](https://github.com/matrix-org/matrix-doc/blob/master/proposals/2732-olm-fallback-keys.md). \ No newline at end of file diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index e321668698..e18f4d01b3 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -259,10 +259,11 @@ class SyncRestServlet(RestServlet): # Corresponding synapse issue: https://github.com/matrix-org/synapse/issues/10456 response["device_one_time_keys_count"] = sync_result.device_one_time_keys_count - if sync_result.device_unused_fallback_key_types: - response[ - "org.matrix.msc2732.device_unused_fallback_key_types" - ] = sync_result.device_unused_fallback_key_types + # https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md + # states that this field should always be included, as long as the server supports the feature. + response[ + "org.matrix.msc2732.device_unused_fallback_key_types" + ] = sync_result.device_unused_fallback_key_types if joined: response["rooms"][Membership.JOIN] = joined -- cgit 1.5.1 From 642a42eddece60afbbd5e5a6659fa9b939238b4a Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Tue, 17 Aug 2021 12:57:58 +0100 Subject: Flatten the synapse.rest.client package (#10600) --- changelog.d/10600.misc | 1 + synapse/app/generic_worker.py | 40 +- synapse/handlers/auth.py | 6 +- synapse/rest/__init__.py | 35 +- synapse/rest/admin/users.py | 4 +- synapse/rest/client/__init__.py | 2 +- synapse/rest/client/_base.py | 100 ++ synapse/rest/client/account.py | 910 ++++++++++++++++ synapse/rest/client/account_data.py | 122 +++ synapse/rest/client/account_validity.py | 104 ++ synapse/rest/client/auth.py | 143 +++ synapse/rest/client/capabilities.py | 68 ++ synapse/rest/client/devices.py | 300 +++++ synapse/rest/client/directory.py | 185 ++++ synapse/rest/client/events.py | 94 ++ synapse/rest/client/filter.py | 94 ++ synapse/rest/client/groups.py | 957 ++++++++++++++++ synapse/rest/client/initial_sync.py | 47 + synapse/rest/client/keys.py | 344 ++++++ synapse/rest/client/knock.py | 107 ++ synapse/rest/client/login.py | 600 ++++++++++ synapse/rest/client/logout.py | 72 ++ synapse/rest/client/notifications.py | 91 ++ synapse/rest/client/openid.py | 94 ++ synapse/rest/client/password_policy.py | 57 + synapse/rest/client/presence.py | 95 ++ synapse/rest/client/profile.py | 155 +++ synapse/rest/client/push_rule.py | 354 ++++++ synapse/rest/client/pusher.py | 171 +++ synapse/rest/client/read_marker.py | 74 ++ synapse/rest/client/receipts.py | 71 ++ synapse/rest/client/register.py | 879 +++++++++++++++ synapse/rest/client/relations.py | 381 +++++++ synapse/rest/client/report_event.py | 68 ++ synapse/rest/client/room.py | 1152 ++++++++++++++++++++ synapse/rest/client/room_batch.py | 441 ++++++++ synapse/rest/client/room_keys.py | 391 +++++++ synapse/rest/client/room_upgrade_rest_servlet.py | 88 ++ synapse/rest/client/sendtodevice.py | 67 ++ synapse/rest/client/shared_rooms.py | 67 ++ synapse/rest/client/sync.py | 532 +++++++++ synapse/rest/client/tags.py | 85 ++ synapse/rest/client/thirdparty.py | 111 ++ synapse/rest/client/tokenrefresh.py | 37 + synapse/rest/client/user_directory.py | 79 ++ synapse/rest/client/v1/__init__.py | 13 - synapse/rest/client/v1/directory.py | 185 ---- synapse/rest/client/v1/events.py | 94 -- synapse/rest/client/v1/initial_sync.py | 47 - synapse/rest/client/v1/login.py | 600 ---------- synapse/rest/client/v1/logout.py | 72 -- synapse/rest/client/v1/presence.py | 95 -- synapse/rest/client/v1/profile.py | 155 --- synapse/rest/client/v1/push_rule.py | 354 ------ synapse/rest/client/v1/pusher.py | 171 --- synapse/rest/client/v1/room.py | 1152 -------------------- synapse/rest/client/v1/voip.py | 73 -- synapse/rest/client/v2_alpha/__init__.py | 13 - synapse/rest/client/v2_alpha/_base.py | 100 -- synapse/rest/client/v2_alpha/account.py | 910 ---------------- synapse/rest/client/v2_alpha/account_data.py | 122 --- synapse/rest/client/v2_alpha/account_validity.py | 104 -- synapse/rest/client/v2_alpha/auth.py | 143 --- synapse/rest/client/v2_alpha/capabilities.py | 68 -- synapse/rest/client/v2_alpha/devices.py | 300 ----- synapse/rest/client/v2_alpha/filter.py | 94 -- synapse/rest/client/v2_alpha/groups.py | 957 ---------------- synapse/rest/client/v2_alpha/keys.py | 344 ------ synapse/rest/client/v2_alpha/knock.py | 107 -- synapse/rest/client/v2_alpha/notifications.py | 91 -- synapse/rest/client/v2_alpha/openid.py | 94 -- synapse/rest/client/v2_alpha/password_policy.py | 57 - synapse/rest/client/v2_alpha/read_marker.py | 74 -- synapse/rest/client/v2_alpha/receipts.py | 71 -- synapse/rest/client/v2_alpha/register.py | 879 --------------- synapse/rest/client/v2_alpha/relations.py | 381 ------- synapse/rest/client/v2_alpha/report_event.py | 68 -- synapse/rest/client/v2_alpha/room.py | 441 -------- synapse/rest/client/v2_alpha/room_keys.py | 391 ------- .../client/v2_alpha/room_upgrade_rest_servlet.py | 88 -- synapse/rest/client/v2_alpha/sendtodevice.py | 67 -- synapse/rest/client/v2_alpha/shared_rooms.py | 67 -- synapse/rest/client/v2_alpha/sync.py | 532 --------- synapse/rest/client/v2_alpha/tags.py | 85 -- synapse/rest/client/v2_alpha/thirdparty.py | 111 -- synapse/rest/client/v2_alpha/tokenrefresh.py | 37 - synapse/rest/client/v2_alpha/user_directory.py | 79 -- synapse/rest/client/voip.py | 73 ++ tests/app/test_phone_stats_home.py | 2 +- tests/events/test_presence_router.py | 2 +- tests/events/test_snapshot.py | 2 +- tests/federation/test_complexity.py | 2 +- tests/federation/test_federation_catch_up.py | 2 +- tests/federation/test_federation_sender.py | 2 +- tests/federation/test_federation_server.py | 2 +- tests/federation/transport/test_knocking.py | 2 +- tests/handlers/test_admin.py | 4 +- tests/handlers/test_directory.py | 2 +- tests/handlers/test_federation.py | 2 +- tests/handlers/test_message.py | 2 +- tests/handlers/test_password_providers.py | 3 +- tests/handlers/test_presence.py | 2 +- tests/handlers/test_room_summary.py | 2 +- tests/handlers/test_stats.py | 2 +- tests/handlers/test_user_directory.py | 3 +- tests/module_api/test_api.py | 2 +- tests/push/test_email.py | 2 +- tests/push/test_http.py | 3 +- tests/replication/tcp/streams/test_events.py | 2 +- tests/replication/test_auth.py | 2 +- tests/replication/test_client_reader_shard.py | 2 +- tests/replication/test_federation_sender_shard.py | 2 +- tests/replication/test_multi_media_repo.py | 2 +- tests/replication/test_pusher_shard.py | 2 +- tests/replication/test_sharded_event_persister.py | 3 +- tests/rest/admin/test_admin.py | 3 +- tests/rest/admin/test_device.py | 2 +- tests/rest/admin/test_event_reports.py | 3 +- tests/rest/admin/test_media.py | 2 +- tests/rest/admin/test_room.py | 2 +- tests/rest/admin/test_statistics.py | 2 +- tests/rest/admin/test_user.py | 3 +- tests/rest/admin/test_username_available.py | 2 +- tests/rest/client/test_consent.py | 2 +- tests/rest/client/test_ephemeral_message.py | 2 +- tests/rest/client/test_identity.py | 2 +- tests/rest/client/test_power_levels.py | 3 +- tests/rest/client/test_redactions.py | 3 +- tests/rest/client/test_retention.py | 2 +- tests/rest/client/test_shadow_banned.py | 9 +- tests/rest/client/test_third_party_rules.py | 2 +- tests/rest/client/v1/test_directory.py | 2 +- tests/rest/client/v1/test_events.py | 2 +- tests/rest/client/v1/test_login.py | 5 +- tests/rest/client/v1/test_presence.py | 2 +- tests/rest/client/v1/test_profile.py | 2 +- tests/rest/client/v1/test_push_rule_attrs.py | 2 +- tests/rest/client/v1/test_rooms.py | 3 +- tests/rest/client/v1/test_typing.py | 2 +- tests/rest/client/v2_alpha/test_account.py | 3 +- tests/rest/client/v2_alpha/test_auth.py | 3 +- tests/rest/client/v2_alpha/test_capabilities.py | 3 +- tests/rest/client/v2_alpha/test_filter.py | 2 +- tests/rest/client/v2_alpha/test_password_policy.py | 3 +- tests/rest/client/v2_alpha/test_register.py | 3 +- tests/rest/client/v2_alpha/test_relations.py | 3 +- tests/rest/client/v2_alpha/test_report_event.py | 3 +- tests/rest/client/v2_alpha/test_sendtodevice.py | 3 +- tests/rest/client/v2_alpha/test_shared_rooms.py | 3 +- tests/rest/client/v2_alpha/test_sync.py | 3 +- tests/rest/client/v2_alpha/test_upgrade_room.py | 3 +- tests/rest/media/v1/test_media_storage.py | 2 +- tests/server_notices/test_consent.py | 3 +- .../test_resource_limits_server_notices.py | 3 +- tests/storage/databases/main/test_events_worker.py | 2 +- tests/storage/test_cleanup_extrems.py | 2 +- tests/storage/test_client_ips.py | 2 +- tests/storage/test_event_chain.py | 2 +- tests/storage/test_events.py | 2 +- tests/storage/test_purge.py | 2 +- tests/storage/test_roommember.py | 2 +- tests/test_mau.py | 2 +- tests/test_terms_auth.py | 2 +- 163 files changed, 9984 insertions(+), 10035 deletions(-) create mode 100644 changelog.d/10600.misc create mode 100644 synapse/rest/client/_base.py create mode 100644 synapse/rest/client/account.py create mode 100644 synapse/rest/client/account_data.py create mode 100644 synapse/rest/client/account_validity.py create mode 100644 synapse/rest/client/auth.py create mode 100644 synapse/rest/client/capabilities.py create mode 100644 synapse/rest/client/devices.py create mode 100644 synapse/rest/client/directory.py create mode 100644 synapse/rest/client/events.py create mode 100644 synapse/rest/client/filter.py create mode 100644 synapse/rest/client/groups.py create mode 100644 synapse/rest/client/initial_sync.py create mode 100644 synapse/rest/client/keys.py create mode 100644 synapse/rest/client/knock.py create mode 100644 synapse/rest/client/login.py create mode 100644 synapse/rest/client/logout.py create mode 100644 synapse/rest/client/notifications.py create mode 100644 synapse/rest/client/openid.py create mode 100644 synapse/rest/client/password_policy.py create mode 100644 synapse/rest/client/presence.py create mode 100644 synapse/rest/client/profile.py create mode 100644 synapse/rest/client/push_rule.py create mode 100644 synapse/rest/client/pusher.py create mode 100644 synapse/rest/client/read_marker.py create mode 100644 synapse/rest/client/receipts.py create mode 100644 synapse/rest/client/register.py create mode 100644 synapse/rest/client/relations.py create mode 100644 synapse/rest/client/report_event.py create mode 100644 synapse/rest/client/room.py create mode 100644 synapse/rest/client/room_batch.py create mode 100644 synapse/rest/client/room_keys.py create mode 100644 synapse/rest/client/room_upgrade_rest_servlet.py create mode 100644 synapse/rest/client/sendtodevice.py create mode 100644 synapse/rest/client/shared_rooms.py create mode 100644 synapse/rest/client/sync.py create mode 100644 synapse/rest/client/tags.py create mode 100644 synapse/rest/client/thirdparty.py create mode 100644 synapse/rest/client/tokenrefresh.py create mode 100644 synapse/rest/client/user_directory.py delete mode 100644 synapse/rest/client/v1/__init__.py delete mode 100644 synapse/rest/client/v1/directory.py delete mode 100644 synapse/rest/client/v1/events.py delete mode 100644 synapse/rest/client/v1/initial_sync.py delete mode 100644 synapse/rest/client/v1/login.py delete mode 100644 synapse/rest/client/v1/logout.py delete mode 100644 synapse/rest/client/v1/presence.py delete mode 100644 synapse/rest/client/v1/profile.py delete mode 100644 synapse/rest/client/v1/push_rule.py delete mode 100644 synapse/rest/client/v1/pusher.py delete mode 100644 synapse/rest/client/v1/room.py delete mode 100644 synapse/rest/client/v1/voip.py delete mode 100644 synapse/rest/client/v2_alpha/__init__.py delete mode 100644 synapse/rest/client/v2_alpha/_base.py delete mode 100644 synapse/rest/client/v2_alpha/account.py delete mode 100644 synapse/rest/client/v2_alpha/account_data.py delete mode 100644 synapse/rest/client/v2_alpha/account_validity.py delete mode 100644 synapse/rest/client/v2_alpha/auth.py delete mode 100644 synapse/rest/client/v2_alpha/capabilities.py delete mode 100644 synapse/rest/client/v2_alpha/devices.py delete mode 100644 synapse/rest/client/v2_alpha/filter.py delete mode 100644 synapse/rest/client/v2_alpha/groups.py delete mode 100644 synapse/rest/client/v2_alpha/keys.py delete mode 100644 synapse/rest/client/v2_alpha/knock.py delete mode 100644 synapse/rest/client/v2_alpha/notifications.py delete mode 100644 synapse/rest/client/v2_alpha/openid.py delete mode 100644 synapse/rest/client/v2_alpha/password_policy.py delete mode 100644 synapse/rest/client/v2_alpha/read_marker.py delete mode 100644 synapse/rest/client/v2_alpha/receipts.py delete mode 100644 synapse/rest/client/v2_alpha/register.py delete mode 100644 synapse/rest/client/v2_alpha/relations.py delete mode 100644 synapse/rest/client/v2_alpha/report_event.py delete mode 100644 synapse/rest/client/v2_alpha/room.py delete mode 100644 synapse/rest/client/v2_alpha/room_keys.py delete mode 100644 synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py delete mode 100644 synapse/rest/client/v2_alpha/sendtodevice.py delete mode 100644 synapse/rest/client/v2_alpha/shared_rooms.py delete mode 100644 synapse/rest/client/v2_alpha/sync.py delete mode 100644 synapse/rest/client/v2_alpha/tags.py delete mode 100644 synapse/rest/client/v2_alpha/thirdparty.py delete mode 100644 synapse/rest/client/v2_alpha/tokenrefresh.py delete mode 100644 synapse/rest/client/v2_alpha/user_directory.py create mode 100644 synapse/rest/client/voip.py (limited to 'synapse') diff --git a/changelog.d/10600.misc b/changelog.d/10600.misc new file mode 100644 index 0000000000..489dc20b11 --- /dev/null +++ b/changelog.d/10600.misc @@ -0,0 +1 @@ +Flatten the `synapse.rest.client` package by moving the contents of `v1` and `v2_alpha` into the parent. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 3b7131af8f..d7b425a7ab 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -66,40 +66,40 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore from synapse.rest.admin import register_servlets_for_media_repo -from synapse.rest.client.v1 import events, login, presence, room -from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet -from synapse.rest.client.v1.profile import ( - ProfileAvatarURLRestServlet, - ProfileDisplaynameRestServlet, - ProfileRestServlet, -) -from synapse.rest.client.v1.push_rule import PushRuleRestServlet -from synapse.rest.client.v1.voip import VoipRestServlet -from synapse.rest.client.v2_alpha import ( +from synapse.rest.client import ( account_data, + events, groups, + login, + presence, read_marker, receipts, + room, room_keys, sync, tags, user_directory, ) -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.rest.client.v2_alpha.account import ThreepidRestServlet -from synapse.rest.client.v2_alpha.account_data import ( - AccountDataServlet, - RoomAccountDataServlet, -) -from synapse.rest.client.v2_alpha.devices import DevicesRestServlet -from synapse.rest.client.v2_alpha.keys import ( +from synapse.rest.client._base import client_patterns +from synapse.rest.client.account import ThreepidRestServlet +from synapse.rest.client.account_data import AccountDataServlet, RoomAccountDataServlet +from synapse.rest.client.devices import DevicesRestServlet +from synapse.rest.client.initial_sync import InitialSyncRestServlet +from synapse.rest.client.keys import ( KeyChangesServlet, KeyQueryServlet, OneTimeKeyServlet, ) -from synapse.rest.client.v2_alpha.register import RegisterRestServlet -from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet +from synapse.rest.client.profile import ( + ProfileAvatarURLRestServlet, + ProfileDisplaynameRestServlet, + ProfileRestServlet, +) +from synapse.rest.client.push_rule import PushRuleRestServlet +from synapse.rest.client.register import RegisterRestServlet +from synapse.rest.client.sendtodevice import SendToDeviceRestServlet from synapse.rest.client.versions import VersionsRestServlet +from synapse.rest.client.voip import VoipRestServlet from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.synapse.client import build_synapse_client_resource_tree diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 22a8552241..161b3c933c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -73,7 +73,7 @@ from synapse.util.stringutils import base62_encode from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: - from synapse.rest.client.v1.login import LoginResponse + from synapse.rest.client.login import LoginResponse from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -461,7 +461,7 @@ class AuthHandler(BaseHandler): If no auth flows have been completed successfully, raises an InteractiveAuthIncompleteError. To handle this, you can use - synapse.rest.client.v2_alpha._base.interactive_auth_handler as a + synapse.rest.client._base.interactive_auth_handler as a decorator. Args: @@ -543,7 +543,7 @@ class AuthHandler(BaseHandler): # Note that the registration endpoint explicitly removes the # "initial_device_display_name" parameter if it is provided # without a "password" parameter. See the changes to - # synapse.rest.client.v2_alpha.register.RegisterRestServlet.on_POST + # synapse.rest.client.register.RegisterRestServlet.on_POST # in commit 544722bad23fc31056b9240189c3cbbbf0ffd3f9. if not clientdict: clientdict = session.clientdict diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 9cffe59ce5..3adc576124 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -14,40 +14,36 @@ # limitations under the License. from synapse.http.server import JsonResource from synapse.rest import admin -from synapse.rest.client import versions -from synapse.rest.client.v1 import ( - directory, - events, - initial_sync, - login as v1_login, - logout, - presence, - profile, - push_rule, - pusher, - room, - voip, -) -from synapse.rest.client.v2_alpha import ( +from synapse.rest.client import ( account, account_data, account_validity, auth, capabilities, devices, + directory, + events, filter, groups, + initial_sync, keys, knock, + login as v1_login, + logout, notifications, openid, password_policy, + presence, + profile, + push_rule, + pusher, read_marker, receipts, register, relations, report_event, - room as roomv2, + room, + room_batch, room_keys, room_upgrade_rest_servlet, sendtodevice, @@ -57,6 +53,8 @@ from synapse.rest.client.v2_alpha import ( thirdparty, tokenrefresh, user_directory, + versions, + voip, ) @@ -85,7 +83,6 @@ class ClientRestResource(JsonResource): # Partially deprecated in r0 events.register_servlets(hs, client_resource) - # "v1" + "r0" room.register_servlets(hs, client_resource) v1_login.register_servlets(hs, client_resource) profile.register_servlets(hs, client_resource) @@ -95,8 +92,6 @@ class ClientRestResource(JsonResource): pusher.register_servlets(hs, client_resource) push_rule.register_servlets(hs, client_resource) logout.register_servlets(hs, client_resource) - - # "v2" sync.register_servlets(hs, client_resource) filter.register_servlets(hs, client_resource) account.register_servlets(hs, client_resource) @@ -118,7 +113,7 @@ class ClientRestResource(JsonResource): user_directory.register_servlets(hs, client_resource) groups.register_servlets(hs, client_resource) room_upgrade_rest_servlet.register_servlets(hs, client_resource) - roomv2.register_servlets(hs, client_resource) + room_batch.register_servlets(hs, client_resource) capabilities.register_servlets(hs, client_resource) account_validity.register_servlets(hs, client_resource) relations.register_servlets(hs, client_resource) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index c885fd77ab..93193b0864 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -34,7 +34,7 @@ from synapse.rest.admin._base import ( assert_requester_is_admin, assert_user_is_admin, ) -from synapse.rest.client.v2_alpha._base import client_patterns +from synapse.rest.client._base import client_patterns from synapse.storage.databases.main.media_repository import MediaSortOrder from synapse.storage.databases.main.stats import UserSortOrder from synapse.types import JsonDict, UserID @@ -504,7 +504,7 @@ class UserRegisterServlet(RestServlet): raise SynapseError(403, "HMAC incorrect") # Reuse the parts of RegisterRestServlet to reduce code duplication - from synapse.rest.client.v2_alpha.register import RegisterRestServlet + from synapse.rest.client.register import RegisterRestServlet register = RegisterRestServlet(self.hs) diff --git a/synapse/rest/client/__init__.py b/synapse/rest/client/__init__.py index 629e2df74a..f9830cc51f 100644 --- a/synapse/rest/client/__init__.py +++ b/synapse/rest/client/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2014-2016 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py new file mode 100644 index 0000000000..0443f4571c --- /dev/null +++ b/synapse/rest/client/_base.py @@ -0,0 +1,100 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains base REST classes for constructing client v1 servlets. +""" +import logging +import re +from typing import Iterable, Pattern + +from synapse.api.errors import InteractiveAuthIncompleteError +from synapse.api.urls import CLIENT_API_PREFIX +from synapse.types import JsonDict + +logger = logging.getLogger(__name__) + + +def client_patterns( + path_regex: str, + releases: Iterable[int] = (0,), + unstable: bool = True, + v1: bool = False, +) -> Iterable[Pattern]: + """Creates a regex compiled client path with the correct client path + prefix. + + Args: + path_regex: The regex string to match. This should NOT have a ^ + as this will be prefixed. + releases: An iterable of releases to include this endpoint under. + unstable: If true, include this endpoint under the "unstable" prefix. + v1: If true, include this endpoint under the "api/v1" prefix. + Returns: + An iterable of patterns. + """ + patterns = [] + + if unstable: + unstable_prefix = CLIENT_API_PREFIX + "/unstable" + patterns.append(re.compile("^" + unstable_prefix + path_regex)) + if v1: + v1_prefix = CLIENT_API_PREFIX + "/api/v1" + patterns.append(re.compile("^" + v1_prefix + path_regex)) + for release in releases: + new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,) + patterns.append(re.compile("^" + new_prefix + path_regex)) + + return patterns + + +def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) -> None: + """ + Enforces a maximum limit of a timeline query. + + Params: + filter_json: The timeline query to modify. + filter_timeline_limit: The maximum limit to allow, passing -1 will + disable enforcing a maximum limit. + """ + if filter_timeline_limit < 0: + return # no upper limits + timeline = filter_json.get("room", {}).get("timeline", {}) + if "limit" in timeline: + filter_json["room"]["timeline"]["limit"] = min( + filter_json["room"]["timeline"]["limit"], filter_timeline_limit + ) + + +def interactive_auth_handler(orig): + """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors + + Takes a on_POST method which returns an Awaitable (errcode, body) response + and adds exception handling to turn a InteractiveAuthIncompleteError into + a 401 response. + + Normal usage is: + + @interactive_auth_handler + async def on_POST(self, request): + # ... + await self.auth_handler.check_auth + """ + + async def wrapped(*args, **kwargs): + try: + return await orig(*args, **kwargs) + except InteractiveAuthIncompleteError as e: + return 401, e.result + + return wrapped diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py new file mode 100644 index 0000000000..fb5ad2906e --- /dev/null +++ b/synapse/rest/client/account.py @@ -0,0 +1,910 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import random +from http import HTTPStatus +from typing import TYPE_CHECKING +from urllib.parse import urlparse + +from synapse.api.constants import LoginType +from synapse.api.errors import ( + Codes, + InteractiveAuthIncompleteError, + SynapseError, + ThreepidValidationError, +) +from synapse.config.emailconfig import ThreepidBehaviour +from synapse.handlers.ui_auth import UIAuthSessionDataConstants +from synapse.http.server import finish_request, respond_with_html +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, + parse_string, +) +from synapse.metrics import threepid_send_requests +from synapse.push.mailer import Mailer +from synapse.util.msisdn import phone_number_to_msisdn +from synapse.util.stringutils import assert_valid_client_secret, random_string +from synapse.util.threepids import check_3pid_allowed, validate_email + +from ._base import client_patterns, interactive_auth_handler + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +logger = logging.getLogger(__name__) + + +class EmailPasswordRequestTokenRestServlet(RestServlet): + PATTERNS = client_patterns("/account/password/email/requestToken$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.hs = hs + self.datastore = hs.get_datastore() + self.config = hs.config + self.identity_handler = hs.get_identity_handler() + + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + self.mailer = Mailer( + hs=self.hs, + app_name=self.config.email_app_name, + template_html=self.config.email_password_reset_template_html, + template_text=self.config.email_password_reset_template_text, + ) + + async def on_POST(self, request): + if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.local_threepid_handling_disabled_due_to_email_config: + logger.warning( + "User password resets have been disabled due to lack of email config" + ) + raise SynapseError( + 400, "Email-based password resets have been disabled on this server" + ) + + body = parse_json_object_from_request(request) + + assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) + + # Extract params from body + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + + # Canonicalise the email address. The addresses are all stored canonicalised + # in the database. This allows the user to reset his password without having to + # know the exact spelling (eg. upper and lower case) of address in the database. + # Stored in the database "foo@bar.com" + # User requests with "FOO@bar.com" would raise a Not Found error + try: + email = validate_email(body["email"]) + except ValueError as e: + raise SynapseError(400, str(e)) + send_attempt = body["send_attempt"] + next_link = body.get("next_link") # Optional param + + if next_link: + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) + + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) + + # The email will be sent to the stored address. + # This avoids a potential account hijack by requesting a password reset to + # an email address which is controlled by the attacker but which, after + # canonicalisation, matches the one in our database. + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + "email", email + ) + + if existing_user_id is None: + if self.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) + return 200, {"sid": random_string(16)} + + raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) + + if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + assert self.hs.config.account_threepid_delegate_email + + # Have the configured identity server handle the request + ret = await self.identity_handler.requestEmailToken( + self.hs.config.account_threepid_delegate_email, + email, + client_secret, + send_attempt, + next_link, + ) + else: + # Send password reset emails from Synapse + sid = await self.identity_handler.send_threepid_validation( + email, + client_secret, + send_attempt, + self.mailer.send_password_reset_mail, + next_link, + ) + + # Wrap the session id in a JSON object + ret = {"sid": sid} + + threepid_send_requests.labels(type="email", reason="password_reset").observe( + send_attempt + ) + + return 200, ret + + +class PasswordRestServlet(RestServlet): + PATTERNS = client_patterns("/account/password$") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + self.datastore = self.hs.get_datastore() + self.password_policy_handler = hs.get_password_policy_handler() + self._set_password_handler = hs.get_set_password_handler() + + @interactive_auth_handler + async def on_POST(self, request): + body = parse_json_object_from_request(request) + + # we do basic sanity checks here because the auth layer will store these + # in sessions. Pull out the new password provided to us. + new_password = body.pop("new_password", None) + if new_password is not None: + if not isinstance(new_password, str) or len(new_password) > 512: + raise SynapseError(400, "Invalid password") + self.password_policy_handler.validate_password(new_password) + + # there are two possibilities here. Either the user does not have an + # access token, and needs to do a password reset; or they have one and + # need to validate their identity. + # + # In the first case, we offer a couple of means of identifying + # themselves (email and msisdn, though it's unclear if msisdn actually + # works). + # + # In the second case, we require a password to confirm their identity. + + if self.auth.has_access_token(request): + requester = await self.auth.get_user_by_req(request) + try: + params, session_id = await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "modify your account password", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth, but + # they're not required to provide the password again. + # + # If a password is available now, hash the provided password and + # store it for later. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + await self.auth_handler.set_session_data( + e.session_id, + UIAuthSessionDataConstants.PASSWORD_HASH, + password_hash, + ) + raise + user_id = requester.user.to_string() + else: + requester = None + try: + result, params, session_id = await self.auth_handler.check_ui_auth( + [[LoginType.EMAIL_IDENTITY]], + request, + body, + "modify your account password", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth, but + # they're not required to provide the password again. + # + # If a password is available now, hash the provided password and + # store it for later. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + await self.auth_handler.set_session_data( + e.session_id, + UIAuthSessionDataConstants.PASSWORD_HASH, + password_hash, + ) + raise + + if LoginType.EMAIL_IDENTITY in result: + threepid = result[LoginType.EMAIL_IDENTITY] + if "medium" not in threepid or "address" not in threepid: + raise SynapseError(500, "Malformed threepid") + if threepid["medium"] == "email": + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See add_threepid in synapse/handlers/auth.py) + try: + threepid["address"] = validate_email(threepid["address"]) + except ValueError as e: + raise SynapseError(400, str(e)) + # if using email, we must know about the email they're authing with! + threepid_user_id = await self.datastore.get_user_id_by_threepid( + threepid["medium"], threepid["address"] + ) + if not threepid_user_id: + raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) + user_id = threepid_user_id + else: + logger.error("Auth succeeded but no known type! %r", result.keys()) + raise SynapseError(500, "", Codes.UNKNOWN) + + # If we have a password in this request, prefer it. Otherwise, use the + # password hash from an earlier request. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + elif session_id is not None: + password_hash = await self.auth_handler.get_session_data( + session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None + ) + else: + # UI validation was skipped, but the request did not include a new + # password. + password_hash = None + if not password_hash: + raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) + + logout_devices = params.get("logout_devices", True) + + await self._set_password_handler.set_password( + user_id, password_hash, logout_devices, requester + ) + + return 200, {} + + +class DeactivateAccountRestServlet(RestServlet): + PATTERNS = client_patterns("/account/deactivate$") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + self._deactivate_account_handler = hs.get_deactivate_account_handler() + + @interactive_auth_handler + async def on_POST(self, request): + body = parse_json_object_from_request(request) + erase = body.get("erase", False) + if not isinstance(erase, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'erase' must be a boolean, if given", + Codes.BAD_JSON, + ) + + requester = await self.auth.get_user_by_req(request) + + # allow ASes to deactivate their own users + if requester.app_service: + await self._deactivate_account_handler.deactivate_account( + requester.user.to_string(), erase, requester + ) + return 200, {} + + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "deactivate your account", + ) + result = await self._deactivate_account_handler.deactivate_account( + requester.user.to_string(), + erase, + requester, + id_server=body.get("id_server"), + ) + if result: + id_server_unbind_result = "success" + else: + id_server_unbind_result = "no-support" + + return 200, {"id_server_unbind_result": id_server_unbind_result} + + +class EmailThreepidRequestTokenRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid/email/requestToken$") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.config = hs.config + self.identity_handler = hs.get_identity_handler() + self.store = self.hs.get_datastore() + + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + self.mailer = Mailer( + hs=self.hs, + app_name=self.config.email_app_name, + template_html=self.config.email_add_threepid_template_html, + template_text=self.config.email_add_threepid_template_text, + ) + + async def on_POST(self, request): + if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.local_threepid_handling_disabled_due_to_email_config: + logger.warning( + "Adding emails have been disabled due to lack of an email config" + ) + raise SynapseError( + 400, "Adding an email to your account is disabled on this server" + ) + + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + + # Canonicalise the email address. The addresses are all stored canonicalised + # in the database. + # This ensures that the validation email is sent to the canonicalised address + # as it will later be entered into the database. + # Otherwise the email will be sent to "FOO@bar.com" and stored as + # "foo@bar.com" in database. + try: + email = validate_email(body["email"]) + except ValueError as e: + raise SynapseError(400, str(e)) + send_attempt = body["send_attempt"] + next_link = body.get("next_link") # Optional param + + if not check_3pid_allowed(self.hs, "email", email): + raise SynapseError( + 403, + "Your email domain is not authorized on this server", + Codes.THREEPID_DENIED, + ) + + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) + + if next_link: + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) + + existing_user_id = await self.store.get_user_id_by_threepid("email", email) + + if existing_user_id is not None: + if self.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) + return 200, {"sid": random_string(16)} + + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) + + if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + assert self.hs.config.account_threepid_delegate_email + + # Have the configured identity server handle the request + ret = await self.identity_handler.requestEmailToken( + self.hs.config.account_threepid_delegate_email, + email, + client_secret, + send_attempt, + next_link, + ) + else: + # Send threepid validation emails from Synapse + sid = await self.identity_handler.send_threepid_validation( + email, + client_secret, + send_attempt, + self.mailer.send_add_threepid_mail, + next_link, + ) + + # Wrap the session id in a JSON object + ret = {"sid": sid} + + threepid_send_requests.labels(type="email", reason="add_threepid").observe( + send_attempt + ) + + return 200, ret + + +class MsisdnThreepidRequestTokenRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + super().__init__() + self.store = self.hs.get_datastore() + self.identity_handler = hs.get_identity_handler() + + async def on_POST(self, request): + body = parse_json_object_from_request(request) + assert_params_in_dict( + body, ["client_secret", "country", "phone_number", "send_attempt"] + ) + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + + country = body["country"] + phone_number = body["phone_number"] + send_attempt = body["send_attempt"] + next_link = body.get("next_link") # Optional param + + msisdn = phone_number_to_msisdn(country, phone_number) + + if not check_3pid_allowed(self.hs, "msisdn", msisdn): + raise SynapseError( + 403, + "Account phone numbers are not authorized on this server", + Codes.THREEPID_DENIED, + ) + + await self.identity_handler.ratelimit_request_token_requests( + request, "msisdn", msisdn + ) + + if next_link: + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) + + existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) + + if existing_user_id is not None: + if self.hs.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) + return 200, {"sid": random_string(16)} + + raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) + + if not self.hs.config.account_threepid_delegate_msisdn: + logger.warning( + "No upstream msisdn account_threepid_delegate configured on the server to " + "handle this request" + ) + raise SynapseError( + 400, + "Adding phone numbers to user account is not supported by this homeserver", + ) + + ret = await self.identity_handler.requestMsisdnToken( + self.hs.config.account_threepid_delegate_msisdn, + country, + phone_number, + client_secret, + send_attempt, + next_link, + ) + + threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe( + send_attempt + ) + + return 200, ret + + +class AddThreepidEmailSubmitTokenServlet(RestServlet): + """Handles 3PID validation token submission for adding an email to a user's account""" + + PATTERNS = client_patterns( + "/add_threepid/email/submit_token$", releases=(), unstable=True + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.config = hs.config + self.clock = hs.get_clock() + self.store = hs.get_datastore() + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + self._failure_email_template = ( + self.config.email_add_threepid_template_failure_html + ) + + async def on_GET(self, request): + if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.local_threepid_handling_disabled_due_to_email_config: + logger.warning( + "Adding emails have been disabled due to lack of an email config" + ) + raise SynapseError( + 400, "Adding an email to your account is disabled on this server" + ) + elif self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + raise SynapseError( + 400, + "This homeserver is not validating threepids. Use an identity server " + "instead.", + ) + + sid = parse_string(request, "sid", required=True) + token = parse_string(request, "token", required=True) + client_secret = parse_string(request, "client_secret", required=True) + assert_valid_client_secret(client_secret) + + # Attempt to validate a 3PID session + try: + # Mark the session as valid + next_link = await self.store.validate_threepid_session( + sid, client_secret, token, self.clock.time_msec() + ) + + # Perform a 302 redirect if next_link is set + if next_link: + request.setResponseCode(302) + request.setHeader("Location", next_link) + finish_request(request) + return None + + # Otherwise show the success template + html = self.config.email_add_threepid_template_success_html_content + status_code = 200 + except ThreepidValidationError as e: + status_code = e.code + + # Show a failure page with a reason + template_vars = {"failure_reason": e.msg} + html = self._failure_email_template.render(**template_vars) + + respond_with_html(request, status_code, html) + + +class AddThreepidMsisdnSubmitTokenServlet(RestServlet): + """Handles 3PID validation token submission for adding a phone number to a user's + account + """ + + PATTERNS = client_patterns( + "/add_threepid/msisdn/submit_token$", releases=(), unstable=True + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.config = hs.config + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.identity_handler = hs.get_identity_handler() + + async def on_POST(self, request): + if not self.config.account_threepid_delegate_msisdn: + raise SynapseError( + 400, + "This homeserver is not validating phone numbers. Use an identity server " + "instead.", + ) + + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["client_secret", "sid", "token"]) + assert_valid_client_secret(body["client_secret"]) + + # Proxy submit_token request to msisdn threepid delegate + response = await self.identity_handler.proxy_msisdn_submit_token( + self.config.account_threepid_delegate_msisdn, + body["client_secret"], + body["sid"], + body["token"], + ) + return 200, response + + +class ThreepidRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid$") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.identity_handler = hs.get_identity_handler() + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + self.datastore = self.hs.get_datastore() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) + + threepids = await self.datastore.user_get_threepids(requester.user.to_string()) + + return 200, {"threepids": threepids} + + async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds") + if threepid_creds is None: + raise SynapseError( + 400, "Missing param three_pid_creds", Codes.MISSING_PARAM + ) + assert_params_in_dict(threepid_creds, ["client_secret", "sid"]) + + sid = threepid_creds["sid"] + client_secret = threepid_creds["client_secret"] + assert_valid_client_secret(client_secret) + + validation_session = await self.identity_handler.validate_threepid_session( + client_secret, sid + ) + if validation_session: + await self.auth_handler.add_threepid( + user_id, + validation_session["medium"], + validation_session["address"], + validation_session["validated_at"], + ) + return 200, {} + + raise SynapseError( + 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED + ) + + +class ThreepidAddRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid/add$") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.identity_handler = hs.get_identity_handler() + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + + @interactive_auth_handler + async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + assert_params_in_dict(body, ["client_secret", "sid"]) + sid = body["sid"] + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "add a third-party identifier to your account", + ) + + validation_session = await self.identity_handler.validate_threepid_session( + client_secret, sid + ) + if validation_session: + await self.auth_handler.add_threepid( + user_id, + validation_session["medium"], + validation_session["address"], + validation_session["validated_at"], + ) + return 200, {} + + raise SynapseError( + 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED + ) + + +class ThreepidBindRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid/bind$") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.identity_handler = hs.get_identity_handler() + self.auth = hs.get_auth() + + async def on_POST(self, request): + body = parse_json_object_from_request(request) + + assert_params_in_dict(body, ["id_server", "sid", "client_secret"]) + id_server = body["id_server"] + sid = body["sid"] + id_access_token = body.get("id_access_token") # optional + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + await self.identity_handler.bind_threepid( + client_secret, sid, user_id, id_server, id_access_token + ) + + return 200, {} + + +class ThreepidUnbindRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid/unbind$") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.identity_handler = hs.get_identity_handler() + self.auth = hs.get_auth() + self.datastore = self.hs.get_datastore() + + async def on_POST(self, request): + """Unbind the given 3pid from a specific identity server, or identity servers that are + known to have this 3pid bound + """ + requester = await self.auth.get_user_by_req(request) + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["medium", "address"]) + + medium = body.get("medium") + address = body.get("address") + id_server = body.get("id_server") + + # Attempt to unbind the threepid from an identity server. If id_server is None, try to + # unbind from all identity servers this threepid has been added to in the past + result = await self.identity_handler.try_unbind_threepid( + requester.user.to_string(), + {"address": address, "medium": medium, "id_server": id_server}, + ) + return 200, {"id_server_unbind_result": "success" if result else "no-support"} + + +class ThreepidDeleteRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid/delete$") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + + async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["medium", "address"]) + + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + try: + ret = await self.auth_handler.delete_threepid( + user_id, body["medium"], body["address"], body.get("id_server") + ) + except Exception: + # NB. This endpoint should succeed if there is nothing to + # delete, so it should only throw if something is wrong + # that we ought to care about. + logger.exception("Failed to remove threepid") + raise SynapseError(500, "Failed to remove threepid") + + if ret: + id_server_unbind_result = "success" + else: + id_server_unbind_result = "no-support" + + return 200, {"id_server_unbind_result": id_server_unbind_result} + + +def assert_valid_next_link(hs: "HomeServer", next_link: str): + """ + Raises a SynapseError if a given next_link value is invalid + + next_link is valid if the scheme is http(s) and the next_link.domain_whitelist config + option is either empty or contains a domain that matches the one in the given next_link + + Args: + hs: The homeserver object + next_link: The next_link value given by the client + + Raises: + SynapseError: If the next_link is invalid + """ + valid = True + + # Parse the contents of the URL + next_link_parsed = urlparse(next_link) + + # Scheme must not point to the local drive + if next_link_parsed.scheme == "file": + valid = False + + # If the domain whitelist is set, the domain must be in it + if ( + valid + and hs.config.next_link_domain_whitelist is not None + and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist + ): + valid = False + + if not valid: + raise SynapseError( + 400, + "'next_link' domain not included in whitelist, or not http(s)", + errcode=Codes.INVALID_PARAM, + ) + + +class WhoamiRestServlet(RestServlet): + PATTERNS = client_patterns("/account/whoami$") + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) + + response = {"user_id": requester.user.to_string()} + + # Appservices and similar accounts do not have device IDs + # that we can report on, so exclude them for compliance. + if requester.device_id is not None: + response["device_id"] = requester.device_id + + return 200, response + + +def register_servlets(hs, http_server): + EmailPasswordRequestTokenRestServlet(hs).register(http_server) + PasswordRestServlet(hs).register(http_server) + DeactivateAccountRestServlet(hs).register(http_server) + EmailThreepidRequestTokenRestServlet(hs).register(http_server) + MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) + AddThreepidEmailSubmitTokenServlet(hs).register(http_server) + AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server) + ThreepidRestServlet(hs).register(http_server) + ThreepidAddRestServlet(hs).register(http_server) + ThreepidBindRestServlet(hs).register(http_server) + ThreepidUnbindRestServlet(hs).register(http_server) + ThreepidDeleteRestServlet(hs).register(http_server) + WhoamiRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py new file mode 100644 index 0000000000..7517e9304e --- /dev/null +++ b/synapse/rest/client/account_data.py @@ -0,0 +1,122 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import AuthError, NotFoundError, SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class AccountDataServlet(RestServlet): + """ + PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1 + GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1 + """ + + PATTERNS = client_patterns( + "/user/(?P[^/]*)/account_data/(?P[^/]*)" + ) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.handler = hs.get_account_data_handler() + + async def on_PUT(self, request, user_id, account_data_type): + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot add account data for other users.") + + body = parse_json_object_from_request(request) + + await self.handler.add_account_data_for_user(user_id, account_data_type, body) + + return 200, {} + + async def on_GET(self, request, user_id, account_data_type): + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot get account data for other users.") + + event = await self.store.get_global_account_data_by_type_for_user( + account_data_type, user_id + ) + + if event is None: + raise NotFoundError("Account data not found") + + return 200, event + + +class RoomAccountDataServlet(RestServlet): + """ + PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 + GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 + """ + + PATTERNS = client_patterns( + "/user/(?P[^/]*)" + "/rooms/(?P[^/]*)" + "/account_data/(?P[^/]*)" + ) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.handler = hs.get_account_data_handler() + + async def on_PUT(self, request, user_id, room_id, account_data_type): + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot add account data for other users.") + + body = parse_json_object_from_request(request) + + if account_data_type == "m.fully_read": + raise SynapseError( + 405, + "Cannot set m.fully_read through this API." + " Use /rooms/!roomId:server.name/read_markers", + ) + + await self.handler.add_account_data_to_room( + user_id, room_id, account_data_type, body + ) + + return 200, {} + + async def on_GET(self, request, user_id, room_id, account_data_type): + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot get account data for other users.") + + event = await self.store.get_account_data_for_room_and_type( + user_id, room_id, account_data_type + ) + + if event is None: + raise NotFoundError("Room account data not found") + + return 200, event + + +def register_servlets(hs, http_server): + AccountDataServlet(hs).register(http_server) + RoomAccountDataServlet(hs).register(http_server) diff --git a/synapse/rest/client/account_validity.py b/synapse/rest/client/account_validity.py new file mode 100644 index 0000000000..3ebe401861 --- /dev/null +++ b/synapse/rest/client/account_validity.py @@ -0,0 +1,104 @@ +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import SynapseError +from synapse.http.server import respond_with_html +from synapse.http.servlet import RestServlet + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class AccountValidityRenewServlet(RestServlet): + PATTERNS = client_patterns("/account_validity/renew$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + + self.hs = hs + self.account_activity_handler = hs.get_account_validity_handler() + self.auth = hs.get_auth() + self.account_renewed_template = ( + hs.config.account_validity.account_validity_account_renewed_template + ) + self.account_previously_renewed_template = ( + hs.config.account_validity.account_validity_account_previously_renewed_template + ) + self.invalid_token_template = ( + hs.config.account_validity.account_validity_invalid_token_template + ) + + async def on_GET(self, request): + if b"token" not in request.args: + raise SynapseError(400, "Missing renewal token") + renewal_token = request.args[b"token"][0] + + ( + token_valid, + token_stale, + expiration_ts, + ) = await self.account_activity_handler.renew_account( + renewal_token.decode("utf8") + ) + + if token_valid: + status_code = 200 + response = self.account_renewed_template.render(expiration_ts=expiration_ts) + elif token_stale: + status_code = 200 + response = self.account_previously_renewed_template.render( + expiration_ts=expiration_ts + ) + else: + status_code = 404 + response = self.invalid_token_template.render(expiration_ts=expiration_ts) + + respond_with_html(request, status_code, response) + + +class AccountValiditySendMailServlet(RestServlet): + PATTERNS = client_patterns("/account_validity/send_mail$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + + self.hs = hs + self.account_activity_handler = hs.get_account_validity_handler() + self.auth = hs.get_auth() + self.account_validity_renew_by_email_enabled = ( + hs.config.account_validity.account_validity_renew_by_email_enabled + ) + + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_expired=True) + user_id = requester.user.to_string() + await self.account_activity_handler.send_renewal_email_to_user(user_id) + + return 200, {} + + +def register_servlets(hs, http_server): + AccountValidityRenewServlet(hs).register(http_server) + AccountValiditySendMailServlet(hs).register(http_server) diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py new file mode 100644 index 0000000000..6ea1b50a62 --- /dev/null +++ b/synapse/rest/client/auth.py @@ -0,0 +1,143 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from synapse.api.constants import LoginType +from synapse.api.errors import SynapseError +from synapse.api.urls import CLIENT_API_PREFIX +from synapse.http.server import respond_with_html +from synapse.http.servlet import RestServlet, parse_string + +from ._base import client_patterns + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class AuthRestServlet(RestServlet): + """ + Handles Client / Server API authentication in any situations where it + cannot be handled in the normal flow (with requests to the same endpoint). + Current use is for web fallback auth. + """ + + PATTERNS = client_patterns(r"/auth/(?P[\w\.]*)/fallback/web") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + self.registration_handler = hs.get_registration_handler() + self.recaptcha_template = hs.config.recaptcha_template + self.terms_template = hs.config.terms_template + self.success_template = hs.config.fallback_success_template + + async def on_GET(self, request, stagetype): + session = parse_string(request, "session") + if not session: + raise SynapseError(400, "No session supplied") + + if stagetype == LoginType.RECAPTCHA: + html = self.recaptcha_template.render( + session=session, + myurl="%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), + sitekey=self.hs.config.recaptcha_public_key, + ) + elif stagetype == LoginType.TERMS: + html = self.terms_template.render( + session=session, + terms_url="%s_matrix/consent?v=%s" + % (self.hs.config.public_baseurl, self.hs.config.user_consent_version), + myurl="%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.TERMS), + ) + + elif stagetype == LoginType.SSO: + # Display a confirmation page which prompts the user to + # re-authenticate with their SSO provider. + html = await self.auth_handler.start_sso_ui_auth(request, session) + + else: + raise SynapseError(404, "Unknown auth stage type") + + # Render the HTML and return. + respond_with_html(request, 200, html) + return None + + async def on_POST(self, request, stagetype): + + session = parse_string(request, "session") + if not session: + raise SynapseError(400, "No session supplied") + + if stagetype == LoginType.RECAPTCHA: + response = parse_string(request, "g-recaptcha-response") + + if not response: + raise SynapseError(400, "No captcha response supplied") + + authdict = {"response": response, "session": session} + + success = await self.auth_handler.add_oob_auth( + LoginType.RECAPTCHA, authdict, request.getClientIP() + ) + + if success: + html = self.success_template.render() + else: + html = self.recaptcha_template.render( + session=session, + myurl="%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), + sitekey=self.hs.config.recaptcha_public_key, + ) + elif stagetype == LoginType.TERMS: + authdict = {"session": session} + + success = await self.auth_handler.add_oob_auth( + LoginType.TERMS, authdict, request.getClientIP() + ) + + if success: + html = self.success_template.render() + else: + html = self.terms_template.render( + session=session, + terms_url="%s_matrix/consent?v=%s" + % ( + self.hs.config.public_baseurl, + self.hs.config.user_consent_version, + ), + myurl="%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.TERMS), + ) + elif stagetype == LoginType.SSO: + # The SSO fallback workflow should not post here, + raise SynapseError(404, "Fallback SSO auth does not support POST requests.") + else: + raise SynapseError(404, "Unknown auth stage type") + + # Render the HTML and return. + respond_with_html(request, 200, html) + return None + + +def register_servlets(hs, http_server): + AuthRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py new file mode 100644 index 0000000000..88e3aac797 --- /dev/null +++ b/synapse/rest/client/capabilities.py @@ -0,0 +1,68 @@ +# Copyright 2019 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Tuple + +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES +from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict + +from ._base import client_patterns + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class CapabilitiesRestServlet(RestServlet): + """End point to expose the capabilities of the server.""" + + PATTERNS = client_patterns("/capabilities$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.hs = hs + self.config = hs.config + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.auth.get_user_by_req(request, allow_guest=True) + change_password = self.auth_handler.can_change_password() + + response = { + "capabilities": { + "m.room_versions": { + "default": self.config.default_room_version.identifier, + "available": { + v.identifier: v.disposition + for v in KNOWN_ROOM_VERSIONS.values() + }, + }, + "m.change_password": {"enabled": change_password}, + } + } + + if self.config.experimental.msc3244_enabled: + response["capabilities"]["m.room_versions"][ + "org.matrix.msc3244.room_capabilities" + ] = MSC3244_CAPABILITIES + + return 200, response + + +def register_servlets(hs: "HomeServer", http_server): + CapabilitiesRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py new file mode 100644 index 0000000000..8b9674db06 --- /dev/null +++ b/synapse/rest/client/devices.py @@ -0,0 +1,300 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api import errors +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.http.site import SynapseRequest + +from ._base import client_patterns, interactive_auth_handler + +logger = logging.getLogger(__name__) + + +class DevicesRestServlet(RestServlet): + PATTERNS = client_patterns("/devices$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + devices = await self.device_handler.get_devices_by_user( + requester.user.to_string() + ) + return 200, {"devices": devices} + + +class DeleteDevicesRestServlet(RestServlet): + """ + API for bulk deletion of devices. Accepts a JSON object with a devices + key which lists the device_ids to delete. Requires user interactive auth. + """ + + PATTERNS = client_patterns("/delete_devices") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + self.auth_handler = hs.get_auth_handler() + + @interactive_auth_handler + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) + + try: + body = parse_json_object_from_request(request) + except errors.SynapseError as e: + if e.errcode == errors.Codes.NOT_JSON: + # DELETE + # deal with older clients which didn't pass a JSON dict + # the same as those that pass an empty dict + body = {} + else: + raise e + + assert_params_in_dict(body, ["devices"]) + + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "remove device(s) from your account", + # Users might call this multiple times in a row while cleaning up + # devices, allow a single UI auth session to be re-used. + can_skip_ui_auth=True, + ) + + await self.device_handler.delete_devices( + requester.user.to_string(), body["devices"] + ) + return 200, {} + + +class DeviceRestServlet(RestServlet): + PATTERNS = client_patterns("/devices/(?P[^/]*)$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + self.auth_handler = hs.get_auth_handler() + + async def on_GET(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + device = await self.device_handler.get_device( + requester.user.to_string(), device_id + ) + return 200, device + + @interactive_auth_handler + async def on_DELETE(self, request, device_id): + requester = await self.auth.get_user_by_req(request) + + try: + body = parse_json_object_from_request(request) + + except errors.SynapseError as e: + if e.errcode == errors.Codes.NOT_JSON: + # deal with older clients which didn't pass a JSON dict + # the same as those that pass an empty dict + body = {} + else: + raise + + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "remove a device from your account", + # Users might call this multiple times in a row while cleaning up + # devices, allow a single UI auth session to be re-used. + can_skip_ui_auth=True, + ) + + await self.device_handler.delete_device(requester.user.to_string(), device_id) + return 200, {} + + async def on_PUT(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + body = parse_json_object_from_request(request) + await self.device_handler.update_device( + requester.user.to_string(), device_id, body + ) + return 200, {} + + +class DehydratedDeviceServlet(RestServlet): + """Retrieve or store a dehydrated device. + + GET /org.matrix.msc2697.v2/dehydrated_device + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "device_id": "dehydrated_device_id", + "device_data": { + "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", + "account": "dehydrated_device" + } + } + + PUT /org.matrix.msc2697/dehydrated_device + Content-Type: application/json + + { + "device_data": { + "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", + "account": "dehydrated_device" + } + } + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "device_id": "dehydrated_device_id" + } + + """ + + PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=()) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + async def on_GET(self, request: SynapseRequest): + requester = await self.auth.get_user_by_req(request) + dehydrated_device = await self.device_handler.get_dehydrated_device( + requester.user.to_string() + ) + if dehydrated_device is not None: + (device_id, device_data) = dehydrated_device + result = {"device_id": device_id, "device_data": device_data} + return (200, result) + else: + raise errors.NotFoundError("No dehydrated device available") + + async def on_PUT(self, request: SynapseRequest): + submission = parse_json_object_from_request(request) + requester = await self.auth.get_user_by_req(request) + + if "device_data" not in submission: + raise errors.SynapseError( + 400, + "device_data missing", + errcode=errors.Codes.MISSING_PARAM, + ) + elif not isinstance(submission["device_data"], dict): + raise errors.SynapseError( + 400, + "device_data must be an object", + errcode=errors.Codes.INVALID_PARAM, + ) + + device_id = await self.device_handler.store_dehydrated_device( + requester.user.to_string(), + submission["device_data"], + submission.get("initial_device_display_name", None), + ) + return 200, {"device_id": device_id} + + +class ClaimDehydratedDeviceServlet(RestServlet): + """Claim a dehydrated device. + + POST /org.matrix.msc2697.v2/dehydrated_device/claim + Content-Type: application/json + + { + "device_id": "dehydrated_device_id" + } + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "success": true, + } + + """ + + PATTERNS = client_patterns( + "/org.matrix.msc2697.v2/dehydrated_device/claim", releases=() + ) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + async def on_POST(self, request: SynapseRequest): + requester = await self.auth.get_user_by_req(request) + + submission = parse_json_object_from_request(request) + + if "device_id" not in submission: + raise errors.SynapseError( + 400, + "device_id missing", + errcode=errors.Codes.MISSING_PARAM, + ) + elif not isinstance(submission["device_id"], str): + raise errors.SynapseError( + 400, + "device_id must be a string", + errcode=errors.Codes.INVALID_PARAM, + ) + + result = await self.device_handler.rehydrate_device( + requester.user.to_string(), + self.auth.get_access_token_from_request(request), + submission["device_id"], + ) + + return (200, result) + + +def register_servlets(hs, http_server): + DeleteDevicesRestServlet(hs).register(http_server) + DevicesRestServlet(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) + DehydratedDeviceServlet(hs).register(http_server) + ClaimDehydratedDeviceServlet(hs).register(http_server) diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py new file mode 100644 index 0000000000..ffa075c8e5 --- /dev/null +++ b/synapse/rest/client/directory.py @@ -0,0 +1,185 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +from synapse.api.errors import ( + AuthError, + Codes, + InvalidClientCredentialsError, + NotFoundError, + SynapseError, +) +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.rest.client._base import client_patterns +from synapse.types import RoomAlias + +logger = logging.getLogger(__name__) + + +def register_servlets(hs, http_server): + ClientDirectoryServer(hs).register(http_server) + ClientDirectoryListServer(hs).register(http_server) + ClientAppserviceDirectoryListServer(hs).register(http_server) + + +class ClientDirectoryServer(RestServlet): + PATTERNS = client_patterns("/directory/room/(?P[^/]*)$", v1=True) + + def __init__(self, hs): + super().__init__() + self.store = hs.get_datastore() + self.directory_handler = hs.get_directory_handler() + self.auth = hs.get_auth() + + async def on_GET(self, request, room_alias): + room_alias = RoomAlias.from_string(room_alias) + + res = await self.directory_handler.get_association(room_alias) + + return 200, res + + async def on_PUT(self, request, room_alias): + room_alias = RoomAlias.from_string(room_alias) + + content = parse_json_object_from_request(request) + if "room_id" not in content: + raise SynapseError( + 400, 'Missing params: ["room_id"]', errcode=Codes.BAD_JSON + ) + + logger.debug("Got content: %s", content) + logger.debug("Got room name: %s", room_alias.to_string()) + + room_id = content["room_id"] + servers = content["servers"] if "servers" in content else None + + logger.debug("Got room_id: %s", room_id) + logger.debug("Got servers: %s", servers) + + # TODO(erikj): Check types. + + room = await self.store.get_room(room_id) + if room is None: + raise SynapseError(400, "Room does not exist") + + requester = await self.auth.get_user_by_req(request) + + await self.directory_handler.create_association( + requester, room_alias, room_id, servers + ) + + return 200, {} + + async def on_DELETE(self, request, room_alias): + try: + service = self.auth.get_appservice_by_req(request) + room_alias = RoomAlias.from_string(room_alias) + await self.directory_handler.delete_appservice_association( + service, room_alias + ) + logger.info( + "Application service at %s deleted alias %s", + service.url, + room_alias.to_string(), + ) + return 200, {} + except InvalidClientCredentialsError: + # fallback to default user behaviour if they aren't an AS + pass + + requester = await self.auth.get_user_by_req(request) + user = requester.user + + room_alias = RoomAlias.from_string(room_alias) + + await self.directory_handler.delete_association(requester, room_alias) + + logger.info( + "User %s deleted alias %s", user.to_string(), room_alias.to_string() + ) + + return 200, {} + + +class ClientDirectoryListServer(RestServlet): + PATTERNS = client_patterns("/directory/list/room/(?P[^/]*)$", v1=True) + + def __init__(self, hs): + super().__init__() + self.store = hs.get_datastore() + self.directory_handler = hs.get_directory_handler() + self.auth = hs.get_auth() + + async def on_GET(self, request, room_id): + room = await self.store.get_room(room_id) + if room is None: + raise NotFoundError("Unknown room") + + return 200, {"visibility": "public" if room["is_public"] else "private"} + + async def on_PUT(self, request, room_id): + requester = await self.auth.get_user_by_req(request) + + content = parse_json_object_from_request(request) + visibility = content.get("visibility", "public") + + await self.directory_handler.edit_published_room_list( + requester, room_id, visibility + ) + + return 200, {} + + async def on_DELETE(self, request, room_id): + requester = await self.auth.get_user_by_req(request) + + await self.directory_handler.edit_published_room_list( + requester, room_id, "private" + ) + + return 200, {} + + +class ClientAppserviceDirectoryListServer(RestServlet): + PATTERNS = client_patterns( + "/directory/list/appservice/(?P[^/]*)/(?P[^/]*)$", v1=True + ) + + def __init__(self, hs): + super().__init__() + self.store = hs.get_datastore() + self.directory_handler = hs.get_directory_handler() + self.auth = hs.get_auth() + + def on_PUT(self, request, network_id, room_id): + content = parse_json_object_from_request(request) + visibility = content.get("visibility", "public") + return self._edit(request, network_id, room_id, visibility) + + def on_DELETE(self, request, network_id, room_id): + return self._edit(request, network_id, room_id, "private") + + async def _edit(self, request, network_id, room_id, visibility): + requester = await self.auth.get_user_by_req(request) + if not requester.app_service: + raise AuthError( + 403, "Only appservices can edit the appservice published room list" + ) + + await self.directory_handler.edit_published_appservice_room_list( + requester.app_service.id, network_id, room_id, visibility + ) + + return 200, {} diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py new file mode 100644 index 0000000000..52bb579cfd --- /dev/null +++ b/synapse/rest/client/events.py @@ -0,0 +1,94 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains REST servlets to do with event streaming, /events.""" +import logging + +from synapse.api.errors import SynapseError +from synapse.http.servlet import RestServlet +from synapse.rest.client._base import client_patterns +from synapse.streams.config import PaginationConfig + +logger = logging.getLogger(__name__) + + +class EventStreamRestServlet(RestServlet): + PATTERNS = client_patterns("/events$", v1=True) + + DEFAULT_LONGPOLL_TIME_MS = 30000 + + def __init__(self, hs): + super().__init__() + self.event_stream_handler = hs.get_event_stream_handler() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + is_guest = requester.is_guest + room_id = None + if is_guest: + if b"room_id" not in request.args: + raise SynapseError(400, "Guest users must specify room_id param") + if b"room_id" in request.args: + room_id = request.args[b"room_id"][0].decode("ascii") + + pagin_config = await PaginationConfig.from_request(self.store, request) + timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS + if b"timeout" in request.args: + try: + timeout = int(request.args[b"timeout"][0]) + except ValueError: + raise SynapseError(400, "timeout must be in milliseconds.") + + as_client_event = b"raw" not in request.args + + chunk = await self.event_stream_handler.get_stream( + requester.user.to_string(), + pagin_config, + timeout=timeout, + as_client_event=as_client_event, + affect_presence=(not is_guest), + room_id=room_id, + is_guest=is_guest, + ) + + return 200, chunk + + +class EventRestServlet(RestServlet): + PATTERNS = client_patterns("/events/(?P[^/]*)$", v1=True) + + def __init__(self, hs): + super().__init__() + self.clock = hs.get_clock() + self.event_handler = hs.get_event_handler() + self.auth = hs.get_auth() + self._event_serializer = hs.get_event_client_serializer() + + async def on_GET(self, request, event_id): + requester = await self.auth.get_user_by_req(request) + event = await self.event_handler.get_event(requester.user, None, event_id) + + time_now = self.clock.time_msec() + if event: + event = await self._event_serializer.serialize_event(event, time_now) + return 200, event + else: + return 404, "Event not found." + + +def register_servlets(hs, http_server): + EventStreamRestServlet(hs).register(http_server) + EventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py new file mode 100644 index 0000000000..411667a9c8 --- /dev/null +++ b/synapse/rest/client/filter.py @@ -0,0 +1,94 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.types import UserID + +from ._base import client_patterns, set_timeline_upper_limit + +logger = logging.getLogger(__name__) + + +class GetFilterRestServlet(RestServlet): + PATTERNS = client_patterns("/user/(?P[^/]*)/filter/(?P[^/]*)") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.filtering = hs.get_filtering() + + async def on_GET(self, request, user_id, filter_id): + target_user = UserID.from_string(user_id) + requester = await self.auth.get_user_by_req(request) + + if target_user != requester.user: + raise AuthError(403, "Cannot get filters for other users") + + if not self.hs.is_mine(target_user): + raise AuthError(403, "Can only get filters for local users") + + try: + filter_id = int(filter_id) + except Exception: + raise SynapseError(400, "Invalid filter_id") + + try: + filter_collection = await self.filtering.get_user_filter( + user_localpart=target_user.localpart, filter_id=filter_id + ) + except StoreError as e: + if e.code != 404: + raise + raise NotFoundError("No such filter") + + return 200, filter_collection.get_filter_json() + + +class CreateFilterRestServlet(RestServlet): + PATTERNS = client_patterns("/user/(?P[^/]*)/filter") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.filtering = hs.get_filtering() + + async def on_POST(self, request, user_id): + + target_user = UserID.from_string(user_id) + requester = await self.auth.get_user_by_req(request) + + if target_user != requester.user: + raise AuthError(403, "Cannot create filters for other users") + + if not self.hs.is_mine(target_user): + raise AuthError(403, "Can only create filters for local users") + + content = parse_json_object_from_request(request) + set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit) + + filter_id = await self.filtering.add_user_filter( + user_localpart=target_user.localpart, user_filter=content + ) + + return 200, {"filter_id": str(filter_id)} + + +def register_servlets(hs, http_server): + GetFilterRestServlet(hs).register(http_server) + CreateFilterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py new file mode 100644 index 0000000000..6285680c00 --- /dev/null +++ b/synapse/rest/client/groups.py @@ -0,0 +1,957 @@ +# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from functools import wraps +from typing import TYPE_CHECKING, Optional, Tuple + +from twisted.web.server import Request + +from synapse.api.constants import ( + MAX_GROUP_CATEGORYID_LENGTH, + MAX_GROUP_ROLEID_LENGTH, + MAX_GROUPID_LENGTH, +) +from synapse.api.errors import Codes, SynapseError +from synapse.handlers.groups_local import GroupsLocalHandler +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.http.site import SynapseRequest +from synapse.types import GroupID, JsonDict + +from ._base import client_patterns + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +def _validate_group_id(f): + """Wrapper to validate the form of the group ID. + + Can be applied to any on_FOO methods that accepts a group ID as a URL parameter. + """ + + @wraps(f) + def wrapper(self, request: Request, group_id: str, *args, **kwargs): + if not GroupID.is_valid(group_id): + raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) + + return f(self, request, group_id, *args, **kwargs) + + return wrapper + + +class GroupServlet(RestServlet): + """Get the group profile""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/profile$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + group_description = await self.groups_handler.get_group_profile( + group_id, requester_user_id + ) + + return 200, group_description + + @_validate_group_id + async def on_POST( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + assert_params_in_dict( + content, ("name", "avatar_url", "short_description", "long_description") + ) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot create group profiles." + await self.groups_handler.update_group_profile( + group_id, requester_user_id, content + ) + + return 200, {} + + +class GroupSummaryServlet(RestServlet): + """Get the full group summary""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/summary$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + get_group_summary = await self.groups_handler.get_group_summary( + group_id, requester_user_id + ) + + return 200, get_group_summary + + +class GroupSummaryRoomsCatServlet(RestServlet): + """Update/delete a rooms entry in the summary. + + Matches both: + - /groups/:group/summary/rooms/:room_id + - /groups/:group/summary/categories/:category/rooms/:room_id + """ + + PATTERNS = client_patterns( + "/groups/(?P[^/]*)/summary" + "(/categories/(?P[^/]+))?" + "/rooms/(?P[^/]*)$" + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_PUT( + self, + request: SynapseRequest, + group_id: str, + category_id: Optional[str], + room_id: str, + ): + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM) + + if category_id and len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: + raise SynapseError( + 400, + "category_id may not be longer than %s characters" + % (MAX_GROUP_CATEGORYID_LENGTH,), + Codes.INVALID_PARAM, + ) + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group summaries." + resp = await self.groups_handler.update_group_summary_room( + group_id, + requester_user_id, + room_id=room_id, + category_id=category_id, + content=content, + ) + + return 200, resp + + @_validate_group_id + async def on_DELETE( + self, request: SynapseRequest, group_id: str, category_id: str, room_id: str + ): + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group profiles." + resp = await self.groups_handler.delete_group_summary_room( + group_id, requester_user_id, room_id=room_id, category_id=category_id + ) + + return 200, resp + + +class GroupCategoryServlet(RestServlet): + """Get/add/update/delete a group category""" + + PATTERNS = client_patterns( + "/groups/(?P[^/]*)/categories/(?P[^/]+)$" + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_GET( + self, request: SynapseRequest, group_id: str, category_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = await self.groups_handler.get_group_category( + group_id, requester_user_id, category_id=category_id + ) + + return 200, category + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id: str, category_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + if not category_id: + raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM) + + if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: + raise SynapseError( + 400, + "category_id may not be longer than %s characters" + % (MAX_GROUP_CATEGORYID_LENGTH,), + Codes.INVALID_PARAM, + ) + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group categories." + resp = await self.groups_handler.update_group_category( + group_id, requester_user_id, category_id=category_id, content=content + ) + + return 200, resp + + @_validate_group_id + async def on_DELETE( + self, request: SynapseRequest, group_id: str, category_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group categories." + resp = await self.groups_handler.delete_group_category( + group_id, requester_user_id, category_id=category_id + ) + + return 200, resp + + +class GroupCategoriesServlet(RestServlet): + """Get all group categories""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/categories/$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = await self.groups_handler.get_group_categories( + group_id, requester_user_id + ) + + return 200, category + + +class GroupRoleServlet(RestServlet): + """Get/add/update/delete a group role""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/(?P[^/]+)$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_GET( + self, request: SynapseRequest, group_id: str, role_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = await self.groups_handler.get_group_role( + group_id, requester_user_id, role_id=role_id + ) + + return 200, category + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id: str, role_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + if not role_id: + raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM) + + if len(role_id) > MAX_GROUP_ROLEID_LENGTH: + raise SynapseError( + 400, + "role_id may not be longer than %s characters" + % (MAX_GROUP_ROLEID_LENGTH,), + Codes.INVALID_PARAM, + ) + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group roles." + resp = await self.groups_handler.update_group_role( + group_id, requester_user_id, role_id=role_id, content=content + ) + + return 200, resp + + @_validate_group_id + async def on_DELETE( + self, request: SynapseRequest, group_id: str, role_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group roles." + resp = await self.groups_handler.delete_group_role( + group_id, requester_user_id, role_id=role_id + ) + + return 200, resp + + +class GroupRolesServlet(RestServlet): + """Get all group roles""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = await self.groups_handler.get_group_roles( + group_id, requester_user_id + ) + + return 200, category + + +class GroupSummaryUsersRoleServlet(RestServlet): + """Update/delete a user's entry in the summary. + + Matches both: + - /groups/:group/summary/users/:room_id + - /groups/:group/summary/roles/:role/users/:user_id + """ + + PATTERNS = client_patterns( + "/groups/(?P[^/]*)/summary" + "(/roles/(?P[^/]+))?" + "/users/(?P[^/]*)$" + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_PUT( + self, + request: SynapseRequest, + group_id: str, + role_id: Optional[str], + user_id: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM) + + if role_id and len(role_id) > MAX_GROUP_ROLEID_LENGTH: + raise SynapseError( + 400, + "role_id may not be longer than %s characters" + % (MAX_GROUP_ROLEID_LENGTH,), + Codes.INVALID_PARAM, + ) + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group summaries." + resp = await self.groups_handler.update_group_summary_user( + group_id, + requester_user_id, + user_id=user_id, + role_id=role_id, + content=content, + ) + + return 200, resp + + @_validate_group_id + async def on_DELETE( + self, request: SynapseRequest, group_id: str, role_id: str, user_id: str + ): + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group summaries." + resp = await self.groups_handler.delete_group_summary_user( + group_id, requester_user_id, user_id=user_id, role_id=role_id + ) + + return 200, resp + + +class GroupRoomServlet(RestServlet): + """Get all rooms in a group""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/rooms$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + result = await self.groups_handler.get_rooms_in_group( + group_id, requester_user_id + ) + + return 200, result + + +class GroupUsersServlet(RestServlet): + """Get all users in a group""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/users$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + result = await self.groups_handler.get_users_in_group( + group_id, requester_user_id + ) + + return 200, result + + +class GroupInvitedUsersServlet(RestServlet): + """Get users invited to a group""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/invited_users$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + result = await self.groups_handler.get_invited_users_in_group( + group_id, requester_user_id + ) + + return 200, result + + +class GroupSettingJoinPolicyServlet(RestServlet): + """Set group join policy""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/settings/m.join_policy$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group join policy." + result = await self.groups_handler.set_group_join_policy( + group_id, requester_user_id, content + ) + + return 200, result + + +class GroupCreateServlet(RestServlet): + """Create a group""" + + PATTERNS = client_patterns("/create_group$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + self.server_name = hs.hostname + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + # TODO: Create group on remote server + content = parse_json_object_from_request(request) + localpart = content.pop("localpart") + group_id = GroupID(localpart, self.server_name).to_string() + + if not localpart: + raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) + + if len(group_id) > MAX_GROUPID_LENGTH: + raise SynapseError( + 400, + "Group ID may not be longer than %s characters" % (MAX_GROUPID_LENGTH,), + Codes.INVALID_PARAM, + ) + + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot create groups." + result = await self.groups_handler.create_group( + group_id, requester_user_id, content + ) + + return 200, result + + +class GroupAdminRoomsServlet(RestServlet): + """Add a room to the group""" + + PATTERNS = client_patterns( + "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)$" + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id: str, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify rooms in a group." + result = await self.groups_handler.add_room_to_group( + group_id, requester_user_id, room_id, content + ) + + return 200, result + + @_validate_group_id + async def on_DELETE( + self, request: SynapseRequest, group_id: str, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group categories." + result = await self.groups_handler.remove_room_from_group( + group_id, requester_user_id, room_id + ) + + return 200, result + + +class GroupAdminRoomsConfigServlet(RestServlet): + """Update the config of a room in a group""" + + PATTERNS = client_patterns( + "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)" + "/config/(?P[^/]*)$" + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id: str, room_id: str, config_key: str + ): + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group categories." + result = await self.groups_handler.update_room_in_group( + group_id, requester_user_id, room_id, config_key, content + ) + + return 200, result + + +class GroupAdminUsersInviteServlet(RestServlet): + """Invite a user to the group""" + + PATTERNS = client_patterns( + "/groups/(?P[^/]*)/admin/users/invite/(?P[^/]*)$" + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + self.store = hs.get_datastore() + self.is_mine_id = hs.is_mine_id + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id, user_id + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + config = content.get("config", {}) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot invite users to a group." + result = await self.groups_handler.invite( + group_id, user_id, requester_user_id, config + ) + + return 200, result + + +class GroupAdminUsersKickServlet(RestServlet): + """Kick a user from the group""" + + PATTERNS = client_patterns( + "/groups/(?P[^/]*)/admin/users/remove/(?P[^/]*)$" + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id, user_id + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot kick users from a group." + result = await self.groups_handler.remove_user_from_group( + group_id, user_id, requester_user_id, content + ) + + return 200, result + + +class GroupSelfLeaveServlet(RestServlet): + """Leave a joined group""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/leave$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot leave a group for a users." + result = await self.groups_handler.remove_user_from_group( + group_id, requester_user_id, requester_user_id, content + ) + + return 200, result + + +class GroupSelfJoinServlet(RestServlet): + """Attempt to join a group, or knock""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/join$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot join a user to a group." + result = await self.groups_handler.join_group( + group_id, requester_user_id, content + ) + + return 200, result + + +class GroupSelfAcceptInviteServlet(RestServlet): + """Accept a group invite""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/accept_invite$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot accept an invite to a group." + result = await self.groups_handler.accept_invite( + group_id, requester_user_id, content + ) + + return 200, result + + +class GroupSelfUpdatePublicityServlet(RestServlet): + """Update whether we publicise a users membership of a group""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/update_publicity$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.store = hs.get_datastore() + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + publicise = content["publicise"] + await self.store.update_group_publicity(group_id, requester_user_id, publicise) + + return 200, {} + + +class PublicisedGroupsForUserServlet(RestServlet): + """Get the list of groups a user is advertising""" + + PATTERNS = client_patterns("/publicised_groups/(?P[^/]*)$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.groups_handler = hs.get_groups_local_handler() + + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + await self.auth.get_user_by_req(request, allow_guest=True) + + result = await self.groups_handler.get_publicised_groups_for_user(user_id) + + return 200, result + + +class PublicisedGroupsForUsersServlet(RestServlet): + """Get the list of groups a user is advertising""" + + PATTERNS = client_patterns("/publicised_groups$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.groups_handler = hs.get_groups_local_handler() + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.auth.get_user_by_req(request, allow_guest=True) + + content = parse_json_object_from_request(request) + user_ids = content["user_ids"] + + result = await self.groups_handler.bulk_get_publicised_groups(user_ids) + + return 200, result + + +class GroupsForUserServlet(RestServlet): + """Get all groups the logged in user is joined to""" + + PATTERNS = client_patterns("/joined_groups$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + result = await self.groups_handler.get_joined_groups(requester_user_id) + + return 200, result + + +def register_servlets(hs: "HomeServer", http_server): + GroupServlet(hs).register(http_server) + GroupSummaryServlet(hs).register(http_server) + GroupInvitedUsersServlet(hs).register(http_server) + GroupUsersServlet(hs).register(http_server) + GroupRoomServlet(hs).register(http_server) + GroupSettingJoinPolicyServlet(hs).register(http_server) + GroupCreateServlet(hs).register(http_server) + GroupAdminRoomsServlet(hs).register(http_server) + GroupAdminRoomsConfigServlet(hs).register(http_server) + GroupAdminUsersInviteServlet(hs).register(http_server) + GroupAdminUsersKickServlet(hs).register(http_server) + GroupSelfLeaveServlet(hs).register(http_server) + GroupSelfJoinServlet(hs).register(http_server) + GroupSelfAcceptInviteServlet(hs).register(http_server) + GroupsForUserServlet(hs).register(http_server) + GroupCategoryServlet(hs).register(http_server) + GroupCategoriesServlet(hs).register(http_server) + GroupSummaryRoomsCatServlet(hs).register(http_server) + GroupRoleServlet(hs).register(http_server) + GroupRolesServlet(hs).register(http_server) + GroupSelfUpdatePublicityServlet(hs).register(http_server) + GroupSummaryUsersRoleServlet(hs).register(http_server) + PublicisedGroupsForUserServlet(hs).register(http_server) + PublicisedGroupsForUsersServlet(hs).register(http_server) diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py new file mode 100644 index 0000000000..12ba0e91db --- /dev/null +++ b/synapse/rest/client/initial_sync.py @@ -0,0 +1,47 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.http.servlet import RestServlet, parse_boolean +from synapse.rest.client._base import client_patterns +from synapse.streams.config import PaginationConfig + + +# TODO: Needs unit testing +class InitialSyncRestServlet(RestServlet): + PATTERNS = client_patterns("/initialSync$", v1=True) + + def __init__(self, hs): + super().__init__() + self.initial_sync_handler = hs.get_initial_sync_handler() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) + as_client_event = b"raw" not in request.args + pagination_config = await PaginationConfig.from_request(self.store, request) + include_archived = parse_boolean(request, "archived", default=False) + content = await self.initial_sync_handler.snapshot_all_rooms( + user_id=requester.user.to_string(), + pagin_config=pagination_config, + as_client_event=as_client_event, + include_archived=include_archived, + ) + + return 200, content + + +def register_servlets(hs, http_server): + InitialSyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py new file mode 100644 index 0000000000..d0d9d30d40 --- /dev/null +++ b/synapse/rest/client/keys.py @@ -0,0 +1,344 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_integer, + parse_json_object_from_request, + parse_string, +) +from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.types import StreamToken + +from ._base import client_patterns, interactive_auth_handler + +logger = logging.getLogger(__name__) + + +class KeyUploadServlet(RestServlet): + """ + POST /keys/upload HTTP/1.1 + Content-Type: application/json + + { + "device_keys": { + "user_id": "", + "device_id": "", + "valid_until_ts": , + "algorithms": [ + "m.olm.curve25519-aes-sha2", + ] + "keys": { + ":": "", + }, + "signatures:" { + "" { + ":": "" + } } }, + "one_time_keys": { + ":": "" + }, + } + """ + + PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + self.device_handler = hs.get_device_handler() + + @trace(opname="upload_keys") + async def on_POST(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + if device_id is not None: + # Providing the device_id should only be done for setting keys + # for dehydrated devices; however, we allow it for any device for + # compatibility with older clients. + if requester.device_id is not None and device_id != requester.device_id: + dehydrated_device = await self.device_handler.get_dehydrated_device( + user_id + ) + if dehydrated_device is not None and device_id != dehydrated_device[0]: + set_tag("error", True) + log_kv( + { + "message": "Client uploading keys for a different device", + "logged_in_id": requester.device_id, + "key_being_uploaded": device_id, + } + ) + logger.warning( + "Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, + device_id, + ) + else: + device_id = requester.device_id + + if device_id is None: + raise SynapseError( + 400, "To upload keys, you must pass device_id when authenticating" + ) + + result = await self.e2e_keys_handler.upload_keys_for_user( + user_id, device_id, body + ) + return 200, result + + +class KeyQueryServlet(RestServlet): + """ + POST /keys/query HTTP/1.1 + Content-Type: application/json + { + "device_keys": { + "": [""] + } } + + HTTP/1.1 200 OK + { + "device_keys": { + "": { + "": { + "user_id": "", // Duplicated to be signed + "device_id": "", // Duplicated to be signed + "valid_until_ts": , + "algorithms": [ // List of supported algorithms + "m.olm.curve25519-aes-sha2", + ], + "keys": { // Must include a ed25519 signing key + ":": "", + }, + "signatures:" { + // Must be signed with device's ed25519 key + "/": { + ":": "" + } + // Must be signed by this server. + "": { + ":": "" + } } } } } } + """ + + PATTERNS = client_patterns("/keys/query$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ + super().__init__() + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user_id = requester.user.to_string() + device_id = requester.device_id + timeout = parse_integer(request, "timeout", 10 * 1000) + body = parse_json_object_from_request(request) + result = await self.e2e_keys_handler.query_devices( + body, timeout, user_id, device_id + ) + return 200, result + + +class KeyChangesServlet(RestServlet): + """Returns the list of changes of keys between two stream tokens (may return + spurious extra results, since we currently ignore the `to` param). + + GET /keys/changes?from=...&to=... + + 200 OK + { "changed": ["@foo:example.com"] } + """ + + PATTERNS = client_patterns("/keys/changes$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ + super().__init__() + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + self.store = hs.get_datastore() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + from_token_string = parse_string(request, "from", required=True) + set_tag("from", from_token_string) + + # We want to enforce they do pass us one, but we ignore it and return + # changes after the "to" as well as before. + set_tag("to", parse_string(request, "to")) + + from_token = await StreamToken.from_string(self.store, from_token_string) + + user_id = requester.user.to_string() + + results = await self.device_handler.get_user_ids_changed(user_id, from_token) + + return 200, results + + +class OneTimeKeyServlet(RestServlet): + """ + POST /keys/claim HTTP/1.1 + { + "one_time_keys": { + "": { + "": "" + } } } + + HTTP/1.1 200 OK + { + "one_time_keys": { + "": { + "": { + ":": "" + } } } } + + """ + + PATTERNS = client_patterns("/keys/claim$") + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) + timeout = parse_integer(request, "timeout", 10 * 1000) + body = parse_json_object_from_request(request) + result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout) + return 200, result + + +class SigningKeyUploadServlet(RestServlet): + """ + POST /keys/device_signing/upload HTTP/1.1 + Content-Type: application/json + + { + } + """ + + PATTERNS = client_patterns("/keys/device_signing/upload$", releases=()) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + self.auth_handler = hs.get_auth_handler() + + @interactive_auth_handler + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "add a device signing key to your account", + # Allow skipping of UI auth since this is frequently called directly + # after login and it is silly to ask users to re-auth immediately. + can_skip_ui_auth=True, + ) + + result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) + return 200, result + + +class SignaturesUploadServlet(RestServlet): + """ + POST /keys/signatures/upload HTTP/1.1 + Content-Type: application/json + + { + "@alice:example.com": { + "": { + "user_id": "", + "device_id": "", + "algorithms": [ + "m.olm.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2" + ], + "keys": { + ":": "", + }, + "signatures": { + "": { + ":": ">" + } + } + } + } + } + """ + + PATTERNS = client_patterns("/keys/signatures/upload$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + result = await self.e2e_keys_handler.upload_signatures_for_device_keys( + user_id, body + ) + return 200, result + + +def register_servlets(hs, http_server): + KeyUploadServlet(hs).register(http_server) + KeyQueryServlet(hs).register(http_server) + KeyChangesServlet(hs).register(http_server) + OneTimeKeyServlet(hs).register(http_server) + SigningKeyUploadServlet(hs).register(http_server) + SignaturesUploadServlet(hs).register(http_server) diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py new file mode 100644 index 0000000000..7d1bc40658 --- /dev/null +++ b/synapse/rest/client/knock.py @@ -0,0 +1,107 @@ +# Copyright 2020 Sorunome +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + +from twisted.web.server import Request + +from synapse.api.constants import Membership +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_json_object_from_request, + parse_strings_from_args, +) +from synapse.http.site import SynapseRequest +from synapse.logging.opentracing import set_tag +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.types import JsonDict, RoomAlias, RoomID + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class KnockRoomAliasServlet(RestServlet): + """ + POST /knock/{roomIdOrAlias} + """ + + PATTERNS = client_patterns("/knock/(?P[^/]*)") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.txns = HttpTransactionCache(hs) + self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() + + async def on_POST( + self, + request: SynapseRequest, + room_identifier: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + content = parse_json_object_from_request(request) + event_content = None + if "reason" in content: + event_content = {"reason": content["reason"]} + + if RoomID.is_valid(room_identifier): + room_id = room_identifier + + # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + args: Dict[bytes, List[bytes]] = request.args # type: ignore + + remote_room_hosts = parse_strings_from_args( + args, "server_name", required=False + ) + elif RoomAlias.is_valid(room_identifier): + handler = self.room_member_handler + room_alias = RoomAlias.from_string(room_identifier) + room_id_obj, remote_room_hosts = await handler.lookup_room_alias(room_alias) + room_id = room_id_obj.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + + await self.room_member_handler.update_membership( + requester=requester, + target=requester.user, + room_id=room_id, + action=Membership.KNOCK, + txn_id=txn_id, + third_party_signed=None, + remote_room_hosts=remote_room_hosts, + content=event_content, + ) + + return 200, {"room_id": room_id} + + def on_PUT(self, request: Request, room_identifier: str, txn_id: str): + set_tag("txn_id", txn_id) + + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_identifier, txn_id + ) + + +def register_servlets(hs, http_server): + KnockRoomAliasServlet(hs).register(http_server) diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py new file mode 100644 index 0000000000..0c8d8967b7 --- /dev/null +++ b/synapse/rest/client/login.py @@ -0,0 +1,600 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional + +from typing_extensions import TypedDict + +from synapse.api.errors import Codes, LoginError, SynapseError +from synapse.api.ratelimiting import Ratelimiter +from synapse.api.urls import CLIENT_API_PREFIX +from synapse.appservice import ApplicationService +from synapse.handlers.sso import SsoIdentityProvider +from synapse.http import get_request_uri +from synapse.http.server import HttpServer, finish_request +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_boolean, + parse_bytes_from_args, + parse_json_object_from_request, + parse_string, +) +from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns +from synapse.rest.well_known import WellKnownBuilder +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class LoginResponse(TypedDict, total=False): + user_id: str + access_token: str + home_server: str + expires_in_ms: Optional[int] + refresh_token: Optional[str] + device_id: str + well_known: Optional[Dict[str, Any]] + + +class LoginRestServlet(RestServlet): + PATTERNS = client_patterns("/login$", v1=True) + CAS_TYPE = "m.login.cas" + SSO_TYPE = "m.login.sso" + TOKEN_TYPE = "m.login.token" + JWT_TYPE = "org.matrix.login.jwt" + JWT_TYPE_DEPRECATED = "m.login.jwt" + APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" + REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token" + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.hs = hs + + # JWT configuration variables. + self.jwt_enabled = hs.config.jwt_enabled + self.jwt_secret = hs.config.jwt_secret + self.jwt_algorithm = hs.config.jwt_algorithm + self.jwt_issuer = hs.config.jwt_issuer + self.jwt_audiences = hs.config.jwt_audiences + + # SSO configuration. + self.saml2_enabled = hs.config.saml2_enabled + self.cas_enabled = hs.config.cas_enabled + self.oidc_enabled = hs.config.oidc_enabled + self._msc2858_enabled = hs.config.experimental.msc2858_enabled + self._msc2918_enabled = hs.config.access_token_lifetime is not None + + self.auth = hs.get_auth() + + self.clock = hs.get_clock() + + self.auth_handler = self.hs.get_auth_handler() + self.registration_handler = hs.get_registration_handler() + self._sso_handler = hs.get_sso_handler() + + self._well_known_builder = WellKnownBuilder(hs) + self._address_ratelimiter = Ratelimiter( + store=hs.get_datastore(), + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_address.per_second, + burst_count=self.hs.config.rc_login_address.burst_count, + ) + self._account_ratelimiter = Ratelimiter( + store=hs.get_datastore(), + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_account.per_second, + burst_count=self.hs.config.rc_login_account.burst_count, + ) + + def on_GET(self, request: SynapseRequest): + flows = [] + if self.jwt_enabled: + flows.append({"type": LoginRestServlet.JWT_TYPE}) + flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED}) + + if self.cas_enabled: + # we advertise CAS for backwards compat, though MSC1721 renamed it + # to SSO. + flows.append({"type": LoginRestServlet.CAS_TYPE}) + + if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: + sso_flow: JsonDict = { + "type": LoginRestServlet.SSO_TYPE, + "identity_providers": [ + _get_auth_flow_dict_for_idp( + idp, + ) + for idp in self._sso_handler.get_identity_providers().values() + ], + } + + if self._msc2858_enabled: + # backwards-compatibility support for clients which don't + # support the stable API yet + sso_flow["org.matrix.msc2858.identity_providers"] = [ + _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True) + for idp in self._sso_handler.get_identity_providers().values() + ] + + flows.append(sso_flow) + + # While it's valid for us to advertise this login type generally, + # synapse currently only gives out these tokens as part of the + # SSO login flow. + # Generally we don't want to advertise login flows that clients + # don't know how to implement, since they (currently) will always + # fall back to the fallback API if they don't understand one of the + # login flow types returned. + flows.append({"type": LoginRestServlet.TOKEN_TYPE}) + + flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types()) + + flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) + + return 200, {"flows": flows} + + async def on_POST(self, request: SynapseRequest): + login_submission = parse_json_object_from_request(request) + + if self._msc2918_enabled: + # Check if this login should also issue a refresh token, as per + # MSC2918 + should_issue_refresh_token = parse_boolean( + request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False + ) + else: + should_issue_refresh_token = False + + try: + if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: + appservice = self.auth.get_appservice_by_req(request) + + if appservice.is_rate_limited(): + await self._address_ratelimiter.ratelimit( + None, request.getClientIP() + ) + + result = await self._do_appservice_login( + login_submission, + appservice, + should_issue_refresh_token=should_issue_refresh_token, + ) + elif self.jwt_enabled and ( + login_submission["type"] == LoginRestServlet.JWT_TYPE + or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED + ): + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) + result = await self._do_jwt_login( + login_submission, + should_issue_refresh_token=should_issue_refresh_token, + ) + elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) + result = await self._do_token_login( + login_submission, + should_issue_refresh_token=should_issue_refresh_token, + ) + else: + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) + result = await self._do_other_login( + login_submission, + should_issue_refresh_token=should_issue_refresh_token, + ) + except KeyError: + raise SynapseError(400, "Missing JSON keys.") + + well_known_data = self._well_known_builder.get_well_known() + if well_known_data: + result["well_known"] = well_known_data + return 200, result + + async def _do_appservice_login( + self, + login_submission: JsonDict, + appservice: ApplicationService, + should_issue_refresh_token: bool = False, + ): + identifier = login_submission.get("identifier") + logger.info("Got appservice login request with identifier: %r", identifier) + + if not isinstance(identifier, dict): + raise SynapseError( + 400, "Invalid identifier in login submission", Codes.INVALID_PARAM + ) + + # this login flow only supports identifiers of type "m.id.user". + if identifier.get("type") != "m.id.user": + raise SynapseError( + 400, "Unknown login identifier type", Codes.INVALID_PARAM + ) + + user = identifier.get("user") + if not isinstance(user, str): + raise SynapseError(400, "Invalid user in identifier", Codes.INVALID_PARAM) + + if user.startswith("@"): + qualified_user_id = user + else: + qualified_user_id = UserID(user, self.hs.hostname).to_string() + + if not appservice.is_interested_in_user(qualified_user_id): + raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN) + + return await self._complete_login( + qualified_user_id, + login_submission, + ratelimit=appservice.is_rate_limited(), + should_issue_refresh_token=should_issue_refresh_token, + ) + + async def _do_other_login( + self, login_submission: JsonDict, should_issue_refresh_token: bool = False + ) -> LoginResponse: + """Handle non-token/saml/jwt logins + + Args: + login_submission: + should_issue_refresh_token: True if this login should issue + a refresh token alongside the access token. + + Returns: + HTTP response + """ + # Log the request we got, but only certain fields to minimise the chance of + # logging someone's password (even if they accidentally put it in the wrong + # field) + logger.info( + "Got login request with identifier: %r, medium: %r, address: %r, user: %r", + login_submission.get("identifier"), + login_submission.get("medium"), + login_submission.get("address"), + login_submission.get("user"), + ) + canonical_user_id, callback = await self.auth_handler.validate_login( + login_submission, ratelimit=True + ) + result = await self._complete_login( + canonical_user_id, + login_submission, + callback, + should_issue_refresh_token=should_issue_refresh_token, + ) + return result + + async def _complete_login( + self, + user_id: str, + login_submission: JsonDict, + callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None, + create_non_existent_users: bool = False, + ratelimit: bool = True, + auth_provider_id: Optional[str] = None, + should_issue_refresh_token: bool = False, + ) -> LoginResponse: + """Called when we've successfully authed the user and now need to + actually login them in (e.g. create devices). This gets called on + all successful logins. + + Applies the ratelimiting for successful login attempts against an + account. + + Args: + user_id: ID of the user to register. + login_submission: Dictionary of login information. + callback: Callback function to run after login. + create_non_existent_users: Whether to create the user if they don't + exist. Defaults to False. + ratelimit: Whether to ratelimit the login request. + auth_provider_id: The SSO IdP the user used, if any (just used for the + prometheus metrics). + should_issue_refresh_token: True if this login should issue + a refresh token alongside the access token. + + Returns: + result: Dictionary of account information after successful login. + """ + + # Before we actually log them in we check if they've already logged in + # too often. This happens here rather than before as we don't + # necessarily know the user before now. + if ratelimit: + await self._account_ratelimiter.ratelimit(None, user_id.lower()) + + if create_non_existent_users: + canonical_uid = await self.auth_handler.check_user_exists(user_id) + if not canonical_uid: + canonical_uid = await self.registration_handler.register_user( + localpart=UserID.from_string(user_id).localpart + ) + user_id = canonical_uid + + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get("initial_device_display_name") + ( + device_id, + access_token, + valid_until_ms, + refresh_token, + ) = await self.registration_handler.register_device( + user_id, + device_id, + initial_display_name, + auth_provider_id=auth_provider_id, + should_issue_refresh_token=should_issue_refresh_token, + ) + + result = LoginResponse( + user_id=user_id, + access_token=access_token, + home_server=self.hs.hostname, + device_id=device_id, + ) + + if valid_until_ms is not None: + expires_in_ms = valid_until_ms - self.clock.time_msec() + result["expires_in_ms"] = expires_in_ms + + if refresh_token is not None: + result["refresh_token"] = refresh_token + + if callback is not None: + await callback(result) + + return result + + async def _do_token_login( + self, login_submission: JsonDict, should_issue_refresh_token: bool = False + ) -> LoginResponse: + """ + Handle the final stage of SSO login. + + Args: + login_submission: The JSON request body. + should_issue_refresh_token: True if this login should issue + a refresh token alongside the access token. + + Returns: + The body of the JSON response. + """ + token = login_submission["token"] + auth_handler = self.auth_handler + res = await auth_handler.validate_short_term_login_token(token) + + return await self._complete_login( + res.user_id, + login_submission, + self.auth_handler._sso_login_callback, + auth_provider_id=res.auth_provider_id, + should_issue_refresh_token=should_issue_refresh_token, + ) + + async def _do_jwt_login( + self, login_submission: JsonDict, should_issue_refresh_token: bool = False + ) -> LoginResponse: + token = login_submission.get("token", None) + if token is None: + raise LoginError( + 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN + ) + + import jwt + + try: + payload = jwt.decode( + token, + self.jwt_secret, + algorithms=[self.jwt_algorithm], + issuer=self.jwt_issuer, + audience=self.jwt_audiences, + ) + except jwt.PyJWTError as e: + # A JWT error occurred, return some info back to the client. + raise LoginError( + 403, + "JWT validation failed: %s" % (str(e),), + errcode=Codes.FORBIDDEN, + ) + + user = payload.get("sub", None) + if user is None: + raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) + + user_id = UserID(user, self.hs.hostname).to_string() + result = await self._complete_login( + user_id, + login_submission, + create_non_existent_users=True, + should_issue_refresh_token=should_issue_refresh_token, + ) + return result + + +def _get_auth_flow_dict_for_idp( + idp: SsoIdentityProvider, use_unstable_brands: bool = False +) -> JsonDict: + """Return an entry for the login flow dict + + Returns an entry suitable for inclusion in "identity_providers" in the + response to GET /_matrix/client/r0/login + + Args: + idp: the identity provider to describe + use_unstable_brands: whether we should use brand identifiers suitable + for the unstable API + """ + e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name} + if idp.idp_icon: + e["icon"] = idp.idp_icon + if idp.idp_brand: + e["brand"] = idp.idp_brand + # use the stable brand identifier if the unstable identifier isn't defined. + if use_unstable_brands and idp.unstable_idp_brand: + e["brand"] = idp.unstable_idp_brand + return e + + +class RefreshTokenServlet(RestServlet): + PATTERNS = client_patterns( + "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True + ) + + def __init__(self, hs: "HomeServer"): + self._auth_handler = hs.get_auth_handler() + self._clock = hs.get_clock() + self.access_token_lifetime = hs.config.access_token_lifetime + + async def on_POST( + self, + request: SynapseRequest, + ): + refresh_submission = parse_json_object_from_request(request) + + assert_params_in_dict(refresh_submission, ["refresh_token"]) + token = refresh_submission["refresh_token"] + if not isinstance(token, str): + raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM) + + valid_until_ms = self._clock.time_msec() + self.access_token_lifetime + access_token, refresh_token = await self._auth_handler.refresh_token( + token, valid_until_ms + ) + expires_in_ms = valid_until_ms - self._clock.time_msec() + return ( + 200, + { + "access_token": access_token, + "refresh_token": refresh_token, + "expires_in_ms": expires_in_ms, + }, + ) + + +class SsoRedirectServlet(RestServlet): + PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [ + re.compile( + "^" + + CLIENT_API_PREFIX + + "/r0/login/sso/redirect/(?P[A-Za-z0-9_.~-]+)$" + ) + ] + + def __init__(self, hs: "HomeServer"): + # make sure that the relevant handlers are instantiated, so that they + # register themselves with the main SSOHandler. + if hs.config.cas_enabled: + hs.get_cas_handler() + if hs.config.saml2_enabled: + hs.get_saml_handler() + if hs.config.oidc_enabled: + hs.get_oidc_handler() + self._sso_handler = hs.get_sso_handler() + self._msc2858_enabled = hs.config.experimental.msc2858_enabled + self._public_baseurl = hs.config.public_baseurl + + def register(self, http_server: HttpServer) -> None: + super().register(http_server) + if self._msc2858_enabled: + # expose additional endpoint for MSC2858 support: backwards-compat support + # for clients which don't yet support the stable endpoints. + http_server.register_paths( + "GET", + client_patterns( + "/org.matrix.msc2858/login/sso/redirect/(?P[A-Za-z0-9_.~-]+)$", + releases=(), + unstable=True, + ), + self.on_GET, + self.__class__.__name__, + ) + + async def on_GET( + self, request: SynapseRequest, idp_id: Optional[str] = None + ) -> None: + if not self._public_baseurl: + raise SynapseError(400, "SSO requires a valid public_baseurl") + + # if this isn't the expected hostname, redirect to the right one, so that we + # get our cookies back. + requested_uri = get_request_uri(request) + baseurl_bytes = self._public_baseurl.encode("utf-8") + if not requested_uri.startswith(baseurl_bytes): + # swap out the incorrect base URL for the right one. + # + # The idea here is to redirect from + # https://foo.bar/whatever/_matrix/... + # to + # https://public.baseurl/_matrix/... + # + i = requested_uri.index(b"/_matrix") + new_uri = baseurl_bytes[:-1] + requested_uri[i:] + logger.info( + "Requested URI %s is not canonical: redirecting to %s", + requested_uri.decode("utf-8", errors="replace"), + new_uri.decode("utf-8", errors="replace"), + ) + request.redirect(new_uri) + finish_request(request) + return + + args: Dict[bytes, List[bytes]] = request.args # type: ignore + client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True) + sso_url = await self._sso_handler.handle_redirect_request( + request, + client_redirect_url, + idp_id, + ) + logger.info("Redirecting to %s", sso_url) + request.redirect(sso_url) + finish_request(request) + + +class CasTicketServlet(RestServlet): + PATTERNS = client_patterns("/login/cas/ticket", v1=True) + + def __init__(self, hs): + super().__init__() + self._cas_handler = hs.get_cas_handler() + + async def on_GET(self, request: SynapseRequest) -> None: + client_redirect_url = parse_string(request, "redirectUrl") + ticket = parse_string(request, "ticket", required=True) + + # Maybe get a session ID (if this ticket is from user interactive + # authentication). + session = parse_string(request, "session") + + # Either client_redirect_url or session must be provided. + if not client_redirect_url and not session: + message = "Missing string query parameter redirectUrl or session" + raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + + await self._cas_handler.handle_ticket( + request, ticket, client_redirect_url, session + ) + + +def register_servlets(hs, http_server): + LoginRestServlet(hs).register(http_server) + if hs.config.access_token_lifetime is not None: + RefreshTokenServlet(hs).register(http_server) + SsoRedirectServlet(hs).register(http_server) + if hs.config.cas_enabled: + CasTicketServlet(hs).register(http_server) diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py new file mode 100644 index 0000000000..6055cac2bd --- /dev/null +++ b/synapse/rest/client/logout.py @@ -0,0 +1,72 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.http.servlet import RestServlet +from synapse.rest.client._base import client_patterns + +logger = logging.getLogger(__name__) + + +class LogoutRestServlet(RestServlet): + PATTERNS = client_patterns("/logout$", v1=True) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() + + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_expired=True) + + if requester.device_id is None: + # The access token wasn't associated with a device. + # Just delete the access token + access_token = self.auth.get_access_token_from_request(request) + await self._auth_handler.delete_access_token(access_token) + else: + await self._device_handler.delete_device( + requester.user.to_string(), requester.device_id + ) + + return 200, {} + + +class LogoutAllRestServlet(RestServlet): + PATTERNS = client_patterns("/logout/all$", v1=True) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() + + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_expired=True) + user_id = requester.user.to_string() + + # first delete all of the user's devices + await self._device_handler.delete_all_devices_for_user(user_id) + + # .. and then delete any access tokens which weren't associated with + # devices. + await self._auth_handler.delete_access_tokens_for_user(user_id) + return 200, {} + + +def register_servlets(hs, http_server): + LogoutRestServlet(hs).register(http_server) + LogoutAllRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py new file mode 100644 index 0000000000..0ede643c2d --- /dev/null +++ b/synapse/rest/client/notifications.py @@ -0,0 +1,91 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.events.utils import format_event_for_client_v2_without_room_id +from synapse.http.servlet import RestServlet, parse_integer, parse_string + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class NotificationsServlet(RestServlet): + PATTERNS = client_patterns("/notifications$") + + def __init__(self, hs): + super().__init__() + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + from_token = parse_string(request, "from", required=False) + limit = parse_integer(request, "limit", default=50) + only = parse_string(request, "only", required=False) + + limit = min(limit, 500) + + push_actions = await self.store.get_push_actions_for_user( + user_id, from_token, limit, only_highlight=(only == "highlight") + ) + + receipts_by_room = await self.store.get_receipts_for_user_with_orderings( + user_id, "m.read" + ) + + notif_event_ids = [pa["event_id"] for pa in push_actions] + notif_events = await self.store.get_events(notif_event_ids) + + returned_push_actions = [] + + next_token = None + + for pa in push_actions: + returned_pa = { + "room_id": pa["room_id"], + "profile_tag": pa["profile_tag"], + "actions": pa["actions"], + "ts": pa["received_ts"], + "event": ( + await self._event_serializer.serialize_event( + notif_events[pa["event_id"]], + self.clock.time_msec(), + event_format=format_event_for_client_v2_without_room_id, + ) + ), + } + + if pa["room_id"] not in receipts_by_room: + returned_pa["read"] = False + else: + receipt = receipts_by_room[pa["room_id"]] + + returned_pa["read"] = ( + receipt["topological_ordering"], + receipt["stream_ordering"], + ) >= (pa["topological_ordering"], pa["stream_ordering"]) + returned_push_actions.append(returned_pa) + next_token = str(pa["stream_ordering"]) + + return 200, {"notifications": returned_push_actions, "next_token": next_token} + + +def register_servlets(hs, http_server): + NotificationsServlet(hs).register(http_server) diff --git a/synapse/rest/client/openid.py b/synapse/rest/client/openid.py new file mode 100644 index 0000000000..e8d2673819 --- /dev/null +++ b/synapse/rest/client/openid.py @@ -0,0 +1,94 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +from synapse.api.errors import AuthError +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.util.stringutils import random_string + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class IdTokenServlet(RestServlet): + """ + Get a bearer token that may be passed to a third party to confirm ownership + of a matrix user id. + + The format of the response could be made compatible with the format given + in http://openid.net/specs/openid-connect-core-1_0.html#TokenResponse + + But instead of returning a signed "id_token" the response contains the + name of the issuing matrix homeserver. This means that for now the third + party will need to check the validity of the "id_token" against the + federation /openid/userinfo endpoint of the homeserver. + + Request: + + POST /user/{user_id}/openid/request_token?access_token=... HTTP/1.1 + + {} + + Response: + + HTTP/1.1 200 OK + { + "access_token": "ABDEFGH", + "token_type": "Bearer", + "matrix_server_name": "example.com", + "expires_in": 3600, + } + """ + + PATTERNS = client_patterns("/user/(?P[^/]*)/openid/request_token") + + EXPIRES_MS = 3600 * 1000 + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.server_name = hs.config.server_name + + async def on_POST(self, request, user_id): + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot request tokens for other users.") + + # Parse the request body to make sure it's JSON, but ignore the contents + # for now. + parse_json_object_from_request(request) + + token = random_string(24) + ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS + + await self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) + + return ( + 200, + { + "access_token": token, + "token_type": "Bearer", + "matrix_server_name": self.server_name, + "expires_in": self.EXPIRES_MS // 1000, + }, + ) + + +def register_servlets(hs, http_server): + IdTokenServlet(hs).register(http_server) diff --git a/synapse/rest/client/password_policy.py b/synapse/rest/client/password_policy.py new file mode 100644 index 0000000000..a83927aee6 --- /dev/null +++ b/synapse/rest/client/password_policy.py @@ -0,0 +1,57 @@ +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.http.servlet import RestServlet + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class PasswordPolicyServlet(RestServlet): + PATTERNS = client_patterns("/password_policy$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + + self.policy = hs.config.password_policy + self.enabled = hs.config.password_policy_enabled + + def on_GET(self, request): + if not self.enabled or not self.policy: + return (200, {}) + + policy = {} + + for param in [ + "minimum_length", + "require_digit", + "require_symbol", + "require_lowercase", + "require_uppercase", + ]: + if param in self.policy: + policy["m.%s" % param] = self.policy[param] + + return (200, policy) + + +def register_servlets(hs, http_server): + PasswordPolicyServlet(hs).register(http_server) diff --git a/synapse/rest/client/presence.py b/synapse/rest/client/presence.py new file mode 100644 index 0000000000..6c27e5faf9 --- /dev/null +++ b/synapse/rest/client/presence.py @@ -0,0 +1,95 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" This module contains REST servlets to do with presence: /presence/ +""" +import logging + +from synapse.api.errors import AuthError, SynapseError +from synapse.handlers.presence import format_user_presence_state +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.rest.client._base import client_patterns +from synapse.types import UserID + +logger = logging.getLogger(__name__) + + +class PresenceStatusRestServlet(RestServlet): + PATTERNS = client_patterns("/presence/(?P[^/]*)/status", v1=True) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.presence_handler = hs.get_presence_handler() + self.clock = hs.get_clock() + self.auth = hs.get_auth() + + self._use_presence = hs.config.server.use_presence + + async def on_GET(self, request, user_id): + requester = await self.auth.get_user_by_req(request) + user = UserID.from_string(user_id) + + if not self._use_presence: + return 200, {"presence": "offline"} + + if requester.user != user: + allowed = await self.presence_handler.is_visible( + observed_user=user, observer_user=requester.user + ) + + if not allowed: + raise AuthError(403, "You are not allowed to see their presence.") + + state = await self.presence_handler.get_state(target_user=user) + state = format_user_presence_state( + state, self.clock.time_msec(), include_user_id=False + ) + + return 200, state + + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request) + user = UserID.from_string(user_id) + + if requester.user != user: + raise AuthError(403, "Can only set your own presence state") + + state = {} + + content = parse_json_object_from_request(request) + + try: + state["presence"] = content.pop("presence") + + if "status_msg" in content: + state["status_msg"] = content.pop("status_msg") + if not isinstance(state["status_msg"], str): + raise SynapseError(400, "status_msg must be a string.") + + if content: + raise KeyError() + except SynapseError as e: + raise e + except Exception: + raise SynapseError(400, "Unable to parse state") + + if self._use_presence: + await self.presence_handler.set_state(user, state) + + return 200, {} + + +def register_servlets(hs, http_server): + PresenceStatusRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py new file mode 100644 index 0000000000..5463ed2c4f --- /dev/null +++ b/synapse/rest/client/profile.py @@ -0,0 +1,155 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" This module contains REST servlets to do with profile: /profile/ """ + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.rest.client._base import client_patterns +from synapse.types import UserID + + +class ProfileDisplaynameRestServlet(RestServlet): + PATTERNS = client_patterns("/profile/(?P[^/]*)/displayname", v1=True) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.profile_handler = hs.get_profile_handler() + self.auth = hs.get_auth() + + async def on_GET(self, request, user_id): + requester_user = None + + if self.hs.config.require_auth_for_profile_requests: + requester = await self.auth.get_user_by_req(request) + requester_user = requester.user + + user = UserID.from_string(user_id) + + await self.profile_handler.check_profile_query_allowed(user, requester_user) + + displayname = await self.profile_handler.get_displayname(user) + + ret = {} + if displayname is not None: + ret["displayname"] = displayname + + return 200, ret + + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user = UserID.from_string(user_id) + is_admin = await self.auth.is_server_admin(requester.user) + + content = parse_json_object_from_request(request) + + try: + new_name = content["displayname"] + except Exception: + raise SynapseError( + code=400, + msg="Unable to parse name", + errcode=Codes.BAD_JSON, + ) + + await self.profile_handler.set_displayname(user, requester, new_name, is_admin) + + return 200, {} + + +class ProfileAvatarURLRestServlet(RestServlet): + PATTERNS = client_patterns("/profile/(?P[^/]*)/avatar_url", v1=True) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.profile_handler = hs.get_profile_handler() + self.auth = hs.get_auth() + + async def on_GET(self, request, user_id): + requester_user = None + + if self.hs.config.require_auth_for_profile_requests: + requester = await self.auth.get_user_by_req(request) + requester_user = requester.user + + user = UserID.from_string(user_id) + + await self.profile_handler.check_profile_query_allowed(user, requester_user) + + avatar_url = await self.profile_handler.get_avatar_url(user) + + ret = {} + if avatar_url is not None: + ret["avatar_url"] = avatar_url + + return 200, ret + + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request) + user = UserID.from_string(user_id) + is_admin = await self.auth.is_server_admin(requester.user) + + content = parse_json_object_from_request(request) + try: + new_avatar_url = content["avatar_url"] + except KeyError: + raise SynapseError( + 400, "Missing key 'avatar_url'", errcode=Codes.MISSING_PARAM + ) + + await self.profile_handler.set_avatar_url( + user, requester, new_avatar_url, is_admin + ) + + return 200, {} + + +class ProfileRestServlet(RestServlet): + PATTERNS = client_patterns("/profile/(?P[^/]*)", v1=True) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.profile_handler = hs.get_profile_handler() + self.auth = hs.get_auth() + + async def on_GET(self, request, user_id): + requester_user = None + + if self.hs.config.require_auth_for_profile_requests: + requester = await self.auth.get_user_by_req(request) + requester_user = requester.user + + user = UserID.from_string(user_id) + + await self.profile_handler.check_profile_query_allowed(user, requester_user) + + displayname = await self.profile_handler.get_displayname(user) + avatar_url = await self.profile_handler.get_avatar_url(user) + + ret = {} + if displayname is not None: + ret["displayname"] = displayname + if avatar_url is not None: + ret["avatar_url"] = avatar_url + + return 200, ret + + +def register_servlets(hs, http_server): + ProfileDisplaynameRestServlet(hs).register(http_server) + ProfileAvatarURLRestServlet(hs).register(http_server) + ProfileRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py new file mode 100644 index 0000000000..702b351d18 --- /dev/null +++ b/synapse/rest/client/push_rule.py @@ -0,0 +1,354 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.api.errors import ( + NotFoundError, + StoreError, + SynapseError, + UnrecognizedRequestError, +) +from synapse.http.servlet import ( + RestServlet, + parse_json_value_from_request, + parse_string, +) +from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS +from synapse.push.clientformat import format_push_rules_for_user +from synapse.push.rulekinds import PRIORITY_CLASS_MAP +from synapse.rest.client._base import client_patterns +from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException + + +class PushRuleRestServlet(RestServlet): + PATTERNS = client_patterns("/(?Ppushrules/.*)$", v1=True) + SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( + "Unrecognised request: You probably wanted a trailing slash" + ) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + self._is_worker = hs.config.worker_app is not None + + self._users_new_default_push_rules = hs.config.users_new_default_push_rules + + async def on_PUT(self, request, path): + if self._is_worker: + raise Exception("Cannot handle PUT /push_rules on worker") + + spec = _rule_spec_from_path(path.split("/")) + try: + priority_class = _priority_class_from_spec(spec) + except InvalidRuleException as e: + raise SynapseError(400, str(e)) + + requester = await self.auth.get_user_by_req(request) + + if "/" in spec["rule_id"] or "\\" in spec["rule_id"]: + raise SynapseError(400, "rule_id may not contain slashes") + + content = parse_json_value_from_request(request) + + user_id = requester.user.to_string() + + if "attr" in spec: + await self.set_rule_attr(user_id, spec, content) + self.notify_user(user_id) + return 200, {} + + if spec["rule_id"].startswith("."): + # Rule ids starting with '.' are reserved for server default rules. + raise SynapseError(400, "cannot add new rule_ids that start with '.'") + + try: + (conditions, actions) = _rule_tuple_from_request_object( + spec["template"], spec["rule_id"], content + ) + except InvalidRuleException as e: + raise SynapseError(400, str(e)) + + before = parse_string(request, "before") + if before: + before = _namespaced_rule_id(spec, before) + + after = parse_string(request, "after") + if after: + after = _namespaced_rule_id(spec, after) + + try: + await self.store.add_push_rule( + user_id=user_id, + rule_id=_namespaced_rule_id_from_spec(spec), + priority_class=priority_class, + conditions=conditions, + actions=actions, + before=before, + after=after, + ) + self.notify_user(user_id) + except InconsistentRuleException as e: + raise SynapseError(400, str(e)) + except RuleNotFoundException as e: + raise SynapseError(400, str(e)) + + return 200, {} + + async def on_DELETE(self, request, path): + if self._is_worker: + raise Exception("Cannot handle DELETE /push_rules on worker") + + spec = _rule_spec_from_path(path.split("/")) + + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + namespaced_rule_id = _namespaced_rule_id_from_spec(spec) + + try: + await self.store.delete_push_rule(user_id, namespaced_rule_id) + self.notify_user(user_id) + return 200, {} + except StoreError as e: + if e.code == 404: + raise NotFoundError() + else: + raise + + async def on_GET(self, request, path): + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + # we build up the full structure and then decide which bits of it + # to send which means doing unnecessary work sometimes but is + # is probably not going to make a whole lot of difference + rules = await self.store.get_push_rules_for_user(user_id) + + rules = format_push_rules_for_user(requester.user, rules) + + path = path.split("/")[1:] + + if path == []: + # we're a reference impl: pedantry is our job. + raise UnrecognizedRequestError( + PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR + ) + + if path[0] == "": + return 200, rules + elif path[0] == "global": + result = _filter_ruleset_with_path(rules["global"], path[1:]) + return 200, result + else: + raise UnrecognizedRequestError() + + def notify_user(self, user_id): + stream_id = self.store.get_max_push_rules_stream_id() + self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) + + async def set_rule_attr(self, user_id, spec, val): + if spec["attr"] not in ("enabled", "actions"): + # for the sake of potential future expansion, shouldn't report + # 404 in the case of an unknown request so check it corresponds to + # a known attribute first. + raise UnrecognizedRequestError() + + namespaced_rule_id = _namespaced_rule_id_from_spec(spec) + rule_id = spec["rule_id"] + is_default_rule = rule_id.startswith(".") + if is_default_rule: + if namespaced_rule_id not in BASE_RULE_IDS: + raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,)) + if spec["attr"] == "enabled": + if isinstance(val, dict) and "enabled" in val: + val = val["enabled"] + if not isinstance(val, bool): + # Legacy fallback + # This should *actually* take a dict, but many clients pass + # bools directly, so let's not break them. + raise SynapseError(400, "Value for 'enabled' must be boolean") + return await self.store.set_push_rule_enabled( + user_id, namespaced_rule_id, val, is_default_rule + ) + elif spec["attr"] == "actions": + actions = val.get("actions") + _check_actions(actions) + namespaced_rule_id = _namespaced_rule_id_from_spec(spec) + rule_id = spec["rule_id"] + is_default_rule = rule_id.startswith(".") + if is_default_rule: + if user_id in self._users_new_default_push_rules: + rule_ids = NEW_RULE_IDS + else: + rule_ids = BASE_RULE_IDS + + if namespaced_rule_id not in rule_ids: + raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) + return await self.store.set_push_rule_actions( + user_id, namespaced_rule_id, actions, is_default_rule + ) + else: + raise UnrecognizedRequestError() + + +def _rule_spec_from_path(path): + """Turn a sequence of path components into a rule spec + + Args: + path (sequence[unicode]): the URL path components. + + Returns: + dict: rule spec dict, containing scope/template/rule_id entries, + and possibly attr. + + Raises: + UnrecognizedRequestError if the path components cannot be parsed. + """ + if len(path) < 2: + raise UnrecognizedRequestError() + if path[0] != "pushrules": + raise UnrecognizedRequestError() + + scope = path[1] + path = path[2:] + if scope != "global": + raise UnrecognizedRequestError() + + if len(path) == 0: + raise UnrecognizedRequestError() + + template = path[0] + path = path[1:] + + if len(path) == 0 or len(path[0]) == 0: + raise UnrecognizedRequestError() + + rule_id = path[0] + + spec = {"scope": scope, "template": template, "rule_id": rule_id} + + path = path[1:] + + if len(path) > 0 and len(path[0]) > 0: + spec["attr"] = path[0] + + return spec + + +def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): + if rule_template in ["override", "underride"]: + if "conditions" not in req_obj: + raise InvalidRuleException("Missing 'conditions'") + conditions = req_obj["conditions"] + for c in conditions: + if "kind" not in c: + raise InvalidRuleException("Condition without 'kind'") + elif rule_template == "room": + conditions = [{"kind": "event_match", "key": "room_id", "pattern": rule_id}] + elif rule_template == "sender": + conditions = [{"kind": "event_match", "key": "user_id", "pattern": rule_id}] + elif rule_template == "content": + if "pattern" not in req_obj: + raise InvalidRuleException("Content rule missing 'pattern'") + pat = req_obj["pattern"] + + conditions = [{"kind": "event_match", "key": "content.body", "pattern": pat}] + else: + raise InvalidRuleException("Unknown rule template: %s" % (rule_template,)) + + if "actions" not in req_obj: + raise InvalidRuleException("No actions found") + actions = req_obj["actions"] + + _check_actions(actions) + + return conditions, actions + + +def _check_actions(actions): + if not isinstance(actions, list): + raise InvalidRuleException("No actions found") + + for a in actions: + if a in ["notify", "dont_notify", "coalesce"]: + pass + elif isinstance(a, dict) and "set_tweak" in a: + pass + else: + raise InvalidRuleException("Unrecognised action") + + +def _filter_ruleset_with_path(ruleset, path): + if path == []: + raise UnrecognizedRequestError( + PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR + ) + + if path[0] == "": + return ruleset + template_kind = path[0] + if template_kind not in ruleset: + raise UnrecognizedRequestError() + path = path[1:] + if path == []: + raise UnrecognizedRequestError( + PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR + ) + if path[0] == "": + return ruleset[template_kind] + rule_id = path[0] + + the_rule = None + for r in ruleset[template_kind]: + if r["rule_id"] == rule_id: + the_rule = r + if the_rule is None: + raise NotFoundError + + path = path[1:] + if len(path) == 0: + return the_rule + + attr = path[0] + if attr in the_rule: + # Make sure we return a JSON object as the attribute may be a + # JSON value. + return {attr: the_rule[attr]} + else: + raise UnrecognizedRequestError() + + +def _priority_class_from_spec(spec): + if spec["template"] not in PRIORITY_CLASS_MAP.keys(): + raise InvalidRuleException("Unknown template: %s" % (spec["template"])) + pc = PRIORITY_CLASS_MAP[spec["template"]] + + return pc + + +def _namespaced_rule_id_from_spec(spec): + return _namespaced_rule_id(spec, spec["rule_id"]) + + +def _namespaced_rule_id(spec, rule_id): + return "global/%s/%s" % (spec["template"], rule_id) + + +class InvalidRuleException(Exception): + pass + + +def register_servlets(hs, http_server): + PushRuleRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py new file mode 100644 index 0000000000..84619c5e41 --- /dev/null +++ b/synapse/rest/client/pusher.py @@ -0,0 +1,171 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import Codes, StoreError, SynapseError +from synapse.http.server import respond_with_html_bytes +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, + parse_string, +) +from synapse.push import PusherConfigException +from synapse.rest.client._base import client_patterns + +logger = logging.getLogger(__name__) + + +class PushersRestServlet(RestServlet): + PATTERNS = client_patterns("/pushers$", v1=True) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) + user = requester.user + + pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) + + filtered_pushers = [p.as_dict() for p in pushers] + + return 200, {"pushers": filtered_pushers} + + +class PushersSetRestServlet(RestServlet): + PATTERNS = client_patterns("/pushers/set$", v1=True) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.notifier = hs.get_notifier() + self.pusher_pool = self.hs.get_pusherpool() + + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) + user = requester.user + + content = parse_json_object_from_request(request) + + if ( + "pushkey" in content + and "app_id" in content + and "kind" in content + and content["kind"] is None + ): + await self.pusher_pool.remove_pusher( + content["app_id"], content["pushkey"], user_id=user.to_string() + ) + return 200, {} + + assert_params_in_dict( + content, + [ + "kind", + "app_id", + "app_display_name", + "device_display_name", + "pushkey", + "lang", + "data", + ], + ) + + logger.debug("set pushkey %s to kind %s", content["pushkey"], content["kind"]) + logger.debug("Got pushers request with body: %r", content) + + append = False + if "append" in content: + append = content["append"] + + if not append: + await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( + app_id=content["app_id"], + pushkey=content["pushkey"], + not_user_id=user.to_string(), + ) + + try: + await self.pusher_pool.add_pusher( + user_id=user.to_string(), + access_token=requester.access_token_id, + kind=content["kind"], + app_id=content["app_id"], + app_display_name=content["app_display_name"], + device_display_name=content["device_display_name"], + pushkey=content["pushkey"], + lang=content["lang"], + data=content["data"], + profile_tag=content.get("profile_tag", ""), + ) + except PusherConfigException as pce: + raise SynapseError( + 400, "Config Error: " + str(pce), errcode=Codes.MISSING_PARAM + ) + + self.notifier.on_new_replication_data() + + return 200, {} + + +class PushersRemoveRestServlet(RestServlet): + """ + To allow pusher to be delete by clicking a link (ie. GET request) + """ + + PATTERNS = client_patterns("/pushers/remove$", v1=True) + SUCCESS_HTML = b"You have been unsubscribed" + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.notifier = hs.get_notifier() + self.auth = hs.get_auth() + self.pusher_pool = self.hs.get_pusherpool() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, rights="delete_pusher") + user = requester.user + + app_id = parse_string(request, "app_id", required=True) + pushkey = parse_string(request, "pushkey", required=True) + + try: + await self.pusher_pool.remove_pusher( + app_id=app_id, pushkey=pushkey, user_id=user.to_string() + ) + except StoreError as se: + if se.code != 404: + # This is fine: they're already unsubscribed + raise + + self.notifier.on_new_replication_data() + + respond_with_html_bytes( + request, + 200, + PushersRemoveRestServlet.SUCCESS_HTML, + ) + return None + + +def register_servlets(hs, http_server): + PushersRestServlet(hs).register(http_server) + PushersSetRestServlet(hs).register(http_server) + PushersRemoveRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py new file mode 100644 index 0000000000..027f8b81fa --- /dev/null +++ b/synapse/rest/client/read_marker.py @@ -0,0 +1,74 @@ +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.constants import ReadReceiptEventFields +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class ReadMarkerRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/read_markers$") + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.receipts_handler = hs.get_receipts_handler() + self.read_marker_handler = hs.get_read_marker_handler() + self.presence_handler = hs.get_presence_handler() + + async def on_POST(self, request, room_id): + requester = await self.auth.get_user_by_req(request) + + await self.presence_handler.bump_presence_active_time(requester.user) + + body = parse_json_object_from_request(request) + read_event_id = body.get("m.read", None) + hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) + + if not isinstance(hidden, bool): + raise SynapseError( + 400, + "Param %s must be a boolean, if given" + % ReadReceiptEventFields.MSC2285_HIDDEN, + Codes.BAD_JSON, + ) + + if read_event_id: + await self.receipts_handler.received_client_receipt( + room_id, + "m.read", + user_id=requester.user.to_string(), + event_id=read_event_id, + hidden=hidden, + ) + + read_marker_event_id = body.get("m.fully_read", None) + if read_marker_event_id: + await self.read_marker_handler.received_client_read_marker( + room_id, + user_id=requester.user.to_string(), + event_id=read_marker_event_id, + ) + + return 200, {} + + +def register_servlets(hs, http_server): + ReadMarkerRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py new file mode 100644 index 0000000000..d9ab836cd8 --- /dev/null +++ b/synapse/rest/client/receipts.py @@ -0,0 +1,71 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.constants import ReadReceiptEventFields +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class ReceiptRestServlet(RestServlet): + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)" + "/receipt/(?P[^/]*)" + "/(?P[^/]*)$" + ) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.receipts_handler = hs.get_receipts_handler() + self.presence_handler = hs.get_presence_handler() + + async def on_POST(self, request, room_id, receipt_type, event_id): + requester = await self.auth.get_user_by_req(request) + + if receipt_type != "m.read": + raise SynapseError(400, "Receipt type must be 'm.read'") + + body = parse_json_object_from_request(request, allow_empty_body=True) + hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) + + if not isinstance(hidden, bool): + raise SynapseError( + 400, + "Param %s must be a boolean, if given" + % ReadReceiptEventFields.MSC2285_HIDDEN, + Codes.BAD_JSON, + ) + + await self.presence_handler.bump_presence_active_time(requester.user) + + await self.receipts_handler.received_client_receipt( + room_id, + receipt_type, + user_id=requester.user.to_string(), + event_id=event_id, + hidden=hidden, + ) + + return 200, {} + + +def register_servlets(hs, http_server): + ReceiptRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py new file mode 100644 index 0000000000..58b8e8f261 --- /dev/null +++ b/synapse/rest/client/register.py @@ -0,0 +1,879 @@ +# Copyright 2015 - 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import hmac +import logging +import random +from typing import List, Union + +import synapse +import synapse.api.auth +import synapse.types +from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType +from synapse.api.errors import ( + Codes, + InteractiveAuthIncompleteError, + SynapseError, + ThreepidValidationError, + UnrecognizedRequestError, +) +from synapse.config import ConfigError +from synapse.config.captcha import CaptchaConfig +from synapse.config.consent import ConsentConfig +from synapse.config.emailconfig import ThreepidBehaviour +from synapse.config.ratelimiting import FederationRateLimitConfig +from synapse.config.registration import RegistrationConfig +from synapse.config.server import is_threepid_reserved +from synapse.handlers.auth import AuthHandler +from synapse.handlers.ui_auth import UIAuthSessionDataConstants +from synapse.http.server import finish_request, respond_with_html +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_boolean, + parse_json_object_from_request, + parse_string, +) +from synapse.metrics import threepid_send_requests +from synapse.push.mailer import Mailer +from synapse.types import JsonDict +from synapse.util.msisdn import phone_number_to_msisdn +from synapse.util.ratelimitutils import FederationRateLimiter +from synapse.util.stringutils import assert_valid_client_secret, random_string +from synapse.util.threepids import ( + canonicalise_email, + check_3pid_allowed, + validate_email, +) + +from ._base import client_patterns, interactive_auth_handler + +# We ought to be using hmac.compare_digest() but on older pythons it doesn't +# exist. It's a _really minor_ security flaw to use plain string comparison +# because the timing attack is so obscured by all the other code here it's +# unlikely to make much difference +if hasattr(hmac, "compare_digest"): + compare_digest = hmac.compare_digest +else: + + def compare_digest(a, b): + return a == b + + +logger = logging.getLogger(__name__) + + +class EmailRegisterRequestTokenRestServlet(RestServlet): + PATTERNS = client_patterns("/register/email/requestToken$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.identity_handler = hs.get_identity_handler() + self.config = hs.config + + if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + self.mailer = Mailer( + hs=self.hs, + app_name=self.config.email_app_name, + template_html=self.config.email_registration_template_html, + template_text=self.config.email_registration_template_text, + ) + + async def on_POST(self, request): + if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.hs.config.local_threepid_handling_disabled_due_to_email_config: + logger.warning( + "Email registration has been disabled due to lack of email config" + ) + raise SynapseError( + 400, "Email-based registration has been disabled on this server" + ) + body = parse_json_object_from_request(request) + + assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) + + # Extract params from body + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See on_POST in EmailThreepidRequestTokenRestServlet + # in synapse/rest/client/account.py) + try: + email = validate_email(body["email"]) + except ValueError as e: + raise SynapseError(400, str(e)) + send_attempt = body["send_attempt"] + next_link = body.get("next_link") # Optional param + + if not check_3pid_allowed(self.hs, "email", email): + raise SynapseError( + 403, + "Your email domain is not authorized to register on this server", + Codes.THREEPID_DENIED, + ) + + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) + + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + "email", email + ) + + if existing_user_id is not None: + if self.hs.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) + return 200, {"sid": random_string(16)} + + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) + + if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + assert self.hs.config.account_threepid_delegate_email + + # Have the configured identity server handle the request + ret = await self.identity_handler.requestEmailToken( + self.hs.config.account_threepid_delegate_email, + email, + client_secret, + send_attempt, + next_link, + ) + else: + # Send registration emails from Synapse + sid = await self.identity_handler.send_threepid_validation( + email, + client_secret, + send_attempt, + self.mailer.send_registration_mail, + next_link, + ) + + # Wrap the session id in a JSON object + ret = {"sid": sid} + + threepid_send_requests.labels(type="email", reason="register").observe( + send_attempt + ) + + return 200, ret + + +class MsisdnRegisterRequestTokenRestServlet(RestServlet): + PATTERNS = client_patterns("/register/msisdn/requestToken$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.identity_handler = hs.get_identity_handler() + + async def on_POST(self, request): + body = parse_json_object_from_request(request) + + assert_params_in_dict( + body, ["client_secret", "country", "phone_number", "send_attempt"] + ) + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + country = body["country"] + phone_number = body["phone_number"] + send_attempt = body["send_attempt"] + next_link = body.get("next_link") # Optional param + + msisdn = phone_number_to_msisdn(country, phone_number) + + if not check_3pid_allowed(self.hs, "msisdn", msisdn): + raise SynapseError( + 403, + "Phone numbers are not authorized to register on this server", + Codes.THREEPID_DENIED, + ) + + await self.identity_handler.ratelimit_request_token_requests( + request, "msisdn", msisdn + ) + + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + "msisdn", msisdn + ) + + if existing_user_id is not None: + if self.hs.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) + return 200, {"sid": random_string(16)} + + raise SynapseError( + 400, "Phone number is already in use", Codes.THREEPID_IN_USE + ) + + if not self.hs.config.account_threepid_delegate_msisdn: + logger.warning( + "No upstream msisdn account_threepid_delegate configured on the server to " + "handle this request" + ) + raise SynapseError( + 400, "Registration by phone number is not supported on this homeserver" + ) + + ret = await self.identity_handler.requestMsisdnToken( + self.hs.config.account_threepid_delegate_msisdn, + country, + phone_number, + client_secret, + send_attempt, + next_link, + ) + + threepid_send_requests.labels(type="msisdn", reason="register").observe( + send_attempt + ) + + return 200, ret + + +class RegistrationSubmitTokenServlet(RestServlet): + """Handles registration 3PID validation token submission""" + + PATTERNS = client_patterns( + "/registration/(?P[^/]*)/submit_token$", releases=(), unstable=True + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.config = hs.config + self.clock = hs.get_clock() + self.store = hs.get_datastore() + + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + self._failure_email_template = ( + self.config.email_registration_template_failure_html + ) + + async def on_GET(self, request, medium): + if medium != "email": + raise SynapseError( + 400, "This medium is currently not supported for registration" + ) + if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.local_threepid_handling_disabled_due_to_email_config: + logger.warning( + "User registration via email has been disabled due to lack of email config" + ) + raise SynapseError( + 400, "Email-based registration is disabled on this server" + ) + + sid = parse_string(request, "sid", required=True) + client_secret = parse_string(request, "client_secret", required=True) + assert_valid_client_secret(client_secret) + token = parse_string(request, "token", required=True) + + # Attempt to validate a 3PID session + try: + # Mark the session as valid + next_link = await self.store.validate_threepid_session( + sid, client_secret, token, self.clock.time_msec() + ) + + # Perform a 302 redirect if next_link is set + if next_link: + if next_link.startswith("file:///"): + logger.warning( + "Not redirecting to next_link as it is a local file: address" + ) + else: + request.setResponseCode(302) + request.setHeader("Location", next_link) + finish_request(request) + return None + + # Otherwise show the success template + html = self.config.email_registration_template_success_html_content + status_code = 200 + except ThreepidValidationError as e: + status_code = e.code + + # Show a failure page with a reason + template_vars = {"failure_reason": e.msg} + html = self._failure_email_template.render(**template_vars) + + respond_with_html(request, status_code, html) + + +class UsernameAvailabilityRestServlet(RestServlet): + PATTERNS = client_patterns("/register/available") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.registration_handler = hs.get_registration_handler() + self.ratelimiter = FederationRateLimiter( + hs.get_clock(), + FederationRateLimitConfig( + # Time window of 2s + window_size=2000, + # Artificially delay requests if rate > sleep_limit/window_size + sleep_limit=1, + # Amount of artificial delay to apply + sleep_msec=1000, + # Error with 429 if more than reject_limit requests are queued + reject_limit=1, + # Allow 1 request at a time + concurrent_requests=1, + ), + ) + + async def on_GET(self, request): + if not self.hs.config.enable_registration: + raise SynapseError( + 403, "Registration has been disabled", errcode=Codes.FORBIDDEN + ) + + ip = request.getClientIP() + with self.ratelimiter.ratelimit(ip) as wait_deferred: + await wait_deferred + + username = parse_string(request, "username", required=True) + + await self.registration_handler.check_username(username) + + return 200, {"available": True} + + +class RegisterRestServlet(RestServlet): + PATTERNS = client_patterns("/register$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.auth_handler = hs.get_auth_handler() + self.registration_handler = hs.get_registration_handler() + self.identity_handler = hs.get_identity_handler() + self.room_member_handler = hs.get_room_member_handler() + self.macaroon_gen = hs.get_macaroon_generator() + self.ratelimiter = hs.get_registration_ratelimiter() + self.password_policy_handler = hs.get_password_policy_handler() + self.clock = hs.get_clock() + self._registration_enabled = self.hs.config.enable_registration + self._msc2918_enabled = hs.config.access_token_lifetime is not None + + self._registration_flows = _calculate_registration_flows( + hs.config, self.auth_handler + ) + + @interactive_auth_handler + async def on_POST(self, request): + body = parse_json_object_from_request(request) + + client_addr = request.getClientIP() + + await self.ratelimiter.ratelimit(None, client_addr, update=False) + + kind = b"user" + if b"kind" in request.args: + kind = request.args[b"kind"][0] + + if kind == b"guest": + ret = await self._do_guest_registration(body, address=client_addr) + return ret + elif kind != b"user": + raise UnrecognizedRequestError( + "Do not understand membership kind: %s" % (kind.decode("utf8"),) + ) + + if self._msc2918_enabled: + # Check if this registration should also issue a refresh token, as + # per MSC2918 + should_issue_refresh_token = parse_boolean( + request, name="org.matrix.msc2918.refresh_token", default=False + ) + else: + should_issue_refresh_token = False + + # Pull out the provided username and do basic sanity checks early since + # the auth layer will store these in sessions. + desired_username = None + if "username" in body: + if not isinstance(body["username"], str) or len(body["username"]) > 512: + raise SynapseError(400, "Invalid username") + desired_username = body["username"] + + # fork off as soon as possible for ASes which have completely + # different registration flows to normal users + + # == Application Service Registration == + if body.get("type") == APP_SERVICE_REGISTRATION_TYPE: + if not self.auth.has_access_token(request): + raise SynapseError( + 400, + "Appservice token must be provided when using a type of m.login.application_service", + ) + + # Verify the AS + self.auth.get_appservice_by_req(request) + + # Set the desired user according to the AS API (which uses the + # 'user' key not 'username'). Since this is a new addition, we'll + # fallback to 'username' if they gave one. + desired_username = body.get("user", desired_username) + + # XXX we should check that desired_username is valid. Currently + # we give appservices carte blanche for any insanity in mxids, + # because the IRC bridges rely on being able to register stupid + # IDs. + + access_token = self.auth.get_access_token_from_request(request) + + if not isinstance(desired_username, str): + raise SynapseError(400, "Desired Username is missing or not a string") + + result = await self._do_appservice_registration( + desired_username, + access_token, + body, + should_issue_refresh_token=should_issue_refresh_token, + ) + + return 200, result + elif self.auth.has_access_token(request): + raise SynapseError( + 400, + "An access token should not be provided on requests to /register (except if type is m.login.application_service)", + ) + + # == Normal User Registration == (everyone else) + if not self._registration_enabled: + raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN) + + # For regular registration, convert the provided username to lowercase + # before attempting to register it. This should mean that people who try + # to register with upper-case in their usernames don't get a nasty surprise. + # + # Note that we treat usernames case-insensitively in login, so they are + # free to carry on imagining that their username is CrAzYh4cKeR if that + # keeps them happy. + if desired_username is not None: + desired_username = desired_username.lower() + + # Check if this account is upgrading from a guest account. + guest_access_token = body.get("guest_access_token", None) + + # Pull out the provided password and do basic sanity checks early. + # + # Note that we remove the password from the body since the auth layer + # will store the body in the session and we don't want a plaintext + # password store there. + password = body.pop("password", None) + if password is not None: + if not isinstance(password, str) or len(password) > 512: + raise SynapseError(400, "Invalid password") + self.password_policy_handler.validate_password(password) + + if "initial_device_display_name" in body and password is None: + # ignore 'initial_device_display_name' if sent without + # a password to work around a client bug where it sent + # the 'initial_device_display_name' param alone, wiping out + # the original registration params + logger.warning("Ignoring initial_device_display_name without password") + del body["initial_device_display_name"] + + session_id = self.auth_handler.get_session_id(body) + registered_user_id = None + password_hash = None + if session_id: + # if we get a registered user id out of here, it means we previously + # registered a user for this session, so we could just return the + # user here. We carry on and go through the auth checks though, + # for paranoia. + registered_user_id = await self.auth_handler.get_session_data( + session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None + ) + # Extract the previously-hashed password from the session. + password_hash = await self.auth_handler.get_session_data( + session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None + ) + + # Ensure that the username is valid. + if desired_username is not None: + await self.registration_handler.check_username( + desired_username, + guest_access_token=guest_access_token, + assigned_user_id=registered_user_id, + ) + + # Check if the user-interactive authentication flows are complete, if + # not this will raise a user-interactive auth error. + try: + auth_result, params, session_id = await self.auth_handler.check_ui_auth( + self._registration_flows, + request, + body, + "register a new account", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth. + # + # Hash the password and store it with the session since the client + # is not required to provide the password again. + # + # If a password hash was previously stored we will not attempt to + # re-hash and store it for efficiency. This assumes the password + # does not change throughout the authentication flow, but this + # should be fine since the data is meant to be consistent. + if not password_hash and password: + password_hash = await self.auth_handler.hash(password) + await self.auth_handler.set_session_data( + e.session_id, + UIAuthSessionDataConstants.PASSWORD_HASH, + password_hash, + ) + raise + + # Check that we're not trying to register a denied 3pid. + # + # the user-facing checks will probably already have happened in + # /register/email/requestToken when we requested a 3pid, but that's not + # guaranteed. + if auth_result: + for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: + if login_type in auth_result: + medium = auth_result[login_type]["medium"] + address = auth_result[login_type]["address"] + + if not check_3pid_allowed(self.hs, medium, address): + raise SynapseError( + 403, + "Third party identifiers (email/phone numbers)" + + " are not authorized on this server", + Codes.THREEPID_DENIED, + ) + + if registered_user_id is not None: + logger.info( + "Already registered user ID %r for this session", registered_user_id + ) + # don't re-register the threepids + registered = False + else: + # If we have a password in this request, prefer it. Otherwise, there + # might be a password hash from an earlier request. + if password: + password_hash = await self.auth_handler.hash(password) + if not password_hash: + raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) + + desired_username = params.get("username", None) + guest_access_token = params.get("guest_access_token", None) + + if desired_username is not None: + desired_username = desired_username.lower() + + threepid = None + if auth_result: + threepid = auth_result.get(LoginType.EMAIL_IDENTITY) + + # Also check that we're not trying to register a 3pid that's already + # been registered. + # + # This has probably happened in /register/email/requestToken as well, + # but if a user hits this endpoint twice then clicks on each link from + # the two activation emails, they would register the same 3pid twice. + for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: + if login_type in auth_result: + medium = auth_result[login_type]["medium"] + address = auth_result[login_type]["address"] + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See on_POST in EmailThreepidRequestTokenRestServlet + # in synapse/rest/client/account.py) + if medium == "email": + try: + address = canonicalise_email(address) + except ValueError as e: + raise SynapseError(400, str(e)) + + existing_user_id = await self.store.get_user_id_by_threepid( + medium, address + ) + + if existing_user_id is not None: + raise SynapseError( + 400, + "%s is already in use" % medium, + Codes.THREEPID_IN_USE, + ) + + entries = await self.store.get_user_agents_ips_to_ui_auth_session( + session_id + ) + + registered_user_id = await self.registration_handler.register_user( + localpart=desired_username, + password_hash=password_hash, + guest_access_token=guest_access_token, + threepid=threepid, + address=client_addr, + user_agent_ips=entries, + ) + # Necessary due to auth checks prior to the threepid being + # written to the db + if threepid: + if is_threepid_reserved( + self.hs.config.mau_limits_reserved_threepids, threepid + ): + await self.store.upsert_monthly_active_user(registered_user_id) + + # Remember that the user account has been registered (and the user + # ID it was registered with, since it might not have been specified). + await self.auth_handler.set_session_data( + session_id, + UIAuthSessionDataConstants.REGISTERED_USER_ID, + registered_user_id, + ) + + registered = True + + return_dict = await self._create_registration_details( + registered_user_id, + params, + should_issue_refresh_token=should_issue_refresh_token, + ) + + if registered: + await self.registration_handler.post_registration_actions( + user_id=registered_user_id, + auth_result=auth_result, + access_token=return_dict.get("access_token"), + ) + + return 200, return_dict + + async def _do_appservice_registration( + self, username, as_token, body, should_issue_refresh_token: bool = False + ): + user_id = await self.registration_handler.appservice_register( + username, as_token + ) + return await self._create_registration_details( + user_id, + body, + is_appservice_ghost=True, + should_issue_refresh_token=should_issue_refresh_token, + ) + + async def _create_registration_details( + self, + user_id: str, + params: JsonDict, + is_appservice_ghost: bool = False, + should_issue_refresh_token: bool = False, + ): + """Complete registration of newly-registered user + + Allocates device_id if one was not given; also creates access_token. + + Args: + user_id: full canonical @user:id + params: registration parameters, from which we pull device_id, + initial_device_name and inhibit_login + is_appservice_ghost + should_issue_refresh_token: True if this registration should issue + a refresh token alongside the access token. + Returns: + dictionary for response from /register + """ + result = {"user_id": user_id, "home_server": self.hs.hostname} + if not params.get("inhibit_login", False): + device_id = params.get("device_id") + initial_display_name = params.get("initial_device_display_name") + ( + device_id, + access_token, + valid_until_ms, + refresh_token, + ) = await self.registration_handler.register_device( + user_id, + device_id, + initial_display_name, + is_guest=False, + is_appservice_ghost=is_appservice_ghost, + should_issue_refresh_token=should_issue_refresh_token, + ) + + result.update({"access_token": access_token, "device_id": device_id}) + + if valid_until_ms is not None: + expires_in_ms = valid_until_ms - self.clock.time_msec() + result["expires_in_ms"] = expires_in_ms + + if refresh_token is not None: + result["refresh_token"] = refresh_token + + return result + + async def _do_guest_registration(self, params, address=None): + if not self.hs.config.allow_guest_access: + raise SynapseError(403, "Guest access is disabled") + user_id = await self.registration_handler.register_user( + make_guest=True, address=address + ) + + # we don't allow guests to specify their own device_id, because + # we have nowhere to store it. + device_id = synapse.api.auth.GUEST_DEVICE_ID + initial_display_name = params.get("initial_device_display_name") + ( + device_id, + access_token, + valid_until_ms, + refresh_token, + ) = await self.registration_handler.register_device( + user_id, device_id, initial_display_name, is_guest=True + ) + + result = { + "user_id": user_id, + "device_id": device_id, + "access_token": access_token, + "home_server": self.hs.hostname, + } + + if valid_until_ms is not None: + expires_in_ms = valid_until_ms - self.clock.time_msec() + result["expires_in_ms"] = expires_in_ms + + if refresh_token is not None: + result["refresh_token"] = refresh_token + + return 200, result + + +def _calculate_registration_flows( + # technically `config` has to provide *all* of these interfaces, not just one + config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig], + auth_handler: AuthHandler, +) -> List[List[str]]: + """Get a suitable flows list for registration + + Args: + config: server configuration + auth_handler: authorization handler + + Returns: a list of supported flows + """ + # FIXME: need a better error than "no auth flow found" for scenarios + # where we required 3PID for registration but the user didn't give one + require_email = "email" in config.registrations_require_3pid + require_msisdn = "msisdn" in config.registrations_require_3pid + + show_msisdn = True + show_email = True + + if config.disable_msisdn_registration: + show_msisdn = False + require_msisdn = False + + enabled_auth_types = auth_handler.get_enabled_auth_types() + if LoginType.EMAIL_IDENTITY not in enabled_auth_types: + show_email = False + if require_email: + raise ConfigError( + "Configuration requires email address at registration, but email " + "validation is not configured" + ) + + if LoginType.MSISDN not in enabled_auth_types: + show_msisdn = False + if require_msisdn: + raise ConfigError( + "Configuration requires msisdn at registration, but msisdn " + "validation is not configured" + ) + + flows = [] + + # only support 3PIDless registration if no 3PIDs are required + if not require_email and not require_msisdn: + # Add a dummy step here, otherwise if a client completes + # recaptcha first we'll assume they were going for this flow + # and complete the request, when they could have been trying to + # complete one of the flows with email/msisdn auth. + flows.append([LoginType.DUMMY]) + + # only support the email-only flow if we don't require MSISDN 3PIDs + if show_email and not require_msisdn: + flows.append([LoginType.EMAIL_IDENTITY]) + + # only support the MSISDN-only flow if we don't require email 3PIDs + if show_msisdn and not require_email: + flows.append([LoginType.MSISDN]) + + if show_email and show_msisdn: + # always let users provide both MSISDN & email + flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY]) + + # Prepend m.login.terms to all flows if we're requiring consent + if config.user_consent_at_registration: + for flow in flows: + flow.insert(0, LoginType.TERMS) + + # Prepend recaptcha to all flows if we're requiring captcha + if config.enable_registration_captcha: + for flow in flows: + flow.insert(0, LoginType.RECAPTCHA) + + return flows + + +def register_servlets(hs, http_server): + EmailRegisterRequestTokenRestServlet(hs).register(http_server) + MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) + UsernameAvailabilityRestServlet(hs).register(http_server) + RegistrationSubmitTokenServlet(hs).register(http_server) + RegisterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py new file mode 100644 index 0000000000..0821cd285f --- /dev/null +++ b/synapse/rest/client/relations.py @@ -0,0 +1,381 @@ +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This class implements the proposed relation APIs from MSC 1849. + +Since the MSC has not been approved all APIs here are unstable and may change at +any time to reflect changes in the MSC. +""" + +import logging + +from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.errors import ShadowBanError, SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_integer, + parse_json_object_from_request, + parse_string, +) +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.storage.relations import ( + AggregationPaginationToken, + PaginationChunk, + RelationPaginationToken, +) +from synapse.util.stringutils import random_string + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class RelationSendServlet(RestServlet): + """Helper API for sending events that have relation data. + + Example API shape to send a 👍 reaction to a room: + + POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D + {} + + { + "event_id": "$foobar" + } + """ + + PATTERN = ( + "/rooms/(?P[^/]*)/send_relation" + "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)" + ) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.event_creation_handler = hs.get_event_creation_handler() + self.txns = HttpTransactionCache(hs) + + def register(self, http_server): + http_server.register_paths( + "POST", + client_patterns(self.PATTERN + "$", releases=()), + self.on_PUT_or_POST, + self.__class__.__name__, + ) + http_server.register_paths( + "PUT", + client_patterns(self.PATTERN + "/(?P[^/]*)$", releases=()), + self.on_PUT, + self.__class__.__name__, + ) + + def on_PUT(self, request, *args, **kwargs): + return self.txns.fetch_or_execute_request( + request, self.on_PUT_or_POST, request, *args, **kwargs + ) + + async def on_PUT_or_POST( + self, request, room_id, parent_id, relation_type, event_type, txn_id=None + ): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + if event_type == EventTypes.Member: + # Add relations to a membership is meaningless, so we just deny it + # at the CS API rather than trying to handle it correctly. + raise SynapseError(400, "Cannot send member events with relations") + + content = parse_json_object_from_request(request) + + aggregation_key = parse_string(request, "key", encoding="utf-8") + + content["m.relates_to"] = { + "event_id": parent_id, + "key": aggregation_key, + "rel_type": relation_type, + } + + event_dict = { + "type": event_type, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + } + + try: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict=event_dict, txn_id=txn_id + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) + + return 200, {"event_id": event_id} + + +class RelationPaginationServlet(RestServlet): + """API to paginate relations on an event by topological ordering, optionally + filtered by relation type and event type. + """ + + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)/relations/(?P[^/]*)" + "(/(?P[^/]*)(/(?P[^/]*))?)?$", + releases=(), + ) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + self.event_handler = hs.get_event_handler() + + async def on_GET( + self, request, room_id, parent_id, relation_type=None, event_type=None + ): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + await self.auth.check_user_in_room_or_world_readable( + room_id, requester.user.to_string(), allow_departed_users=True + ) + + # This gets the original event and checks that a) the event exists and + # b) the user is allowed to view it. + event = await self.event_handler.get_event(requester.user, room_id, parent_id) + + limit = parse_integer(request, "limit", default=5) + from_token_str = parse_string(request, "from") + to_token_str = parse_string(request, "to") + + if event.internal_metadata.is_redacted(): + # If the event is redacted, return an empty list of relations + pagination_chunk = PaginationChunk(chunk=[]) + else: + # Return the relations + from_token = None + if from_token_str: + from_token = RelationPaginationToken.from_string(from_token_str) + + to_token = None + if to_token_str: + to_token = RelationPaginationToken.from_string(to_token_str) + + pagination_chunk = await self.store.get_relations_for_event( + event_id=parent_id, + relation_type=relation_type, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + events = await self.store.get_events_as_list( + [c["event_id"] for c in pagination_chunk.chunk] + ) + + now = self.clock.time_msec() + # We set bundle_aggregations to False when retrieving the original + # event because we want the content before relations were applied to + # it. + original_event = await self._event_serializer.serialize_event( + event, now, bundle_aggregations=False + ) + # Similarly, we don't allow relations to be applied to relations, so we + # return the original relations without any aggregations on top of them + # here. + events = await self._event_serializer.serialize_events( + events, now, bundle_aggregations=False + ) + + return_value = pagination_chunk.to_dict() + return_value["chunk"] = events + return_value["original_event"] = original_event + + return 200, return_value + + +class RelationAggregationPaginationServlet(RestServlet): + """API to paginate aggregation groups of relations, e.g. paginate the + types and counts of the reactions on the events. + + Example request and response: + + GET /rooms/{room_id}/aggregations/{parent_id} + + { + chunk: [ + { + "type": "m.reaction", + "key": "👍", + "count": 3 + } + ] + } + """ + + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)/aggregations/(?P[^/]*)" + "(/(?P[^/]*)(/(?P[^/]*))?)?$", + releases=(), + ) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.event_handler = hs.get_event_handler() + + async def on_GET( + self, request, room_id, parent_id, relation_type=None, event_type=None + ): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + await self.auth.check_user_in_room_or_world_readable( + room_id, + requester.user.to_string(), + allow_departed_users=True, + ) + + # This checks that a) the event exists and b) the user is allowed to + # view it. + event = await self.event_handler.get_event(requester.user, room_id, parent_id) + + if relation_type not in (RelationTypes.ANNOTATION, None): + raise SynapseError(400, "Relation type must be 'annotation'") + + limit = parse_integer(request, "limit", default=5) + from_token_str = parse_string(request, "from") + to_token_str = parse_string(request, "to") + + if event.internal_metadata.is_redacted(): + # If the event is redacted, return an empty list of relations + pagination_chunk = PaginationChunk(chunk=[]) + else: + # Return the relations + from_token = None + if from_token_str: + from_token = AggregationPaginationToken.from_string(from_token_str) + + to_token = None + if to_token_str: + to_token = AggregationPaginationToken.from_string(to_token_str) + + pagination_chunk = await self.store.get_aggregation_groups_for_event( + event_id=parent_id, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + return 200, pagination_chunk.to_dict() + + +class RelationAggregationGroupPaginationServlet(RestServlet): + """API to paginate within an aggregation group of relations, e.g. paginate + all the 👍 reactions on an event. + + Example request and response: + + GET /rooms/{room_id}/aggregations/{parent_id}/m.annotation/m.reaction/👍 + + { + chunk: [ + { + "type": "m.reaction", + "content": { + "m.relates_to": { + "rel_type": "m.annotation", + "key": "👍" + } + } + }, + ... + ] + } + """ + + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)/aggregations/(?P[^/]*)" + "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)$", + releases=(), + ) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + self.event_handler = hs.get_event_handler() + + async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + await self.auth.check_user_in_room_or_world_readable( + room_id, + requester.user.to_string(), + allow_departed_users=True, + ) + + # This checks that a) the event exists and b) the user is allowed to + # view it. + await self.event_handler.get_event(requester.user, room_id, parent_id) + + if relation_type != RelationTypes.ANNOTATION: + raise SynapseError(400, "Relation type must be 'annotation'") + + limit = parse_integer(request, "limit", default=5) + from_token_str = parse_string(request, "from") + to_token_str = parse_string(request, "to") + + from_token = None + if from_token_str: + from_token = RelationPaginationToken.from_string(from_token_str) + + to_token = None + if to_token_str: + to_token = RelationPaginationToken.from_string(to_token_str) + + result = await self.store.get_relations_for_event( + event_id=parent_id, + relation_type=relation_type, + event_type=event_type, + aggregation_key=key, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + events = await self.store.get_events_as_list( + [c["event_id"] for c in result.chunk] + ) + + now = self.clock.time_msec() + events = await self._event_serializer.serialize_events(events, now) + + return_value = result.to_dict() + return_value["chunk"] = events + + return 200, return_value + + +def register_servlets(hs, http_server): + RelationSendServlet(hs).register(http_server) + RelationPaginationServlet(hs).register(http_server) + RelationAggregationPaginationServlet(hs).register(http_server) + RelationAggregationGroupPaginationServlet(hs).register(http_server) diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py new file mode 100644 index 0000000000..07ea39a8a3 --- /dev/null +++ b/synapse/rest/client/report_event.py @@ -0,0 +1,68 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from http import HTTPStatus + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class ReportEventRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/report/(?P[^/]*)$") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.store = hs.get_datastore() + + async def on_POST(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + body = parse_json_object_from_request(request) + + if not isinstance(body.get("reason", ""), str): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'reason' must be a string", + Codes.BAD_JSON, + ) + if not isinstance(body.get("score", 0), int): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'score' must be an integer", + Codes.BAD_JSON, + ) + + await self.store.add_event_report( + room_id=room_id, + event_id=event_id, + user_id=user_id, + reason=body.get("reason"), + content=body, + received_ts=self.clock.time_msec(), + ) + + return 200, {} + + +def register_servlets(hs, http_server): + ReportEventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py new file mode 100644 index 0000000000..ed238b2141 --- /dev/null +++ b/synapse/rest/client/room.py @@ -0,0 +1,1152 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" This module contains REST servlets to do with rooms: /rooms/ """ +import logging +import re +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from urllib import parse as urlparse + +from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import ( + AuthError, + Codes, + InvalidClientCredentialsError, + MissingClientTokenError, + ShadowBanError, + SynapseError, +) +from synapse.api.filtering import Filter +from synapse.events.utils import format_event_for_client_v2 +from synapse.http.servlet import ( + ResolveRoomIdMixin, + RestServlet, + assert_params_in_dict, + parse_boolean, + parse_integer, + parse_json_object_from_request, + parse_string, + parse_strings_from_args, +) +from synapse.http.site import SynapseRequest +from synapse.logging.opentracing import set_tag +from synapse.rest.client._base import client_patterns +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.storage.state import StateFilter +from synapse.streams.config import PaginationConfig +from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID +from synapse.util import json_decoder +from synapse.util.stringutils import parse_and_validate_server_name, random_string + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class TransactionRestServlet(RestServlet): + def __init__(self, hs): + super().__init__() + self.txns = HttpTransactionCache(hs) + + +class RoomCreateRestServlet(TransactionRestServlet): + # No PATTERN; we have custom dispatch rules here + + def __init__(self, hs): + super().__init__(hs) + self._room_creation_handler = hs.get_room_creation_handler() + self.auth = hs.get_auth() + + def register(self, http_server): + PATTERNS = "/createRoom" + register_txn_path(self, PATTERNS, http_server) + + def on_PUT(self, request, txn_id): + set_tag("txn_id", txn_id) + return self.txns.fetch_or_execute_request(request, self.on_POST, request) + + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) + + info, _ = await self._room_creation_handler.create_room( + requester, self.get_room_config(request) + ) + + return 200, info + + def get_room_config(self, request): + user_supplied_config = parse_json_object_from_request(request) + return user_supplied_config + + +# TODO: Needs unit testing for generic events +class RoomStateEventRestServlet(TransactionRestServlet): + def __init__(self, hs): + super().__init__(hs) + self.event_creation_handler = hs.get_event_creation_handler() + self.room_member_handler = hs.get_room_member_handler() + self.message_handler = hs.get_message_handler() + self.auth = hs.get_auth() + + def register(self, http_server): + # /room/$roomid/state/$eventtype + no_state_key = "/rooms/(?P[^/]*)/state/(?P[^/]*)$" + + # /room/$roomid/state/$eventtype/$statekey + state_key = ( + "/rooms/(?P[^/]*)/state/" + "(?P[^/]*)/(?P[^/]*)$" + ) + + http_server.register_paths( + "GET", + client_patterns(state_key, v1=True), + self.on_GET, + self.__class__.__name__, + ) + http_server.register_paths( + "PUT", + client_patterns(state_key, v1=True), + self.on_PUT, + self.__class__.__name__, + ) + http_server.register_paths( + "GET", + client_patterns(no_state_key, v1=True), + self.on_GET_no_state_key, + self.__class__.__name__, + ) + http_server.register_paths( + "PUT", + client_patterns(no_state_key, v1=True), + self.on_PUT_no_state_key, + self.__class__.__name__, + ) + + def on_GET_no_state_key(self, request, room_id, event_type): + return self.on_GET(request, room_id, event_type, "") + + def on_PUT_no_state_key(self, request, room_id, event_type): + return self.on_PUT(request, room_id, event_type, "") + + async def on_GET(self, request, room_id, event_type, state_key): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + format = parse_string( + request, "format", default="content", allowed_values=["content", "event"] + ) + + msg_handler = self.message_handler + data = await msg_handler.get_room_data( + user_id=requester.user.to_string(), + room_id=room_id, + event_type=event_type, + state_key=state_key, + ) + + if not data: + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + + if format == "event": + event = format_event_for_client_v2(data.get_dict()) + return 200, event + elif format == "content": + return 200, data.get_dict()["content"] + + async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): + requester = await self.auth.get_user_by_req(request) + + if txn_id: + set_tag("txn_id", txn_id) + + content = parse_json_object_from_request(request) + + event_dict = { + "type": event_type, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + } + + if state_key is not None: + event_dict["state_key"] = state_key + + try: + if event_type == EventTypes.Member: + membership = content.get("membership", None) + event_id, _ = await self.room_member_handler.update_membership( + requester, + target=UserID.from_string(state_key), + room_id=room_id, + action=membership, + content=content, + ) + else: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict, txn_id=txn_id + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) + + set_tag("event_id", event_id) + ret = {"event_id": event_id} + return 200, ret + + +# TODO: Needs unit testing for generic events + feedback +class RoomSendEventRestServlet(TransactionRestServlet): + def __init__(self, hs): + super().__init__(hs) + self.event_creation_handler = hs.get_event_creation_handler() + self.auth = hs.get_auth() + + def register(self, http_server): + # /rooms/$roomid/send/$event_type[/$txn_id] + PATTERNS = "/rooms/(?P[^/]*)/send/(?P[^/]*)" + register_txn_path(self, PATTERNS, http_server, with_get=True) + + async def on_POST(self, request, room_id, event_type, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + content = parse_json_object_from_request(request) + + event_dict = { + "type": event_type, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + } + + if b"ts" in request.args and requester.app_service: + event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) + + try: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict, txn_id=txn_id + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) + + set_tag("event_id", event_id) + return 200, {"event_id": event_id} + + def on_GET(self, request, room_id, event_type, txn_id): + return 200, "Not implemented" + + def on_PUT(self, request, room_id, event_type, txn_id): + set_tag("txn_id", txn_id) + + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id, event_type, txn_id + ) + + +# TODO: Needs unit testing for room ID + alias joins +class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): + def __init__(self, hs): + super().__init__(hs) + super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up + self.auth = hs.get_auth() + + def register(self, http_server): + # /join/$room_identifier[/$txn_id] + PATTERNS = "/join/(?P[^/]*)" + register_txn_path(self, PATTERNS, http_server) + + async def on_POST( + self, + request: SynapseRequest, + room_identifier: str, + txn_id: Optional[str] = None, + ): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + try: + content = parse_json_object_from_request(request) + except Exception: + # Turns out we used to ignore the body entirely, and some clients + # cheekily send invalid bodies. + content = {} + + # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + args: Dict[bytes, List[bytes]] = request.args # type: ignore + remote_room_hosts = parse_strings_from_args(args, "server_name", required=False) + room_id, remote_room_hosts = await self.resolve_room_id( + room_identifier, + remote_room_hosts, + ) + + await self.room_member_handler.update_membership( + requester=requester, + target=requester.user, + room_id=room_id, + action="join", + txn_id=txn_id, + remote_room_hosts=remote_room_hosts, + content=content, + third_party_signed=content.get("third_party_signed", None), + ) + + return 200, {"room_id": room_id} + + def on_PUT(self, request, room_identifier, txn_id): + set_tag("txn_id", txn_id) + + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_identifier, txn_id + ) + + +# TODO: Needs unit testing +class PublicRoomListRestServlet(TransactionRestServlet): + PATTERNS = client_patterns("/publicRooms$", v1=True) + + def __init__(self, hs): + super().__init__(hs) + self.hs = hs + self.auth = hs.get_auth() + + async def on_GET(self, request): + server = parse_string(request, "server") + + try: + await self.auth.get_user_by_req(request, allow_guest=True) + except InvalidClientCredentialsError as e: + # Option to allow servers to require auth when accessing + # /publicRooms via CS API. This is especially helpful in private + # federations. + if not self.hs.config.allow_public_rooms_without_auth: + raise + + # We allow people to not be authed if they're just looking at our + # room list, but require auth when we proxy the request. + # In both cases we call the auth function, as that has the side + # effect of logging who issued this request if an access token was + # provided. + if server: + raise e + + limit: Optional[int] = parse_integer(request, "limit", 0) + since_token = parse_string(request, "since") + + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + + handler = self.hs.get_room_list_handler() + if server and server != self.hs.config.server_name: + # Ensure the server is valid. + try: + parse_and_validate_server_name(server) + except ValueError: + raise SynapseError( + 400, + "Invalid server name: %s" % (server,), + Codes.INVALID_PARAM, + ) + + data = await handler.get_remote_public_room_list( + server, limit=limit, since_token=since_token + ) + else: + data = await handler.get_local_public_room_list( + limit=limit, since_token=since_token + ) + + return 200, data + + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) + + server = parse_string(request, "server") + content = parse_json_object_from_request(request) + + limit: Optional[int] = int(content.get("limit", 100)) + since_token = content.get("since", None) + search_filter = content.get("filter", None) + + include_all_networks = content.get("include_all_networks", False) + third_party_instance_id = content.get("third_party_instance_id", None) + + if include_all_networks: + network_tuple = None + if third_party_instance_id is not None: + raise SynapseError( + 400, "Can't use include_all_networks with an explicit network" + ) + elif third_party_instance_id is None: + network_tuple = ThirdPartyInstanceID(None, None) + else: + network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) + + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + + handler = self.hs.get_room_list_handler() + if server and server != self.hs.config.server_name: + # Ensure the server is valid. + try: + parse_and_validate_server_name(server) + except ValueError: + raise SynapseError( + 400, + "Invalid server name: %s" % (server,), + Codes.INVALID_PARAM, + ) + + data = await handler.get_remote_public_room_list( + server, + limit=limit, + since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) + + else: + data = await handler.get_local_public_room_list( + limit=limit, + since_token=since_token, + search_filter=search_filter, + network_tuple=network_tuple, + ) + + return 200, data + + +# TODO: Needs unit testing +class RoomMemberListRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/members$", v1=True) + + def __init__(self, hs): + super().__init__() + self.message_handler = hs.get_message_handler() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, room_id): + # TODO support Pagination stream API (limit/tokens) + requester = await self.auth.get_user_by_req(request, allow_guest=True) + handler = self.message_handler + + # request the state as of a given event, as identified by a stream token, + # for consistency with /messages etc. + # useful for getting the membership in retrospect as of a given /sync + # response. + at_token_string = parse_string(request, "at") + if at_token_string is None: + at_token = None + else: + at_token = await StreamToken.from_string(self.store, at_token_string) + + # let you filter down on particular memberships. + # XXX: this may not be the best shape for this API - we could pass in a filter + # instead, except filters aren't currently aware of memberships. + # See https://github.com/matrix-org/matrix-doc/issues/1337 for more details. + membership = parse_string(request, "membership") + not_membership = parse_string(request, "not_membership") + + events = await handler.get_state_events( + room_id=room_id, + user_id=requester.user.to_string(), + at_token=at_token, + state_filter=StateFilter.from_types([(EventTypes.Member, None)]), + ) + + chunk = [] + + for event in events: + if (membership and event["content"].get("membership") != membership) or ( + not_membership and event["content"].get("membership") == not_membership + ): + continue + chunk.append(event) + + return 200, {"chunk": chunk} + + +# deprecated in favour of /members?membership=join? +# except it does custom AS logic and has a simpler return format +class JoinedRoomMemberListRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/joined_members$", v1=True) + + def __init__(self, hs): + super().__init__() + self.message_handler = hs.get_message_handler() + self.auth = hs.get_auth() + + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request) + + users_with_profile = await self.message_handler.get_joined_members( + requester, room_id + ) + + return 200, {"joined": users_with_profile} + + +# TODO: Needs better unit testing +class RoomMessageListRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/messages$", v1=True) + + def __init__(self, hs): + super().__init__() + self.pagination_handler = hs.get_pagination_handler() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + pagination_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) + as_client_event = b"raw" not in request.args + filter_str = parse_string(request, "filter", encoding="utf-8") + if filter_str: + filter_json = urlparse.unquote(filter_str) + event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) + if ( + event_filter + and event_filter.filter_json.get("event_format", "client") + == "federation" + ): + as_client_event = False + else: + event_filter = None + + msgs = await self.pagination_handler.get_messages( + room_id=room_id, + requester=requester, + pagin_config=pagination_config, + as_client_event=as_client_event, + event_filter=event_filter, + ) + + return 200, msgs + + +# TODO: Needs unit testing +class RoomStateRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/state$", v1=True) + + def __init__(self, hs): + super().__init__() + self.message_handler = hs.get_message_handler() + self.auth = hs.get_auth() + + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + # Get all the current state for this room + events = await self.message_handler.get_state_events( + room_id=room_id, + user_id=requester.user.to_string(), + is_guest=requester.is_guest, + ) + return 200, events + + +# TODO: Needs unit testing +class RoomInitialSyncRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/initialSync$", v1=True) + + def __init__(self, hs): + super().__init__() + self.initial_sync_handler = hs.get_initial_sync_handler() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + pagination_config = await PaginationConfig.from_request(self.store, request) + content = await self.initial_sync_handler.room_initial_sync( + room_id=room_id, requester=requester, pagin_config=pagination_config + ) + return 200, content + + +class RoomEventServlet(RestServlet): + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)/event/(?P[^/]*)$", v1=True + ) + + def __init__(self, hs): + super().__init__() + self.clock = hs.get_clock() + self.event_handler = hs.get_event_handler() + self._event_serializer = hs.get_event_client_serializer() + self.auth = hs.get_auth() + + async def on_GET(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + try: + event = await self.event_handler.get_event( + requester.user, room_id, event_id + ) + except AuthError: + # This endpoint is supposed to return a 404 when the requester does + # not have permission to access the event + # https://matrix.org/docs/spec/client_server/r0.5.0#get-matrix-client-r0-rooms-roomid-event-eventid + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + + time_now = self.clock.time_msec() + if event: + event = await self._event_serializer.serialize_event(event, time_now) + return 200, event + + return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + + +class RoomEventContextServlet(RestServlet): + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)/context/(?P[^/]*)$", v1=True + ) + + def __init__(self, hs): + super().__init__() + self.clock = hs.get_clock() + self.room_context_handler = hs.get_room_context_handler() + self._event_serializer = hs.get_event_client_serializer() + self.auth = hs.get_auth() + + async def on_GET(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + limit = parse_integer(request, "limit", default=10) + + # picking the API shape for symmetry with /messages + filter_str = parse_string(request, "filter", encoding="utf-8") + if filter_str: + filter_json = urlparse.unquote(filter_str) + event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) + else: + event_filter = None + + results = await self.room_context_handler.get_event_context( + requester, room_id, event_id, limit, event_filter + ) + + if not results: + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + + time_now = self.clock.time_msec() + results["events_before"] = await self._event_serializer.serialize_events( + results["events_before"], time_now + ) + results["event"] = await self._event_serializer.serialize_event( + results["event"], time_now + ) + results["events_after"] = await self._event_serializer.serialize_events( + results["events_after"], time_now + ) + results["state"] = await self._event_serializer.serialize_events( + results["state"], + time_now, + # No need to bundle aggregations for state events + bundle_aggregations=False, + ) + + return 200, results + + +class RoomForgetRestServlet(TransactionRestServlet): + def __init__(self, hs): + super().__init__(hs) + self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() + + def register(self, http_server): + PATTERNS = "/rooms/(?P[^/]*)/forget" + register_txn_path(self, PATTERNS, http_server) + + async def on_POST(self, request, room_id, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=False) + + await self.room_member_handler.forget(user=requester.user, room_id=room_id) + + return 200, {} + + def on_PUT(self, request, room_id, txn_id): + set_tag("txn_id", txn_id) + + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id, txn_id + ) + + +# TODO: Needs unit testing +class RoomMembershipRestServlet(TransactionRestServlet): + def __init__(self, hs): + super().__init__(hs) + self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() + + def register(self, http_server): + # /rooms/$roomid/[invite|join|leave] + PATTERNS = ( + "/rooms/(?P[^/]*)/" + "(?Pjoin|invite|leave|ban|unban|kick)" + ) + register_txn_path(self, PATTERNS, http_server) + + async def on_POST(self, request, room_id, membership_action, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + if requester.is_guest and membership_action not in { + Membership.JOIN, + Membership.LEAVE, + }: + raise AuthError(403, "Guest access not allowed") + + try: + content = parse_json_object_from_request(request) + except Exception: + # Turns out we used to ignore the body entirely, and some clients + # cheekily send invalid bodies. + content = {} + + if membership_action == "invite" and self._has_3pid_invite_keys(content): + try: + await self.room_member_handler.do_3pid_invite( + room_id, + requester.user, + content["medium"], + content["address"], + content["id_server"], + requester, + txn_id, + content.get("id_access_token"), + ) + except ShadowBanError: + # Pretend the request succeeded. + pass + return 200, {} + + target = requester.user + if membership_action in ["invite", "ban", "unban", "kick"]: + assert_params_in_dict(content, ["user_id"]) + target = UserID.from_string(content["user_id"]) + + event_content = None + if "reason" in content: + event_content = {"reason": content["reason"]} + + try: + await self.room_member_handler.update_membership( + requester=requester, + target=target, + room_id=room_id, + action=membership_action, + txn_id=txn_id, + third_party_signed=content.get("third_party_signed", None), + content=event_content, + ) + except ShadowBanError: + # Pretend the request succeeded. + pass + + return_value = {} + + if membership_action == "join": + return_value["room_id"] = room_id + + return 200, return_value + + def _has_3pid_invite_keys(self, content): + for key in {"id_server", "medium", "address"}: + if key not in content: + return False + return True + + def on_PUT(self, request, room_id, membership_action, txn_id): + set_tag("txn_id", txn_id) + + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id, membership_action, txn_id + ) + + +class RoomRedactEventRestServlet(TransactionRestServlet): + def __init__(self, hs): + super().__init__(hs) + self.event_creation_handler = hs.get_event_creation_handler() + self.auth = hs.get_auth() + + def register(self, http_server): + PATTERNS = "/rooms/(?P[^/]*)/redact/(?P[^/]*)" + register_txn_path(self, PATTERNS, http_server) + + async def on_POST(self, request, room_id, event_id, txn_id=None): + requester = await self.auth.get_user_by_req(request) + content = parse_json_object_from_request(request) + + try: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "redacts": event_id, + }, + txn_id=txn_id, + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) + + set_tag("event_id", event_id) + return 200, {"event_id": event_id} + + def on_PUT(self, request, room_id, event_id, txn_id): + set_tag("txn_id", txn_id) + + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id, event_id, txn_id + ) + + +class RoomTypingRestServlet(RestServlet): + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)/typing/(?P[^/]*)$", v1=True + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.hs = hs + self.presence_handler = hs.get_presence_handler() + self.auth = hs.get_auth() + + # If we're not on the typing writer instance we should scream if we get + # requests. + self._is_typing_writer = ( + hs.config.worker.writers.typing == hs.get_instance_name() + ) + + async def on_PUT(self, request, room_id, user_id): + requester = await self.auth.get_user_by_req(request) + + if not self._is_typing_writer: + raise Exception("Got /typing request on instance that is not typing writer") + + room_id = urlparse.unquote(room_id) + target_user = UserID.from_string(urlparse.unquote(user_id)) + + content = parse_json_object_from_request(request) + + await self.presence_handler.bump_presence_active_time(requester.user) + + # Limit timeout to stop people from setting silly typing timeouts. + timeout = min(content.get("timeout", 30000), 120000) + + # Defer getting the typing handler since it will raise on workers. + typing_handler = self.hs.get_typing_writer_handler() + + try: + if content["typing"]: + await typing_handler.started_typing( + target_user=target_user, + requester=requester, + room_id=room_id, + timeout=timeout, + ) + else: + await typing_handler.stopped_typing( + target_user=target_user, requester=requester, room_id=room_id + ) + except ShadowBanError: + # Pretend this worked without error. + pass + + return 200, {} + + +class RoomAliasListServlet(RestServlet): + PATTERNS = [ + re.compile( + r"^/_matrix/client/unstable/org\.matrix\.msc2432" + r"/rooms/(?P[^/]*)/aliases" + ), + ] + list(client_patterns("/rooms/(?P[^/]*)/aliases$", unstable=False)) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.directory_handler = hs.get_directory_handler() + + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request) + + alias_list = await self.directory_handler.get_aliases_for_room( + requester, room_id + ) + + return 200, {"aliases": alias_list} + + +class SearchRestServlet(RestServlet): + PATTERNS = client_patterns("/search$", v1=True) + + def __init__(self, hs): + super().__init__() + self.search_handler = hs.get_search_handler() + self.auth = hs.get_auth() + + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) + + content = parse_json_object_from_request(request) + + batch = parse_string(request, "next_batch") + results = await self.search_handler.search(requester.user, content, batch) + + return 200, results + + +class JoinedRoomsRestServlet(RestServlet): + PATTERNS = client_patterns("/joined_rooms$", v1=True) + + def __init__(self, hs): + super().__init__() + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + room_ids = await self.store.get_rooms_for_user(requester.user.to_string()) + return 200, {"joined_rooms": list(room_ids)} + + +def register_txn_path(servlet, regex_string, http_server, with_get=False): + """Registers a transaction-based path. + + This registers two paths: + PUT regex_string/$txnid + POST regex_string + + Args: + regex_string (str): The regex string to register. Must NOT have a + trailing $ as this string will be appended to. + http_server : The http_server to register paths with. + with_get: True to also register respective GET paths for the PUTs. + """ + http_server.register_paths( + "POST", + client_patterns(regex_string + "$", v1=True), + servlet.on_POST, + servlet.__class__.__name__, + ) + http_server.register_paths( + "PUT", + client_patterns(regex_string + "/(?P[^/]*)$", v1=True), + servlet.on_PUT, + servlet.__class__.__name__, + ) + if with_get: + http_server.register_paths( + "GET", + client_patterns(regex_string + "/(?P[^/]*)$", v1=True), + servlet.on_GET, + servlet.__class__.__name__, + ) + + +class RoomSpaceSummaryRestServlet(RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc2946" + "/rooms/(?P[^/]*)/spaces$" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._auth = hs.get_auth() + self._room_summary_handler = hs.get_room_summary_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + + max_rooms_per_space = parse_integer(request, "max_rooms_per_space") + if max_rooms_per_space is not None and max_rooms_per_space < 0: + raise SynapseError( + 400, + "Value for 'max_rooms_per_space' must be a non-negative integer", + Codes.BAD_JSON, + ) + + return 200, await self._room_summary_handler.get_space_summary( + requester.user.to_string(), + room_id, + suggested_only=parse_boolean(request, "suggested_only", default=False), + max_rooms_per_space=max_rooms_per_space, + ) + + # TODO When switching to the stable endpoint, remove the POST handler. + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + content = parse_json_object_from_request(request) + + suggested_only = content.get("suggested_only", False) + if not isinstance(suggested_only, bool): + raise SynapseError( + 400, "'suggested_only' must be a boolean", Codes.BAD_JSON + ) + + max_rooms_per_space = content.get("max_rooms_per_space") + if max_rooms_per_space is not None: + if not isinstance(max_rooms_per_space, int): + raise SynapseError( + 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON + ) + if max_rooms_per_space < 0: + raise SynapseError( + 400, + "Value for 'max_rooms_per_space' must be a non-negative integer", + Codes.BAD_JSON, + ) + + return 200, await self._room_summary_handler.get_space_summary( + requester.user.to_string(), + room_id, + suggested_only=suggested_only, + max_rooms_per_space=max_rooms_per_space, + ) + + +class RoomHierarchyRestServlet(RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc2946" + "/rooms/(?P[^/]*)/hierarchy$" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._auth = hs.get_auth() + self._room_summary_handler = hs.get_room_summary_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + + max_depth = parse_integer(request, "max_depth") + if max_depth is not None and max_depth < 0: + raise SynapseError( + 400, "'max_depth' must be a non-negative integer", Codes.BAD_JSON + ) + + limit = parse_integer(request, "limit") + if limit is not None and limit <= 0: + raise SynapseError( + 400, "'limit' must be a positive integer", Codes.BAD_JSON + ) + + return 200, await self._room_summary_handler.get_room_hierarchy( + requester.user.to_string(), + room_id, + suggested_only=parse_boolean(request, "suggested_only", default=False), + max_depth=max_depth, + limit=limit, + from_token=parse_string(request, "from"), + ) + + +class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/im.nheko.summary" + "/rooms/(?P[^/]*)/summary$" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + self._auth = hs.get_auth() + self._room_summary_handler = hs.get_room_summary_handler() + + async def on_GET( + self, request: SynapseRequest, room_identifier: str + ) -> Tuple[int, JsonDict]: + try: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + requester_user_id: Optional[str] = requester.user.to_string() + except MissingClientTokenError: + # auth is optional + requester_user_id = None + + # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + args: Dict[bytes, List[bytes]] = request.args # type: ignore + remote_room_hosts = parse_strings_from_args(args, "via", required=False) + room_id, remote_room_hosts = await self.resolve_room_id( + room_identifier, + remote_room_hosts, + ) + + return 200, await self._room_summary_handler.get_room_summary( + requester_user_id, + room_id, + remote_room_hosts, + ) + + +def register_servlets(hs: "HomeServer", http_server, is_worker=False): + RoomStateEventRestServlet(hs).register(http_server) + RoomMemberListRestServlet(hs).register(http_server) + JoinedRoomMemberListRestServlet(hs).register(http_server) + RoomMessageListRestServlet(hs).register(http_server) + JoinRoomAliasServlet(hs).register(http_server) + RoomMembershipRestServlet(hs).register(http_server) + RoomSendEventRestServlet(hs).register(http_server) + PublicRoomListRestServlet(hs).register(http_server) + RoomStateRestServlet(hs).register(http_server) + RoomRedactEventRestServlet(hs).register(http_server) + RoomTypingRestServlet(hs).register(http_server) + RoomEventContextServlet(hs).register(http_server) + RoomSpaceSummaryRestServlet(hs).register(http_server) + RoomHierarchyRestServlet(hs).register(http_server) + if hs.config.experimental.msc3266_enabled: + RoomSummaryRestServlet(hs).register(http_server) + RoomEventServlet(hs).register(http_server) + JoinedRoomsRestServlet(hs).register(http_server) + RoomAliasListServlet(hs).register(http_server) + SearchRestServlet(hs).register(http_server) + + # Some servlets only get registered for the main process. + if not is_worker: + RoomCreateRestServlet(hs).register(http_server) + RoomForgetRestServlet(hs).register(http_server) + + +def register_deprecated_servlets(hs, http_server): + RoomInitialSyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py new file mode 100644 index 0000000000..3172aba605 --- /dev/null +++ b/synapse/rest/client/room_batch.py @@ -0,0 +1,441 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.appservice import ApplicationService +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, + parse_string, + parse_strings_from_args, +) +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.types import Requester, UserID, create_requester +from synapse.util.stringutils import random_string + +logger = logging.getLogger(__name__) + + +class RoomBatchSendEventRestServlet(RestServlet): + """ + API endpoint which can insert a chunk of events historically back in time + next to the given `prev_event`. + + `chunk_id` comes from `next_chunk_id `in the response of the batch send + endpoint and is derived from the "insertion" events added to each chunk. + It's not required for the first batch send. + + `state_events_at_start` is used to define the historical state events + needed to auth the events like join events. These events will float + outside of the normal DAG as outlier's and won't be visible in the chat + history which also allows us to insert multiple chunks without having a bunch + of `@mxid joined the room` noise between each chunk. + + `events` is chronological chunk/list of events you want to insert. + There is a reverse-chronological constraint on chunks so once you insert + some messages, you can only insert older ones after that. + tldr; Insert chunks from your most recent history -> oldest history. + + POST /_matrix/client/unstable/org.matrix.msc2716/rooms//batch_send?prev_event=&chunk_id= + { + "events": [ ... ], + "state_events_at_start": [ ... ] + } + """ + + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc2716" + "/rooms/(?P[^/]*)/batch_send$" + ), + ) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.store = hs.get_datastore() + self.state_store = hs.get_storage().state + self.event_creation_handler = hs.get_event_creation_handler() + self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() + self.txns = HttpTransactionCache(hs) + + async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int: + ( + most_recent_prev_event_id, + most_recent_prev_event_depth, + ) = await self.store.get_max_depth_of(prev_event_ids) + + # We want to insert the historical event after the `prev_event` but before the successor event + # + # We inherit depth from the successor event instead of the `prev_event` + # because events returned from `/messages` are first sorted by `topological_ordering` + # which is just the `depth` and then tie-break with `stream_ordering`. + # + # We mark these inserted historical events as "backfilled" which gives them a + # negative `stream_ordering`. If we use the same depth as the `prev_event`, + # then our historical event will tie-break and be sorted before the `prev_event` + # when it should come after. + # + # We want to use the successor event depth so they appear after `prev_event` because + # it has a larger `depth` but before the successor event because the `stream_ordering` + # is negative before the successor event. + successor_event_ids = await self.store.get_successor_events( + [most_recent_prev_event_id] + ) + + # If we can't find any successor events, then it's a forward extremity of + # historical messages and we can just inherit from the previous historical + # event which we can already assume has the correct depth where we want + # to insert into. + if not successor_event_ids: + depth = most_recent_prev_event_depth + else: + ( + _, + oldest_successor_depth, + ) = await self.store.get_min_depth_of(successor_event_ids) + + depth = oldest_successor_depth + + return depth + + def _create_insertion_event_dict( + self, sender: str, room_id: str, origin_server_ts: int + ): + """Creates an event dict for an "insertion" event with the proper fields + and a random chunk ID. + + Args: + sender: The event author MXID + room_id: The room ID that the event belongs to + origin_server_ts: Timestamp when the event was sent + + Returns: + Tuple of event ID and stream ordering position + """ + + next_chunk_id = random_string(8) + insertion_event = { + "type": EventTypes.MSC2716_INSERTION, + "sender": sender, + "room_id": room_id, + "content": { + EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id, + EventContentFields.MSC2716_HISTORICAL: True, + }, + "origin_server_ts": origin_server_ts, + } + + return insertion_event + + async def _create_requester_for_user_id_from_app_service( + self, user_id: str, app_service: ApplicationService + ) -> Requester: + """Creates a new requester for the given user_id + and validates that the app service is allowed to control + the given user. + + Args: + user_id: The author MXID that the app service is controlling + app_service: The app service that controls the user + + Returns: + Requester object + """ + + await self.auth.validate_appservice_can_control_user_id(app_service, user_id) + + return create_requester(user_id, app_service=app_service) + + async def on_POST(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=False) + + if not requester.app_service: + raise AuthError( + 403, + "Only application services can use the /batchsend endpoint", + ) + + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["state_events_at_start", "events"]) + + prev_events_from_query = parse_strings_from_args(request.args, "prev_event") + chunk_id_from_query = parse_string(request, "chunk_id") + + if prev_events_from_query is None: + raise SynapseError( + 400, + "prev_event query parameter is required when inserting historical messages back in time", + errcode=Codes.MISSING_PARAM, + ) + + # For the event we are inserting next to (`prev_events_from_query`), + # find the most recent auth events (derived from state events) that + # allowed that message to be sent. We will use that as a base + # to auth our historical messages against. + ( + most_recent_prev_event_id, + _, + ) = await self.store.get_max_depth_of(prev_events_from_query) + # mapping from (type, state_key) -> state_event_id + prev_state_map = await self.state_store.get_state_ids_for_event( + most_recent_prev_event_id + ) + # List of state event ID's + prev_state_ids = list(prev_state_map.values()) + auth_event_ids = prev_state_ids + + state_events_at_start = [] + for state_event in body["state_events_at_start"]: + assert_params_in_dict( + state_event, ["type", "origin_server_ts", "content", "sender"] + ) + + logger.debug( + "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s", + state_event, + auth_event_ids, + ) + + event_dict = { + "type": state_event["type"], + "origin_server_ts": state_event["origin_server_ts"], + "content": state_event["content"], + "room_id": room_id, + "sender": state_event["sender"], + "state_key": state_event["state_key"], + } + + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + + # Make the state events float off on their own + fake_prev_event_id = "$" + random_string(43) + + # TODO: This is pretty much the same as some other code to handle inserting state in this file + if event_dict["type"] == EventTypes.Member: + membership = event_dict["content"].get("membership", None) + event_id, _ = await self.room_member_handler.update_membership( + await self._create_requester_for_user_id_from_app_service( + state_event["sender"], requester.app_service + ), + target=UserID.from_string(event_dict["state_key"]), + room_id=room_id, + action=membership, + content=event_dict["content"], + outlier=True, + prev_event_ids=[fake_prev_event_id], + # Make sure to use a copy of this list because we modify it + # later in the loop here. Otherwise it will be the same + # reference and also update in the event when we append later. + auth_event_ids=auth_event_ids.copy(), + ) + else: + # TODO: Add some complement tests that adds state that is not member joins + # and will use this code path. Maybe we only want to support join state events + # and can get rid of this `else`? + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + await self._create_requester_for_user_id_from_app_service( + state_event["sender"], requester.app_service + ), + event_dict, + outlier=True, + prev_event_ids=[fake_prev_event_id], + # Make sure to use a copy of this list because we modify it + # later in the loop here. Otherwise it will be the same + # reference and also update in the event when we append later. + auth_event_ids=auth_event_ids.copy(), + ) + event_id = event.event_id + + state_events_at_start.append(event_id) + auth_event_ids.append(event_id) + + events_to_create = body["events"] + + inherited_depth = await self._inherit_depth_from_prev_ids( + prev_events_from_query + ) + + # Figure out which chunk to connect to. If they passed in + # chunk_id_from_query let's use it. The chunk ID passed in comes + # from the chunk_id in the "insertion" event from the previous chunk. + last_event_in_chunk = events_to_create[-1] + chunk_id_to_connect_to = chunk_id_from_query + base_insertion_event = None + if chunk_id_from_query: + # All but the first base insertion event should point at a fake + # event, which causes the HS to ask for the state at the start of + # the chunk later. + prev_event_ids = [fake_prev_event_id] + # TODO: Verify the chunk_id_from_query corresponds to an insertion event + pass + # Otherwise, create an insertion event to act as a starting point. + # + # We don't always have an insertion event to start hanging more history + # off of (ideally there would be one in the main DAG, but that's not the + # case if we're wanting to add history to e.g. existing rooms without + # an insertion event), in which case we just create a new insertion event + # that can then get pointed to by a "marker" event later. + else: + prev_event_ids = prev_events_from_query + + base_insertion_event_dict = self._create_insertion_event_dict( + sender=requester.user.to_string(), + room_id=room_id, + origin_server_ts=last_event_in_chunk["origin_server_ts"], + ) + base_insertion_event_dict["prev_events"] = prev_event_ids.copy() + + ( + base_insertion_event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + await self._create_requester_for_user_id_from_app_service( + base_insertion_event_dict["sender"], + requester.app_service, + ), + base_insertion_event_dict, + prev_event_ids=base_insertion_event_dict.get("prev_events"), + auth_event_ids=auth_event_ids, + historical=True, + depth=inherited_depth, + ) + + chunk_id_to_connect_to = base_insertion_event["content"][ + EventContentFields.MSC2716_NEXT_CHUNK_ID + ] + + # Connect this current chunk to the insertion event from the previous chunk + chunk_event = { + "type": EventTypes.MSC2716_CHUNK, + "sender": requester.user.to_string(), + "room_id": room_id, + "content": { + EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to, + EventContentFields.MSC2716_HISTORICAL: True, + }, + # Since the chunk event is put at the end of the chunk, + # where the newest-in-time event is, copy the origin_server_ts from + # the last event we're inserting + "origin_server_ts": last_event_in_chunk["origin_server_ts"], + } + # Add the chunk event to the end of the chunk (newest-in-time) + events_to_create.append(chunk_event) + + # Add an "insertion" event to the start of each chunk (next to the oldest-in-time + # event in the chunk) so the next chunk can be connected to this one. + insertion_event = self._create_insertion_event_dict( + sender=requester.user.to_string(), + room_id=room_id, + # Since the insertion event is put at the start of the chunk, + # where the oldest-in-time event is, copy the origin_server_ts from + # the first event we're inserting + origin_server_ts=events_to_create[0]["origin_server_ts"], + ) + # Prepend the insertion event to the start of the chunk (oldest-in-time) + events_to_create = [insertion_event] + events_to_create + + event_ids = [] + events_to_persist = [] + for ev in events_to_create: + assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) + + event_dict = { + "type": ev["type"], + "origin_server_ts": ev["origin_server_ts"], + "content": ev["content"], + "room_id": room_id, + "sender": ev["sender"], # requester.user.to_string(), + "prev_events": prev_event_ids.copy(), + } + + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + + event, context = await self.event_creation_handler.create_event( + await self._create_requester_for_user_id_from_app_service( + ev["sender"], requester.app_service + ), + event_dict, + prev_event_ids=event_dict.get("prev_events"), + auth_event_ids=auth_event_ids, + historical=True, + depth=inherited_depth, + ) + logger.debug( + "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s", + event, + prev_event_ids, + auth_event_ids, + ) + + assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( + event.sender, + ) + + events_to_persist.append((event, context)) + event_id = event.event_id + + event_ids.append(event_id) + prev_event_ids = [event_id] + + # Persist events in reverse-chronological order so they have the + # correct stream_ordering as they are backfilled (which decrements). + # Events are sorted by (topological_ordering, stream_ordering) + # where topological_ordering is just depth. + for (event, context) in reversed(events_to_persist): + ev = await self.event_creation_handler.handle_new_client_event( + await self._create_requester_for_user_id_from_app_service( + event["sender"], requester.app_service + ), + event=event, + context=context, + ) + + # Add the base_insertion_event to the bottom of the list we return + if base_insertion_event is not None: + event_ids.append(base_insertion_event.event_id) + + return 200, { + "state_events": state_events_at_start, + "events": event_ids, + "next_chunk_id": insertion_event["content"][ + EventContentFields.MSC2716_NEXT_CHUNK_ID + ], + } + + def on_GET(self, request, room_id): + return 501, "Not implemented" + + def on_PUT(self, request, room_id): + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id + ) + + +def register_servlets(hs, http_server): + msc2716_enabled = hs.config.experimental.msc2716_enabled + + if msc2716_enabled: + RoomBatchSendEventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py new file mode 100644 index 0000000000..263596be86 --- /dev/null +++ b/synapse/rest/client/room_keys.py @@ -0,0 +1,391 @@ +# Copyright 2017, 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_json_object_from_request, + parse_string, +) + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class RoomKeysServlet(RestServlet): + PATTERNS = client_patterns( + "/room_keys/keys(/(?P[^/]+))?(/(?P[^/]+))?$" + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.auth = hs.get_auth() + self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() + + async def on_PUT(self, request, room_id, session_id): + """ + Uploads one or more encrypted E2E room keys for backup purposes. + room_id: the ID of the room the keys are for (optional) + session_id: the ID for the E2E room keys for the room (optional) + version: the version of the user's backup which this data is for. + the version must already have been created via the /room_keys/version API. + + Each session has: + * first_message_index: a numeric index indicating the oldest message + encrypted by this session. + * forwarded_count: how many times the uploading client claims this key + has been shared (forwarded) + * is_verified: whether the client that uploaded the keys claims they + were sent by a device which they've verified + * session_data: base64-encrypted data describing the session. + + Returns 200 OK on success with body {} + Returns 403 Forbidden if the version in question is not the most recently + created version (i.e. if this is an old client trying to write to a stale backup) + Returns 404 Not Found if the version in question doesn't exist + + The API is designed to be otherwise agnostic to the room_key encryption + algorithm being used. Sessions are merged with existing ones in the + backup using the heuristics: + * is_verified sessions always win over unverified sessions + * older first_message_index always win over newer sessions + * lower forwarded_count always wins over higher forwarded_count + + We trust the clients not to lie and corrupt their own backups. + It also means that if your access_token is stolen, the attacker could + delete your backup. + + POST /room_keys/keys/!abc:matrix.org/c0ff33?version=1 HTTP/1.1 + Content-Type: application/json + + { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + + Or... + + POST /room_keys/keys/!abc:matrix.org?version=1 HTTP/1.1 + Content-Type: application/json + + { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + + Or... + + POST /room_keys/keys?version=1 HTTP/1.1 + Content-Type: application/json + + { + "rooms": { + "!abc:matrix.org": { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + } + } + """ + requester = await self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + version = parse_string(request, "version") + + if session_id: + body = {"sessions": {session_id: body}} + + if room_id: + body = {"rooms": {room_id: body}} + + ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) + return 200, ret + + async def on_GET(self, request, room_id, session_id): + """ + Retrieves one or more encrypted E2E room keys for backup purposes. + Symmetric with the PUT version of the API. + + room_id: the ID of the room to retrieve the keys for (optional) + session_id: the ID for the E2E room keys to retrieve the keys for (optional) + version: the version of the user's backup which this data is for. + the version must already have been created via the /change_secret API. + + Returns as follows: + + GET /room_keys/keys/!abc:matrix.org/c0ff33?version=1 HTTP/1.1 + { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + + Or... + + GET /room_keys/keys/!abc:matrix.org?version=1 HTTP/1.1 + { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + + Or... + + GET /room_keys/keys?version=1 HTTP/1.1 + { + "rooms": { + "!abc:matrix.org": { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + } + } + """ + requester = await self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + version = parse_string(request, "version", required=True) + + room_keys = await self.e2e_room_keys_handler.get_room_keys( + user_id, version, room_id, session_id + ) + + # Convert room_keys to the right format to return. + if session_id: + # If the client requests a specific session, but that session was + # not backed up, then return an M_NOT_FOUND. + if room_keys["rooms"] == {}: + raise NotFoundError("No room_keys found") + else: + room_keys = room_keys["rooms"][room_id]["sessions"][session_id] + elif room_id: + # If the client requests all sessions from a room, but no sessions + # are found, then return an empty result rather than an error, so + # that clients don't have to handle an error condition, and an + # empty result is valid. (Similarly if the client requests all + # sessions from the backup, but in that case, room_keys is already + # in the right format, so we don't need to do anything about it.) + if room_keys["rooms"] == {}: + room_keys = {"sessions": {}} + else: + room_keys = room_keys["rooms"][room_id] + + return 200, room_keys + + async def on_DELETE(self, request, room_id, session_id): + """ + Deletes one or more encrypted E2E room keys for a user for backup purposes. + + DELETE /room_keys/keys/!abc:matrix.org/c0ff33?version=1 + HTTP/1.1 200 OK + {} + + room_id: the ID of the room whose keys to delete (optional) + session_id: the ID for the E2E session to delete (optional) + version: the version of the user's backup which this data is for. + the version must already have been created via the /change_secret API. + """ + + requester = await self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + version = parse_string(request, "version") + + ret = await self.e2e_room_keys_handler.delete_room_keys( + user_id, version, room_id, session_id + ) + return 200, ret + + +class RoomKeysNewVersionServlet(RestServlet): + PATTERNS = client_patterns("/room_keys/version$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.auth = hs.get_auth() + self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() + + async def on_POST(self, request): + """ + Create a new backup version for this user's room_keys with the given + info. The version is allocated by the server and returned to the user + in the response. This API is intended to be used whenever the user + changes the encryption key for their backups, ensuring that backups + encrypted with different keys don't collide. + + It takes out an exclusive lock on this user's room_key backups, to ensure + clients only upload to the current backup. + + The algorithm passed in the version info is a reverse-DNS namespaced + identifier to describe the format of the encrypted backupped keys. + + The auth_data is { user_id: "user_id", nonce: } + encrypted using the algorithm and current encryption key described above. + + POST /room_keys/version + Content-Type: application/json + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" + } + + HTTP/1.1 200 OK + Content-Type: application/json + { + "version": 12345 + } + """ + requester = await self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + info = parse_json_object_from_request(request) + + new_version = await self.e2e_room_keys_handler.create_version(user_id, info) + return 200, {"version": new_version} + + # we deliberately don't have a PUT /version, as these things really should + # be immutable to avoid people footgunning + + +class RoomKeysVersionServlet(RestServlet): + PATTERNS = client_patterns("/room_keys/version(/(?P[^/]+))?$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.auth = hs.get_auth() + self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() + + async def on_GET(self, request, version): + """ + Retrieve the version information about a given version of the user's + room_keys backup. If the version part is missing, returns info about the + most current backup version (if any) + + It takes out an exclusive lock on this user's room_key backups, to ensure + clients only upload to the current backup. + + Returns 404 if the given version does not exist. + + GET /room_keys/version/12345 HTTP/1.1 + { + "version": "12345", + "algorithm": "m.megolm_backup.v1", + "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" + } + """ + requester = await self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + + try: + info = await self.e2e_room_keys_handler.get_version_info(user_id, version) + except SynapseError as e: + if e.code == 404: + raise SynapseError(404, "No backup found", Codes.NOT_FOUND) + return 200, info + + async def on_DELETE(self, request, version): + """ + Delete the information about a given version of the user's + room_keys backup. If the version part is missing, deletes the most + current backup version (if any). Doesn't delete the actual room data. + + DELETE /room_keys/version/12345 HTTP/1.1 + HTTP/1.1 200 OK + {} + """ + if version is None: + raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND) + + requester = await self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + + await self.e2e_room_keys_handler.delete_version(user_id, version) + return 200, {} + + async def on_PUT(self, request, version): + """ + Update the information about a given version of the user's room_keys backup. + + POST /room_keys/version/12345 HTTP/1.1 + Content-Type: application/json + { + "algorithm": "m.megolm_backup.v1", + "auth_data": { + "public_key": "abcdefg", + "signatures": { + "ed25519:something": "hijklmnop" + } + }, + "version": "12345" + } + + HTTP/1.1 200 OK + Content-Type: application/json + {} + """ + requester = await self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + info = parse_json_object_from_request(request) + + if version is None: + raise SynapseError( + 400, "No version specified to update", Codes.MISSING_PARAM + ) + + await self.e2e_room_keys_handler.update_version(user_id, version, info) + return 200, {} + + +def register_servlets(hs, http_server): + RoomKeysServlet(hs).register(http_server) + RoomKeysVersionServlet(hs).register(http_server) + RoomKeysNewVersionServlet(hs).register(http_server) diff --git a/synapse/rest/client/room_upgrade_rest_servlet.py b/synapse/rest/client/room_upgrade_rest_servlet.py new file mode 100644 index 0000000000..6d1b083acb --- /dev/null +++ b/synapse/rest/client/room_upgrade_rest_servlet.py @@ -0,0 +1,88 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import Codes, ShadowBanError, SynapseError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.util import stringutils + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class RoomUpgradeRestServlet(RestServlet): + """Handler for room upgrade requests. + + Handles requests of the form: + + POST /_matrix/client/r0/rooms/$roomid/upgrade HTTP/1.1 + Content-Type: application/json + + { + "new_version": "2", + } + + Creates a new room and shuts down the old one. Returns the ID of the new room. + + Args: + hs (synapse.server.HomeServer): + """ + + PATTERNS = client_patterns( + # /rooms/$roomid/upgrade + "/rooms/(?P[^/]*)/upgrade$" + ) + + def __init__(self, hs): + super().__init__() + self._hs = hs + self._room_creation_handler = hs.get_room_creation_handler() + self._auth = hs.get_auth() + + async def on_POST(self, request, room_id): + requester = await self._auth.get_user_by_req(request) + + content = parse_json_object_from_request(request) + assert_params_in_dict(content, ("new_version",)) + + new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"]) + if new_version is None: + raise SynapseError( + 400, + "Your homeserver does not support this room version", + Codes.UNSUPPORTED_ROOM_VERSION, + ) + + try: + new_room_id = await self._room_creation_handler.upgrade_room( + requester, room_id, new_version + ) + except ShadowBanError: + # Generate a random room ID. + new_room_id = stringutils.random_string(18) + + ret = {"replacement_room": new_room_id} + + return 200, ret + + +def register_servlets(hs, http_server): + RoomUpgradeRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py new file mode 100644 index 0000000000..d537d811d8 --- /dev/null +++ b/synapse/rest/client/sendtodevice.py @@ -0,0 +1,67 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Tuple + +from synapse.http import servlet +from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request +from synapse.logging.opentracing import set_tag, trace +from synapse.rest.client.transactions import HttpTransactionCache + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class SendToDeviceRestServlet(servlet.RestServlet): + PATTERNS = client_patterns( + "/sendToDevice/(?P[^/]*)/(?P[^/]*)$" + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.txns = HttpTransactionCache(hs) + self.device_message_handler = hs.get_device_message_handler() + + @trace(opname="sendToDevice") + def on_PUT(self, request, message_type, txn_id): + set_tag("message_type", message_type) + set_tag("txn_id", txn_id) + return self.txns.fetch_or_execute_request( + request, self._put, request, message_type, txn_id + ) + + async def _put(self, request, message_type, txn_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + content = parse_json_object_from_request(request) + assert_params_in_dict(content, ("messages",)) + + await self.device_message_handler.send_device_message( + requester, message_type, content["messages"] + ) + + response: Tuple[int, dict] = (200, {}) + return response + + +def register_servlets(hs, http_server): + SendToDeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py new file mode 100644 index 0000000000..d2e7f04b40 --- /dev/null +++ b/synapse/rest/client/shared_rooms.py @@ -0,0 +1,67 @@ +# Copyright 2020 Half-Shot +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet +from synapse.types import UserID + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class UserSharedRoomsServlet(RestServlet): + """ + GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1 + """ + + PATTERNS = client_patterns( + "/uk.half-shot.msc2666/user/shared_rooms/(?P[^/]*)", + releases=(), # This is an unstable feature + ) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.user_directory_active = hs.config.update_user_directory + + async def on_GET(self, request, user_id): + + if not self.user_directory_active: + raise SynapseError( + code=400, + msg="The user directory is disabled on this server. Cannot determine shared rooms.", + errcode=Codes.FORBIDDEN, + ) + + UserID.from_string(user_id) + + requester = await self.auth.get_user_by_req(request) + if user_id == requester.user.to_string(): + raise SynapseError( + code=400, + msg="You cannot request a list of shared rooms with yourself", + errcode=Codes.FORBIDDEN, + ) + rooms = await self.store.get_shared_rooms_for_users( + requester.user.to_string(), user_id + ) + + return 200, {"joined": list(rooms)} + + +def register_servlets(hs, http_server): + UserSharedRoomsServlet(hs).register(http_server) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py new file mode 100644 index 0000000000..e18f4d01b3 --- /dev/null +++ b/synapse/rest/client/sync.py @@ -0,0 +1,532 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +import logging +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple + +from synapse.api.constants import Membership, PresenceState +from synapse.api.errors import Codes, StoreError, SynapseError +from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection +from synapse.events.utils import ( + format_event_for_client_v2_without_room_id, + format_event_raw, +) +from synapse.handlers.presence import format_user_presence_state +from synapse.handlers.sync import KnockedSyncResult, SyncConfig +from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict, StreamToken +from synapse.util import json_decoder + +from ._base import client_patterns, set_timeline_upper_limit + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class SyncRestServlet(RestServlet): + """ + + GET parameters:: + timeout(int): How long to wait for new events in milliseconds. + since(batch_token): Batch token when asking for incremental deltas. + set_presence(str): What state the device presence should be set to. + default is "online". + filter(filter_id): A filter to apply to the events returned. + + Response JSON:: + { + "next_batch": // batch token for the next /sync + "presence": // presence data for the user. + "rooms": { + "join": { // Joined rooms being updated. + "${room_id}": { // Id of the room being updated + "event_map": // Map of EventID -> event JSON. + "timeline": { // The recent events in the room if gap is "true" + "limited": // Was the per-room event limit exceeded? + // otherwise the next events in the room. + "events": [] // list of EventIDs in the "event_map". + "prev_batch": // back token for getting previous events. + } + "state": {"events": []} // list of EventIDs updating the + // current state to be what it should + // be at the end of the batch. + "ephemeral": {"events": []} // list of event objects + } + }, + "invite": {}, // Invited rooms being updated. + "leave": {} // Archived rooms being updated. + } + } + """ + + PATTERNS = client_patterns("/sync$") + ALLOWED_PRESENCE = {"online", "offline", "unavailable"} + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.sync_handler = hs.get_sync_handler() + self.clock = hs.get_clock() + self.filtering = hs.get_filtering() + self.presence_handler = hs.get_presence_handler() + self._server_notices_sender = hs.get_server_notices_sender() + self._event_serializer = hs.get_event_client_serializer() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + # This will always be set by the time Twisted calls us. + assert request.args is not None + + if b"from" in request.args: + # /events used to use 'from', but /sync uses 'since'. + # Lets be helpful and whine if we see a 'from'. + raise SynapseError( + 400, "'from' is not a valid query parameter. Did you mean 'since'?" + ) + + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user = requester.user + device_id = requester.device_id + + timeout = parse_integer(request, "timeout", default=0) + since = parse_string(request, "since") + set_presence = parse_string( + request, + "set_presence", + default="online", + allowed_values=self.ALLOWED_PRESENCE, + ) + filter_id = parse_string(request, "filter") + full_state = parse_boolean(request, "full_state", default=False) + + logger.debug( + "/sync: user=%r, timeout=%r, since=%r, " + "set_presence=%r, filter_id=%r, device_id=%r", + user, + timeout, + since, + set_presence, + filter_id, + device_id, + ) + + request_key = (user, timeout, since, filter_id, full_state, device_id) + + if filter_id is None: + filter_collection = DEFAULT_FILTER_COLLECTION + elif filter_id.startswith("{"): + try: + filter_object = json_decoder.decode(filter_id) + set_timeline_upper_limit( + filter_object, self.hs.config.filter_timeline_limit + ) + except Exception: + raise SynapseError(400, "Invalid filter JSON") + self.filtering.check_valid_filter(filter_object) + filter_collection = FilterCollection(filter_object) + else: + try: + filter_collection = await self.filtering.get_user_filter( + user.localpart, filter_id + ) + except StoreError as err: + if err.code != 404: + raise + # fix up the description and errcode to be more useful + raise SynapseError(400, "No such filter", errcode=Codes.INVALID_PARAM) + + sync_config = SyncConfig( + user=user, + filter_collection=filter_collection, + is_guest=requester.is_guest, + request_key=request_key, + device_id=device_id, + ) + + since_token = None + if since is not None: + since_token = await StreamToken.from_string(self.store, since) + + # send any outstanding server notices to the user. + await self._server_notices_sender.on_user_syncing(user.to_string()) + + affect_presence = set_presence != PresenceState.OFFLINE + + if affect_presence: + await self.presence_handler.set_state( + user, {"presence": set_presence}, True + ) + + context = await self.presence_handler.user_syncing( + user.to_string(), affect_presence=affect_presence + ) + with context: + sync_result = await self.sync_handler.wait_for_sync_for_user( + requester, + sync_config, + since_token=since_token, + timeout=timeout, + full_state=full_state, + ) + + # the client may have disconnected by now; don't bother to serialize the + # response if so. + if request._disconnected: + logger.info("Client has disconnected; not serializing response.") + return 200, {} + + time_now = self.clock.time_msec() + response_content = await self.encode_response( + time_now, sync_result, requester.access_token_id, filter_collection + ) + + logger.debug("Event formatting complete") + return 200, response_content + + async def encode_response(self, time_now, sync_result, access_token_id, filter): + logger.debug("Formatting events in sync response") + if filter.event_format == "client": + event_formatter = format_event_for_client_v2_without_room_id + elif filter.event_format == "federation": + event_formatter = format_event_raw + else: + raise Exception("Unknown event format %s" % (filter.event_format,)) + + joined = await self.encode_joined( + sync_result.joined, + time_now, + access_token_id, + filter.event_fields, + event_formatter, + ) + + invited = await self.encode_invited( + sync_result.invited, time_now, access_token_id, event_formatter + ) + + knocked = await self.encode_knocked( + sync_result.knocked, time_now, access_token_id, event_formatter + ) + + archived = await self.encode_archived( + sync_result.archived, + time_now, + access_token_id, + filter.event_fields, + event_formatter, + ) + + logger.debug("building sync response dict") + + response: dict = defaultdict(dict) + response["next_batch"] = await sync_result.next_batch.to_string(self.store) + + if sync_result.account_data: + response["account_data"] = {"events": sync_result.account_data} + if sync_result.presence: + response["presence"] = SyncRestServlet.encode_presence( + sync_result.presence, time_now + ) + + if sync_result.to_device: + response["to_device"] = {"events": sync_result.to_device} + + if sync_result.device_lists.changed: + response["device_lists"]["changed"] = list(sync_result.device_lists.changed) + if sync_result.device_lists.left: + response["device_lists"]["left"] = list(sync_result.device_lists.left) + + # We always include this because https://github.com/vector-im/element-android/issues/3725 + # The spec isn't terribly clear on when this can be omitted and how a client would tell + # the difference between "no keys present" and "nothing changed" in terms of whole field + # absent / individual key type entry absent + # Corresponding synapse issue: https://github.com/matrix-org/synapse/issues/10456 + response["device_one_time_keys_count"] = sync_result.device_one_time_keys_count + + # https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md + # states that this field should always be included, as long as the server supports the feature. + response[ + "org.matrix.msc2732.device_unused_fallback_key_types" + ] = sync_result.device_unused_fallback_key_types + + if joined: + response["rooms"][Membership.JOIN] = joined + if invited: + response["rooms"][Membership.INVITE] = invited + if knocked: + response["rooms"][Membership.KNOCK] = knocked + if archived: + response["rooms"][Membership.LEAVE] = archived + + if sync_result.groups.join: + response["groups"][Membership.JOIN] = sync_result.groups.join + if sync_result.groups.invite: + response["groups"][Membership.INVITE] = sync_result.groups.invite + if sync_result.groups.leave: + response["groups"][Membership.LEAVE] = sync_result.groups.leave + + return response + + @staticmethod + def encode_presence(events, time_now): + return { + "events": [ + { + "type": "m.presence", + "sender": event.user_id, + "content": format_user_presence_state( + event, time_now, include_user_id=False + ), + } + for event in events + ] + } + + async def encode_joined( + self, rooms, time_now, token_id, event_fields, event_formatter + ): + """ + Encode the joined rooms in a sync result + + Args: + rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync + results for rooms this user is joined to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing + of transaction IDs + event_fields(list): List of event fields to include. If empty, + all fields will be returned. + event_formatter (func[dict]): function to convert from federation format + to client format + Returns: + dict[str, dict[str, object]]: the joined rooms list, in our + response format + """ + joined = {} + for room in rooms: + joined[room.room_id] = await self.encode_room( + room, + time_now, + token_id, + joined=True, + only_fields=event_fields, + event_formatter=event_formatter, + ) + + return joined + + async def encode_invited(self, rooms, time_now, token_id, event_formatter): + """ + Encode the invited rooms in a sync result + + Args: + rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of + sync results for rooms this user is invited to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing + of transaction IDs + event_formatter (func[dict]): function to convert from federation format + to client format + + Returns: + dict[str, dict[str, object]]: the invited rooms list, in our + response format + """ + invited = {} + for room in rooms: + invite = await self._event_serializer.serialize_event( + room.invite, + time_now, + token_id=token_id, + event_format=event_formatter, + include_stripped_room_state=True, + ) + unsigned = dict(invite.get("unsigned", {})) + invite["unsigned"] = unsigned + invited_state = list(unsigned.pop("invite_room_state", [])) + invited_state.append(invite) + invited[room.room_id] = {"invite_state": {"events": invited_state}} + + return invited + + async def encode_knocked( + self, + rooms: List[KnockedSyncResult], + time_now: int, + token_id: int, + event_formatter: Callable[[Dict], Dict], + ) -> Dict[str, Dict[str, Any]]: + """ + Encode the rooms we've knocked on in a sync result. + + Args: + rooms: list of sync results for rooms this user is knocking on + time_now: current time - used as a baseline for age calculations + token_id: ID of the user's auth token - used for namespacing of transaction IDs + event_formatter: function to convert from federation format to client format + + Returns: + The list of rooms the user has knocked on, in our response format. + """ + knocked = {} + for room in rooms: + knock = await self._event_serializer.serialize_event( + room.knock, + time_now, + token_id=token_id, + event_format=event_formatter, + include_stripped_room_state=True, + ) + + # Extract the `unsigned` key from the knock event. + # This is where we (cheekily) store the knock state events + unsigned = knock.setdefault("unsigned", {}) + + # Duplicate the dictionary in order to avoid modifying the original + unsigned = dict(unsigned) + + # Extract the stripped room state from the unsigned dict + # This is for clients to get a little bit of information about + # the room they've knocked on, without revealing any sensitive information + knocked_state = list(unsigned.pop("knock_room_state", [])) + + # Append the actual knock membership event itself as well. This provides + # the client with: + # + # * A knock state event that they can use for easier internal tracking + # * The rough timestamp of when the knock occurred contained within the event + knocked_state.append(knock) + + # Build the `knock_state` dictionary, which will contain the state of the + # room that the client has knocked on + knocked[room.room_id] = {"knock_state": {"events": knocked_state}} + + return knocked + + async def encode_archived( + self, rooms, time_now, token_id, event_fields, event_formatter + ): + """ + Encode the archived rooms in a sync result + + Args: + rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of + sync results for rooms this user is joined to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing + of transaction IDs + event_fields(list): List of event fields to include. If empty, + all fields will be returned. + event_formatter (func[dict]): function to convert from federation format + to client format + Returns: + dict[str, dict[str, object]]: The invited rooms list, in our + response format + """ + joined = {} + for room in rooms: + joined[room.room_id] = await self.encode_room( + room, + time_now, + token_id, + joined=False, + only_fields=event_fields, + event_formatter=event_formatter, + ) + + return joined + + async def encode_room( + self, room, time_now, token_id, joined, only_fields, event_formatter + ): + """ + Args: + room (JoinedSyncResult|ArchivedSyncResult): sync result for a + single room + time_now (int): current time - used as a baseline for age + calculations + token_id (int): ID of the user's auth token - used for namespacing + of transaction IDs + joined (bool): True if the user is joined to this room - will mean + we handle ephemeral events + only_fields(list): Optional. The list of event fields to include. + event_formatter (func[dict]): function to convert from federation format + to client format + Returns: + dict[str, object]: the room, encoded in our response format + """ + + def serialize(events): + return self._event_serializer.serialize_events( + events, + time_now=time_now, + # We don't bundle "live" events, as otherwise clients + # will end up double counting annotations. + bundle_aggregations=False, + token_id=token_id, + event_format=event_formatter, + only_event_fields=only_fields, + ) + + state_dict = room.state + timeline_events = room.timeline.events + + state_events = state_dict.values() + + for event in itertools.chain(state_events, timeline_events): + # We've had bug reports that events were coming down under the + # wrong room. + if event.room_id != room.room_id: + logger.warning( + "Event %r is under room %r instead of %r", + event.event_id, + room.room_id, + event.room_id, + ) + + serialized_state = await serialize(state_events) + serialized_timeline = await serialize(timeline_events) + + account_data = room.account_data + + result = { + "timeline": { + "events": serialized_timeline, + "prev_batch": await room.timeline.prev_batch.to_string(self.store), + "limited": room.timeline.limited, + }, + "state": {"events": serialized_state}, + "account_data": {"events": account_data}, + } + + if joined: + ephemeral_events = room.ephemeral + result["ephemeral"] = {"events": ephemeral_events} + result["unread_notifications"] = room.unread_notifications + result["summary"] = room.summary + result["org.matrix.msc2654.unread_count"] = room.unread_count + + return result + + +def register_servlets(hs, http_server): + SyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py new file mode 100644 index 0000000000..c14f83be18 --- /dev/null +++ b/synapse/rest/client/tags.py @@ -0,0 +1,85 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import AuthError +from synapse.http.servlet import RestServlet, parse_json_object_from_request + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class TagListServlet(RestServlet): + """ + GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 + """ + + PATTERNS = client_patterns("/user/(?P[^/]*)/rooms/(?P[^/]*)/tags") + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, user_id, room_id): + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot get tags for other users.") + + tags = await self.store.get_tags_for_room(user_id, room_id) + + return 200, {"tags": tags} + + +class TagServlet(RestServlet): + """ + PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 + DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 + """ + + PATTERNS = client_patterns( + "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags/(?P[^/]*)" + ) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.handler = hs.get_account_data_handler() + + async def on_PUT(self, request, user_id, room_id, tag): + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot add tags for other users.") + + body = parse_json_object_from_request(request) + + await self.handler.add_tag_to_room(user_id, room_id, tag, body) + + return 200, {} + + async def on_DELETE(self, request, user_id, room_id, tag): + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot add tags for other users.") + + await self.handler.remove_tag_from_room(user_id, room_id, tag) + + return 200, {} + + +def register_servlets(hs, http_server): + TagListServlet(hs).register(http_server) + TagServlet(hs).register(http_server) diff --git a/synapse/rest/client/thirdparty.py b/synapse/rest/client/thirdparty.py new file mode 100644 index 0000000000..b5c67c9bb6 --- /dev/null +++ b/synapse/rest/client/thirdparty.py @@ -0,0 +1,111 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +from synapse.api.constants import ThirdPartyEntityKind +from synapse.http.servlet import RestServlet + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class ThirdPartyProtocolsServlet(RestServlet): + PATTERNS = client_patterns("/thirdparty/protocols") + + def __init__(self, hs): + super().__init__() + + self.auth = hs.get_auth() + self.appservice_handler = hs.get_application_service_handler() + + async def on_GET(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) + + protocols = await self.appservice_handler.get_3pe_protocols() + return 200, protocols + + +class ThirdPartyProtocolServlet(RestServlet): + PATTERNS = client_patterns("/thirdparty/protocol/(?P[^/]+)$") + + def __init__(self, hs): + super().__init__() + + self.auth = hs.get_auth() + self.appservice_handler = hs.get_application_service_handler() + + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) + + protocols = await self.appservice_handler.get_3pe_protocols( + only_protocol=protocol + ) + if protocol in protocols: + return 200, protocols[protocol] + else: + return 404, {"error": "Unknown protocol"} + + +class ThirdPartyUserServlet(RestServlet): + PATTERNS = client_patterns("/thirdparty/user(/(?P[^/]+))?$") + + def __init__(self, hs): + super().__init__() + + self.auth = hs.get_auth() + self.appservice_handler = hs.get_application_service_handler() + + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) + + fields = request.args + fields.pop(b"access_token", None) + + results = await self.appservice_handler.query_3pe( + ThirdPartyEntityKind.USER, protocol, fields + ) + + return 200, results + + +class ThirdPartyLocationServlet(RestServlet): + PATTERNS = client_patterns("/thirdparty/location(/(?P[^/]+))?$") + + def __init__(self, hs): + super().__init__() + + self.auth = hs.get_auth() + self.appservice_handler = hs.get_application_service_handler() + + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) + + fields = request.args + fields.pop(b"access_token", None) + + results = await self.appservice_handler.query_3pe( + ThirdPartyEntityKind.LOCATION, protocol, fields + ) + + return 200, results + + +def register_servlets(hs, http_server): + ThirdPartyProtocolsServlet(hs).register(http_server) + ThirdPartyProtocolServlet(hs).register(http_server) + ThirdPartyUserServlet(hs).register(http_server) + ThirdPartyLocationServlet(hs).register(http_server) diff --git a/synapse/rest/client/tokenrefresh.py b/synapse/rest/client/tokenrefresh.py new file mode 100644 index 0000000000..b2f858545c --- /dev/null +++ b/synapse/rest/client/tokenrefresh.py @@ -0,0 +1,37 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.api.errors import AuthError +from synapse.http.servlet import RestServlet + +from ._base import client_patterns + + +class TokenRefreshRestServlet(RestServlet): + """ + Exchanges refresh tokens for a pair of an access token and a new refresh + token. + """ + + PATTERNS = client_patterns("/tokenrefresh") + + def __init__(self, hs): + super().__init__() + + async def on_POST(self, request): + raise AuthError(403, "tokenrefresh is no longer supported.") + + +def register_servlets(hs, http_server): + TokenRefreshRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py new file mode 100644 index 0000000000..7e8912f0b9 --- /dev/null +++ b/synapse/rest/client/user_directory.py @@ -0,0 +1,79 @@ +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class UserDirectorySearchRestServlet(RestServlet): + PATTERNS = client_patterns("/user_directory/search$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.user_directory_handler = hs.get_user_directory_handler() + + async def on_POST(self, request): + """Searches for users in directory + + Returns: + dict of the form:: + + { + "limited": , # whether there were more results or not + "results": [ # Ordered by best match first + { + "user_id": , + "display_name": , + "avatar_url": + } + ] + } + """ + requester = await self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + + if not self.hs.config.user_directory_search_enabled: + return 200, {"limited": False, "results": []} + + body = parse_json_object_from_request(request) + + limit = body.get("limit", 10) + limit = min(limit, 50) + + try: + search_term = body["search_term"] + except Exception: + raise SynapseError(400, "`search_term` is required field") + + results = await self.user_directory_handler.search_users( + user_id, search_term, limit + ) + + return 200, results + + +def register_servlets(hs, http_server): + UserDirectorySearchRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/__init__.py b/synapse/rest/client/v1/__init__.py deleted file mode 100644 index 5e83dba2ed..0000000000 --- a/synapse/rest/client/v1/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py deleted file mode 100644 index ae92a3df8e..0000000000 --- a/synapse/rest/client/v1/directory.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging - -from synapse.api.errors import ( - AuthError, - Codes, - InvalidClientCredentialsError, - NotFoundError, - SynapseError, -) -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.types import RoomAlias - -logger = logging.getLogger(__name__) - - -def register_servlets(hs, http_server): - ClientDirectoryServer(hs).register(http_server) - ClientDirectoryListServer(hs).register(http_server) - ClientAppserviceDirectoryListServer(hs).register(http_server) - - -class ClientDirectoryServer(RestServlet): - PATTERNS = client_patterns("/directory/room/(?P[^/]*)$", v1=True) - - def __init__(self, hs): - super().__init__() - self.store = hs.get_datastore() - self.directory_handler = hs.get_directory_handler() - self.auth = hs.get_auth() - - async def on_GET(self, request, room_alias): - room_alias = RoomAlias.from_string(room_alias) - - res = await self.directory_handler.get_association(room_alias) - - return 200, res - - async def on_PUT(self, request, room_alias): - room_alias = RoomAlias.from_string(room_alias) - - content = parse_json_object_from_request(request) - if "room_id" not in content: - raise SynapseError( - 400, 'Missing params: ["room_id"]', errcode=Codes.BAD_JSON - ) - - logger.debug("Got content: %s", content) - logger.debug("Got room name: %s", room_alias.to_string()) - - room_id = content["room_id"] - servers = content["servers"] if "servers" in content else None - - logger.debug("Got room_id: %s", room_id) - logger.debug("Got servers: %s", servers) - - # TODO(erikj): Check types. - - room = await self.store.get_room(room_id) - if room is None: - raise SynapseError(400, "Room does not exist") - - requester = await self.auth.get_user_by_req(request) - - await self.directory_handler.create_association( - requester, room_alias, room_id, servers - ) - - return 200, {} - - async def on_DELETE(self, request, room_alias): - try: - service = self.auth.get_appservice_by_req(request) - room_alias = RoomAlias.from_string(room_alias) - await self.directory_handler.delete_appservice_association( - service, room_alias - ) - logger.info( - "Application service at %s deleted alias %s", - service.url, - room_alias.to_string(), - ) - return 200, {} - except InvalidClientCredentialsError: - # fallback to default user behaviour if they aren't an AS - pass - - requester = await self.auth.get_user_by_req(request) - user = requester.user - - room_alias = RoomAlias.from_string(room_alias) - - await self.directory_handler.delete_association(requester, room_alias) - - logger.info( - "User %s deleted alias %s", user.to_string(), room_alias.to_string() - ) - - return 200, {} - - -class ClientDirectoryListServer(RestServlet): - PATTERNS = client_patterns("/directory/list/room/(?P[^/]*)$", v1=True) - - def __init__(self, hs): - super().__init__() - self.store = hs.get_datastore() - self.directory_handler = hs.get_directory_handler() - self.auth = hs.get_auth() - - async def on_GET(self, request, room_id): - room = await self.store.get_room(room_id) - if room is None: - raise NotFoundError("Unknown room") - - return 200, {"visibility": "public" if room["is_public"] else "private"} - - async def on_PUT(self, request, room_id): - requester = await self.auth.get_user_by_req(request) - - content = parse_json_object_from_request(request) - visibility = content.get("visibility", "public") - - await self.directory_handler.edit_published_room_list( - requester, room_id, visibility - ) - - return 200, {} - - async def on_DELETE(self, request, room_id): - requester = await self.auth.get_user_by_req(request) - - await self.directory_handler.edit_published_room_list( - requester, room_id, "private" - ) - - return 200, {} - - -class ClientAppserviceDirectoryListServer(RestServlet): - PATTERNS = client_patterns( - "/directory/list/appservice/(?P[^/]*)/(?P[^/]*)$", v1=True - ) - - def __init__(self, hs): - super().__init__() - self.store = hs.get_datastore() - self.directory_handler = hs.get_directory_handler() - self.auth = hs.get_auth() - - def on_PUT(self, request, network_id, room_id): - content = parse_json_object_from_request(request) - visibility = content.get("visibility", "public") - return self._edit(request, network_id, room_id, visibility) - - def on_DELETE(self, request, network_id, room_id): - return self._edit(request, network_id, room_id, "private") - - async def _edit(self, request, network_id, room_id, visibility): - requester = await self.auth.get_user_by_req(request) - if not requester.app_service: - raise AuthError( - 403, "Only appservices can edit the appservice published room list" - ) - - await self.directory_handler.edit_published_appservice_room_list( - requester.app_service.id, network_id, room_id, visibility - ) - - return 200, {} diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py deleted file mode 100644 index ee7454996e..0000000000 --- a/synapse/rest/client/v1/events.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This module contains REST servlets to do with event streaming, /events.""" -import logging - -from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.streams.config import PaginationConfig - -logger = logging.getLogger(__name__) - - -class EventStreamRestServlet(RestServlet): - PATTERNS = client_patterns("/events$", v1=True) - - DEFAULT_LONGPOLL_TIME_MS = 30000 - - def __init__(self, hs): - super().__init__() - self.event_stream_handler = hs.get_event_stream_handler() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - is_guest = requester.is_guest - room_id = None - if is_guest: - if b"room_id" not in request.args: - raise SynapseError(400, "Guest users must specify room_id param") - if b"room_id" in request.args: - room_id = request.args[b"room_id"][0].decode("ascii") - - pagin_config = await PaginationConfig.from_request(self.store, request) - timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS - if b"timeout" in request.args: - try: - timeout = int(request.args[b"timeout"][0]) - except ValueError: - raise SynapseError(400, "timeout must be in milliseconds.") - - as_client_event = b"raw" not in request.args - - chunk = await self.event_stream_handler.get_stream( - requester.user.to_string(), - pagin_config, - timeout=timeout, - as_client_event=as_client_event, - affect_presence=(not is_guest), - room_id=room_id, - is_guest=is_guest, - ) - - return 200, chunk - - -class EventRestServlet(RestServlet): - PATTERNS = client_patterns("/events/(?P[^/]*)$", v1=True) - - def __init__(self, hs): - super().__init__() - self.clock = hs.get_clock() - self.event_handler = hs.get_event_handler() - self.auth = hs.get_auth() - self._event_serializer = hs.get_event_client_serializer() - - async def on_GET(self, request, event_id): - requester = await self.auth.get_user_by_req(request) - event = await self.event_handler.get_event(requester.user, None, event_id) - - time_now = self.clock.time_msec() - if event: - event = await self._event_serializer.serialize_event(event, time_now) - return 200, event - else: - return 404, "Event not found." - - -def register_servlets(hs, http_server): - EventStreamRestServlet(hs).register(http_server) - EventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py deleted file mode 100644 index bef1edc838..0000000000 --- a/synapse/rest/client/v1/initial_sync.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from synapse.http.servlet import RestServlet, parse_boolean -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.streams.config import PaginationConfig - - -# TODO: Needs unit testing -class InitialSyncRestServlet(RestServlet): - PATTERNS = client_patterns("/initialSync$", v1=True) - - def __init__(self, hs): - super().__init__() - self.initial_sync_handler = hs.get_initial_sync_handler() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request) - as_client_event = b"raw" not in request.args - pagination_config = await PaginationConfig.from_request(self.store, request) - include_archived = parse_boolean(request, "archived", default=False) - content = await self.initial_sync_handler.snapshot_all_rooms( - user_id=requester.user.to_string(), - pagin_config=pagination_config, - as_client_event=as_client_event, - include_archived=include_archived, - ) - - return 200, content - - -def register_servlets(hs, http_server): - InitialSyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py deleted file mode 100644 index 11567bf32c..0000000000 --- a/synapse/rest/client/v1/login.py +++ /dev/null @@ -1,600 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import re -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional - -from typing_extensions import TypedDict - -from synapse.api.errors import Codes, LoginError, SynapseError -from synapse.api.ratelimiting import Ratelimiter -from synapse.api.urls import CLIENT_API_PREFIX -from synapse.appservice import ApplicationService -from synapse.handlers.sso import SsoIdentityProvider -from synapse.http import get_request_uri -from synapse.http.server import HttpServer, finish_request -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_boolean, - parse_bytes_from_args, - parse_json_object_from_request, - parse_string, -) -from synapse.http.site import SynapseRequest -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.rest.well_known import WellKnownBuilder -from synapse.types import JsonDict, UserID - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class LoginResponse(TypedDict, total=False): - user_id: str - access_token: str - home_server: str - expires_in_ms: Optional[int] - refresh_token: Optional[str] - device_id: str - well_known: Optional[Dict[str, Any]] - - -class LoginRestServlet(RestServlet): - PATTERNS = client_patterns("/login$", v1=True) - CAS_TYPE = "m.login.cas" - SSO_TYPE = "m.login.sso" - TOKEN_TYPE = "m.login.token" - JWT_TYPE = "org.matrix.login.jwt" - JWT_TYPE_DEPRECATED = "m.login.jwt" - APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" - REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token" - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - - # JWT configuration variables. - self.jwt_enabled = hs.config.jwt_enabled - self.jwt_secret = hs.config.jwt_secret - self.jwt_algorithm = hs.config.jwt_algorithm - self.jwt_issuer = hs.config.jwt_issuer - self.jwt_audiences = hs.config.jwt_audiences - - # SSO configuration. - self.saml2_enabled = hs.config.saml2_enabled - self.cas_enabled = hs.config.cas_enabled - self.oidc_enabled = hs.config.oidc_enabled - self._msc2858_enabled = hs.config.experimental.msc2858_enabled - self._msc2918_enabled = hs.config.access_token_lifetime is not None - - self.auth = hs.get_auth() - - self.clock = hs.get_clock() - - self.auth_handler = self.hs.get_auth_handler() - self.registration_handler = hs.get_registration_handler() - self._sso_handler = hs.get_sso_handler() - - self._well_known_builder = WellKnownBuilder(hs) - self._address_ratelimiter = Ratelimiter( - store=hs.get_datastore(), - clock=hs.get_clock(), - rate_hz=self.hs.config.rc_login_address.per_second, - burst_count=self.hs.config.rc_login_address.burst_count, - ) - self._account_ratelimiter = Ratelimiter( - store=hs.get_datastore(), - clock=hs.get_clock(), - rate_hz=self.hs.config.rc_login_account.per_second, - burst_count=self.hs.config.rc_login_account.burst_count, - ) - - def on_GET(self, request: SynapseRequest): - flows = [] - if self.jwt_enabled: - flows.append({"type": LoginRestServlet.JWT_TYPE}) - flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED}) - - if self.cas_enabled: - # we advertise CAS for backwards compat, though MSC1721 renamed it - # to SSO. - flows.append({"type": LoginRestServlet.CAS_TYPE}) - - if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: - sso_flow: JsonDict = { - "type": LoginRestServlet.SSO_TYPE, - "identity_providers": [ - _get_auth_flow_dict_for_idp( - idp, - ) - for idp in self._sso_handler.get_identity_providers().values() - ], - } - - if self._msc2858_enabled: - # backwards-compatibility support for clients which don't - # support the stable API yet - sso_flow["org.matrix.msc2858.identity_providers"] = [ - _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True) - for idp in self._sso_handler.get_identity_providers().values() - ] - - flows.append(sso_flow) - - # While it's valid for us to advertise this login type generally, - # synapse currently only gives out these tokens as part of the - # SSO login flow. - # Generally we don't want to advertise login flows that clients - # don't know how to implement, since they (currently) will always - # fall back to the fallback API if they don't understand one of the - # login flow types returned. - flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - - flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types()) - - flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) - - return 200, {"flows": flows} - - async def on_POST(self, request: SynapseRequest): - login_submission = parse_json_object_from_request(request) - - if self._msc2918_enabled: - # Check if this login should also issue a refresh token, as per - # MSC2918 - should_issue_refresh_token = parse_boolean( - request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False - ) - else: - should_issue_refresh_token = False - - try: - if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: - appservice = self.auth.get_appservice_by_req(request) - - if appservice.is_rate_limited(): - await self._address_ratelimiter.ratelimit( - None, request.getClientIP() - ) - - result = await self._do_appservice_login( - login_submission, - appservice, - should_issue_refresh_token=should_issue_refresh_token, - ) - elif self.jwt_enabled and ( - login_submission["type"] == LoginRestServlet.JWT_TYPE - or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED - ): - await self._address_ratelimiter.ratelimit(None, request.getClientIP()) - result = await self._do_jwt_login( - login_submission, - should_issue_refresh_token=should_issue_refresh_token, - ) - elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: - await self._address_ratelimiter.ratelimit(None, request.getClientIP()) - result = await self._do_token_login( - login_submission, - should_issue_refresh_token=should_issue_refresh_token, - ) - else: - await self._address_ratelimiter.ratelimit(None, request.getClientIP()) - result = await self._do_other_login( - login_submission, - should_issue_refresh_token=should_issue_refresh_token, - ) - except KeyError: - raise SynapseError(400, "Missing JSON keys.") - - well_known_data = self._well_known_builder.get_well_known() - if well_known_data: - result["well_known"] = well_known_data - return 200, result - - async def _do_appservice_login( - self, - login_submission: JsonDict, - appservice: ApplicationService, - should_issue_refresh_token: bool = False, - ): - identifier = login_submission.get("identifier") - logger.info("Got appservice login request with identifier: %r", identifier) - - if not isinstance(identifier, dict): - raise SynapseError( - 400, "Invalid identifier in login submission", Codes.INVALID_PARAM - ) - - # this login flow only supports identifiers of type "m.id.user". - if identifier.get("type") != "m.id.user": - raise SynapseError( - 400, "Unknown login identifier type", Codes.INVALID_PARAM - ) - - user = identifier.get("user") - if not isinstance(user, str): - raise SynapseError(400, "Invalid user in identifier", Codes.INVALID_PARAM) - - if user.startswith("@"): - qualified_user_id = user - else: - qualified_user_id = UserID(user, self.hs.hostname).to_string() - - if not appservice.is_interested_in_user(qualified_user_id): - raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN) - - return await self._complete_login( - qualified_user_id, - login_submission, - ratelimit=appservice.is_rate_limited(), - should_issue_refresh_token=should_issue_refresh_token, - ) - - async def _do_other_login( - self, login_submission: JsonDict, should_issue_refresh_token: bool = False - ) -> LoginResponse: - """Handle non-token/saml/jwt logins - - Args: - login_submission: - should_issue_refresh_token: True if this login should issue - a refresh token alongside the access token. - - Returns: - HTTP response - """ - # Log the request we got, but only certain fields to minimise the chance of - # logging someone's password (even if they accidentally put it in the wrong - # field) - logger.info( - "Got login request with identifier: %r, medium: %r, address: %r, user: %r", - login_submission.get("identifier"), - login_submission.get("medium"), - login_submission.get("address"), - login_submission.get("user"), - ) - canonical_user_id, callback = await self.auth_handler.validate_login( - login_submission, ratelimit=True - ) - result = await self._complete_login( - canonical_user_id, - login_submission, - callback, - should_issue_refresh_token=should_issue_refresh_token, - ) - return result - - async def _complete_login( - self, - user_id: str, - login_submission: JsonDict, - callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None, - create_non_existent_users: bool = False, - ratelimit: bool = True, - auth_provider_id: Optional[str] = None, - should_issue_refresh_token: bool = False, - ) -> LoginResponse: - """Called when we've successfully authed the user and now need to - actually login them in (e.g. create devices). This gets called on - all successful logins. - - Applies the ratelimiting for successful login attempts against an - account. - - Args: - user_id: ID of the user to register. - login_submission: Dictionary of login information. - callback: Callback function to run after login. - create_non_existent_users: Whether to create the user if they don't - exist. Defaults to False. - ratelimit: Whether to ratelimit the login request. - auth_provider_id: The SSO IdP the user used, if any (just used for the - prometheus metrics). - should_issue_refresh_token: True if this login should issue - a refresh token alongside the access token. - - Returns: - result: Dictionary of account information after successful login. - """ - - # Before we actually log them in we check if they've already logged in - # too often. This happens here rather than before as we don't - # necessarily know the user before now. - if ratelimit: - await self._account_ratelimiter.ratelimit(None, user_id.lower()) - - if create_non_existent_users: - canonical_uid = await self.auth_handler.check_user_exists(user_id) - if not canonical_uid: - canonical_uid = await self.registration_handler.register_user( - localpart=UserID.from_string(user_id).localpart - ) - user_id = canonical_uid - - device_id = login_submission.get("device_id") - initial_display_name = login_submission.get("initial_device_display_name") - ( - device_id, - access_token, - valid_until_ms, - refresh_token, - ) = await self.registration_handler.register_device( - user_id, - device_id, - initial_display_name, - auth_provider_id=auth_provider_id, - should_issue_refresh_token=should_issue_refresh_token, - ) - - result = LoginResponse( - user_id=user_id, - access_token=access_token, - home_server=self.hs.hostname, - device_id=device_id, - ) - - if valid_until_ms is not None: - expires_in_ms = valid_until_ms - self.clock.time_msec() - result["expires_in_ms"] = expires_in_ms - - if refresh_token is not None: - result["refresh_token"] = refresh_token - - if callback is not None: - await callback(result) - - return result - - async def _do_token_login( - self, login_submission: JsonDict, should_issue_refresh_token: bool = False - ) -> LoginResponse: - """ - Handle the final stage of SSO login. - - Args: - login_submission: The JSON request body. - should_issue_refresh_token: True if this login should issue - a refresh token alongside the access token. - - Returns: - The body of the JSON response. - """ - token = login_submission["token"] - auth_handler = self.auth_handler - res = await auth_handler.validate_short_term_login_token(token) - - return await self._complete_login( - res.user_id, - login_submission, - self.auth_handler._sso_login_callback, - auth_provider_id=res.auth_provider_id, - should_issue_refresh_token=should_issue_refresh_token, - ) - - async def _do_jwt_login( - self, login_submission: JsonDict, should_issue_refresh_token: bool = False - ) -> LoginResponse: - token = login_submission.get("token", None) - if token is None: - raise LoginError( - 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN - ) - - import jwt - - try: - payload = jwt.decode( - token, - self.jwt_secret, - algorithms=[self.jwt_algorithm], - issuer=self.jwt_issuer, - audience=self.jwt_audiences, - ) - except jwt.PyJWTError as e: - # A JWT error occurred, return some info back to the client. - raise LoginError( - 403, - "JWT validation failed: %s" % (str(e),), - errcode=Codes.FORBIDDEN, - ) - - user = payload.get("sub", None) - if user is None: - raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) - - user_id = UserID(user, self.hs.hostname).to_string() - result = await self._complete_login( - user_id, - login_submission, - create_non_existent_users=True, - should_issue_refresh_token=should_issue_refresh_token, - ) - return result - - -def _get_auth_flow_dict_for_idp( - idp: SsoIdentityProvider, use_unstable_brands: bool = False -) -> JsonDict: - """Return an entry for the login flow dict - - Returns an entry suitable for inclusion in "identity_providers" in the - response to GET /_matrix/client/r0/login - - Args: - idp: the identity provider to describe - use_unstable_brands: whether we should use brand identifiers suitable - for the unstable API - """ - e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name} - if idp.idp_icon: - e["icon"] = idp.idp_icon - if idp.idp_brand: - e["brand"] = idp.idp_brand - # use the stable brand identifier if the unstable identifier isn't defined. - if use_unstable_brands and idp.unstable_idp_brand: - e["brand"] = idp.unstable_idp_brand - return e - - -class RefreshTokenServlet(RestServlet): - PATTERNS = client_patterns( - "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True - ) - - def __init__(self, hs: "HomeServer"): - self._auth_handler = hs.get_auth_handler() - self._clock = hs.get_clock() - self.access_token_lifetime = hs.config.access_token_lifetime - - async def on_POST( - self, - request: SynapseRequest, - ): - refresh_submission = parse_json_object_from_request(request) - - assert_params_in_dict(refresh_submission, ["refresh_token"]) - token = refresh_submission["refresh_token"] - if not isinstance(token, str): - raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM) - - valid_until_ms = self._clock.time_msec() + self.access_token_lifetime - access_token, refresh_token = await self._auth_handler.refresh_token( - token, valid_until_ms - ) - expires_in_ms = valid_until_ms - self._clock.time_msec() - return ( - 200, - { - "access_token": access_token, - "refresh_token": refresh_token, - "expires_in_ms": expires_in_ms, - }, - ) - - -class SsoRedirectServlet(RestServlet): - PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [ - re.compile( - "^" - + CLIENT_API_PREFIX - + "/r0/login/sso/redirect/(?P[A-Za-z0-9_.~-]+)$" - ) - ] - - def __init__(self, hs: "HomeServer"): - # make sure that the relevant handlers are instantiated, so that they - # register themselves with the main SSOHandler. - if hs.config.cas_enabled: - hs.get_cas_handler() - if hs.config.saml2_enabled: - hs.get_saml_handler() - if hs.config.oidc_enabled: - hs.get_oidc_handler() - self._sso_handler = hs.get_sso_handler() - self._msc2858_enabled = hs.config.experimental.msc2858_enabled - self._public_baseurl = hs.config.public_baseurl - - def register(self, http_server: HttpServer) -> None: - super().register(http_server) - if self._msc2858_enabled: - # expose additional endpoint for MSC2858 support: backwards-compat support - # for clients which don't yet support the stable endpoints. - http_server.register_paths( - "GET", - client_patterns( - "/org.matrix.msc2858/login/sso/redirect/(?P[A-Za-z0-9_.~-]+)$", - releases=(), - unstable=True, - ), - self.on_GET, - self.__class__.__name__, - ) - - async def on_GET( - self, request: SynapseRequest, idp_id: Optional[str] = None - ) -> None: - if not self._public_baseurl: - raise SynapseError(400, "SSO requires a valid public_baseurl") - - # if this isn't the expected hostname, redirect to the right one, so that we - # get our cookies back. - requested_uri = get_request_uri(request) - baseurl_bytes = self._public_baseurl.encode("utf-8") - if not requested_uri.startswith(baseurl_bytes): - # swap out the incorrect base URL for the right one. - # - # The idea here is to redirect from - # https://foo.bar/whatever/_matrix/... - # to - # https://public.baseurl/_matrix/... - # - i = requested_uri.index(b"/_matrix") - new_uri = baseurl_bytes[:-1] + requested_uri[i:] - logger.info( - "Requested URI %s is not canonical: redirecting to %s", - requested_uri.decode("utf-8", errors="replace"), - new_uri.decode("utf-8", errors="replace"), - ) - request.redirect(new_uri) - finish_request(request) - return - - args: Dict[bytes, List[bytes]] = request.args # type: ignore - client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True) - sso_url = await self._sso_handler.handle_redirect_request( - request, - client_redirect_url, - idp_id, - ) - logger.info("Redirecting to %s", sso_url) - request.redirect(sso_url) - finish_request(request) - - -class CasTicketServlet(RestServlet): - PATTERNS = client_patterns("/login/cas/ticket", v1=True) - - def __init__(self, hs): - super().__init__() - self._cas_handler = hs.get_cas_handler() - - async def on_GET(self, request: SynapseRequest) -> None: - client_redirect_url = parse_string(request, "redirectUrl") - ticket = parse_string(request, "ticket", required=True) - - # Maybe get a session ID (if this ticket is from user interactive - # authentication). - session = parse_string(request, "session") - - # Either client_redirect_url or session must be provided. - if not client_redirect_url and not session: - message = "Missing string query parameter redirectUrl or session" - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) - - await self._cas_handler.handle_ticket( - request, ticket, client_redirect_url, session - ) - - -def register_servlets(hs, http_server): - LoginRestServlet(hs).register(http_server) - if hs.config.access_token_lifetime is not None: - RefreshTokenServlet(hs).register(http_server) - SsoRedirectServlet(hs).register(http_server) - if hs.config.cas_enabled: - CasTicketServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py deleted file mode 100644 index 5aa7908d73..0000000000 --- a/synapse/rest/client/v1/logout.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.http.servlet import RestServlet -from synapse.rest.client.v2_alpha._base import client_patterns - -logger = logging.getLogger(__name__) - - -class LogoutRestServlet(RestServlet): - PATTERNS = client_patterns("/logout$", v1=True) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() - - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request, allow_expired=True) - - if requester.device_id is None: - # The access token wasn't associated with a device. - # Just delete the access token - access_token = self.auth.get_access_token_from_request(request) - await self._auth_handler.delete_access_token(access_token) - else: - await self._device_handler.delete_device( - requester.user.to_string(), requester.device_id - ) - - return 200, {} - - -class LogoutAllRestServlet(RestServlet): - PATTERNS = client_patterns("/logout/all$", v1=True) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() - - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request, allow_expired=True) - user_id = requester.user.to_string() - - # first delete all of the user's devices - await self._device_handler.delete_all_devices_for_user(user_id) - - # .. and then delete any access tokens which weren't associated with - # devices. - await self._auth_handler.delete_access_tokens_for_user(user_id) - return 200, {} - - -def register_servlets(hs, http_server): - LogoutRestServlet(hs).register(http_server) - LogoutAllRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py deleted file mode 100644 index 2b24fe5aa6..0000000000 --- a/synapse/rest/client/v1/presence.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" This module contains REST servlets to do with presence: /presence/ -""" -import logging - -from synapse.api.errors import AuthError, SynapseError -from synapse.handlers.presence import format_user_presence_state -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.types import UserID - -logger = logging.getLogger(__name__) - - -class PresenceStatusRestServlet(RestServlet): - PATTERNS = client_patterns("/presence/(?P[^/]*)/status", v1=True) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.presence_handler = hs.get_presence_handler() - self.clock = hs.get_clock() - self.auth = hs.get_auth() - - self._use_presence = hs.config.server.use_presence - - async def on_GET(self, request, user_id): - requester = await self.auth.get_user_by_req(request) - user = UserID.from_string(user_id) - - if not self._use_presence: - return 200, {"presence": "offline"} - - if requester.user != user: - allowed = await self.presence_handler.is_visible( - observed_user=user, observer_user=requester.user - ) - - if not allowed: - raise AuthError(403, "You are not allowed to see their presence.") - - state = await self.presence_handler.get_state(target_user=user) - state = format_user_presence_state( - state, self.clock.time_msec(), include_user_id=False - ) - - return 200, state - - async def on_PUT(self, request, user_id): - requester = await self.auth.get_user_by_req(request) - user = UserID.from_string(user_id) - - if requester.user != user: - raise AuthError(403, "Can only set your own presence state") - - state = {} - - content = parse_json_object_from_request(request) - - try: - state["presence"] = content.pop("presence") - - if "status_msg" in content: - state["status_msg"] = content.pop("status_msg") - if not isinstance(state["status_msg"], str): - raise SynapseError(400, "status_msg must be a string.") - - if content: - raise KeyError() - except SynapseError as e: - raise e - except Exception: - raise SynapseError(400, "Unable to parse state") - - if self._use_presence: - await self.presence_handler.set_state(user, state) - - return 200, {} - - -def register_servlets(hs, http_server): - PresenceStatusRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py deleted file mode 100644 index f42f4b3567..0000000000 --- a/synapse/rest/client/v1/profile.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" This module contains REST servlets to do with profile: /profile/ """ - -from synapse.api.errors import Codes, SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.types import UserID - - -class ProfileDisplaynameRestServlet(RestServlet): - PATTERNS = client_patterns("/profile/(?P[^/]*)/displayname", v1=True) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.profile_handler = hs.get_profile_handler() - self.auth = hs.get_auth() - - async def on_GET(self, request, user_id): - requester_user = None - - if self.hs.config.require_auth_for_profile_requests: - requester = await self.auth.get_user_by_req(request) - requester_user = requester.user - - user = UserID.from_string(user_id) - - await self.profile_handler.check_profile_query_allowed(user, requester_user) - - displayname = await self.profile_handler.get_displayname(user) - - ret = {} - if displayname is not None: - ret["displayname"] = displayname - - return 200, ret - - async def on_PUT(self, request, user_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - user = UserID.from_string(user_id) - is_admin = await self.auth.is_server_admin(requester.user) - - content = parse_json_object_from_request(request) - - try: - new_name = content["displayname"] - except Exception: - raise SynapseError( - code=400, - msg="Unable to parse name", - errcode=Codes.BAD_JSON, - ) - - await self.profile_handler.set_displayname(user, requester, new_name, is_admin) - - return 200, {} - - -class ProfileAvatarURLRestServlet(RestServlet): - PATTERNS = client_patterns("/profile/(?P[^/]*)/avatar_url", v1=True) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.profile_handler = hs.get_profile_handler() - self.auth = hs.get_auth() - - async def on_GET(self, request, user_id): - requester_user = None - - if self.hs.config.require_auth_for_profile_requests: - requester = await self.auth.get_user_by_req(request) - requester_user = requester.user - - user = UserID.from_string(user_id) - - await self.profile_handler.check_profile_query_allowed(user, requester_user) - - avatar_url = await self.profile_handler.get_avatar_url(user) - - ret = {} - if avatar_url is not None: - ret["avatar_url"] = avatar_url - - return 200, ret - - async def on_PUT(self, request, user_id): - requester = await self.auth.get_user_by_req(request) - user = UserID.from_string(user_id) - is_admin = await self.auth.is_server_admin(requester.user) - - content = parse_json_object_from_request(request) - try: - new_avatar_url = content["avatar_url"] - except KeyError: - raise SynapseError( - 400, "Missing key 'avatar_url'", errcode=Codes.MISSING_PARAM - ) - - await self.profile_handler.set_avatar_url( - user, requester, new_avatar_url, is_admin - ) - - return 200, {} - - -class ProfileRestServlet(RestServlet): - PATTERNS = client_patterns("/profile/(?P[^/]*)", v1=True) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.profile_handler = hs.get_profile_handler() - self.auth = hs.get_auth() - - async def on_GET(self, request, user_id): - requester_user = None - - if self.hs.config.require_auth_for_profile_requests: - requester = await self.auth.get_user_by_req(request) - requester_user = requester.user - - user = UserID.from_string(user_id) - - await self.profile_handler.check_profile_query_allowed(user, requester_user) - - displayname = await self.profile_handler.get_displayname(user) - avatar_url = await self.profile_handler.get_avatar_url(user) - - ret = {} - if displayname is not None: - ret["displayname"] = displayname - if avatar_url is not None: - ret["avatar_url"] = avatar_url - - return 200, ret - - -def register_servlets(hs, http_server): - ProfileDisplaynameRestServlet(hs).register(http_server) - ProfileAvatarURLRestServlet(hs).register(http_server) - ProfileRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py deleted file mode 100644 index be29a0b39e..0000000000 --- a/synapse/rest/client/v1/push_rule.py +++ /dev/null @@ -1,354 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.api.errors import ( - NotFoundError, - StoreError, - SynapseError, - UnrecognizedRequestError, -) -from synapse.http.servlet import ( - RestServlet, - parse_json_value_from_request, - parse_string, -) -from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS -from synapse.push.clientformat import format_push_rules_for_user -from synapse.push.rulekinds import PRIORITY_CLASS_MAP -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException - - -class PushRuleRestServlet(RestServlet): - PATTERNS = client_patterns("/(?Ppushrules/.*)$", v1=True) - SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( - "Unrecognised request: You probably wanted a trailing slash" - ) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.notifier = hs.get_notifier() - self._is_worker = hs.config.worker_app is not None - - self._users_new_default_push_rules = hs.config.users_new_default_push_rules - - async def on_PUT(self, request, path): - if self._is_worker: - raise Exception("Cannot handle PUT /push_rules on worker") - - spec = _rule_spec_from_path(path.split("/")) - try: - priority_class = _priority_class_from_spec(spec) - except InvalidRuleException as e: - raise SynapseError(400, str(e)) - - requester = await self.auth.get_user_by_req(request) - - if "/" in spec["rule_id"] or "\\" in spec["rule_id"]: - raise SynapseError(400, "rule_id may not contain slashes") - - content = parse_json_value_from_request(request) - - user_id = requester.user.to_string() - - if "attr" in spec: - await self.set_rule_attr(user_id, spec, content) - self.notify_user(user_id) - return 200, {} - - if spec["rule_id"].startswith("."): - # Rule ids starting with '.' are reserved for server default rules. - raise SynapseError(400, "cannot add new rule_ids that start with '.'") - - try: - (conditions, actions) = _rule_tuple_from_request_object( - spec["template"], spec["rule_id"], content - ) - except InvalidRuleException as e: - raise SynapseError(400, str(e)) - - before = parse_string(request, "before") - if before: - before = _namespaced_rule_id(spec, before) - - after = parse_string(request, "after") - if after: - after = _namespaced_rule_id(spec, after) - - try: - await self.store.add_push_rule( - user_id=user_id, - rule_id=_namespaced_rule_id_from_spec(spec), - priority_class=priority_class, - conditions=conditions, - actions=actions, - before=before, - after=after, - ) - self.notify_user(user_id) - except InconsistentRuleException as e: - raise SynapseError(400, str(e)) - except RuleNotFoundException as e: - raise SynapseError(400, str(e)) - - return 200, {} - - async def on_DELETE(self, request, path): - if self._is_worker: - raise Exception("Cannot handle DELETE /push_rules on worker") - - spec = _rule_spec_from_path(path.split("/")) - - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - - namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - - try: - await self.store.delete_push_rule(user_id, namespaced_rule_id) - self.notify_user(user_id) - return 200, {} - except StoreError as e: - if e.code == 404: - raise NotFoundError() - else: - raise - - async def on_GET(self, request, path): - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - - # we build up the full structure and then decide which bits of it - # to send which means doing unnecessary work sometimes but is - # is probably not going to make a whole lot of difference - rules = await self.store.get_push_rules_for_user(user_id) - - rules = format_push_rules_for_user(requester.user, rules) - - path = path.split("/")[1:] - - if path == []: - # we're a reference impl: pedantry is our job. - raise UnrecognizedRequestError( - PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR - ) - - if path[0] == "": - return 200, rules - elif path[0] == "global": - result = _filter_ruleset_with_path(rules["global"], path[1:]) - return 200, result - else: - raise UnrecognizedRequestError() - - def notify_user(self, user_id): - stream_id = self.store.get_max_push_rules_stream_id() - self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) - - async def set_rule_attr(self, user_id, spec, val): - if spec["attr"] not in ("enabled", "actions"): - # for the sake of potential future expansion, shouldn't report - # 404 in the case of an unknown request so check it corresponds to - # a known attribute first. - raise UnrecognizedRequestError() - - namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec["rule_id"] - is_default_rule = rule_id.startswith(".") - if is_default_rule: - if namespaced_rule_id not in BASE_RULE_IDS: - raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,)) - if spec["attr"] == "enabled": - if isinstance(val, dict) and "enabled" in val: - val = val["enabled"] - if not isinstance(val, bool): - # Legacy fallback - # This should *actually* take a dict, but many clients pass - # bools directly, so let's not break them. - raise SynapseError(400, "Value for 'enabled' must be boolean") - return await self.store.set_push_rule_enabled( - user_id, namespaced_rule_id, val, is_default_rule - ) - elif spec["attr"] == "actions": - actions = val.get("actions") - _check_actions(actions) - namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec["rule_id"] - is_default_rule = rule_id.startswith(".") - if is_default_rule: - if user_id in self._users_new_default_push_rules: - rule_ids = NEW_RULE_IDS - else: - rule_ids = BASE_RULE_IDS - - if namespaced_rule_id not in rule_ids: - raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) - return await self.store.set_push_rule_actions( - user_id, namespaced_rule_id, actions, is_default_rule - ) - else: - raise UnrecognizedRequestError() - - -def _rule_spec_from_path(path): - """Turn a sequence of path components into a rule spec - - Args: - path (sequence[unicode]): the URL path components. - - Returns: - dict: rule spec dict, containing scope/template/rule_id entries, - and possibly attr. - - Raises: - UnrecognizedRequestError if the path components cannot be parsed. - """ - if len(path) < 2: - raise UnrecognizedRequestError() - if path[0] != "pushrules": - raise UnrecognizedRequestError() - - scope = path[1] - path = path[2:] - if scope != "global": - raise UnrecognizedRequestError() - - if len(path) == 0: - raise UnrecognizedRequestError() - - template = path[0] - path = path[1:] - - if len(path) == 0 or len(path[0]) == 0: - raise UnrecognizedRequestError() - - rule_id = path[0] - - spec = {"scope": scope, "template": template, "rule_id": rule_id} - - path = path[1:] - - if len(path) > 0 and len(path[0]) > 0: - spec["attr"] = path[0] - - return spec - - -def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): - if rule_template in ["override", "underride"]: - if "conditions" not in req_obj: - raise InvalidRuleException("Missing 'conditions'") - conditions = req_obj["conditions"] - for c in conditions: - if "kind" not in c: - raise InvalidRuleException("Condition without 'kind'") - elif rule_template == "room": - conditions = [{"kind": "event_match", "key": "room_id", "pattern": rule_id}] - elif rule_template == "sender": - conditions = [{"kind": "event_match", "key": "user_id", "pattern": rule_id}] - elif rule_template == "content": - if "pattern" not in req_obj: - raise InvalidRuleException("Content rule missing 'pattern'") - pat = req_obj["pattern"] - - conditions = [{"kind": "event_match", "key": "content.body", "pattern": pat}] - else: - raise InvalidRuleException("Unknown rule template: %s" % (rule_template,)) - - if "actions" not in req_obj: - raise InvalidRuleException("No actions found") - actions = req_obj["actions"] - - _check_actions(actions) - - return conditions, actions - - -def _check_actions(actions): - if not isinstance(actions, list): - raise InvalidRuleException("No actions found") - - for a in actions: - if a in ["notify", "dont_notify", "coalesce"]: - pass - elif isinstance(a, dict) and "set_tweak" in a: - pass - else: - raise InvalidRuleException("Unrecognised action") - - -def _filter_ruleset_with_path(ruleset, path): - if path == []: - raise UnrecognizedRequestError( - PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR - ) - - if path[0] == "": - return ruleset - template_kind = path[0] - if template_kind not in ruleset: - raise UnrecognizedRequestError() - path = path[1:] - if path == []: - raise UnrecognizedRequestError( - PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR - ) - if path[0] == "": - return ruleset[template_kind] - rule_id = path[0] - - the_rule = None - for r in ruleset[template_kind]: - if r["rule_id"] == rule_id: - the_rule = r - if the_rule is None: - raise NotFoundError - - path = path[1:] - if len(path) == 0: - return the_rule - - attr = path[0] - if attr in the_rule: - # Make sure we return a JSON object as the attribute may be a - # JSON value. - return {attr: the_rule[attr]} - else: - raise UnrecognizedRequestError() - - -def _priority_class_from_spec(spec): - if spec["template"] not in PRIORITY_CLASS_MAP.keys(): - raise InvalidRuleException("Unknown template: %s" % (spec["template"])) - pc = PRIORITY_CLASS_MAP[spec["template"]] - - return pc - - -def _namespaced_rule_id_from_spec(spec): - return _namespaced_rule_id(spec, spec["rule_id"]) - - -def _namespaced_rule_id(spec, rule_id): - return "global/%s/%s" % (spec["template"], rule_id) - - -class InvalidRuleException(Exception): - pass - - -def register_servlets(hs, http_server): - PushRuleRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py deleted file mode 100644 index 18102eca6c..0000000000 --- a/synapse/rest/client/v1/pusher.py +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.errors import Codes, StoreError, SynapseError -from synapse.http.server import respond_with_html_bytes -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_json_object_from_request, - parse_string, -) -from synapse.push import PusherConfigException -from synapse.rest.client.v2_alpha._base import client_patterns - -logger = logging.getLogger(__name__) - - -class PushersRestServlet(RestServlet): - PATTERNS = client_patterns("/pushers$", v1=True) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request) - user = requester.user - - pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) - - filtered_pushers = [p.as_dict() for p in pushers] - - return 200, {"pushers": filtered_pushers} - - -class PushersSetRestServlet(RestServlet): - PATTERNS = client_patterns("/pushers/set$", v1=True) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.notifier = hs.get_notifier() - self.pusher_pool = self.hs.get_pusherpool() - - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request) - user = requester.user - - content = parse_json_object_from_request(request) - - if ( - "pushkey" in content - and "app_id" in content - and "kind" in content - and content["kind"] is None - ): - await self.pusher_pool.remove_pusher( - content["app_id"], content["pushkey"], user_id=user.to_string() - ) - return 200, {} - - assert_params_in_dict( - content, - [ - "kind", - "app_id", - "app_display_name", - "device_display_name", - "pushkey", - "lang", - "data", - ], - ) - - logger.debug("set pushkey %s to kind %s", content["pushkey"], content["kind"]) - logger.debug("Got pushers request with body: %r", content) - - append = False - if "append" in content: - append = content["append"] - - if not append: - await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( - app_id=content["app_id"], - pushkey=content["pushkey"], - not_user_id=user.to_string(), - ) - - try: - await self.pusher_pool.add_pusher( - user_id=user.to_string(), - access_token=requester.access_token_id, - kind=content["kind"], - app_id=content["app_id"], - app_display_name=content["app_display_name"], - device_display_name=content["device_display_name"], - pushkey=content["pushkey"], - lang=content["lang"], - data=content["data"], - profile_tag=content.get("profile_tag", ""), - ) - except PusherConfigException as pce: - raise SynapseError( - 400, "Config Error: " + str(pce), errcode=Codes.MISSING_PARAM - ) - - self.notifier.on_new_replication_data() - - return 200, {} - - -class PushersRemoveRestServlet(RestServlet): - """ - To allow pusher to be delete by clicking a link (ie. GET request) - """ - - PATTERNS = client_patterns("/pushers/remove$", v1=True) - SUCCESS_HTML = b"You have been unsubscribed" - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.notifier = hs.get_notifier() - self.auth = hs.get_auth() - self.pusher_pool = self.hs.get_pusherpool() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request, rights="delete_pusher") - user = requester.user - - app_id = parse_string(request, "app_id", required=True) - pushkey = parse_string(request, "pushkey", required=True) - - try: - await self.pusher_pool.remove_pusher( - app_id=app_id, pushkey=pushkey, user_id=user.to_string() - ) - except StoreError as se: - if se.code != 404: - # This is fine: they're already unsubscribed - raise - - self.notifier.on_new_replication_data() - - respond_with_html_bytes( - request, - 200, - PushersRemoveRestServlet.SUCCESS_HTML, - ) - return None - - -def register_servlets(hs, http_server): - PushersRestServlet(hs).register(http_server) - PushersSetRestServlet(hs).register(http_server) - PushersRemoveRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py deleted file mode 100644 index ba7250ad8e..0000000000 --- a/synapse/rest/client/v1/room.py +++ /dev/null @@ -1,1152 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" This module contains REST servlets to do with rooms: /rooms/ """ -import logging -import re -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple -from urllib import parse as urlparse - -from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import ( - AuthError, - Codes, - InvalidClientCredentialsError, - MissingClientTokenError, - ShadowBanError, - SynapseError, -) -from synapse.api.filtering import Filter -from synapse.events.utils import format_event_for_client_v2 -from synapse.http.servlet import ( - ResolveRoomIdMixin, - RestServlet, - assert_params_in_dict, - parse_boolean, - parse_integer, - parse_json_object_from_request, - parse_string, - parse_strings_from_args, -) -from synapse.http.site import SynapseRequest -from synapse.logging.opentracing import set_tag -from synapse.rest.client.transactions import HttpTransactionCache -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.storage.state import StateFilter -from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID -from synapse.util import json_decoder -from synapse.util.stringutils import parse_and_validate_server_name, random_string - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class TransactionRestServlet(RestServlet): - def __init__(self, hs): - super().__init__() - self.txns = HttpTransactionCache(hs) - - -class RoomCreateRestServlet(TransactionRestServlet): - # No PATTERN; we have custom dispatch rules here - - def __init__(self, hs): - super().__init__(hs) - self._room_creation_handler = hs.get_room_creation_handler() - self.auth = hs.get_auth() - - def register(self, http_server): - PATTERNS = "/createRoom" - register_txn_path(self, PATTERNS, http_server) - - def on_PUT(self, request, txn_id): - set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request(request, self.on_POST, request) - - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request) - - info, _ = await self._room_creation_handler.create_room( - requester, self.get_room_config(request) - ) - - return 200, info - - def get_room_config(self, request): - user_supplied_config = parse_json_object_from_request(request) - return user_supplied_config - - -# TODO: Needs unit testing for generic events -class RoomStateEventRestServlet(TransactionRestServlet): - def __init__(self, hs): - super().__init__(hs) - self.event_creation_handler = hs.get_event_creation_handler() - self.room_member_handler = hs.get_room_member_handler() - self.message_handler = hs.get_message_handler() - self.auth = hs.get_auth() - - def register(self, http_server): - # /room/$roomid/state/$eventtype - no_state_key = "/rooms/(?P[^/]*)/state/(?P[^/]*)$" - - # /room/$roomid/state/$eventtype/$statekey - state_key = ( - "/rooms/(?P[^/]*)/state/" - "(?P[^/]*)/(?P[^/]*)$" - ) - - http_server.register_paths( - "GET", - client_patterns(state_key, v1=True), - self.on_GET, - self.__class__.__name__, - ) - http_server.register_paths( - "PUT", - client_patterns(state_key, v1=True), - self.on_PUT, - self.__class__.__name__, - ) - http_server.register_paths( - "GET", - client_patterns(no_state_key, v1=True), - self.on_GET_no_state_key, - self.__class__.__name__, - ) - http_server.register_paths( - "PUT", - client_patterns(no_state_key, v1=True), - self.on_PUT_no_state_key, - self.__class__.__name__, - ) - - def on_GET_no_state_key(self, request, room_id, event_type): - return self.on_GET(request, room_id, event_type, "") - - def on_PUT_no_state_key(self, request, room_id, event_type): - return self.on_PUT(request, room_id, event_type, "") - - async def on_GET(self, request, room_id, event_type, state_key): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - format = parse_string( - request, "format", default="content", allowed_values=["content", "event"] - ) - - msg_handler = self.message_handler - data = await msg_handler.get_room_data( - user_id=requester.user.to_string(), - room_id=room_id, - event_type=event_type, - state_key=state_key, - ) - - if not data: - raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) - - if format == "event": - event = format_event_for_client_v2(data.get_dict()) - return 200, event - elif format == "content": - return 200, data.get_dict()["content"] - - async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): - requester = await self.auth.get_user_by_req(request) - - if txn_id: - set_tag("txn_id", txn_id) - - content = parse_json_object_from_request(request) - - event_dict = { - "type": event_type, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - } - - if state_key is not None: - event_dict["state_key"] = state_key - - try: - if event_type == EventTypes.Member: - membership = content.get("membership", None) - event_id, _ = await self.room_member_handler.update_membership( - requester, - target=UserID.from_string(state_key), - room_id=room_id, - action=membership, - content=content, - ) - else: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, txn_id=txn_id - ) - event_id = event.event_id - except ShadowBanError: - event_id = "$" + random_string(43) - - set_tag("event_id", event_id) - ret = {"event_id": event_id} - return 200, ret - - -# TODO: Needs unit testing for generic events + feedback -class RoomSendEventRestServlet(TransactionRestServlet): - def __init__(self, hs): - super().__init__(hs) - self.event_creation_handler = hs.get_event_creation_handler() - self.auth = hs.get_auth() - - def register(self, http_server): - # /rooms/$roomid/send/$event_type[/$txn_id] - PATTERNS = "/rooms/(?P[^/]*)/send/(?P[^/]*)" - register_txn_path(self, PATTERNS, http_server, with_get=True) - - async def on_POST(self, request, room_id, event_type, txn_id=None): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - content = parse_json_object_from_request(request) - - event_dict = { - "type": event_type, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - } - - if b"ts" in request.args and requester.app_service: - event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) - - try: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, txn_id=txn_id - ) - event_id = event.event_id - except ShadowBanError: - event_id = "$" + random_string(43) - - set_tag("event_id", event_id) - return 200, {"event_id": event_id} - - def on_GET(self, request, room_id, event_type, txn_id): - return 200, "Not implemented" - - def on_PUT(self, request, room_id, event_type, txn_id): - set_tag("txn_id", txn_id) - - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, event_type, txn_id - ) - - -# TODO: Needs unit testing for room ID + alias joins -class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): - def __init__(self, hs): - super().__init__(hs) - super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up - self.auth = hs.get_auth() - - def register(self, http_server): - # /join/$room_identifier[/$txn_id] - PATTERNS = "/join/(?P[^/]*)" - register_txn_path(self, PATTERNS, http_server) - - async def on_POST( - self, - request: SynapseRequest, - room_identifier: str, - txn_id: Optional[str] = None, - ): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - try: - content = parse_json_object_from_request(request) - except Exception: - # Turns out we used to ignore the body entirely, and some clients - # cheekily send invalid bodies. - content = {} - - # twisted.web.server.Request.args is incorrectly defined as Optional[Any] - args: Dict[bytes, List[bytes]] = request.args # type: ignore - remote_room_hosts = parse_strings_from_args(args, "server_name", required=False) - room_id, remote_room_hosts = await self.resolve_room_id( - room_identifier, - remote_room_hosts, - ) - - await self.room_member_handler.update_membership( - requester=requester, - target=requester.user, - room_id=room_id, - action="join", - txn_id=txn_id, - remote_room_hosts=remote_room_hosts, - content=content, - third_party_signed=content.get("third_party_signed", None), - ) - - return 200, {"room_id": room_id} - - def on_PUT(self, request, room_identifier, txn_id): - set_tag("txn_id", txn_id) - - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_identifier, txn_id - ) - - -# TODO: Needs unit testing -class PublicRoomListRestServlet(TransactionRestServlet): - PATTERNS = client_patterns("/publicRooms$", v1=True) - - def __init__(self, hs): - super().__init__(hs) - self.hs = hs - self.auth = hs.get_auth() - - async def on_GET(self, request): - server = parse_string(request, "server") - - try: - await self.auth.get_user_by_req(request, allow_guest=True) - except InvalidClientCredentialsError as e: - # Option to allow servers to require auth when accessing - # /publicRooms via CS API. This is especially helpful in private - # federations. - if not self.hs.config.allow_public_rooms_without_auth: - raise - - # We allow people to not be authed if they're just looking at our - # room list, but require auth when we proxy the request. - # In both cases we call the auth function, as that has the side - # effect of logging who issued this request if an access token was - # provided. - if server: - raise e - - limit: Optional[int] = parse_integer(request, "limit", 0) - since_token = parse_string(request, "since") - - if limit == 0: - # zero is a special value which corresponds to no limit. - limit = None - - handler = self.hs.get_room_list_handler() - if server and server != self.hs.config.server_name: - # Ensure the server is valid. - try: - parse_and_validate_server_name(server) - except ValueError: - raise SynapseError( - 400, - "Invalid server name: %s" % (server,), - Codes.INVALID_PARAM, - ) - - data = await handler.get_remote_public_room_list( - server, limit=limit, since_token=since_token - ) - else: - data = await handler.get_local_public_room_list( - limit=limit, since_token=since_token - ) - - return 200, data - - async def on_POST(self, request): - await self.auth.get_user_by_req(request, allow_guest=True) - - server = parse_string(request, "server") - content = parse_json_object_from_request(request) - - limit: Optional[int] = int(content.get("limit", 100)) - since_token = content.get("since", None) - search_filter = content.get("filter", None) - - include_all_networks = content.get("include_all_networks", False) - third_party_instance_id = content.get("third_party_instance_id", None) - - if include_all_networks: - network_tuple = None - if third_party_instance_id is not None: - raise SynapseError( - 400, "Can't use include_all_networks with an explicit network" - ) - elif third_party_instance_id is None: - network_tuple = ThirdPartyInstanceID(None, None) - else: - network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) - - if limit == 0: - # zero is a special value which corresponds to no limit. - limit = None - - handler = self.hs.get_room_list_handler() - if server and server != self.hs.config.server_name: - # Ensure the server is valid. - try: - parse_and_validate_server_name(server) - except ValueError: - raise SynapseError( - 400, - "Invalid server name: %s" % (server,), - Codes.INVALID_PARAM, - ) - - data = await handler.get_remote_public_room_list( - server, - limit=limit, - since_token=since_token, - search_filter=search_filter, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) - - else: - data = await handler.get_local_public_room_list( - limit=limit, - since_token=since_token, - search_filter=search_filter, - network_tuple=network_tuple, - ) - - return 200, data - - -# TODO: Needs unit testing -class RoomMemberListRestServlet(RestServlet): - PATTERNS = client_patterns("/rooms/(?P[^/]*)/members$", v1=True) - - def __init__(self, hs): - super().__init__() - self.message_handler = hs.get_message_handler() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - - async def on_GET(self, request, room_id): - # TODO support Pagination stream API (limit/tokens) - requester = await self.auth.get_user_by_req(request, allow_guest=True) - handler = self.message_handler - - # request the state as of a given event, as identified by a stream token, - # for consistency with /messages etc. - # useful for getting the membership in retrospect as of a given /sync - # response. - at_token_string = parse_string(request, "at") - if at_token_string is None: - at_token = None - else: - at_token = await StreamToken.from_string(self.store, at_token_string) - - # let you filter down on particular memberships. - # XXX: this may not be the best shape for this API - we could pass in a filter - # instead, except filters aren't currently aware of memberships. - # See https://github.com/matrix-org/matrix-doc/issues/1337 for more details. - membership = parse_string(request, "membership") - not_membership = parse_string(request, "not_membership") - - events = await handler.get_state_events( - room_id=room_id, - user_id=requester.user.to_string(), - at_token=at_token, - state_filter=StateFilter.from_types([(EventTypes.Member, None)]), - ) - - chunk = [] - - for event in events: - if (membership and event["content"].get("membership") != membership) or ( - not_membership and event["content"].get("membership") == not_membership - ): - continue - chunk.append(event) - - return 200, {"chunk": chunk} - - -# deprecated in favour of /members?membership=join? -# except it does custom AS logic and has a simpler return format -class JoinedRoomMemberListRestServlet(RestServlet): - PATTERNS = client_patterns("/rooms/(?P[^/]*)/joined_members$", v1=True) - - def __init__(self, hs): - super().__init__() - self.message_handler = hs.get_message_handler() - self.auth = hs.get_auth() - - async def on_GET(self, request, room_id): - requester = await self.auth.get_user_by_req(request) - - users_with_profile = await self.message_handler.get_joined_members( - requester, room_id - ) - - return 200, {"joined": users_with_profile} - - -# TODO: Needs better unit testing -class RoomMessageListRestServlet(RestServlet): - PATTERNS = client_patterns("/rooms/(?P[^/]*)/messages$", v1=True) - - def __init__(self, hs): - super().__init__() - self.pagination_handler = hs.get_pagination_handler() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - - async def on_GET(self, request, room_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = await PaginationConfig.from_request( - self.store, request, default_limit=10 - ) - as_client_event = b"raw" not in request.args - filter_str = parse_string(request, "filter", encoding="utf-8") - if filter_str: - filter_json = urlparse.unquote(filter_str) - event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) - if ( - event_filter - and event_filter.filter_json.get("event_format", "client") - == "federation" - ): - as_client_event = False - else: - event_filter = None - - msgs = await self.pagination_handler.get_messages( - room_id=room_id, - requester=requester, - pagin_config=pagination_config, - as_client_event=as_client_event, - event_filter=event_filter, - ) - - return 200, msgs - - -# TODO: Needs unit testing -class RoomStateRestServlet(RestServlet): - PATTERNS = client_patterns("/rooms/(?P[^/]*)/state$", v1=True) - - def __init__(self, hs): - super().__init__() - self.message_handler = hs.get_message_handler() - self.auth = hs.get_auth() - - async def on_GET(self, request, room_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - # Get all the current state for this room - events = await self.message_handler.get_state_events( - room_id=room_id, - user_id=requester.user.to_string(), - is_guest=requester.is_guest, - ) - return 200, events - - -# TODO: Needs unit testing -class RoomInitialSyncRestServlet(RestServlet): - PATTERNS = client_patterns("/rooms/(?P[^/]*)/initialSync$", v1=True) - - def __init__(self, hs): - super().__init__() - self.initial_sync_handler = hs.get_initial_sync_handler() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - - async def on_GET(self, request, room_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = await PaginationConfig.from_request(self.store, request) - content = await self.initial_sync_handler.room_initial_sync( - room_id=room_id, requester=requester, pagin_config=pagination_config - ) - return 200, content - - -class RoomEventServlet(RestServlet): - PATTERNS = client_patterns( - "/rooms/(?P[^/]*)/event/(?P[^/]*)$", v1=True - ) - - def __init__(self, hs): - super().__init__() - self.clock = hs.get_clock() - self.event_handler = hs.get_event_handler() - self._event_serializer = hs.get_event_client_serializer() - self.auth = hs.get_auth() - - async def on_GET(self, request, room_id, event_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - try: - event = await self.event_handler.get_event( - requester.user, room_id, event_id - ) - except AuthError: - # This endpoint is supposed to return a 404 when the requester does - # not have permission to access the event - # https://matrix.org/docs/spec/client_server/r0.5.0#get-matrix-client-r0-rooms-roomid-event-eventid - raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) - - time_now = self.clock.time_msec() - if event: - event = await self._event_serializer.serialize_event(event, time_now) - return 200, event - - return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) - - -class RoomEventContextServlet(RestServlet): - PATTERNS = client_patterns( - "/rooms/(?P[^/]*)/context/(?P[^/]*)$", v1=True - ) - - def __init__(self, hs): - super().__init__() - self.clock = hs.get_clock() - self.room_context_handler = hs.get_room_context_handler() - self._event_serializer = hs.get_event_client_serializer() - self.auth = hs.get_auth() - - async def on_GET(self, request, room_id, event_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - limit = parse_integer(request, "limit", default=10) - - # picking the API shape for symmetry with /messages - filter_str = parse_string(request, "filter", encoding="utf-8") - if filter_str: - filter_json = urlparse.unquote(filter_str) - event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) - else: - event_filter = None - - results = await self.room_context_handler.get_event_context( - requester, room_id, event_id, limit, event_filter - ) - - if not results: - raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) - - time_now = self.clock.time_msec() - results["events_before"] = await self._event_serializer.serialize_events( - results["events_before"], time_now - ) - results["event"] = await self._event_serializer.serialize_event( - results["event"], time_now - ) - results["events_after"] = await self._event_serializer.serialize_events( - results["events_after"], time_now - ) - results["state"] = await self._event_serializer.serialize_events( - results["state"], - time_now, - # No need to bundle aggregations for state events - bundle_aggregations=False, - ) - - return 200, results - - -class RoomForgetRestServlet(TransactionRestServlet): - def __init__(self, hs): - super().__init__(hs) - self.room_member_handler = hs.get_room_member_handler() - self.auth = hs.get_auth() - - def register(self, http_server): - PATTERNS = "/rooms/(?P[^/]*)/forget" - register_txn_path(self, PATTERNS, http_server) - - async def on_POST(self, request, room_id, txn_id=None): - requester = await self.auth.get_user_by_req(request, allow_guest=False) - - await self.room_member_handler.forget(user=requester.user, room_id=room_id) - - return 200, {} - - def on_PUT(self, request, room_id, txn_id): - set_tag("txn_id", txn_id) - - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, txn_id - ) - - -# TODO: Needs unit testing -class RoomMembershipRestServlet(TransactionRestServlet): - def __init__(self, hs): - super().__init__(hs) - self.room_member_handler = hs.get_room_member_handler() - self.auth = hs.get_auth() - - def register(self, http_server): - # /rooms/$roomid/[invite|join|leave] - PATTERNS = ( - "/rooms/(?P[^/]*)/" - "(?Pjoin|invite|leave|ban|unban|kick)" - ) - register_txn_path(self, PATTERNS, http_server) - - async def on_POST(self, request, room_id, membership_action, txn_id=None): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - if requester.is_guest and membership_action not in { - Membership.JOIN, - Membership.LEAVE, - }: - raise AuthError(403, "Guest access not allowed") - - try: - content = parse_json_object_from_request(request) - except Exception: - # Turns out we used to ignore the body entirely, and some clients - # cheekily send invalid bodies. - content = {} - - if membership_action == "invite" and self._has_3pid_invite_keys(content): - try: - await self.room_member_handler.do_3pid_invite( - room_id, - requester.user, - content["medium"], - content["address"], - content["id_server"], - requester, - txn_id, - content.get("id_access_token"), - ) - except ShadowBanError: - # Pretend the request succeeded. - pass - return 200, {} - - target = requester.user - if membership_action in ["invite", "ban", "unban", "kick"]: - assert_params_in_dict(content, ["user_id"]) - target = UserID.from_string(content["user_id"]) - - event_content = None - if "reason" in content: - event_content = {"reason": content["reason"]} - - try: - await self.room_member_handler.update_membership( - requester=requester, - target=target, - room_id=room_id, - action=membership_action, - txn_id=txn_id, - third_party_signed=content.get("third_party_signed", None), - content=event_content, - ) - except ShadowBanError: - # Pretend the request succeeded. - pass - - return_value = {} - - if membership_action == "join": - return_value["room_id"] = room_id - - return 200, return_value - - def _has_3pid_invite_keys(self, content): - for key in {"id_server", "medium", "address"}: - if key not in content: - return False - return True - - def on_PUT(self, request, room_id, membership_action, txn_id): - set_tag("txn_id", txn_id) - - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, membership_action, txn_id - ) - - -class RoomRedactEventRestServlet(TransactionRestServlet): - def __init__(self, hs): - super().__init__(hs) - self.event_creation_handler = hs.get_event_creation_handler() - self.auth = hs.get_auth() - - def register(self, http_server): - PATTERNS = "/rooms/(?P[^/]*)/redact/(?P[^/]*)" - register_txn_path(self, PATTERNS, http_server) - - async def on_POST(self, request, room_id, event_id, txn_id=None): - requester = await self.auth.get_user_by_req(request) - content = parse_json_object_from_request(request) - - try: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.Redaction, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "redacts": event_id, - }, - txn_id=txn_id, - ) - event_id = event.event_id - except ShadowBanError: - event_id = "$" + random_string(43) - - set_tag("event_id", event_id) - return 200, {"event_id": event_id} - - def on_PUT(self, request, room_id, event_id, txn_id): - set_tag("txn_id", txn_id) - - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, event_id, txn_id - ) - - -class RoomTypingRestServlet(RestServlet): - PATTERNS = client_patterns( - "/rooms/(?P[^/]*)/typing/(?P[^/]*)$", v1=True - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.presence_handler = hs.get_presence_handler() - self.auth = hs.get_auth() - - # If we're not on the typing writer instance we should scream if we get - # requests. - self._is_typing_writer = ( - hs.config.worker.writers.typing == hs.get_instance_name() - ) - - async def on_PUT(self, request, room_id, user_id): - requester = await self.auth.get_user_by_req(request) - - if not self._is_typing_writer: - raise Exception("Got /typing request on instance that is not typing writer") - - room_id = urlparse.unquote(room_id) - target_user = UserID.from_string(urlparse.unquote(user_id)) - - content = parse_json_object_from_request(request) - - await self.presence_handler.bump_presence_active_time(requester.user) - - # Limit timeout to stop people from setting silly typing timeouts. - timeout = min(content.get("timeout", 30000), 120000) - - # Defer getting the typing handler since it will raise on workers. - typing_handler = self.hs.get_typing_writer_handler() - - try: - if content["typing"]: - await typing_handler.started_typing( - target_user=target_user, - requester=requester, - room_id=room_id, - timeout=timeout, - ) - else: - await typing_handler.stopped_typing( - target_user=target_user, requester=requester, room_id=room_id - ) - except ShadowBanError: - # Pretend this worked without error. - pass - - return 200, {} - - -class RoomAliasListServlet(RestServlet): - PATTERNS = [ - re.compile( - r"^/_matrix/client/unstable/org\.matrix\.msc2432" - r"/rooms/(?P[^/]*)/aliases" - ), - ] + list(client_patterns("/rooms/(?P[^/]*)/aliases$", unstable=False)) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.directory_handler = hs.get_directory_handler() - - async def on_GET(self, request, room_id): - requester = await self.auth.get_user_by_req(request) - - alias_list = await self.directory_handler.get_aliases_for_room( - requester, room_id - ) - - return 200, {"aliases": alias_list} - - -class SearchRestServlet(RestServlet): - PATTERNS = client_patterns("/search$", v1=True) - - def __init__(self, hs): - super().__init__() - self.search_handler = hs.get_search_handler() - self.auth = hs.get_auth() - - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request) - - content = parse_json_object_from_request(request) - - batch = parse_string(request, "next_batch") - results = await self.search_handler.search(requester.user, content, batch) - - return 200, results - - -class JoinedRoomsRestServlet(RestServlet): - PATTERNS = client_patterns("/joined_rooms$", v1=True) - - def __init__(self, hs): - super().__init__() - self.store = hs.get_datastore() - self.auth = hs.get_auth() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - room_ids = await self.store.get_rooms_for_user(requester.user.to_string()) - return 200, {"joined_rooms": list(room_ids)} - - -def register_txn_path(servlet, regex_string, http_server, with_get=False): - """Registers a transaction-based path. - - This registers two paths: - PUT regex_string/$txnid - POST regex_string - - Args: - regex_string (str): The regex string to register. Must NOT have a - trailing $ as this string will be appended to. - http_server : The http_server to register paths with. - with_get: True to also register respective GET paths for the PUTs. - """ - http_server.register_paths( - "POST", - client_patterns(regex_string + "$", v1=True), - servlet.on_POST, - servlet.__class__.__name__, - ) - http_server.register_paths( - "PUT", - client_patterns(regex_string + "/(?P[^/]*)$", v1=True), - servlet.on_PUT, - servlet.__class__.__name__, - ) - if with_get: - http_server.register_paths( - "GET", - client_patterns(regex_string + "/(?P[^/]*)$", v1=True), - servlet.on_GET, - servlet.__class__.__name__, - ) - - -class RoomSpaceSummaryRestServlet(RestServlet): - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc2946" - "/rooms/(?P[^/]*)/spaces$" - ), - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self._auth = hs.get_auth() - self._room_summary_handler = hs.get_room_summary_handler() - - async def on_GET( - self, request: SynapseRequest, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request, allow_guest=True) - - max_rooms_per_space = parse_integer(request, "max_rooms_per_space") - if max_rooms_per_space is not None and max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - return 200, await self._room_summary_handler.get_space_summary( - requester.user.to_string(), - room_id, - suggested_only=parse_boolean(request, "suggested_only", default=False), - max_rooms_per_space=max_rooms_per_space, - ) - - # TODO When switching to the stable endpoint, remove the POST handler. - async def on_POST( - self, request: SynapseRequest, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request, allow_guest=True) - content = parse_json_object_from_request(request) - - suggested_only = content.get("suggested_only", False) - if not isinstance(suggested_only, bool): - raise SynapseError( - 400, "'suggested_only' must be a boolean", Codes.BAD_JSON - ) - - max_rooms_per_space = content.get("max_rooms_per_space") - if max_rooms_per_space is not None: - if not isinstance(max_rooms_per_space, int): - raise SynapseError( - 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON - ) - if max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - return 200, await self._room_summary_handler.get_space_summary( - requester.user.to_string(), - room_id, - suggested_only=suggested_only, - max_rooms_per_space=max_rooms_per_space, - ) - - -class RoomHierarchyRestServlet(RestServlet): - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc2946" - "/rooms/(?P[^/]*)/hierarchy$" - ), - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self._auth = hs.get_auth() - self._room_summary_handler = hs.get_room_summary_handler() - - async def on_GET( - self, request: SynapseRequest, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request, allow_guest=True) - - max_depth = parse_integer(request, "max_depth") - if max_depth is not None and max_depth < 0: - raise SynapseError( - 400, "'max_depth' must be a non-negative integer", Codes.BAD_JSON - ) - - limit = parse_integer(request, "limit") - if limit is not None and limit <= 0: - raise SynapseError( - 400, "'limit' must be a positive integer", Codes.BAD_JSON - ) - - return 200, await self._room_summary_handler.get_room_hierarchy( - requester.user.to_string(), - room_id, - suggested_only=parse_boolean(request, "suggested_only", default=False), - max_depth=max_depth, - limit=limit, - from_token=parse_string(request, "from"), - ) - - -class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet): - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/im.nheko.summary" - "/rooms/(?P[^/]*)/summary$" - ), - ) - - def __init__(self, hs: "HomeServer"): - super().__init__(hs) - self._auth = hs.get_auth() - self._room_summary_handler = hs.get_room_summary_handler() - - async def on_GET( - self, request: SynapseRequest, room_identifier: str - ) -> Tuple[int, JsonDict]: - try: - requester = await self._auth.get_user_by_req(request, allow_guest=True) - requester_user_id: Optional[str] = requester.user.to_string() - except MissingClientTokenError: - # auth is optional - requester_user_id = None - - # twisted.web.server.Request.args is incorrectly defined as Optional[Any] - args: Dict[bytes, List[bytes]] = request.args # type: ignore - remote_room_hosts = parse_strings_from_args(args, "via", required=False) - room_id, remote_room_hosts = await self.resolve_room_id( - room_identifier, - remote_room_hosts, - ) - - return 200, await self._room_summary_handler.get_room_summary( - requester_user_id, - room_id, - remote_room_hosts, - ) - - -def register_servlets(hs: "HomeServer", http_server, is_worker=False): - RoomStateEventRestServlet(hs).register(http_server) - RoomMemberListRestServlet(hs).register(http_server) - JoinedRoomMemberListRestServlet(hs).register(http_server) - RoomMessageListRestServlet(hs).register(http_server) - JoinRoomAliasServlet(hs).register(http_server) - RoomMembershipRestServlet(hs).register(http_server) - RoomSendEventRestServlet(hs).register(http_server) - PublicRoomListRestServlet(hs).register(http_server) - RoomStateRestServlet(hs).register(http_server) - RoomRedactEventRestServlet(hs).register(http_server) - RoomTypingRestServlet(hs).register(http_server) - RoomEventContextServlet(hs).register(http_server) - RoomSpaceSummaryRestServlet(hs).register(http_server) - RoomHierarchyRestServlet(hs).register(http_server) - if hs.config.experimental.msc3266_enabled: - RoomSummaryRestServlet(hs).register(http_server) - RoomEventServlet(hs).register(http_server) - JoinedRoomsRestServlet(hs).register(http_server) - RoomAliasListServlet(hs).register(http_server) - SearchRestServlet(hs).register(http_server) - - # Some servlets only get registered for the main process. - if not is_worker: - RoomCreateRestServlet(hs).register(http_server) - RoomForgetRestServlet(hs).register(http_server) - - -def register_deprecated_servlets(hs, http_server): - RoomInitialSyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py deleted file mode 100644 index c780ffded5..0000000000 --- a/synapse/rest/client/v1/voip.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import base64 -import hashlib -import hmac - -from synapse.http.servlet import RestServlet -from synapse.rest.client.v2_alpha._base import client_patterns - - -class VoipRestServlet(RestServlet): - PATTERNS = client_patterns("/voip/turnServer$", v1=True) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req( - request, self.hs.config.turn_allow_guests - ) - - turnUris = self.hs.config.turn_uris - turnSecret = self.hs.config.turn_shared_secret - turnUsername = self.hs.config.turn_username - turnPassword = self.hs.config.turn_password - userLifetime = self.hs.config.turn_user_lifetime - - if turnUris and turnSecret and userLifetime: - expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000 - username = "%d:%s" % (expiry, requester.user.to_string()) - - mac = hmac.new( - turnSecret.encode(), msg=username.encode(), digestmod=hashlib.sha1 - ) - # We need to use standard padded base64 encoding here - # encode_base64 because we need to add the standard padding to get the - # same result as the TURN server. - password = base64.b64encode(mac.digest()).decode("ascii") - - elif turnUris and turnUsername and turnPassword and userLifetime: - username = turnUsername - password = turnPassword - - else: - return 200, {} - - return ( - 200, - { - "username": username, - "password": password, - "ttl": userLifetime / 1000, - "uris": turnUris, - }, - ) - - -def register_servlets(hs, http_server): - VoipRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py deleted file mode 100644 index 5e83dba2ed..0000000000 --- a/synapse/rest/client/v2_alpha/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py deleted file mode 100644 index 0443f4571c..0000000000 --- a/synapse/rest/client/v2_alpha/_base.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This module contains base REST classes for constructing client v1 servlets. -""" -import logging -import re -from typing import Iterable, Pattern - -from synapse.api.errors import InteractiveAuthIncompleteError -from synapse.api.urls import CLIENT_API_PREFIX -from synapse.types import JsonDict - -logger = logging.getLogger(__name__) - - -def client_patterns( - path_regex: str, - releases: Iterable[int] = (0,), - unstable: bool = True, - v1: bool = False, -) -> Iterable[Pattern]: - """Creates a regex compiled client path with the correct client path - prefix. - - Args: - path_regex: The regex string to match. This should NOT have a ^ - as this will be prefixed. - releases: An iterable of releases to include this endpoint under. - unstable: If true, include this endpoint under the "unstable" prefix. - v1: If true, include this endpoint under the "api/v1" prefix. - Returns: - An iterable of patterns. - """ - patterns = [] - - if unstable: - unstable_prefix = CLIENT_API_PREFIX + "/unstable" - patterns.append(re.compile("^" + unstable_prefix + path_regex)) - if v1: - v1_prefix = CLIENT_API_PREFIX + "/api/v1" - patterns.append(re.compile("^" + v1_prefix + path_regex)) - for release in releases: - new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,) - patterns.append(re.compile("^" + new_prefix + path_regex)) - - return patterns - - -def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) -> None: - """ - Enforces a maximum limit of a timeline query. - - Params: - filter_json: The timeline query to modify. - filter_timeline_limit: The maximum limit to allow, passing -1 will - disable enforcing a maximum limit. - """ - if filter_timeline_limit < 0: - return # no upper limits - timeline = filter_json.get("room", {}).get("timeline", {}) - if "limit" in timeline: - filter_json["room"]["timeline"]["limit"] = min( - filter_json["room"]["timeline"]["limit"], filter_timeline_limit - ) - - -def interactive_auth_handler(orig): - """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors - - Takes a on_POST method which returns an Awaitable (errcode, body) response - and adds exception handling to turn a InteractiveAuthIncompleteError into - a 401 response. - - Normal usage is: - - @interactive_auth_handler - async def on_POST(self, request): - # ... - await self.auth_handler.check_auth - """ - - async def wrapped(*args, **kwargs): - try: - return await orig(*args, **kwargs) - except InteractiveAuthIncompleteError as e: - return 401, e.result - - return wrapped diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py deleted file mode 100644 index fb5ad2906e..0000000000 --- a/synapse/rest/client/v2_alpha/account.py +++ /dev/null @@ -1,910 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import random -from http import HTTPStatus -from typing import TYPE_CHECKING -from urllib.parse import urlparse - -from synapse.api.constants import LoginType -from synapse.api.errors import ( - Codes, - InteractiveAuthIncompleteError, - SynapseError, - ThreepidValidationError, -) -from synapse.config.emailconfig import ThreepidBehaviour -from synapse.handlers.ui_auth import UIAuthSessionDataConstants -from synapse.http.server import finish_request, respond_with_html -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_json_object_from_request, - parse_string, -) -from synapse.metrics import threepid_send_requests -from synapse.push.mailer import Mailer -from synapse.util.msisdn import phone_number_to_msisdn -from synapse.util.stringutils import assert_valid_client_secret, random_string -from synapse.util.threepids import check_3pid_allowed, validate_email - -from ._base import client_patterns, interactive_auth_handler - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -logger = logging.getLogger(__name__) - - -class EmailPasswordRequestTokenRestServlet(RestServlet): - PATTERNS = client_patterns("/account/password/email/requestToken$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.datastore = hs.get_datastore() - self.config = hs.config - self.identity_handler = hs.get_identity_handler() - - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.mailer = Mailer( - hs=self.hs, - app_name=self.config.email_app_name, - template_html=self.config.email_password_reset_template_html, - template_text=self.config.email_password_reset_template_text, - ) - - async def on_POST(self, request): - if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "User password resets have been disabled due to lack of email config" - ) - raise SynapseError( - 400, "Email-based password resets have been disabled on this server" - ) - - body = parse_json_object_from_request(request) - - assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) - - # Extract params from body - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) - - # Canonicalise the email address. The addresses are all stored canonicalised - # in the database. This allows the user to reset his password without having to - # know the exact spelling (eg. upper and lower case) of address in the database. - # Stored in the database "foo@bar.com" - # User requests with "FOO@bar.com" would raise a Not Found error - try: - email = validate_email(body["email"]) - except ValueError as e: - raise SynapseError(400, str(e)) - send_attempt = body["send_attempt"] - next_link = body.get("next_link") # Optional param - - if next_link: - # Raise if the provided next_link value isn't valid - assert_valid_next_link(self.hs, next_link) - - await self.identity_handler.ratelimit_request_token_requests( - request, "email", email - ) - - # The email will be sent to the stored address. - # This avoids a potential account hijack by requesting a password reset to - # an email address which is controlled by the attacker but which, after - # canonicalisation, matches the one in our database. - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( - "email", email - ) - - if existing_user_id is None: - if self.config.request_token_inhibit_3pid_errors: - # Make the client think the operation succeeded. See the rationale in the - # comments for request_token_inhibit_3pid_errors. - # Also wait for some random amount of time between 100ms and 1s to make it - # look like we did something. - await self.hs.get_clock().sleep(random.randint(1, 10) / 10) - return 200, {"sid": random_string(16)} - - raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) - - if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email - - # Have the configured identity server handle the request - ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, - email, - client_secret, - send_attempt, - next_link, - ) - else: - # Send password reset emails from Synapse - sid = await self.identity_handler.send_threepid_validation( - email, - client_secret, - send_attempt, - self.mailer.send_password_reset_mail, - next_link, - ) - - # Wrap the session id in a JSON object - ret = {"sid": sid} - - threepid_send_requests.labels(type="email", reason="password_reset").observe( - send_attempt - ) - - return 200, ret - - -class PasswordRestServlet(RestServlet): - PATTERNS = client_patterns("/account/password$") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() - self.datastore = self.hs.get_datastore() - self.password_policy_handler = hs.get_password_policy_handler() - self._set_password_handler = hs.get_set_password_handler() - - @interactive_auth_handler - async def on_POST(self, request): - body = parse_json_object_from_request(request) - - # we do basic sanity checks here because the auth layer will store these - # in sessions. Pull out the new password provided to us. - new_password = body.pop("new_password", None) - if new_password is not None: - if not isinstance(new_password, str) or len(new_password) > 512: - raise SynapseError(400, "Invalid password") - self.password_policy_handler.validate_password(new_password) - - # there are two possibilities here. Either the user does not have an - # access token, and needs to do a password reset; or they have one and - # need to validate their identity. - # - # In the first case, we offer a couple of means of identifying - # themselves (email and msisdn, though it's unclear if msisdn actually - # works). - # - # In the second case, we require a password to confirm their identity. - - if self.auth.has_access_token(request): - requester = await self.auth.get_user_by_req(request) - try: - params, session_id = await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - "modify your account password", - ) - except InteractiveAuthIncompleteError as e: - # The user needs to provide more steps to complete auth, but - # they're not required to provide the password again. - # - # If a password is available now, hash the provided password and - # store it for later. - if new_password: - password_hash = await self.auth_handler.hash(new_password) - await self.auth_handler.set_session_data( - e.session_id, - UIAuthSessionDataConstants.PASSWORD_HASH, - password_hash, - ) - raise - user_id = requester.user.to_string() - else: - requester = None - try: - result, params, session_id = await self.auth_handler.check_ui_auth( - [[LoginType.EMAIL_IDENTITY]], - request, - body, - "modify your account password", - ) - except InteractiveAuthIncompleteError as e: - # The user needs to provide more steps to complete auth, but - # they're not required to provide the password again. - # - # If a password is available now, hash the provided password and - # store it for later. - if new_password: - password_hash = await self.auth_handler.hash(new_password) - await self.auth_handler.set_session_data( - e.session_id, - UIAuthSessionDataConstants.PASSWORD_HASH, - password_hash, - ) - raise - - if LoginType.EMAIL_IDENTITY in result: - threepid = result[LoginType.EMAIL_IDENTITY] - if "medium" not in threepid or "address" not in threepid: - raise SynapseError(500, "Malformed threepid") - if threepid["medium"] == "email": - # For emails, canonicalise the address. - # We store all email addresses canonicalised in the DB. - # (See add_threepid in synapse/handlers/auth.py) - try: - threepid["address"] = validate_email(threepid["address"]) - except ValueError as e: - raise SynapseError(400, str(e)) - # if using email, we must know about the email they're authing with! - threepid_user_id = await self.datastore.get_user_id_by_threepid( - threepid["medium"], threepid["address"] - ) - if not threepid_user_id: - raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) - user_id = threepid_user_id - else: - logger.error("Auth succeeded but no known type! %r", result.keys()) - raise SynapseError(500, "", Codes.UNKNOWN) - - # If we have a password in this request, prefer it. Otherwise, use the - # password hash from an earlier request. - if new_password: - password_hash = await self.auth_handler.hash(new_password) - elif session_id is not None: - password_hash = await self.auth_handler.get_session_data( - session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None - ) - else: - # UI validation was skipped, but the request did not include a new - # password. - password_hash = None - if not password_hash: - raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) - - logout_devices = params.get("logout_devices", True) - - await self._set_password_handler.set_password( - user_id, password_hash, logout_devices, requester - ) - - return 200, {} - - -class DeactivateAccountRestServlet(RestServlet): - PATTERNS = client_patterns("/account/deactivate$") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() - self._deactivate_account_handler = hs.get_deactivate_account_handler() - - @interactive_auth_handler - async def on_POST(self, request): - body = parse_json_object_from_request(request) - erase = body.get("erase", False) - if not isinstance(erase, bool): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Param 'erase' must be a boolean, if given", - Codes.BAD_JSON, - ) - - requester = await self.auth.get_user_by_req(request) - - # allow ASes to deactivate their own users - if requester.app_service: - await self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase, requester - ) - return 200, {} - - await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - "deactivate your account", - ) - result = await self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), - erase, - requester, - id_server=body.get("id_server"), - ) - if result: - id_server_unbind_result = "success" - else: - id_server_unbind_result = "no-support" - - return 200, {"id_server_unbind_result": id_server_unbind_result} - - -class EmailThreepidRequestTokenRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/email/requestToken$") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.config = hs.config - self.identity_handler = hs.get_identity_handler() - self.store = self.hs.get_datastore() - - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.mailer = Mailer( - hs=self.hs, - app_name=self.config.email_app_name, - template_html=self.config.email_add_threepid_template_html, - template_text=self.config.email_add_threepid_template_text, - ) - - async def on_POST(self, request): - if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "Adding emails have been disabled due to lack of an email config" - ) - raise SynapseError( - 400, "Adding an email to your account is disabled on this server" - ) - - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) - - # Canonicalise the email address. The addresses are all stored canonicalised - # in the database. - # This ensures that the validation email is sent to the canonicalised address - # as it will later be entered into the database. - # Otherwise the email will be sent to "FOO@bar.com" and stored as - # "foo@bar.com" in database. - try: - email = validate_email(body["email"]) - except ValueError as e: - raise SynapseError(400, str(e)) - send_attempt = body["send_attempt"] - next_link = body.get("next_link") # Optional param - - if not check_3pid_allowed(self.hs, "email", email): - raise SynapseError( - 403, - "Your email domain is not authorized on this server", - Codes.THREEPID_DENIED, - ) - - await self.identity_handler.ratelimit_request_token_requests( - request, "email", email - ) - - if next_link: - # Raise if the provided next_link value isn't valid - assert_valid_next_link(self.hs, next_link) - - existing_user_id = await self.store.get_user_id_by_threepid("email", email) - - if existing_user_id is not None: - if self.config.request_token_inhibit_3pid_errors: - # Make the client think the operation succeeded. See the rationale in the - # comments for request_token_inhibit_3pid_errors. - # Also wait for some random amount of time between 100ms and 1s to make it - # look like we did something. - await self.hs.get_clock().sleep(random.randint(1, 10) / 10) - return 200, {"sid": random_string(16)} - - raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - - if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email - - # Have the configured identity server handle the request - ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, - email, - client_secret, - send_attempt, - next_link, - ) - else: - # Send threepid validation emails from Synapse - sid = await self.identity_handler.send_threepid_validation( - email, - client_secret, - send_attempt, - self.mailer.send_add_threepid_mail, - next_link, - ) - - # Wrap the session id in a JSON object - ret = {"sid": sid} - - threepid_send_requests.labels(type="email", reason="add_threepid").observe( - send_attempt - ) - - return 200, ret - - -class MsisdnThreepidRequestTokenRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$") - - def __init__(self, hs: "HomeServer"): - self.hs = hs - super().__init__() - self.store = self.hs.get_datastore() - self.identity_handler = hs.get_identity_handler() - - async def on_POST(self, request): - body = parse_json_object_from_request(request) - assert_params_in_dict( - body, ["client_secret", "country", "phone_number", "send_attempt"] - ) - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) - - country = body["country"] - phone_number = body["phone_number"] - send_attempt = body["send_attempt"] - next_link = body.get("next_link") # Optional param - - msisdn = phone_number_to_msisdn(country, phone_number) - - if not check_3pid_allowed(self.hs, "msisdn", msisdn): - raise SynapseError( - 403, - "Account phone numbers are not authorized on this server", - Codes.THREEPID_DENIED, - ) - - await self.identity_handler.ratelimit_request_token_requests( - request, "msisdn", msisdn - ) - - if next_link: - # Raise if the provided next_link value isn't valid - assert_valid_next_link(self.hs, next_link) - - existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) - - if existing_user_id is not None: - if self.hs.config.request_token_inhibit_3pid_errors: - # Make the client think the operation succeeded. See the rationale in the - # comments for request_token_inhibit_3pid_errors. - # Also wait for some random amount of time between 100ms and 1s to make it - # look like we did something. - await self.hs.get_clock().sleep(random.randint(1, 10) / 10) - return 200, {"sid": random_string(16)} - - raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) - - if not self.hs.config.account_threepid_delegate_msisdn: - logger.warning( - "No upstream msisdn account_threepid_delegate configured on the server to " - "handle this request" - ) - raise SynapseError( - 400, - "Adding phone numbers to user account is not supported by this homeserver", - ) - - ret = await self.identity_handler.requestMsisdnToken( - self.hs.config.account_threepid_delegate_msisdn, - country, - phone_number, - client_secret, - send_attempt, - next_link, - ) - - threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe( - send_attempt - ) - - return 200, ret - - -class AddThreepidEmailSubmitTokenServlet(RestServlet): - """Handles 3PID validation token submission for adding an email to a user's account""" - - PATTERNS = client_patterns( - "/add_threepid/email/submit_token$", releases=(), unstable=True - ) - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.config = hs.config - self.clock = hs.get_clock() - self.store = hs.get_datastore() - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self._failure_email_template = ( - self.config.email_add_threepid_template_failure_html - ) - - async def on_GET(self, request): - if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "Adding emails have been disabled due to lack of an email config" - ) - raise SynapseError( - 400, "Adding an email to your account is disabled on this server" - ) - elif self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - raise SynapseError( - 400, - "This homeserver is not validating threepids. Use an identity server " - "instead.", - ) - - sid = parse_string(request, "sid", required=True) - token = parse_string(request, "token", required=True) - client_secret = parse_string(request, "client_secret", required=True) - assert_valid_client_secret(client_secret) - - # Attempt to validate a 3PID session - try: - # Mark the session as valid - next_link = await self.store.validate_threepid_session( - sid, client_secret, token, self.clock.time_msec() - ) - - # Perform a 302 redirect if next_link is set - if next_link: - request.setResponseCode(302) - request.setHeader("Location", next_link) - finish_request(request) - return None - - # Otherwise show the success template - html = self.config.email_add_threepid_template_success_html_content - status_code = 200 - except ThreepidValidationError as e: - status_code = e.code - - # Show a failure page with a reason - template_vars = {"failure_reason": e.msg} - html = self._failure_email_template.render(**template_vars) - - respond_with_html(request, status_code, html) - - -class AddThreepidMsisdnSubmitTokenServlet(RestServlet): - """Handles 3PID validation token submission for adding a phone number to a user's - account - """ - - PATTERNS = client_patterns( - "/add_threepid/msisdn/submit_token$", releases=(), unstable=True - ) - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.config = hs.config - self.clock = hs.get_clock() - self.store = hs.get_datastore() - self.identity_handler = hs.get_identity_handler() - - async def on_POST(self, request): - if not self.config.account_threepid_delegate_msisdn: - raise SynapseError( - 400, - "This homeserver is not validating phone numbers. Use an identity server " - "instead.", - ) - - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["client_secret", "sid", "token"]) - assert_valid_client_secret(body["client_secret"]) - - # Proxy submit_token request to msisdn threepid delegate - response = await self.identity_handler.proxy_msisdn_submit_token( - self.config.account_threepid_delegate_msisdn, - body["client_secret"], - body["sid"], - body["token"], - ) - return 200, response - - -class ThreepidRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid$") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.identity_handler = hs.get_identity_handler() - self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() - self.datastore = self.hs.get_datastore() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request) - - threepids = await self.datastore.user_get_threepids(requester.user.to_string()) - - return 200, {"threepids": threepids} - - async def on_POST(self, request): - if not self.hs.config.enable_3pid_changes: - raise SynapseError( - 400, "3PID changes are disabled on this server", Codes.FORBIDDEN - ) - - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - - threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds") - if threepid_creds is None: - raise SynapseError( - 400, "Missing param three_pid_creds", Codes.MISSING_PARAM - ) - assert_params_in_dict(threepid_creds, ["client_secret", "sid"]) - - sid = threepid_creds["sid"] - client_secret = threepid_creds["client_secret"] - assert_valid_client_secret(client_secret) - - validation_session = await self.identity_handler.validate_threepid_session( - client_secret, sid - ) - if validation_session: - await self.auth_handler.add_threepid( - user_id, - validation_session["medium"], - validation_session["address"], - validation_session["validated_at"], - ) - return 200, {} - - raise SynapseError( - 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED - ) - - -class ThreepidAddRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/add$") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.identity_handler = hs.get_identity_handler() - self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() - - @interactive_auth_handler - async def on_POST(self, request): - if not self.hs.config.enable_3pid_changes: - raise SynapseError( - 400, "3PID changes are disabled on this server", Codes.FORBIDDEN - ) - - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - - assert_params_in_dict(body, ["client_secret", "sid"]) - sid = body["sid"] - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) - - await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - "add a third-party identifier to your account", - ) - - validation_session = await self.identity_handler.validate_threepid_session( - client_secret, sid - ) - if validation_session: - await self.auth_handler.add_threepid( - user_id, - validation_session["medium"], - validation_session["address"], - validation_session["validated_at"], - ) - return 200, {} - - raise SynapseError( - 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED - ) - - -class ThreepidBindRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/bind$") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.identity_handler = hs.get_identity_handler() - self.auth = hs.get_auth() - - async def on_POST(self, request): - body = parse_json_object_from_request(request) - - assert_params_in_dict(body, ["id_server", "sid", "client_secret"]) - id_server = body["id_server"] - sid = body["sid"] - id_access_token = body.get("id_access_token") # optional - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) - - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - - await self.identity_handler.bind_threepid( - client_secret, sid, user_id, id_server, id_access_token - ) - - return 200, {} - - -class ThreepidUnbindRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/unbind$") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.identity_handler = hs.get_identity_handler() - self.auth = hs.get_auth() - self.datastore = self.hs.get_datastore() - - async def on_POST(self, request): - """Unbind the given 3pid from a specific identity server, or identity servers that are - known to have this 3pid bound - """ - requester = await self.auth.get_user_by_req(request) - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["medium", "address"]) - - medium = body.get("medium") - address = body.get("address") - id_server = body.get("id_server") - - # Attempt to unbind the threepid from an identity server. If id_server is None, try to - # unbind from all identity servers this threepid has been added to in the past - result = await self.identity_handler.try_unbind_threepid( - requester.user.to_string(), - {"address": address, "medium": medium, "id_server": id_server}, - ) - return 200, {"id_server_unbind_result": "success" if result else "no-support"} - - -class ThreepidDeleteRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/delete$") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() - - async def on_POST(self, request): - if not self.hs.config.enable_3pid_changes: - raise SynapseError( - 400, "3PID changes are disabled on this server", Codes.FORBIDDEN - ) - - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["medium", "address"]) - - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - - try: - ret = await self.auth_handler.delete_threepid( - user_id, body["medium"], body["address"], body.get("id_server") - ) - except Exception: - # NB. This endpoint should succeed if there is nothing to - # delete, so it should only throw if something is wrong - # that we ought to care about. - logger.exception("Failed to remove threepid") - raise SynapseError(500, "Failed to remove threepid") - - if ret: - id_server_unbind_result = "success" - else: - id_server_unbind_result = "no-support" - - return 200, {"id_server_unbind_result": id_server_unbind_result} - - -def assert_valid_next_link(hs: "HomeServer", next_link: str): - """ - Raises a SynapseError if a given next_link value is invalid - - next_link is valid if the scheme is http(s) and the next_link.domain_whitelist config - option is either empty or contains a domain that matches the one in the given next_link - - Args: - hs: The homeserver object - next_link: The next_link value given by the client - - Raises: - SynapseError: If the next_link is invalid - """ - valid = True - - # Parse the contents of the URL - next_link_parsed = urlparse(next_link) - - # Scheme must not point to the local drive - if next_link_parsed.scheme == "file": - valid = False - - # If the domain whitelist is set, the domain must be in it - if ( - valid - and hs.config.next_link_domain_whitelist is not None - and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist - ): - valid = False - - if not valid: - raise SynapseError( - 400, - "'next_link' domain not included in whitelist, or not http(s)", - errcode=Codes.INVALID_PARAM, - ) - - -class WhoamiRestServlet(RestServlet): - PATTERNS = client_patterns("/account/whoami$") - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request) - - response = {"user_id": requester.user.to_string()} - - # Appservices and similar accounts do not have device IDs - # that we can report on, so exclude them for compliance. - if requester.device_id is not None: - response["device_id"] = requester.device_id - - return 200, response - - -def register_servlets(hs, http_server): - EmailPasswordRequestTokenRestServlet(hs).register(http_server) - PasswordRestServlet(hs).register(http_server) - DeactivateAccountRestServlet(hs).register(http_server) - EmailThreepidRequestTokenRestServlet(hs).register(http_server) - MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) - AddThreepidEmailSubmitTokenServlet(hs).register(http_server) - AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server) - ThreepidRestServlet(hs).register(http_server) - ThreepidAddRestServlet(hs).register(http_server) - ThreepidBindRestServlet(hs).register(http_server) - ThreepidUnbindRestServlet(hs).register(http_server) - ThreepidDeleteRestServlet(hs).register(http_server) - WhoamiRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py deleted file mode 100644 index 7517e9304e..0000000000 --- a/synapse/rest/client/v2_alpha/account_data.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.errors import AuthError, NotFoundError, SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class AccountDataServlet(RestServlet): - """ - PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1 - GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1 - """ - - PATTERNS = client_patterns( - "/user/(?P[^/]*)/account_data/(?P[^/]*)" - ) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.handler = hs.get_account_data_handler() - - async def on_PUT(self, request, user_id, account_data_type): - requester = await self.auth.get_user_by_req(request) - if user_id != requester.user.to_string(): - raise AuthError(403, "Cannot add account data for other users.") - - body = parse_json_object_from_request(request) - - await self.handler.add_account_data_for_user(user_id, account_data_type, body) - - return 200, {} - - async def on_GET(self, request, user_id, account_data_type): - requester = await self.auth.get_user_by_req(request) - if user_id != requester.user.to_string(): - raise AuthError(403, "Cannot get account data for other users.") - - event = await self.store.get_global_account_data_by_type_for_user( - account_data_type, user_id - ) - - if event is None: - raise NotFoundError("Account data not found") - - return 200, event - - -class RoomAccountDataServlet(RestServlet): - """ - PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 - GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 - """ - - PATTERNS = client_patterns( - "/user/(?P[^/]*)" - "/rooms/(?P[^/]*)" - "/account_data/(?P[^/]*)" - ) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.handler = hs.get_account_data_handler() - - async def on_PUT(self, request, user_id, room_id, account_data_type): - requester = await self.auth.get_user_by_req(request) - if user_id != requester.user.to_string(): - raise AuthError(403, "Cannot add account data for other users.") - - body = parse_json_object_from_request(request) - - if account_data_type == "m.fully_read": - raise SynapseError( - 405, - "Cannot set m.fully_read through this API." - " Use /rooms/!roomId:server.name/read_markers", - ) - - await self.handler.add_account_data_to_room( - user_id, room_id, account_data_type, body - ) - - return 200, {} - - async def on_GET(self, request, user_id, room_id, account_data_type): - requester = await self.auth.get_user_by_req(request) - if user_id != requester.user.to_string(): - raise AuthError(403, "Cannot get account data for other users.") - - event = await self.store.get_account_data_for_room_and_type( - user_id, room_id, account_data_type - ) - - if event is None: - raise NotFoundError("Room account data not found") - - return 200, event - - -def register_servlets(hs, http_server): - AccountDataServlet(hs).register(http_server) - RoomAccountDataServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py deleted file mode 100644 index 3ebe401861..0000000000 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.errors import SynapseError -from synapse.http.server import respond_with_html -from synapse.http.servlet import RestServlet - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class AccountValidityRenewServlet(RestServlet): - PATTERNS = client_patterns("/account_validity/renew$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - - self.hs = hs - self.account_activity_handler = hs.get_account_validity_handler() - self.auth = hs.get_auth() - self.account_renewed_template = ( - hs.config.account_validity.account_validity_account_renewed_template - ) - self.account_previously_renewed_template = ( - hs.config.account_validity.account_validity_account_previously_renewed_template - ) - self.invalid_token_template = ( - hs.config.account_validity.account_validity_invalid_token_template - ) - - async def on_GET(self, request): - if b"token" not in request.args: - raise SynapseError(400, "Missing renewal token") - renewal_token = request.args[b"token"][0] - - ( - token_valid, - token_stale, - expiration_ts, - ) = await self.account_activity_handler.renew_account( - renewal_token.decode("utf8") - ) - - if token_valid: - status_code = 200 - response = self.account_renewed_template.render(expiration_ts=expiration_ts) - elif token_stale: - status_code = 200 - response = self.account_previously_renewed_template.render( - expiration_ts=expiration_ts - ) - else: - status_code = 404 - response = self.invalid_token_template.render(expiration_ts=expiration_ts) - - respond_with_html(request, status_code, response) - - -class AccountValiditySendMailServlet(RestServlet): - PATTERNS = client_patterns("/account_validity/send_mail$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - - self.hs = hs - self.account_activity_handler = hs.get_account_validity_handler() - self.auth = hs.get_auth() - self.account_validity_renew_by_email_enabled = ( - hs.config.account_validity.account_validity_renew_by_email_enabled - ) - - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request, allow_expired=True) - user_id = requester.user.to_string() - await self.account_activity_handler.send_renewal_email_to_user(user_id) - - return 200, {} - - -def register_servlets(hs, http_server): - AccountValidityRenewServlet(hs).register(http_server) - AccountValiditySendMailServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py deleted file mode 100644 index 6ea1b50a62..0000000000 --- a/synapse/rest/client/v2_alpha/auth.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import TYPE_CHECKING - -from synapse.api.constants import LoginType -from synapse.api.errors import SynapseError -from synapse.api.urls import CLIENT_API_PREFIX -from synapse.http.server import respond_with_html -from synapse.http.servlet import RestServlet, parse_string - -from ._base import client_patterns - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class AuthRestServlet(RestServlet): - """ - Handles Client / Server API authentication in any situations where it - cannot be handled in the normal flow (with requests to the same endpoint). - Current use is for web fallback auth. - """ - - PATTERNS = client_patterns(r"/auth/(?P[\w\.]*)/fallback/web") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() - self.registration_handler = hs.get_registration_handler() - self.recaptcha_template = hs.config.recaptcha_template - self.terms_template = hs.config.terms_template - self.success_template = hs.config.fallback_success_template - - async def on_GET(self, request, stagetype): - session = parse_string(request, "session") - if not session: - raise SynapseError(400, "No session supplied") - - if stagetype == LoginType.RECAPTCHA: - html = self.recaptcha_template.render( - session=session, - myurl="%s/r0/auth/%s/fallback/web" - % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), - sitekey=self.hs.config.recaptcha_public_key, - ) - elif stagetype == LoginType.TERMS: - html = self.terms_template.render( - session=session, - terms_url="%s_matrix/consent?v=%s" - % (self.hs.config.public_baseurl, self.hs.config.user_consent_version), - myurl="%s/r0/auth/%s/fallback/web" - % (CLIENT_API_PREFIX, LoginType.TERMS), - ) - - elif stagetype == LoginType.SSO: - # Display a confirmation page which prompts the user to - # re-authenticate with their SSO provider. - html = await self.auth_handler.start_sso_ui_auth(request, session) - - else: - raise SynapseError(404, "Unknown auth stage type") - - # Render the HTML and return. - respond_with_html(request, 200, html) - return None - - async def on_POST(self, request, stagetype): - - session = parse_string(request, "session") - if not session: - raise SynapseError(400, "No session supplied") - - if stagetype == LoginType.RECAPTCHA: - response = parse_string(request, "g-recaptcha-response") - - if not response: - raise SynapseError(400, "No captcha response supplied") - - authdict = {"response": response, "session": session} - - success = await self.auth_handler.add_oob_auth( - LoginType.RECAPTCHA, authdict, request.getClientIP() - ) - - if success: - html = self.success_template.render() - else: - html = self.recaptcha_template.render( - session=session, - myurl="%s/r0/auth/%s/fallback/web" - % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), - sitekey=self.hs.config.recaptcha_public_key, - ) - elif stagetype == LoginType.TERMS: - authdict = {"session": session} - - success = await self.auth_handler.add_oob_auth( - LoginType.TERMS, authdict, request.getClientIP() - ) - - if success: - html = self.success_template.render() - else: - html = self.terms_template.render( - session=session, - terms_url="%s_matrix/consent?v=%s" - % ( - self.hs.config.public_baseurl, - self.hs.config.user_consent_version, - ), - myurl="%s/r0/auth/%s/fallback/web" - % (CLIENT_API_PREFIX, LoginType.TERMS), - ) - elif stagetype == LoginType.SSO: - # The SSO fallback workflow should not post here, - raise SynapseError(404, "Fallback SSO auth does not support POST requests.") - else: - raise SynapseError(404, "Unknown auth stage type") - - # Render the HTML and return. - respond_with_html(request, 200, html) - return None - - -def register_servlets(hs, http_server): - AuthRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py deleted file mode 100644 index 88e3aac797..0000000000 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2019 New Vector -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import TYPE_CHECKING, Tuple - -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES -from synapse.http.servlet import RestServlet -from synapse.http.site import SynapseRequest -from synapse.types import JsonDict - -from ._base import client_patterns - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class CapabilitiesRestServlet(RestServlet): - """End point to expose the capabilities of the server.""" - - PATTERNS = client_patterns("/capabilities$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.config = hs.config - self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() - - async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - await self.auth.get_user_by_req(request, allow_guest=True) - change_password = self.auth_handler.can_change_password() - - response = { - "capabilities": { - "m.room_versions": { - "default": self.config.default_room_version.identifier, - "available": { - v.identifier: v.disposition - for v in KNOWN_ROOM_VERSIONS.values() - }, - }, - "m.change_password": {"enabled": change_password}, - } - } - - if self.config.experimental.msc3244_enabled: - response["capabilities"]["m.room_versions"][ - "org.matrix.msc3244.room_capabilities" - ] = MSC3244_CAPABILITIES - - return 200, response - - -def register_servlets(hs: "HomeServer", http_server): - CapabilitiesRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py deleted file mode 100644 index 8b9674db06..0000000000 --- a/synapse/rest/client/v2_alpha/devices.py +++ /dev/null @@ -1,300 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api import errors -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_json_object_from_request, -) -from synapse.http.site import SynapseRequest - -from ._base import client_patterns, interactive_auth_handler - -logger = logging.getLogger(__name__) - - -class DevicesRestServlet(RestServlet): - PATTERNS = client_patterns("/devices$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - devices = await self.device_handler.get_devices_by_user( - requester.user.to_string() - ) - return 200, {"devices": devices} - - -class DeleteDevicesRestServlet(RestServlet): - """ - API for bulk deletion of devices. Accepts a JSON object with a devices - key which lists the device_ids to delete. Requires user interactive auth. - """ - - PATTERNS = client_patterns("/delete_devices") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() - self.auth_handler = hs.get_auth_handler() - - @interactive_auth_handler - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request) - - try: - body = parse_json_object_from_request(request) - except errors.SynapseError as e: - if e.errcode == errors.Codes.NOT_JSON: - # DELETE - # deal with older clients which didn't pass a JSON dict - # the same as those that pass an empty dict - body = {} - else: - raise e - - assert_params_in_dict(body, ["devices"]) - - await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - "remove device(s) from your account", - # Users might call this multiple times in a row while cleaning up - # devices, allow a single UI auth session to be re-used. - can_skip_ui_auth=True, - ) - - await self.device_handler.delete_devices( - requester.user.to_string(), body["devices"] - ) - return 200, {} - - -class DeviceRestServlet(RestServlet): - PATTERNS = client_patterns("/devices/(?P[^/]*)$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() - self.auth_handler = hs.get_auth_handler() - - async def on_GET(self, request, device_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - device = await self.device_handler.get_device( - requester.user.to_string(), device_id - ) - return 200, device - - @interactive_auth_handler - async def on_DELETE(self, request, device_id): - requester = await self.auth.get_user_by_req(request) - - try: - body = parse_json_object_from_request(request) - - except errors.SynapseError as e: - if e.errcode == errors.Codes.NOT_JSON: - # deal with older clients which didn't pass a JSON dict - # the same as those that pass an empty dict - body = {} - else: - raise - - await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - "remove a device from your account", - # Users might call this multiple times in a row while cleaning up - # devices, allow a single UI auth session to be re-used. - can_skip_ui_auth=True, - ) - - await self.device_handler.delete_device(requester.user.to_string(), device_id) - return 200, {} - - async def on_PUT(self, request, device_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - body = parse_json_object_from_request(request) - await self.device_handler.update_device( - requester.user.to_string(), device_id, body - ) - return 200, {} - - -class DehydratedDeviceServlet(RestServlet): - """Retrieve or store a dehydrated device. - - GET /org.matrix.msc2697.v2/dehydrated_device - - HTTP/1.1 200 OK - Content-Type: application/json - - { - "device_id": "dehydrated_device_id", - "device_data": { - "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", - "account": "dehydrated_device" - } - } - - PUT /org.matrix.msc2697/dehydrated_device - Content-Type: application/json - - { - "device_data": { - "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", - "account": "dehydrated_device" - } - } - - HTTP/1.1 200 OK - Content-Type: application/json - - { - "device_id": "dehydrated_device_id" - } - - """ - - PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=()) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() - - async def on_GET(self, request: SynapseRequest): - requester = await self.auth.get_user_by_req(request) - dehydrated_device = await self.device_handler.get_dehydrated_device( - requester.user.to_string() - ) - if dehydrated_device is not None: - (device_id, device_data) = dehydrated_device - result = {"device_id": device_id, "device_data": device_data} - return (200, result) - else: - raise errors.NotFoundError("No dehydrated device available") - - async def on_PUT(self, request: SynapseRequest): - submission = parse_json_object_from_request(request) - requester = await self.auth.get_user_by_req(request) - - if "device_data" not in submission: - raise errors.SynapseError( - 400, - "device_data missing", - errcode=errors.Codes.MISSING_PARAM, - ) - elif not isinstance(submission["device_data"], dict): - raise errors.SynapseError( - 400, - "device_data must be an object", - errcode=errors.Codes.INVALID_PARAM, - ) - - device_id = await self.device_handler.store_dehydrated_device( - requester.user.to_string(), - submission["device_data"], - submission.get("initial_device_display_name", None), - ) - return 200, {"device_id": device_id} - - -class ClaimDehydratedDeviceServlet(RestServlet): - """Claim a dehydrated device. - - POST /org.matrix.msc2697.v2/dehydrated_device/claim - Content-Type: application/json - - { - "device_id": "dehydrated_device_id" - } - - HTTP/1.1 200 OK - Content-Type: application/json - - { - "success": true, - } - - """ - - PATTERNS = client_patterns( - "/org.matrix.msc2697.v2/dehydrated_device/claim", releases=() - ) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() - - async def on_POST(self, request: SynapseRequest): - requester = await self.auth.get_user_by_req(request) - - submission = parse_json_object_from_request(request) - - if "device_id" not in submission: - raise errors.SynapseError( - 400, - "device_id missing", - errcode=errors.Codes.MISSING_PARAM, - ) - elif not isinstance(submission["device_id"], str): - raise errors.SynapseError( - 400, - "device_id must be a string", - errcode=errors.Codes.INVALID_PARAM, - ) - - result = await self.device_handler.rehydrate_device( - requester.user.to_string(), - self.auth.get_access_token_from_request(request), - submission["device_id"], - ) - - return (200, result) - - -def register_servlets(hs, http_server): - DeleteDevicesRestServlet(hs).register(http_server) - DevicesRestServlet(hs).register(http_server) - DeviceRestServlet(hs).register(http_server) - DehydratedDeviceServlet(hs).register(http_server) - ClaimDehydratedDeviceServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py deleted file mode 100644 index 411667a9c8..0000000000 --- a/synapse/rest/client/v2_alpha/filter.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.types import UserID - -from ._base import client_patterns, set_timeline_upper_limit - -logger = logging.getLogger(__name__) - - -class GetFilterRestServlet(RestServlet): - PATTERNS = client_patterns("/user/(?P[^/]*)/filter/(?P[^/]*)") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.filtering = hs.get_filtering() - - async def on_GET(self, request, user_id, filter_id): - target_user = UserID.from_string(user_id) - requester = await self.auth.get_user_by_req(request) - - if target_user != requester.user: - raise AuthError(403, "Cannot get filters for other users") - - if not self.hs.is_mine(target_user): - raise AuthError(403, "Can only get filters for local users") - - try: - filter_id = int(filter_id) - except Exception: - raise SynapseError(400, "Invalid filter_id") - - try: - filter_collection = await self.filtering.get_user_filter( - user_localpart=target_user.localpart, filter_id=filter_id - ) - except StoreError as e: - if e.code != 404: - raise - raise NotFoundError("No such filter") - - return 200, filter_collection.get_filter_json() - - -class CreateFilterRestServlet(RestServlet): - PATTERNS = client_patterns("/user/(?P[^/]*)/filter") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.filtering = hs.get_filtering() - - async def on_POST(self, request, user_id): - - target_user = UserID.from_string(user_id) - requester = await self.auth.get_user_by_req(request) - - if target_user != requester.user: - raise AuthError(403, "Cannot create filters for other users") - - if not self.hs.is_mine(target_user): - raise AuthError(403, "Can only create filters for local users") - - content = parse_json_object_from_request(request) - set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit) - - filter_id = await self.filtering.add_user_filter( - user_localpart=target_user.localpart, user_filter=content - ) - - return 200, {"filter_id": str(filter_id)} - - -def register_servlets(hs, http_server): - GetFilterRestServlet(hs).register(http_server) - CreateFilterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py deleted file mode 100644 index 6285680c00..0000000000 --- a/synapse/rest/client/v2_alpha/groups.py +++ /dev/null @@ -1,957 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from functools import wraps -from typing import TYPE_CHECKING, Optional, Tuple - -from twisted.web.server import Request - -from synapse.api.constants import ( - MAX_GROUP_CATEGORYID_LENGTH, - MAX_GROUP_ROLEID_LENGTH, - MAX_GROUPID_LENGTH, -) -from synapse.api.errors import Codes, SynapseError -from synapse.handlers.groups_local import GroupsLocalHandler -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_json_object_from_request, -) -from synapse.http.site import SynapseRequest -from synapse.types import GroupID, JsonDict - -from ._base import client_patterns - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -def _validate_group_id(f): - """Wrapper to validate the form of the group ID. - - Can be applied to any on_FOO methods that accepts a group ID as a URL parameter. - """ - - @wraps(f) - def wrapper(self, request: Request, group_id: str, *args, **kwargs): - if not GroupID.is_valid(group_id): - raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) - - return f(self, request, group_id, *args, **kwargs) - - return wrapper - - -class GroupServlet(RestServlet): - """Get the group profile""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/profile$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - group_description = await self.groups_handler.get_group_profile( - group_id, requester_user_id - ) - - return 200, group_description - - @_validate_group_id - async def on_POST( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert_params_in_dict( - content, ("name", "avatar_url", "short_description", "long_description") - ) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot create group profiles." - await self.groups_handler.update_group_profile( - group_id, requester_user_id, content - ) - - return 200, {} - - -class GroupSummaryServlet(RestServlet): - """Get the full group summary""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/summary$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - get_group_summary = await self.groups_handler.get_group_summary( - group_id, requester_user_id - ) - - return 200, get_group_summary - - -class GroupSummaryRoomsCatServlet(RestServlet): - """Update/delete a rooms entry in the summary. - - Matches both: - - /groups/:group/summary/rooms/:room_id - - /groups/:group/summary/categories/:category/rooms/:room_id - """ - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/summary" - "(/categories/(?P[^/]+))?" - "/rooms/(?P[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, - request: SynapseRequest, - group_id: str, - category_id: Optional[str], - room_id: str, - ): - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - if category_id == "": - raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM) - - if category_id and len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: - raise SynapseError( - 400, - "category_id may not be longer than %s characters" - % (MAX_GROUP_CATEGORYID_LENGTH,), - Codes.INVALID_PARAM, - ) - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group summaries." - resp = await self.groups_handler.update_group_summary_room( - group_id, - requester_user_id, - room_id=room_id, - category_id=category_id, - content=content, - ) - - return 200, resp - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, category_id: str, room_id: str - ): - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group profiles." - resp = await self.groups_handler.delete_group_summary_room( - group_id, requester_user_id, room_id=room_id, category_id=category_id - ) - - return 200, resp - - -class GroupCategoryServlet(RestServlet): - """Get/add/update/delete a group category""" - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/categories/(?P[^/]+)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str, category_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - category = await self.groups_handler.get_group_category( - group_id, requester_user_id, category_id=category_id - ) - - return 200, category - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, category_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - if not category_id: - raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM) - - if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: - raise SynapseError( - 400, - "category_id may not be longer than %s characters" - % (MAX_GROUP_CATEGORYID_LENGTH,), - Codes.INVALID_PARAM, - ) - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group categories." - resp = await self.groups_handler.update_group_category( - group_id, requester_user_id, category_id=category_id, content=content - ) - - return 200, resp - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, category_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group categories." - resp = await self.groups_handler.delete_group_category( - group_id, requester_user_id, category_id=category_id - ) - - return 200, resp - - -class GroupCategoriesServlet(RestServlet): - """Get all group categories""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/categories/$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - category = await self.groups_handler.get_group_categories( - group_id, requester_user_id - ) - - return 200, category - - -class GroupRoleServlet(RestServlet): - """Get/add/update/delete a group role""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/(?P[^/]+)$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str, role_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - category = await self.groups_handler.get_group_role( - group_id, requester_user_id, role_id=role_id - ) - - return 200, category - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, role_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - if not role_id: - raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM) - - if len(role_id) > MAX_GROUP_ROLEID_LENGTH: - raise SynapseError( - 400, - "role_id may not be longer than %s characters" - % (MAX_GROUP_ROLEID_LENGTH,), - Codes.INVALID_PARAM, - ) - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group roles." - resp = await self.groups_handler.update_group_role( - group_id, requester_user_id, role_id=role_id, content=content - ) - - return 200, resp - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, role_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group roles." - resp = await self.groups_handler.delete_group_role( - group_id, requester_user_id, role_id=role_id - ) - - return 200, resp - - -class GroupRolesServlet(RestServlet): - """Get all group roles""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - category = await self.groups_handler.get_group_roles( - group_id, requester_user_id - ) - - return 200, category - - -class GroupSummaryUsersRoleServlet(RestServlet): - """Update/delete a user's entry in the summary. - - Matches both: - - /groups/:group/summary/users/:room_id - - /groups/:group/summary/roles/:role/users/:user_id - """ - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/summary" - "(/roles/(?P[^/]+))?" - "/users/(?P[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, - request: SynapseRequest, - group_id: str, - role_id: Optional[str], - user_id: str, - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - if role_id == "": - raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM) - - if role_id and len(role_id) > MAX_GROUP_ROLEID_LENGTH: - raise SynapseError( - 400, - "role_id may not be longer than %s characters" - % (MAX_GROUP_ROLEID_LENGTH,), - Codes.INVALID_PARAM, - ) - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group summaries." - resp = await self.groups_handler.update_group_summary_user( - group_id, - requester_user_id, - user_id=user_id, - role_id=role_id, - content=content, - ) - - return 200, resp - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, role_id: str, user_id: str - ): - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group summaries." - resp = await self.groups_handler.delete_group_summary_user( - group_id, requester_user_id, user_id=user_id, role_id=role_id - ) - - return 200, resp - - -class GroupRoomServlet(RestServlet): - """Get all rooms in a group""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/rooms$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - result = await self.groups_handler.get_rooms_in_group( - group_id, requester_user_id - ) - - return 200, result - - -class GroupUsersServlet(RestServlet): - """Get all users in a group""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/users$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - result = await self.groups_handler.get_users_in_group( - group_id, requester_user_id - ) - - return 200, result - - -class GroupInvitedUsersServlet(RestServlet): - """Get users invited to a group""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/invited_users$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - result = await self.groups_handler.get_invited_users_in_group( - group_id, requester_user_id - ) - - return 200, result - - -class GroupSettingJoinPolicyServlet(RestServlet): - """Set group join policy""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/settings/m.join_policy$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group join policy." - result = await self.groups_handler.set_group_join_policy( - group_id, requester_user_id, content - ) - - return 200, result - - -class GroupCreateServlet(RestServlet): - """Create a group""" - - PATTERNS = client_patterns("/create_group$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - self.server_name = hs.hostname - - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - # TODO: Create group on remote server - content = parse_json_object_from_request(request) - localpart = content.pop("localpart") - group_id = GroupID(localpart, self.server_name).to_string() - - if not localpart: - raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) - - if len(group_id) > MAX_GROUPID_LENGTH: - raise SynapseError( - 400, - "Group ID may not be longer than %s characters" % (MAX_GROUPID_LENGTH,), - Codes.INVALID_PARAM, - ) - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot create groups." - result = await self.groups_handler.create_group( - group_id, requester_user_id, content - ) - - return 200, result - - -class GroupAdminRoomsServlet(RestServlet): - """Add a room to the group""" - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify rooms in a group." - result = await self.groups_handler.add_room_to_group( - group_id, requester_user_id, room_id, content - ) - - return 200, result - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group categories." - result = await self.groups_handler.remove_room_from_group( - group_id, requester_user_id, room_id - ) - - return 200, result - - -class GroupAdminRoomsConfigServlet(RestServlet): - """Update the config of a room in a group""" - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)" - "/config/(?P[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, room_id: str, config_key: str - ): - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group categories." - result = await self.groups_handler.update_room_in_group( - group_id, requester_user_id, room_id, config_key, content - ) - - return 200, result - - -class GroupAdminUsersInviteServlet(RestServlet): - """Invite a user to the group""" - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/admin/users/invite/(?P[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - self.store = hs.get_datastore() - self.is_mine_id = hs.is_mine_id - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id, user_id - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - config = content.get("config", {}) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot invite users to a group." - result = await self.groups_handler.invite( - group_id, user_id, requester_user_id, config - ) - - return 200, result - - -class GroupAdminUsersKickServlet(RestServlet): - """Kick a user from the group""" - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/admin/users/remove/(?P[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id, user_id - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot kick users from a group." - result = await self.groups_handler.remove_user_from_group( - group_id, user_id, requester_user_id, content - ) - - return 200, result - - -class GroupSelfLeaveServlet(RestServlet): - """Leave a joined group""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/self/leave$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot leave a group for a users." - result = await self.groups_handler.remove_user_from_group( - group_id, requester_user_id, requester_user_id, content - ) - - return 200, result - - -class GroupSelfJoinServlet(RestServlet): - """Attempt to join a group, or knock""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/self/join$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot join a user to a group." - result = await self.groups_handler.join_group( - group_id, requester_user_id, content - ) - - return 200, result - - -class GroupSelfAcceptInviteServlet(RestServlet): - """Accept a group invite""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/self/accept_invite$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot accept an invite to a group." - result = await self.groups_handler.accept_invite( - group_id, requester_user_id, content - ) - - return 200, result - - -class GroupSelfUpdatePublicityServlet(RestServlet): - """Update whether we publicise a users membership of a group""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/self/update_publicity$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.store = hs.get_datastore() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - publicise = content["publicise"] - await self.store.update_group_publicity(group_id, requester_user_id, publicise) - - return 200, {} - - -class PublicisedGroupsForUserServlet(RestServlet): - """Get the list of groups a user is advertising""" - - PATTERNS = client_patterns("/publicised_groups/(?P[^/]*)$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.store = hs.get_datastore() - self.groups_handler = hs.get_groups_local_handler() - - async def on_GET( - self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: - await self.auth.get_user_by_req(request, allow_guest=True) - - result = await self.groups_handler.get_publicised_groups_for_user(user_id) - - return 200, result - - -class PublicisedGroupsForUsersServlet(RestServlet): - """Get the list of groups a user is advertising""" - - PATTERNS = client_patterns("/publicised_groups$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.store = hs.get_datastore() - self.groups_handler = hs.get_groups_local_handler() - - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - await self.auth.get_user_by_req(request, allow_guest=True) - - content = parse_json_object_from_request(request) - user_ids = content["user_ids"] - - result = await self.groups_handler.bulk_get_publicised_groups(user_ids) - - return 200, result - - -class GroupsForUserServlet(RestServlet): - """Get all groups the logged in user is joined to""" - - PATTERNS = client_patterns("/joined_groups$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - result = await self.groups_handler.get_joined_groups(requester_user_id) - - return 200, result - - -def register_servlets(hs: "HomeServer", http_server): - GroupServlet(hs).register(http_server) - GroupSummaryServlet(hs).register(http_server) - GroupInvitedUsersServlet(hs).register(http_server) - GroupUsersServlet(hs).register(http_server) - GroupRoomServlet(hs).register(http_server) - GroupSettingJoinPolicyServlet(hs).register(http_server) - GroupCreateServlet(hs).register(http_server) - GroupAdminRoomsServlet(hs).register(http_server) - GroupAdminRoomsConfigServlet(hs).register(http_server) - GroupAdminUsersInviteServlet(hs).register(http_server) - GroupAdminUsersKickServlet(hs).register(http_server) - GroupSelfLeaveServlet(hs).register(http_server) - GroupSelfJoinServlet(hs).register(http_server) - GroupSelfAcceptInviteServlet(hs).register(http_server) - GroupsForUserServlet(hs).register(http_server) - GroupCategoryServlet(hs).register(http_server) - GroupCategoriesServlet(hs).register(http_server) - GroupSummaryRoomsCatServlet(hs).register(http_server) - GroupRoleServlet(hs).register(http_server) - GroupRolesServlet(hs).register(http_server) - GroupSelfUpdatePublicityServlet(hs).register(http_server) - GroupSummaryUsersRoleServlet(hs).register(http_server) - PublicisedGroupsForUserServlet(hs).register(http_server) - PublicisedGroupsForUsersServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py deleted file mode 100644 index d0d9d30d40..0000000000 --- a/synapse/rest/client/v2_alpha/keys.py +++ /dev/null @@ -1,344 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2019 New Vector Ltd -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.errors import SynapseError -from synapse.http.servlet import ( - RestServlet, - parse_integer, - parse_json_object_from_request, - parse_string, -) -from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.types import StreamToken - -from ._base import client_patterns, interactive_auth_handler - -logger = logging.getLogger(__name__) - - -class KeyUploadServlet(RestServlet): - """ - POST /keys/upload HTTP/1.1 - Content-Type: application/json - - { - "device_keys": { - "user_id": "", - "device_id": "", - "valid_until_ts": , - "algorithms": [ - "m.olm.curve25519-aes-sha2", - ] - "keys": { - ":": "", - }, - "signatures:" { - "" { - ":": "" - } } }, - "one_time_keys": { - ":": "" - }, - } - """ - - PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.auth = hs.get_auth() - self.e2e_keys_handler = hs.get_e2e_keys_handler() - self.device_handler = hs.get_device_handler() - - @trace(opname="upload_keys") - async def on_POST(self, request, device_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - - if device_id is not None: - # Providing the device_id should only be done for setting keys - # for dehydrated devices; however, we allow it for any device for - # compatibility with older clients. - if requester.device_id is not None and device_id != requester.device_id: - dehydrated_device = await self.device_handler.get_dehydrated_device( - user_id - ) - if dehydrated_device is not None and device_id != dehydrated_device[0]: - set_tag("error", True) - log_kv( - { - "message": "Client uploading keys for a different device", - "logged_in_id": requester.device_id, - "key_being_uploaded": device_id, - } - ) - logger.warning( - "Client uploading keys for a different device " - "(logged in as %s, uploading for %s)", - requester.device_id, - device_id, - ) - else: - device_id = requester.device_id - - if device_id is None: - raise SynapseError( - 400, "To upload keys, you must pass device_id when authenticating" - ) - - result = await self.e2e_keys_handler.upload_keys_for_user( - user_id, device_id, body - ) - return 200, result - - -class KeyQueryServlet(RestServlet): - """ - POST /keys/query HTTP/1.1 - Content-Type: application/json - { - "device_keys": { - "": [""] - } } - - HTTP/1.1 200 OK - { - "device_keys": { - "": { - "": { - "user_id": "", // Duplicated to be signed - "device_id": "", // Duplicated to be signed - "valid_until_ts": , - "algorithms": [ // List of supported algorithms - "m.olm.curve25519-aes-sha2", - ], - "keys": { // Must include a ed25519 signing key - ":": "", - }, - "signatures:" { - // Must be signed with device's ed25519 key - "/": { - ":": "" - } - // Must be signed by this server. - "": { - ":": "" - } } } } } } - """ - - PATTERNS = client_patterns("/keys/query$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ - super().__init__() - self.auth = hs.get_auth() - self.e2e_keys_handler = hs.get_e2e_keys_handler() - - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - user_id = requester.user.to_string() - device_id = requester.device_id - timeout = parse_integer(request, "timeout", 10 * 1000) - body = parse_json_object_from_request(request) - result = await self.e2e_keys_handler.query_devices( - body, timeout, user_id, device_id - ) - return 200, result - - -class KeyChangesServlet(RestServlet): - """Returns the list of changes of keys between two stream tokens (may return - spurious extra results, since we currently ignore the `to` param). - - GET /keys/changes?from=...&to=... - - 200 OK - { "changed": ["@foo:example.com"] } - """ - - PATTERNS = client_patterns("/keys/changes$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ - super().__init__() - self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - from_token_string = parse_string(request, "from", required=True) - set_tag("from", from_token_string) - - # We want to enforce they do pass us one, but we ignore it and return - # changes after the "to" as well as before. - set_tag("to", parse_string(request, "to")) - - from_token = await StreamToken.from_string(self.store, from_token_string) - - user_id = requester.user.to_string() - - results = await self.device_handler.get_user_ids_changed(user_id, from_token) - - return 200, results - - -class OneTimeKeyServlet(RestServlet): - """ - POST /keys/claim HTTP/1.1 - { - "one_time_keys": { - "": { - "": "" - } } } - - HTTP/1.1 200 OK - { - "one_time_keys": { - "": { - "": { - ":": "" - } } } } - - """ - - PATTERNS = client_patterns("/keys/claim$") - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.e2e_keys_handler = hs.get_e2e_keys_handler() - - async def on_POST(self, request): - await self.auth.get_user_by_req(request, allow_guest=True) - timeout = parse_integer(request, "timeout", 10 * 1000) - body = parse_json_object_from_request(request) - result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout) - return 200, result - - -class SigningKeyUploadServlet(RestServlet): - """ - POST /keys/device_signing/upload HTTP/1.1 - Content-Type: application/json - - { - } - """ - - PATTERNS = client_patterns("/keys/device_signing/upload$", releases=()) - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.e2e_keys_handler = hs.get_e2e_keys_handler() - self.auth_handler = hs.get_auth_handler() - - @interactive_auth_handler - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - - await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - "add a device signing key to your account", - # Allow skipping of UI auth since this is frequently called directly - # after login and it is silly to ask users to re-auth immediately. - can_skip_ui_auth=True, - ) - - result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) - return 200, result - - -class SignaturesUploadServlet(RestServlet): - """ - POST /keys/signatures/upload HTTP/1.1 - Content-Type: application/json - - { - "@alice:example.com": { - "": { - "user_id": "", - "device_id": "", - "algorithms": [ - "m.olm.curve25519-aes-sha2", - "m.megolm.v1.aes-sha2" - ], - "keys": { - ":": "", - }, - "signatures": { - "": { - ":": ">" - } - } - } - } - } - """ - - PATTERNS = client_patterns("/keys/signatures/upload$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.auth = hs.get_auth() - self.e2e_keys_handler = hs.get_e2e_keys_handler() - - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - - result = await self.e2e_keys_handler.upload_signatures_for_device_keys( - user_id, body - ) - return 200, result - - -def register_servlets(hs, http_server): - KeyUploadServlet(hs).register(http_server) - KeyQueryServlet(hs).register(http_server) - KeyChangesServlet(hs).register(http_server) - OneTimeKeyServlet(hs).register(http_server) - SigningKeyUploadServlet(hs).register(http_server) - SignaturesUploadServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/knock.py b/synapse/rest/client/v2_alpha/knock.py deleted file mode 100644 index 7d1bc40658..0000000000 --- a/synapse/rest/client/v2_alpha/knock.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2020 Sorunome -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple - -from twisted.web.server import Request - -from synapse.api.constants import Membership -from synapse.api.errors import SynapseError -from synapse.http.servlet import ( - RestServlet, - parse_json_object_from_request, - parse_strings_from_args, -) -from synapse.http.site import SynapseRequest -from synapse.logging.opentracing import set_tag -from synapse.rest.client.transactions import HttpTransactionCache -from synapse.types import JsonDict, RoomAlias, RoomID - -if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class KnockRoomAliasServlet(RestServlet): - """ - POST /knock/{roomIdOrAlias} - """ - - PATTERNS = client_patterns("/knock/(?P[^/]*)") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.txns = HttpTransactionCache(hs) - self.room_member_handler = hs.get_room_member_handler() - self.auth = hs.get_auth() - - async def on_POST( - self, - request: SynapseRequest, - room_identifier: str, - txn_id: Optional[str] = None, - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - - content = parse_json_object_from_request(request) - event_content = None - if "reason" in content: - event_content = {"reason": content["reason"]} - - if RoomID.is_valid(room_identifier): - room_id = room_identifier - - # twisted.web.server.Request.args is incorrectly defined as Optional[Any] - args: Dict[bytes, List[bytes]] = request.args # type: ignore - - remote_room_hosts = parse_strings_from_args( - args, "server_name", required=False - ) - elif RoomAlias.is_valid(room_identifier): - handler = self.room_member_handler - room_alias = RoomAlias.from_string(room_identifier) - room_id_obj, remote_room_hosts = await handler.lookup_room_alias(room_alias) - room_id = room_id_obj.to_string() - else: - raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) - ) - - await self.room_member_handler.update_membership( - requester=requester, - target=requester.user, - room_id=room_id, - action=Membership.KNOCK, - txn_id=txn_id, - third_party_signed=None, - remote_room_hosts=remote_room_hosts, - content=event_content, - ) - - return 200, {"room_id": room_id} - - def on_PUT(self, request: Request, room_identifier: str, txn_id: str): - set_tag("txn_id", txn_id) - - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_identifier, txn_id - ) - - -def register_servlets(hs, http_server): - KnockRoomAliasServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py deleted file mode 100644 index 0ede643c2d..0000000000 --- a/synapse/rest/client/v2_alpha/notifications.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.events.utils import format_event_for_client_v2_without_room_id -from synapse.http.servlet import RestServlet, parse_integer, parse_string - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class NotificationsServlet(RestServlet): - PATTERNS = client_patterns("/notifications$") - - def __init__(self, hs): - super().__init__() - self.store = hs.get_datastore() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self._event_serializer = hs.get_event_client_serializer() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - - from_token = parse_string(request, "from", required=False) - limit = parse_integer(request, "limit", default=50) - only = parse_string(request, "only", required=False) - - limit = min(limit, 500) - - push_actions = await self.store.get_push_actions_for_user( - user_id, from_token, limit, only_highlight=(only == "highlight") - ) - - receipts_by_room = await self.store.get_receipts_for_user_with_orderings( - user_id, "m.read" - ) - - notif_event_ids = [pa["event_id"] for pa in push_actions] - notif_events = await self.store.get_events(notif_event_ids) - - returned_push_actions = [] - - next_token = None - - for pa in push_actions: - returned_pa = { - "room_id": pa["room_id"], - "profile_tag": pa["profile_tag"], - "actions": pa["actions"], - "ts": pa["received_ts"], - "event": ( - await self._event_serializer.serialize_event( - notif_events[pa["event_id"]], - self.clock.time_msec(), - event_format=format_event_for_client_v2_without_room_id, - ) - ), - } - - if pa["room_id"] not in receipts_by_room: - returned_pa["read"] = False - else: - receipt = receipts_by_room[pa["room_id"]] - - returned_pa["read"] = ( - receipt["topological_ordering"], - receipt["stream_ordering"], - ) >= (pa["topological_ordering"], pa["stream_ordering"]) - returned_push_actions.append(returned_pa) - next_token = str(pa["stream_ordering"]) - - return 200, {"notifications": returned_push_actions, "next_token": next_token} - - -def register_servlets(hs, http_server): - NotificationsServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py deleted file mode 100644 index e8d2673819..0000000000 --- a/synapse/rest/client/v2_alpha/openid.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging - -from synapse.api.errors import AuthError -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.util.stringutils import random_string - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class IdTokenServlet(RestServlet): - """ - Get a bearer token that may be passed to a third party to confirm ownership - of a matrix user id. - - The format of the response could be made compatible with the format given - in http://openid.net/specs/openid-connect-core-1_0.html#TokenResponse - - But instead of returning a signed "id_token" the response contains the - name of the issuing matrix homeserver. This means that for now the third - party will need to check the validity of the "id_token" against the - federation /openid/userinfo endpoint of the homeserver. - - Request: - - POST /user/{user_id}/openid/request_token?access_token=... HTTP/1.1 - - {} - - Response: - - HTTP/1.1 200 OK - { - "access_token": "ABDEFGH", - "token_type": "Bearer", - "matrix_server_name": "example.com", - "expires_in": 3600, - } - """ - - PATTERNS = client_patterns("/user/(?P[^/]*)/openid/request_token") - - EXPIRES_MS = 3600 * 1000 - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.clock = hs.get_clock() - self.server_name = hs.config.server_name - - async def on_POST(self, request, user_id): - requester = await self.auth.get_user_by_req(request) - if user_id != requester.user.to_string(): - raise AuthError(403, "Cannot request tokens for other users.") - - # Parse the request body to make sure it's JSON, but ignore the contents - # for now. - parse_json_object_from_request(request) - - token = random_string(24) - ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS - - await self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) - - return ( - 200, - { - "access_token": token, - "token_type": "Bearer", - "matrix_server_name": self.server_name, - "expires_in": self.EXPIRES_MS // 1000, - }, - ) - - -def register_servlets(hs, http_server): - IdTokenServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py deleted file mode 100644 index a83927aee6..0000000000 --- a/synapse/rest/client/v2_alpha/password_policy.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.http.servlet import RestServlet - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class PasswordPolicyServlet(RestServlet): - PATTERNS = client_patterns("/password_policy$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - - self.policy = hs.config.password_policy - self.enabled = hs.config.password_policy_enabled - - def on_GET(self, request): - if not self.enabled or not self.policy: - return (200, {}) - - policy = {} - - for param in [ - "minimum_length", - "require_digit", - "require_symbol", - "require_lowercase", - "require_uppercase", - ]: - if param in self.policy: - policy["m.%s" % param] = self.policy[param] - - return (200, policy) - - -def register_servlets(hs, http_server): - PasswordPolicyServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py deleted file mode 100644 index 027f8b81fa..0000000000 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.constants import ReadReceiptEventFields -from synapse.api.errors import Codes, SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class ReadMarkerRestServlet(RestServlet): - PATTERNS = client_patterns("/rooms/(?P[^/]*)/read_markers$") - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.receipts_handler = hs.get_receipts_handler() - self.read_marker_handler = hs.get_read_marker_handler() - self.presence_handler = hs.get_presence_handler() - - async def on_POST(self, request, room_id): - requester = await self.auth.get_user_by_req(request) - - await self.presence_handler.bump_presence_active_time(requester.user) - - body = parse_json_object_from_request(request) - read_event_id = body.get("m.read", None) - hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) - - if not isinstance(hidden, bool): - raise SynapseError( - 400, - "Param %s must be a boolean, if given" - % ReadReceiptEventFields.MSC2285_HIDDEN, - Codes.BAD_JSON, - ) - - if read_event_id: - await self.receipts_handler.received_client_receipt( - room_id, - "m.read", - user_id=requester.user.to_string(), - event_id=read_event_id, - hidden=hidden, - ) - - read_marker_event_id = body.get("m.fully_read", None) - if read_marker_event_id: - await self.read_marker_handler.received_client_read_marker( - room_id, - user_id=requester.user.to_string(), - event_id=read_marker_event_id, - ) - - return 200, {} - - -def register_servlets(hs, http_server): - ReadMarkerRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py deleted file mode 100644 index d9ab836cd8..0000000000 --- a/synapse/rest/client/v2_alpha/receipts.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.constants import ReadReceiptEventFields -from synapse.api.errors import Codes, SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class ReceiptRestServlet(RestServlet): - PATTERNS = client_patterns( - "/rooms/(?P[^/]*)" - "/receipt/(?P[^/]*)" - "/(?P[^/]*)$" - ) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.receipts_handler = hs.get_receipts_handler() - self.presence_handler = hs.get_presence_handler() - - async def on_POST(self, request, room_id, receipt_type, event_id): - requester = await self.auth.get_user_by_req(request) - - if receipt_type != "m.read": - raise SynapseError(400, "Receipt type must be 'm.read'") - - body = parse_json_object_from_request(request, allow_empty_body=True) - hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) - - if not isinstance(hidden, bool): - raise SynapseError( - 400, - "Param %s must be a boolean, if given" - % ReadReceiptEventFields.MSC2285_HIDDEN, - Codes.BAD_JSON, - ) - - await self.presence_handler.bump_presence_active_time(requester.user) - - await self.receipts_handler.received_client_receipt( - room_id, - receipt_type, - user_id=requester.user.to_string(), - event_id=event_id, - hidden=hidden, - ) - - return 200, {} - - -def register_servlets(hs, http_server): - ReceiptRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py deleted file mode 100644 index 4d31584acd..0000000000 --- a/synapse/rest/client/v2_alpha/register.py +++ /dev/null @@ -1,879 +0,0 @@ -# Copyright 2015 - 2016 OpenMarket Ltd -# Copyright 2017 Vector Creations Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import hmac -import logging -import random -from typing import List, Union - -import synapse -import synapse.api.auth -import synapse.types -from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType -from synapse.api.errors import ( - Codes, - InteractiveAuthIncompleteError, - SynapseError, - ThreepidValidationError, - UnrecognizedRequestError, -) -from synapse.config import ConfigError -from synapse.config.captcha import CaptchaConfig -from synapse.config.consent import ConsentConfig -from synapse.config.emailconfig import ThreepidBehaviour -from synapse.config.ratelimiting import FederationRateLimitConfig -from synapse.config.registration import RegistrationConfig -from synapse.config.server import is_threepid_reserved -from synapse.handlers.auth import AuthHandler -from synapse.handlers.ui_auth import UIAuthSessionDataConstants -from synapse.http.server import finish_request, respond_with_html -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_boolean, - parse_json_object_from_request, - parse_string, -) -from synapse.metrics import threepid_send_requests -from synapse.push.mailer import Mailer -from synapse.types import JsonDict -from synapse.util.msisdn import phone_number_to_msisdn -from synapse.util.ratelimitutils import FederationRateLimiter -from synapse.util.stringutils import assert_valid_client_secret, random_string -from synapse.util.threepids import ( - canonicalise_email, - check_3pid_allowed, - validate_email, -) - -from ._base import client_patterns, interactive_auth_handler - -# We ought to be using hmac.compare_digest() but on older pythons it doesn't -# exist. It's a _really minor_ security flaw to use plain string comparison -# because the timing attack is so obscured by all the other code here it's -# unlikely to make much difference -if hasattr(hmac, "compare_digest"): - compare_digest = hmac.compare_digest -else: - - def compare_digest(a, b): - return a == b - - -logger = logging.getLogger(__name__) - - -class EmailRegisterRequestTokenRestServlet(RestServlet): - PATTERNS = client_patterns("/register/email/requestToken$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.hs = hs - self.identity_handler = hs.get_identity_handler() - self.config = hs.config - - if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.mailer = Mailer( - hs=self.hs, - app_name=self.config.email_app_name, - template_html=self.config.email_registration_template_html, - template_text=self.config.email_registration_template_text, - ) - - async def on_POST(self, request): - if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.hs.config.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "Email registration has been disabled due to lack of email config" - ) - raise SynapseError( - 400, "Email-based registration has been disabled on this server" - ) - body = parse_json_object_from_request(request) - - assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) - - # Extract params from body - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) - - # For emails, canonicalise the address. - # We store all email addresses canonicalised in the DB. - # (See on_POST in EmailThreepidRequestTokenRestServlet - # in synapse/rest/client/v2_alpha/account.py) - try: - email = validate_email(body["email"]) - except ValueError as e: - raise SynapseError(400, str(e)) - send_attempt = body["send_attempt"] - next_link = body.get("next_link") # Optional param - - if not check_3pid_allowed(self.hs, "email", email): - raise SynapseError( - 403, - "Your email domain is not authorized to register on this server", - Codes.THREEPID_DENIED, - ) - - await self.identity_handler.ratelimit_request_token_requests( - request, "email", email - ) - - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( - "email", email - ) - - if existing_user_id is not None: - if self.hs.config.request_token_inhibit_3pid_errors: - # Make the client think the operation succeeded. See the rationale in the - # comments for request_token_inhibit_3pid_errors. - # Also wait for some random amount of time between 100ms and 1s to make it - # look like we did something. - await self.hs.get_clock().sleep(random.randint(1, 10) / 10) - return 200, {"sid": random_string(16)} - - raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - - if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email - - # Have the configured identity server handle the request - ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, - email, - client_secret, - send_attempt, - next_link, - ) - else: - # Send registration emails from Synapse - sid = await self.identity_handler.send_threepid_validation( - email, - client_secret, - send_attempt, - self.mailer.send_registration_mail, - next_link, - ) - - # Wrap the session id in a JSON object - ret = {"sid": sid} - - threepid_send_requests.labels(type="email", reason="register").observe( - send_attempt - ) - - return 200, ret - - -class MsisdnRegisterRequestTokenRestServlet(RestServlet): - PATTERNS = client_patterns("/register/msisdn/requestToken$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.hs = hs - self.identity_handler = hs.get_identity_handler() - - async def on_POST(self, request): - body = parse_json_object_from_request(request) - - assert_params_in_dict( - body, ["client_secret", "country", "phone_number", "send_attempt"] - ) - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) - country = body["country"] - phone_number = body["phone_number"] - send_attempt = body["send_attempt"] - next_link = body.get("next_link") # Optional param - - msisdn = phone_number_to_msisdn(country, phone_number) - - if not check_3pid_allowed(self.hs, "msisdn", msisdn): - raise SynapseError( - 403, - "Phone numbers are not authorized to register on this server", - Codes.THREEPID_DENIED, - ) - - await self.identity_handler.ratelimit_request_token_requests( - request, "msisdn", msisdn - ) - - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( - "msisdn", msisdn - ) - - if existing_user_id is not None: - if self.hs.config.request_token_inhibit_3pid_errors: - # Make the client think the operation succeeded. See the rationale in the - # comments for request_token_inhibit_3pid_errors. - # Also wait for some random amount of time between 100ms and 1s to make it - # look like we did something. - await self.hs.get_clock().sleep(random.randint(1, 10) / 10) - return 200, {"sid": random_string(16)} - - raise SynapseError( - 400, "Phone number is already in use", Codes.THREEPID_IN_USE - ) - - if not self.hs.config.account_threepid_delegate_msisdn: - logger.warning( - "No upstream msisdn account_threepid_delegate configured on the server to " - "handle this request" - ) - raise SynapseError( - 400, "Registration by phone number is not supported on this homeserver" - ) - - ret = await self.identity_handler.requestMsisdnToken( - self.hs.config.account_threepid_delegate_msisdn, - country, - phone_number, - client_secret, - send_attempt, - next_link, - ) - - threepid_send_requests.labels(type="msisdn", reason="register").observe( - send_attempt - ) - - return 200, ret - - -class RegistrationSubmitTokenServlet(RestServlet): - """Handles registration 3PID validation token submission""" - - PATTERNS = client_patterns( - "/registration/(?P[^/]*)/submit_token$", releases=(), unstable=True - ) - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.config = hs.config - self.clock = hs.get_clock() - self.store = hs.get_datastore() - - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self._failure_email_template = ( - self.config.email_registration_template_failure_html - ) - - async def on_GET(self, request, medium): - if medium != "email": - raise SynapseError( - 400, "This medium is currently not supported for registration" - ) - if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "User registration via email has been disabled due to lack of email config" - ) - raise SynapseError( - 400, "Email-based registration is disabled on this server" - ) - - sid = parse_string(request, "sid", required=True) - client_secret = parse_string(request, "client_secret", required=True) - assert_valid_client_secret(client_secret) - token = parse_string(request, "token", required=True) - - # Attempt to validate a 3PID session - try: - # Mark the session as valid - next_link = await self.store.validate_threepid_session( - sid, client_secret, token, self.clock.time_msec() - ) - - # Perform a 302 redirect if next_link is set - if next_link: - if next_link.startswith("file:///"): - logger.warning( - "Not redirecting to next_link as it is a local file: address" - ) - else: - request.setResponseCode(302) - request.setHeader("Location", next_link) - finish_request(request) - return None - - # Otherwise show the success template - html = self.config.email_registration_template_success_html_content - status_code = 200 - except ThreepidValidationError as e: - status_code = e.code - - # Show a failure page with a reason - template_vars = {"failure_reason": e.msg} - html = self._failure_email_template.render(**template_vars) - - respond_with_html(request, status_code, html) - - -class UsernameAvailabilityRestServlet(RestServlet): - PATTERNS = client_patterns("/register/available") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.hs = hs - self.registration_handler = hs.get_registration_handler() - self.ratelimiter = FederationRateLimiter( - hs.get_clock(), - FederationRateLimitConfig( - # Time window of 2s - window_size=2000, - # Artificially delay requests if rate > sleep_limit/window_size - sleep_limit=1, - # Amount of artificial delay to apply - sleep_msec=1000, - # Error with 429 if more than reject_limit requests are queued - reject_limit=1, - # Allow 1 request at a time - concurrent_requests=1, - ), - ) - - async def on_GET(self, request): - if not self.hs.config.enable_registration: - raise SynapseError( - 403, "Registration has been disabled", errcode=Codes.FORBIDDEN - ) - - ip = request.getClientIP() - with self.ratelimiter.ratelimit(ip) as wait_deferred: - await wait_deferred - - username = parse_string(request, "username", required=True) - - await self.registration_handler.check_username(username) - - return 200, {"available": True} - - -class RegisterRestServlet(RestServlet): - PATTERNS = client_patterns("/register$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - - self.hs = hs - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.auth_handler = hs.get_auth_handler() - self.registration_handler = hs.get_registration_handler() - self.identity_handler = hs.get_identity_handler() - self.room_member_handler = hs.get_room_member_handler() - self.macaroon_gen = hs.get_macaroon_generator() - self.ratelimiter = hs.get_registration_ratelimiter() - self.password_policy_handler = hs.get_password_policy_handler() - self.clock = hs.get_clock() - self._registration_enabled = self.hs.config.enable_registration - self._msc2918_enabled = hs.config.access_token_lifetime is not None - - self._registration_flows = _calculate_registration_flows( - hs.config, self.auth_handler - ) - - @interactive_auth_handler - async def on_POST(self, request): - body = parse_json_object_from_request(request) - - client_addr = request.getClientIP() - - await self.ratelimiter.ratelimit(None, client_addr, update=False) - - kind = b"user" - if b"kind" in request.args: - kind = request.args[b"kind"][0] - - if kind == b"guest": - ret = await self._do_guest_registration(body, address=client_addr) - return ret - elif kind != b"user": - raise UnrecognizedRequestError( - "Do not understand membership kind: %s" % (kind.decode("utf8"),) - ) - - if self._msc2918_enabled: - # Check if this registration should also issue a refresh token, as - # per MSC2918 - should_issue_refresh_token = parse_boolean( - request, name="org.matrix.msc2918.refresh_token", default=False - ) - else: - should_issue_refresh_token = False - - # Pull out the provided username and do basic sanity checks early since - # the auth layer will store these in sessions. - desired_username = None - if "username" in body: - if not isinstance(body["username"], str) or len(body["username"]) > 512: - raise SynapseError(400, "Invalid username") - desired_username = body["username"] - - # fork off as soon as possible for ASes which have completely - # different registration flows to normal users - - # == Application Service Registration == - if body.get("type") == APP_SERVICE_REGISTRATION_TYPE: - if not self.auth.has_access_token(request): - raise SynapseError( - 400, - "Appservice token must be provided when using a type of m.login.application_service", - ) - - # Verify the AS - self.auth.get_appservice_by_req(request) - - # Set the desired user according to the AS API (which uses the - # 'user' key not 'username'). Since this is a new addition, we'll - # fallback to 'username' if they gave one. - desired_username = body.get("user", desired_username) - - # XXX we should check that desired_username is valid. Currently - # we give appservices carte blanche for any insanity in mxids, - # because the IRC bridges rely on being able to register stupid - # IDs. - - access_token = self.auth.get_access_token_from_request(request) - - if not isinstance(desired_username, str): - raise SynapseError(400, "Desired Username is missing or not a string") - - result = await self._do_appservice_registration( - desired_username, - access_token, - body, - should_issue_refresh_token=should_issue_refresh_token, - ) - - return 200, result - elif self.auth.has_access_token(request): - raise SynapseError( - 400, - "An access token should not be provided on requests to /register (except if type is m.login.application_service)", - ) - - # == Normal User Registration == (everyone else) - if not self._registration_enabled: - raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN) - - # For regular registration, convert the provided username to lowercase - # before attempting to register it. This should mean that people who try - # to register with upper-case in their usernames don't get a nasty surprise. - # - # Note that we treat usernames case-insensitively in login, so they are - # free to carry on imagining that their username is CrAzYh4cKeR if that - # keeps them happy. - if desired_username is not None: - desired_username = desired_username.lower() - - # Check if this account is upgrading from a guest account. - guest_access_token = body.get("guest_access_token", None) - - # Pull out the provided password and do basic sanity checks early. - # - # Note that we remove the password from the body since the auth layer - # will store the body in the session and we don't want a plaintext - # password store there. - password = body.pop("password", None) - if password is not None: - if not isinstance(password, str) or len(password) > 512: - raise SynapseError(400, "Invalid password") - self.password_policy_handler.validate_password(password) - - if "initial_device_display_name" in body and password is None: - # ignore 'initial_device_display_name' if sent without - # a password to work around a client bug where it sent - # the 'initial_device_display_name' param alone, wiping out - # the original registration params - logger.warning("Ignoring initial_device_display_name without password") - del body["initial_device_display_name"] - - session_id = self.auth_handler.get_session_id(body) - registered_user_id = None - password_hash = None - if session_id: - # if we get a registered user id out of here, it means we previously - # registered a user for this session, so we could just return the - # user here. We carry on and go through the auth checks though, - # for paranoia. - registered_user_id = await self.auth_handler.get_session_data( - session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None - ) - # Extract the previously-hashed password from the session. - password_hash = await self.auth_handler.get_session_data( - session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None - ) - - # Ensure that the username is valid. - if desired_username is not None: - await self.registration_handler.check_username( - desired_username, - guest_access_token=guest_access_token, - assigned_user_id=registered_user_id, - ) - - # Check if the user-interactive authentication flows are complete, if - # not this will raise a user-interactive auth error. - try: - auth_result, params, session_id = await self.auth_handler.check_ui_auth( - self._registration_flows, - request, - body, - "register a new account", - ) - except InteractiveAuthIncompleteError as e: - # The user needs to provide more steps to complete auth. - # - # Hash the password and store it with the session since the client - # is not required to provide the password again. - # - # If a password hash was previously stored we will not attempt to - # re-hash and store it for efficiency. This assumes the password - # does not change throughout the authentication flow, but this - # should be fine since the data is meant to be consistent. - if not password_hash and password: - password_hash = await self.auth_handler.hash(password) - await self.auth_handler.set_session_data( - e.session_id, - UIAuthSessionDataConstants.PASSWORD_HASH, - password_hash, - ) - raise - - # Check that we're not trying to register a denied 3pid. - # - # the user-facing checks will probably already have happened in - # /register/email/requestToken when we requested a 3pid, but that's not - # guaranteed. - if auth_result: - for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: - if login_type in auth_result: - medium = auth_result[login_type]["medium"] - address = auth_result[login_type]["address"] - - if not check_3pid_allowed(self.hs, medium, address): - raise SynapseError( - 403, - "Third party identifiers (email/phone numbers)" - + " are not authorized on this server", - Codes.THREEPID_DENIED, - ) - - if registered_user_id is not None: - logger.info( - "Already registered user ID %r for this session", registered_user_id - ) - # don't re-register the threepids - registered = False - else: - # If we have a password in this request, prefer it. Otherwise, there - # might be a password hash from an earlier request. - if password: - password_hash = await self.auth_handler.hash(password) - if not password_hash: - raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) - - desired_username = params.get("username", None) - guest_access_token = params.get("guest_access_token", None) - - if desired_username is not None: - desired_username = desired_username.lower() - - threepid = None - if auth_result: - threepid = auth_result.get(LoginType.EMAIL_IDENTITY) - - # Also check that we're not trying to register a 3pid that's already - # been registered. - # - # This has probably happened in /register/email/requestToken as well, - # but if a user hits this endpoint twice then clicks on each link from - # the two activation emails, they would register the same 3pid twice. - for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: - if login_type in auth_result: - medium = auth_result[login_type]["medium"] - address = auth_result[login_type]["address"] - # For emails, canonicalise the address. - # We store all email addresses canonicalised in the DB. - # (See on_POST in EmailThreepidRequestTokenRestServlet - # in synapse/rest/client/v2_alpha/account.py) - if medium == "email": - try: - address = canonicalise_email(address) - except ValueError as e: - raise SynapseError(400, str(e)) - - existing_user_id = await self.store.get_user_id_by_threepid( - medium, address - ) - - if existing_user_id is not None: - raise SynapseError( - 400, - "%s is already in use" % medium, - Codes.THREEPID_IN_USE, - ) - - entries = await self.store.get_user_agents_ips_to_ui_auth_session( - session_id - ) - - registered_user_id = await self.registration_handler.register_user( - localpart=desired_username, - password_hash=password_hash, - guest_access_token=guest_access_token, - threepid=threepid, - address=client_addr, - user_agent_ips=entries, - ) - # Necessary due to auth checks prior to the threepid being - # written to the db - if threepid: - if is_threepid_reserved( - self.hs.config.mau_limits_reserved_threepids, threepid - ): - await self.store.upsert_monthly_active_user(registered_user_id) - - # Remember that the user account has been registered (and the user - # ID it was registered with, since it might not have been specified). - await self.auth_handler.set_session_data( - session_id, - UIAuthSessionDataConstants.REGISTERED_USER_ID, - registered_user_id, - ) - - registered = True - - return_dict = await self._create_registration_details( - registered_user_id, - params, - should_issue_refresh_token=should_issue_refresh_token, - ) - - if registered: - await self.registration_handler.post_registration_actions( - user_id=registered_user_id, - auth_result=auth_result, - access_token=return_dict.get("access_token"), - ) - - return 200, return_dict - - async def _do_appservice_registration( - self, username, as_token, body, should_issue_refresh_token: bool = False - ): - user_id = await self.registration_handler.appservice_register( - username, as_token - ) - return await self._create_registration_details( - user_id, - body, - is_appservice_ghost=True, - should_issue_refresh_token=should_issue_refresh_token, - ) - - async def _create_registration_details( - self, - user_id: str, - params: JsonDict, - is_appservice_ghost: bool = False, - should_issue_refresh_token: bool = False, - ): - """Complete registration of newly-registered user - - Allocates device_id if one was not given; also creates access_token. - - Args: - user_id: full canonical @user:id - params: registration parameters, from which we pull device_id, - initial_device_name and inhibit_login - is_appservice_ghost - should_issue_refresh_token: True if this registration should issue - a refresh token alongside the access token. - Returns: - dictionary for response from /register - """ - result = {"user_id": user_id, "home_server": self.hs.hostname} - if not params.get("inhibit_login", False): - device_id = params.get("device_id") - initial_display_name = params.get("initial_device_display_name") - ( - device_id, - access_token, - valid_until_ms, - refresh_token, - ) = await self.registration_handler.register_device( - user_id, - device_id, - initial_display_name, - is_guest=False, - is_appservice_ghost=is_appservice_ghost, - should_issue_refresh_token=should_issue_refresh_token, - ) - - result.update({"access_token": access_token, "device_id": device_id}) - - if valid_until_ms is not None: - expires_in_ms = valid_until_ms - self.clock.time_msec() - result["expires_in_ms"] = expires_in_ms - - if refresh_token is not None: - result["refresh_token"] = refresh_token - - return result - - async def _do_guest_registration(self, params, address=None): - if not self.hs.config.allow_guest_access: - raise SynapseError(403, "Guest access is disabled") - user_id = await self.registration_handler.register_user( - make_guest=True, address=address - ) - - # we don't allow guests to specify their own device_id, because - # we have nowhere to store it. - device_id = synapse.api.auth.GUEST_DEVICE_ID - initial_display_name = params.get("initial_device_display_name") - ( - device_id, - access_token, - valid_until_ms, - refresh_token, - ) = await self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest=True - ) - - result = { - "user_id": user_id, - "device_id": device_id, - "access_token": access_token, - "home_server": self.hs.hostname, - } - - if valid_until_ms is not None: - expires_in_ms = valid_until_ms - self.clock.time_msec() - result["expires_in_ms"] = expires_in_ms - - if refresh_token is not None: - result["refresh_token"] = refresh_token - - return 200, result - - -def _calculate_registration_flows( - # technically `config` has to provide *all* of these interfaces, not just one - config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig], - auth_handler: AuthHandler, -) -> List[List[str]]: - """Get a suitable flows list for registration - - Args: - config: server configuration - auth_handler: authorization handler - - Returns: a list of supported flows - """ - # FIXME: need a better error than "no auth flow found" for scenarios - # where we required 3PID for registration but the user didn't give one - require_email = "email" in config.registrations_require_3pid - require_msisdn = "msisdn" in config.registrations_require_3pid - - show_msisdn = True - show_email = True - - if config.disable_msisdn_registration: - show_msisdn = False - require_msisdn = False - - enabled_auth_types = auth_handler.get_enabled_auth_types() - if LoginType.EMAIL_IDENTITY not in enabled_auth_types: - show_email = False - if require_email: - raise ConfigError( - "Configuration requires email address at registration, but email " - "validation is not configured" - ) - - if LoginType.MSISDN not in enabled_auth_types: - show_msisdn = False - if require_msisdn: - raise ConfigError( - "Configuration requires msisdn at registration, but msisdn " - "validation is not configured" - ) - - flows = [] - - # only support 3PIDless registration if no 3PIDs are required - if not require_email and not require_msisdn: - # Add a dummy step here, otherwise if a client completes - # recaptcha first we'll assume they were going for this flow - # and complete the request, when they could have been trying to - # complete one of the flows with email/msisdn auth. - flows.append([LoginType.DUMMY]) - - # only support the email-only flow if we don't require MSISDN 3PIDs - if show_email and not require_msisdn: - flows.append([LoginType.EMAIL_IDENTITY]) - - # only support the MSISDN-only flow if we don't require email 3PIDs - if show_msisdn and not require_email: - flows.append([LoginType.MSISDN]) - - if show_email and show_msisdn: - # always let users provide both MSISDN & email - flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY]) - - # Prepend m.login.terms to all flows if we're requiring consent - if config.user_consent_at_registration: - for flow in flows: - flow.insert(0, LoginType.TERMS) - - # Prepend recaptcha to all flows if we're requiring captcha - if config.enable_registration_captcha: - for flow in flows: - flow.insert(0, LoginType.RECAPTCHA) - - return flows - - -def register_servlets(hs, http_server): - EmailRegisterRequestTokenRestServlet(hs).register(http_server) - MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) - UsernameAvailabilityRestServlet(hs).register(http_server) - RegistrationSubmitTokenServlet(hs).register(http_server) - RegisterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py deleted file mode 100644 index 0821cd285f..0000000000 --- a/synapse/rest/client/v2_alpha/relations.py +++ /dev/null @@ -1,381 +0,0 @@ -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This class implements the proposed relation APIs from MSC 1849. - -Since the MSC has not been approved all APIs here are unstable and may change at -any time to reflect changes in the MSC. -""" - -import logging - -from synapse.api.constants import EventTypes, RelationTypes -from synapse.api.errors import ShadowBanError, SynapseError -from synapse.http.servlet import ( - RestServlet, - parse_integer, - parse_json_object_from_request, - parse_string, -) -from synapse.rest.client.transactions import HttpTransactionCache -from synapse.storage.relations import ( - AggregationPaginationToken, - PaginationChunk, - RelationPaginationToken, -) -from synapse.util.stringutils import random_string - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class RelationSendServlet(RestServlet): - """Helper API for sending events that have relation data. - - Example API shape to send a 👍 reaction to a room: - - POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D - {} - - { - "event_id": "$foobar" - } - """ - - PATTERN = ( - "/rooms/(?P[^/]*)/send_relation" - "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)" - ) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.event_creation_handler = hs.get_event_creation_handler() - self.txns = HttpTransactionCache(hs) - - def register(self, http_server): - http_server.register_paths( - "POST", - client_patterns(self.PATTERN + "$", releases=()), - self.on_PUT_or_POST, - self.__class__.__name__, - ) - http_server.register_paths( - "PUT", - client_patterns(self.PATTERN + "/(?P[^/]*)$", releases=()), - self.on_PUT, - self.__class__.__name__, - ) - - def on_PUT(self, request, *args, **kwargs): - return self.txns.fetch_or_execute_request( - request, self.on_PUT_or_POST, request, *args, **kwargs - ) - - async def on_PUT_or_POST( - self, request, room_id, parent_id, relation_type, event_type, txn_id=None - ): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - if event_type == EventTypes.Member: - # Add relations to a membership is meaningless, so we just deny it - # at the CS API rather than trying to handle it correctly. - raise SynapseError(400, "Cannot send member events with relations") - - content = parse_json_object_from_request(request) - - aggregation_key = parse_string(request, "key", encoding="utf-8") - - content["m.relates_to"] = { - "event_id": parent_id, - "key": aggregation_key, - "rel_type": relation_type, - } - - event_dict = { - "type": event_type, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - } - - try: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict=event_dict, txn_id=txn_id - ) - event_id = event.event_id - except ShadowBanError: - event_id = "$" + random_string(43) - - return 200, {"event_id": event_id} - - -class RelationPaginationServlet(RestServlet): - """API to paginate relations on an event by topological ordering, optionally - filtered by relation type and event type. - """ - - PATTERNS = client_patterns( - "/rooms/(?P[^/]*)/relations/(?P[^/]*)" - "(/(?P[^/]*)(/(?P[^/]*))?)?$", - releases=(), - ) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.clock = hs.get_clock() - self._event_serializer = hs.get_event_client_serializer() - self.event_handler = hs.get_event_handler() - - async def on_GET( - self, request, room_id, parent_id, relation_type=None, event_type=None - ): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - await self.auth.check_user_in_room_or_world_readable( - room_id, requester.user.to_string(), allow_departed_users=True - ) - - # This gets the original event and checks that a) the event exists and - # b) the user is allowed to view it. - event = await self.event_handler.get_event(requester.user, room_id, parent_id) - - limit = parse_integer(request, "limit", default=5) - from_token_str = parse_string(request, "from") - to_token_str = parse_string(request, "to") - - if event.internal_metadata.is_redacted(): - # If the event is redacted, return an empty list of relations - pagination_chunk = PaginationChunk(chunk=[]) - else: - # Return the relations - from_token = None - if from_token_str: - from_token = RelationPaginationToken.from_string(from_token_str) - - to_token = None - if to_token_str: - to_token = RelationPaginationToken.from_string(to_token_str) - - pagination_chunk = await self.store.get_relations_for_event( - event_id=parent_id, - relation_type=relation_type, - event_type=event_type, - limit=limit, - from_token=from_token, - to_token=to_token, - ) - - events = await self.store.get_events_as_list( - [c["event_id"] for c in pagination_chunk.chunk] - ) - - now = self.clock.time_msec() - # We set bundle_aggregations to False when retrieving the original - # event because we want the content before relations were applied to - # it. - original_event = await self._event_serializer.serialize_event( - event, now, bundle_aggregations=False - ) - # Similarly, we don't allow relations to be applied to relations, so we - # return the original relations without any aggregations on top of them - # here. - events = await self._event_serializer.serialize_events( - events, now, bundle_aggregations=False - ) - - return_value = pagination_chunk.to_dict() - return_value["chunk"] = events - return_value["original_event"] = original_event - - return 200, return_value - - -class RelationAggregationPaginationServlet(RestServlet): - """API to paginate aggregation groups of relations, e.g. paginate the - types and counts of the reactions on the events. - - Example request and response: - - GET /rooms/{room_id}/aggregations/{parent_id} - - { - chunk: [ - { - "type": "m.reaction", - "key": "👍", - "count": 3 - } - ] - } - """ - - PATTERNS = client_patterns( - "/rooms/(?P[^/]*)/aggregations/(?P[^/]*)" - "(/(?P[^/]*)(/(?P[^/]*))?)?$", - releases=(), - ) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.event_handler = hs.get_event_handler() - - async def on_GET( - self, request, room_id, parent_id, relation_type=None, event_type=None - ): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - await self.auth.check_user_in_room_or_world_readable( - room_id, - requester.user.to_string(), - allow_departed_users=True, - ) - - # This checks that a) the event exists and b) the user is allowed to - # view it. - event = await self.event_handler.get_event(requester.user, room_id, parent_id) - - if relation_type not in (RelationTypes.ANNOTATION, None): - raise SynapseError(400, "Relation type must be 'annotation'") - - limit = parse_integer(request, "limit", default=5) - from_token_str = parse_string(request, "from") - to_token_str = parse_string(request, "to") - - if event.internal_metadata.is_redacted(): - # If the event is redacted, return an empty list of relations - pagination_chunk = PaginationChunk(chunk=[]) - else: - # Return the relations - from_token = None - if from_token_str: - from_token = AggregationPaginationToken.from_string(from_token_str) - - to_token = None - if to_token_str: - to_token = AggregationPaginationToken.from_string(to_token_str) - - pagination_chunk = await self.store.get_aggregation_groups_for_event( - event_id=parent_id, - event_type=event_type, - limit=limit, - from_token=from_token, - to_token=to_token, - ) - - return 200, pagination_chunk.to_dict() - - -class RelationAggregationGroupPaginationServlet(RestServlet): - """API to paginate within an aggregation group of relations, e.g. paginate - all the 👍 reactions on an event. - - Example request and response: - - GET /rooms/{room_id}/aggregations/{parent_id}/m.annotation/m.reaction/👍 - - { - chunk: [ - { - "type": "m.reaction", - "content": { - "m.relates_to": { - "rel_type": "m.annotation", - "key": "👍" - } - } - }, - ... - ] - } - """ - - PATTERNS = client_patterns( - "/rooms/(?P[^/]*)/aggregations/(?P[^/]*)" - "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)$", - releases=(), - ) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.clock = hs.get_clock() - self._event_serializer = hs.get_event_client_serializer() - self.event_handler = hs.get_event_handler() - - async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - await self.auth.check_user_in_room_or_world_readable( - room_id, - requester.user.to_string(), - allow_departed_users=True, - ) - - # This checks that a) the event exists and b) the user is allowed to - # view it. - await self.event_handler.get_event(requester.user, room_id, parent_id) - - if relation_type != RelationTypes.ANNOTATION: - raise SynapseError(400, "Relation type must be 'annotation'") - - limit = parse_integer(request, "limit", default=5) - from_token_str = parse_string(request, "from") - to_token_str = parse_string(request, "to") - - from_token = None - if from_token_str: - from_token = RelationPaginationToken.from_string(from_token_str) - - to_token = None - if to_token_str: - to_token = RelationPaginationToken.from_string(to_token_str) - - result = await self.store.get_relations_for_event( - event_id=parent_id, - relation_type=relation_type, - event_type=event_type, - aggregation_key=key, - limit=limit, - from_token=from_token, - to_token=to_token, - ) - - events = await self.store.get_events_as_list( - [c["event_id"] for c in result.chunk] - ) - - now = self.clock.time_msec() - events = await self._event_serializer.serialize_events(events, now) - - return_value = result.to_dict() - return_value["chunk"] = events - - return 200, return_value - - -def register_servlets(hs, http_server): - RelationSendServlet(hs).register(http_server) - RelationPaginationServlet(hs).register(http_server) - RelationAggregationPaginationServlet(hs).register(http_server) - RelationAggregationGroupPaginationServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py deleted file mode 100644 index 07ea39a8a3..0000000000 --- a/synapse/rest/client/v2_alpha/report_event.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from http import HTTPStatus - -from synapse.api.errors import Codes, SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class ReportEventRestServlet(RestServlet): - PATTERNS = client_patterns("/rooms/(?P[^/]*)/report/(?P[^/]*)$") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.store = hs.get_datastore() - - async def on_POST(self, request, room_id, event_id): - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - - body = parse_json_object_from_request(request) - - if not isinstance(body.get("reason", ""), str): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Param 'reason' must be a string", - Codes.BAD_JSON, - ) - if not isinstance(body.get("score", 0), int): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Param 'score' must be an integer", - Codes.BAD_JSON, - ) - - await self.store.add_event_report( - room_id=room_id, - event_id=event_id, - user_id=user_id, - reason=body.get("reason"), - content=body, - received_ts=self.clock.time_msec(), - ) - - return 200, {} - - -def register_servlets(hs, http_server): - ReportEventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/room.py b/synapse/rest/client/v2_alpha/room.py deleted file mode 100644 index 3172aba605..0000000000 --- a/synapse/rest/client/v2_alpha/room.py +++ /dev/null @@ -1,441 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import re - -from synapse.api.constants import EventContentFields, EventTypes -from synapse.api.errors import AuthError, Codes, SynapseError -from synapse.appservice import ApplicationService -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_json_object_from_request, - parse_string, - parse_strings_from_args, -) -from synapse.rest.client.transactions import HttpTransactionCache -from synapse.types import Requester, UserID, create_requester -from synapse.util.stringutils import random_string - -logger = logging.getLogger(__name__) - - -class RoomBatchSendEventRestServlet(RestServlet): - """ - API endpoint which can insert a chunk of events historically back in time - next to the given `prev_event`. - - `chunk_id` comes from `next_chunk_id `in the response of the batch send - endpoint and is derived from the "insertion" events added to each chunk. - It's not required for the first batch send. - - `state_events_at_start` is used to define the historical state events - needed to auth the events like join events. These events will float - outside of the normal DAG as outlier's and won't be visible in the chat - history which also allows us to insert multiple chunks without having a bunch - of `@mxid joined the room` noise between each chunk. - - `events` is chronological chunk/list of events you want to insert. - There is a reverse-chronological constraint on chunks so once you insert - some messages, you can only insert older ones after that. - tldr; Insert chunks from your most recent history -> oldest history. - - POST /_matrix/client/unstable/org.matrix.msc2716/rooms//batch_send?prev_event=&chunk_id= - { - "events": [ ... ], - "state_events_at_start": [ ... ] - } - """ - - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc2716" - "/rooms/(?P[^/]*)/batch_send$" - ), - ) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.store = hs.get_datastore() - self.state_store = hs.get_storage().state - self.event_creation_handler = hs.get_event_creation_handler() - self.room_member_handler = hs.get_room_member_handler() - self.auth = hs.get_auth() - self.txns = HttpTransactionCache(hs) - - async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int: - ( - most_recent_prev_event_id, - most_recent_prev_event_depth, - ) = await self.store.get_max_depth_of(prev_event_ids) - - # We want to insert the historical event after the `prev_event` but before the successor event - # - # We inherit depth from the successor event instead of the `prev_event` - # because events returned from `/messages` are first sorted by `topological_ordering` - # which is just the `depth` and then tie-break with `stream_ordering`. - # - # We mark these inserted historical events as "backfilled" which gives them a - # negative `stream_ordering`. If we use the same depth as the `prev_event`, - # then our historical event will tie-break and be sorted before the `prev_event` - # when it should come after. - # - # We want to use the successor event depth so they appear after `prev_event` because - # it has a larger `depth` but before the successor event because the `stream_ordering` - # is negative before the successor event. - successor_event_ids = await self.store.get_successor_events( - [most_recent_prev_event_id] - ) - - # If we can't find any successor events, then it's a forward extremity of - # historical messages and we can just inherit from the previous historical - # event which we can already assume has the correct depth where we want - # to insert into. - if not successor_event_ids: - depth = most_recent_prev_event_depth - else: - ( - _, - oldest_successor_depth, - ) = await self.store.get_min_depth_of(successor_event_ids) - - depth = oldest_successor_depth - - return depth - - def _create_insertion_event_dict( - self, sender: str, room_id: str, origin_server_ts: int - ): - """Creates an event dict for an "insertion" event with the proper fields - and a random chunk ID. - - Args: - sender: The event author MXID - room_id: The room ID that the event belongs to - origin_server_ts: Timestamp when the event was sent - - Returns: - Tuple of event ID and stream ordering position - """ - - next_chunk_id = random_string(8) - insertion_event = { - "type": EventTypes.MSC2716_INSERTION, - "sender": sender, - "room_id": room_id, - "content": { - EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id, - EventContentFields.MSC2716_HISTORICAL: True, - }, - "origin_server_ts": origin_server_ts, - } - - return insertion_event - - async def _create_requester_for_user_id_from_app_service( - self, user_id: str, app_service: ApplicationService - ) -> Requester: - """Creates a new requester for the given user_id - and validates that the app service is allowed to control - the given user. - - Args: - user_id: The author MXID that the app service is controlling - app_service: The app service that controls the user - - Returns: - Requester object - """ - - await self.auth.validate_appservice_can_control_user_id(app_service, user_id) - - return create_requester(user_id, app_service=app_service) - - async def on_POST(self, request, room_id): - requester = await self.auth.get_user_by_req(request, allow_guest=False) - - if not requester.app_service: - raise AuthError( - 403, - "Only application services can use the /batchsend endpoint", - ) - - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["state_events_at_start", "events"]) - - prev_events_from_query = parse_strings_from_args(request.args, "prev_event") - chunk_id_from_query = parse_string(request, "chunk_id") - - if prev_events_from_query is None: - raise SynapseError( - 400, - "prev_event query parameter is required when inserting historical messages back in time", - errcode=Codes.MISSING_PARAM, - ) - - # For the event we are inserting next to (`prev_events_from_query`), - # find the most recent auth events (derived from state events) that - # allowed that message to be sent. We will use that as a base - # to auth our historical messages against. - ( - most_recent_prev_event_id, - _, - ) = await self.store.get_max_depth_of(prev_events_from_query) - # mapping from (type, state_key) -> state_event_id - prev_state_map = await self.state_store.get_state_ids_for_event( - most_recent_prev_event_id - ) - # List of state event ID's - prev_state_ids = list(prev_state_map.values()) - auth_event_ids = prev_state_ids - - state_events_at_start = [] - for state_event in body["state_events_at_start"]: - assert_params_in_dict( - state_event, ["type", "origin_server_ts", "content", "sender"] - ) - - logger.debug( - "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s", - state_event, - auth_event_ids, - ) - - event_dict = { - "type": state_event["type"], - "origin_server_ts": state_event["origin_server_ts"], - "content": state_event["content"], - "room_id": room_id, - "sender": state_event["sender"], - "state_key": state_event["state_key"], - } - - # Mark all events as historical - event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - - # Make the state events float off on their own - fake_prev_event_id = "$" + random_string(43) - - # TODO: This is pretty much the same as some other code to handle inserting state in this file - if event_dict["type"] == EventTypes.Member: - membership = event_dict["content"].get("membership", None) - event_id, _ = await self.room_member_handler.update_membership( - await self._create_requester_for_user_id_from_app_service( - state_event["sender"], requester.app_service - ), - target=UserID.from_string(event_dict["state_key"]), - room_id=room_id, - action=membership, - content=event_dict["content"], - outlier=True, - prev_event_ids=[fake_prev_event_id], - # Make sure to use a copy of this list because we modify it - # later in the loop here. Otherwise it will be the same - # reference and also update in the event when we append later. - auth_event_ids=auth_event_ids.copy(), - ) - else: - # TODO: Add some complement tests that adds state that is not member joins - # and will use this code path. Maybe we only want to support join state events - # and can get rid of this `else`? - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - await self._create_requester_for_user_id_from_app_service( - state_event["sender"], requester.app_service - ), - event_dict, - outlier=True, - prev_event_ids=[fake_prev_event_id], - # Make sure to use a copy of this list because we modify it - # later in the loop here. Otherwise it will be the same - # reference and also update in the event when we append later. - auth_event_ids=auth_event_ids.copy(), - ) - event_id = event.event_id - - state_events_at_start.append(event_id) - auth_event_ids.append(event_id) - - events_to_create = body["events"] - - inherited_depth = await self._inherit_depth_from_prev_ids( - prev_events_from_query - ) - - # Figure out which chunk to connect to. If they passed in - # chunk_id_from_query let's use it. The chunk ID passed in comes - # from the chunk_id in the "insertion" event from the previous chunk. - last_event_in_chunk = events_to_create[-1] - chunk_id_to_connect_to = chunk_id_from_query - base_insertion_event = None - if chunk_id_from_query: - # All but the first base insertion event should point at a fake - # event, which causes the HS to ask for the state at the start of - # the chunk later. - prev_event_ids = [fake_prev_event_id] - # TODO: Verify the chunk_id_from_query corresponds to an insertion event - pass - # Otherwise, create an insertion event to act as a starting point. - # - # We don't always have an insertion event to start hanging more history - # off of (ideally there would be one in the main DAG, but that's not the - # case if we're wanting to add history to e.g. existing rooms without - # an insertion event), in which case we just create a new insertion event - # that can then get pointed to by a "marker" event later. - else: - prev_event_ids = prev_events_from_query - - base_insertion_event_dict = self._create_insertion_event_dict( - sender=requester.user.to_string(), - room_id=room_id, - origin_server_ts=last_event_in_chunk["origin_server_ts"], - ) - base_insertion_event_dict["prev_events"] = prev_event_ids.copy() - - ( - base_insertion_event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - await self._create_requester_for_user_id_from_app_service( - base_insertion_event_dict["sender"], - requester.app_service, - ), - base_insertion_event_dict, - prev_event_ids=base_insertion_event_dict.get("prev_events"), - auth_event_ids=auth_event_ids, - historical=True, - depth=inherited_depth, - ) - - chunk_id_to_connect_to = base_insertion_event["content"][ - EventContentFields.MSC2716_NEXT_CHUNK_ID - ] - - # Connect this current chunk to the insertion event from the previous chunk - chunk_event = { - "type": EventTypes.MSC2716_CHUNK, - "sender": requester.user.to_string(), - "room_id": room_id, - "content": { - EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to, - EventContentFields.MSC2716_HISTORICAL: True, - }, - # Since the chunk event is put at the end of the chunk, - # where the newest-in-time event is, copy the origin_server_ts from - # the last event we're inserting - "origin_server_ts": last_event_in_chunk["origin_server_ts"], - } - # Add the chunk event to the end of the chunk (newest-in-time) - events_to_create.append(chunk_event) - - # Add an "insertion" event to the start of each chunk (next to the oldest-in-time - # event in the chunk) so the next chunk can be connected to this one. - insertion_event = self._create_insertion_event_dict( - sender=requester.user.to_string(), - room_id=room_id, - # Since the insertion event is put at the start of the chunk, - # where the oldest-in-time event is, copy the origin_server_ts from - # the first event we're inserting - origin_server_ts=events_to_create[0]["origin_server_ts"], - ) - # Prepend the insertion event to the start of the chunk (oldest-in-time) - events_to_create = [insertion_event] + events_to_create - - event_ids = [] - events_to_persist = [] - for ev in events_to_create: - assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) - - event_dict = { - "type": ev["type"], - "origin_server_ts": ev["origin_server_ts"], - "content": ev["content"], - "room_id": room_id, - "sender": ev["sender"], # requester.user.to_string(), - "prev_events": prev_event_ids.copy(), - } - - # Mark all events as historical - event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - - event, context = await self.event_creation_handler.create_event( - await self._create_requester_for_user_id_from_app_service( - ev["sender"], requester.app_service - ), - event_dict, - prev_event_ids=event_dict.get("prev_events"), - auth_event_ids=auth_event_ids, - historical=True, - depth=inherited_depth, - ) - logger.debug( - "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s", - event, - prev_event_ids, - auth_event_ids, - ) - - assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( - event.sender, - ) - - events_to_persist.append((event, context)) - event_id = event.event_id - - event_ids.append(event_id) - prev_event_ids = [event_id] - - # Persist events in reverse-chronological order so they have the - # correct stream_ordering as they are backfilled (which decrements). - # Events are sorted by (topological_ordering, stream_ordering) - # where topological_ordering is just depth. - for (event, context) in reversed(events_to_persist): - ev = await self.event_creation_handler.handle_new_client_event( - await self._create_requester_for_user_id_from_app_service( - event["sender"], requester.app_service - ), - event=event, - context=context, - ) - - # Add the base_insertion_event to the bottom of the list we return - if base_insertion_event is not None: - event_ids.append(base_insertion_event.event_id) - - return 200, { - "state_events": state_events_at_start, - "events": event_ids, - "next_chunk_id": insertion_event["content"][ - EventContentFields.MSC2716_NEXT_CHUNK_ID - ], - } - - def on_GET(self, request, room_id): - return 501, "Not implemented" - - def on_PUT(self, request, room_id): - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id - ) - - -def register_servlets(hs, http_server): - msc2716_enabled = hs.config.experimental.msc2716_enabled - - if msc2716_enabled: - RoomBatchSendEventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py deleted file mode 100644 index 263596be86..0000000000 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ /dev/null @@ -1,391 +0,0 @@ -# Copyright 2017, 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.errors import Codes, NotFoundError, SynapseError -from synapse.http.servlet import ( - RestServlet, - parse_json_object_from_request, - parse_string, -) - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class RoomKeysServlet(RestServlet): - PATTERNS = client_patterns( - "/room_keys/keys(/(?P[^/]+))?(/(?P[^/]+))?$" - ) - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.auth = hs.get_auth() - self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - - async def on_PUT(self, request, room_id, session_id): - """ - Uploads one or more encrypted E2E room keys for backup purposes. - room_id: the ID of the room the keys are for (optional) - session_id: the ID for the E2E room keys for the room (optional) - version: the version of the user's backup which this data is for. - the version must already have been created via the /room_keys/version API. - - Each session has: - * first_message_index: a numeric index indicating the oldest message - encrypted by this session. - * forwarded_count: how many times the uploading client claims this key - has been shared (forwarded) - * is_verified: whether the client that uploaded the keys claims they - were sent by a device which they've verified - * session_data: base64-encrypted data describing the session. - - Returns 200 OK on success with body {} - Returns 403 Forbidden if the version in question is not the most recently - created version (i.e. if this is an old client trying to write to a stale backup) - Returns 404 Not Found if the version in question doesn't exist - - The API is designed to be otherwise agnostic to the room_key encryption - algorithm being used. Sessions are merged with existing ones in the - backup using the heuristics: - * is_verified sessions always win over unverified sessions - * older first_message_index always win over newer sessions - * lower forwarded_count always wins over higher forwarded_count - - We trust the clients not to lie and corrupt their own backups. - It also means that if your access_token is stolen, the attacker could - delete your backup. - - POST /room_keys/keys/!abc:matrix.org/c0ff33?version=1 HTTP/1.1 - Content-Type: application/json - - { - "first_message_index": 1, - "forwarded_count": 1, - "is_verified": false, - "session_data": "SSBBTSBBIEZJU0gK" - } - - Or... - - POST /room_keys/keys/!abc:matrix.org?version=1 HTTP/1.1 - Content-Type: application/json - - { - "sessions": { - "c0ff33": { - "first_message_index": 1, - "forwarded_count": 1, - "is_verified": false, - "session_data": "SSBBTSBBIEZJU0gK" - } - } - } - - Or... - - POST /room_keys/keys?version=1 HTTP/1.1 - Content-Type: application/json - - { - "rooms": { - "!abc:matrix.org": { - "sessions": { - "c0ff33": { - "first_message_index": 1, - "forwarded_count": 1, - "is_verified": false, - "session_data": "SSBBTSBBIEZJU0gK" - } - } - } - } - } - """ - requester = await self.auth.get_user_by_req(request, allow_guest=False) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - version = parse_string(request, "version") - - if session_id: - body = {"sessions": {session_id: body}} - - if room_id: - body = {"rooms": {room_id: body}} - - ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) - return 200, ret - - async def on_GET(self, request, room_id, session_id): - """ - Retrieves one or more encrypted E2E room keys for backup purposes. - Symmetric with the PUT version of the API. - - room_id: the ID of the room to retrieve the keys for (optional) - session_id: the ID for the E2E room keys to retrieve the keys for (optional) - version: the version of the user's backup which this data is for. - the version must already have been created via the /change_secret API. - - Returns as follows: - - GET /room_keys/keys/!abc:matrix.org/c0ff33?version=1 HTTP/1.1 - { - "first_message_index": 1, - "forwarded_count": 1, - "is_verified": false, - "session_data": "SSBBTSBBIEZJU0gK" - } - - Or... - - GET /room_keys/keys/!abc:matrix.org?version=1 HTTP/1.1 - { - "sessions": { - "c0ff33": { - "first_message_index": 1, - "forwarded_count": 1, - "is_verified": false, - "session_data": "SSBBTSBBIEZJU0gK" - } - } - } - - Or... - - GET /room_keys/keys?version=1 HTTP/1.1 - { - "rooms": { - "!abc:matrix.org": { - "sessions": { - "c0ff33": { - "first_message_index": 1, - "forwarded_count": 1, - "is_verified": false, - "session_data": "SSBBTSBBIEZJU0gK" - } - } - } - } - } - """ - requester = await self.auth.get_user_by_req(request, allow_guest=False) - user_id = requester.user.to_string() - version = parse_string(request, "version", required=True) - - room_keys = await self.e2e_room_keys_handler.get_room_keys( - user_id, version, room_id, session_id - ) - - # Convert room_keys to the right format to return. - if session_id: - # If the client requests a specific session, but that session was - # not backed up, then return an M_NOT_FOUND. - if room_keys["rooms"] == {}: - raise NotFoundError("No room_keys found") - else: - room_keys = room_keys["rooms"][room_id]["sessions"][session_id] - elif room_id: - # If the client requests all sessions from a room, but no sessions - # are found, then return an empty result rather than an error, so - # that clients don't have to handle an error condition, and an - # empty result is valid. (Similarly if the client requests all - # sessions from the backup, but in that case, room_keys is already - # in the right format, so we don't need to do anything about it.) - if room_keys["rooms"] == {}: - room_keys = {"sessions": {}} - else: - room_keys = room_keys["rooms"][room_id] - - return 200, room_keys - - async def on_DELETE(self, request, room_id, session_id): - """ - Deletes one or more encrypted E2E room keys for a user for backup purposes. - - DELETE /room_keys/keys/!abc:matrix.org/c0ff33?version=1 - HTTP/1.1 200 OK - {} - - room_id: the ID of the room whose keys to delete (optional) - session_id: the ID for the E2E session to delete (optional) - version: the version of the user's backup which this data is for. - the version must already have been created via the /change_secret API. - """ - - requester = await self.auth.get_user_by_req(request, allow_guest=False) - user_id = requester.user.to_string() - version = parse_string(request, "version") - - ret = await self.e2e_room_keys_handler.delete_room_keys( - user_id, version, room_id, session_id - ) - return 200, ret - - -class RoomKeysNewVersionServlet(RestServlet): - PATTERNS = client_patterns("/room_keys/version$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.auth = hs.get_auth() - self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - - async def on_POST(self, request): - """ - Create a new backup version for this user's room_keys with the given - info. The version is allocated by the server and returned to the user - in the response. This API is intended to be used whenever the user - changes the encryption key for their backups, ensuring that backups - encrypted with different keys don't collide. - - It takes out an exclusive lock on this user's room_key backups, to ensure - clients only upload to the current backup. - - The algorithm passed in the version info is a reverse-DNS namespaced - identifier to describe the format of the encrypted backupped keys. - - The auth_data is { user_id: "user_id", nonce: } - encrypted using the algorithm and current encryption key described above. - - POST /room_keys/version - Content-Type: application/json - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" - } - - HTTP/1.1 200 OK - Content-Type: application/json - { - "version": 12345 - } - """ - requester = await self.auth.get_user_by_req(request, allow_guest=False) - user_id = requester.user.to_string() - info = parse_json_object_from_request(request) - - new_version = await self.e2e_room_keys_handler.create_version(user_id, info) - return 200, {"version": new_version} - - # we deliberately don't have a PUT /version, as these things really should - # be immutable to avoid people footgunning - - -class RoomKeysVersionServlet(RestServlet): - PATTERNS = client_patterns("/room_keys/version(/(?P[^/]+))?$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.auth = hs.get_auth() - self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - - async def on_GET(self, request, version): - """ - Retrieve the version information about a given version of the user's - room_keys backup. If the version part is missing, returns info about the - most current backup version (if any) - - It takes out an exclusive lock on this user's room_key backups, to ensure - clients only upload to the current backup. - - Returns 404 if the given version does not exist. - - GET /room_keys/version/12345 HTTP/1.1 - { - "version": "12345", - "algorithm": "m.megolm_backup.v1", - "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" - } - """ - requester = await self.auth.get_user_by_req(request, allow_guest=False) - user_id = requester.user.to_string() - - try: - info = await self.e2e_room_keys_handler.get_version_info(user_id, version) - except SynapseError as e: - if e.code == 404: - raise SynapseError(404, "No backup found", Codes.NOT_FOUND) - return 200, info - - async def on_DELETE(self, request, version): - """ - Delete the information about a given version of the user's - room_keys backup. If the version part is missing, deletes the most - current backup version (if any). Doesn't delete the actual room data. - - DELETE /room_keys/version/12345 HTTP/1.1 - HTTP/1.1 200 OK - {} - """ - if version is None: - raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND) - - requester = await self.auth.get_user_by_req(request, allow_guest=False) - user_id = requester.user.to_string() - - await self.e2e_room_keys_handler.delete_version(user_id, version) - return 200, {} - - async def on_PUT(self, request, version): - """ - Update the information about a given version of the user's room_keys backup. - - POST /room_keys/version/12345 HTTP/1.1 - Content-Type: application/json - { - "algorithm": "m.megolm_backup.v1", - "auth_data": { - "public_key": "abcdefg", - "signatures": { - "ed25519:something": "hijklmnop" - } - }, - "version": "12345" - } - - HTTP/1.1 200 OK - Content-Type: application/json - {} - """ - requester = await self.auth.get_user_by_req(request, allow_guest=False) - user_id = requester.user.to_string() - info = parse_json_object_from_request(request) - - if version is None: - raise SynapseError( - 400, "No version specified to update", Codes.MISSING_PARAM - ) - - await self.e2e_room_keys_handler.update_version(user_id, version, info) - return 200, {} - - -def register_servlets(hs, http_server): - RoomKeysServlet(hs).register(http_server) - RoomKeysVersionServlet(hs).register(http_server) - RoomKeysNewVersionServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py deleted file mode 100644 index 6d1b083acb..0000000000 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.errors import Codes, ShadowBanError, SynapseError -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_json_object_from_request, -) -from synapse.util import stringutils - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class RoomUpgradeRestServlet(RestServlet): - """Handler for room upgrade requests. - - Handles requests of the form: - - POST /_matrix/client/r0/rooms/$roomid/upgrade HTTP/1.1 - Content-Type: application/json - - { - "new_version": "2", - } - - Creates a new room and shuts down the old one. Returns the ID of the new room. - - Args: - hs (synapse.server.HomeServer): - """ - - PATTERNS = client_patterns( - # /rooms/$roomid/upgrade - "/rooms/(?P[^/]*)/upgrade$" - ) - - def __init__(self, hs): - super().__init__() - self._hs = hs - self._room_creation_handler = hs.get_room_creation_handler() - self._auth = hs.get_auth() - - async def on_POST(self, request, room_id): - requester = await self._auth.get_user_by_req(request) - - content = parse_json_object_from_request(request) - assert_params_in_dict(content, ("new_version",)) - - new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"]) - if new_version is None: - raise SynapseError( - 400, - "Your homeserver does not support this room version", - Codes.UNSUPPORTED_ROOM_VERSION, - ) - - try: - new_room_id = await self._room_creation_handler.upgrade_room( - requester, room_id, new_version - ) - except ShadowBanError: - # Generate a random room ID. - new_room_id = stringutils.random_string(18) - - ret = {"replacement_room": new_room_id} - - return 200, ret - - -def register_servlets(hs, http_server): - RoomUpgradeRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py deleted file mode 100644 index d537d811d8..0000000000 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Tuple - -from synapse.http import servlet -from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request -from synapse.logging.opentracing import set_tag, trace -from synapse.rest.client.transactions import HttpTransactionCache - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class SendToDeviceRestServlet(servlet.RestServlet): - PATTERNS = client_patterns( - "/sendToDevice/(?P[^/]*)/(?P[^/]*)$" - ) - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.txns = HttpTransactionCache(hs) - self.device_message_handler = hs.get_device_message_handler() - - @trace(opname="sendToDevice") - def on_PUT(self, request, message_type, txn_id): - set_tag("message_type", message_type) - set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request( - request, self._put, request, message_type, txn_id - ) - - async def _put(self, request, message_type, txn_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - content = parse_json_object_from_request(request) - assert_params_in_dict(content, ("messages",)) - - await self.device_message_handler.send_device_message( - requester, message_type, content["messages"] - ) - - response: Tuple[int, dict] = (200, {}) - return response - - -def register_servlets(hs, http_server): - SendToDeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/v2_alpha/shared_rooms.py deleted file mode 100644 index d2e7f04b40..0000000000 --- a/synapse/rest/client/v2_alpha/shared_rooms.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2020 Half-Shot -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from synapse.api.errors import Codes, SynapseError -from synapse.http.servlet import RestServlet -from synapse.types import UserID - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class UserSharedRoomsServlet(RestServlet): - """ - GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1 - """ - - PATTERNS = client_patterns( - "/uk.half-shot.msc2666/user/shared_rooms/(?P[^/]*)", - releases=(), # This is an unstable feature - ) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.user_directory_active = hs.config.update_user_directory - - async def on_GET(self, request, user_id): - - if not self.user_directory_active: - raise SynapseError( - code=400, - msg="The user directory is disabled on this server. Cannot determine shared rooms.", - errcode=Codes.FORBIDDEN, - ) - - UserID.from_string(user_id) - - requester = await self.auth.get_user_by_req(request) - if user_id == requester.user.to_string(): - raise SynapseError( - code=400, - msg="You cannot request a list of shared rooms with yourself", - errcode=Codes.FORBIDDEN, - ) - rooms = await self.store.get_shared_rooms_for_users( - requester.user.to_string(), user_id - ) - - return 200, {"joined": list(rooms)} - - -def register_servlets(hs, http_server): - UserSharedRoomsServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py deleted file mode 100644 index e18f4d01b3..0000000000 --- a/synapse/rest/client/v2_alpha/sync.py +++ /dev/null @@ -1,532 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import itertools -import logging -from collections import defaultdict -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple - -from synapse.api.constants import Membership, PresenceState -from synapse.api.errors import Codes, StoreError, SynapseError -from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection -from synapse.events.utils import ( - format_event_for_client_v2_without_room_id, - format_event_raw, -) -from synapse.handlers.presence import format_user_presence_state -from synapse.handlers.sync import KnockedSyncResult, SyncConfig -from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string -from synapse.http.site import SynapseRequest -from synapse.types import JsonDict, StreamToken -from synapse.util import json_decoder - -from ._base import client_patterns, set_timeline_upper_limit - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class SyncRestServlet(RestServlet): - """ - - GET parameters:: - timeout(int): How long to wait for new events in milliseconds. - since(batch_token): Batch token when asking for incremental deltas. - set_presence(str): What state the device presence should be set to. - default is "online". - filter(filter_id): A filter to apply to the events returned. - - Response JSON:: - { - "next_batch": // batch token for the next /sync - "presence": // presence data for the user. - "rooms": { - "join": { // Joined rooms being updated. - "${room_id}": { // Id of the room being updated - "event_map": // Map of EventID -> event JSON. - "timeline": { // The recent events in the room if gap is "true" - "limited": // Was the per-room event limit exceeded? - // otherwise the next events in the room. - "events": [] // list of EventIDs in the "event_map". - "prev_batch": // back token for getting previous events. - } - "state": {"events": []} // list of EventIDs updating the - // current state to be what it should - // be at the end of the batch. - "ephemeral": {"events": []} // list of event objects - } - }, - "invite": {}, // Invited rooms being updated. - "leave": {} // Archived rooms being updated. - } - } - """ - - PATTERNS = client_patterns("/sync$") - ALLOWED_PRESENCE = {"online", "offline", "unavailable"} - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.sync_handler = hs.get_sync_handler() - self.clock = hs.get_clock() - self.filtering = hs.get_filtering() - self.presence_handler = hs.get_presence_handler() - self._server_notices_sender = hs.get_server_notices_sender() - self._event_serializer = hs.get_event_client_serializer() - - async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - # This will always be set by the time Twisted calls us. - assert request.args is not None - - if b"from" in request.args: - # /events used to use 'from', but /sync uses 'since'. - # Lets be helpful and whine if we see a 'from'. - raise SynapseError( - 400, "'from' is not a valid query parameter. Did you mean 'since'?" - ) - - requester = await self.auth.get_user_by_req(request, allow_guest=True) - user = requester.user - device_id = requester.device_id - - timeout = parse_integer(request, "timeout", default=0) - since = parse_string(request, "since") - set_presence = parse_string( - request, - "set_presence", - default="online", - allowed_values=self.ALLOWED_PRESENCE, - ) - filter_id = parse_string(request, "filter") - full_state = parse_boolean(request, "full_state", default=False) - - logger.debug( - "/sync: user=%r, timeout=%r, since=%r, " - "set_presence=%r, filter_id=%r, device_id=%r", - user, - timeout, - since, - set_presence, - filter_id, - device_id, - ) - - request_key = (user, timeout, since, filter_id, full_state, device_id) - - if filter_id is None: - filter_collection = DEFAULT_FILTER_COLLECTION - elif filter_id.startswith("{"): - try: - filter_object = json_decoder.decode(filter_id) - set_timeline_upper_limit( - filter_object, self.hs.config.filter_timeline_limit - ) - except Exception: - raise SynapseError(400, "Invalid filter JSON") - self.filtering.check_valid_filter(filter_object) - filter_collection = FilterCollection(filter_object) - else: - try: - filter_collection = await self.filtering.get_user_filter( - user.localpart, filter_id - ) - except StoreError as err: - if err.code != 404: - raise - # fix up the description and errcode to be more useful - raise SynapseError(400, "No such filter", errcode=Codes.INVALID_PARAM) - - sync_config = SyncConfig( - user=user, - filter_collection=filter_collection, - is_guest=requester.is_guest, - request_key=request_key, - device_id=device_id, - ) - - since_token = None - if since is not None: - since_token = await StreamToken.from_string(self.store, since) - - # send any outstanding server notices to the user. - await self._server_notices_sender.on_user_syncing(user.to_string()) - - affect_presence = set_presence != PresenceState.OFFLINE - - if affect_presence: - await self.presence_handler.set_state( - user, {"presence": set_presence}, True - ) - - context = await self.presence_handler.user_syncing( - user.to_string(), affect_presence=affect_presence - ) - with context: - sync_result = await self.sync_handler.wait_for_sync_for_user( - requester, - sync_config, - since_token=since_token, - timeout=timeout, - full_state=full_state, - ) - - # the client may have disconnected by now; don't bother to serialize the - # response if so. - if request._disconnected: - logger.info("Client has disconnected; not serializing response.") - return 200, {} - - time_now = self.clock.time_msec() - response_content = await self.encode_response( - time_now, sync_result, requester.access_token_id, filter_collection - ) - - logger.debug("Event formatting complete") - return 200, response_content - - async def encode_response(self, time_now, sync_result, access_token_id, filter): - logger.debug("Formatting events in sync response") - if filter.event_format == "client": - event_formatter = format_event_for_client_v2_without_room_id - elif filter.event_format == "federation": - event_formatter = format_event_raw - else: - raise Exception("Unknown event format %s" % (filter.event_format,)) - - joined = await self.encode_joined( - sync_result.joined, - time_now, - access_token_id, - filter.event_fields, - event_formatter, - ) - - invited = await self.encode_invited( - sync_result.invited, time_now, access_token_id, event_formatter - ) - - knocked = await self.encode_knocked( - sync_result.knocked, time_now, access_token_id, event_formatter - ) - - archived = await self.encode_archived( - sync_result.archived, - time_now, - access_token_id, - filter.event_fields, - event_formatter, - ) - - logger.debug("building sync response dict") - - response: dict = defaultdict(dict) - response["next_batch"] = await sync_result.next_batch.to_string(self.store) - - if sync_result.account_data: - response["account_data"] = {"events": sync_result.account_data} - if sync_result.presence: - response["presence"] = SyncRestServlet.encode_presence( - sync_result.presence, time_now - ) - - if sync_result.to_device: - response["to_device"] = {"events": sync_result.to_device} - - if sync_result.device_lists.changed: - response["device_lists"]["changed"] = list(sync_result.device_lists.changed) - if sync_result.device_lists.left: - response["device_lists"]["left"] = list(sync_result.device_lists.left) - - # We always include this because https://github.com/vector-im/element-android/issues/3725 - # The spec isn't terribly clear on when this can be omitted and how a client would tell - # the difference between "no keys present" and "nothing changed" in terms of whole field - # absent / individual key type entry absent - # Corresponding synapse issue: https://github.com/matrix-org/synapse/issues/10456 - response["device_one_time_keys_count"] = sync_result.device_one_time_keys_count - - # https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md - # states that this field should always be included, as long as the server supports the feature. - response[ - "org.matrix.msc2732.device_unused_fallback_key_types" - ] = sync_result.device_unused_fallback_key_types - - if joined: - response["rooms"][Membership.JOIN] = joined - if invited: - response["rooms"][Membership.INVITE] = invited - if knocked: - response["rooms"][Membership.KNOCK] = knocked - if archived: - response["rooms"][Membership.LEAVE] = archived - - if sync_result.groups.join: - response["groups"][Membership.JOIN] = sync_result.groups.join - if sync_result.groups.invite: - response["groups"][Membership.INVITE] = sync_result.groups.invite - if sync_result.groups.leave: - response["groups"][Membership.LEAVE] = sync_result.groups.leave - - return response - - @staticmethod - def encode_presence(events, time_now): - return { - "events": [ - { - "type": "m.presence", - "sender": event.user_id, - "content": format_user_presence_state( - event, time_now, include_user_id=False - ), - } - for event in events - ] - } - - async def encode_joined( - self, rooms, time_now, token_id, event_fields, event_formatter - ): - """ - Encode the joined rooms in a sync result - - Args: - rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync - results for rooms this user is joined to - time_now(int): current time - used as a baseline for age - calculations - token_id(int): ID of the user's auth token - used for namespacing - of transaction IDs - event_fields(list): List of event fields to include. If empty, - all fields will be returned. - event_formatter (func[dict]): function to convert from federation format - to client format - Returns: - dict[str, dict[str, object]]: the joined rooms list, in our - response format - """ - joined = {} - for room in rooms: - joined[room.room_id] = await self.encode_room( - room, - time_now, - token_id, - joined=True, - only_fields=event_fields, - event_formatter=event_formatter, - ) - - return joined - - async def encode_invited(self, rooms, time_now, token_id, event_formatter): - """ - Encode the invited rooms in a sync result - - Args: - rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of - sync results for rooms this user is invited to - time_now(int): current time - used as a baseline for age - calculations - token_id(int): ID of the user's auth token - used for namespacing - of transaction IDs - event_formatter (func[dict]): function to convert from federation format - to client format - - Returns: - dict[str, dict[str, object]]: the invited rooms list, in our - response format - """ - invited = {} - for room in rooms: - invite = await self._event_serializer.serialize_event( - room.invite, - time_now, - token_id=token_id, - event_format=event_formatter, - include_stripped_room_state=True, - ) - unsigned = dict(invite.get("unsigned", {})) - invite["unsigned"] = unsigned - invited_state = list(unsigned.pop("invite_room_state", [])) - invited_state.append(invite) - invited[room.room_id] = {"invite_state": {"events": invited_state}} - - return invited - - async def encode_knocked( - self, - rooms: List[KnockedSyncResult], - time_now: int, - token_id: int, - event_formatter: Callable[[Dict], Dict], - ) -> Dict[str, Dict[str, Any]]: - """ - Encode the rooms we've knocked on in a sync result. - - Args: - rooms: list of sync results for rooms this user is knocking on - time_now: current time - used as a baseline for age calculations - token_id: ID of the user's auth token - used for namespacing of transaction IDs - event_formatter: function to convert from federation format to client format - - Returns: - The list of rooms the user has knocked on, in our response format. - """ - knocked = {} - for room in rooms: - knock = await self._event_serializer.serialize_event( - room.knock, - time_now, - token_id=token_id, - event_format=event_formatter, - include_stripped_room_state=True, - ) - - # Extract the `unsigned` key from the knock event. - # This is where we (cheekily) store the knock state events - unsigned = knock.setdefault("unsigned", {}) - - # Duplicate the dictionary in order to avoid modifying the original - unsigned = dict(unsigned) - - # Extract the stripped room state from the unsigned dict - # This is for clients to get a little bit of information about - # the room they've knocked on, without revealing any sensitive information - knocked_state = list(unsigned.pop("knock_room_state", [])) - - # Append the actual knock membership event itself as well. This provides - # the client with: - # - # * A knock state event that they can use for easier internal tracking - # * The rough timestamp of when the knock occurred contained within the event - knocked_state.append(knock) - - # Build the `knock_state` dictionary, which will contain the state of the - # room that the client has knocked on - knocked[room.room_id] = {"knock_state": {"events": knocked_state}} - - return knocked - - async def encode_archived( - self, rooms, time_now, token_id, event_fields, event_formatter - ): - """ - Encode the archived rooms in a sync result - - Args: - rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of - sync results for rooms this user is joined to - time_now(int): current time - used as a baseline for age - calculations - token_id(int): ID of the user's auth token - used for namespacing - of transaction IDs - event_fields(list): List of event fields to include. If empty, - all fields will be returned. - event_formatter (func[dict]): function to convert from federation format - to client format - Returns: - dict[str, dict[str, object]]: The invited rooms list, in our - response format - """ - joined = {} - for room in rooms: - joined[room.room_id] = await self.encode_room( - room, - time_now, - token_id, - joined=False, - only_fields=event_fields, - event_formatter=event_formatter, - ) - - return joined - - async def encode_room( - self, room, time_now, token_id, joined, only_fields, event_formatter - ): - """ - Args: - room (JoinedSyncResult|ArchivedSyncResult): sync result for a - single room - time_now (int): current time - used as a baseline for age - calculations - token_id (int): ID of the user's auth token - used for namespacing - of transaction IDs - joined (bool): True if the user is joined to this room - will mean - we handle ephemeral events - only_fields(list): Optional. The list of event fields to include. - event_formatter (func[dict]): function to convert from federation format - to client format - Returns: - dict[str, object]: the room, encoded in our response format - """ - - def serialize(events): - return self._event_serializer.serialize_events( - events, - time_now=time_now, - # We don't bundle "live" events, as otherwise clients - # will end up double counting annotations. - bundle_aggregations=False, - token_id=token_id, - event_format=event_formatter, - only_event_fields=only_fields, - ) - - state_dict = room.state - timeline_events = room.timeline.events - - state_events = state_dict.values() - - for event in itertools.chain(state_events, timeline_events): - # We've had bug reports that events were coming down under the - # wrong room. - if event.room_id != room.room_id: - logger.warning( - "Event %r is under room %r instead of %r", - event.event_id, - room.room_id, - event.room_id, - ) - - serialized_state = await serialize(state_events) - serialized_timeline = await serialize(timeline_events) - - account_data = room.account_data - - result = { - "timeline": { - "events": serialized_timeline, - "prev_batch": await room.timeline.prev_batch.to_string(self.store), - "limited": room.timeline.limited, - }, - "state": {"events": serialized_state}, - "account_data": {"events": account_data}, - } - - if joined: - ephemeral_events = room.ephemeral - result["ephemeral"] = {"events": ephemeral_events} - result["unread_notifications"] = room.unread_notifications - result["summary"] = room.summary - result["org.matrix.msc2654.unread_count"] = room.unread_count - - return result - - -def register_servlets(hs, http_server): - SyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py deleted file mode 100644 index c14f83be18..0000000000 --- a/synapse/rest/client/v2_alpha/tags.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.errors import AuthError -from synapse.http.servlet import RestServlet, parse_json_object_from_request - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class TagListServlet(RestServlet): - """ - GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 - """ - - PATTERNS = client_patterns("/user/(?P[^/]*)/rooms/(?P[^/]*)/tags") - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - - async def on_GET(self, request, user_id, room_id): - requester = await self.auth.get_user_by_req(request) - if user_id != requester.user.to_string(): - raise AuthError(403, "Cannot get tags for other users.") - - tags = await self.store.get_tags_for_room(user_id, room_id) - - return 200, {"tags": tags} - - -class TagServlet(RestServlet): - """ - PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 - DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 - """ - - PATTERNS = client_patterns( - "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags/(?P[^/]*)" - ) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.handler = hs.get_account_data_handler() - - async def on_PUT(self, request, user_id, room_id, tag): - requester = await self.auth.get_user_by_req(request) - if user_id != requester.user.to_string(): - raise AuthError(403, "Cannot add tags for other users.") - - body = parse_json_object_from_request(request) - - await self.handler.add_tag_to_room(user_id, room_id, tag, body) - - return 200, {} - - async def on_DELETE(self, request, user_id, room_id, tag): - requester = await self.auth.get_user_by_req(request) - if user_id != requester.user.to_string(): - raise AuthError(403, "Cannot add tags for other users.") - - await self.handler.remove_tag_from_room(user_id, room_id, tag) - - return 200, {} - - -def register_servlets(hs, http_server): - TagListServlet(hs).register(http_server) - TagServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py deleted file mode 100644 index b5c67c9bb6..0000000000 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging - -from synapse.api.constants import ThirdPartyEntityKind -from synapse.http.servlet import RestServlet - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class ThirdPartyProtocolsServlet(RestServlet): - PATTERNS = client_patterns("/thirdparty/protocols") - - def __init__(self, hs): - super().__init__() - - self.auth = hs.get_auth() - self.appservice_handler = hs.get_application_service_handler() - - async def on_GET(self, request): - await self.auth.get_user_by_req(request, allow_guest=True) - - protocols = await self.appservice_handler.get_3pe_protocols() - return 200, protocols - - -class ThirdPartyProtocolServlet(RestServlet): - PATTERNS = client_patterns("/thirdparty/protocol/(?P[^/]+)$") - - def __init__(self, hs): - super().__init__() - - self.auth = hs.get_auth() - self.appservice_handler = hs.get_application_service_handler() - - async def on_GET(self, request, protocol): - await self.auth.get_user_by_req(request, allow_guest=True) - - protocols = await self.appservice_handler.get_3pe_protocols( - only_protocol=protocol - ) - if protocol in protocols: - return 200, protocols[protocol] - else: - return 404, {"error": "Unknown protocol"} - - -class ThirdPartyUserServlet(RestServlet): - PATTERNS = client_patterns("/thirdparty/user(/(?P[^/]+))?$") - - def __init__(self, hs): - super().__init__() - - self.auth = hs.get_auth() - self.appservice_handler = hs.get_application_service_handler() - - async def on_GET(self, request, protocol): - await self.auth.get_user_by_req(request, allow_guest=True) - - fields = request.args - fields.pop(b"access_token", None) - - results = await self.appservice_handler.query_3pe( - ThirdPartyEntityKind.USER, protocol, fields - ) - - return 200, results - - -class ThirdPartyLocationServlet(RestServlet): - PATTERNS = client_patterns("/thirdparty/location(/(?P[^/]+))?$") - - def __init__(self, hs): - super().__init__() - - self.auth = hs.get_auth() - self.appservice_handler = hs.get_application_service_handler() - - async def on_GET(self, request, protocol): - await self.auth.get_user_by_req(request, allow_guest=True) - - fields = request.args - fields.pop(b"access_token", None) - - results = await self.appservice_handler.query_3pe( - ThirdPartyEntityKind.LOCATION, protocol, fields - ) - - return 200, results - - -def register_servlets(hs, http_server): - ThirdPartyProtocolsServlet(hs).register(http_server) - ThirdPartyProtocolServlet(hs).register(http_server) - ThirdPartyUserServlet(hs).register(http_server) - ThirdPartyLocationServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py deleted file mode 100644 index b2f858545c..0000000000 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.api.errors import AuthError -from synapse.http.servlet import RestServlet - -from ._base import client_patterns - - -class TokenRefreshRestServlet(RestServlet): - """ - Exchanges refresh tokens for a pair of an access token and a new refresh - token. - """ - - PATTERNS = client_patterns("/tokenrefresh") - - def __init__(self, hs): - super().__init__() - - async def on_POST(self, request): - raise AuthError(403, "tokenrefresh is no longer supported.") - - -def register_servlets(hs, http_server): - TokenRefreshRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py deleted file mode 100644 index 7e8912f0b9..0000000000 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class UserDirectorySearchRestServlet(RestServlet): - PATTERNS = client_patterns("/user_directory/search$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.user_directory_handler = hs.get_user_directory_handler() - - async def on_POST(self, request): - """Searches for users in directory - - Returns: - dict of the form:: - - { - "limited": , # whether there were more results or not - "results": [ # Ordered by best match first - { - "user_id": , - "display_name": , - "avatar_url": - } - ] - } - """ - requester = await self.auth.get_user_by_req(request, allow_guest=False) - user_id = requester.user.to_string() - - if not self.hs.config.user_directory_search_enabled: - return 200, {"limited": False, "results": []} - - body = parse_json_object_from_request(request) - - limit = body.get("limit", 10) - limit = min(limit, 50) - - try: - search_term = body["search_term"] - except Exception: - raise SynapseError(400, "`search_term` is required field") - - results = await self.user_directory_handler.search_users( - user_id, search_term, limit - ) - - return 200, results - - -def register_servlets(hs, http_server): - UserDirectorySearchRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/voip.py b/synapse/rest/client/voip.py new file mode 100644 index 0000000000..f53020520d --- /dev/null +++ b/synapse/rest/client/voip.py @@ -0,0 +1,73 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import hashlib +import hmac + +from synapse.http.servlet import RestServlet +from synapse.rest.client._base import client_patterns + + +class VoipRestServlet(RestServlet): + PATTERNS = client_patterns("/voip/turnServer$", v1=True) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req( + request, self.hs.config.turn_allow_guests + ) + + turnUris = self.hs.config.turn_uris + turnSecret = self.hs.config.turn_shared_secret + turnUsername = self.hs.config.turn_username + turnPassword = self.hs.config.turn_password + userLifetime = self.hs.config.turn_user_lifetime + + if turnUris and turnSecret and userLifetime: + expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000 + username = "%d:%s" % (expiry, requester.user.to_string()) + + mac = hmac.new( + turnSecret.encode(), msg=username.encode(), digestmod=hashlib.sha1 + ) + # We need to use standard padded base64 encoding here + # encode_base64 because we need to add the standard padding to get the + # same result as the TURN server. + password = base64.b64encode(mac.digest()).decode("ascii") + + elif turnUris and turnUsername and turnPassword and userLifetime: + username = turnUsername + password = turnPassword + + else: + return 200, {} + + return ( + 200, + { + "username": username, + "password": password, + "ttl": userLifetime / 1000, + "uris": turnUris, + }, + ) + + +def register_servlets(hs, http_server): + VoipRestServlet(hs).register(http_server) diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py index 5527e278db..d66aeb00eb 100644 --- a/tests/app/test_phone_stats_home.py +++ b/tests/app/test_phone_stats_home.py @@ -1,6 +1,6 @@ import synapse from synapse.app.phone_stats_home import start_phone_stats_home -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest from tests.unittest import HomeserverTestCase diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 3f41e99950..6b87f571b8 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -22,7 +22,7 @@ from synapse.federation.units import Transaction from synapse.handlers.presence import UserPresenceState from synapse.module_api import ModuleApi from synapse.rest import admin -from synapse.rest.client.v1 import login, presence, room +from synapse.rest.client import login, presence, room from synapse.types import JsonDict, StreamToken, create_requester from tests.handlers.test_sync import generate_sync_config diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index 48e98aac79..ca27388ae8 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -14,7 +14,7 @@ from synapse.events.snapshot import EventContext from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest from tests.test_utils.event_injection import create_event diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 1a809b2a6a..7b486aba4a 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -16,7 +16,7 @@ from unittest.mock import Mock from synapse.api.errors import Codes, SynapseError from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import UserID from tests import unittest diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 802c5ad299..f0aa8ed9db 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -6,7 +6,7 @@ from synapse.events import EventBase from synapse.federation.sender import PerDestinationQueue, TransactionManager from synapse.federation.units import Edu from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.util.retryutils import NotRetryingDestination from tests.test_utils import event_injection, make_awaitable diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index b00dd143d6..65b18fbd7a 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -21,7 +21,7 @@ from twisted.internet import defer from synapse.api.constants import RoomEncryptionAlgorithms from synapse.rest import admin -from synapse.rest.client.v1 import login +from synapse.rest.client import login from synapse.types import JsonDict, ReadReceipt from tests.test_utils import make_awaitable diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 1737891564..0b60cc4261 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -19,7 +19,7 @@ from parameterized import parameterized from synapse.events import make_event_from_dict from synapse.federation.federation_server import server_matches_acl_event from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index aab44bce4a..383214ab50 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -18,7 +18,7 @@ from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.room_versions import RoomVersions from synapse.events import builder from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.types import RoomAlias diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py index 18a734daf4..59de1142b1 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py @@ -15,12 +15,10 @@ from collections import Counter from unittest.mock import Mock -import synapse.api.errors -import synapse.handlers.admin import synapse.rest.admin import synapse.storage from synapse.api.constants import EventTypes -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 7a8041ab44..a0a48b564e 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -19,7 +19,7 @@ import synapse import synapse.api.errors from synapse.api.constants import EventTypes from synapse.config.room_directory import RoomDirectoryConfig -from synapse.rest.client.v1 import directory, login, room +from synapse.rest.client import directory, login, room from synapse.types import RoomAlias, create_requester from tests import unittest diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 4140fcefc2..c72a8972a3 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -22,7 +22,7 @@ from synapse.events import EventBase from synapse.federation.federation_base import event_from_pdu_json from synapse.logging.context import LoggingContext, run_in_background from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.util.stringutils import random_string from tests import unittest diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index a8a9fc5b62..8a8d369fac 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -18,7 +18,7 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import create_requester from synapse.util.stringutils import random_string diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 32651db096..38e6d9f536 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -20,8 +20,7 @@ from unittest.mock import Mock from twisted.internet import defer import synapse -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import devices +from synapse.rest.client import devices, login from synapse.types import JsonDict from tests import unittest diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 29845a80da..0a52bc8b72 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -33,7 +33,7 @@ from synapse.handlers.presence import ( handle_update, ) from synapse.rest import admin -from synapse.rest.client.v1 import room +from synapse.rest.client import room from synapse.types import UserID, get_domain_from_id from tests import unittest diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index 732d746e38..ac800afa7d 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -28,7 +28,7 @@ from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.handlers.room_summary import _child_events_comparison_key, _RoomEntry from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.types import JsonDict, UserID diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index e4059acda3..1ba4c05b9b 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.storage.databases.main import stats from tests import unittest diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 549876dc85..e44bf2b3b1 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -18,8 +18,7 @@ from twisted.internet import defer import synapse.rest.admin from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes from synapse.api.room_versions import RoomVersion, RoomVersions -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import user_directory +from synapse.rest.client import login, room, user_directory from synapse.storage.roommember import ProfileInfo from tests import unittest diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 0b817cc701..7dd519cd44 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -20,7 +20,7 @@ from synapse.events import EventBase from synapse.federation.units import Transaction from synapse.handlers.presence import UserPresenceState from synapse.rest import admin -from synapse.rest.client.v1 import login, presence, room +from synapse.rest.client import login, presence, room from synapse.types import create_requester from tests.events.test_presence_router import send_presence_update, sync_presence diff --git a/tests/push/test_email.py b/tests/push/test_email.py index a487706758..e0a3342088 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -21,7 +21,7 @@ from twisted.internet.defer import Deferred import synapse.rest.admin from synapse.api.errors import Codes, SynapseError -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests.unittest import HomeserverTestCase diff --git a/tests/push/test_http.py b/tests/push/test_http.py index ffd75b1491..c068d329a9 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -18,8 +18,7 @@ from twisted.internet.defer import Deferred import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable from synapse.push import PusherConfigException -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import receipts +from synapse.rest.client import login, receipts, room from tests.unittest import HomeserverTestCase, override_config diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 666008425a..f198a94887 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -24,7 +24,7 @@ from synapse.replication.tcp.streams.events import ( EventsStreamRow, ) from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests.replication._base import BaseStreamTestCase from tests.test_utils.event_injection import inject_event, inject_member_event diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py index 1346e0e160..43a16bb141 100644 --- a/tests/replication/test_auth.py +++ b/tests/replication/test_auth.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from synapse.rest.client.v2_alpha import register +from synapse.rest.client import register from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import FakeChannel, make_request diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index b9751efdc5..995097d72c 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from synapse.rest.client.v2_alpha import register +from synapse.rest.client import register from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index a0c710f855..af5dfca752 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -17,7 +17,7 @@ from unittest.mock import Mock from synapse.api.constants import EventTypes, Membership from synapse.events.builder import EventBuilderFactory from synapse.rest.admin import register_servlets_for_client_rest_resource -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import UserID, create_requester from tests.replication._base import BaseMultiWorkerStreamTestCase diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index ffa425328f..ac419f0db3 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -22,7 +22,7 @@ from twisted.web.http import HTTPChannel from twisted.web.server import Request from synapse.rest import admin -from synapse.rest.client.v1 import login +from synapse.rest.client import login from synapse.server import HomeServer from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 1e4e3821b9..4094a75f36 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -17,7 +17,7 @@ from unittest.mock import Mock from twisted.internet import defer from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests.replication._base import BaseMultiWorkerStreamTestCase diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index f3615af97e..0a6e4795ee 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -16,8 +16,7 @@ from unittest.mock import patch from synapse.api.room_versions import RoomVersion from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index a7c6e595b9..bfa638fb4b 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -24,8 +24,7 @@ import synapse.rest.admin from synapse.http.server import JsonResource from synapse.logging.context import make_deferred_yieldable from synapse.rest.admin import VersionServlet -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import groups +from synapse.rest.client import groups, login, room from tests import unittest from tests.server import FakeSite, make_request diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index 120730b764..c4afe5c3d9 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -17,7 +17,7 @@ import urllib.parse import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client.v1 import login +from synapse.rest.client import login from tests import unittest diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index f15d1cf6f7..e9ef89731f 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -16,8 +16,7 @@ import json import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import report_event +from synapse.rest.client import login, report_event, room from tests import unittest diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 7198fd293f..972d60570c 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -20,7 +20,7 @@ from parameterized import parameterized import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client.v1 import login, profile, room +from synapse.rest.client import login, profile, room from synapse.rest.media.v1.filepath import MediaFilePaths from tests import unittest diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 17ec8bfd3b..c9d4731017 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -22,7 +22,7 @@ from parameterized import parameterized_class import synapse.rest.admin from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes -from synapse.rest.client.v1 import directory, events, login, room +from synapse.rest.client import directory, events, login, room from tests import unittest diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 79cac4266b..5cd82209c4 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -18,7 +18,7 @@ from typing import Any, Dict, List, Optional import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client.v1 import login +from synapse.rest.client import login from tests import unittest diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index a736ec4754..ef77275238 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -27,8 +27,7 @@ import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions -from synapse.rest.client.v1 import login, logout, profile, room -from synapse.rest.client.v2_alpha import devices, sync +from synapse.rest.client import devices, login, logout, profile, room, sync from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.types import JsonDict, UserID diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py index 53cbc8ddab..4e1c49c28b 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py @@ -14,7 +14,7 @@ import synapse.rest.admin from synapse.api.errors import Codes, SynapseError -from synapse.rest.client.v1 import login +from synapse.rest.client import login from tests import unittest diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index 5cc62a910a..65c58ce70a 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -16,7 +16,7 @@ import os import synapse.rest.admin from synapse.api.urls import ConsentURIBuilder -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.rest.consent import consent_resource from tests import unittest diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py index eec0fc01f9..3d7aa8ec86 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.api.constants import EventContentFields, EventTypes from synapse.rest import admin -from synapse.rest.client.v1 import room +from synapse.rest.client import room from tests import unittest diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index 478296ba0e..ca2e8ff8ef 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -15,7 +15,7 @@ import json import synapse.rest.admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py index ba5ad47df5..91d0762cb0 100644 --- a/tests/rest/client/test_power_levels.py +++ b/tests/rest/client/test_power_levels.py @@ -13,8 +13,7 @@ # limitations under the License. from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from tests.unittest import HomeserverTestCase diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index dfd85221d0..433d715f69 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -13,8 +13,7 @@ # limitations under the License. from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from tests.unittest import HomeserverTestCase diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index e1a6e73e17..b58452195a 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -15,7 +15,7 @@ from unittest.mock import Mock from synapse.api.constants import EventTypes from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.visibility import filter_events_for_client from tests import unittest diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 288ee12888..6a0d9a82be 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -16,8 +16,13 @@ from unittest.mock import Mock, patch import synapse.rest.admin from synapse.api.constants import EventTypes -from synapse.rest.client.v1 import directory, login, profile, room -from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet +from synapse.rest.client import ( + directory, + login, + profile, + room, + room_upgrade_rest_servlet, +) from synapse.types import UserID from tests import unittest diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 28dd47a28b..0ae4029640 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -19,7 +19,7 @@ from synapse.events import EventBase from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.module_api import ModuleApi from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import Requester, StateMap from synapse.util.frozenutils import unfreeze diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py index 8ed470490b..d2181ea907 100644 --- a/tests/rest/client/v1/test_directory.py +++ b/tests/rest/client/v1/test_directory.py @@ -15,7 +15,7 @@ import json from synapse.rest import admin -from synapse.rest.client.v1 import directory, login, room +from synapse.rest.client import directory, login, room from synapse.types import RoomAlias from synapse.util.stringutils import random_string diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 2789d51546..a90294003e 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -17,7 +17,7 @@ from unittest.mock import Mock import synapse.rest.admin -from synapse.rest.client.v1 import events, login, room +from synapse.rest.client import events, login, room from tests import unittest diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 7eba69642a..eba3552b19 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -24,9 +24,8 @@ from twisted.web.resource import Resource import synapse.rest.admin from synapse.appservice import ApplicationService -from synapse.rest.client.v1 import login, logout -from synapse.rest.client.v2_alpha import devices, register -from synapse.rest.client.v2_alpha.account import WhoamiRestServlet +from synapse.rest.client import devices, login, logout, register +from synapse.rest.client.account import WhoamiRestServlet from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import create_requester diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 597e4c67de..1d152352d1 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -17,7 +17,7 @@ from unittest.mock import Mock from twisted.internet import defer from synapse.handlers.presence import PresenceHandler -from synapse.rest.client.v1 import presence +from synapse.rest.client import presence from synapse.types import UserID from tests import unittest diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 165ad33fb7..2860579c2e 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -14,7 +14,7 @@ """Tests REST events for /profile paths.""" from synapse.rest import admin -from synapse.rest.client.v1 import login, profile, room +from synapse.rest.client import login, profile, room from tests import unittest diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py index d077616082..d0ce91ccd9 100644 --- a/tests/rest/client/v1/test_push_rule_attrs.py +++ b/tests/rest/client/v1/test_push_rule_attrs.py @@ -13,7 +13,7 @@ # limitations under the License. import synapse from synapse.api.errors import Codes -from synapse.rest.client.v1 import login, push_rule, room +from synapse.rest.client import login, push_rule, room from tests.unittest import HomeserverTestCase diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 1a9528ec20..0c9cbb9aff 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -29,8 +29,7 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin -from synapse.rest.client.v1 import directory, login, profile, room -from synapse.rest.client.v2_alpha import account +from synapse.rest.client import account, directory, login, profile, room from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util.stringutils import random_string diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 44e22ca999..b54b004733 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -17,7 +17,7 @@ from unittest.mock import Mock -from synapse.rest.client.v1 import room +from synapse.rest.client import room from synapse.types import UserID from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index e7e617e9df..b946fca8b3 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -25,8 +25,7 @@ import synapse.rest.admin from synapse.api.constants import LoginType, Membership from synapse.api.errors import Codes, HttpResponseException from synapse.appservice import ApplicationService -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import account, register +from synapse.rest.client import account, login, register, room from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 6b90f838b6..cf5cfb910c 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -19,8 +19,7 @@ from twisted.internet.defer import succeed import synapse.rest.admin from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import account, auth, devices, register +from synapse.rest.client import account, auth, devices, login, register from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import JsonDict, UserID diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index f80f48a455..ad83b3d2ff 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -13,8 +13,7 @@ # limitations under the License. import synapse.rest.admin from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import capabilities +from synapse.rest.client import capabilities, login from tests import unittest from tests.unittest import override_config diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index c7e47725b7..475c6bed3d 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -15,7 +15,7 @@ from twisted.internet import defer from synapse.api.errors import Codes -from synapse.rest.client.v2_alpha import filter +from synapse.rest.client import filter from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py index 6f07ff6cbb..3cf5871899 100644 --- a/tests/rest/client/v2_alpha/test_password_policy.py +++ b/tests/rest/client/v2_alpha/test_password_policy.py @@ -17,8 +17,7 @@ import json from synapse.api.constants import LoginType from synapse.api.errors import Codes from synapse.rest import admin -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import account, password_policy, register +from synapse.rest.client import account, login, password_policy, register from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index a52e5e608a..fecda037a5 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -23,8 +23,7 @@ import synapse.rest.admin from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService -from synapse.rest.client.v1 import login, logout -from synapse.rest.client.v2_alpha import account, account_validity, register, sync +from synapse.rest.client import account, account_validity, login, logout, register, sync from tests import unittest from tests.unittest import override_config diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 2e2f94742e..02b5e9a8d0 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -19,8 +19,7 @@ from typing import Optional from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import register, relations +from synapse.rest.client import login, register, relations, room from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/v2_alpha/test_report_event.py index a76a6fef1e..ee6b0b9ebf 100644 --- a/tests/rest/client/v2_alpha/test_report_event.py +++ b/tests/rest/client/v2_alpha/test_report_event.py @@ -15,8 +15,7 @@ import json import synapse.rest.admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import report_event +from synapse.rest.client import login, report_event, room from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_sendtodevice.py b/tests/rest/client/v2_alpha/test_sendtodevice.py index c9c99cc5d7..6db7062a8e 100644 --- a/tests/rest/client/v2_alpha/test_sendtodevice.py +++ b/tests/rest/client/v2_alpha/test_sendtodevice.py @@ -13,8 +13,7 @@ # limitations under the License. from synapse.rest import admin -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import sendtodevice, sync +from synapse.rest.client import login, sendtodevice, sync from tests.unittest import HomeserverTestCase, override_config diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py index cedb9614a8..283eccd53f 100644 --- a/tests/rest/client/v2_alpha/test_shared_rooms.py +++ b/tests/rest/client/v2_alpha/test_shared_rooms.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import synapse.rest.admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import shared_rooms +from synapse.rest.client import login, room, shared_rooms from tests import unittest from tests.server import FakeChannel diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 15748ed4fd..95be369d4b 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -21,8 +21,7 @@ from synapse.api.constants import ( ReadReceiptEventFields, RelationTypes, ) -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import knock, read_marker, receipts, sync +from synapse.rest.client import knock, login, read_marker, receipts, room, sync from tests import unittest from tests.federation.transport.test_knocking import ( diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/v2_alpha/test_upgrade_room.py index 5f3f15fc57..72f976d8e2 100644 --- a/tests/rest/client/v2_alpha/test_upgrade_room.py +++ b/tests/rest/client/v2_alpha/test_upgrade_room.py @@ -15,8 +15,7 @@ from typing import Optional from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet +from synapse.rest.client import login, room, room_upgrade_rest_servlet from tests import unittest from tests.server import FakeChannel diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 2d6b49692e..6085444b9d 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -30,7 +30,7 @@ from twisted.internet.defer import Deferred from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.logging.context import make_deferred_yieldable from synapse.rest import admin -from synapse.rest.client.v1 import login +from synapse.rest.client import login from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.media_storage import MediaStorage diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py index ac98259b7e..58b399a043 100644 --- a/tests/server_notices/test_consent.py +++ b/tests/server_notices/test_consent.py @@ -15,8 +15,7 @@ import os import synapse.rest.admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from tests import unittest diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 3245aa91ca..8701b5f7e3 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -19,8 +19,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, ) diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index d05d367685..a649e8c618 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -15,7 +15,7 @@ import json from synapse.logging.context import LoggingContext from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.util.async_helpers import yieldable_gather_results diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 77c4fe721c..da98733ce8 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -17,7 +17,7 @@ from unittest.mock import Mock, patch import synapse.rest.admin from synapse.api.constants import EventTypes -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.storage import prepare_database from synapse.types import UserID, create_requester diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index e57fce9694..1c2df54ecc 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -17,7 +17,7 @@ from unittest.mock import Mock import synapse.rest.admin from synapse.http.site import XForwardedForRequest -from synapse.rest.client.v1 import login +from synapse.rest.client import login from tests import unittest from tests.server import make_request diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index d87f124c26..93136f0717 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes from synapse.api.room_versions import RoomVersions from synapse.events import EventBase from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.storage.databases.main.events import _LinkMap from synapse.types import create_requester diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index 617bc8091f..f462a8b1c7 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -17,7 +17,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.federation.federation_base import event_from_pdu_json from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests.unittest import HomeserverTestCase diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index e5574063f1..22a77c3ccc 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.api.errors import NotFoundError, SynapseError -from synapse.rest.client.v1 import room +from synapse.rest.client import room from tests.unittest import HomeserverTestCase diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 9fa968f6bb..c72dc40510 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -15,7 +15,7 @@ from synapse.api.constants import Membership from synapse.rest.admin import register_servlets_for_client_rest_resource -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import UserID, create_requester from tests import unittest diff --git a/tests/test_mau.py b/tests/test_mau.py index fa6ef92b3b..66111eb367 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -17,7 +17,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.appservice import ApplicationService -from synapse.rest.client.v2_alpha import register, sync +from synapse.rest.client import register, sync from tests import unittest from tests.unittest import override_config diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 0df480db9f..67dcf567cd 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -17,7 +17,7 @@ from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactorClock -from synapse.rest.client.v2_alpha.register import register_servlets +from synapse.rest.client.register import register_servlets from synapse.util import Clock from tests import unittest -- cgit 1.5.1 From 5f7b1e1f276fdd25304ff06076e1cd77cf3a9640 Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Tue, 17 Aug 2021 13:13:11 +0100 Subject: Make `PeriodicallyFlushingMemoryHandler` the default logging handler. (#10518) --- changelog.d/10518.feature | 1 + docker/conf/log.config | 27 ++++++++++++++++++++------- docs/sample_log_config.yaml | 27 ++++++++++++++++++++------- synapse/config/logger.py | 27 ++++++++++++++++++++------- 4 files changed, 61 insertions(+), 21 deletions(-) create mode 100644 changelog.d/10518.feature (limited to 'synapse') diff --git a/changelog.d/10518.feature b/changelog.d/10518.feature new file mode 100644 index 0000000000..112e4d105c --- /dev/null +++ b/changelog.d/10518.feature @@ -0,0 +1 @@ +The default logging handler for new installations is now `PeriodicallyFlushingMemoryHandler`, a buffered logging handler which periodically flushes itself. diff --git a/docker/conf/log.config b/docker/conf/log.config index a994626926..7a216a36a0 100644 --- a/docker/conf/log.config +++ b/docker/conf/log.config @@ -18,18 +18,31 @@ handlers: backupCount: 6 # Does not include the current log file. encoding: utf8 - # Default to buffering writes to log file for efficiency. This means that - # there will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR - # logs will still be flushed immediately. + # Default to buffering writes to log file for efficiency. + # WARNING/ERROR logs will still be flushed immediately, but there will be a + # delay (of up to `period` seconds, or until the buffer is full with + # `capacity` messages) before INFO/DEBUG logs get written. buffer: - class: logging.handlers.MemoryHandler + class: synapse.logging.handlers.PeriodicallyFlushingMemoryHandler target: file - # The capacity is the number of log lines that are buffered before - # being written to disk. Increasing this will lead to better + + # The capacity is the maximum number of log lines that are buffered + # before being written to disk. Increasing this will lead to better # performance, at the expensive of it taking longer for log lines to # be written to disk. + # This parameter is required. capacity: 10 - flushLevel: 30 # Flush for WARNING logs as well + + # Logs with a level at or above the flush level will cause the buffer to + # be flushed immediately. + # Default value: 40 (ERROR) + # Other values: 50 (CRITICAL), 30 (WARNING), 20 (INFO), 10 (DEBUG) + flushLevel: 30 # Flush immediately for WARNING logs and higher + + # The period of time, in seconds, between forced flushes. + # Messages will not be delayed for longer than this time. + # Default value: 5 seconds + period: 5 {% endif %} console: diff --git a/docs/sample_log_config.yaml b/docs/sample_log_config.yaml index 669e600081..2485ad25ed 100644 --- a/docs/sample_log_config.yaml +++ b/docs/sample_log_config.yaml @@ -24,18 +24,31 @@ handlers: backupCount: 3 # Does not include the current log file. encoding: utf8 - # Default to buffering writes to log file for efficiency. This means that - # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR - # logs will still be flushed immediately. + # Default to buffering writes to log file for efficiency. + # WARNING/ERROR logs will still be flushed immediately, but there will be a + # delay (of up to `period` seconds, or until the buffer is full with + # `capacity` messages) before INFO/DEBUG logs get written. buffer: - class: logging.handlers.MemoryHandler + class: synapse.logging.handlers.PeriodicallyFlushingMemoryHandler target: file - # The capacity is the number of log lines that are buffered before - # being written to disk. Increasing this will lead to better + + # The capacity is the maximum number of log lines that are buffered + # before being written to disk. Increasing this will lead to better # performance, at the expensive of it taking longer for log lines to # be written to disk. + # This parameter is required. capacity: 10 - flushLevel: 30 # Flush for WARNING logs as well + + # Logs with a level at or above the flush level will cause the buffer to + # be flushed immediately. + # Default value: 40 (ERROR) + # Other values: 50 (CRITICAL), 30 (WARNING), 20 (INFO), 10 (DEBUG) + flushLevel: 30 # Flush immediately for WARNING logs and higher + + # The period of time, in seconds, between forced flushes. + # Messages will not be delayed for longer than this time. + # Default value: 5 seconds + period: 5 # A handler that writes logs to stderr. Unused by default, but can be used # instead of "buffer" and "file" in the logger handlers. diff --git a/synapse/config/logger.py b/synapse/config/logger.py index ad4e6e61c3..4a398a7932 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -67,18 +67,31 @@ handlers: backupCount: 3 # Does not include the current log file. encoding: utf8 - # Default to buffering writes to log file for efficiency. This means that - # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR - # logs will still be flushed immediately. + # Default to buffering writes to log file for efficiency. + # WARNING/ERROR logs will still be flushed immediately, but there will be a + # delay (of up to `period` seconds, or until the buffer is full with + # `capacity` messages) before INFO/DEBUG logs get written. buffer: - class: logging.handlers.MemoryHandler + class: synapse.logging.handlers.PeriodicallyFlushingMemoryHandler target: file - # The capacity is the number of log lines that are buffered before - # being written to disk. Increasing this will lead to better + + # The capacity is the maximum number of log lines that are buffered + # before being written to disk. Increasing this will lead to better # performance, at the expensive of it taking longer for log lines to # be written to disk. + # This parameter is required. capacity: 10 - flushLevel: 30 # Flush for WARNING logs as well + + # Logs with a level at or above the flush level will cause the buffer to + # be flushed immediately. + # Default value: 40 (ERROR) + # Other values: 50 (CRITICAL), 30 (WARNING), 20 (INFO), 10 (DEBUG) + flushLevel: 30 # Flush immediately for WARNING logs and higher + + # The period of time, in seconds, between forced flushes. + # Messages will not be delayed for longer than this time. + # Default value: 5 seconds + period: 5 # A handler that writes logs to stderr. Unused by default, but can be used # instead of "buffer" and "file" in the logger handlers. -- cgit 1.5.1 From c4cf0c047329e125f0940281fd53688474d26581 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 17 Aug 2021 08:19:12 -0400 Subject: Attempt to pull from the legacy spaces summary API over federation. (#10583) If the new /hierarchy API does not exist on all destinations, fallback to querying the /spaces API and translating the results. This is a backwards compatibility hack since not all of the federated homeservers will update at the same time. --- changelog.d/10583.feature | 1 + synapse/federation/federation_client.py | 64 ++++++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 9 deletions(-) create mode 100644 changelog.d/10583.feature (limited to 'synapse') diff --git a/changelog.d/10583.feature b/changelog.d/10583.feature new file mode 100644 index 0000000000..ffc4e4289c --- /dev/null +++ b/changelog.d/10583.feature @@ -0,0 +1 @@ +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 0af953a5d6..29979414e3 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1364,13 +1364,59 @@ class FederationClient(FederationBase): return room, children, inaccessible_children - # TODO Fallback to the old federation API and translate the results. - return await self._try_destination_list( - "fetch room hierarchy", - destinations, - send_request, - failover_on_unknown_endpoint=True, - ) + try: + return await self._try_destination_list( + "fetch room hierarchy", + destinations, + send_request, + failover_on_unknown_endpoint=True, + ) + except SynapseError as e: + # Fallback to the old federation API and translate the results if + # no servers implement the new API. + # + # The algorithm below is a bit inefficient as it only attempts to + # get information for the requested room, but the legacy API may + # return additional layers. + if e.code == 502: + legacy_result = await self.get_space_summary( + destinations, + room_id, + suggested_only, + max_rooms_per_space=None, + exclude_rooms=[], + ) + + # Find the requested room in the response (and remove it). + for _i, room in enumerate(legacy_result.rooms): + if room.get("room_id") == room_id: + break + else: + # The requested room was not returned, nothing we can do. + raise + requested_room = legacy_result.rooms.pop(_i) + + # Find any children events of the requested room. + children_events = [] + children_room_ids = set() + for event in legacy_result.events: + if event.room_id == room_id: + children_events.append(event.data) + children_room_ids.add(event.state_key) + # And add them under the requested room. + requested_room["children_state"] = children_events + + # Find the children rooms. + children = [] + for room in legacy_result.rooms: + if room.get("room_id") in children_room_ids: + children.append(room) + + # It isn't clear from the response whether some of the rooms are + # not accessible. + return requested_room, children, () + + raise @attr.s(frozen=True, slots=True, auto_attribs=True) @@ -1430,7 +1476,7 @@ class FederationSpaceSummaryEventResult: class FederationSpaceSummaryResult: """Represents the data returned by a successful get_space_summary call.""" - rooms: Sequence[JsonDict] + rooms: List[JsonDict] events: Sequence[FederationSpaceSummaryEventResult] @classmethod @@ -1444,7 +1490,7 @@ class FederationSpaceSummaryResult: ValueError if d is not a valid /spaces/ response """ rooms = d.get("rooms") - if not isinstance(rooms, Sequence): + if not isinstance(rooms, List): raise ValueError("'rooms' must be a list") if any(not isinstance(r, dict) for r in rooms): raise ValueError("Invalid room in 'rooms' list") -- cgit 1.5.1 From 56397599809e131174daaeb4c6dc18fde9db6c3f Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 17 Aug 2021 14:45:24 +0200 Subject: Centralise the custom template directory (#10596) Several configuration sections are using separate settings for custom template directories, which can be confusing. This PR adds a new top-level configuration for a custom template directory which is then used for every module. The only exception is the consent templates, since the consent template directory require a specific hierarchy, so it's probably better that it stays separate from everything else. --- changelog.d/10596.removal | 1 + docs/SUMMARY.md | 1 + docs/sample_config.yaml | 225 ++-------------------- docs/templates.md | 239 ++++++++++++++++++++++++ docs/upgrade.md | 11 ++ synapse/config/account_validity.py | 7 +- synapse/config/emailconfig.py | 71 +++---- synapse/config/server.py | 25 +++ synapse/config/sso.py | 173 +---------------- synapse/module_api/__init__.py | 3 +- synapse/rest/synapse/client/new_user_consent.py | 2 + synapse/rest/synapse/client/pick_username.py | 2 + 12 files changed, 342 insertions(+), 418 deletions(-) create mode 100644 changelog.d/10596.removal create mode 100644 docs/templates.md (limited to 'synapse') diff --git a/changelog.d/10596.removal b/changelog.d/10596.removal new file mode 100644 index 0000000000..e69f632db4 --- /dev/null +++ b/changelog.d/10596.removal @@ -0,0 +1 @@ +The `template_dir` configuration settings in the `sso`, `account_validity` and `email` sections of the configuration file are now deprecated in favour of the global `templates.custom_template_directory` setting. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html) for more information. diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 3d320a1c43..56e0141c2b 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -21,6 +21,7 @@ - [Homeserver Sample Config File](usage/configuration/homeserver_sample_config.md) - [Logging Sample Config File](usage/configuration/logging_sample_config.md) - [Structured Logging](structured_logging.md) + - [Templates](templates.md) - [User Authentication](usage/configuration/user_authentication/README.md) - [Single-Sign On]() - [OpenID Connect](openid.md) diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index aeebcaf45f..3ec76d5abf 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -565,6 +565,19 @@ retention: # #next_link_domain_whitelist: ["matrix.org"] +# Templates to use when generating email or HTML page contents. +# +templates: + # Directory in which Synapse will try to find template files to use to generate + # email or HTML page contents. + # If not set, or a file is not found within the template directory, a default + # template from within the Synapse package will be used. + # + # See https://matrix-org.github.io/synapse/latest/templates.html for more + # information about using custom templates. + # + #custom_template_directory: /path/to/custom/templates/ + ## TLS ## @@ -1895,6 +1908,9 @@ cas_config: # Additional settings to use with single-sign on systems such as OpenID Connect, # SAML2 and CAS. # +# Server admins can configure custom templates for pages related to SSO. See +# https://matrix-org.github.io/synapse/latest/templates.html for more information. +# sso: # A list of client URLs which are whitelisted so that the user does not # have to confirm giving access to their account to the URL. Any client @@ -1927,169 +1943,6 @@ sso: # #update_profile_information: true - # Directory in which Synapse will try to find the template files below. - # If not set, or the files named below are not found within the template - # directory, default templates from within the Synapse package will be used. - # - # Synapse will look for the following templates in this directory: - # - # * HTML page to prompt the user to choose an Identity Provider during - # login: 'sso_login_idp_picker.html'. - # - # This is only used if multiple SSO Identity Providers are configured. - # - # When rendering, this template is given the following variables: - # * redirect_url: the URL that the user will be redirected to after - # login. - # - # * server_name: the homeserver's name. - # - # * providers: a list of available Identity Providers. Each element is - # an object with the following attributes: - # - # * idp_id: unique identifier for the IdP - # * idp_name: user-facing name for the IdP - # * idp_icon: if specified in the IdP config, an MXC URI for an icon - # for the IdP - # * idp_brand: if specified in the IdP config, a textual identifier - # for the brand of the IdP - # - # The rendered HTML page should contain a form which submits its results - # back as a GET request, with the following query parameters: - # - # * redirectUrl: the client redirect URI (ie, the `redirect_url` passed - # to the template) - # - # * idp: the 'idp_id' of the chosen IDP. - # - # * HTML page to prompt new users to enter a userid and confirm other - # details: 'sso_auth_account_details.html'. This is only shown if the - # SSO implementation (with any user_mapping_provider) does not return - # a localpart. - # - # When rendering, this template is given the following variables: - # - # * server_name: the homeserver's name. - # - # * idp: details of the SSO Identity Provider that the user logged in - # with: an object with the following attributes: - # - # * idp_id: unique identifier for the IdP - # * idp_name: user-facing name for the IdP - # * idp_icon: if specified in the IdP config, an MXC URI for an icon - # for the IdP - # * idp_brand: if specified in the IdP config, a textual identifier - # for the brand of the IdP - # - # * user_attributes: an object containing details about the user that - # we received from the IdP. May have the following attributes: - # - # * display_name: the user's display_name - # * emails: a list of email addresses - # - # The template should render a form which submits the following fields: - # - # * username: the localpart of the user's chosen user id - # - # * HTML page allowing the user to consent to the server's terms and - # conditions. This is only shown for new users, and only if - # `user_consent.require_at_registration` is set. - # - # When rendering, this template is given the following variables: - # - # * server_name: the homeserver's name. - # - # * user_id: the user's matrix proposed ID. - # - # * user_profile.display_name: the user's proposed display name, if any. - # - # * consent_version: the version of the terms that the user will be - # shown - # - # * terms_url: a link to the page showing the terms. - # - # The template should render a form which submits the following fields: - # - # * accepted_version: the version of the terms accepted by the user - # (ie, 'consent_version' from the input variables). - # - # * HTML page for a confirmation step before redirecting back to the client - # with the login token: 'sso_redirect_confirm.html'. - # - # When rendering, this template is given the following variables: - # - # * redirect_url: the URL the user is about to be redirected to. - # - # * display_url: the same as `redirect_url`, but with the query - # parameters stripped. The intention is to have a - # human-readable URL to show to users, not to use it as - # the final address to redirect to. - # - # * server_name: the homeserver's name. - # - # * new_user: a boolean indicating whether this is the user's first time - # logging in. - # - # * user_id: the user's matrix ID. - # - # * user_profile.avatar_url: an MXC URI for the user's avatar, if any. - # None if the user has not set an avatar. - # - # * user_profile.display_name: the user's display name. None if the user - # has not set a display name. - # - # * HTML page which notifies the user that they are authenticating to confirm - # an operation on their account during the user interactive authentication - # process: 'sso_auth_confirm.html'. - # - # When rendering, this template is given the following variables: - # * redirect_url: the URL the user is about to be redirected to. - # - # * description: the operation which the user is being asked to confirm - # - # * idp: details of the Identity Provider that we will use to confirm - # the user's identity: an object with the following attributes: - # - # * idp_id: unique identifier for the IdP - # * idp_name: user-facing name for the IdP - # * idp_icon: if specified in the IdP config, an MXC URI for an icon - # for the IdP - # * idp_brand: if specified in the IdP config, a textual identifier - # for the brand of the IdP - # - # * HTML page shown after a successful user interactive authentication session: - # 'sso_auth_success.html'. - # - # Note that this page must include the JavaScript which notifies of a successful authentication - # (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback). - # - # This template has no additional variables. - # - # * HTML page shown after a user-interactive authentication session which - # does not map correctly onto the expected user: 'sso_auth_bad_user.html'. - # - # When rendering, this template is given the following variables: - # * server_name: the homeserver's name. - # * user_id_to_verify: the MXID of the user that we are trying to - # validate. - # - # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database) - # attempts to login: 'sso_account_deactivated.html'. - # - # This template has no additional variables. - # - # * HTML page to display to users if something goes wrong during the - # OpenID Connect authentication process: 'sso_error.html'. - # - # When rendering, this template is given two variables: - # * error: the technical name of the error - # * error_description: a human-readable message for the error - # - # You can see the default templates at: - # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates - # - #template_dir: "res/templates" - # JSON web token integration. The following settings can be used to make # Synapse JSON web tokens for authentication, instead of its internal @@ -2220,6 +2073,9 @@ ui_auth: # Configuration for sending emails from Synapse. # +# Server admins can configure custom templates for email content. See +# https://matrix-org.github.io/synapse/latest/templates.html for more information. +# email: # The hostname of the outgoing SMTP server to use. Defaults to 'localhost'. # @@ -2296,49 +2152,6 @@ email: # #invite_client_location: https://app.element.io - # Directory in which Synapse will try to find the template files below. - # If not set, or the files named below are not found within the template - # directory, default templates from within the Synapse package will be used. - # - # Synapse will look for the following templates in this directory: - # - # * The contents of email notifications of missed events: 'notif_mail.html' and - # 'notif_mail.txt'. - # - # * The contents of account expiry notice emails: 'notice_expiry.html' and - # 'notice_expiry.txt'. - # - # * The contents of password reset emails sent by the homeserver: - # 'password_reset.html' and 'password_reset.txt' - # - # * An HTML page that a user will see when they follow the link in the password - # reset email. The user will be asked to confirm the action before their - # password is reset: 'password_reset_confirmation.html' - # - # * HTML pages for success and failure that a user will see when they confirm - # the password reset flow using the page above: 'password_reset_success.html' - # and 'password_reset_failure.html' - # - # * The contents of address verification emails sent during registration: - # 'registration.html' and 'registration.txt' - # - # * HTML pages for success and failure that a user will see when they follow - # the link in an address verification email sent during registration: - # 'registration_success.html' and 'registration_failure.html' - # - # * The contents of address verification emails sent when an address is added - # to a Matrix account: 'add_threepid.html' and 'add_threepid.txt' - # - # * HTML pages for success and failure that a user will see when they follow - # the link in an address verification email sent when an address is added - # to a Matrix account: 'add_threepid_success.html' and - # 'add_threepid_failure.html' - # - # You can see the default templates at: - # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates - # - #template_dir: "res/templates" - # Subjects to use when sending emails from Synapse. # # The placeholder '%(app)s' will be replaced with the value of the 'app_name' diff --git a/docs/templates.md b/docs/templates.md new file mode 100644 index 0000000000..a240f58b54 --- /dev/null +++ b/docs/templates.md @@ -0,0 +1,239 @@ +# Templates + +Synapse uses parametrised templates to generate the content of emails it sends and +webpages it shows to users. + +By default, Synapse will use the templates listed [here](https://github.com/matrix-org/synapse/tree/master/synapse/res/templates). +Server admins can configure an additional directory for Synapse to look for templates +in, allowing them to specify custom templates: + +```yaml +templates: + custom_templates_directory: /path/to/custom/templates/ +``` + +If this setting is not set, or the files named below are not found within the directory, +default templates from within the Synapse package will be used. + +Templates that are given variables when being rendered are rendered using [Jinja 2](https://jinja.palletsprojects.com/en/2.11.x/). +Templates rendered by Jinja 2 can also access two functions on top of the functions +already available as part of Jinja 2: + +```python +format_ts(value: int, format: str) -> str +``` + +Formats a timestamp in milliseconds. + +Example: `reason.last_sent_ts|format_ts("%c")` + +```python +mxc_to_http(value: str, width: int, height: int, resize_method: str = "crop") -> str +``` + +Turns a `mxc://` URL for media content into an HTTP(S) one using the homeserver's +`public_baseurl` configuration setting as the URL's base. + +Example: `message.sender_avatar_url|mxc_to_http(32,32)` + + +## Email templates + +Below are the templates Synapse will look for when generating the content of an email: + +* `notif_mail.html` and `notif_mail.txt`: The contents of email notifications of missed + events. + When rendering, this template is given the following variables: + * `user_display_name`: the display name for the user receiving the notification + * `unsubscribe_link`: the link users can click to unsubscribe from email notifications + * `summary_text`: a summary of the notification(s). The text used can be customised + by configuring the various settings in the `email.subjects` section of the + configuration file. + * `rooms`: a list of rooms containing events to include in the email. Each element is + an object with the following attributes: + * `title`: a human-readable name for the room + * `hash`: a hash of the ID of the room + * `invite`: a boolean, which is `True` if the room is an invite the user hasn't + accepted yet, `False` otherwise + * `notifs`: a list of events, or an empty list if `invite` is `True`. Each element + is an object with the following attributes: + * `link`: a `matrix.to` link to the event + * `ts`: the time in milliseconds at which the event was received + * `messages`: a list of messages containing one message before the event, the + message in the event, and one message after the event. Each element is an + object with the following attributes: + * `event_type`: the type of the event + * `is_historical`: a boolean, which is `False` if the message is the one + that triggered the notification, `True` otherwise + * `id`: the ID of the event + * `ts`: the time in milliseconds at which the event was sent + * `sender_name`: the display name for the event's sender + * `sender_avatar_url`: the avatar URL (as a `mxc://` URL) for the event's + sender + * `sender_hash`: a hash of the user ID of the sender + * `link`: a `matrix.to` link to the room + * `reason`: information on the event that triggered the email to be sent. It's an + object with the following attributes: + * `room_id`: the ID of the room the event was sent in + * `room_name`: a human-readable name for the room the event was sent in + * `now`: the current time in milliseconds + * `received_at`: the time in milliseconds at which the event was received + * `delay_before_mail_ms`: the amount of time in milliseconds Synapse always waits + before ever emailing about a notification (to give the user a chance to respond + to other push or notice the window) + * `last_sent_ts`: the time in milliseconds at which a notification was last sent + for an event in this room + * `throttle_ms`: the minimum amount of time in milliseconds between two + notifications can be sent for this room +* `password_reset.html` and `password_reset.txt`: The contents of password reset emails + sent by the homeserver. + When rendering, these templates are given a `link` variable which contains the link the + user must click in order to reset their password. +* `registration.html` and `registration.txt`: The contents of address verification emails + sent during registration. + When rendering, these templates are given a `link` variable which contains the link the + user must click in order to validate their email address. +* `add_threepid.html` and `add_threepid.txt`: The contents of address verification emails + sent when an address is added to a Matrix account. + When rendering, these templates are given a `link` variable which contains the link the + user must click in order to validate their email address. + + +## HTML page templates for registration and password reset + +Below are the templates Synapse will look for when generating pages related to +registration and password reset: + +* `password_reset_confirmation.html`: An HTML page that a user will see when they follow + the link in the password reset email. The user will be asked to confirm the action + before their password is reset. + When rendering, this template is given the following variables: + * `sid`: the session ID for the password reset + * `token`: the token for the password reset + * `client_secret`: the client secret for the password reset +* `password_reset_success.html` and `password_reset_failure.html`: HTML pages for success + and failure that a user will see when they confirm the password reset flow using the + page above. + When rendering, `password_reset_success.html` is given no variable, and + `password_reset_failure.html` is given a `failure_reason`, which contains the reason + for the password reset failure. +* `registration_success.html` and `registration_failure.html`: HTML pages for success and + failure that a user will see when they follow the link in an address verification email + sent during registration. + When rendering, `registration_success.html` is given no variable, and + `registration_failure.html` is given a `failure_reason`, which contains the reason + for the registration failure. +* `add_threepid_success.html` and `add_threepid_failure.html`: HTML pages for success and + failure that a user will see when they follow the link in an address verification email + sent when an address is added to a Matrix account. + When rendering, `add_threepid_success.html` is given no variable, and + `add_threepid_failure.html` is given a `failure_reason`, which contains the reason + for the registration failure. + + +## HTML page templates for Single Sign-On (SSO) + +Below are the templates Synapse will look for when generating pages related to SSO: + +* `sso_login_idp_picker.html`: HTML page to prompt the user to choose an + Identity Provider during login. + This is only used if multiple SSO Identity Providers are configured. + When rendering, this template is given the following variables: + * `redirect_url`: the URL that the user will be redirected to after + login. + * `server_name`: the homeserver's name. + * `providers`: a list of available Identity Providers. Each element is + an object with the following attributes: + * `idp_id`: unique identifier for the IdP + * `idp_name`: user-facing name for the IdP + * `idp_icon`: if specified in the IdP config, an MXC URI for an icon + for the IdP + * `idp_brand`: if specified in the IdP config, a textual identifier + for the brand of the IdP + The rendered HTML page should contain a form which submits its results + back as a GET request, with the following query parameters: + * `redirectUrl`: the client redirect URI (ie, the `redirect_url` passed + to the template) + * `idp`: the 'idp_id' of the chosen IDP. +* `sso_auth_account_details.html`: HTML page to prompt new users to enter a + userid and confirm other details. This is only shown if the + SSO implementation (with any `user_mapping_provider`) does not return + a localpart. + When rendering, this template is given the following variables: + * `server_name`: the homeserver's name. + * `idp`: details of the SSO Identity Provider that the user logged in + with: an object with the following attributes: + * `idp_id`: unique identifier for the IdP + * `idp_name`: user-facing name for the IdP + * `idp_icon`: if specified in the IdP config, an MXC URI for an icon + for the IdP + * `idp_brand`: if specified in the IdP config, a textual identifier + for the brand of the IdP + * `user_attributes`: an object containing details about the user that + we received from the IdP. May have the following attributes: + * display_name: the user's display_name + * emails: a list of email addresses + The template should render a form which submits the following fields: + * `username`: the localpart of the user's chosen user id +* `sso_new_user_consent.html`: HTML page allowing the user to consent to the + server's terms and conditions. This is only shown for new users, and only if + `user_consent.require_at_registration` is set. + When rendering, this template is given the following variables: + * `server_name`: the homeserver's name. + * `user_id`: the user's matrix proposed ID. + * `user_profile.display_name`: the user's proposed display name, if any. + * consent_version: the version of the terms that the user will be + shown + * `terms_url`: a link to the page showing the terms. + The template should render a form which submits the following fields: + * `accepted_version`: the version of the terms accepted by the user + (ie, 'consent_version' from the input variables). +* `sso_redirect_confirm.html`: HTML page for a confirmation step before redirecting back + to the client with the login token. + When rendering, this template is given the following variables: + * `redirect_url`: the URL the user is about to be redirected to. + * `display_url`: the same as `redirect_url`, but with the query + parameters stripped. The intention is to have a + human-readable URL to show to users, not to use it as + the final address to redirect to. + * `server_name`: the homeserver's name. + * `new_user`: a boolean indicating whether this is the user's first time + logging in. + * `user_id`: the user's matrix ID. + * `user_profile.avatar_url`: an MXC URI for the user's avatar, if any. + `None` if the user has not set an avatar. + * `user_profile.display_name`: the user's display name. `None` if the user + has not set a display name. +* `sso_auth_confirm.html`: HTML page which notifies the user that they are authenticating + to confirm an operation on their account during the user interactive authentication + process. + When rendering, this template is given the following variables: + * `redirect_url`: the URL the user is about to be redirected to. + * `description`: the operation which the user is being asked to confirm + * `idp`: details of the Identity Provider that we will use to confirm + the user's identity: an object with the following attributes: + * `idp_id`: unique identifier for the IdP + * `idp_name`: user-facing name for the IdP + * `idp_icon`: if specified in the IdP config, an MXC URI for an icon + for the IdP + * `idp_brand`: if specified in the IdP config, a textual identifier + for the brand of the IdP +* `sso_auth_success.html`: HTML page shown after a successful user interactive + authentication session. + Note that this page must include the JavaScript which notifies of a successful + authentication (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback). + This template has no additional variables. +* `sso_auth_bad_user.html`: HTML page shown after a user-interactive authentication + session which does not map correctly onto the expected user. + When rendering, this template is given the following variables: + * `server_name`: the homeserver's name. + * `user_id_to_verify`: the MXID of the user that we are trying to + validate. +* `sso_account_deactivated.html`: HTML page shown during single sign-on if a deactivated + user (according to Synapse's database) attempts to login. + This template has no additional variables. +* `sso_error.html`: HTML page to display to users if something goes wrong during the + OpenID Connect authentication process. + When rendering, this template is given two variables: + * `error`: the technical name of the error + * `error_description`: a human-readable message for the error diff --git a/docs/upgrade.md b/docs/upgrade.md index 8831c9d6cf..1c459d8e2b 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -112,6 +112,17 @@ environment variable. See [using a forward proxy with Synapse documentation](setup/forward_proxy.md) for details. +## Deprecation of `template_dir` + +The `template_dir` settings in the `sso`, `account_validity` and `email` sections of the +configuration file are now deprecated. Server admins should use the new +`templates.custom_template_directory` setting in the configuration file and use one single +custom template directory for all aforementioned features. Template file names remain +unchanged. See [the related documentation](https://matrix-org.github.io/synapse/latest/templates.html) +for more information and examples. + +We plan to remove support for these settings in October 2021. + # Upgrading to v1.39.0 diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py index 9acce5996e..52e63ab1f6 100644 --- a/synapse/config/account_validity.py +++ b/synapse/config/account_validity.py @@ -78,6 +78,11 @@ class AccountValidityConfig(Config): ) # Read and store template content + custom_template_directories = ( + self.root.server.custom_template_directory, + account_validity_template_dir, + ) + ( self.account_validity_account_renewed_template, self.account_validity_account_previously_renewed_template, @@ -88,5 +93,5 @@ class AccountValidityConfig(Config): "account_previously_renewed.html", invalid_token_template_filename, ], - (td for td in (account_validity_template_dir,) if td), + (td for td in custom_template_directories if td), ) diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index fc74b4a8b9..4477419196 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -258,7 +258,12 @@ class EmailConfig(Config): add_threepid_template_success_html, ], ( - td for td in (template_dir,) if td + td + for td in ( + self.root.server.custom_template_directory, + template_dir, + ) + if td ), # Filter out template_dir if not provided ) @@ -299,7 +304,14 @@ class EmailConfig(Config): self.email_notif_template_text, ) = self.read_templates( [notif_template_html, notif_template_text], - (td for td in (template_dir,) if td), + ( + td + for td in ( + self.root.server.custom_template_directory, + template_dir, + ) + if td + ), # Filter out template_dir if not provided ) self.email_notif_for_new_users = email_config.get( @@ -322,7 +334,14 @@ class EmailConfig(Config): self.account_validity_template_text, ) = self.read_templates( [expiry_template_html, expiry_template_text], - (td for td in (template_dir,) if td), + ( + td + for td in ( + self.root.server.custom_template_directory, + template_dir, + ) + if td + ), # Filter out template_dir if not provided ) subjects_config = email_config.get("subjects", {}) @@ -354,6 +373,9 @@ class EmailConfig(Config): """\ # Configuration for sending emails from Synapse. # + # Server admins can configure custom templates for email content. See + # https://matrix-org.github.io/synapse/latest/templates.html for more information. + # email: # The hostname of the outgoing SMTP server to use. Defaults to 'localhost'. # @@ -430,49 +452,6 @@ class EmailConfig(Config): # #invite_client_location: https://app.element.io - # Directory in which Synapse will try to find the template files below. - # If not set, or the files named below are not found within the template - # directory, default templates from within the Synapse package will be used. - # - # Synapse will look for the following templates in this directory: - # - # * The contents of email notifications of missed events: 'notif_mail.html' and - # 'notif_mail.txt'. - # - # * The contents of account expiry notice emails: 'notice_expiry.html' and - # 'notice_expiry.txt'. - # - # * The contents of password reset emails sent by the homeserver: - # 'password_reset.html' and 'password_reset.txt' - # - # * An HTML page that a user will see when they follow the link in the password - # reset email. The user will be asked to confirm the action before their - # password is reset: 'password_reset_confirmation.html' - # - # * HTML pages for success and failure that a user will see when they confirm - # the password reset flow using the page above: 'password_reset_success.html' - # and 'password_reset_failure.html' - # - # * The contents of address verification emails sent during registration: - # 'registration.html' and 'registration.txt' - # - # * HTML pages for success and failure that a user will see when they follow - # the link in an address verification email sent during registration: - # 'registration_success.html' and 'registration_failure.html' - # - # * The contents of address verification emails sent when an address is added - # to a Matrix account: 'add_threepid.html' and 'add_threepid.txt' - # - # * HTML pages for success and failure that a user will see when they follow - # the link in an address verification email sent when an address is added - # to a Matrix account: 'add_threepid_success.html' and - # 'add_threepid_failure.html' - # - # You can see the default templates at: - # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates - # - #template_dir: "res/templates" - # Subjects to use when sending emails from Synapse. # # The placeholder '%%(app)s' will be replaced with the value of the 'app_name' diff --git a/synapse/config/server.py b/synapse/config/server.py index 187b4301a0..8494795919 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -710,6 +710,18 @@ class ServerConfig(Config): # Turn the list into a set to improve lookup speed. self.next_link_domain_whitelist = set(next_link_domain_whitelist) + templates_config = config.get("templates") or {} + if not isinstance(templates_config, dict): + raise ConfigError("The 'templates' section must be a dictionary") + + self.custom_template_directory = templates_config.get( + "custom_template_directory" + ) + if self.custom_template_directory is not None and not isinstance( + self.custom_template_directory, str + ): + raise ConfigError("'custom_template_directory' must be a string") + def has_tls_listener(self) -> bool: return any(listener.tls for listener in self.listeners) @@ -1284,6 +1296,19 @@ class ServerConfig(Config): # all domains. # #next_link_domain_whitelist: ["matrix.org"] + + # Templates to use when generating email or HTML page contents. + # + templates: + # Directory in which Synapse will try to find template files to use to generate + # email or HTML page contents. + # If not set, or a file is not found within the template directory, a default + # template from within the Synapse package will be used. + # + # See https://matrix-org.github.io/synapse/latest/templates.html for more + # information about using custom templates. + # + #custom_template_directory: /path/to/custom/templates/ """ % locals() ) diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 4b590e0535..fe1177ab81 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -45,6 +45,11 @@ class SSOConfig(Config): self.sso_template_dir = sso_config.get("template_dir") # Read templates from disk + custom_template_directories = ( + self.root.server.custom_template_directory, + self.sso_template_dir, + ) + ( self.sso_login_idp_picker_template, self.sso_redirect_confirm_template, @@ -63,7 +68,7 @@ class SSOConfig(Config): "sso_auth_success.html", "sso_auth_bad_user.html", ], - (td for td in (self.sso_template_dir,) if td), + (td for td in custom_template_directories if td), ) # These templates have no placeholders, so render them here @@ -94,6 +99,9 @@ class SSOConfig(Config): # Additional settings to use with single-sign on systems such as OpenID Connect, # SAML2 and CAS. # + # Server admins can configure custom templates for pages related to SSO. See + # https://matrix-org.github.io/synapse/latest/templates.html for more information. + # sso: # A list of client URLs which are whitelisted so that the user does not # have to confirm giving access to their account to the URL. Any client @@ -125,167 +133,4 @@ class SSOConfig(Config): # information when first signing in. Defaults to false. # #update_profile_information: true - - # Directory in which Synapse will try to find the template files below. - # If not set, or the files named below are not found within the template - # directory, default templates from within the Synapse package will be used. - # - # Synapse will look for the following templates in this directory: - # - # * HTML page to prompt the user to choose an Identity Provider during - # login: 'sso_login_idp_picker.html'. - # - # This is only used if multiple SSO Identity Providers are configured. - # - # When rendering, this template is given the following variables: - # * redirect_url: the URL that the user will be redirected to after - # login. - # - # * server_name: the homeserver's name. - # - # * providers: a list of available Identity Providers. Each element is - # an object with the following attributes: - # - # * idp_id: unique identifier for the IdP - # * idp_name: user-facing name for the IdP - # * idp_icon: if specified in the IdP config, an MXC URI for an icon - # for the IdP - # * idp_brand: if specified in the IdP config, a textual identifier - # for the brand of the IdP - # - # The rendered HTML page should contain a form which submits its results - # back as a GET request, with the following query parameters: - # - # * redirectUrl: the client redirect URI (ie, the `redirect_url` passed - # to the template) - # - # * idp: the 'idp_id' of the chosen IDP. - # - # * HTML page to prompt new users to enter a userid and confirm other - # details: 'sso_auth_account_details.html'. This is only shown if the - # SSO implementation (with any user_mapping_provider) does not return - # a localpart. - # - # When rendering, this template is given the following variables: - # - # * server_name: the homeserver's name. - # - # * idp: details of the SSO Identity Provider that the user logged in - # with: an object with the following attributes: - # - # * idp_id: unique identifier for the IdP - # * idp_name: user-facing name for the IdP - # * idp_icon: if specified in the IdP config, an MXC URI for an icon - # for the IdP - # * idp_brand: if specified in the IdP config, a textual identifier - # for the brand of the IdP - # - # * user_attributes: an object containing details about the user that - # we received from the IdP. May have the following attributes: - # - # * display_name: the user's display_name - # * emails: a list of email addresses - # - # The template should render a form which submits the following fields: - # - # * username: the localpart of the user's chosen user id - # - # * HTML page allowing the user to consent to the server's terms and - # conditions. This is only shown for new users, and only if - # `user_consent.require_at_registration` is set. - # - # When rendering, this template is given the following variables: - # - # * server_name: the homeserver's name. - # - # * user_id: the user's matrix proposed ID. - # - # * user_profile.display_name: the user's proposed display name, if any. - # - # * consent_version: the version of the terms that the user will be - # shown - # - # * terms_url: a link to the page showing the terms. - # - # The template should render a form which submits the following fields: - # - # * accepted_version: the version of the terms accepted by the user - # (ie, 'consent_version' from the input variables). - # - # * HTML page for a confirmation step before redirecting back to the client - # with the login token: 'sso_redirect_confirm.html'. - # - # When rendering, this template is given the following variables: - # - # * redirect_url: the URL the user is about to be redirected to. - # - # * display_url: the same as `redirect_url`, but with the query - # parameters stripped. The intention is to have a - # human-readable URL to show to users, not to use it as - # the final address to redirect to. - # - # * server_name: the homeserver's name. - # - # * new_user: a boolean indicating whether this is the user's first time - # logging in. - # - # * user_id: the user's matrix ID. - # - # * user_profile.avatar_url: an MXC URI for the user's avatar, if any. - # None if the user has not set an avatar. - # - # * user_profile.display_name: the user's display name. None if the user - # has not set a display name. - # - # * HTML page which notifies the user that they are authenticating to confirm - # an operation on their account during the user interactive authentication - # process: 'sso_auth_confirm.html'. - # - # When rendering, this template is given the following variables: - # * redirect_url: the URL the user is about to be redirected to. - # - # * description: the operation which the user is being asked to confirm - # - # * idp: details of the Identity Provider that we will use to confirm - # the user's identity: an object with the following attributes: - # - # * idp_id: unique identifier for the IdP - # * idp_name: user-facing name for the IdP - # * idp_icon: if specified in the IdP config, an MXC URI for an icon - # for the IdP - # * idp_brand: if specified in the IdP config, a textual identifier - # for the brand of the IdP - # - # * HTML page shown after a successful user interactive authentication session: - # 'sso_auth_success.html'. - # - # Note that this page must include the JavaScript which notifies of a successful authentication - # (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback). - # - # This template has no additional variables. - # - # * HTML page shown after a user-interactive authentication session which - # does not map correctly onto the expected user: 'sso_auth_bad_user.html'. - # - # When rendering, this template is given the following variables: - # * server_name: the homeserver's name. - # * user_id_to_verify: the MXID of the user that we are trying to - # validate. - # - # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database) - # attempts to login: 'sso_account_deactivated.html'. - # - # This template has no additional variables. - # - # * HTML page to display to users if something goes wrong during the - # OpenID Connect authentication process: 'sso_error.html'. - # - # When rendering, this template is given two variables: - # * error: the technical name of the error - # * error_description: a human-readable message for the error - # - # You can see the default templates at: - # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates - # - #template_dir: "res/templates" """ diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 82725853bc..2f99d31c42 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -91,6 +91,7 @@ class ModuleApi: self._state = hs.get_state_handler() self._clock: Clock = hs.get_clock() self._send_email_handler = hs.get_send_email_handler() + self.custom_template_dir = hs.config.server.custom_template_directory try: app_name = self._hs.config.email_app_name @@ -679,7 +680,7 @@ class ModuleApi: """ return self._hs.config.read_templates( filenames, - (td for td in (custom_template_directory,) if td), + (td for td in (self.custom_template_dir, custom_template_directory) if td), ) diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py index 488b97b32e..fc62a09b7f 100644 --- a/synapse/rest/synapse/client/new_user_consent.py +++ b/synapse/rest/synapse/client/new_user_consent.py @@ -46,6 +46,8 @@ class NewUserConsentResource(DirectServeHtmlResource): self._consent_version = hs.config.consent.user_consent_version def template_search_dirs(): + if hs.config.server.custom_template_directory: + yield hs.config.server.custom_template_directory if hs.config.sso.sso_template_dir: yield hs.config.sso.sso_template_dir yield hs.config.sso.default_template_dir diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index ab24ec0a8e..c15b83c387 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -74,6 +74,8 @@ class AccountDetailsResource(DirectServeHtmlResource): self._sso_handler = hs.get_sso_handler() def template_search_dirs(): + if hs.config.server.custom_template_directory: + yield hs.config.server.custom_template_directory if hs.config.sso.sso_template_dir: yield hs.config.sso.sso_template_dir yield hs.config.sso.default_template_dir -- cgit 1.5.1 From 84469bdac773ddb79cfc99f31bbac78d27450682 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 17 Aug 2021 14:02:50 +0100 Subject: Remove the unused public_room_list_stream (#10565) Co-authored-by: Patrick Cloke --- changelog.d/10565.misc | 1 + synapse/app/admin_cmd.py | 2 - synapse/app/generic_worker.py | 4 +- synapse/replication/slave/storage/room.py | 37 ----- synapse/replication/tcp/streams/__init__.py | 3 - synapse/replication/tcp/streams/_base.py | 25 ---- synapse/storage/databases/main/__init__.py | 4 +- synapse/storage/databases/main/room.py | 215 +++++----------------------- synapse/storage/schema/__init__.py | 7 +- 9 files changed, 48 insertions(+), 250 deletions(-) create mode 100644 changelog.d/10565.misc delete mode 100644 synapse/replication/slave/storage/room.py (limited to 'synapse') diff --git a/changelog.d/10565.misc b/changelog.d/10565.misc new file mode 100644 index 0000000000..06796b61ab --- /dev/null +++ b/changelog.d/10565.misc @@ -0,0 +1 @@ +Remove the unused public rooms replication stream. \ No newline at end of file diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 3234d9ebba..7396db93c6 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -38,7 +38,6 @@ from synapse.replication.slave.storage.groups import SlavedGroupServerStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.room import RoomStore from synapse.server import HomeServer from synapse.util.logcontext import LoggingContext from synapse.util.versionstring import get_version_string @@ -58,7 +57,6 @@ class AdminCmdSlavedStore( SlavedPushRuleStore, SlavedEventStore, SlavedClientIpStore, - RoomStore, BaseSlavedStore, ): pass diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index d7b425a7ab..845e6a8220 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -64,7 +64,6 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.pushers import SlavedPusherStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.room import RoomStore from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.client import ( account_data, @@ -114,6 +113,7 @@ from synapse.storage.databases.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, ) from synapse.storage.databases.main.presence import PresenceStore +from synapse.storage.databases.main.room import RoomWorkerStore from synapse.storage.databases.main.search import SearchStore from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.transactions import TransactionWorkerStore @@ -237,7 +237,7 @@ class GenericWorkerSlavedStore( ClientIpWorkerStore, SlavedEventStore, SlavedKeyStore, - RoomStore, + RoomWorkerStore, DirectoryStore, SlavedApplicationServiceStore, SlavedRegistrationStore, diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py deleted file mode 100644 index 8cc6de3f46..0000000000 --- a/synapse/replication/slave/storage/room.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.replication.tcp.streams import PublicRoomsStream -from synapse.storage.database import DatabasePool -from synapse.storage.databases.main.room import RoomWorkerStore - -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker - - -class RoomStore(RoomWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - self._public_room_id_gen = SlavedIdTracker( - db_conn, "public_room_list_stream", "stream_id" - ) - - def get_current_public_room_stream_id(self): - return self._public_room_id_gen.get_current_token() - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == PublicRoomsStream.NAME: - self._public_room_id_gen.advance(instance_name, token) - - return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 4c0023c68a..f41eabd85e 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -32,7 +32,6 @@ from synapse.replication.tcp.streams._base import ( GroupServerStream, PresenceFederationStream, PresenceStream, - PublicRoomsStream, PushersStream, PushRulesStream, ReceiptsStream, @@ -57,7 +56,6 @@ STREAMS_MAP = { PushRulesStream, PushersStream, CachesStream, - PublicRoomsStream, DeviceListsStream, ToDeviceStream, FederationStream, @@ -79,7 +77,6 @@ __all__ = [ "PushRulesStream", "PushersStream", "CachesStream", - "PublicRoomsStream", "DeviceListsStream", "ToDeviceStream", "TagAccountDataStream", diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 3716c41bea..9b905aba9d 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -447,31 +447,6 @@ class CachesStream(Stream): ) -class PublicRoomsStream(Stream): - """The public rooms list changed""" - - PublicRoomsStreamRow = namedtuple( - "PublicRoomsStreamRow", - ( - "room_id", # str - "visibility", # str - "appservice_id", # str, optional - "network_id", # str, optional - ), - ) - - NAME = "public_rooms" - ROW_TYPE = PublicRoomsStreamRow - - def __init__(self, hs): - store = hs.get_datastore() - super().__init__( - hs.get_instance_name(), - current_token_without_instance(store.get_current_public_room_stream_id), - store.get_all_new_public_rooms, - ) - - class DeviceListsStream(Stream): """Either a user has updated their devices or a remote server needs to be told about a device update. diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 8d9f07111d..01b918e12e 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -127,9 +127,6 @@ class DataStore( self._clock = hs.get_clock() self.database_engine = database.engine - self._public_room_id_gen = StreamIdGenerator( - db_conn, "public_room_list_stream", "stream_id" - ) self._device_list_id_gen = StreamIdGenerator( db_conn, "device_lists_stream", @@ -170,6 +167,7 @@ class DataStore( sequence_name="cache_invalidation_stream_seq", writers=[], ) + else: self._cache_id_gen = None diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 443e5f3315..c7a1c1e8d9 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -890,55 +890,6 @@ class RoomWorkerStore(SQLBaseStore): return total_media_quarantined - async def get_all_new_public_rooms( - self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: - """Get updates for public rooms replication stream. - - Args: - instance_name: The writer we want to fetch updates from. Unused - here since there is only ever one writer. - last_id: The token to fetch updates from. Exclusive. - current_id: The token to fetch updates up to. Inclusive. - limit: The requested limit for the number of rows to return. The - function may return more or fewer rows. - - Returns: - A tuple consisting of: the updates, a token to use to fetch - subsequent updates, and whether we returned fewer rows than exists - between the requested tokens due to the limit. - - The token returned can be used in a subsequent call to this - function to get further updatees. - - The updates are a list of 2-tuples of stream ID and the row data - """ - if last_id == current_id: - return [], current_id, False - - def get_all_new_public_rooms(txn): - sql = """ - SELECT stream_id, room_id, visibility, appservice_id, network_id - FROM public_room_list_stream - WHERE stream_id > ? AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? - """ - - txn.execute(sql, (last_id, current_id, limit)) - updates = [(row[0], row[1:]) for row in txn] - limited = False - upto_token = current_id - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited - - return await self.db_pool.runInteraction( - "get_all_new_public_rooms", get_all_new_public_rooms - ) - async def get_rooms_for_retention_period_in_range( self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False ) -> Dict[str, dict]: @@ -1410,34 +1361,17 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): StoreError if the room could not be stored. """ try: - - def store_room_txn(txn, next_id): - self.db_pool.simple_insert_txn( - txn, - "rooms", - { - "room_id": room_id, - "creator": room_creator_user_id, - "is_public": is_public, - "room_version": room_version.identifier, - "has_auth_chain_index": True, - }, - ) - if is_public: - self.db_pool.simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - }, - ) - - async with self._public_room_id_gen.get_next() as next_id: - await self.db_pool.runInteraction( - "store_room_txn", store_room_txn, next_id - ) + await self.db_pool.simple_insert( + "rooms", + { + "room_id": room_id, + "creator": room_creator_user_id, + "is_public": is_public, + "room_version": room_version.identifier, + "has_auth_chain_index": True, + }, + desc="store_room", + ) except Exception as e: logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") @@ -1470,49 +1404,14 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): lock=False, ) - async def set_room_is_public(self, room_id, is_public): - def set_room_is_public_txn(txn, next_id): - self.db_pool.simple_update_one_txn( - txn, - table="rooms", - keyvalues={"room_id": room_id}, - updatevalues={"is_public": is_public}, - ) - - entries = self.db_pool.simple_select_list_txn( - txn, - table="public_room_list_stream", - keyvalues={ - "room_id": room_id, - "appservice_id": None, - "network_id": None, - }, - retcols=("stream_id", "visibility"), - ) - - entries.sort(key=lambda r: r["stream_id"]) - - add_to_stream = True - if entries: - add_to_stream = bool(entries[-1]["visibility"]) != is_public - - if add_to_stream: - self.db_pool.simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - "appservice_id": None, - "network_id": None, - }, - ) + async def set_room_is_public(self, room_id: str, is_public: bool) -> None: + await self.db_pool.simple_update_one( + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"is_public": is_public}, + desc="set_room_is_public", + ) - async with self._public_room_id_gen.get_next() as next_id: - await self.db_pool.runInteraction( - "set_room_is_public", set_room_is_public_txn, next_id - ) self.hs.get_notifier().on_new_replication_data() async def set_room_is_public_appservice( @@ -1533,68 +1432,33 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): list. """ - def set_room_is_public_appservice_txn(txn, next_id): - if is_public: - try: - self.db_pool.simple_insert_txn( - txn, - table="appservice_room_list", - values={ - "appservice_id": appservice_id, - "network_id": network_id, - "room_id": room_id, - }, - ) - except self.database_engine.module.IntegrityError: - # We've already inserted, nothing to do. - return - else: - self.db_pool.simple_delete_txn( - txn, - table="appservice_room_list", - keyvalues={ - "appservice_id": appservice_id, - "network_id": network_id, - "room_id": room_id, - }, - ) - - entries = self.db_pool.simple_select_list_txn( - txn, - table="public_room_list_stream", + if is_public: + await self.db_pool.simple_upsert( + table="appservice_room_list", keyvalues={ + "appservice_id": appservice_id, + "network_id": network_id, "room_id": room_id, + }, + values={}, + insertion_values={ "appservice_id": appservice_id, "network_id": network_id, + "room_id": room_id, }, - retcols=("stream_id", "visibility"), + desc="set_room_is_public_appservice_true", ) - - entries.sort(key=lambda r: r["stream_id"]) - - add_to_stream = True - if entries: - add_to_stream = bool(entries[-1]["visibility"]) != is_public - - if add_to_stream: - self.db_pool.simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - "appservice_id": appservice_id, - "network_id": network_id, - }, - ) - - async with self._public_room_id_gen.get_next() as next_id: - await self.db_pool.runInteraction( - "set_room_is_public_appservice", - set_room_is_public_appservice_txn, - next_id, + else: + await self.db_pool.simple_delete( + table="appservice_room_list", + keyvalues={ + "appservice_id": appservice_id, + "network_id": network_id, + "room_id": room_id, + }, + desc="set_room_is_public_appservice_false", ) + self.hs.get_notifier().on_new_replication_data() async def add_event_report( @@ -1787,9 +1651,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): "get_event_reports_paginate", _get_event_reports_paginate_txn ) - def get_current_public_room_stream_id(self): - return self._public_room_id_gen.get_current_token() - async def block_room(self, room_id: str, user_id: str) -> None: """Marks the room as blocked. Can be called multiple times. diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 7e0687e197..a5bc0ee8a5 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 62 +SCHEMA_VERSION = 63 """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -25,6 +25,11 @@ for more information on how this works. Changes in SCHEMA_VERSION = 61: - The `user_stats_historical` and `room_stats_historical` tables are not written and are not read (previously, they were written but not read). + +Changes in SCHEMA_VERSION = 63: + - The `public_room_list_stream` table is not written nor read to + (previously, it was written and read to, but not for any significant purpose). + https://github.com/matrix-org/synapse/pull/10565 """ -- cgit 1.5.1