diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/config/spam_checker.py | 4 | ||||
-rw-r--r-- | synapse/handlers/pagination.py | 5 | ||||
-rw-r--r-- | synapse/handlers/relations.py | 117 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 30 | ||||
-rw-r--r-- | synapse/push/bulk_push_rule_evaluator.py | 2 | ||||
-rw-r--r-- | synapse/rest/client/relations.py | 75 | ||||
-rw-r--r-- | synapse/rest/media/v1/preview_html.py | 39 | ||||
-rw-r--r-- | synapse/rest/media/v1/preview_url_resource.py | 23 | ||||
-rw-r--r-- | synapse/server.py | 5 | ||||
-rw-r--r-- | synapse/storage/database.py | 61 | ||||
-rw-r--r-- | synapse/storage/databases/main/account_data.py | 41 | ||||
-rw-r--r-- | synapse/storage/databases/main/cache.py | 57 | ||||
-rw-r--r-- | synapse/visibility.py | 18 |
13 files changed, 266 insertions, 211 deletions
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py index a233a9ce03..4c52103b1c 100644 --- a/synapse/config/spam_checker.py +++ b/synapse/config/spam_checker.py @@ -25,8 +25,8 @@ logger = logging.getLogger(__name__) LEGACY_SPAM_CHECKER_WARNING = """ This server is using a spam checker module that is implementing the deprecated spam checker interface. Please check with the module's maintainer to see if a new version -supporting Synapse's generic modules system is available. -For more information, please see https://matrix-org.github.io/synapse/latest/modules.html +supporting Synapse's generic modules system is available. For more information, please +see https://matrix-org.github.io/synapse/latest/modules/index.html ---------------------------------------------------------------------------------------""" diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 60059fec3e..41679f7f86 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set import attr @@ -422,7 +422,7 @@ class PaginationHandler: pagin_config: PaginationConfig, as_client_event: bool = True, event_filter: Optional[Filter] = None, - ) -> Dict[str, Any]: + ) -> JsonDict: """Get messages in a room. Args: @@ -431,6 +431,7 @@ class PaginationHandler: pagin_config: The pagination config rules to apply, if any. as_client_event: True to get events in client-server format. event_filter: Filter to apply to results or None + Returns: Pagination API results """ diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py new file mode 100644 index 0000000000..8e475475ad --- /dev/null +++ b/synapse/handlers/relations.py @@ -0,0 +1,117 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Optional + +from synapse.api.errors import SynapseError +from synapse.types import JsonDict, Requester, StreamToken + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +logger = logging.getLogger(__name__) + + +class RelationsHandler: + def __init__(self, hs: "HomeServer"): + self._main_store = hs.get_datastores().main + self._auth = hs.get_auth() + self._clock = hs.get_clock() + self._event_handler = hs.get_event_handler() + self._event_serializer = hs.get_event_client_serializer() + + async def get_relations( + self, + requester: Requester, + event_id: str, + room_id: str, + relation_type: Optional[str] = None, + event_type: Optional[str] = None, + aggregation_key: Optional[str] = None, + limit: int = 5, + direction: str = "b", + from_token: Optional[StreamToken] = None, + to_token: Optional[StreamToken] = None, + ) -> JsonDict: + """Get related events of a event, ordered by topological ordering. + + TODO Accept a PaginationConfig instead of individual pagination parameters. + + Args: + requester: The user requesting the relations. + event_id: Fetch events that relate to this event ID. + room_id: The room the event belongs to. + relation_type: Only fetch events with this relation type, if given. + event_type: Only fetch events with this event type, if given. + aggregation_key: Only fetch events with this aggregation key, if given. + limit: Only fetch the most recent `limit` events. + direction: Whether to fetch the most recent first (`"b"`) or the + oldest first (`"f"`). + from_token: Fetch rows from the given token, or from the start if None. + to_token: Fetch rows up to the given token, or up to the end if None. + + Returns: + The pagination chunk. + """ + + user_id = requester.user.to_string() + + await self._auth.check_user_in_room_or_world_readable( + room_id, user_id, allow_departed_users=True + ) + + # This gets the original event and checks that a) the event exists and + # b) the user is allowed to view it. + event = await self._event_handler.get_event(requester.user, room_id, event_id) + if event is None: + raise SynapseError(404, "Unknown parent event.") + + pagination_chunk = await self._main_store.get_relations_for_event( + event_id=event_id, + event=event, + room_id=room_id, + relation_type=relation_type, + event_type=event_type, + aggregation_key=aggregation_key, + limit=limit, + direction=direction, + from_token=from_token, + to_token=to_token, + ) + + events = await self._main_store.get_events_as_list( + [c["event_id"] for c in pagination_chunk.chunk] + ) + + now = self._clock.time_msec() + # Do not bundle aggregations when retrieving the original event because + # we want the content before relations are applied to it. + original_event = self._event_serializer.serialize_event( + event, now, bundle_aggregations=None + ) + # The relations returned for the requested event do include their + # bundled aggregations. + aggregations = await self._main_store.get_bundled_aggregations( + events, requester.user.to_string() + ) + serialized_events = self._event_serializer.serialize_events( + events, now, bundle_aggregations=aggregations + ) + + return_value = await pagination_chunk.to_dict(self._main_store) + return_value["chunk"] = serialized_events + return_value["original_event"] = original_event + + return return_value diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 0aa3052fd6..c9d6a18bd7 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -28,7 +28,7 @@ from typing import ( import attr from prometheus_client import Counter -from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes +from synapse.api.constants import EventTypes, Membership, ReceiptTypes from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS @@ -1601,7 +1601,7 @@ class SyncHandler: return set(), set(), set(), set() # 3. Work out which rooms need reporting in the sync response. - ignored_users = await self._get_ignored_users(user_id) + ignored_users = await self.store.ignored_users(user_id) if since_token: room_changes = await self._get_rooms_changed( sync_result_builder, ignored_users @@ -1627,7 +1627,6 @@ class SyncHandler: logger.debug("Generating room entry for %s", room_entry.room_id) await self._generate_room_entry( sync_result_builder, - ignored_users, room_entry, ephemeral=ephemeral_by_room.get(room_entry.room_id, []), tags=tags_by_room.get(room_entry.room_id), @@ -1657,29 +1656,6 @@ class SyncHandler: newly_left_users, ) - async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]: - """Retrieve the users ignored by the given user from their global account_data. - - Returns an empty set if - - there is no global account_data entry for ignored_users - - there is such an entry, but it's not a JSON object. - """ - # TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead? - ignored_account_data = ( - await self.store.get_global_account_data_by_type_for_user( - user_id=user_id, data_type=AccountDataTypes.IGNORED_USER_LIST - ) - ) - - # If there is ignored users account data and it matches the proper type, - # then use it. - ignored_users: FrozenSet[str] = frozenset() - if ignored_account_data: - ignored_users_data = ignored_account_data.get("ignored_users", {}) - if isinstance(ignored_users_data, dict): - ignored_users = frozenset(ignored_users_data.keys()) - return ignored_users - async def _have_rooms_changed( self, sync_result_builder: "SyncResultBuilder" ) -> bool: @@ -2022,7 +1998,6 @@ class SyncHandler: async def _generate_room_entry( self, sync_result_builder: "SyncResultBuilder", - ignored_users: FrozenSet[str], room_builder: "RoomSyncResultBuilder", ephemeral: List[JsonDict], tags: Optional[Dict[str, Dict[str, Any]]], @@ -2051,7 +2026,6 @@ class SyncHandler: Args: sync_result_builder - ignored_users: Set of users ignored by user. room_builder ephemeral: List of new ephemeral events for room tags: List of *all* tags for room, or None if there has been diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8140afcb6b..030898e4d0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -213,7 +213,7 @@ class BulkPushRuleEvaluator: if not event.is_state(): ignorers = await self.store.ignored_by(event.sender) else: - ignorers = set() + ignorers = frozenset() for uid, rules in rules_by_user.items(): if event.sender == uid: diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index d9a6be43f7..c16078b187 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -51,9 +51,7 @@ class RelationPaginationServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.clock = hs.get_clock() - self._event_serializer = hs.get_event_client_serializer() - self.event_handler = hs.get_event_handler() + self._relations_handler = hs.get_relations_handler() async def on_GET( self, @@ -65,16 +63,6 @@ class RelationPaginationServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - await self.auth.check_user_in_room_or_world_readable( - room_id, requester.user.to_string(), allow_departed_users=True - ) - - # This gets the original event and checks that a) the event exists and - # b) the user is allowed to view it. - event = await self.event_handler.get_event(requester.user, room_id, parent_id) - if event is None: - raise SynapseError(404, "Unknown parent event.") - limit = parse_integer(request, "limit", default=5) direction = parse_string( request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"] @@ -90,9 +78,9 @@ class RelationPaginationServlet(RestServlet): if to_token_str: to_token = await StreamToken.from_string(self.store, to_token_str) - pagination_chunk = await self.store.get_relations_for_event( + result = await self._relations_handler.get_relations( + requester=requester, event_id=parent_id, - event=event, room_id=room_id, relation_type=relation_type, event_type=event_type, @@ -102,30 +90,7 @@ class RelationPaginationServlet(RestServlet): to_token=to_token, ) - events = await self.store.get_events_as_list( - [c["event_id"] for c in pagination_chunk.chunk] - ) - - now = self.clock.time_msec() - # Do not bundle aggregations when retrieving the original event because - # we want the content before relations are applied to it. - original_event = self._event_serializer.serialize_event( - event, now, bundle_aggregations=None - ) - # The relations returned for the requested event do include their - # bundled aggregations. - aggregations = await self.store.get_bundled_aggregations( - events, requester.user.to_string() - ) - serialized_events = self._event_serializer.serialize_events( - events, now, bundle_aggregations=aggregations - ) - - return_value = await pagination_chunk.to_dict(self.store) - return_value["chunk"] = serialized_events - return_value["original_event"] = original_event - - return 200, return_value + return 200, result class RelationAggregationPaginationServlet(RestServlet): @@ -245,9 +210,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.clock = hs.get_clock() - self._event_serializer = hs.get_event_client_serializer() - self.event_handler = hs.get_event_handler() + self._relations_handler = hs.get_relations_handler() async def on_GET( self, @@ -260,18 +223,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - await self.auth.check_user_in_room_or_world_readable( - room_id, - requester.user.to_string(), - allow_departed_users=True, - ) - - # This checks that a) the event exists and b) the user is allowed to - # view it. - event = await self.event_handler.get_event(requester.user, room_id, parent_id) - if event is None: - raise SynapseError(404, "Unknown parent event.") - if relation_type != RelationTypes.ANNOTATION: raise SynapseError(400, "Relation type must be 'annotation'") @@ -286,9 +237,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet): if to_token_str: to_token = await StreamToken.from_string(self.store, to_token_str) - result = await self.store.get_relations_for_event( + result = await self._relations_handler.get_relations( + requester=requester, event_id=parent_id, - event=event, room_id=room_id, relation_type=relation_type, event_type=event_type, @@ -298,17 +249,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): to_token=to_token, ) - events = await self.store.get_events_as_list( - [c["event_id"] for c in result.chunk] - ) - - now = self.clock.time_msec() - serialized_events = self._event_serializer.serialize_events(events, now) - - return_value = await result.to_dict(self.store) - return_value["chunk"] = serialized_events - - return 200, return_value + return 200, result def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py index 872a9e72e8..4cc9c66fbe 100644 --- a/synapse/rest/media/v1/preview_html.py +++ b/synapse/rest/media/v1/preview_html.py @@ -16,7 +16,6 @@ import itertools import logging import re from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union -from urllib import parse as urlparse if TYPE_CHECKING: from lxml import etree @@ -144,9 +143,7 @@ def decode_body( return etree.fromstring(body, parser) -def parse_html_to_open_graph( - tree: "etree.Element", media_uri: str -) -> Dict[str, Optional[str]]: +def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: """ Parse the HTML document into an Open Graph response. @@ -155,7 +152,6 @@ def parse_html_to_open_graph( Args: tree: The parsed HTML document. - media_url: The URI used to download the body. Returns: The Open Graph response as a dictionary. @@ -209,7 +205,7 @@ def parse_html_to_open_graph( "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" ) if meta_image: - og["og:image"] = rebase_url(meta_image[0], media_uri) + og["og:image"] = meta_image[0] else: # TODO: consider inlined CSS styles as well as width & height attribs images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") @@ -320,37 +316,6 @@ def _iterate_over_text( ) -def rebase_url(url: str, base: str) -> str: - """ - Resolves a potentially relative `url` against an absolute `base` URL. - - For example: - - >>> rebase_url("subpage", "https://example.com/foo/") - 'https://example.com/foo/subpage' - >>> rebase_url("sibling", "https://example.com/foo") - 'https://example.com/sibling' - >>> rebase_url("/bar", "https://example.com/foo/") - 'https://example.com/bar' - >>> rebase_url("https://alice.com/a/", "https://example.com/foo/") - 'https://alice.com/a' - """ - base_parts = urlparse.urlparse(base) - # Convert the parsed URL to a list for (potential) modification. - url_parts = list(urlparse.urlparse(url)) - # Add a scheme, if one does not exist. - if not url_parts[0]: - url_parts[0] = base_parts.scheme or "http" - # Fix up the hostname, if this is not a data URL. - if url_parts[0] != "data" and not url_parts[1]: - url_parts[1] = base_parts.netloc - # If the path does not start with a /, nest it under the base path's last - # directory. - if not url_parts[2].startswith("/"): - url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts.path) + url_parts[2] - return urlparse.urlunparse(url_parts) - - def summarize_paragraphs( text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500 ) -> Optional[str]: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 14ea88b240..d47af8ead6 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -22,7 +22,7 @@ import shutil import sys import traceback from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple -from urllib import parse as urlparse +from urllib.parse import urljoin, urlparse, urlsplit from urllib.request import urlopen import attr @@ -44,11 +44,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.oembed import OEmbedProvider -from synapse.rest.media.v1.preview_html import ( - decode_body, - parse_html_to_open_graph, - rebase_url, -) +from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph from synapse.types import JsonDict, UserID from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred @@ -187,7 +183,7 @@ class PreviewUrlResource(DirectServeJsonResource): ts = self.clock.time_msec() # XXX: we could move this into _do_preview if we wanted. - url_tuple = urlparse.urlsplit(url) + url_tuple = urlsplit(url) for entry in self.url_preview_url_blacklist: match = True for attrib in entry: @@ -322,7 +318,7 @@ class PreviewUrlResource(DirectServeJsonResource): # Parse Open Graph information from the HTML in case the oEmbed # response failed or is incomplete. - og_from_html = parse_html_to_open_graph(tree, media_info.uri) + og_from_html = parse_html_to_open_graph(tree) # Compile the Open Graph response by using the scraped # information from the HTML and overlaying any information @@ -588,12 +584,17 @@ class PreviewUrlResource(DirectServeJsonResource): if "og:image" not in og or not og["og:image"]: return + # The image URL from the HTML might be relative to the previewed page, + # convert it to an URL which can be requested directly. + image_url = og["og:image"] + url_parts = urlparse(image_url) + if url_parts.scheme != "data": + image_url = urljoin(media_info.uri, image_url) + # FIXME: it might be cleaner to use the same flow as the main /preview_url # request itself and benefit from the same caching etc. But for now we # just rely on the caching on the master request to speed things up. - image_info = await self._handle_url( - rebase_url(og["og:image"], media_info.uri), user, allow_data_urls=True - ) + image_info = await self._handle_url(image_url, user, allow_data_urls=True) if _is_media(image_info.media_type): # TODO: make sure we don't choke on white-on-transparent images diff --git a/synapse/server.py b/synapse/server.py index 2fcf18a7a6..380369db92 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -94,6 +94,7 @@ from synapse.handlers.profile import ProfileHandler from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.register import RegistrationHandler +from synapse.handlers.relations import RelationsHandler from synapse.handlers.room import ( RoomContextHandler, RoomCreationHandler, @@ -720,6 +721,10 @@ class HomeServer(metaclass=abc.ABCMeta): return PaginationHandler(self) @cache_in_self + def get_relations_handler(self) -> RelationsHandler: + return RelationsHandler(self) + + @cache_in_self def get_room_context_handler(self) -> RoomContextHandler: return RoomContextHandler(self) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 99802228c9..9749f0c06e 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -41,6 +41,7 @@ from prometheus_client import Histogram from typing_extensions import Literal from twisted.enterprise import adbapi +from twisted.internet import defer from synapse.api.errors import StoreError from synapse.config.database import DatabaseConnectionConfig @@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor +from synapse.util.async_helpers import delay_cancellation from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -732,34 +734,45 @@ class DatabasePool: Returns: The result of func """ - after_callbacks: List[_CallbackListEntry] = [] - exception_callbacks: List[_CallbackListEntry] = [] - if not current_context(): - logger.warning("Starting db txn '%s' from sentinel context", desc) + async def _runInteraction() -> R: + after_callbacks: List[_CallbackListEntry] = [] + exception_callbacks: List[_CallbackListEntry] = [] - try: - with opentracing.start_active_span(f"db.{desc}"): - result = await self.runWithConnection( - self.new_transaction, - desc, - after_callbacks, - exception_callbacks, - func, - *args, - db_autocommit=db_autocommit, - isolation_level=isolation_level, - **kwargs, - ) + if not current_context(): + logger.warning("Starting db txn '%s' from sentinel context", desc) - for after_callback, after_args, after_kwargs in after_callbacks: - after_callback(*after_args, **after_kwargs) - except Exception: - for after_callback, after_args, after_kwargs in exception_callbacks: - after_callback(*after_args, **after_kwargs) - raise + try: + with opentracing.start_active_span(f"db.{desc}"): + result = await self.runWithConnection( + self.new_transaction, + desc, + after_callbacks, + exception_callbacks, + func, + *args, + db_autocommit=db_autocommit, + isolation_level=isolation_level, + **kwargs, + ) - return cast(R, result) + for after_callback, after_args, after_kwargs in after_callbacks: + after_callback(*after_args, **after_kwargs) + + return cast(R, result) + except Exception: + for after_callback, after_args, after_kwargs in exception_callbacks: + after_callback(*after_args, **after_kwargs) + raise + + # To handle cancellation, we ensure that `after_callback`s and + # `exception_callback`s are always run, since the transaction will complete + # on another thread regardless of cancellation. + # + # We also wait until everything above is done before releasing the + # `CancelledError`, so that logging contexts won't get used after they have been + # finished. + return await delay_cancellation(defer.ensureDeferred(_runInteraction())) async def runWithConnection( self, diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 52146aacc8..9af9f4f18e 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -14,7 +14,17 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Tuple, + cast, +) from synapse.api.constants import AccountDataTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker @@ -365,7 +375,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) @cached(max_entries=5000, iterable=True) - async def ignored_by(self, user_id: str) -> Set[str]: + async def ignored_by(self, user_id: str) -> FrozenSet[str]: """ Get users which ignore the given user. @@ -375,7 +385,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) Return: The user IDs which ignore the given user. """ - return set( + return frozenset( await self.db_pool.simple_select_onecol( table="ignored_users", keyvalues={"ignored_user_id": user_id}, @@ -384,6 +394,26 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) ) + @cached(max_entries=5000, iterable=True) + async def ignored_users(self, user_id: str) -> FrozenSet[str]: + """ + Get users which the given user ignores. + + Params: + user_id: The user ID which is making the request. + + Return: + The user IDs which are ignored by the given user. + """ + return frozenset( + await self.db_pool.simple_select_onecol( + table="ignored_users", + keyvalues={"ignorer_user_id": user_id}, + retcol="ignored_user_id", + desc="ignored_users", + ) + ) + def process_replication_rows( self, stream_name: str, @@ -529,6 +559,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) else: currently_ignored_users = set() + # If the data has not changed, nothing to do. + if previously_ignored_users == currently_ignored_users: + return + # Delete entries which are no longer ignored. self.db_pool.simple_delete_many_txn( txn, @@ -551,6 +585,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) # Invalidate the cache for any ignored users which were added or removed. for ignored_user_id in previously_ignored_users ^ currently_ignored_users: self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) + self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) async def purge_account_data_for_user(self, user_id: str) -> None: """ diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index d6a2df1afe..2d7511d613 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamCurrentStateRow, EventsStreamEventRow, + EventsStreamRow, ) from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -31,6 +32,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines import PostgresEngine +from synapse.util.caches.descriptors import _CachedFunction from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_updated_caches_txn(txn): + def get_all_updated_caches_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: # We purposefully don't bound by the current token, as we want to # send across cache invalidations as quickly as possible. Cache # invalidations are idempotent, so duplicates are fine. @@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): "get_all_updated_caches", get_all_updated_caches_txn ) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: if stream_name == EventsStream.NAME: for row in rows: self._process_event_stream_row(token, row) @@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) - def _process_event_stream_row(self, token, row): + def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: data = row.data if row.type == EventsStreamEventRow.TypeId: + assert isinstance(data, EventsStreamEventRow) self._invalidate_caches_for_event( token, data.event_id, @@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): backfilled=False, ) elif row.type == EventsStreamCurrentStateRow.TypeId: - self._curr_state_delta_stream_cache.entity_has_changed( - row.data.room_id, token - ) + assert isinstance(data, EventsStreamCurrentStateRow) + self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) if data.type == EventTypes.Member: self.get_rooms_for_user_with_stream_ordering.invalidate( @@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore): def _invalidate_caches_for_event( self, - stream_ordering, - event_id, - room_id, - etype, - state_key, - redacts, - relates_to, - backfilled, - ): + stream_ordering: int, + event_id: str, + room_id: str, + etype: str, + state_key: Optional[str], + redacts: Optional[str], + relates_to: Optional[str], + backfilled: bool, + ) -> None: self._invalidate_get_event_cache(event_id) self.have_seen_event.invalidate((room_id, event_id)) @@ -207,7 +213,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_thread_summary.invalidate((relates_to,)) self.get_thread_participated.invalidate((relates_to,)) - async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): + async def invalidate_cache_and_stream( + self, cache_name: str, keys: Tuple[Any, ...] + ) -> None: """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -227,7 +235,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore): keys, ) - def _invalidate_cache_and_stream(self, txn, cache_func, keys): + def _invalidate_cache_and_stream( + self, + txn: LoggingTransaction, + cache_func: _CachedFunction, + keys: Tuple[Any, ...], + ) -> None: """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -238,7 +251,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): txn.call_after(cache_func.invalidate, keys) self._send_invalidation_to_replication(txn, cache_func.__name__, keys) - def _invalidate_all_cache_and_stream(self, txn, cache_func): + def _invalidate_all_cache_and_stream( + self, txn: LoggingTransaction, cache_func: _CachedFunction + ) -> None: """Invalidates the entire cache and adds it to the cache stream so slaves will know to invalidate their caches. """ @@ -279,8 +294,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ) def _send_invalidation_to_replication( - self, txn, cache_name: str, keys: Optional[Iterable[Any]] - ): + self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]] + ) -> None: """Notifies replication that given cache has been invalidated. Note that this does *not* invalidate the cache locally. @@ -315,7 +330,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): "instance_name": self._instance_name, "cache_func": cache_name, "keys": keys, - "invalidation_ts": self.clock.time_msec(), + "invalidation_ts": self._clock.time_msec(), }, ) diff --git a/synapse/visibility.py b/synapse/visibility.py index 281cbe4d88..49519eb8f5 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -14,12 +14,7 @@ import logging from typing import Dict, FrozenSet, List, Optional -from synapse.api.constants import ( - AccountDataTypes, - EventTypes, - HistoryVisibility, - Membership, -) +from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.events import EventBase from synapse.events.utils import prune_event from synapse.storage import Storage @@ -87,15 +82,8 @@ async def filter_events_for_client( state_filter=StateFilter.from_types(types), ) - ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user( - user_id, AccountDataTypes.IGNORED_USER_LIST - ) - - ignore_list: FrozenSet[str] = frozenset() - if ignore_dict_content: - ignored_users_dict = ignore_dict_content.get("ignored_users", {}) - if isinstance(ignored_users_dict, dict): - ignore_list = frozenset(ignored_users_dict.keys()) + # Get the users who are ignored by the requesting user. + ignore_list = await storage.main.ignored_users(user_id) erased_senders = await storage.main.are_users_erased(e.sender for e in events) |