summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--UPGRADE.rst13
-rw-r--r--changelog.d/7785.feature1
-rw-r--r--changelog.d/8170.feature1
-rw-r--r--changelog.d/8213.misc1
-rw-r--r--changelog.d/8225.misc1
-rw-r--r--changelog.d/8226.bugfix1
-rw-r--r--docs/workers.md1
-rw-r--r--synapse/app/admin_cmd.py20
-rw-r--r--synapse/app/homeserver.py18
-rw-r--r--synapse/config/_base.py21
-rw-r--r--synapse/config/_base.pyi1
-rw-r--r--synapse/config/workers.py37
-rw-r--r--synapse/handlers/federation.py44
-rw-r--r--synapse/handlers/message.py14
-rw-r--r--synapse/handlers/room.py14
-rw-r--r--synapse/handlers/room_member.py7
-rw-r--r--synapse/replication/http/federation.py12
-rw-r--r--synapse/replication/tcp/handler.py2
-rw-r--r--synapse/replication/tcp/streams/events.py4
-rw-r--r--synapse/rest/__init__.py4
-rw-r--r--synapse/rest/client/v2_alpha/shared_rooms.py68
-rw-r--r--synapse/rest/client/versions.py2
-rw-r--r--synapse/storage/databases/__init__.py2
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py17
-rw-r--r--synapse/storage/databases/main/event_federation.py2
-rw-r--r--synapse/storage/databases/main/events.py4
-rw-r--r--synapse/storage/databases/main/events_worker.py66
-rw-r--r--synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql16
-rw-r--r--synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres26
-rw-r--r--synapse/storage/databases/main/stats.py34
-rw-r--r--synapse/storage/databases/main/user_directory.py44
-rw-r--r--synapse/storage/util/id_generators.py10
-rw-r--r--tests/handlers/test_register.py6
-rw-r--r--tests/rest/client/v2_alpha/test_shared_rooms.py138
34 files changed, 531 insertions, 121 deletions
diff --git a/UPGRADE.rst b/UPGRADE.rst

index b2069a0d26..188171d7ab 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst
@@ -1,3 +1,16 @@ +Upgrading to v1.20.0 +==================== + +Shared rooms endpoint (MSC2666) +------------------------------- + +This release contains a new unstable endpoint `/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/.*` +for fetching rooms one user has in common with another. This feature requires the +`update_user_directory` config flag to be `True`. If you are you are using a `synapse.app.user_dir` +worker, requests to this endpoint must be handled by that worker. +See `docs/workers.md <docs/workers.md>`_ for more details. + + Upgrading Synapse ================= diff --git a/changelog.d/7785.feature b/changelog.d/7785.feature new file mode 100644
index 0000000000..c7e51c9320 --- /dev/null +++ b/changelog.d/7785.feature
@@ -0,0 +1 @@ +Add an endpoint to query your shared rooms with another user as an implementation of [MSC2666](https://github.com/matrix-org/matrix-doc/pull/2666). diff --git a/changelog.d/8170.feature b/changelog.d/8170.feature new file mode 100644
index 0000000000..b363e929ea --- /dev/null +++ b/changelog.d/8170.feature
@@ -0,0 +1 @@ +Add experimental support for sharding event persister. diff --git a/changelog.d/8213.misc b/changelog.d/8213.misc new file mode 100644
index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8213.misc
@@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/changelog.d/8225.misc b/changelog.d/8225.misc new file mode 100644
index 0000000000..979c8b227b --- /dev/null +++ b/changelog.d/8225.misc
@@ -0,0 +1 @@ +Refactor queries for device keys and cross-signatures. diff --git a/changelog.d/8226.bugfix b/changelog.d/8226.bugfix new file mode 100644
index 0000000000..60bdff576d --- /dev/null +++ b/changelog.d/8226.bugfix
@@ -0,0 +1 @@ +Fix a longstanding bug where stats updates could break when unexpected profile data was included in events. diff --git a/docs/workers.md b/docs/workers.md
index bfec745897..7a8f5c89fc 100644 --- a/docs/workers.md +++ b/docs/workers.md
@@ -380,6 +380,7 @@ Handles searches in the user directory. It can handle REST endpoints matching the following regular expressions: ^/_matrix/client/(api/v1|r0|unstable)/user_directory/search$ + ^/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/.*$ When using this worker you must also set `update_user_directory: False` in the shared configuration file to stop the main synapse running background diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index a37818fe9a..b6c9085670 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py
@@ -79,8 +79,7 @@ class AdminCmdServer(HomeServer): pass -@defer.inlineCallbacks -def export_data_command(hs, args): +async def export_data_command(hs, args): """Export data for a user. Args: @@ -91,10 +90,8 @@ def export_data_command(hs, args): user_id = args.user_id directory = args.output_directory - res = yield defer.ensureDeferred( - hs.get_handlers().admin_handler.export_user_data( - user_id, FileExfiltrationWriter(user_id, directory=directory) - ) + res = await hs.get_handlers().admin_handler.export_user_data( + user_id, FileExfiltrationWriter(user_id, directory=directory) ) print(res) @@ -232,14 +229,15 @@ def start(config_options): # We also make sure that `_base.start` gets run before we actually run the # command. - @defer.inlineCallbacks - def run(_reactor): + async def run(): with LoggingContext("command"): - yield _base.start(ss, []) - yield args.func(ss, args) + _base.start(ss, []) + await args.func(ss, args) _base.start_worker_reactor( - "synapse-admin-cmd", config, run_command=lambda: task.react(run) + "synapse-admin-cmd", + config, + run_command=lambda: task.react(lambda _reactor: defer.ensureDeferred(run())), ) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 98d0d14a12..6014adc850 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py
@@ -411,26 +411,24 @@ def setup(config_options): return provision - @defer.inlineCallbacks - def reprovision_acme(): + async def reprovision_acme(): """ Provision a certificate from ACME, if required, and reload the TLS certificate if it's renewed. """ - reprovisioned = yield defer.ensureDeferred(do_acme()) + reprovisioned = await do_acme() if reprovisioned: _base.refresh_certificate(hs) - @defer.inlineCallbacks - def start(): + async def start(): try: # Run the ACME provisioning code, if it's enabled. if hs.config.acme_enabled: acme = hs.get_acme_handler() # Start up the webservices which we will respond to ACME # challenges with, and then provision. - yield defer.ensureDeferred(acme.start_listening()) - yield defer.ensureDeferred(do_acme()) + await acme.start_listening() + await do_acme() # Check if it needs to be reprovisioned every day. hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000) @@ -439,8 +437,8 @@ def setup(config_options): if hs.config.oidc_enabled: oidc = hs.get_oidc_handler() # Loading the provider metadata also ensures the provider config is valid. - yield defer.ensureDeferred(oidc.load_metadata()) - yield defer.ensureDeferred(oidc.load_jwks()) + await oidc.load_metadata() + await oidc.load_jwks() _base.start(hs, config.listeners) @@ -456,7 +454,7 @@ def setup(config_options): reactor.stop() sys.exit(1) - reactor.callWhenRunning(start) + reactor.callWhenRunning(lambda: defer.ensureDeferred(start())) return hs diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 1477b27326..876bd354ab 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py
@@ -833,11 +833,26 @@ class ShardedWorkerHandlingConfig: def should_handle(self, instance_name: str, key: str) -> bool: """Whether this instance is responsible for handling the given key. """ - - # If multiple instances are not defined we always return true. + # If multiple instances are not defined we always return true if not self.instances or len(self.instances) == 1: return True + return self.get_instance(key) == instance_name + + def get_instance(self, key: str) -> str: + """Get the instance responsible for handling the given key. + + Note: For things like federation sending the config for which instance + is sending is known only to the sender instance if there is only one. + Therefore `should_handle` should be used where possible. + """ + + if not self.instances: + return "master" + + if len(self.instances) == 1: + return self.instances[0] + # We shard by taking the hash, modulo it by the number of instances and # then checking whether this instance matches the instance at that # index. @@ -847,7 +862,7 @@ class ShardedWorkerHandlingConfig: dest_hash = sha256(key.encode("utf8")).digest() dest_int = int.from_bytes(dest_hash, byteorder="little") remainder = dest_int % (len(self.instances)) - return self.instances[remainder] == instance_name + return self.instances[remainder] __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index eb911e8f9f..b8faafa9bd 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi
@@ -142,3 +142,4 @@ class ShardedWorkerHandlingConfig: instances: List[str] def __init__(self, instances: List[str]) -> None: ... def should_handle(self, instance_name: str, key: str) -> bool: ... + def get_instance(self, key: str) -> str: ... diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index c784a71508..f23e42cdf9 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py
@@ -13,12 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Union + import attr from ._base import Config, ConfigError, ShardedWorkerHandlingConfig from .server import ListenerConfig, parse_listener_def +def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: + """Helper for allowing parsing a string or list of strings to a config + option expecting a list of strings. + """ + + if isinstance(obj, str): + return [obj] + return obj + + @attr.s class InstanceLocationConfig: """The host and port to talk to an instance via HTTP replication. @@ -33,11 +45,13 @@ class WriterLocations: """Specifies the instances that write various streams. Attributes: - events: The instance that writes to the event and backfill streams. - events: The instance that writes to the typing stream. + events: The instances that write to the event and backfill streams. + typing: The instance that writes to the typing stream. """ - events = attr.ib(default="master", type=str) + events = attr.ib( + default=["master"], type=List[str], converter=_instance_to_list_converter + ) typing = attr.ib(default="master", type=str) @@ -105,15 +119,18 @@ class WorkerConfig(Config): writers = config.get("stream_writers") or {} self.writers = WriterLocations(**writers) - # Check that the configured writer for events and typing also appears in + # Check that the configured writers for events and typing also appears in # `instance_map`. for stream in ("events", "typing"): - instance = getattr(self.writers, stream) - if instance != "master" and instance not in self.instance_map: - raise ConfigError( - "Instance %r is configured to write %s but does not appear in `instance_map` config." - % (instance, stream) - ) + instances = _instance_to_list_converter(getattr(self.writers, stream)) + for instance in instances: + if instance != "master" and instance not in self.instance_map: + raise ConfigError( + "Instance %r is configured to write %s but does not appear in `instance_map` config." + % (instance, stream) + ) + + self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index b7e23a6072..f67b29cba1 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -926,7 +926,8 @@ class FederationHandler(BaseHandler): ) ) - await self._handle_new_events(dest, ev_infos, backfilled=True) + if ev_infos: + await self._handle_new_events(dest, room_id, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -1219,7 +1220,7 @@ class FederationHandler(BaseHandler): event_infos.append(_NewEventInfo(event, None, auth)) await self._handle_new_events( - destination, event_infos, + destination, room_id, event_infos, ) def _sanity_check_event(self, ev): @@ -1366,15 +1367,15 @@ class FederationHandler(BaseHandler): ) max_stream_id = await self._persist_auth_tree( - origin, auth_chain, state, event, room_version_obj + origin, room_id, auth_chain, state, event, room_version_obj ) # We wait here until this instance has seen the events come down # replication (if we're using replication) as the below uses caches. - # - # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.config.worker.writers.events, "events", max_stream_id + self.config.worker.events_shard_config.get_instance(room_id), + "events", + max_stream_id, ) # Check whether this room is the result of an upgrade of a room we already know @@ -1635,7 +1636,7 @@ class FederationHandler(BaseHandler): ) context = await self.state_handler.compute_event_context(event) - await self.persist_events_and_notify([(event, context)]) + await self.persist_events_and_notify(event.room_id, [(event, context)]) return event @@ -1662,7 +1663,9 @@ class FederationHandler(BaseHandler): await self.federation_client.send_leave(host_list, event) context = await self.state_handler.compute_event_context(event) - stream_id = await self.persist_events_and_notify([(event, context)]) + stream_id = await self.persist_events_and_notify( + event.room_id, [(event, context)] + ) return event, stream_id @@ -1910,7 +1913,7 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( - [(event, context)], backfilled=backfilled + event.room_id, [(event, context)], backfilled=backfilled ) except Exception: run_in_background( @@ -1923,6 +1926,7 @@ class FederationHandler(BaseHandler): async def _handle_new_events( self, origin: str, + room_id: str, event_infos: Iterable[_NewEventInfo], backfilled: bool = False, ) -> None: @@ -1954,6 +1958,7 @@ class FederationHandler(BaseHandler): ) await self.persist_events_and_notify( + room_id, [ (ev_info.event, context) for ev_info, context in zip(event_infos, contexts) @@ -1964,6 +1969,7 @@ class FederationHandler(BaseHandler): async def _persist_auth_tree( self, origin: str, + room_id: str, auth_events: List[EventBase], state: List[EventBase], event: EventBase, @@ -1978,6 +1984,7 @@ class FederationHandler(BaseHandler): Args: origin: Where the events came from + room_id, auth_events state event @@ -2052,17 +2059,20 @@ class FederationHandler(BaseHandler): events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR await self.persist_events_and_notify( + room_id, [ (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) - ] + ], ) new_event_context = await self.state_handler.compute_event_context( event, old_state=state ) - return await self.persist_events_and_notify([(event, new_event_context)]) + return await self.persist_events_and_notify( + room_id, [(event, new_event_context)] + ) async def _prep_event( self, @@ -2913,6 +2923,7 @@ class FederationHandler(BaseHandler): async def persist_events_and_notify( self, + room_id: str, event_and_contexts: Sequence[Tuple[EventBase, EventContext]], backfilled: bool = False, ) -> int: @@ -2920,14 +2931,19 @@ class FederationHandler(BaseHandler): necessary. Args: - event_and_contexts: + room_id: The room ID of events being persisted. + event_and_contexts: Sequence of events with their associated + context that should be persisted. All events must belong to + the same room. backfilled: Whether these events are a result of backfilling or not """ - if self.config.worker.writers.events != self._instance_name: + instance = self.config.worker.events_shard_config.get_instance(room_id) + if instance != self._instance_name: result = await self._send_events( - instance_name=self.config.worker.writers.events, + instance_name=instance, store=self.store, + room_id=room_id, event_and_contexts=event_and_contexts, backfilled=backfilled, ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 6ab6ab2c34..aa362dade0 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py
@@ -377,9 +377,8 @@ class EventCreationHandler(object): self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases - self._is_event_writer = ( - self.config.worker.writers.events == hs.get_instance_name() - ) + self._events_shard_config = self.config.worker.events_shard_config + self._instance_name = hs.get_instance_name() self.room_invite_state_types = self.hs.config.room_invite_state_types @@ -907,9 +906,10 @@ class EventCreationHandler(object): try: # If we're a worker we need to hit out to the master. - if not self._is_event_writer: + writer_instance = self._events_shard_config.get_instance(event.room_id) + if writer_instance != self._instance_name: result = await self.send_event( - instance_name=self.config.worker.writers.events, + instance_name=writer_instance, event_id=event.event_id, store=self.store, requester=requester, @@ -977,7 +977,9 @@ class EventCreationHandler(object): This should only be run on the instance in charge of persisting events. """ - assert self._is_event_writer + assert self._events_shard_config.should_handle( + self._instance_name, event.room_id + ) if ratelimit: # We check if this is a room admin redacting an event so that we diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a5171af56d..bbf8560ded 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py
@@ -833,7 +833,9 @@ class RoomCreationHandler(BaseHandler): # Always wait for room creation to progate before returning await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", last_stream_id + self.hs.config.worker.events_shard_config.get_instance(room_id), + "events", + last_stream_id, ) return result, last_stream_id @@ -1290,10 +1292,10 @@ class RoomShutdownHandler(object): # We now wait for the create room to come back in via replication so # that we can assume that all the joins/invites have propogated before # we try and auto join below. - # - # TODO: Currently the events stream is written to from master await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id + self.hs.config.worker.events_shard_config.get_instance(new_room_id), + "events", + stream_id, ) else: new_room_id = None @@ -1323,7 +1325,9 @@ class RoomShutdownHandler(object): # Wait for leave to come in over replication before trying to forget. await self._replication.wait_for_stream_position( - self.hs.config.worker.writers.events, "events", stream_id + self.hs.config.worker.events_shard_config.get_instance(room_id), + "events", + stream_id, ) await self.room_member_handler.forget(target_requester.user, room_id) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index eb64b3b939..a3b5f084df 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -83,13 +83,6 @@ class RoomMemberHandler(object): self._enable_lookup = hs.config.enable_3pid_lookup self.allow_per_room_profiles = self.config.allow_per_room_profiles - self._event_stream_writer_instance = hs.config.worker.writers.events - self._is_on_event_persistence_instance = ( - self._event_stream_writer_instance == hs.get_instance_name() - ) - if self._is_on_event_persistence_instance: - self.persist_event_storage = hs.get_storage().persistence - self._join_rate_limiter_local = Ratelimiter( clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 6b56315148..5c8be747e1 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py
@@ -65,10 +65,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.federation_handler = hs.get_handlers().federation_handler @staticmethod - async def _serialize_payload(store, event_and_contexts, backfilled): + async def _serialize_payload(store, room_id, event_and_contexts, backfilled): """ Args: store + room_id (str) event_and_contexts (list[tuple[FrozenEvent, EventContext]]) backfilled (bool): Whether or not the events are the result of backfilling @@ -88,7 +89,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): } ) - payload = {"events": event_payloads, "backfilled": backfilled} + payload = { + "events": event_payloads, + "backfilled": backfilled, + "room_id": room_id, + } return payload @@ -96,6 +101,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): with Measure(self.clock, "repl_fed_send_events_parse"): content = parse_json_object_from_request(request) + room_id = content["room_id"] backfilled = content["backfilled"] event_payloads = content["events"] @@ -120,7 +126,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): logger.info("Got %d events from federation", len(event_and_contexts)) max_stream_id = await self.federation_handler.persist_events_and_notify( - event_and_contexts, backfilled + room_id, event_and_contexts, backfilled ) return 200, {"max_stream_id": max_stream_id} diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 1c303f3a46..b323841f73 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py
@@ -109,7 +109,7 @@ class ReplicationCommandHandler: if isinstance(stream, (EventsStream, BackfillStream)): # Only add EventStream and BackfillStream as a source on the # instance in charge of event persistence. - if hs.config.worker.writers.events == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.events: self._streams_to_replicate.append(stream) continue diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 16c63ff4ec..3705618b4f 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py
@@ -19,7 +19,7 @@ from typing import List, Tuple, Type import attr -from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance +from ._base import Stream, StreamUpdateResult, Token """Handling of the 'events' replication stream @@ -117,7 +117,7 @@ class EventsStream(Stream): self._store = hs.get_datastore() super().__init__( hs.get_instance_name(), - current_token_without_instance(self._store.get_current_events_token), + self._store._stream_id_gen.get_current_token_for_writer, self._update_function, ) diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 2e81eeff65..10ac6fd7dc 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py
@@ -50,6 +50,7 @@ from synapse.rest.client.v2_alpha import ( room_keys, room_upgrade_rest_servlet, sendtodevice, + shared_rooms, sync, tags, thirdparty, @@ -126,3 +127,6 @@ class ClientRestResource(JsonResource): synapse.rest.admin.register_servlets_for_client_rest_resource( hs, client_resource ) + + # unstable + shared_rooms.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/v2_alpha/shared_rooms.py new file mode 100644
index 0000000000..2492634dac --- /dev/null +++ b/synapse/rest/client/v2_alpha/shared_rooms.py
@@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Half-Shot +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet +from synapse.types import UserID + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class UserSharedRoomsServlet(RestServlet): + """ + GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1 + """ + + PATTERNS = client_patterns( + "/uk.half-shot.msc2666/user/shared_rooms/(?P<user_id>[^/]*)", + releases=(), # This is an unstable feature + ) + + def __init__(self, hs): + super(UserSharedRoomsServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.user_directory_active = hs.config.update_user_directory + + async def on_GET(self, request, user_id): + + if not self.user_directory_active: + raise SynapseError( + code=400, + msg="The user directory is disabled on this server. Cannot determine shared rooms.", + errcode=Codes.FORBIDDEN, + ) + + UserID.from_string(user_id) + + requester = await self.auth.get_user_by_req(request) + if user_id == requester.user.to_string(): + raise SynapseError( + code=400, + msg="You cannot request a list of shared rooms with yourself", + errcode=Codes.FORBIDDEN, + ) + rooms = await self.store.get_shared_rooms_for_users( + requester.user.to_string(), user_id + ) + + return 200, {"joined": list(rooms)} + + +def register_servlets(hs, http_server): + UserSharedRoomsServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index b1999d051b..58ec5a694a 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py
@@ -63,6 +63,8 @@ class VersionsRestServlet(RestServlet): # Tchap does not currently assume this rule for r0.5.0 # XXX: Remove this when it does "m.lazy_load_members": True, + # Implements additional endpoints as described in MSC2666 + "uk.half-shot.msc2666": True, }, }, ) diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 0ac854aee2..c73d54fb67 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py
@@ -68,7 +68,7 @@ class Databases(object): # If we're on a process that can persist events also # instantiate a `PersistEventsStore` - if hs.config.worker.writers.events == hs.get_instance_name(): + if hs.get_instance_name() in hs.config.worker.writers.events: persist_events = PersistEventsStore(hs, database, main) if "state" in database_config.databases: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 449d95f31e..4059701cfd 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -24,7 +24,7 @@ from twisted.enterprise.adbapi import Connection from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause +from synapse.storage.database import make_in_list_sql_clause from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -58,18 +58,13 @@ class EndToEndKeyWorkerStore(SQLBaseStore): Returns: (stream_id, devices) """ - return await self.db_pool.runInteraction( - "get_e2e_device_keys_for_federation_query", - self._get_e2e_device_keys_for_federation_query_txn, - user_id, - ) - - def _get_e2e_device_keys_for_federation_query_txn( - self, txn: LoggingTransaction, user_id: str - ) -> Tuple[int, List[JsonDict]]: now_stream_id = self.get_device_stream_token() - devices = self._get_e2e_device_keys_and_signatures_txn(txn, [(user_id, None)]) + devices = await self.db_pool.runInteraction( + "get_e2e_device_keys_and_signatures_txn", + self._get_e2e_device_keys_and_signatures_txn, + [(user_id, None)], + ) if devices: user_devices = devices[user_id] diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 0b69aa6a94..4c3c162acf 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas """ if stream_ordering <= self.stream_ordering_month_ago: - raise StoreError(400, "stream_ordering too old") + raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,)) sql = """ SELECT event_id FROM stream_ordering_to_exterm diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 6313b41eef..46b11e705b 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -97,6 +97,7 @@ class PersistEventsStore: self.store = main_data_store self.database_engine = db.engine self._clock = hs.get_clock() + self._instance_name = hs.get_instance_name() self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id @@ -108,7 +109,7 @@ class PersistEventsStore: # This should only exist on instances that are configured to write assert ( - hs.config.worker.writers.events == hs.get_instance_name() + hs.get_instance_name() in hs.config.worker.writers.events ), "Can only instantiate EventsStore on master" async def _persist_events_and_state_updates( @@ -800,6 +801,7 @@ class PersistEventsStore: table="events", values=[ { + "instance_name": self._instance_name, "stream_ordering": event.internal_metadata.stream_ordering, "topological_ordering": event.depth, "depth": event.depth, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index a7a73cc3d8..17f5997b89 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -42,7 +42,8 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import Collection, get_domain_from_id from synapse.util.caches.descriptors import Cache, cached from synapse.util.iterutils import batch_iter @@ -78,27 +79,54 @@ class EventsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) - if hs.config.worker.writers.events == hs.get_instance_name(): - # We are the process in charge of generating stream ids for events, - # so instantiate ID generators based on the database - self._stream_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering", + if isinstance(database.engine, PostgresEngine): + # If we're using Postgres than we can use `MultiWriterIdGenerator` + # regardless of whether this process writes to the streams or not. + self._stream_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + instance_name=hs.get_instance_name(), + table="events", + instance_column="instance_name", + id_column="stream_ordering", + sequence_name="events_stream_seq", ) - self._backfill_id_gen = StreamIdGenerator( - db_conn, - "events", - "stream_ordering", - step=-1, - extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + self._backfill_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + instance_name=hs.get_instance_name(), + table="events", + instance_column="instance_name", + id_column="stream_ordering", + sequence_name="events_backfill_stream_seq", + positive=False, ) else: - # Another process is in charge of persisting events and generating - # stream IDs: rely on the replication streams to let us know which - # IDs we can process. - self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") - self._backfill_id_gen = SlavedIdTracker( - db_conn, "events", "stream_ordering", step=-1 - ) + # We shouldn't be running in worker mode with SQLite, but its useful + # to support it for unit tests. + # + # If this process is the writer than we need to use + # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets + # updated over replication. (Multiple writers are not supported for + # SQLite). + if hs.get_instance_name() in hs.config.worker.writers.events: + self._stream_id_gen = StreamIdGenerator( + db_conn, "events", "stream_ordering", + ) + self._backfill_id_gen = StreamIdGenerator( + db_conn, + "events", + "stream_ordering", + step=-1, + extra_tables=[("ex_outlier_stream", "event_stream_ordering")], + ) + else: + self._stream_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering" + ) + self._backfill_id_gen = SlavedIdTracker( + db_conn, "events", "stream_ordering", step=-1 + ) self._get_event_cache = Cache( "*getEvent*", diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql new file mode 100644
index 0000000000..98ff76d709 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql
@@ -0,0 +1,16 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE events ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres new file mode 100644
index 0000000000..97c1e6a0c5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
@@ -0,0 +1,26 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS events_stream_seq; + +SELECT setval('events_stream_seq', ( + SELECT COALESCE(MAX(stream_ordering), 1) FROM events +)); + +CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq; + +SELECT setval('events_backfill_stream_seq', ( + SELECT COALESCE(-MIN(stream_ordering), 1) FROM events +)); diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 9b9bc304a8..55a250ef06 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py
@@ -224,14 +224,32 @@ class StatsStore(StateDeltasStore): ) async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None: - """ + """Update the state of a room. + + fields can contain the following keys with string values: + * join_rules + * history_visibility + * encryption + * name + * topic + * avatar + * canonical_alias + + A is_federatable key can also be included with a boolean value. + Args: - room_id - fields + room_id: The room ID to update the state of. + fields: The fields to update. This can include a partial list of the + above fields to only update some room information. """ - - # For whatever reason some of the fields may contain null bytes, which - # postgres isn't a fan of, so we replace those fields with null. + # Ensure that the values to update are valid, they should be strings and + # not contain any null bytes. + # + # Invalid data gets overwritten with null. + # + # Note that a missing value should not be overwritten (it keeps the + # previous value). + sentinel = object() for col in ( "join_rules", "history_visibility", @@ -241,8 +259,8 @@ class StatsStore(StateDeltasStore): "avatar", "canonical_alias", ): - field = fields.get(col) - if field and "\0" in field: + field = fields.get(col, sentinel) + if field is not sentinel and (not isinstance(field, str) or "\0" in field): fields[col] = None await self.db_pool.simple_upsert( diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index c977db042e..f2f9a5799a 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py
@@ -15,7 +15,7 @@ import logging import re -from typing import Any, Dict, Iterable, Optional, Tuple +from typing import Any, Dict, Iterable, Optional, Set, Tuple from synapse.api.constants import EventTypes, JoinRules from synapse.storage.database import DatabasePool @@ -675,6 +675,48 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) + @cached() + async def get_shared_rooms_for_users( + self, user_id: str, other_user_id: str + ) -> Set[str]: + """ + Returns the rooms that a local user shares with another local or remote user. + + Args: + user_id: The MXID of a local user + other_user_id: The MXID of the other user + + Returns: + A set of room ID's that the users share. + """ + + def _get_shared_rooms_for_users_txn(txn): + txn.execute( + """ + SELECT p1.room_id + FROM users_in_public_rooms as p1 + INNER JOIN users_in_public_rooms as p2 + ON p1.room_id = p2.room_id + AND p1.user_id = ? + AND p2.user_id = ? + UNION + SELECT room_id + FROM users_who_share_private_rooms + WHERE + user_id = ? + AND other_user_id = ? + """, + (user_id, other_user_id, user_id, other_user_id), + ) + rows = self.db_pool.cursor_to_dict(txn) + return rows + + rows = await self.db_pool.runInteraction( + "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn + ) + + return {row["room_id"] for row in rows} + async def get_user_directory_stream_pos(self) -> int: return await self.db_pool.simple_select_one_onecol( table="user_directory_stream_pos", diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9f3d23f0a5..8fd21c2bf8 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py
@@ -231,8 +231,12 @@ class MultiWriterIdGenerator: # gaps should be relatively rare it's still worth doing the book keeping # that allows us to skip forwards when there are gapless runs of # positions. + # + # We start at 1 here as a) the first generated stream ID will be 2, and + # b) other parts of the code assume that stream IDs are strictly greater + # than 0. self._persisted_upto_position = ( - min(self._current_positions.values()) if self._current_positions else 0 + min(self._current_positions.values()) if self._current_positions else 1 ) self._known_persisted_positions = [] # type: List[int] @@ -362,9 +366,7 @@ class MultiWriterIdGenerator: equal to it have been successfully persisted. """ - # Currently we don't support this operation, as it's not obvious how to - # condense the stream positions of multiple writers into a single int. - raise NotImplementedError() + return self.get_persisted_upto_position() def get_current_token_for_writer(self, instance_name: str) -> int: """Returns the position of the given writer. diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index f56708417f..02111d5d7f 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py
@@ -576,16 +576,16 @@ class RegistrationTestCase(unittest.HomeserverTestCase): # Mock Synapse's threepid validator get_threepid_validation_session = Mock( - return_value=defer.succeed( + return_value=make_awaitable( {"medium": "email", "address": email, "validated_at": 0} ) ) self.store.get_threepid_validation_session = get_threepid_validation_session - delete_threepid_session = Mock(return_value=defer.succeed(None)) + delete_threepid_session = Mock(return_value=make_awaitable(None)) self.store.delete_threepid_session = delete_threepid_session # Mock Synapse's http json post method to check for the internal bind call - post_json_get_json = Mock(return_value=defer.succeed(None)) + post_json_get_json = Mock(return_value=make_awaitable(None)) self.hs.get_simple_http_client().post_json_get_json = post_json_get_json # Retrieve a UIA session ID diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py new file mode 100644
index 0000000000..5ae72fd008 --- /dev/null +++ b/tests/rest/client/v2_alpha/test_shared_rooms.py
@@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Half-Shot +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import synapse.rest.admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import shared_rooms + +from tests import unittest + + +class UserSharedRoomsTest(unittest.HomeserverTestCase): + """ + Tests the UserSharedRoomsServlet. + """ + + servlets = [ + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + shared_rooms.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["update_user_directory"] = True + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_user_directory_handler() + + def _get_shared_rooms(self, token, other_user): + request, channel = self.make_request( + "GET", + "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" + % other_user, + access_token=token, + ) + self.render(request) + return request, channel + + def test_shared_room_list_public(self): + """ + A room should show up in the shared list of rooms between two users + if it is public. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + request, channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 1) + self.assertEquals(channel.json_body["joined"][0], room) + + def test_shared_room_list_private(self): + """ + A room should show up in the shared list of rooms between two users + if it is private. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + request, channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 1) + self.assertEquals(channel.json_body["joined"][0], room) + + def test_shared_room_list_mixed(self): + """ + The shared room list between two users should contain both public and private + rooms. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + room_public = self.helper.create_room_as(u1, is_public=True, tok=u1_token) + room_private = self.helper.create_room_as(u2, is_public=False, tok=u2_token) + self.helper.invite(room_public, src=u1, targ=u2, tok=u1_token) + self.helper.invite(room_private, src=u2, targ=u1, tok=u2_token) + self.helper.join(room_public, user=u2, tok=u2_token) + self.helper.join(room_private, user=u1, tok=u1_token) + + request, channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 2) + self.assertTrue(room_public in channel.json_body["joined"]) + self.assertTrue(room_private in channel.json_body["joined"]) + + def test_shared_room_list_after_leave(self): + """ + A room should no longer be considered shared if the other + user has left it. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + # Assert user directory is not empty + request, channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 1) + self.assertEquals(channel.json_body["joined"][0], room) + + self.helper.leave(room, user=u1, tok=u1_token) + + request, channel = self._get_shared_rooms(u2_token, u1) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 0)