diff --git a/synapse/__init__.py b/synapse/__init__.py
index 870707f476..88aef1889c 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -68,7 +68,7 @@ try:
except ImportError:
pass
-__version__ = "1.55.0rc1"
+__version__ = "1.55.2"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index e4dc04c0b4..0f75e7b9d4 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -261,7 +261,10 @@ class SynapseHomeServer(HomeServer):
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
if name == "metrics" and self.config.metrics.enable_metrics:
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
+ metrics_resource: Resource = MetricsResource(RegistryProxy)
+ if compress:
+ metrics_resource = gz_wrap(metrics_resource)
+ resources[METRICS_PREFIX] = metrics_resource
if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
@@ -348,6 +351,23 @@ def setup(config_options: List[str]) -> SynapseHomeServer:
if config.server.gc_seconds:
synapse.metrics.MIN_TIME_BETWEEN_GCS = config.server.gc_seconds
+ if (
+ config.registration.enable_registration
+ and not config.registration.enable_registration_without_verification
+ ):
+ if (
+ not config.captcha.enable_registration_captcha
+ and not config.registration.registrations_require_3pid
+ and not config.registration.registration_requires_token
+ ):
+
+ raise ConfigError(
+ "You have enabled open registration without any verification. This is a known vector for "
+ "spam and abuse. If you would like to allow public registration, please consider adding email, "
+ "captcha, or token-based verification. Otherwise this check can be removed by setting the "
+ "`enable_registration_without_verification` config option to `true`."
+ )
+
hs = SynapseHomeServer(
config.server.server_name,
config=config,
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 06ccf15cd9..d7f2219f53 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -37,6 +37,12 @@ DEFAULT_CONFIG = """\
# 'txn_limit' gives the maximum number of transactions to run per connection
# before reconnecting. Defaults to 0, which means no limit.
#
+# 'allow_unsafe_locale' is an option specific to Postgres. Under the default behavior, Synapse will refuse to
+# start if the postgres db is set to a non-C locale. You can override this behavior (which is *not* recommended)
+# by setting 'allow_unsafe_locale' to true. Note that doing so may corrupt your database. You can find more information
+# here: https://matrix-org.github.io/synapse/latest/postgres.html#fixing-incorrect-collate-or-ctype and here:
+# https://wiki.postgresql.org/wiki/Locale_data_changes
+#
# 'args' gives options which are passed through to the database engine,
# except for options starting 'cp_', which are used to configure the Twisted
# connection pool. For a reference to valid arguments, see:
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index ea9b50fe97..40fb329a7f 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -33,6 +33,10 @@ class RegistrationConfig(Config):
str(config["disable_registration"])
)
+ self.enable_registration_without_verification = strtobool(
+ str(config.get("enable_registration_without_verification", False))
+ )
+
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
@@ -207,10 +211,18 @@ class RegistrationConfig(Config):
# Registration can be rate-limited using the parameters in the "Ratelimiting"
# section of this file.
- # Enable registration for new users.
+ # Enable registration for new users. Defaults to 'false'. It is highly recommended that if you enable registration,
+ # you use either captcha, email, or token-based verification to verify that new users are not bots. In order to enable registration
+ # without any verification, you must also set `enable_registration_without_verification`, found below.
#
#enable_registration: false
+ # Enable registration without email or captcha verification. Note: this option is *not* recommended,
+ # as registration without verification is a known vector for spam and abuse. Defaults to false. Has no effect
+ # unless `enable_registration` is also enabled.
+ #
+ #enable_registration_without_verification: true
+
# Time that a user's session remains valid for, after they log in.
#
# Note that this is not currently compatible with guest logins.
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 49cd0a4f19..38de4b8000 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -676,6 +676,10 @@ class ServerConfig(Config):
):
raise ConfigError("'custom_template_directory' must be a string")
+ self.use_account_validity_in_account_status: bool = (
+ config.get("use_account_validity_in_account_status") or False
+ )
+
def has_tls_listener(self) -> bool:
return any(listener.tls for listener in self.listeners)
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index a233a9ce03..4c52103b1c 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -25,8 +25,8 @@ logger = logging.getLogger(__name__)
LEGACY_SPAM_CHECKER_WARNING = """
This server is using a spam checker module that is implementing the deprecated spam
checker interface. Please check with the module's maintainer to see if a new version
-supporting Synapse's generic modules system is available.
-For more information, please see https://matrix-org.github.io/synapse/latest/modules.html
+supporting Synapse's generic modules system is available. For more information, please
+see https://matrix-org.github.io/synapse/latest/modules/index.html
---------------------------------------------------------------------------------------"""
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 60904a55f5..cd80fcf9d1 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -21,7 +21,6 @@ from typing import (
Awaitable,
Callable,
Collection,
- Dict,
List,
Optional,
Tuple,
@@ -31,7 +30,7 @@ from typing import (
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.media_storage import ReadableFileWrapper
from synapse.spam_checker_api import RegistrationBehaviour
-from synapse.types import RoomAlias
+from synapse.types import RoomAlias, UserProfile
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
@@ -50,7 +49,7 @@ USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[[str, str, str, str], Awaitable[bo
USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]]
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]]
USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
-CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]]
+CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]]
LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
[
Optional[dict],
@@ -383,7 +382,7 @@ class SpamChecker:
return True
- async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
+ async def check_username_for_spam(self, user_profile: UserProfile) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
If the server considers a username spammy, then it will not be included in
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index a0520068e0..7120062127 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -38,8 +38,8 @@ from synapse.util.frozenutils import unfreeze
from . import EventBase
if TYPE_CHECKING:
+ from synapse.handlers.relations import BundledAggregations
from synapse.server import HomeServer
- from synapse.storage.databases.main.relations import BundledAggregations
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 482bbdd867..c7400c737b 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -22,7 +22,6 @@ from typing import (
Callable,
Collection,
Dict,
- Iterable,
List,
Optional,
Tuple,
@@ -577,10 +576,10 @@ class FederationServer(FederationBase):
async def _on_context_state_request_compute(
self, room_id: str, event_id: Optional[str]
) -> Dict[str, list]:
+ pdus: Collection[EventBase]
if event_id:
- pdus: Iterable[EventBase] = await self.handler.get_state_for_pdu(
- room_id, event_id
- )
+ event_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
+ pdus = await self.store.get_events_as_list(event_ids)
else:
pdus = (await self.state.get_current_state(room_id)).values()
@@ -1093,7 +1092,7 @@ class FederationServer(FederationBase):
# has started processing).
while True:
async with lock:
- logger.info("handling received PDU: %s", event)
+ logger.info("handling received PDU in room %s: %s", room_id, event)
try:
with nested_logging_context(event.event_id):
await self._federation_event_handler.on_receive_pdu(
diff --git a/synapse/handlers/account.py b/synapse/handlers/account.py
index d5badf635b..c05a14304c 100644
--- a/synapse/handlers/account.py
+++ b/synapse/handlers/account.py
@@ -26,6 +26,10 @@ class AccountHandler:
self._main_store = hs.get_datastores().main
self._is_mine = hs.is_mine
self._federation_client = hs.get_federation_client()
+ self._use_account_validity_in_account_status = (
+ hs.config.server.use_account_validity_in_account_status
+ )
+ self._account_validity_handler = hs.get_account_validity_handler()
async def get_account_statuses(
self,
@@ -106,6 +110,13 @@ class AccountHandler:
"deactivated": userinfo.is_deactivated,
}
+ if self._use_account_validity_in_account_status:
+ status[
+ "org.matrix.expired"
+ ] = await self._account_validity_handler.is_user_expired(
+ user_id.to_string()
+ )
+
return status
async def _get_remote_account_statuses(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index db39aeabde..350ec9c03a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -950,54 +950,35 @@ class FederationHandler:
return event
- async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
- """Returns the state at the event. i.e. not including said event."""
-
- event = await self.store.get_event(event_id, check_room_id=room_id)
-
- state_groups = await self.state_store.get_state_groups(room_id, [event_id])
-
- if state_groups:
- _, state = list(state_groups.items()).pop()
- results = {(e.type, e.state_key): e for e in state}
-
- if event.is_state():
- # Get previous state
- if "replaces_state" in event.unsigned:
- prev_id = event.unsigned["replaces_state"]
- if prev_id != event.event_id:
- prev_event = await self.store.get_event(prev_id)
- results[(event.type, event.state_key)] = prev_event
- else:
- del results[(event.type, event.state_key)]
-
- res = list(results.values())
- return res
- else:
- return []
-
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event."""
event = await self.store.get_event(event_id, check_room_id=room_id)
+ if event.internal_metadata.outlier:
+ raise NotFoundError("State not known at event %s" % (event_id,))
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
- if state_groups:
- _, state = list(state_groups.items()).pop()
- results = state
+ # get_state_groups_ids should return exactly one result
+ assert len(state_groups) == 1
- if event.is_state():
- # Get previous state
- if "replaces_state" in event.unsigned:
- prev_id = event.unsigned["replaces_state"]
- if prev_id != event.event_id:
- results[(event.type, event.state_key)] = prev_id
- else:
- results.pop((event.type, event.state_key), None)
+ state_map = next(iter(state_groups.values()))
- return list(results.values())
- else:
- return []
+ state_key = event.get_state_key()
+ if state_key is not None:
+ # the event was not rejected (get_event raises a NotFoundError for rejected
+ # events) so the state at the event should include the event itself.
+ assert (
+ state_map.get((event.type, state_key)) == event.event_id
+ ), "State at event did not include event itself"
+
+ # ... but we need the state *before* that event
+ if "replaces_state" in event.unsigned:
+ prev_id = event.unsigned["replaces_state"]
+ state_map[(event.type, state_key)] = prev_id
+ else:
+ del state_map[(event.type, state_key)]
+
+ return list(state_map.values())
async def on_backfill_request(
self, origin: str, room_id: str, pdu_list: List[str], limit: int
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 848acbee36..a3021d4ada 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -493,6 +493,7 @@ class EventCreationHandler:
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
+ state_event_ids: Optional[List[str]] = None,
require_consent: bool = True,
outlier: bool = False,
historical: bool = False,
@@ -527,6 +528,15 @@ class EventCreationHandler:
If non-None, prev_event_ids must also be provided.
+ state_event_ids:
+ The full state at a given event. This is used particularly by the MSC2716
+ /batch_send endpoint. One use case is with insertion events which float at
+ the beginning of a historical batch and don't have any `prev_events` to
+ derive from; we add all of these state events as the explicit state so the
+ rest of the historical batch can inherit the same state and state_group.
+ This should normally be left as None, which will cause the auth_event_ids
+ to be calculated based on the room state at the prev_events.
+
require_consent: Whether to check if the requester has
consented to the privacy policy.
@@ -612,6 +622,7 @@ class EventCreationHandler:
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
+ state_event_ids=state_event_ids,
depth=depth,
)
@@ -771,7 +782,7 @@ class EventCreationHandler:
event_dict: dict,
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
- auth_event_ids: Optional[List[str]] = None,
+ state_event_ids: Optional[List[str]] = None,
ratelimit: bool = True,
txn_id: Optional[str] = None,
ignore_shadow_ban: bool = False,
@@ -795,12 +806,14 @@ class EventCreationHandler:
The event IDs to use as the prev events.
Should normally be left as None to automatically request them
from the database.
- auth_event_ids:
- The event ids to use as the auth_events for the new event.
- Should normally be left as None, which will cause them to be calculated
- based on the room state at the prev_events.
-
- If non-None, prev_event_ids must also be provided.
+ state_event_ids:
+ The full state at a given event. This is used particularly by the MSC2716
+ /batch_send endpoint. One use case is with insertion events which float at
+ the beginning of a historical batch and don't have any `prev_events` to
+ derive from; we add all of these state events as the explicit state so the
+ rest of the historical batch can inherit the same state and state_group.
+ This should normally be left as None, which will cause the auth_event_ids
+ to be calculated based on the room state at the prev_events.
ratelimit: Whether to rate limit this send.
txn_id: The transaction ID.
ignore_shadow_ban: True if shadow-banned users should be allowed to
@@ -856,8 +869,9 @@ class EventCreationHandler:
requester,
event_dict,
txn_id=txn_id,
+ allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
- auth_event_ids=auth_event_ids,
+ state_event_ids=state_event_ids,
outlier=outlier,
historical=historical,
depth=depth,
@@ -893,6 +907,7 @@ class EventCreationHandler:
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
+ state_event_ids: Optional[List[str]] = None,
depth: Optional[int] = None,
) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client
@@ -915,6 +930,15 @@ class EventCreationHandler:
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
+ state_event_ids:
+ The full state at a given event. This is used particularly by the MSC2716
+ /batch_send endpoint. One use case is with insertion events which float at
+ the beginning of a historical batch and don't have any `prev_events` to
+ derive from; we add all of these state events as the explicit state so the
+ rest of the historical batch can inherit the same state and state_group.
+ This should normally be left as None, which will cause the auth_event_ids
+ to be calculated based on the room state at the prev_events.
+
depth: Override the depth used to order the event in the DAG.
Should normally be set to None, which will cause the depth to be calculated
based on the prev_events.
@@ -922,31 +946,26 @@ class EventCreationHandler:
Returns:
Tuple of created event, context
"""
- # Strip down the auth_event_ids to only what we need to auth the event.
+ # Strip down the state_event_ids to only what we need to auth the event.
# For example, we don't need extra m.room.member that don't match event.sender
- full_state_ids_at_event = None
- if auth_event_ids is not None:
- # If auth events are provided, prev events must be also.
+ if state_event_ids is not None:
+ # Do a quick check to make sure that prev_event_ids is present to
+ # make the type-checking around `builder.build` happy.
# prev_event_ids could be an empty array though.
assert prev_event_ids is not None
- # Copy the full auth state before it stripped down
- full_state_ids_at_event = auth_event_ids.copy()
-
temp_event = await builder.build(
prev_event_ids=prev_event_ids,
- auth_event_ids=auth_event_ids,
+ auth_event_ids=state_event_ids,
depth=depth,
)
- auth_events = await self.store.get_events_as_list(auth_event_ids)
+ state_events = await self.store.get_events_as_list(state_event_ids)
# Create a StateMap[str]
- auth_event_state_map = {
- (e.type, e.state_key): e.event_id for e in auth_events
- }
- # Actually strip down and use the necessary auth events
+ state_map = {(e.type, e.state_key): e.event_id for e in state_events}
+ # Actually strip down and only use the necessary auth events
auth_event_ids = self._event_auth_handler.compute_auth_events(
event=temp_event,
- current_state_ids=auth_event_state_map,
+ current_state_ids=state_map,
for_verification=False,
)
@@ -989,12 +1008,16 @@ class EventCreationHandler:
context = EventContext.for_outlier()
elif (
event.type == EventTypes.MSC2716_INSERTION
- and full_state_ids_at_event
+ and state_event_ids
and builder.internal_metadata.is_historical()
):
+ # Add explicit state to the insertion event so it has state to derive
+ # from even though it's floating with no `prev_events`. The rest of
+ # the batch can derive from this state and state_group.
+ #
# TODO(faster_joins): figure out how this works, and make sure that the
# old state is complete.
- old_state = await self.store.get_events_as_list(full_state_ids_at_event)
+ old_state = await self.store.get_events_as_list(state_event_ids)
context = await self.state.compute_event_context(event, old_state=old_state)
else:
context = await self.state.compute_event_context(event)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 60059fec3e..876b879483 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set
import attr
@@ -134,6 +134,7 @@ class PaginationHandler:
self.clock = hs.get_clock()
self._server_name = hs.hostname
self._room_shutdown_handler = hs.get_room_shutdown_handler()
+ self._relations_handler = hs.get_relations_handler()
self.pagination_lock = ReadWriteLock()
# IDs of rooms in which there currently an active purge *or delete* operation.
@@ -422,7 +423,7 @@ class PaginationHandler:
pagin_config: PaginationConfig,
as_client_event: bool = True,
event_filter: Optional[Filter] = None,
- ) -> Dict[str, Any]:
+ ) -> JsonDict:
"""Get messages in a room.
Args:
@@ -431,6 +432,7 @@ class PaginationHandler:
pagin_config: The pagination config rules to apply, if any.
as_client_event: True to get events in client-server format.
event_filter: Filter to apply to results or None
+
Returns:
Pagination API results
"""
@@ -538,7 +540,9 @@ class PaginationHandler:
state_dict = await self.store.get_events(list(state_ids.values()))
state = state_dict.values()
- aggregations = await self.store.get_bundled_aggregations(events, user_id)
+ aggregations = await self._relations_handler.get_bundled_aggregations(
+ events, user_id
+ )
time_now = self.clock.time_msec()
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 6554c0d3c2..239b0aa744 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -336,12 +336,18 @@ class ProfileHandler:
"""Check that the size and content type of the avatar at the given MXC URI are
within the configured limits.
+ If the given `mxc` is empty, no checks are performed. (Users are always able to
+ unset their avatar.)
+
Args:
mxc: The MXC URI at which the avatar can be found.
Returns:
A boolean indicating whether the file can be allowed to be set as an avatar.
"""
+ if mxc == "":
+ return True
+
if not self.max_avatar_size and not self.allowed_avatar_mimetypes:
return True
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
new file mode 100644
index 0000000000..73217d135d
--- /dev/null
+++ b/synapse/handlers/relations.py
@@ -0,0 +1,271 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast
+
+import attr
+from frozendict import frozendict
+
+from synapse.api.constants import RelationTypes
+from synapse.api.errors import SynapseError
+from synapse.events import EventBase
+from synapse.types import JsonDict, Requester, StreamToken
+from synapse.visibility import filter_events_for_client
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+ from synapse.storage.databases.main import DataStore
+
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _ThreadAggregation:
+ # The latest event in the thread.
+ latest_event: EventBase
+ # The latest edit to the latest event in the thread.
+ latest_edit: Optional[EventBase]
+ # The total number of events in the thread.
+ count: int
+ # True if the current user has sent an event to the thread.
+ current_user_participated: bool
+
+
+@attr.s(slots=True, auto_attribs=True)
+class BundledAggregations:
+ """
+ The bundled aggregations for an event.
+
+ Some values require additional processing during serialization.
+ """
+
+ annotations: Optional[JsonDict] = None
+ references: Optional[JsonDict] = None
+ replace: Optional[EventBase] = None
+ thread: Optional[_ThreadAggregation] = None
+
+ def __bool__(self) -> bool:
+ return bool(self.annotations or self.references or self.replace or self.thread)
+
+
+class RelationsHandler:
+ def __init__(self, hs: "HomeServer"):
+ self._main_store = hs.get_datastores().main
+ self._storage = hs.get_storage()
+ self._auth = hs.get_auth()
+ self._clock = hs.get_clock()
+ self._event_handler = hs.get_event_handler()
+ self._event_serializer = hs.get_event_client_serializer()
+
+ async def get_relations(
+ self,
+ requester: Requester,
+ event_id: str,
+ room_id: str,
+ relation_type: Optional[str] = None,
+ event_type: Optional[str] = None,
+ aggregation_key: Optional[str] = None,
+ limit: int = 5,
+ direction: str = "b",
+ from_token: Optional[StreamToken] = None,
+ to_token: Optional[StreamToken] = None,
+ ) -> JsonDict:
+ """Get related events of a event, ordered by topological ordering.
+
+ TODO Accept a PaginationConfig instead of individual pagination parameters.
+
+ Args:
+ requester: The user requesting the relations.
+ event_id: Fetch events that relate to this event ID.
+ room_id: The room the event belongs to.
+ relation_type: Only fetch events with this relation type, if given.
+ event_type: Only fetch events with this event type, if given.
+ aggregation_key: Only fetch events with this aggregation key, if given.
+ limit: Only fetch the most recent `limit` events.
+ direction: Whether to fetch the most recent first (`"b"`) or the
+ oldest first (`"f"`).
+ from_token: Fetch rows from the given token, or from the start if None.
+ to_token: Fetch rows up to the given token, or up to the end if None.
+
+ Returns:
+ The pagination chunk.
+ """
+
+ user_id = requester.user.to_string()
+
+ # TODO Properly handle a user leaving a room.
+ (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
+ room_id, user_id, allow_departed_users=True
+ )
+
+ # This gets the original event and checks that a) the event exists and
+ # b) the user is allowed to view it.
+ event = await self._event_handler.get_event(requester.user, room_id, event_id)
+ if event is None:
+ raise SynapseError(404, "Unknown parent event.")
+
+ pagination_chunk = await self._main_store.get_relations_for_event(
+ event_id=event_id,
+ event=event,
+ room_id=room_id,
+ relation_type=relation_type,
+ event_type=event_type,
+ aggregation_key=aggregation_key,
+ limit=limit,
+ direction=direction,
+ from_token=from_token,
+ to_token=to_token,
+ )
+
+ events = await self._main_store.get_events_as_list(
+ [c["event_id"] for c in pagination_chunk.chunk]
+ )
+
+ events = await filter_events_for_client(
+ self._storage, user_id, events, is_peeking=(member_event_id is None)
+ )
+
+ now = self._clock.time_msec()
+ # Do not bundle aggregations when retrieving the original event because
+ # we want the content before relations are applied to it.
+ original_event = self._event_serializer.serialize_event(
+ event, now, bundle_aggregations=None
+ )
+ # The relations returned for the requested event do include their
+ # bundled aggregations.
+ aggregations = await self.get_bundled_aggregations(
+ events, requester.user.to_string()
+ )
+ serialized_events = self._event_serializer.serialize_events(
+ events, now, bundle_aggregations=aggregations
+ )
+
+ return_value = await pagination_chunk.to_dict(self._main_store)
+ return_value["chunk"] = serialized_events
+ return_value["original_event"] = original_event
+
+ return return_value
+
+ async def _get_bundled_aggregation_for_event(
+ self, event: EventBase, user_id: str
+ ) -> Optional[BundledAggregations]:
+ """Generate bundled aggregations for an event.
+
+ Note that this does not use a cache, but depends on cached methods.
+
+ Args:
+ event: The event to calculate bundled aggregations for.
+ user_id: The user requesting the bundled aggregations.
+
+ Returns:
+ The bundled aggregations for an event, if bundled aggregations are
+ enabled and the event can have bundled aggregations.
+ """
+
+ # Do not bundle aggregations for an event which represents an edit or an
+ # annotation. It does not make sense for them to have related events.
+ relates_to = event.content.get("m.relates_to")
+ if isinstance(relates_to, (dict, frozendict)):
+ relation_type = relates_to.get("rel_type")
+ if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+ return None
+
+ event_id = event.event_id
+ room_id = event.room_id
+
+ # The bundled aggregations to include, a mapping of relation type to a
+ # type-specific value. Some types include the direct return type here
+ # while others need more processing during serialization.
+ aggregations = BundledAggregations()
+
+ annotations = await self._main_store.get_aggregation_groups_for_event(
+ event_id, room_id
+ )
+ if annotations.chunk:
+ aggregations.annotations = await annotations.to_dict(
+ cast("DataStore", self)
+ )
+
+ references = await self._main_store.get_relations_for_event(
+ event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
+ )
+ if references.chunk:
+ aggregations.references = await references.to_dict(cast("DataStore", self))
+
+ # Store the bundled aggregations in the event metadata for later use.
+ return aggregations
+
+ async def get_bundled_aggregations(
+ self, events: Iterable[EventBase], user_id: str
+ ) -> Dict[str, BundledAggregations]:
+ """Generate bundled aggregations for events.
+
+ Args:
+ events: The iterable of events to calculate bundled aggregations for.
+ user_id: The user requesting the bundled aggregations.
+
+ Returns:
+ A map of event ID to the bundled aggregation for the event. Not all
+ events may have bundled aggregations in the results.
+ """
+ # De-duplicate events by ID to handle the same event requested multiple times.
+ #
+ # State events do not get bundled aggregations.
+ events_by_id = {
+ event.event_id: event for event in events if not event.is_state()
+ }
+
+ # event ID -> bundled aggregation in non-serialized form.
+ results: Dict[str, BundledAggregations] = {}
+
+ # Fetch other relations per event.
+ for event in events_by_id.values():
+ event_result = await self._get_bundled_aggregation_for_event(event, user_id)
+ if event_result:
+ results[event.event_id] = event_result
+
+ # Fetch any edits (but not for redacted events).
+ edits = await self._main_store.get_applicable_edits(
+ [
+ event_id
+ for event_id, event in events_by_id.items()
+ if not event.internal_metadata.is_redacted()
+ ]
+ )
+ for event_id, edit in edits.items():
+ results.setdefault(event_id, BundledAggregations()).replace = edit
+
+ # Fetch thread summaries.
+ summaries = await self._main_store.get_thread_summaries(events_by_id.keys())
+ # Only fetch participated for a limited selection based on what had
+ # summaries.
+ participated = await self._main_store.get_threads_participated(
+ [event_id for event_id, summary in summaries.items() if summary], user_id
+ )
+ for event_id, summary in summaries.items():
+ if summary:
+ thread_count, latest_thread_event, edit = summary
+ results.setdefault(
+ event_id, BundledAggregations()
+ ).thread = _ThreadAggregation(
+ latest_event=latest_thread_event,
+ latest_edit=edit,
+ count=thread_count,
+ # If there's a thread summary it must also exist in the
+ # participated dictionary.
+ current_user_participated=participated[event_id],
+ )
+
+ return results
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index b9735631fc..092e185c99 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -60,8 +60,8 @@ from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers.federation import get_domains_from_state
+from synapse.handlers.relations import BundledAggregations
from synapse.rest.admin._base import assert_user_is_admin
-from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.state import StateFilter
from synapse.streams import EventSource
from synapse.types import (
@@ -1118,6 +1118,7 @@ class RoomContextHandler:
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self.state_store = self.storage.state
+ self._relations_handler = hs.get_relations_handler()
async def get_event_context(
self,
@@ -1190,7 +1191,7 @@ class RoomContextHandler:
event = filtered[0]
# Fetch the aggregations.
- aggregations = await self.store.get_bundled_aggregations(
+ aggregations = await self._relations_handler.get_bundled_aggregations(
itertools.chain(events_before, (event,), events_after),
user.to_string(),
)
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index abbf7b7b27..a0255bd143 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -121,12 +121,11 @@ class RoomBatchHandler:
return create_requester(user_id, app_service=app_service)
- async def get_most_recent_auth_event_ids_from_event_id_list(
+ async def get_most_recent_full_state_ids_from_event_id_list(
self, event_ids: List[str]
) -> List[str]:
- """Find the most recent auth event ids (derived from state events) that
- allowed that message to be sent. We will use this as a base
- to auth our historical messages against.
+ """Find the most recent event_id and grab the full state at that event.
+ We will use this as a base to auth our historical messages against.
Args:
event_ids: List of event ID's to look at
@@ -136,38 +135,37 @@ class RoomBatchHandler:
"""
(
- most_recent_prev_event_id,
+ most_recent_event_id,
_,
) = await self.store.get_max_depth_of(event_ids)
# mapping from (type, state_key) -> state_event_id
prev_state_map = await self.state_store.get_state_ids_for_event(
- most_recent_prev_event_id
+ most_recent_event_id
)
# List of state event ID's
- prev_state_ids = list(prev_state_map.values())
- auth_event_ids = prev_state_ids
+ full_state_ids = list(prev_state_map.values())
- return auth_event_ids
+ return full_state_ids
async def persist_state_events_at_start(
self,
state_events_at_start: List[JsonDict],
room_id: str,
- initial_auth_event_ids: List[str],
+ initial_state_event_ids: List[str],
app_service_requester: Requester,
) -> List[str]:
"""Takes all `state_events_at_start` event dictionaries and creates/persists
- them as floating state events which don't resolve into the current room state.
- They are floating because they reference a fake prev_event which doesn't connect
- to the normal DAG at all.
+ them in a floating state event chain which don't resolve into the current room
+ state. They are floating because they reference no prev_events and are marked
+ as outliers which disconnects them from the normal DAG.
Args:
state_events_at_start:
room_id: Room where you want the events persisted in.
- initial_auth_event_ids: These will be the auth_events for the first
- state event created. Each event created afterwards will be
- added to the list of auth events for the next state event
- created.
+ initial_state_event_ids:
+ The base set of state for the historical batch which the floating
+ state chain will derive from. This should probably be the state
+ from the `prev_event` defined by `/batch_send?prev_event_id=$abc`.
app_service_requester: The requester of an application service.
Returns:
@@ -176,7 +174,7 @@ class RoomBatchHandler:
assert app_service_requester.app_service
state_event_ids_at_start = []
- auth_event_ids = initial_auth_event_ids.copy()
+ state_event_ids = initial_state_event_ids.copy()
# Make the state events float off on their own by specifying no
# prev_events for the first one in the chain so we don't have a bunch of
@@ -189,9 +187,7 @@ class RoomBatchHandler:
)
logger.debug(
- "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s",
- state_event,
- auth_event_ids,
+ "RoomBatchSendEventRestServlet inserting state_event=%s", state_event
)
event_dict = {
@@ -217,16 +213,26 @@ class RoomBatchHandler:
room_id=room_id,
action=membership,
content=event_dict["content"],
+ # Mark as an outlier to disconnect it from the normal DAG
+ # and not show up between batches of history.
outlier=True,
historical=True,
- # Only the first event in the chain should be floating.
+ # Only the first event in the state chain should be floating.
# The rest should hang off each other in a chain.
allow_no_prev_events=index == 0,
prev_event_ids=prev_event_ids_for_state_chain,
+ # Since each state event is marked as an outlier, the
+ # `EventContext.for_outlier()` won't have any `state_ids`
+ # set and therefore can't derive any state even though the
+ # prev_events are set. Also since the first event in the
+ # state chain is floating with no `prev_events`, it can't
+ # derive state from anywhere automatically. So we need to
+ # set some state explicitly.
+ #
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same
# reference and also update in the event when we append later.
- auth_event_ids=auth_event_ids.copy(),
+ state_event_ids=state_event_ids.copy(),
)
else:
# TODO: Add some complement tests that adds state that is not member joins
@@ -240,21 +246,31 @@ class RoomBatchHandler:
state_event["sender"], app_service_requester.app_service
),
event_dict,
+ # Mark as an outlier to disconnect it from the normal DAG
+ # and not show up between batches of history.
outlier=True,
historical=True,
- # Only the first event in the chain should be floating.
+ # Only the first event in the state chain should be floating.
# The rest should hang off each other in a chain.
allow_no_prev_events=index == 0,
prev_event_ids=prev_event_ids_for_state_chain,
+ # Since each state event is marked as an outlier, the
+ # `EventContext.for_outlier()` won't have any `state_ids`
+ # set and therefore can't derive any state even though the
+ # prev_events are set. Also since the first event in the
+ # state chain is floating with no `prev_events`, it can't
+ # derive state from anywhere automatically. So we need to
+ # set some state explicitly.
+ #
# Make sure to use a copy of this list because we modify it
# later in the loop here. Otherwise it will be the same
# reference and also update in the event when we append later.
- auth_event_ids=auth_event_ids.copy(),
+ state_event_ids=state_event_ids.copy(),
)
event_id = event.event_id
state_event_ids_at_start.append(event_id)
- auth_event_ids.append(event_id)
+ state_event_ids.append(event_id)
# Connect all the state in a floating chain
prev_event_ids_for_state_chain = [event_id]
@@ -265,7 +281,7 @@ class RoomBatchHandler:
events_to_create: List[JsonDict],
room_id: str,
inherited_depth: int,
- auth_event_ids: List[str],
+ initial_state_event_ids: List[str],
app_service_requester: Requester,
) -> List[str]:
"""Create and persists all events provided sequentially. Handles the
@@ -281,8 +297,10 @@ class RoomBatchHandler:
room_id: Room where you want the events persisted in.
inherited_depth: The depth to create the events at (you will
probably by calling inherit_depth_from_prev_ids(...)).
- auth_event_ids: Define which events allow you to create the given
- event in the room.
+ initial_state_event_ids:
+ This is used to set explicit state for the insertion event at
+ the start of the historical batch since it's floating with no
+ prev_events to derive state from automatically.
app_service_requester: The requester of an application service.
Returns:
@@ -290,6 +308,11 @@ class RoomBatchHandler:
"""
assert app_service_requester.app_service
+ # We expect the first event in a historical batch to be an insertion event
+ assert events_to_create[0]["type"] == EventTypes.MSC2716_INSERTION
+ # We expect the last event in a historical batch to be an batch event
+ assert events_to_create[-1]["type"] == EventTypes.MSC2716_BATCH
+
# Make the historical event chain float off on its own by specifying no
# prev_events for the first event in the chain which causes the HS to
# ask for the state at the start of the batch later.
@@ -321,11 +344,16 @@ class RoomBatchHandler:
ev["sender"], app_service_requester.app_service
),
event_dict,
- # Only the first event in the chain should be floating.
- # The rest should hang off each other in a chain.
+ # Only the first event (which is the insertion event) in the
+ # chain should be floating. The rest should hang off each other
+ # in a chain.
allow_no_prev_events=index == 0,
prev_event_ids=event_dict.get("prev_events"),
- auth_event_ids=auth_event_ids,
+ # Since the first event (which is the insertion event) in the
+ # chain is floating with no `prev_events`, it can't derive state
+ # from anywhere automatically. So we need to set some state
+ # explicitly.
+ state_event_ids=initial_state_event_ids if index == 0 else None,
historical=True,
depth=inherited_depth,
)
@@ -343,10 +371,9 @@ class RoomBatchHandler:
)
logger.debug(
- "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s",
+ "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s",
event,
prev_event_ids,
- auth_event_ids,
)
events_to_persist.append((event, context))
@@ -376,12 +403,12 @@ class RoomBatchHandler:
room_id: str,
batch_id_to_connect_to: str,
inherited_depth: int,
- auth_event_ids: List[str],
+ initial_state_event_ids: List[str],
app_service_requester: Requester,
) -> Tuple[List[str], str]:
"""
- Handles creating and persisting all of the historical events as well
- as insertion and batch meta events to make the batch navigable in the DAG.
+ Handles creating and persisting all of the historical events as well as
+ insertion and batch meta events to make the batch navigable in the DAG.
Args:
events_to_create: List of historical events to create in JSON
@@ -391,8 +418,13 @@ class RoomBatchHandler:
want this batch to connect to.
inherited_depth: The depth to create the events at (you will
probably by calling inherit_depth_from_prev_ids(...)).
- auth_event_ids: Define which events allow you to create the given
- event in the room.
+ initial_state_event_ids:
+ This is used to set explicit state for the insertion event at
+ the start of the historical batch since it's floating with no
+ prev_events to derive state from automatically. This should
+ probably be the state from the `prev_event` defined by
+ `/batch_send?prev_event_id=$abc` plus the outcome of
+ `persist_state_events_at_start`
app_service_requester: The requester of an application service.
Returns:
@@ -438,7 +470,7 @@ class RoomBatchHandler:
events_to_create=events_to_create,
room_id=room_id,
inherited_depth=inherited_depth,
- auth_event_ids=auth_event_ids,
+ initial_state_event_ids=initial_state_event_ids,
app_service_requester=app_service_requester,
)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index e2ce4a35ef..938965f303 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -271,7 +271,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
membership: str,
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
- auth_event_ids: Optional[List[str]] = None,
+ state_event_ids: Optional[List[str]] = None,
txn_id: Optional[str] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
@@ -294,10 +294,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
events should have a prev_event and we should only use this in special
cases like MSC2716.
prev_event_ids: The event IDs to use as the prev events
- auth_event_ids:
- The event ids to use as the auth_events for the new event.
- Should normally be left as None, which will cause them to be calculated
- based on the room state at the prev_events.
+ state_event_ids:
+ The full state at a given event. This is used particularly by the MSC2716
+ /batch_send endpoint. One use case is the historical `state_events_at_start`;
+ since each is marked as an `outlier`, the `EventContext.for_outlier()` won't
+ have any `state_ids` set and therefore can't derive any state even though the
+ prev_events are set so we need to set them ourself via this argument.
+ This should normally be left as None, which will cause the auth_event_ids
+ to be calculated based on the room state at the prev_events.
txn_id:
ratelimit:
@@ -352,7 +356,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
txn_id=txn_id,
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
- auth_event_ids=auth_event_ids,
+ state_event_ids=state_event_ids,
require_consent=require_consent,
outlier=outlier,
historical=historical,
@@ -455,7 +459,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
historical: bool = False,
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
- auth_event_ids: Optional[List[str]] = None,
+ state_event_ids: Optional[List[str]] = None,
) -> Tuple[str, int]:
"""Update a user's membership in a room.
@@ -483,10 +487,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
events should have a prev_event and we should only use this in special
cases like MSC2716.
prev_event_ids: The event IDs to use as the prev events
- auth_event_ids:
- The event ids to use as the auth_events for the new event.
- Should normally be left as None, which will cause them to be calculated
- based on the room state at the prev_events.
+ state_event_ids:
+ The full state at a given event. This is used particularly by the MSC2716
+ /batch_send endpoint. One use case is the historical `state_events_at_start`;
+ since each is marked as an `outlier`, the `EventContext.for_outlier()` won't
+ have any `state_ids` set and therefore can't derive any state even though the
+ prev_events are set so we need to set them ourself via this argument.
+ This should normally be left as None, which will cause the auth_event_ids
+ to be calculated based on the room state at the prev_events.
Returns:
A tuple of the new event ID and stream ID.
@@ -539,7 +547,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
historical=historical,
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
- auth_event_ids=auth_event_ids,
+ state_event_ids=state_event_ids,
)
return result
@@ -561,7 +569,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
historical: bool = False,
allow_no_prev_events: bool = False,
prev_event_ids: Optional[List[str]] = None,
- auth_event_ids: Optional[List[str]] = None,
+ state_event_ids: Optional[List[str]] = None,
) -> Tuple[str, int]:
"""Helper for update_membership.
@@ -591,10 +599,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
events should have a prev_event and we should only use this in special
cases like MSC2716.
prev_event_ids: The event IDs to use as the prev events
- auth_event_ids:
- The event ids to use as the auth_events for the new event.
- Should normally be left as None, which will cause them to be calculated
- based on the room state at the prev_events.
+ state_event_ids:
+ The full state at a given event. This is used particularly by the MSC2716
+ /batch_send endpoint. One use case is the historical `state_events_at_start`;
+ since each is marked as an `outlier`, the `EventContext.for_outlier()` won't
+ have any `state_ids` set and therefore can't derive any state even though the
+ prev_events are set so we need to set them ourself via this argument.
+ This should normally be left as None, which will cause the auth_event_ids
+ to be calculated based on the room state at the prev_events.
Returns:
A tuple of the new event ID and stream ID.
@@ -721,7 +733,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
ratelimit=ratelimit,
allow_no_prev_events=allow_no_prev_events,
prev_event_ids=prev_event_ids,
- auth_event_ids=auth_event_ids,
+ state_event_ids=state_event_ids,
content=content,
require_consent=require_consent,
outlier=outlier,
@@ -945,7 +957,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
txn_id=txn_id,
ratelimit=ratelimit,
prev_event_ids=latest_event_ids,
- auth_event_ids=auth_event_ids,
+ state_event_ids=state_event_ids,
content=content,
require_consent=require_consent,
outlier=outlier,
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index aa16e417eb..30eddda65f 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -54,6 +54,7 @@ class SearchHandler:
self.clock = hs.get_clock()
self.hs = hs
self._event_serializer = hs.get_event_client_serializer()
+ self._relations_handler = hs.get_relations_handler()
self.storage = hs.get_storage()
self.state_store = self.storage.state
self.auth = hs.get_auth()
@@ -354,7 +355,7 @@ class SearchHandler:
aggregations = None
if self._msc3666_enabled:
- aggregations = await self.store.get_bundled_aggregations(
+ aggregations = await self._relations_handler.get_bundled_aggregations(
# Generate an iterable of EventBase for all the events that will be
# returned, including contextual events.
itertools.chain(
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 0aa3052fd6..6c569cfb1c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -28,16 +28,16 @@ from typing import (
import attr
from prometheus_client import Counter
-from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes
+from synapse.api.constants import EventTypes, Membership, ReceiptTypes
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
+from synapse.handlers.relations import BundledAggregations
from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import NotifCounts
-from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
@@ -269,6 +269,7 @@ class SyncHandler:
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self.presence_handler = hs.get_presence_handler()
+ self._relations_handler = hs.get_relations_handler()
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
@@ -638,8 +639,10 @@ class SyncHandler:
# as clients will have all the necessary information.
bundled_aggregations = None
if limited or newly_joined_room:
- bundled_aggregations = await self.store.get_bundled_aggregations(
- recents, sync_config.user.to_string()
+ bundled_aggregations = (
+ await self._relations_handler.get_bundled_aggregations(
+ recents, sync_config.user.to_string()
+ )
)
return TimelineBatch(
@@ -1601,7 +1604,7 @@ class SyncHandler:
return set(), set(), set(), set()
# 3. Work out which rooms need reporting in the sync response.
- ignored_users = await self._get_ignored_users(user_id)
+ ignored_users = await self.store.ignored_users(user_id)
if since_token:
room_changes = await self._get_rooms_changed(
sync_result_builder, ignored_users
@@ -1627,7 +1630,6 @@ class SyncHandler:
logger.debug("Generating room entry for %s", room_entry.room_id)
await self._generate_room_entry(
sync_result_builder,
- ignored_users,
room_entry,
ephemeral=ephemeral_by_room.get(room_entry.room_id, []),
tags=tags_by_room.get(room_entry.room_id),
@@ -1657,29 +1659,6 @@ class SyncHandler:
newly_left_users,
)
- async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]:
- """Retrieve the users ignored by the given user from their global account_data.
-
- Returns an empty set if
- - there is no global account_data entry for ignored_users
- - there is such an entry, but it's not a JSON object.
- """
- # TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead?
- ignored_account_data = (
- await self.store.get_global_account_data_by_type_for_user(
- user_id=user_id, data_type=AccountDataTypes.IGNORED_USER_LIST
- )
- )
-
- # If there is ignored users account data and it matches the proper type,
- # then use it.
- ignored_users: FrozenSet[str] = frozenset()
- if ignored_account_data:
- ignored_users_data = ignored_account_data.get("ignored_users", {})
- if isinstance(ignored_users_data, dict):
- ignored_users = frozenset(ignored_users_data.keys())
- return ignored_users
-
async def _have_rooms_changed(
self, sync_result_builder: "SyncResultBuilder"
) -> bool:
@@ -2022,7 +2001,6 @@ class SyncHandler:
async def _generate_room_entry(
self,
sync_result_builder: "SyncResultBuilder",
- ignored_users: FrozenSet[str],
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
tags: Optional[Dict[str, Dict[str, Any]]],
@@ -2051,7 +2029,6 @@ class SyncHandler:
Args:
sync_result_builder
- ignored_users: Set of users ignored by user.
room_builder
ephemeral: List of new ephemeral events for room
tags: List of *all* tags for room, or None if there has been
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index d27ed2be6a..048fd4bb82 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -19,8 +19,8 @@ import synapse.metrics
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.user_directory import SearchResult
from synapse.storage.roommember import ProfileInfo
-from synapse.types import JsonDict
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@@ -78,7 +78,7 @@ class UserDirectoryHandler(StateDeltasHandler):
async def search_users(
self, user_id: str, search_term: str, limit: int
- ) -> JsonDict:
+ ) -> SearchResult:
"""Searches for users in directory
Returns:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index d735c1d461..ba9755f08b 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -111,6 +111,7 @@ from synapse.types import (
StateMap,
UserID,
UserInfo,
+ UserProfile,
create_requester,
)
from synapse.util import Clock
@@ -150,6 +151,7 @@ __all__ = [
"EventBase",
"StateMap",
"ProfileInfo",
+ "UserProfile",
]
logger = logging.getLogger(__name__)
@@ -609,15 +611,18 @@ class ModuleApi:
localpart: str,
displayname: Optional[str] = None,
emails: Optional[List[str]] = None,
+ admin: bool = False,
) -> "defer.Deferred[str]":
"""Registers a new user with given localpart and optional displayname, emails.
Added in Synapse v1.2.0.
+ Changed in Synapse v1.56.0: add 'admin' argument to register the user as admin.
Args:
localpart: The localpart of the new user.
displayname: The displayname of the new user.
emails: Emails to bind to the new user.
+ admin: True if the user should be registered as a server admin.
Raises:
SynapseError if there is an error performing the registration. Check the
@@ -631,6 +636,7 @@ class ModuleApi:
localpart=localpart,
default_display_name=displayname,
bind_emails=emails or [],
+ admin=admin,
)
)
@@ -665,7 +671,8 @@ class ModuleApi:
def record_user_external_id(
self, auth_provider_id: str, remote_user_id: str, registered_user_id: str
) -> defer.Deferred:
- """Record a mapping from an external user id to a mxid
+ """Record a mapping between an external user id from a single sign-on provider
+ and a mxid.
Added in Synapse v1.9.0.
@@ -1280,6 +1287,30 @@ class ModuleApi:
"""
await self._registration_handler.check_username(username)
+ async def store_remote_3pid_association(
+ self, user_id: str, medium: str, address: str, id_server: str
+ ) -> None:
+ """Stores an existing association between a user ID and a third-party identifier.
+
+ The association must already exist on the remote identity server.
+
+ Added in Synapse v1.56.0.
+
+ Args:
+ user_id: The user ID that's been associated with the 3PID.
+ medium: The medium of the 3PID (current supported values are "msisdn" and
+ "email").
+ address: The address of the 3PID.
+ id_server: The identity server the 3PID association has been registered on.
+ This should only be the domain (or IP address, optionally with the port
+ number) for the identity server. This will be used to reach out to the
+ identity server using HTTPS (unless specified otherwise by Synapse's
+ configuration) when attempting to unbind the third-party identifier.
+
+
+ """
+ await self._store.add_user_bound_threepid(user_id, medium, address, id_server)
+
class PublicRoomListManager:
"""Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 8140afcb6b..a402a3e403 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -24,6 +24,7 @@ from synapse.event_auth import get_user_power_level
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
+from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.descriptors import lru_cache
@@ -213,7 +214,7 @@ class BulkPushRuleEvaluator:
if not event.is_state():
ignorers = await self.store.ignored_by(event.sender)
else:
- ignorers = set()
+ ignorers = frozenset()
for uid, rules in rules_by_user.items():
if event.sender == uid:
@@ -292,7 +293,7 @@ def _condition_checker(
return True
-MemberMap = Dict[str, Tuple[str, str]]
+MemberMap = Dict[str, Optional[EventIdMembership]]
Rule = Dict[str, dict]
RulesByUser = Dict[str, List[Rule]]
StateGroup = Union[object, int]
@@ -306,7 +307,7 @@ class RulesForRoomData:
*only* include data, and not references to e.g. the data stores.
"""
- # event_id -> (user_id, state)
+ # event_id -> EventIdMembership
member_map: MemberMap = attr.Factory(dict)
# user_id -> rules
rules_by_user: RulesByUser = attr.Factory(dict)
@@ -447,11 +448,10 @@ class RulesForRoom:
res = self.data.member_map.get(event_id, None)
if res:
- user_id, state = res
- if state == Membership.JOIN:
- rules = self.data.rules_by_user.get(user_id, None)
+ if res.membership == Membership.JOIN:
+ rules = self.data.rules_by_user.get(res.user_id, None)
if rules:
- ret_rules_by_user[user_id] = rules
+ ret_rules_by_user[res.user_id] = rules
continue
# If a user has left a room we remove their push rule. If they
@@ -502,24 +502,26 @@ class RulesForRoom:
"""
sequence = self.data.sequence
- rows = await self.store.get_membership_from_event_ids(member_event_ids.values())
-
- members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows}
+ members = await self.store.get_membership_from_event_ids(
+ member_event_ids.values()
+ )
- # If the event is a join event then it will be in current state evnts
+ # If the event is a join event then it will be in current state events
# map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member:
for event_id in member_event_ids.values():
if event_id == event.event_id:
- members[event_id] = (event.state_key, event.membership)
+ members[event_id] = EventIdMembership(
+ user_id=event.state_key, membership=event.membership
+ )
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())
joined_user_ids = {
- user_id
- for user_id, membership in members.values()
- if membership == Membership.JOIN
+ entry.user_id
+ for entry in members.values()
+ if entry and entry.membership == Membership.JOIN
}
logger.debug("Joined: %r", joined_user_ids)
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 649a4f49d0..5ccdd88364 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, TypeVar
import bleach
import jinja2
+from markupsafe import Markup
from synapse.api.constants import EventTypes, Membership, RoomTypes
from synapse.api.errors import StoreError
@@ -867,7 +868,7 @@ class Mailer:
)
-def safe_markup(raw_html: str) -> jinja2.Markup:
+def safe_markup(raw_html: str) -> Markup:
"""
Sanitise a raw HTML string to a set of allowed tags and attributes, and linkify any bare URLs.
@@ -877,7 +878,7 @@ def safe_markup(raw_html: str) -> jinja2.Markup:
Returns:
A Markup object ready to safely use in a Jinja template.
"""
- return jinja2.Markup(
+ return Markup(
bleach.linkify(
bleach.clean(
raw_html,
@@ -891,7 +892,7 @@ def safe_markup(raw_html: str) -> jinja2.Markup:
)
-def safe_text(raw_text: str) -> jinja2.Markup:
+def safe_text(raw_text: str) -> Markup:
"""
Sanitise text (escape any HTML tags), and then linkify any bare URLs.
@@ -901,7 +902,7 @@ def safe_text(raw_text: str) -> jinja2.Markup:
Returns:
A Markup object ready to safely use in a Jinja template.
"""
- return jinja2.Markup(
+ return Markup(
bleach.linkify(bleach.clean(raw_text, tags=[], attributes=[], strip=False))
)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 1dd39f06cf..8419ab3aca 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -74,7 +74,10 @@ REQUIREMENTS = [
# Note: 21.1.0 broke `/sync`, see #9936
"attrs>=19.2.0,!=21.1.0",
"netaddr>=0.7.18",
- "Jinja2>=2.9",
+ # Jinja 2.x is incompatible with MarkupSafe>=2.1. To ensure that admins do not
+ # end up with a broken installation, with recent MarkupSafe but old Jinja, we
+ # add a lower bound to the Jinja2 dependency.
+ "Jinja2>=3.0",
"bleach>=1.4.3",
# We use `ParamSpec`, which was added in `typing-extensions` 3.10.0.0.
"typing-extensions>=3.10.0",
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 762808a571..57c4773edc 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -32,6 +32,7 @@ from synapse.rest.client import (
knock,
login as v1_login,
logout,
+ mutual_rooms,
notifications,
openid,
password_policy,
@@ -49,7 +50,6 @@ from synapse.rest.client import (
room_keys,
room_upgrade_rest_servlet,
sendtodevice,
- shared_rooms,
sync,
tags,
thirdparty,
@@ -132,4 +132,4 @@ class ClientRestResource(JsonResource):
admin.register_servlets_for_client_rest_resource(hs, client_resource)
# unstable
- shared_rooms.register_servlets(hs, client_resource)
+ mutual_rooms.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/mutual_rooms.py
index e669fa7890..27bfaf0b29 100644
--- a/synapse/rest/client/shared_rooms.py
+++ b/synapse/rest/client/mutual_rooms.py
@@ -28,13 +28,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class UserSharedRoomsServlet(RestServlet):
+class UserMutualRoomsServlet(RestServlet):
"""
- GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1
+ GET /uk.half-shot.msc2666/user/mutual_rooms/{user_id} HTTP/1.1
"""
PATTERNS = client_patterns(
- "/uk.half-shot.msc2666/user/shared_rooms/(?P<user_id>[^/]*)",
+ "/uk.half-shot.msc2666/user/mutual_rooms/(?P<user_id>[^/]*)",
releases=(), # This is an unstable feature
)
@@ -42,17 +42,19 @@ class UserSharedRoomsServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
- self.user_directory_active = hs.config.server.update_user_directory
+ self.user_directory_search_enabled = (
+ hs.config.userdirectory.user_directory_search_enabled
+ )
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
- if not self.user_directory_active:
+ if not self.user_directory_search_enabled:
raise SynapseError(
code=400,
- msg="The user directory is disabled on this server. Cannot determine shared rooms.",
- errcode=Codes.FORBIDDEN,
+ msg="User directory searching is disabled. Cannot determine shared rooms.",
+ errcode=Codes.UNKNOWN,
)
UserID.from_string(user_id)
@@ -64,7 +66,8 @@ class UserSharedRoomsServlet(RestServlet):
msg="You cannot request a list of shared rooms with yourself",
errcode=Codes.FORBIDDEN,
)
- rooms = await self.store.get_shared_rooms_for_users(
+
+ rooms = await self.store.get_mutual_rooms_for_users(
requester.user.to_string(), user_id
)
@@ -72,4 +75,4 @@ class UserSharedRoomsServlet(RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
- UserSharedRoomsServlet(hs).register(http_server)
+ UserMutualRoomsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index d9a6be43f7..c16078b187 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -51,9 +51,7 @@ class RelationPaginationServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
- self.clock = hs.get_clock()
- self._event_serializer = hs.get_event_client_serializer()
- self.event_handler = hs.get_event_handler()
+ self._relations_handler = hs.get_relations_handler()
async def on_GET(
self,
@@ -65,16 +63,6 @@ class RelationPaginationServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- await self.auth.check_user_in_room_or_world_readable(
- room_id, requester.user.to_string(), allow_departed_users=True
- )
-
- # This gets the original event and checks that a) the event exists and
- # b) the user is allowed to view it.
- event = await self.event_handler.get_event(requester.user, room_id, parent_id)
- if event is None:
- raise SynapseError(404, "Unknown parent event.")
-
limit = parse_integer(request, "limit", default=5)
direction = parse_string(
request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"]
@@ -90,9 +78,9 @@ class RelationPaginationServlet(RestServlet):
if to_token_str:
to_token = await StreamToken.from_string(self.store, to_token_str)
- pagination_chunk = await self.store.get_relations_for_event(
+ result = await self._relations_handler.get_relations(
+ requester=requester,
event_id=parent_id,
- event=event,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
@@ -102,30 +90,7 @@ class RelationPaginationServlet(RestServlet):
to_token=to_token,
)
- events = await self.store.get_events_as_list(
- [c["event_id"] for c in pagination_chunk.chunk]
- )
-
- now = self.clock.time_msec()
- # Do not bundle aggregations when retrieving the original event because
- # we want the content before relations are applied to it.
- original_event = self._event_serializer.serialize_event(
- event, now, bundle_aggregations=None
- )
- # The relations returned for the requested event do include their
- # bundled aggregations.
- aggregations = await self.store.get_bundled_aggregations(
- events, requester.user.to_string()
- )
- serialized_events = self._event_serializer.serialize_events(
- events, now, bundle_aggregations=aggregations
- )
-
- return_value = await pagination_chunk.to_dict(self.store)
- return_value["chunk"] = serialized_events
- return_value["original_event"] = original_event
-
- return 200, return_value
+ return 200, result
class RelationAggregationPaginationServlet(RestServlet):
@@ -245,9 +210,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
- self.clock = hs.get_clock()
- self._event_serializer = hs.get_event_client_serializer()
- self.event_handler = hs.get_event_handler()
+ self._relations_handler = hs.get_relations_handler()
async def on_GET(
self,
@@ -260,18 +223,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- await self.auth.check_user_in_room_or_world_readable(
- room_id,
- requester.user.to_string(),
- allow_departed_users=True,
- )
-
- # This checks that a) the event exists and b) the user is allowed to
- # view it.
- event = await self.event_handler.get_event(requester.user, room_id, parent_id)
- if event is None:
- raise SynapseError(404, "Unknown parent event.")
-
if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'")
@@ -286,9 +237,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
if to_token_str:
to_token = await StreamToken.from_string(self.store, to_token_str)
- result = await self.store.get_relations_for_event(
+ result = await self._relations_handler.get_relations(
+ requester=requester,
event_id=parent_id,
- event=event,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
@@ -298,17 +249,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
to_token=to_token,
)
- events = await self.store.get_events_as_list(
- [c["event_id"] for c in result.chunk]
- )
-
- now = self.clock.time_msec()
- serialized_events = self._event_serializer.serialize_events(events, now)
-
- return_value = await result.to_dict(self.store)
- return_value["chunk"] = serialized_events
-
- return 200, return_value
+ return 200, result
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 8a06ab8c5f..47e152c8cc 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -645,6 +645,7 @@ class RoomEventServlet(RestServlet):
self._store = hs.get_datastores().main
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
+ self._relations_handler = hs.get_relations_handler()
self.auth = hs.get_auth()
async def on_GET(
@@ -663,7 +664,7 @@ class RoomEventServlet(RestServlet):
if event:
# Ensure there are bundled aggregations available.
- aggregations = await self._store.get_bundled_aggregations(
+ aggregations = await self._relations_handler.get_bundled_aggregations(
[event], requester.user.to_string()
)
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index 0048973e59..0780485322 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -124,14 +124,14 @@ class RoomBatchSendEventRestServlet(RestServlet):
)
# For the event we are inserting next to (`prev_event_ids_from_query`),
- # find the most recent auth events (derived from state events) that
- # allowed that message to be sent. We will use that as a base
- # to auth our historical messages against.
- auth_event_ids = await self.room_batch_handler.get_most_recent_auth_event_ids_from_event_id_list(
+ # find the most recent state events that allowed that message to be
+ # sent. We will use that as a base to auth our historical messages
+ # against.
+ state_event_ids = await self.room_batch_handler.get_most_recent_full_state_ids_from_event_id_list(
prev_event_ids_from_query
)
- if not auth_event_ids:
+ if not state_event_ids:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"No auth events found for given prev_event query parameter. The prev_event=%s probably does not exist."
@@ -148,13 +148,13 @@ class RoomBatchSendEventRestServlet(RestServlet):
await self.room_batch_handler.persist_state_events_at_start(
state_events_at_start=body["state_events_at_start"],
room_id=room_id,
- initial_auth_event_ids=auth_event_ids,
+ initial_state_event_ids=state_event_ids,
app_service_requester=requester,
)
)
# Update our ongoing auth event ID list with all of the new state we
# just created
- auth_event_ids.extend(state_event_ids_at_start)
+ state_event_ids.extend(state_event_ids_at_start)
inherited_depth = await self.room_batch_handler.inherit_depth_from_prev_ids(
prev_event_ids_from_query
@@ -196,7 +196,12 @@ class RoomBatchSendEventRestServlet(RestServlet):
),
base_insertion_event_dict,
prev_event_ids=base_insertion_event_dict.get("prev_events"),
- auth_event_ids=auth_event_ids,
+ # Also set the explicit state here because we want to resolve
+ # any `state_events_at_start` here too. It's not strictly
+ # necessary to accomplish anything but if someone asks for the
+ # state at this point, we probably want to show them the
+ # historical state that was part of this batch.
+ state_event_ids=state_event_ids,
historical=True,
depth=inherited_depth,
)
@@ -212,7 +217,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
room_id=room_id,
batch_id_to_connect_to=batch_id_to_connect_to,
inherited_depth=inherited_depth,
- auth_event_ids=auth_event_ids,
+ initial_state_event_ids=state_event_ids,
app_service_requester=requester,
)
diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py
index a47d9bd01d..116c982ce6 100644
--- a/synapse/rest/client/user_directory.py
+++ b/synapse/rest/client/user_directory.py
@@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
-from synapse.types import JsonDict
+from synapse.types import JsonMapping
from ._base import client_patterns
@@ -38,7 +38,7 @@ class UserDirectorySearchRestServlet(RestServlet):
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
- async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonMapping]:
"""Searches for users in directory
Returns:
diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py
index 872a9e72e8..4cc9c66fbe 100644
--- a/synapse/rest/media/v1/preview_html.py
+++ b/synapse/rest/media/v1/preview_html.py
@@ -16,7 +16,6 @@ import itertools
import logging
import re
from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union
-from urllib import parse as urlparse
if TYPE_CHECKING:
from lxml import etree
@@ -144,9 +143,7 @@ def decode_body(
return etree.fromstring(body, parser)
-def parse_html_to_open_graph(
- tree: "etree.Element", media_uri: str
-) -> Dict[str, Optional[str]]:
+def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
"""
Parse the HTML document into an Open Graph response.
@@ -155,7 +152,6 @@ def parse_html_to_open_graph(
Args:
tree: The parsed HTML document.
- media_url: The URI used to download the body.
Returns:
The Open Graph response as a dictionary.
@@ -209,7 +205,7 @@ def parse_html_to_open_graph(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
- og["og:image"] = rebase_url(meta_image[0], media_uri)
+ og["og:image"] = meta_image[0]
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
@@ -320,37 +316,6 @@ def _iterate_over_text(
)
-def rebase_url(url: str, base: str) -> str:
- """
- Resolves a potentially relative `url` against an absolute `base` URL.
-
- For example:
-
- >>> rebase_url("subpage", "https://example.com/foo/")
- 'https://example.com/foo/subpage'
- >>> rebase_url("sibling", "https://example.com/foo")
- 'https://example.com/sibling'
- >>> rebase_url("/bar", "https://example.com/foo/")
- 'https://example.com/bar'
- >>> rebase_url("https://alice.com/a/", "https://example.com/foo/")
- 'https://alice.com/a'
- """
- base_parts = urlparse.urlparse(base)
- # Convert the parsed URL to a list for (potential) modification.
- url_parts = list(urlparse.urlparse(url))
- # Add a scheme, if one does not exist.
- if not url_parts[0]:
- url_parts[0] = base_parts.scheme or "http"
- # Fix up the hostname, if this is not a data URL.
- if url_parts[0] != "data" and not url_parts[1]:
- url_parts[1] = base_parts.netloc
- # If the path does not start with a /, nest it under the base path's last
- # directory.
- if not url_parts[2].startswith("/"):
- url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts.path) + url_parts[2]
- return urlparse.urlunparse(url_parts)
-
-
def summarize_paragraphs(
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
) -> Optional[str]:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 14ea88b240..d47af8ead6 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -22,7 +22,7 @@ import shutil
import sys
import traceback
from typing import TYPE_CHECKING, BinaryIO, Iterable, Optional, Tuple
-from urllib import parse as urlparse
+from urllib.parse import urljoin, urlparse, urlsplit
from urllib.request import urlopen
import attr
@@ -44,11 +44,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.oembed import OEmbedProvider
-from synapse.rest.media.v1.preview_html import (
- decode_body,
- parse_html_to_open_graph,
- rebase_url,
-)
+from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph
from synapse.types import JsonDict, UserID
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
@@ -187,7 +183,7 @@ class PreviewUrlResource(DirectServeJsonResource):
ts = self.clock.time_msec()
# XXX: we could move this into _do_preview if we wanted.
- url_tuple = urlparse.urlsplit(url)
+ url_tuple = urlsplit(url)
for entry in self.url_preview_url_blacklist:
match = True
for attrib in entry:
@@ -322,7 +318,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# Parse Open Graph information from the HTML in case the oEmbed
# response failed or is incomplete.
- og_from_html = parse_html_to_open_graph(tree, media_info.uri)
+ og_from_html = parse_html_to_open_graph(tree)
# Compile the Open Graph response by using the scraped
# information from the HTML and overlaying any information
@@ -588,12 +584,17 @@ class PreviewUrlResource(DirectServeJsonResource):
if "og:image" not in og or not og["og:image"]:
return
+ # The image URL from the HTML might be relative to the previewed page,
+ # convert it to an URL which can be requested directly.
+ image_url = og["og:image"]
+ url_parts = urlparse(image_url)
+ if url_parts.scheme != "data":
+ image_url = urljoin(media_info.uri, image_url)
+
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
- image_info = await self._handle_url(
- rebase_url(og["og:image"], media_info.uri), user, allow_data_urls=True
- )
+ image_info = await self._handle_url(image_url, user, allow_data_urls=True)
if _is_media(image_info.media_type):
# TODO: make sure we don't choke on white-on-transparent images
diff --git a/synapse/server.py b/synapse/server.py
index 2fcf18a7a6..380369db92 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -94,6 +94,7 @@ from synapse.handlers.profile import ProfileHandler
from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.register import RegistrationHandler
+from synapse.handlers.relations import RelationsHandler
from synapse.handlers.room import (
RoomContextHandler,
RoomCreationHandler,
@@ -720,6 +721,10 @@ class HomeServer(metaclass=abc.ABCMeta):
return PaginationHandler(self)
@cache_in_self
+ def get_relations_handler(self) -> RelationsHandler:
+ return RelationsHandler(self)
+
+ @cache_in_self
def get_room_context_handler(self) -> RoomContextHandler:
return RoomContextHandler(self)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 99802228c9..72fef1533f 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -41,6 +41,7 @@ from prometheus_client import Histogram
from typing_extensions import Literal
from twisted.enterprise import adbapi
+from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
+from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -286,13 +288,17 @@ class LoggingTransaction:
"""
if isinstance(self.database_engine, PostgresEngine):
- from psycopg2.extras import execute_batch # type: ignore
+ from psycopg2.extras import execute_batch
- self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
+ self._do_execute(
+ lambda the_sql: execute_batch(self.txn, the_sql, args), sql
+ )
else:
self.executemany(sql, args)
- def execute_values(self, sql: str, *args: Any, fetch: bool = True) -> List[Tuple]:
+ def execute_values(
+ self, sql: str, values: Iterable[Iterable[Any]], fetch: bool = True
+ ) -> List[Tuple]:
"""Corresponds to psycopg2.extras.execute_values. Only available when
using postgres.
@@ -300,10 +306,11 @@ class LoggingTransaction:
rows (e.g. INSERTs).
"""
assert isinstance(self.database_engine, PostgresEngine)
- from psycopg2.extras import execute_values # type: ignore
+ from psycopg2.extras import execute_values
return self._do_execute(
- lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args
+ lambda the_sql: execute_values(self.txn, the_sql, values, fetch=fetch),
+ sql,
)
def execute(self, sql: str, *args: Any) -> None:
@@ -732,34 +739,45 @@ class DatabasePool:
Returns:
The result of func
"""
- after_callbacks: List[_CallbackListEntry] = []
- exception_callbacks: List[_CallbackListEntry] = []
- if not current_context():
- logger.warning("Starting db txn '%s' from sentinel context", desc)
+ async def _runInteraction() -> R:
+ after_callbacks: List[_CallbackListEntry] = []
+ exception_callbacks: List[_CallbackListEntry] = []
- try:
- with opentracing.start_active_span(f"db.{desc}"):
- result = await self.runWithConnection(
- self.new_transaction,
- desc,
- after_callbacks,
- exception_callbacks,
- func,
- *args,
- db_autocommit=db_autocommit,
- isolation_level=isolation_level,
- **kwargs,
- )
+ if not current_context():
+ logger.warning("Starting db txn '%s' from sentinel context", desc)
- for after_callback, after_args, after_kwargs in after_callbacks:
- after_callback(*after_args, **after_kwargs)
- except Exception:
- for after_callback, after_args, after_kwargs in exception_callbacks:
- after_callback(*after_args, **after_kwargs)
- raise
+ try:
+ with opentracing.start_active_span(f"db.{desc}"):
+ result = await self.runWithConnection(
+ self.new_transaction,
+ desc,
+ after_callbacks,
+ exception_callbacks,
+ func,
+ *args,
+ db_autocommit=db_autocommit,
+ isolation_level=isolation_level,
+ **kwargs,
+ )
- return cast(R, result)
+ for after_callback, after_args, after_kwargs in after_callbacks:
+ after_callback(*after_args, **after_kwargs)
+
+ return cast(R, result)
+ except Exception:
+ for after_callback, after_args, after_kwargs in exception_callbacks:
+ after_callback(*after_args, **after_kwargs)
+ raise
+
+ # To handle cancellation, we ensure that `after_callback`s and
+ # `exception_callback`s are always run, since the transaction will complete
+ # on another thread regardless of cancellation.
+ #
+ # We also wait until everything above is done before releasing the
+ # `CancelledError`, so that logging contexts won't get used after they have been
+ # finished.
+ return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
async def runWithConnection(
self,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 52146aacc8..9af9f4f18e 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,7 +14,17 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ FrozenSet,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ cast,
+)
from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@@ -365,7 +375,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
@cached(max_entries=5000, iterable=True)
- async def ignored_by(self, user_id: str) -> Set[str]:
+ async def ignored_by(self, user_id: str) -> FrozenSet[str]:
"""
Get users which ignore the given user.
@@ -375,7 +385,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
Return:
The user IDs which ignore the given user.
"""
- return set(
+ return frozenset(
await self.db_pool.simple_select_onecol(
table="ignored_users",
keyvalues={"ignored_user_id": user_id},
@@ -384,6 +394,26 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
)
+ @cached(max_entries=5000, iterable=True)
+ async def ignored_users(self, user_id: str) -> FrozenSet[str]:
+ """
+ Get users which the given user ignores.
+
+ Params:
+ user_id: The user ID which is making the request.
+
+ Return:
+ The user IDs which are ignored by the given user.
+ """
+ return frozenset(
+ await self.db_pool.simple_select_onecol(
+ table="ignored_users",
+ keyvalues={"ignorer_user_id": user_id},
+ retcol="ignored_user_id",
+ desc="ignored_users",
+ )
+ )
+
def process_replication_rows(
self,
stream_name: str,
@@ -529,6 +559,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
else:
currently_ignored_users = set()
+ # If the data has not changed, nothing to do.
+ if previously_ignored_users == currently_ignored_users:
+ return
+
# Delete entries which are no longer ignored.
self.db_pool.simple_delete_many_txn(
txn,
@@ -551,6 +585,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
# Invalidate the cache for any ignored users which were added or removed.
for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+ self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
async def purge_account_data_for_user(self, user_id: str) -> None:
"""
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index d6a2df1afe..dd4e83a2ad 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
+ EventsStreamRow,
)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@@ -31,6 +32,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
+from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
- def get_all_updated_caches_txn(txn):
+ def get_all_updated_caches_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
@@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"get_all_updated_caches", get_all_updated_caches_txn
)
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+ ) -> None:
if stream_name == EventsStream.NAME:
for row in rows:
self._process_event_stream_row(token, row)
@@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows)
- def _process_event_stream_row(self, token, row):
+ def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data
if row.type == EventsStreamEventRow.TypeId:
+ assert isinstance(data, EventsStreamEventRow)
self._invalidate_caches_for_event(
token,
data.event_id,
@@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
- self._curr_state_delta_stream_cache.entity_has_changed(
- row.data.room_id, token
- )
+ assert isinstance(data, EventsStreamCurrentStateRow)
+ self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)
if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate(
@@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
def _invalidate_caches_for_event(
self,
- stream_ordering,
- event_id,
- room_id,
- etype,
- state_key,
- redacts,
- relates_to,
- backfilled,
- ):
+ stream_ordering: int,
+ event_id: str,
+ room_id: str,
+ etype: str,
+ state_key: Optional[str],
+ redacts: Optional[str],
+ relates_to: Optional[str],
+ backfilled: bool,
+ ) -> None:
self._invalidate_get_event_cache(event_id)
self.have_seen_event.invalidate((room_id, event_id))
@@ -186,6 +192,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))
+ # The `_get_membership_from_event_id` is immutable, except for the
+ # case where we look up an event *before* persisting it.
+ self._get_membership_from_event_id.invalidate((event_id,))
+
if not backfilled:
self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
@@ -207,7 +217,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_thread_summary.invalidate((relates_to,))
self.get_thread_participated.invalidate((relates_to,))
- async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
+ async def invalidate_cache_and_stream(
+ self, cache_name: str, keys: Tuple[Any, ...]
+ ) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -227,7 +239,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
keys,
)
- def _invalidate_cache_and_stream(self, txn, cache_func, keys):
+ def _invalidate_cache_and_stream(
+ self,
+ txn: LoggingTransaction,
+ cache_func: _CachedFunction,
+ keys: Tuple[Any, ...],
+ ) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -238,7 +255,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn.call_after(cache_func.invalidate, keys)
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
- def _invalidate_all_cache_and_stream(self, txn, cache_func):
+ def _invalidate_all_cache_and_stream(
+ self, txn: LoggingTransaction, cache_func: _CachedFunction
+ ) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches.
"""
@@ -279,8 +298,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
)
def _send_invalidation_to_replication(
- self, txn, cache_name: str, keys: Optional[Iterable[Any]]
- ):
+ self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
+ ) -> None:
"""Notifies replication that given cache has been invalidated.
Note that this does *not* invalidate the cache locally.
@@ -315,7 +334,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"instance_name": self._instance_name,
"cache_func": cache_name,
"keys": keys,
- "invalidation_ts": self.clock.time_msec(),
+ "invalidation_ts": self._clock.time_msec(),
},
)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 277e6422eb..634e19e035 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1073,9 +1073,15 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
/* Get the depth and stream_ordering of the prev_event_id from the events table */
INNER JOIN events
ON prev_event_id = events.event_id
+
+ /* exclude outliers from the results (we don't have the state, so cannot
+ * verify if the requesting server can see them).
+ */
+ WHERE NOT events.outlier
+
/* Look for an edge which matches the given event_id */
- WHERE event_edges.event_id = ?
- AND event_edges.is_state = ?
+ AND event_edges.event_id = ? AND NOT event_edges.is_state
+
/* Because we can have many events at the same depth,
* we want to also tie-break and sort on stream_ordering */
ORDER BY depth DESC, stream_ordering DESC
@@ -1084,7 +1090,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(
connected_prev_event_query,
- (event_id, False, limit),
+ (event_id, limit),
)
return [
BackfillQueueNavigationItem(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1f60aef180..d253243125 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1745,6 +1745,13 @@ class PersistEventsStore:
(event.state_key,),
)
+ # The `_get_membership_from_event_id` is immutable, except for the
+ # case where we look up an event *before* persisting it.
+ txn.call_after(
+ self.store._get_membership_from_event_id.invalidate,
+ (event.event_id,),
+ )
+
# We update the local_current_membership table only if the event is
# "current", i.e., its something that has just happened.
#
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 3f6086050b..0aef121d83 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -13,13 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
from typing_extensions import TypedDict
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -75,7 +79,7 @@ class GroupServerWorkerStore(SQLBaseStore):
) -> List[Dict[str, Any]]:
# TODO: Pagination
- keyvalues = {"group_id": group_id}
+ keyvalues: JsonDict = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@@ -117,7 +121,7 @@ class GroupServerWorkerStore(SQLBaseStore):
# TODO: Pagination
- def _get_rooms_in_group_txn(txn):
+ def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]:
sql = """
SELECT room_id, is_public FROM group_rooms
WHERE group_id = ?
@@ -176,8 +180,10 @@ class GroupServerWorkerStore(SQLBaseStore):
* "order": int, the sort order of rooms in this category
"""
- def _get_rooms_for_summary_txn(txn):
- keyvalues = {"group_id": group_id}
+ def _get_rooms_for_summary_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
+ keyvalues: JsonDict = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@@ -241,7 +247,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_rooms_for_summary", _get_rooms_for_summary_txn
)
- async def get_group_categories(self, group_id):
+ async def get_group_categories(self, group_id: str) -> JsonDict:
rows = await self.db_pool.simple_select_list(
table="group_room_categories",
keyvalues={"group_id": group_id},
@@ -257,7 +263,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows
}
- async def get_group_category(self, group_id, category_id):
+ async def get_group_category(self, group_id: str, category_id: str) -> JsonDict:
category = await self.db_pool.simple_select_one(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@@ -269,7 +275,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return category
- async def get_group_roles(self, group_id):
+ async def get_group_roles(self, group_id: str) -> JsonDict:
rows = await self.db_pool.simple_select_list(
table="group_roles",
keyvalues={"group_id": group_id},
@@ -285,7 +291,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows
}
- async def get_group_role(self, group_id, role_id):
+ async def get_group_role(self, group_id: str, role_id: str) -> JsonDict:
role = await self.db_pool.simple_select_one(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@@ -311,15 +317,19 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_local_groups_for_room",
)
- async def get_users_for_summary_by_role(self, group_id, include_private=False):
+ async def get_users_for_summary_by_role(
+ self, group_id: str, include_private: bool = False
+ ) -> Tuple[List[JsonDict], JsonDict]:
"""Get the users and roles that should be included in a summary request
Returns:
([users], [roles])
"""
- def _get_users_for_summary_txn(txn):
- keyvalues = {"group_id": group_id}
+ def _get_users_for_summary_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[JsonDict], JsonDict]:
+ keyvalues: JsonDict = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
@@ -406,7 +416,9 @@ class GroupServerWorkerStore(SQLBaseStore):
allow_none=True,
)
- async def get_users_membership_info_in_group(self, group_id, user_id):
+ async def get_users_membership_info_in_group(
+ self, group_id: str, user_id: str
+ ) -> JsonDict:
"""Get a dict describing the membership of a user in a group.
Example if joined:
@@ -421,7 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore):
An empty dict if the user is not join/invite/etc
"""
- def _get_users_membership_in_group_txn(txn):
+ def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict:
row = self.db_pool.simple_select_one_txn(
txn,
table="group_users",
@@ -463,10 +475,14 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_publicised_groups_for_user",
)
- async def get_attestations_need_renewals(self, valid_until_ms):
+ async def get_attestations_need_renewals(
+ self, valid_until_ms: int
+ ) -> List[Dict[str, Any]]:
"""Get all attestations that need to be renewed until givent time"""
- def _get_attestations_need_renewals_txn(txn):
+ def _get_attestations_need_renewals_txn(
+ txn: LoggingTransaction,
+ ) -> List[Dict[str, Any]]:
sql = """
SELECT group_id, user_id FROM group_attestations_renewals
WHERE valid_until_ms <= ?
@@ -478,7 +494,9 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
)
- async def get_remote_attestation(self, group_id, user_id):
+ async def get_remote_attestation(
+ self, group_id: str, user_id: str
+ ) -> Optional[JsonDict]:
"""Get the attestation that proves the remote agrees that the user is
in the group.
"""
@@ -504,8 +522,8 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_joined_groups",
)
- async def get_all_groups_for_user(self, user_id, now_token):
- def _get_all_groups_for_user_txn(txn):
+ async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]:
+ def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
sql = """
SELECT group_id, type, membership, u.content
FROM local_group_updates AS u
@@ -528,15 +546,16 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_all_groups_for_user", _get_all_groups_for_user_txn
)
- async def get_groups_changes_for_user(self, user_id, from_token, to_token):
- from_token = int(from_token)
- has_changed = self._group_updates_stream_cache.has_entity_changed(
+ async def get_groups_changes_for_user(
+ self, user_id: str, from_token: int, to_token: int
+ ) -> List[JsonDict]:
+ has_changed = self._group_updates_stream_cache.has_entity_changed( # type: ignore[attr-defined]
user_id, from_token
)
if not has_changed:
return []
- def _get_groups_changes_for_user_txn(txn):
+ def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
sql = """
SELECT group_id, membership, type, u.content
FROM local_group_updates AS u
@@ -583,12 +602,14 @@ class GroupServerWorkerStore(SQLBaseStore):
"""
last_id = int(last_id)
- has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)
+ has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) # type: ignore[attr-defined]
if not has_changed:
return [], current_id, False
- def _get_all_groups_changes_txn(txn):
+ def _get_all_groups_changes_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
sql = """
SELECT stream_id, group_id, user_id, type, content
FROM local_group_updates
@@ -596,10 +617,13 @@ class GroupServerWorkerStore(SQLBaseStore):
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
- updates = [
- (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
- for stream_id, group_id, user_id, gtype, content_json in txn
- ]
+ updates = cast(
+ List[Tuple[int, tuple]],
+ [
+ (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
+ for stream_id, group_id, user_id, gtype, content_json in txn
+ ],
+ )
limited = False
upto_token = current_id
@@ -633,8 +657,8 @@ class GroupServerStore(GroupServerWorkerStore):
self,
group_id: str,
room_id: str,
- category_id: str,
- order: int,
+ category_id: Optional[str],
+ order: Optional[int],
is_public: Optional[bool],
) -> None:
"""Add (or update) room's entry in summary.
@@ -661,11 +685,11 @@ class GroupServerStore(GroupServerWorkerStore):
def _add_room_to_summary_txn(
self,
- txn,
+ txn: LoggingTransaction,
group_id: str,
room_id: str,
- category_id: str,
- order: int,
+ category_id: Optional[str],
+ order: Optional[int],
is_public: Optional[bool],
) -> None:
"""Add (or update) room's entry in summary.
@@ -750,7 +774,7 @@ class GroupServerStore(GroupServerWorkerStore):
WHERE group_id = ? AND category_id = ?
"""
txn.execute(sql, (group_id, category_id))
- (order,) = txn.fetchone()
+ (order,) = cast(Tuple[int], txn.fetchone())
if existing:
to_update = {}
@@ -766,7 +790,7 @@ class GroupServerStore(GroupServerWorkerStore):
"category_id": category_id,
"room_id": room_id,
},
- values=to_update,
+ updatevalues=to_update,
)
else:
if is_public is None:
@@ -785,7 +809,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def remove_room_from_summary(
- self, group_id: str, room_id: str, category_id: str
+ self, group_id: str, room_id: str, category_id: Optional[str]
) -> int:
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
@@ -808,8 +832,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_public: Optional[bool],
) -> None:
"""Add/update room category for group"""
- insertion_values = {}
- update_values = {"category_id": category_id} # This cannot be empty
+ insertion_values: JsonDict = {}
+ update_values: JsonDict = {"category_id": category_id} # This cannot be empty
if profile is None:
insertion_values["profile"] = "{}"
@@ -844,8 +868,8 @@ class GroupServerStore(GroupServerWorkerStore):
is_public: Optional[bool],
) -> None:
"""Add/remove user role"""
- insertion_values = {}
- update_values = {"role_id": role_id} # This cannot be empty
+ insertion_values: JsonDict = {}
+ update_values: JsonDict = {"role_id": role_id} # This cannot be empty
if profile is None:
insertion_values["profile"] = "{}"
@@ -876,8 +900,8 @@ class GroupServerStore(GroupServerWorkerStore):
self,
group_id: str,
user_id: str,
- role_id: str,
- order: int,
+ role_id: Optional[str],
+ order: Optional[int],
is_public: Optional[bool],
) -> None:
"""Add (or update) user's entry in summary.
@@ -904,13 +928,13 @@ class GroupServerStore(GroupServerWorkerStore):
def _add_user_to_summary_txn(
self,
- txn,
+ txn: LoggingTransaction,
group_id: str,
user_id: str,
- role_id: str,
- order: int,
+ role_id: Optional[str],
+ order: Optional[int],
is_public: Optional[bool],
- ):
+ ) -> None:
"""Add (or update) user's entry in summary.
Args:
@@ -989,7 +1013,7 @@ class GroupServerStore(GroupServerWorkerStore):
WHERE group_id = ? AND role_id = ?
"""
txn.execute(sql, (group_id, role_id))
- (order,) = txn.fetchone()
+ (order,) = cast(Tuple[int], txn.fetchone())
if existing:
to_update = {}
@@ -1005,7 +1029,7 @@ class GroupServerStore(GroupServerWorkerStore):
"role_id": role_id,
"user_id": user_id,
},
- values=to_update,
+ updatevalues=to_update,
)
else:
if is_public is None:
@@ -1024,7 +1048,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def remove_user_from_summary(
- self, group_id: str, user_id: str, role_id: str
+ self, group_id: str, user_id: str, role_id: Optional[str]
) -> int:
if role_id is None:
role_id = _DEFAULT_ROLE_ID
@@ -1065,7 +1089,7 @@ class GroupServerStore(GroupServerWorkerStore):
Optional if the user and group are on the same server
"""
- def _add_user_to_group_txn(txn):
+ def _add_user_to_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_insert_txn(
txn,
table="group_users",
@@ -1108,7 +1132,7 @@ class GroupServerStore(GroupServerWorkerStore):
await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
- def _remove_user_from_group_txn(txn):
+ def _remove_user_from_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="group_users",
@@ -1159,7 +1183,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
- def _remove_room_from_group_txn(txn):
+ def _remove_room_from_group_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn(
txn,
table="group_rooms",
@@ -1216,7 +1240,9 @@ class GroupServerStore(GroupServerWorkerStore):
content = content or {}
- def _register_user_group_membership_txn(txn, next_id):
+ def _register_user_group_membership_txn(
+ txn: LoggingTransaction, next_id: int
+ ) -> int:
# TODO: Upsert?
self.db_pool.simple_delete_txn(
txn,
@@ -1249,7 +1275,7 @@ class GroupServerStore(GroupServerWorkerStore):
),
},
)
- self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
+ self._group_updates_stream_cache.entity_has_changed(user_id, next_id) # type: ignore[attr-defined]
# TODO: Insert profile to ensure it comes down stream if its a join.
@@ -1289,7 +1315,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
- async with self._group_updates_id_gen.get_next() as next_id:
+ async with self._group_updates_id_gen.get_next() as next_id: # type: ignore[attr-defined]
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
@@ -1298,7 +1324,13 @@ class GroupServerStore(GroupServerWorkerStore):
return res
async def create_group(
- self, group_id, user_id, name, avatar_url, short_description, long_description
+ self,
+ group_id: str,
+ user_id: str,
+ name: str,
+ avatar_url: str,
+ short_description: str,
+ long_description: str,
) -> None:
await self.db_pool.simple_insert(
table="groups",
@@ -1313,7 +1345,7 @@ class GroupServerStore(GroupServerWorkerStore):
desc="create_group",
)
- async def update_group_profile(self, group_id, profile):
+ async def update_group_profile(self, group_id: str, profile: JsonDict) -> None:
await self.db_pool.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
@@ -1361,8 +1393,8 @@ class GroupServerStore(GroupServerWorkerStore):
desc="remove_attestation_renewal",
)
- def get_group_stream_token(self):
- return self._group_updates_id_gen.get_current_token()
+ def get_group_stream_token(self) -> int:
+ return self._group_updates_id_gen.get_current_token() # type: ignore[attr-defined]
async def delete_group(self, group_id: str) -> None:
"""Deletes a group fully from the database.
@@ -1371,7 +1403,7 @@ class GroupServerStore(GroupServerWorkerStore):
group_id: The group ID to delete.
"""
- def _delete_group_txn(txn):
+ def _delete_group_txn(txn: LoggingTransaction) -> None:
tables = [
"groups",
"group_users",
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index cbba356b4a..322ed05390 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -156,7 +156,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
- self.server_name = hs.hostname
+ self.server_name: str = hs.hostname
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e9a0cdc6be..216622964a 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -12,15 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
+ LoggingTransaction,
make_in_list_sql_clause,
)
+from synapse.storage.databases.main.registration import RegistrationWorkerStore
from synapse.util.caches.descriptors import cached
from synapse.util.threepids import canonicalise_email
@@ -56,7 +58,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
Number of current monthly active users
"""
- def _count_users(txn):
+ def _count_users(txn: LoggingTransaction) -> int:
# Exclude app service users
sql = """
SELECT COUNT(*)
@@ -66,7 +68,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
WHERE (users.appservice_id IS NULL OR users.appservice_id = '');
"""
txn.execute(sql)
- (count,) = txn.fetchone()
+ (count,) = cast(Tuple[int], txn.fetchone())
return count
return await self.db_pool.runInteraction("count_users", _count_users)
@@ -84,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""
- def _count_users_by_service(txn):
+ def _count_users_by_service(txn: LoggingTransaction) -> Dict[str, int]:
sql = """
SELECT COALESCE(appservice_id, 'native'), COUNT(*)
FROM monthly_active_users
@@ -93,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""
txn.execute(sql)
- result = txn.fetchall()
+ result = cast(List[Tuple[str, int]], txn.fetchall())
return dict(result)
return await self.db_pool.runInteraction(
@@ -141,12 +143,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
)
@wrap_as_background_process("reap_monthly_active_users")
- async def reap_monthly_active_users(self):
+ async def reap_monthly_active_users(self) -> None:
"""Cleans out monthly active user table to ensure that no stale
entries exist.
"""
- def _reap_users(txn, reserved_users):
+ def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None:
"""
Args:
reserved_users (tuple): reserved users to preserve
@@ -210,10 +212,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
# is racy.
# Have resolved to invalidate the whole cache for now and do
# something about it if and when the perf becomes significant
- self._invalidate_all_cache_and_stream(
+ self._invalidate_all_cache_and_stream( # type: ignore[attr-defined]
txn, self.user_last_seen_monthly_active
)
- self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
+ self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) # type: ignore[attr-defined]
reserved_users = await self.get_registered_reserved_users()
await self.db_pool.runInteraction(
@@ -221,7 +223,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
)
-class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
+class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -242,13 +244,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
)
- def _initialise_reserved_users(self, txn, threepids):
+ def _initialise_reserved_users(
+ self, txn: LoggingTransaction, threepids: List[dict]
+ ) -> None:
"""Ensures that reserved threepids are accounted for in the MAU table, should
be called on start up.
Args:
- txn (cursor):
- threepids (list[dict]): List of threepid dicts to reserve
+ txn:
+ threepids: List of threepid dicts to reserve
"""
# XXX what is this function trying to achieve? It upserts into
@@ -299,7 +303,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
)
- def upsert_monthly_active_user_txn(self, txn, user_id):
+ def upsert_monthly_active_user_txn(
+ self, txn: LoggingTransaction, user_id: str
+ ) -> None:
"""Updates or inserts monthly active user member
We consciously do not call is_support_txn from this method because it
@@ -336,7 +342,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
txn, self.user_last_seen_monthly_active, (user_id,)
)
- async def populate_monthly_active_users(self, user_id):
+ async def populate_monthly_active_users(self, user_id: str) -> None:
"""Checks on the state of monthly active user limits and optionally
add the user to the monthly active tables
@@ -345,7 +351,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"""
if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
- is_guest = await self.is_guest(user_id)
+ is_guest = await self.is_guest(user_id) # type: ignore[attr-defined]
if is_guest:
return
is_trial = await self.is_trial_user(user_id)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index bf0b903af2..e6f97aeece 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -24,10 +24,9 @@ from typing import (
Optional,
Set,
Tuple,
+ cast,
)
-from twisted.internet import defer
-
from synapse.api.constants import ReceiptTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
@@ -38,7 +37,11 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdTracker,
+ MultiWriterIdGenerator,
+ StreamIdGenerator,
+)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@@ -58,6 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
hs: "HomeServer",
):
self._instance_name = hs.get_instance_name()
+ self._receipts_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine):
self._can_write_to_receipts = (
@@ -161,7 +165,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
" AND user_id = ?"
)
txn.execute(sql, (user_id,))
- return txn.fetchall()
+ return cast(List[Tuple[str, str, int, int]], txn.fetchall())
rows = await self.db_pool.runInteraction(
"get_receipts_for_user_with_orderings", f
@@ -257,7 +261,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
if not rows:
return []
- content = {}
+ content: JsonDict = {}
for row in rows:
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
row["user_id"]
@@ -305,7 +309,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"_get_linearized_receipts_for_rooms", f
)
- results = {}
+ results: JsonDict = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
@@ -370,7 +374,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"get_linearized_receipts_for_all_rooms", f
)
- results = {}
+ results: JsonDict = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
@@ -399,7 +403,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
if last_id == current_id:
- return defer.succeed([])
+ return []
def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
sql = """
@@ -453,7 +457,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (last_id, current_id, limit))
- updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn]
+ updates = cast(
+ List[Tuple[int, list]],
+ [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
+ )
limited = False
upper_bound = current_id
@@ -496,7 +503,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))
- def process_replication_rows(self, stream_name, instance_name, token, rows):
+ def process_replication_rows(
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ rows: Iterable[Any],
+ ) -> None:
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(instance_name, token)
for row in rows:
@@ -584,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
- self._remove_old_push_actions_before_txn(
+ self._remove_old_push_actions_before_txn( # type: ignore[attr-defined]
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
)
@@ -637,7 +650,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"insert_receipt_conv", graph_to_linear
)
- async with self._receipts_id_gen.get_next() as stream_id:
+ async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index a698d10cc5..7f3d190e94 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -22,6 +22,7 @@ import attr
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
+from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import (
DatabasePool,
@@ -123,7 +124,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
):
super().__init__(database, db_conn, hs)
- self.config = hs.config
+ self.config: HomeServerConfig = hs.config
# Note: we don't check this sequence for consistency as we'd have to
# call `find_max_generated_user_id_localpart` each time, which is
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index c4869d64e6..b2295fd51f 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -27,7 +27,6 @@ from typing import (
)
import attr
-from frozendict import frozendict
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
@@ -41,45 +40,15 @@ from synapse.storage.database import (
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
-from synapse.types import JsonDict, RoomStreamToken, StreamToken
+from synapse.types import RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
from synapse.server import HomeServer
- from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class _ThreadAggregation:
- # The latest event in the thread.
- latest_event: EventBase
- # The latest edit to the latest event in the thread.
- latest_edit: Optional[EventBase]
- # The total number of events in the thread.
- count: int
- # True if the current user has sent an event to the thread.
- current_user_participated: bool
-
-
-@attr.s(slots=True, auto_attribs=True)
-class BundledAggregations:
- """
- The bundled aggregations for an event.
-
- Some values require additional processing during serialization.
- """
-
- annotations: Optional[JsonDict] = None
- references: Optional[JsonDict] = None
- replace: Optional[EventBase] = None
- thread: Optional[_ThreadAggregation] = None
-
- def __bool__(self) -> bool:
- return bool(self.annotations or self.references or self.replace or self.thread)
-
-
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -384,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
- async def _get_applicable_edits(
+ async def get_applicable_edits(
self, event_ids: Collection[str]
) -> Dict[str, Optional[EventBase]]:
"""Get the most recent edit (if any) that has happened for the given
@@ -473,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
- async def _get_thread_summaries(
+ async def get_thread_summaries(
self, event_ids: Collection[str]
) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
"""Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
@@ -587,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore):
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
# Check to see if any of those events are edited.
- latest_edits = await self._get_applicable_edits(latest_event_ids.values())
+ latest_edits = await self.get_applicable_edits(latest_event_ids.values())
# Map to the event IDs to the thread summary.
#
@@ -610,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
- async def _get_threads_participated(
+ async def get_threads_participated(
self, event_ids: Collection[str], user_id: str
) -> Dict[str, bool]:
"""Get whether the requesting user participated in the given threads.
@@ -766,114 +735,6 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
- async def _get_bundled_aggregation_for_event(
- self, event: EventBase, user_id: str
- ) -> Optional[BundledAggregations]:
- """Generate bundled aggregations for an event.
-
- Note that this does not use a cache, but depends on cached methods.
-
- Args:
- event: The event to calculate bundled aggregations for.
- user_id: The user requesting the bundled aggregations.
-
- Returns:
- The bundled aggregations for an event, if bundled aggregations are
- enabled and the event can have bundled aggregations.
- """
-
- # Do not bundle aggregations for an event which represents an edit or an
- # annotation. It does not make sense for them to have related events.
- relates_to = event.content.get("m.relates_to")
- if isinstance(relates_to, (dict, frozendict)):
- relation_type = relates_to.get("rel_type")
- if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
- return None
-
- event_id = event.event_id
- room_id = event.room_id
-
- # The bundled aggregations to include, a mapping of relation type to a
- # type-specific value. Some types include the direct return type here
- # while others need more processing during serialization.
- aggregations = BundledAggregations()
-
- annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
- if annotations.chunk:
- aggregations.annotations = await annotations.to_dict(
- cast("DataStore", self)
- )
-
- references = await self.get_relations_for_event(
- event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
- )
- if references.chunk:
- aggregations.references = await references.to_dict(cast("DataStore", self))
-
- # Store the bundled aggregations in the event metadata for later use.
- return aggregations
-
- async def get_bundled_aggregations(
- self, events: Iterable[EventBase], user_id: str
- ) -> Dict[str, BundledAggregations]:
- """Generate bundled aggregations for events.
-
- Args:
- events: The iterable of events to calculate bundled aggregations for.
- user_id: The user requesting the bundled aggregations.
-
- Returns:
- A map of event ID to the bundled aggregation for the event. Not all
- events may have bundled aggregations in the results.
- """
- # De-duplicate events by ID to handle the same event requested multiple times.
- #
- # State events do not get bundled aggregations.
- events_by_id = {
- event.event_id: event for event in events if not event.is_state()
- }
-
- # event ID -> bundled aggregation in non-serialized form.
- results: Dict[str, BundledAggregations] = {}
-
- # Fetch other relations per event.
- for event in events_by_id.values():
- event_result = await self._get_bundled_aggregation_for_event(event, user_id)
- if event_result:
- results[event.event_id] = event_result
-
- # Fetch any edits (but not for redacted events).
- edits = await self._get_applicable_edits(
- [
- event_id
- for event_id, event in events_by_id.items()
- if not event.internal_metadata.is_redacted()
- ]
- )
- for event_id, edit in edits.items():
- results.setdefault(event_id, BundledAggregations()).replace = edit
-
- # Fetch thread summaries.
- summaries = await self._get_thread_summaries(events_by_id.keys())
- # Only fetch participated for a limited selection based on what had
- # summaries.
- participated = await self._get_threads_participated(summaries.keys(), user_id)
- for event_id, summary in summaries.items():
- if summary:
- thread_count, latest_thread_event, edit = summary
- results.setdefault(
- event_id, BundledAggregations()
- ).thread = _ThreadAggregation(
- latest_event=latest_thread_event,
- latest_edit=edit,
- count=thread_count,
- # If there's a thread summary it must also exist in the
- # participated dictionary.
- current_user_participated=participated[event_id],
- )
-
- return results
-
class RelationsStore(RelationsWorkerStore):
pass
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 94068940b9..18b1acd9e1 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -34,6 +34,7 @@ import attr
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
+from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@@ -98,7 +99,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
):
super().__init__(database, db_conn, hs)
- self.config = hs.config
+ self.config: HomeServerConfig = hs.config
async def store_room(
self,
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index bef675b845..3248da5356 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -63,6 +63,14 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class EventIdMembership:
+ """Returned by `get_membership_from_event_ids`"""
+
+ user_id: str
+ membership: str
+
+
class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(
self,
@@ -772,7 +780,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
retcols=("user_id", "display_name", "avatar_url", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
- desc="_get_membership_from_event_ids",
+ desc="_get_joined_profiles_from_event_ids",
)
return {
@@ -1000,12 +1008,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
+ @cached(max_entries=5000)
+ async def _get_membership_from_event_id(
+ self, member_event_id: str
+ ) -> Optional[EventIdMembership]:
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="_get_membership_from_event_id", list_name="member_event_ids"
+ )
async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
- ) -> List[dict]:
- """Get user_id and membership of a set of event IDs."""
+ ) -> Dict[str, Optional[EventIdMembership]]:
+ """Get user_id and membership of a set of event IDs.
+
+ Returns:
+ Mapping from event ID to `EventIdMembership` if the event is a
+ membership event, otherwise the value is None.
+ """
- return await self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
@@ -1015,6 +1037,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
desc="get_membership_from_event_ids",
)
+ return {
+ row["event_id"]: EventIdMembership(
+ membership=row["membership"], user_id=row["user_id"]
+ )
+ for row in rows
+ }
+
async def is_local_host_in_room_ignoring_users(
self, room_id: str, ignore_users: Collection[str]
) -> bool:
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index bb41beb827..79abe758e6 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -14,7 +14,7 @@
import logging
import re
-from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
import attr
@@ -74,7 +74,7 @@ class SearchWorkerStore(SQLBaseStore):
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
)
- args = (
+ args1 = (
(
entry.event_id,
entry.room_id,
@@ -86,14 +86,14 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries
)
- txn.execute_batch(sql, args)
+ txn.execute_batch(sql, args1)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
- args = (
+ args2 = (
(
entry.event_id,
entry.room_id,
@@ -102,7 +102,7 @@ class SearchWorkerStore(SQLBaseStore):
)
for entry in entries
)
- txn.execute_batch(sql, args)
+ txn.execute_batch(sql, args2)
else:
# This should be unreachable.
@@ -427,7 +427,7 @@ class SearchStore(SearchBackgroundUpdateStore):
search_query = _parse_query(self.database_engine, search_term)
- args = []
+ args: List[Any] = []
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
@@ -496,7 +496,7 @@ class SearchStore(SearchBackgroundUpdateStore):
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
- events = await self.get_events_as_list(
+ events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
@@ -530,7 +530,7 @@ class SearchStore(SearchBackgroundUpdateStore):
room_ids: Collection[str],
search_term: str,
keys: Iterable[str],
- limit,
+ limit: int,
pagination_token: Optional[str] = None,
) -> JsonDict:
"""Performs a full text search over events with given keys.
@@ -549,7 +549,7 @@ class SearchStore(SearchBackgroundUpdateStore):
search_query = _parse_query(self.database_engine, search_term)
- args = []
+ args: List[Any] = []
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
@@ -573,9 +573,9 @@ class SearchStore(SearchBackgroundUpdateStore):
if pagination_token:
try:
- origin_server_ts, stream = pagination_token.split(",")
- origin_server_ts = int(origin_server_ts)
- stream = int(stream)
+ origin_server_ts_str, stream_str = pagination_token.split(",")
+ origin_server_ts = int(origin_server_ts_str)
+ stream = int(stream_str)
except Exception:
raise SynapseError(400, "Invalid pagination token")
@@ -654,7 +654,7 @@ class SearchStore(SearchBackgroundUpdateStore):
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
- events = await self.get_events_as_list(
+ events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 417aef1dbc..28460fd364 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -14,7 +14,7 @@
# limitations under the License.
import collections.abc
import logging
-from typing import TYPE_CHECKING, Iterable, Optional, Set
+from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@@ -29,7 +29,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
-from synapse.types import StateMap
+from synapse.types import JsonDict, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
@@ -241,7 +241,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# We delegate to the cached version
return await self.get_current_state_ids(room_id)
- def _get_filtered_current_state_ids_txn(txn):
+ def _get_filtered_current_state_ids_txn(
+ txn: LoggingTransaction,
+ ) -> StateMap[str]:
results = {}
sql = """
SELECT type, state_key, event_id FROM current_state_events
@@ -281,11 +283,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
event_id = state.get((EventTypes.CanonicalAlias, ""))
if not event_id:
- return
+ return None
event = await self.get_event(event_id, allow_none=True)
if not event:
- return
+ return None
return event.content.get("canonical_alias")
@@ -304,7 +306,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
list_name="event_ids",
num_args=1,
)
- async def _get_state_group_for_events(self, event_ids):
+ async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
"""Returns mapping event_id -> state_group"""
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
@@ -355,7 +357,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
):
super().__init__(database, db_conn, hs)
- self.server_name = hs.hostname
+ self.server_name: str = hs.hostname
self.db_pool.updates.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
@@ -375,7 +377,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
self._background_remove_left_rooms,
)
- async def _background_remove_left_rooms(self, progress, batch_size):
+ async def _background_remove_left_rooms(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
"""Background update to delete rows from `current_state_events` and
`event_forward_extremities` tables of rooms that the server is no
longer joined to.
@@ -383,7 +387,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
last_room_id = progress.get("last_room_id", "")
- def _background_remove_left_rooms_txn(txn):
+ def _background_remove_left_rooms_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[bool, Set[str]]:
# get a batch of room ids to consider
sql = """
SELECT DISTINCT room_id FROM current_state_events
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 427ae1f649..b95dbef678 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -108,7 +108,7 @@ class StatsStore(StateDeltasStore):
):
super().__init__(database, db_conn, hs)
- self.server_name = hs.hostname
+ self.server_name: str = hs.hostname
self.clock = self.hs.get_clock()
self.stats_enabled = hs.config.stats.stats_enabled
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index e7fddd2426..df772d4721 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -26,6 +26,8 @@ from typing import (
cast,
)
+from typing_extensions import TypedDict
+
from synapse.api.errors import StoreError
if TYPE_CHECKING:
@@ -40,7 +42,12 @@ from synapse.storage.database import (
from synapse.storage.databases.main.state import StateFilter
from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
+from synapse.types import (
+ JsonDict,
+ UserProfile,
+ get_domain_from_id,
+ get_localpart_from_id,
+)
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -61,7 +68,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
) -> None:
super().__init__(database, db_conn, hs)
- self.server_name = hs.hostname
+ self.server_name: str = hs.hostname
self.db_pool.updates.register_background_update_handler(
"populate_user_directory_createtables",
@@ -591,6 +598,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
+class SearchResult(TypedDict):
+ limited: bool
+ results: List[UserProfile]
+
+
class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
@@ -718,7 +730,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
- async def get_shared_rooms_for_users(
+ async def get_mutual_rooms_for_users(
self, user_id: str, other_user_id: str
) -> Set[str]:
"""
@@ -732,7 +744,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
A set of room ID's that the users share.
"""
- def _get_shared_rooms_for_users_txn(
+ def _get_mutual_rooms_for_users_txn(
txn: LoggingTransaction,
) -> List[Dict[str, str]]:
txn.execute(
@@ -756,7 +768,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return rows
rows = await self.db_pool.runInteraction(
- "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn
+ "get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn
)
return {row["room_id"] for row in rows}
@@ -777,7 +789,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
async def search_user_dir(
self, user_id: str, search_term: str, limit: int
- ) -> JsonDict:
+ ) -> SearchResult:
"""Searches for users in directory
Returns:
@@ -910,8 +922,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
- results = await self.db_pool.execute(
- "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
+ results = cast(
+ List[UserProfile],
+ await self.db_pool.execute(
+ "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
+ ),
)
limited = len(results) > limit
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 9abc02046e..afb7d5054d 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -27,7 +27,7 @@ def create_engine(database_config) -> BaseDatabaseEngine:
if name == "psycopg2":
# Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
- import psycopg2 # type: ignore
+ import psycopg2
return PostgresEngine(psycopg2, database_config)
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 808342fafb..e8d29e2870 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -47,17 +47,26 @@ class PostgresEngine(BaseDatabaseEngine):
self.default_isolation_level = (
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
+ self.config = database_config
@property
def single_threaded(self) -> bool:
return False
+ def get_db_locale(self, txn):
+ txn.execute(
+ "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
+ )
+ collation, ctype = txn.fetchone()
+ return collation, ctype
+
def check_database(self, db_conn, allow_outdated_version: bool = False):
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them
# together. For example, version 8.1.5 will be returned as 80105
self._version = db_conn.server_version
+ allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
# Are we on a supported PostgreSQL version?
if not allow_outdated_version and self._version < 100000:
@@ -72,33 +81,39 @@ class PostgresEngine(BaseDatabaseEngine):
"See docs/postgres.md for more information." % (rows[0][0],)
)
- txn.execute(
- "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
- )
- collation, ctype = txn.fetchone()
+ collation, ctype = self.get_db_locale(txn)
if collation != "C":
logger.warning(
- "Database has incorrect collation of %r. Should be 'C'\n"
- "See docs/postgres.md for more information.",
+ "Database has incorrect collation of %r. Should be 'C'",
collation,
)
+ if not allow_unsafe_locale:
+ raise IncorrectDatabaseSetup(
+ "Database has incorrect collation of %r. Should be 'C'\n"
+ "See docs/postgres.md for more information. You can override this check by"
+ "setting 'allow_unsafe_locale' to true in the database config.",
+ collation,
+ )
if ctype != "C":
- logger.warning(
- "Database has incorrect ctype of %r. Should be 'C'\n"
- "See docs/postgres.md for more information.",
- ctype,
- )
+ if not allow_unsafe_locale:
+ logger.warning(
+ "Database has incorrect ctype of %r. Should be 'C'",
+ ctype,
+ )
+ raise IncorrectDatabaseSetup(
+ "Database has incorrect ctype of %r. Should be 'C'\n"
+ "See docs/postgres.md for more information. You can override this check by"
+ "setting 'allow_unsafe_locale' to true in the database config.",
+ ctype,
+ )
def check_new_database(self, txn):
"""Gets called when setting up a brand new database. This allows us to
apply stricter checks on new databases versus existing database.
"""
- txn.execute(
- "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
- )
- collation, ctype = txn.fetchone()
+ collation, ctype = self.get_db_locale(txn)
errors = []
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 7d543fdbe0..b402922817 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -1023,8 +1023,13 @@ class EventsPersistenceStorage:
# Check if any of the changes that we don't have events for are joins.
if events_to_check:
- rows = await self.main_store.get_membership_from_event_ids(events_to_check)
- is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
+ members = await self.main_store.get_membership_from_event_ids(
+ events_to_check
+ )
+ is_still_joined = any(
+ member and member.membership == Membership.JOIN
+ for member in members.values()
+ )
if is_still_joined:
return True
@@ -1060,9 +1065,11 @@ class EventsPersistenceStorage:
), event_id in current_state.items()
if typ == EventTypes.Member and not self.is_mine_id(state_key)
]
- rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
+ members = await self.main_store.get_membership_from_event_ids(remote_event_ids)
potentially_left_users.update(
- row["user_id"] for row in rows if row["membership"] == Membership.JOIN
+ member.user_id
+ for member in members.values()
+ if member and member.membership == Membership.JOIN
)
return False
diff --git a/synapse/types.py b/synapse/types.py
index 53be3583a0..5ce2a5b0a5 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -34,6 +34,7 @@ from typing import (
import attr
from frozendict import frozendict
from signedjson.key import decode_verify_key_bytes
+from typing_extensions import TypedDict
from unpaddedbase64 import decode_base64
from zope.interface import Interface
@@ -63,6 +64,10 @@ MutableStateMap = MutableMapping[StateKey, T]
# JSON types. These could be made stronger, but will do for now.
# A JSON-serialisable dict.
JsonDict = Dict[str, Any]
+# A JSON-serialisable mapping; roughly speaking an immutable JSONDict.
+# Useful when you have a TypedDict which isn't going to be mutated and you don't want
+# to cast to JsonDict everywhere.
+JsonMapping = Mapping[str, Any]
# A JSON-serialisable object.
JsonSerializable = object
@@ -791,3 +796,9 @@ class UserInfo:
is_deactivated: bool
is_guest: bool
is_shadow_banned: bool
+
+
+class UserProfile(TypedDict):
+ user_id: str
+ display_name: Optional[str]
+ avatar_url: Optional[str]
diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py
index 12cd804939..66f1da7502 100644
--- a/synapse/util/check_dependencies.py
+++ b/synapse/util/check_dependencies.py
@@ -128,6 +128,19 @@ def _incorrect_version(
)
+def _no_reported_version(requirement: Requirement, extra: Optional[str] = None) -> str:
+ if extra:
+ return (
+ f"Synapse {VERSION} needs {requirement} for {extra}, "
+ f"but can't determine {requirement.name}'s version"
+ )
+ else:
+ return (
+ f"Synapse {VERSION} needs {requirement}, "
+ f"but can't determine {requirement.name}'s version"
+ )
+
+
def check_requirements(extra: Optional[str] = None) -> None:
"""Check Synapse's dependencies are present and correctly versioned.
@@ -163,8 +176,17 @@ def check_requirements(extra: Optional[str] = None) -> None:
deps_unfulfilled.append(requirement.name)
errors.append(_not_installed(requirement, extra))
else:
+ if dist.version is None:
+ # This shouldn't happen---it suggests a borked virtualenv. (See #12223)
+ # Try to give a vaguely helpful error message anyway.
+ # Type-ignore: the annotations don't reflect reality: see
+ # https://github.com/python/typeshed/issues/7513
+ # https://bugs.python.org/issue47060
+ deps_unfulfilled.append(requirement.name) # type: ignore[unreachable]
+ errors.append(_no_reported_version(requirement, extra))
+
# We specify prereleases=True to allow prereleases such as RCs.
- if not requirement.specifier.contains(dist.version, prereleases=True):
+ elif not requirement.specifier.contains(dist.version, prereleases=True):
deps_unfulfilled.append(requirement.name)
errors.append(_incorrect_version(requirement, dist.version, extra))
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 281cbe4d88..49519eb8f5 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -14,12 +14,7 @@
import logging
from typing import Dict, FrozenSet, List, Optional
-from synapse.api.constants import (
- AccountDataTypes,
- EventTypes,
- HistoryVisibility,
- Membership,
-)
+from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.events import EventBase
from synapse.events.utils import prune_event
from synapse.storage import Storage
@@ -87,15 +82,8 @@ async def filter_events_for_client(
state_filter=StateFilter.from_types(types),
)
- ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user(
- user_id, AccountDataTypes.IGNORED_USER_LIST
- )
-
- ignore_list: FrozenSet[str] = frozenset()
- if ignore_dict_content:
- ignored_users_dict = ignore_dict_content.get("ignored_users", {})
- if isinstance(ignored_users_dict, dict):
- ignore_list = frozenset(ignored_users_dict.keys())
+ # Get the users who are ignored by the requesting user.
+ ignore_list = await storage.main.ignored_users(user_id)
erased_senders = await storage.main.are_users_erased(e.sender for e in events)
|