diff --git a/UPGRADE.rst b/UPGRADE.rst
index b2069a0d26..188171d7ab 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -1,3 +1,16 @@
+Upgrading to v1.20.0
+====================
+
+Shared rooms endpoint (MSC2666)
+-------------------------------
+
+This release contains a new unstable endpoint `/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/.*`
+for fetching rooms one user has in common with another. This feature requires the
+`update_user_directory` config flag to be `True`. If you are you are using a `synapse.app.user_dir`
+worker, requests to this endpoint must be handled by that worker.
+See `docs/workers.md <docs/workers.md>`_ for more details.
+
+
Upgrading Synapse
=================
diff --git a/changelog.d/7785.feature b/changelog.d/7785.feature
new file mode 100644
index 0000000000..c7e51c9320
--- /dev/null
+++ b/changelog.d/7785.feature
@@ -0,0 +1 @@
+Add an endpoint to query your shared rooms with another user as an implementation of [MSC2666](https://github.com/matrix-org/matrix-doc/pull/2666).
diff --git a/changelog.d/8170.feature b/changelog.d/8170.feature
new file mode 100644
index 0000000000..b363e929ea
--- /dev/null
+++ b/changelog.d/8170.feature
@@ -0,0 +1 @@
+Add experimental support for sharding event persister.
diff --git a/changelog.d/8213.misc b/changelog.d/8213.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8213.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8225.misc b/changelog.d/8225.misc
new file mode 100644
index 0000000000..979c8b227b
--- /dev/null
+++ b/changelog.d/8225.misc
@@ -0,0 +1 @@
+Refactor queries for device keys and cross-signatures.
diff --git a/changelog.d/8226.bugfix b/changelog.d/8226.bugfix
new file mode 100644
index 0000000000..60bdff576d
--- /dev/null
+++ b/changelog.d/8226.bugfix
@@ -0,0 +1 @@
+Fix a longstanding bug where stats updates could break when unexpected profile data was included in events.
diff --git a/docs/workers.md b/docs/workers.md
index bfec745897..7a8f5c89fc 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -380,6 +380,7 @@ Handles searches in the user directory. It can handle REST endpoints matching
the following regular expressions:
^/_matrix/client/(api/v1|r0|unstable)/user_directory/search$
+ ^/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/.*$
When using this worker you must also set `update_user_directory: False` in the
shared configuration file to stop the main synapse running background
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index a37818fe9a..b6c9085670 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -79,8 +79,7 @@ class AdminCmdServer(HomeServer):
pass
-@defer.inlineCallbacks
-def export_data_command(hs, args):
+async def export_data_command(hs, args):
"""Export data for a user.
Args:
@@ -91,10 +90,8 @@ def export_data_command(hs, args):
user_id = args.user_id
directory = args.output_directory
- res = yield defer.ensureDeferred(
- hs.get_handlers().admin_handler.export_user_data(
- user_id, FileExfiltrationWriter(user_id, directory=directory)
- )
+ res = await hs.get_handlers().admin_handler.export_user_data(
+ user_id, FileExfiltrationWriter(user_id, directory=directory)
)
print(res)
@@ -232,14 +229,15 @@ def start(config_options):
# We also make sure that `_base.start` gets run before we actually run the
# command.
- @defer.inlineCallbacks
- def run(_reactor):
+ async def run():
with LoggingContext("command"):
- yield _base.start(ss, [])
- yield args.func(ss, args)
+ _base.start(ss, [])
+ await args.func(ss, args)
_base.start_worker_reactor(
- "synapse-admin-cmd", config, run_command=lambda: task.react(run)
+ "synapse-admin-cmd",
+ config,
+ run_command=lambda: task.react(lambda _reactor: defer.ensureDeferred(run())),
)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 98d0d14a12..6014adc850 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -411,26 +411,24 @@ def setup(config_options):
return provision
- @defer.inlineCallbacks
- def reprovision_acme():
+ async def reprovision_acme():
"""
Provision a certificate from ACME, if required, and reload the TLS
certificate if it's renewed.
"""
- reprovisioned = yield defer.ensureDeferred(do_acme())
+ reprovisioned = await do_acme()
if reprovisioned:
_base.refresh_certificate(hs)
- @defer.inlineCallbacks
- def start():
+ async def start():
try:
# Run the ACME provisioning code, if it's enabled.
if hs.config.acme_enabled:
acme = hs.get_acme_handler()
# Start up the webservices which we will respond to ACME
# challenges with, and then provision.
- yield defer.ensureDeferred(acme.start_listening())
- yield defer.ensureDeferred(do_acme())
+ await acme.start_listening()
+ await do_acme()
# Check if it needs to be reprovisioned every day.
hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
@@ -439,8 +437,8 @@ def setup(config_options):
if hs.config.oidc_enabled:
oidc = hs.get_oidc_handler()
# Loading the provider metadata also ensures the provider config is valid.
- yield defer.ensureDeferred(oidc.load_metadata())
- yield defer.ensureDeferred(oidc.load_jwks())
+ await oidc.load_metadata()
+ await oidc.load_jwks()
_base.start(hs, config.listeners)
@@ -456,7 +454,7 @@ def setup(config_options):
reactor.stop()
sys.exit(1)
- reactor.callWhenRunning(start)
+ reactor.callWhenRunning(lambda: defer.ensureDeferred(start()))
return hs
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 1477b27326..876bd354ab 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -833,11 +833,26 @@ class ShardedWorkerHandlingConfig:
def should_handle(self, instance_name: str, key: str) -> bool:
"""Whether this instance is responsible for handling the given key.
"""
-
- # If multiple instances are not defined we always return true.
+ # If multiple instances are not defined we always return true
if not self.instances or len(self.instances) == 1:
return True
+ return self.get_instance(key) == instance_name
+
+ def get_instance(self, key: str) -> str:
+ """Get the instance responsible for handling the given key.
+
+ Note: For things like federation sending the config for which instance
+ is sending is known only to the sender instance if there is only one.
+ Therefore `should_handle` should be used where possible.
+ """
+
+ if not self.instances:
+ return "master"
+
+ if len(self.instances) == 1:
+ return self.instances[0]
+
# We shard by taking the hash, modulo it by the number of instances and
# then checking whether this instance matches the instance at that
# index.
@@ -847,7 +862,7 @@ class ShardedWorkerHandlingConfig:
dest_hash = sha256(key.encode("utf8")).digest()
dest_int = int.from_bytes(dest_hash, byteorder="little")
remainder = dest_int % (len(self.instances))
- return self.instances[remainder] == instance_name
+ return self.instances[remainder]
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index eb911e8f9f..b8faafa9bd 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -142,3 +142,4 @@ class ShardedWorkerHandlingConfig:
instances: List[str]
def __init__(self, instances: List[str]) -> None: ...
def should_handle(self, instance_name: str, key: str) -> bool: ...
+ def get_instance(self, key: str) -> str: ...
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index c784a71508..f23e42cdf9 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -13,12 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List, Union
+
import attr
from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
from .server import ListenerConfig, parse_listener_def
+def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
+ """Helper for allowing parsing a string or list of strings to a config
+ option expecting a list of strings.
+ """
+
+ if isinstance(obj, str):
+ return [obj]
+ return obj
+
+
@attr.s
class InstanceLocationConfig:
"""The host and port to talk to an instance via HTTP replication.
@@ -33,11 +45,13 @@ class WriterLocations:
"""Specifies the instances that write various streams.
Attributes:
- events: The instance that writes to the event and backfill streams.
- events: The instance that writes to the typing stream.
+ events: The instances that write to the event and backfill streams.
+ typing: The instance that writes to the typing stream.
"""
- events = attr.ib(default="master", type=str)
+ events = attr.ib(
+ default=["master"], type=List[str], converter=_instance_to_list_converter
+ )
typing = attr.ib(default="master", type=str)
@@ -105,15 +119,18 @@ class WorkerConfig(Config):
writers = config.get("stream_writers") or {}
self.writers = WriterLocations(**writers)
- # Check that the configured writer for events and typing also appears in
+ # Check that the configured writers for events and typing also appears in
# `instance_map`.
for stream in ("events", "typing"):
- instance = getattr(self.writers, stream)
- if instance != "master" and instance not in self.instance_map:
- raise ConfigError(
- "Instance %r is configured to write %s but does not appear in `instance_map` config."
- % (instance, stream)
- )
+ instances = _instance_to_list_converter(getattr(self.writers, stream))
+ for instance in instances:
+ if instance != "master" and instance not in self.instance_map:
+ raise ConfigError(
+ "Instance %r is configured to write %s but does not appear in `instance_map` config."
+ % (instance, stream)
+ )
+
+ self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index b7e23a6072..f67b29cba1 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -926,7 +926,8 @@ class FederationHandler(BaseHandler):
)
)
- await self._handle_new_events(dest, ev_infos, backfilled=True)
+ if ev_infos:
+ await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
@@ -1219,7 +1220,7 @@ class FederationHandler(BaseHandler):
event_infos.append(_NewEventInfo(event, None, auth))
await self._handle_new_events(
- destination, event_infos,
+ destination, room_id, event_infos,
)
def _sanity_check_event(self, ev):
@@ -1366,15 +1367,15 @@ class FederationHandler(BaseHandler):
)
max_stream_id = await self._persist_auth_tree(
- origin, auth_chain, state, event, room_version_obj
+ origin, room_id, auth_chain, state, event, room_version_obj
)
# We wait here until this instance has seen the events come down
# replication (if we're using replication) as the below uses caches.
- #
- # TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position(
- self.config.worker.writers.events, "events", max_stream_id
+ self.config.worker.events_shard_config.get_instance(room_id),
+ "events",
+ max_stream_id,
)
# Check whether this room is the result of an upgrade of a room we already know
@@ -1635,7 +1636,7 @@ class FederationHandler(BaseHandler):
)
context = await self.state_handler.compute_event_context(event)
- await self.persist_events_and_notify([(event, context)])
+ await self.persist_events_and_notify(event.room_id, [(event, context)])
return event
@@ -1662,7 +1663,9 @@ class FederationHandler(BaseHandler):
await self.federation_client.send_leave(host_list, event)
context = await self.state_handler.compute_event_context(event)
- stream_id = await self.persist_events_and_notify([(event, context)])
+ stream_id = await self.persist_events_and_notify(
+ event.room_id, [(event, context)]
+ )
return event, stream_id
@@ -1910,7 +1913,7 @@ class FederationHandler(BaseHandler):
)
await self.persist_events_and_notify(
- [(event, context)], backfilled=backfilled
+ event.room_id, [(event, context)], backfilled=backfilled
)
except Exception:
run_in_background(
@@ -1923,6 +1926,7 @@ class FederationHandler(BaseHandler):
async def _handle_new_events(
self,
origin: str,
+ room_id: str,
event_infos: Iterable[_NewEventInfo],
backfilled: bool = False,
) -> None:
@@ -1954,6 +1958,7 @@ class FederationHandler(BaseHandler):
)
await self.persist_events_and_notify(
+ room_id,
[
(ev_info.event, context)
for ev_info, context in zip(event_infos, contexts)
@@ -1964,6 +1969,7 @@ class FederationHandler(BaseHandler):
async def _persist_auth_tree(
self,
origin: str,
+ room_id: str,
auth_events: List[EventBase],
state: List[EventBase],
event: EventBase,
@@ -1978,6 +1984,7 @@ class FederationHandler(BaseHandler):
Args:
origin: Where the events came from
+ room_id,
auth_events
state
event
@@ -2052,17 +2059,20 @@ class FederationHandler(BaseHandler):
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
await self.persist_events_and_notify(
+ room_id,
[
(e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state)
- ]
+ ],
)
new_event_context = await self.state_handler.compute_event_context(
event, old_state=state
)
- return await self.persist_events_and_notify([(event, new_event_context)])
+ return await self.persist_events_and_notify(
+ room_id, [(event, new_event_context)]
+ )
async def _prep_event(
self,
@@ -2913,6 +2923,7 @@ class FederationHandler(BaseHandler):
async def persist_events_and_notify(
self,
+ room_id: str,
event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
backfilled: bool = False,
) -> int:
@@ -2920,14 +2931,19 @@ class FederationHandler(BaseHandler):
necessary.
Args:
- event_and_contexts:
+ room_id: The room ID of events being persisted.
+ event_and_contexts: Sequence of events with their associated
+ context that should be persisted. All events must belong to
+ the same room.
backfilled: Whether these events are a result of
backfilling or not
"""
- if self.config.worker.writers.events != self._instance_name:
+ instance = self.config.worker.events_shard_config.get_instance(room_id)
+ if instance != self._instance_name:
result = await self._send_events(
- instance_name=self.config.worker.writers.events,
+ instance_name=instance,
store=self.store,
+ room_id=room_id,
event_and_contexts=event_and_contexts,
backfilled=backfilled,
)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 6ab6ab2c34..aa362dade0 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -377,9 +377,8 @@ class EventCreationHandler(object):
self.notifier = hs.get_notifier()
self.config = hs.config
self.require_membership_for_aliases = hs.config.require_membership_for_aliases
- self._is_event_writer = (
- self.config.worker.writers.events == hs.get_instance_name()
- )
+ self._events_shard_config = self.config.worker.events_shard_config
+ self._instance_name = hs.get_instance_name()
self.room_invite_state_types = self.hs.config.room_invite_state_types
@@ -907,9 +906,10 @@ class EventCreationHandler(object):
try:
# If we're a worker we need to hit out to the master.
- if not self._is_event_writer:
+ writer_instance = self._events_shard_config.get_instance(event.room_id)
+ if writer_instance != self._instance_name:
result = await self.send_event(
- instance_name=self.config.worker.writers.events,
+ instance_name=writer_instance,
event_id=event.event_id,
store=self.store,
requester=requester,
@@ -977,7 +977,9 @@ class EventCreationHandler(object):
This should only be run on the instance in charge of persisting events.
"""
- assert self._is_event_writer
+ assert self._events_shard_config.should_handle(
+ self._instance_name, event.room_id
+ )
if ratelimit:
# We check if this is a room admin redacting an event so that we
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a5171af56d..bbf8560ded 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -833,7 +833,9 @@ class RoomCreationHandler(BaseHandler):
# Always wait for room creation to progate before returning
await self._replication.wait_for_stream_position(
- self.hs.config.worker.writers.events, "events", last_stream_id
+ self.hs.config.worker.events_shard_config.get_instance(room_id),
+ "events",
+ last_stream_id,
)
return result, last_stream_id
@@ -1290,10 +1292,10 @@ class RoomShutdownHandler(object):
# We now wait for the create room to come back in via replication so
# that we can assume that all the joins/invites have propogated before
# we try and auto join below.
- #
- # TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position(
- self.hs.config.worker.writers.events, "events", stream_id
+ self.hs.config.worker.events_shard_config.get_instance(new_room_id),
+ "events",
+ stream_id,
)
else:
new_room_id = None
@@ -1323,7 +1325,9 @@ class RoomShutdownHandler(object):
# Wait for leave to come in over replication before trying to forget.
await self._replication.wait_for_stream_position(
- self.hs.config.worker.writers.events, "events", stream_id
+ self.hs.config.worker.events_shard_config.get_instance(room_id),
+ "events",
+ stream_id,
)
await self.room_member_handler.forget(target_requester.user, room_id)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index eb64b3b939..a3b5f084df 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -83,13 +83,6 @@ class RoomMemberHandler(object):
self._enable_lookup = hs.config.enable_3pid_lookup
self.allow_per_room_profiles = self.config.allow_per_room_profiles
- self._event_stream_writer_instance = hs.config.worker.writers.events
- self._is_on_event_persistence_instance = (
- self._event_stream_writer_instance == hs.get_instance_name()
- )
- if self._is_on_event_persistence_instance:
- self.persist_event_storage = hs.get_storage().persistence
-
self._join_rate_limiter_local = Ratelimiter(
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 6b56315148..5c8be747e1 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -65,10 +65,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.federation_handler = hs.get_handlers().federation_handler
@staticmethod
- async def _serialize_payload(store, event_and_contexts, backfilled):
+ async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
"""
Args:
store
+ room_id (str)
event_and_contexts (list[tuple[FrozenEvent, EventContext]])
backfilled (bool): Whether or not the events are the result of
backfilling
@@ -88,7 +89,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
}
)
- payload = {"events": event_payloads, "backfilled": backfilled}
+ payload = {
+ "events": event_payloads,
+ "backfilled": backfilled,
+ "room_id": room_id,
+ }
return payload
@@ -96,6 +101,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request)
+ room_id = content["room_id"]
backfilled = content["backfilled"]
event_payloads = content["events"]
@@ -120,7 +126,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
logger.info("Got %d events from federation", len(event_and_contexts))
max_stream_id = await self.federation_handler.persist_events_and_notify(
- event_and_contexts, backfilled
+ room_id, event_and_contexts, backfilled
)
return 200, {"max_stream_id": max_stream_id}
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 1c303f3a46..b323841f73 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -109,7 +109,7 @@ class ReplicationCommandHandler:
if isinstance(stream, (EventsStream, BackfillStream)):
# Only add EventStream and BackfillStream as a source on the
# instance in charge of event persistence.
- if hs.config.worker.writers.events == hs.get_instance_name():
+ if hs.get_instance_name() in hs.config.worker.writers.events:
self._streams_to_replicate.append(stream)
continue
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 16c63ff4ec..3705618b4f 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,7 +19,7 @@ from typing import List, Tuple, Type
import attr
-from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
+from ._base import Stream, StreamUpdateResult, Token
"""Handling of the 'events' replication stream
@@ -117,7 +117,7 @@ class EventsStream(Stream):
self._store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(self._store.get_current_events_token),
+ self._store._stream_id_gen.get_current_token_for_writer,
self._update_function,
)
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 2e81eeff65..10ac6fd7dc 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -50,6 +50,7 @@ from synapse.rest.client.v2_alpha import (
room_keys,
room_upgrade_rest_servlet,
sendtodevice,
+ shared_rooms,
sync,
tags,
thirdparty,
@@ -126,3 +127,6 @@ class ClientRestResource(JsonResource):
synapse.rest.admin.register_servlets_for_client_rest_resource(
hs, client_resource
)
+
+ # unstable
+ shared_rooms.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/v2_alpha/shared_rooms.py
new file mode 100644
index 0000000000..2492634dac
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/shared_rooms.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Half-Shot
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import RestServlet
+from synapse.types import UserID
+
+from ._base import client_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class UserSharedRoomsServlet(RestServlet):
+ """
+ GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1
+ """
+
+ PATTERNS = client_patterns(
+ "/uk.half-shot.msc2666/user/shared_rooms/(?P<user_id>[^/]*)",
+ releases=(), # This is an unstable feature
+ )
+
+ def __init__(self, hs):
+ super(UserSharedRoomsServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.user_directory_active = hs.config.update_user_directory
+
+ async def on_GET(self, request, user_id):
+
+ if not self.user_directory_active:
+ raise SynapseError(
+ code=400,
+ msg="The user directory is disabled on this server. Cannot determine shared rooms.",
+ errcode=Codes.FORBIDDEN,
+ )
+
+ UserID.from_string(user_id)
+
+ requester = await self.auth.get_user_by_req(request)
+ if user_id == requester.user.to_string():
+ raise SynapseError(
+ code=400,
+ msg="You cannot request a list of shared rooms with yourself",
+ errcode=Codes.FORBIDDEN,
+ )
+ rooms = await self.store.get_shared_rooms_for_users(
+ requester.user.to_string(), user_id
+ )
+
+ return 200, {"joined": list(rooms)}
+
+
+def register_servlets(hs, http_server):
+ UserSharedRoomsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index b1999d051b..58ec5a694a 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -63,6 +63,8 @@ class VersionsRestServlet(RestServlet):
# Tchap does not currently assume this rule for r0.5.0
# XXX: Remove this when it does
"m.lazy_load_members": True,
+ # Implements additional endpoints as described in MSC2666
+ "uk.half-shot.msc2666": True,
},
},
)
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 0ac854aee2..c73d54fb67 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -68,7 +68,7 @@ class Databases(object):
# If we're on a process that can persist events also
# instantiate a `PersistEventsStore`
- if hs.config.worker.writers.events == hs.get_instance_name():
+ if hs.get_instance_name() in hs.config.worker.writers.events:
persist_events = PersistEventsStore(hs, database, main)
if "state" in database_config.databases:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 449d95f31e..4059701cfd 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -24,7 +24,7 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
+from synapse.storage.database import make_in_list_sql_clause
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@@ -58,18 +58,13 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Returns:
(stream_id, devices)
"""
- return await self.db_pool.runInteraction(
- "get_e2e_device_keys_for_federation_query",
- self._get_e2e_device_keys_for_federation_query_txn,
- user_id,
- )
-
- def _get_e2e_device_keys_for_federation_query_txn(
- self, txn: LoggingTransaction, user_id: str
- ) -> Tuple[int, List[JsonDict]]:
now_stream_id = self.get_device_stream_token()
- devices = self._get_e2e_device_keys_and_signatures_txn(txn, [(user_id, None)])
+ devices = await self.db_pool.runInteraction(
+ "get_e2e_device_keys_and_signatures_txn",
+ self._get_e2e_device_keys_and_signatures_txn,
+ [(user_id, None)],
+ )
if devices:
user_devices = devices[user_id]
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 0b69aa6a94..4c3c162acf 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"""
if stream_ordering <= self.stream_ordering_month_ago:
- raise StoreError(400, "stream_ordering too old")
+ raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
sql = """
SELECT event_id FROM stream_ordering_to_exterm
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 6313b41eef..46b11e705b 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -97,6 +97,7 @@ class PersistEventsStore:
self.store = main_data_store
self.database_engine = db.engine
self._clock = hs.get_clock()
+ self._instance_name = hs.get_instance_name()
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
@@ -108,7 +109,7 @@ class PersistEventsStore:
# This should only exist on instances that are configured to write
assert (
- hs.config.worker.writers.events == hs.get_instance_name()
+ hs.get_instance_name() in hs.config.worker.writers.events
), "Can only instantiate EventsStore on master"
async def _persist_events_and_state_updates(
@@ -800,6 +801,7 @@ class PersistEventsStore:
table="events",
values=[
{
+ "instance_name": self._instance_name,
"stream_ordering": event.internal_metadata.stream_ordering,
"topological_ordering": event.depth,
"depth": event.depth,
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index a7a73cc3d8..17f5997b89 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -42,7 +42,8 @@ from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter
@@ -78,27 +79,54 @@ class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
- if hs.config.worker.writers.events == hs.get_instance_name():
- # We are the process in charge of generating stream ids for events,
- # so instantiate ID generators based on the database
- self._stream_id_gen = StreamIdGenerator(
- db_conn, "events", "stream_ordering",
+ if isinstance(database.engine, PostgresEngine):
+ # If we're using Postgres than we can use `MultiWriterIdGenerator`
+ # regardless of whether this process writes to the streams or not.
+ self._stream_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ instance_name=hs.get_instance_name(),
+ table="events",
+ instance_column="instance_name",
+ id_column="stream_ordering",
+ sequence_name="events_stream_seq",
)
- self._backfill_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- step=-1,
- extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+ self._backfill_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ instance_name=hs.get_instance_name(),
+ table="events",
+ instance_column="instance_name",
+ id_column="stream_ordering",
+ sequence_name="events_backfill_stream_seq",
+ positive=False,
)
else:
- # Another process is in charge of persisting events and generating
- # stream IDs: rely on the replication streams to let us know which
- # IDs we can process.
- self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
- self._backfill_id_gen = SlavedIdTracker(
- db_conn, "events", "stream_ordering", step=-1
- )
+ # We shouldn't be running in worker mode with SQLite, but its useful
+ # to support it for unit tests.
+ #
+ # If this process is the writer than we need to use
+ # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+ # updated over replication. (Multiple writers are not supported for
+ # SQLite).
+ if hs.get_instance_name() in hs.config.worker.writers.events:
+ self._stream_id_gen = StreamIdGenerator(
+ db_conn, "events", "stream_ordering",
+ )
+ self._backfill_id_gen = StreamIdGenerator(
+ db_conn,
+ "events",
+ "stream_ordering",
+ step=-1,
+ extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+ )
+ else:
+ self._stream_id_gen = SlavedIdTracker(
+ db_conn, "events", "stream_ordering"
+ )
+ self._backfill_id_gen = SlavedIdTracker(
+ db_conn, "events", "stream_ordering", step=-1
+ )
self._get_event_cache = Cache(
"*getEvent*",
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql
new file mode 100644
index 0000000000..98ff76d709
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql
@@ -0,0 +1,16 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE events ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
new file mode 100644
index 0000000000..97c1e6a0c5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
@@ -0,0 +1,26 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS events_stream_seq;
+
+SELECT setval('events_stream_seq', (
+ SELECT COALESCE(MAX(stream_ordering), 1) FROM events
+));
+
+CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
+
+SELECT setval('events_backfill_stream_seq', (
+ SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
+));
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 9b9bc304a8..55a250ef06 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -224,14 +224,32 @@ class StatsStore(StateDeltasStore):
)
async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None:
- """
+ """Update the state of a room.
+
+ fields can contain the following keys with string values:
+ * join_rules
+ * history_visibility
+ * encryption
+ * name
+ * topic
+ * avatar
+ * canonical_alias
+
+ A is_federatable key can also be included with a boolean value.
+
Args:
- room_id
- fields
+ room_id: The room ID to update the state of.
+ fields: The fields to update. This can include a partial list of the
+ above fields to only update some room information.
"""
-
- # For whatever reason some of the fields may contain null bytes, which
- # postgres isn't a fan of, so we replace those fields with null.
+ # Ensure that the values to update are valid, they should be strings and
+ # not contain any null bytes.
+ #
+ # Invalid data gets overwritten with null.
+ #
+ # Note that a missing value should not be overwritten (it keeps the
+ # previous value).
+ sentinel = object()
for col in (
"join_rules",
"history_visibility",
@@ -241,8 +259,8 @@ class StatsStore(StateDeltasStore):
"avatar",
"canonical_alias",
):
- field = fields.get(col)
- if field and "\0" in field:
+ field = fields.get(col, sentinel)
+ if field is not sentinel and (not isinstance(field, str) or "\0" in field):
fields[col] = None
await self.db_pool.simple_upsert(
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index c977db042e..f2f9a5799a 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -15,7 +15,7 @@
import logging
import re
-from typing import Any, Dict, Iterable, Optional, Tuple
+from typing import Any, Dict, Iterable, Optional, Set, Tuple
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.database import DatabasePool
@@ -675,6 +675,48 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
+ @cached()
+ async def get_shared_rooms_for_users(
+ self, user_id: str, other_user_id: str
+ ) -> Set[str]:
+ """
+ Returns the rooms that a local user shares with another local or remote user.
+
+ Args:
+ user_id: The MXID of a local user
+ other_user_id: The MXID of the other user
+
+ Returns:
+ A set of room ID's that the users share.
+ """
+
+ def _get_shared_rooms_for_users_txn(txn):
+ txn.execute(
+ """
+ SELECT p1.room_id
+ FROM users_in_public_rooms as p1
+ INNER JOIN users_in_public_rooms as p2
+ ON p1.room_id = p2.room_id
+ AND p1.user_id = ?
+ AND p2.user_id = ?
+ UNION
+ SELECT room_id
+ FROM users_who_share_private_rooms
+ WHERE
+ user_id = ?
+ AND other_user_id = ?
+ """,
+ (user_id, other_user_id, user_id, other_user_id),
+ )
+ rows = self.db_pool.cursor_to_dict(txn)
+ return rows
+
+ rows = await self.db_pool.runInteraction(
+ "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn
+ )
+
+ return {row["room_id"] for row in rows}
+
async def get_user_directory_stream_pos(self) -> int:
return await self.db_pool.simple_select_one_onecol(
table="user_directory_stream_pos",
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9f3d23f0a5..8fd21c2bf8 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -231,8 +231,12 @@ class MultiWriterIdGenerator:
# gaps should be relatively rare it's still worth doing the book keeping
# that allows us to skip forwards when there are gapless runs of
# positions.
+ #
+ # We start at 1 here as a) the first generated stream ID will be 2, and
+ # b) other parts of the code assume that stream IDs are strictly greater
+ # than 0.
self._persisted_upto_position = (
- min(self._current_positions.values()) if self._current_positions else 0
+ min(self._current_positions.values()) if self._current_positions else 1
)
self._known_persisted_positions = [] # type: List[int]
@@ -362,9 +366,7 @@ class MultiWriterIdGenerator:
equal to it have been successfully persisted.
"""
- # Currently we don't support this operation, as it's not obvious how to
- # condense the stream positions of multiple writers into a single int.
- raise NotImplementedError()
+ return self.get_persisted_upto_position()
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index f56708417f..02111d5d7f 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -576,16 +576,16 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Mock Synapse's threepid validator
get_threepid_validation_session = Mock(
- return_value=defer.succeed(
+ return_value=make_awaitable(
{"medium": "email", "address": email, "validated_at": 0}
)
)
self.store.get_threepid_validation_session = get_threepid_validation_session
- delete_threepid_session = Mock(return_value=defer.succeed(None))
+ delete_threepid_session = Mock(return_value=make_awaitable(None))
self.store.delete_threepid_session = delete_threepid_session
# Mock Synapse's http json post method to check for the internal bind call
- post_json_get_json = Mock(return_value=defer.succeed(None))
+ post_json_get_json = Mock(return_value=make_awaitable(None))
self.hs.get_simple_http_client().post_json_get_json = post_json_get_json
# Retrieve a UIA session ID
diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py
new file mode 100644
index 0000000000..5ae72fd008
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_shared_rooms.py
@@ -0,0 +1,138 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Half-Shot
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import synapse.rest.admin
+from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import shared_rooms
+
+from tests import unittest
+
+
+class UserSharedRoomsTest(unittest.HomeserverTestCase):
+ """
+ Tests the UserSharedRoomsServlet.
+ """
+
+ servlets = [
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ shared_rooms.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ config["update_user_directory"] = True
+ return self.setup_test_homeserver(config=config)
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.handler = hs.get_user_directory_handler()
+
+ def _get_shared_rooms(self, token, other_user):
+ request, channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s"
+ % other_user,
+ access_token=token,
+ )
+ self.render(request)
+ return request, channel
+
+ def test_shared_room_list_public(self):
+ """
+ A room should show up in the shared list of rooms between two users
+ if it is public.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 1)
+ self.assertEquals(channel.json_body["joined"][0], room)
+
+ def test_shared_room_list_private(self):
+ """
+ A room should show up in the shared list of rooms between two users
+ if it is private.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 1)
+ self.assertEquals(channel.json_body["joined"][0], room)
+
+ def test_shared_room_list_mixed(self):
+ """
+ The shared room list between two users should contain both public and private
+ rooms.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room_public = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+ room_private = self.helper.create_room_as(u2, is_public=False, tok=u2_token)
+ self.helper.invite(room_public, src=u1, targ=u2, tok=u1_token)
+ self.helper.invite(room_private, src=u2, targ=u1, tok=u2_token)
+ self.helper.join(room_public, user=u2, tok=u2_token)
+ self.helper.join(room_private, user=u1, tok=u1_token)
+
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 2)
+ self.assertTrue(room_public in channel.json_body["joined"])
+ self.assertTrue(room_private in channel.json_body["joined"])
+
+ def test_shared_room_list_after_leave(self):
+ """
+ A room should no longer be considered shared if the other
+ user has left it.
+ """
+ u1 = self.register_user("user1", "pass")
+ u1_token = self.login(u1, "pass")
+ u2 = self.register_user("user2", "pass")
+ u2_token = self.login(u2, "pass")
+
+ room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+ self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+ self.helper.join(room, user=u2, tok=u2_token)
+
+ # Assert user directory is not empty
+ request, channel = self._get_shared_rooms(u1_token, u2)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 1)
+ self.assertEquals(channel.json_body["joined"][0], room)
+
+ self.helper.leave(room, user=u1, tok=u1_token)
+
+ request, channel = self._get_shared_rooms(u2_token, u1)
+ self.assertEquals(200, channel.code, channel.result)
+ self.assertEquals(len(channel.json_body["joined"]), 0)
|