From 406f7bfa171147f662c2f74d132334527f4129a9 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Tue, 12 Oct 2021 10:44:59 +0100 Subject: Add an approximate difference method to StateFilters (#10825) --- synapse/storage/state.py | 172 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 171 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 5e86befde4..b5ba1560d1 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -15,9 +15,11 @@ import logging from typing import ( TYPE_CHECKING, Awaitable, + Collection, Dict, Iterable, List, + Mapping, Optional, Set, Tuple, @@ -29,7 +31,7 @@ from frozendict import frozendict from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.types import MutableStateMap, StateMap +from synapse.types import MutableStateMap, StateKey, StateMap if TYPE_CHECKING: from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad @@ -134,6 +136,23 @@ class StateFilter: include_others=True, ) + @staticmethod + def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool): + """ + Returns a (frozen) StateFilter with the same contents as the parameters + specified here, which can be made of mutable types. + """ + types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {} + for state_types, state_keys in types.items(): + if state_keys is not None: + types_with_frozen_values[state_types] = frozenset(state_keys) + else: + types_with_frozen_values[state_types] = None + + return StateFilter( + frozendict(types_with_frozen_values), include_others=include_others + ) + def return_expanded(self) -> "StateFilter": """Creates a new StateFilter where type wild cards have been removed (except for memberships). The returned filter is a superset of the @@ -356,6 +375,157 @@ class StateFilter: return member_filter, non_member_filter + def _decompose_into_four_parts( + self, + ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]: + """ + Decomposes this state filter into 4 constituent parts, which can be + thought of as this: + all? - minus_wildcards + plus_wildcards + plus_state_keys + + where + * all represents ALL state + * minus_wildcards represents entire state types to remove + * plus_wildcards represents entire state types to add + * plus_state_keys represents individual state keys to add + + See `recompose_from_four_parts` for the other direction of this + correspondence. + """ + is_all = self.include_others + excluded_types: Set[str] = {t for t in self.types if is_all} + wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None} + concrete_keys: Set[StateKey] = set(self.concrete_types()) + + return (is_all, excluded_types), (wildcard_types, concrete_keys) + + @staticmethod + def _recompose_from_four_parts( + all_part: bool, + minus_wildcards: Set[str], + plus_wildcards: Set[str], + plus_state_keys: Set[StateKey], + ) -> "StateFilter": + """ + Recomposes a state filter from 4 parts. + + See `decompose_into_four_parts` (the other direction of this + correspondence) for descriptions on each of the parts. + """ + + # {state type -> set of state keys OR None for wildcard} + # (The same structure as that of a StateFilter.) + new_types: Dict[str, Optional[Set[str]]] = {} + + # if we start with all, insert the excluded statetypes as empty sets + # to prevent them from being included + if all_part: + new_types.update({state_type: set() for state_type in minus_wildcards}) + + # insert the plus wildcards + new_types.update({state_type: None for state_type in plus_wildcards}) + + # insert the specific state keys + for state_type, state_key in plus_state_keys: + if state_type in new_types: + entry = new_types[state_type] + if entry is not None: + entry.add(state_key) + elif not all_part: + # don't insert if the entire type is already included by + # include_others as this would actually shrink the state allowed + # by this filter. + new_types[state_type] = {state_key} + + return StateFilter.freeze(new_types, include_others=all_part) + + def approx_difference(self, other: "StateFilter") -> "StateFilter": + """ + Returns a state filter which represents `self - other`. + + This is useful for determining what state remains to be pulled out of the + database if we want the state included by `self` but already have the state + included by `other`. + + The returned state filter + - MUST include all state events that are included by this filter (`self`) + unless they are included by `other`; + - MUST NOT include state events not included by this filter (`self`); and + - MAY be an over-approximation: the returned state filter + MAY additionally include some state events from `other`. + + This implementation attempts to return the narrowest such state filter. + In the case that `self` contains wildcards for state types where + `other` contains specific state keys, an approximation must be made: + the returned state filter keeps the wildcard, as state filters are not + able to express 'all state keys except some given examples'. + e.g. + StateFilter(m.room.member -> None (wildcard)) + minus + StateFilter(m.room.member -> {'@wombat:example.org'}) + is approximated as + StateFilter(m.room.member -> None (wildcard)) + """ + + # We first transform self and other into an alternative representation: + # - whether or not they include all events to begin with ('all') + # - if so, which event types are excluded? ('excludes') + # - which entire event types to include ('wildcards') + # - which concrete state keys to include ('concrete state keys') + (self_all, self_excludes), ( + self_wildcards, + self_concrete_keys, + ) = self._decompose_into_four_parts() + (other_all, other_excludes), ( + other_wildcards, + other_concrete_keys, + ) = other._decompose_into_four_parts() + + # Start with an estimate of the difference based on self + new_all = self_all + # Wildcards from the other can be added to the exclusion filter + new_excludes = self_excludes | other_wildcards + # We remove wildcards that appeared as wildcards in the other + new_wildcards = self_wildcards - other_wildcards + # We filter out the concrete state keys that appear in the other + # as wildcards or concrete state keys. + new_concrete_keys = { + (state_type, state_key) + for (state_type, state_key) in self_concrete_keys + if state_type not in other_wildcards + } - other_concrete_keys + + if other_all: + if self_all: + # If self starts with all, then we add as wildcards any + # types which appear in the other's exclusion filter (but + # aren't in the self exclusion filter). This is as the other + # filter will return everything BUT the types in its exclusion, so + # we need to add those excluded types that also match the self + # filter as wildcard types in the new filter. + new_wildcards |= other_excludes.difference(self_excludes) + + # If other is an `include_others` then the difference isn't. + new_all = False + # (We have no need for excludes when we don't start with all, as there + # is nothing to exclude.) + new_excludes = set() + + # We also filter out all state types that aren't in the exclusion + # list of the other. + new_wildcards &= other_excludes + new_concrete_keys = { + (state_type, state_key) + for (state_type, state_key) in new_concrete_keys + if state_type in other_excludes + } + + # Transform our newly-constructed state filter from the alternative + # representation back into the normal StateFilter representation. + return StateFilter._recompose_from_four_parts( + new_all, new_excludes, new_wildcards, new_concrete_keys + ) + class StateGroupStorage: """High level interface to fetching state for event.""" -- cgit 1.5.1 From 6b18eb443054087c4a8153b19b3cc4d3b731d324 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 12 Oct 2021 11:23:46 +0100 Subject: Fix opentracing and Prometheus metrics for replication requests (#10996) This commit fixes two bugs to do with decorators not instrumenting `ReplicationEndpoint`'s `send_request` correctly. There are two decorators on `send_request`: Prometheus' `Gauge.track_inprogress()` and Synapse's `opentracing.trace`. `Gauge.track_inprogress()` does not have any support for async functions when used as a decorator. Since async functions behave like regular functions that return coroutines, only the creation of the coroutine was covered by the metric and none of the actual body of `send_request`. `Gauge.track_inprogress()` returns a regular, non-async function wrapping `send_request`, which is the source of the next bug. The `opentracing.trace` decorator would normally handle async functions correctly, but since the wrapped `send_request` is a non-async function, the decorator ends up suffering from the same issue as `Gauge.track_inprogress()`: the opentracing span only measures the creation of the coroutine and none of the actual function body. Using `Gauge.track_inprogress()` as a context manager instead of a decorator resolves both bugs. --- changelog.d/10996.misc | 1 + synapse/logging/opentracing.py | 8 ++ synapse/replication/http/_base.py | 154 +++++++++++++++++++------------------- 3 files changed, 87 insertions(+), 76 deletions(-) create mode 100644 changelog.d/10996.misc (limited to 'synapse') diff --git a/changelog.d/10996.misc b/changelog.d/10996.misc new file mode 100644 index 0000000000..c830d7ec2c --- /dev/null +++ b/changelog.d/10996.misc @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.21.0 that causes opentracing and Prometheus metrics for replication requests to be measured incorrectly. diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 5276c4bfcc..20d23a4260 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -807,6 +807,14 @@ def trace(func=None, opname=None): result.addCallbacks(call_back, err_back) else: + if inspect.isawaitable(result): + logger.error( + "@trace may not have wrapped %s correctly! " + "The function is not async but returned a %s.", + func.__qualname__, + type(result).__name__, + ) + scope.__exit__(None, None, None) return result diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index f1b78d09f9..e047ec74d8 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -182,85 +182,87 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): ) @trace(opname="outgoing_replication_request") - @outgoing_gauge.track_inprogress() async def send_request(*, instance_name="master", **kwargs): - if instance_name == local_instance_name: - raise Exception("Trying to send HTTP request to self") - if instance_name == "master": - host = master_host - port = master_port - elif instance_name in instance_map: - host = instance_map[instance_name].host - port = instance_map[instance_name].port - else: - raise Exception( - "Instance %r not in 'instance_map' config" % (instance_name,) + with outgoing_gauge.track_inprogress(): + if instance_name == local_instance_name: + raise Exception("Trying to send HTTP request to self") + if instance_name == "master": + host = master_host + port = master_port + elif instance_name in instance_map: + host = instance_map[instance_name].host + port = instance_map[instance_name].port + else: + raise Exception( + "Instance %r not in 'instance_map' config" % (instance_name,) + ) + + data = await cls._serialize_payload(**kwargs) + + url_args = [ + urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS + ] + + if cls.CACHE: + txn_id = random_string(10) + url_args.append(txn_id) + + if cls.METHOD == "POST": + request_func = client.post_json_get_json + elif cls.METHOD == "PUT": + request_func = client.put_json + elif cls.METHOD == "GET": + request_func = client.get_json + else: + # We have already asserted in the constructor that a + # compatible was picked, but lets be paranoid. + raise Exception( + "Unknown METHOD on %s replication endpoint" % (cls.NAME,) + ) + + uri = "http://%s:%s/_synapse/replication/%s/%s" % ( + host, + port, + cls.NAME, + "/".join(url_args), ) - data = await cls._serialize_payload(**kwargs) - - url_args = [ - urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS - ] - - if cls.CACHE: - txn_id = random_string(10) - url_args.append(txn_id) - - if cls.METHOD == "POST": - request_func = client.post_json_get_json - elif cls.METHOD == "PUT": - request_func = client.put_json - elif cls.METHOD == "GET": - request_func = client.get_json - else: - # We have already asserted in the constructor that a - # compatible was picked, but lets be paranoid. - raise Exception( - "Unknown METHOD on %s replication endpoint" % (cls.NAME,) - ) - - uri = "http://%s:%s/_synapse/replication/%s/%s" % ( - host, - port, - cls.NAME, - "/".join(url_args), - ) - - try: - # We keep retrying the same request for timeouts. This is so that we - # have a good idea that the request has either succeeded or failed on - # the master, and so whether we should clean up or not. - while True: - headers: Dict[bytes, List[bytes]] = {} - # Add an authorization header, if configured. - if replication_secret: - headers[b"Authorization"] = [b"Bearer " + replication_secret] - opentracing.inject_header_dict(headers, check_destination=False) - try: - result = await request_func(uri, data, headers=headers) - break - except RequestTimedOutError: - if not cls.RETRY_ON_TIMEOUT: - raise - - logger.warning("%s request timed out; retrying", cls.NAME) - - # If we timed out we probably don't need to worry about backing - # off too much, but lets just wait a little anyway. - await clock.sleep(1) - except HttpResponseException as e: - # We convert to SynapseError as we know that it was a SynapseError - # on the main process that we should send to the client. (And - # importantly, not stack traces everywhere) - _outgoing_request_counter.labels(cls.NAME, e.code).inc() - raise e.to_synapse_error() - except Exception as e: - _outgoing_request_counter.labels(cls.NAME, "ERR").inc() - raise SynapseError(502, "Failed to talk to main process") from e - - _outgoing_request_counter.labels(cls.NAME, 200).inc() - return result + try: + # We keep retrying the same request for timeouts. This is so that we + # have a good idea that the request has either succeeded or failed + # on the master, and so whether we should clean up or not. + while True: + headers: Dict[bytes, List[bytes]] = {} + # Add an authorization header, if configured. + if replication_secret: + headers[b"Authorization"] = [ + b"Bearer " + replication_secret + ] + opentracing.inject_header_dict(headers, check_destination=False) + try: + result = await request_func(uri, data, headers=headers) + break + except RequestTimedOutError: + if not cls.RETRY_ON_TIMEOUT: + raise + + logger.warning("%s request timed out; retrying", cls.NAME) + + # If we timed out we probably don't need to worry about backing + # off too much, but lets just wait a little anyway. + await clock.sleep(1) + except HttpResponseException as e: + # We convert to SynapseError as we know that it was a SynapseError + # on the main process that we should send to the client. (And + # importantly, not stack traces everywhere) + _outgoing_request_counter.labels(cls.NAME, e.code).inc() + raise e.to_synapse_error() + except Exception as e: + _outgoing_request_counter.labels(cls.NAME, "ERR").inc() + raise SynapseError(502, "Failed to talk to main process") from e + + _outgoing_request_counter.labels(cls.NAME, 200).inc() + return result return send_request -- cgit 1.5.1 From b8b905c4ea8a0d922d34d469f7d220f53def1b53 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 12 Oct 2021 11:24:05 +0100 Subject: Fix inconsistent behavior of `get_last_client_by_ip` (#10970) Make `get_last_client_by_ip` return the same dictionary structure regardless of whether the data has been persisted to the database. This change will allow slightly cleaner type hints to be applied later on. --- changelog.d/10970.misc | 1 + synapse/storage/databases/main/client_ips.py | 13 ++++++--- tests/storage/test_client_ips.py | 43 ++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 4 deletions(-) create mode 100644 changelog.d/10970.misc (limited to 'synapse') diff --git a/changelog.d/10970.misc b/changelog.d/10970.misc new file mode 100644 index 0000000000..bb75ea79a6 --- /dev/null +++ b/changelog.d/10970.misc @@ -0,0 +1 @@ +Fix inconsistent behavior of `get_last_client_by_ip` when reporting data that has not been stored in the database yet. diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index c77acc7c84..6c1ef09049 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -538,15 +538,20 @@ class ClientIpStore(ClientIpWorkerStore): """ ret = await super().get_last_client_ip_by_device(user_id, device_id) - # Update what is retrieved from the database with data which is pending insertion. + # Update what is retrieved from the database with data which is pending + # insertion, as if it has already been stored in the database. for key in self._batch_row_update: - uid, access_token, ip = key + uid, _access_token, ip = key if uid == user_id: user_agent, did, last_seen = self._batch_row_update[key] + + if did is None: + # These updates don't make it to the `devices` table + continue + if not device_id or did == device_id: - ret[(user_id, device_id)] = { + ret[(user_id, did)] = { "user_id": user_id, - "access_token": access_token, "ip": ip, "user_agent": user_agent, "device_id": did, diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index dada4f98c9..0e4013ebea 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -146,6 +146,49 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ], ) + @parameterized.expand([(False,), (True,)]) + def test_get_last_client_ip_by_device(self, after_persisting: bool): + """Test `get_last_client_ip_by_device` for persisted and unpersisted data""" + self.reactor.advance(12345678) + + user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success( + self.store.store_device( + user_id, + device_id, + "display name", + ) + ) + self.get_success( + self.store.insert_client_ip( + user_id, "access_token", "ip", "user_agent", device_id + ) + ) + + if after_persisting: + # Trigger the storage loop + self.reactor.advance(10) + + result = self.get_success( + self.store.get_last_client_ip_by_device(user_id, device_id) + ) + + self.assertEqual( + result, + { + (user_id, device_id): { + "user_id": user_id, + "device_id": device_id, + "ip": "ip", + "user_agent": "user_agent", + "last_seen": 12345678000, + }, + }, + ) + @parameterized.expand([(False,), (True,)]) def test_get_user_ip_and_agents(self, after_persisting: bool): """Test `get_user_ip_and_agents` for persisted and unpersisted data""" -- cgit 1.5.1 From 36224e056a0ba91b4541607c5ad5cd5152d0e672 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 12 Oct 2021 13:50:34 +0100 Subject: Add type hints to `synapse.storage.databases.main.client_ips` (#10972) --- changelog.d/10972.misc | 1 + mypy.ini | 4 + synapse/handlers/device.py | 15 ++- synapse/module_api/__init__.py | 6 +- synapse/storage/databases/main/client_ips.py | 140 +++++++++++++++++++-------- 5 files changed, 121 insertions(+), 45 deletions(-) create mode 100644 changelog.d/10972.misc (limited to 'synapse') diff --git a/changelog.d/10972.misc b/changelog.d/10972.misc new file mode 100644 index 0000000000..f66a7beaf0 --- /dev/null +++ b/changelog.d/10972.misc @@ -0,0 +1 @@ +Add type hints to `synapse.storage.databases.main.client_ips`. diff --git a/mypy.ini b/mypy.ini index a7019e2bd4..174a6edae6 100644 --- a/mypy.ini +++ b/mypy.ini @@ -53,6 +53,7 @@ files = synapse/storage/_base.py, synapse/storage/background_updates.py, synapse/storage/databases/main/appservice.py, + synapse/storage/databases/main/client_ips.py, synapse/storage/databases/main/events.py, synapse/storage/databases/main/keys.py, synapse/storage/databases/main/pusher.py, @@ -108,6 +109,9 @@ disallow_untyped_defs = True [mypy-synapse.state.*] disallow_untyped_defs = True +[mypy-synapse.storage.databases.main.client_ips] +disallow_untyped_defs = True + [mypy-synapse.storage.util.*] disallow_untyped_defs = True diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 75e6019760..6eafbea25d 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -14,7 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, +) from synapse.api import errors from synapse.api.constants import EventTypes @@ -595,7 +606,7 @@ class DeviceHandler(DeviceWorkerHandler): def _update_device_from_client_ips( - device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict] + device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]] ) -> None: ip = client_ips.get((device["user_id"], device["device_id"]), {}) device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 8ae21bc43c..b2a228c231 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -773,9 +773,9 @@ class ModuleApi: # Sanitize some of the data. We don't want to return tokens. return [ UserIpAndAgent( - ip=str(data["ip"]), - user_agent=str(data["user_agent"]), - last_seen=int(data["last_seen"]), + ip=data["ip"], + user_agent=data["user_agent"], + last_seen=data["last_seen"], ) for data in raw_data ] diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 6c1ef09049..b81d9218ce 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -13,14 +13,26 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast + +from typing_extensions import TypedDict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, make_tuple_comparison_clause -from synapse.types import UserID +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_tuple_comparison_clause, +) +from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore +from synapse.storage.types import Connection +from synapse.types import JsonDict, UserID from synapse.util.caches.lrucache import LruCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # Number of msec of granularity to store the user IP 'last seen' time. Smaller @@ -29,8 +41,31 @@ logger = logging.getLogger(__name__) LAST_SEEN_GRANULARITY = 120 * 1000 +class DeviceLastConnectionInfo(TypedDict): + """Metadata for the last connection seen for a user and device combination""" + + # These types must match the columns in the `devices` table + user_id: str + device_id: str + + ip: Optional[str] + user_agent: Optional[str] + last_seen: Optional[int] + + +class LastConnectionInfo(TypedDict): + """Metadata for the last connection seen for an access token and IP combination""" + + # These types must match the columns in the `user_ips` table + access_token: str + ip: str + + user_agent: str + last_seen: int + + class ClientIpBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -81,8 +116,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): "devices_last_seen", self._devices_last_seen_update ) - async def _remove_user_ip_nonunique(self, progress, batch_size): - def f(conn): + async def _remove_user_ip_nonunique( + self, progress: JsonDict, batch_size: int + ) -> int: + def f(conn: LoggingDatabaseConnection) -> None: txn = conn.cursor() txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") txn.close() @@ -93,14 +130,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): ) return 1 - async def _analyze_user_ip(self, progress, batch_size): + async def _analyze_user_ip(self, progress: JsonDict, batch_size: int) -> int: # Background update to analyze user_ips table before we run the # deduplication background update. The table may not have been analyzed # for ages due to the table locks. # # This will lock out the naive upserts to user_ips while it happens, but # the analyze should be quick (28GB table takes ~10s) - def user_ips_analyze(txn): + def user_ips_analyze(txn: LoggingTransaction) -> None: txn.execute("ANALYZE user_ips") await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze) @@ -109,16 +146,16 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return 1 - async def _remove_user_ip_dupes(self, progress, batch_size): + async def _remove_user_ip_dupes(self, progress: JsonDict, batch_size: int) -> int: # This works function works by scanning the user_ips table in batches # based on `last_seen`. For each row in a batch it searches the rest of # the table to see if there are any duplicates, if there are then they # are removed and replaced with a suitable row. # Fetch the start of the batch - begin_last_seen = progress.get("last_seen", 0) + begin_last_seen: int = progress.get("last_seen", 0) - def get_last_seen(txn): + def get_last_seen(txn: LoggingTransaction) -> Optional[int]: txn.execute( """ SELECT last_seen FROM user_ips @@ -129,7 +166,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): """, (begin_last_seen, batch_size), ) - row = txn.fetchone() + row = cast(Optional[Tuple[int]], txn.fetchone()) if row: return row[0] else: @@ -149,7 +186,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): end_last_seen, ) - def remove(txn): + def remove(txn: LoggingTransaction) -> None: # This works by looking at all entries in the given time span, and # then for each (user_id, access_token, ip) tuple in that range # checking for any duplicates in the rest of the table (via a join). @@ -161,10 +198,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): # Define the search space, which requires handling the last batch in # a different way + args: Tuple[int, ...] if last: clause = "? <= last_seen" args = (begin_last_seen,) else: + assert end_last_seen is not None clause = "? <= last_seen AND last_seen < ?" args = (begin_last_seen, end_last_seen) @@ -189,7 +228,9 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): ), args, ) - res = txn.fetchall() + res = cast( + List[Tuple[str, str, str, Optional[str], str, int, int]], txn.fetchall() + ) # We've got some duplicates for i in res: @@ -278,13 +319,15 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return batch_size - async def _devices_last_seen_update(self, progress, batch_size): + async def _devices_last_seen_update( + self, progress: JsonDict, batch_size: int + ) -> int: """Background update to insert last seen info into devices table""" - last_user_id = progress.get("last_user_id", "") - last_device_id = progress.get("last_device_id", "") + last_user_id: str = progress.get("last_user_id", "") + last_device_id: str = progress.get("last_device_id", "") - def _devices_last_seen_update_txn(txn): + def _devices_last_seen_update_txn(txn: LoggingTransaction) -> int: # This consists of two queries: # # 1. The sub-query searches for the next N devices and joins @@ -296,6 +339,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): # we'll just end up updating the same device row multiple # times, which is fine. + where_args: List[Union[str, int]] where_clause, where_args = make_tuple_comparison_clause( [("user_id", last_user_id), ("device_id", last_device_id)], ) @@ -319,7 +363,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): } txn.execute(sql, where_args + [batch_size]) - rows = txn.fetchall() + rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) if not rows: return 0 @@ -350,7 +394,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.user_ips_max_age = hs.config.server.user_ips_max_age @@ -359,7 +403,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) @wrap_as_background_process("prune_old_user_ips") - async def _prune_old_user_ips(self): + async def _prune_old_user_ips(self) -> None: """Removes entries in user IPs older than the configured period.""" if self.user_ips_max_age is None: @@ -394,9 +438,9 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): ) """ - timestamp = self.clock.time_msec() - self.user_ips_max_age + timestamp = self._clock.time_msec() - self.user_ips_max_age - def _prune_old_user_ips_txn(txn): + def _prune_old_user_ips_txn(txn: LoggingTransaction) -> None: txn.execute(sql, (timestamp,)) await self.db_pool.runInteraction( @@ -405,7 +449,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): async def get_last_client_ip_by_device( self, user_id: str, device_id: Optional[str] - ) -> Dict[Tuple[str, str], dict]: + ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]: """For each device_id listed, give the user_ip it was last seen on. The result might be slightly out of date as client IPs are inserted in batches. @@ -423,26 +467,32 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): if device_id is not None: keyvalues["device_id"] = device_id - res = await self.db_pool.simple_select_list( - table="devices", - keyvalues=keyvalues, - retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + res = cast( + List[DeviceLastConnectionInfo], + await self.db_pool.simple_select_list( + table="devices", + keyvalues=keyvalues, + retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), + ), ) return {(d["user_id"], d["device_id"]): d for d in res} -class ClientIpStore(ClientIpWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): +class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): - self.client_ip_last_seen = LruCache( + # (user_id, access_token, ip,) -> last_seen + self.client_ip_last_seen = LruCache[Tuple[str, str, str], int]( cache_name="client_ip_last_seen", max_size=50000 ) super().__init__(database, db_conn, hs) # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen) - self._batch_row_update = {} + self._batch_row_update: Dict[ + Tuple[str, str, str], Tuple[str, Optional[str], int] + ] = {} self._client_ip_looper = self._clock.looping_call( self._update_client_ips_batch, 5 * 1000 @@ -452,8 +502,14 @@ class ClientIpStore(ClientIpWorkerStore): ) async def insert_client_ip( - self, user_id, access_token, ip, user_agent, device_id, now=None - ): + self, + user_id: str, + access_token: str, + ip: str, + user_agent: str, + device_id: Optional[str], + now: Optional[int] = None, + ) -> None: if not now: now = int(self._clock.time_msec()) key = (user_id, access_token, ip) @@ -485,7 +541,11 @@ class ClientIpStore(ClientIpWorkerStore): "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) - def _update_client_ips_batch_txn(self, txn, to_update): + def _update_client_ips_batch_txn( + self, + txn: LoggingTransaction, + to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]], + ) -> None: if "user_ips" in self.db_pool._unsafe_to_upsert_tables or ( not self.database_engine.can_native_upsert ): @@ -525,7 +585,7 @@ class ClientIpStore(ClientIpWorkerStore): async def get_last_client_ip_by_device( self, user_id: str, device_id: Optional[str] - ) -> Dict[Tuple[str, str], dict]: + ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]: """For each device_id listed, give the user_ip it was last seen on Args: @@ -561,12 +621,12 @@ class ClientIpStore(ClientIpWorkerStore): async def get_user_ip_and_agents( self, user: UserID, since_ts: int = 0 - ) -> List[Dict[str, Union[str, int]]]: + ) -> List[LastConnectionInfo]: """ Fetch IP/User Agent connection since a given timestamp. """ user_id = user.to_string() - results = {} + results: Dict[Tuple[str, str], Tuple[str, int]] = {} for key in self._batch_row_update: ( @@ -579,7 +639,7 @@ class ClientIpStore(ClientIpWorkerStore): if last_seen >= since_ts: results[(access_token, ip)] = (user_agent, last_seen) - def get_recent(txn): + def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]: txn.execute( """ SELECT access_token, ip, user_agent, last_seen FROM user_ips @@ -589,7 +649,7 @@ class ClientIpStore(ClientIpWorkerStore): """, (since_ts, user_id), ) - return txn.fetchall() + return cast(List[Tuple[str, str, str, int]], txn.fetchall()) rows = await self.db_pool.runInteraction( desc="get_user_ip_and_agents", func=get_recent -- cgit 1.5.1 From 8eaffe013cd37cdf9ec34875fc13d9b1249919e7 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Tue, 12 Oct 2021 18:19:21 +0100 Subject: Update `_wrap_in_base_path` type hints to preserve function arguments (#11055) --- changelog.d/11055.misc | 1 + synapse/rest/media/v1/filepath.py | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) create mode 100644 changelog.d/11055.misc (limited to 'synapse') diff --git a/changelog.d/11055.misc b/changelog.d/11055.misc new file mode 100644 index 0000000000..27688c3214 --- /dev/null +++ b/changelog.d/11055.misc @@ -0,0 +1 @@ +Improve type hints for `_wrap_in_base_path` decorator used by `MediaFilePaths`. diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index 08bd85f664..eb66b749a2 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -16,12 +16,15 @@ import functools import os import re -from typing import Any, Callable, List +from typing import Any, Callable, List, TypeVar, cast NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") -def _wrap_in_base_path(func: Callable[..., str]) -> Callable[..., str]: +F = TypeVar("F", bound=Callable[..., str]) + + +def _wrap_in_base_path(func: F) -> F: """Takes a function that returns a relative path and turns it into an absolute path based on the location of the primary media store """ @@ -31,7 +34,7 @@ def _wrap_in_base_path(func: Callable[..., str]) -> Callable[..., str]: path = func(self, *args, **kwargs) return os.path.join(self.base_path, path) - return _wrapped + return cast(F, _wrapped) class MediaFilePaths: -- cgit 1.5.1 From 2a2b189130c6ad040d00548d3dfc7030a9618a57 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Wed, 13 Oct 2021 08:42:41 +0100 Subject: Mark Module API error imports as re-exported and mark Synapse as containing type annotations (#11054) --- MANIFEST.in | 1 + changelog.d/11054.misc | 1 + synapse/module_api/errors.py | 11 +++++++++-- synapse/py.typed | 0 4 files changed, 11 insertions(+), 2 deletions(-) create mode 100644 changelog.d/11054.misc create mode 100644 synapse/py.typed (limited to 'synapse') diff --git a/MANIFEST.in b/MANIFEST.in index 44d5cc7618..c24786c3b3 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -8,6 +8,7 @@ include demo/demo.tls.dh include demo/*.py include demo/*.sh +include synapse/py.typed recursive-include synapse/storage *.sql recursive-include synapse/storage *.sql.postgres recursive-include synapse/storage *.sql.sqlite diff --git a/changelog.d/11054.misc b/changelog.d/11054.misc new file mode 100644 index 0000000000..1103368fec --- /dev/null +++ b/changelog.d/11054.misc @@ -0,0 +1 @@ +Mark the Synapse package as containing type annotations and fix export declarations so that Synapse pluggable modules may be type checked against Synapse. diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py index 98ea911a81..1db900e41f 100644 --- a/synapse/module_api/errors.py +++ b/synapse/module_api/errors.py @@ -14,9 +14,16 @@ """Exception types which are exposed as part of the stable module API""" -from synapse.api.errors import ( # noqa: F401 +from synapse.api.errors import ( InvalidClientCredentialsError, RedirectException, SynapseError, ) -from synapse.config._base import ConfigError # noqa: F401 +from synapse.config._base import ConfigError + +__all__ = [ + "InvalidClientCredentialsError", + "RedirectException", + "SynapseError", + "ConfigError", +] diff --git a/synapse/py.typed b/synapse/py.typed new file mode 100644 index 0000000000..e69de29bb2 -- cgit 1.5.1 From 732bbf6737813b75e0cf9a255cae73f529c981ec Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 13 Oct 2021 07:00:07 -0400 Subject: Be more lenient when parsing the version for oEmbed responses. (#11065) --- changelog.d/11065.misc | 1 + mypy.ini | 1 + synapse/rest/media/v1/oembed.py | 13 ++++--- synapse/rest/media/v1/preview_url_resource.py | 2 +- tests/rest/media/v1/test_oembed.py | 51 +++++++++++++++++++++++++++ 5 files changed, 60 insertions(+), 8 deletions(-) create mode 100644 changelog.d/11065.misc create mode 100644 tests/rest/media/v1/test_oembed.py (limited to 'synapse') diff --git a/changelog.d/11065.misc b/changelog.d/11065.misc new file mode 100644 index 0000000000..c6f37fc52b --- /dev/null +++ b/changelog.d/11065.misc @@ -0,0 +1 @@ +Be more lenient when parsing oEmbed response versions. diff --git a/mypy.ini b/mypy.ini index 22768a037d..93757cd95d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -90,6 +90,7 @@ files = tests/rest/client/test_login.py, tests/rest/client/test_auth.py, tests/rest/media/v1/test_filepath.py, + tests/rest/media/v1/test_oembed.py, tests/storage/test_state.py, tests/storage/test_user_directory.py, tests/util/test_itertools.py, diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py index 78b1603f19..2a59552c20 100644 --- a/synapse/rest/media/v1/oembed.py +++ b/synapse/rest/media/v1/oembed.py @@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, List, Optional import attr -from synapse.http.client import SimpleHttpClient from synapse.types import JsonDict from synapse.util import json_decoder @@ -48,7 +47,7 @@ class OEmbedProvider: requesting/parsing oEmbed content. """ - def __init__(self, hs: "HomeServer", client: SimpleHttpClient): + def __init__(self, hs: "HomeServer"): self._oembed_patterns = {} for oembed_endpoint in hs.config.oembed.oembed_patterns: api_endpoint = oembed_endpoint.api_endpoint @@ -69,7 +68,6 @@ class OEmbedProvider: # Iterate through each URL pattern and point it to the endpoint. for pattern in oembed_endpoint.url_patterns: self._oembed_patterns[pattern] = api_endpoint - self._client = client def get_oembed_url(self, url: str) -> Optional[str]: """ @@ -139,10 +137,11 @@ class OEmbedProvider: # oEmbed responses *must* be UTF-8 according to the spec. oembed = json_decoder.decode(raw_body.decode("utf-8")) - # Ensure there's a version of 1.0. - oembed_version = oembed["version"] - if oembed_version != "1.0": - raise RuntimeError(f"Invalid version: {oembed_version}") + # The version is a required string field, but not always provided, + # or sometimes provided as a float. Be lenient. + oembed_version = oembed.get("version", "1.0") + if oembed_version != "1.0" and oembed_version != 1: + raise RuntimeError(f"Invalid oEmbed version: {oembed_version}") # Ensure the cache age is None or an int. cache_age = oembed.get("cache_age") diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 1fe0fc8aa9..5bddd21ef1 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -140,7 +140,7 @@ class PreviewUrlResource(DirectServeJsonResource): self.primary_base_path = media_repo.primary_base_path self.media_storage = media_storage - self._oembed = OEmbedProvider(hs, self.client) + self._oembed = OEmbedProvider(hs) # We run the background jobs if we're the instance specified (or no # instance is specified, where we assume there is only one instance diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py new file mode 100644 index 0000000000..048d0ca44a --- /dev/null +++ b/tests/rest/media/v1/test_oembed.py @@ -0,0 +1,51 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest.media.v1.oembed import OEmbedProvider +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests.unittest import HomeserverTestCase + + +class OEmbedTests(HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + self.oembed = OEmbedProvider(homeserver) + + def parse_response(self, response: JsonDict): + return self.oembed.parse_oembed_response( + "https://test", json.dumps(response).encode("utf-8") + ) + + def test_version(self): + """Accept versions that are similar to 1.0 as a string or int (or missing).""" + for version in ("1.0", 1.0, 1): + result = self.parse_response({"version": version, "type": "link"}) + # An empty Open Graph response is an error, ensure the URL is included. + self.assertIn("og:url", result.open_graph_result) + + # A missing version should be treated as 1.0. + result = self.parse_response({"type": "link"}) + self.assertIn("og:url", result.open_graph_result) + + # Invalid versions should be rejected. + for version in ("2.0", "1", 1.1, 0, None, {}, []): + result = self.parse_response({"version": version, "type": "link"}) + # An empty Open Graph response is an error, ensure the URL is included. + self.assertEqual({}, result.open_graph_result) -- cgit 1.5.1 From cdd308845ba22fef22a39ed5bf904b438e48b491 Mon Sep 17 00:00:00 2001 From: Azrenbeth <77782548+Azrenbeth@users.noreply.github.com> Date: Wed, 13 Oct 2021 12:21:52 +0100 Subject: Port the Password Auth Providers module interface to the new generic interface (#10548) Co-authored-by: Azrenbeth <7782548+Azrenbeth@users.noreply.github.com> Co-authored-by: Brendan Abolivier --- changelog.d/10548.feature | 1 + docs/SUMMARY.md | 1 + docs/modules/password_auth_provider_callbacks.md | 153 +++++++ docs/modules/porting_legacy_module.md | 3 + docs/password_auth_providers.md | 6 + docs/sample_config.yaml | 28 -- synapse/app/_base.py | 2 + synapse/config/password_auth_providers.py | 53 +-- synapse/handlers/auth.py | 528 +++++++++++++++++------ synapse/module_api/__init__.py | 9 + synapse/server.py | 6 +- synapse/storage/prepare_database.py | 2 + tests/handlers/test_password_providers.py | 223 ++++++++-- 13 files changed, 790 insertions(+), 225 deletions(-) create mode 100644 changelog.d/10548.feature create mode 100644 docs/modules/password_auth_provider_callbacks.md (limited to 'synapse') diff --git a/changelog.d/10548.feature b/changelog.d/10548.feature new file mode 100644 index 0000000000..263a811faf --- /dev/null +++ b/changelog.d/10548.feature @@ -0,0 +1 @@ +Port the Password Auth Providers module interface to the new generic interface. \ No newline at end of file diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index bdb44543b8..35412ea92c 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -43,6 +43,7 @@ - [Third-party rules callbacks](modules/third_party_rules_callbacks.md) - [Presence router callbacks](modules/presence_router_callbacks.md) - [Account validity callbacks](modules/account_validity_callbacks.md) + - [Password auth provider callbacks](modules/password_auth_provider_callbacks.md) - [Porting a legacy module to the new interface](modules/porting_legacy_module.md) - [Workers](workers.md) - [Using `synctl` with Workers](synctl_workers.md) diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md new file mode 100644 index 0000000000..36417dd39e --- /dev/null +++ b/docs/modules/password_auth_provider_callbacks.md @@ -0,0 +1,153 @@ +# Password auth provider callbacks + +Password auth providers offer a way for server administrators to integrate +their Synapse installation with an external authentication system. The callbacks can be +registered by using the Module API's `register_password_auth_provider_callbacks` method. + +## Callbacks + +### `auth_checkers` + +``` + auth_checkers: Dict[Tuple[str,Tuple], Callable] +``` + +A dict mapping from tuples of a login type identifier (such as `m.login.password`) and a +tuple of field names (such as `("password", "secret_thing")`) to authentication checking +callbacks, which should be of the following form: + +```python +async def check_auth( + user: str, + login_type: str, + login_dict: "synapse.module_api.JsonDict", +) -> Optional[ + Tuple[ + str, + Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]] + ] +] +``` + +The login type and field names should be provided by the user in the +request to the `/login` API. [The Matrix specification](https://matrix.org/docs/spec/client_server/latest#authentication-types) +defines some types, however user defined ones are also allowed. + +The callback is passed the `user` field provided by the client (which might not be in +`@username:server` form), the login type, and a dictionary of login secrets passed by +the client. + +If the authentication is successful, the module must return the user's Matrix ID (e.g. +`@alice:example.com`) and optionally a callback to be called with the response to the +`/login` request. If the module doesn't wish to return a callback, it must return `None` +instead. + +If the authentication is unsuccessful, the module must return `None`. + +### `check_3pid_auth` + +```python +async def check_3pid_auth( + medium: str, + address: str, + password: str, +) -> Optional[ + Tuple[ + str, + Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]] + ] +] +``` + +Called when a user attempts to register or log in with a third party identifier, +such as email. It is passed the medium (eg. `email`), an address (eg. `jdoe@example.com`) +and the user's password. + +If the authentication is successful, the module must return the user's Matrix ID (e.g. +`@alice:example.com`) and optionally a callback to be called with the response to the `/login` request. +If the module doesn't wish to return a callback, it must return None instead. + +If the authentication is unsuccessful, the module must return None. + +### `on_logged_out` + +```python +async def on_logged_out( + user_id: str, + device_id: Optional[str], + access_token: str +) -> None +``` +Called during a logout request for a user. It is passed the qualified user ID, the ID of the +deactivated device (if any: access tokens are occasionally created without an associated +device ID), and the (now deactivated) access token. + +## Example + +The example module below implements authentication checkers for two different login types: +- `my.login.type` + - Expects a `my_field` field to be sent to `/login` + - Is checked by the method: `self.check_my_login` +- `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based)) + - Expects a `password` field to be sent to `/login` + - Is checked by the method: `self.check_pass` + + +```python +from typing import Awaitable, Callable, Optional, Tuple + +import synapse +from synapse import module_api + + +class MyAuthProvider: + def __init__(self, config: dict, api: module_api): + + self.api = api + + self.credentials = { + "bob": "building", + "@scoop:matrix.org": "digging", + } + + api.register_password_auth_provider_callbacks( + auth_checkers={ + ("my.login_type", ("my_field",)): self.check_my_login, + ("m.login.password", ("password",)): self.check_pass, + }, + ) + + async def check_my_login( + self, + username: str, + login_type: str, + login_dict: "synapse.module_api.JsonDict", + ) -> Optional[ + Tuple[ + str, + Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]], + ] + ]: + if login_type != "my.login_type": + return None + + if self.credentials.get(username) == login_dict.get("my_field"): + return self.api.get_qualified_user_id(username) + + async def check_pass( + self, + username: str, + login_type: str, + login_dict: "synapse.module_api.JsonDict", + ) -> Optional[ + Tuple[ + str, + Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]], + ] + ]: + if login_type != "m.login.password": + return None + + if self.credentials.get(username) == login_dict.get("password"): + return self.api.get_qualified_user_id(username) +``` diff --git a/docs/modules/porting_legacy_module.md b/docs/modules/porting_legacy_module.md index a7a251e535..89084eb7b3 100644 --- a/docs/modules/porting_legacy_module.md +++ b/docs/modules/porting_legacy_module.md @@ -12,6 +12,9 @@ should register this resource in its `__init__` method using the `register_web_r method from the `ModuleApi` class (see [this section](writing_a_module.html#registering-a-web-resource) for more info). +There is no longer a `get_db_schema_files` callback provided for password auth provider modules. Any +changes to the database should now be made by the module using the module API class. + The module's author should also update any example in the module's configuration to only use the new `modules` section in Synapse's configuration file (see [this section](index.html#using-modules) for more info). diff --git a/docs/password_auth_providers.md b/docs/password_auth_providers.md index d2cdb9b2f4..d7beacfff3 100644 --- a/docs/password_auth_providers.md +++ b/docs/password_auth_providers.md @@ -1,3 +1,9 @@ +

+This page of the Synapse documentation is now deprecated. For up to date +documentation on setting up or writing a password auth provider module, please see +this page. +

+ # Password auth provider modules Password auth providers offer a way for server administrators to diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 166cec38d3..7bfaed483b 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2260,34 +2260,6 @@ email: #email_validation: "[%(server_name)s] Validate your email" -# Password providers allow homeserver administrators to integrate -# their Synapse installation with existing authentication methods -# ex. LDAP, external tokens, etc. -# -# For more information and known implementations, please see -# https://matrix-org.github.io/synapse/latest/password_auth_providers.html -# -# Note: instances wishing to use SAML or CAS authentication should -# instead use the `saml2_config` or `cas_config` options, -# respectively. -# -password_providers: -# # Example config for an LDAP auth provider -# - module: "ldap_auth_provider.LdapAuthProvider" -# config: -# enabled: true -# uri: "ldap://ldap.example.com:389" -# start_tls: true -# base: "ou=users,dc=example,dc=com" -# attributes: -# uid: "cn" -# mail: "email" -# name: "givenName" -# #bind_dn: -# #bind_password: -# #filter: "(objectClass=posixAccount)" - - ## Push ## diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 4a204a5823..bb4d53d778 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -42,6 +42,7 @@ from synapse.crypto import context_factory from synapse.events.presence_router import load_legacy_presence_router from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.events.third_party_rules import load_legacy_third_party_event_rules +from synapse.handlers.auth import load_legacy_password_auth_providers from synapse.logging.context import PreserveLoggingContext from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.jemalloc import setup_jemalloc_stats @@ -379,6 +380,7 @@ async def start(hs: "HomeServer"): load_legacy_spam_checkers(hs) load_legacy_third_party_event_rules(hs) load_legacy_presence_router(hs) + load_legacy_password_auth_providers(hs) # If we've configured an expiry time for caches, start the background job now. setup_expire_lru_cache_entries(hs) diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index 83994df798..f980102b45 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -25,6 +25,29 @@ class PasswordAuthProviderConfig(Config): section = "authproviders" def read_config(self, config, **kwargs): + """Parses the old password auth providers config. The config format looks like this: + + password_providers: + # Example config for an LDAP auth provider + - module: "ldap_auth_provider.LdapAuthProvider" + config: + enabled: true + uri: "ldap://ldap.example.com:389" + start_tls: true + base: "ou=users,dc=example,dc=com" + attributes: + uid: "cn" + mail: "email" + name: "givenName" + #bind_dn: + #bind_password: + #filter: "(objectClass=posixAccount)" + + We expect admins to use modules for this feature (which is why it doesn't appear + in the sample config file), but we want to keep support for it around for a bit + for backwards compatibility. + """ + self.password_providers: List[Tuple[Type, Any]] = [] providers = [] @@ -49,33 +72,3 @@ class PasswordAuthProviderConfig(Config): ) self.password_providers.append((provider_class, provider_config)) - - def generate_config_section(self, **kwargs): - return """\ - # Password providers allow homeserver administrators to integrate - # their Synapse installation with existing authentication methods - # ex. LDAP, external tokens, etc. - # - # For more information and known implementations, please see - # https://matrix-org.github.io/synapse/latest/password_auth_providers.html - # - # Note: instances wishing to use SAML or CAS authentication should - # instead use the `saml2_config` or `cas_config` options, - # respectively. - # - password_providers: - # # Example config for an LDAP auth provider - # - module: "ldap_auth_provider.LdapAuthProvider" - # config: - # enabled: true - # uri: "ldap://ldap.example.com:389" - # start_tls: true - # base: "ou=users,dc=example,dc=com" - # attributes: - # uid: "cn" - # mail: "email" - # name: "givenName" - # #bind_dn: - # #bind_password: - # #filter: "(objectClass=posixAccount)" - """ diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index f4612a5b92..ebe75a9e9b 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -200,46 +200,13 @@ class AuthHandler: self.bcrypt_rounds = hs.config.registration.bcrypt_rounds - # we can't use hs.get_module_api() here, because to do so will create an - # import loop. - # - # TODO: refactor this class to separate the lower-level stuff that - # ModuleApi can use from the higher-level stuff that uses ModuleApi, as - # better way to break the loop - account_handler = ModuleApi(hs, self) - - self.password_providers = [ - PasswordProvider.load(module, config, account_handler) - for module, config in hs.config.authproviders.password_providers - ] - - logger.info("Extra password_providers: %s", self.password_providers) + self.password_auth_provider = hs.get_password_auth_provider() self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.auth.password_enabled self._password_localdb_enabled = hs.config.auth.password_localdb_enabled - # start out by assuming PASSWORD is enabled; we will remove it later if not. - login_types = set() - if self._password_localdb_enabled: - login_types.add(LoginType.PASSWORD) - - for provider in self.password_providers: - login_types.update(provider.get_supported_login_types().keys()) - - if not self._password_enabled: - login_types.discard(LoginType.PASSWORD) - - # Some clients just pick the first type in the list. In this case, we want - # them to use PASSWORD (rather than token or whatever), so we want to make sure - # that comes first, where it's present. - self._supported_login_types = [] - if LoginType.PASSWORD in login_types: - self._supported_login_types.append(LoginType.PASSWORD) - login_types.remove(LoginType.PASSWORD) - self._supported_login_types.extend(login_types) - # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. self._failed_uia_attempts_ratelimiter = Ratelimiter( @@ -427,11 +394,10 @@ class AuthHandler: ui_auth_types.add(LoginType.PASSWORD) # also allow auth from password providers - for provider in self.password_providers: - for t in provider.get_supported_login_types().keys(): - if t == LoginType.PASSWORD and not self._password_enabled: - continue - ui_auth_types.add(t) + for t in self.password_auth_provider.get_supported_login_types().keys(): + if t == LoginType.PASSWORD and not self._password_enabled: + continue + ui_auth_types.add(t) # if sso is enabled, allow the user to log in via SSO iff they have a mapping # from sso to mxid. @@ -1038,7 +1004,25 @@ class AuthHandler: Returns: login types """ - return self._supported_login_types + # Load any login types registered by modules + # This is stored in the password_auth_provider so this doesn't trigger + # any callbacks + types = list(self.password_auth_provider.get_supported_login_types().keys()) + + # This list should include PASSWORD if (either _password_localdb_enabled is + # true or if one of the modules registered it) AND _password_enabled is true + # Also: + # Some clients just pick the first type in the list. In this case, we want + # them to use PASSWORD (rather than token or whatever), so we want to make sure + # that comes first, where it's present. + if LoginType.PASSWORD in types: + types.remove(LoginType.PASSWORD) + if self._password_enabled: + types.insert(0, LoginType.PASSWORD) + elif self._password_localdb_enabled and self._password_enabled: + types.insert(0, LoginType.PASSWORD) + + return types async def validate_login( self, @@ -1217,15 +1201,20 @@ class AuthHandler: known_login_type = False - for provider in self.password_providers: - supported_login_types = provider.get_supported_login_types() - if login_type not in supported_login_types: - # this password provider doesn't understand this login type - continue - + # Check if login_type matches a type registered by one of the modules + # We don't need to remove LoginType.PASSWORD from the list if password login is + # disabled, since if that were the case then by this point we know that the + # login_type is not LoginType.PASSWORD + supported_login_types = self.password_auth_provider.get_supported_login_types() + # check if the login type being used is supported by a module + if login_type in supported_login_types: + # Make a note that this login type is supported by the server known_login_type = True + # Get all the fields expected for this login types login_fields = supported_login_types[login_type] + # go through the login submission and keep track of which required fields are + # provided/not provided missing_fields = [] login_dict = {} for f in login_fields: @@ -1233,6 +1222,7 @@ class AuthHandler: missing_fields.append(f) else: login_dict[f] = login_submission[f] + # raise an error if any of the expected fields for that login type weren't provided if missing_fields: raise SynapseError( 400, @@ -1240,10 +1230,15 @@ class AuthHandler: % (login_type, missing_fields), ) - result = await provider.check_auth(username, login_type, login_dict) + # call all of the check_auth hooks for that login_type + # it will return a result once the first success is found (or None otherwise) + result = await self.password_auth_provider.check_auth( + username, login_type, login_dict + ) if result: return result + # if no module managed to authenticate the user, then fallback to built in password based auth if login_type == LoginType.PASSWORD and self._password_localdb_enabled: known_login_type = True @@ -1282,11 +1277,16 @@ class AuthHandler: completed login/registration, or `None`. If authentication was unsuccessful, `user_id` and `callback` are both `None`. """ - for provider in self.password_providers: - result = await provider.check_3pid_auth(medium, address, password) - if result: - return result + # call all of the check_3pid_auth callbacks + # Result will be from the first callback that returns something other than None + # If all the callbacks return None, then result is also set to None + result = await self.password_auth_provider.check_3pid_auth( + medium, address, password + ) + if result: + return result + # if result is None then return (None, None) return None, None async def _check_local_password(self, user_id: str, password: str) -> Optional[str]: @@ -1365,13 +1365,12 @@ class AuthHandler: user_info = await self.auth.get_user_by_access_token(access_token) await self.store.delete_access_token(access_token) - # see if any of our auth providers want to know about this - for provider in self.password_providers: - await provider.on_logged_out( - user_id=user_info.user_id, - device_id=user_info.device_id, - access_token=access_token, - ) + # see if any modules want to know about this + await self.password_auth_provider.on_logged_out( + user_id=user_info.user_id, + device_id=user_info.device_id, + access_token=access_token, + ) # delete pushers associated with this access token if user_info.token_id is not None: @@ -1398,12 +1397,11 @@ class AuthHandler: user_id, except_token_id=except_token_id, device_id=device_id ) - # see if any of our auth providers want to know about this - for provider in self.password_providers: - for token, _, device_id in tokens_and_devices: - await provider.on_logged_out( - user_id=user_id, device_id=device_id, access_token=token - ) + # see if any modules want to know about this + for token, _, device_id in tokens_and_devices: + await self.password_auth_provider.on_logged_out( + user_id=user_id, device_id=device_id, access_token=token + ) # delete pushers associated with the access tokens await self.hs.get_pusherpool().remove_pushers_by_access_token( @@ -1811,40 +1809,228 @@ class MacaroonGenerator: return macaroon -class PasswordProvider: - """Wrapper for a password auth provider module +def load_legacy_password_auth_providers(hs: "HomeServer") -> None: + module_api = hs.get_module_api() + for module, config in hs.config.authproviders.password_providers: + load_single_legacy_password_auth_provider( + module=module, config=config, api=module_api + ) - This class abstracts out all of the backwards-compatibility hacks for - password providers, to provide a consistent interface. - """ - @classmethod - def load( - cls, module: Type, config: JsonDict, module_api: ModuleApi - ) -> "PasswordProvider": - try: - pp = module(config=config, account_handler=module_api) - except Exception as e: - logger.error("Error while initializing %r: %s", module, e) - raise - return cls(pp, module_api) +def load_single_legacy_password_auth_provider( + module: Type, config: JsonDict, api: ModuleApi +) -> None: + try: + provider = module(config=config, account_handler=api) + except Exception as e: + logger.error("Error while initializing %r: %s", module, e) + raise + + # The known hooks. If a module implements a method who's name appears in this set + # we'll want to register it + password_auth_provider_methods = { + "check_3pid_auth", + "on_logged_out", + } + + # All methods that the module provides should be async, but this wasn't enforced + # in the old module system, so we wrap them if needed + def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: + # f might be None if the callback isn't implemented by the module. In this + # case we don't want to register a callback at all so we return None. + if f is None: + return None + + # We need to wrap check_password because its old form would return a boolean + # but we now want it to behave just like check_auth() and return the matrix id of + # the user if authentication succeeded or None otherwise + if f.__name__ == "check_password": + + async def wrapped_check_password( + username: str, login_type: str, login_dict: JsonDict + ) -> Optional[Tuple[str, Optional[Callable]]]: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + matrix_user_id = api.get_qualified_user_id(username) + password = login_dict["password"] + + is_valid = await f(matrix_user_id, password) + + if is_valid: + return matrix_user_id, None + + return None - def __init__(self, pp: "PasswordProvider", module_api: ModuleApi): - self._pp = pp - self._module_api = module_api + return wrapped_check_password + + # We need to wrap check_auth as in the old form it could return + # just a str, but now it must return Optional[Tuple[str, Optional[Callable]] + if f.__name__ == "check_auth": + + async def wrapped_check_auth( + username: str, login_type: str, login_dict: JsonDict + ) -> Optional[Tuple[str, Optional[Callable]]]: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + result = await f(username, login_type, login_dict) + + if isinstance(result, str): + return result, None + + return result + + return wrapped_check_auth + + # We need to wrap check_3pid_auth as in the old form it could return + # just a str, but now it must return Optional[Tuple[str, Optional[Callable]] + if f.__name__ == "check_3pid_auth": + + async def wrapped_check_3pid_auth( + medium: str, address: str, password: str + ) -> Optional[Tuple[str, Optional[Callable]]]: + # We've already made sure f is not None above, but mypy doesn't do well + # across function boundaries so we need to tell it f is definitely not + # None. + assert f is not None + + result = await f(medium, address, password) + + if isinstance(result, str): + return result, None + + return result - self._supported_login_types = {} + return wrapped_check_3pid_auth - # grandfather in check_password support - if hasattr(self._pp, "check_password"): - self._supported_login_types[LoginType.PASSWORD] = ("password",) + def run(*args: Tuple, **kwargs: Dict) -> Awaitable: + # mypy doesn't do well across function boundaries so we need to tell it + # f is definitely not None. + assert f is not None - g = getattr(self._pp, "get_supported_login_types", None) - if g: - self._supported_login_types.update(g()) + return maybe_awaitable(f(*args, **kwargs)) - def __str__(self) -> str: - return str(self._pp) + return run + + # populate hooks with the implemented methods, wrapped with async_wrapper + hooks = { + hook: async_wrapper(getattr(provider, hook, None)) + for hook in password_auth_provider_methods + } + + supported_login_types = {} + # call get_supported_login_types and add that to the dict + g = getattr(provider, "get_supported_login_types", None) + if g is not None: + # Note the old module style also called get_supported_login_types at loading time + # and it is synchronous + supported_login_types.update(g()) + + auth_checkers = {} + # Legacy modules have a check_auth method which expects to be called with one of + # the keys returned by get_supported_login_types. New style modules register a + # dictionary of login_type->check_auth_method mappings + check_auth = async_wrapper(getattr(provider, "check_auth", None)) + if check_auth is not None: + for login_type, fields in supported_login_types.items(): + # need tuple(fields) since fields can be any Iterable type (so may not be hashable) + auth_checkers[(login_type, tuple(fields))] = check_auth + + # if it has a "check_password" method then it should handle all auth checks + # with login type of LoginType.PASSWORD + check_password = async_wrapper(getattr(provider, "check_password", None)) + if check_password is not None: + # need to use a tuple here for ("password",) not a list since lists aren't hashable + auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password + + api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers) + + +CHECK_3PID_AUTH_CALLBACK = Callable[ + [str, str, str], + Awaitable[ + Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]] + ], +] +ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable] +CHECK_AUTH_CALLBACK = Callable[ + [str, str, JsonDict], + Awaitable[ + Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]] + ], +] + + +class PasswordAuthProvider: + """ + A class that the AuthHandler calls when authenticating users + It allows modules to provide alternative methods for authentication + """ + + def __init__(self) -> None: + # lists of callbacks + self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = [] + self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = [] + + # Mapping from login type to login parameters + self._supported_login_types: Dict[str, Iterable[str]] = {} + + # Mapping from login type to auth checker callbacks + self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {} + + def register_password_auth_provider_callbacks( + self, + check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None, + on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None, + auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None, + ) -> None: + # Register check_3pid_auth callback + if check_3pid_auth is not None: + self.check_3pid_auth_callbacks.append(check_3pid_auth) + + # register on_logged_out callback + if on_logged_out is not None: + self.on_logged_out_callbacks.append(on_logged_out) + + if auth_checkers is not None: + # register a new supported login_type + # Iterate through all of the types being registered + for (login_type, fields), callback in auth_checkers.items(): + # Note: fields may be empty here. This would allow a modules auth checker to + # be called with just 'login_type' and no password or other secrets + + # Need to check that all the field names are strings or may get nasty errors later + for f in fields: + if not isinstance(f, str): + raise RuntimeError( + "A module tried to register support for login type: %s with parameters %s" + " but all parameter names must be strings" + % (login_type, fields) + ) + + # 2 modules supporting the same login type must expect the same fields + # e.g. 1 can't expect "pass" if the other expects "password" + # so throw an exception if that happens + if login_type not in self._supported_login_types.get(login_type, []): + self._supported_login_types[login_type] = fields + else: + fields_currently_supported = self._supported_login_types.get( + login_type + ) + if fields_currently_supported != fields: + raise RuntimeError( + "A module tried to register support for login type: %s with parameters %s" + " but another module had already registered support for that type with parameters %s" + % (login_type, fields, fields_currently_supported) + ) + + # Add the new method to the list of auth_checker_callbacks for this login type + self.auth_checker_callbacks.setdefault(login_type, []).append(callback) def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: """Get the login types supported by this password provider @@ -1852,20 +2038,15 @@ class PasswordProvider: Returns a map from a login type identifier (such as m.login.password) to an iterable giving the fields which must be provided by the user in the submission to the /login API. - - This wrapper adds m.login.password to the list if the underlying password - provider supports the check_password() api. """ + return self._supported_login_types async def check_auth( self, username: str, login_type: str, login_dict: JsonDict - ) -> Optional[Tuple[str, Optional[Callable]]]: + ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]: """Check if the user has presented valid login credentials - This wrapper also calls check_password() if the underlying password provider - supports the check_password() api and the login type is m.login.password. - Args: username: user id presented by the client. Either an MXID or an unqualified username. @@ -1879,63 +2060,130 @@ class PasswordProvider: user, and `callback` is an optional callback which will be called with the result from the /login call (including access_token, device_id, etc.) """ - # first grandfather in a call to check_password - if login_type == LoginType.PASSWORD: - check_password = getattr(self._pp, "check_password", None) - if check_password: - qualified_user_id = self._module_api.get_qualified_user_id(username) - is_valid = await check_password( - qualified_user_id, login_dict["password"] - ) - if is_valid: - return qualified_user_id, None - check_auth = getattr(self._pp, "check_auth", None) - if not check_auth: - return None - result = await check_auth(username, login_type, login_dict) + # Go through all callbacks for the login type until one returns with a value + # other than None (i.e. until a callback returns a success) + for callback in self.auth_checker_callbacks[login_type]: + try: + result = await callback(username, login_type, login_dict) + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + continue - # Check if the return value is a str or a tuple - if isinstance(result, str): - # If it's a str, set callback function to None - return result, None + if result is not None: + # Check that the callback returned a Tuple[str, Optional[Callable]] + # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks + # result is always the right type, but as it is 3rd party code it might not be + + if not isinstance(result, tuple) or len(result) != 2: + logger.warning( + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue - return result + # pull out the two parts of the tuple so we can do type checking + str_result, callback_result = result + + # the 1st item in the tuple should be a str + if not isinstance(str_result, str): + logger.warning( # type: ignore[unreachable] + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # the second should be Optional[Callable] + if callback_result is not None: + if not callable(callback_result): + logger.warning( # type: ignore[unreachable] + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # The result is a (str, Optional[callback]) tuple so return the successful result + return result + + # If this point has been reached then none of the callbacks successfully authenticated + # the user so return None + return None async def check_3pid_auth( self, medium: str, address: str, password: str - ) -> Optional[Tuple[str, Optional[Callable]]]: - g = getattr(self._pp, "check_3pid_auth", None) - if not g: - return None - + ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]: # This function is able to return a deferred that either # resolves None, meaning authentication failure, or upon # success, to a str (which is the user_id) or a tuple of # (user_id, callback_func), where callback_func should be run # after we've finished everything else - result = await g(medium, address, password) - # Check if the return value is a str or a tuple - if isinstance(result, str): - # If it's a str, set callback function to None - return result, None + for callback in self.check_3pid_auth_callbacks: + try: + result = await callback(medium, address, password) + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + continue - return result + if result is not None: + # Check that the callback returned a Tuple[str, Optional[Callable]] + # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks + # result is always the right type, but as it is 3rd party code it might not be + + if not isinstance(result, tuple) or len(result) != 2: + logger.warning( + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # pull out the two parts of the tuple so we can do type checking + str_result, callback_result = result + + # the 1st item in the tuple should be a str + if not isinstance(str_result, str): + logger.warning( # type: ignore[unreachable] + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # the second should be Optional[Callable] + if callback_result is not None: + if not callable(callback_result): + logger.warning( # type: ignore[unreachable] + "Wrong type returned by module API callback %s: %s, expected" + " Optional[Tuple[str, Optional[Callable]]]", + callback, + result, + ) + continue + + # The result is a (str, Optional[callback]) tuple so return the successful result + return result + + # If this point has been reached then none of the callbacks successfully authenticated + # the user so return None + return None async def on_logged_out( self, user_id: str, device_id: Optional[str], access_token: str ) -> None: - g = getattr(self._pp, "on_logged_out", None) - if not g: - return - # This might return an awaitable, if it does block the log out - # until it completes. - await maybe_awaitable( - g( - user_id=user_id, - device_id=device_id, - access_token=access_token, - ) - ) + # call all of the on_logged_out callbacks + for callback in self.on_logged_out_callbacks: + try: + callback(user_id, device_id, access_token) + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + continue diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index b2a228c231..ab7ef8f950 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -45,6 +45,7 @@ from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.rest.client.login import LoginResponse from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter @@ -83,6 +84,8 @@ __all__ = [ "DirectServeJsonResource", "ModuleApi", "PRESENCE_ALL_USERS", + "LoginResponse", + "JsonDict", ] logger = logging.getLogger(__name__) @@ -139,6 +142,7 @@ class ModuleApi: self._spam_checker = hs.get_spam_checker() self._account_validity_handler = hs.get_account_validity_handler() self._third_party_event_rules = hs.get_third_party_event_rules() + self._password_auth_provider = hs.get_password_auth_provider() self._presence_router = hs.get_presence_router() ################################################################################# @@ -164,6 +168,11 @@ class ModuleApi: """Registers callbacks for presence router capabilities.""" return self._presence_router.register_presence_router_callbacks + @property + def register_password_auth_provider_callbacks(self): + """Registers callbacks for password auth provider capabilities.""" + return self._password_auth_provider.register_password_auth_provider_callbacks + def register_web_resource(self, path: str, resource: IResource): """Registers a web resource to be served at the given path. diff --git a/synapse/server.py b/synapse/server.py index 5bc045d615..a64c846d1c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -65,7 +65,7 @@ from synapse.handlers.account_data import AccountDataHandler from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.admin import AdminHandler from synapse.handlers.appservice import ApplicationServicesHandler -from synapse.handlers.auth import AuthHandler, MacaroonGenerator +from synapse.handlers.auth import AuthHandler, MacaroonGenerator, PasswordAuthProvider from synapse.handlers.cas import CasHandler from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler @@ -687,6 +687,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_third_party_event_rules(self) -> ThirdPartyEventRules: return ThirdPartyEventRules(self) + @cache_in_self + def get_password_auth_provider(self) -> PasswordAuthProvider: + return PasswordAuthProvider() + @cache_in_self def get_room_member_handler(self) -> RoomMemberHandler: if self.config.worker.worker_app: diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 11ca47ea28..1629d2a53c 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -549,6 +549,8 @@ def _apply_module_schemas( database_engine: config: application config """ + # This is the old way for password_auth_provider modules to make changes + # to the database. This should instead be done using the module API for (mod, _config) in config.authproviders.password_providers: if not hasattr(mod, "get_db_schema_files"): continue diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 38e6d9f536..7dd4a5a367 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -20,6 +20,8 @@ from unittest.mock import Mock from twisted.internet import defer import synapse +from synapse.handlers.auth import load_legacy_password_auth_providers +from synapse.module_api import ModuleApi from synapse.rest.client import devices, login from synapse.types import JsonDict @@ -36,8 +38,8 @@ ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_servi mock_password_provider = Mock() -class PasswordOnlyAuthProvider: - """A password_provider which only implements `check_password`.""" +class LegacyPasswordOnlyAuthProvider: + """A legacy password_provider which only implements `check_password`.""" @staticmethod def parse_config(self): @@ -50,8 +52,8 @@ class PasswordOnlyAuthProvider: return mock_password_provider.check_password(*args) -class CustomAuthProvider: - """A password_provider which implements a custom login type.""" +class LegacyCustomAuthProvider: + """A legacy password_provider which implements a custom login type.""" @staticmethod def parse_config(self): @@ -67,7 +69,23 @@ class CustomAuthProvider: return mock_password_provider.check_auth(*args) -class PasswordCustomAuthProvider: +class CustomAuthProvider: + """A module which registers password_auth_provider callbacks for a custom login type.""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, api: ModuleApi): + api.register_password_auth_provider_callbacks( + auth_checkers={("test.login_type", ("test_field",)): self.check_auth}, + ) + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + +class LegacyPasswordCustomAuthProvider: """A password_provider which implements password login via `check_auth`, as well as a custom type.""" @@ -85,8 +103,32 @@ class PasswordCustomAuthProvider: return mock_password_provider.check_auth(*args) -def providers_config(*providers: Type[Any]) -> dict: - """Returns a config dict that will enable the given password auth providers""" +class PasswordCustomAuthProvider: + """A module which registers password_auth_provider callbacks for a custom login type. + as well as a password login""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, api: ModuleApi): + api.register_password_auth_provider_callbacks( + auth_checkers={ + ("test.login_type", ("test_field",)): self.check_auth, + ("m.login.password", ("password",)): self.check_auth, + }, + ) + pass + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + def check_pass(self, *args): + return mock_password_provider.check_password(*args) + + +def legacy_providers_config(*providers: Type[Any]) -> dict: + """Returns a config dict that will enable the given legacy password auth providers""" return { "password_providers": [ {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}} @@ -95,6 +137,16 @@ def providers_config(*providers: Type[Any]) -> dict: } +def providers_config(*providers: Type[Any]) -> dict: + """Returns a config dict that will enable the given modules""" + return { + "modules": [ + {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}} + for provider in providers + ] + } + + class PasswordAuthProviderTests(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, @@ -107,8 +159,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() super().setUp() - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_password_only_auth_provider_login(self): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + # Load the modules into the homeserver + module_api = hs.get_module_api() + for module, config in hs.config.modules.loaded_modules: + module(config=config, api=module_api) + load_legacy_password_auth_providers(hs) + + return hs + + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_password_only_auth_progiver_login_legacy(self): + self.password_only_auth_provider_login_test_body() + + def password_only_auth_provider_login_test_body(self): # login flows should only have m.login.password flows = self._get_login_flows() self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) @@ -138,8 +203,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "@ USER🙂NAME :test", " pASS😢word " ) - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_password_only_auth_provider_ui_auth(self): + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_password_only_auth_provider_ui_auth_legacy(self): + self.password_only_auth_provider_ui_auth_test_body() + + def password_only_auth_provider_ui_auth_test_body(self): """UI Auth should delegate correctly to the password provider""" # create the user, otherwise access doesn't work @@ -172,8 +240,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) mock_password_provider.check_password.assert_called_once_with("@u:test", "p") - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_local_user_fallback_login(self): + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_local_user_fallback_login_legacy(self): + self.local_user_fallback_login_test_body() + + def local_user_fallback_login_test_body(self): """rejected login should fall back to local db""" self.register_user("localuser", "localpass") @@ -186,8 +257,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) self.assertEqual("@localuser:test", channel.json_body["user_id"]) - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_local_user_fallback_ui_auth(self): + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_local_user_fallback_ui_auth_legacy(self): + self.local_user_fallback_ui_auth_test_body() + + def local_user_fallback_ui_auth_test_body(self): """rejected login should fall back to local db""" self.register_user("localuser", "localpass") @@ -223,11 +297,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { - **providers_config(PasswordOnlyAuthProvider), + **legacy_providers_config(LegacyPasswordOnlyAuthProvider), "password_config": {"localdb_enabled": False}, } ) - def test_no_local_user_fallback_login(self): + def test_no_local_user_fallback_login_legacy(self): + self.no_local_user_fallback_login_test_body() + + def no_local_user_fallback_login_test_body(self): """localdb_enabled can block login with the local password""" self.register_user("localuser", "localpass") @@ -242,11 +319,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { - **providers_config(PasswordOnlyAuthProvider), + **legacy_providers_config(LegacyPasswordOnlyAuthProvider), "password_config": {"localdb_enabled": False}, } ) - def test_no_local_user_fallback_ui_auth(self): + def test_no_local_user_fallback_ui_auth_legacy(self): + self.no_local_user_fallback_ui_auth_test_body() + + def no_local_user_fallback_ui_auth_test_body(self): """localdb_enabled can block ui auth with the local password""" self.register_user("localuser", "localpass") @@ -280,11 +360,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { - **providers_config(PasswordOnlyAuthProvider), + **legacy_providers_config(LegacyPasswordOnlyAuthProvider), "password_config": {"enabled": False}, } ) - def test_password_auth_disabled(self): + def test_password_auth_disabled_legacy(self): + self.password_auth_disabled_test_body() + + def password_auth_disabled_test_body(self): """password auth doesn't work if it's disabled across the board""" # login flows should be empty flows = self._get_login_flows() @@ -295,8 +378,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_password.assert_not_called() + @override_config(legacy_providers_config(LegacyCustomAuthProvider)) + def test_custom_auth_provider_login_legacy(self): + self.custom_auth_provider_login_test_body() + @override_config(providers_config(CustomAuthProvider)) def test_custom_auth_provider_login(self): + self.custom_auth_provider_login_test_body() + + def custom_auth_provider_login_test_body(self): # login flows should have the custom flow and m.login.password, since we # haven't disabled local password lookup. # (password must come first, because reasons) @@ -312,7 +402,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() - mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + mock_password_provider.check_auth.return_value = defer.succeed( + ("@user:bz", None) + ) channel = self._send_login("test.login_type", "u", test_field="y") self.assertEqual(channel.code, 200, channel.result) self.assertEqual("@user:bz", channel.json_body["user_id"]) @@ -325,7 +417,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # in these cases, but at least we can guard against the API changing # unexpectedly mock_password_provider.check_auth.return_value = defer.succeed( - "@ MALFORMED! :bz" + ("@ MALFORMED! :bz", None) ) channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") self.assertEqual(channel.code, 200, channel.result) @@ -334,8 +426,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): " USER🙂NAME ", "test.login_type", {"test_field": " abc "} ) + @override_config(legacy_providers_config(LegacyCustomAuthProvider)) + def test_custom_auth_provider_ui_auth_legacy(self): + self.custom_auth_provider_ui_auth_test_body() + @override_config(providers_config(CustomAuthProvider)) def test_custom_auth_provider_ui_auth(self): + self.custom_auth_provider_ui_auth_test_body() + + def custom_auth_provider_ui_auth_test_body(self): # register the user and log in twice, to get two devices self.register_user("localuser", "localpass") tok1 = self.login("localuser", "localpass") @@ -367,7 +466,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() # right params, but authing as the wrong user - mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + mock_password_provider.check_auth.return_value = defer.succeed( + ("@user:bz", None) + ) body["auth"]["test_field"] = "foo" channel = self._delete_device(tok1, "dev2", body) self.assertEqual(channel.code, 403) @@ -379,7 +480,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # and finally, succeed mock_password_provider.check_auth.return_value = defer.succeed( - "@localuser:test" + ("@localuser:test", None) ) channel = self._delete_device(tok1, "dev2", body) self.assertEqual(channel.code, 200) @@ -387,8 +488,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "localuser", "test.login_type", {"test_field": "foo"} ) + @override_config(legacy_providers_config(LegacyCustomAuthProvider)) + def test_custom_auth_provider_callback_legacy(self): + self.custom_auth_provider_callback_test_body() + @override_config(providers_config(CustomAuthProvider)) def test_custom_auth_provider_callback(self): + self.custom_auth_provider_callback_test_body() + + def custom_auth_provider_callback_test_body(self): callback = Mock(return_value=defer.succeed(None)) mock_password_provider.check_auth.return_value = defer.succeed( @@ -410,10 +518,22 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): for p in ["user_id", "access_token", "device_id", "home_server"]: self.assertIn(p, call_args[0]) + @override_config( + { + **legacy_providers_config(LegacyCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_custom_auth_password_disabled_legacy(self): + self.custom_auth_password_disabled_test_body() + @override_config( {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}} ) def test_custom_auth_password_disabled(self): + self.custom_auth_password_disabled_test_body() + + def custom_auth_password_disabled_test_body(self): """Test login with a custom auth provider where password login is disabled""" self.register_user("localuser", "localpass") @@ -425,6 +545,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() + @override_config( + { + **legacy_providers_config(LegacyCustomAuthProvider), + "password_config": {"enabled": False, "localdb_enabled": False}, + } + ) + def test_custom_auth_password_disabled_localdb_enabled_legacy(self): + self.custom_auth_password_disabled_localdb_enabled_test_body() + @override_config( { **providers_config(CustomAuthProvider), @@ -432,6 +561,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): } ) def test_custom_auth_password_disabled_localdb_enabled(self): + self.custom_auth_password_disabled_localdb_enabled_test_body() + + def custom_auth_password_disabled_localdb_enabled_test_body(self): """Check the localdb_enabled == enabled == False Regression test for https://github.com/matrix-org/synapse/issues/8914: check @@ -448,6 +580,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() + @override_config( + { + **legacy_providers_config(LegacyPasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_login_legacy(self): + self.password_custom_auth_password_disabled_login_test_body() + @override_config( { **providers_config(PasswordCustomAuthProvider), @@ -455,6 +596,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): } ) def test_password_custom_auth_password_disabled_login(self): + self.password_custom_auth_password_disabled_login_test_body() + + def password_custom_auth_password_disabled_login_test_body(self): """log in with a custom auth provider which implements password, but password login is disabled""" self.register_user("localuser", "localpass") @@ -466,6 +610,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): channel = self._send_password_login("localuser", "localpass") self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() + mock_password_provider.check_password.assert_not_called() + + @override_config( + { + **legacy_providers_config(LegacyPasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_ui_auth_legacy(self): + self.password_custom_auth_password_disabled_ui_auth_test_body() @override_config( { @@ -474,12 +628,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): } ) def test_password_custom_auth_password_disabled_ui_auth(self): + self.password_custom_auth_password_disabled_ui_auth_test_body() + + def password_custom_auth_password_disabled_ui_auth_test_body(self): """UI Auth with a custom auth provider which implements password, but password login is disabled""" # register the user and log in twice via the test login type to get two devices, self.register_user("localuser", "localpass") mock_password_provider.check_auth.return_value = defer.succeed( - "@localuser:test" + ("@localuser:test", None) ) channel = self._send_login("test.login_type", "localuser", test_field="") self.assertEqual(channel.code, 200, channel.result) @@ -516,6 +673,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "Password login has been disabled.", channel.json_body["error"] ) mock_password_provider.check_auth.assert_not_called() + mock_password_provider.check_password.assert_not_called() mock_password_provider.reset_mock() # successful auth @@ -526,6 +684,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.check_auth.assert_called_once_with( "localuser", "test.login_type", {"test_field": "x"} ) + mock_password_provider.check_password.assert_not_called() + + @override_config( + { + **legacy_providers_config(LegacyCustomAuthProvider), + "password_config": {"localdb_enabled": False}, + } + ) + def test_custom_auth_no_local_user_fallback_legacy(self): + self.custom_auth_no_local_user_fallback_test_body() @override_config( { @@ -534,6 +702,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): } ) def test_custom_auth_no_local_user_fallback(self): + self.custom_auth_no_local_user_fallback_test_body() + + def custom_auth_no_local_user_fallback_test_body(self): """Test login with a custom auth provider where the local db is disabled""" self.register_user("localuser", "localpass") -- cgit 1.5.1 From 1f9d0b8a7a6a777d59fe3217724f3e2ddb94a9b2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 13 Oct 2021 07:24:07 -0400 Subject: Add type hints to synapse.events.*. (#11066) Except `synapse/events/__init__.py`, which will be done in a follow-up. --- changelog.d/11066.misc | 1 + mypy.ini | 6 ++ synapse/events/builder.py | 4 +- synapse/events/presence_router.py | 21 +++---- synapse/events/snapshot.py | 110 +++++++++++++++++++---------------- synapse/events/spamcheck.py | 25 ++++---- synapse/events/third_party_rules.py | 25 ++++---- synapse/events/utils.py | 113 +++++++++++++++++++++--------------- synapse/events/validator.py | 18 +++--- synapse/handlers/room.py | 22 ++++++- synapse/rest/client/relations.py | 8 +-- 11 files changed, 208 insertions(+), 145 deletions(-) create mode 100644 changelog.d/11066.misc (limited to 'synapse') diff --git a/changelog.d/11066.misc b/changelog.d/11066.misc new file mode 100644 index 0000000000..1e337bee54 --- /dev/null +++ b/changelog.d/11066.misc @@ -0,0 +1 @@ +Add type hints to `synapse.events`. diff --git a/mypy.ini b/mypy.ini index 93757cd95d..2cdd552f46 100644 --- a/mypy.ini +++ b/mypy.ini @@ -22,8 +22,11 @@ files = synapse/crypto, synapse/event_auth.py, synapse/events/builder.py, + synapse/events/presence_router.py, + synapse/events/snapshot.py, synapse/events/spamcheck.py, synapse/events/third_party_rules.py, + synapse/events/utils.py, synapse/events/validator.py, synapse/federation, synapse/groups, @@ -96,6 +99,9 @@ files = tests/util/test_itertools.py, tests/util/test_stream_change_cache.py +[mypy-synapse.events.*] +disallow_untyped_defs = True + [mypy-synapse.handlers.*] disallow_untyped_defs = True diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 50f2a4c1f4..4f409f31e1 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -90,13 +90,13 @@ class EventBuilder: ) @property - def state_key(self): + def state_key(self) -> str: if self._state_key is not None: return self._state_key raise AttributeError("state_key") - def is_state(self): + def is_state(self) -> bool: return self._state_key is not None async def build( diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py index 68b8b19024..a58f313e8b 100644 --- a/synapse/events/presence_router.py +++ b/synapse/events/presence_router.py @@ -14,6 +14,7 @@ import logging from typing import ( TYPE_CHECKING, + Any, Awaitable, Callable, Dict, @@ -33,14 +34,13 @@ if TYPE_CHECKING: GET_USERS_FOR_STATES_CALLBACK = Callable[ [Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]] ] -GET_INTERESTED_USERS_CALLBACK = Callable[ - [str], Awaitable[Union[Set[str], "PresenceRouter.ALL_USERS"]] -] +# This must either return a set of strings or the constant PresenceRouter.ALL_USERS. +GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]] logger = logging.getLogger(__name__) -def load_legacy_presence_router(hs: "HomeServer"): +def load_legacy_presence_router(hs: "HomeServer") -> None: """Wrapper that loads a presence router module configured using the old configuration, and registers the hooks they implement. """ @@ -69,9 +69,10 @@ def load_legacy_presence_router(hs: "HomeServer"): if f is None: return None - def run(*args, **kwargs): - # mypy doesn't do well across function boundaries so we need to tell it - # f is definitely not None. + def run(*args: Any, **kwargs: Any) -> Awaitable: + # Assertion required because mypy can't prove we won't change `f` + # back to `None`. See + # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions assert f is not None return maybe_awaitable(f(*args, **kwargs)) @@ -104,7 +105,7 @@ class PresenceRouter: self, get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None, get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None, - ): + ) -> None: # PresenceRouter modules are required to implement both of these methods # or neither of them as they are assumed to act in a complementary manner paired_methods = [get_users_for_states, get_interested_users] @@ -142,7 +143,7 @@ class PresenceRouter: # Don't include any extra destinations for presence updates return {} - users_for_states = {} + users_for_states: Dict[str, Set[UserPresenceState]] = {} # run all the callbacks for get_users_for_states and combine the results for callback in self._get_users_for_states_callbacks: try: @@ -171,7 +172,7 @@ class PresenceRouter: return users_for_states - async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]: + async def get_interested_users(self, user_id: str) -> Union[Set[str], str]: """ Retrieve a list of users that `user_id` is interested in receiving the presence of. This will be in addition to those they share a room with. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 5ba01eeef9..d7527008c4 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -11,17 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import attr from frozendict import frozendict +from twisted.internet.defer import Deferred + from synapse.appservice import ApplicationService from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background -from synapse.types import StateMap +from synapse.types import JsonDict, StateMap if TYPE_CHECKING: + from synapse.storage import Storage from synapse.storage.databases.main import DataStore @@ -112,13 +115,13 @@ class EventContext: @staticmethod def with_state( - state_group, - state_group_before_event, - current_state_ids, - prev_state_ids, - prev_group=None, - delta_ids=None, - ): + state_group: Optional[int], + state_group_before_event: Optional[int], + current_state_ids: Optional[StateMap[str]], + prev_state_ids: Optional[StateMap[str]], + prev_group: Optional[int] = None, + delta_ids: Optional[StateMap[str]] = None, + ) -> "EventContext": return EventContext( current_state_ids=current_state_ids, prev_state_ids=prev_state_ids, @@ -129,22 +132,22 @@ class EventContext: ) @staticmethod - def for_outlier(): + def for_outlier() -> "EventContext": """Return an EventContext instance suitable for persisting an outlier event""" return EventContext( current_state_ids={}, prev_state_ids={}, ) - async def serialize(self, event: EventBase, store: "DataStore") -> dict: + async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict: """Converts self to a type that can be serialized as JSON, and then deserialized by `deserialize` Args: - event (FrozenEvent): The event that this context relates to + event: The event that this context relates to Returns: - dict + The serialized event. """ # We don't serialize the full state dicts, instead they get pulled out @@ -170,17 +173,16 @@ class EventContext: } @staticmethod - def deserialize(storage, input): + def deserialize(storage: "Storage", input: JsonDict) -> "EventContext": """Converts a dict that was produced by `serialize` back into a EventContext. Args: - storage (Storage): Used to convert AS ID to AS object and fetch - state. - input (dict): A dict produced by `serialize` + storage: Used to convert AS ID to AS object and fetch state. + input: A dict produced by `serialize` Returns: - EventContext + The event context. """ context = _AsyncEventContextImpl( # We use the state_group and prev_state_id stuff to pull the @@ -241,22 +243,25 @@ class EventContext: await self._ensure_fetched() return self._current_state_ids - async def get_prev_state_ids(self): + async def get_prev_state_ids(self) -> StateMap[str]: """ Gets the room state map, excluding this event. For a non-state event, this will be the same as get_current_state_ids(). Returns: - dict[(str, str), str]|None: Returns None if state_group - is None, which happens when the associated event is an outlier. - Maps a (type, state_key) to the event ID of the state event matching - this tuple. + Returns {} if state_group is None, which happens when the associated + event is an outlier. + + Maps a (type, state_key) to the event ID of the state event matching + this tuple. """ await self._ensure_fetched() + # There *should* be previous state IDs now. + assert self._prev_state_ids is not None return self._prev_state_ids - def get_cached_current_state_ids(self): + def get_cached_current_state_ids(self) -> Optional[StateMap[str]]: """Gets the current state IDs if we have them already cached. It is an error to access this for a rejected event, since rejected state should @@ -264,16 +269,17 @@ class EventContext: ``rejected`` is set. Returns: - dict[(str, str), str]|None: Returns None if we haven't cached the - state or if state_group is None, which happens when the associated - event is an outlier. + Returns None if we haven't cached the state or if state_group is None + (which happens when the associated event is an outlier). + + Otherwise, returns the the current state IDs. """ if self.rejected: raise RuntimeError("Attempt to access state_ids of rejected event") return self._current_state_ids - async def _ensure_fetched(self): + async def _ensure_fetched(self) -> None: return None @@ -285,46 +291,46 @@ class _AsyncEventContextImpl(EventContext): Attributes: - _storage (Storage) + _storage - _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have - been calculated. None if we haven't started calculating yet + _fetching_state_deferred: Resolves when *_state_ids have been calculated. + None if we haven't started calculating yet - _event_type (str): The type of the event the context is associated with. + _event_type: The type of the event the context is associated with. - _event_state_key (str): The state_key of the event the context is - associated with. + _event_state_key: The state_key of the event the context is associated with. - _prev_state_id (str|None): If the event associated with the context is - a state event, then `_prev_state_id` is the event_id of the state - that was replaced. + _prev_state_id: If the event associated with the context is a state event, + then `_prev_state_id` is the event_id of the state that was replaced. """ # This needs to have a default as we're inheriting - _storage = attr.ib(default=None) - _prev_state_id = attr.ib(default=None) - _event_type = attr.ib(default=None) - _event_state_key = attr.ib(default=None) - _fetching_state_deferred = attr.ib(default=None) + _storage: "Storage" = attr.ib(default=None) + _prev_state_id: Optional[str] = attr.ib(default=None) + _event_type: str = attr.ib(default=None) + _event_state_key: Optional[str] = attr.ib(default=None) + _fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None) - async def _ensure_fetched(self): + async def _ensure_fetched(self) -> None: if not self._fetching_state_deferred: self._fetching_state_deferred = run_in_background(self._fill_out_state) - return await make_deferred_yieldable(self._fetching_state_deferred) + await make_deferred_yieldable(self._fetching_state_deferred) - async def _fill_out_state(self): + async def _fill_out_state(self) -> None: """Called to populate the _current_state_ids and _prev_state_ids attributes by loading from the database. """ if self.state_group is None: return - self._current_state_ids = await self._storage.state.get_state_ids_for_group( + current_state_ids = await self._storage.state.get_state_ids_for_group( self.state_group ) + # Set this separately so mypy knows current_state_ids is not None. + self._current_state_ids = current_state_ids if self._event_state_key is not None: - self._prev_state_ids = dict(self._current_state_ids) + self._prev_state_ids = dict(current_state_ids) key = (self._event_type, self._event_state_key) if self._prev_state_id: @@ -332,10 +338,12 @@ class _AsyncEventContextImpl(EventContext): else: self._prev_state_ids.pop(key, None) else: - self._prev_state_ids = self._current_state_ids + self._prev_state_ids = current_state_ids -def _encode_state_dict(state_dict): +def _encode_state_dict( + state_dict: Optional[StateMap[str]], +) -> Optional[List[Tuple[str, str, str]]]: """Since dicts of (type, state_key) -> event_id cannot be serialized in JSON we need to convert them to a form that can. """ @@ -345,7 +353,9 @@ def _encode_state_dict(state_dict): return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()] -def _decode_state_dict(input): +def _decode_state_dict( + input: Optional[List[Tuple[str, str, str]]] +) -> Optional[StateMap[str]]: """Decodes a state dict encoded using `_encode_state_dict` above""" if input is None: return None diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index ae4c8ab257..3134beb8d3 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -77,7 +77,7 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[ ] -def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): +def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None: """Wrapper that loads spam checkers configured using the old configuration, and registers the spam checker hooks they implement. """ @@ -129,9 +129,9 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): request_info: Collection[Tuple[str, str]], auth_provider_id: Optional[str], ) -> Union[Awaitable[RegistrationBehaviour], RegistrationBehaviour]: - # We've already made sure f is not None above, but mypy doesn't - # do well across function boundaries so we need to tell it f is - # definitely not None. + # Assertion required because mypy can't prove we won't + # change `f` back to `None`. See + # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions assert f is not None return f( @@ -146,9 +146,10 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): "Bad signature for callback check_registration_for_spam", ) - def run(*args, **kwargs): - # mypy doesn't do well across function boundaries so we need to tell it - # wrapped_func is definitely not None. + def run(*args: Any, **kwargs: Any) -> Awaitable: + # Assertion required because mypy can't prove we won't change `f` + # back to `None`. See + # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions assert wrapped_func is not None return maybe_awaitable(wrapped_func(*args, **kwargs)) @@ -165,7 +166,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): class SpamChecker: - def __init__(self): + def __init__(self) -> None: self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = [] self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = [] self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = [] @@ -209,7 +210,7 @@ class SpamChecker: CHECK_REGISTRATION_FOR_SPAM_CALLBACK ] = None, check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None, - ): + ) -> None: """Register callbacks from module for each hook.""" if check_event_for_spam is not None: self._check_event_for_spam_callbacks.append(check_event_for_spam) @@ -275,7 +276,9 @@ class SpamChecker: return False - async def user_may_join_room(self, user_id: str, room_id: str, is_invited: bool): + async def user_may_join_room( + self, user_id: str, room_id: str, is_invited: bool + ) -> bool: """Checks if a given users is allowed to join a room. Not called when a user creates a room. @@ -285,7 +288,7 @@ class SpamChecker: is_invited: Whether the user is invited into the room Returns: - bool: Whether the user may join the room + Whether the user may join the room """ for callback in self._user_may_join_room_callbacks: if await callback(user_id, room_id, is_invited) is False: diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 976d9fa446..2a6dabdab6 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.events import EventBase @@ -38,7 +38,7 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[ ] -def load_legacy_third_party_event_rules(hs: "HomeServer"): +def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: """Wrapper that loads a third party event rules module configured using the old configuration, and registers the hooks they implement. """ @@ -77,9 +77,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"): event: EventBase, state_events: StateMap[EventBase], ) -> Tuple[bool, Optional[dict]]: - # We've already made sure f is not None above, but mypy doesn't do well - # across function boundaries so we need to tell it f is definitely not - # None. + # Assertion required because mypy can't prove we won't change + # `f` back to `None`. See + # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions assert f is not None res = await f(event, state_events) @@ -98,9 +98,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"): async def wrap_on_create_room( requester: Requester, config: dict, is_requester_admin: bool ) -> None: - # We've already made sure f is not None above, but mypy doesn't do well - # across function boundaries so we need to tell it f is definitely not - # None. + # Assertion required because mypy can't prove we won't change + # `f` back to `None`. See + # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions assert f is not None res = await f(requester, config, is_requester_admin) @@ -112,9 +112,10 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"): return wrap_on_create_room - def run(*args, **kwargs): - # mypy doesn't do well across function boundaries so we need to tell it - # f is definitely not None. + def run(*args: Any, **kwargs: Any) -> Awaitable: + # Assertion required because mypy can't prove we won't change `f` + # back to `None`. See + # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions assert f is not None return maybe_awaitable(f(*args, **kwargs)) @@ -162,7 +163,7 @@ class ThirdPartyEventRules: check_visibility_can_be_modified: Optional[ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = None, - ): + ) -> None: """Register callbacks from modules for each hook.""" if check_event_allowed is not None: self._check_event_allowed_callbacks.append(check_event_allowed) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 520edbbf61..23bd24d963 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -13,18 +13,32 @@ # limitations under the License. import collections.abc import re -from typing import Any, Mapping, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Union, +) from frozendict import frozendict from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion +from synapse.types import JsonDict from synapse.util.async_helpers import yieldable_gather_results from synapse.util.frozenutils import unfreeze from . import EventBase +if TYPE_CHECKING: + from synapse.server import HomeServer + # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' # (? EventBase: return pruned_event -def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: +def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDict: """Redacts the event_dict in the same way as `prune_event`, except it operates on dicts rather than event objects @@ -97,7 +111,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: new_content = {} - def add_fields(*fields): + def add_fields(*fields: str) -> None: for field in fields: if field in event_dict["content"]: new_content[field] = event_dict["content"][field] @@ -151,7 +165,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: allowed_fields["content"] = new_content - unsigned = {} + unsigned: JsonDict = {} allowed_fields["unsigned"] = unsigned event_unsigned = event_dict.get("unsigned", {}) @@ -164,16 +178,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: return allowed_fields -def _copy_field(src, dst, field): +def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None: """Copy the field in 'src' to 'dst'. For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"] then dst={"foo":{"bar":5}}. Args: - src(dict): The dict to read from. - dst(dict): The dict to modify. - field(list): List of keys to drill down to in 'src'. + src: The dict to read from. + dst: The dict to modify. + field: List of keys to drill down to in 'src'. """ if len(field) == 0: # this should be impossible return @@ -205,7 +219,7 @@ def _copy_field(src, dst, field): sub_out_dict[key_to_move] = sub_dict[key_to_move] -def only_fields(dictionary, fields): +def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict: """Return a new dict with only the fields in 'dictionary' which are present in 'fields'. @@ -215,11 +229,11 @@ def only_fields(dictionary, fields): A literal '.' character in a field name may be escaped using a '\'. Args: - dictionary(dict): The dictionary to read from. - fields(list): A list of fields to copy over. Only shallow refs are + dictionary: The dictionary to read from. + fields: A list of fields to copy over. Only shallow refs are taken. Returns: - dict: A new dictionary with only the given fields. If fields was empty, + A new dictionary with only the given fields. If fields was empty, the same dictionary is returned. """ if len(fields) == 0: @@ -235,17 +249,17 @@ def only_fields(dictionary, fields): [f.replace(r"\.", r".") for f in field_array] for field_array in split_fields ] - output = {} + output: JsonDict = {} for field_array in split_fields: _copy_field(dictionary, output, field_array) return output -def format_event_raw(d): +def format_event_raw(d: JsonDict) -> JsonDict: return d -def format_event_for_client_v1(d): +def format_event_for_client_v1(d: JsonDict) -> JsonDict: d = format_event_for_client_v2(d) sender = d.get("sender") @@ -267,7 +281,7 @@ def format_event_for_client_v1(d): return d -def format_event_for_client_v2(d): +def format_event_for_client_v2(d: JsonDict) -> JsonDict: drop_keys = ( "auth_events", "prev_events", @@ -282,37 +296,37 @@ def format_event_for_client_v2(d): return d -def format_event_for_client_v2_without_room_id(d): +def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict: d = format_event_for_client_v2(d) d.pop("room_id", None) return d def serialize_event( - e, - time_now_ms, - as_client_event=True, - event_format=format_event_for_client_v1, - token_id=None, - only_event_fields=None, - include_stripped_room_state=False, -): + e: Union[JsonDict, EventBase], + time_now_ms: int, + as_client_event: bool = True, + event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1, + token_id: Optional[str] = None, + only_event_fields: Optional[List[str]] = None, + include_stripped_room_state: bool = False, +) -> JsonDict: """Serialize event for clients Args: - e (EventBase) - time_now_ms (int) - as_client_event (bool) + e + time_now_ms + as_client_event event_format token_id only_event_fields - include_stripped_room_state (bool): Some events can have stripped room state + include_stripped_room_state: Some events can have stripped room state stored in the `unsigned` field. This is required for invite and knock functionality. If this option is False, that state will be removed from the event before it is returned. Otherwise, it will be kept. Returns: - dict + The serialized event dictionary. """ # FIXME(erikj): To handle the case of presence events and the like @@ -369,25 +383,29 @@ class EventClientSerializer: clients. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.experimental_msc1849_support_enabled = ( hs.config.server.experimental_msc1849_support_enabled ) async def serialize_event( - self, event, time_now, bundle_aggregations=True, **kwargs - ): + self, + event: Union[JsonDict, EventBase], + time_now: int, + bundle_aggregations: bool = True, + **kwargs: Any, + ) -> JsonDict: """Serializes a single event. Args: - event (EventBase) - time_now (int): The current time in milliseconds - bundle_aggregations (bool): Whether to bundle in related events + event + time_now: The current time in milliseconds + bundle_aggregations: Whether to bundle in related events **kwargs: Arguments to pass to `serialize_event` Returns: - dict: The serialized event + The serialized event """ # To handle the case of presence events and the like if not isinstance(event, EventBase): @@ -448,25 +466,27 @@ class EventClientSerializer: return serialized_event - def serialize_events(self, events, time_now, **kwargs): + async def serialize_events( + self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any + ) -> List[JsonDict]: """Serializes multiple events. Args: - event (iter[EventBase]) - time_now (int): The current time in milliseconds + event + time_now: The current time in milliseconds **kwargs: Arguments to pass to `serialize_event` Returns: - Deferred[list[dict]]: The list of serialized events + The list of serialized events """ - return yieldable_gather_results( + return await yieldable_gather_results( self.serialize_event, events, time_now=time_now, **kwargs ) def copy_power_levels_contents( old_power_levels: Mapping[str, Union[int, Mapping[str, int]]] -): +) -> Dict[str, Union[int, Dict[str, int]]]: """Copy the content of a power_levels event, unfreezing frozendicts along the way Raises: @@ -475,7 +495,7 @@ def copy_power_levels_contents( if not isinstance(old_power_levels, collections.abc.Mapping): raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,)) - power_levels = {} + power_levels: Dict[str, Union[int, Dict[str, int]]] = {} for k, v in old_power_levels.items(): if isinstance(v, int): @@ -483,7 +503,8 @@ def copy_power_levels_contents( continue if isinstance(v, collections.abc.Mapping): - power_levels[k] = h = {} + h: Dict[str, int] = {} + power_levels[k] = h for k1, v1 in v.items(): # we should only have one level of nesting if not isinstance(v1, int): @@ -498,7 +519,7 @@ def copy_power_levels_contents( return power_levels -def validate_canonicaljson(value: Any): +def validate_canonicaljson(value: Any) -> None: """ Ensure that the JSON object is valid according to the rules of canonical JSON. diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 6eb6544c4c..4d459c17f1 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections.abc -from typing import Union +from typing import Iterable, Union import jsonschema @@ -28,11 +28,11 @@ from synapse.events.utils import ( validate_canonicaljson, ) from synapse.federation.federation_server import server_matches_acl_event -from synapse.types import EventID, RoomID, UserID +from synapse.types import EventID, JsonDict, RoomID, UserID class EventValidator: - def validate_new(self, event: EventBase, config: HomeServerConfig): + def validate_new(self, event: EventBase, config: HomeServerConfig) -> None: """Validates the event has roughly the right format Args: @@ -116,7 +116,7 @@ class EventValidator: errcode=Codes.BAD_JSON, ) - def _validate_retention(self, event: EventBase): + def _validate_retention(self, event: EventBase) -> None: """Checks that an event that defines the retention policy for a room respects the format enforced by the spec. @@ -156,7 +156,7 @@ class EventValidator: errcode=Codes.BAD_JSON, ) - def validate_builder(self, event: Union[EventBase, EventBuilder]): + def validate_builder(self, event: Union[EventBase, EventBuilder]) -> None: """Validates that the builder/event has roughly the right format. Only checks values that we expect a proto event to have, rather than all the fields an event would have @@ -204,14 +204,14 @@ class EventValidator: self._ensure_state_event(event) - def _ensure_strings(self, d, keys): + def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None: for s in keys: if s not in d: raise SynapseError(400, "'%s' not in content" % (s,)) if not isinstance(d[s], str): raise SynapseError(400, "'%s' not a string type" % (s,)) - def _ensure_state_event(self, event): + def _ensure_state_event(self, event: Union[EventBase, EventBuilder]) -> None: if not event.is_state(): raise SynapseError(400, "'%s' must be state events" % (event.type,)) @@ -244,7 +244,9 @@ POWER_LEVELS_SCHEMA = { } -def _create_power_level_validator(): +# This could return something newer than Draft 7, but that's the current "latest" +# validator. +def _create_power_level_validator() -> jsonschema.Draft7Validator: validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA) # by default jsonschema does not consider a frozendict to be an object so diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 7072bca1fc..6f39e9446f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -465,17 +465,35 @@ class RoomCreationHandler: # the room has been created # Calculate the minimum power level needed to clone the room event_power_levels = power_levels.get("events", {}) + if not isinstance(event_power_levels, dict): + event_power_levels = {} state_default = power_levels.get("state_default", 50) + try: + state_default_int = int(state_default) # type: ignore[arg-type] + except (TypeError, ValueError): + state_default_int = 50 ban = power_levels.get("ban", 50) - needed_power_level = max(state_default, ban, max(event_power_levels.values())) + try: + ban = int(ban) # type: ignore[arg-type] + except (TypeError, ValueError): + ban = 50 + needed_power_level = max( + state_default_int, ban, max(event_power_levels.values()) + ) # Get the user's current power level, this matches the logic in get_user_power_level, # but without the entire state map. user_power_levels = power_levels.setdefault("users", {}) + if not isinstance(user_power_levels, dict): + user_power_levels = {} users_default = power_levels.get("users_default", 0) current_power_level = user_power_levels.get(user_id, users_default) + try: + current_power_level_int = int(current_power_level) # type: ignore[arg-type] + except (TypeError, ValueError): + current_power_level_int = 0 # Raise the requester's power level in the new room if necessary - if current_power_level < needed_power_level: + if current_power_level_int < needed_power_level: user_power_levels[user_id] = needed_power_level await self._send_events_for_new_room( diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 0b0711c03c..d695c18be2 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -232,12 +232,12 @@ class RelationPaginationServlet(RestServlet): # Similarly, we don't allow relations to be applied to relations, so we # return the original relations without any aggregations on top of them # here. - events = await self._event_serializer.serialize_events( + serialized_events = await self._event_serializer.serialize_events( events, now, bundle_aggregations=False ) return_value = pagination_chunk.to_dict() - return_value["chunk"] = events + return_value["chunk"] = serialized_events return_value["original_event"] = original_event return 200, return_value @@ -416,10 +416,10 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ) now = self.clock.time_msec() - events = await self._event_serializer.serialize_events(events, now) + serialized_events = await self._event_serializer.serialize_events(events, now) return_value = result.to_dict() - return_value["chunk"] = events + return_value["chunk"] = serialized_events return 200, return_value -- cgit 1.5.1 From b59f3281d5bd2c5e3b717a37a40023afc766370c Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Wed, 13 Oct 2021 13:41:24 +0100 Subject: Remove dead code from `MediaFilePaths` (#11056) --- changelog.d/11056.misc | 1 + synapse/rest/media/v1/filepath.py | 17 ----------------- 2 files changed, 1 insertion(+), 17 deletions(-) create mode 100644 changelog.d/11056.misc (limited to 'synapse') diff --git a/changelog.d/11056.misc b/changelog.d/11056.misc new file mode 100644 index 0000000000..dd701ed177 --- /dev/null +++ b/changelog.d/11056.misc @@ -0,0 +1 @@ +Remove dead code from `MediaFilePaths`. diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index eb66b749a2..bec77088ee 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -48,23 +48,6 @@ class MediaFilePaths: def __init__(self, primary_base_path: str): self.base_path = primary_base_path - def default_thumbnail_rel( - self, - default_top_level: str, - default_sub_type: str, - width: int, - height: int, - content_type: str, - method: str, - ) -> str: - top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) - return os.path.join( - "default_thumbnails", default_top_level, default_sub_type, file_name - ) - - default_thumbnail = _wrap_in_base_path(default_thumbnail_rel) - def local_media_filepath_rel(self, media_id: str) -> str: return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:]) -- cgit 1.5.1 From 317e9e415c378d7b89b5f140b42e45db4583b35a Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 13 Oct 2021 13:50:00 +0100 Subject: Rearrange the user_directory's `_handle_deltas` function (#11035) * Pull out `_handle_room_membership_event` * Discard excluded users early * Rearrange logic so the change is membership is effectively switched over. See PR for rationale. --- changelog.d/11035.misc | 1 + synapse/handlers/user_directory.py | 135 +++++++++++++++++++++---------------- 2 files changed, 79 insertions(+), 57 deletions(-) create mode 100644 changelog.d/11035.misc (limited to 'synapse') diff --git a/changelog.d/11035.misc b/changelog.d/11035.misc new file mode 100644 index 0000000000..6b45b7e9bd --- /dev/null +++ b/changelog.d/11035.misc @@ -0,0 +1 @@ +Rearrange the internal workings of the incremental user directory updates. \ No newline at end of file diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 8810f048ba..52b2de388f 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -196,63 +196,12 @@ class UserDirectoryHandler(StateDeltasHandler): room_id, prev_event_id, event_id, typ ) elif typ == EventTypes.Member: - change = await self._get_key_change( + await self._handle_room_membership_event( + room_id, prev_event_id, event_id, - key_name="membership", - public_value=Membership.JOIN, + state_key, ) - - is_remote = not self.is_mine_id(state_key) - if change is MatchChange.now_false: - # Need to check if the server left the room entirely, if so - # we might need to remove all the users in that room - is_in_room = await self.store.is_host_joined( - room_id, self.server_name - ) - if not is_in_room: - logger.debug("Server left room: %r", room_id) - # Fetch all the users that we marked as being in user - # directory due to being in the room and then check if - # need to remove those users or not - user_ids = await self.store.get_users_in_dir_due_to_room( - room_id - ) - - for user_id in user_ids: - await self._handle_remove_user(room_id, user_id) - continue - else: - logger.debug("Server is still in room: %r", room_id) - - include_in_dir = ( - is_remote - or await self.store.should_include_local_user_in_dir(state_key) - ) - if include_in_dir: - if change is MatchChange.no_change: - # Handle any profile changes for remote users. - # (For local users we are not forced to scan membership - # events; instead the rest of the application calls - # `handle_local_profile_change`.) - if is_remote: - await self._handle_profile_change( - state_key, room_id, prev_event_id, event_id - ) - continue - - if change is MatchChange.now_true: # The user joined - # This may be the first time we've seen a remote user. If - # so, ensure we have a directory entry for them. (We don't - # need to do this for local users: their directory entry - # is created at the point of registration. - if is_remote: - await self._upsert_directory_entry_for_remote_user( - state_key, event_id - ) - await self._track_user_joined_room(room_id, state_key) - else: # The user left - await self._handle_remove_user(room_id, state_key) else: logger.debug("Ignoring irrelevant type: %r", typ) @@ -326,6 +275,72 @@ class UserDirectoryHandler(StateDeltasHandler): for user_id in users_in_room: await self._track_user_joined_room(room_id, user_id) + async def _handle_room_membership_event( + self, + room_id: str, + prev_event_id: str, + event_id: str, + state_key: str, + ) -> None: + """Process a single room membershp event. + + We have to do two things: + + 1. Update the room-sharing tables. + This applies to remote users and non-excluded local users. + 2. Update the user_directory and user_directory_search tables. + This applies to remote users only, because we only become aware of + the (and any profile changes) by listening to these events. + The rest of the application knows exactly when local users are + created or their profile changed---it will directly call methods + on this class. + """ + joined = await self._get_key_change( + prev_event_id, + event_id, + key_name="membership", + public_value=Membership.JOIN, + ) + + # Both cases ignore excluded local users, so start by discarding them. + is_remote = not self.is_mine_id(state_key) + if not is_remote and not await self.store.should_include_local_user_in_dir( + state_key + ): + return + + if joined is MatchChange.now_false: + # Need to check if the server left the room entirely, if so + # we might need to remove all the users in that room + is_in_room = await self.store.is_host_joined(room_id, self.server_name) + if not is_in_room: + logger.debug("Server left room: %r", room_id) + # Fetch all the users that we marked as being in user + # directory due to being in the room and then check if + # need to remove those users or not + user_ids = await self.store.get_users_in_dir_due_to_room(room_id) + + for user_id in user_ids: + await self._handle_remove_user(room_id, user_id) + else: + logger.debug("Server is still in room: %r", room_id) + await self._handle_remove_user(room_id, state_key) + elif joined is MatchChange.no_change: + # Handle any profile changes for remote users. + # (For local users the rest of the application calls + # `handle_local_profile_change`.) + if is_remote: + await self._handle_possible_remote_profile_change( + state_key, room_id, prev_event_id, event_id + ) + elif joined is MatchChange.now_true: # The user joined + # This may be the first time we've seen a remote user. If + # so, ensure we have a directory entry for them. (For local users, + # the rest of the application calls `handle_local_profile_change`.) + if is_remote: + await self._upsert_directory_entry_for_remote_user(state_key, event_id) + await self._track_user_joined_room(room_id, state_key) + async def _upsert_directory_entry_for_remote_user( self, user_id: str, event_id: str ) -> None: @@ -386,7 +401,12 @@ class UserDirectoryHandler(StateDeltasHandler): await self.store.add_users_who_share_private_room(room_id, to_insert) async def _handle_remove_user(self, room_id: str, user_id: str) -> None: - """Called when we might need to remove user from directory + """Called when when someone leaves a room. The user may be local or remote. + + (If the person who left was the last local user in this room, the server + is no longer in the room. We call this function to forget that the remaining + remote users are in the room, even though they haven't left. So the name is + a little misleading!) Args: room_id: The room ID that user left or stopped being public that @@ -403,7 +423,7 @@ class UserDirectoryHandler(StateDeltasHandler): if len(rooms_user_is_in) == 0: await self.store.remove_from_user_dir(user_id) - async def _handle_profile_change( + async def _handle_possible_remote_profile_change( self, user_id: str, room_id: str, @@ -411,7 +431,8 @@ class UserDirectoryHandler(StateDeltasHandler): event_id: Optional[str], ) -> None: """Check member event changes for any profile changes and update the - database if there are. + database if there are. This is intended for remote users only. The caller + is responsible for checking that the given user is remote. """ if not prev_event_id or not event_id: return -- cgit 1.5.1 From 35d6b914eb98e218007b7977a85f5022d25072fb Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 13 Oct 2021 17:44:00 -0500 Subject: Resolve and share `state_groups` for all historical events in batch (MSC2716) (#10975) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resolve and share `state_groups` for all historical events in batch. This also helps for showing the appropriate avatar/displayname in Element and will work whenever `/messages` has one of the historical messages as the first message in the batch. This does have the flaw where if you just insert a single historical event somewhere, it probably won't resolve the state correctly from `/messages` or `/context` since it will grab a non historical event above or below with resolved state which never included the historical state back then. For the same reasions, this also does not work in Element between the transition from actual messages to historical messages. In the Gitter case, this isn't really a problem since all of the historical messages are in one big lump at the beginning of the room. For a future iteration, might be good to look at `/messages` and `/context` to additionally add the `state` for any historical messages in that batch. --- How are the `state_groups` shared? To illustrate the `state_group` sharing, see this example: **Before** (new `state_group` for every event 😬, very inefficient): ``` # Tests from https://github.com/matrix-org/complement/pull/206 $ COMPLEMENT_ALWAYS_PRINT_SERVER_LOGS=1 COMPLEMENT_DIR=../complement ./scripts-dev/complement.sh TestBackfillingHistory/parallel/should_resolve_member_state_events_for_historical_events create_new_client_event m.room.member event=$_JXfwUDIWS6xKGG4SmZXjSFrizhARM7QblhATVWWUcA state_group=None create_new_client_event org.matrix.msc2716.insertion event=$1ZBfmBKEjg94d-vGYymKrVYeghwBOuGJ3wubU1-I9y0 state_group=9 create_new_client_event org.matrix.msc2716.insertion event=$Mq2JvRetTyclPuozRI682SAjYp3GqRuPc8_cH5-ezPY state_group=10 create_new_client_event m.room.message event=$MfmY4rBQkxrIp8jVwVMTJ4PKnxSigpG9E2cn7S0AtTo state_group=11 create_new_client_event m.room.message event=$uYOv6V8wiF7xHwOMt-60d1AoOIbqLgrDLz6ZIQDdWUI state_group=12 create_new_client_event m.room.message event=$PAbkJRMxb0bX4A6av463faiAhxkE3FEObM1xB4D0UG4 state_group=13 create_new_client_event org.matrix.msc2716.batch event=$Oy_S7AWN7rJQe_MYwGPEy6RtbYklrI-tAhmfiLrCaKI state_group=14 ``` **After** (all events in batch sharing `state_group=10`) (the base insertion event has `state_group=8` which matches the `prev_event` we're inserting next to): ``` # Tests from https://github.com/matrix-org/complement/pull/206 $ COMPLEMENT_ALWAYS_PRINT_SERVER_LOGS=1 COMPLEMENT_DIR=../complement ./scripts-dev/complement.sh TestBackfillingHistory/parallel/should_resolve_member_state_events_for_historical_events create_new_client_event m.room.member event=$PWomJ8PwENYEYuVNoG30gqtybuQQSZ55eldBUSs0i0U state_group=None create_new_client_event org.matrix.msc2716.insertion event=$e_mCU7Eah9ABF6nQU7lu4E1RxIWccNF05AKaTT5m3lw state_group=9 create_new_client_event org.matrix.msc2716.insertion event=$ui7A3_GdXIcJq0C8GpyrF8X7B3DTjMd_WGCjogax7xU state_group=10 create_new_client_event m.room.message event=$EnTIM5rEGVezQJiYl62uFBl6kJ7B-sMxWqe2D_4FX1I state_group=10 create_new_client_event m.room.message event=$LGx5jGONnBPuNhAuZqHeEoXChd9ryVkuTZatGisOPjk state_group=10 create_new_client_event m.room.message event=$wW0zwoN50lbLu1KoKbybVMxLbKUj7GV_olozIc5i3M0 state_group=10 create_new_client_event org.matrix.msc2716.batch event=$5ZB6dtzqFBCEuMRgpkU201Qhx3WtXZGTz_YgldL6JrQ state_group=10 ``` --- changelog.d/10975.feature | 1 + synapse/handlers/message.py | 57 +++++++++++++--------- synapse/handlers/room_batch.py | 40 +++++++++++---- synapse/rest/client/room_batch.py | 15 +++--- synapse/storage/databases/main/events.py | 10 ++-- synapse/storage/databases/main/room_batch.py | 13 +++++ synapse/storage/schema/__init__.py | 6 ++- .../delta/65/01msc2716_insertion_event_edges.sql | 19 ++++++++ 8 files changed, 114 insertions(+), 47 deletions(-) create mode 100644 changelog.d/10975.feature create mode 100644 synapse/storage/schema/main/delta/65/01msc2716_insertion_event_edges.sql (limited to 'synapse') diff --git a/changelog.d/10975.feature b/changelog.d/10975.feature new file mode 100644 index 0000000000..167426e1fc --- /dev/null +++ b/changelog.d/10975.feature @@ -0,0 +1 @@ +Resolve and share `state_groups` for all [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) historical events in batch. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4de9f4b828..2e024b551f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -607,29 +607,6 @@ class EventCreationHandler: builder.internal_metadata.historical = historical - # Strip down the auth_event_ids to only what we need to auth the event. - # For example, we don't need extra m.room.member that don't match event.sender - if auth_event_ids is not None: - # If auth events are provided, prev events must be also. - assert prev_event_ids is not None - - temp_event = await builder.build( - prev_event_ids=prev_event_ids, - auth_event_ids=auth_event_ids, - depth=depth, - ) - auth_events = await self.store.get_events_as_list(auth_event_ids) - # Create a StateMap[str] - auth_event_state_map = { - (e.type, e.state_key): e.event_id for e in auth_events - } - # Actually strip down and use the necessary auth events - auth_event_ids = self._event_auth_handler.compute_auth_events( - event=temp_event, - current_state_ids=auth_event_state_map, - for_verification=False, - ) - event, context = await self.create_new_client_event( builder=builder, requester=requester, @@ -936,6 +913,33 @@ class EventCreationHandler: Tuple of created event, context """ + # Strip down the auth_event_ids to only what we need to auth the event. + # For example, we don't need extra m.room.member that don't match event.sender + full_state_ids_at_event = None + if auth_event_ids is not None: + # If auth events are provided, prev events must be also. + assert prev_event_ids is not None + + # Copy the full auth state before it stripped down + full_state_ids_at_event = auth_event_ids.copy() + + temp_event = await builder.build( + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, + depth=depth, + ) + auth_events = await self.store.get_events_as_list(auth_event_ids) + # Create a StateMap[str] + auth_event_state_map = { + (e.type, e.state_key): e.event_id for e in auth_events + } + # Actually strip down and use the necessary auth events + auth_event_ids = self._event_auth_handler.compute_auth_events( + event=temp_event, + current_state_ids=auth_event_state_map, + for_verification=False, + ) + if prev_event_ids is not None: assert ( len(prev_event_ids) <= 10 @@ -965,6 +969,13 @@ class EventCreationHandler: if builder.internal_metadata.outlier: event.internal_metadata.outlier = True context = EventContext.for_outlier() + elif ( + event.type == EventTypes.MSC2716_INSERTION + and full_state_ids_at_event + and builder.internal_metadata.is_historical() + ): + old_state = await self.store.get_events_as_list(full_state_ids_at_event) + context = await self.state.compute_event_context(event, old_state=old_state) else: context = await self.state.compute_event_context(event) diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 51dd4e7555..2f5a3e4d19 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -13,6 +13,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def generate_fake_event_id() -> str: + return "$fake_" + random_string(43) + + class RoomBatchHandler: def __init__(self, hs: "HomeServer"): self.hs = hs @@ -177,6 +181,11 @@ class RoomBatchHandler: state_event_ids_at_start = [] auth_event_ids = initial_auth_event_ids.copy() + + # Make the state events float off on their own so we don't have a + # bunch of `@mxid joined the room` noise between each batch + prev_event_id_for_state_chain = generate_fake_event_id() + for state_event in state_events_at_start: assert_params_in_dict( state_event, ["type", "origin_server_ts", "content", "sender"] @@ -200,10 +209,6 @@ class RoomBatchHandler: # Mark all events as historical event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - # Make the state events float off on their own so we don't have a - # bunch of `@mxid joined the room` noise between each batch - fake_prev_event_id = "$" + random_string(43) - # TODO: This is pretty much the same as some other code to handle inserting state in this file if event_dict["type"] == EventTypes.Member: membership = event_dict["content"].get("membership", None) @@ -216,7 +221,7 @@ class RoomBatchHandler: action=membership, content=event_dict["content"], outlier=True, - prev_event_ids=[fake_prev_event_id], + prev_event_ids=[prev_event_id_for_state_chain], # Make sure to use a copy of this list because we modify it # later in the loop here. Otherwise it will be the same # reference and also update in the event when we append later. @@ -235,7 +240,7 @@ class RoomBatchHandler: ), event_dict, outlier=True, - prev_event_ids=[fake_prev_event_id], + prev_event_ids=[prev_event_id_for_state_chain], # Make sure to use a copy of this list because we modify it # later in the loop here. Otherwise it will be the same # reference and also update in the event when we append later. @@ -245,6 +250,8 @@ class RoomBatchHandler: state_event_ids_at_start.append(event_id) auth_event_ids.append(event_id) + # Connect all the state in a floating chain + prev_event_id_for_state_chain = event_id return state_event_ids_at_start @@ -289,6 +296,10 @@ class RoomBatchHandler: for ev in events_to_create: assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) + assert self.hs.is_mine_id(ev["sender"]), "User must be our own: %s" % ( + ev["sender"], + ) + event_dict = { "type": ev["type"], "origin_server_ts": ev["origin_server_ts"], @@ -311,6 +322,19 @@ class RoomBatchHandler: historical=True, depth=inherited_depth, ) + + assert context._state_group + + # Normally this is done when persisting the event but we have to + # pre-emptively do it here because we create all the events first, + # then persist them in another pass below. And we want to share + # state_groups across the whole batch so this lookup needs to work + # for the next event in the batch in this loop. + await self.store.store_state_group_id_for_event_id( + event_id=event.event_id, + state_group_id=context._state_group, + ) + logger.debug( "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s", event, @@ -318,10 +342,6 @@ class RoomBatchHandler: auth_event_ids, ) - assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( - event.sender, - ) - events_to_persist.append((event, context)) event_id = event.event_id diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 38ad4c2447..99f8156ad0 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -32,7 +32,6 @@ from synapse.http.servlet import ( from synapse.http.site import SynapseRequest from synapse.rest.client.transactions import HttpTransactionCache from synapse.types import JsonDict -from synapse.util.stringutils import random_string if TYPE_CHECKING: from synapse.server import HomeServer @@ -160,11 +159,6 @@ class RoomBatchSendEventRestServlet(RestServlet): base_insertion_event = None if batch_id_from_query: batch_id_to_connect_to = batch_id_from_query - # All but the first base insertion event should point at a fake - # event, which causes the HS to ask for the state at the start of - # the batch later. - fake_prev_event_id = "$" + random_string(43) - prev_event_ids = [fake_prev_event_id] # Otherwise, create an insertion event to act as a starting point. # # We don't always have an insertion event to start hanging more history @@ -173,8 +167,6 @@ class RoomBatchSendEventRestServlet(RestServlet): # an insertion event), in which case we just create a new insertion event # that can then get pointed to by a "marker" event later. else: - prev_event_ids = prev_event_ids_from_query - base_insertion_event_dict = ( self.room_batch_handler.create_insertion_event_dict( sender=requester.user.to_string(), @@ -182,7 +174,7 @@ class RoomBatchSendEventRestServlet(RestServlet): origin_server_ts=last_event_in_batch["origin_server_ts"], ) ) - base_insertion_event_dict["prev_events"] = prev_event_ids.copy() + base_insertion_event_dict["prev_events"] = prev_event_ids_from_query.copy() ( base_insertion_event, @@ -203,6 +195,11 @@ class RoomBatchSendEventRestServlet(RestServlet): EventContentFields.MSC2716_NEXT_BATCH_ID ] + # Also connect the historical event chain to the end of the floating + # state chain, which causes the HS to ask for the state at the start of + # the batch later. + prev_event_ids = [state_event_ids_at_start[-1]] + # Create and persist all of the historical events as well as insertion # and batch meta events to make the batch navigable in the DAG. event_ids, next_batch_id = await self.room_batch_handler.handle_batch_of_events( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 19f55c19c5..37439f8562 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2069,12 +2069,14 @@ class PersistEventsStore: state_groups[event.event_id] = context.state_group - self.db_pool.simple_insert_many_txn( + self.db_pool.simple_upsert_many_txn( txn, table="event_to_state_groups", - values=[ - {"state_group": state_group_id, "event_id": event_id} - for event_id, state_group_id in state_groups.items() + key_names=["event_id"], + key_values=[[event_id] for event_id, _ in state_groups.items()], + value_names=["state_group"], + value_values=[ + [state_group_id] for _, state_group_id in state_groups.items() ], ) diff --git a/synapse/storage/databases/main/room_batch.py b/synapse/storage/databases/main/room_batch.py index 300a563c9e..dcbce8fdcf 100644 --- a/synapse/storage/databases/main/room_batch.py +++ b/synapse/storage/databases/main/room_batch.py @@ -36,3 +36,16 @@ class RoomBatchStore(SQLBaseStore): retcol="event_id", allow_none=True, ) + + async def store_state_group_id_for_event_id( + self, event_id: str, state_group_id: int + ) -> Optional[str]: + { + await self.db_pool.simple_upsert( + table="event_to_state_groups", + keyvalues={"event_id": event_id}, + values={"state_group": state_group_id, "event_id": event_id}, + # Unique constraint on event_id so we don't have to lock + lock=False, + ) + } diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 1aee741a8b..a1d2332326 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 64 # remember to update the list below when updating +SCHEMA_VERSION = 65 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -41,6 +41,10 @@ Changes in SCHEMA_VERSION = 63: Changes in SCHEMA_VERSION = 64: - MSC2716: Rename related tables and columns from "chunks" to "batches". + +Changes in SCHEMA_VERSION = 65: + - MSC2716: Remove unique event_id constraint from insertion_event_edges + because an insertion event can have multiple edges. """ diff --git a/synapse/storage/schema/main/delta/65/01msc2716_insertion_event_edges.sql b/synapse/storage/schema/main/delta/65/01msc2716_insertion_event_edges.sql new file mode 100644 index 0000000000..98b25daf45 --- /dev/null +++ b/synapse/storage/schema/main/delta/65/01msc2716_insertion_event_edges.sql @@ -0,0 +1,19 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Recreate the insertion_event_edges event_id index without the unique constraint +-- because an insertion event can have multiple edges. +DROP INDEX insertion_event_edges_event_id; +CREATE INDEX IF NOT EXISTS insertion_event_edges_event_id ON insertion_event_edges(event_id); -- cgit 1.5.1 From e2f0b49b3fa9fd87cd24ac6bdc46a94db532ba89 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 14 Oct 2021 10:17:20 -0400 Subject: Attempt different character encodings when previewing a URL. (#11077) This follows similar logic to BeautifulSoup where we attempt different character encodings until we find one which works. --- changelog.d/11077.bugfix | 1 + synapse/rest/media/v1/preview_url_resource.py | 80 +++++++++++++-------------- tests/test_preview.py | 66 +++++++++++++--------- 3 files changed, 80 insertions(+), 67 deletions(-) create mode 100644 changelog.d/11077.bugfix (limited to 'synapse') diff --git a/changelog.d/11077.bugfix b/changelog.d/11077.bugfix new file mode 100644 index 0000000000..dc35c86440 --- /dev/null +++ b/changelog.d/11077.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when attempting to preview URLs which are in the `windows-1252` character encoding. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 5bddd21ef1..7ee91a0c05 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -295,8 +295,7 @@ class PreviewUrlResource(DirectServeJsonResource): with open(media_info.filename, "rb") as file: body = file.read() - encoding = get_html_media_encoding(body, media_info.media_type) - tree = decode_body(body, encoding) + tree = decode_body(body, media_info.uri, media_info.media_type) if tree is not None: # Check if this HTML document points to oEmbed information and # defer to that. @@ -632,16 +631,19 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("No media removed from url cache") -def get_html_media_encoding(body: bytes, content_type: str) -> str: +def get_html_media_encodings(body: bytes, content_type: Optional[str]) -> Iterable[str]: """ - Get the encoding of the body based on the (presumably) HTML body or media_type. + Get potential encoding of the body based on the (presumably) HTML body or the content-type header. The precedence used for finding a character encoding is: - 1. meta tag with a charset declared. + 1. tag with a charset declared. 2. The XML document's character encoding attribute. 3. The Content-Type header. - 4. Fallback to UTF-8. + 4. Fallback to utf-8. + 5. Fallback to windows-1252. + + This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector. Args: body: The HTML document, as bytes. @@ -653,36 +655,39 @@ def get_html_media_encoding(body: bytes, content_type: str) -> str: # Limit searches to the first 1kb, since it ought to be at the top. body_start = body[:1024] - # Let's try and figure out if it has an encoding set in a meta tag. + # Check if it has an encoding set in a meta tag. match = _charset_match.search(body_start) if match: - return match.group(1).decode("ascii") + yield match.group(1).decode("ascii") # TODO Support - # If we didn't find a match, see if it an XML document with an encoding. + # Check if it has an XML document with an encoding. match = _xml_encoding_match.match(body_start) if match: - return match.group(1).decode("ascii") + yield match.group(1).decode("ascii") - # If we don't find a match, we'll look at the HTTP Content-Type, and - # if that doesn't exist, we'll fall back to UTF-8. - content_match = _content_type_match.match(content_type) - if content_match: - return content_match.group(1) + # Check the HTTP Content-Type header for a character set. + if content_type: + content_match = _content_type_match.match(content_type) + if content_match: + yield content_match.group(1) - return "utf-8" + # Finally, fallback to UTF-8, then windows-1252. + yield "utf-8" + yield "windows-1252" def decode_body( - body: bytes, request_encoding: Optional[str] = None + body: bytes, uri: str, content_type: Optional[str] = None ) -> Optional["etree.Element"]: """ This uses lxml to parse the HTML document. Args: body: The HTML document, as bytes. - request_encoding: The character encoding of the body, as a string. + uri: The URI used to download the body. + content_type: The Content-Type header. Returns: The parsed HTML body, or None if an error occurred during processed. @@ -691,32 +696,25 @@ def decode_body( if not body: return None + for encoding in get_html_media_encodings(body, content_type): + try: + body_str = body.decode(encoding) + except Exception: + pass + else: + break + else: + logger.warning("Unable to decode HTML body for %s", uri) + return None + from lxml import etree - # Create an HTML parser. If this fails, log and return no metadata. - try: - parser = etree.HTMLParser(recover=True, encoding=request_encoding) - except LookupError: - # blindly consider the encoding as utf-8. - parser = etree.HTMLParser(recover=True, encoding="utf-8") - except Exception as e: - logger.warning("Unable to create HTML parser: %s" % (e,)) - return None + # Create an HTML parser. + parser = etree.HTMLParser(recover=True, encoding="utf-8") - def _attempt_decode_body( - body_attempt: Union[bytes, str] - ) -> Optional["etree.Element"]: - # Attempt to parse the body. Returns None if the body was successfully - # parsed, but no tree was found. - return etree.fromstring(body_attempt, parser) - - # Attempt to parse the body. If this fails, log and return no metadata. - try: - return _attempt_decode_body(body) - except UnicodeDecodeError: - # blindly try decoding the body as utf-8, which seems to fix - # the charset mismatches on https://google.com - return _attempt_decode_body(body.decode("utf-8", "ignore")) + # Attempt to parse the body. Returns None if the body was successfully + # parsed, but no tree was found. + return etree.fromstring(body_str, parser) def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]: diff --git a/tests/test_preview.py b/tests/test_preview.py index 09e017b4d9..c6789017bc 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -15,7 +15,7 @@ from synapse.rest.media.v1.preview_url_resource import ( _calc_og, decode_body, - get_html_media_encoding, + get_html_media_encodings, summarize_paragraphs, ) @@ -159,7 +159,7 @@ class CalcOgTestCase(unittest.TestCase): """ - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -175,7 +175,7 @@ class CalcOgTestCase(unittest.TestCase): """ - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -194,7 +194,7 @@ class CalcOgTestCase(unittest.TestCase): """ - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") og = _calc_og(tree, "http://example.com/test.html") self.assertEqual( @@ -216,7 +216,7 @@ class CalcOgTestCase(unittest.TestCase): """ - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -230,7 +230,7 @@ class CalcOgTestCase(unittest.TestCase): """ - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) @@ -245,7 +245,7 @@ class CalcOgTestCase(unittest.TestCase): """ - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) @@ -260,7 +260,7 @@ class CalcOgTestCase(unittest.TestCase): """ - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) @@ -268,13 +268,13 @@ class CalcOgTestCase(unittest.TestCase): def test_empty(self): """Test a body with no data in it.""" html = b"" - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") self.assertIsNone(tree) def test_no_tree(self): """A valid body with no tree in it.""" html = b"\x00" - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") self.assertIsNone(tree) def test_invalid_encoding(self): @@ -287,7 +287,7 @@ class CalcOgTestCase(unittest.TestCase): """ - tree = decode_body(html, "invalid-encoding") + tree = decode_body(html, "http://example.com/test.html", "invalid-encoding") og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -302,15 +302,29 @@ class CalcOgTestCase(unittest.TestCase): """ - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) + def test_windows_1252(self): + """A body which uses windows-1252, but doesn't declare that.""" + html = b""" + + \xf3 + + Some text. + + + """ + tree = decode_body(html, "http://example.com/test.html") + og = _calc_og(tree, "http://example.com/test.html") + self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."}) + class MediaEncodingTestCase(unittest.TestCase): def test_meta_charset(self): """A character encoding is found via the meta tag.""" - encoding = get_html_media_encoding( + encodings = get_html_media_encodings( b""" @@ -319,10 +333,10 @@ class MediaEncodingTestCase(unittest.TestCase): """, "text/html", ) - self.assertEqual(encoding, "ascii") + self.assertEqual(list(encodings), ["ascii", "utf-8", "windows-1252"]) # A less well-formed version. - encoding = get_html_media_encoding( + encodings = get_html_media_encodings( b""" < meta charset = ascii> @@ -331,11 +345,11 @@ class MediaEncodingTestCase(unittest.TestCase): """, "text/html", ) - self.assertEqual(encoding, "ascii") + self.assertEqual(list(encodings), ["ascii", "utf-8", "windows-1252"]) def test_meta_charset_underscores(self): """A character encoding contains underscore.""" - encoding = get_html_media_encoding( + encodings = get_html_media_encodings( b""" @@ -344,11 +358,11 @@ class MediaEncodingTestCase(unittest.TestCase): """, "text/html", ) - self.assertEqual(encoding, "Shift_JIS") + self.assertEqual(list(encodings), ["Shift_JIS", "utf-8", "windows-1252"]) def test_xml_encoding(self): """A character encoding is found via the meta tag.""" - encoding = get_html_media_encoding( + encodings = get_html_media_encodings( b""" @@ -356,11 +370,11 @@ class MediaEncodingTestCase(unittest.TestCase): """, "text/html", ) - self.assertEqual(encoding, "ascii") + self.assertEqual(list(encodings), ["ascii", "utf-8", "windows-1252"]) def test_meta_xml_encoding(self): """Meta tags take precedence over XML encoding.""" - encoding = get_html_media_encoding( + encodings = get_html_media_encodings( b""" @@ -370,7 +384,7 @@ class MediaEncodingTestCase(unittest.TestCase): """, "text/html", ) - self.assertEqual(encoding, "UTF-16") + self.assertEqual(list(encodings), ["UTF-16", "ascii", "utf-8", "windows-1252"]) def test_content_type(self): """A character encoding is found via the Content-Type header.""" @@ -384,10 +398,10 @@ class MediaEncodingTestCase(unittest.TestCase): 'text/html; charset=ascii";', ) for header in headers: - encoding = get_html_media_encoding(b"", header) - self.assertEqual(encoding, "ascii") + encodings = get_html_media_encodings(b"", header) + self.assertEqual(list(encodings), ["ascii", "utf-8", "windows-1252"]) def test_fallback(self): """A character encoding cannot be found in the body or header.""" - encoding = get_html_media_encoding(b"", "text/html") - self.assertEqual(encoding, "utf-8") + encodings = get_html_media_encodings(b"", "text/html") + self.assertEqual(list(encodings), ["utf-8", "windows-1252"]) -- cgit 1.5.1 From efd0074ab76a9449087e38afadcc5a1b4d5a2813 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 14 Oct 2021 14:51:44 -0400 Subject: Ensure each charset is attempted only once during media preview. (#11089) There's no point in trying more than once since it is guaranteed to continually fail. --- changelog.d/11089.bugfix | 1 + synapse/rest/media/v1/preview_url_resource.py | 34 +++++++++++++++++---- tests/test_preview.py | 43 ++++++++++++++++++++++----- 3 files changed, 64 insertions(+), 14 deletions(-) create mode 100644 changelog.d/11089.bugfix (limited to 'synapse') diff --git a/changelog.d/11089.bugfix b/changelog.d/11089.bugfix new file mode 100644 index 0000000000..dc35c86440 --- /dev/null +++ b/changelog.d/11089.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when attempting to preview URLs which are in the `windows-1252` character encoding. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 7ee91a0c05..278fd901e2 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import codecs import datetime import errno import fnmatch @@ -22,7 +23,7 @@ import re import shutil import sys import traceback -from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Tuple, Union from urllib import parse as urlparse import attr @@ -631,6 +632,14 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("No media removed from url cache") +def _normalise_encoding(encoding: str) -> Optional[str]: + """Use the Python codec's name as the normalised entry.""" + try: + return codecs.lookup(encoding).name + except LookupError: + return None + + def get_html_media_encodings(body: bytes, content_type: Optional[str]) -> Iterable[str]: """ Get potential encoding of the body based on the (presumably) HTML body or the content-type header. @@ -652,30 +661,43 @@ def get_html_media_encodings(body: bytes, content_type: Optional[str]) -> Iterab Returns: The character encoding of the body, as a string. """ + # There's no point in returning an encoding more than once. + attempted_encodings: Set[str] = set() + # Limit searches to the first 1kb, since it ought to be at the top. body_start = body[:1024] # Check if it has an encoding set in a meta tag. match = _charset_match.search(body_start) if match: - yield match.group(1).decode("ascii") + encoding = _normalise_encoding(match.group(1).decode("ascii")) + if encoding: + attempted_encodings.add(encoding) + yield encoding # TODO Support # Check if it has an XML document with an encoding. match = _xml_encoding_match.match(body_start) if match: - yield match.group(1).decode("ascii") + encoding = _normalise_encoding(match.group(1).decode("ascii")) + if encoding and encoding not in attempted_encodings: + attempted_encodings.add(encoding) + yield encoding # Check the HTTP Content-Type header for a character set. if content_type: content_match = _content_type_match.match(content_type) if content_match: - yield content_match.group(1) + encoding = _normalise_encoding(content_match.group(1)) + if encoding and encoding not in attempted_encodings: + attempted_encodings.add(encoding) + yield encoding # Finally, fallback to UTF-8, then windows-1252. - yield "utf-8" - yield "windows-1252" + for fallback in ("utf-8", "cp1252"): + if fallback not in attempted_encodings: + yield fallback def decode_body( diff --git a/tests/test_preview.py b/tests/test_preview.py index c6789017bc..9a576f9a4e 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -307,7 +307,7 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) def test_windows_1252(self): - """A body which uses windows-1252, but doesn't declare that.""" + """A body which uses cp1252, but doesn't declare that.""" html = b""" \xf3 @@ -333,7 +333,7 @@ class MediaEncodingTestCase(unittest.TestCase): """, "text/html", ) - self.assertEqual(list(encodings), ["ascii", "utf-8", "windows-1252"]) + self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) # A less well-formed version. encodings = get_html_media_encodings( @@ -345,7 +345,7 @@ class MediaEncodingTestCase(unittest.TestCase): """, "text/html", ) - self.assertEqual(list(encodings), ["ascii", "utf-8", "windows-1252"]) + self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) def test_meta_charset_underscores(self): """A character encoding contains underscore.""" @@ -358,7 +358,7 @@ class MediaEncodingTestCase(unittest.TestCase): """, "text/html", ) - self.assertEqual(list(encodings), ["Shift_JIS", "utf-8", "windows-1252"]) + self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"]) def test_xml_encoding(self): """A character encoding is found via the meta tag.""" @@ -370,7 +370,7 @@ class MediaEncodingTestCase(unittest.TestCase): """, "text/html", ) - self.assertEqual(list(encodings), ["ascii", "utf-8", "windows-1252"]) + self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) def test_meta_xml_encoding(self): """Meta tags take precedence over XML encoding.""" @@ -384,7 +384,7 @@ class MediaEncodingTestCase(unittest.TestCase): """, "text/html", ) - self.assertEqual(list(encodings), ["UTF-16", "ascii", "utf-8", "windows-1252"]) + self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"]) def test_content_type(self): """A character encoding is found via the Content-Type header.""" @@ -399,9 +399,36 @@ class MediaEncodingTestCase(unittest.TestCase): ) for header in headers: encodings = get_html_media_encodings(b"", header) - self.assertEqual(list(encodings), ["ascii", "utf-8", "windows-1252"]) + self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) def test_fallback(self): """A character encoding cannot be found in the body or header.""" encodings = get_html_media_encodings(b"", "text/html") - self.assertEqual(list(encodings), ["utf-8", "windows-1252"]) + self.assertEqual(list(encodings), ["utf-8", "cp1252"]) + + def test_duplicates(self): + """Ensure each encoding is only attempted once.""" + encodings = get_html_media_encodings( + b""" + + + + + + """, + 'text/html; charset="UTF_8"', + ) + self.assertEqual(list(encodings), ["utf-8", "cp1252"]) + + def test_unknown_invalid(self): + """A character encoding should be ignored if it is unknown or invalid.""" + encodings = get_html_media_encodings( + b""" + + + + + """, + 'text/html; charset="invalid"', + ) + self.assertEqual(list(encodings), ["utf-8", "cp1252"]) -- cgit 1.5.1 From daf498e099394e206709bbc7a330be4a989e31d5 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 14 Oct 2021 18:53:45 -0500 Subject: Fix 500 error on `/messages` when we accumulate more than 5 backward extremities (#11027) Found while working on the Gitter backfill script and noticed it only happened after we sent 7 batches, https://gitlab.com/gitterHQ/webapp/-/merge_requests/2229#note_665906390 When there are more than 5 backward extremities for a given depth, backfill will throw an error because we sliced the extremity list to 5 but then try to iterate over the full list. This causes us to look for state that we never fetched and we get a `KeyError`. Before when calling `/messages` when there are more than 5 backward extremities: ``` Traceback (most recent call last): File "/usr/local/lib/python3.8/site-packages/synapse/http/server.py", line 258, in _async_render_wrapper callback_return = await self._async_render(request) File "/usr/local/lib/python3.8/site-packages/synapse/http/server.py", line 446, in _async_render callback_return = await raw_callback_return File "/usr/local/lib/python3.8/site-packages/synapse/rest/client/room.py", line 580, in on_GET msgs = await self.pagination_handler.get_messages( File "/usr/local/lib/python3.8/site-packages/synapse/handlers/pagination.py", line 396, in get_messages await self.hs.get_federation_handler().maybe_backfill( File "/usr/local/lib/python3.8/site-packages/synapse/handlers/federation.py", line 133, in maybe_backfill return await self._maybe_backfill_inner(room_id, current_depth, limit) File "/usr/local/lib/python3.8/site-packages/synapse/handlers/federation.py", line 386, in _maybe_backfill_inner likely_extremeties_domains = get_domains_from_state(states[e_id]) KeyError: '$zpFflMEBtZdgcMQWTakaVItTLMjLFdKcRWUPHbbSZJl' ``` --- changelog.d/11027.bugfix | 1 + synapse/handlers/federation.py | 24 +++++++------- synapse/handlers/federation_event.py | 2 +- tests/handlers/test_federation.py | 64 ++++++++++++++++++++++++++++++++++++ 4 files changed, 79 insertions(+), 12 deletions(-) create mode 100644 changelog.d/11027.bugfix (limited to 'synapse') diff --git a/changelog.d/11027.bugfix b/changelog.d/11027.bugfix new file mode 100644 index 0000000000..ae6cc44470 --- /dev/null +++ b/changelog.d/11027.bugfix @@ -0,0 +1 @@ +Fix 500 error on `/messages` when the server accumulates more than 5 backwards extremities at a given depth for a room. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3e341bd287..e072efad16 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -238,18 +238,10 @@ class FederationHandler: ) return False - logger.debug( - "room_id: %s, backfill: current_depth: %s, max_depth: %s, extrems: %s", - room_id, - current_depth, - max_depth, - sorted_extremeties_tuple, - ) - # We ignore extremities that have a greater depth than our current depth # as: # 1. we don't really care about getting events that have happened - # before our current position; and + # after our current position; and # 2. we have likely previously tried and failed to backfill from that # extremity, so to avoid getting "stuck" requesting the same # backfill repeatedly we drop those extremities. @@ -257,9 +249,19 @@ class FederationHandler: t for t in sorted_extremeties_tuple if int(t[1]) <= current_depth ] + logger.debug( + "room_id: %s, backfill: current_depth: %s, limit: %s, max_depth: %s, extrems: %s filtered_sorted_extremeties_tuple: %s", + room_id, + current_depth, + limit, + max_depth, + sorted_extremeties_tuple, + filtered_sorted_extremeties_tuple, + ) + # However, we need to check that the filtered extremities are non-empty. # If they are empty then either we can a) bail or b) still attempt to - # backill. We opt to try backfilling anyway just in case we do get + # backfill. We opt to try backfilling anyway just in case we do get # relevant events. if filtered_sorted_extremeties_tuple: sorted_extremeties_tuple = filtered_sorted_extremeties_tuple @@ -389,7 +391,7 @@ class FederationHandler: for key, state_dict in states.items() } - for e_id, _ in sorted_extremeties_tuple: + for e_id in event_ids: likely_extremeties_domains = get_domains_from_state(states[e_id]) success = await try_backfill( diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index f640b417b3..0e455678aa 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -392,7 +392,7 @@ class FederationEventHandler: @log_function async def backfill( - self, dest: str, room_id: str, limit: int, extremities: List[str] + self, dest: str, room_id: str, limit: int, extremities: Iterable[str] ) -> None: """Trigger a backfill request to `dest` for the given `room_id` diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 936ebf3dde..e1557566e4 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -23,6 +23,7 @@ from synapse.federation.federation_base import event_from_pdu_json from synapse.logging.context import LoggingContext, run_in_background from synapse.rest import admin from synapse.rest.client import login, room +from synapse.types import create_requester from synapse.util.stringutils import random_string from tests import unittest @@ -30,6 +31,10 @@ from tests import unittest logger = logging.getLogger(__name__) +def generate_fake_event_id() -> str: + return "$fake_" + random_string(43) + + class FederationTestCase(unittest.HomeserverTestCase): servlets = [ admin.register_servlets, @@ -198,6 +203,65 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(sg, sg2) + def test_backfill_with_many_backward_extremities(self): + """ + Check that we can backfill with many backward extremities. + The goal is to make sure that when we only use a portion + of backwards extremities(the magic number is more than 5), + no errors are thrown. + + Regression test, see #11027 + """ + # create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + requester = create_requester(user_id) + + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + + ev1 = self.helper.send(room_id, "first message", tok=tok) + + # Create "many" backward extremities. The magic number we're trying to + # create more than is 5 which corresponds to the number of backward + # extremities we slice off in `_maybe_backfill_inner` + for _ in range(0, 8): + event_handler = self.hs.get_event_creation_handler() + event, context = self.get_success( + event_handler.create_event( + requester, + { + "type": "m.room.message", + "content": { + "msgtype": "m.text", + "body": "message connected to fake event", + }, + "room_id": room_id, + "sender": user_id, + }, + prev_event_ids=[ + ev1["event_id"], + # We're creating an backward extremity each time thanks + # to this fake event + generate_fake_event_id(), + ], + ) + ) + self.get_success( + event_handler.handle_new_client_event(requester, event, context) + ) + + current_depth = 1 + limit = 100 + with LoggingContext("receive_pdu"): + # Make sure backfill still works + d = run_in_background( + self.hs.get_federation_handler().maybe_backfill, + room_id, + current_depth, + limit, + ) + self.get_success(d) + def test_backfill_floating_outlier_membership_auth(self): """ As the local homeserver, check that we can properly process a federated -- cgit 1.5.1 From 6a67f3786a73f72739ebe8e5aca372c39626d768 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 15 Oct 2021 13:10:58 +0100 Subject: Fix logging context warnings when losing replication connection (#10984) Instead of triggering `__exit__` manually on the replication handler's logging context, use it as a context manager so that there is an `__enter__` call to balance the `__exit__`. --- changelog.d/10984.misc | 1 + synapse/replication/tcp/protocol.py | 18 +++++++++++++----- synapse/replication/tcp/redis.py | 18 +++++++++++++----- 3 files changed, 27 insertions(+), 10 deletions(-) create mode 100644 changelog.d/10984.misc (limited to 'synapse') diff --git a/changelog.d/10984.misc b/changelog.d/10984.misc new file mode 100644 index 0000000000..86c4081cc4 --- /dev/null +++ b/changelog.d/10984.misc @@ -0,0 +1 @@ +Fix spurious warnings about losing the logging context on the `ReplicationCommandHandler` when losing the replication connection. diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 8c80153ab6..7bae36db16 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -182,9 +182,13 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. - self._logging_context = BackgroundProcessLoggingContext( - "replication-conn", self.conn_id - ) + with PreserveLoggingContext(): + # thanks to `PreserveLoggingContext()`, the new logcontext is guaranteed to + # capture the sentinel context as its containing context and won't prevent + # GC of / unintentionally reactivate what would be the current context. + self._logging_context = BackgroundProcessLoggingContext( + "replication-conn", self.conn_id + ) def connectionMade(self): logger.info("[%s] Connection established", self.id()) @@ -434,8 +438,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if self.transport: self.transport.unregisterProducer() - # mark the logging context as finished - self._logging_context.__exit__(None, None, None) + # mark the logging context as finished by triggering `__exit__()` + with PreserveLoggingContext(): + with self._logging_context: + pass + # the sentinel context is now active, which may not be correct. + # PreserveLoggingContext() will restore the correct logging context. def __str__(self): addr = None diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 062fe2f33e..8d28bd3f3f 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -100,9 +100,13 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. - self._logging_context = BackgroundProcessLoggingContext( - "replication_command_handler" - ) + with PreserveLoggingContext(): + # thanks to `PreserveLoggingContext()`, the new logcontext is guaranteed to + # capture the sentinel context as its containing context and won't prevent + # GC of / unintentionally reactivate what would be the current context. + self._logging_context = BackgroundProcessLoggingContext( + "replication_command_handler" + ) def connectionMade(self): logger.info("Connected to redis") @@ -182,8 +186,12 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): super().connectionLost(reason) self.synapse_handler.lost_connection(self) - # mark the logging context as finished - self._logging_context.__exit__(None, None, None) + # mark the logging context as finished by triggering `__exit__()` + with PreserveLoggingContext(): + with self._logging_context: + pass + # the sentinel context is now active, which may not be correct. + # PreserveLoggingContext() will restore the correct logging context. def send_command(self, cmd: Command): """Send a command if connection has been established. -- cgit 1.5.1 From 55731333488bfd53ece117938dde1cef710eef68 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 15 Oct 2021 10:30:48 -0400 Subject: Move experimental & retention config out of the server module. (#11070) --- changelog.d/11070.misc | 1 + docs/sample_config.yaml | 83 ++++++------ synapse/config/_base.pyi | 2 + synapse/config/experimental.py | 3 + synapse/config/homeserver.py | 2 + synapse/config/retention.py | 226 +++++++++++++++++++++++++++++++++ synapse/config/server.py | 201 ----------------------------- synapse/events/utils.py | 6 +- synapse/handlers/pagination.py | 13 +- synapse/storage/databases/main/room.py | 8 +- 10 files changed, 290 insertions(+), 255 deletions(-) create mode 100644 changelog.d/11070.misc create mode 100644 synapse/config/retention.py (limited to 'synapse') diff --git a/changelog.d/11070.misc b/changelog.d/11070.misc new file mode 100644 index 0000000000..52b23f9671 --- /dev/null +++ b/changelog.d/11070.misc @@ -0,0 +1 @@ +Create a separate module for the retention configuration. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 7bfaed483b..b90ed62d61 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -472,6 +472,48 @@ limit_remote_rooms: # #user_ips_max_age: 14d +# Inhibits the /requestToken endpoints from returning an error that might leak +# information about whether an e-mail address is in use or not on this +# homeserver. +# Note that for some endpoints the error situation is the e-mail already being +# used, and for others the error is entering the e-mail being unused. +# If this option is enabled, instead of returning an error, these endpoints will +# act as if no error happened and return a fake session ID ('sid') to clients. +# +#request_token_inhibit_3pid_errors: true + +# A list of domains that the domain portion of 'next_link' parameters +# must match. +# +# This parameter is optionally provided by clients while requesting +# validation of an email or phone number, and maps to a link that +# users will be automatically redirected to after validation +# succeeds. Clients can make use this parameter to aid the validation +# process. +# +# The whitelist is applied whether the homeserver or an +# identity server is handling validation. +# +# The default value is no whitelist functionality; all domains are +# allowed. Setting this value to an empty list will instead disallow +# all domains. +# +#next_link_domain_whitelist: ["matrix.org"] + +# Templates to use when generating email or HTML page contents. +# +templates: + # Directory in which Synapse will try to find template files to use to generate + # email or HTML page contents. + # If not set, or a file is not found within the template directory, a default + # template from within the Synapse package will be used. + # + # See https://matrix-org.github.io/synapse/latest/templates.html for more + # information about using custom templates. + # + #custom_template_directory: /path/to/custom/templates/ + + # Message retention policy at the server level. # # Room admins and mods can define a retention period for their rooms using the @@ -541,47 +583,6 @@ retention: # - shortest_max_lifetime: 3d # interval: 1d -# Inhibits the /requestToken endpoints from returning an error that might leak -# information about whether an e-mail address is in use or not on this -# homeserver. -# Note that for some endpoints the error situation is the e-mail already being -# used, and for others the error is entering the e-mail being unused. -# If this option is enabled, instead of returning an error, these endpoints will -# act as if no error happened and return a fake session ID ('sid') to clients. -# -#request_token_inhibit_3pid_errors: true - -# A list of domains that the domain portion of 'next_link' parameters -# must match. -# -# This parameter is optionally provided by clients while requesting -# validation of an email or phone number, and maps to a link that -# users will be automatically redirected to after validation -# succeeds. Clients can make use this parameter to aid the validation -# process. -# -# The whitelist is applied whether the homeserver or an -# identity server is handling validation. -# -# The default value is no whitelist functionality; all domains are -# allowed. Setting this value to an empty list will instead disallow -# all domains. -# -#next_link_domain_whitelist: ["matrix.org"] - -# Templates to use when generating email or HTML page contents. -# -templates: - # Directory in which Synapse will try to find template files to use to generate - # email or HTML page contents. - # If not set, or a file is not found within the template directory, a default - # template from within the Synapse package will be used. - # - # See https://matrix-org.github.io/synapse/latest/templates.html for more - # information about using custom templates. - # - #custom_template_directory: /path/to/custom/templates/ - ## TLS ## diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 06fbd1166b..c1d9069798 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -26,6 +26,7 @@ from synapse.config import ( redis, registration, repository, + retention, room_directory, saml2, server, @@ -91,6 +92,7 @@ class RootConfig: modules: modules.ModulesConfig caches: cache.CacheConfig federation: federation.FederationConfig + retention: retention.RetentionConfig config_classes: List = ... def __init__(self) -> None: ... diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 7b0381c06a..b013a3918c 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -24,6 +24,9 @@ class ExperimentalConfig(Config): def read_config(self, config: JsonDict, **kwargs): experimental = config.get("experimental_features") or {} + # Whether to enable experimental MSC1849 (aka relations) support + self.msc1849_enabled = config.get("experimental_msc1849_support_enabled", True) + # MSC3026 (busy presence state) self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 442f1b9ac0..001605c265 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -38,6 +38,7 @@ from .ratelimiting import RatelimitConfig from .redis import RedisConfig from .registration import RegistrationConfig from .repository import ContentRepositoryConfig +from .retention import RetentionConfig from .room import RoomConfig from .room_directory import RoomDirectoryConfig from .saml2 import SAML2Config @@ -59,6 +60,7 @@ class HomeServerConfig(RootConfig): config_classes = [ ModulesConfig, ServerConfig, + RetentionConfig, TlsConfig, FederationConfig, CacheConfig, diff --git a/synapse/config/retention.py b/synapse/config/retention.py new file mode 100644 index 0000000000..aed9bf458f --- /dev/null +++ b/synapse/config/retention.py @@ -0,0 +1,226 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import List, Optional + +import attr + +from synapse.config._base import Config, ConfigError + +logger = logging.getLogger(__name__) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RetentionPurgeJob: + """Object describing the configuration of the manhole""" + + interval: int + shortest_max_lifetime: Optional[int] + longest_max_lifetime: Optional[int] + + +class RetentionConfig(Config): + section = "retention" + + def read_config(self, config, **kwargs): + retention_config = config.get("retention") + if retention_config is None: + retention_config = {} + + self.retention_enabled = retention_config.get("enabled", False) + + retention_default_policy = retention_config.get("default_policy") + + if retention_default_policy is not None: + self.retention_default_min_lifetime = retention_default_policy.get( + "min_lifetime" + ) + if self.retention_default_min_lifetime is not None: + self.retention_default_min_lifetime = self.parse_duration( + self.retention_default_min_lifetime + ) + + self.retention_default_max_lifetime = retention_default_policy.get( + "max_lifetime" + ) + if self.retention_default_max_lifetime is not None: + self.retention_default_max_lifetime = self.parse_duration( + self.retention_default_max_lifetime + ) + + if ( + self.retention_default_min_lifetime is not None + and self.retention_default_max_lifetime is not None + and ( + self.retention_default_min_lifetime + > self.retention_default_max_lifetime + ) + ): + raise ConfigError( + "The default retention policy's 'min_lifetime' can not be greater" + " than its 'max_lifetime'" + ) + else: + self.retention_default_min_lifetime = None + self.retention_default_max_lifetime = None + + if self.retention_enabled: + logger.info( + "Message retention policies support enabled with the following default" + " policy: min_lifetime = %s ; max_lifetime = %s", + self.retention_default_min_lifetime, + self.retention_default_max_lifetime, + ) + + self.retention_allowed_lifetime_min = retention_config.get( + "allowed_lifetime_min" + ) + if self.retention_allowed_lifetime_min is not None: + self.retention_allowed_lifetime_min = self.parse_duration( + self.retention_allowed_lifetime_min + ) + + self.retention_allowed_lifetime_max = retention_config.get( + "allowed_lifetime_max" + ) + if self.retention_allowed_lifetime_max is not None: + self.retention_allowed_lifetime_max = self.parse_duration( + self.retention_allowed_lifetime_max + ) + + if ( + self.retention_allowed_lifetime_min is not None + and self.retention_allowed_lifetime_max is not None + and self.retention_allowed_lifetime_min + > self.retention_allowed_lifetime_max + ): + raise ConfigError( + "Invalid retention policy limits: 'allowed_lifetime_min' can not be" + " greater than 'allowed_lifetime_max'" + ) + + self.retention_purge_jobs: List[RetentionPurgeJob] = [] + for purge_job_config in retention_config.get("purge_jobs", []): + interval_config = purge_job_config.get("interval") + + if interval_config is None: + raise ConfigError( + "A retention policy's purge jobs configuration must have the" + " 'interval' key set." + ) + + interval = self.parse_duration(interval_config) + + shortest_max_lifetime = purge_job_config.get("shortest_max_lifetime") + + if shortest_max_lifetime is not None: + shortest_max_lifetime = self.parse_duration(shortest_max_lifetime) + + longest_max_lifetime = purge_job_config.get("longest_max_lifetime") + + if longest_max_lifetime is not None: + longest_max_lifetime = self.parse_duration(longest_max_lifetime) + + if ( + shortest_max_lifetime is not None + and longest_max_lifetime is not None + and shortest_max_lifetime > longest_max_lifetime + ): + raise ConfigError( + "A retention policy's purge jobs configuration's" + " 'shortest_max_lifetime' value can not be greater than its" + " 'longest_max_lifetime' value." + ) + + self.retention_purge_jobs.append( + RetentionPurgeJob(interval, shortest_max_lifetime, longest_max_lifetime) + ) + + if not self.retention_purge_jobs: + self.retention_purge_jobs = [ + RetentionPurgeJob(self.parse_duration("1d"), None, None) + ] + + def generate_config_section(self, config_dir_path, server_name, **kwargs): + return """\ + # Message retention policy at the server level. + # + # Room admins and mods can define a retention period for their rooms using the + # 'm.room.retention' state event, and server admins can cap this period by setting + # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options. + # + # If this feature is enabled, Synapse will regularly look for and purge events + # which are older than the room's maximum retention period. Synapse will also + # filter events received over federation so that events that should have been + # purged are ignored and not stored again. + # + retention: + # The message retention policies feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # Default retention policy. If set, Synapse will apply it to rooms that lack the + # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't + # matter much because Synapse doesn't take it into account yet. + # + #default_policy: + # min_lifetime: 1d + # max_lifetime: 1y + + # Retention policy limits. If set, and the state of a room contains a + # 'm.room.retention' event in its state which contains a 'min_lifetime' or a + # 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy + # to these limits when running purge jobs. + # + #allowed_lifetime_min: 1d + #allowed_lifetime_max: 1y + + # Server admins can define the settings of the background jobs purging the + # events which lifetime has expired under the 'purge_jobs' section. + # + # If no configuration is provided, a single job will be set up to delete expired + # events in every room daily. + # + # Each job's configuration defines which range of message lifetimes the job + # takes care of. For example, if 'shortest_max_lifetime' is '2d' and + # 'longest_max_lifetime' is '3d', the job will handle purging expired events in + # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and + # lower than or equal to 3 days. Both the minimum and the maximum value of a + # range are optional, e.g. a job with no 'shortest_max_lifetime' and a + # 'longest_max_lifetime' of '3d' will handle every room with a retention policy + # which 'max_lifetime' is lower than or equal to three days. + # + # The rationale for this per-job configuration is that some rooms might have a + # retention policy with a low 'max_lifetime', where history needs to be purged + # of outdated messages on a more frequent basis than for the rest of the rooms + # (e.g. every 12h), but not want that purge to be performed by a job that's + # iterating over every room it knows, which could be heavy on the server. + # + # If any purge job is configured, it is strongly recommended to have at least + # a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime' + # set, or one job without 'shortest_max_lifetime' and one job without + # 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if + # 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a + # room's policy to these values is done after the policies are retrieved from + # Synapse's database (which is done using the range specified in a purge job's + # configuration). + # + #purge_jobs: + # - longest_max_lifetime: 3d + # interval: 12h + # - shortest_max_lifetime: 3d + # interval: 1d + """ diff --git a/synapse/config/server.py b/synapse/config/server.py index 818b806357..ed094bdc44 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -225,15 +225,6 @@ class ManholeConfig: pub_key: Optional[Key] -@attr.s(slots=True, frozen=True, auto_attribs=True) -class RetentionConfig: - """Object describing the configuration of the manhole""" - - interval: int - shortest_max_lifetime: Optional[int] - longest_max_lifetime: Optional[int] - - @attr.s(frozen=True) class LimitRemoteRoomsConfig: enabled: bool = attr.ib(validator=attr.validators.instance_of(bool), default=False) @@ -376,11 +367,6 @@ class ServerConfig(Config): # (other than those sent by local server admins) self.block_non_admin_invites = config.get("block_non_admin_invites", False) - # Whether to enable experimental MSC1849 (aka relations) support - self.experimental_msc1849_support_enabled = config.get( - "experimental_msc1849_support_enabled", True - ) - # Options to control access by tracking MAU self.limit_usage_by_mau = config.get("limit_usage_by_mau", False) self.max_mau_value = 0 @@ -466,124 +452,6 @@ class ServerConfig(Config): # events with profile information that differ from the target's global profile. self.allow_per_room_profiles = config.get("allow_per_room_profiles", True) - retention_config = config.get("retention") - if retention_config is None: - retention_config = {} - - self.retention_enabled = retention_config.get("enabled", False) - - retention_default_policy = retention_config.get("default_policy") - - if retention_default_policy is not None: - self.retention_default_min_lifetime = retention_default_policy.get( - "min_lifetime" - ) - if self.retention_default_min_lifetime is not None: - self.retention_default_min_lifetime = self.parse_duration( - self.retention_default_min_lifetime - ) - - self.retention_default_max_lifetime = retention_default_policy.get( - "max_lifetime" - ) - if self.retention_default_max_lifetime is not None: - self.retention_default_max_lifetime = self.parse_duration( - self.retention_default_max_lifetime - ) - - if ( - self.retention_default_min_lifetime is not None - and self.retention_default_max_lifetime is not None - and ( - self.retention_default_min_lifetime - > self.retention_default_max_lifetime - ) - ): - raise ConfigError( - "The default retention policy's 'min_lifetime' can not be greater" - " than its 'max_lifetime'" - ) - else: - self.retention_default_min_lifetime = None - self.retention_default_max_lifetime = None - - if self.retention_enabled: - logger.info( - "Message retention policies support enabled with the following default" - " policy: min_lifetime = %s ; max_lifetime = %s", - self.retention_default_min_lifetime, - self.retention_default_max_lifetime, - ) - - self.retention_allowed_lifetime_min = retention_config.get( - "allowed_lifetime_min" - ) - if self.retention_allowed_lifetime_min is not None: - self.retention_allowed_lifetime_min = self.parse_duration( - self.retention_allowed_lifetime_min - ) - - self.retention_allowed_lifetime_max = retention_config.get( - "allowed_lifetime_max" - ) - if self.retention_allowed_lifetime_max is not None: - self.retention_allowed_lifetime_max = self.parse_duration( - self.retention_allowed_lifetime_max - ) - - if ( - self.retention_allowed_lifetime_min is not None - and self.retention_allowed_lifetime_max is not None - and self.retention_allowed_lifetime_min - > self.retention_allowed_lifetime_max - ): - raise ConfigError( - "Invalid retention policy limits: 'allowed_lifetime_min' can not be" - " greater than 'allowed_lifetime_max'" - ) - - self.retention_purge_jobs: List[RetentionConfig] = [] - for purge_job_config in retention_config.get("purge_jobs", []): - interval_config = purge_job_config.get("interval") - - if interval_config is None: - raise ConfigError( - "A retention policy's purge jobs configuration must have the" - " 'interval' key set." - ) - - interval = self.parse_duration(interval_config) - - shortest_max_lifetime = purge_job_config.get("shortest_max_lifetime") - - if shortest_max_lifetime is not None: - shortest_max_lifetime = self.parse_duration(shortest_max_lifetime) - - longest_max_lifetime = purge_job_config.get("longest_max_lifetime") - - if longest_max_lifetime is not None: - longest_max_lifetime = self.parse_duration(longest_max_lifetime) - - if ( - shortest_max_lifetime is not None - and longest_max_lifetime is not None - and shortest_max_lifetime > longest_max_lifetime - ): - raise ConfigError( - "A retention policy's purge jobs configuration's" - " 'shortest_max_lifetime' value can not be greater than its" - " 'longest_max_lifetime' value." - ) - - self.retention_purge_jobs.append( - RetentionConfig(interval, shortest_max_lifetime, longest_max_lifetime) - ) - - if not self.retention_purge_jobs: - self.retention_purge_jobs = [ - RetentionConfig(self.parse_duration("1d"), None, None) - ] - self.listeners = [parse_listener_def(x) for x in config.get("listeners", [])] # no_tls is not really supported any more, but let's grandfather it in @@ -1255,75 +1123,6 @@ class ServerConfig(Config): # #user_ips_max_age: 14d - # Message retention policy at the server level. - # - # Room admins and mods can define a retention period for their rooms using the - # 'm.room.retention' state event, and server admins can cap this period by setting - # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options. - # - # If this feature is enabled, Synapse will regularly look for and purge events - # which are older than the room's maximum retention period. Synapse will also - # filter events received over federation so that events that should have been - # purged are ignored and not stored again. - # - retention: - # The message retention policies feature is disabled by default. Uncomment the - # following line to enable it. - # - #enabled: true - - # Default retention policy. If set, Synapse will apply it to rooms that lack the - # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't - # matter much because Synapse doesn't take it into account yet. - # - #default_policy: - # min_lifetime: 1d - # max_lifetime: 1y - - # Retention policy limits. If set, and the state of a room contains a - # 'm.room.retention' event in its state which contains a 'min_lifetime' or a - # 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy - # to these limits when running purge jobs. - # - #allowed_lifetime_min: 1d - #allowed_lifetime_max: 1y - - # Server admins can define the settings of the background jobs purging the - # events which lifetime has expired under the 'purge_jobs' section. - # - # If no configuration is provided, a single job will be set up to delete expired - # events in every room daily. - # - # Each job's configuration defines which range of message lifetimes the job - # takes care of. For example, if 'shortest_max_lifetime' is '2d' and - # 'longest_max_lifetime' is '3d', the job will handle purging expired events in - # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and - # lower than or equal to 3 days. Both the minimum and the maximum value of a - # range are optional, e.g. a job with no 'shortest_max_lifetime' and a - # 'longest_max_lifetime' of '3d' will handle every room with a retention policy - # which 'max_lifetime' is lower than or equal to three days. - # - # The rationale for this per-job configuration is that some rooms might have a - # retention policy with a low 'max_lifetime', where history needs to be purged - # of outdated messages on a more frequent basis than for the rest of the rooms - # (e.g. every 12h), but not want that purge to be performed by a job that's - # iterating over every room it knows, which could be heavy on the server. - # - # If any purge job is configured, it is strongly recommended to have at least - # a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime' - # set, or one job without 'shortest_max_lifetime' and one job without - # 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if - # 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a - # room's policy to these values is done after the policies are retrieved from - # Synapse's database (which is done using the range specified in a purge job's - # configuration). - # - #purge_jobs: - # - longest_max_lifetime: 3d - # interval: 12h - # - shortest_max_lifetime: 3d - # interval: 1d - # Inhibits the /requestToken endpoints from returning an error that might leak # information about whether an e-mail address is in use or not on this # homeserver. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 23bd24d963..3f3eba86a8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -385,9 +385,7 @@ class EventClientSerializer: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - self.experimental_msc1849_support_enabled = ( - hs.config.server.experimental_msc1849_support_enabled - ) + self._msc1849_enabled = hs.config.experimental.msc1849_enabled async def serialize_event( self, @@ -418,7 +416,7 @@ class EventClientSerializer: # we need to bundle in with the event. # Do not bundle relations if the event has been redacted if not event.internal_metadata.is_redacted() and ( - self.experimental_msc1849_support_enabled and bundle_aggregations + self._msc1849_enabled and bundle_aggregations ): annotations = await self.store.get_aggregation_groups_for_event(event_id) references = await self.store.get_relations_for_event( diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 176e4dfdd4..60ff896386 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -86,19 +86,22 @@ class PaginationHandler: self._event_serializer = hs.get_event_client_serializer() self._retention_default_max_lifetime = ( - hs.config.server.retention_default_max_lifetime + hs.config.retention.retention_default_max_lifetime ) self._retention_allowed_lifetime_min = ( - hs.config.server.retention_allowed_lifetime_min + hs.config.retention.retention_allowed_lifetime_min ) self._retention_allowed_lifetime_max = ( - hs.config.server.retention_allowed_lifetime_max + hs.config.retention.retention_allowed_lifetime_max ) - if hs.config.worker.run_background_tasks and hs.config.server.retention_enabled: + if ( + hs.config.worker.run_background_tasks + and hs.config.retention.retention_enabled + ): # Run the purge jobs described in the configuration file. - for job in hs.config.server.retention_purge_jobs: + for job in hs.config.retention.retention_purge_jobs: logger.info("Setting up purge job with config: %s", job) self.clock.looping_call( diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index d69eaf80ce..835d7889cb 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -679,8 +679,8 @@ class RoomWorkerStore(SQLBaseStore): # policy. if not ret: return { - "min_lifetime": self.config.server.retention_default_min_lifetime, - "max_lifetime": self.config.server.retention_default_max_lifetime, + "min_lifetime": self.config.retention.retention_default_min_lifetime, + "max_lifetime": self.config.retention.retention_default_max_lifetime, } row = ret[0] @@ -690,10 +690,10 @@ class RoomWorkerStore(SQLBaseStore): # The default values will be None if no default policy has been defined, or if one # of the attributes is missing from the default policy. if row["min_lifetime"] is None: - row["min_lifetime"] = self.config.server.retention_default_min_lifetime + row["min_lifetime"] = self.config.retention.retention_default_min_lifetime if row["max_lifetime"] is None: - row["max_lifetime"] = self.config.server.retention_default_max_lifetime + row["max_lifetime"] = self.config.retention.retention_default_max_lifetime return row -- cgit 1.5.1 From e09be0c87a8c0da1381e7986cc940d1411705081 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 15 Oct 2021 15:53:05 +0100 Subject: Correctly exclude users when making a room public or private (#11075) Co-authored-by: Patrick Cloke --- changelog.d/11075.bugfix | 1 + synapse/handlers/user_directory.py | 11 ++- tests/handlers/test_user_directory.py | 142 +++++++++++++++++++++++++--------- tests/storage/test_user_directory.py | 77 ++++++++---------- 4 files changed, 148 insertions(+), 83 deletions(-) create mode 100644 changelog.d/11075.bugfix (limited to 'synapse') diff --git a/changelog.d/11075.bugfix b/changelog.d/11075.bugfix new file mode 100644 index 0000000000..9b24971c5a --- /dev/null +++ b/changelog.d/11075.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where users excluded from the user directory were added into the directory if they belonged to a room which became public or private. \ No newline at end of file diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 52b2de388f..99f23ed967 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -266,14 +266,17 @@ class UserDirectoryHandler(StateDeltasHandler): for user_id in users_in_room: await self.store.remove_user_who_share_room(user_id, room_id) - # Then, re-add them to the tables. + # Then, re-add all remote users and some local users to the tables. # NOTE: this is not the most efficient method, as _track_user_joined_room sets # up local_user -> other_user and other_user_whos_local -> local_user, # which when ran over an entire room, will result in the same values # being added multiple times. The batching upserts shouldn't make this # too bad, though. for user_id in users_in_room: - await self._track_user_joined_room(room_id, user_id) + if not self.is_mine_id( + user_id + ) or await self.store.should_include_local_user_in_dir(user_id): + await self._track_user_joined_room(room_id, user_id) async def _handle_room_membership_event( self, @@ -364,8 +367,8 @@ class UserDirectoryHandler(StateDeltasHandler): """Someone's just joined a room. Update `users_in_public_rooms` or `users_who_share_private_rooms` as appropriate. - The caller is responsible for ensuring that the given user is not excluded - from the user directory. + The caller is responsible for ensuring that the given user should be + included in the user directory. """ is_public = await self.store.is_room_world_readable_or_publicly_joinable( room_id diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 0120b4688b..e0635c8898 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -109,18 +109,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): tok=alice_token, ) - users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) - in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) - in_private = self.get_success( - self.user_dir_helper.get_users_who_share_private_rooms() + # The user directory should reflect the room memberships above. + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() ) - self.assertEqual(users, {alice, bob}) + self.assertEqual(in_public, {(alice, public), (bob, public), (alice, public2)}) self.assertEqual( - set(in_public), {(alice, public), (bob, public), (alice, public2)} - ) - self.assertEqual( - self.user_dir_helper._compress_shared(in_private), + in_private, {(alice, bob, private), (bob, alice, private)}, ) @@ -209,6 +205,88 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) self.assertEqual(set(in_public), {(user1, room), (user2, room)}) + def test_excludes_users_when_making_room_public(self) -> None: + # Create a regular user and a support user. + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + support = "@support1:test" + self.get_success( + self.store.register_user( + user_id=support, password_hash=None, user_type=UserTypes.SUPPORT + ) + ) + + # Make a public and private room containing Alice and the support user + public, initially_private = self._create_rooms_and_inject_memberships( + alice, alice_token, support + ) + self._check_only_one_user_in_directory(alice, public) + + # Alice makes the private room public. + self.helper.send_state( + initially_private, + "m.room.join_rules", + {"join_rule": "public"}, + tok=alice_token, + ) + + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice}) + self.assertEqual(in_public, {(alice, public), (alice, initially_private)}) + self.assertEqual(in_private, set()) + + def test_switching_from_private_to_public_to_private(self) -> None: + """Check we update the room sharing tables when switching a room + from private to public, then back again to private.""" + # Alice and Bob share a private room. + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + bob = self.register_user("bob", "pass") + bob_token = self.login(bob, "pass") + room = self.helper.create_room_as(alice, is_public=False, tok=alice_token) + self.helper.invite(room, alice, bob, tok=alice_token) + self.helper.join(room, bob, tok=bob_token) + + # The user directory should reflect this. + def check_user_dir_for_private_room() -> None: + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice, bob}) + self.assertEqual(in_public, set()) + self.assertEqual(in_private, {(alice, bob, room), (bob, alice, room)}) + + check_user_dir_for_private_room() + + # Alice makes the room public. + self.helper.send_state( + room, + "m.room.join_rules", + {"join_rule": "public"}, + tok=alice_token, + ) + + # The user directory should be updated accordingly + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice, bob}) + self.assertEqual(in_public, {(alice, room), (bob, room)}) + self.assertEqual(in_private, set()) + + # Alice makes the room private. + self.helper.send_state( + room, + "m.room.join_rules", + {"join_rule": "invite"}, + tok=alice_token, + ) + + # The user directory should be updated accordingly + check_user_dir_for_private_room() + def _create_rooms_and_inject_memberships( self, creator: str, token: str, joiner: str ) -> Tuple[str, str]: @@ -232,15 +310,18 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): return public_room, private_room def _check_only_one_user_in_directory(self, user: str, public: str) -> None: - users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) - in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) - in_private = self.get_success( - self.user_dir_helper.get_users_who_share_private_rooms() - ) + """Check that the user directory DB tables show that: + - only one user is in the user directory + - they belong to exactly one public room + - they don't share a private room with anyone. + """ + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) self.assertEqual(users, {user}) - self.assertEqual(set(in_public), {(user, public)}) - self.assertEqual(in_private, []) + self.assertEqual(in_public, {(user, public)}) + self.assertEqual(in_private, set()) def test_handle_local_profile_change_with_support_user(self) -> None: support_user_id = "@support:test" @@ -581,11 +662,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.user_dir_helper.get_users_in_public_rooms() ) - self.assertEqual( - self.user_dir_helper._compress_shared(shares_private), - {(u1, u2, room), (u2, u1, room)}, - ) - self.assertEqual(public_users, []) + self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)}) + self.assertEqual(public_users, set()) # We get one search result when searching for user2 by user1. s = self.get_success(self.handler.search_users(u1, "user2", 10)) @@ -610,8 +688,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.user_dir_helper.get_users_in_public_rooms() ) - self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set()) - self.assertEqual(public_users, []) + self.assertEqual(shares_private, set()) + self.assertEqual(public_users, set()) # User1 now gets no search results for any of the other users. s = self.get_success(self.handler.search_users(u1, "user2", 10)) @@ -645,11 +723,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.user_dir_helper.get_users_in_public_rooms() ) - self.assertEqual( - self.user_dir_helper._compress_shared(shares_private), - {(u1, u2, room), (u2, u1, room)}, - ) - self.assertEqual(public_users, []) + self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)}) + self.assertEqual(public_users, set()) # We get one search result when searching for user2 by user1. s = self.get_success(self.handler.search_users(u1, "user2", 10)) @@ -704,11 +779,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.user_dir_helper.get_users_in_public_rooms() ) - self.assertEqual( - self.user_dir_helper._compress_shared(shares_private), - {(u1, u2, room), (u2, u1, room)}, - ) - self.assertEqual(public_users, []) + self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)}) + self.assertEqual(public_users, set()) # Configure a spam checker. spam_checker = self.hs.get_spam_checker() @@ -740,8 +812,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) # No users share rooms - self.assertEqual(public_users, []) - self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set()) + self.assertEqual(public_users, set()) + self.assertEqual(shares_private, set()) # Despite not sharing a room, search_all_users means we get a search # result. diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index be3ed64f5e..37cf7bb232 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Set, Tuple +from typing import Any, Dict, Set, Tuple from unittest import mock from unittest.mock import Mock, patch @@ -42,18 +42,7 @@ class GetUserDirectoryTables: def __init__(self, store: DataStore): self.store = store - def _compress_shared( - self, shared: List[Dict[str, str]] - ) -> Set[Tuple[str, str, str]]: - """ - Compress a list of users who share rooms dicts to a list of tuples. - """ - r = set() - for i in shared: - r.add((i["user_id"], i["other_user_id"], i["room_id"])) - return r - - async def get_users_in_public_rooms(self) -> List[Tuple[str, str]]: + async def get_users_in_public_rooms(self) -> Set[Tuple[str, str]]: """Fetch the entire `users_in_public_rooms` table. Returns a list of tuples (user_id, room_id) where room_id is public and @@ -63,24 +52,27 @@ class GetUserDirectoryTables: "users_in_public_rooms", None, ("user_id", "room_id") ) - retval = [] + retval = set() for i in r: - retval.append((i["user_id"], i["room_id"])) + retval.add((i["user_id"], i["room_id"])) return retval - async def get_users_who_share_private_rooms(self) -> List[Dict[str, str]]: + async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]: """Fetch the entire `users_who_share_private_rooms` table. - Returns a dict containing "user_id", "other_user_id" and "room_id" keys. - The dicts can be flattened to Tuples with the `_compress_shared` method. - (This seems a little awkward---maybe we could clean this up.) + Returns a set of tuples (user_id, other_user_id, room_id) corresponding + to the rows of `users_who_share_private_rooms`. """ - return await self.store.db_pool.simple_select_list( + rows = await self.store.db_pool.simple_select_list( "users_who_share_private_rooms", None, ["user_id", "other_user_id", "room_id"], ) + rv = set() + for row in rows: + rv.add((row["user_id"], row["other_user_id"], row["room_id"])) + return rv async def get_users_in_user_directory(self) -> Set[str]: """Fetch the set of users in the `user_directory` table. @@ -113,6 +105,16 @@ class GetUserDirectoryTables: for row in rows } + async def get_tables( + self, + ) -> Tuple[Set[str], Set[Tuple[str, str]], Set[Tuple[str, str, str]]]: + """Multiple tests want to inspect these tables, so expose them together.""" + return ( + await self.get_users_in_user_directory(), + await self.get_users_in_public_rooms(), + await self.get_users_who_share_private_rooms(), + ) + class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): """Ensure that rebuilding the directory writes the correct data to the DB. @@ -166,8 +168,8 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): ) # Nothing updated yet - self.assertEqual(shares_private, []) - self.assertEqual(public_users, []) + self.assertEqual(shares_private, set()) + self.assertEqual(public_users, set()) # Ugh, have to reset this flag self.store.db_pool.updates._all_done = False @@ -236,24 +238,15 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): # Do the initial population of the user directory via the background update self._purge_and_rebuild_user_dir() - shares_private = self.get_success( - self.user_dir_helper.get_users_who_share_private_rooms() - ) - public_users = self.get_success( - self.user_dir_helper.get_users_in_public_rooms() + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() ) # User 1 and User 2 are in the same public room - self.assertEqual(set(public_users), {(u1, room), (u2, room)}) - + self.assertEqual(in_public, {(u1, room), (u2, room)}) # User 1 and User 3 share private rooms - self.assertEqual( - self.user_dir_helper._compress_shared(shares_private), - {(u1, u3, private_room), (u3, u1, private_room)}, - ) - + self.assertEqual(in_private, {(u1, u3, private_room), (u3, u1, private_room)}) # All three should have entries in the directory - users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) self.assertEqual(users, {u1, u2, u3}) # The next four tests (test_population_excludes_*) all set up @@ -289,16 +282,12 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): self, normal_user: str, public_room: str, private_room: str ) -> None: # After rebuilding the directory, we should only see the normal user. - users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) - self.assertEqual(users, {normal_user}) - in_public_rooms = self.get_success( - self.user_dir_helper.get_users_in_public_rooms() + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() ) - self.assertEqual(set(in_public_rooms), {(normal_user, public_room)}) - in_private_rooms = self.get_success( - self.user_dir_helper.get_users_who_share_private_rooms() - ) - self.assertEqual(in_private_rooms, []) + self.assertEqual(users, {normal_user}) + self.assertEqual(in_public, {(normal_user, public_room)}) + self.assertEqual(in_private, set()) def test_population_excludes_support_user(self) -> None: # Create a normal and support user. -- cgit 1.5.1 From 37b845dabc687b1e0d4bc84bf5933db10db641d5 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 18 Oct 2021 14:20:04 +0100 Subject: Don't remove local users from dir when the leave their last room (#11103) --- changelog.d/11103.bugfix | 1 + synapse/handlers/user_directory.py | 13 +++++---- tests/handlers/test_user_directory.py | 50 +++++++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 5 deletions(-) create mode 100644 changelog.d/11103.bugfix (limited to 'synapse') diff --git a/changelog.d/11103.bugfix b/changelog.d/11103.bugfix new file mode 100644 index 0000000000..3498f04a45 --- /dev/null +++ b/changelog.d/11103.bugfix @@ -0,0 +1 @@ +Fix local users who left all their rooms being removed from the user directory, even if the "search_all_users" config option was enabled. \ No newline at end of file diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 99f23ed967..991fee7e58 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -415,16 +415,19 @@ class UserDirectoryHandler(StateDeltasHandler): room_id: The room ID that user left or stopped being public that user_id """ - logger.debug("Removing user %r", user_id) + logger.debug("Removing user %r from room %r", user_id, room_id) # Remove user from sharing tables await self.store.remove_user_who_share_room(user_id, room_id) - # Are they still in any rooms? If not, remove them entirely. - rooms_user_is_in = await self.store.get_user_dir_rooms_user_is_in(user_id) + # Additionally, if they're a remote user and we're no longer joined + # to any rooms they're in, remove them from the user directory. + if not self.is_mine_id(user_id): + rooms_user_is_in = await self.store.get_user_dir_rooms_user_is_in(user_id) - if len(rooms_user_is_in) == 0: - await self.store.remove_from_user_dir(user_id) + if len(rooms_user_is_in) == 0: + logger.debug("Removing user %r from directory", user_id) + await self.store.remove_from_user_dir(user_id) async def _handle_possible_remote_profile_change( self, diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index e0635c8898..b9ad92b977 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -914,6 +914,56 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.hs.get_storage().persistence.persist_event(event, context) ) + def test_local_user_leaving_room_remains_in_user_directory(self) -> None: + """We've chosen to simplify the user directory's implementation by + always including local users. Ensure this invariant is maintained when + a local user + - leaves a room, and + - leaves the last room they're in which is visible to this server. + + This is user-visible if the "search_all_users" config option is on: the + local user who left a room would no longer be searchable if this test fails! + """ + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + bob = self.register_user("bob", "pass") + bob_token = self.login(bob, "pass") + + # Alice makes two public rooms, which Bob joins. + room1 = self.helper.create_room_as(alice, is_public=True, tok=alice_token) + room2 = self.helper.create_room_as(alice, is_public=True, tok=alice_token) + self.helper.join(room1, bob, tok=bob_token) + self.helper.join(room2, bob, tok=bob_token) + + # The user directory tables are updated. + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice, bob}) + self.assertEqual( + in_public, {(alice, room1), (alice, room2), (bob, room1), (bob, room2)} + ) + self.assertEqual(in_private, set()) + + # Alice leaves one room. She should still be in the directory. + self.helper.leave(room1, alice, tok=alice_token) + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice, bob}) + self.assertEqual(in_public, {(alice, room2), (bob, room1), (bob, room2)}) + self.assertEqual(in_private, set()) + + # Alice leaves the other. She should still be in the directory. + self.helper.leave(room2, alice, tok=alice_token) + self.wait_for_background_updates() + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice, bob}) + self.assertEqual(in_public, {(bob, room1), (bob, room2)}) + self.assertEqual(in_private, set()) + class TestUserDirSearchDisabled(unittest.HomeserverTestCase): servlets = [ -- cgit 1.5.1 From 7d70582eb0e0b9656e64e827eca6dfc2533b8ae1 Mon Sep 17 00:00:00 2001 From: Hillery Shay Date: Mon, 18 Oct 2021 08:14:12 -0700 Subject: Fix broken export-data admin command and add a test for it to CI (#11078) Fix broken export-data admin command and add a test for it to CI --- .ci/scripts/test_export_data_command.sh | 57 +++++++++++++++++++++++++++++++++ .github/workflows/tests.yml | 29 +++++++++++++++++ changelog.d/11078.bugfix | 1 + synapse/app/admin_cmd.py | 14 ++++---- 4 files changed, 93 insertions(+), 8 deletions(-) create mode 100755 .ci/scripts/test_export_data_command.sh create mode 100644 changelog.d/11078.bugfix (limited to 'synapse') diff --git a/.ci/scripts/test_export_data_command.sh b/.ci/scripts/test_export_data_command.sh new file mode 100755 index 0000000000..75f5811d10 --- /dev/null +++ b/.ci/scripts/test_export_data_command.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bash + +# Test for the export-data admin command against sqlite and postgres + +set -xe +cd `dirname $0`/../.. + +echo "--- Install dependencies" + +# Install dependencies for this test. +pip install psycopg2 + +# Install Synapse itself. This won't update any libraries. +pip install -e . + +echo "--- Generate the signing key" + +# Generate the server's signing key. +python -m synapse.app.homeserver --generate-keys -c .ci/sqlite-config.yaml + +echo "--- Prepare test database" + +# Make sure the SQLite3 database is using the latest schema and has no pending background update. +scripts/update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates + +# Run the export-data command on the sqlite test database +python -m synapse.app.admin_cmd -c .ci/sqlite-config.yaml export-data @anon-20191002_181700-832:localhost:8800 \ +--output-directory /tmp/export_data + +# Test that the output directory exists and contains the rooms directory +dir="/tmp/export_data/rooms" +if [ -d "$dir" ]; then + echo "Command successful, this test passes" +else + echo "No output directories found, the command fails against a sqlite database." + exit 1 +fi + +# Create the PostgreSQL database. +.ci/scripts/postgres_exec.py "CREATE DATABASE synapse" + +# Port the SQLite databse to postgres so we can check command works against postgres +echo "+++ Port SQLite3 databse to postgres" +scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml + +# Run the export-data command on postgres database +python -m synapse.app.admin_cmd -c .ci/postgres-config.yaml export-data @anon-20191002_181700-832:localhost:8800 \ +--output-directory /tmp/export_data2 + +# Test that the output directory exists and contains the rooms directory +dir2="/tmp/export_data2/rooms" +if [ -d "$dir2" ]; then + echo "Command successful, this test passes" +else + echo "No output directories found, the command fails against a postgres database." + exit 1 +fi diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9e302bf446..8d7e8cafd9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -253,6 +253,35 @@ jobs: /logs/results.tap /logs/**/*.log* + export-data: + if: ${{ !failure() && !cancelled() }} # Allow previous steps to be skipped, but not fail + needs: [linting-done, portdb] + runs-on: ubuntu-latest + env: + TOP: ${{ github.workspace }} + + services: + postgres: + image: postgres + ports: + - 5432:5432 + env: + POSTGRES_PASSWORD: "postgres" + POSTGRES_INITDB_ARGS: "--lc-collate C --lc-ctype C --encoding UTF8" + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - uses: actions/checkout@v2 + - run: sudo apt-get -qq install xmlsec1 + - uses: actions/setup-python@v2 + with: + python-version: "3.9" + - run: .ci/scripts/test_export_data_command.sh + portdb: if: ${{ !failure() && !cancelled() }} # Allow previous steps to be skipped, but not fail needs: linting-done diff --git a/changelog.d/11078.bugfix b/changelog.d/11078.bugfix new file mode 100644 index 0000000000..cc813babe4 --- /dev/null +++ b/changelog.d/11078.bugfix @@ -0,0 +1 @@ +Fix broken export-data admin command and add test script checking the command to CI. \ No newline at end of file diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 13d20af457..b156b93bf3 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -39,6 +39,7 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.server import HomeServer +from synapse.storage.databases.main.room import RoomWorkerStore from synapse.util.logcontext import LoggingContext from synapse.util.versionstring import get_version_string @@ -58,6 +59,7 @@ class AdminCmdSlavedStore( SlavedEventStore, SlavedClientIpStore, BaseSlavedStore, + RoomWorkerStore, ): pass @@ -185,11 +187,7 @@ def start(config_options): # a full worker config. config.worker.worker_app = "synapse.app.admin_cmd" - if ( - not config.worker.worker_daemonize - and not config.worker.worker_log_file - and not config.worker.worker_log_config - ): + if not config.worker.worker_daemonize and not config.worker.worker_log_config: # Since we're meant to be run as a "command" let's not redirect stdio # unless we've actually set log config. config.logging.no_redirect_stdio = True @@ -198,9 +196,9 @@ def start(config_options): config.server.update_user_directory = False config.worker.run_background_tasks = False config.worker.start_pushers = False - config.pusher_shard_config.instances = [] + config.worker.pusher_shard_config.instances = [] config.worker.send_federation = False - config.federation_shard_config.instances = [] + config.worker.federation_shard_config.instances = [] synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts @@ -221,7 +219,7 @@ def start(config_options): async def run(): with LoggingContext("command"): - _base.start(ss) + await _base.start(ss) await args.func(ss, args) _base.start_worker_reactor( -- cgit 1.5.1 From e8f24b6c3566f3fc902b9ec0d6c483821ac37cb7 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 18 Oct 2021 18:17:15 +0200 Subject: `_run_push_actions_and_persist_event`: handle no min_depth (#11014) Make sure that we correctly handle rooms where we do not yet have a `min_depth`, and also add some comments and logging. --- changelog.d/11014.misc | 1 + synapse/handlers/federation_event.py | 28 ++++++++++++++-------- synapse/storage/databases/main/event_federation.py | 2 +- 3 files changed, 20 insertions(+), 11 deletions(-) create mode 100644 changelog.d/11014.misc (limited to 'synapse') diff --git a/changelog.d/11014.misc b/changelog.d/11014.misc new file mode 100644 index 0000000000..4b99ea354f --- /dev/null +++ b/changelog.d/11014.misc @@ -0,0 +1 @@ +Add some extra logging to the event persistence code. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 0e455678aa..b8ce0006bb 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -214,7 +214,7 @@ class FederationEventHandler: if missing_prevs: # We only backfill backwards to the min depth. - min_depth = await self.get_min_depth_for_context(pdu.room_id) + min_depth = await self._store.get_min_depth(pdu.room_id) logger.debug("min_depth: %d", min_depth) if min_depth is not None and pdu.depth > min_depth: @@ -1696,16 +1696,27 @@ class FederationEventHandler: # persist_events_and_notify directly.) assert not event.internal_metadata.outlier - try: - if ( - not backfilled - and not context.rejected - and (await self._store.get_min_depth(event.room_id)) <= event.depth - ): + if not backfilled and not context.rejected: + min_depth = await self._store.get_min_depth(event.room_id) + if min_depth is None or min_depth > event.depth: + # XXX richvdh 2021/10/07: I don't really understand what this + # condition is doing. I think it's trying not to send pushes + # for events that predate our join - but that's not really what + # min_depth means, and anyway ancient events are a more general + # problem. + # + # for now I'm just going to log about it. + logger.info( + "Skipping push actions for old event with depth %s < %s", + event.depth, + min_depth, + ) + else: await self._action_generator.handle_push_actions_for_event( event, context ) + try: await self.persist_events_and_notify( event.room_id, [(event, context)], backfilled=backfilled ) @@ -1837,6 +1848,3 @@ class FederationEventHandler: len(ev.auth_event_ids()), ) raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") - - async def get_min_depth_for_context(self, context: str) -> int: - return await self._store.get_min_depth(context) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 10184d6ae7..ba9f71a230 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -906,7 +906,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas desc="get_latest_event_ids_in_room", ) - async def get_min_depth(self, room_id: str) -> int: + async def get_min_depth(self, room_id: str) -> Optional[int]: """For the given room, get the minimum depth we have seen for it.""" return await self.db_pool.runInteraction( "get_min_depth", self._get_min_depth_interaction, room_id -- cgit 1.5.1 From a5d2ea3d08f780cdb746ea7101824513a9ec9610 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 18 Oct 2021 19:28:30 +0200 Subject: Check *all* auth events for room id and rejection (#11009) This fixes a bug where we would accept an event whose `auth_events` include rejected events, if the rejected event was shadowed by another `auth_event` with same `(type, state_key)`. The approach is to pass a list of auth events into `check_auth_rules_for_event` instead of a dict, which of course means updating the call sites. This is an extension of #10956. --- changelog.d/11009.bugfix | 1 + synapse/event_auth.py | 33 ++++----- synapse/handlers/event_auth.py | 3 +- synapse/handlers/federation.py | 10 +-- synapse/handlers/federation_event.py | 16 ++-- synapse/state/v1.py | 4 +- synapse/state/v2.py | 2 +- tests/test_event_auth.py | 138 +++++++++++++++++++++++------------ 8 files changed, 122 insertions(+), 85 deletions(-) create mode 100644 changelog.d/11009.bugfix (limited to 'synapse') diff --git a/changelog.d/11009.bugfix b/changelog.d/11009.bugfix new file mode 100644 index 0000000000..13b8e5983b --- /dev/null +++ b/changelog.d/11009.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug which meant that events received over federation were sometimes incorrectly accepted into the room state. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index ca0293a3dc..e885961698 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes @@ -113,7 +113,7 @@ def validate_event_for_room_version( def check_auth_rules_for_event( - room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase] + room_version_obj: RoomVersion, event: EventBase, auth_events: Iterable[EventBase] ) -> None: """Check that an event complies with the auth rules @@ -137,8 +137,6 @@ def check_auth_rules_for_event( Raises: AuthError if the checks fail """ - assert isinstance(auth_events, dict) - # We need to ensure that the auth events are actually for the same room, to # stop people from using powers they've been granted in other rooms for # example. @@ -147,7 +145,7 @@ def check_auth_rules_for_event( # the state res algorithm isn't silly enough to give us events from different rooms. # Still, it's easier to do it anyway. room_id = event.room_id - for auth_event in auth_events.values(): + for auth_event in auth_events: if auth_event.room_id != room_id: raise AuthError( 403, @@ -186,8 +184,10 @@ def check_auth_rules_for_event( logger.debug("Allowing! %s", event) return + auth_dict = {(e.type, e.state_key): e for e in auth_events} + # 3. If event does not have a m.room.create in its auth_events, reject. - creation_event = auth_events.get((EventTypes.Create, ""), None) + creation_event = auth_dict.get((EventTypes.Create, ""), None) if not creation_event: raise AuthError(403, "No create event in auth events") @@ -195,7 +195,7 @@ def check_auth_rules_for_event( creating_domain = get_domain_from_id(event.room_id) originating_domain = get_domain_from_id(event.sender) if creating_domain != originating_domain: - if not _can_federate(event, auth_events): + if not _can_federate(event, auth_dict): raise AuthError(403, "This room has been marked as unfederatable.") # 4. If type is m.room.aliases @@ -217,23 +217,20 @@ def check_auth_rules_for_event( logger.debug("Allowing! %s", event) return - if logger.isEnabledFor(logging.DEBUG): - logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()]) - # 5. If type is m.room.membership if event.type == EventTypes.Member: - _is_membership_change_allowed(room_version_obj, event, auth_events) + _is_membership_change_allowed(room_version_obj, event, auth_dict) logger.debug("Allowing! %s", event) return - _check_event_sender_in_room(event, auth_events) + _check_event_sender_in_room(event, auth_dict) # Special case to allow m.room.third_party_invite events wherever # a user is allowed to issue invites. Fixes # https://github.com/vector-im/vector-web/issues/1208 hopefully if event.type == EventTypes.ThirdPartyInvite: - user_level = get_user_power_level(event.user_id, auth_events) - invite_level = get_named_level(auth_events, "invite", 0) + user_level = get_user_power_level(event.user_id, auth_dict) + invite_level = get_named_level(auth_dict, "invite", 0) if user_level < invite_level: raise AuthError(403, "You don't have permission to invite users") @@ -241,20 +238,20 @@ def check_auth_rules_for_event( logger.debug("Allowing! %s", event) return - _can_send_event(event, auth_events) + _can_send_event(event, auth_dict) if event.type == EventTypes.PowerLevels: - _check_power_levels(room_version_obj, event, auth_events) + _check_power_levels(room_version_obj, event, auth_dict) if event.type == EventTypes.Redaction: - check_redaction(room_version_obj, event, auth_events) + check_redaction(room_version_obj, event, auth_dict) if ( event.type == EventTypes.MSC2716_INSERTION or event.type == EventTypes.MSC2716_BATCH or event.type == EventTypes.MSC2716_MARKER ): - check_historical(room_version_obj, event, auth_events) + check_historical(room_version_obj, event, auth_dict) logger.debug("Allowing! %s", event) diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index d089c56286..365063ebdf 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -55,8 +55,7 @@ class EventAuthHandler: """Check an event passes the auth rules at its own auth events""" auth_event_ids = event.auth_event_ids() auth_events_by_id = await self._store.get_events(auth_event_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()} - check_auth_rules_for_event(room_version_obj, event, auth_events) + check_auth_rules_for_event(room_version_obj, event, auth_events_by_id.values()) def compute_auth_events( self, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index e072efad16..69f1ef3afa 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1167,13 +1167,11 @@ class FederationHandler: logger.info("Failed to find auth event %r", e_id) for e in itertools.chain(auth_events, state, [event]): - auth_for_e = { - (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] - for e_id in e.auth_event_ids() - if e_id in event_map - } + auth_for_e = [ + event_map[e_id] for e_id in e.auth_event_ids() if e_id in event_map + ] if create_event: - auth_for_e[(EventTypes.Create, "")] = create_event + auth_for_e.append(create_event) try: validate_event_for_room_version(room_version, e) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index b8ce0006bb..1705432d7c 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1203,7 +1203,7 @@ class FederationEventHandler: def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]: with nested_logging_context(suffix=event.event_id): - auth = {} + auth = [] for auth_event_id in event.auth_event_ids(): ae = persisted_events.get(auth_event_id) if not ae: @@ -1216,7 +1216,7 @@ class FederationEventHandler: # exist, which means it is premature to reject `event`. Instead we # just ignore it for now. return None - auth[(ae.type, ae.state_key)] = ae + auth.append(ae) context = EventContext.for_outlier() try: @@ -1305,7 +1305,9 @@ class FederationEventHandler: auth_events_for_auth = calculated_auth_event_map try: - check_auth_rules_for_event(room_version_obj, event, auth_events_for_auth) + check_auth_rules_for_event( + room_version_obj, event, auth_events_for_auth.values() + ) except AuthError as e: logger.warning("Failed auth resolution for %r because %s", event, e) context.rejected = RejectedReason.AUTH_ERROR @@ -1403,11 +1405,9 @@ class FederationEventHandler: current_state_ids_list = [ e for k, e in current_state_ids.items() if k in auth_types ] - - auth_events_map = await self._store.get_events(current_state_ids_list) - current_auth_events = { - (e.type, e.state_key): e for e in auth_events_map.values() - } + current_auth_events = await self._store.get_events_as_list( + current_state_ids_list + ) try: check_auth_rules_for_event(room_version_obj, event, current_auth_events) diff --git a/synapse/state/v1.py b/synapse/state/v1.py index ffe6207a3c..6edadea550 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -332,7 +332,7 @@ def _resolve_auth_events( event_auth.check_auth_rules_for_event( RoomVersions.V1, event, - auth_events, + auth_events.values(), ) prev_event = event except AuthError: @@ -350,7 +350,7 @@ def _resolve_normal_events( event_auth.check_auth_rules_for_event( RoomVersions.V1, event, - auth_events, + auth_events.values(), ) return event except AuthError: diff --git a/synapse/state/v2.py b/synapse/state/v2.py index bd18eefd58..c618df2fde 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -549,7 +549,7 @@ async def _iterative_auth_checks( event_auth.check_auth_rules_for_event( room_version, event, - auth_events, + auth_events.values(), ) resolved_state[(event.type, event.state_key)] = event_id diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index cf407c51cf..e2c506e5a4 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -24,6 +24,47 @@ from synapse.types import JsonDict, get_domain_from_id class EventAuthTestCase(unittest.TestCase): + def test_rejected_auth_events(self): + """ + Events that refer to rejected events in their auth events are rejected + """ + creator = "@creator:example.com" + auth_events = [ + _create_event(creator), + _join_event(creator), + ] + + # creator should be able to send state + event_auth.check_auth_rules_for_event( + RoomVersions.V9, + _random_state_event(creator), + auth_events, + ) + + # ... but a rejected join_rules event should cause it to be rejected + rejected_join_rules = _join_rules_event(creator, "public") + rejected_join_rules.rejected_reason = "stinky" + auth_events.append(rejected_join_rules) + + self.assertRaises( + AuthError, + event_auth.check_auth_rules_for_event, + RoomVersions.V9, + _random_state_event(creator), + auth_events, + ) + + # ... even if there is *also* a good join rules + auth_events.append(_join_rules_event(creator, "public")) + + self.assertRaises( + AuthError, + event_auth.check_auth_rules_for_event, + RoomVersions.V9, + _random_state_event(creator), + auth_events, + ) + def test_random_users_cannot_send_state_before_first_pl(self): """ Check that, before the first PL lands, the creator is the only user @@ -31,11 +72,11 @@ class EventAuthTestCase(unittest.TestCase): """ creator = "@creator:example.com" joiner = "@joiner:example.com" - auth_events = { - ("m.room.create", ""): _create_event(creator), - ("m.room.member", creator): _join_event(creator), - ("m.room.member", joiner): _join_event(joiner), - } + auth_events = [ + _create_event(creator), + _join_event(creator), + _join_event(joiner), + ] # creator should be able to send state event_auth.check_auth_rules_for_event( @@ -62,15 +103,15 @@ class EventAuthTestCase(unittest.TestCase): pleb = "@joiner:example.com" king = "@joiner2:example.com" - auth_events = { - ("m.room.create", ""): _create_event(creator), - ("m.room.member", creator): _join_event(creator), - ("m.room.power_levels", ""): _power_levels_event( + auth_events = [ + _create_event(creator), + _join_event(creator), + _power_levels_event( creator, {"state_default": "30", "users": {pleb: "29", king: "30"}} ), - ("m.room.member", pleb): _join_event(pleb), - ("m.room.member", king): _join_event(king), - } + _join_event(pleb), + _join_event(king), + ] # pleb should not be able to send state self.assertRaises( @@ -92,10 +133,10 @@ class EventAuthTestCase(unittest.TestCase): """Alias events have special behavior up through room version 6.""" creator = "@creator:example.com" other = "@other:example.com" - auth_events = { - ("m.room.create", ""): _create_event(creator), - ("m.room.member", creator): _join_event(creator), - } + auth_events = [ + _create_event(creator), + _join_event(creator), + ] # creator should be able to send aliases event_auth.check_auth_rules_for_event( @@ -131,10 +172,10 @@ class EventAuthTestCase(unittest.TestCase): """After MSC2432, alias events have no special behavior.""" creator = "@creator:example.com" other = "@other:example.com" - auth_events = { - ("m.room.create", ""): _create_event(creator), - ("m.room.member", creator): _join_event(creator), - } + auth_events = [ + _create_event(creator), + _join_event(creator), + ] # creator should be able to send aliases event_auth.check_auth_rules_for_event( @@ -170,14 +211,14 @@ class EventAuthTestCase(unittest.TestCase): creator = "@creator:example.com" pleb = "@joiner:example.com" - auth_events = { - ("m.room.create", ""): _create_event(creator), - ("m.room.member", creator): _join_event(creator), - ("m.room.power_levels", ""): _power_levels_event( + auth_events = [ + _create_event(creator), + _join_event(creator), + _power_levels_event( creator, {"state_default": "30", "users": {pleb: "30"}} ), - ("m.room.member", pleb): _join_event(pleb), - } + _join_event(pleb), + ] # pleb should be able to modify the notifications power level. event_auth.check_auth_rules_for_event( @@ -211,7 +252,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) # A user cannot be force-joined to a room. @@ -219,7 +260,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _member_event(pleb, "join", sender=creator), - auth_events, + auth_events.values(), ) # Banned should be rejected. @@ -228,7 +269,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) # A user who left can re-join. @@ -236,7 +277,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) # A user can send a join if they're in the room. @@ -244,7 +285,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) # A user can accept an invite. @@ -254,7 +295,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) def test_join_rules_invite(self): @@ -275,7 +316,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) # A user cannot be force-joined to a room. @@ -283,7 +324,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _member_event(pleb, "join", sender=creator), - auth_events, + auth_events.values(), ) # Banned should be rejected. @@ -292,7 +333,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) # A user who left cannot re-join. @@ -301,7 +342,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) # A user can send a join if they're in the room. @@ -309,7 +350,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) # A user can accept an invite. @@ -319,7 +360,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) def test_join_rules_msc3083_restricted(self): @@ -347,7 +388,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) # A properly formatted join event should work. @@ -360,7 +401,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V8, authorised_join_event, - auth_events, + auth_events.values(), ) # A join issued by a specific user works (i.e. the power level checks @@ -380,7 +421,7 @@ class EventAuthTestCase(unittest.TestCase): EventContentFields.AUTHORISING_USER: "@inviter:foo.test" }, ), - pl_auth_events, + pl_auth_events.values(), ) # A join which is missing an authorised server is rejected. @@ -388,7 +429,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event(pleb), - auth_events, + auth_events.values(), ) # An join authorised by a user who is not in the room is rejected. @@ -405,7 +446,7 @@ class EventAuthTestCase(unittest.TestCase): EventContentFields.AUTHORISING_USER: "@other:example.com" }, ), - auth_events, + auth_events.values(), ) # A user cannot be force-joined to a room. (This uses an event which @@ -421,7 +462,7 @@ class EventAuthTestCase(unittest.TestCase): EventContentFields.AUTHORISING_USER: "@inviter:foo.test" }, ), - auth_events, + auth_events.values(), ) # Banned should be rejected. @@ -430,7 +471,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V8, authorised_join_event, - auth_events, + auth_events.values(), ) # A user who left can re-join. @@ -438,7 +479,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V8, authorised_join_event, - auth_events, + auth_events.values(), ) # A user can send a join if they're in the room. (This doesn't need to @@ -447,7 +488,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event(pleb), - auth_events, + auth_events.values(), ) # A user can accept an invite. (This doesn't need to be authorised since @@ -458,7 +499,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event(pleb), - auth_events, + auth_events.values(), ) @@ -473,6 +514,7 @@ def _create_event(user_id: str) -> EventBase: "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), "type": "m.room.create", + "state_key": "", "sender": user_id, "content": {"creator": user_id}, } -- cgit 1.5.1 From cc33d9eee205ab57ce562ac410c8912c14343134 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 18 Oct 2021 19:29:37 +0200 Subject: Check auth on received events' auth_events (#11001) Currently, when we receive an event whose auth_events differ from those we expect, we state-resolve between the two state sets, and check that the event passes auth based on the resolved state. This means that it's possible for us to accept events which don't pass auth at their declared auth_events (or where the auth events themselves were rejected), leading to problems down the line like #10083. This change means we will: * ignore any events where we cannot find the auth events * reject any events whose auth events were rejected * reject any events which do not pass auth at their declared auth_events. Together with a whole raft of previous work, this is a partial fix to #9595. Fixes #6643. Based on #11009. --- changelog.d/11001.bugfix | 1 + synapse/handlers/federation_event.py | 99 +++++++++++++++++++++++++++++++++++- 2 files changed, 98 insertions(+), 2 deletions(-) create mode 100644 changelog.d/11001.bugfix (limited to 'synapse') diff --git a/changelog.d/11001.bugfix b/changelog.d/11001.bugfix new file mode 100644 index 0000000000..f51ffb3481 --- /dev/null +++ b/changelog.d/11001.bugfix @@ -0,0 +1 @@ + Fix a long-standing bug which meant that events received over federation were sometimes incorrectly accepted into the room state. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 1705432d7c..af2c88394d 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1256,6 +1256,10 @@ class FederationEventHandler: Returns: The updated context object. + + Raises: + AuthError if we were unable to find copies of the event's auth events. + (Most other failures just cause us to set `context.rejected`.) """ # This method should only be used for non-outliers assert not event.internal_metadata.outlier @@ -1272,7 +1276,26 @@ class FederationEventHandler: context.rejected = RejectedReason.AUTH_ERROR return context - # calculate what the auth events *should* be, to use as a basis for auth. + # next, check that we have all of the event's auth events. + # + # Note that this can raise AuthError, which we want to propagate to the + # caller rather than swallow with `context.rejected` (since we cannot be + # certain that there is a permanent problem with the event). + claimed_auth_events = await self._load_or_fetch_auth_events_for_event( + origin, event + ) + + # ... and check that the event passes auth at those auth events. + try: + check_auth_rules_for_event(room_version_obj, event, claimed_auth_events) + except AuthError as e: + logger.warning( + "While checking auth of %r against auth_events: %s", event, e + ) + context.rejected = RejectedReason.AUTH_ERROR + return context + + # now check auth against what we think the auth events *should* be. prev_state_ids = await context.get_prev_state_ids() auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=True @@ -1472,6 +1495,9 @@ class FederationEventHandler: # if we have missing events, we need to fetch those events from somewhere. # # we start by checking if they are in the store, and then try calling /event_auth/. + # + # TODO: this code is now redundant, since it should be impossible for us to + # get here without already having the auth events. if missing_auth: have_events = await self._store.have_seen_events( event.room_id, missing_auth @@ -1575,7 +1601,7 @@ class FederationEventHandler: logger.info( "After state res: updating auth_events with new state %s", { - (d.type, d.state_key): d.event_id + d for d in new_state.values() if auth_events.get((d.type, d.state_key)) != d }, @@ -1589,6 +1615,75 @@ class FederationEventHandler: return context, auth_events + async def _load_or_fetch_auth_events_for_event( + self, destination: str, event: EventBase + ) -> Collection[EventBase]: + """Fetch this event's auth_events, from database or remote + + Loads any of the auth_events that we already have from the database/cache. If + there are any that are missing, calls /event_auth to get the complete auth + chain for the event (and then attempts to load the auth_events again). + + If any of the auth_events cannot be found, raises an AuthError. This can happen + for a number of reasons; eg: the events don't exist, or we were unable to talk + to `destination`, or we couldn't validate the signature on the event (which + in turn has multiple potential causes). + + Args: + destination: where to send the /event_auth request. Typically the server + that sent us `event` in the first place. + event: the event whose auth_events we want + + Returns: + all of the events in `event.auth_events`, after deduplication + + Raises: + AuthError if we were unable to fetch the auth_events for any reason. + """ + event_auth_event_ids = set(event.auth_event_ids()) + event_auth_events = await self._store.get_events( + event_auth_event_ids, allow_rejected=True + ) + missing_auth_event_ids = event_auth_event_ids.difference( + event_auth_events.keys() + ) + if not missing_auth_event_ids: + return event_auth_events.values() + + logger.info( + "Event %s refers to unknown auth events %s: fetching auth chain", + event, + missing_auth_event_ids, + ) + try: + await self._get_remote_auth_chain_for_event( + destination, event.room_id, event.event_id + ) + except Exception as e: + logger.warning("Failed to get auth chain for %s: %s", event, e) + # in this case, it's very likely we still won't have all the auth + # events - but we pick that up below. + + # try to fetch the auth events we missed list time. + extra_auth_events = await self._store.get_events( + missing_auth_event_ids, allow_rejected=True + ) + missing_auth_event_ids.difference_update(extra_auth_events.keys()) + event_auth_events.update(extra_auth_events) + if not missing_auth_event_ids: + return event_auth_events.values() + + # we still don't have all the auth events. + logger.warning( + "Missing auth events for %s: %s", + event, + shortstr(missing_auth_event_ids), + ) + # the fact we can't find the auth event doesn't mean it doesn't + # exist, which means it is premature to store `event` as rejected. + # instead we raise an AuthError, which will make the caller ignore it. + raise AuthError(code=HTTPStatus.FORBIDDEN, msg="Auth events could not be found") + async def _get_remote_auth_chain_for_event( self, destination: str, room_id: str, event_id: str ) -> None: -- cgit 1.5.1 From 3ab55d43bd66b377c1ed94a40931eba98dd07b01 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 18 Oct 2021 15:01:10 -0400 Subject: Add missing type hints to synapse.api. (#11109) * Convert UserPresenceState to attrs. * Remove args/kwargs from error classes and explicitly pass msg/errorcode. --- changelog.d/11109.misc | 1 + mypy.ini | 3 ++ synapse/api/auth.py | 14 ++++-- synapse/api/errors.py | 69 +++++++++----------------- synapse/api/filtering.py | 18 +++---- synapse/api/presence.py | 51 ++++++++++--------- synapse/api/ratelimiting.py | 4 +- synapse/api/urls.py | 13 ++--- synapse/handlers/presence.py | 2 +- synapse/storage/databases/main/registration.py | 8 +-- 10 files changed, 84 insertions(+), 99 deletions(-) create mode 100644 changelog.d/11109.misc (limited to 'synapse') diff --git a/changelog.d/11109.misc b/changelog.d/11109.misc new file mode 100644 index 0000000000..d83936ccc4 --- /dev/null +++ b/changelog.d/11109.misc @@ -0,0 +1 @@ +Add missing type hints to `synapse.api` module. diff --git a/mypy.ini b/mypy.ini index cb4489eb37..14d8bb8eaf 100644 --- a/mypy.ini +++ b/mypy.ini @@ -100,6 +100,9 @@ files = tests/util/test_itertools.py, tests/util/test_stream_change_cache.py +[mypy-synapse.api.*] +disallow_untyped_defs = True + [mypy-synapse.events.*] disallow_untyped_defs = True diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e6ca9232ee..44883c6663 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -245,7 +245,7 @@ class Auth: async def validate_appservice_can_control_user_id( self, app_service: ApplicationService, user_id: str - ): + ) -> None: """Validates that the app service is allowed to control the given user. @@ -618,5 +618,13 @@ class Auth: % (user_id, room_id), ) - async def check_auth_blocking(self, *args, **kwargs) -> None: - await self._auth_blocking.check_auth_blocking(*args, **kwargs) + async def check_auth_blocking( + self, + user_id: Optional[str] = None, + threepid: Optional[dict] = None, + user_type: Optional[str] = None, + requester: Optional[Requester] = None, + ) -> None: + await self._auth_blocking.check_auth_blocking( + user_id=user_id, threepid=threepid, user_type=user_type, requester=requester + ) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 9480f448d7..685d1c25cf 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -18,7 +18,7 @@ import logging import typing from http import HTTPStatus -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from twisted.web import http @@ -143,7 +143,7 @@ class SynapseError(CodeMessageException): super().__init__(code, msg) self.errcode = errcode - def error_dict(self): + def error_dict(self) -> "JsonDict": return cs_error(self.msg, self.errcode) @@ -175,7 +175,7 @@ class ProxiedRequestError(SynapseError): else: self._additional_fields = dict(additional_fields) - def error_dict(self): + def error_dict(self) -> "JsonDict": return cs_error(self.msg, self.errcode, **self._additional_fields) @@ -196,7 +196,7 @@ class ConsentNotGivenError(SynapseError): ) self._consent_uri = consent_uri - def error_dict(self): + def error_dict(self) -> "JsonDict": return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri) @@ -262,14 +262,10 @@ class InteractiveAuthIncompleteError(Exception): class UnrecognizedRequestError(SynapseError): """An error indicating we don't understand the request you're trying to make""" - def __init__(self, *args, **kwargs): - if "errcode" not in kwargs: - kwargs["errcode"] = Codes.UNRECOGNIZED - if len(args) == 0: - message = "Unrecognized request" - else: - message = args[0] - super().__init__(400, message, **kwargs) + def __init__( + self, msg: str = "Unrecognized request", errcode: str = Codes.UNRECOGNIZED + ): + super().__init__(400, msg, errcode) class NotFoundError(SynapseError): @@ -284,10 +280,8 @@ class AuthError(SynapseError): other poorly-defined times. """ - def __init__(self, *args, **kwargs): - if "errcode" not in kwargs: - kwargs["errcode"] = Codes.FORBIDDEN - super().__init__(*args, **kwargs) + def __init__(self, code: int, msg: str, errcode: str = Codes.FORBIDDEN): + super().__init__(code, msg, errcode) class InvalidClientCredentialsError(SynapseError): @@ -321,7 +315,7 @@ class InvalidClientTokenError(InvalidClientCredentialsError): super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN") self._soft_logout = soft_logout - def error_dict(self): + def error_dict(self) -> "JsonDict": d = super().error_dict() d["soft_logout"] = self._soft_logout return d @@ -345,7 +339,7 @@ class ResourceLimitError(SynapseError): self.limit_type = limit_type super().__init__(code, msg, errcode=errcode) - def error_dict(self): + def error_dict(self) -> "JsonDict": return cs_error( self.msg, self.errcode, @@ -357,32 +351,17 @@ class ResourceLimitError(SynapseError): class EventSizeError(SynapseError): """An error raised when an event is too big.""" - def __init__(self, *args, **kwargs): - if "errcode" not in kwargs: - kwargs["errcode"] = Codes.TOO_LARGE - super().__init__(413, *args, **kwargs) - - -class EventStreamError(SynapseError): - """An error raised when there a problem with the event stream.""" - - def __init__(self, *args, **kwargs): - if "errcode" not in kwargs: - kwargs["errcode"] = Codes.BAD_PAGINATION - super().__init__(*args, **kwargs) + def __init__(self, msg: str): + super().__init__(413, msg, Codes.TOO_LARGE) class LoginError(SynapseError): """An error raised when there was a problem logging in.""" - pass - class StoreError(SynapseError): """An error raised when there was a problem storing some data.""" - pass - class InvalidCaptchaError(SynapseError): def __init__( @@ -395,7 +374,7 @@ class InvalidCaptchaError(SynapseError): super().__init__(code, msg, errcode) self.error_url = error_url - def error_dict(self): + def error_dict(self) -> "JsonDict": return cs_error(self.msg, self.errcode, error_url=self.error_url) @@ -412,7 +391,7 @@ class LimitExceededError(SynapseError): super().__init__(code, msg, errcode) self.retry_after_ms = retry_after_ms - def error_dict(self): + def error_dict(self) -> "JsonDict": return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) @@ -443,10 +422,8 @@ class UnsupportedRoomVersionError(SynapseError): class ThreepidValidationError(SynapseError): """An error raised when there was a problem authorising an event.""" - def __init__(self, *args, **kwargs): - if "errcode" not in kwargs: - kwargs["errcode"] = Codes.FORBIDDEN - super().__init__(*args, **kwargs) + def __init__(self, msg: str, errcode: str = Codes.FORBIDDEN): + super().__init__(400, msg, errcode) class IncompatibleRoomVersionError(SynapseError): @@ -466,7 +443,7 @@ class IncompatibleRoomVersionError(SynapseError): self._room_version = room_version - def error_dict(self): + def error_dict(self) -> "JsonDict": return cs_error(self.msg, self.errcode, room_version=self._room_version) @@ -494,7 +471,7 @@ class RequestSendFailed(RuntimeError): errors (like programming errors). """ - def __init__(self, inner_exception, can_retry): + def __init__(self, inner_exception: BaseException, can_retry: bool): super().__init__( "Failed to send request: %s: %s" % (type(inner_exception).__name__, inner_exception) @@ -503,7 +480,7 @@ class RequestSendFailed(RuntimeError): self.can_retry = can_retry -def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs): +def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs: Any) -> "JsonDict": """Utility method for constructing an error response for client-server interactions. @@ -551,7 +528,7 @@ class FederationError(RuntimeError): msg = "%s %s: %s" % (level, code, reason) super().__init__(msg) - def get_dict(self): + def get_dict(self) -> "JsonDict": return { "level": self.level, "code": self.code, @@ -580,7 +557,7 @@ class HttpResponseException(CodeMessageException): super().__init__(code, msg) self.response = response - def to_synapse_error(self): + def to_synapse_error(self) -> SynapseError: """Make a SynapseError based on an HTTPResponseException This is useful when a proxied request has failed, and we need to diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 20e91a115d..bc550ae646 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -231,24 +231,24 @@ class FilterCollection: def include_redundant_members(self) -> bool: return self._room_state_filter.include_redundant_members() - def filter_presence(self, events): + def filter_presence( + self, events: Iterable[UserPresenceState] + ) -> List[UserPresenceState]: return self._presence_filter.filter(events) - def filter_account_data(self, events): + def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]: return self._account_data.filter(events) - def filter_room_state(self, events): + def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]: return self._room_state_filter.filter(self._room_filter.filter(events)) - def filter_room_timeline(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: + def filter_room_timeline(self, events: Iterable[EventBase]) -> List[EventBase]: return self._room_timeline_filter.filter(self._room_filter.filter(events)) - def filter_room_ephemeral(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: + def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]: return self._room_ephemeral_filter.filter(self._room_filter.filter(events)) - def filter_room_account_data( - self, events: Iterable[FilterEvent] - ) -> List[FilterEvent]: + def filter_room_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]: return self._room_account_data.filter(self._room_filter.filter(events)) def blocks_all_presence(self) -> bool: @@ -309,7 +309,7 @@ class Filter: # except for presence which actually gets passed around as its own # namedtuple type. if isinstance(event, UserPresenceState): - sender = event.user_id + sender: Optional[str] = event.user_id room_id = None ev_type = "m.presence" contains_url = False diff --git a/synapse/api/presence.py b/synapse/api/presence.py index a3bf0348d1..b80aa83cb3 100644 --- a/synapse/api/presence.py +++ b/synapse/api/presence.py @@ -12,49 +12,48 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple +from typing import Any, Optional + +import attr from synapse.api.constants import PresenceState +from synapse.types import JsonDict -class UserPresenceState( - namedtuple( - "UserPresenceState", - ( - "user_id", - "state", - "last_active_ts", - "last_federation_update_ts", - "last_user_sync_ts", - "status_msg", - "currently_active", - ), - ) -): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class UserPresenceState: """Represents the current presence state of the user. - user_id (str) - last_active (int): Time in msec that the user last interacted with server. - last_federation_update (int): Time in msec since either a) we sent a presence + user_id + last_active: Time in msec that the user last interacted with server. + last_federation_update: Time in msec since either a) we sent a presence update to other servers or b) we received a presence update, depending on if is a local user or not. - last_user_sync (int): Time in msec that the user last *completed* a sync + last_user_sync: Time in msec that the user last *completed* a sync (or event stream). - status_msg (str): User set status message. + status_msg: User set status message. """ - def as_dict(self): - return dict(self._asdict()) + user_id: str + state: str + last_active_ts: int + last_federation_update_ts: int + last_user_sync_ts: int + status_msg: Optional[str] + currently_active: bool + + def as_dict(self) -> JsonDict: + return attr.asdict(self) @staticmethod - def from_dict(d): + def from_dict(d: JsonDict) -> "UserPresenceState": return UserPresenceState(**d) - def copy_and_replace(self, **kwargs): - return self._replace(**kwargs) + def copy_and_replace(self, **kwargs: Any) -> "UserPresenceState": + return attr.evolve(self, **kwargs) @classmethod - def default(cls, user_id): + def default(cls, user_id: str) -> "UserPresenceState": """Returns a default presence state.""" return cls( user_id=user_id, diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index e8964097d3..849c18ceda 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -161,7 +161,7 @@ class Ratelimiter: return allowed, time_allowed - def _prune_message_counts(self, time_now_s: float): + def _prune_message_counts(self, time_now_s: float) -> None: """Remove message count entries that have not exceeded their defined rate_hz limit @@ -190,7 +190,7 @@ class Ratelimiter: update: bool = True, n_actions: int = 1, _time_now_s: Optional[float] = None, - ): + ) -> None: """Checks if an action can be performed. If not, raises a LimitExceededError Checks if the user has ratelimiting disabled in the database by looking diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 032c69b210..6e84b1524f 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -19,6 +19,7 @@ from hashlib import sha256 from urllib.parse import urlencode from synapse.config import ConfigError +from synapse.config.homeserver import HomeServerConfig SYNAPSE_CLIENT_API_PREFIX = "/_synapse/client" CLIENT_API_PREFIX = "/_matrix/client" @@ -34,11 +35,7 @@ LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" class ConsentURIBuilder: - def __init__(self, hs_config): - """ - Args: - hs_config (synapse.config.homeserver.HomeServerConfig): - """ + def __init__(self, hs_config: HomeServerConfig): if hs_config.key.form_secret is None: raise ConfigError("form_secret not set in config") if hs_config.server.public_baseurl is None: @@ -47,15 +44,15 @@ class ConsentURIBuilder: self._hmac_secret = hs_config.key.form_secret.encode("utf-8") self._public_baseurl = hs_config.server.public_baseurl - def build_user_consent_uri(self, user_id): + def build_user_consent_uri(self, user_id: str) -> str: """Build a URI which we can give to the user to do their privacy policy consent Args: - user_id (str): mxid or username of user + user_id: mxid or username of user Returns - (str) the URI where the user can do consent + The URI where the user can do consent """ mac = hmac.new( key=self._hmac_secret, msg=user_id.encode("ascii"), digestmod=sha256 diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 404afb9402..b5968e047b 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1489,7 +1489,7 @@ def format_user_presence_state( The "user_id" is optional so that this function can be used to format presence updates for client /sync responses and for federation /send requests. """ - content = {"presence": state.state} + content: JsonDict = {"presence": state.state} if include_user_id: content["user_id"] = state.user_id if state.last_active_ts: diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 181841ee06..0ab56d8a07 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -2237,7 +2237,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): # accident. row = {"client_secret": None, "validated_at": None} else: - raise ThreepidValidationError(400, "Unknown session_id") + raise ThreepidValidationError("Unknown session_id") retrieved_client_secret = row["client_secret"] validated_at = row["validated_at"] @@ -2252,14 +2252,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): if not row: raise ThreepidValidationError( - 400, "Validation token not found or has expired" + "Validation token not found or has expired" ) expires = row["expires"] next_link = row["next_link"] if retrieved_client_secret != client_secret: raise ThreepidValidationError( - 400, "This client_secret does not match the provided session_id" + "This client_secret does not match the provided session_id" ) # If the session is already validated, no need to revalidate @@ -2268,7 +2268,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): if expires <= current_ts: raise ThreepidValidationError( - 400, "This token has expired. Please request a new one" + "This token has expired. Please request a new one" ) # Looks good. Validate the session -- cgit 1.5.1 From d85bc9a4a7c853c4ca0499f8c4e51d8644c3fcfa Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 19 Oct 2021 11:21:50 +0200 Subject: Include rejected status when we log events. (#11008) If we find ourselves dealing with rejected events, we proably want to know about it. Let's include it in the stringification of the event so that it gets logged. --- changelog.d/11008.misc | 1 + synapse/events/__init__.py | 16 ++++++++++------ 2 files changed, 11 insertions(+), 6 deletions(-) create mode 100644 changelog.d/11008.misc (limited to 'synapse') diff --git a/changelog.d/11008.misc b/changelog.d/11008.misc new file mode 100644 index 0000000000..a67d95d66f --- /dev/null +++ b/changelog.d/11008.misc @@ -0,0 +1 @@ +Include rejected status when we log events. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 49190459c8..157669ea88 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -348,12 +348,16 @@ class EventBase(metaclass=abc.ABCMeta): return self.__repr__() def __repr__(self): - return "<%s event_id=%r, type=%r, state_key=%r, outlier=%s>" % ( - self.__class__.__name__, - self.event_id, - self.get("type", None), - self.get("state_key", None), - self.internal_metadata.is_outlier(), + rejection = f"REJECTED={self.rejected_reason}, " if self.rejected_reason else "" + + return ( + f"<{self.__class__.__name__} " + f"{rejection}" + f"event_id={self.event_id}, " + f"type={self.get('type')}, " + f"state_key={self.get('state_key')}, " + f"outlier={self.internal_metadata.is_outlier()}" + ">" ) -- cgit 1.5.1 From 0170774b1906c901b214acd63ab4936c177db5a3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 19 Oct 2021 11:23:55 +0200 Subject: Rename `_auth_and_persist_fetched_events` (#11116) ... to `_auth_and_persist_outliers`, since that reflects its purpose better. --- changelog.d/11116.misc | 1 + synapse/handlers/federation_event.py | 23 +++++++++-------------- 2 files changed, 10 insertions(+), 14 deletions(-) create mode 100644 changelog.d/11116.misc (limited to 'synapse') diff --git a/changelog.d/11116.misc b/changelog.d/11116.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/11116.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index af2c88394d..22d364800b 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1116,14 +1116,12 @@ class FederationEventHandler: await concurrently_execute(get_event, event_ids, 5) logger.info("Fetched %i events of %i requested", len(events), len(event_ids)) - await self._auth_and_persist_fetched_events(destination, room_id, events) + await self._auth_and_persist_outliers(room_id, events) - async def _auth_and_persist_fetched_events( - self, origin: str, room_id: str, events: Iterable[EventBase] + async def _auth_and_persist_outliers( + self, room_id: str, events: Iterable[EventBase] ) -> None: - """Persist the events fetched by _get_events_and_persist or _get_remote_auth_chain_for_event - - The events to be persisted must be outliers. + """Persist a batch of outlier events fetched from remote servers. We first sort the events to make sure that we process each event's auth_events before the event itself, and then auth and persist them. @@ -1131,7 +1129,6 @@ class FederationEventHandler: Notifies about the events where appropriate. Params: - origin: where the events came from room_id: the room that the events are meant to be in (though this has not yet been checked) events: the events that have been fetched @@ -1167,15 +1164,15 @@ class FederationEventHandler: shortstr(e.event_id for e in roots), ) - await self._auth_and_persist_fetched_events_inner(origin, room_id, roots) + await self._auth_and_persist_outliers_inner(room_id, roots) for ev in roots: del event_map[ev.event_id] - async def _auth_and_persist_fetched_events_inner( - self, origin: str, room_id: str, fetched_events: Collection[EventBase] + async def _auth_and_persist_outliers_inner( + self, room_id: str, fetched_events: Collection[EventBase] ) -> None: - """Helper for _auth_and_persist_fetched_events + """Helper for _auth_and_persist_outliers Persists a batch of events where we have (theoretically) already persisted all of their auth events. @@ -1719,9 +1716,7 @@ class FederationEventHandler: for s in seen_remotes: remote_event_map.pop(s, None) - await self._auth_and_persist_fetched_events( - destination, room_id, remote_event_map.values() - ) + await self._auth_and_persist_outliers(room_id, remote_event_map.values()) async def _update_context_for_auth_events( self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] -- cgit 1.5.1 From f3efa0036bf3ec716b855839bad75702d31f7352 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 19 Oct 2021 11:24:09 +0200 Subject: Move _persist_auth_tree into FederationEventHandler (#11115) This is just a lift-and-shift, because it fits more naturally here. We do rename it to `process_remote_join` at the same time though. --- changelog.d/11115.misc | 1 + synapse/handlers/federation.py | 128 ++--------------------------------- synapse/handlers/federation_event.py | 116 ++++++++++++++++++++++++++++++- 3 files changed, 120 insertions(+), 125 deletions(-) create mode 100644 changelog.d/11115.misc (limited to 'synapse') diff --git a/changelog.d/11115.misc b/changelog.d/11115.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/11115.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 69f1ef3afa..3112cc88b1 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -15,7 +15,6 @@ """Contains handlers for federation events.""" -import itertools import logging from http import HTTPStatus from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union @@ -27,12 +26,7 @@ from unpaddedbase64 import decode_base64 from twisted.internet import defer from synapse import event_auth -from synapse.api.constants import ( - EventContentFields, - EventTypes, - Membership, - RejectedReason, -) +from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import ( AuthError, CodeMessageException, @@ -43,12 +37,9 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.crypto.event_signing import compute_event_signature -from synapse.event_auth import ( - check_auth_rules_for_event, - validate_event_for_room_version, -) +from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator @@ -519,7 +510,7 @@ class FederationHandler: auth_events=auth_chain, ) - max_stream_id = await self._persist_auth_tree( + max_stream_id = await self._federation_event_handler.process_remote_join( origin, room_id, auth_chain, state, event, room_version_obj ) @@ -1095,117 +1086,6 @@ class FederationHandler: else: return None - async def _persist_auth_tree( - self, - origin: str, - room_id: str, - auth_events: List[EventBase], - state: List[EventBase], - event: EventBase, - room_version: RoomVersion, - ) -> int: - """Checks the auth chain is valid (and passes auth checks) for the - state and event. Then persists the auth chain and state atomically. - Persists the event separately. Notifies about the persisted events - where appropriate. - - Will attempt to fetch missing auth events. - - Args: - origin: Where the events came from - room_id, - auth_events - state - event - room_version: The room version we expect this room to have, and - will raise if it doesn't match the version in the create event. - """ - events_to_context = {} - for e in itertools.chain(auth_events, state): - e.internal_metadata.outlier = True - events_to_context[e.event_id] = EventContext.for_outlier() - - event_map = { - e.event_id: e for e in itertools.chain(auth_events, state, [event]) - } - - create_event = None - for e in auth_events: - if (e.type, e.state_key) == (EventTypes.Create, ""): - create_event = e - break - - if create_event is None: - # If the state doesn't have a create event then the room is - # invalid, and it would fail auth checks anyway. - raise SynapseError(400, "No create event in state") - - room_version_id = create_event.content.get( - "room_version", RoomVersions.V1.identifier - ) - - if room_version.identifier != room_version_id: - raise SynapseError(400, "Room version mismatch") - - missing_auth_events = set() - for e in itertools.chain(auth_events, state, [event]): - for e_id in e.auth_event_ids(): - if e_id not in event_map: - missing_auth_events.add(e_id) - - for e_id in missing_auth_events: - m_ev = await self.federation_client.get_pdu( - [origin], - e_id, - room_version=room_version, - outlier=True, - timeout=10000, - ) - if m_ev and m_ev.event_id == e_id: - event_map[e_id] = m_ev - else: - logger.info("Failed to find auth event %r", e_id) - - for e in itertools.chain(auth_events, state, [event]): - auth_for_e = [ - event_map[e_id] for e_id in e.auth_event_ids() if e_id in event_map - ] - if create_event: - auth_for_e.append(create_event) - - try: - validate_event_for_room_version(room_version, e) - check_auth_rules_for_event(room_version, e, auth_for_e) - except SynapseError as err: - # we may get SynapseErrors here as well as AuthErrors. For - # instance, there are a couple of (ancient) events in some - # rooms whose senders do not have the correct sigil; these - # cause SynapseErrors in auth.check. We don't want to give up - # the attempt to federate altogether in such cases. - - logger.warning("Rejecting %s because %s", e.event_id, err.msg) - - if e == event: - raise - events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR - - if auth_events or state: - await self._federation_event_handler.persist_events_and_notify( - room_id, - [ - (e, events_to_context[e.event_id]) - for e in itertools.chain(auth_events, state) - ], - ) - - new_event_context = await self.state_handler.compute_event_context( - event, old_state=state - ) - - return await self._federation_event_handler.persist_events_and_notify( - room_id, [(event, new_event_context)] - ) - async def on_get_missing_events( self, origin: str, diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 22d364800b..5a2f2e5ebb 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import logging from http import HTTPStatus from typing import ( @@ -45,7 +46,7 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions from synapse.event_auth import ( auth_types_for_event, check_auth_rules_for_event, @@ -390,6 +391,119 @@ class FederationEventHandler: prev_member_event, ) + async def process_remote_join( + self, + origin: str, + room_id: str, + auth_events: List[EventBase], + state: List[EventBase], + event: EventBase, + room_version: RoomVersion, + ) -> int: + """Persists the events returned by a send_join + + Checks the auth chain is valid (and passes auth checks) for the + state and event. Then persists the auth chain and state atomically. + Persists the event separately. Notifies about the persisted events + where appropriate. + + Will attempt to fetch missing auth events. + + Args: + origin: Where the events came from + room_id, + auth_events + state + event + room_version: The room version we expect this room to have, and + will raise if it doesn't match the version in the create event. + """ + events_to_context = {} + for e in itertools.chain(auth_events, state): + e.internal_metadata.outlier = True + events_to_context[e.event_id] = EventContext.for_outlier() + + event_map = { + e.event_id: e for e in itertools.chain(auth_events, state, [event]) + } + + create_event = None + for e in auth_events: + if (e.type, e.state_key) == (EventTypes.Create, ""): + create_event = e + break + + if create_event is None: + # If the state doesn't have a create event then the room is + # invalid, and it would fail auth checks anyway. + raise SynapseError(400, "No create event in state") + + room_version_id = create_event.content.get( + "room_version", RoomVersions.V1.identifier + ) + + if room_version.identifier != room_version_id: + raise SynapseError(400, "Room version mismatch") + + missing_auth_events = set() + for e in itertools.chain(auth_events, state, [event]): + for e_id in e.auth_event_ids(): + if e_id not in event_map: + missing_auth_events.add(e_id) + + for e_id in missing_auth_events: + m_ev = await self._federation_client.get_pdu( + [origin], + e_id, + room_version=room_version, + outlier=True, + timeout=10000, + ) + if m_ev and m_ev.event_id == e_id: + event_map[e_id] = m_ev + else: + logger.info("Failed to find auth event %r", e_id) + + for e in itertools.chain(auth_events, state, [event]): + auth_for_e = [ + event_map[e_id] for e_id in e.auth_event_ids() if e_id in event_map + ] + if create_event: + auth_for_e.append(create_event) + + try: + validate_event_for_room_version(room_version, e) + check_auth_rules_for_event(room_version, e, auth_for_e) + except SynapseError as err: + # we may get SynapseErrors here as well as AuthErrors. For + # instance, there are a couple of (ancient) events in some + # rooms whose senders do not have the correct sigil; these + # cause SynapseErrors in auth.check. We don't want to give up + # the attempt to federate altogether in such cases. + + logger.warning("Rejecting %s because %s", e.event_id, err.msg) + + if e == event: + raise + events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR + + if auth_events or state: + await self.persist_events_and_notify( + room_id, + [ + (e, events_to_context[e.event_id]) + for e in itertools.chain(auth_events, state) + ], + ) + + new_event_context = await self._state_handler.compute_event_context( + event, old_state=state + ) + + return await self.persist_events_and_notify( + room_id, [(event, new_event_context)] + ) + @log_function async def backfill( self, dest: str, room_id: str, limit: int, extremities: Iterable[str] -- cgit 1.5.1 From 0dd0c40329cf620590b781b13d5b79332581fea7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 19 Oct 2021 10:29:03 -0400 Subject: Add missing type hints to event fetching. (#11121) Updates the event rows returned from the database to be attrs classes instead of dictionaries. --- changelog.d/11121.misc | 1 + synapse/storage/databases/main/events_worker.py | 142 ++++++++++++++---------- 2 files changed, 82 insertions(+), 61 deletions(-) create mode 100644 changelog.d/11121.misc (limited to 'synapse') diff --git a/changelog.d/11121.misc b/changelog.d/11121.misc new file mode 100644 index 0000000000..916beeaacb --- /dev/null +++ b/changelog.d/11121.misc @@ -0,0 +1 @@ +Add type hints for event fetching. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 4a1a2f4a6a..ae37901be9 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -55,8 +55,9 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import PostgresEngine +from synapse.storage.types import Connection from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id @@ -86,6 +87,47 @@ class _EventCacheEntry: redacted_event: Optional[EventBase] +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _EventRow: + """ + An event, as pulled from the database. + + Properties: + event_id: The event ID of the event. + + stream_ordering: stream ordering for this event + + json: json-encoded event structure + + internal_metadata: json-encoded internal metadata dict + + format_version: The format of the event. Hopefully one of EventFormatVersions. + 'None' means the event predates EventFormatVersions (so the event is format V1). + + room_version_id: The version of the room which contains the event. Hopefully + one of RoomVersions. + + Due to historical reasons, there may be a few events in the database which + do not have an associated room; in this case None will be returned here. + + rejected_reason: if the event was rejected, the reason why. + + redactions: a list of event-ids which (claim to) redact this event. + + outlier: True if this event is an outlier. + """ + + event_id: str + stream_ordering: int + json: str + internal_metadata: str + format_version: Optional[int] + room_version_id: Optional[int] + rejected_reason: Optional[str] + redactions: List[str] + outlier: bool + + class EventRedactBehaviour(Names): """ What to do when retrieving a redacted event from the database. @@ -686,7 +728,7 @@ class EventsWorkerStore(SQLBaseStore): for e in state_to_include.values() ] - def _do_fetch(self, conn): + def _do_fetch(self, conn: Connection) -> None: """Takes a database connection and waits for requests for events from the _event_fetch_list queue. """ @@ -713,13 +755,15 @@ class EventsWorkerStore(SQLBaseStore): self._fetch_event_list(conn, event_list) - def _fetch_event_list(self, conn, event_list): + def _fetch_event_list( + self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]] + ) -> None: """Handle a load of requests from the _event_fetch_list queue Args: - conn (twisted.enterprise.adbapi.Connection): database connection + conn: database connection - event_list (list[Tuple[list[str], Deferred]]): + event_list: The fetch requests. Each entry consists of a list of event ids to be fetched, and a deferred to be completed once the events have been fetched. @@ -788,7 +832,7 @@ class EventsWorkerStore(SQLBaseStore): row = row_map.get(event_id) fetched_events[event_id] = row if row: - redaction_ids.update(row["redactions"]) + redaction_ids.update(row.redactions) events_to_fetch = redaction_ids.difference(fetched_events.keys()) if events_to_fetch: @@ -799,32 +843,32 @@ class EventsWorkerStore(SQLBaseStore): for event_id, row in fetched_events.items(): if not row: continue - assert row["event_id"] == event_id + assert row.event_id == event_id - rejected_reason = row["rejected_reason"] + rejected_reason = row.rejected_reason # If the event or metadata cannot be parsed, log the error and act # as if the event is unknown. try: - d = db_to_json(row["json"]) + d = db_to_json(row.json) except ValueError: logger.error("Unable to parse json from event: %s", event_id) continue try: - internal_metadata = db_to_json(row["internal_metadata"]) + internal_metadata = db_to_json(row.internal_metadata) except ValueError: logger.error( "Unable to parse internal_metadata from event: %s", event_id ) continue - format_version = row["format_version"] + format_version = row.format_version if format_version is None: # This means that we stored the event before we had the concept # of a event format version, so it must be a V1 event. format_version = EventFormatVersions.V1 - room_version_id = row["room_version_id"] + room_version_id = row.room_version_id if not room_version_id: # this should only happen for out-of-band membership events which @@ -889,8 +933,8 @@ class EventsWorkerStore(SQLBaseStore): internal_metadata_dict=internal_metadata, rejected_reason=rejected_reason, ) - original_ev.internal_metadata.stream_ordering = row["stream_ordering"] - original_ev.internal_metadata.outlier = row["outlier"] + original_ev.internal_metadata.stream_ordering = row.stream_ordering + original_ev.internal_metadata.outlier = row.outlier event_map[event_id] = original_ev @@ -898,7 +942,7 @@ class EventsWorkerStore(SQLBaseStore): # the cache entries. result_map = {} for event_id, original_ev in event_map.items(): - redactions = fetched_events[event_id]["redactions"] + redactions = fetched_events[event_id].redactions redacted_event = self._maybe_redact_event_row( original_ev, redactions, event_map ) @@ -912,17 +956,17 @@ class EventsWorkerStore(SQLBaseStore): return result_map - async def _enqueue_events(self, events): + async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]: """Fetches events from the database using the _event_fetch_list. This allows batch and bulk fetching of events - it allows us to fetch events without having to create a new transaction for each request for events. Args: - events (Iterable[str]): events to be fetched. + events: events to be fetched. Returns: - Dict[str, Dict]: map from event id to row data from the database. - May contain events that weren't requested. + A map from event id to row data from the database. May contain events + that weren't requested. """ events_d = defer.Deferred() @@ -949,43 +993,19 @@ class EventsWorkerStore(SQLBaseStore): return row_map - def _fetch_event_rows(self, txn, event_ids): + def _fetch_event_rows( + self, txn: LoggingTransaction, event_ids: Iterable[str] + ) -> Dict[str, _EventRow]: """Fetch event rows from the database Events which are not found are omitted from the result. - The returned per-event dicts contain the following keys: - - * event_id (str) - - * stream_ordering (int): stream ordering for this event - - * json (str): json-encoded event structure - - * internal_metadata (str): json-encoded internal metadata dict - - * format_version (int|None): The format of the event. Hopefully one - of EventFormatVersions. 'None' means the event predates - EventFormatVersions (so the event is format V1). - - * room_version_id (str|None): The version of the room which contains the event. - Hopefully one of RoomVersions. - - Due to historical reasons, there may be a few events in the database which - do not have an associated room; in this case None will be returned here. - - * rejected_reason (str|None): if the event was rejected, the reason - why. - - * redactions (List[str]): a list of event-ids which (claim to) redact - this event. - Args: - txn (twisted.enterprise.adbapi.Connection): - event_ids (Iterable[str]): event IDs to fetch + txn: The database transaction. + event_ids: event IDs to fetch Returns: - Dict[str, Dict]: a map from event id to event info. + A map from event id to event info. """ event_dict = {} for evs in batch_iter(event_ids, 200): @@ -1013,17 +1033,17 @@ class EventsWorkerStore(SQLBaseStore): for row in txn: event_id = row[0] - event_dict[event_id] = { - "event_id": event_id, - "stream_ordering": row[1], - "internal_metadata": row[2], - "json": row[3], - "format_version": row[4], - "room_version_id": row[5], - "rejected_reason": row[6], - "redactions": [], - "outlier": row[7], - } + event_dict[event_id] = _EventRow( + event_id=event_id, + stream_ordering=row[1], + internal_metadata=row[2], + json=row[3], + format_version=row[4], + room_version_id=row[5], + rejected_reason=row[6], + redactions=[], + outlier=row[7], + ) # check for redactions redactions_sql = "SELECT event_id, redacts FROM redactions WHERE " @@ -1035,7 +1055,7 @@ class EventsWorkerStore(SQLBaseStore): for (redacter, redacted) in txn: d = event_dict.get(redacted) if d: - d["redactions"].append(redacter) + d.redactions.append(redacter) return event_dict -- cgit 1.5.1 From 522489fbcdfd4f52f1ebbbf5670c474cbf3e29be Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Wed, 20 Oct 2021 12:00:03 +0100 Subject: 1.45.1 --- CHANGES.md | 9 +++++++++ changelog.d/11127.bugfix | 1 - debian/changelog | 6 ++++++ synapse/__init__.py | 2 +- 4 files changed, 16 insertions(+), 2 deletions(-) delete mode 100644 changelog.d/11127.bugfix (limited to 'synapse') diff --git a/CHANGES.md b/CHANGES.md index 435387d7b0..4b7dd69c96 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,12 @@ +Synapse 1.45.1 (2021-10-20) +=========================== + +Bugfixes +-------- + +- Revert change to counting of deactivated users towards the monthly active users limit, introduced in 1.45.0rc1 ([\#10947](https://github.com/matrix-org/synapse/issues/10947)). ([\#11127](https://github.com/matrix-org/synapse/issues/11127)) + + Synapse 1.45.0 (2021-10-19) =========================== diff --git a/changelog.d/11127.bugfix b/changelog.d/11127.bugfix deleted file mode 100644 index 54417a9975..0000000000 --- a/changelog.d/11127.bugfix +++ /dev/null @@ -1 +0,0 @@ -Revert change to counting of deactivated users towards the monthly active users limit ([\#10947](https://github.com/matrix-org/synapse/issues/10947)). diff --git a/debian/changelog b/debian/changelog index 5fefb2f2ac..1ee81f2a34 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.45.1) stable; urgency=medium + + * New synapse release 1.45.1. + + -- Synapse Packaging team Wed, 20 Oct 2021 11:58:27 +0100 + matrix-synapse-py3 (1.45.0) stable; urgency=medium * New synapse release 1.45.0. diff --git a/synapse/__init__.py b/synapse/__init__.py index 97452f34fe..2687d932ea 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.45.0" +__version__ = "1.45.1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.5.1 From 2c61a318cc46ec38e64d6a497f6077d23b9341bf Mon Sep 17 00:00:00 2001 From: Aaron R Date: Wed, 20 Oct 2021 09:41:48 -0500 Subject: Show error when timestamp in seconds is provided to the /purge_media_cache API (#11101) --- changelog.d/11101.bugfix | 1 + docs/admin_api/media_admin_api.md | 6 +-- synapse/rest/admin/media.py | 33 +++++++++--- tests/rest/admin/test_media.py | 106 ++++++++++++++++++++++++++++++++++++-- 4 files changed, 133 insertions(+), 13 deletions(-) create mode 100644 changelog.d/11101.bugfix (limited to 'synapse') diff --git a/changelog.d/11101.bugfix b/changelog.d/11101.bugfix new file mode 100644 index 0000000000..0de507848f --- /dev/null +++ b/changelog.d/11101.bugfix @@ -0,0 +1 @@ +Show an error when timestamp in seconds is provided to the `/purge_media_cache` Admin API. \ No newline at end of file diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index ea05bd6e44..60b8bc7379 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -257,9 +257,9 @@ POST /_synapse/admin/v1/media//delete?before_ts= URL Parameters * `server_name`: string - The name of your local server (e.g `matrix.org`). -* `before_ts`: string representing a positive integer - Unix timestamp in ms. +* `before_ts`: string representing a positive integer - Unix timestamp in milliseconds. Files that were last used before this timestamp will be deleted. It is the timestamp of -last access and not the timestamp creation. +last access, not the timestamp when the file was created. * `size_gt`: Optional - string representing a positive integer - Size of the media in bytes. Files that are larger will be deleted. Defaults to `0`. * `keep_profiles`: Optional - string representing a boolean - Switch to also delete files @@ -302,7 +302,7 @@ POST /_synapse/admin/v1/purge_media_cache?before_ts= URL Parameters -* `unix_timestamp_in_ms`: string representing a positive integer - Unix timestamp in ms. +* `unix_timestamp_in_ms`: string representing a positive integer - Unix timestamp in milliseconds. All cached media that was last accessed before this timestamp will be removed. Response: diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 8ce443049e..30a687d234 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -40,7 +40,7 @@ class QuarantineMediaInRoom(RestServlet): """ PATTERNS = [ - *admin_patterns("/room/(?P[^/]+)/media/quarantine"), + *admin_patterns("/room/(?P[^/]+)/media/quarantine$"), # This path kept around for legacy reasons *admin_patterns("/quarantine_media/(?P[^/]+)"), ] @@ -70,7 +70,7 @@ class QuarantineMediaByUser(RestServlet): this server. """ - PATTERNS = admin_patterns("/user/(?P[^/]+)/media/quarantine") + PATTERNS = admin_patterns("/user/(?P[^/]+)/media/quarantine$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -199,7 +199,7 @@ class UnprotectMediaByID(RestServlet): class ListMediaInRoom(RestServlet): """Lists all of the media in a given room.""" - PATTERNS = admin_patterns("/room/(?P[^/]+)/media") + PATTERNS = admin_patterns("/room/(?P[^/]+)/media$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -219,7 +219,7 @@ class ListMediaInRoom(RestServlet): class PurgeMediaCacheRestServlet(RestServlet): - PATTERNS = admin_patterns("/purge_media_cache") + PATTERNS = admin_patterns("/purge_media_cache$") def __init__(self, hs: "HomeServer"): self.media_repository = hs.get_media_repository() @@ -231,6 +231,20 @@ class PurgeMediaCacheRestServlet(RestServlet): before_ts = parse_integer(request, "before_ts", required=True) logger.info("before_ts: %r", before_ts) + if before_ts < 0: + raise SynapseError( + 400, + "Query parameter before_ts must be a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + elif before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds + raise SynapseError( + 400, + "Query parameter before_ts you provided is from the year 1970. " + + "Double check that you are providing a timestamp in milliseconds.", + errcode=Codes.INVALID_PARAM, + ) + ret = await self.media_repository.delete_old_remote_media(before_ts) return 200, ret @@ -271,7 +285,7 @@ class DeleteMediaByDateSize(RestServlet): timestamp and size. """ - PATTERNS = admin_patterns("/media/(?P[^/]+)/delete") + PATTERNS = admin_patterns("/media/(?P[^/]+)/delete$") def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -291,7 +305,14 @@ class DeleteMediaByDateSize(RestServlet): if before_ts < 0: raise SynapseError( 400, - "Query parameter before_ts must be a string representing a positive integer.", + "Query parameter before_ts must be a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + elif before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds + raise SynapseError( + 400, + "Query parameter before_ts you provided is from the year 1970. " + + "Double check that you are providing a timestamp in milliseconds.", errcode=Codes.INVALID_PARAM, ) if size_gt < 0: diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index ce30a19213..db0e78c039 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -27,6 +27,9 @@ from tests import unittest from tests.server import FakeSite, make_request from tests.test_utils import SMALL_PNG +VALID_TIMESTAMP = 1609459200000 # 2021-01-01 in milliseconds +INVALID_TIMESTAMP_IN_S = 1893456000 # 2030-01-01 in seconds + class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): @@ -203,6 +206,9 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.filepaths = MediaFilePaths(hs.config.media.media_store_path) self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name + # Move clock up to somewhat realistic time + self.reactor.advance(1000000000) + def test_no_auth(self): """ Try to delete media without authentication. @@ -237,7 +243,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - url + "?before_ts=1234", + url + f"?before_ts={VALID_TIMESTAMP}", access_token=self.admin_user_tok, ) @@ -273,13 +279,27 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual( - "Query parameter before_ts must be a string representing a positive integer.", + "Query parameter before_ts must be a positive integer.", channel.json_body["error"], ) channel = self.make_request( "POST", - self.url + "?before_ts=1234&size_gt=-1234", + self.url + f"?before_ts={INVALID_TIMESTAMP_IN_S}", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "Query parameter before_ts you provided is from the year 1970. " + + "Double check that you are providing a timestamp in milliseconds.", + channel.json_body["error"], + ) + + channel = self.make_request( + "POST", + self.url + f"?before_ts={VALID_TIMESTAMP}&size_gt=-1234", access_token=self.admin_user_tok, ) @@ -292,7 +312,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", - self.url + "?before_ts=1234&keep_profiles=not_bool", + self.url + f"?before_ts={VALID_TIMESTAMP}&keep_profiles=not_bool", access_token=self.admin_user_tok, ) @@ -767,3 +787,81 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): media_info = self.get_success(self.store.get_local_media(self.media_id)) self.assertFalse(media_info["safe_from_quarantine"]) + + +class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_media_repo, + login.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.media_repo = hs.get_media_repository_resource() + self.server_name = hs.hostname + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.filepaths = MediaFilePaths(hs.config.media.media_store_path) + self.url = "/_synapse/admin/v1/purge_media_cache" + + def test_no_auth(self): + """ + Try to delete media without authentication. + """ + + channel = self.make_request("POST", self.url, b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_not_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + channel = self.make_request( + "POST", + self.url, + access_token=self.other_user_token, + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_invalid_parameter(self): + """ + If parameters are invalid, an error is returned. + """ + channel = self.make_request( + "POST", + self.url + "?before_ts=-1234", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "Query parameter before_ts must be a positive integer.", + channel.json_body["error"], + ) + + channel = self.make_request( + "POST", + self.url + f"?before_ts={INVALID_TIMESTAMP_IN_S}", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "Query parameter before_ts you provided is from the year 1970. " + + "Double check that you are providing a timestamp in milliseconds.", + channel.json_body["error"], + ) -- cgit 1.5.1 From 0930e9ae124265165df2cccdbf051de63c334436 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 20 Oct 2021 19:22:40 +0200 Subject: Clean up `_update_auth_events_and_context_for_auth` (#11122) Remove some redundant code, and generally simplify. --- changelog.d/11122.misc | 1 + synapse/handlers/federation_event.py | 151 +++++++++-------------------------- 2 files changed, 38 insertions(+), 114 deletions(-) create mode 100644 changelog.d/11122.misc (limited to 'synapse') diff --git a/changelog.d/11122.misc b/changelog.d/11122.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/11122.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 5a2f2e5ebb..3431a80ab4 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -65,7 +65,6 @@ from synapse.replication.http.federation import ( from synapse.state import StateResolutionStore from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( - MutableStateMap, PersistedEventPosition, RoomStreamToken, StateMap, @@ -1417,13 +1416,8 @@ class FederationEventHandler: } try: - ( - context, - auth_events_for_auth, - ) = await self._update_auth_events_and_context_for_auth( - origin, + updated_auth_events = await self._update_auth_events_for_auth( event, - context, calculated_auth_event_map=calculated_auth_event_map, ) except Exception: @@ -1436,6 +1430,14 @@ class FederationEventHandler: "Ignoring failure and continuing processing of event.", event.event_id, ) + updated_auth_events = None + + if updated_auth_events: + context = await self._update_context_for_auth_events( + event, context, updated_auth_events + ) + auth_events_for_auth = updated_auth_events + else: auth_events_for_auth = calculated_auth_event_map try: @@ -1560,13 +1562,11 @@ class FederationEventHandler: soft_failed_event_counter.inc() event.internal_metadata.soft_failed = True - async def _update_auth_events_and_context_for_auth( + async def _update_auth_events_for_auth( self, - origin: str, event: EventBase, - context: EventContext, calculated_auth_event_map: StateMap[EventBase], - ) -> Tuple[EventContext, StateMap[EventBase]]: + ) -> Optional[StateMap[EventBase]]: """Helper for _check_event_auth. See there for docs. Checks whether a given event has the expected auth events. If it @@ -1579,96 +1579,27 @@ class FederationEventHandler: processing of the event. Args: - origin: event: - context: calculated_auth_event_map: Our calculated auth_events based on the state of the room at the event's position in the DAG. Returns: - updated context, updated auth event map + updated auth event map, or None if no changes are needed. + """ assert not event.internal_metadata.outlier - # take a copy of calculated_auth_event_map before we modify it. - auth_events: MutableStateMap[EventBase] = dict(calculated_auth_event_map) - + # check for events which are in the event's claimed auth_events, but not + # in our calculated event map. event_auth_events = set(event.auth_event_ids()) - - # missing_auth is the set of the event's auth_events which we don't yet have - # in auth_events. - missing_auth = event_auth_events.difference( - e.event_id for e in auth_events.values() - ) - - # if we have missing events, we need to fetch those events from somewhere. - # - # we start by checking if they are in the store, and then try calling /event_auth/. - # - # TODO: this code is now redundant, since it should be impossible for us to - # get here without already having the auth events. - if missing_auth: - have_events = await self._store.have_seen_events( - event.room_id, missing_auth - ) - logger.debug("Events %s are in the store", have_events) - missing_auth.difference_update(have_events) - - # missing_auth is now the set of event_ids which: - # a. are listed in event.auth_events, *and* - # b. are *not* part of our calculated auth events based on room state, *and* - # c. are *not* yet in our database. - - if missing_auth: - # If we don't have all the auth events, we need to get them. - logger.info("auth_events contains unknown events: %s", missing_auth) - try: - await self._get_remote_auth_chain_for_event( - origin, event.room_id, event.event_id - ) - except Exception: - logger.exception("Failed to get auth chain") - else: - # load any auth events we might have persisted from the database. This - # has the side-effect of correctly setting the rejected_reason on them. - auth_events.update( - { - (ae.type, ae.state_key): ae - for ae in await self._store.get_events_as_list( - missing_auth, allow_rejected=True - ) - } - ) - - # auth_events now contains - # 1. our *calculated* auth events based on the room state, plus: - # 2. any events which: - # a. are listed in `event.auth_events`, *and* - # b. are not part of our calculated auth events, *and* - # c. were not in our database before the call to /event_auth - # d. have since been added to our database (most likely by /event_auth). - different_auth = event_auth_events.difference( - e.event_id for e in auth_events.values() + e.event_id for e in calculated_auth_event_map.values() ) - # different_auth is the set of events which *are* in `event.auth_events`, but - # which are *not* in `auth_events`. Comparing with (2.) above, this means - # exclusively the set of `event.auth_events` which we already had in our - # database before any call to /event_auth. - # - # I'm reasonably sure that the fact that events returned by /event_auth are - # blindly added to auth_events (and hence excluded from different_auth) is a bug - # - though it's a very long-standing one (see - # https://github.com/matrix-org/synapse/commit/78015948a7febb18e000651f72f8f58830a55b93#diff-0bc92da3d703202f5b9be2d3f845e375f5b1a6bc6ba61705a8af9be1121f5e42R786 - # from Jan 2015 which seems to add it, though it actually just moves it from - # elsewhere (before that, it gets lost in a mess of huge "various bug fixes" - # PRs). - if not different_auth: - return context, auth_events + return None logger.info( "auth_events refers to events which are not in our calculated auth " @@ -1680,27 +1611,18 @@ class FederationEventHandler: # necessary? different_events = await self._store.get_events_as_list(different_auth) + # double-check they're all in the same room - we should already have checked + # this but it doesn't hurt to check again. for d in different_events: - if d.room_id != event.room_id: - logger.warning( - "Event %s refers to auth_event %s which is in a different room", - event.event_id, - d.event_id, - ) - - # don't attempt to resolve the claimed auth events against our own - # in this case: just use our own auth events. - # - # XXX: should we reject the event in this case? It feels like we should, - # but then shouldn't we also do so if we've failed to fetch any of the - # auth events? - return context, auth_events + assert ( + d.room_id == event.room_id + ), f"Event {event.event_id} refers to auth_event {d.event_id} which is in a different room" # now we state-resolve between our own idea of the auth events, and the remote's # idea of them. - local_state = auth_events.values() - remote_auth_events = dict(auth_events) + local_state = calculated_auth_event_map.values() + remote_auth_events = dict(calculated_auth_event_map) remote_auth_events.update({(d.type, d.state_key): d for d in different_events}) remote_state = remote_auth_events.values() @@ -1708,23 +1630,24 @@ class FederationEventHandler: new_state = await self._state_handler.resolve_events( room_version, (local_state, remote_state), event ) + different_state = { + (d.type, d.state_key): d + for d in new_state.values() + if calculated_auth_event_map.get((d.type, d.state_key)) != d + } + if not different_state: + logger.info("State res returned no new state") + return None logger.info( "After state res: updating auth_events with new state %s", - { - d - for d in new_state.values() - if auth_events.get((d.type, d.state_key)) != d - }, + different_state.values(), ) - auth_events.update(new_state) - - context = await self._update_context_for_auth_events( - event, context, auth_events - ) - - return context, auth_events + # take a copy of calculated_auth_event_map before we modify it. + auth_events = dict(calculated_auth_event_map) + auth_events.update(different_state) + return auth_events async def _load_or_fetch_auth_events_for_event( self, destination: str, event: EventBase -- cgit 1.5.1 From 62db603fa0cae4813e119291b606bff290461b2b Mon Sep 17 00:00:00 2001 From: Robert Edström <108799+Legogris@users.noreply.github.com> Date: Wed, 20 Oct 2021 17:43:49 +0000 Subject: Consider IP whitelist for identity server resolution (#11120) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Robert Edström --- changelog.d/11120.bugfix | 1 + synapse/handlers/identity.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 changelog.d/11120.bugfix (limited to 'synapse') diff --git a/changelog.d/11120.bugfix b/changelog.d/11120.bugfix new file mode 100644 index 0000000000..6b39e3e89d --- /dev/null +++ b/changelog.d/11120.bugfix @@ -0,0 +1 @@ +Identity server connection is no longer ignoring `ip_range_whitelist`. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 9c319b5383..7ef8698a5e 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -54,7 +54,9 @@ class IdentityHandler: self.http_client = SimpleHttpClient(hs) # An HTTP client for contacting identity servers specified by clients. self.blacklisting_http_client = SimpleHttpClient( - hs, ip_blacklist=hs.config.server.federation_ip_range_blacklist + hs, + ip_blacklist=hs.config.server.federation_ip_range_blacklist, + ip_whitelist=hs.config.server.federation_ip_range_whitelist, ) self.federation_http_client = hs.get_federation_http_client() self.hs = hs -- cgit 1.5.1 From ef7fe09778ad672d6ed80fb2206cfbc11e6a9a5e Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 21 Oct 2021 10:52:32 +0200 Subject: Fix setting a user's external_id via the admin API returns 500 and deletes users existing external mappings if that external ID is already mapped (#11051) Fixes #10846 --- changelog.d/11051.bugfix | 1 + synapse/rest/admin/users.py | 47 +++--- synapse/storage/databases/main/registration.py | 95 ++++++++++- tests/rest/admin/test_user.py | 215 ++++++++++++++++++++++++- 4 files changed, 321 insertions(+), 37 deletions(-) create mode 100644 changelog.d/11051.bugfix (limited to 'synapse') diff --git a/changelog.d/11051.bugfix b/changelog.d/11051.bugfix new file mode 100644 index 0000000000..63126843d2 --- /dev/null +++ b/changelog.d/11051.bugfix @@ -0,0 +1 @@ +Fix a bug where setting a user's external_id via the admin API returns 500 and deletes users existing external mappings if that external ID is already mapped. \ No newline at end of file diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index f20aa65301..c0bebc3cf0 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -35,6 +35,7 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.rest.client._base import client_patterns +from synapse.storage.databases.main.registration import ExternalIDReuseException from synapse.storage.databases.main.stats import UserSortOrder from synapse.types import JsonDict, UserID @@ -228,12 +229,12 @@ class UserRestServletV2(RestServlet): if not isinstance(deactivate, bool): raise SynapseError(400, "'deactivated' parameter is not of type boolean") - # convert List[Dict[str, str]] into Set[Tuple[str, str]] + # convert List[Dict[str, str]] into List[Tuple[str, str]] if external_ids is not None: - new_external_ids = { + new_external_ids = [ (external_id["auth_provider"], external_id["external_id"]) for external_id in external_ids - } + ] # convert List[Dict[str, str]] into Set[Tuple[str, str]] if threepids is not None: @@ -275,28 +276,13 @@ class UserRestServletV2(RestServlet): ) if external_ids is not None: - # get changed external_ids (added and removed) - cur_external_ids = set( - await self.store.get_external_ids_by_user(user_id) - ) - add_external_ids = new_external_ids - cur_external_ids - del_external_ids = cur_external_ids - new_external_ids - - # remove old external_ids - for auth_provider, external_id in del_external_ids: - await self.store.remove_user_external_id( - auth_provider, - external_id, - user_id, - ) - - # add new external_ids - for auth_provider, external_id in add_external_ids: - await self.store.record_user_external_id( - auth_provider, - external_id, + try: + await self.store.replace_user_external_id( + new_external_ids, user_id, ) + except ExternalIDReuseException: + raise SynapseError(409, "External id is already in use.") if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( @@ -384,12 +370,15 @@ class UserRestServletV2(RestServlet): ) if external_ids is not None: - for auth_provider, external_id in new_external_ids: - await self.store.record_user_external_id( - auth_provider, - external_id, - user_id, - ) + try: + for auth_provider, external_id in new_external_ids: + await self.store.record_user_external_id( + auth_provider, + external_id, + user_id, + ) + except ExternalIDReuseException: + raise SynapseError(409, "External id is already in use.") if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 0ab56d8a07..37d47aa823 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -23,7 +23,11 @@ import attr from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.stats import StatsStore from synapse.storage.types import Cursor @@ -40,6 +44,13 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 logger = logging.getLogger(__name__) +class ExternalIDReuseException(Exception): + """Exception if writing an external id for a user fails, + because this external id is given to an other user.""" + + pass + + @attr.s(frozen=True, slots=True) class TokenLookupResult: """Result of looking up an access token. @@ -588,24 +599,44 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): auth_provider: identifier for the remote auth provider external_id: id on that system user_id: complete mxid that it is mapped to + Raises: + ExternalIDReuseException if the new external_id could not be mapped. """ - await self.db_pool.simple_insert( + + try: + await self.db_pool.runInteraction( + "record_user_external_id", + self._record_user_external_id_txn, + auth_provider, + external_id, + user_id, + ) + except self.database_engine.module.IntegrityError: + raise ExternalIDReuseException() + + def _record_user_external_id_txn( + self, + txn: LoggingTransaction, + auth_provider: str, + external_id: str, + user_id: str, + ) -> None: + + self.db_pool.simple_insert_txn( + txn, table="user_external_ids", values={ "auth_provider": auth_provider, "external_id": external_id, "user_id": user_id, }, - desc="record_user_external_id", ) async def remove_user_external_id( self, auth_provider: str, external_id: str, user_id: str ) -> None: """Remove a mapping from an external user id to a mxid - If the mapping is not found, this method does nothing. - Args: auth_provider: identifier for the remote auth provider external_id: id on that system @@ -621,6 +652,60 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="remove_user_external_id", ) + async def replace_user_external_id( + self, + record_external_ids: List[Tuple[str, str]], + user_id: str, + ) -> None: + """Replace mappings from external user ids to a mxid in a single transaction. + All mappings are deleted and the new ones are created. + + Args: + record_external_ids: + List with tuple of auth_provider and external_id to record + user_id: complete mxid that it is mapped to + Raises: + ExternalIDReuseException if the new external_id could not be mapped. + """ + + def _remove_user_external_ids_txn( + txn: LoggingTransaction, + user_id: str, + ) -> None: + """Remove all mappings from external user ids to a mxid + If these mappings are not found, this method does nothing. + + Args: + user_id: complete mxid that it is mapped to + """ + + self.db_pool.simple_delete_txn( + txn, + table="user_external_ids", + keyvalues={"user_id": user_id}, + ) + + def _replace_user_external_id_txn( + txn: LoggingTransaction, + ): + _remove_user_external_ids_txn(txn, user_id) + + for auth_provider, external_id in record_external_ids: + self._record_user_external_id_txn( + txn, + auth_provider, + external_id, + user_id, + ) + + try: + await self.db_pool.runInteraction( + "replace_user_external_id", + _replace_user_external_id_txn, + ) + except self.database_engine.module.IntegrityError: + raise ExternalIDReuseException() + async def get_user_by_external_id( self, auth_provider: str, external_id: str ) -> Optional[str]: diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index c9e2754b09..839442ddba 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1180,9 +1180,8 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.other_user, device_id=None, valid_until_ms=None ) ) - self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( - self.other_user - ) + self.url_prefix = "/_synapse/admin/v2/users/%s" + self.url_other_user = self.url_prefix % self.other_user def test_requester_is_no_admin(self): """ @@ -1738,6 +1737,93 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, len(channel.json_body["threepids"])) self._check_fields(channel.json_body) + def test_set_duplicate_threepid(self): + """ + Test setting the same threepid for a second user. + First user loses and second user gets mapping of this threepid. + """ + + # create a user to set a threepid + first_user = self.register_user("first_user", "pass") + url_first_user = self.url_prefix % first_user + + # Add threepid to first user + channel = self.make_request( + "PUT", + url_first_user, + access_token=self.admin_user_tok, + content={ + "threepids": [ + {"medium": "email", "address": "bob1@bob.bob"}, + ], + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(first_user, channel.json_body["name"]) + self.assertEqual(1, len(channel.json_body["threepids"])) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob1@bob.bob", channel.json_body["threepids"][0]["address"]) + self._check_fields(channel.json_body) + + # Add threepids to other user + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "threepids": [ + {"medium": "email", "address": "bob2@bob.bob"}, + ], + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(1, len(channel.json_body["threepids"])) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"]) + self._check_fields(channel.json_body) + + # Add two new threepids to other user + # one is used by first_user + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "threepids": [ + {"medium": "email", "address": "bob1@bob.bob"}, + {"medium": "email", "address": "bob3@bob.bob"}, + ], + }, + ) + + # other user has this two threepids + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["threepids"])) + # result does not always have the same sort order, therefore it becomes sorted + sorted_result = sorted( + channel.json_body["threepids"], key=lambda k: k["address"] + ) + self.assertEqual("email", sorted_result[0]["medium"]) + self.assertEqual("bob1@bob.bob", sorted_result[0]["address"]) + self.assertEqual("email", sorted_result[1]["medium"]) + self.assertEqual("bob3@bob.bob", sorted_result[1]["address"]) + self._check_fields(channel.json_body) + + # first_user has no threepid anymore + channel = self.make_request( + "GET", + url_first_user, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(first_user, channel.json_body["name"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self._check_fields(channel.json_body) + def test_set_external_id(self): """ Test setting external id for an other user. @@ -1836,6 +1922,129 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual(0, len(channel.json_body["external_ids"])) + def test_set_duplicate_external_id(self): + """ + Test that setting the same external id for a second user fails and + external id from user must not be changed. + """ + + # create a user to use an external id + first_user = self.register_user("first_user", "pass") + url_first_user = self.url_prefix % first_user + + # Add an external id to first user + channel = self.make_request( + "PUT", + url_first_user, + access_token=self.admin_user_tok, + content={ + "external_ids": [ + { + "external_id": "external_id1", + "auth_provider": "auth_provider", + }, + ], + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(first_user, channel.json_body["name"]) + self.assertEqual(1, len(channel.json_body["external_ids"])) + self.assertEqual( + "external_id1", channel.json_body["external_ids"][0]["external_id"] + ) + self.assertEqual( + "auth_provider", channel.json_body["external_ids"][0]["auth_provider"] + ) + self._check_fields(channel.json_body) + + # Add an external id to other user + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "external_ids": [ + { + "external_id": "external_id2", + "auth_provider": "auth_provider", + }, + ], + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(1, len(channel.json_body["external_ids"])) + self.assertEqual( + "external_id2", channel.json_body["external_ids"][0]["external_id"] + ) + self.assertEqual( + "auth_provider", channel.json_body["external_ids"][0]["auth_provider"] + ) + self._check_fields(channel.json_body) + + # Add two new external_ids to other user + # one is used by first + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "external_ids": [ + { + "external_id": "external_id1", + "auth_provider": "auth_provider", + }, + { + "external_id": "external_id3", + "auth_provider": "auth_provider", + }, + ], + }, + ) + + # must fail + self.assertEqual(409, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual("External id is already in use.", channel.json_body["error"]) + + # other user must not changed + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(1, len(channel.json_body["external_ids"])) + self.assertEqual( + "external_id2", channel.json_body["external_ids"][0]["external_id"] + ) + self.assertEqual( + "auth_provider", channel.json_body["external_ids"][0]["auth_provider"] + ) + self._check_fields(channel.json_body) + + # first user must not changed + channel = self.make_request( + "GET", + url_first_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(first_user, channel.json_body["name"]) + self.assertEqual(1, len(channel.json_body["external_ids"])) + self.assertEqual( + "external_id1", channel.json_body["external_ids"][0]["external_id"] + ) + self.assertEqual( + "auth_provider", channel.json_body["external_ids"][0]["auth_provider"] + ) + self._check_fields(channel.json_body) + def test_deactivate_user(self): """ Test deactivating another user. -- cgit 1.5.1 From 0f9adc99ada1f66f4897c8164dcf509a955e5584 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 21 Oct 2021 09:07:07 -0400 Subject: Add missing type hints to synapse.crypto. (#11146) And require type hints for this module. --- changelog.d/11146.misc | 1 + mypy.ini | 3 +++ synapse/crypto/context_factory.py | 40 +++++++++++++++++++++++++-------------- synapse/crypto/event_signing.py | 2 +- synapse/crypto/keyring.py | 8 +++++--- 5 files changed, 36 insertions(+), 18 deletions(-) create mode 100644 changelog.d/11146.misc (limited to 'synapse') diff --git a/changelog.d/11146.misc b/changelog.d/11146.misc new file mode 100644 index 0000000000..6ce1c9f9f5 --- /dev/null +++ b/changelog.d/11146.misc @@ -0,0 +1 @@ +Add missing type hints to `synapse.crypto`. diff --git a/mypy.ini b/mypy.ini index 14d8bb8eaf..c5f44aea39 100644 --- a/mypy.ini +++ b/mypy.ini @@ -103,6 +103,9 @@ files = [mypy-synapse.api.*] disallow_untyped_defs = True +[mypy-synapse.crypto.*] +disallow_untyped_defs = True + [mypy-synapse.events.*] disallow_untyped_defs = True diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 2a6110eb10..7855f3498b 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -29,9 +29,12 @@ from twisted.internet.ssl import ( TLSVersion, platformTrust, ) +from twisted.protocols.tls import TLSMemoryBIOProtocol from twisted.python.failure import Failure from twisted.web.iweb import IPolicyForHTTPS +from synapse.config.homeserver import HomeServerConfig + logger = logging.getLogger(__name__) @@ -51,7 +54,7 @@ class ServerContextFactory(ContextFactory): per https://github.com/matrix-org/synapse/issues/1691 """ - def __init__(self, config): + def __init__(self, config: HomeServerConfig): # TODO: once pyOpenSSL exposes TLS_METHOD and SSL_CTX_set_min_proto_version, # switch to those (see https://github.com/pyca/cryptography/issues/5379). # @@ -64,7 +67,7 @@ class ServerContextFactory(ContextFactory): self.configure_context(self._context, config) @staticmethod - def configure_context(context, config): + def configure_context(context: SSL.Context, config: HomeServerConfig) -> None: try: _ecCurve = crypto.get_elliptic_curve(_defaultCurveName) context.set_tmp_ecdh(_ecCurve) @@ -75,14 +78,15 @@ class ServerContextFactory(ContextFactory): SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_NO_TLSv1 | SSL.OP_NO_TLSv1_1 ) context.use_certificate_chain_file(config.tls.tls_certificate_file) + assert config.tls.tls_private_key is not None context.use_privatekey(config.tls.tls_private_key) # https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/ context.set_cipher_list( - "ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1:!AESCCM" + b"ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1:!AESCCM" ) - def getContext(self): + def getContext(self) -> SSL.Context: return self._context @@ -98,7 +102,7 @@ class FederationPolicyForHTTPS: constructs an SSLClientConnectionCreator factory accordingly. """ - def __init__(self, config): + def __init__(self, config: HomeServerConfig): self._config = config # Check if we're using a custom list of a CA certificates @@ -131,7 +135,7 @@ class FederationPolicyForHTTPS: self._config.tls.federation_certificate_verification_whitelist ) - def get_options(self, host: bytes): + def get_options(self, host: bytes) -> IOpenSSLClientConnectionCreator: # IPolicyForHTTPS.get_options takes bytes, but we want to compare # against the str whitelist. The hostnames in the whitelist are already # IDNA-encoded like the hosts will be here. @@ -153,7 +157,9 @@ class FederationPolicyForHTTPS: return SSLClientConnectionCreator(host, ssl_context, should_verify) - def creatorForNetloc(self, hostname, port): + def creatorForNetloc( + self, hostname: bytes, port: int + ) -> IOpenSSLClientConnectionCreator: """Implements the IPolicyForHTTPS interface so that this can be passed directly to agents. """ @@ -169,16 +175,18 @@ class RegularPolicyForHTTPS: trust root. """ - def __init__(self): + def __init__(self) -> None: trust_root = platformTrust() self._ssl_context = CertificateOptions(trustRoot=trust_root).getContext() self._ssl_context.set_info_callback(_context_info_cb) - def creatorForNetloc(self, hostname, port): + def creatorForNetloc( + self, hostname: bytes, port: int + ) -> IOpenSSLClientConnectionCreator: return SSLClientConnectionCreator(hostname, self._ssl_context, True) -def _context_info_cb(ssl_connection, where, ret): +def _context_info_cb(ssl_connection: SSL.Connection, where: int, ret: int) -> None: """The 'information callback' for our openssl context objects. Note: Once this is set as the info callback on a Context object, the Context should @@ -204,11 +212,13 @@ class SSLClientConnectionCreator: Replaces twisted.internet.ssl.ClientTLSOptions """ - def __init__(self, hostname: bytes, ctx, verify_certs: bool): + def __init__(self, hostname: bytes, ctx: SSL.Context, verify_certs: bool): self._ctx = ctx self._verifier = ConnectionVerifier(hostname, verify_certs) - def clientConnectionForTLS(self, tls_protocol): + def clientConnectionForTLS( + self, tls_protocol: TLSMemoryBIOProtocol + ) -> SSL.Connection: context = self._ctx connection = SSL.Connection(context, None) @@ -219,7 +229,7 @@ class SSLClientConnectionCreator: # ... and we also gut-wrench a '_synapse_tls_verifier' attribute into the # tls_protocol so that the SSL context's info callback has something to # call to do the cert verification. - tls_protocol._synapse_tls_verifier = self._verifier + tls_protocol._synapse_tls_verifier = self._verifier # type: ignore[attr-defined] return connection @@ -244,7 +254,9 @@ class ConnectionVerifier: self._hostnameBytes = hostname self._hostnameASCII = self._hostnameBytes.decode("ascii") - def verify_context_info_cb(self, ssl_connection, where): + def verify_context_info_cb( + self, ssl_connection: SSL.Connection, where: int + ) -> None: if where & SSL.SSL_CB_HANDSHAKE_START and not self._is_ip_address: ssl_connection.set_tlsext_host_name(self._hostnameBytes) diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index 0f2b632e47..7520647d1e 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -100,7 +100,7 @@ def compute_content_hash( def compute_event_reference_hash( - event, hash_algorithm: Hasher = hashlib.sha256 + event: EventBase, hash_algorithm: Hasher = hashlib.sha256 ) -> Tuple[str, bytes]: """Computes the event reference hash. This is the hash of the redacted event. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index e1e13a2412..8628e951c4 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -87,7 +87,7 @@ class VerifyJsonRequest: server_name: str, json_object: JsonDict, minimum_valid_until_ms: int, - ): + ) -> "VerifyJsonRequest": """Create a VerifyJsonRequest to verify all signatures on a signed JSON object for the given server. """ @@ -104,7 +104,7 @@ class VerifyJsonRequest: server_name: str, event: EventBase, minimum_valid_until_ms: int, - ): + ) -> "VerifyJsonRequest": """Create a VerifyJsonRequest to verify all signatures on an event object for the given server. """ @@ -449,7 +449,9 @@ class StoreKeyFetcher(KeyFetcher): self.store = hs.get_datastore() - async def _fetch_keys(self, keys_to_fetch: List[_FetchKeyRequest]): + async def _fetch_keys( + self, keys_to_fetch: List[_FetchKeyRequest] + ) -> Dict[str, Dict[str, FetchKeyResult]]: key_ids_to_fetch = ( (queue_value.server_name, key_id) for queue_value in keys_to_fetch -- cgit 1.5.1 From 6408372234eef2d72a13ee838c07199751c56378 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 21 Oct 2021 17:42:25 +0100 Subject: Improve docstrings for methods related to sending EDUs to application services (#11138) --- changelog.d/11138.misc | 1 + synapse/handlers/appservice.py | 94 ++++++++++++++++++++++++++++++++++++------ synapse/handlers/device.py | 4 ++ synapse/handlers/presence.py | 34 ++++++++++++--- synapse/handlers/receipts.py | 8 +++- synapse/handlers/typing.py | 12 ++++-- synapse/notifier.py | 18 +++++++- 7 files changed, 148 insertions(+), 23 deletions(-) create mode 100644 changelog.d/11138.misc (limited to 'synapse') diff --git a/changelog.d/11138.misc b/changelog.d/11138.misc new file mode 100644 index 0000000000..79b7776975 --- /dev/null +++ b/changelog.d/11138.misc @@ -0,0 +1 @@ +Add docstrings and comments to the application service ephemeral event sending code. \ No newline at end of file diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 163278708c..36c206dae6 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -185,19 +185,26 @@ class ApplicationServicesHandler: new_token: Optional[int], users: Optional[Collection[Union[str, UserID]]] = None, ) -> None: - """This is called by the notifier in the background - when a ephemeral event handled by the homeserver. - - This will determine which appservices - are interested in the event, and submit them. + """ + This is called by the notifier in the background when an ephemeral event is handled + by the homeserver. - Events will only be pushed to appservices - that have opted into ephemeral events + This will determine which appservices are interested in the event, and submit them. Args: stream_key: The stream the event came from. - new_token: The latest stream token - users: The user(s) involved with the event. + + `stream_key` can be "typing_key", "receipt_key" or "presence_key". Any other + value for `stream_key` will cause this function to return early. + + Ephemeral events will only be pushed to appservices that have opted into + them. + + Appservices will only receive ephemeral events that fall within their + registered user and room namespaces. + + new_token: The latest stream token. + users: The users that should be informed of the new event, if any. """ if not self.notify_appservices: return @@ -232,21 +239,32 @@ class ApplicationServicesHandler: for service in services: # Only handle typing if we have the latest token if stream_key == "typing_key" and new_token is not None: + # Note that we don't persist the token (via set_type_stream_id_for_appservice) + # for typing_key due to performance reasons and due to their highly + # ephemeral nature. + # + # Instead we simply grab the latest typing updates in _handle_typing + # and, if they apply to this application service, send it off. events = await self._handle_typing(service, new_token) if events: self.scheduler.submit_ephemeral_events_for_as(service, events) - # We don't persist the token for typing_key for performance reasons + elif stream_key == "receipt_key": events = await self._handle_receipts(service) if events: self.scheduler.submit_ephemeral_events_for_as(service, events) + + # Persist the latest handled stream token for this appservice await self.store.set_type_stream_id_for_appservice( service, "read_receipt", new_token ) + elif stream_key == "presence_key": events = await self._handle_presence(service, users) if events: self.scheduler.submit_ephemeral_events_for_as(service, events) + + # Persist the latest handled stream token for this appservice await self.store.set_type_stream_id_for_appservice( service, "presence", new_token ) @@ -254,18 +272,54 @@ class ApplicationServicesHandler: async def _handle_typing( self, service: ApplicationService, new_token: int ) -> List[JsonDict]: + """ + Return the typing events since the given stream token that the given application + service should receive. + + First fetch all typing events between the given typing stream token (non-inclusive) + and the latest typing event stream token (inclusive). Then return only those typing + events that the given application service may be interested in. + + Args: + service: The application service to check for which events it should receive. + new_token: A typing event stream token. + + Returns: + A list of JSON dictionaries containing data derived from the typing events that + should be sent to the given application service. + """ typing_source = self.event_sources.sources.typing # Get the typing events from just before current typing, _ = await typing_source.get_new_events_as( service=service, # For performance reasons, we don't persist the previous - # token in the DB and instead fetch the latest typing information + # token in the DB and instead fetch the latest typing event # for appservices. + # TODO: It'd likely be more efficient to simply fetch the + # typing event with the given 'new_token' stream token and + # check if the given service was interested, rather than + # iterating over all typing events and only grabbing the + # latest few. from_key=new_token - 1, ) return typing async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]: + """ + Return the latest read receipts that the given application service should receive. + + First fetch all read receipts between the last receipt stream token that this + application service should have previously received (non-inclusive) and the + latest read receipt stream token (inclusive). Then from that set, return only + those read receipts that the given application service may be interested in. + + Args: + service: The application service to check for which events it should receive. + + Returns: + A list of JSON dictionaries containing data derived from the read receipts that + should be sent to the given application service. + """ from_key = await self.store.get_type_stream_id_for_appservice( service, "read_receipt" ) @@ -278,6 +332,22 @@ class ApplicationServicesHandler: async def _handle_presence( self, service: ApplicationService, users: Collection[Union[str, UserID]] ) -> List[JsonDict]: + """ + Return the latest presence updates that the given application service should receive. + + First, filter the given users list to those that the application service is + interested in. Then retrieve the latest presence updates since the + the last-known previously received presence stream token for the given + application service. Return those presence updates. + + Args: + service: The application service that ephemeral events are being sent to. + users: The users that should receive the presence update. + + Returns: + A list of json dictionaries containing data derived from the presence events + that should be sent to the given application service. + """ events: List[JsonDict] = [] presence_source = self.event_sources.sources.presence from_key = await self.store.get_type_stream_id_for_appservice( @@ -290,9 +360,9 @@ class ApplicationServicesHandler: interested = await service.is_interested_in_presence(user, self.store) if not interested: continue + presence_events, _ = await presence_source.get_new_events( user=user, - service=service, from_key=from_key, ) time_now = self.clock.time_msec() diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 6eafbea25d..68b446eb66 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -454,6 +454,10 @@ class DeviceHandler(DeviceWorkerHandler): ) -> None: """Notify that a user's device(s) has changed. Pokes the notifier, and remote servers if the user is local. + + Args: + user_id: The Matrix ID of the user who's device list has been updated. + device_ids: The device IDs that have changed. """ if not device_ids: # No changes to notify about, so this is a no-op. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index b5968e047b..fdab50da37 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -52,7 +52,6 @@ import synapse.metrics from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState -from synapse.appservice import ApplicationService from synapse.events.presence_router import PresenceRouter from synapse.logging.context import run_in_background from synapse.logging.utils import log_function @@ -1483,11 +1482,37 @@ def should_notify(old_state: UserPresenceState, new_state: UserPresenceState) -> def format_user_presence_state( state: UserPresenceState, now: int, include_user_id: bool = True ) -> JsonDict: - """Convert UserPresenceState to a format that can be sent down to clients + """Convert UserPresenceState to a JSON format that can be sent down to clients and to other servers. - The "user_id" is optional so that this function can be used to format presence - updates for client /sync responses and for federation /send requests. + Args: + state: The user presence state to format. + now: The current timestamp since the epoch in ms. + include_user_id: Whether to include `user_id` in the returned dictionary. + As this function can be used both to format presence updates for client /sync + responses and for federation /send requests, only the latter needs the include + the `user_id` field. + + Returns: + A JSON dictionary with the following keys: + * presence: The presence state as a str. + * user_id: Optional. Included if `include_user_id` is truthy. The canonical + Matrix ID of the user. + * last_active_ago: Optional. Included if `last_active_ts` is set on `state`. + The timestamp that the user was last active. + * status_msg: Optional. Included if `status_msg` is set on `state`. The user's + status. + * currently_active: Optional. Included only if `state.state` is "online". + + Example: + + { + "presence": "online", + "user_id": "@alice:example.com", + "last_active_ago": 16783813918, + "status_msg": "Hello world!", + "currently_active": True + } """ content: JsonDict = {"presence": state.state} if include_user_id: @@ -1526,7 +1551,6 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): is_guest: bool = False, explicit_room_id: Optional[str] = None, include_offline: bool = True, - service: Optional[ApplicationService] = None, ) -> Tuple[List[UserPresenceState], int]: # The process for getting presence events are: # 1. Get the rooms the user is in. diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 374e961e3b..4911a11535 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -241,12 +241,18 @@ class ReceiptEventSource(EventSource[int, JsonDict]): async def get_new_events_as( self, from_key: int, service: ApplicationService ) -> Tuple[List[JsonDict], int]: - """Returns a set of new receipt events that an appservice + """Returns a set of new read receipt events that an appservice may be interested in. Args: from_key: the stream position at which events should be fetched from service: The appservice which may be interested + + Returns: + A two-tuple containing the following: + * A list of json dictionaries derived from read receipts that the + appservice may be interested in. + * The current read receipt stream token. """ from_key = int(from_key) to_key = self.get_current_key() diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index d10e9b8ec4..c411d69924 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -465,17 +465,23 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): may be interested in. Args: - from_key: the stream position at which events should be fetched from - service: The appservice which may be interested + from_key: the stream position at which events should be fetched from. + service: The appservice which may be interested. + + Returns: + A two-tuple containing the following: + * A list of json dictionaries derived from typing events that the + appservice may be interested in. + * The latest known room serial. """ with Measure(self.clock, "typing.get_new_events_as"): - from_key = int(from_key) handler = self.get_typing_handler() events = [] for room_id in handler._room_serials.keys(): if handler._room_serials[room_id] <= from_key: continue + if not await service.matches_user_in_member_list( room_id, handler.store ): diff --git a/synapse/notifier.py b/synapse/notifier.py index 1a9f84ba45..1acd899fab 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -379,7 +379,14 @@ class Notifier: stream_key: str, new_token: Union[int, RoomStreamToken], users: Optional[Collection[Union[str, UserID]]] = None, - ): + ) -> None: + """Notify application services of ephemeral event activity. + + Args: + stream_key: The stream the event came from. + new_token: The value of the new stream token. + users: The users that should be informed of the new event, if any. + """ try: stream_token = None if isinstance(new_token, int): @@ -402,10 +409,17 @@ class Notifier: new_token: Union[int, RoomStreamToken], users: Optional[Collection[Union[str, UserID]]] = None, rooms: Optional[Collection[str]] = None, - ): + ) -> None: """Used to inform listeners that something has happened event wise. Will wake up all listeners for the given users and rooms. + + Args: + stream_key: The stream the event came from. + new_token: The value of the new stream token. + users: The users that should be informed of the new event. + rooms: A collection of room IDs for which each joined member will be + informed of the new event. """ users = users or [] rooms = rooms or [] -- cgit 1.5.1 From 2d91b6256e53a9e60027880b0407bd77cb653ad1 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 21 Oct 2021 17:48:59 +0100 Subject: Fix adding excluded users to the private room sharing tables when joining a room (#11143) * We only need to fetch users in private rooms * Filter out `user_id` at the top * Discard excluded users in the top loop We weren't doing this in the "First, if they're our user" branch so this is a bugfix. * The caller must check that `user_id` is included This is in the docstring. There are two call sites: - one in `_handle_room_publicity_change`, which explicitly checks before calling; - and another in `_handle_room_membership_event`, which returns early if the user is excluded. So this change is safe. * Test joining a private room with an excluded user * Tweak an existing test * Changelog * test docstring * lint --- changelog.d/11143.misc | 1 + synapse/handlers/user_directory.py | 28 +++++++-------- tests/handlers/test_user_directory.py | 67 +++++++++++++++++++++++++++-------- 3 files changed, 67 insertions(+), 29 deletions(-) create mode 100644 changelog.d/11143.misc (limited to 'synapse') diff --git a/changelog.d/11143.misc b/changelog.d/11143.misc new file mode 100644 index 0000000000..496e44a9c0 --- /dev/null +++ b/changelog.d/11143.misc @@ -0,0 +1 @@ +Fix a long-standing bug where users excluded from the directory could still be added to the `users_who_share_private_rooms` table after a regular user joins a private room. \ No newline at end of file diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 991fee7e58..a0eb45446f 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -373,31 +373,29 @@ class UserDirectoryHandler(StateDeltasHandler): is_public = await self.store.is_room_world_readable_or_publicly_joinable( room_id ) - other_users_in_room = await self.store.get_users_in_room(room_id) - if is_public: await self.store.add_users_in_public_rooms(room_id, (user_id,)) else: + users_in_room = await self.store.get_users_in_room(room_id) + other_users_in_room = [ + other + for other in users_in_room + if other != user_id + and ( + not self.is_mine_id(other) + or await self.store.should_include_local_user_in_dir(other) + ) + ] to_insert = set() # First, if they're our user then we need to update for every user if self.is_mine_id(user_id): - if await self.store.should_include_local_user_in_dir(user_id): - for other_user_id in other_users_in_room: - if user_id == other_user_id: - continue - - to_insert.add((user_id, other_user_id)) + for other_user_id in other_users_in_room: + to_insert.add((user_id, other_user_id)) # Next we need to update for every local user in the room for other_user_id in other_users_in_room: - if user_id == other_user_id: - continue - - include_other_user = self.is_mine_id( - other_user_id - ) and await self.store.should_include_local_user_in_dir(other_user_id) - if include_other_user: + if self.is_mine_id(other_user_id): to_insert.add((other_user_id, user_id)) if to_insert: diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index b9ad92b977..70c621b825 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -646,22 +646,20 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): u2_token = self.login(u2, "pass") u3 = self.register_user("user3", "pass") - # We do not add users to the directory until they join a room. + # u1 can't see u2 until they share a private room, or u1 is in a public room. s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 0) + # Get u1 and u2 into a private room. room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) self.helper.invite(room, src=u1, targ=u2, tok=u1_token) self.helper.join(room, user=u2, tok=u2_token) # Check we have populated the database correctly. - shares_private = self.get_success( - self.user_dir_helper.get_users_who_share_private_rooms() - ) - public_users = self.get_success( - self.user_dir_helper.get_users_in_public_rooms() + users, public_users, shares_private = self.get_success( + self.user_dir_helper.get_tables() ) - + self.assertEqual(users, {u1, u2, u3}) self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)}) self.assertEqual(public_users, set()) @@ -680,14 +678,11 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # User 2 then leaves. self.helper.leave(room, user=u2, tok=u2_token) - # Check we have removed the values. - shares_private = self.get_success( - self.user_dir_helper.get_users_who_share_private_rooms() - ) - public_users = self.get_success( - self.user_dir_helper.get_users_in_public_rooms() + # Check this is reflected in the DB. + users, public_users, shares_private = self.get_success( + self.user_dir_helper.get_tables() ) - + self.assertEqual(users, {u1, u2, u3}) self.assertEqual(shares_private, set()) self.assertEqual(public_users, set()) @@ -698,6 +693,50 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user3", 10)) self.assertEqual(len(s["results"]), 0) + def test_joining_private_room_with_excluded_user(self) -> None: + """ + When a user excluded from the user directory, E say, joins a private + room, E will not appear in the `users_who_share_private_rooms` table. + + When a normal user, U say, joins a private room containing E, then + U will appear in the `users_who_share_private_rooms` table, but E will + not. + """ + # Setup a support and two normal users. + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + bob = self.register_user("bob", "pass") + bob_token = self.login(bob, "pass") + support = "@support1:test" + self.get_success( + self.store.register_user( + user_id=support, password_hash=None, user_type=UserTypes.SUPPORT + ) + ) + + # Alice makes a room. Inject the support user into the room. + room = self.helper.create_room_as(alice, is_public=False, tok=alice_token) + self.get_success(inject_member_event(self.hs, room, support, "join")) + # Check the DB state. The support user should not be in the directory. + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice, bob}) + self.assertEqual(in_public, set()) + self.assertEqual(in_private, set()) + + # Then invite Bob, who accepts. + self.helper.invite(room, alice, bob, tok=alice_token) + self.helper.join(room, bob, tok=bob_token) + + # Check the DB state. The support user should not be in the directory. + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice, bob}) + self.assertEqual(in_public, set()) + self.assertEqual(in_private, {(alice, bob, room), (bob, alice, room)}) + def test_spam_checker(self) -> None: """ A user which fails the spam checks will not appear in search results. -- cgit 1.5.1 From ba00e20234eadae66f105f8bda64e39beed9a92d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 21 Oct 2021 14:39:16 -0400 Subject: Add a thread relation type per MSC3440. (#11088) Adds experimental support for MSC3440's `io.element.thread` relation type (and the aggregation for it). --- changelog.d/11088.feature | 1 + synapse/api/constants.py | 1 + synapse/config/experimental.py | 2 + synapse/events/utils.py | 17 +++++++++ synapse/rest/client/relations.py | 3 +- synapse/storage/databases/main/events.py | 4 ++ synapse/storage/databases/main/relations.py | 59 ++++++++++++++++++++++++++++- tests/rest/client/test_relations.py | 40 ++++++++++++++++--- 8 files changed, 119 insertions(+), 8 deletions(-) create mode 100644 changelog.d/11088.feature (limited to 'synapse') diff --git a/changelog.d/11088.feature b/changelog.d/11088.feature new file mode 100644 index 0000000000..76b0d28084 --- /dev/null +++ b/changelog.d/11088.feature @@ -0,0 +1 @@ +Experimental support for the thread relation defined in [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index a31f037748..a33ac34161 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -176,6 +176,7 @@ class RelationTypes: ANNOTATION = "m.annotation" REPLACE = "m.replace" REFERENCE = "m.reference" + THREAD = "io.element.thread" class LimitBlockingTypes: diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index b013a3918c..8b098ad48d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -26,6 +26,8 @@ class ExperimentalConfig(Config): # Whether to enable experimental MSC1849 (aka relations) support self.msc1849_enabled = config.get("experimental_msc1849_support_enabled", True) + # MSC3440 (thread relation) + self.msc3440_enabled: bool = experimental.get("msc3440_enabled", False) # MSC3026 (busy presence state) self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 3f3eba86a8..6fa631aa1d 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -386,6 +386,7 @@ class EventClientSerializer: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self._msc1849_enabled = hs.config.experimental.msc1849_enabled + self._msc3440_enabled = hs.config.experimental.msc3440_enabled async def serialize_event( self, @@ -462,6 +463,22 @@ class EventClientSerializer: "sender": edit.sender, } + # If this event is the start of a thread, include a summary of the replies. + if self._msc3440_enabled: + ( + thread_count, + latest_thread_event, + ) = await self.store.get_thread_summary(event_id) + if latest_thread_event: + r = serialized_event["unsigned"].setdefault("m.relations", {}) + r[RelationTypes.THREAD] = { + # Don't bundle aggregations as this could recurse forever. + "latest_event": await self.serialize_event( + latest_thread_event, time_now, bundle_aggregations=False + ), + "count": thread_count, + } + return serialized_event async def serialize_events( diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index d695c18be2..58f6699073 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -128,9 +128,10 @@ class RelationSendServlet(RestServlet): content["m.relates_to"] = { "event_id": parent_id, - "key": aggregation_key, "rel_type": relation_type, } + if aggregation_key is not None: + content["m.relates_to"]["key"] = aggregation_key event_dict = { "type": event_type, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 37439f8562..8d9086ecf0 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1710,6 +1710,7 @@ class PersistEventsStore: RelationTypes.ANNOTATION, RelationTypes.REFERENCE, RelationTypes.REPLACE, + RelationTypes.THREAD, ): # Unknown relation type return @@ -1740,6 +1741,9 @@ class PersistEventsStore: if rel_type == RelationTypes.REPLACE: txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) + if rel_type == RelationTypes.THREAD: + txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) + def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): """Handles keeping track of insertion events and edges/connections. Part of MSC2716. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 2bbf6d6a95..40760fbd1b 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Optional +from typing import Optional, Tuple import attr @@ -269,6 +269,63 @@ class RelationsWorkerStore(SQLBaseStore): return await self.get_event(edit_id, allow_none=True) + @cached() + async def get_thread_summary( + self, event_id: str + ) -> Tuple[int, Optional[EventBase]]: + """Get the number of threaded replies, the senders of those replies, and + the latest reply (if any) for the given event. + + Args: + event_id: The original event ID + + Returns: + The number of items in the thread and the most recent response, if any. + """ + + def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]: + # Fetch the count of threaded events and the latest event ID. + # TODO Should this only allow m.room.message events. + sql = """ + SELECT event_id + FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id = ? + AND relation_type = ? + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT 1 + """ + + txn.execute(sql, (event_id, RelationTypes.THREAD)) + row = txn.fetchone() + if row is None: + return 0, None + + latest_event_id = row[0] + + sql = """ + SELECT COALESCE(COUNT(event_id), 0) + FROM event_relations + WHERE + relates_to_id = ? + AND relation_type = ? + """ + txn.execute(sql, (event_id, RelationTypes.THREAD)) + count = txn.fetchone()[0] + + return count, latest_event_id + + count, latest_event_id = await self.db_pool.runInteraction( + "get_thread_summary", _get_thread_summary_txn + ) + + latest_event = None + if latest_event_id: + latest_event = await self.get_event(latest_event_id, allow_none=True) + + return count, latest_event + async def has_user_annotated_event( self, parent_id: str, event_type: str, aggregation_key: str, sender: str ) -> bool: diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 3c7d49f0b4..78c2fb86b9 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -101,10 +101,10 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_basic_paginate_relations(self): """Tests that calling pagination API correctly the latest relations.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") self.assertEquals(200, channel.code, channel.json_body) annotation_id = channel.json_body["event_id"] @@ -141,8 +141,10 @@ class RelationsTestCase(unittest.HomeserverTestCase): """ expected_event_ids = [] - for _ in range(10): - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + for idx in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx) + ) self.assertEquals(200, channel.code, channel.json_body) expected_event_ids.append(channel.json_body["event_id"]) @@ -386,8 +388,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) self.assertEquals(400, channel.code, channel.json_body) + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_aggregation_get_event(self): - """Test that annotations and references get correctly bundled when + """Test that annotations, references, and threads get correctly bundled when getting the parent event. """ @@ -410,6 +413,13 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) reply_2 = channel.json_body["event_id"] + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + thread_2 = channel.json_body["event_id"] + channel = self.make_request( "GET", "/rooms/%s/event/%s" % (self.room, self.parent_id), @@ -429,6 +439,25 @@ class RelationsTestCase(unittest.HomeserverTestCase): RelationTypes.REFERENCE: { "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] }, + RelationTypes.THREAD: { + "count": 2, + "latest_event": { + "age": 100, + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "origin_server_ts": 1600, + "room_id": self.room, + "sender": self.user_id, + "type": "m.room.test", + "unsigned": {"age": 100}, + "user_id": self.user_id, + }, + }, }, ) @@ -559,7 +588,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): { "m.relates_to": { "event_id": self.parent_id, - "key": None, "rel_type": "m.reference", } }, -- cgit 1.5.1 From b9ce53e8785d6f0dba6a3efcd708e4a185c32465 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Fri, 22 Oct 2021 13:00:52 +0300 Subject: Fix synapse.config module "read" command (#11145) `synapse.config.__main__` has the possibility to read a config item. This can be used to conveniently also validate the config is valid before trying to start Synapse. The "read" command broke in https://github.com/matrix-org/synapse/pull/10916 as it now requires passing in "server.server_name" for example. Also made the read command optional so one can just call this with just the confirm file reference and get a "Config parses OK" if things are ok. Signed-off-by: Jason Robinson Co-authored-by: Brendan Abolivier --- changelog.d/11145.bugfix | 1 + synapse/config/__main__.py | 46 ++++++++++++++++++++-------- tests/config/test___main__.py | 31 +++++++++++++++++++ tests/config/test_load.py | 70 ++++++++++--------------------------------- tests/config/utils.py | 58 +++++++++++++++++++++++++++++++++++ 5 files changed, 138 insertions(+), 68 deletions(-) create mode 100644 changelog.d/11145.bugfix create mode 100644 tests/config/test___main__.py create mode 100644 tests/config/utils.py (limited to 'synapse') diff --git a/changelog.d/11145.bugfix b/changelog.d/11145.bugfix new file mode 100644 index 0000000000..f369feac42 --- /dev/null +++ b/changelog.d/11145.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.45.0 breaking the configuration file parsing script. diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py index b5b6735a8f..c555f5f914 100644 --- a/synapse/config/__main__.py +++ b/synapse/config/__main__.py @@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,25 +12,44 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys + from synapse.config._base import ConfigError +from synapse.config.homeserver import HomeServerConfig -if __name__ == "__main__": - import sys - from synapse.config.homeserver import HomeServerConfig +def main(args): + action = args[1] if len(args) > 1 and args[1] == "read" else None + # If we're reading a key in the config file, then `args[1]` will be `read` and `args[2]` + # will be the key to read. + # We'll want to rework this code if we want to support more actions than just `read`. + load_config_args = args[3:] if action else args[1:] - action = sys.argv[1] + try: + config = HomeServerConfig.load_config("", load_config_args) + except ConfigError as e: + sys.stderr.write("\n" + str(e) + "\n") + sys.exit(1) + + print("Config parses OK!") if action == "read": - key = sys.argv[2] + key = args[2] + key_parts = key.split(".") + + value = config try: - config = HomeServerConfig.load_config("", sys.argv[3:]) - except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") + while len(key_parts): + value = getattr(value, key_parts[0]) + key_parts.pop(0) + + print(f"\n{key}: {value}") + except AttributeError: + print( + f"\nNo '{key}' key could be found in the provided configuration file." + ) sys.exit(1) - print(getattr(config, key)) - sys.exit(0) - else: - sys.stderr.write("Unknown command %r\n" % (action,)) - sys.exit(1) + +if __name__ == "__main__": + main(sys.argv) diff --git a/tests/config/test___main__.py b/tests/config/test___main__.py new file mode 100644 index 0000000000..b1c73d3612 --- /dev/null +++ b/tests/config/test___main__.py @@ -0,0 +1,31 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.config.__main__ import main + +from tests.config.utils import ConfigFileTestCase + + +class ConfigMainFileTestCase(ConfigFileTestCase): + def test_executes_without_an_action(self): + self.generate_config() + main(["", "-c", self.config_file]) + + def test_read__error_if_key_not_found(self): + self.generate_config() + with self.assertRaises(SystemExit): + main(["", "read", "foo.bar.hello", "-c", self.config_file]) + + def test_read__passes_if_key_found(self): + self.generate_config() + main(["", "read", "server.server_name", "-c", self.config_file]) diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 59635de205..765258c47a 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -1,4 +1,5 @@ # Copyright 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,43 +12,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os.path -import shutil -import tempfile -from contextlib import redirect_stdout -from io import StringIO - import yaml from synapse.config import ConfigError from synapse.config.homeserver import HomeServerConfig -from tests import unittest - - -class ConfigLoadingTestCase(unittest.TestCase): - def setUp(self): - self.dir = tempfile.mkdtemp() - self.file = os.path.join(self.dir, "homeserver.yaml") +from tests.config.utils import ConfigFileTestCase - def tearDown(self): - shutil.rmtree(self.dir) +class ConfigLoadingFileTestCase(ConfigFileTestCase): def test_load_fails_if_server_name_missing(self): self.generate_config_and_remove_lines_containing("server_name") with self.assertRaises(ConfigError): - HomeServerConfig.load_config("", ["-c", self.file]) + HomeServerConfig.load_config("", ["-c", self.config_file]) with self.assertRaises(ConfigError): - HomeServerConfig.load_or_generate_config("", ["-c", self.file]) + HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) def test_generates_and_loads_macaroon_secret_key(self): self.generate_config() - with open(self.file) as f: + with open(self.config_file) as f: raw = yaml.safe_load(f) self.assertIn("macaroon_secret_key", raw) - config = HomeServerConfig.load_config("", ["-c", self.file]) + config = HomeServerConfig.load_config("", ["-c", self.config_file]) self.assertTrue( hasattr(config.key, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", @@ -58,7 +46,7 @@ class ConfigLoadingTestCase(unittest.TestCase): "was: %r" % (config.key.macaroon_secret_key,) ) - config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) + config = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) self.assertTrue( hasattr(config.key, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", @@ -71,9 +59,9 @@ class ConfigLoadingTestCase(unittest.TestCase): def test_load_succeeds_if_macaroon_secret_key_missing(self): self.generate_config_and_remove_lines_containing("macaroon") - config1 = HomeServerConfig.load_config("", ["-c", self.file]) - config2 = HomeServerConfig.load_config("", ["-c", self.file]) - config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) + config1 = HomeServerConfig.load_config("", ["-c", self.config_file]) + config2 = HomeServerConfig.load_config("", ["-c", self.config_file]) + config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) self.assertEqual( config1.key.macaroon_secret_key, config2.key.macaroon_secret_key ) @@ -87,15 +75,15 @@ class ConfigLoadingTestCase(unittest.TestCase): ["enable_registration: true", "disable_registration: true"] ) # Check that disable_registration clobbers enable_registration. - config = HomeServerConfig.load_config("", ["-c", self.file]) + config = HomeServerConfig.load_config("", ["-c", self.config_file]) self.assertFalse(config.registration.enable_registration) - config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) + config = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file]) self.assertFalse(config.registration.enable_registration) # Check that either config value is clobbered by the command line. config = HomeServerConfig.load_or_generate_config( - "", ["-c", self.file, "--enable-registration"] + "", ["-c", self.config_file, "--enable-registration"] ) self.assertTrue(config.registration.enable_registration) @@ -104,33 +92,5 @@ class ConfigLoadingTestCase(unittest.TestCase): self.add_lines_to_config(["enable_metrics: true"]) # The default Metrics Flags are off by default. - config = HomeServerConfig.load_config("", ["-c", self.file]) + config = HomeServerConfig.load_config("", ["-c", self.config_file]) self.assertFalse(config.metrics.metrics_flags.known_servers) - - def generate_config(self): - with redirect_stdout(StringIO()): - HomeServerConfig.load_or_generate_config( - "", - [ - "--generate-config", - "-c", - self.file, - "--report-stats=yes", - "-H", - "lemurs.win", - ], - ) - - def generate_config_and_remove_lines_containing(self, needle): - self.generate_config() - - with open(self.file) as f: - contents = f.readlines() - contents = [line for line in contents if needle not in line] - with open(self.file, "w") as f: - f.write("".join(contents)) - - def add_lines_to_config(self, lines): - with open(self.file, "a") as f: - for line in lines: - f.write(line + "\n") diff --git a/tests/config/utils.py b/tests/config/utils.py new file mode 100644 index 0000000000..94c18a052b --- /dev/null +++ b/tests/config/utils.py @@ -0,0 +1,58 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import shutil +import tempfile +import unittest +from contextlib import redirect_stdout +from io import StringIO + +from synapse.config.homeserver import HomeServerConfig + + +class ConfigFileTestCase(unittest.TestCase): + def setUp(self): + self.dir = tempfile.mkdtemp() + self.config_file = os.path.join(self.dir, "homeserver.yaml") + + def tearDown(self): + shutil.rmtree(self.dir) + + def generate_config(self): + with redirect_stdout(StringIO()): + HomeServerConfig.load_or_generate_config( + "", + [ + "--generate-config", + "-c", + self.config_file, + "--report-stats=yes", + "-H", + "lemurs.win", + ], + ) + + def generate_config_and_remove_lines_containing(self, needle): + self.generate_config() + + with open(self.config_file) as f: + contents = f.readlines() + contents = [line for line in contents if needle not in line] + with open(self.config_file, "w") as f: + f.write("".join(contents)) + + def add_lines_to_config(self, lines): + with open(self.config_file, "a") as f: + for line in lines: + f.write(line + "\n") -- cgit 1.5.1 From 2b82ec425fccb0ef626242779f7ccd4d77a0685c Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Fri, 22 Oct 2021 18:15:41 +0100 Subject: Add type hints for most `HomeServer` parameters (#11095) --- changelog.d/11095.misc | 1 + synapse/app/_base.py | 8 +++---- synapse/app/admin_cmd.py | 4 ++-- synapse/app/generic_worker.py | 4 ++-- synapse/app/homeserver.py | 2 +- synapse/app/phone_stats_home.py | 8 +++++-- synapse/appservice/api.py | 3 ++- synapse/config/logger.py | 9 ++++++- synapse/federation/federation_base.py | 7 +++++- synapse/federation/federation_server.py | 9 +++---- synapse/http/matrixfederationclient.py | 8 +++++-- synapse/http/server.py | 19 +++++++++------ synapse/replication/http/__init__.py | 9 +++++-- synapse/replication/http/_base.py | 8 ++++--- synapse/replication/http/account_data.py | 14 +++++++---- synapse/replication/http/devices.py | 8 +++++-- synapse/replication/http/federation.py | 16 ++++++++----- synapse/replication/http/login.py | 8 +++++-- synapse/replication/http/membership.py | 6 ++--- synapse/replication/http/presence.py | 2 +- synapse/replication/http/push.py | 2 +- synapse/replication/http/register.py | 10 +++++--- synapse/replication/http/send_event.py | 8 +++++-- synapse/replication/http/streams.py | 8 +++++-- synapse/replication/slave/storage/_base.py | 7 ++++-- synapse/replication/slave/storage/client_ips.py | 7 +++++- synapse/replication/slave/storage/devices.py | 7 +++++- synapse/replication/slave/storage/events.py | 6 ++++- synapse/replication/slave/storage/filtering.py | 7 +++++- synapse/replication/slave/storage/groups.py | 7 +++++- synapse/replication/tcp/external_cache.py | 9 ++++++- synapse/replication/tcp/handler.py | 6 ++++- synapse/replication/tcp/resource.py | 8 +++++-- synapse/replication/tcp/streams/_base.py | 20 ++++++++-------- synapse/rest/admin/devices.py | 2 +- synapse/server.py | 11 ++++++--- synapse/storage/database.py | 6 ++++- synapse/storage/databases/__init__.py | 28 +++++++++++++++++----- synapse/storage/databases/main/__init__.py | 7 ++++-- synapse/storage/databases/main/account_data.py | 7 ++++-- synapse/storage/databases/main/cache.py | 7 ++++-- synapse/storage/databases/main/deviceinbox.py | 9 ++++--- synapse/storage/databases/main/devices.py | 21 ++++++++++++---- synapse/storage/databases/main/event_federation.py | 9 ++++--- .../storage/databases/main/event_push_actions.py | 9 ++++--- .../storage/databases/main/events_bg_updates.py | 7 ++++-- synapse/storage/databases/main/media_repository.py | 9 ++++--- synapse/storage/databases/main/metrics.py | 7 ++++-- .../storage/databases/main/monthly_active_users.py | 9 ++++--- synapse/storage/databases/main/push_rule.py | 7 ++++-- synapse/storage/databases/main/receipts.py | 7 ++++-- synapse/storage/databases/main/room.py | 11 +++++---- synapse/storage/databases/main/roommember.py | 7 +++--- synapse/storage/databases/main/search.py | 9 ++++--- synapse/storage/databases/main/state.py | 11 +++++---- synapse/storage/databases/main/stats.py | 7 ++++-- synapse/storage/databases/main/transactions.py | 7 ++++-- synapse/storage/persist_events.py | 6 ++++- 58 files changed, 342 insertions(+), 143 deletions(-) create mode 100644 changelog.d/11095.misc (limited to 'synapse') diff --git a/changelog.d/11095.misc b/changelog.d/11095.misc new file mode 100644 index 0000000000..786e90b595 --- /dev/null +++ b/changelog.d/11095.misc @@ -0,0 +1 @@ +Add type hints to most `HomeServer` parameters. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index bb4d53d778..2ca2e051e4 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -294,7 +294,7 @@ def listen_ssl( return r -def refresh_certificate(hs): +def refresh_certificate(hs: "HomeServer"): """ Refresh the TLS certificates that Synapse is using by re-reading them from disk and updating the TLS context factories to use them. @@ -419,11 +419,11 @@ async def start(hs: "HomeServer"): atexit.register(gc.freeze) -def setup_sentry(hs): +def setup_sentry(hs: "HomeServer"): """Enable sentry integration, if enabled in configuration Args: - hs (synapse.server.HomeServer) + hs """ if not hs.config.metrics.sentry_enabled: @@ -449,7 +449,7 @@ def setup_sentry(hs): scope.set_tag("worker_name", name) -def setup_sdnotify(hs): +def setup_sdnotify(hs: "HomeServer"): """Adds process state hooks to tell systemd what we are up to.""" # Tell systemd our state, if we're using it. This will silently fail if diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index b156b93bf3..2fc848596d 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -68,11 +68,11 @@ class AdminCmdServer(HomeServer): DATASTORE_CLASS = AdminCmdSlavedStore -async def export_data_command(hs, args): +async def export_data_command(hs: HomeServer, args): """Export data for a user. Args: - hs (HomeServer) + hs args (argparse.Namespace) """ diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 7489f31d9a..51eadf122d 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -131,10 +131,10 @@ class KeyUploadServlet(RestServlet): PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") - def __init__(self, hs): + def __init__(self, hs: HomeServer): """ Args: - hs (synapse.server.HomeServer): server + hs: server """ super().__init__() self.auth = hs.get_auth() diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 422f03cc04..93e2299266 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -412,7 +412,7 @@ def format_config_error(e: ConfigError) -> Iterator[str]: e = e.__cause__ -def run(hs): +def run(hs: HomeServer): PROFILE_SYNAPSE = False if PROFILE_SYNAPSE: diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index fcd01e833c..126450e17a 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -15,11 +15,15 @@ import logging import math import resource import sys +from typing import TYPE_CHECKING from prometheus_client import Gauge from synapse.metrics.background_process_metrics import wrap_as_background_process +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger("synapse.app.homeserver") # Contains the list of processes we will be monitoring @@ -41,7 +45,7 @@ registered_reserved_users_mau_gauge = Gauge( @wrap_as_background_process("phone_stats_home") -async def phone_stats_home(hs, stats, stats_process=_stats_process): +async def phone_stats_home(hs: "HomeServer", stats, stats_process=_stats_process): logger.info("Gathering stats for reporting") now = int(hs.get_clock().time()) uptime = int(now - hs.start_time) @@ -142,7 +146,7 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): logger.warning("Error reporting stats: %s", e) -def start_phone_stats_home(hs): +def start_phone_stats_home(hs: "HomeServer"): """ Start the background tasks which report phone home stats. """ diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 935f24263c..d08f6bbd7f 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -27,6 +27,7 @@ from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: from synapse.appservice import ApplicationService + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -84,7 +85,7 @@ class ApplicationServiceApi(SimpleHttpClient): pushing. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.clock = hs.get_clock() diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 0a08231e5a..5252e61a99 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -18,6 +18,7 @@ import os import sys import threading from string import Template +from typing import TYPE_CHECKING import yaml from zope.interface import implementer @@ -38,6 +39,9 @@ from synapse.util.versionstring import get_version_string from ._base import Config, ConfigError +if TYPE_CHECKING: + from synapse.server import HomeServer + DEFAULT_LOG_CONFIG = Template( """\ # Log configuration for Synapse. @@ -306,7 +310,10 @@ def _reload_logging_config(log_config_path): def setup_logging( - hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner + hs: "HomeServer", + config, + use_worker_options=False, + logBeginner: LogBeginner = globalLogBeginner, ) -> None: """ Set up the logging subsystem. diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 0cd424e12a..f56344a3b9 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -14,6 +14,7 @@ # limitations under the License. import logging from collections import namedtuple +from typing import TYPE_CHECKING from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership from synapse.api.errors import Codes, SynapseError @@ -25,11 +26,15 @@ from synapse.events.utils import prune_event, validate_canonicaljson from synapse.http.servlet import assert_params_in_dict from synapse.types import JsonDict, get_domain_from_id +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) class FederationBase: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.server_name = hs.hostname diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index d8c0b86f23..0d66034f44 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -467,7 +467,7 @@ class FederationServer(FederationBase): async def on_room_state_request( self, origin: str, room_id: str, event_id: Optional[str] - ) -> Tuple[int, Dict[str, Any]]: + ) -> Tuple[int, JsonDict]: origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) @@ -481,7 +481,7 @@ class FederationServer(FederationBase): # - but that's non-trivial to get right, and anyway somewhat defeats # the point of the linearizer. with (await self._server_linearizer.queue((origin, room_id))): - resp = dict( + resp: JsonDict = dict( await self._state_resp_cache.wrap( (room_id, event_id), self._on_context_state_request_compute, @@ -1061,11 +1061,12 @@ class FederationServer(FederationBase): origin, event = next - lock = await self.store.try_acquire_lock( + new_lock = await self.store.try_acquire_lock( _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id ) - if not lock: + if not new_lock: return + lock = new_lock def __str__(self) -> str: return "" % self.server_name diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 4f59224686..203d723d41 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -21,6 +21,7 @@ import typing import urllib.parse from io import BytesIO, StringIO from typing import ( + TYPE_CHECKING, Callable, Dict, Generic, @@ -73,6 +74,9 @@ from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) outgoing_requests_counter = Counter( @@ -319,7 +323,7 @@ class MatrixFederationHttpClient: requests. """ - def __init__(self, hs, tls_client_options_factory): + def __init__(self, hs: "HomeServer", tls_client_options_factory): self.hs = hs self.signing_key = hs.signing_key self.server_name = hs.hostname @@ -711,7 +715,7 @@ class MatrixFederationHttpClient: Returns: A list of headers to be added as "Authorization:" headers """ - request = { + request: JsonDict = { "method": method.decode("ascii"), "uri": url_bytes.decode("ascii"), "origin": self.server_name, diff --git a/synapse/http/server.py b/synapse/http/server.py index 897ba5e453..1af0d9a31d 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -22,6 +22,7 @@ import urllib from http import HTTPStatus from inspect import isawaitable from typing import ( + TYPE_CHECKING, Any, Awaitable, Callable, @@ -61,6 +62,9 @@ from synapse.util import json_encoder from synapse.util.caches import intern_dict from synapse.util.iterutils import chunk_seq +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) HTML_ERROR_TEMPLATE = """ @@ -343,6 +347,11 @@ class DirectServeJsonResource(_AsyncResource): return_json_error(f, request) +_PathEntry = collections.namedtuple( + "_PathEntry", ["pattern", "callback", "servlet_classname"] +) + + class JsonResource(DirectServeJsonResource): """This implements the HttpServer interface and provides JSON support for Resources. @@ -359,14 +368,10 @@ class JsonResource(DirectServeJsonResource): isLeaf = True - _PathEntry = collections.namedtuple( - "_PathEntry", ["pattern", "callback", "servlet_classname"] - ) - - def __init__(self, hs, canonical_json=True, extract_context=False): + def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False): super().__init__(canonical_json, extract_context) self.clock = hs.get_clock() - self.path_regexs = {} + self.path_regexs: Dict[bytes, List[_PathEntry]] = {} self.hs = hs def register_paths(self, method, path_patterns, callback, servlet_classname): @@ -391,7 +396,7 @@ class JsonResource(DirectServeJsonResource): for path_pattern in path_patterns: logger.debug("Registering for %s %s", method, path_pattern.pattern) self.path_regexs.setdefault(method, []).append( - self._PathEntry(path_pattern, callback, servlet_classname) + _PathEntry(path_pattern, callback, servlet_classname) ) def _get_handler_for_request( diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index ba8114ac9e..1457d9d59b 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from synapse.http.server import JsonResource from synapse.replication.http import ( account_data, @@ -26,16 +28,19 @@ from synapse.replication.http import ( streams, ) +if TYPE_CHECKING: + from synapse.server import HomeServer + REPLICATION_PREFIX = "/_synapse/replication" class ReplicationRestResource(JsonResource): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): # We enable extracting jaeger contexts here as these are internal APIs. super().__init__(hs, canonical_json=False, extract_context=True) self.register_servlets(hs) - def register_servlets(self, hs): + def register_servlets(self, hs: "HomeServer"): send_event.register_servlets(hs, self) federation.register_servlets(hs, self) presence.register_servlets(hs, self) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index e047ec74d8..585332b244 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -17,7 +17,7 @@ import logging import re import urllib from inspect import signature -from typing import TYPE_CHECKING, Dict, List, Tuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple from prometheus_client import Counter, Gauge @@ -156,7 +156,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): pass @classmethod - def make_client(cls, hs): + def make_client(cls, hs: "HomeServer"): """Create a client that makes requests. Returns a callable that accepts the same parameters as @@ -208,7 +208,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): url_args.append(txn_id) if cls.METHOD == "POST": - request_func = client.post_json_get_json + request_func: Callable[ + ..., Awaitable[Any] + ] = client.post_json_get_json elif cls.METHOD == "PUT": request_func = client.put_json elif cls.METHOD == "GET": diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py index 70e951af63..5f0f225aa9 100644 --- a/synapse/replication/http/account_data.py +++ b/synapse/replication/http/account_data.py @@ -13,10 +13,14 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -37,7 +41,7 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): PATH_ARGS = ("user_id", "account_data_type") CACHE = False - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.handler = hs.get_account_data_handler() @@ -78,7 +82,7 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): PATH_ARGS = ("user_id", "room_id", "account_data_type") CACHE = False - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.handler = hs.get_account_data_handler() @@ -119,7 +123,7 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint): PATH_ARGS = ("user_id", "room_id", "tag") CACHE = False - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.handler = hs.get_account_data_handler() @@ -162,7 +166,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): ) CACHE = False - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.handler = hs.get_account_data_handler() @@ -183,7 +187,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): return 200, {"max_stream_id": max_stream_id} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationUserAccountDataRestServlet(hs).register(http_server) ReplicationRoomAccountDataRestServlet(hs).register(http_server) ReplicationAddTagRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index 5a5818ef61..42dffb39cb 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -13,9 +13,13 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.replication.http._base import ReplicationEndpoint +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -51,7 +55,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): PATH_ARGS = ("user_id",) CACHE = False - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.device_list_updater = hs.get_device_handler().device_list_updater @@ -68,5 +72,5 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): return 200, user_devices -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationUserDevicesResyncRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index a0b3145f4e..5ed535c90d 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict @@ -21,6 +22,9 @@ from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -56,7 +60,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): NAME = "fed_send_events" PATH_ARGS = () - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() @@ -151,7 +155,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): NAME = "fed_send_edu" PATH_ARGS = ("edu_type",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() @@ -194,7 +198,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): # This is a query, so let's not bother caching CACHE = False - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() @@ -238,7 +242,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): NAME = "fed_cleanup_room" PATH_ARGS = ("room_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() @@ -273,7 +277,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint): NAME = "store_room_on_outlier_membership" PATH_ARGS = ("room_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() @@ -289,7 +293,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationFederationSendEventsRestServlet(hs).register(http_server) ReplicationFederationSendEduRestServlet(hs).register(http_server) ReplicationGetQueryRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 550bd5c95f..0db419ea57 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -13,10 +13,14 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -30,7 +34,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): NAME = "device_check_registered" PATH_ARGS = ("user_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.registration_handler = hs.get_registration_handler() @@ -82,5 +86,5 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): return 200, res -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): RegisterDeviceReplicationServlet(hs).register(http_server) diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 34206c5060..7371c240b2 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -45,7 +45,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): NAME = "remote_join" PATH_ARGS = ("room_id", "user_id") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.federation_handler = hs.get_federation_handler() @@ -320,7 +320,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): PATH_ARGS = ("room_id", "user_id", "change") CACHE = False # No point caching as should return instantly. - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.registeration_handler = hs.get_registration_handler() @@ -360,7 +360,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationRemoteJoinRestServlet(hs).register(http_server) ReplicationRemoteRejectInviteRestServlet(hs).register(http_server) ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py index bb00247953..63143085d5 100644 --- a/synapse/replication/http/presence.py +++ b/synapse/replication/http/presence.py @@ -117,6 +117,6 @@ class ReplicationPresenceSetState(ReplicationEndpoint): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationBumpPresenceActiveTime(hs).register(http_server) ReplicationPresenceSetState(hs).register(http_server) diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py index 139427cb1f..6c8db3061e 100644 --- a/synapse/replication/http/push.py +++ b/synapse/replication/http/push.py @@ -67,5 +67,5 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationRemovePusherRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index d6dd7242eb..7adfbb666f 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -13,10 +13,14 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -26,7 +30,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): NAME = "register_user" PATH_ARGS = ("user_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() self.registration_handler = hs.get_registration_handler() @@ -100,7 +104,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): NAME = "post_register" PATH_ARGS = ("user_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.store = hs.get_datastore() self.registration_handler = hs.get_registration_handler() @@ -130,6 +134,6 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationRegisterServlet(hs).register(http_server) ReplicationPostRegisterActionsServlet(hs).register(http_server) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index fae5ffa451..9f6851d059 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict @@ -22,6 +23,9 @@ from synapse.replication.http._base import ReplicationEndpoint from synapse.types import Requester, UserID from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -57,7 +61,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): NAME = "send_event" PATH_ARGS = ("event_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() @@ -135,5 +139,5 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationSendEventRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index 9afa147d00..3223bc2432 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -13,11 +13,15 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.api.errors import SynapseError from synapse.http.servlet import parse_integer from synapse.replication.http._base import ReplicationEndpoint +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -46,7 +50,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): PATH_ARGS = ("stream_name",) METHOD = "GET" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._instance_name = hs.get_instance_name() @@ -74,5 +78,5 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): ReplicationGetStreamUpdates(hs).register(http_server) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index e460dd85cd..7ecb446e7c 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -13,18 +13,21 @@ # limitations under the License. import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional from synapse.storage.database import DatabasePool from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class BaseSlavedStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen: Optional[ diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 436d39c320..61cd7e5228 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -12,15 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from synapse.storage.database import DatabasePool from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.util.caches.lrucache import LruCache from ._base import BaseSlavedStore +if TYPE_CHECKING: + from synapse.server import HomeServer + class SlavedClientIpStore(BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.client_ip_last_seen: LruCache[tuple, int] = LruCache( diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 26bdead565..0a58296089 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream @@ -20,9 +22,12 @@ from synapse.storage.databases.main.devices import DeviceWorkerStore from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index d4d3f8c448..63ed50caa5 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING from synapse.storage.database import DatabasePool from synapse.storage.databases.main.event_federation import EventFederationWorkerStore @@ -30,6 +31,9 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -54,7 +58,7 @@ class SlavedEventStore( RelationsWorkerStore, BaseSlavedStore, ): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 37875bc973..90284c202d 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from synapse.storage.database import DatabasePool from synapse.storage.databases.main.filtering import FilteringStore from ._base import BaseSlavedStore +if TYPE_CHECKING: + from synapse.server import HomeServer + class SlavedFilteringStore(BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index e9bdc38470..497e16c69e 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import GroupServerStream @@ -19,9 +21,12 @@ from synapse.storage.database import DatabasePool from synapse.storage.databases.main.group_server import GroupServerWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py index b402f82810..aaf91e5e02 100644 --- a/synapse/replication/tcp/external_cache.py +++ b/synapse/replication/tcp/external_cache.py @@ -21,6 +21,8 @@ from synapse.logging.context import make_deferred_yieldable from synapse.util import json_decoder, json_encoder if TYPE_CHECKING: + from txredisapi import RedisProtocol + from synapse.server import HomeServer set_counter = Counter( @@ -59,7 +61,12 @@ class ExternalCache: """ def __init__(self, hs: "HomeServer"): - self._redis_connection = hs.get_outbound_redis_connection() + if hs.config.redis.redis_enabled: + self._redis_connection: Optional[ + "RedisProtocol" + ] = hs.get_outbound_redis_connection() + else: + self._redis_connection = None def _get_redis_key(self, cache_name: str, key: str) -> str: return "cache_v1:%s:%s" % (cache_name, key) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 6aa9318027..06fd06fdf3 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -294,7 +294,7 @@ class ReplicationCommandHandler: # This shouldn't be possible raise Exception("Unrecognised command %s in stream queue", cmd.NAME) - def start_replication(self, hs): + def start_replication(self, hs: "HomeServer"): """Helper method to start a replication connection to the remote server using TCP. """ @@ -321,6 +321,8 @@ class ReplicationCommandHandler: hs.config.redis.redis_host, # type: ignore[arg-type] hs.config.redis.redis_port, self._factory, + timeout=30, + bindAddress=None, ) else: client_name = hs.get_instance_name() @@ -331,6 +333,8 @@ class ReplicationCommandHandler: host, # type: ignore[arg-type] port, self._factory, + timeout=30, + bindAddress=None, ) def get_streams(self) -> Dict[str, Stream]: diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 80f9b23bfd..55326877fd 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -16,6 +16,7 @@ import logging import random +from typing import TYPE_CHECKING from prometheus_client import Counter @@ -27,6 +28,9 @@ from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol from synapse.replication.tcp.streams import EventsStream from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.server import HomeServer + stream_updates_counter = Counter( "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] ) @@ -37,7 +41,7 @@ logger = logging.getLogger(__name__) class ReplicationStreamProtocolFactory(Factory): """Factory for new replication connections.""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.command_handler = hs.get_tcp_replication() self.clock = hs.get_clock() self.server_name = hs.config.server.server_name @@ -65,7 +69,7 @@ class ReplicationStreamer: data is available it will propagate to all connected clients. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.clock = hs.get_clock() self.notifier = hs.get_notifier() diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 9b905aba9d..c8b188ae4e 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -241,7 +241,7 @@ class BackfillStream(Stream): NAME = "backfill" ROW_TYPE = BackfillStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -363,7 +363,7 @@ class ReceiptsStream(Stream): NAME = "receipts" ROW_TYPE = ReceiptsStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -380,7 +380,7 @@ class PushRulesStream(Stream): NAME = "push_rules" ROW_TYPE = PushRulesStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() super().__init__( @@ -405,7 +405,7 @@ class PushersStream(Stream): NAME = "pushers" ROW_TYPE = PushersStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( @@ -438,7 +438,7 @@ class CachesStream(Stream): NAME = "caches" ROW_TYPE = CachesStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -459,7 +459,7 @@ class DeviceListsStream(Stream): NAME = "device_lists" ROW_TYPE = DeviceListsStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -476,7 +476,7 @@ class ToDeviceStream(Stream): NAME = "to_device" ROW_TYPE = ToDeviceStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -495,7 +495,7 @@ class TagAccountDataStream(Stream): NAME = "tag_account_data" ROW_TYPE = TagAccountDataStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -582,7 +582,7 @@ class GroupServerStream(Stream): NAME = "groups" ROW_TYPE = GroupsStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), @@ -599,7 +599,7 @@ class UserSignatureStream(Stream): NAME = "user_signature" ROW_TYPE = UserSignatureStreamRow - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): store = hs.get_datastore() super().__init__( hs.get_instance_name(), diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index a6fa03c90f..80fbf32f17 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -110,7 +110,7 @@ class DevicesRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): """ Args: - hs (synapse.server.HomeServer): server + hs: server """ self.hs = hs self.auth = hs.get_auth() diff --git a/synapse/server.py b/synapse/server.py index a64c846d1c..0fbf36ba99 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -800,9 +800,14 @@ class HomeServer(metaclass=abc.ABCMeta): return ExternalCache(self) @cache_in_self - def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]: - if not self.config.redis.redis_enabled: - return None + def get_outbound_redis_connection(self) -> "RedisProtocol": + """ + The Redis connection used for replication. + + Raises: + AssertionError: if Redis is not enabled in the homeserver config. + """ + assert self.config.redis.redis_enabled # We only want to import redis module if we're using it, as we have # `txredisapi` as an optional dependency. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index f5a8f90a0f..fa4e89d35c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -19,6 +19,7 @@ from collections import defaultdict from sys import intern from time import monotonic as monotonic_time from typing import ( + TYPE_CHECKING, Any, Callable, Collection, @@ -52,6 +53,9 @@ from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor +if TYPE_CHECKING: + from synapse.server import HomeServer + # python 3 does not have a maximum int value MAX_TXN_ID = 2 ** 63 - 1 @@ -392,7 +396,7 @@ class DatabasePool: def __init__( self, - hs, + hs: "HomeServer", database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine, ): diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index 20b755056b..cfe887b7f7 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -13,33 +13,49 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Generic, List, Optional, Type, TypeVar +from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_conn from synapse.storage.databases.main.events import PersistEventsStore from synapse.storage.databases.state import StateGroupDataStore from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) -class Databases: +DataStoreT = TypeVar("DataStoreT", bound=SQLBaseStore, covariant=True) + + +class Databases(Generic[DataStoreT]): """The various databases. These are low level interfaces to physical databases. Attributes: - main (DataStore) + databases + main + state + persist_events """ - def __init__(self, main_store_class, hs): + databases: List[DatabasePool] + main: DataStoreT + state: StateGroupDataStore + persist_events: Optional[PersistEventsStore] + + def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"): # Note we pass in the main store class here as workers use a different main # store. self.databases = [] - main = None - state = None - persist_events = None + main: Optional[DataStoreT] = None + state: Optional[StateGroupDataStore] = None + persist_events: Optional[PersistEventsStore] = None for database_config in hs.config.database.databases: db_name = database_config.name diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 5c21402dea..259cae5b37 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import DatabasePool @@ -75,6 +75,9 @@ from .ui_auth import UIAuthStore from .user_directory import UserDirectoryStore from .user_erasure_store import UserErasureStore +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -126,7 +129,7 @@ class DataStore( LockStore, SessionStore, ): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 70ca3e09f7..f8bec266ac 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple from synapse.api.constants import AccountDataTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker @@ -28,6 +28,9 @@ from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -36,7 +39,7 @@ class AccountDataWorkerStore(SQLBaseStore): `get_max_account_data_stream_id` which can be called in the initializer. """ - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): self._instance_name = hs.get_instance_name() if isinstance(database.engine, PostgresEngine): diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index c57ae5ef15..36e8422fc6 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -15,7 +15,7 @@ import itertools import logging -from typing import Any, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple from synapse.api.constants import EventTypes from synapse.replication.tcp.streams import BackfillStream, CachesStream @@ -29,6 +29,9 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.util.iterutils import batch_iter +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -38,7 +41,7 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake" class CacheInvalidationWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 3154906d45..8143168107 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.logging import issue9533_logger from synapse.logging.opentracing import log_kv, set_tag, trace @@ -26,11 +26,14 @@ from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class DeviceInboxWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() @@ -553,7 +556,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 6464520386..a01bf2c5b7 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -15,7 +15,17 @@ # limitations under the License. import abc import logging -from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, +) from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( @@ -38,6 +48,9 @@ from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( @@ -48,7 +61,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" class DeviceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: @@ -915,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -1047,7 +1060,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ba9f71a230..ef5d1ef01e 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -14,7 +14,7 @@ import itertools import logging from queue import Empty, PriorityQueue -from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple from prometheus_client import Counter, Gauge @@ -34,6 +34,9 @@ from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter +if TYPE_CHECKING: + from synapse.server import HomeServer + oldest_pdu_in_federation_staging = Gauge( "synapse_federation_server_oldest_inbound_pdu_in_staging", "The age in seconds since we received the oldest pdu in the federation staging area", @@ -59,7 +62,7 @@ class _NoChainCoverIndex(Exception): class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: @@ -1511,7 +1514,7 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 97b3e92d3f..d957e770dc 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import attr @@ -23,6 +23,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.util import json_encoder from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -64,7 +67,7 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn @@ -892,7 +895,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 1afc59fafb..fc49112063 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import attr @@ -26,6 +26,9 @@ from synapse.storage.databases.main.events import PersistEventsStore from synapse.storage.types import Cursor from synapse.types import JsonDict +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -76,7 +79,7 @@ class _CalculateChainCover: class EventsBackgroundUpdatesStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 2fa945d171..717487be28 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -13,11 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from enum import Enum -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool +if TYPE_CHECKING: + from synapse.server import HomeServer + BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = ( "media_repository_drop_index_wo_method" ) @@ -43,7 +46,7 @@ class MediaSortOrder(Enum): class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_index_update( @@ -123,7 +126,7 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): """Persistence for attachments and avatars""" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.server_name = hs.hostname diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index dac3d14da8..d901933ae4 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -14,7 +14,7 @@ import calendar import logging import time -from typing import Dict +from typing import TYPE_CHECKING, Dict from synapse.metrics import GaugeBucketCollector from synapse.metrics.background_process_metrics import wrap_as_background_process @@ -24,6 +24,9 @@ from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # Collect metrics on the number of forward extremities that exist. @@ -52,7 +55,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): stats and prometheus metrics. """ - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) # Read the extrems every 60 minutes diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index a14ac03d4b..b5284e4f67 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -12,13 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # Number of msec of granularity to store the monthly_active_user timestamp @@ -27,7 +30,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000 class MonthlyActiveUsersWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs @@ -209,7 +212,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._mau_stats_only = hs.config.server.mau_stats_only diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index fc720f5947..fa782023d4 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -14,7 +14,7 @@ # limitations under the License. import abc import logging -from typing import Dict, List, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Tuple, Union from synapse.api.errors import NotFoundError, StoreError from synapse.push.baserules import list_with_base_rules @@ -33,6 +33,9 @@ from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -75,7 +78,7 @@ class PushRulesWorkerStore( `get_max_push_rules_stream_id` which can be called in the initializer. """ - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if hs.config.worker.worker_app is None: diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 01a4281301..c99f8aebdb 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple from twisted.internet import defer @@ -29,11 +29,14 @@ from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class ReceiptsWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): self._instance_name = hs.get_instance_name() if isinstance(database.engine, PostgresEngine): diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 835d7889cb..f879bbe7c7 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -17,7 +17,7 @@ import collections import logging from abc import abstractmethod from enum import Enum -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.errors import StoreError @@ -32,6 +32,9 @@ from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.stringutils import MXC_REGEX +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -69,7 +72,7 @@ class RoomSortOrder(Enum): class RoomWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.config = hs.config @@ -1026,7 +1029,7 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = ( class RoomBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.config = hs.config @@ -1411,7 +1414,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.config = hs.config diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index ddb162a4fc..4b288bb2e7 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -53,6 +53,7 @@ from synapse.util.caches.descriptors import _CacheContext, cached, cachedList from synapse.util.metrics import Measure if TYPE_CHECKING: + from synapse.server import HomeServer from synapse.state import _StateCacheEntry logger = logging.getLogger(__name__) @@ -63,7 +64,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" class RoomMemberWorkerStore(EventsWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) # Used by `_get_joined_hosts` to ensure only one thing mutates the cache @@ -982,7 +983,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): class RoomMemberBackgroundUpdateStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile @@ -1132,7 +1133,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) async def forget(self, user_id: str, room_id: str) -> None: diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index c85383c975..7fe233767f 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -15,7 +15,7 @@ import logging import re from collections import namedtuple -from typing import Collection, Iterable, List, Optional, Set +from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set from synapse.api.errors import SynapseError from synapse.events import EventBase @@ -24,6 +24,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.engines import PostgresEngine, Sqlite3Engine +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) SearchEntry = namedtuple( @@ -102,7 +105,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if not hs.config.server.enable_search: @@ -355,7 +358,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): class SearchStore(SearchBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) async def search_msgs(self, room_ids, search_term, keys): diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index a8e8dd4577..fa2c3b1feb 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -15,7 +15,7 @@ import collections.abc import logging from collections import namedtuple -from typing import Iterable, Optional, Set +from typing import TYPE_CHECKING, Iterable, Optional, Set from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError @@ -30,6 +30,9 @@ from synapse.types import StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -53,7 +56,7 @@ class _GetStateGroupDelta( class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers.""" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) async def get_room_version(self, room_id: str) -> RoomVersion: @@ -346,7 +349,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -533,5 +536,5 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): * `state_groups_state`: Maps state group to state events. """ - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index e20033bb28..5d7b59d861 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -16,7 +16,7 @@ import logging from enum import Enum from itertools import chain -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing_extensions import Counter @@ -29,6 +29,9 @@ from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.types import JsonDict from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # these fields track absolutes (e.g. total number of rooms on the server) @@ -93,7 +96,7 @@ class UserSortOrder(Enum): class StatsStore(StateDeltasStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.server_name = hs.hostname diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 860146cd1b..d7dc1f73ac 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -14,7 +14,7 @@ import logging from collections import namedtuple -from typing import Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -26,6 +26,9 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.types import JsonDict from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + db_binary_type = memoryview logger = logging.getLogger(__name__) @@ -57,7 +60,7 @@ class DestinationRetryTimings: class TransactionWorkerStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 0e8270746d..402f134d89 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -18,6 +18,7 @@ import itertools import logging from collections import deque from typing import ( + TYPE_CHECKING, Any, Awaitable, Callable, @@ -56,6 +57,9 @@ from synapse.types import ( from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) # The number of times we are recalculating the current state @@ -272,7 +276,7 @@ class EventsPersistenceStorage: current state and forward extremity changes. """ - def __init__(self, hs, stores: Databases): + def __init__(self, hs: "HomeServer", stores: Databases): # We ultimately want to split out the state store from the main store, # so we use separate variables here even though they point to the same # store for now. -- cgit 1.5.1 From 85a09f8b8ba7c8023c0d28a526d32111fc704197 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 25 Oct 2021 13:01:04 +0100 Subject: Fix module API's `get_user_ip_and_agents` function when run on workers (#11112) --- changelog.d/11112.bugfix | 1 + synapse/module_api/__init__.py | 6 +- synapse/storage/databases/main/client_ips.py | 124 ++++++++++++++++++--------- 3 files changed, 91 insertions(+), 40 deletions(-) create mode 100644 changelog.d/11112.bugfix (limited to 'synapse') diff --git a/changelog.d/11112.bugfix b/changelog.d/11112.bugfix new file mode 100644 index 0000000000..c8e22da8cf --- /dev/null +++ b/changelog.d/11112.bugfix @@ -0,0 +1 @@ +Fix a bug which caused the module API's `get_user_ip_and_agents` function to always fail on workers. `get_user_ip_and_agents` was introduced in 1.44.0 and did not function correctly on worker processes at the time. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index ab7ef8f950..d37252b6b3 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -46,6 +46,7 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client.login import LoginResponse +from synapse.storage import DataStore from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter @@ -61,6 +62,7 @@ from synapse.util import Clock from synapse.util.caches.descriptors import cached if TYPE_CHECKING: + from synapse.app.generic_worker import GenericWorkerSlavedStore from synapse.server import HomeServer """ @@ -111,7 +113,9 @@ class ModuleApi: def __init__(self, hs: "HomeServer", auth_handler): self._hs = hs - self._store = hs.get_datastore() + # TODO: Fix this type hint once the types for the data stores have been ironed + # out. + self._store: Union[DataStore, "GenericWorkerSlavedStore"] = hs.get_datastore() self._auth = hs.get_auth() self._auth_handler = auth_handler self._server_name = hs.hostname diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index b81d9218ce..1dc7f0ebe3 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -478,6 +478,58 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): return {(d["user_id"], d["device_id"]): d for d in res} + async def get_user_ip_and_agents( + self, user: UserID, since_ts: int = 0 + ) -> List[LastConnectionInfo]: + """Fetch the IPs and user agents for a user since the given timestamp. + + The result might be slightly out of date as client IPs are inserted in batches. + + Args: + user: The user for which to fetch IP addresses and user agents. + since_ts: The timestamp after which to fetch IP addresses and user agents, + in milliseconds. + + Returns: + A list of dictionaries, each containing: + * `access_token`: The access token used. + * `ip`: The IP address used. + * `user_agent`: The last user agent seen for this access token and IP + address combination. + * `last_seen`: The timestamp at which this access token and IP address + combination was last seen, in milliseconds. + + Only the latest user agent for each access token and IP address combination + is available. + """ + user_id = user.to_string() + + def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]: + txn.execute( + """ + SELECT access_token, ip, user_agent, last_seen FROM user_ips + WHERE last_seen >= ? AND user_id = ? + ORDER BY last_seen + DESC + """, + (since_ts, user_id), + ) + return cast(List[Tuple[str, str, str, int]], txn.fetchall()) + + rows = await self.db_pool.runInteraction( + desc="get_user_ip_and_agents", func=get_recent + ) + + return [ + { + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "last_seen": last_seen, + } + for access_token, ip, user_agent, last_seen in rows + ] + class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): @@ -622,49 +674,43 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore): async def get_user_ip_and_agents( self, user: UserID, since_ts: int = 0 ) -> List[LastConnectionInfo]: + """Fetch the IPs and user agents for a user since the given timestamp. + + Args: + user: The user for which to fetch IP addresses and user agents. + since_ts: The timestamp after which to fetch IP addresses and user agents, + in milliseconds. + + Returns: + A list of dictionaries, each containing: + * `access_token`: The access token used. + * `ip`: The IP address used. + * `user_agent`: The last user agent seen for this access token and IP + address combination. + * `last_seen`: The timestamp at which this access token and IP address + combination was last seen, in milliseconds. + + Only the latest user agent for each access token and IP address combination + is available. """ - Fetch IP/User Agent connection since a given timestamp. - """ - user_id = user.to_string() - results: Dict[Tuple[str, str], Tuple[str, int]] = {} + results: Dict[Tuple[str, str], LastConnectionInfo] = { + (connection["access_token"], connection["ip"]): connection + for connection in await super().get_user_ip_and_agents(user, since_ts) + } + # Overlay data that is pending insertion on top of the results from the + # database. + user_id = user.to_string() for key in self._batch_row_update: - ( - uid, - access_token, - ip, - ) = key + uid, access_token, ip = key if uid == user_id: user_agent, _, last_seen = self._batch_row_update[key] if last_seen >= since_ts: - results[(access_token, ip)] = (user_agent, last_seen) - - def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]: - txn.execute( - """ - SELECT access_token, ip, user_agent, last_seen FROM user_ips - WHERE last_seen >= ? AND user_id = ? - ORDER BY last_seen - DESC - """, - (since_ts, user_id), - ) - return cast(List[Tuple[str, str, str, int]], txn.fetchall()) - - rows = await self.db_pool.runInteraction( - desc="get_user_ip_and_agents", func=get_recent - ) + results[(access_token, ip)] = { + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "last_seen": last_seen, + } - results.update( - ((access_token, ip), (user_agent, last_seen)) - for access_token, ip, user_agent, last_seen in rows - ) - return [ - { - "access_token": access_token, - "ip": ip, - "user_agent": user_agent, - "last_seen": last_seen, - } - for (access_token, ip), (user_agent, last_seen) in results.items() - ] + return list(results.values()) -- cgit 1.5.1 From da957a60e8958b08a52bd1404a89cf9bbcd033e0 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 25 Oct 2021 16:21:09 +0200 Subject: Ensure that we correctly auth events returned by `send_join` (#11012) This is the final piece of the jigsaw for #9595. As with other changes before this one (eg #10771), we need to make sure that we auth the auth events in the right order, and actually check that their predecessors haven't been rejected. To do this I've reused the existing code we use when persisting outliers elsewhere. I've removed the code for attempting to fetch missing auth_events - the events should have been present in the send_join response, so the likely reason they are missing is that we couldn't verify them, so requesting them again is unlikely to help. Instead, we simply drop any state which relies on those auth events, as we do at a backwards-extremity. See also matrix-org/complement#216 for a test for this. --- changelog.d/11012.bugfix | 1 + synapse/handlers/federation_event.py | 146 ++++++++++++++--------------------- 2 files changed, 61 insertions(+), 86 deletions(-) create mode 100644 changelog.d/11012.bugfix (limited to 'synapse') diff --git a/changelog.d/11012.bugfix b/changelog.d/11012.bugfix new file mode 100644 index 0000000000..13b8e5983b --- /dev/null +++ b/changelog.d/11012.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug which meant that events received over federation were sometimes incorrectly accepted into the room state. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 3431a80ab4..9584d5bd46 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -361,6 +361,7 @@ class FederationEventHandler: # need to. await self._event_creation_handler.cache_joined_hosts_for_event(event, context) + await self._check_for_soft_fail(event, None, origin=origin) await self._run_push_actions_and_persist_event(event, context) return event, context @@ -402,29 +403,28 @@ class FederationEventHandler: """Persists the events returned by a send_join Checks the auth chain is valid (and passes auth checks) for the - state and event. Then persists the auth chain and state atomically. - Persists the event separately. Notifies about the persisted events - where appropriate. - - Will attempt to fetch missing auth events. + state and event. Then persists all of the events. + Notifies about the persisted events where appropriate. Args: origin: Where the events came from - room_id, + room_id: auth_events state event room_version: The room version we expect this room to have, and will raise if it doesn't match the version in the create event. + + Returns: + The stream ID after which all events have been persisted. + + Raises: + SynapseError if the response is in some way invalid. """ - events_to_context = {} for e in itertools.chain(auth_events, state): e.internal_metadata.outlier = True - events_to_context[e.event_id] = EventContext.for_outlier() - event_map = { - e.event_id: e for e in itertools.chain(auth_events, state, [event]) - } + event_map = {e.event_id: e for e in itertools.chain(auth_events, state)} create_event = None for e in auth_events: @@ -444,64 +444,36 @@ class FederationEventHandler: if room_version.identifier != room_version_id: raise SynapseError(400, "Room version mismatch") - missing_auth_events = set() - for e in itertools.chain(auth_events, state, [event]): - for e_id in e.auth_event_ids(): - if e_id not in event_map: - missing_auth_events.add(e_id) - - for e_id in missing_auth_events: - m_ev = await self._federation_client.get_pdu( - [origin], - e_id, - room_version=room_version, - outlier=True, - timeout=10000, - ) - if m_ev and m_ev.event_id == e_id: - event_map[e_id] = m_ev - else: - logger.info("Failed to find auth event %r", e_id) - - for e in itertools.chain(auth_events, state, [event]): - auth_for_e = [ - event_map[e_id] for e_id in e.auth_event_ids() if e_id in event_map - ] - if create_event: - auth_for_e.append(create_event) - - try: - validate_event_for_room_version(room_version, e) - check_auth_rules_for_event(room_version, e, auth_for_e) - except SynapseError as err: - # we may get SynapseErrors here as well as AuthErrors. For - # instance, there are a couple of (ancient) events in some - # rooms whose senders do not have the correct sigil; these - # cause SynapseErrors in auth.check. We don't want to give up - # the attempt to federate altogether in such cases. - - logger.warning("Rejecting %s because %s", e.event_id, err.msg) - - if e == event: - raise - events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR - - if auth_events or state: - await self.persist_events_and_notify( - room_id, - [ - (e, events_to_context[e.event_id]) - for e in itertools.chain(auth_events, state) - ], + # filter out any events we have already seen + seen_remotes = await self._store.have_seen_events(room_id, event_map.keys()) + for s in seen_remotes: + event_map.pop(s, None) + + # persist the auth chain and state events. + # + # any invalid events here will be marked as rejected, and we'll carry on. + # + # any events whose auth events are missing (ie, not in the send_join response, + # and not already in our db) will just be ignored. This is correct behaviour, + # because the reason that auth_events are missing might be due to us being + # unable to validate their signatures. The fact that we can't validate their + # signatures right now doesn't mean that we will *never* be able to, so it + # is premature to reject them. + # + await self._auth_and_persist_outliers(room_id, event_map.values()) + + # and now persist the join event itself. + logger.info("Peristing join-via-remote %s", event) + with nested_logging_context(suffix=event.event_id): + context = await self._state_handler.compute_event_context( + event, old_state=state ) - new_event_context = await self._state_handler.compute_event_context( - event, old_state=state - ) + context = await self._check_event_auth(origin, event, context) + if context.rejected: + raise SynapseError(400, "Join event was rejected") - return await self.persist_events_and_notify( - room_id, [(event, new_event_context)] - ) + return await self.persist_events_and_notify(room_id, [(event, context)]) @log_function async def backfill( @@ -974,9 +946,15 @@ class FederationEventHandler: ) -> None: """Called when we have a new non-outlier event. - This is called when we have a new event to add to the room DAG - either directly - via a /send request, retrieved via get_missing_events after a /send request, or - backfilled after a client request. + This is called when we have a new event to add to the room DAG. This can be + due to: + * events received directly via a /send request + * events retrieved via get_missing_events after a /send request + * events backfilled after a client request. + + It's not currently used for events received from incoming send_{join,knock,leave} + requests (which go via on_send_membership_event), nor for joins created by a + remote join dance (which go via process_remote_join). We need to do auth checks and put it through the StateHandler. @@ -1012,11 +990,19 @@ class FederationEventHandler: logger.exception("Unexpected AuthError from _check_event_auth") raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) + if not backfilled and not context.rejected: + # For new (non-backfilled and non-outlier) events we check if the event + # passes auth based on the current state. If it doesn't then we + # "soft-fail" the event. + await self._check_for_soft_fail(event, state, origin=origin) + await self._run_push_actions_and_persist_event(event, context, backfilled) - if backfilled: + if backfilled or context.rejected: return + await self._maybe_kick_guest_users(event) + # For encrypted messages we check that we know about the sending device, # if we don't then we mark the device cache for that user as stale. if event.type == EventTypes.Encrypted: @@ -1317,14 +1303,14 @@ class FederationEventHandler: for auth_event_id in event.auth_event_ids(): ae = persisted_events.get(auth_event_id) if not ae: + # the fact we can't find the auth event doesn't mean it doesn't + # exist, which means it is premature to reject `event`. Instead we + # just ignore it for now. logger.warning( - "Event %s relies on auth_event %s, which could not be found.", + "Dropping event %s, which relies on auth_event %s, which could not be found", event, auth_event_id, ) - # the fact we can't find the auth event doesn't mean it doesn't - # exist, which means it is premature to reject `event`. Instead we - # just ignore it for now. return None auth.append(ae) @@ -1447,10 +1433,6 @@ class FederationEventHandler: except AuthError as e: logger.warning("Failed auth resolution for %r because %s", event, e) context.rejected = RejectedReason.AUTH_ERROR - return context - - await self._check_for_soft_fail(event, state, backfilled, origin=origin) - await self._maybe_kick_guest_users(event) return context @@ -1470,7 +1452,6 @@ class FederationEventHandler: self, event: EventBase, state: Optional[Iterable[EventBase]], - backfilled: bool, origin: str, ) -> None: """Checks if we should soft fail the event; if so, marks the event as @@ -1479,15 +1460,8 @@ class FederationEventHandler: Args: event state: The state at the event if we don't have all the event's prev events - backfilled: Whether the event is from backfill origin: The host the event originates from. """ - # For new (non-backfilled and non-outlier) events we check if the event - # passes auth based on the current state. If it doesn't then we - # "soft-fail" the event. - if backfilled or event.internal_metadata.is_outlier(): - return - extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id) extrem_ids = set(extrem_ids_list) prev_event_ids = set(event.prev_event_ids()) -- cgit 1.5.1 From 4387b791e01eb1a207fe44fecbc901eead8eb4db Mon Sep 17 00:00:00 2001 From: AndrewFerr Date: Mon, 25 Oct 2021 10:24:49 -0400 Subject: Don't set new room alias before potential 403 (#10930) Fixes: #10929 Signed-off-by: Andrew Ferrazzutti --- changelog.d/10930.bugfix | 1 + synapse/handlers/directory.py | 4 +- synapse/handlers/room.py | 18 +++---- tests/handlers/test_directory.py | 102 ++++++++++++++++++++++++++++++++++++++- 4 files changed, 113 insertions(+), 12 deletions(-) create mode 100644 changelog.d/10930.bugfix (limited to 'synapse') diff --git a/changelog.d/10930.bugfix b/changelog.d/10930.bugfix new file mode 100644 index 0000000000..756bfe9107 --- /dev/null +++ b/changelog.d/10930.bugfix @@ -0,0 +1 @@ +Newly-created public rooms are now only assigned an alias if the room's creation has not been blocked by permission settings. Contributed by @AndrewFerr. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 14ed7d9879..8567cb0e00 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -145,7 +145,7 @@ class DirectoryHandler: if not self.config.roomdirectory.is_alias_creation_allowed( user_id, room_id, room_alias_str ): - # Lets just return a generic message, as there may be all sorts of + # Let's just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages # per alias creation rule? raise SynapseError(403, "Not allowed to create alias") @@ -461,7 +461,7 @@ class DirectoryHandler: if not self.config.roomdirectory.is_publishing_room_allowed( user_id, room_id, room_aliases ): - # Lets just return a generic message, as there may be all sorts of + # Let's just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages # per alias creation rule? raise SynapseError(403, "Not allowed to publish room") diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 6f39e9446f..cf01d58ea1 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -773,6 +773,15 @@ class RoomCreationHandler: if not allowed_by_third_party_rules: raise SynapseError(403, "Room visibility value not allowed.") + if is_public: + if not self.config.roomdirectory.is_publishing_room_allowed( + user_id, room_id, room_alias + ): + # Let's just return a generic message, as there may be all sorts of + # reasons why we said no. TODO: Allow configurable error messages + # per alias creation rule? + raise SynapseError(403, "Not allowed to publish room") + directory_handler = self.hs.get_directory_handler() if room_alias: await directory_handler.create_association( @@ -783,15 +792,6 @@ class RoomCreationHandler: check_membership=False, ) - if is_public: - if not self.config.roomdirectory.is_publishing_room_allowed( - user_id, room_id, room_alias - ): - # Lets just return a generic message, as there may be all sorts of - # reasons why we said no. TODO: Allow configurable error messages - # per alias creation rule? - raise SynapseError(403, "Not allowed to publish room") - preset_config = config.get( "preset", RoomCreationPreset.PRIVATE_CHAT diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 6a2e76ca4a..be008227df 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -15,8 +15,8 @@ from unittest.mock import Mock -import synapse import synapse.api.errors +import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.config.room_directory import RoomDirectoryConfig from synapse.rest.client import directory, login, room @@ -432,6 +432,106 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.result) +class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): + data = {"room_alias_name": "unofficial_test"} + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + directory.register_servlets, + room.register_servlets, + ] + hijack_auth = False + + def prepare(self, reactor, clock, hs): + self.allowed_user_id = self.register_user("allowed", "pass") + self.allowed_access_token = self.login("allowed", "pass") + + self.denied_user_id = self.register_user("denied", "pass") + self.denied_access_token = self.login("denied", "pass") + + # This time we add custom room list publication rules + config = {} + config["alias_creation_rules"] = [] + config["room_list_publication_rules"] = [ + {"user_id": "*", "alias": "*", "action": "deny"}, + {"user_id": self.allowed_user_id, "alias": "*", "action": "allow"}, + ] + + rd_config = RoomDirectoryConfig() + rd_config.read_config(config) + + self.hs.config.roomdirectory.is_publishing_room_allowed = ( + rd_config.is_publishing_room_allowed + ) + + return hs + + def test_denied_without_publication_permission(self): + """ + Try to create a room, register an alias for it, and publish it, + as a user without permission to publish rooms. + (This is used as both a standalone test & as a helper function.) + """ + self.helper.create_room_as( + self.denied_user_id, + tok=self.denied_access_token, + extra_content=self.data, + is_public=True, + expect_code=403, + ) + + def test_allowed_when_creating_private_room(self): + """ + Try to create a room, register an alias for it, and NOT publish it, + as a user without permission to publish rooms. + (This is used as both a standalone test & as a helper function.) + """ + self.helper.create_room_as( + self.denied_user_id, + tok=self.denied_access_token, + extra_content=self.data, + is_public=False, + expect_code=200, + ) + + def test_allowed_with_publication_permission(self): + """ + Try to create a room, register an alias for it, and publish it, + as a user WITH permission to publish rooms. + (This is used as both a standalone test & as a helper function.) + """ + self.helper.create_room_as( + self.allowed_user_id, + tok=self.allowed_access_token, + extra_content=self.data, + is_public=False, + expect_code=200, + ) + + def test_can_create_as_private_room_after_rejection(self): + """ + After failing to publish a room with an alias as a user without publish permission, + retry as the same user, but without publishing the room. + + This should pass, but used to fail because the alias was registered by the first + request, even though the room creation was denied. + """ + self.test_denied_without_publication_permission() + self.test_allowed_when_creating_private_room() + + def test_can_create_with_permission_after_rejection(self): + """ + After failing to publish a room with an alias as a user without publish permission, + retry as someone with permission, using the same alias. + + This also used to fail because of the alias having been registered by the first + request, leaving it unavailable for any other user's new rooms. + """ + self.test_denied_without_publication_permission() + self.test_allowed_with_publication_permission() + + class TestRoomListSearchDisabled(unittest.HomeserverTestCase): user_id = "@test:test" -- cgit 1.5.1 From c1510c97b56060b7ab470b11264ed10dad445e14 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 25 Oct 2021 18:45:19 +0200 Subject: Fix cyclic import in the module API (#11180) Introduced in #10548 See https://github.com/matrix-org/synapse-email-account-validity/runs/3979337154?check_suite_focus=true for an example of a module's CI choking over this issue. --- changelog.d/11180.feature | 1 + synapse/handlers/auth.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) create mode 100644 changelog.d/11180.feature (limited to 'synapse') diff --git a/changelog.d/11180.feature b/changelog.d/11180.feature new file mode 100644 index 0000000000..82c40bf1b2 --- /dev/null +++ b/changelog.d/11180.feature @@ -0,0 +1 @@ +Port the Password Auth Providers module interface to the new generic interface. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index ebe75a9e9b..d508d7d32a 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -62,7 +62,6 @@ from synapse.http.server import finish_request, respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.module_api import ModuleApi from synapse.storage.roommember import ProfileInfo from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils @@ -73,6 +72,7 @@ from synapse.util.stringutils import base62_encode from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: + from synapse.module_api import ModuleApi from synapse.rest.client.login import LoginResponse from synapse.server import HomeServer @@ -1818,7 +1818,9 @@ def load_legacy_password_auth_providers(hs: "HomeServer") -> None: def load_single_legacy_password_auth_provider( - module: Type, config: JsonDict, api: ModuleApi + module: Type, + config: JsonDict, + api: "ModuleApi", ) -> None: try: provider = module(config=config, account_handler=api) -- cgit 1.5.1 From 63cbdd8af081839f245915a18ed57f1a44f1a5f4 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Tue, 26 Oct 2021 12:01:06 +0300 Subject: Enable changing user type via users admin API (#11174) Users admin API can now also modify user type in addition to allowing it to be set on user creation. Signed-off-by: Jason Robinson Co-authored-by: Brendan Abolivier --- changelog.d/11174.feature | 1 + docs/admin_api/user_admin_api.md | 9 ++++- synapse/rest/admin/users.py | 3 ++ synapse/storage/databases/main/registration.py | 18 +++++++++ tests/rest/admin/test_user.py | 51 ++++++++++++++++++++++++++ 5 files changed, 80 insertions(+), 2 deletions(-) create mode 100644 changelog.d/11174.feature (limited to 'synapse') diff --git a/changelog.d/11174.feature b/changelog.d/11174.feature new file mode 100644 index 0000000000..8eecd92681 --- /dev/null +++ b/changelog.d/11174.feature @@ -0,0 +1 @@ +Users admin API can now also modify user type in addition to allowing it to be set on user creation. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 534f8400ba..f03539c9f0 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -50,7 +50,8 @@ It returns a JSON body like the following: "auth_provider": "", "external_id": "" } - ] + ], + "user_type": null } ``` @@ -97,7 +98,8 @@ with a body of: ], "avatar_url": "", "admin": false, - "deactivated": false + "deactivated": false, + "user_type": null } ``` @@ -135,6 +137,9 @@ Body parameters: unchanged on existing accounts and set to `false` for new accounts. A user cannot be erased by deactivating with this API. For details on deactivating users see [Deactivate Account](#deactivate-account). +- `user_type` - string or null, optional. If provided, the user type will be + adjusted. If `null` given, the user type will be cleared. Other + allowed options are: `bot` and `support`. If the user already exists then optional parameters default to the current value. diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index c0bebc3cf0..d14fafbbc9 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -326,6 +326,9 @@ class UserRestServletV2(RestServlet): target_user.to_string() ) + if "user_type" in body: + await self.store.set_user_type(target_user, user_type) + user = await self.admin_handler.get_user(target_user) assert user is not None diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 37d47aa823..6c7d6ba508 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -499,6 +499,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn) + async def set_user_type(self, user: UserID, user_type: Optional[UserTypes]) -> None: + """Sets the user type. + + Args: + user: user ID of the user. + user_type: type of the user or None for a user without a type. + """ + + def set_user_type_txn(txn): + self.db_pool.simple_update_one_txn( + txn, "users", {"name": user.to_string()}, {"user_type": user_type} + ) + self._invalidate_cache_and_stream( + txn, self.get_user_by_id, (user.to_string(),) + ) + + await self.db_pool.runInteraction("set_user_type", set_user_type_txn) + def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]: sql = """ SELECT users.name as user_id, diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 839442ddba..25e8d6cf27 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2270,6 +2270,57 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) + def test_set_user_type(self): + """ + Test changing user type. + """ + + # Set to support type + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"user_type": UserTypes.SUPPORT}, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) + + # Get user + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) + + # Change back to a regular user + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"user_type": None}, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertIsNone(channel.json_body["user_type"]) + + # Get user + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertIsNone(channel.json_body["user_type"]) + def test_accidental_deactivation_prevention(self): """ Ensure an account can't accidentally be deactivated by using a str value -- cgit 1.5.1 From 8c8e36af0d6c3855de7bd786be14b85f5dae4ea7 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 26 Oct 2021 11:09:10 +0200 Subject: Document the version each module API method was added to Synapse (#11183) --- changelog.d/11183.doc | 1 + synapse/module_api/__init__.py | 99 +++++++++++++++++++++++++++++++++++++----- 2 files changed, 89 insertions(+), 11 deletions(-) create mode 100644 changelog.d/11183.doc (limited to 'synapse') diff --git a/changelog.d/11183.doc b/changelog.d/11183.doc new file mode 100644 index 0000000000..a171a107af --- /dev/null +++ b/changelog.d/11183.doc @@ -0,0 +1 @@ +Document the version of Synapse that introduced each module API method. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index d37252b6b3..d707a9325d 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -154,27 +154,42 @@ class ModuleApi: @property def register_spam_checker_callbacks(self): - """Registers callbacks for spam checking capabilities.""" + """Registers callbacks for spam checking capabilities. + + Added in Synapse v1.37.0. + """ return self._spam_checker.register_callbacks @property def register_account_validity_callbacks(self): - """Registers callbacks for account validity capabilities.""" + """Registers callbacks for account validity capabilities. + + Added in Synapse v1.39.0. + """ return self._account_validity_handler.register_account_validity_callbacks @property def register_third_party_rules_callbacks(self): - """Registers callbacks for third party event rules capabilities.""" + """Registers callbacks for third party event rules capabilities. + + Added in Synapse v1.39.0. + """ return self._third_party_event_rules.register_third_party_rules_callbacks @property def register_presence_router_callbacks(self): - """Registers callbacks for presence router capabilities.""" + """Registers callbacks for presence router capabilities. + + Added in Synapse v1.42.0. + """ return self._presence_router.register_presence_router_callbacks @property def register_password_auth_provider_callbacks(self): - """Registers callbacks for password auth provider capabilities.""" + """Registers callbacks for password auth provider capabilities. + + Added in Synapse v1.46.0. + """ return self._password_auth_provider.register_password_auth_provider_callbacks def register_web_resource(self, path: str, resource: IResource): @@ -185,6 +200,8 @@ class ModuleApi: If multiple modules register a resource for the same path, the module that appears the highest in the configuration file takes priority. + Added in Synapse v1.37.0. + Args: path: The path to register the resource for. resource: The resource to attach to this path. @@ -199,6 +216,8 @@ class ModuleApi: """Allows making outbound HTTP requests to remote resources. An instance of synapse.http.client.SimpleHttpClient + + Added in Synapse v1.22.0. """ return self._http_client @@ -208,22 +227,32 @@ class ModuleApi: public room list. An instance of synapse.module_api.PublicRoomListManager + + Added in Synapse v1.22.0. """ return self._public_room_list_manager @property def public_baseurl(self) -> str: - """The configured public base URL for this homeserver.""" + """The configured public base URL for this homeserver. + + Added in Synapse v1.39.0. + """ return self._hs.config.server.public_baseurl @property def email_app_name(self) -> str: - """The application name configured in the homeserver's configuration.""" + """The application name configured in the homeserver's configuration. + + Added in Synapse v1.39.0. + """ return self._hs.config.email.email_app_name async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: """Get user info by user_id + Added in Synapse v1.41.0. + Args: user_id: Fully qualified user id. Returns: @@ -239,6 +268,8 @@ class ModuleApi: ) -> Requester: """Check the access_token provided for a request + Added in Synapse v1.39.0. + Args: req: Incoming HTTP request allow_guest: True if guest users should be allowed. If this @@ -264,6 +295,8 @@ class ModuleApi: async def is_user_admin(self, user_id: str) -> bool: """Checks if a user is a server admin. + Added in Synapse v1.39.0. + Args: user_id: The Matrix ID of the user to check. @@ -278,6 +311,8 @@ class ModuleApi: Takes a user id provided by the user and adds the @ and :domain to qualify it, if necessary + Added in Synapse v0.25.0. + Args: username (str): provided user id @@ -291,6 +326,8 @@ class ModuleApi: async def get_profile_for_user(self, localpart: str) -> ProfileInfo: """Look up the profile info for the user with the given localpart. + Added in Synapse v1.39.0. + Args: localpart: The localpart to look up profile information for. @@ -303,6 +340,8 @@ class ModuleApi: """Look up the threepids (email addresses and phone numbers) associated with the given Matrix user ID. + Added in Synapse v1.39.0. + Args: user_id: The Matrix user ID to look up threepids for. @@ -317,6 +356,8 @@ class ModuleApi: def check_user_exists(self, user_id): """Check if user exists. + Added in Synapse v0.25.0. + Args: user_id (str): Complete @user:id @@ -336,6 +377,8 @@ class ModuleApi: return that device to the user. Prefer separate calls to register_user and register_device. + Added in Synapse v0.25.0. + Args: localpart (str): The localpart of the new user. displayname (str|None): The displayname of the new user. @@ -356,6 +399,8 @@ class ModuleApi: ): """Registers a new user with given localpart and optional displayname, emails. + Added in Synapse v1.2.0. + Args: localpart (str): The localpart of the new user. displayname (str|None): The displayname of the new user. @@ -379,6 +424,8 @@ class ModuleApi: def register_device(self, user_id, device_id=None, initial_display_name=None): """Register a device for a user and generate an access token. + Added in Synapse v1.2.0. + Args: user_id (str): full canonical @user:id device_id (str|None): The device ID to check, or None to generate @@ -402,6 +449,8 @@ class ModuleApi: ) -> defer.Deferred: """Record a mapping from an external user id to a mxid + Added in Synapse v1.9.0. + Args: auth_provider: identifier for the remote auth provider external_id: id on that system @@ -421,6 +470,8 @@ class ModuleApi: ) -> str: """Generate a login token suitable for m.login.token authentication + Added in Synapse v1.9.0. + Args: user_id: gives the ID of the user that the token is for @@ -440,6 +491,8 @@ class ModuleApi: def invalidate_access_token(self, access_token): """Invalidate an access token for a user + Added in Synapse v0.25.0. + Args: access_token(str): access token @@ -470,6 +523,8 @@ class ModuleApi: def run_db_interaction(self, desc, func, *args, **kwargs): """Run a function with a database connection + Added in Synapse v0.25.0. + Args: desc (str): description for the transaction, for metrics etc func (func): function to be run. Passed a database cursor object @@ -493,6 +548,8 @@ class ModuleApi: This is deprecated in favor of complete_sso_login_async. + Added in Synapse v1.11.1. + Args: registered_user_id: The MXID that has been registered as a previous step of of this SSO login. @@ -519,6 +576,8 @@ class ModuleApi: want their access token sent to `client_redirect_url`, or redirect them to that URL with a token directly if the URL matches with one of the whitelisted clients. + Added in Synapse v1.13.0. + Args: registered_user_id: The MXID that has been registered as a previous step of of this SSO login. @@ -547,6 +606,8 @@ class ModuleApi: (This is exposed for compatibility with the old SpamCheckerApi. We should probably deprecate it and replace it with an async method in a subclass.) + Added in Synapse v1.22.0. + Args: room_id: The room ID to get state events in. types: The event type and state key (using None @@ -567,6 +628,8 @@ class ModuleApi: async def create_and_send_event_into_room(self, event_dict: JsonDict) -> EventBase: """Create and send an event into a room. Membership events are currently not supported. + Added in Synapse v1.22.0. + Args: event_dict: A dictionary representing the event to send. Required keys are `type`, `room_id`, `sender` and `content`. @@ -607,6 +670,8 @@ class ModuleApi: Note that this method can only be run on the process that is configured to write to the presence stream. By default this is the main process. + + Added in Synapse v1.32.0. """ if self._hs._instance_name not in self._hs.config.worker.writers.presence: raise Exception( @@ -661,6 +726,8 @@ class ModuleApi: Waits `msec` initially before calling `f` for the first time. + Added in Synapse v1.39.0. + Args: f: The function to call repeatedly. f can be either synchronous or asynchronous, and must follow Synapse's logcontext rules. @@ -700,6 +767,8 @@ class ModuleApi: ): """Send an email on behalf of the homeserver. + Added in Synapse v1.39.0. + Args: recipient: The email address for the recipient. subject: The email's subject. @@ -723,6 +792,8 @@ class ModuleApi: By default, Synapse will look for these templates in its configured template directory, but another directory to search in can be provided. + Added in Synapse v1.39.0. + Args: filenames: The name of the template files to look for. custom_template_directory: An additional directory to look for the files in. @@ -740,13 +811,13 @@ class ModuleApi: """ Checks whether an ID (user id, room, ...) comes from this homeserver. + Added in Synapse v1.44.0. + Args: id: any Matrix id (e.g. user id, room id, ...), either as a raw id, e.g. string "@user:example.com" or as a parsed UserID, RoomID, ... Returns: True if id comes from this homeserver, False otherwise. - - Added in Synapse v1.44.0. """ if isinstance(id, DomainSpecificString): return self._hs.is_mine(id) @@ -759,6 +830,8 @@ class ModuleApi: """ Return the list of user IPs and agents for a user. + Added in Synapse v1.44.0. + Args: user_id: the id of a user, local or remote since_ts: a timestamp in seconds since the epoch, @@ -767,8 +840,6 @@ class ModuleApi: The list of all UserIpAndAgent that the user has used to connect to this homeserver since `since_ts`. If the user is remote, this list is empty. - - Added in Synapse v1.44.0. """ # Don't hit the db if this is not a local user. is_mine = False @@ -807,6 +878,8 @@ class PublicRoomListManager: async def room_is_in_public_room_list(self, room_id: str) -> bool: """Checks whether a room is in the public room list. + Added in Synapse v1.22.0. + Args: room_id: The ID of the room. @@ -823,6 +896,8 @@ class PublicRoomListManager: async def add_room_to_public_room_list(self, room_id: str) -> None: """Publishes a room to the public room list. + Added in Synapse v1.22.0. + Args: room_id: The ID of the room. """ @@ -831,6 +906,8 @@ class PublicRoomListManager: async def remove_room_from_public_room_list(self, room_id: str) -> None: """Removes a room from the public room list. + Added in Synapse v1.22.0. + Args: room_id: The ID of the room. """ -- cgit 1.5.1 From d52c58dfa3f548b489dae0b1945cf733d4a6538c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 26 Oct 2021 07:38:45 -0400 Subject: Add a background update for updating MSC3440 relation threads. (#11181) --- changelog.d/11181.feature | 1 + .../storage/databases/main/events_bg_updates.py | 85 +++++++++++++++++++++- .../schema/main/delta/65/02_thread_relations.sql | 18 +++++ 3 files changed, 102 insertions(+), 2 deletions(-) create mode 100644 changelog.d/11181.feature create mode 100644 synapse/storage/schema/main/delta/65/02_thread_relations.sql (limited to 'synapse') diff --git a/changelog.d/11181.feature b/changelog.d/11181.feature new file mode 100644 index 0000000000..76b0d28084 --- /dev/null +++ b/changelog.d/11181.feature @@ -0,0 +1 @@ +Experimental support for the thread relation defined in [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440). diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index fc49112063..f92d824876 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -17,11 +17,15 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import attr -from synapse.api.constants import EventContentFields +from synapse.api.constants import EventContentFields, RelationTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, make_tuple_comparison_clause +from synapse.storage.database import ( + DatabasePool, + LoggingTransaction, + make_tuple_comparison_clause, +) from synapse.storage.databases.main.events import PersistEventsStore from synapse.storage.types import Cursor from synapse.types import JsonDict @@ -167,6 +171,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): self._purged_chain_cover_index, ) + self.db_pool.updates.register_background_update_handler( + "event_thread_relation", self._event_thread_relation + ) + ################################################################################ # bg updates for replacing stream_ordering with a BIGINT @@ -1091,6 +1099,79 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return result + async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int: + """Background update handler which will store thread relations for existing events.""" + last_event_id = progress.get("last_event_id", "") + + def _event_thread_relation_txn(txn: LoggingTransaction) -> int: + txn.execute( + """ + SELECT event_id, json FROM event_json + LEFT JOIN event_relations USING (event_id) + WHERE event_id > ? AND relates_to_id IS NULL + ORDER BY event_id LIMIT ? + """, + (last_event_id, batch_size), + ) + + results = list(txn) + missing_thread_relations = [] + for (event_id, event_json_raw) in results: + try: + event_json = db_to_json(event_json_raw) + except Exception as e: + logger.warning( + "Unable to load event %s (no relations will be updated): %s", + event_id, + e, + ) + continue + + # If there's no relation (or it is not a thread), skip! + relates_to = event_json["content"].get("m.relates_to") + if not relates_to or not isinstance(relates_to, dict): + continue + if relates_to.get("rel_type") != RelationTypes.THREAD: + continue + + # Get the parent ID. + parent_id = relates_to.get("event_id") + if not isinstance(parent_id, str): + continue + + missing_thread_relations.append((event_id, parent_id)) + + # Insert the missing data. + self.db_pool.simple_insert_many_txn( + txn=txn, + table="event_relations", + values=[ + { + "event_id": event_id, + "relates_to_Id": parent_id, + "relation_type": RelationTypes.THREAD, + } + for event_id, parent_id in missing_thread_relations + ], + ) + + if results: + latest_event_id = results[-1][0] + self.db_pool.updates._background_update_progress_txn( + txn, "event_thread_relation", {"last_event_id": latest_event_id} + ) + + return len(results) + + num_rows = await self.db_pool.runInteraction( + desc="event_thread_relation", func=_event_thread_relation_txn + ) + + if not num_rows: + await self.db_pool.updates._end_background_update("event_thread_relation") + + return num_rows + async def _background_populate_stream_ordering2( self, progress: JsonDict, batch_size: int ) -> int: diff --git a/synapse/storage/schema/main/delta/65/02_thread_relations.sql b/synapse/storage/schema/main/delta/65/02_thread_relations.sql new file mode 100644 index 0000000000..d60517f7b4 --- /dev/null +++ b/synapse/storage/schema/main/delta/65/02_thread_relations.sql @@ -0,0 +1,18 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Check old events for thread relations. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6502, 'event_thread_relation', '{}'); -- cgit 1.5.1 From 7004f43da143f5d1d35c742add1238c51e62ca19 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 26 Oct 2021 13:45:38 +0100 Subject: Move DNS lookups into separate thread pool (#11177) This is to stop large bursts of lookups starving out other users of the thread pools. Fixes #11049. --- changelog.d/11177.bugfix | 1 + synapse/app/_base.py | 13 ++++- synapse/util/gai_resolver.py | 136 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 149 insertions(+), 1 deletion(-) create mode 100644 changelog.d/11177.bugfix create mode 100644 synapse/util/gai_resolver.py (limited to 'synapse') diff --git a/changelog.d/11177.bugfix b/changelog.d/11177.bugfix new file mode 100644 index 0000000000..ca5bc0df28 --- /dev/null +++ b/changelog.d/11177.bugfix @@ -0,0 +1 @@ +Fix a performance regression introduced in v1.44.0 which could cause client requests to time out when making large numbers of outbound requests. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 2ca2e051e4..03627cdcba 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -31,6 +31,7 @@ import twisted from twisted.internet import defer, error, reactor from twisted.logger import LoggingFile, LogLevel from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.python.threadpool import ThreadPool import synapse from synapse.api.constants import MAX_PDU_SIZE @@ -48,6 +49,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces from synapse.metrics.jemalloc import setup_jemalloc_stats from synapse.util.caches.lrucache import setup_expire_lru_cache_entries from synapse.util.daemonize import daemonize_process +from synapse.util.gai_resolver import GAIResolver from synapse.util.rlimit import change_resource_limit from synapse.util.versionstring import get_version_string @@ -338,9 +340,18 @@ async def start(hs: "HomeServer"): Args: hs: homeserver instance """ + reactor = hs.get_reactor() + + # We want to use a separate thread pool for the resolver so that large + # numbers of DNS requests don't starve out other users of the threadpool. + resolver_threadpool = ThreadPool(name="gai_resolver") + resolver_threadpool.start() + reactor.installNameResolver( + GAIResolver(reactor, getThreadPool=lambda: resolver_threadpool) + ) + # Set up the SIGHUP machinery. if hasattr(signal, "SIGHUP"): - reactor = hs.get_reactor() @wrap_as_background_process("sighup") def handle_sighup(*args, **kwargs): diff --git a/synapse/util/gai_resolver.py b/synapse/util/gai_resolver.py new file mode 100644 index 0000000000..a447ce4e55 --- /dev/null +++ b/synapse/util/gai_resolver.py @@ -0,0 +1,136 @@ +# This is a direct lift from +# https://github.com/twisted/twisted/blob/release-21.2.0-10091/src/twisted/internet/_resolver.py. +# We copy it here as we need to instantiate `GAIResolver` manually, but it is a +# private class. + + +from socket import ( + AF_INET, + AF_INET6, + AF_UNSPEC, + SOCK_DGRAM, + SOCK_STREAM, + gaierror, + getaddrinfo, +) + +from zope.interface import implementer + +from twisted.internet.address import IPv4Address, IPv6Address +from twisted.internet.interfaces import IHostnameResolver, IHostResolution +from twisted.internet.threads import deferToThreadPool + + +@implementer(IHostResolution) +class HostResolution: + """ + The in-progress resolution of a given hostname. + """ + + def __init__(self, name): + """ + Create a L{HostResolution} with the given name. + """ + self.name = name + + def cancel(self): + # IHostResolution.cancel + raise NotImplementedError() + + +_any = frozenset([IPv4Address, IPv6Address]) + +_typesToAF = { + frozenset([IPv4Address]): AF_INET, + frozenset([IPv6Address]): AF_INET6, + _any: AF_UNSPEC, +} + +_afToType = { + AF_INET: IPv4Address, + AF_INET6: IPv6Address, +} + +_transportToSocket = { + "TCP": SOCK_STREAM, + "UDP": SOCK_DGRAM, +} + +_socktypeToType = { + SOCK_STREAM: "TCP", + SOCK_DGRAM: "UDP", +} + + +@implementer(IHostnameResolver) +class GAIResolver: + """ + L{IHostnameResolver} implementation that resolves hostnames by calling + L{getaddrinfo} in a thread. + """ + + def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo): + """ + Create a L{GAIResolver}. + @param reactor: the reactor to schedule result-delivery on + @type reactor: L{IReactorThreads} + @param getThreadPool: a function to retrieve the thread pool to use for + scheduling name resolutions. If not supplied, the use the given + C{reactor}'s thread pool. + @type getThreadPool: 0-argument callable returning a + L{twisted.python.threadpool.ThreadPool} + @param getaddrinfo: a reference to the L{getaddrinfo} to use - mainly + parameterized for testing. + @type getaddrinfo: callable with the same signature as L{getaddrinfo} + """ + self._reactor = reactor + self._getThreadPool = ( + reactor.getThreadPool if getThreadPool is None else getThreadPool + ) + self._getaddrinfo = getaddrinfo + + def resolveHostName( + self, + resolutionReceiver, + hostName, + portNumber=0, + addressTypes=None, + transportSemantics="TCP", + ): + """ + See L{IHostnameResolver.resolveHostName} + @param resolutionReceiver: see interface + @param hostName: see interface + @param portNumber: see interface + @param addressTypes: see interface + @param transportSemantics: see interface + @return: see interface + """ + pool = self._getThreadPool() + addressFamily = _typesToAF[ + _any if addressTypes is None else frozenset(addressTypes) + ] + socketType = _transportToSocket[transportSemantics] + + def get(): + try: + return self._getaddrinfo( + hostName, portNumber, addressFamily, socketType + ) + except gaierror: + return [] + + d = deferToThreadPool(self._reactor, pool, get) + resolution = HostResolution(hostName) + resolutionReceiver.resolutionBegan(resolution) + + @d.addCallback + def deliverResults(result): + for family, socktype, _proto, _cannoname, sockaddr in result: + addrType = _afToType[family] + resolutionReceiver.addressResolved( + addrType(_socktypeToType.get(socktype, "TCP"), *sockaddr) + ) + resolutionReceiver.resolutionComplete() + + return resolution -- cgit 1.5.1 From cc75a6b1b20f599c6ec6699fb77c8a72b87d1ec2 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Tue, 26 Oct 2021 14:04:51 +0100 Subject: 1.46.0rc1 --- CHANGES.md | 74 +++++++++++++++++++++++++++++++++++++++++++++++ changelog.d/10548.feature | 1 - changelog.d/10930.bugfix | 1 - changelog.d/10972.misc | 1 - changelog.d/10975.feature | 1 - changelog.d/10984.misc | 1 - changelog.d/11001.bugfix | 1 - changelog.d/11008.misc | 1 - changelog.d/11009.bugfix | 1 - changelog.d/11012.bugfix | 1 - changelog.d/11014.misc | 1 - changelog.d/11024.misc | 1 - changelog.d/11027.bugfix | 1 - changelog.d/11035.misc | 1 - changelog.d/11048.misc | 1 - changelog.d/11051.bugfix | 1 - changelog.d/11054.misc | 1 - changelog.d/11055.misc | 1 - changelog.d/11056.misc | 1 - changelog.d/11057.misc | 1 - changelog.d/11065.misc | 1 - changelog.d/11066.misc | 1 - changelog.d/11068.misc | 1 - changelog.d/11069.doc | 1 - changelog.d/11070.misc | 1 - changelog.d/11071.misc | 1 - changelog.d/11075.bugfix | 1 - changelog.d/11076.misc | 1 - changelog.d/11077.bugfix | 1 - changelog.d/11078.bugfix | 1 - changelog.d/11083.doc | 1 - changelog.d/11088.feature | 1 - changelog.d/11089.bugfix | 1 - changelog.d/11092.doc | 1 - changelog.d/11093.doc | 1 - changelog.d/11095.misc | 1 - changelog.d/11096.doc | 1 - changelog.d/11101.bugfix | 1 - changelog.d/11103.bugfix | 1 - changelog.d/11109.misc | 1 - changelog.d/11112.bugfix | 1 - changelog.d/11115.misc | 1 - changelog.d/11116.misc | 1 - changelog.d/11118.doc | 1 - changelog.d/11120.bugfix | 1 - changelog.d/11121.misc | 1 - changelog.d/11122.misc | 1 - changelog.d/11132.doc | 1 - changelog.d/11138.misc | 1 - changelog.d/11139.misc | 1 - changelog.d/11143.misc | 1 - changelog.d/11144.misc | 1 - changelog.d/11145.bugfix | 1 - changelog.d/11146.misc | 1 - changelog.d/11174.feature | 1 - changelog.d/11177.bugfix | 1 - changelog.d/11180.feature | 1 - changelog.d/11181.feature | 1 - changelog.d/11183.doc | 1 - debian/changelog | 6 ++++ synapse/__init__.py | 2 +- 61 files changed, 81 insertions(+), 59 deletions(-) delete mode 100644 changelog.d/10548.feature delete mode 100644 changelog.d/10930.bugfix delete mode 100644 changelog.d/10972.misc delete mode 100644 changelog.d/10975.feature delete mode 100644 changelog.d/10984.misc delete mode 100644 changelog.d/11001.bugfix delete mode 100644 changelog.d/11008.misc delete mode 100644 changelog.d/11009.bugfix delete mode 100644 changelog.d/11012.bugfix delete mode 100644 changelog.d/11014.misc delete mode 100644 changelog.d/11024.misc delete mode 100644 changelog.d/11027.bugfix delete mode 100644 changelog.d/11035.misc delete mode 100644 changelog.d/11048.misc delete mode 100644 changelog.d/11051.bugfix delete mode 100644 changelog.d/11054.misc delete mode 100644 changelog.d/11055.misc delete mode 100644 changelog.d/11056.misc delete mode 100644 changelog.d/11057.misc delete mode 100644 changelog.d/11065.misc delete mode 100644 changelog.d/11066.misc delete mode 100644 changelog.d/11068.misc delete mode 100644 changelog.d/11069.doc delete mode 100644 changelog.d/11070.misc delete mode 100644 changelog.d/11071.misc delete mode 100644 changelog.d/11075.bugfix delete mode 100644 changelog.d/11076.misc delete mode 100644 changelog.d/11077.bugfix delete mode 100644 changelog.d/11078.bugfix delete mode 100644 changelog.d/11083.doc delete mode 100644 changelog.d/11088.feature delete mode 100644 changelog.d/11089.bugfix delete mode 100644 changelog.d/11092.doc delete mode 100644 changelog.d/11093.doc delete mode 100644 changelog.d/11095.misc delete mode 100644 changelog.d/11096.doc delete mode 100644 changelog.d/11101.bugfix delete mode 100644 changelog.d/11103.bugfix delete mode 100644 changelog.d/11109.misc delete mode 100644 changelog.d/11112.bugfix delete mode 100644 changelog.d/11115.misc delete mode 100644 changelog.d/11116.misc delete mode 100644 changelog.d/11118.doc delete mode 100644 changelog.d/11120.bugfix delete mode 100644 changelog.d/11121.misc delete mode 100644 changelog.d/11122.misc delete mode 100644 changelog.d/11132.doc delete mode 100644 changelog.d/11138.misc delete mode 100644 changelog.d/11139.misc delete mode 100644 changelog.d/11143.misc delete mode 100644 changelog.d/11144.misc delete mode 100644 changelog.d/11145.bugfix delete mode 100644 changelog.d/11146.misc delete mode 100644 changelog.d/11174.feature delete mode 100644 changelog.d/11177.bugfix delete mode 100644 changelog.d/11180.feature delete mode 100644 changelog.d/11181.feature delete mode 100644 changelog.d/11183.doc (limited to 'synapse') diff --git a/CHANGES.md b/CHANGES.md index 92e6c6873e..88f8b5e01d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,77 @@ +Synapse 1.46.0rc1 (2021-10-26) +============================== + +Features +-------- + +- Port the Password Auth Providers module interface to the new generic interface. ([\#10548](https://github.com/matrix-org/synapse/issues/10548), [\#11180](https://github.com/matrix-org/synapse/issues/11180)) +- Resolve and share `state_groups` for all [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) historical events in batch. ([\#10975](https://github.com/matrix-org/synapse/issues/10975)) +- Experimental support for the thread relation defined in [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440). ([\#11088](https://github.com/matrix-org/synapse/issues/11088), [\#11181](https://github.com/matrix-org/synapse/issues/11181)) +- Users admin API can now also modify user type in addition to allowing it to be set on user creation. ([\#11174](https://github.com/matrix-org/synapse/issues/11174)) + + +Bugfixes +-------- + +- Newly-created public rooms are now only assigned an alias if the room's creation has not been blocked by permission settings. Contributed by @AndrewFerr. ([\#10930](https://github.com/matrix-org/synapse/issues/10930)) +- Fix a long-standing bug which meant that events received over federation were sometimes incorrectly accepted into the room state. ([\#11001](https://github.com/matrix-org/synapse/issues/11001), [\#11009](https://github.com/matrix-org/synapse/issues/11009), [\#11012](https://github.com/matrix-org/synapse/issues/11012)) +- Fix 500 error on `/messages` when the server accumulates more than 5 backwards extremities at a given depth for a room. ([\#11027](https://github.com/matrix-org/synapse/issues/11027)) +- Fix a bug where setting a user's external_id via the admin API returns 500 and deletes users existing external mappings if that external ID is already mapped. ([\#11051](https://github.com/matrix-org/synapse/issues/11051)) +- Fix a long-standing bug where users excluded from the user directory were added into the directory if they belonged to a room which became public or private. ([\#11075](https://github.com/matrix-org/synapse/issues/11075)) +- Fix a long-standing bug when attempting to preview URLs which are in the `windows-1252` character encoding. ([\#11077](https://github.com/matrix-org/synapse/issues/11077), [\#11089](https://github.com/matrix-org/synapse/issues/11089)) +- Fix broken export-data admin command and add test script checking the command to CI. ([\#11078](https://github.com/matrix-org/synapse/issues/11078)) +- Show an error when timestamp in seconds is provided to the `/purge_media_cache` Admin API. ([\#11101](https://github.com/matrix-org/synapse/issues/11101)) +- Fix local users who left all their rooms being removed from the user directory, even if the "search_all_users" config option was enabled. ([\#11103](https://github.com/matrix-org/synapse/issues/11103)) +- Fix a bug which caused the module API's `get_user_ip_and_agents` function to always fail on workers. `get_user_ip_and_agents` was introduced in 1.44.0 and did not function correctly on worker processes at the time. ([\#11112](https://github.com/matrix-org/synapse/issues/11112)) +- Identity server connection is no longer ignoring `ip_range_whitelist`. ([\#11120](https://github.com/matrix-org/synapse/issues/11120)) +- Fix a bug introduced in Synapse v1.45.0 breaking the configuration file parsing script. ([\#11145](https://github.com/matrix-org/synapse/issues/11145)) +- Fix a performance regression introduced in v1.44.0 which could cause client requests to time out when making large numbers of outbound requests. ([\#11177](https://github.com/matrix-org/synapse/issues/11177)) + + +Improved Documentation +---------------------- + +- Fix broken links relating to module API deprecation in the upgrade notes. ([\#11069](https://github.com/matrix-org/synapse/issues/11069)) +- Add more information about what happens when a user is deactivated. ([\#11083](https://github.com/matrix-org/synapse/issues/11083)) +- Clarify the the sample log config can be copied from the documentation without issue. ([\#11092](https://github.com/matrix-org/synapse/issues/11092)) +- Update the admin API documentation with an updated list of the characters allowed in registration tokens. ([\#11093](https://github.com/matrix-org/synapse/issues/11093)) +- Document Synapse's behaviour when dealing with multiple modules registering the same callbacks and/or handlers for the same HTTP endpoints. ([\#11096](https://github.com/matrix-org/synapse/issues/11096)) +- Fix instances of `[example]{.title-ref}` in the upgrade documentation as a result of prior RST to Markdown conversion. ([\#11118](https://github.com/matrix-org/synapse/issues/11118)) +- Document the version of Synapse each module callback was introduced in. ([\#11132](https://github.com/matrix-org/synapse/issues/11132)) +- Document the version of Synapse that introduced each module API method. ([\#11183](https://github.com/matrix-org/synapse/issues/11183)) + + +Internal Changes +---------------- + +- Add type hints to `synapse.storage.databases.main.client_ips`. ([\#10972](https://github.com/matrix-org/synapse/issues/10972)) +- Fix spurious warnings about losing the logging context on the `ReplicationCommandHandler` when losing the replication connection. ([\#10984](https://github.com/matrix-org/synapse/issues/10984)) +- Include rejected status when we log events. ([\#11008](https://github.com/matrix-org/synapse/issues/11008)) +- Add some extra logging to the event persistence code. ([\#11014](https://github.com/matrix-org/synapse/issues/11014)) +- Add support for Ubuntu 21.10 "Impish Indri". ([\#11024](https://github.com/matrix-org/synapse/issues/11024)) +- Rearrange the internal workings of the incremental user directory updates. ([\#11035](https://github.com/matrix-org/synapse/issues/11035)) +- Simplify the user admin API tests. ([\#11048](https://github.com/matrix-org/synapse/issues/11048)) +- Mark the Synapse package as containing type annotations and fix export declarations so that Synapse pluggable modules may be type checked against Synapse. ([\#11054](https://github.com/matrix-org/synapse/issues/11054)) +- Improve type hints for `_wrap_in_base_path` decorator used by `MediaFilePaths`. ([\#11055](https://github.com/matrix-org/synapse/issues/11055)) +- Remove dead code from `MediaFilePaths`. ([\#11056](https://github.com/matrix-org/synapse/issues/11056)) +- Add tests for `MediaFilePaths` class. ([\#11057](https://github.com/matrix-org/synapse/issues/11057)) +- Be more lenient when parsing oEmbed response versions. ([\#11065](https://github.com/matrix-org/synapse/issues/11065)) +- Add type hints to `synapse.events`. ([\#11066](https://github.com/matrix-org/synapse/issues/11066)) +- Always dump logs from unit tests during CI runs. ([\#11068](https://github.com/matrix-org/synapse/issues/11068)) +- Create a separate module for the retention configuration. ([\#11070](https://github.com/matrix-org/synapse/issues/11070)) +- Add a test for the workaround introduced in [\#11042](https://github.com/matrix-org/synapse/pull/11042) concerning the behaviour of third-party rule modules and `SynapseError`s. ([\#11071](https://github.com/matrix-org/synapse/issues/11071)) +- Fix type hints in the relations tests. ([\#11076](https://github.com/matrix-org/synapse/issues/11076)) +- Add type hints to most `HomeServer` parameters. ([\#11095](https://github.com/matrix-org/synapse/issues/11095)) +- Add missing type hints to `synapse.api` module. ([\#11109](https://github.com/matrix-org/synapse/issues/11109)) +- Clean up some of the federation event authentication code for clarity. ([\#11115](https://github.com/matrix-org/synapse/issues/11115), [\#11116](https://github.com/matrix-org/synapse/issues/11116), [\#11122](https://github.com/matrix-org/synapse/issues/11122)) +- Add type hints for event fetching. ([\#11121](https://github.com/matrix-org/synapse/issues/11121)) +- Add docstrings and comments to the application service ephemeral event sending code. ([\#11138](https://github.com/matrix-org/synapse/issues/11138)) +- Update the `sign_json` script to support inline configuration of the signing key. ([\#11139](https://github.com/matrix-org/synapse/issues/11139)) +- Fix a long-standing bug where users excluded from the directory could still be added to the `users_who_share_private_rooms` table after a regular user joins a private room. ([\#11143](https://github.com/matrix-org/synapse/issues/11143)) +- Fix broken link in the docker image README. ([\#11144](https://github.com/matrix-org/synapse/issues/11144)) +- Add missing type hints to `synapse.crypto`. ([\#11146](https://github.com/matrix-org/synapse/issues/11146)) + + Synapse 1.45.1 (2021-10-20) =========================== diff --git a/changelog.d/10548.feature b/changelog.d/10548.feature deleted file mode 100644 index 263a811faf..0000000000 --- a/changelog.d/10548.feature +++ /dev/null @@ -1 +0,0 @@ -Port the Password Auth Providers module interface to the new generic interface. \ No newline at end of file diff --git a/changelog.d/10930.bugfix b/changelog.d/10930.bugfix deleted file mode 100644 index 756bfe9107..0000000000 --- a/changelog.d/10930.bugfix +++ /dev/null @@ -1 +0,0 @@ -Newly-created public rooms are now only assigned an alias if the room's creation has not been blocked by permission settings. Contributed by @AndrewFerr. diff --git a/changelog.d/10972.misc b/changelog.d/10972.misc deleted file mode 100644 index f66a7beaf0..0000000000 --- a/changelog.d/10972.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to `synapse.storage.databases.main.client_ips`. diff --git a/changelog.d/10975.feature b/changelog.d/10975.feature deleted file mode 100644 index 167426e1fc..0000000000 --- a/changelog.d/10975.feature +++ /dev/null @@ -1 +0,0 @@ -Resolve and share `state_groups` for all [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) historical events in batch. diff --git a/changelog.d/10984.misc b/changelog.d/10984.misc deleted file mode 100644 index 86c4081cc4..0000000000 --- a/changelog.d/10984.misc +++ /dev/null @@ -1 +0,0 @@ -Fix spurious warnings about losing the logging context on the `ReplicationCommandHandler` when losing the replication connection. diff --git a/changelog.d/11001.bugfix b/changelog.d/11001.bugfix deleted file mode 100644 index f51ffb3481..0000000000 --- a/changelog.d/11001.bugfix +++ /dev/null @@ -1 +0,0 @@ - Fix a long-standing bug which meant that events received over federation were sometimes incorrectly accepted into the room state. diff --git a/changelog.d/11008.misc b/changelog.d/11008.misc deleted file mode 100644 index a67d95d66f..0000000000 --- a/changelog.d/11008.misc +++ /dev/null @@ -1 +0,0 @@ -Include rejected status when we log events. diff --git a/changelog.d/11009.bugfix b/changelog.d/11009.bugfix deleted file mode 100644 index 13b8e5983b..0000000000 --- a/changelog.d/11009.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug which meant that events received over federation were sometimes incorrectly accepted into the room state. diff --git a/changelog.d/11012.bugfix b/changelog.d/11012.bugfix deleted file mode 100644 index 13b8e5983b..0000000000 --- a/changelog.d/11012.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug which meant that events received over federation were sometimes incorrectly accepted into the room state. diff --git a/changelog.d/11014.misc b/changelog.d/11014.misc deleted file mode 100644 index 4b99ea354f..0000000000 --- a/changelog.d/11014.misc +++ /dev/null @@ -1 +0,0 @@ -Add some extra logging to the event persistence code. diff --git a/changelog.d/11024.misc b/changelog.d/11024.misc deleted file mode 100644 index 51ad800d4d..0000000000 --- a/changelog.d/11024.misc +++ /dev/null @@ -1 +0,0 @@ -Add support for Ubuntu 21.10 "Impish Indri". \ No newline at end of file diff --git a/changelog.d/11027.bugfix b/changelog.d/11027.bugfix deleted file mode 100644 index ae6cc44470..0000000000 --- a/changelog.d/11027.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix 500 error on `/messages` when the server accumulates more than 5 backwards extremities at a given depth for a room. diff --git a/changelog.d/11035.misc b/changelog.d/11035.misc deleted file mode 100644 index 6b45b7e9bd..0000000000 --- a/changelog.d/11035.misc +++ /dev/null @@ -1 +0,0 @@ -Rearrange the internal workings of the incremental user directory updates. \ No newline at end of file diff --git a/changelog.d/11048.misc b/changelog.d/11048.misc deleted file mode 100644 index 22d3c956f5..0000000000 --- a/changelog.d/11048.misc +++ /dev/null @@ -1 +0,0 @@ -Simplify the user admin API tests. \ No newline at end of file diff --git a/changelog.d/11051.bugfix b/changelog.d/11051.bugfix deleted file mode 100644 index 63126843d2..0000000000 --- a/changelog.d/11051.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug where setting a user's external_id via the admin API returns 500 and deletes users existing external mappings if that external ID is already mapped. \ No newline at end of file diff --git a/changelog.d/11054.misc b/changelog.d/11054.misc deleted file mode 100644 index 1103368fec..0000000000 --- a/changelog.d/11054.misc +++ /dev/null @@ -1 +0,0 @@ -Mark the Synapse package as containing type annotations and fix export declarations so that Synapse pluggable modules may be type checked against Synapse. diff --git a/changelog.d/11055.misc b/changelog.d/11055.misc deleted file mode 100644 index 27688c3214..0000000000 --- a/changelog.d/11055.misc +++ /dev/null @@ -1 +0,0 @@ -Improve type hints for `_wrap_in_base_path` decorator used by `MediaFilePaths`. diff --git a/changelog.d/11056.misc b/changelog.d/11056.misc deleted file mode 100644 index dd701ed177..0000000000 --- a/changelog.d/11056.misc +++ /dev/null @@ -1 +0,0 @@ -Remove dead code from `MediaFilePaths`. diff --git a/changelog.d/11057.misc b/changelog.d/11057.misc deleted file mode 100644 index 4d412d3e9b..0000000000 --- a/changelog.d/11057.misc +++ /dev/null @@ -1 +0,0 @@ -Add tests for `MediaFilePaths` class. diff --git a/changelog.d/11065.misc b/changelog.d/11065.misc deleted file mode 100644 index c6f37fc52b..0000000000 --- a/changelog.d/11065.misc +++ /dev/null @@ -1 +0,0 @@ -Be more lenient when parsing oEmbed response versions. diff --git a/changelog.d/11066.misc b/changelog.d/11066.misc deleted file mode 100644 index 1e337bee54..0000000000 --- a/changelog.d/11066.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to `synapse.events`. diff --git a/changelog.d/11068.misc b/changelog.d/11068.misc deleted file mode 100644 index 1fe69aecde..0000000000 --- a/changelog.d/11068.misc +++ /dev/null @@ -1 +0,0 @@ -Always dump logs from unit tests during CI runs. diff --git a/changelog.d/11069.doc b/changelog.d/11069.doc deleted file mode 100644 index dae4ae1777..0000000000 --- a/changelog.d/11069.doc +++ /dev/null @@ -1 +0,0 @@ -Fix broken links relating to module API deprecation in the upgrade notes. diff --git a/changelog.d/11070.misc b/changelog.d/11070.misc deleted file mode 100644 index 52b23f9671..0000000000 --- a/changelog.d/11070.misc +++ /dev/null @@ -1 +0,0 @@ -Create a separate module for the retention configuration. diff --git a/changelog.d/11071.misc b/changelog.d/11071.misc deleted file mode 100644 index 33a11abdd5..0000000000 --- a/changelog.d/11071.misc +++ /dev/null @@ -1 +0,0 @@ -Add a test for the workaround introduced in [\#11042](https://github.com/matrix-org/synapse/pull/11042) concerning the behaviour of third-party rule modules and `SynapseError`s. diff --git a/changelog.d/11075.bugfix b/changelog.d/11075.bugfix deleted file mode 100644 index 9b24971c5a..0000000000 --- a/changelog.d/11075.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug where users excluded from the user directory were added into the directory if they belonged to a room which became public or private. \ No newline at end of file diff --git a/changelog.d/11076.misc b/changelog.d/11076.misc deleted file mode 100644 index c581a86e47..0000000000 --- a/changelog.d/11076.misc +++ /dev/null @@ -1 +0,0 @@ -Fix type hints in the relations tests. diff --git a/changelog.d/11077.bugfix b/changelog.d/11077.bugfix deleted file mode 100644 index dc35c86440..0000000000 --- a/changelog.d/11077.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug when attempting to preview URLs which are in the `windows-1252` character encoding. diff --git a/changelog.d/11078.bugfix b/changelog.d/11078.bugfix deleted file mode 100644 index cc813babe4..0000000000 --- a/changelog.d/11078.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix broken export-data admin command and add test script checking the command to CI. \ No newline at end of file diff --git a/changelog.d/11083.doc b/changelog.d/11083.doc deleted file mode 100644 index 245dd3758d..0000000000 --- a/changelog.d/11083.doc +++ /dev/null @@ -1 +0,0 @@ -Add more information about what happens when a user is deactivated. \ No newline at end of file diff --git a/changelog.d/11088.feature b/changelog.d/11088.feature deleted file mode 100644 index 76b0d28084..0000000000 --- a/changelog.d/11088.feature +++ /dev/null @@ -1 +0,0 @@ -Experimental support for the thread relation defined in [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440). diff --git a/changelog.d/11089.bugfix b/changelog.d/11089.bugfix deleted file mode 100644 index dc35c86440..0000000000 --- a/changelog.d/11089.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug when attempting to preview URLs which are in the `windows-1252` character encoding. diff --git a/changelog.d/11092.doc b/changelog.d/11092.doc deleted file mode 100644 index 916c2b3476..0000000000 --- a/changelog.d/11092.doc +++ /dev/null @@ -1 +0,0 @@ -Clarify the the sample log config can be copied from the documentation without issue. diff --git a/changelog.d/11093.doc b/changelog.d/11093.doc deleted file mode 100644 index 70fca0bdce..0000000000 --- a/changelog.d/11093.doc +++ /dev/null @@ -1 +0,0 @@ -Update the admin API documentation with an updated list of the characters allowed in registration tokens. diff --git a/changelog.d/11095.misc b/changelog.d/11095.misc deleted file mode 100644 index 786e90b595..0000000000 --- a/changelog.d/11095.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to most `HomeServer` parameters. diff --git a/changelog.d/11096.doc b/changelog.d/11096.doc deleted file mode 100644 index d8e7424289..0000000000 --- a/changelog.d/11096.doc +++ /dev/null @@ -1 +0,0 @@ -Document Synapse's behaviour when dealing with multiple modules registering the same callbacks and/or handlers for the same HTTP endpoints. diff --git a/changelog.d/11101.bugfix b/changelog.d/11101.bugfix deleted file mode 100644 index 0de507848f..0000000000 --- a/changelog.d/11101.bugfix +++ /dev/null @@ -1 +0,0 @@ -Show an error when timestamp in seconds is provided to the `/purge_media_cache` Admin API. \ No newline at end of file diff --git a/changelog.d/11103.bugfix b/changelog.d/11103.bugfix deleted file mode 100644 index 3498f04a45..0000000000 --- a/changelog.d/11103.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix local users who left all their rooms being removed from the user directory, even if the "search_all_users" config option was enabled. \ No newline at end of file diff --git a/changelog.d/11109.misc b/changelog.d/11109.misc deleted file mode 100644 index d83936ccc4..0000000000 --- a/changelog.d/11109.misc +++ /dev/null @@ -1 +0,0 @@ -Add missing type hints to `synapse.api` module. diff --git a/changelog.d/11112.bugfix b/changelog.d/11112.bugfix deleted file mode 100644 index c8e22da8cf..0000000000 --- a/changelog.d/11112.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug which caused the module API's `get_user_ip_and_agents` function to always fail on workers. `get_user_ip_and_agents` was introduced in 1.44.0 and did not function correctly on worker processes at the time. diff --git a/changelog.d/11115.misc b/changelog.d/11115.misc deleted file mode 100644 index 9a765435db..0000000000 --- a/changelog.d/11115.misc +++ /dev/null @@ -1 +0,0 @@ -Clean up some of the federation event authentication code for clarity. diff --git a/changelog.d/11116.misc b/changelog.d/11116.misc deleted file mode 100644 index 9a765435db..0000000000 --- a/changelog.d/11116.misc +++ /dev/null @@ -1 +0,0 @@ -Clean up some of the federation event authentication code for clarity. diff --git a/changelog.d/11118.doc b/changelog.d/11118.doc deleted file mode 100644 index 3c2187f3b1..0000000000 --- a/changelog.d/11118.doc +++ /dev/null @@ -1 +0,0 @@ -Fix instances of `[example]{.title-ref}` in the upgrade documentation as a result of prior RST to Markdown conversion. diff --git a/changelog.d/11120.bugfix b/changelog.d/11120.bugfix deleted file mode 100644 index 6b39e3e89d..0000000000 --- a/changelog.d/11120.bugfix +++ /dev/null @@ -1 +0,0 @@ -Identity server connection is no longer ignoring `ip_range_whitelist`. diff --git a/changelog.d/11121.misc b/changelog.d/11121.misc deleted file mode 100644 index 916beeaacb..0000000000 --- a/changelog.d/11121.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints for event fetching. diff --git a/changelog.d/11122.misc b/changelog.d/11122.misc deleted file mode 100644 index 9a765435db..0000000000 --- a/changelog.d/11122.misc +++ /dev/null @@ -1 +0,0 @@ -Clean up some of the federation event authentication code for clarity. diff --git a/changelog.d/11132.doc b/changelog.d/11132.doc deleted file mode 100644 index 4f38be5b27..0000000000 --- a/changelog.d/11132.doc +++ /dev/null @@ -1 +0,0 @@ -Document the version of Synapse each module callback was introduced in. diff --git a/changelog.d/11138.misc b/changelog.d/11138.misc deleted file mode 100644 index 79b7776975..0000000000 --- a/changelog.d/11138.misc +++ /dev/null @@ -1 +0,0 @@ -Add docstrings and comments to the application service ephemeral event sending code. \ No newline at end of file diff --git a/changelog.d/11139.misc b/changelog.d/11139.misc deleted file mode 100644 index 86a9189200..0000000000 --- a/changelog.d/11139.misc +++ /dev/null @@ -1 +0,0 @@ -Update the `sign_json` script to support inline configuration of the signing key. diff --git a/changelog.d/11143.misc b/changelog.d/11143.misc deleted file mode 100644 index 496e44a9c0..0000000000 --- a/changelog.d/11143.misc +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug where users excluded from the directory could still be added to the `users_who_share_private_rooms` table after a regular user joins a private room. \ No newline at end of file diff --git a/changelog.d/11144.misc b/changelog.d/11144.misc deleted file mode 100644 index b5db109e2b..0000000000 --- a/changelog.d/11144.misc +++ /dev/null @@ -1 +0,0 @@ -Fix broken link in the docker image README. diff --git a/changelog.d/11145.bugfix b/changelog.d/11145.bugfix deleted file mode 100644 index f369feac42..0000000000 --- a/changelog.d/11145.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in Synapse v1.45.0 breaking the configuration file parsing script. diff --git a/changelog.d/11146.misc b/changelog.d/11146.misc deleted file mode 100644 index 6ce1c9f9f5..0000000000 --- a/changelog.d/11146.misc +++ /dev/null @@ -1 +0,0 @@ -Add missing type hints to `synapse.crypto`. diff --git a/changelog.d/11174.feature b/changelog.d/11174.feature deleted file mode 100644 index 8eecd92681..0000000000 --- a/changelog.d/11174.feature +++ /dev/null @@ -1 +0,0 @@ -Users admin API can now also modify user type in addition to allowing it to be set on user creation. diff --git a/changelog.d/11177.bugfix b/changelog.d/11177.bugfix deleted file mode 100644 index ca5bc0df28..0000000000 --- a/changelog.d/11177.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a performance regression introduced in v1.44.0 which could cause client requests to time out when making large numbers of outbound requests. diff --git a/changelog.d/11180.feature b/changelog.d/11180.feature deleted file mode 100644 index 82c40bf1b2..0000000000 --- a/changelog.d/11180.feature +++ /dev/null @@ -1 +0,0 @@ -Port the Password Auth Providers module interface to the new generic interface. diff --git a/changelog.d/11181.feature b/changelog.d/11181.feature deleted file mode 100644 index 76b0d28084..0000000000 --- a/changelog.d/11181.feature +++ /dev/null @@ -1 +0,0 @@ -Experimental support for the thread relation defined in [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440). diff --git a/changelog.d/11183.doc b/changelog.d/11183.doc deleted file mode 100644 index a171a107af..0000000000 --- a/changelog.d/11183.doc +++ /dev/null @@ -1 +0,0 @@ -Document the version of Synapse that introduced each module API method. diff --git a/debian/changelog b/debian/changelog index 1ee81f2a34..ea96676f74 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.46.0~rc1) stable; urgency=medium + + * New synapse release 1.46.0~rc1. + + -- Synapse Packaging team Tue, 26 Oct 2021 14:04:04 +0100 + matrix-synapse-py3 (1.45.1) stable; urgency=medium * New synapse release 1.45.1. diff --git a/synapse/__init__.py b/synapse/__init__.py index 2687d932ea..355b36fc63 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.45.1" +__version__ = "1.46.0rc1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.5.1