diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index a233a9ce03..4c52103b1c 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -25,8 +25,8 @@ logger = logging.getLogger(__name__)
LEGACY_SPAM_CHECKER_WARNING = """
This server is using a spam checker module that is implementing the deprecated spam
checker interface. Please check with the module's maintainer to see if a new version
-supporting Synapse's generic modules system is available.
-For more information, please see https://matrix-org.github.io/synapse/latest/modules.html
+supporting Synapse's generic modules system is available. For more information, please
+see https://matrix-org.github.io/synapse/latest/modules/index.html
---------------------------------------------------------------------------------------"""
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 60059fec3e..41679f7f86 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set
import attr
@@ -422,7 +422,7 @@ class PaginationHandler:
pagin_config: PaginationConfig,
as_client_event: bool = True,
event_filter: Optional[Filter] = None,
- ) -> Dict[str, Any]:
+ ) -> JsonDict:
"""Get messages in a room.
Args:
@@ -431,6 +431,7 @@ class PaginationHandler:
pagin_config: The pagination config rules to apply, if any.
as_client_event: True to get events in client-server format.
event_filter: Filter to apply to results or None
+
Returns:
Pagination API results
"""
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
new file mode 100644
index 0000000000..8e475475ad
--- /dev/null
+++ b/synapse/handlers/relations.py
@@ -0,0 +1,117 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Optional
+
+from synapse.api.errors import SynapseError
+from synapse.types import JsonDict, Requester, StreamToken
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+logger = logging.getLogger(__name__)
+
+
+class RelationsHandler:
+ def __init__(self, hs: "HomeServer"):
+ self._main_store = hs.get_datastores().main
+ self._auth = hs.get_auth()
+ self._clock = hs.get_clock()
+ self._event_handler = hs.get_event_handler()
+ self._event_serializer = hs.get_event_client_serializer()
+
+ async def get_relations(
+ self,
+ requester: Requester,
+ event_id: str,
+ room_id: str,
+ relation_type: Optional[str] = None,
+ event_type: Optional[str] = None,
+ aggregation_key: Optional[str] = None,
+ limit: int = 5,
+ direction: str = "b",
+ from_token: Optional[StreamToken] = None,
+ to_token: Optional[StreamToken] = None,
+ ) -> JsonDict:
+ """Get related events of a event, ordered by topological ordering.
+
+ TODO Accept a PaginationConfig instead of individual pagination parameters.
+
+ Args:
+ requester: The user requesting the relations.
+ event_id: Fetch events that relate to this event ID.
+ room_id: The room the event belongs to.
+ relation_type: Only fetch events with this relation type, if given.
+ event_type: Only fetch events with this event type, if given.
+ aggregation_key: Only fetch events with this aggregation key, if given.
+ limit: Only fetch the most recent `limit` events.
+ direction: Whether to fetch the most recent first (`"b"`) or the
+ oldest first (`"f"`).
+ from_token: Fetch rows from the given token, or from the start if None.
+ to_token: Fetch rows up to the given token, or up to the end if None.
+
+ Returns:
+ The pagination chunk.
+ """
+
+ user_id = requester.user.to_string()
+
+ await self._auth.check_user_in_room_or_world_readable(
+ room_id, user_id, allow_departed_users=True
+ )
+
+ # This gets the original event and checks that a) the event exists and
+ # b) the user is allowed to view it.
+ event = await self._event_handler.get_event(requester.user, room_id, event_id)
+ if event is None:
+ raise SynapseError(404, "Unknown parent event.")
+
+ pagination_chunk = await self._main_store.get_relations_for_event(
+ event_id=event_id,
+ event=event,
+ room_id=room_id,
+ relation_type=relation_type,
+ event_type=event_type,
+ aggregation_key=aggregation_key,
+ limit=limit,
+ direction=direction,
+ from_token=from_token,
+ to_token=to_token,
+ )
+
+ events = await self._main_store.get_events_as_list(
+ [c["event_id"] for c in pagination_chunk.chunk]
+ )
+
+ now = self._clock.time_msec()
+ # Do not bundle aggregations when retrieving the original event because
+ # we want the content before relations are applied to it.
+ original_event = self._event_serializer.serialize_event(
+ event, now, bundle_aggregations=None
+ )
+ # The relations returned for the requested event do include their
+ # bundled aggregations.
+ aggregations = await self._main_store.get_bundled_aggregations(
+ events, requester.user.to_string()
+ )
+ serialized_events = self._event_serializer.serialize_events(
+ events, now, bundle_aggregations=aggregations
+ )
+
+ return_value = await pagination_chunk.to_dict(self._main_store)
+ return_value["chunk"] = serialized_events
+ return_value["original_event"] = original_event
+
+ return return_value
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 0aa3052fd6..c9d6a18bd7 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -28,7 +28,7 @@ from typing import (
import attr
from prometheus_client import Counter
-from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes
+from synapse.api.constants import EventTypes, Membership, ReceiptTypes
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -1601,7 +1601,7 @@ class SyncHandler:
return set(), set(), set(), set()
# 3. Work out which rooms need reporting in the sync response.
- ignored_users = await self._get_ignored_users(user_id)
+ ignored_users = await self.store.ignored_users(user_id)
if since_token:
room_changes = await self._get_rooms_changed(
sync_result_builder, ignored_users
@@ -1627,7 +1627,6 @@ class SyncHandler:
logger.debug("Generating room entry for %s", room_entry.room_id)
await self._generate_room_entry(
sync_result_builder,
- ignored_users,
room_entry,
ephemeral=ephemeral_by_room.get(room_entry.room_id, []),
tags=tags_by_room.get(room_entry.room_id),
@@ -1657,29 +1656,6 @@ class SyncHandler:
newly_left_users,
)
- async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]:
- """Retrieve the users ignored by the given user from their global account_data.
-
- Returns an empty set if
- - there is no global account_data entry for ignored_users
- - there is such an entry, but it's not a JSON object.
- """
- # TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead?
- ignored_account_data = (
- await self.store.get_global_account_data_by_type_for_user(
- user_id=user_id, data_type=AccountDataTypes.IGNORED_USER_LIST
- )
- )
-
- # If there is ignored users account data and it matches the proper type,
- # then use it.
- ignored_users: FrozenSet[str] = frozenset()
- if ignored_account_data:
- ignored_users_data = ignored_account_data.get("ignored_users", {})
- if isinstance(ignored_users_data, dict):
- ignored_users = frozenset(ignored_users_data.keys())
- return ignored_users
-
async def _have_rooms_changed(
self, sync_result_builder: "SyncResultBuilder"
) -> bool:
@@ -2022,7 +1998,6 @@ class SyncHandler:
async def _generate_room_entry(
self,
sync_result_builder: "SyncResultBuilder",
- ignored_users: FrozenSet[str],
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
tags: Optional[Dict[str, Dict[str, Any]]],
@@ -2051,7 +2026,6 @@ class SyncHandler:
Args:
sync_result_builder
- ignored_users: Set of users ignored by user.
room_builder
ephemeral: List of new ephemeral events for room
tags: List of *all* tags for room, or None if there has been
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 8140afcb6b..030898e4d0 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -213,7 +213,7 @@ class BulkPushRuleEvaluator:
if not event.is_state():
ignorers = await self.store.ignored_by(event.sender)
else:
- ignorers = set()
+ ignorers = frozenset()
for uid, rules in rules_by_user.items():
if event.sender == uid:
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index d9a6be43f7..c16078b187 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -51,9 +51,7 @@ class RelationPaginationServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
- self.clock = hs.get_clock()
- self._event_serializer = hs.get_event_client_serializer()
- self.event_handler = hs.get_event_handler()
+ self._relations_handler = hs.get_relations_handler()
async def on_GET(
self,
@@ -65,16 +63,6 @@ class RelationPaginationServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- await self.auth.check_user_in_room_or_world_readable(
- room_id, requester.user.to_string(), allow_departed_users=True
- )
-
- # This gets the original event and checks that a) the event exists and
- # b) the user is allowed to view it.
- event = await self.event_handler.get_event(requester.user, room_id, parent_id)
- if event is None:
- raise SynapseError(404, "Unknown parent event.")
-
limit = parse_integer(request, "limit", default=5)
direction = parse_string(
request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"]
@@ -90,9 +78,9 @@ class RelationPaginationServlet(RestServlet):
if to_token_str:
to_token = await StreamToken.from_string(self.store, to_token_str)
- pagination_chunk = await self.store.get_relations_for_event(
+ result = await self._relations_handler.get_relations(
+ requester=requester,
event_id=parent_id,
- event=event,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
@@ -102,30 +90,7 @@ class RelationPaginationServlet(RestServlet):
to_token=to_token,
)
- events = await self.store.get_events_as_list(
- [c["event_id"] for c in pagination_chunk.chunk]
- )
-
- now = self.clock.time_msec()
- # Do not bundle aggregations when retrieving the original event because
- # we want the content before relations are applied to it.
- original_event = self._event_serializer.serialize_event(
- event, now, bundle_aggregations=None
- )
- # The relations returned for the requested event do include their
- # bundled aggregations.
- aggregations = await self.store.get_bundled_aggregations(
- events, requester.user.to_string()
- )
- serialized_events = self._event_serializer.serialize_events(
- events, now, bundle_aggregations=aggregations
- )
-
- return_value = await pagination_chunk.to_dict(self.store)
- return_value["chunk"] = serialized_events
- return_value["original_event"] = original_event
-
- return 200, return_value
+ return 200, result
class RelationAggregationPaginationServlet(RestServlet):
@@ -245,9 +210,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
- self.clock = hs.get_clock()
- self._event_serializer = hs.get_event_client_serializer()
- self.event_handler = hs.get_event_handler()
+ self._relations_handler = hs.get_relations_handler()
async def on_GET(
self,
@@ -260,18 +223,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- await self.auth.check_user_in_room_or_world_readable(
- room_id,
- requester.user.to_string(),
- allow_departed_users=True,
- )
-
- # This checks that a) the event exists and b) the user is allowed to
- # view it.
- event = await self.event_handler.get_event(requester.user, room_id, parent_id)
- if event is None:
- raise SynapseError(404, "Unknown parent event.")
-
if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'")
@@ -286,9 +237,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
if to_token_str:
to_token = await StreamToken.from_string(self.store, to_token_str)
- result = await self.store.get_relations_for_event(
+ result = await self._relations_handler.get_relations(
+ requester=requester,
event_id=parent_id,
- event=event,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
@@ -298,17 +249,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
to_token=to_token,
)
- events = await self.store.get_events_as_list(
- [c["event_id"] for c in result.chunk]
- )
-
- now = self.clock.time_msec()
- serialized_events = self._event_serializer.serialize_events(events, now)
-
- return_value = await result.to_dict(self.store)
- return_value["chunk"] = serialized_events
-
- return 200, return_value
+ return 200, result
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py
index 872a9e72e8..4cc9c66fbe 100644
--- a/synapse/rest/media/v1/preview_html.py
+++ b/synapse/rest/media/v1/preview_html.py
@@ -16,7 +16,6 @@ import itertools
import logging
import re
from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union
-from urllib import parse as urlparse
if TYPE_CHECKING:
from lxml import etree
@@ -144,9 +143,7 @@ def decode_body(
return etree.fromstring(body, parser)
-def parse_html_to_open_graph(
- tree: "etree.Element", media_uri: str
-) -> Dict[str, Optional[str]]:
+def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
"""
Parse the HTML document into an Open Graph response.
@@ -155,7 +152,6 @@ def parse_html_to_open_graph(
Args:
tree: The parsed HTML document.
- media_url: The URI used to download the body.
Returns:
The Open Graph response as a dictionary.
@@ -209,7 +205,7 @@ def parse_html_to_open_graph(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
- og["og:image"] = rebase_url(meta_image[0], media_uri)
+ og["og:image"] = meta_image[0]
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
@@ -320,37 +316,6 @@ def _iterate_over_text(
)
-def rebase_url(url: str, base: str) -> str:
- """
- Resolves a potentially relative `url` against an absolute `base` URL.
-
- For example:
-
- >>> rebase_url("subpage", "https://example.com/foo/")
- 'https://example.com/foo/subpage'
- >>> rebase_url("sibling", "https://example.com/foo")
- 'https://example.com/sibling'
- >>> rebase_url("/bar", "https://example.com/foo/")
- 'https://example.com/bar'
- >>> rebase_url("https://alice.com/a/", "https://example.com/foo/")
- 'https://alice.com/a'
- """
- base_parts = urlparse.urlparse(base)
- # Convert the parsed URL to a list for (potential) modification.
- url_parts = list(urlparse.urlparse(url))
- # Add a scheme, if one does not exist.
- if not url_parts[0]:
- url_parts[0] = base_parts.scheme or "http"
- # Fix up the hostname, if this is not a data URL.
- if url_parts[0] != "data" and not url_parts[1]:
- url_parts[1] = base_parts.netloc
- # If the path does not start with a /, nest it under the base path's last
- # directory.
- if not url_parts[2].startswith("/"):
- url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts.path) + url_parts[2]
- return urlparse.urlunparse(url_parts)
-
-
def summarize_paragraphs(
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
) -> Optional[str]:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 14ea88b240..d47af8ead6 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -22,7 +22,7 @@ import shutil
import sys
import traceback
from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple
-from urllib import parse as urlparse
+from urllib.parse import urljoin, urlparse, urlsplit
from urllib.request import urlopen
import attr
@@ -44,11 +44,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.oembed import OEmbedProvider
-from synapse.rest.media.v1.preview_html import (
- decode_body,
- parse_html_to_open_graph,
- rebase_url,
-)
+from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph
from synapse.types import JsonDict, UserID
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
@@ -187,7 +183,7 @@ class PreviewUrlResource(DirectServeJsonResource):
ts = self.clock.time_msec()
# XXX: we could move this into _do_preview if we wanted.
- url_tuple = urlparse.urlsplit(url)
+ url_tuple = urlsplit(url)
for entry in self.url_preview_url_blacklist:
match = True
for attrib in entry:
@@ -322,7 +318,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# Parse Open Graph information from the HTML in case the oEmbed
# response failed or is incomplete.
- og_from_html = parse_html_to_open_graph(tree, media_info.uri)
+ og_from_html = parse_html_to_open_graph(tree)
# Compile the Open Graph response by using the scraped
# information from the HTML and overlaying any information
@@ -588,12 +584,17 @@ class PreviewUrlResource(DirectServeJsonResource):
if "og:image" not in og or not og["og:image"]:
return
+ # The image URL from the HTML might be relative to the previewed page,
+ # convert it to an URL which can be requested directly.
+ image_url = og["og:image"]
+ url_parts = urlparse(image_url)
+ if url_parts.scheme != "data":
+ image_url = urljoin(media_info.uri, image_url)
+
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
- image_info = await self._handle_url(
- rebase_url(og["og:image"], media_info.uri), user, allow_data_urls=True
- )
+ image_info = await self._handle_url(image_url, user, allow_data_urls=True)
if _is_media(image_info.media_type):
# TODO: make sure we don't choke on white-on-transparent images
diff --git a/synapse/server.py b/synapse/server.py
index 2fcf18a7a6..380369db92 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -94,6 +94,7 @@ from synapse.handlers.profile import ProfileHandler
from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.register import RegistrationHandler
+from synapse.handlers.relations import RelationsHandler
from synapse.handlers.room import (
RoomContextHandler,
RoomCreationHandler,
@@ -720,6 +721,10 @@ class HomeServer(metaclass=abc.ABCMeta):
return PaginationHandler(self)
@cache_in_self
+ def get_relations_handler(self) -> RelationsHandler:
+ return RelationsHandler(self)
+
+ @cache_in_self
def get_room_context_handler(self) -> RoomContextHandler:
return RoomContextHandler(self)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 99802228c9..9749f0c06e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -41,6 +41,7 @@ from prometheus_client import Histogram
from typing_extensions import Literal
from twisted.enterprise import adbapi
+from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
+from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -732,34 +734,45 @@ class DatabasePool:
Returns:
The result of func
"""
- after_callbacks: List[_CallbackListEntry] = []
- exception_callbacks: List[_CallbackListEntry] = []
- if not current_context():
- logger.warning("Starting db txn '%s' from sentinel context", desc)
+ async def _runInteraction() -> R:
+ after_callbacks: List[_CallbackListEntry] = []
+ exception_callbacks: List[_CallbackListEntry] = []
- try:
- with opentracing.start_active_span(f"db.{desc}"):
- result = await self.runWithConnection(
- self.new_transaction,
- desc,
- after_callbacks,
- exception_callbacks,
- func,
- *args,
- db_autocommit=db_autocommit,
- isolation_level=isolation_level,
- **kwargs,
- )
+ if not current_context():
+ logger.warning("Starting db txn '%s' from sentinel context", desc)
- for after_callback, after_args, after_kwargs in after_callbacks:
- after_callback(*after_args, **after_kwargs)
- except Exception:
- for after_callback, after_args, after_kwargs in exception_callbacks:
- after_callback(*after_args, **after_kwargs)
- raise
+ try:
+ with opentracing.start_active_span(f"db.{desc}"):
+ result = await self.runWithConnection(
+ self.new_transaction,
+ desc,
+ after_callbacks,
+ exception_callbacks,
+ func,
+ *args,
+ db_autocommit=db_autocommit,
+ isolation_level=isolation_level,
+ **kwargs,
+ )
- return cast(R, result)
+ for after_callback, after_args, after_kwargs in after_callbacks:
+ after_callback(*after_args, **after_kwargs)
+
+ return cast(R, result)
+ except Exception:
+ for after_callback, after_args, after_kwargs in exception_callbacks:
+ after_callback(*after_args, **after_kwargs)
+ raise
+
+ # To handle cancellation, we ensure that `after_callback`s and
+ # `exception_callback`s are always run, since the transaction will complete
+ # on another thread regardless of cancellation.
+ #
+ # We also wait until everything above is done before releasing the
+ # `CancelledError`, so that logging contexts won't get used after they have been
+ # finished.
+ return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
async def runWithConnection(
self,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 52146aacc8..9af9f4f18e 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,7 +14,17 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ FrozenSet,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ cast,
+)
from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@@ -365,7 +375,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
@cached(max_entries=5000, iterable=True)
- async def ignored_by(self, user_id: str) -> Set[str]:
+ async def ignored_by(self, user_id: str) -> FrozenSet[str]:
"""
Get users which ignore the given user.
@@ -375,7 +385,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
Return:
The user IDs which ignore the given user.
"""
- return set(
+ return frozenset(
await self.db_pool.simple_select_onecol(
table="ignored_users",
keyvalues={"ignored_user_id": user_id},
@@ -384,6 +394,26 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
)
+ @cached(max_entries=5000, iterable=True)
+ async def ignored_users(self, user_id: str) -> FrozenSet[str]:
+ """
+ Get users which the given user ignores.
+
+ Params:
+ user_id: The user ID which is making the request.
+
+ Return:
+ The user IDs which are ignored by the given user.
+ """
+ return frozenset(
+ await self.db_pool.simple_select_onecol(
+ table="ignored_users",
+ keyvalues={"ignorer_user_id": user_id},
+ retcol="ignored_user_id",
+ desc="ignored_users",
+ )
+ )
+
def process_replication_rows(
self,
stream_name: str,
@@ -529,6 +559,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
else:
currently_ignored_users = set()
+ # If the data has not changed, nothing to do.
+ if previously_ignored_users == currently_ignored_users:
+ return
+
# Delete entries which are no longer ignored.
self.db_pool.simple_delete_many_txn(
txn,
@@ -551,6 +585,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
# Invalidate the cache for any ignored users which were added or removed.
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+ self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
async def purge_account_data_for_user(self, user_id: str) -> None:
"""
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index d6a2df1afe..2d7511d613 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
+ EventsStreamRow,
)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@@ -31,6 +32,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
+from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def get_all_updated_caches_txn(txn):
+ def get_all_updated_caches_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
@@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"get_all_updated_caches", get_all_updated_caches_txn
)
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+ ) -> None:
if stream_name == EventsStream.NAME:
for row in rows:
self._process_event_stream_row(token, row)
@@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows)
- def _process_event_stream_row(self, token, row):
+ def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data
if row.type == EventsStreamEventRow.TypeId:
+ assert isinstance(data, EventsStreamEventRow)
self._invalidate_caches_for_event(
token,
data.event_id,
@@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
- self._curr_state_delta_stream_cache.entity_has_changed(
- row.data.room_id, token
- )
+ assert isinstance(data, EventsStreamCurrentStateRow)
+ self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)
if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate(
@@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
def _invalidate_caches_for_event(
self,
- stream_ordering,
- event_id,
- room_id,
- etype,
- state_key,
- redacts,
- relates_to,
- backfilled,
- ):
+ stream_ordering: int,
+ event_id: str,
+ room_id: str,
+ etype: str,
+ state_key: Optional[str],
+ redacts: Optional[str],
+ relates_to: Optional[str],
+ backfilled: bool,
+ ) -> None:
self._invalidate_get_event_cache(event_id)
self.have_seen_event.invalidate((room_id, event_id))
@@ -207,7 +213,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_thread_summary.invalidate((relates_to,))
self.get_thread_participated.invalidate((relates_to,))
- async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
+ async def invalidate_cache_and_stream(
+ self, cache_name: str, keys: Tuple[Any, ...]
+ ) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -227,7 +235,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
keys,
)
- def _invalidate_cache_and_stream(self, txn, cache_func, keys):
+ def _invalidate_cache_and_stream(
+ self,
+ txn: LoggingTransaction,
+ cache_func: _CachedFunction,
+ keys: Tuple[Any, ...],
+ ) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -238,7 +251,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
- def _invalidate_all_cache_and_stream(self, txn, cache_func):
+ def _invalidate_all_cache_and_stream(
+ self, txn: LoggingTransaction, cache_func: _CachedFunction
+ ) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches.
"""
@@ -279,8 +294,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
)
def _send_invalidation_to_replication(
- self, txn, cache_name: str, keys: Optional[Iterable[Any]]
- ):
+ self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
+ ) -> None:
"""Notifies replication that given cache has been invalidated.
Note that this does *not* invalidate the cache locally.
@@ -315,7 +330,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"instance_name": self._instance_name,
"cache_func": cache_name,
"keys": keys,
- "invalidation_ts": self.clock.time_msec(),
+ "invalidation_ts": self._clock.time_msec(),
},
)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 281cbe4d88..49519eb8f5 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -14,12 +14,7 @@
import logging
from typing import Dict, FrozenSet, List, Optional
-from synapse.api.constants import (
- AccountDataTypes,
- EventTypes,
- HistoryVisibility,
- Membership,
-)
+from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.events import EventBase
from synapse.events.utils import prune_event
from synapse.storage import Storage
@@ -87,15 +82,8 @@ async def filter_events_for_client(
state_filter=StateFilter.from_types(types),
)
- ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user(
- user_id, AccountDataTypes.IGNORED_USER_LIST
- )
-
- ignore_list: FrozenSet[str] = frozenset()
- if ignore_dict_content:
- ignored_users_dict = ignore_dict_content.get("ignored_users", {})
- if isinstance(ignored_users_dict, dict):
- ignore_list = frozenset(ignored_users_dict.keys())
+ # Get the users who are ignored by the requesting user.
+ ignore_list = await storage.main.ignored_users(user_id)
erased_senders = await storage.main.are_users_erased(e.sender for e in events)
|