From deedb917325ea9ce8085df45dd925b8d583fd661 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 8 Sep 2020 14:26:54 +0100 Subject: Fix `MultiWriterIdGenerator.current_position`. (#8257) It did not correctly handle IDs finishing being persisted out of order, resulting in the `current_position` lagging until new IDs are persisted. --- changelog.d/8257.misc | 1 + synapse/storage/util/id_generators.py | 43 +++++++++++++++++++++++++----- tests/storage/test_id_generators.py | 50 +++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 6 deletions(-) create mode 100644 changelog.d/8257.misc diff --git a/changelog.d/8257.misc b/changelog.d/8257.misc new file mode 100644 index 0000000000..47ac583eb4 --- /dev/null +++ b/changelog.d/8257.misc @@ -0,0 +1 @@ +Fix non-user visible bug in implementation of `MultiWriterIdGenerator.get_current_token_for_writer`. diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index b7eb4f8ac9..2a66b3ad4e 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -224,6 +224,10 @@ class MultiWriterIdGenerator: # should be less than the minimum of this set (if not empty). self._unfinished_ids = set() # type: Set[int] + # Set of local IDs that we've processed that are larger than the current + # position, due to there being smaller unpersisted IDs. + self._finished_ids = set() # type: Set[int] + # We track the max position where we know everything before has been # persisted. This is done by a) looking at the min across all instances # and b) noting that if we have seen a run of persisted positions @@ -348,17 +352,44 @@ class MultiWriterIdGenerator: def _mark_id_as_finished(self, next_id: int): """The ID has finished being processed so we should advance the - current poistion if possible. + current position if possible. """ with self._lock: self._unfinished_ids.discard(next_id) + self._finished_ids.add(next_id) + + new_cur = None + + if self._unfinished_ids: + # If there are unfinished IDs then the new position will be the + # largest finished ID less than the minimum unfinished ID. + + finished = set() + + min_unfinshed = min(self._unfinished_ids) + for s in self._finished_ids: + if s < min_unfinshed: + if new_cur is None or new_cur < s: + new_cur = s + else: + finished.add(s) + + # We clear these out since they're now all less than the new + # position. + self._finished_ids = finished + else: + # There are no unfinished IDs so the new position is simply the + # largest finished one. + new_cur = max(self._finished_ids) + + # We clear these out since they're now all less than the new + # position. + self._finished_ids.clear() - # Figure out if its safe to advance the position by checking there - # aren't any lower allocated IDs that are yet to finish. - if all(c > next_id for c in self._unfinished_ids): + if new_cur: curr = self._current_positions.get(self._instance_name, 0) - self._current_positions[self._instance_name] = max(curr, next_id) + self._current_positions[self._instance_name] = max(curr, new_cur) self._add_persisted_position(next_id) @@ -428,7 +459,7 @@ class MultiWriterIdGenerator: # We move the current min position up if the minimum current positions # of all instances is higher (since by definition all positions less # that that have been persisted). - min_curr = min(self._current_positions.values()) + min_curr = min(self._current_positions.values(), default=0) self._persisted_upto_position = max(min_curr, self._persisted_upto_position) # We now iterate through the seen positions, discarding those that are diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index f0a8e32f1e..20636fc400 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -122,6 +122,56 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) + def test_out_of_order_finish(self): + """Test that IDs persisted out of order are correctly handled + """ + + # Prefill table with 7 rows written by 'master' + self._insert_rows("master", 7) + + id_gen = self._create_id_generator() + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + ctx1 = self.get_success(id_gen.get_next()) + ctx2 = self.get_success(id_gen.get_next()) + ctx3 = self.get_success(id_gen.get_next()) + ctx4 = self.get_success(id_gen.get_next()) + + s1 = ctx1.__enter__() + s2 = ctx2.__enter__() + s3 = ctx3.__enter__() + s4 = ctx4.__enter__() + + self.assertEqual(s1, 8) + self.assertEqual(s2, 9) + self.assertEqual(s3, 10) + self.assertEqual(s4, 11) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + ctx2.__exit__(None, None, None) + + self.assertEqual(id_gen.get_positions(), {"master": 7}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) + + ctx1.__exit__(None, None, None) + + self.assertEqual(id_gen.get_positions(), {"master": 9}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) + + ctx4.__exit__(None, None, None) + + self.assertEqual(id_gen.get_positions(), {"master": 9}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) + + ctx3.__exit__(None, None, None) + + self.assertEqual(id_gen.get_positions(), {"master": 11}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), 11) + def test_multi_instance(self): """Test that reads and writes from multiple processes are handled correctly. -- cgit 1.5.1 From 703e2b8a960a0845a5fbe14c34c45ae802300a4d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 8 Sep 2020 14:52:51 +0100 Subject: Use the right constructor for log records (#8278) Update `log_function` to use the right factory to create log records, to make sure that they have `request` attributes. Fixes: #8267. --- changelog.d/8278.bugfix | 1 + synapse/logging/utils.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 changelog.d/8278.bugfix diff --git a/changelog.d/8278.bugfix b/changelog.d/8278.bugfix new file mode 100644 index 0000000000..50e40ca2a9 --- /dev/null +++ b/changelog.d/8278.bugfix @@ -0,0 +1 @@ +Fix a bug which cause the logging system to report errors, if `DEBUG` was enabled and no `context` filter was applied. diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py index fea774e2e5..becf66dd86 100644 --- a/synapse/logging/utils.py +++ b/synapse/logging/utils.py @@ -29,11 +29,11 @@ def _log_debug_as_f(f, msg, msg_args): lineno = f.__code__.co_firstlineno pathname = f.__code__.co_filename - record = logging.LogRecord( + record = logger.makeRecord( name=name, level=logging.DEBUG, - pathname=pathname, - lineno=lineno, + fn=pathname, + lno=lineno, msg=msg, args=msg_args, exc_info=None, -- cgit 1.5.1 From 0f545e6b9670fd780579445ff68dba95a8e08545 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 8 Sep 2020 15:00:17 +0100 Subject: Clean up types for PaginationConfig (#8250) This removes `SourcePaginationConfig` and `get_pagination_rows`. The reasoning behind this is that these generic classes/functions erased the types of the IDs it used (i.e. instead of passing around `StreamToken` it'd pass in e.g. `token.room_key`, which don't have uniform types). --- changelog.d/8250.misc | 1 + synapse/handlers/initial_sync.py | 11 ++++---- synapse/handlers/pagination.py | 42 ++++++++++++++------------- synapse/handlers/presence.py | 3 -- synapse/handlers/receipts.py | 15 ---------- synapse/notifier.py | 5 ++-- synapse/streams/config.py | 61 ++++++++++++++-------------------------- 7 files changed, 52 insertions(+), 86 deletions(-) create mode 100644 changelog.d/8250.misc diff --git a/changelog.d/8250.misc b/changelog.d/8250.misc new file mode 100644 index 0000000000..b6896a9300 --- /dev/null +++ b/changelog.d/8250.misc @@ -0,0 +1 @@ +Clean up type hints for `PaginationConfig`. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index d5ddc583ad..ddb8f0712b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -116,14 +116,13 @@ class InitialSyncHandler(BaseHandler): now_token = self.hs.get_event_sources().get_current_token() presence_stream = self.hs.get_event_sources().sources["presence"] - pagination_config = PaginationConfig(from_token=now_token) - presence, _ = await presence_stream.get_pagination_rows( - user, pagination_config.get_source_config("presence"), None + presence, _ = await presence_stream.get_new_events( + user, from_key=None, include_offline=False ) - receipt_stream = self.hs.get_event_sources().sources["receipt"] - receipt, _ = await receipt_stream.get_pagination_rows( - user, pagination_config.get_source_config("receipt"), None + joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN] + receipt = await self.store.get_linearized_receipts_for_rooms( + joined_rooms, to_key=int(now_token.receipt_key), ) tags_by_room = await self.store.get_tags_for_user(user_id) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 34ed0e2921..195a1fd77e 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -335,20 +335,16 @@ class PaginationHandler: user_id = requester.user.to_string() if pagin_config.from_token: - room_token = pagin_config.from_token.room_key + from_token = pagin_config.from_token else: - pagin_config.from_token = ( - self.hs.get_event_sources().get_current_token_for_pagination() - ) - room_token = pagin_config.from_token.room_key - - room_token = RoomStreamToken.parse(room_token) + from_token = self.hs.get_event_sources().get_current_token_for_pagination() - pagin_config.from_token = pagin_config.from_token.copy_and_replace( - "room_key", str(room_token) - ) + if pagin_config.limit is None: + # This shouldn't happen as we've set a default limit before this + # gets called. + raise Exception("limit not set") - source_config = pagin_config.get_source_config("room") + room_token = RoomStreamToken.parse(from_token.room_key) with await self.pagination_lock.read(room_id): ( @@ -358,7 +354,7 @@ class PaginationHandler: room_id, user_id, allow_departed_users=True ) - if source_config.direction == "b": + if pagin_config.direction == "b": # if we're going backwards, we might need to backfill. This # requires that we have a topo token. if room_token.topological: @@ -381,22 +377,28 @@ class PaginationHandler: member_event_id ) if RoomStreamToken.parse(leave_token).topological < max_topo: - source_config.from_key = str(leave_token) + from_token = from_token.copy_and_replace( + "room_key", leave_token + ) await self.hs.get_handlers().federation_handler.maybe_backfill( room_id, max_topo ) + to_room_key = None + if pagin_config.to_token: + to_room_key = pagin_config.to_token.room_key + events, next_key = await self.store.paginate_room_events( room_id=room_id, - from_key=source_config.from_key, - to_key=source_config.to_key, - direction=source_config.direction, - limit=source_config.limit, + from_key=from_token.room_key, + to_key=to_room_key, + direction=pagin_config.direction, + limit=pagin_config.limit, event_filter=event_filter, ) - next_token = pagin_config.from_token.copy_and_replace("room_key", next_key) + next_token = from_token.copy_and_replace("room_key", next_key) if events: if event_filter: @@ -409,7 +411,7 @@ class PaginationHandler: if not events: return { "chunk": [], - "start": pagin_config.from_token.to_string(), + "start": from_token.to_string(), "end": next_token.to_string(), } @@ -438,7 +440,7 @@ class PaginationHandler: events, time_now, as_client_event=as_client_event ) ), - "start": pagin_config.from_token.to_string(), + "start": from_token.to_string(), "end": next_token.to_string(), } diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 91a3aec1cc..1000ac95ff 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1108,9 +1108,6 @@ class PresenceEventSource: def get_current_key(self): return self.store.get_current_presence_token() - async def get_pagination_rows(self, user, pagination_config, key): - return await self.get_new_events(user, from_key=None, include_offline=False) - @cached(num_args=2, cache_context=True) async def _get_interested_in(self, user, explicit_room_id, cache_context): """Returns the set of users that the given user should see presence diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 2cc6c2eb68..bdd8e52edd 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -142,18 +142,3 @@ class ReceiptEventSource: def get_current_key(self, direction="f"): return self.store.get_max_receipt_stream_id() - - async def get_pagination_rows(self, user, config, key): - to_key = int(config.from_key) - - if config.to_key: - from_key = int(config.to_key) - else: - from_key = None - - room_ids = await self.store.get_rooms_for_user(user.to_string()) - events = await self.store.get_linearized_receipts_for_rooms( - room_ids, from_key=from_key, to_key=to_key - ) - - return (events, to_key) diff --git a/synapse/notifier.py b/synapse/notifier.py index b7f4041306..71f2370874 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -432,8 +432,9 @@ class Notifier: If explicit_room_id is set, that room will be polled for events only if it is world readable or the user has joined the room. """ - from_token = pagination_config.from_token - if not from_token: + if pagination_config.from_token: + from_token = pagination_config.from_token + else: from_token = self.event_sources.get_current_token() limit = pagination_config.limit diff --git a/synapse/streams/config.py b/synapse/streams/config.py index d97dc4d101..0bdf846edf 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -14,9 +14,13 @@ # limitations under the License. import logging +from typing import Optional + +import attr from synapse.api.errors import SynapseError from synapse.http.servlet import parse_integer, parse_string +from synapse.http.site import SynapseRequest from synapse.types import StreamToken logger = logging.getLogger(__name__) @@ -25,38 +29,22 @@ logger = logging.getLogger(__name__) MAX_LIMIT = 1000 -class SourcePaginationConfig: - - """A configuration object which stores pagination parameters for a - specific event source.""" - - def __init__(self, from_key=None, to_key=None, direction="f", limit=None): - self.from_key = from_key - self.to_key = to_key - self.direction = "f" if direction == "f" else "b" - self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None - - def __repr__(self): - return "StreamConfig(from_key=%r, to_key=%r, direction=%r, limit=%r)" % ( - self.from_key, - self.to_key, - self.direction, - self.limit, - ) - - +@attr.s(slots=True) class PaginationConfig: - """A configuration object which stores pagination parameters.""" - def __init__(self, from_token=None, to_token=None, direction="f", limit=None): - self.from_token = from_token - self.to_token = to_token - self.direction = "f" if direction == "f" else "b" - self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None + from_token = attr.ib(type=Optional[StreamToken]) + to_token = attr.ib(type=Optional[StreamToken]) + direction = attr.ib(type=str) + limit = attr.ib(type=Optional[int]) @classmethod - def from_request(cls, request, raise_invalid_params=True, default_limit=None): + def from_request( + cls, + request: SynapseRequest, + raise_invalid_params: bool = True, + default_limit: Optional[int] = None, + ) -> "PaginationConfig": direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) from_tok = parse_string(request, "from") @@ -78,8 +66,11 @@ class PaginationConfig: limit = parse_integer(request, "limit", default=default_limit) - if limit and limit < 0: - raise SynapseError(400, "Limit must be 0 or above") + if limit: + if limit < 0: + raise SynapseError(400, "Limit must be 0 or above") + + limit = min(int(limit), MAX_LIMIT) try: return PaginationConfig(from_tok, to_tok, direction, limit) @@ -87,20 +78,10 @@ class PaginationConfig: logger.exception("Failed to create pagination config") raise SynapseError(400, "Invalid request.") - def __repr__(self): + def __repr__(self) -> str: return ("PaginationConfig(from_tok=%r, to_tok=%r, direction=%r, limit=%r)") % ( self.from_token, self.to_token, self.direction, self.limit, ) - - def get_source_config(self, source_name): - keyname = "%s_key" % source_name - - return SourcePaginationConfig( - from_key=getattr(self.from_token, keyname), - to_key=getattr(self.to_token, keyname) if self.to_token else None, - direction=self.direction, - limit=self.limit, - ) -- cgit 1.5.1 From 094896a69d44a69946df099da59adec0b52da0ac Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 8 Sep 2020 16:03:09 +0100 Subject: Add a config option for validating 'next_link' parameters against a domain whitelist (#8275) This is a config option ported over from DINUM's Sydent: https://github.com/matrix-org/sydent/pull/285 They've switched to validating 3PIDs via Synapse rather than Sydent, and would like to retain this functionality. This original purpose for this change is phishing prevention. This solution could also potentially be replaced by a similar one to https://github.com/matrix-org/synapse/pull/8004, but across all `*/submit_token` endpoint. This option may still be useful to enterprise even with that safeguard in place though, if they want to be absolutely sure that their employees don't follow links to other domains. --- changelog.d/8275.feature | 1 + docs/sample_config.yaml | 18 +++++ synapse/config/server.py | 33 ++++++++- synapse/rest/client/v2_alpha/account.py | 66 +++++++++++++++--- tests/rest/client/v2_alpha/test_account.py | 103 +++++++++++++++++++++++++++-- 5 files changed, 204 insertions(+), 17 deletions(-) create mode 100644 changelog.d/8275.feature diff --git a/changelog.d/8275.feature b/changelog.d/8275.feature new file mode 100644 index 0000000000..17549c3df3 --- /dev/null +++ b/changelog.d/8275.feature @@ -0,0 +1 @@ +Add a config option to specify a whitelist of domains that a user can be redirected to after validating their email or phone number. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 3528d9e11f..994b0a62c4 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -432,6 +432,24 @@ retention: # #request_token_inhibit_3pid_errors: true +# A list of domains that the domain portion of 'next_link' parameters +# must match. +# +# This parameter is optionally provided by clients while requesting +# validation of an email or phone number, and maps to a link that +# users will be automatically redirected to after validation +# succeeds. Clients can make use this parameter to aid the validation +# process. +# +# The whitelist is applied whether the homeserver or an +# identity server is handling validation. +# +# The default value is no whitelist functionality; all domains are +# allowed. Setting this value to an empty list will instead disallow +# all domains. +# +#next_link_domain_whitelist: ["matrix.org"] + ## TLS ## diff --git a/synapse/config/server.py b/synapse/config/server.py index e85c6a0840..532b910470 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -19,7 +19,7 @@ import logging import os.path import re from textwrap import indent -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Set import attr import yaml @@ -542,6 +542,19 @@ class ServerConfig(Config): users_new_default_push_rules ) # type: set + # Whitelist of domain names that given next_link parameters must have + next_link_domain_whitelist = config.get( + "next_link_domain_whitelist" + ) # type: Optional[List[str]] + + self.next_link_domain_whitelist = None # type: Optional[Set[str]] + if next_link_domain_whitelist is not None: + if not isinstance(next_link_domain_whitelist, list): + raise ConfigError("'next_link_domain_whitelist' must be a list") + + # Turn the list into a set to improve lookup speed. + self.next_link_domain_whitelist = set(next_link_domain_whitelist) + def has_tls_listener(self) -> bool: return any(listener.tls for listener in self.listeners) @@ -1014,6 +1027,24 @@ class ServerConfig(Config): # act as if no error happened and return a fake session ID ('sid') to clients. # #request_token_inhibit_3pid_errors: true + + # A list of domains that the domain portion of 'next_link' parameters + # must match. + # + # This parameter is optionally provided by clients while requesting + # validation of an email or phone number, and maps to a link that + # users will be automatically redirected to after validation + # succeeds. Clients can make use this parameter to aid the validation + # process. + # + # The whitelist is applied whether the homeserver or an + # identity server is handling validation. + # + # The default value is no whitelist functionality; all domains are + # allowed. Setting this value to an empty list will instead disallow + # all domains. + # + #next_link_domain_whitelist: ["matrix.org"] """ % locals() ) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 3481477731..455051ac46 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -17,6 +17,11 @@ import logging import random from http import HTTPStatus +from typing import TYPE_CHECKING +from urllib.parse import urlparse + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer from synapse.api.constants import LoginType from synapse.api.errors import ( @@ -98,6 +103,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) + # The email will be sent to the stored address. # This avoids a potential account hijack by requesting a password reset to # an email address which is controlled by the attacker but which, after @@ -446,6 +454,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) + existing_user_id = await self.store.get_user_id_by_threepid("email", email) if existing_user_id is not None: @@ -517,6 +528,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) + existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) if existing_user_id is not None: @@ -603,15 +617,10 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): # Perform a 302 redirect if next_link is set if next_link: - if next_link.startswith("file:///"): - logger.warning( - "Not redirecting to next_link as it is a local file: address" - ) - else: - request.setResponseCode(302) - request.setHeader("Location", next_link) - finish_request(request) - return None + request.setResponseCode(302) + request.setHeader("Location", next_link) + finish_request(request) + return None # Otherwise show the success template html = self.config.email_add_threepid_template_success_html_content @@ -875,6 +884,45 @@ class ThreepidDeleteRestServlet(RestServlet): return 200, {"id_server_unbind_result": id_server_unbind_result} +def assert_valid_next_link(hs: "HomeServer", next_link: str): + """ + Raises a SynapseError if a given next_link value is invalid + + next_link is valid if the scheme is http(s) and the next_link.domain_whitelist config + option is either empty or contains a domain that matches the one in the given next_link + + Args: + hs: The homeserver object + next_link: The next_link value given by the client + + Raises: + SynapseError: If the next_link is invalid + """ + valid = True + + # Parse the contents of the URL + next_link_parsed = urlparse(next_link) + + # Scheme must not point to the local drive + if next_link_parsed.scheme == "file": + valid = False + + # If the domain whitelist is set, the domain must be in it + if ( + valid + and hs.config.next_link_domain_whitelist is not None + and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist + ): + valid = False + + if not valid: + raise SynapseError( + 400, + "'next_link' domain not included in whitelist, or not http(s)", + errcode=Codes.INVALID_PARAM, + ) + + class WhoamiRestServlet(RestServlet): PATTERNS = client_patterns("/account/whoami$") diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 152a5182fa..0a51aeff92 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -14,11 +14,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json import os import re from email.parser import Parser +from typing import Optional import pkg_resources @@ -29,6 +29,7 @@ from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register from tests import unittest +from tests.unittest import override_config class PasswordResetTestCase(unittest.HomeserverTestCase): @@ -668,16 +669,104 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertFalse(channel.json_body["threepids"]) - def _request_token(self, email, client_secret): + @override_config({"next_link_domain_whitelist": None}) + def test_next_link(self): + """Tests a valid next_link parameter value with no whitelist (good case)""" + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/a/good/site", + expect_code=200, + ) + + @override_config({"next_link_domain_whitelist": None}) + def test_next_link_exotic_protocol(self): + """Tests using a esoteric protocol as a next_link parameter value. + Someone may be hosting a client on IPFS etc. + """ + self._request_token( + "something@example.com", + "some_secret", + next_link="some-protocol://abcdefghijklmopqrstuvwxyz", + expect_code=200, + ) + + @override_config({"next_link_domain_whitelist": None}) + def test_next_link_file_uri(self): + """Tests next_link parameters cannot be file URI""" + # Attempt to use a next_link value that points to the local disk + self._request_token( + "something@example.com", + "some_secret", + next_link="file:///host/path", + expect_code=400, + ) + + @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) + def test_next_link_domain_whitelist(self): + """Tests next_link parameters must fit the whitelist if provided""" + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/some/good/page", + expect_code=200, + ) + + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.org/some/also/good/page", + expect_code=200, + ) + + self._request_token( + "something@example.com", + "some_secret", + next_link="https://bad.example.org/some/bad/page", + expect_code=400, + ) + + @override_config({"next_link_domain_whitelist": []}) + def test_empty_next_link_domain_whitelist(self): + """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially + disallowed + """ + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/a/page", + expect_code=400, + ) + + def _request_token( + self, + email: str, + client_secret: str, + next_link: Optional[str] = None, + expect_code: int = 200, + ) -> str: + """Request a validation token to add an email address to a user's account + + Args: + email: The email address to validate + client_secret: A secret string + next_link: A link to redirect the user to after validation + expect_code: Expected return code of the call + + Returns: + The ID of the new threepid validation session + """ + body = {"client_secret": client_secret, "email": email, "send_attempt": 1} + if next_link: + body["next_link"] = next_link + request, channel = self.make_request( - "POST", - b"account/3pid/email/requestToken", - {"client_secret": client_secret, "email": email, "send_attempt": 1}, + "POST", b"account/3pid/email/requestToken", body, ) self.render(request) - self.assertEquals(200, channel.code, channel.result) + self.assertEquals(expect_code, channel.code, channel.result) - return channel.json_body["sid"] + return channel.json_body.get("sid") def _request_token_invalid_email( self, email, expected_errcode, expected_error, client_secret="foobar", -- cgit 1.5.1 From 63c0e9e1954fc7fc10a2575c54aecc8944de60f3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 8 Sep 2020 16:48:15 +0100 Subject: Add types to StreamToken and RoomStreamToken (#8279) The intention here is to change `StreamToken.room_key` to be a `RoomStreamToken` in a future PR, but that is a big enough change without this refactoring too. --- changelog.d/8279.misc | 1 + synapse/handlers/sync.py | 5 +- synapse/storage/databases/main/devices.py | 7 +- synapse/storage/databases/main/stream.py | 21 +++-- synapse/types.py | 152 +++++++++++++++--------------- 5 files changed, 95 insertions(+), 91 deletions(-) create mode 100644 changelog.d/8279.misc diff --git a/changelog.d/8279.misc b/changelog.d/8279.misc new file mode 100644 index 0000000000..99f669001f --- /dev/null +++ b/changelog.d/8279.misc @@ -0,0 +1 @@ +Add type hints to `StreamToken` and `RoomStreamToken` classes. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index e2ddb628ff..cc47e8b62c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1310,12 +1310,11 @@ class SyncHandler: presence_source = self.event_sources.sources["presence"] since_token = sync_result_builder.since_token + presence_key = None + include_offline = False if since_token and not sync_result_builder.full_state: presence_key = since_token.presence_key include_offline = True - else: - presence_key = None - include_offline = False presence, presence_key = await presence_source.get_new_events( user=user, diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index add4e3ea0e..306fc6947c 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -481,7 +481,7 @@ class DeviceWorkerStore(SQLBaseStore): } async def get_users_whose_devices_changed( - self, from_key: str, user_ids: Iterable[str] + self, from_key: int, user_ids: Iterable[str] ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that are in the given list of user_ids. @@ -493,7 +493,6 @@ class DeviceWorkerStore(SQLBaseStore): Returns: The set of user_ids whose devices have changed since `from_key` """ - from_key = int(from_key) # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. @@ -527,7 +526,7 @@ class DeviceWorkerStore(SQLBaseStore): ) async def get_users_whose_signatures_changed( - self, user_id: str, from_key: str + self, user_id: str, from_key: int ) -> Set[str]: """Get the users who have new cross-signing signatures made by `user_id` since `from_key`. @@ -539,7 +538,7 @@ class DeviceWorkerStore(SQLBaseStore): Returns: A set of user IDs with updated signatures. """ - from_key = int(from_key) + if self._user_signature_stream_cache.has_entity_changed(user_id, from_key): sql = """ SELECT DISTINCT user_ids FROM user_signature_stream diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index be6df8a6d1..08a13a8b47 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -79,8 +79,8 @@ _EventDictReturn = namedtuple( def generate_pagination_where_clause( direction: str, column_names: Tuple[str, str], - from_token: Optional[Tuple[int, int]], - to_token: Optional[Tuple[int, int]], + from_token: Optional[Tuple[Optional[int], int]], + to_token: Optional[Tuple[Optional[int], int]], engine: BaseDatabaseEngine, ) -> str: """Creates an SQL expression to bound the columns by the pagination @@ -535,13 +535,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if limit == 0: return [], end_token - end_token = RoomStreamToken.parse(end_token) + parsed_end_token = RoomStreamToken.parse(end_token) rows, token = await self.db_pool.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, room_id, - from_token=end_token, + from_token=parsed_end_token, limit=limit, ) @@ -989,8 +989,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): bounds = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=from_token, - to_token=to_token, + from_token=from_token.as_tuple(), + to_token=to_token.as_tuple() if to_token else None, engine=self.database_engine, ) @@ -1083,16 +1083,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): and `to_key`). """ - from_key = RoomStreamToken.parse(from_key) + parsed_from_key = RoomStreamToken.parse(from_key) + parsed_to_key = None if to_key: - to_key = RoomStreamToken.parse(to_key) + parsed_to_key = RoomStreamToken.parse(to_key) rows, token = await self.db_pool.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, - from_key, - to_key, + parsed_from_key, + parsed_to_key, direction, limit, event_filter, diff --git a/synapse/types.py b/synapse/types.py index f7de48f148..ba45335038 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -18,7 +18,7 @@ import re import string import sys from collections import namedtuple -from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar +from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar import attr from signedjson.key import decode_verify_key_bytes @@ -362,22 +362,79 @@ def map_username_to_mxid_localpart(username, case_sensitive=False): return username.decode("ascii") -class StreamToken( - namedtuple( - "Token", - ( - "room_key", - "presence_key", - "typing_key", - "receipt_key", - "account_data_key", - "push_rules_key", - "to_device_key", - "device_list_key", - "groups_key", - ), +@attr.s(frozen=True, slots=True) +class RoomStreamToken: + """Tokens are positions between events. The token "s1" comes after event 1. + + s0 s1 + | | + [0] V [1] V [2] + + Tokens can either be a point in the live event stream or a cursor going + through historic events. + + When traversing the live event stream events are ordered by when they + arrived at the homeserver. + + When traversing historic events the events are ordered by their depth in + the event graph "topological_ordering" and then by when they arrived at the + homeserver "stream_ordering". + + Live tokens start with an "s" followed by the "stream_ordering" id of the + event it comes after. Historic tokens start with a "t" followed by the + "topological_ordering" id of the event it comes after, followed by "-", + followed by the "stream_ordering" id of the event it comes after. + """ + + topological = attr.ib( + type=Optional[int], + validator=attr.validators.optional(attr.validators.instance_of(int)), ) -): + stream = attr.ib(type=int, validator=attr.validators.instance_of(int)) + + @classmethod + def parse(cls, string: str) -> "RoomStreamToken": + try: + if string[0] == "s": + return cls(topological=None, stream=int(string[1:])) + if string[0] == "t": + parts = string[1:].split("-", 1) + return cls(topological=int(parts[0]), stream=int(parts[1])) + except Exception: + pass + raise SynapseError(400, "Invalid token %r" % (string,)) + + @classmethod + def parse_stream_token(cls, string: str) -> "RoomStreamToken": + try: + if string[0] == "s": + return cls(topological=None, stream=int(string[1:])) + except Exception: + pass + raise SynapseError(400, "Invalid token %r" % (string,)) + + def as_tuple(self) -> Tuple[Optional[int], int]: + return (self.topological, self.stream) + + def __str__(self) -> str: + if self.topological is not None: + return "t%d-%d" % (self.topological, self.stream) + else: + return "s%d" % (self.stream,) + + +@attr.s(slots=True, frozen=True) +class StreamToken: + room_key = attr.ib(type=str) + presence_key = attr.ib(type=int) + typing_key = attr.ib(type=int) + receipt_key = attr.ib(type=int) + account_data_key = attr.ib(type=int) + push_rules_key = attr.ib(type=int) + to_device_key = attr.ib(type=int) + device_list_key = attr.ib(type=int) + groups_key = attr.ib(type=int) + _SEPARATOR = "_" START = None # type: StreamToken @@ -385,15 +442,15 @@ class StreamToken( def from_string(cls, string): try: keys = string.split(cls._SEPARATOR) - while len(keys) < len(cls._fields): + while len(keys) < len(attr.fields(cls)): # i.e. old token from before receipt_key keys.append("0") - return cls(*keys) + return cls(keys[0], *(int(k) for k in keys[1:])) except Exception: raise SynapseError(400, "Invalid Token") def to_string(self): - return self._SEPARATOR.join([str(k) for k in self]) + return self._SEPARATOR.join([str(k) for k in attr.astuple(self)]) @property def room_stream_id(self): @@ -435,63 +492,10 @@ class StreamToken( return self def copy_and_replace(self, key, new_value): - return self._replace(**{key: new_value}) - - -StreamToken.START = StreamToken(*(["s0"] + ["0"] * (len(StreamToken._fields) - 1))) - - -class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): - """Tokens are positions between events. The token "s1" comes after event 1. - - s0 s1 - | | - [0] V [1] V [2] - - Tokens can either be a point in the live event stream or a cursor going - through historic events. - - When traversing the live event stream events are ordered by when they - arrived at the homeserver. - - When traversing historic events the events are ordered by their depth in - the event graph "topological_ordering" and then by when they arrived at the - homeserver "stream_ordering". - - Live tokens start with an "s" followed by the "stream_ordering" id of the - event it comes after. Historic tokens start with a "t" followed by the - "topological_ordering" id of the event it comes after, followed by "-", - followed by the "stream_ordering" id of the event it comes after. - """ + return attr.evolve(self, **{key: new_value}) - __slots__ = [] # type: list - - @classmethod - def parse(cls, string): - try: - if string[0] == "s": - return cls(topological=None, stream=int(string[1:])) - if string[0] == "t": - parts = string[1:].split("-", 1) - return cls(topological=int(parts[0]), stream=int(parts[1])) - except Exception: - pass - raise SynapseError(400, "Invalid token %r" % (string,)) - @classmethod - def parse_stream_token(cls, string): - try: - if string[0] == "s": - return cls(topological=None, stream=int(string[1:])) - except Exception: - pass - raise SynapseError(400, "Invalid token %r" % (string,)) - - def __str__(self): - if self.topological is not None: - return "t%d-%d" % (self.topological, self.stream) - else: - return "s%d" % (self.stream,) +StreamToken.START = StreamToken.from_string("s0_0") class ThirdPartyInstanceID( -- cgit 1.5.1 From 560f3b8609a3d1d566f33eeab029a4e96fe3ee02 Mon Sep 17 00:00:00 2001 From: "DeepBlueV7.X" Date: Tue, 8 Sep 2020 16:19:50 +0000 Subject: Include method in thumbnail media name (#7124) This fixes an issue where different methods (crop/scale) overwrite each other. This first tries the new path. If that fails and we are looking for a remote thumbnail, it tries the old path. If that still isn't found, it continues as normal. This should probably be removed in the future, after some of the newer thumbnails were generated with the new path on most deployments. Then the overhead should be minimal if the other thumbnails need to be regenerated. Signed-off-by: Nicolas Werner --- changelog.d/7124.bugfix | 1 + synapse/rest/media/v1/filepath.py | 19 +++++++- synapse/rest/media/v1/media_storage.py | 28 +++++++++++ synapse/storage/databases/main/media_repository.py | 57 ++++++++++++++++++++++ ...add_method_to_thumbnail_constraint.sql.postgres | 33 +++++++++++++ ...07add_method_to_thumbnail_constraint.sql.sqlite | 44 +++++++++++++++++ 6 files changed, 181 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7124.bugfix create mode 100644 synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres create mode 100644 synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite diff --git a/changelog.d/7124.bugfix b/changelog.d/7124.bugfix new file mode 100644 index 0000000000..8fd177780d --- /dev/null +++ b/changelog.d/7124.bugfix @@ -0,0 +1 @@ +Fix a bug in the media repository where remote thumbnails with the same size but different crop methods would overwrite each other. Contributed by @deepbluev7. diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index d2826374a7..7447eeaebe 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -80,7 +80,7 @@ class MediaFilePaths: self, server_name, file_id, width, height, content_type, method ): top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( "remote_thumbnail", server_name, @@ -92,6 +92,23 @@ class MediaFilePaths: remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) + # Legacy path that was used to store thumbnails previously. + # Should be removed after some time, when most of the thumbnails are stored + # using the new path. + def remote_media_thumbnail_rel_legacy( + self, server_name, file_id, width, height, content_type + ): + top_level_type, sub_type = content_type.split("/") + file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) + return os.path.join( + "remote_thumbnail", + server_name, + file_id[0:2], + file_id[2:4], + file_id[4:], + file_name, + ) + def remote_media_thumbnail_dir(self, server_name, file_id): return os.path.join( self.base_path, diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 3a352b5631..5681677fc9 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -147,6 +147,20 @@ class MediaStorage: if os.path.exists(local_path): return FileResponder(open(local_path, "rb")) + # Fallback for paths without method names + # Should be removed in the future + if file_info.thumbnail and file_info.server_name: + legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail_width, + height=file_info.thumbnail_height, + content_type=file_info.thumbnail_type, + ) + legacy_local_path = os.path.join(self.local_media_directory, legacy_path) + if os.path.exists(legacy_local_path): + return FileResponder(open(legacy_local_path, "rb")) + for provider in self.storage_providers: res = await provider.fetch(path, file_info) # type: Any if res: @@ -170,6 +184,20 @@ class MediaStorage: if os.path.exists(local_path): return local_path + # Fallback for paths without method names + # Should be removed in the future + if file_info.thumbnail and file_info.server_name: + legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail_width, + height=file_info.thumbnail_height, + content_type=file_info.thumbnail_type, + ) + legacy_local_path = os.path.join(self.local_media_directory, legacy_path) + if os.path.exists(legacy_local_path): + return legacy_local_path + dirname = os.path.dirname(local_path) if not os.path.exists(dirname): os.makedirs(dirname) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 86557d5512..1d76c761a6 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -17,6 +17,10 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool +BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = ( + "media_repository_drop_index_wo_method" +) + class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): @@ -32,6 +36,59 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): where_clause="url_cache IS NOT NULL", ) + # The following the updates add the method to the unique constraint of + # the thumbnail databases. That fixes an issue, where thumbnails of the + # same resolution, but different methods could overwrite one another. + # This can happen with custom thumbnail configs or with dynamic thumbnailing. + self.db_pool.updates.register_background_index_update( + update_name="local_media_repository_thumbnails_method_idx", + index_name="local_media_repository_thumbn_media_id_width_height_method_key", + table="local_media_repository_thumbnails", + columns=[ + "media_id", + "thumbnail_width", + "thumbnail_height", + "thumbnail_type", + "thumbnail_method", + ], + unique=True, + ) + + self.db_pool.updates.register_background_index_update( + update_name="remote_media_repository_thumbnails_method_idx", + index_name="remote_media_repository_thumbn_media_origin_id_width_height_method_key", + table="remote_media_cache_thumbnails", + columns=[ + "media_origin", + "media_id", + "thumbnail_width", + "thumbnail_height", + "thumbnail_type", + "thumbnail_method", + ], + unique=True, + ) + + self.db_pool.updates.register_background_update_handler( + BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD, + self._drop_media_index_without_method, + ) + + async def _drop_media_index_without_method(self, progress, batch_size): + def f(txn): + txn.execute( + "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key" + ) + txn.execute( + "ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_repository_thumbn_media_id_thumbnail_width_thum_key" + ) + + await self.db_pool.runInteraction("drop_media_indices_without_method", f) + await self.db_pool.updates._end_background_update( + BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD + ) + return 1 + class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): """Persistence for attachments and avatars""" diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres new file mode 100644 index 0000000000..b64926e9c9 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres @@ -0,0 +1,33 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This adds the method to the unique key constraint of the thumbnail databases. + * Otherwise you can't have a scaled and a cropped thumbnail with the same + * resolution, which happens quite often with dynamic thumbnailing. + * This is the postgres specific migration modifying the table with a background + * migration. + */ + +-- add new index that includes method to local media +INSERT INTO background_updates (update_name, progress_json) VALUES + ('local_media_repository_thumbnails_method_idx', '{}'); + +-- add new index that includes method to remote media +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx'); + +-- drop old index +INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES + ('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx'); + diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite new file mode 100644 index 0000000000..1d0c04b53a --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite @@ -0,0 +1,44 @@ +/* Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This adds the method to the unique key constraint of the thumbnail databases. + * Otherwise you can't have a scaled and a cropped thumbnail with the same + * resolution, which happens quite often with dynamic thumbnailing. + * This is a sqlite specific migration, since sqlite can't modify the unique + * constraint of a table without recreating it. + */ + +CREATE TABLE local_media_repository_thumbnails_new ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method ) ); + +INSERT INTO local_media_repository_thumbnails_new + SELECT media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method, thumbnail_length + FROM local_media_repository_thumbnails; + +DROP TABLE local_media_repository_thumbnails; + +ALTER TABLE local_media_repository_thumbnails_new RENAME TO local_media_repository_thumbnails; + +CREATE INDEX local_media_repository_thumbnails_media_id ON local_media_repository_thumbnails (media_id); + + + +CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails_new ( media_origin TEXT, media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_method TEXT, thumbnail_type TEXT, thumbnail_length INTEGER, filesystem_id TEXT, UNIQUE ( media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method ) ); + +INSERT INTO remote_media_cache_thumbnails_new + SELECT media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_method, thumbnail_type, thumbnail_length, filesystem_id + FROM remote_media_cache_thumbnails; + +DROP TABLE remote_media_cache_thumbnails; + +ALTER TABLE remote_media_cache_thumbnails_new RENAME TO remote_media_cache_thumbnails; -- cgit 1.5.1 From 1553adc83122ac245f523524ae1583cd556ed121 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 8 Sep 2020 17:43:31 +0100 Subject: Fix mypy error on develop (#8282) --- changelog.d/8282.misc | 1 + synapse/handlers/pagination.py | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) create mode 100644 changelog.d/8282.misc diff --git a/changelog.d/8282.misc b/changelog.d/8282.misc new file mode 100644 index 0000000000..b6896a9300 --- /dev/null +++ b/changelog.d/8282.misc @@ -0,0 +1 @@ +Clean up type hints for `PaginationConfig`. diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 195a1fd77e..ec17d3d888 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -373,12 +373,15 @@ class PaginationHandler: # case "JOIN" would have been returned. assert member_event_id - leave_token = await self.store.get_topological_token_for_event( + leave_token_str = await self.store.get_topological_token_for_event( member_event_id ) - if RoomStreamToken.parse(leave_token).topological < max_topo: + leave_token = RoomStreamToken.parse(leave_token_str) + assert leave_token.topological is not None + + if leave_token.topological < max_topo: from_token = from_token.copy_and_replace( - "room_key", leave_token + "room_key", leave_token_str ) await self.hs.get_handlers().federation_handler.maybe_backfill( -- cgit 1.5.1 From e45b834119468272816c6558ebadb5cc286f148b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 8 Sep 2020 16:50:51 -0400 Subject: Add types to async_helpers (#8260) --- changelog.d/8260.misc | 1 + mypy.ini | 3 +- synapse/util/async_helpers.py | 135 ++++++++++++++++++++++++++---------------- 3 files changed, 88 insertions(+), 51 deletions(-) create mode 100644 changelog.d/8260.misc diff --git a/changelog.d/8260.misc b/changelog.d/8260.misc new file mode 100644 index 0000000000..164eea8b59 --- /dev/null +++ b/changelog.d/8260.misc @@ -0,0 +1 @@ +Add type hints to `synapse.util.async_helpers`. diff --git a/mypy.ini b/mypy.ini index 7764f17856..460392377e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -34,7 +34,7 @@ files = synapse/http/federation/well_known_resolver.py, synapse/http/server.py, synapse/http/site.py, - synapse/logging/, + synapse/logging, synapse/metrics, synapse/module_api, synapse/notifier.py, @@ -54,6 +54,7 @@ files = synapse/storage/util, synapse/streams, synapse/types.py, + synapse/util/async_helpers.py, synapse/util/caches/descriptors.py, synapse/util/caches/stream_change_cache.py, synapse/util/metrics.py, diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index bb57e27beb..67ce9a5f39 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -17,13 +17,25 @@ import collections import logging from contextlib import contextmanager -from typing import Dict, Sequence, Set, Union +from typing import ( + Any, + Callable, + Dict, + Hashable, + Iterable, + List, + Optional, + Set, + TypeVar, + Union, +) import attr from typing_extensions import ContextManager from twisted.internet import defer from twisted.internet.defer import CancelledError +from twisted.internet.interfaces import IReactorTime from twisted.python import failure from synapse.logging.context import ( @@ -54,7 +66,7 @@ class ObservableDeferred: __slots__ = ["_deferred", "_observers", "_result"] - def __init__(self, deferred, consumeErrors=False): + def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False): object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_result", None) object.__setattr__(self, "_observers", set()) @@ -111,25 +123,25 @@ class ObservableDeferred: success, res = self._result return defer.succeed(res) if success else defer.fail(res) - def observers(self): + def observers(self) -> List[defer.Deferred]: return self._observers - def has_called(self): + def has_called(self) -> bool: return self._result is not None - def has_succeeded(self): + def has_succeeded(self) -> bool: return self._result is not None and self._result[0] is True - def get_result(self): + def get_result(self) -> Any: return self._result[1] - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: return getattr(self._deferred, name) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any) -> None: setattr(self._deferred, name, value) - def __repr__(self): + def __repr__(self) -> str: return "" % ( id(self), self._result, @@ -137,18 +149,20 @@ class ObservableDeferred: ) -def concurrently_execute(func, args, limit): - """Executes the function with each argument conncurrently while limiting +def concurrently_execute( + func: Callable, args: Iterable[Any], limit: int +) -> defer.Deferred: + """Executes the function with each argument concurrently while limiting the number of concurrent executions. Args: - func (func): Function to execute, should return a deferred or coroutine. - args (Iterable): List of arguments to pass to func, each invocation of func + func: Function to execute, should return a deferred or coroutine. + args: List of arguments to pass to func, each invocation of func gets a single argument. - limit (int): Maximum number of conccurent executions. + limit: Maximum number of conccurent executions. Returns: - deferred: Resolved when all function invocations have finished. + Deferred[list]: Resolved when all function invocations have finished. """ it = iter(args) @@ -167,14 +181,17 @@ def concurrently_execute(func, args, limit): ).addErrback(unwrapFirstError) -def yieldable_gather_results(func, iter, *args, **kwargs): +def yieldable_gather_results( + func: Callable, iter: Iterable, *args: Any, **kwargs: Any +) -> defer.Deferred: """Executes the function with each argument concurrently. Args: - func (func): Function to execute that returns a Deferred - iter (iter): An iterable that yields items that get passed as the first + func: Function to execute that returns a Deferred + iter: An iterable that yields items that get passed as the first argument to the function *args: Arguments to be passed to each call to func + **kwargs: Keyword arguments to be passed to each call to func Returns Deferred[list]: Resolved when all functions have been invoked, or errors if @@ -188,24 +205,37 @@ def yieldable_gather_results(func, iter, *args, **kwargs): ).addErrback(unwrapFirstError) +@attr.s(slots=True) +class _LinearizerEntry: + # The number of things executing. + count = attr.ib(type=int) + # Deferreds for the things blocked from executing. + deferreds = attr.ib(type=collections.OrderedDict) + + class Linearizer: """Limits concurrent access to resources based on a key. Useful to ensure only a few things happen at a time on a given resource. Example: - with (yield limiter.queue("test_key")): + with await limiter.queue("test_key"): # do some work. """ - def __init__(self, name=None, max_count=1, clock=None): + def __init__( + self, + name: Optional[str] = None, + max_count: int = 1, + clock: Optional[Clock] = None, + ): """ Args: - max_count(int): The maximum number of concurrent accesses + max_count: The maximum number of concurrent accesses """ if name is None: - self.name = id(self) + self.name = id(self) # type: Union[str, int] else: self.name = name @@ -216,15 +246,10 @@ class Linearizer: self._clock = clock self.max_count = max_count - # key_to_defer is a map from the key to a 2 element list where - # the first element is the number of things executing, and - # the second element is an OrderedDict, where the keys are deferreds for the - # things blocked from executing. - self.key_to_defer = ( - {} - ) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]] + # key_to_defer is a map from the key to a _LinearizerEntry. + self.key_to_defer = {} # type: Dict[Hashable, _LinearizerEntry] - def is_queued(self, key) -> bool: + def is_queued(self, key: Hashable) -> bool: """Checks whether there is a process queued up waiting """ entry = self.key_to_defer.get(key) @@ -234,25 +259,27 @@ class Linearizer: # There are waiting deferreds only in the OrderedDict of deferreds is # non-empty. - return bool(entry[1]) + return bool(entry.deferreds) - def queue(self, key): + def queue(self, key: Hashable) -> defer.Deferred: # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly. # (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not # propagated inside inlineCallbacks until Twisted 18.7) - entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()]) + entry = self.key_to_defer.setdefault( + key, _LinearizerEntry(0, collections.OrderedDict()) + ) # If the number of things executing is greater than the maximum # then add a deferred to the list of blocked items # When one of the things currently executing finishes it will callback # this item so that it can continue executing. - if entry[0] >= self.max_count: + if entry.count >= self.max_count: res = self._await_lock(key) else: logger.debug( "Acquired uncontended linearizer lock %r for key %r", self.name, key ) - entry[0] += 1 + entry.count += 1 res = defer.succeed(None) # once we successfully get the lock, we need to return a context manager which @@ -267,15 +294,15 @@ class Linearizer: # We've finished executing so check if there are any things # blocked waiting to execute and start one of them - entry[0] -= 1 + entry.count -= 1 - if entry[1]: - (next_def, _) = entry[1].popitem(last=False) + if entry.deferreds: + (next_def, _) = entry.deferreds.popitem(last=False) # we need to run the next thing in the sentinel context. with PreserveLoggingContext(): next_def.callback(None) - elif entry[0] == 0: + elif entry.count == 0: # We were the last thing for this key: remove it from the # map. del self.key_to_defer[key] @@ -283,7 +310,7 @@ class Linearizer: res.addCallback(_ctx_manager) return res - def _await_lock(self, key): + def _await_lock(self, key: Hashable) -> defer.Deferred: """Helper for queue: adds a deferred to the queue Assumes that we've already checked that we've reached the limit of the number @@ -298,11 +325,11 @@ class Linearizer: logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key) new_defer = make_deferred_yieldable(defer.Deferred()) - entry[1][new_defer] = 1 + entry.deferreds[new_defer] = 1 def cb(_r): logger.debug("Acquired linearizer lock %r for key %r", self.name, key) - entry[0] += 1 + entry.count += 1 # if the code holding the lock completes synchronously, then it # will recursively run the next claimant on the list. That can @@ -331,7 +358,7 @@ class Linearizer: ) # we just have to take ourselves back out of the queue. - del entry[1][new_defer] + del entry.deferreds[new_defer] return e new_defer.addCallbacks(cb, eb) @@ -419,14 +446,22 @@ class ReadWriteLock: return _ctx_manager() -def _cancelled_to_timed_out_error(value, timeout): +R = TypeVar("R") + + +def _cancelled_to_timed_out_error(value: R, timeout: float) -> R: if isinstance(value, failure.Failure): value.trap(CancelledError) raise defer.TimeoutError(timeout, "Deferred") return value -def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): +def timeout_deferred( + deferred: defer.Deferred, + timeout: float, + reactor: IReactorTime, + on_timeout_cancel: Optional[Callable[[Any, float], Any]] = None, +) -> defer.Deferred: """The in built twisted `Deferred.addTimeout` fails to time out deferreds that have a canceller that throws exceptions. This method creates a new deferred that wraps and times out the given deferred, correctly handling @@ -437,10 +472,10 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred Args: - deferred (Deferred) - timeout (float): Timeout in seconds - reactor (twisted.interfaces.IReactorTime): The twisted reactor to use - on_timeout_cancel (callable): A callable which is called immediately + deferred: The Deferred to potentially timeout. + timeout: Timeout in seconds + reactor: The twisted reactor to use + on_timeout_cancel: A callable which is called immediately after the deferred times out, and not if this deferred is otherwise cancelled before the timeout. @@ -452,7 +487,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None): CancelledError Failure into a defer.TimeoutError. Returns: - Deferred + A new Deferred. """ new_d = defer.Deferred() -- cgit 1.5.1