From 4ff0201e6235b8b2efc5ce5a7dc3c479ea96df53 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 1 Oct 2020 08:09:18 -0400 Subject: Enable mypy checking for unreachable code and fix instances. (#8432) --- synapse/replication/tcp/protocol.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'synapse/replication') diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 0b0d204e64..a509e599c2 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -51,10 +51,11 @@ import fcntl import logging import struct from inspect import isawaitable -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional from prometheus_client import Counter +from twisted.internet import task from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure @@ -152,9 +153,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_received_command = self.clock.time_msec() self.last_sent_command = 0 - self.time_we_closed = None # When we requested the connection be closed + # When we requested the connection be closed + self.time_we_closed = None # type: Optional[int] - self.received_ping = False # Have we reecived a ping from the other side + self.received_ping = False # Have we received a ping from the other side self.state = ConnectionStates.CONNECTING @@ -165,7 +167,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.pending_commands = [] # type: List[Command] # The LoopingCall for sending pings. - self._send_ping_loop = None + self._send_ping_loop = None # type: Optional[task.LoopingCall] # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. -- cgit 1.5.1 From 6c5d5e507e629cf57ae8c1034879e8ffaef33e9f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 2 Oct 2020 09:57:12 +0100 Subject: Add unit test for event persister sharding (#8433) --- changelog.d/8433.misc | 1 + mypy.ini | 3 + stubs/txredisapi.pyi | 20 +- synapse/replication/tcp/handler.py | 6 +- synapse/replication/tcp/redis.py | 40 +++- tests/replication/_base.py | 224 ++++++++++++++++++++-- tests/replication/test_sharded_event_persister.py | 102 ++++++++++ tests/unittest.py | 2 +- 8 files changed, 371 insertions(+), 27 deletions(-) create mode 100644 changelog.d/8433.misc create mode 100644 tests/replication/test_sharded_event_persister.py (limited to 'synapse/replication') diff --git a/changelog.d/8433.misc b/changelog.d/8433.misc new file mode 100644 index 0000000000..05f8b5bbf4 --- /dev/null +++ b/changelog.d/8433.misc @@ -0,0 +1 @@ +Add unit test for event persister sharding. diff --git a/mypy.ini b/mypy.ini index c283f15b21..e84ad04e41 100644 --- a/mypy.ini +++ b/mypy.ini @@ -143,3 +143,6 @@ ignore_missing_imports = True [mypy-nacl.*] ignore_missing_imports = True + +[mypy-hiredis] +ignore_missing_imports = True diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index c66413f003..522244bb57 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -16,7 +16,7 @@ """Contains *incomplete* type hints for txredisapi. """ -from typing import List, Optional, Union +from typing import List, Optional, Union, Type class RedisProtocol: def publish(self, channel: str, message: bytes): ... @@ -42,3 +42,21 @@ def lazyConnection( class SubscriberFactory: def buildProtocol(self, addr): ... + +class ConnectionHandler: ... + +class RedisFactory: + continueTrying: bool + handler: RedisProtocol + def __init__( + self, + uuid: str, + dbid: Optional[int], + poolsize: int, + isLazy: bool = False, + handler: Type = ConnectionHandler, + charset: str = "utf-8", + password: Optional[str] = None, + replyTimeout: Optional[int] = None, + convertNumbers: Optional[int] = True, + ): ... diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index b323841f73..e92da7b263 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -251,10 +251,9 @@ class ReplicationCommandHandler: using TCP. """ if hs.config.redis.redis_enabled: - import txredisapi - from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, + lazyConnection, ) logger.info( @@ -271,7 +270,8 @@ class ReplicationCommandHandler: # connection after SUBSCRIBE is called). # First create the connection for sending commands. - outbound_redis_connection = txredisapi.lazyConnection( + outbound_redis_connection = lazyConnection( + reactor=hs.get_reactor(), host=hs.config.redis_host, port=hs.config.redis_port, password=hs.config.redis.redis_password, diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index f225e533de..de19705c1f 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -15,7 +15,7 @@ import logging from inspect import isawaitable -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import txredisapi @@ -228,3 +228,41 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): p.password = self.password return p + + +def lazyConnection( + reactor, + host: str = "localhost", + port: int = 6379, + dbid: Optional[int] = None, + reconnect: bool = True, + charset: str = "utf-8", + password: Optional[str] = None, + connectTimeout: Optional[int] = None, + replyTimeout: Optional[int] = None, + convertNumbers: bool = True, +) -> txredisapi.RedisProtocol: + """Equivalent to `txredisapi.lazyConnection`, except allows specifying a + reactor. + """ + + isLazy = True + poolsize = 1 + + uuid = "%s:%d" % (host, port) + factory = txredisapi.RedisFactory( + uuid, + dbid, + poolsize, + isLazy, + txredisapi.ConnectionHandler, + charset, + password, + replyTimeout, + convertNumbers, + ) + factory.continueTrying = reconnect + for x in range(poolsize): + reactor.connectTCP(host, port, factory, connectTimeout) + + return factory.handler diff --git a/tests/replication/_base.py b/tests/replication/_base.py index ae60874ec3..81ea985b9f 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -12,13 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import Any, Callable, List, Optional, Tuple import attr +import hiredis from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime +from twisted.internet.protocol import Protocol from twisted.internet.task import LoopingCall from twisted.web.http import HTTPChannel @@ -27,7 +28,7 @@ from synapse.app.generic_worker import ( GenericWorkerServer, ) from synapse.http.server import JsonResource -from synapse.http.site import SynapseRequest +from synapse.http.site import SynapseRequest, SynapseSite from synapse.replication.http import ReplicationRestResource, streams from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol @@ -197,19 +198,37 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.server_factory = ReplicationStreamProtocolFactory(self.hs) self.streamer = self.hs.get_replication_streamer() + # Fake in memory Redis server that servers can connect to. + self._redis_server = FakeRedisPubSubServer() + store = self.hs.get_datastore() self.database_pool = store.db_pool self.reactor.lookups["testserv"] = "1.2.3.4" + self.reactor.lookups["localhost"] = "127.0.0.1" + + # A map from a HS instance to the associated HTTP Site to use for + # handling inbound HTTP requests to that instance. + self._hs_to_site = {self.hs: self.site} + + if self.hs.config.redis.redis_enabled: + # Handle attempts to connect to fake redis server. + self.reactor.add_tcp_client_callback( + "localhost", 6379, self.connect_any_redis_attempts, + ) - self._worker_hs_to_resource = {} + self.hs.get_tcp_replication().start_replication(self.hs) # When we see a connection attempt to the master replication listener we # automatically set up the connection. This is so that tests don't # manually have to go and explicitly set it up each time (plus sometimes # it is impossible to write the handling explicitly in the tests). + # + # Register the master replication listener: self.reactor.add_tcp_client_callback( - "1.2.3.4", 8765, self._handle_http_replication_attempt + "1.2.3.4", + 8765, + lambda: self._handle_http_replication_attempt(self.hs, 8765), ) def create_test_json_resource(self): @@ -253,28 +272,63 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): **kwargs ) + # If the instance is in the `instance_map` config then workers may try + # and send HTTP requests to it, so we register it with + # `_handle_http_replication_attempt` like we do with the master HS. + instance_name = worker_hs.get_instance_name() + instance_loc = worker_hs.config.worker.instance_map.get(instance_name) + if instance_loc: + # Ensure the host is one that has a fake DNS entry. + if instance_loc.host not in self.reactor.lookups: + raise Exception( + "Host does not have an IP for instance_map[%r].host = %r" + % (instance_name, instance_loc.host,) + ) + + self.reactor.add_tcp_client_callback( + self.reactor.lookups[instance_loc.host], + instance_loc.port, + lambda: self._handle_http_replication_attempt( + worker_hs, instance_loc.port + ), + ) + store = worker_hs.get_datastore() store.db_pool._db_pool = self.database_pool._db_pool - repl_handler = ReplicationCommandHandler(worker_hs) - client = ClientReplicationStreamProtocol( - worker_hs, "client", "test", self.clock, repl_handler, - ) - server = self.server_factory.buildProtocol(None) + # Set up TCP replication between master and the new worker if we don't + # have Redis support enabled. + if not worker_hs.config.redis_enabled: + repl_handler = ReplicationCommandHandler(worker_hs) + client = ClientReplicationStreamProtocol( + worker_hs, "client", "test", self.clock, repl_handler, + ) + server = self.server_factory.buildProtocol(None) - client_transport = FakeTransport(server, self.reactor) - client.makeConnection(client_transport) + client_transport = FakeTransport(server, self.reactor) + client.makeConnection(client_transport) - server_transport = FakeTransport(client, self.reactor) - server.makeConnection(server_transport) + server_transport = FakeTransport(client, self.reactor) + server.makeConnection(server_transport) # Set up a resource for the worker - resource = ReplicationRestResource(self.hs) + resource = ReplicationRestResource(worker_hs) for servlet in self.servlets: servlet(worker_hs, resource) - self._worker_hs_to_resource[worker_hs] = resource + self._hs_to_site[worker_hs] = SynapseSite( + logger_name="synapse.access.http.fake", + site_tag="{}-{}".format( + worker_hs.config.server.server_name, worker_hs.get_instance_name() + ), + config=worker_hs.config.server.listeners[0], + resource=resource, + server_version_string="1", + ) + + if worker_hs.config.redis.redis_enabled: + worker_hs.get_tcp_replication().start_replication(worker_hs) return worker_hs @@ -285,7 +339,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): return config def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest): - render(request, self._worker_hs_to_resource[worker_hs], self.reactor) + render(request, self._hs_to_site[worker_hs].resource, self.reactor) def replicate(self): """Tell the master side of replication that something has happened, and then @@ -294,9 +348,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.streamer.on_notifier_poke() self.pump() - def _handle_http_replication_attempt(self): - """Handles a connection attempt to the master replication HTTP - listener. + def _handle_http_replication_attempt(self, hs, repl_port): + """Handles a connection attempt to the given HS replication HTTP + listener on the given port. """ # We should have at least one outbound connection attempt, where the @@ -305,7 +359,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.assertGreaterEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop() self.assertEqual(host, "1.2.3.4") - self.assertEqual(port, 8765) + self.assertEqual(port, repl_port) # Set up client side protocol client_protocol = client_factory.buildProtocol(None) @@ -315,7 +369,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # Set up the server side protocol channel = _PushHTTPChannel(self.reactor) channel.requestFactory = request_factory - channel.site = self.site + channel.site = self._hs_to_site[hs] # Connect client to server and vice versa. client_to_server_transport = FakeTransport( @@ -333,6 +387,32 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # inside `connecTCP` before the connection has been passed back to the # code that requested the TCP connection. + def connect_any_redis_attempts(self): + """If redis is enabled we need to deal with workers connecting to a + redis server. We don't want to use a real Redis server so we use a + fake one. + """ + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "localhost") + self.assertEqual(port, 6379) + + client_protocol = client_factory.buildProtocol(None) + server_protocol = self._redis_server.buildProtocol(None) + + client_to_server_transport = FakeTransport( + server_protocol, self.reactor, client_protocol + ) + client_protocol.makeConnection(client_to_server_transport) + + server_to_client_transport = FakeTransport( + client_protocol, self.reactor, server_protocol + ) + server_protocol.makeConnection(server_to_client_transport) + + return client_to_server_transport, server_to_client_transport + class TestReplicationDataHandler(GenericWorkerReplicationHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" @@ -467,3 +547,105 @@ class _PullToPushProducer: pass self.stopProducing() + + +class FakeRedisPubSubServer: + """A fake Redis server for pub/sub. + """ + + def __init__(self): + self._subscribers = set() + + def add_subscriber(self, conn): + """A connection has called SUBSCRIBE + """ + self._subscribers.add(conn) + + def remove_subscriber(self, conn): + """A connection has called UNSUBSCRIBE + """ + self._subscribers.discard(conn) + + def publish(self, conn, channel, msg) -> int: + """A connection want to publish a message to subscribers. + """ + for sub in self._subscribers: + sub.send(["message", channel, msg]) + + return len(self._subscribers) + + def buildProtocol(self, addr): + return FakeRedisPubSubProtocol(self) + + +class FakeRedisPubSubProtocol(Protocol): + """A connection from a client talking to the fake Redis server. + """ + + def __init__(self, server: FakeRedisPubSubServer): + self._server = server + self._reader = hiredis.Reader() + + def dataReceived(self, data): + self._reader.feed(data) + + # We might get multiple messages in one packet. + while True: + msg = self._reader.gets() + + if msg is False: + # No more messages. + return + + if not isinstance(msg, list): + # Inbound commands should always be a list + raise Exception("Expected redis list") + + self.handle_command(msg[0], *msg[1:]) + + def handle_command(self, command, *args): + """Received a Redis command from the client. + """ + + # We currently only support pub/sub. + if command == b"PUBLISH": + channel, message = args + num_subscribers = self._server.publish(self, channel, message) + self.send(num_subscribers) + elif command == b"SUBSCRIBE": + (channel,) = args + self._server.add_subscriber(self) + self.send(["subscribe", channel, 1]) + else: + raise Exception("Unknown command") + + def send(self, msg): + """Send a message back to the client. + """ + raw = self.encode(msg).encode("utf-8") + + self.transport.write(raw) + self.transport.flush() + + def encode(self, obj): + """Encode an object to its Redis format. + + Supports: strings/bytes, integers and list/tuples. + """ + + if isinstance(obj, bytes): + # We assume bytes are just unicode strings. + obj = obj.decode("utf-8") + + if isinstance(obj, str): + return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj) + if isinstance(obj, int): + return ":{val}\r\n".format(val=obj) + if isinstance(obj, (list, tuple)): + items = "".join(self.encode(a) for a in obj) + return "*{len}\r\n{items}".format(len=len(obj), items=items) + + raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj) + + def connectionLost(self, reason): + self._server.remove_subscriber(self) diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py new file mode 100644 index 0000000000..6068d14905 --- /dev/null +++ b/tests/replication/test_sharded_event_persister.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.rest import admin +from synapse.rest.client.v1 import login, room + +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.utils import USE_POSTGRES_FOR_TESTS + +logger = logging.getLogger(__name__) + + +class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): + """Checks event persisting sharding works + """ + + # Event persister sharding requires postgres (due to needing + # `MutliWriterIdGenerator`). + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + # Register a user who sends a message that we'll get notified about + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + def default_config(self): + conf = super().default_config() + conf["redis"] = {"enabled": "true"} + conf["stream_writers"] = {"events": ["worker1", "worker2"]} + conf["instance_map"] = { + "worker1": {"host": "testserv", "port": 1001}, + "worker2": {"host": "testserv", "port": 1002}, + } + return conf + + def test_basic(self): + """Simple test to ensure that multiple rooms can be created and joined, + and that different rooms get handled by different instances. + """ + + self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "worker1"}, + ) + + self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "worker2"}, + ) + + persisted_on_1 = False + persisted_on_2 = False + + store = self.hs.get_datastore() + + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Keep making new rooms until we see rooms being persisted on both + # workers. + for _ in range(10): + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # The other user joins + self.helper.join( + room=room, user=self.other_user_id, tok=self.other_access_token + ) + + # The other user sends some messages + rseponse = self.helper.send(room, body="Hi!", tok=self.other_access_token) + event_id = rseponse["event_id"] + + # The event position includes which instance persisted the event. + pos = self.get_success(store.get_position_for_event(event_id)) + + persisted_on_1 |= pos.instance_name == "worker1" + persisted_on_2 |= pos.instance_name == "worker2" + + if persisted_on_1 and persisted_on_2: + break + + self.assertTrue(persisted_on_1) + self.assertTrue(persisted_on_2) diff --git a/tests/unittest.py b/tests/unittest.py index e654c0442d..82ede9de34 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -241,7 +241,7 @@ class HomeserverTestCase(TestCase): # create a site to wrap the resource. self.site = SynapseSite( logger_name="synapse.access.http.fake", - site_tag="test", + site_tag=self.hs.config.server.server_name, config=self.hs.config.server.listeners[0], resource=self.resource, server_version_string="1", -- cgit 1.5.1 From c9c0ad5e204f309f2686dbe250382e481e0f82c2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 9 Oct 2020 07:24:34 -0400 Subject: Remove the deprecated Handlers object (#8494) All handlers now available via get_*_handler() methods on the HomeServer. --- changelog.d/8494.misc | 1 + synapse/app/admin_cmd.py | 2 +- synapse/federation/federation_server.py | 7 ++++- synapse/handlers/__init__.py | 33 ----------------------- synapse/handlers/auth.py | 2 +- synapse/handlers/deactivate_account.py | 2 +- synapse/handlers/message.py | 4 +-- synapse/handlers/pagination.py | 2 +- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 2 +- synapse/handlers/room_member.py | 6 ++--- synapse/handlers/ui_auth/checkers.py | 2 +- synapse/replication/http/federation.py | 2 +- synapse/replication/http/membership.py | 2 +- synapse/rest/admin/rooms.py | 4 +-- synapse/rest/admin/users.py | 13 +++++---- synapse/rest/client/v1/directory.py | 25 +++++++++-------- synapse/rest/client/v1/login.py | 1 - synapse/rest/client/v1/room.py | 10 +++---- synapse/rest/client/v2_alpha/account.py | 16 +++++------ synapse/rest/client/v2_alpha/register.py | 6 ++--- synapse/server.py | 30 +++++++++++++++++---- tests/api/test_auth.py | 12 +++------ tests/api/test_filtering.py | 5 +--- tests/crypto/test_keyring.py | 6 ++--- tests/handlers/test_admin.py | 2 +- tests/handlers/test_auth.py | 11 ++------ tests/handlers/test_directory.py | 10 +++---- tests/handlers/test_e2e_keys.py | 2 +- tests/handlers/test_e2e_room_keys.py | 2 +- tests/handlers/test_federation.py | 2 +- tests/handlers/test_presence.py | 2 +- tests/handlers/test_profile.py | 7 ----- tests/handlers/test_register.py | 22 ++++++--------- tests/replication/test_federation_sender_shard.py | 2 +- tests/rest/client/test_shadow_banned.py | 2 +- tests/rest/client/v1/test_events.py | 2 +- tests/rest/client/v1/test_rooms.py | 6 ++++- tests/rest/client/v1/test_typing.py | 2 +- tests/test_federation.py | 2 +- 40 files changed, 116 insertions(+), 157 deletions(-) create mode 100644 changelog.d/8494.misc (limited to 'synapse/replication') diff --git a/changelog.d/8494.misc b/changelog.d/8494.misc new file mode 100644 index 0000000000..6e56c6b854 --- /dev/null +++ b/changelog.d/8494.misc @@ -0,0 +1 @@ +Remove the deprecated `Handlers` object. diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index f0d65d08d7..b4bd4d8e7a 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -89,7 +89,7 @@ async def export_data_command(hs, args): user_id = args.user_id directory = args.output_directory - res = await hs.get_handlers().admin_handler.export_user_data( + res = await hs.get_admin_handler().export_user_data( user_id, FileExfiltrationWriter(user_id, directory=directory) ) print(res) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 1c7ea886c9..e8039e244c 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -99,10 +99,15 @@ class FederationServer(FederationBase): super().__init__(hs) self.auth = hs.get_auth() - self.handler = hs.get_handlers().federation_handler + self.handler = hs.get_federation_handler() self.state = hs.get_state_handler() self.device_handler = hs.get_device_handler() + + # Ensure the following handlers are loaded since they register callbacks + # with FederationHandlerRegistry. + hs.get_directory_handler() + self._federation_ratelimiter = hs.get_federation_ratelimiter() self._server_linearizer = Linearizer("fed_server") diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 286f0054be..bfebb0f644 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -12,36 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from .admin import AdminHandler -from .directory import DirectoryHandler -from .federation import FederationHandler -from .identity import IdentityHandler -from .search import SearchHandler - - -class Handlers: - - """ Deprecated. A collection of handlers. - - At some point most of the classes whose name ended "Handler" were - accessed through this class. - - However this makes it painful to unit test the handlers and to run cut - down versions of synapse that only use specific handlers because using a - single handler required creating all of the handlers. So some of the - handlers have been lifted out of the Handlers object and are now accessed - directly through the homeserver object itself. - - Any new handlers should follow the new pattern of being accessed through - the homeserver object and should not be added to the Handlers object. - - The remaining handlers should be moved out of the handlers object. - """ - - def __init__(self, hs): - self.federation_handler = FederationHandler(hs) - self.directory_handler = DirectoryHandler(hs) - self.admin_handler = AdminHandler(hs) - self.identity_handler = IdentityHandler(hs) - self.search_handler = SearchHandler(hs) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index f6d17c53b1..1d1ddc2245 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1080,7 +1080,7 @@ class AuthHandler(BaseHandler): if medium == "email": address = canonicalise_email(address) - identity_handler = self.hs.get_handlers().identity_handler + identity_handler = self.hs.get_identity_handler() result = await identity_handler.try_unbind_threepid( user_id, {"medium": medium, "address": address, "id_server": id_server} ) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 72a5831531..58c9f12686 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -37,7 +37,7 @@ class DeactivateAccountHandler(BaseHandler): self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() self._room_member_handler = hs.get_room_member_handler() - self._identity_handler = hs.get_handlers().identity_handler + self._identity_handler = hs.get_identity_handler() self.user_directory_handler = hs.get_user_directory_handler() # Flag that indicates whether the process to part users from rooms is running diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 3e9a22e8f3..33d133a4b2 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1014,7 +1014,7 @@ class EventCreationHandler: # Check the alias is currently valid (if it has changed). room_alias_str = event.content.get("alias", None) - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.hs.get_directory_handler() if room_alias_str and room_alias_str != original_alias: await self._validate_canonical_alias( directory_handler, room_alias_str, event.room_id @@ -1040,7 +1040,7 @@ class EventCreationHandler: directory_handler, alias_str, event.room_id ) - federation_handler = self.hs.get_handlers().federation_handler + federation_handler = self.hs.get_federation_handler() if event.type == EventTypes.Member: if event.content["membership"] == Membership.INVITE: diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 2c2a633938..085b685959 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -383,7 +383,7 @@ class PaginationHandler: "room_key", leave_token ) - await self.hs.get_handlers().federation_handler.maybe_backfill( + await self.hs.get_federation_handler().maybe_backfill( room_id, curr_topo, limit=pagin_config.limit, ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 538f4b2a61..a6f1d21674 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -48,7 +48,7 @@ class RegistrationHandler(BaseHandler): self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() - self.identity_handler = self.hs.get_handlers().identity_handler + self.identity_handler = self.hs.get_identity_handler() self.ratelimiter = hs.get_registration_ratelimiter() self.macaroon_gen = hs.get_macaroon_generator() self._server_notices_mxid = hs.config.server_notices_mxid diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index d0530a446c..1d04d41e98 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -691,7 +691,7 @@ class RoomCreationHandler(BaseHandler): if not allowed_by_third_party_rules: raise SynapseError(403, "Room visibility value not allowed.") - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.hs.get_directory_handler() if room_alias: await directory_handler.create_association( requester=requester, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index fd8114a64d..ffbc62ff44 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -64,9 +64,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.state_handler = hs.get_state_handler() self.config = hs.config - self.federation_handler = hs.get_handlers().federation_handler - self.directory_handler = hs.get_handlers().directory_handler - self.identity_handler = hs.get_handlers().identity_handler + self.federation_handler = hs.get_federation_handler() + self.directory_handler = hs.get_directory_handler() + self.identity_handler = hs.get_identity_handler() self.registration_handler = hs.get_registration_handler() self.profile_handler = hs.get_profile_handler() self.event_creation_handler = hs.get_event_creation_handler() diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 9146dc1a3b..3d66bf305e 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -143,7 +143,7 @@ class _BaseThreepidAuthChecker: threepid_creds = authdict["threepid_creds"] - identity_handler = self.hs.get_handlers().identity_handler + identity_handler = self.hs.get_identity_handler() logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,)) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 5393b9a9e7..b4f4a68b5c 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -62,7 +62,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.store = hs.get_datastore() self.storage = hs.get_storage() self.clock = hs.get_clock() - self.federation_handler = hs.get_handlers().federation_handler + self.federation_handler = hs.get_federation_handler() @staticmethod async def _serialize_payload(store, room_id, event_and_contexts, backfilled): diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 30680baee8..e7cc74a5d2 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -47,7 +47,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): def __init__(self, hs): super().__init__(hs) - self.federation_handler = hs.get_handlers().federation_handler + self.federation_handler = hs.get_federation_handler() self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 09726d52d6..f5304ff43d 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -138,7 +138,7 @@ class ListRoomRestServlet(RestServlet): def __init__(self, hs): self.store = hs.get_datastore() self.auth = hs.get_auth() - self.admin_handler = hs.get_handlers().admin_handler + self.admin_handler = hs.get_admin_handler() async def on_GET(self, request): requester = await self.auth.get_user_by_req(request) @@ -273,7 +273,7 @@ class JoinRoomAliasServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.room_member_handler = hs.get_room_member_handler() - self.admin_handler = hs.get_handlers().admin_handler + self.admin_handler = hs.get_admin_handler() self.state_handler = hs.get_state_handler() async def on_POST(self, request, room_identifier): diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 20dc1d0e05..8efefbc0a0 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -45,7 +45,7 @@ class UsersRestServlet(RestServlet): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() - self.admin_handler = hs.get_handlers().admin_handler + self.admin_handler = hs.get_admin_handler() async def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) @@ -82,7 +82,7 @@ class UsersRestServletV2(RestServlet): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() - self.admin_handler = hs.get_handlers().admin_handler + self.admin_handler = hs.get_admin_handler() async def on_GET(self, request): await assert_requester_is_admin(self.auth, request) @@ -135,7 +135,7 @@ class UserRestServletV2(RestServlet): def __init__(self, hs): self.hs = hs self.auth = hs.get_auth() - self.admin_handler = hs.get_handlers().admin_handler + self.admin_handler = hs.get_admin_handler() self.store = hs.get_datastore() self.auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() @@ -448,7 +448,7 @@ class WhoisRestServlet(RestServlet): def __init__(self, hs): self.hs = hs self.auth = hs.get_auth() - self.handlers = hs.get_handlers() + self.admin_handler = hs.get_admin_handler() async def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) @@ -461,7 +461,7 @@ class WhoisRestServlet(RestServlet): if not self.hs.is_mine(target_user): raise SynapseError(400, "Can only whois a local user") - ret = await self.handlers.admin_handler.get_whois(target_user) + ret = await self.admin_handler.get_whois(target_user) return 200, ret @@ -591,7 +591,6 @@ class SearchUsersRestServlet(RestServlet): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() - self.handlers = hs.get_handlers() async def on_GET(self, request, target_user_id): """Get request to search user table for specific users according to @@ -612,7 +611,7 @@ class SearchUsersRestServlet(RestServlet): term = parse_string(request, "term", required=True) logger.info("term: %s ", term) - ret = await self.handlers.store.search_users(term) + ret = await self.store.search_users(term) return 200, ret diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index faabeeb91c..e5af26b176 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -42,14 +42,13 @@ class ClientDirectoryServer(RestServlet): def __init__(self, hs): super().__init__() self.store = hs.get_datastore() - self.handlers = hs.get_handlers() + self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() async def on_GET(self, request, room_alias): room_alias = RoomAlias.from_string(room_alias) - dir_handler = self.handlers.directory_handler - res = await dir_handler.get_association(room_alias) + res = await self.directory_handler.get_association(room_alias) return 200, res @@ -79,19 +78,19 @@ class ClientDirectoryServer(RestServlet): requester = await self.auth.get_user_by_req(request) - await self.handlers.directory_handler.create_association( + await self.directory_handler.create_association( requester, room_alias, room_id, servers ) return 200, {} async def on_DELETE(self, request, room_alias): - dir_handler = self.handlers.directory_handler - try: service = self.auth.get_appservice_by_req(request) room_alias = RoomAlias.from_string(room_alias) - await dir_handler.delete_appservice_association(service, room_alias) + await self.directory_handler.delete_appservice_association( + service, room_alias + ) logger.info( "Application service at %s deleted alias %s", service.url, @@ -107,7 +106,7 @@ class ClientDirectoryServer(RestServlet): room_alias = RoomAlias.from_string(room_alias) - await dir_handler.delete_association(requester, room_alias) + await self.directory_handler.delete_association(requester, room_alias) logger.info( "User %s deleted alias %s", user.to_string(), room_alias.to_string() @@ -122,7 +121,7 @@ class ClientDirectoryListServer(RestServlet): def __init__(self, hs): super().__init__() self.store = hs.get_datastore() - self.handlers = hs.get_handlers() + self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() async def on_GET(self, request, room_id): @@ -138,7 +137,7 @@ class ClientDirectoryListServer(RestServlet): content = parse_json_object_from_request(request) visibility = content.get("visibility", "public") - await self.handlers.directory_handler.edit_published_room_list( + await self.directory_handler.edit_published_room_list( requester, room_id, visibility ) @@ -147,7 +146,7 @@ class ClientDirectoryListServer(RestServlet): async def on_DELETE(self, request, room_id): requester = await self.auth.get_user_by_req(request) - await self.handlers.directory_handler.edit_published_room_list( + await self.directory_handler.edit_published_room_list( requester, room_id, "private" ) @@ -162,7 +161,7 @@ class ClientAppserviceDirectoryListServer(RestServlet): def __init__(self, hs): super().__init__() self.store = hs.get_datastore() - self.handlers = hs.get_handlers() + self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() def on_PUT(self, request, network_id, room_id): @@ -180,7 +179,7 @@ class ClientAppserviceDirectoryListServer(RestServlet): 403, "Only appservices can edit the appservice published room list" ) - await self.handlers.directory_handler.edit_published_appservice_room_list( + await self.directory_handler.edit_published_appservice_room_list( requester.app_service.id, network_id, room_id, visibility ) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 3d1693d7ac..d7deb9300d 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -67,7 +67,6 @@ class LoginRestServlet(RestServlet): self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() - self.handlers = hs.get_handlers() self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( clock=hs.get_clock(), diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index b63389e5fe..00b4397082 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -112,7 +112,6 @@ class RoomCreateRestServlet(TransactionRestServlet): class RoomStateEventRestServlet(TransactionRestServlet): def __init__(self, hs): super().__init__(hs) - self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.message_handler = hs.get_message_handler() @@ -798,7 +797,6 @@ class RoomMembershipRestServlet(TransactionRestServlet): class RoomRedactEventRestServlet(TransactionRestServlet): def __init__(self, hs): super().__init__(hs) - self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() @@ -903,7 +901,7 @@ class RoomAliasListServlet(RestServlet): def __init__(self, hs: "synapse.server.HomeServer"): super().__init__() self.auth = hs.get_auth() - self.directory_handler = hs.get_handlers().directory_handler + self.directory_handler = hs.get_directory_handler() async def on_GET(self, request, room_id): requester = await self.auth.get_user_by_req(request) @@ -920,7 +918,7 @@ class SearchRestServlet(RestServlet): def __init__(self, hs): super().__init__() - self.handlers = hs.get_handlers() + self.search_handler = hs.get_search_handler() self.auth = hs.get_auth() async def on_POST(self, request): @@ -929,9 +927,7 @@ class SearchRestServlet(RestServlet): content = parse_json_object_from_request(request) batch = parse_string(request, "next_batch") - results = await self.handlers.search_handler.search( - requester.user, content, batch - ) + results = await self.search_handler.search(requester.user, content, batch) return 200, results diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index ab5815e7f7..e857cff176 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -56,7 +56,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): self.hs = hs self.datastore = hs.get_datastore() self.config = hs.config - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self.mailer = Mailer( @@ -327,7 +327,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): super().__init__() self.hs = hs self.config = hs.config - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() self.store = self.hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: @@ -424,7 +424,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): self.hs = hs super().__init__() self.store = self.hs.get_datastore() - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() async def on_POST(self, request): body = parse_json_object_from_request(request) @@ -574,7 +574,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): self.config = hs.config self.clock = hs.get_clock() self.store = hs.get_datastore() - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() async def on_POST(self, request): if not self.config.account_threepid_delegate_msisdn: @@ -604,7 +604,7 @@ class ThreepidRestServlet(RestServlet): def __init__(self, hs): super().__init__() self.hs = hs - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.datastore = self.hs.get_datastore() @@ -660,7 +660,7 @@ class ThreepidAddRestServlet(RestServlet): def __init__(self, hs): super().__init__() self.hs = hs - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() @@ -711,7 +711,7 @@ class ThreepidBindRestServlet(RestServlet): def __init__(self, hs): super().__init__() self.hs = hs - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() async def on_POST(self, request): @@ -740,7 +740,7 @@ class ThreepidUnbindRestServlet(RestServlet): def __init__(self, hs): super().__init__() self.hs = hs - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() self.datastore = self.hs.get_datastore() diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index ffa2dfce42..395b6a82a9 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -78,7 +78,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): """ super().__init__() self.hs = hs - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() self.config = hs.config if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: @@ -176,7 +176,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): """ super().__init__() self.hs = hs - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() async def on_POST(self, request): body = parse_json_object_from_request(request) @@ -370,7 +370,7 @@ class RegisterRestServlet(RestServlet): self.store = hs.get_datastore() self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() - self.identity_handler = hs.get_handlers().identity_handler + self.identity_handler = hs.get_identity_handler() self.room_member_handler = hs.get_room_member_handler() self.macaroon_gen = hs.get_macaroon_generator() self.ratelimiter = hs.get_registration_ratelimiter() diff --git a/synapse/server.py b/synapse/server.py index f83dd6148c..e793793cdc 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -54,19 +54,22 @@ from synapse.federation.sender import FederationSender from synapse.federation.transport.client import TransportLayerClient from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler -from synapse.handlers import Handlers from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.acme import AcmeHandler +from synapse.handlers.admin import AdminHandler from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.auth import AuthHandler, MacaroonGenerator from synapse.handlers.cas_handler import CasHandler from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler from synapse.handlers.devicemessage import DeviceMessageHandler +from synapse.handlers.directory import DirectoryHandler from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler from synapse.handlers.events import EventHandler, EventStreamHandler +from synapse.handlers.federation import FederationHandler from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler +from synapse.handlers.identity import IdentityHandler from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.message import EventCreationHandler, MessageHandler from synapse.handlers.pagination import PaginationHandler @@ -84,6 +87,7 @@ from synapse.handlers.room import ( from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_member import RoomMemberMasterHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler +from synapse.handlers.search import SearchHandler from synapse.handlers.set_password import SetPasswordHandler from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler @@ -318,10 +322,6 @@ class HomeServer(metaclass=abc.ABCMeta): def get_federation_server(self) -> FederationServer: return FederationServer(self) - @cache_in_self - def get_handlers(self) -> Handlers: - return Handlers(self) - @cache_in_self def get_notifier(self) -> Notifier: return Notifier(self) @@ -408,6 +408,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_device_message_handler(self) -> DeviceMessageHandler: return DeviceMessageHandler(self) + @cache_in_self + def get_directory_handler(self) -> DirectoryHandler: + return DirectoryHandler(self) + @cache_in_self def get_e2e_keys_handler(self) -> E2eKeysHandler: return E2eKeysHandler(self) @@ -420,6 +424,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_acme_handler(self) -> AcmeHandler: return AcmeHandler(self) + @cache_in_self + def get_admin_handler(self) -> AdminHandler: + return AdminHandler(self) + @cache_in_self def get_application_service_api(self) -> ApplicationServiceApi: return ApplicationServiceApi(self) @@ -440,6 +448,14 @@ class HomeServer(metaclass=abc.ABCMeta): def get_event_stream_handler(self) -> EventStreamHandler: return EventStreamHandler(self) + @cache_in_self + def get_federation_handler(self) -> FederationHandler: + return FederationHandler(self) + + @cache_in_self + def get_identity_handler(self) -> IdentityHandler: + return IdentityHandler(self) + @cache_in_self def get_initial_sync_handler(self) -> InitialSyncHandler: return InitialSyncHandler(self) @@ -459,6 +475,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_deactivate_account_handler(self) -> DeactivateAccountHandler: return DeactivateAccountHandler(self) + @cache_in_self + def get_search_handler(self) -> SearchHandler: + return SearchHandler(self) + @cache_in_self def get_set_password_handler(self) -> SetPasswordHandler: return SetPasswordHandler(self) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 8ab56ec94c..cb6f29d670 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -19,7 +19,6 @@ import pymacaroons from twisted.internet import defer -import synapse.handlers.auth from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import ( @@ -36,20 +35,15 @@ from tests import unittest from tests.utils import mock_getRawHeaders, setup_test_homeserver -class TestHandlers: - def __init__(self, hs): - self.auth_handler = synapse.handlers.auth.AuthHandler(hs) - - class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.state_handler = Mock() self.store = Mock() - self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None) + self.hs = yield setup_test_homeserver(self.addCleanup) self.hs.get_datastore = Mock(return_value=self.store) - self.hs.handlers = TestHandlers(self.hs) + self.hs.get_auth_handler().store = self.store self.auth = Auth(self.hs) # AuthBlocking reads from the hs' config on initialization. We need to @@ -283,7 +277,7 @@ class AuthTestCase(unittest.TestCase): self.store.get_device = Mock(return_value=defer.succeed(None)) token = yield defer.ensureDeferred( - self.hs.handlers.auth_handler.get_access_token_for_user_id( + self.hs.get_auth_handler().get_access_token_for_user_id( USER_ID, "DEVICE", valid_until_ms=None ) ) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index d2d535d23c..c98ae75974 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -50,10 +50,7 @@ class FilteringTestCase(unittest.TestCase): self.mock_http_client.put_json = DeferredMockCallable() hs = yield setup_test_homeserver( - self.addCleanup, - handlers=None, - http_client=self.mock_http_client, - keyring=Mock(), + self.addCleanup, http_client=self.mock_http_client, keyring=Mock(), ) self.filtering = hs.get_filtering() diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 8ff1460c0d..697916a019 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -315,7 +315,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.http_client = Mock() - hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client) + hs = self.setup_test_homeserver(http_client=self.http_client) return hs def test_get_keys_from_server(self): @@ -395,9 +395,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): } ] - return self.setup_test_homeserver( - handlers=None, http_client=self.http_client, config=config - ) + return self.setup_test_homeserver(http_client=self.http_client, config=config) def build_perspectives_response( self, server_name: str, signing_key: SigningKey, valid_until_ts: int, diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py index fc37c4328c..5c2b4de1a6 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py @@ -35,7 +35,7 @@ class ExfiltrateData(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.admin_handler = hs.get_handlers().admin_handler + self.admin_handler = hs.get_admin_handler() self.user1 = self.register_user("user1", "password") self.token1 = self.login("user1", "password") diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 97877c2e42..b5055e018c 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -21,24 +21,17 @@ from twisted.internet import defer import synapse import synapse.api.errors from synapse.api.errors import ResourceLimitError -from synapse.handlers.auth import AuthHandler from tests import unittest from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver -class AuthHandlers: - def __init__(self, hs): - self.auth_handler = AuthHandler(hs) - - class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None) - self.hs.handlers = AuthHandlers(self.hs) - self.auth_handler = self.hs.handlers.auth_handler + self.hs = yield setup_test_homeserver(self.addCleanup) + self.auth_handler = self.hs.get_auth_handler() self.macaroon_generator = self.hs.get_macaroon_generator() # MAU tests diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index bc0c5aefdc..2ce6dc9528 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -48,7 +48,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): federation_registry=self.mock_registry, ) - self.handler = hs.get_handlers().directory_handler + self.handler = hs.get_directory_handler() self.store = hs.get_datastore() @@ -110,7 +110,7 @@ class TestCreateAlias(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - self.handler = hs.get_handlers().directory_handler + self.handler = hs.get_directory_handler() # Create user self.admin_user = self.register_user("admin", "pass", admin=True) @@ -173,7 +173,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.handler = hs.get_handlers().directory_handler + self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() # Create user @@ -289,7 +289,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() - self.handler = hs.get_handlers().directory_handler + self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() # Create user @@ -442,7 +442,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.result) self.room_list_handler = hs.get_room_list_handler() - self.directory_handler = hs.get_handlers().directory_handler + self.directory_handler = hs.get_directory_handler() return hs diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index e79d612f7a..924f29f051 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -38,7 +38,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield utils.setup_test_homeserver( - self.addCleanup, handlers=None, federation_client=mock.Mock() + self.addCleanup, federation_client=mock.Mock() ) self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs) self.store = self.hs.get_datastore() diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 7adde9b9de..45f201a399 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -54,7 +54,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield utils.setup_test_homeserver( - self.addCleanup, handlers=None, replication_layer=mock.Mock() + self.addCleanup, replication_layer=mock.Mock() ) self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs) self.local_user = "@boris:" + self.hs.hostname diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 96fea58673..9ef80fe502 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -38,7 +38,7 @@ class FederationTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(http_client=None) - self.handler = hs.get_handlers().federation_handler + self.handler = hs.get_federation_handler() self.store = hs.get_datastore() return hs diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 306dcfe944..914c82e7a8 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -470,7 +470,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.federation_sender = hs.get_federation_sender() self.event_builder_factory = hs.get_event_builder_factory() - self.federation_handler = hs.get_handlers().federation_handler + self.federation_handler = hs.get_federation_handler() self.presence_handler = hs.get_presence_handler() # self.event_builder_for_2 = EventBuilderFactory(hs) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 8e95e53d9e..a69fa28b41 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -20,7 +20,6 @@ from twisted.internet import defer import synapse.types from synapse.api.errors import AuthError, SynapseError -from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID from tests import unittest @@ -28,11 +27,6 @@ from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver -class ProfileHandlers: - def __init__(self, hs): - self.profile_handler = MasterProfileHandler(hs) - - class ProfileTestCase(unittest.TestCase): """ Tests profile management. """ @@ -51,7 +45,6 @@ class ProfileTestCase(unittest.TestCase): hs = yield setup_test_homeserver( self.addCleanup, http_client=None, - handlers=None, resource_for_federation=Mock(), federation_client=self.mock_federation, federation_server=Mock(), diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 702c6aa089..bdf3d0a8a2 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -18,7 +18,6 @@ from mock import Mock from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import Codes, ResourceLimitError, SynapseError -from synapse.handlers.register import RegistrationHandler from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import RoomAlias, UserID, create_requester @@ -29,11 +28,6 @@ from tests.utils import mock_getRawHeaders from .. import unittest -class RegistrationHandlers: - def __init__(self, hs): - self.registration_handler = RegistrationHandler(hs) - - class RegistrationTestCase(unittest.HomeserverTestCase): """ Tests the RegistrationHandler. """ @@ -154,7 +148,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): room_alias_str = "#room:test" user_id = self.get_success(self.handler.register_user(localpart="jeff")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.hs.get_directory_handler() room_alias = RoomAlias.from_string(room_alias_str) room_id = self.get_success(directory_handler.get_association(room_alias)) @@ -193,7 +187,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.hs.get_directory_handler() room_alias = RoomAlias.from_string(room_alias_str) self.get_failure(directory_handler.get_association(room_alias), SynapseError) @@ -205,7 +199,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.store.is_real_user = Mock(return_value=make_awaitable(True)) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.hs.get_directory_handler() room_alias = RoomAlias.from_string(room_alias_str) room_id = self.get_success(directory_handler.get_association(room_alias)) @@ -237,7 +231,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): user_id = self.get_success(self.handler.register_user(localpart="jeff")) # Ensure the room was created. - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.hs.get_directory_handler() room_alias = RoomAlias.from_string(room_alias_str) room_id = self.get_success(directory_handler.get_association(room_alias)) @@ -266,7 +260,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): user_id = self.get_success(self.handler.register_user(localpart="jeff")) # Ensure the room was created. - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.hs.get_directory_handler() room_alias = RoomAlias.from_string(room_alias_str) room_id = self.get_success(directory_handler.get_association(room_alias)) @@ -304,7 +298,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): user_id = self.get_success(self.handler.register_user(localpart="jeff")) # Ensure the room was created. - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.hs.get_directory_handler() room_alias = RoomAlias.from_string(room_alias_str) room_id = self.get_success(directory_handler.get_association(room_alias)) @@ -347,7 +341,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) # Ensure the room was created. - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.hs.get_directory_handler() room_alias = RoomAlias.from_string(room_alias_str) room_id = self.get_success(directory_handler.get_association(room_alias)) @@ -384,7 +378,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): user_id = self.get_success(self.handler.register_user(localpart="jeff")) # Ensure the room was created. - directory_handler = self.hs.get_handlers().directory_handler + directory_handler = self.hs.get_directory_handler() room_alias = RoomAlias.from_string(room_alias_str) room_id = self.get_success(directory_handler.get_association(room_alias)) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 1d7edee5ba..9c4a9c3563 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -207,7 +207,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): def create_room_with_remote_server(self, user, token, remote_server="other_server"): room = self.helper.create_room_as(user, tok=token) store = self.hs.get_datastore() - federation = self.hs.get_handlers().federation_handler + federation = self.hs.get_federation_handler() prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) room_version = self.get_success(store.get_room_version(room)) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index dfe4bf7762..6bb02b9630 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -78,7 +78,7 @@ class RoomTestCase(_ShadowBannedBase): def test_invite_3pid(self): """Ensure that a 3PID invite does not attempt to contact the identity server.""" - identity_handler = self.hs.get_handlers().identity_handler + identity_handler = self.hs.get_identity_handler() identity_handler.lookup_3pid = Mock( side_effect=AssertionError("This should not get called") ) diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index f75520877f..3397ba5579 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -42,7 +42,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) - hs.get_handlers().federation_handler = Mock() + hs.get_federation_handler = Mock() return hs diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 0d809d25d5..9ba5f9d943 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -32,6 +32,7 @@ from synapse.types import JsonDict, RoomAlias, UserID from synapse.util.stringutils import random_string from tests import unittest +from tests.test_utils import make_awaitable PATH_PREFIX = b"/_matrix/client/api/v1" @@ -47,7 +48,10 @@ class RoomBase(unittest.HomeserverTestCase): "red", http_client=None, federation_client=Mock(), ) - self.hs.get_federation_handler = Mock(return_value=Mock()) + self.hs.get_federation_handler = Mock() + self.hs.get_federation_handler.return_value.maybe_backfill = Mock( + return_value=make_awaitable(None) + ) async def _insert_client_ip(*args, **kwargs): return None diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 94d2bf2eb1..cd58ee7792 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -44,7 +44,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.event_source = hs.get_event_sources().sources["typing"] - hs.get_handlers().federation_handler = Mock() + hs.get_federation_handler = Mock() async def get_user_by_access_token(token=None, allow_guest=False): return { diff --git a/tests/test_federation.py b/tests/test_federation.py index 27a7fc9ed7..d39e792580 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -75,7 +75,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): } ) - self.handler = self.homeserver.get_handlers().federation_handler + self.handler = self.homeserver.get_federation_handler() self.handler.do_auth = lambda origin, event, context, auth_events: succeed( context ) -- cgit 1.5.1 From 5009ffcaa45fc3522edc04de2f2b98dc7fe5c59c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 9 Oct 2020 13:10:33 +0100 Subject: Only send RDATA for instance local events. (#8496) When pulling events out of the DB to send over replication we were not filtering by instance name, and so we were sending events for other instances. --- changelog.d/8496.misc | 1 + synapse/replication/tcp/streams/_base.py | 11 ++++++-- synapse/replication/tcp/streams/events.py | 6 ++-- synapse/storage/databases/main/events.py | 12 ++++---- synapse/storage/databases/main/events_worker.py | 32 ++++++++++++---------- .../delta/58/20instance_name_event_tables.sql | 17 ++++++++++++ 6 files changed, 54 insertions(+), 25 deletions(-) create mode 100644 changelog.d/8496.misc create mode 100644 synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql (limited to 'synapse/replication') diff --git a/changelog.d/8496.misc b/changelog.d/8496.misc new file mode 100644 index 0000000000..237cb3b311 --- /dev/null +++ b/changelog.d/8496.misc @@ -0,0 +1 @@ +Allow events to be sent to clients sooner when using sharded event persisters. diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 54dccd15a6..61b282ab2d 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -240,13 +240,18 @@ class BackfillStream(Stream): ROW_TYPE = BackfillStreamRow def __init__(self, hs): - store = hs.get_datastore() + self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), - current_token_without_instance(store.get_current_backfill_token), - store.get_all_new_backfill_event_rows, + self._current_token, + self.store.get_all_new_backfill_event_rows, ) + def _current_token(self, instance_name: str) -> int: + # The backfill stream over replication operates on *positive* numbers, + # which means we need to negate it. + return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name) + class PresenceStream(Stream): PresenceStreamRow = namedtuple( diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index ccc7ca30d8..82e9e0d64e 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -155,7 +155,7 @@ class EventsStream(Stream): # now we fetch up to that many rows from the events table event_rows = await self._store.get_all_new_forward_event_rows( - from_token, current_token, target_row_count + instance_name, from_token, current_token, target_row_count ) # type: List[Tuple] # we rely on get_all_new_forward_event_rows strictly honouring the limit, so @@ -180,7 +180,7 @@ class EventsStream(Stream): upper_limit, state_rows_limited, ) = await self._store.get_all_updated_current_state_deltas( - from_token, upper_limit, target_row_count + instance_name, from_token, upper_limit, target_row_count ) limited = limited or state_rows_limited @@ -189,7 +189,7 @@ class EventsStream(Stream): # not to bother with the limit. ex_outliers_rows = await self._store.get_ex_outlier_stream_rows( - from_token, upper_limit + instance_name, from_token, upper_limit ) # type: List[Tuple] # we now need to turn the raw database rows returned into tuples suitable diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index b4abd961b9..b19c424ba9 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -426,12 +426,12 @@ class PersistEventsStore: # so that async background tasks get told what happened. sql = """ INSERT INTO current_state_delta_stream - (stream_id, room_id, type, state_key, event_id, prev_event_id) - SELECT ?, room_id, type, state_key, null, event_id + (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id) + SELECT ?, ?, room_id, type, state_key, null, event_id FROM current_state_events WHERE room_id = ? """ - txn.execute(sql, (stream_id, room_id)) + txn.execute(sql, (stream_id, self._instance_name, room_id)) self.db_pool.simple_delete_txn( txn, table="current_state_events", keyvalues={"room_id": room_id}, @@ -452,8 +452,8 @@ class PersistEventsStore: # sql = """ INSERT INTO current_state_delta_stream - (stream_id, room_id, type, state_key, event_id, prev_event_id) - SELECT ?, ?, ?, ?, ?, ( + (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id) + SELECT ?, ?, ?, ?, ?, ?, ( SELECT event_id FROM current_state_events WHERE room_id = ? AND type = ? AND state_key = ? ) @@ -463,6 +463,7 @@ class PersistEventsStore: ( ( stream_id, + self._instance_name, room_id, etype, state_key, @@ -755,6 +756,7 @@ class PersistEventsStore: "event_stream_ordering": stream_order, "event_id": event.event_id, "state_group": state_group_id, + "instance_name": self._instance_name, }, ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index b7ed8ca6ab..4e74fafe43 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1034,16 +1034,12 @@ class EventsWorkerStore(SQLBaseStore): return {"v1": complexity_v1} - def get_current_backfill_token(self): - """The current minimum token that backfilled events have reached""" - return -self._backfill_id_gen.get_current_token() - def get_current_events_token(self): """The current maximum token that events have reached""" return self._stream_id_gen.get_current_token() async def get_all_new_forward_event_rows( - self, last_id: int, current_id: int, limit: int + self, instance_name: str, last_id: int, current_id: int, limit: int ) -> List[Tuple]: """Returns new events, for the Events replication stream @@ -1067,10 +1063,11 @@ class EventsWorkerStore(SQLBaseStore): " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " WHERE ? < stream_ordering AND stream_ordering <= ?" + " AND instance_name = ?" " ORDER BY stream_ordering ASC" " LIMIT ?" ) - txn.execute(sql, (last_id, current_id, limit)) + txn.execute(sql, (last_id, current_id, instance_name, limit)) return txn.fetchall() return await self.db_pool.runInteraction( @@ -1078,7 +1075,7 @@ class EventsWorkerStore(SQLBaseStore): ) async def get_ex_outlier_stream_rows( - self, last_id: int, current_id: int + self, instance_name: str, last_id: int, current_id: int ) -> List[Tuple]: """Returns de-outliered events, for the Events replication stream @@ -1097,16 +1094,17 @@ class EventsWorkerStore(SQLBaseStore): "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," " state_key, redacts, relates_to_id" " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" + " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " WHERE ? < event_stream_ordering" " AND event_stream_ordering <= ?" + " AND out.instance_name = ?" " ORDER BY event_stream_ordering ASC" ) - txn.execute(sql, (last_id, current_id)) + txn.execute(sql, (last_id, current_id, instance_name)) return txn.fetchall() return await self.db_pool.runInteraction( @@ -1119,6 +1117,9 @@ class EventsWorkerStore(SQLBaseStore): """Get updates for backfill replication stream, including all new backfilled events and events that have gone from being outliers to not. + NOTE: The IDs given here are from replication, and so should be + *positive*. + Args: instance_name: The writer we want to fetch updates from. Unused here since there is only ever one writer. @@ -1149,10 +1150,11 @@ class EventsWorkerStore(SQLBaseStore): " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " WHERE ? > stream_ordering AND stream_ordering >= ?" + " AND instance_name = ?" " ORDER BY stream_ordering ASC" " LIMIT ?" ) - txn.execute(sql, (-last_id, -current_id, limit)) + txn.execute(sql, (-last_id, -current_id, instance_name, limit)) new_event_updates = [(row[0], row[1:]) for row in txn] limited = False @@ -1166,15 +1168,16 @@ class EventsWorkerStore(SQLBaseStore): "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," " state_key, redacts, relates_to_id" " FROM events AS e" - " INNER JOIN ex_outlier_stream USING (event_id)" + " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" " WHERE ? > event_stream_ordering" " AND event_stream_ordering >= ?" + " AND out.instance_name = ?" " ORDER BY event_stream_ordering DESC" ) - txn.execute(sql, (-last_id, -upper_bound)) + txn.execute(sql, (-last_id, -upper_bound, instance_name)) new_event_updates.extend((row[0], row[1:]) for row in txn) if len(new_event_updates) >= limit: @@ -1188,7 +1191,7 @@ class EventsWorkerStore(SQLBaseStore): ) async def get_all_updated_current_state_deltas( - self, from_token: int, to_token: int, target_row_count: int + self, instance_name: str, from_token: int, to_token: int, target_row_count: int ) -> Tuple[List[Tuple], int, bool]: """Fetch updates from current_state_delta_stream @@ -1214,9 +1217,10 @@ class EventsWorkerStore(SQLBaseStore): SELECT stream_id, room_id, type, state_key, event_id FROM current_state_delta_stream WHERE ? < stream_id AND stream_id <= ? + AND instance_name = ? ORDER BY stream_id ASC LIMIT ? """ - txn.execute(sql, (from_token, to_token, target_row_count)) + txn.execute(sql, (from_token, to_token, instance_name, target_row_count)) return txn.fetchall() def get_deltas_for_stream_id_txn(txn, stream_id): diff --git a/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql b/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql new file mode 100644 index 0000000000..ad1f481428 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql @@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE current_state_delta_stream ADD COLUMN instance_name TEXT; +ALTER TABLE ex_outlier_stream ADD COLUMN instance_name TEXT; -- cgit 1.5.1 From 1781bbe319ce24e8e468f0422519dc5823d8d420 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 9 Oct 2020 11:35:11 -0400 Subject: Add type hints to response cache. (#8507) --- changelog.d/8507.misc | 1 + mypy.ini | 1 + synapse/appservice/api.py | 4 +-- synapse/federation/federation_server.py | 8 ++++-- synapse/handlers/initial_sync.py | 10 ++++--- synapse/handlers/room.py | 2 +- synapse/handlers/sync.py | 4 ++- synapse/replication/http/_base.py | 2 +- synapse/util/caches/response_cache.py | 50 ++++++++++++++++++--------------- 9 files changed, 48 insertions(+), 34 deletions(-) create mode 100644 changelog.d/8507.misc (limited to 'synapse/replication') diff --git a/changelog.d/8507.misc b/changelog.d/8507.misc new file mode 100644 index 0000000000..724da8a996 --- /dev/null +++ b/changelog.d/8507.misc @@ -0,0 +1 @@ + Add type hints to various parts of the code base. diff --git a/mypy.ini b/mypy.ini index 19b60f7534..f08fe992a4 100644 --- a/mypy.ini +++ b/mypy.ini @@ -65,6 +65,7 @@ files = synapse/types.py, synapse/util/async_helpers.py, synapse/util/caches/descriptors.py, + synapse/util/caches/response_cache.py, synapse/util/caches/stream_change_cache.py, synapse/util/metrics.py, tests/replication, diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index c526c28b93..e8f0793795 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Tuple from prometheus_client import Counter @@ -93,7 +93,7 @@ class ApplicationServiceApi(SimpleHttpClient): self.protocol_meta_cache = ResponseCache( hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS - ) + ) # type: ResponseCache[Tuple[str, str]] async def query_user(self, service, user_id): if service.url is None: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e8039e244c..23278e36b7 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -116,7 +116,7 @@ class FederationServer(FederationBase): # We cache results for transaction with the same ID self._transaction_resp_cache = ResponseCache( hs, "fed_txn_handler", timeout_ms=30000 - ) + ) # type: ResponseCache[Tuple[str, str]] self.transaction_actions = TransactionActions(self.store) @@ -124,10 +124,12 @@ class FederationServer(FederationBase): # We cache responses to state queries, as they take a while and often # come in waves. - self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) + self._state_resp_cache = ResponseCache( + hs, "state_resp", timeout_ms=30000 + ) # type: ResponseCache[Tuple[str, str]] self._state_ids_resp_cache = ResponseCache( hs, "state_ids_resp", timeout_ms=30000 - ) + ) # type: ResponseCache[Tuple[str, str]] self._federation_metrics_domains = ( hs.get_config().federation.federation_metrics_domains diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 39a85801c1..98075f48d2 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Tuple from twisted.internet import defer @@ -47,12 +47,14 @@ class InitialSyncHandler(BaseHandler): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() - self.snapshot_cache = ResponseCache(hs, "initial_sync_cache") + self.snapshot_cache = ResponseCache( + hs, "initial_sync_cache" + ) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]] self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state - def snapshot_all_rooms( + async def snapshot_all_rooms( self, user_id: str, pagin_config: PaginationConfig, @@ -84,7 +86,7 @@ class InitialSyncHandler(BaseHandler): include_archived, ) - return self.snapshot_cache.wrap( + return await self.snapshot_cache.wrap( key, self._snapshot_all_rooms, user_id, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 1d04d41e98..93ed51063a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -120,7 +120,7 @@ class RoomCreationHandler(BaseHandler): # subsequent requests self._upgrade_response_cache = ResponseCache( hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS - ) + ) # type: ResponseCache[Tuple[str, str]] self._server_notices_mxid = hs.config.server_notices_mxid self.third_party_event_rules = hs.get_third_party_event_rules() diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 6fb8332f93..a306631094 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -243,7 +243,9 @@ class SyncHandler: self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() - self.response_cache = ResponseCache(hs, "sync") + self.response_cache = ResponseCache( + hs, "sync" + ) # type: ResponseCache[Tuple[Any, ...]] self.state = hs.get_state_handler() self.auth = hs.get_auth() self.storage = hs.get_storage() diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 64edadb624..2b3972cb14 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -92,7 +92,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): if self.CACHE: self.response_cache = ResponseCache( hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 - ) + ) # type: ResponseCache[str] # We reserve `instance_name` as a parameter to sending requests, so we # assert here that sub classes don't try and use the name. diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index df1a721add..32228f42ee 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar from twisted.internet import defer @@ -20,10 +21,15 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import register_cache +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) +T = TypeVar("T") + -class ResponseCache: +class ResponseCache(Generic[T]): """ This caches a deferred response. Until the deferred completes it will be returned from the cache. This means that if the client retries the request @@ -31,8 +37,9 @@ class ResponseCache: used rather than trying to compute a new response. """ - def __init__(self, hs, name, timeout_ms=0): - self.pending_result_cache = {} # Requests that haven't finished yet. + def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0): + # Requests that haven't finished yet. + self.pending_result_cache = {} # type: Dict[T, ObservableDeferred] self.clock = hs.get_clock() self.timeout_sec = timeout_ms / 1000.0 @@ -40,13 +47,13 @@ class ResponseCache: self._name = name self._metrics = register_cache("response_cache", name, self, resizable=False) - def size(self): + def size(self) -> int: return len(self.pending_result_cache) - def __len__(self): + def __len__(self) -> int: return self.size() - def get(self, key): + def get(self, key: T) -> Optional[defer.Deferred]: """Look up the given key. Can return either a new Deferred (which also doesn't follow the synapse @@ -58,12 +65,11 @@ class ResponseCache: from an absent cache entry. Args: - key (hashable): + key: key to get/set in the cache Returns: - twisted.internet.defer.Deferred|None|E: None if there is no entry - for this key; otherwise either a deferred result or the result - itself. + None if there is no entry for this key; otherwise a deferred which + resolves to the result. """ result = self.pending_result_cache.get(key) if result is not None: @@ -73,7 +79,7 @@ class ResponseCache: self._metrics.inc_misses() return None - def set(self, key, deferred): + def set(self, key: T, deferred: defer.Deferred) -> defer.Deferred: """Set the entry for the given key to the given deferred. *deferred* should run its callbacks in the sentinel logcontext (ie, @@ -85,12 +91,11 @@ class ResponseCache: result. You will probably want to make_deferred_yieldable the result. Args: - key (hashable): - deferred (twisted.internet.defer.Deferred[T): + key: key to get/set in the cache + deferred: The deferred which resolves to the result. Returns: - twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual - result. + A new deferred which resolves to the actual result. """ result = ObservableDeferred(deferred, consumeErrors=True) self.pending_result_cache[key] = result @@ -107,7 +112,9 @@ class ResponseCache: result.addBoth(remove) return result.observe() - def wrap(self, key, callback, *args, **kwargs): + def wrap( + self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any + ) -> defer.Deferred: """Wrap together a *get* and *set* call, taking care of logcontexts First looks up the key in the cache, and if it is present makes it @@ -118,21 +125,20 @@ class ResponseCache: Example usage: - @defer.inlineCallbacks - def handle_request(request): + async def handle_request(request): # etc return result - result = yield response_cache.wrap( + result = await response_cache.wrap( key, handle_request, request, ) Args: - key (hashable): key to get/set in the cache + key: key to get/set in the cache - callback (callable): function to call if the key is not found in + callback: function to call if the key is not found in the cache *args: positional parameters to pass to the callback, if it is used @@ -140,7 +146,7 @@ class ResponseCache: **kwargs: named parameters to pass to the callback, if it is used Returns: - twisted.internet.defer.Deferred: yieldable result + Deferred which resolves to the result """ result = self.get(key) if not result: -- cgit 1.5.1 From 8de3703d214c814ad637793a0cc2220e20579ffa Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 12 Oct 2020 15:51:41 +0100 Subject: Make event persisters periodically announce position over replication. (#8499) Currently background proccesses stream the events stream use the "minimum persisted position" (i.e. `get_current_token()`) rather than the vector clock style tokens. This is broadly fine as it doesn't matter if the background processes lag a small amount. However, in extreme cases (i.e. SyTests) where we only write to one event persister the background processes will never make progress. This PR changes it so that the `MultiWriterIDGenerator` keeps the current position of a given instance as up to date as possible (i.e using the latest token it sees if its not in the process of persisting anything), and then periodically announces that over replication. This then allows the "minimum persisted position" to advance, albeit with a small lag. --- changelog.d/8499.misc | 1 + docs/tcp_replication.md | 13 +++++++--- synapse/replication/tcp/client.py | 4 +++ synapse/replication/tcp/commands.py | 36 +++++++++++++++++++-------- synapse/replication/tcp/handler.py | 24 ++++++++++-------- synapse/replication/tcp/resource.py | 47 ++++++++++++++++++++++++++++++++++- synapse/storage/util/id_generators.py | 10 ++++++++ synapse/storage/util/sequence.py | 2 ++ tests/storage/test_id_generators.py | 25 ++++++++++++------- 9 files changed, 128 insertions(+), 34 deletions(-) create mode 100644 changelog.d/8499.misc (limited to 'synapse/replication') diff --git a/changelog.d/8499.misc b/changelog.d/8499.misc new file mode 100644 index 0000000000..237cb3b311 --- /dev/null +++ b/changelog.d/8499.misc @@ -0,0 +1 @@ +Allow events to be sent to clients sooner when using sharded event persisters. diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index db318baa9d..ad145439b4 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -15,7 +15,7 @@ example flow would be (where '>' indicates master to worker and > SERVER example.com < REPLICATE - > POSITION events master 53 + > POSITION events master 53 53 > RDATA events master 54 ["$foo1:bar.com", ...] > RDATA events master 55 ["$foo4:bar.com", ...] @@ -138,9 +138,9 @@ the wire: < NAME synapse.app.appservice < PING 1490197665618 < REPLICATE - > POSITION events master 1 - > POSITION backfill master 1 - > POSITION caches master 1 + > POSITION events master 1 1 + > POSITION backfill master 1 1 + > POSITION caches master 1 1 > RDATA caches master 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513] > RDATA events master 14 ["$149019767112vOHxz:localhost:8823", "!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null] @@ -185,6 +185,11 @@ client (C): updates via HTTP API, rather than via the DB, then processes should make the request to the appropriate process. + Two positions are included, the "new" position and the last position sent respectively. + This allows servers to tell instances that the positions have advanced but no + data has been written, without clients needlessly checking to see if they + have missed any updates. + #### ERROR (S, C) There was an error diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e165429cad..e27ee216f0 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -191,6 +191,10 @@ class ReplicationDataHandler: async def on_position(self, stream_name: str, instance_name: str, token: int): self.store.process_replication_rows(stream_name, instance_name, token, []) + # We poke the generic "replication" notifier to wake anything up that + # may be streaming. + self.notifier.notify_replication() + def on_remote_server_up(self, server: str): """Called when get a new REMOTE_SERVER_UP command.""" diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 8cd47770c1..ac532ed588 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -141,15 +141,23 @@ class RdataCommand(Command): class PositionCommand(Command): - """Sent by the server to tell the client the stream position without - needing to send an RDATA. + """Sent by an instance to tell others the stream position without needing to + send an RDATA. + + Two tokens are sent, the new position and the last position sent by the + instance (in an RDATA or other POSITION). The tokens are chosen so that *no* + rows were written by the instance between the `prev_token` and `new_token`. + (If an instance hasn't sent a position before then the new position can be + used for both.) Format:: - POSITION + POSITION - On receipt of a POSITION command clients should check if they have missed - any updates, and if so then fetch them out of band. + On receipt of a POSITION command instances should check if they have missed + any updates, and if so then fetch them out of band. Instances can check this + by comparing their view of the current token for the sending instance with + the included `prev_token`. The `` is the process that sent the command and is the source of the stream. @@ -157,18 +165,26 @@ class PositionCommand(Command): NAME = "POSITION" - def __init__(self, stream_name, instance_name, token): + def __init__(self, stream_name, instance_name, prev_token, new_token): self.stream_name = stream_name self.instance_name = instance_name - self.token = token + self.prev_token = prev_token + self.new_token = new_token @classmethod def from_line(cls, line): - stream_name, instance_name, token = line.split(" ", 2) - return cls(stream_name, instance_name, int(token)) + stream_name, instance_name, prev_token, new_token = line.split(" ", 3) + return cls(stream_name, instance_name, int(prev_token), int(new_token)) def to_line(self): - return " ".join((self.stream_name, self.instance_name, str(self.token))) + return " ".join( + ( + self.stream_name, + self.instance_name, + str(self.prev_token), + str(self.new_token), + ) + ) class ErrorCommand(_SimpleCommand): diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index e92da7b263..95e5502bf2 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -101,8 +101,9 @@ class ReplicationCommandHandler: self._streams_to_replicate = [] # type: List[Stream] for stream in self._streams.values(): - if stream.NAME == CachesStream.NAME: - # All workers can write to the cache invalidation stream. + if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME: + # All workers can write to the cache invalidation stream when + # using redis. self._streams_to_replicate.append(stream) continue @@ -313,11 +314,14 @@ class ReplicationCommandHandler: # We respond with current position of all streams this instance # replicates. for stream in self.get_streams_to_replicate(): + # Note that we use the current token as the prev token here (rather + # than stream.last_token), as we can't be sure that there have been + # no rows written between last token and the current token (since we + # might be racing with the replication sending bg process). + current_token = stream.current_token(self._instance_name) self.send_command( PositionCommand( - stream.NAME, - self._instance_name, - stream.current_token(self._instance_name), + stream.NAME, self._instance_name, current_token, current_token, ) ) @@ -511,16 +515,16 @@ class ReplicationCommandHandler: # If the position token matches our current token then we're up to # date and there's nothing to do. Otherwise, fetch all updates # between then and now. - missing_updates = cmd.token != current_token + missing_updates = cmd.prev_token != current_token while missing_updates: logger.info( "Fetching replication rows for '%s' between %i and %i", stream_name, current_token, - cmd.token, + cmd.new_token, ) (updates, current_token, missing_updates) = await stream.get_updates_since( - cmd.instance_name, current_token, cmd.token + cmd.instance_name, current_token, cmd.new_token ) # TODO: add some tests for this @@ -536,11 +540,11 @@ class ReplicationCommandHandler: [stream.parse_row(row) for row in rows], ) - logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token) # We've now caught up to position sent to us, notify handler. await self._replication_data_handler.on_position( - cmd.stream_name, cmd.instance_name, cmd.token + cmd.stream_name, cmd.instance_name, cmd.new_token ) self._streams_by_connection.setdefault(conn, set()).add(stream_name) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 687984e7a8..666c13fdb7 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -23,7 +23,9 @@ from prometheus_client import Counter from twisted.internet.protocol import Factory from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.tcp.commands import PositionCommand from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol +from synapse.replication.tcp.streams import EventsStream from synapse.util.metrics import Measure stream_updates_counter = Counter( @@ -84,6 +86,23 @@ class ReplicationStreamer: # Set of streams to replicate. self.streams = self.command_handler.get_streams_to_replicate() + # If we have streams then we must have redis enabled or on master + assert ( + not self.streams + or hs.config.redis.redis_enabled + or not hs.config.worker.worker_app + ) + + # If we are replicating an event stream we want to periodically check if + # we should send updated POSITIONs. We do this as a looping call rather + # explicitly poking when the position advances (without new data to + # replicate) to reduce replication traffic (otherwise each writer would + # likely send a POSITION for each new event received over replication). + # + # Note that if the position hasn't advanced then we won't send anything. + if any(EventsStream.NAME == s.NAME for s in self.streams): + self.clock.looping_call(self.on_notifier_poke, 1000) + def on_notifier_poke(self): """Checks if there is actually any new data and sends it to the connections if there are. @@ -91,7 +110,7 @@ class ReplicationStreamer: This should get called each time new data is available, even if it is currently being executed, so that nothing gets missed """ - if not self.command_handler.connected(): + if not self.command_handler.connected() or not self.streams: # Don't bother if nothing is listening. We still need to advance # the stream tokens otherwise they'll fall behind forever for stream in self.streams: @@ -136,6 +155,8 @@ class ReplicationStreamer: self._replication_torture_level / 1000.0 ) + last_token = stream.last_token + logger.debug( "Getting stream: %s: %s -> %s", stream.NAME, @@ -159,6 +180,30 @@ class ReplicationStreamer: ) stream_updates_counter.labels(stream.NAME).inc(len(updates)) + else: + # The token has advanced but there is no data to + # send, so we send a `POSITION` to inform other + # workers of the updated position. + if stream.NAME == EventsStream.NAME: + # XXX: We only do this for the EventStream as it + # turns out that e.g. account data streams share + # their "current token" with each other, meaning + # that it is *not* safe to send a POSITION. + logger.info( + "Sending position: %s -> %s", + stream.NAME, + current_token, + ) + self.command_handler.send_command( + PositionCommand( + stream.NAME, + self._instance_name, + last_token, + current_token, + ) + ) + continue + # Some streams return multiple rows with the same stream IDs, # we need to make sure they get sent out in batches. We do # this by setting the current token to all but the last of diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index d7e40aaa8b..3d8da48f2d 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -524,6 +524,16 @@ class MultiWriterIdGenerator: heapq.heappush(self._known_persisted_positions, new_id) + # If we're a writer and we don't have any active writes we update our + # current position to the latest position seen. This allows the instance + # to report a recent position when asked, rather than a potentially old + # one (if this instance hasn't written anything for a while). + our_current_position = self._current_positions.get(self._instance_name) + if our_current_position and not self._unfinished_ids: + self._current_positions[self._instance_name] = max( + our_current_position, new_id + ) + # We move the current min position up if the minimum current positions # of all instances is higher (since by definition all positions less # that that have been persisted). diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index ff2d038ad2..4386b6101e 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -126,6 +126,8 @@ class PostgresSequenceGenerator(SequenceGenerator): if max_stream_id > last_value: logger.warning( "Postgres sequence %s is behind table %s: %d < %d", + self._sequence_name, + table, last_value, max_stream_id, ) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 392b08832b..cc0612cf65 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -199,10 +199,17 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): first_id_gen = self._create_id_generator("first", writers=["first", "second"]) second_id_gen = self._create_id_generator("second", writers=["first", "second"]) - self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) - self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) + # The first ID gen will notice that it can advance its token to 7 as it + # has no in progress writes... + self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 7}) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) + # ... but the second ID gen doesn't know that. + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3) + self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) + # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. @@ -211,7 +218,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(stream_id, 8) self.assertEqual( - first_id_gen.get_positions(), {"first": 3, "second": 7} + first_id_gen.get_positions(), {"first": 7, "second": 7} ) self.get_success(_get_next_async()) @@ -279,7 +286,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_row_with_id("first", 3) self._insert_row_with_id("second", 5) - id_gen = self._create_id_generator("first", writers=["first", "second"]) + id_gen = self._create_id_generator("worker", writers=["first", "second"]) self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) @@ -319,14 +326,14 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen = self._create_id_generator("first", writers=["first", "second"]) - self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) + self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5}) - self.assertEqual(id_gen.get_persisted_upto_position(), 3) + self.assertEqual(id_gen.get_persisted_upto_position(), 5) async def _get_next_async(): async with id_gen.get_next() as stream_id: self.assertEqual(stream_id, 6) - self.assertEqual(id_gen.get_persisted_upto_position(), 3) + self.assertEqual(id_gen.get_persisted_upto_position(), 5) self.get_success(_get_next_async()) @@ -388,7 +395,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self._insert_row_with_id("second", 5) # Initial config has two writers - id_gen = self._create_id_generator("first", writers=["first", "second"]) + id_gen = self._create_id_generator("worker", writers=["first", "second"]) self.assertEqual(id_gen.get_persisted_upto_position(), 3) self.assertEqual(id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(id_gen.get_current_token_for_writer("second"), 5) @@ -568,7 +575,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.get_success(_get_next_async2()) - self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2}) + self.assertEqual(id_gen_1.get_positions(), {"first": -2, "second": -2}) self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_1.get_persisted_upto_position(), -2) self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) -- cgit 1.5.1 From b2486f6656bec2307e62de19d2830994a42b879d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 13 Oct 2020 12:07:56 +0100 Subject: Fix message duplication if something goes wrong after persisting the event (#8476) Should fix #3365. --- changelog.d/8476.bugfix | 1 + synapse/handlers/federation.py | 9 +- synapse/handlers/message.py | 48 +++++-- synapse/handlers/room_member.py | 13 +- synapse/replication/http/send_event.py | 16 ++- synapse/storage/databases/main/events.py | 31 ++++ synapse/storage/databases/main/events_worker.py | 83 ++++++++++- synapse/storage/databases/main/registration.py | 6 +- .../databases/main/schema/delta/58/19txn_id.sql | 40 ++++++ synapse/storage/persist_events.py | 96 +++++++++++-- tests/handlers/test_message.py | 157 +++++++++++++++++++++ tests/rest/client/test_third_party_rules.py | 2 +- tests/unittest.py | 11 +- 13 files changed, 481 insertions(+), 32 deletions(-) create mode 100644 changelog.d/8476.bugfix create mode 100644 synapse/storage/databases/main/schema/delta/58/19txn_id.sql create mode 100644 tests/handlers/test_message.py (limited to 'synapse/replication') diff --git a/changelog.d/8476.bugfix b/changelog.d/8476.bugfix new file mode 100644 index 0000000000..993a269979 --- /dev/null +++ b/changelog.d/8476.bugfix @@ -0,0 +1 @@ +Fix message duplication if something goes wrong after persisting the event. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 5ac2fc5656..455acd7669 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2966,17 +2966,20 @@ class FederationHandler(BaseHandler): return result["max_stream_id"] else: assert self.storage.persistence - max_stream_token = await self.storage.persistence.persist_events( + + # Note that this returns the events that were persisted, which may not be + # the same as were passed in if some were deduplicated due to transaction IDs. + events, max_stream_token = await self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled ) if self._ephemeral_messages_enabled: - for (event, context) in event_and_contexts: + for event in events: # If there's an expiry timestamp on the event, schedule its expiry. self._message_handler.maybe_schedule_expiry(event) if not backfilled: # Never notify for backfilled events - for event, _ in event_and_contexts: + for event in events: await self._notify_persisted_event(event, max_stream_token) return max_stream_token.stream diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ad0b7bd868..b0da938aa9 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -689,7 +689,7 @@ class EventCreationHandler: send this event. Returns: - The event, and its stream ordering (if state event deduplication happened, + The event, and its stream ordering (if deduplication happened, the previous, duplicate event). Raises: @@ -712,6 +712,19 @@ class EventCreationHandler: # extremities to pile up, which in turn leads to state resolution # taking longer. with (await self.limiter.queue(event_dict["room_id"])): + if txn_id and requester.access_token_id: + existing_event_id = await self.store.get_event_id_from_transaction_id( + event_dict["room_id"], + requester.user.to_string(), + requester.access_token_id, + txn_id, + ) + if existing_event_id: + event = await self.store.get_event(existing_event_id) + # we know it was persisted, so must have a stream ordering + assert event.internal_metadata.stream_ordering + return event, event.internal_metadata.stream_ordering + event, context = await self.create_event( requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id ) @@ -913,10 +926,20 @@ class EventCreationHandler: extra_users=extra_users, ) stream_id = result["stream_id"] - event.internal_metadata.stream_ordering = stream_id + event_id = result["event_id"] + if event_id != event.event_id: + # If we get a different event back then it means that its + # been de-duplicated, so we replace the given event with the + # one already persisted. + event = await self.store.get_event(event_id) + else: + # If we newly persisted the event then we need to update its + # stream_ordering entry manually (as it was persisted on + # another worker). + event.internal_metadata.stream_ordering = stream_id return event - stream_id = await self.persist_and_notify_client_event( + event = await self.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) @@ -965,11 +988,16 @@ class EventCreationHandler: context: EventContext, ratelimit: bool = True, extra_users: List[UserID] = [], - ) -> int: + ) -> EventBase: """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. This should only be run on the instance in charge of persisting events. + + Returns: + The persisted event. This may be different than the given event if + it was de-duplicated (e.g. because we had already persisted an + event with the same transaction ID.) """ assert self.storage.persistence is not None assert self._events_shard_config.should_handle( @@ -1137,9 +1165,13 @@ class EventCreationHandler: if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") - event_pos, max_stream_token = await self.storage.persistence.persist_event( - event, context=context - ) + # Note that this returns the event that was persisted, which may not be + # the same as we passed in if it was deduplicated due transaction IDs. + ( + event, + event_pos, + max_stream_token, + ) = await self.storage.persistence.persist_event(event, context=context) if self._ephemeral_events_enabled: # If there's an expiry timestamp on the event, schedule its expiry. @@ -1160,7 +1192,7 @@ class EventCreationHandler: # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) - return event_pos.stream + return event async def _bump_active_time(self, user: UserID) -> None: try: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index ffbc62ff44..0080eeaf8d 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -171,6 +171,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if requester.is_guest: content["kind"] = "guest" + # Check if we already have an event with a matching transaction ID. (We + # do this check just before we persist an event as well, but may as well + # do it up front for efficiency.) + if txn_id and requester.access_token_id: + existing_event_id = await self.store.get_event_id_from_transaction_id( + room_id, requester.user.to_string(), requester.access_token_id, txn_id, + ) + if existing_event_id: + event_pos = await self.store.get_position_for_event(existing_event_id) + return existing_event_id, event_pos.stream + event, context = await self.event_creation_handler.create_event( requester, { @@ -679,7 +690,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if is_blocked: raise SynapseError(403, "This room has been blocked on this server") - await self.event_creation_handler.handle_new_client_event( + event = await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[target_user], ratelimit=ratelimit ) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 9a3a694d5d..fc129dbaa7 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -46,6 +46,12 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "ratelimit": true, "extra_users": [], } + + 200 OK + + { "stream_id": 12345, "event_id": "$abcdef..." } + + The returned event ID may not match the sent event if it was deduplicated. """ NAME = "send_event" @@ -116,11 +122,17 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) - stream_id = await self.event_creation_handler.persist_and_notify_client_event( + event = await self.event_creation_handler.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) - return 200, {"stream_id": stream_id} + return ( + 200, + { + "stream_id": event.internal_metadata.stream_ordering, + "event_id": event.event_id, + }, + ) def register_servlets(hs, http_server): diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index b19c424ba9..fdb17745f6 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -361,6 +361,8 @@ class PersistEventsStore: self._store_event_txn(txn, events_and_contexts=events_and_contexts) + self._persist_transaction_ids_txn(txn, events_and_contexts) + # Insert into event_to_state_groups. self._store_event_state_mappings_txn(txn, events_and_contexts) @@ -405,6 +407,35 @@ class PersistEventsStore: # room_memberships, where applicable. self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) + def _persist_transaction_ids_txn( + self, + txn: LoggingTransaction, + events_and_contexts: List[Tuple[EventBase, EventContext]], + ): + """Persist the mapping from transaction IDs to event IDs (if defined). + """ + + to_insert = [] + for event, _ in events_and_contexts: + token_id = getattr(event.internal_metadata, "token_id", None) + txn_id = getattr(event.internal_metadata, "txn_id", None) + if token_id and txn_id: + to_insert.append( + { + "event_id": event.event_id, + "room_id": event.room_id, + "user_id": event.sender, + "token_id": token_id, + "txn_id": txn_id, + "inserted_ts": self._clock.time_msec(), + } + ) + + if to_insert: + self.db_pool.simple_insert_many_txn( + txn, table="event_txn_id", values=to_insert, + ) + def _update_current_state_txn( self, txn: LoggingTransaction, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 4e74fafe43..3ec4d1d9c2 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import itertools import logging import threading @@ -137,6 +136,15 @@ class EventsWorkerStore(SQLBaseStore): db_conn, "events", "stream_ordering", step=-1 ) + if not hs.config.worker.worker_app: + # We periodically clean out old transaction ID mappings + self._clock.looping_call( + run_as_background_process, + 5 * 60 * 1000, + "_cleanup_old_transaction_ids", + self._cleanup_old_transaction_ids, + ) + self._get_event_cache = Cache( "*getEvent*", keylen=3, @@ -1308,3 +1316,76 @@ class EventsWorkerStore(SQLBaseStore): return await self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) + + async def get_event_id_from_transaction_id( + self, room_id: str, user_id: str, token_id: int, txn_id: str + ) -> Optional[str]: + """Look up if we have already persisted an event for the transaction ID, + returning the event ID if so. + """ + return await self.db_pool.simple_select_one_onecol( + table="event_txn_id", + keyvalues={ + "room_id": room_id, + "user_id": user_id, + "token_id": token_id, + "txn_id": txn_id, + }, + retcol="event_id", + allow_none=True, + desc="get_event_id_from_transaction_id", + ) + + async def get_already_persisted_events( + self, events: Iterable[EventBase] + ) -> Dict[str, str]: + """Look up if we have already persisted an event for the transaction ID, + returning a mapping from event ID in the given list to the event ID of + an existing event. + + Also checks if there are duplicates in the given events, if there are + will map duplicates to the *first* event. + """ + + mapping = {} + txn_id_to_event = {} # type: Dict[Tuple[str, int, str], str] + + for event in events: + token_id = getattr(event.internal_metadata, "token_id", None) + txn_id = getattr(event.internal_metadata, "txn_id", None) + + if token_id and txn_id: + # Check if this is a duplicate of an event in the given events. + existing = txn_id_to_event.get((event.room_id, token_id, txn_id)) + if existing: + mapping[event.event_id] = existing + continue + + # Check if this is a duplicate of an event we've already + # persisted. + existing = await self.get_event_id_from_transaction_id( + event.room_id, event.sender, token_id, txn_id + ) + if existing: + mapping[event.event_id] = existing + txn_id_to_event[(event.room_id, token_id, txn_id)] = existing + else: + txn_id_to_event[(event.room_id, token_id, txn_id)] = event.event_id + + return mapping + + async def _cleanup_old_transaction_ids(self): + """Cleans out transaction id mappings older than 24hrs. + """ + + def _cleanup_old_transaction_ids_txn(txn): + sql = """ + DELETE FROM event_txn_id + WHERE inserted_ts < ? + """ + one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000 + txn.execute(sql, (one_day_ago,)) + + return await self.db_pool.runInteraction( + "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn, + ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 236d3cdbe3..9a003e30f9 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1003,7 +1003,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): token: str, device_id: Optional[str], valid_until_ms: Optional[int], - ) -> None: + ) -> int: """Adds an access token for the given user. Args: @@ -1013,6 +1013,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): valid_until_ms: when the token is valid until. None for no expiry. Raises: StoreError if there was a problem adding this. + Returns: + The token ID """ next_id = self._access_tokens_id_gen.get_next() @@ -1028,6 +1030,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): desc="add_access_token_to_user", ) + return next_id + def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str: old_device_id = self.db_pool.simple_select_one_onecol_txn( txn, "access_tokens", {"token": token}, "device_id" diff --git a/synapse/storage/databases/main/schema/delta/58/19txn_id.sql b/synapse/storage/databases/main/schema/delta/58/19txn_id.sql new file mode 100644 index 0000000000..b2454121a8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/19txn_id.sql @@ -0,0 +1,40 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- A map of recent events persisted with transaction IDs. Used to deduplicate +-- send event requests with the same transaction ID. +-- +-- Note: transaction IDs are scoped to the room ID/user ID/access token that was +-- used to make the request. +-- +-- Note: The foreign key constraints are ON DELETE CASCADE, as if we delete the +-- events or access token we don't want to try and de-duplicate the event. +CREATE TABLE IF NOT EXISTS event_txn_id ( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL, + user_id TEXT NOT NULL, + token_id BIGINT NOT NULL, + txn_id TEXT NOT NULL, + inserted_ts BIGINT NOT NULL, + FOREIGN KEY (event_id) + REFERENCES events (event_id) ON DELETE CASCADE, + FOREIGN KEY (token_id) + REFERENCES access_tokens (id) ON DELETE CASCADE +); + +CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id); +CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id); +CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts); diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 4d2d88d1f0..70e636b0ba 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -96,7 +96,9 @@ class _EventPeristenceQueue: Returns: defer.Deferred: a deferred which will resolve once the events are - persisted. Runs its callbacks *without* a logcontext. + persisted. Runs its callbacks *without* a logcontext. The result + is the same as that returned by the callback passed to + `handle_queue`. """ queue = self._event_persist_queues.setdefault(room_id, deque()) if queue: @@ -199,7 +201,7 @@ class EventsPersistenceStorage: self, events_and_contexts: Iterable[Tuple[EventBase, EventContext]], backfilled: bool = False, - ) -> RoomStreamToken: + ) -> Tuple[List[EventBase], RoomStreamToken]: """ Write events to the database Args: @@ -209,7 +211,11 @@ class EventsPersistenceStorage: which might update the current state etc. Returns: - the stream ordering of the latest persisted event + List of events persisted, the current position room stream position. + The list of events persisted may not be the same as those passed in + if they were deduplicated due to an event already existing that + matched the transcation ID; the existing event is returned in such + a case. """ partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]] for event, ctx in events_and_contexts: @@ -225,19 +231,41 @@ class EventsPersistenceStorage: for room_id in partitioned: self._maybe_start_persisting(room_id) - await make_deferred_yieldable( + # Each deferred returns a map from event ID to existing event ID if the + # event was deduplicated. (The dict may also include other entries if + # the event was persisted in a batch with other events). + # + # Since we use `defer.gatherResults` we need to merge the returned list + # of dicts into one. + ret_vals = await make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) + replaced_events = {} + for d in ret_vals: + replaced_events.update(d) + + events = [] + for event, _ in events_and_contexts: + existing_event_id = replaced_events.get(event.event_id) + if existing_event_id: + events.append(await self.main_store.get_event(existing_event_id)) + else: + events.append(event) - return self.main_store.get_room_max_token() + return ( + events, + self.main_store.get_room_max_token(), + ) async def persist_event( self, event: EventBase, context: EventContext, backfilled: bool = False - ) -> Tuple[PersistedEventPosition, RoomStreamToken]: + ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]: """ Returns: - The stream ordering of `event`, and the stream ordering of the - latest persisted event + The event, stream ordering of `event`, and the stream ordering of the + latest persisted event. The returned event may not match the given + event if it was deduplicated due to an existing event matching the + transaction ID. """ deferred = self._event_persist_queue.add_to_queue( event.room_id, [(event, context)], backfilled=backfilled @@ -245,19 +273,33 @@ class EventsPersistenceStorage: self._maybe_start_persisting(event.room_id) - await make_deferred_yieldable(deferred) + # The deferred returns a map from event ID to existing event ID if the + # event was deduplicated. (The dict may also include other entries if + # the event was persisted in a batch with other events.) + replaced_events = await make_deferred_yieldable(deferred) + replaced_event = replaced_events.get(event.event_id) + if replaced_event: + event = await self.main_store.get_event(replaced_event) event_stream_id = event.internal_metadata.stream_ordering # stream ordering should have been assigned by now assert event_stream_id pos = PersistedEventPosition(self._instance_name, event_stream_id) - return pos, self.main_store.get_room_max_token() + return event, pos, self.main_store.get_room_max_token() def _maybe_start_persisting(self, room_id: str): + """Pokes the `_event_persist_queue` to start handling new items in the + queue, if not already in progress. + + Causes the deferreds returned by `add_to_queue` to resolve with: a + dictionary of event ID to event ID we didn't persist as we already had + another event persisted with the same TXN ID. + """ + async def persisting_queue(item): with Measure(self._clock, "persist_events"): - await self._persist_events( + return await self._persist_events( item.events_and_contexts, backfilled=item.backfilled ) @@ -267,12 +309,38 @@ class EventsPersistenceStorage: self, events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool = False, - ): + ) -> Dict[str, str]: """Calculates the change to current state and forward extremities, and persists the given events and with those updates. + + Returns: + A dictionary of event ID to event ID we didn't persist as we already + had another event persisted with the same TXN ID. """ + replaced_events = {} # type: Dict[str, str] if not events_and_contexts: - return + return replaced_events + + # Check if any of the events have a transaction ID that has already been + # persisted, and if so we don't persist it again. + # + # We should have checked this a long time before we get here, but it's + # possible that different send event requests race in such a way that + # they both pass the earlier checks. Checking here isn't racey as we can + # have only one `_persist_events` per room being called at a time. + replaced_events = await self.main_store.get_already_persisted_events( + (event for event, _ in events_and_contexts) + ) + + if replaced_events: + events_and_contexts = [ + (e, ctx) + for e, ctx in events_and_contexts + if e.event_id not in replaced_events + ] + + if not events_and_contexts: + return replaced_events chunks = [ events_and_contexts[x : x + 100] @@ -441,6 +509,8 @@ class EventsPersistenceStorage: await self._handle_potentially_left_users(potentially_left_users) + return replaced_events + async def _calculate_new_extremities( self, room_id: str, diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py new file mode 100644 index 0000000000..64e28bc639 --- /dev/null +++ b/tests/handlers/test_message.py @@ -0,0 +1,157 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Tuple + +from synapse.api.constants import EventTypes +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.types import create_requester +from synapse.util.stringutils import random_string + +from tests import unittest + +logger = logging.getLogger(__name__) + + +class EventCreationTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = self.hs.get_event_creation_handler() + self.persist_event_storage = self.hs.get_storage().persistence + + self.user_id = self.register_user("tester", "foobar") + self.access_token = self.login("tester", "foobar") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) + + self.info = self.get_success( + self.hs.get_datastore().get_user_by_access_token(self.access_token,) + ) + self.token_id = self.info["token_id"] + + self.requester = create_requester(self.user_id, access_token_id=self.token_id) + + def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]: + """Create a new event with the given transaction ID. All events produced + by this method will be considered duplicates. + """ + + # We create a new event with a random body, as otherwise we'll produce + # *exactly* the same event with the same hash, and so same event ID. + return self.get_success( + self.handler.create_event( + self.requester, + { + "type": EventTypes.Message, + "room_id": self.room_id, + "sender": self.requester.user.to_string(), + "content": {"msgtype": "m.text", "body": random_string(5)}, + }, + token_id=self.token_id, + txn_id=txn_id, + ) + ) + + def test_duplicated_txn_id(self): + """Test that attempting to handle/persist an event with a transaction ID + that has already been persisted correctly returns the old event and does + *not* produce duplicate messages. + """ + + txn_id = "something_suitably_random" + + event1, context = self._create_duplicate_event(txn_id) + + ret_event1 = self.get_success( + self.handler.handle_new_client_event(self.requester, event1, context) + ) + stream_id1 = ret_event1.internal_metadata.stream_ordering + + self.assertEqual(event1.event_id, ret_event1.event_id) + + event2, context = self._create_duplicate_event(txn_id) + + # We want to test that the deduplication at the persit event end works, + # so we want to make sure we test with different events. + self.assertNotEqual(event1.event_id, event2.event_id) + + ret_event2 = self.get_success( + self.handler.handle_new_client_event(self.requester, event2, context) + ) + stream_id2 = ret_event2.internal_metadata.stream_ordering + + # Assert that the returned values match those from the initial event + # rather than the new one. + self.assertEqual(ret_event1.event_id, ret_event2.event_id) + self.assertEqual(stream_id1, stream_id2) + + # Let's test that calling `persist_event` directly also does the right + # thing. + event3, context = self._create_duplicate_event(txn_id) + self.assertNotEqual(event1.event_id, event3.event_id) + + ret_event3, event_pos3, _ = self.get_success( + self.persist_event_storage.persist_event(event3, context) + ) + + # Assert that the returned values match those from the initial event + # rather than the new one. + self.assertEqual(ret_event1.event_id, ret_event3.event_id) + self.assertEqual(stream_id1, event_pos3.stream) + + # Let's test that calling `persist_events` directly also does the right + # thing. + event4, context = self._create_duplicate_event(txn_id) + self.assertNotEqual(event1.event_id, event3.event_id) + + events, _ = self.get_success( + self.persist_event_storage.persist_events([(event3, context)]) + ) + ret_event4 = events[0] + + # Assert that the returned values match those from the initial event + # rather than the new one. + self.assertEqual(ret_event1.event_id, ret_event4.event_id) + + def test_duplicated_txn_id_one_call(self): + """Test that we correctly handle duplicates that we try and persist at + the same time. + """ + + txn_id = "something_else_suitably_random" + + # Create two duplicate events to persist at the same time + event1, context1 = self._create_duplicate_event(txn_id) + event2, context2 = self._create_duplicate_event(txn_id) + + # Ensure their event IDs are different to start with + self.assertNotEqual(event1.event_id, event2.event_id) + + events, _ = self.get_success( + self.persist_event_storage.persist_events( + [(event1, context1), (event2, context2)] + ) + ) + + # Check that we've deduplicated the events. + self.assertEqual(len(events), 2) + self.assertEqual(events[0].event_id, events[1].event_id) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index d03e121664..b737625e33 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -107,7 +107,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): request, channel = self.make_request( "PUT", - "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id, + "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/2" % self.room_id, {}, access_token=self.tok, ) diff --git a/tests/unittest.py b/tests/unittest.py index 5c87f6097e..6c1661c92c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -254,17 +254,24 @@ class HomeserverTestCase(TestCase): if hasattr(self, "user_id"): if self.hijack_auth: + # We need a valid token ID to satisfy foreign key constraints. + token_id = self.get_success( + self.hs.get_datastore().add_access_token_to_user( + self.helper.auth_user_id, "some_fake_token", None, None, + ) + ) + async def get_user_by_access_token(token=None, allow_guest=False): return { "user": UserID.from_string(self.helper.auth_user_id), - "token_id": 1, + "token_id": token_id, "is_guest": False, } async def get_user_by_req(request, allow_guest=False, rights="access"): return create_requester( UserID.from_string(self.helper.auth_user_id), - 1, + token_id, False, False, None, -- cgit 1.5.1 From 7eff59ec91e59140c375b43a6dac05b833ab0051 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 14 Oct 2020 19:40:53 +0100 Subject: Add some more type annotations to Cache --- synapse/replication/slave/storage/client_ips.py | 2 +- synapse/util/caches/descriptors.py | 81 ++++++++++++++++++------- synapse/util/caches/lrucache.py | 3 +- 3 files changed, 62 insertions(+), 24 deletions(-) (limited to 'synapse/replication') diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 1f8dafe7ea..273d627fad 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -26,7 +26,7 @@ class SlavedClientIpStore(BaseSlavedStore): self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 - ) + ) # type: Cache[tuple, int] async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 98b34f2223..14458bc20f 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,12 +13,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import enum import functools import inspect import logging import threading -from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast +from typing import ( + Any, + Callable, + Generic, + Iterable, + MutableMapping, + Optional, + Tuple, + TypeVar, + Union, + cast, +) from weakref import WeakValueDictionary from prometheus_client import Gauge @@ -38,6 +49,8 @@ logger = logging.getLogger(__name__) CacheKey = Union[Tuple, Any] F = TypeVar("F", bound=Callable[..., Any]) +KT = TypeVar("KT") +VT = TypeVar("VT") class _CachedFunction(Generic[F]): @@ -61,13 +74,19 @@ cache_pending_metric = Gauge( ["name"], ) -_CacheSentinel = object() + +class _Sentinel(enum.Enum): + # defining a sentinel in this way allows mypy to correctly handle the + # type of a dictionary lookup. + sentinel = object() class CacheEntry: __slots__ = ["deferred", "callbacks", "invalidated"] - def __init__(self, deferred, callbacks): + def __init__( + self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]] + ): self.deferred = deferred self.callbacks = set(callbacks) self.invalidated = False @@ -80,7 +99,13 @@ class CacheEntry: self.callbacks.clear() -class Cache: +class Cache(Generic[KT, VT]): + """Wraps an LruCache, adding support for Deferred results. + + It expects that each entry added with set() will be a Deferred; likewise get() + may return an ObservableDeferred. + """ + __slots__ = ( "cache", "name", @@ -103,19 +128,23 @@ class Cache: Args: name: The name of the cache max_entries: Maximum amount of entries that the cache will hold - keylen: The length of the tuple used as the cache key + keylen: The length of the tuple used as the cache key. Ignored unless + `tree` is True. tree: Use a TreeCache instead of a dict as the underlying cache type iterable: If True, count each item in the cached object as an entry, rather than each cached object apply_cache_factor_from_config: Whether cache factors specified in the config file affect `max_entries` - - Returns: - Cache """ cache_type = TreeCache if tree else dict - self._pending_deferred_cache = cache_type() + # _pending_deferred_cache maps from the key value to a `CacheEntry` object. + self._pending_deferred_cache = ( + cache_type() + ) # type: MutableMapping[KT, CacheEntry] + + # cache is used for completed results and maps to the result itself, rather than + # a Deferred. self.cache = LruCache( max_size=max_entries, keylen=keylen, @@ -155,7 +184,13 @@ class Cache: "Cache objects can only be accessed from the main thread" ) - def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True): + def get( + self, + key: KT, + default=_Sentinel.sentinel, + callback: Optional[Callable[[], None]] = None, + update_metrics: bool = True, + ): """Looks the key up in the caches. Args: @@ -166,30 +201,32 @@ class Cache: update_metrics (bool): whether to update the cache hit rate metrics Returns: - Either an ObservableDeferred or the raw result + Either an ObservableDeferred or the result itself """ callbacks = [callback] if callback else [] - val = self._pending_deferred_cache.get(key, _CacheSentinel) - if val is not _CacheSentinel: + val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) + if val is not _Sentinel.sentinel: val.callbacks.update(callbacks) if update_metrics: self.metrics.inc_hits() return val.deferred - val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) - if val is not _CacheSentinel: + val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks) + if val is not _Sentinel.sentinel: self.metrics.inc_hits() return val if update_metrics: self.metrics.inc_misses() - if default is _CacheSentinel: + if default is _Sentinel.sentinel: raise KeyError() else: return default - def set(self, key, value, callback=None): + def set( + self, key: KT, value: defer.Deferred, callback: Optional[Callable[[], None]] = None + ) -> ObservableDeferred: if not isinstance(value, defer.Deferred): raise TypeError("not a Deferred") @@ -248,7 +285,7 @@ class Cache: observer.addCallbacks(cb, eb) return observable - def prefill(self, key, value, callback=None): + def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): callbacks = [callback] if callback else [] self.cache.set(key, value, callbacks=callbacks) @@ -267,7 +304,7 @@ class Cache: if entry: entry.invalidate() - def invalidate_many(self, key): + def invalidate_many(self, key: KT): self.check_thread() if not isinstance(key, tuple): raise TypeError("The cache key must be a tuple not %r" % (type(key),)) @@ -275,7 +312,7 @@ class Cache: # if we have a pending lookup for this key, remove it from the # _pending_deferred_cache, as above - entry_dict = self._pending_deferred_cache.pop(key, None) + entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None) if entry_dict is not None: for entry in iterate_tree_cache_entry(entry_dict): entry.invalidate() @@ -396,7 +433,7 @@ class CacheDescriptor(_CacheDescriptorBase): keylen=self.num_args, tree=self.tree, iterable=self.iterable, - ) + ) # type: Cache[Tuple, Any] def get_cache_key_gen(args, kwargs): """Given some args/kwargs return a generator that resolves into diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 4bc1a67b58..33eae2b7c4 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -64,7 +64,8 @@ class LruCache: Args: max_size: The maximum amount of entries the cache can hold - keylen: The length of the tuple used as the cache key + keylen: The length of the tuple used as the cache key. Ignored unless + cache_type is `TreeCache`. cache_type (type): type of underlying cache to be used. Typically one of dict -- cgit 1.5.1 From 9f87da0a84f93096e228f01f1139c9b5db8ea3d4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 14 Oct 2020 19:43:37 +0100 Subject: Rename Cache->DeferredCache --- synapse/replication/slave/storage/client_ips.py | 6 +++--- synapse/storage/databases/main/client_ips.py | 4 ++-- synapse/storage/databases/main/devices.py | 4 ++-- synapse/storage/databases/main/events_worker.py | 4 ++-- synapse/util/caches/descriptors.py | 19 ++++++++++++------- tests/storage/test__base.py | 10 +++++----- tests/test_metrics.py | 4 ++-- tests/util/caches/test_descriptors.py | 4 ++-- 8 files changed, 30 insertions(+), 25 deletions(-) (limited to 'synapse/replication') diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 273d627fad..40ea78a353 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -15,7 +15,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY -from synapse.util.caches.descriptors import Cache +from synapse.util.caches.descriptors import DeferredCache from ._base import BaseSlavedStore @@ -24,9 +24,9 @@ class SlavedClientIpStore(BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self.client_ip_last_seen = Cache( + self.client_ip_last_seen = DeferredCache( name="client_ip_last_seen", keylen=4, max_entries=50000 - ) # type: Cache[tuple, int] + ) # type: DeferredCache[tuple, int] async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index a25a888443..ad32701d36 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_tuple_comparison_clause -from synapse.util.caches.descriptors import Cache +from synapse.util.caches.descriptors import DeferredCache logger = logging.getLogger(__name__) @@ -410,7 +410,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): class ClientIpStore(ClientIpWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - self.client_ip_last_seen = Cache( + self.client_ip_last_seen = DeferredCache( name="client_ip_last_seen", keylen=4, max_entries=50000 ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 88fd97e1df..d903155e89 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -34,7 +34,7 @@ from synapse.storage.database import ( ) from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder -from synapse.util.caches.descriptors import Cache, cached, cachedList +from synapse.util.caches.descriptors import DeferredCache, cached, cachedList from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -1004,7 +1004,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. - self.device_id_exists_cache = Cache( + self.device_id_exists_cache = DeferredCache( name="device_id_exists", keylen=2, max_entries=10000 ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 3ec4d1d9c2..be7f60f2e8 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -42,7 +42,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import Collection, get_domain_from_id -from synapse.util.caches.descriptors import Cache, cached +from synapse.util.caches.descriptors import DeferredCache, cached from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -145,7 +145,7 @@ class EventsWorkerStore(SQLBaseStore): self._cleanup_old_transaction_ids, ) - self._get_event_cache = Cache( + self._get_event_cache = DeferredCache( "*getEvent*", keylen=3, max_entries=hs.config.caches.event_cache_size, diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 14458bc20f..7c9fe199bf 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -99,7 +99,7 @@ class CacheEntry: self.callbacks.clear() -class Cache(Generic[KT, VT]): +class DeferredCache(Generic[KT, VT]): """Wraps an LruCache, adding support for Deferred results. It expects that each entry added with set() will be a Deferred; likewise get() @@ -225,7 +225,10 @@ class Cache(Generic[KT, VT]): return default def set( - self, key: KT, value: defer.Deferred, callback: Optional[Callable[[], None]] = None + self, + key: KT, + value: defer.Deferred, + callback: Optional[Callable[[], None]] = None, ) -> ObservableDeferred: if not isinstance(value, defer.Deferred): raise TypeError("not a Deferred") @@ -427,13 +430,13 @@ class CacheDescriptor(_CacheDescriptorBase): self.iterable = iterable def __get__(self, obj, owner): - cache = Cache( + cache = DeferredCache( name=self.orig.__name__, max_entries=self.max_entries, keylen=self.num_args, tree=self.tree, iterable=self.iterable, - ) # type: Cache[Tuple, Any] + ) # type: DeferredCache[Tuple, Any] def get_cache_key_gen(args, kwargs): """Given some args/kwargs return a generator that resolves into @@ -677,9 +680,9 @@ class _CacheContext: _cache_context_objects = ( WeakValueDictionary() - ) # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext] + ) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext] - def __init__(self, cache, cache_key): # type: (Cache, CacheKey) -> None + def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None self._cache = cache self._cache_key = cache_key @@ -688,7 +691,9 @@ class _CacheContext: self._cache.invalidate(self._cache_key) @classmethod - def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheContext + def get_instance( + cls, cache, cache_key + ): # type: (DeferredCache, CacheKey) -> _CacheContext """Returns an instance constructed with the given arguments. A new instance is only created if none already exists. diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index f5afed017c..00adcab7b9 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -20,14 +20,14 @@ from mock import Mock from twisted.internet import defer from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.descriptors import Cache, cached +from synapse.util.caches.descriptors import DeferredCache, cached from tests import unittest -class CacheTestCase(unittest.HomeserverTestCase): +class DeferredCacheTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, homeserver): - self.cache = Cache("test") + self.cache = DeferredCache("test") def test_empty(self): failed = False @@ -56,7 +56,7 @@ class CacheTestCase(unittest.HomeserverTestCase): self.assertTrue(failed) def test_eviction(self): - cache = Cache("test", max_entries=2) + cache = DeferredCache("test", max_entries=2) cache.prefill(1, "one") cache.prefill(2, "two") @@ -74,7 +74,7 @@ class CacheTestCase(unittest.HomeserverTestCase): cache.get(3) def test_eviction_lru(self): - cache = Cache("test", max_entries=2) + cache = DeferredCache("test", max_entries=2) cache.prefill(1, "one") cache.prefill(2, "two") diff --git a/tests/test_metrics.py b/tests/test_metrics.py index f5f63d8ed6..1c03a52f7c 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -15,7 +15,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, InFlightGauge, generate_latest -from synapse.util.caches.descriptors import Cache +from synapse.util.caches.descriptors import DeferredCache from tests import unittest @@ -138,7 +138,7 @@ class CacheMetricsTests(unittest.HomeserverTestCase): Caches produce metrics reflecting their state when scraped. """ CACHE_NAME = "cache_metrics_test_fgjkbdfg" - cache = Cache(CACHE_NAME, max_entries=777) + cache = DeferredCache(CACHE_NAME, max_entries=777) items = { x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 677e925477..bd870b4a33 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -42,9 +42,9 @@ def run_on_reactor(): return make_deferred_yieldable(d) -class CacheTestCase(unittest.TestCase): +class DeferredCacheTestCase(unittest.TestCase): def test_invalidate_all(self): - cache = descriptors.Cache("testcache") + cache = descriptors.DeferredCache("testcache") callback_record = [False, False] -- cgit 1.5.1 From 4182bb812f21d7231ff0efc5e93e5f2e88f6605e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 14 Oct 2020 23:25:23 +0100 Subject: move DeferredCache into its own module --- synapse/replication/slave/storage/client_ips.py | 2 +- synapse/storage/databases/main/client_ips.py | 2 +- synapse/storage/databases/main/devices.py | 3 +- synapse/storage/databases/main/events_worker.py | 3 +- synapse/util/caches/deferred_cache.py | 292 ++++++++++++++++++++++++ synapse/util/caches/descriptors.py | 284 +---------------------- tests/storage/test__base.py | 3 +- tests/test_metrics.py | 2 +- tests/util/caches/test_deferred_cache.py | 64 ++++++ tests/util/caches/test_descriptors.py | 44 ---- 10 files changed, 367 insertions(+), 332 deletions(-) create mode 100644 synapse/util/caches/deferred_cache.py create mode 100644 tests/util/caches/test_deferred_cache.py (limited to 'synapse/replication') diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 40ea78a353..4b0ea0cc01 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -15,7 +15,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY -from synapse.util.caches.descriptors import DeferredCache +from synapse.util.caches.deferred_cache import DeferredCache from ._base import BaseSlavedStore diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index ad32701d36..9e66e6648a 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_tuple_comparison_clause -from synapse.util.caches.descriptors import DeferredCache +from synapse.util.caches.deferred_cache import DeferredCache logger = logging.getLogger(__name__) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index d903155e89..e662a20d24 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -34,7 +34,8 @@ from synapse.storage.database import ( ) from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder -from synapse.util.caches.descriptors import DeferredCache, cached, cachedList +from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index be7f60f2e8..ff150f0be7 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -42,7 +42,8 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import Collection, get_domain_from_id -from synapse.util.caches.descriptors import DeferredCache, cached +from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.descriptors import cached from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py new file mode 100644 index 0000000000..f728cd2cf2 --- /dev/null +++ b/synapse/util/caches/deferred_cache.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import threading +from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast + +from prometheus_client import Gauge + +from twisted.internet import defer + +from synapse.util.async_helpers import ObservableDeferred +from synapse.util.caches import register_cache +from synapse.util.caches.lrucache import LruCache +from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry + +cache_pending_metric = Gauge( + "synapse_util_caches_cache_pending", + "Number of lookups currently pending for this cache", + ["name"], +) + + +KT = TypeVar("KT") +VT = TypeVar("VT") + + +class _Sentinel(enum.Enum): + # defining a sentinel in this way allows mypy to correctly handle the + # type of a dictionary lookup. + sentinel = object() + + +class DeferredCache(Generic[KT, VT]): + """Wraps an LruCache, adding support for Deferred results. + + It expects that each entry added with set() will be a Deferred; likewise get() + may return an ObservableDeferred. + """ + + __slots__ = ( + "cache", + "name", + "keylen", + "thread", + "metrics", + "_pending_deferred_cache", + ) + + def __init__( + self, + name: str, + max_entries: int = 1000, + keylen: int = 1, + tree: bool = False, + iterable: bool = False, + apply_cache_factor_from_config: bool = True, + ): + """ + Args: + name: The name of the cache + max_entries: Maximum amount of entries that the cache will hold + keylen: The length of the tuple used as the cache key. Ignored unless + `tree` is True. + tree: Use a TreeCache instead of a dict as the underlying cache type + iterable: If True, count each item in the cached object as an entry, + rather than each cached object + apply_cache_factor_from_config: Whether cache factors specified in the + config file affect `max_entries` + """ + cache_type = TreeCache if tree else dict + + # _pending_deferred_cache maps from the key value to a `CacheEntry` object. + self._pending_deferred_cache = ( + cache_type() + ) # type: MutableMapping[KT, CacheEntry] + + # cache is used for completed results and maps to the result itself, rather than + # a Deferred. + self.cache = LruCache( + max_size=max_entries, + keylen=keylen, + cache_type=cache_type, + size_callback=(lambda d: len(d)) if iterable else None, + evicted_callback=self._on_evicted, + apply_cache_factor_from_config=apply_cache_factor_from_config, + ) + + self.name = name + self.keylen = keylen + self.thread = None # type: Optional[threading.Thread] + self.metrics = register_cache( + "cache", + name, + self.cache, + collect_callback=self._metrics_collection_callback, + ) + + @property + def max_entries(self): + return self.cache.max_size + + def _on_evicted(self, evicted_count): + self.metrics.inc_evictions(evicted_count) + + def _metrics_collection_callback(self): + cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache)) + + def check_thread(self): + expected_thread = self.thread + if expected_thread is None: + self.thread = threading.current_thread() + else: + if expected_thread is not threading.current_thread(): + raise ValueError( + "Cache objects can only be accessed from the main thread" + ) + + def get( + self, + key: KT, + default=_Sentinel.sentinel, + callback: Optional[Callable[[], None]] = None, + update_metrics: bool = True, + ): + """Looks the key up in the caches. + + Args: + key(tuple) + default: What is returned if key is not in the caches. If not + specified then function throws KeyError instead + callback(fn): Gets called when the entry in the cache is invalidated + update_metrics (bool): whether to update the cache hit rate metrics + + Returns: + Either an ObservableDeferred or the result itself + """ + callbacks = [callback] if callback else [] + val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) + if val is not _Sentinel.sentinel: + val.callbacks.update(callbacks) + if update_metrics: + self.metrics.inc_hits() + return val.deferred + + val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks) + if val is not _Sentinel.sentinel: + self.metrics.inc_hits() + return val + + if update_metrics: + self.metrics.inc_misses() + + if default is _Sentinel.sentinel: + raise KeyError() + else: + return default + + def set( + self, + key: KT, + value: defer.Deferred, + callback: Optional[Callable[[], None]] = None, + ) -> ObservableDeferred: + if not isinstance(value, defer.Deferred): + raise TypeError("not a Deferred") + + callbacks = [callback] if callback else [] + self.check_thread() + observable = ObservableDeferred(value, consumeErrors=True) + observer = observable.observe() + entry = CacheEntry(deferred=observable, callbacks=callbacks) + + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry: + existing_entry.invalidate() + + self._pending_deferred_cache[key] = entry + + def compare_and_pop(): + """Check if our entry is still the one in _pending_deferred_cache, and + if so, pop it. + + Returns true if the entries matched. + """ + existing_entry = self._pending_deferred_cache.pop(key, None) + if existing_entry is entry: + return True + + # oops, the _pending_deferred_cache has been updated since + # we started our query, so we are out of date. + # + # Better put back whatever we took out. (We do it this way + # round, rather than peeking into the _pending_deferred_cache + # and then removing on a match, to make the common case faster) + if existing_entry is not None: + self._pending_deferred_cache[key] = existing_entry + + return False + + def cb(result): + if compare_and_pop(): + self.cache.set(key, result, entry.callbacks) + else: + # we're not going to put this entry into the cache, so need + # to make sure that the invalidation callbacks are called. + # That was probably done when _pending_deferred_cache was + # updated, but it's possible that `set` was called without + # `invalidate` being previously called, in which case it may + # not have been. Either way, let's double-check now. + entry.invalidate() + + def eb(_fail): + compare_and_pop() + entry.invalidate() + + # once the deferred completes, we can move the entry from the + # _pending_deferred_cache to the real cache. + # + observer.addCallbacks(cb, eb) + return observable + + def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): + callbacks = [callback] if callback else [] + self.cache.set(key, value, callbacks=callbacks) + + def invalidate(self, key): + self.check_thread() + self.cache.pop(key, None) + + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, which will (a) stop it being returned + # for future queries and (b) stop it being persisted as a proper entry + # in self.cache. + entry = self._pending_deferred_cache.pop(key, None) + + # run the invalidation callbacks now, rather than waiting for the + # deferred to resolve. + if entry: + entry.invalidate() + + def invalidate_many(self, key: KT): + self.check_thread() + if not isinstance(key, tuple): + raise TypeError("The cache key must be a tuple not %r" % (type(key),)) + self.cache.del_multi(key) + + # if we have a pending lookup for this key, remove it from the + # _pending_deferred_cache, as above + entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None) + if entry_dict is not None: + for entry in iterate_tree_cache_entry(entry_dict): + entry.invalidate() + + def invalidate_all(self): + self.check_thread() + self.cache.clear() + for entry in self._pending_deferred_cache.values(): + entry.invalidate() + self._pending_deferred_cache.clear() + + +class CacheEntry: + __slots__ = ["deferred", "callbacks", "invalidated"] + + def __init__( + self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]] + ): + self.deferred = deferred + self.callbacks = set(callbacks) + self.invalidated = False + + def invalidate(self): + if not self.invalidated: + self.invalidated = True + for callback in self.callbacks: + callback() + self.callbacks.clear() diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 7c9fe199bf..1f43886804 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,44 +13,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import enum import functools import inspect import logging -import threading -from typing import ( - Any, - Callable, - Generic, - Iterable, - MutableMapping, - Optional, - Tuple, - TypeVar, - Union, - cast, -) +from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast from weakref import WeakValueDictionary -from prometheus_client import Gauge - from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry - -from . import register_cache +from synapse.util.caches.deferred_cache import DeferredCache logger = logging.getLogger(__name__) CacheKey = Union[Tuple, Any] F = TypeVar("F", bound=Callable[..., Any]) -KT = TypeVar("KT") -VT = TypeVar("VT") class _CachedFunction(Generic[F]): @@ -68,266 +48,6 @@ class _CachedFunction(Generic[F]): __call__ = None # type: F -cache_pending_metric = Gauge( - "synapse_util_caches_cache_pending", - "Number of lookups currently pending for this cache", - ["name"], -) - - -class _Sentinel(enum.Enum): - # defining a sentinel in this way allows mypy to correctly handle the - # type of a dictionary lookup. - sentinel = object() - - -class CacheEntry: - __slots__ = ["deferred", "callbacks", "invalidated"] - - def __init__( - self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]] - ): - self.deferred = deferred - self.callbacks = set(callbacks) - self.invalidated = False - - def invalidate(self): - if not self.invalidated: - self.invalidated = True - for callback in self.callbacks: - callback() - self.callbacks.clear() - - -class DeferredCache(Generic[KT, VT]): - """Wraps an LruCache, adding support for Deferred results. - - It expects that each entry added with set() will be a Deferred; likewise get() - may return an ObservableDeferred. - """ - - __slots__ = ( - "cache", - "name", - "keylen", - "thread", - "metrics", - "_pending_deferred_cache", - ) - - def __init__( - self, - name: str, - max_entries: int = 1000, - keylen: int = 1, - tree: bool = False, - iterable: bool = False, - apply_cache_factor_from_config: bool = True, - ): - """ - Args: - name: The name of the cache - max_entries: Maximum amount of entries that the cache will hold - keylen: The length of the tuple used as the cache key. Ignored unless - `tree` is True. - tree: Use a TreeCache instead of a dict as the underlying cache type - iterable: If True, count each item in the cached object as an entry, - rather than each cached object - apply_cache_factor_from_config: Whether cache factors specified in the - config file affect `max_entries` - """ - cache_type = TreeCache if tree else dict - - # _pending_deferred_cache maps from the key value to a `CacheEntry` object. - self._pending_deferred_cache = ( - cache_type() - ) # type: MutableMapping[KT, CacheEntry] - - # cache is used for completed results and maps to the result itself, rather than - # a Deferred. - self.cache = LruCache( - max_size=max_entries, - keylen=keylen, - cache_type=cache_type, - size_callback=(lambda d: len(d)) if iterable else None, - evicted_callback=self._on_evicted, - apply_cache_factor_from_config=apply_cache_factor_from_config, - ) - - self.name = name - self.keylen = keylen - self.thread = None # type: Optional[threading.Thread] - self.metrics = register_cache( - "cache", - name, - self.cache, - collect_callback=self._metrics_collection_callback, - ) - - @property - def max_entries(self): - return self.cache.max_size - - def _on_evicted(self, evicted_count): - self.metrics.inc_evictions(evicted_count) - - def _metrics_collection_callback(self): - cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache)) - - def check_thread(self): - expected_thread = self.thread - if expected_thread is None: - self.thread = threading.current_thread() - else: - if expected_thread is not threading.current_thread(): - raise ValueError( - "Cache objects can only be accessed from the main thread" - ) - - def get( - self, - key: KT, - default=_Sentinel.sentinel, - callback: Optional[Callable[[], None]] = None, - update_metrics: bool = True, - ): - """Looks the key up in the caches. - - Args: - key(tuple) - default: What is returned if key is not in the caches. If not - specified then function throws KeyError instead - callback(fn): Gets called when the entry in the cache is invalidated - update_metrics (bool): whether to update the cache hit rate metrics - - Returns: - Either an ObservableDeferred or the result itself - """ - callbacks = [callback] if callback else [] - val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) - if val is not _Sentinel.sentinel: - val.callbacks.update(callbacks) - if update_metrics: - self.metrics.inc_hits() - return val.deferred - - val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks) - if val is not _Sentinel.sentinel: - self.metrics.inc_hits() - return val - - if update_metrics: - self.metrics.inc_misses() - - if default is _Sentinel.sentinel: - raise KeyError() - else: - return default - - def set( - self, - key: KT, - value: defer.Deferred, - callback: Optional[Callable[[], None]] = None, - ) -> ObservableDeferred: - if not isinstance(value, defer.Deferred): - raise TypeError("not a Deferred") - - callbacks = [callback] if callback else [] - self.check_thread() - observable = ObservableDeferred(value, consumeErrors=True) - observer = observable.observe() - entry = CacheEntry(deferred=observable, callbacks=callbacks) - - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry: - existing_entry.invalidate() - - self._pending_deferred_cache[key] = entry - - def compare_and_pop(): - """Check if our entry is still the one in _pending_deferred_cache, and - if so, pop it. - - Returns true if the entries matched. - """ - existing_entry = self._pending_deferred_cache.pop(key, None) - if existing_entry is entry: - return True - - # oops, the _pending_deferred_cache has been updated since - # we started our query, so we are out of date. - # - # Better put back whatever we took out. (We do it this way - # round, rather than peeking into the _pending_deferred_cache - # and then removing on a match, to make the common case faster) - if existing_entry is not None: - self._pending_deferred_cache[key] = existing_entry - - return False - - def cb(result): - if compare_and_pop(): - self.cache.set(key, result, entry.callbacks) - else: - # we're not going to put this entry into the cache, so need - # to make sure that the invalidation callbacks are called. - # That was probably done when _pending_deferred_cache was - # updated, but it's possible that `set` was called without - # `invalidate` being previously called, in which case it may - # not have been. Either way, let's double-check now. - entry.invalidate() - - def eb(_fail): - compare_and_pop() - entry.invalidate() - - # once the deferred completes, we can move the entry from the - # _pending_deferred_cache to the real cache. - # - observer.addCallbacks(cb, eb) - return observable - - def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): - callbacks = [callback] if callback else [] - self.cache.set(key, value, callbacks=callbacks) - - def invalidate(self, key): - self.check_thread() - self.cache.pop(key, None) - - # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, which will (a) stop it being returned - # for future queries and (b) stop it being persisted as a proper entry - # in self.cache. - entry = self._pending_deferred_cache.pop(key, None) - - # run the invalidation callbacks now, rather than waiting for the - # deferred to resolve. - if entry: - entry.invalidate() - - def invalidate_many(self, key: KT): - self.check_thread() - if not isinstance(key, tuple): - raise TypeError("The cache key must be a tuple not %r" % (type(key),)) - self.cache.del_multi(key) - - # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, as above - entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None) - if entry_dict is not None: - for entry in iterate_tree_cache_entry(entry_dict): - entry.invalidate() - - def invalidate_all(self): - self.check_thread() - self.cache.clear() - for entry in self._pending_deferred_cache.values(): - entry.invalidate() - self._pending_deferred_cache.clear() - - class _CacheDescriptorBase: def __init__(self, orig: _CachedFunction, num_args, cache_context=False): self.orig = orig diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 00adcab7b9..2598dbe0a7 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -20,7 +20,8 @@ from mock import Mock from twisted.internet import defer from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches.descriptors import DeferredCache, cached +from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.descriptors import cached from tests import unittest diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 1c03a52f7c..759e4cd048 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -15,7 +15,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, InFlightGauge, generate_latest -from synapse.util.caches.descriptors import DeferredCache +from synapse.util.caches.deferred_cache import DeferredCache from tests import unittest diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py new file mode 100644 index 0000000000..9b6acdfc43 --- /dev/null +++ b/tests/util/caches/test_deferred_cache.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +from twisted.internet import defer + +import synapse.util.caches.deferred_cache + + +class DeferredCacheTestCase(unittest.TestCase): + def test_invalidate_all(self): + cache = synapse.util.caches.deferred_cache.DeferredCache("testcache") + + callback_record = [False, False] + + def record_callback(idx): + callback_record[idx] = True + + # add a couple of pending entries + d1 = defer.Deferred() + cache.set("key1", d1, partial(record_callback, 0)) + + d2 = defer.Deferred() + cache.set("key2", d2, partial(record_callback, 1)) + + # lookup should return observable deferreds + self.assertFalse(cache.get("key1").has_called()) + self.assertFalse(cache.get("key2").has_called()) + + # let one of the lookups complete + d2.callback("result2") + + # for now at least, the cache will return real results rather than an + # observabledeferred + self.assertEqual(cache.get("key2"), "result2") + + # now do the invalidation + cache.invalidate_all() + + # lookup should return none + self.assertIsNone(cache.get("key1", None)) + self.assertIsNone(cache.get("key2", None)) + + # both callbacks should have been callbacked + self.assertTrue(callback_record[0], "Invalidation callback for key1 not called") + self.assertTrue(callback_record[1], "Invalidation callback for key2 not called") + + # letting the other lookup complete should do nothing + d1.callback("result1") + self.assertIsNone(cache.get("key1", None)) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index bd870b4a33..3d1f960869 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from functools import partial import mock @@ -42,49 +41,6 @@ def run_on_reactor(): return make_deferred_yieldable(d) -class DeferredCacheTestCase(unittest.TestCase): - def test_invalidate_all(self): - cache = descriptors.DeferredCache("testcache") - - callback_record = [False, False] - - def record_callback(idx): - callback_record[idx] = True - - # add a couple of pending entries - d1 = defer.Deferred() - cache.set("key1", d1, partial(record_callback, 0)) - - d2 = defer.Deferred() - cache.set("key2", d2, partial(record_callback, 1)) - - # lookup should return observable deferreds - self.assertFalse(cache.get("key1").has_called()) - self.assertFalse(cache.get("key2").has_called()) - - # let one of the lookups complete - d2.callback("result2") - - # for now at least, the cache will return real results rather than an - # observabledeferred - self.assertEqual(cache.get("key2"), "result2") - - # now do the invalidation - cache.invalidate_all() - - # lookup should return none - self.assertIsNone(cache.get("key1", None)) - self.assertIsNone(cache.get("key2", None)) - - # both callbacks should have been callbacked - self.assertTrue(callback_record[0], "Invalidation callback for key1 not called") - self.assertTrue(callback_record[1], "Invalidation callback for key2 not called") - - # letting the other lookup complete should do nothing - d1.callback("result1") - self.assertIsNone(cache.get("key1", None)) - - class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cache(self): -- cgit 1.5.1 From 97647b33c248f25571bae617365d95434e6a3d5f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 19 Oct 2020 12:20:29 +0100 Subject: Replace DeferredCache with LruCache where possible (#8563) Most of these uses don't need a full-blown DeferredCache; LruCache is lighter and more appropriate. --- changelog.d/8563.misc | 1 + synapse/replication/slave/storage/client_ips.py | 10 +++++----- synapse/storage/_base.py | 12 +++++++----- synapse/storage/databases/main/client_ips.py | 8 ++++---- synapse/storage/databases/main/devices.py | 8 ++++---- synapse/storage/databases/main/events.py | 4 +--- synapse/storage/databases/main/events_worker.py | 11 +++++------ synapse/util/caches/lrucache.py | 3 +++ 8 files changed, 30 insertions(+), 27 deletions(-) create mode 100644 changelog.d/8563.misc (limited to 'synapse/replication') diff --git a/changelog.d/8563.misc b/changelog.d/8563.misc new file mode 100644 index 0000000000..eeba8e5fee --- /dev/null +++ b/changelog.d/8563.misc @@ -0,0 +1 @@ +Replace `DeferredCache` with the lighter-weight `LruCache` where possible. diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 4b0ea0cc01..0f5b7adef7 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -15,7 +15,7 @@ from synapse.storage.database import DatabasePool from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY -from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.lrucache import LruCache from ._base import BaseSlavedStore @@ -24,9 +24,9 @@ class SlavedClientIpStore(BaseSlavedStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self.client_ip_last_seen = DeferredCache( - name="client_ip_last_seen", keylen=4, max_entries=50000 - ) # type: DeferredCache[tuple, int] + self.client_ip_last_seen = LruCache( + cache_name="client_ip_last_seen", keylen=4, max_size=50000 + ) # type: LruCache[tuple, int] async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) @@ -41,7 +41,7 @@ class SlavedClientIpStore(BaseSlavedStore): if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return - self.client_ip_last_seen.prefill(key, now) + self.client_ip_last_seen.set(key, now) self.hs.get_tcp_replication().send_user_ip( user_id, access_token, ip, user_agent, device_id, now diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index ab49d227de..2b196ded1b 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -76,14 +76,16 @@ class SQLBaseStore(metaclass=ABCMeta): """ try: - if key is None: - getattr(self, cache_name).invalidate_all() - else: - getattr(self, cache_name).invalidate(tuple(key)) + cache = getattr(self, cache_name) except AttributeError: # We probably haven't pulled in the cache in this worker, # which is fine. - pass + return + + if key is None: + cache.invalidate_all() + else: + cache.invalidate(tuple(key)) def db_to_json(db_content): diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 9e66e6648a..339bd691a4 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_tuple_comparison_clause -from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -410,8 +410,8 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): class ClientIpStore(ClientIpWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): - self.client_ip_last_seen = DeferredCache( - name="client_ip_last_seen", keylen=4, max_entries=50000 + self.client_ip_last_seen = LruCache( + cache_name="client_ip_last_seen", keylen=4, max_size=50000 ) super().__init__(database, db_conn, hs) @@ -442,7 +442,7 @@ class ClientIpStore(ClientIpWorkerStore): if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return - self.client_ip_last_seen.prefill(key, now) + self.client_ip_last_seen.set(key, now) self._batch_row_update[key] = (user_agent, device_id, now) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index e662a20d24..dfb4f87b8f 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -34,8 +34,8 @@ from synapse.storage.database import ( ) from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder -from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -1005,8 +1005,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. - self.device_id_exists_cache = DeferredCache( - name="device_id_exists", keylen=2, max_entries=10000 + self.device_id_exists_cache = LruCache( + cache_name="device_id_exists", keylen=2, max_size=10000 ) async def store_device( @@ -1052,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) if hidden: raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) - self.device_id_exists_cache.prefill(key, True) + self.device_id_exists_cache.set(key, True) return inserted except StoreError: raise diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index ba3b1769b0..87808c1483 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1051,9 +1051,7 @@ class PersistEventsStore: def prefill(): for cache_entry in to_prefill: - self.store._get_event_cache.prefill( - (cache_entry[0].event_id,), cache_entry - ) + self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry) txn.call_after(prefill) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 0ad9a19b3d..c342df2a8b 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -42,8 +42,8 @@ from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import Collection, get_domain_from_id -from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.descriptors import cached +from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -146,11 +146,10 @@ class EventsWorkerStore(SQLBaseStore): self._cleanup_old_transaction_ids, ) - self._get_event_cache = DeferredCache( - "*getEvent*", + self._get_event_cache = LruCache( + cache_name="*getEvent*", keylen=3, - max_entries=hs.config.caches.event_cache_size, - apply_cache_factor_from_config=False, + max_size=hs.config.caches.event_cache_size, ) self._event_fetch_lock = threading.Condition() @@ -749,7 +748,7 @@ class EventsWorkerStore(SQLBaseStore): event=original_ev, redacted_event=redacted_event ) - self._get_event_cache.prefill((event_id,), cache_entry) + self._get_event_cache.set((event_id,), cache_entry) result_map[event_id] = cache_entry return result_map diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 4e95dd9bf3..3b471d8fd3 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -337,6 +337,9 @@ class LruCache(Generic[KT, VT]): self.set = cache_set self.setdefault = cache_set_default self.pop = cache_pop + # `invalidate` is exposed for consistency with DeferredCache, so that it can be + # invalidated by the cache invalidation replication stream. + self.invalidate = cache_pop if cache_type is TreeCache: self.del_multi = cache_del_multi self.len = synchronized(cache_len) -- cgit 1.5.1 From 2b7c180879e5d62145feed88375ba55f18fc2ae5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 26 Oct 2020 09:30:19 +0000 Subject: Start fewer opentracing spans (#8640) #8567 started a span for every background process. This is good as it means all Synapse code that gets run should be in a span (unless in the sentinel logging context), but it means we generate about 15x the number of spans as we did previously. This PR attempts to reduce that number by a) not starting one for send commands to Redis, and b) deferring starting background processes until after we're sure they're necessary. I don't really know how much this will help. --- changelog.d/8640.misc | 1 + synapse/handlers/appservice.py | 50 +++++++++++++++++++++++---- synapse/logging/opentracing.py | 10 +++--- synapse/metrics/background_process_metrics.py | 12 +++++-- synapse/notifier.py | 34 ++++++------------ synapse/push/pusherpool.py | 18 ++++++++-- synapse/replication/tcp/redis.py | 4 ++- tests/handlers/test_appservice.py | 20 +++++------ 8 files changed, 96 insertions(+), 53 deletions(-) create mode 100644 changelog.d/8640.misc (limited to 'synapse/replication') diff --git a/changelog.d/8640.misc b/changelog.d/8640.misc new file mode 100644 index 0000000000..cf6023f783 --- /dev/null +++ b/changelog.d/8640.misc @@ -0,0 +1 @@ +Reduce number of OpenTracing spans started. diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 07240d3a14..7826387e53 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from prometheus_client import Counter @@ -30,7 +30,10 @@ from synapse.metrics import ( event_processing_loop_counter, event_processing_loop_room_count, ) -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.types import Collection, JsonDict, RoomStreamToken, UserID from synapse.util.metrics import Measure @@ -53,7 +56,7 @@ class ApplicationServicesHandler: self.current_max = 0 self.is_processing = False - async def notify_interested_services(self, max_token: RoomStreamToken): + def notify_interested_services(self, max_token: RoomStreamToken): """Notifies (pushes) all application services interested in this event. Pushing is done asynchronously, so this method won't block for any @@ -72,6 +75,12 @@ class ApplicationServicesHandler: if self.is_processing: return + # We only start a new background process if necessary rather than + # optimistically (to cut down on overhead). + self._notify_interested_services(max_token) + + @wrap_as_background_process("notify_interested_services") + async def _notify_interested_services(self, max_token: RoomStreamToken): with Measure(self.clock, "notify_interested_services"): self.is_processing = True try: @@ -166,8 +175,11 @@ class ApplicationServicesHandler: finally: self.is_processing = False - async def notify_interested_services_ephemeral( - self, stream_key: str, new_token: Optional[int], users: Collection[UserID] = [], + def notify_interested_services_ephemeral( + self, + stream_key: str, + new_token: Optional[int], + users: Collection[Union[str, UserID]] = [], ): """This is called by the notifier in the background when a ephemeral event handled by the homeserver. @@ -183,13 +195,34 @@ class ApplicationServicesHandler: new_token: The latest stream token users: The user(s) involved with the event. """ + if not self.notify_appservices: + return + + if stream_key not in ("typing_key", "receipt_key", "presence_key"): + return + services = [ service for service in self.store.get_app_services() if service.supports_ephemeral ] - if not services or not self.notify_appservices: + if not services: return + + # We only start a new background process if necessary rather than + # optimistically (to cut down on overhead). + self._notify_interested_services_ephemeral( + services, stream_key, new_token, users + ) + + @wrap_as_background_process("notify_interested_services_ephemeral") + async def _notify_interested_services_ephemeral( + self, + services: List[ApplicationService], + stream_key: str, + new_token: Optional[int], + users: Collection[Union[str, UserID]], + ): logger.info("Checking interested services for %s" % (stream_key)) with Measure(self.clock, "notify_interested_services_ephemeral"): for service in services: @@ -237,7 +270,7 @@ class ApplicationServicesHandler: return receipts async def _handle_presence( - self, service: ApplicationService, users: Collection[UserID] + self, service: ApplicationService, users: Collection[Union[str, UserID]] ): events = [] # type: List[JsonDict] presence_source = self.event_sources.sources["presence"] @@ -245,6 +278,9 @@ class ApplicationServicesHandler: service, "presence" ) for user in users: + if isinstance(user, str): + user = UserID.from_string(user) + interested = await service.is_interested_in_presence(user, self.store) if not interested: continue diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index e58850faff..ab586c318c 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -317,7 +317,7 @@ def ensure_active_span(message, ret=None): @contextlib.contextmanager -def _noop_context_manager(*args, **kwargs): +def noop_context_manager(*args, **kwargs): """Does exactly what it says on the tin""" yield @@ -413,7 +413,7 @@ def start_active_span( """ if opentracing is None: - return _noop_context_manager() + return noop_context_manager() return opentracing.tracer.start_active_span( operation_name, @@ -428,7 +428,7 @@ def start_active_span( def start_active_span_follows_from(operation_name, contexts): if opentracing is None: - return _noop_context_manager() + return noop_context_manager() references = [opentracing.follows_from(context) for context in contexts] scope = start_active_span(operation_name, references=references) @@ -459,7 +459,7 @@ def start_active_span_from_request( # Also, twisted uses byte arrays while opentracing expects strings. if opentracing is None: - return _noop_context_manager() + return noop_context_manager() header_dict = { k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders() @@ -497,7 +497,7 @@ def start_active_span_from_edu( """ if opentracing is None: - return _noop_context_manager() + return noop_context_manager() carrier = json_decoder.decode(edu_content.get("context", "{}")).get( "opentracing", {} diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 08fbf78eee..658f6ecd72 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -24,7 +24,7 @@ from prometheus_client.core import REGISTRY, Counter, Gauge from twisted.internet import defer from synapse.logging.context import LoggingContext, PreserveLoggingContext -from synapse.logging.opentracing import start_active_span +from synapse.logging.opentracing import noop_context_manager, start_active_span if TYPE_CHECKING: import resource @@ -167,7 +167,7 @@ class _BackgroundProcess: ) -def run_as_background_process(desc: str, func, *args, **kwargs): +def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwargs): """Run the given function in its own logcontext, with resource metrics This should be used to wrap processes which are fired off to run in the @@ -181,6 +181,9 @@ def run_as_background_process(desc: str, func, *args, **kwargs): Args: desc: a description for this background process type func: a function, which may return a Deferred or a coroutine + bg_start_span: Whether to start an opentracing span. Defaults to True. + Should only be disabled for processes that will not log to or tag + a span. args: positional args for func kwargs: keyword args for func @@ -199,7 +202,10 @@ def run_as_background_process(desc: str, func, *args, **kwargs): with BackgroundProcessLoggingContext(desc) as context: context.request = "%s-%i" % (desc, count) try: - with start_active_span(desc, tags={"request_id": context.request}): + ctx = noop_context_manager() + if bg_start_span: + ctx = start_active_span(desc, tags={"request_id": context.request}) + with ctx: result = func(*args, **kwargs) if inspect.isawaitable(result): diff --git a/synapse/notifier.py b/synapse/notifier.py index 858b487bec..eb56b26f21 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -40,7 +40,6 @@ from synapse.handlers.presence import format_user_presence_state from synapse.logging.context import PreserveLoggingContext from synapse.logging.utils import log_function from synapse.metrics import LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.streams.config import PaginationConfig from synapse.types import ( Collection, @@ -310,44 +309,37 @@ class Notifier: """ # poke any interested application service. - run_as_background_process( - "_notify_app_services", self._notify_app_services, max_room_stream_token - ) - - run_as_background_process( - "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_token - ) + self._notify_app_services(max_room_stream_token) + self._notify_pusher_pool(max_room_stream_token) if self.federation_sender: self.federation_sender.notify_new_events(max_room_stream_token) - async def _notify_app_services(self, max_room_stream_token: RoomStreamToken): + def _notify_app_services(self, max_room_stream_token: RoomStreamToken): try: - await self.appservice_handler.notify_interested_services( - max_room_stream_token - ) + self.appservice_handler.notify_interested_services(max_room_stream_token) except Exception: logger.exception("Error notifying application services of event") - async def _notify_app_services_ephemeral( + def _notify_app_services_ephemeral( self, stream_key: str, new_token: Union[int, RoomStreamToken], - users: Collection[UserID] = [], + users: Collection[Union[str, UserID]] = [], ): try: stream_token = None if isinstance(new_token, int): stream_token = new_token - await self.appservice_handler.notify_interested_services_ephemeral( + self.appservice_handler.notify_interested_services_ephemeral( stream_key, stream_token, users ) except Exception: logger.exception("Error notifying application services of event") - async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken): + def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken): try: - await self._pusher_pool.on_new_notifications(max_room_stream_token) + self._pusher_pool.on_new_notifications(max_room_stream_token) except Exception: logger.exception("Error pusher pool of event") @@ -384,12 +376,8 @@ class Notifier: self.notify_replication() # Notify appservices - run_as_background_process( - "_notify_app_services_ephemeral", - self._notify_app_services_ephemeral, - stream_key, - new_token, - users, + self._notify_app_services_ephemeral( + stream_key, new_token, users, ) def on_new_replication_data(self) -> None: diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 0080c68ce2..f325964983 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -19,7 +19,10 @@ from typing import TYPE_CHECKING, Dict, Union from prometheus_client import Gauge -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.push import PusherConfigException from synapse.push.emailpusher import EmailPusher from synapse.push.httppusher import HttpPusher @@ -187,7 +190,7 @@ class PusherPool: ) await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - async def on_new_notifications(self, max_token: RoomStreamToken): + def on_new_notifications(self, max_token: RoomStreamToken): if not self.pushers: # nothing to do here. return @@ -201,6 +204,17 @@ class PusherPool: # Nothing to do return + # We only start a new background process if necessary rather than + # optimistically (to cut down on overhead). + self._on_new_notifications(max_token) + + @wrap_as_background_process("on_new_notifications") + async def _on_new_notifications(self, max_token: RoomStreamToken): + # We just use the minimum stream ordering and ignore the vector clock + # component. This is safe to do as long as we *always* ignore the vector + # clock components. + max_stream_id = max_token.stream + prev_stream_id = self._last_room_stream_id_seen self._last_room_stream_id_seen = max_stream_id diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index de19705c1f..bc6ba709a7 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -166,7 +166,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): Args: cmd (Command) """ - run_as_background_process("send-cmd", self._async_send_command, cmd) + run_as_background_process( + "send-cmd", self._async_send_command, cmd, bg_start_span=False + ) async def _async_send_command(self, cmd: Command): """Encode a replication command and send it over our outbound connection""" diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index ee4f3da31c..53763cd0f9 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -42,7 +42,6 @@ class AppServiceHandlerTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() self.handler = ApplicationServicesHandler(hs) - @defer.inlineCallbacks def test_notify_interested_services(self): interested_service = self._mkservice(is_interested=True) services = [ @@ -62,14 +61,12 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.succeed((0, [event])), defer.succeed((0, [])), ] - yield defer.ensureDeferred( - self.handler.notify_interested_services(RoomStreamToken(None, 0)) - ) + self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.mock_scheduler.submit_event_for_as.assert_called_once_with( interested_service, event ) - @defer.inlineCallbacks def test_query_user_exists_unknown_user(self): user_id = "@someone:anywhere" services = [self._mkservice(is_interested=True)] @@ -83,12 +80,11 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.succeed((0, [event])), defer.succeed((0, [])), ] - yield defer.ensureDeferred( - self.handler.notify_interested_services(RoomStreamToken(None, 0)) - ) + + self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) - @defer.inlineCallbacks def test_query_user_exists_known_user(self): user_id = "@someone:anywhere" services = [self._mkservice(is_interested=True)] @@ -102,9 +98,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.succeed((0, [event])), defer.succeed((0, [])), ] - yield defer.ensureDeferred( - self.handler.notify_interested_services(RoomStreamToken(None, 0)) - ) + + self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.assertFalse( self.mock_as_api.query_user.called, "query_user called when it shouldn't have been.", -- cgit 1.5.1 From 4215a3acd4a77cb3331144782d43f99635f3d0ed Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 27 Oct 2020 17:37:08 +0000 Subject: Don't unnecessarily start bg process in replication sending loop. (#8670) --- changelog.d/8670.misc | 1 + synapse/replication/tcp/resource.py | 10 ++++++++++ 2 files changed, 11 insertions(+) create mode 100644 changelog.d/8670.misc (limited to 'synapse/replication') diff --git a/changelog.d/8670.misc b/changelog.d/8670.misc new file mode 100644 index 0000000000..cf6023f783 --- /dev/null +++ b/changelog.d/8670.misc @@ -0,0 +1 @@ +Reduce number of OpenTracing spans started. diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 666c13fdb7..1d4ceac0f1 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -117,6 +117,16 @@ class ReplicationStreamer: stream.discard_updates_and_advance() return + # We check up front to see if anything has actually changed, as we get + # poked because of changes that happened on other instances. + if all( + stream.last_token == stream.current_token(self._instance_name) + for stream in self.streams + ): + return + + # If there are updates then we need to set this even if we're already + # looping, as the loop needs to know that he might need to loop again. self.pending_updates = True if self.is_looping: -- cgit 1.5.1 From a6ea1a957e8e38ca3f98d4da32ee49a40fcb4807 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 28 Oct 2020 12:11:45 +0000 Subject: Don't pull event from DB when handling replication traffic. (#8669) I was trying to make it so that we didn't have to start a background task when handling RDATA, but that is a bigger job (due to all the code in `generic_worker`). However I still think not pulling the event from the DB may help reduce some DB usage due to replication, even if most workers will simply go and pull that event from the DB later anyway. Co-authored-by: Patrick Cloke --- changelog.d/8669.misc | 1 + synapse/notifier.py | 68 ++++++++++++++++++++----- synapse/replication/tcp/client.py | 20 +++++--- synapse/replication/tcp/streams/events.py | 21 +++++--- synapse/storage/databases/main/events_worker.py | 8 ++- 5 files changed, 87 insertions(+), 31 deletions(-) create mode 100644 changelog.d/8669.misc (limited to 'synapse/replication') diff --git a/changelog.d/8669.misc b/changelog.d/8669.misc new file mode 100644 index 0000000000..5228105cd3 --- /dev/null +++ b/changelog.d/8669.misc @@ -0,0 +1 @@ +Don't pull event from DB when handling replication traffic. diff --git a/synapse/notifier.py b/synapse/notifier.py index eb56b26f21..a17352ef46 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -28,6 +28,7 @@ from typing import ( Union, ) +import attr from prometheus_client import Counter from twisted.internet import defer @@ -173,6 +174,17 @@ class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))): return bool(self.events) +@attr.s(slots=True, frozen=True) +class _PendingRoomEventEntry: + event_pos = attr.ib(type=PersistedEventPosition) + extra_users = attr.ib(type=Collection[UserID]) + + room_id = attr.ib(type=str) + type = attr.ib(type=str) + state_key = attr.ib(type=Optional[str]) + membership = attr.ib(type=Optional[str]) + + class Notifier: """ This class is responsible for notifying any listeners when there are new events available for it. @@ -190,9 +202,7 @@ class Notifier: self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() - self.pending_new_room_events = ( - [] - ) # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]] + self.pending_new_room_events = [] # type: List[_PendingRoomEventEntry] # Called when there are new things to stream over replication self.replication_callbacks = [] # type: List[Callable[[], None]] @@ -255,7 +265,29 @@ class Notifier: max_room_stream_token: RoomStreamToken, extra_users: Collection[UserID] = [], ): - """ Used by handlers to inform the notifier something has happened + """Unwraps event and calls `on_new_room_event_args`. + """ + self.on_new_room_event_args( + event_pos=event_pos, + room_id=event.room_id, + event_type=event.type, + state_key=event.get("state_key"), + membership=event.content.get("membership"), + max_room_stream_token=max_room_stream_token, + extra_users=extra_users, + ) + + def on_new_room_event_args( + self, + room_id: str, + event_type: str, + state_key: Optional[str], + membership: Optional[str], + event_pos: PersistedEventPosition, + max_room_stream_token: RoomStreamToken, + extra_users: Collection[UserID] = [], + ): + """Used by handlers to inform the notifier something has happened in the room, room event wise. This triggers the notifier to wake up any listeners that are @@ -266,7 +298,16 @@ class Notifier: until all previous events have been persisted before notifying the client streams. """ - self.pending_new_room_events.append((event_pos, event, extra_users)) + self.pending_new_room_events.append( + _PendingRoomEventEntry( + event_pos=event_pos, + extra_users=extra_users, + room_id=room_id, + type=event_type, + state_key=state_key, + membership=membership, + ) + ) self._notify_pending_new_room_events(max_room_stream_token) self.notify_replication() @@ -284,18 +325,19 @@ class Notifier: users = set() # type: Set[UserID] rooms = set() # type: Set[str] - for event_pos, event, extra_users in pending: - if event_pos.persisted_after(max_room_stream_token): - self.pending_new_room_events.append((event_pos, event, extra_users)) + for entry in pending: + if entry.event_pos.persisted_after(max_room_stream_token): + self.pending_new_room_events.append(entry) else: if ( - event.type == EventTypes.Member - and event.membership == Membership.JOIN + entry.type == EventTypes.Member + and entry.membership == Membership.JOIN + and entry.state_key ): - self._user_joined_room(event.state_key, event.room_id) + self._user_joined_room(entry.state_key, entry.room_id) - users.update(extra_users) - rooms.add(event.room_id) + users.update(entry.extra_users) + rooms.add(entry.room_id) if users or rooms: self.on_new_event( diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e27ee216f0..2618eb1e53 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -141,21 +141,25 @@ class ReplicationDataHandler: if row.type != EventsStreamEventRow.TypeId: continue assert isinstance(row, EventsStreamRow) + assert isinstance(row.data, EventsStreamEventRow) - event = await self.store.get_event( - row.data.event_id, allow_rejected=True - ) - if event.rejected_reason: + if row.data.rejected: continue extra_users = () # type: Tuple[UserID, ...] - if event.type == EventTypes.Member: - extra_users = (UserID.from_string(event.state_key),) + if row.data.type == EventTypes.Member and row.data.state_key: + extra_users = (UserID.from_string(row.data.state_key),) max_token = self.store.get_room_max_token() event_pos = PersistedEventPosition(instance_name, token) - self.notifier.on_new_room_event( - event, event_pos, max_token, extra_users + self.notifier.on_new_room_event_args( + event_pos=event_pos, + max_room_stream_token=max_token, + extra_users=extra_users, + room_id=row.data.room_id, + event_type=row.data.type, + state_key=row.data.state_key, + membership=row.data.membership, ) # Notify any waiting deferreds. The list is ordered by position so we diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 82e9e0d64e..86a62b71eb 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -15,12 +15,15 @@ # limitations under the License. import heapq from collections.abc import Iterable -from typing import List, Tuple, Type +from typing import TYPE_CHECKING, List, Optional, Tuple, Type import attr from ._base import Stream, StreamUpdateResult, Token +if TYPE_CHECKING: + from synapse.server import HomeServer + """Handling of the 'events' replication stream This stream contains rows of various types. Each row therefore contains a 'type' @@ -81,12 +84,14 @@ class BaseEventsStreamRow: class EventsStreamEventRow(BaseEventsStreamRow): TypeId = "ev" - event_id = attr.ib() # str - room_id = attr.ib() # str - type = attr.ib() # str - state_key = attr.ib() # str, optional - redacts = attr.ib() # str, optional - relates_to = attr.ib() # str, optional + event_id = attr.ib(type=str) + room_id = attr.ib(type=str) + type = attr.ib(type=str) + state_key = attr.ib(type=Optional[str]) + redacts = attr.ib(type=Optional[str]) + relates_to = attr.ib(type=Optional[str]) + membership = attr.ib(type=Optional[str]) + rejected = attr.ib(type=bool) @attr.s(slots=True, frozen=True) @@ -113,7 +118,7 @@ class EventsStream(Stream): NAME = "events" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self._store = hs.get_datastore() super().__init__( hs.get_instance_name(), diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index cd1f31aa62..5ae263827d 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1117,11 +1117,13 @@ class EventsWorkerStore(SQLBaseStore): def get_all_new_forward_event_rows(txn): sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" + " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" + " LEFT JOIN room_memberships USING (event_id)" + " LEFT JOIN rejections USING (event_id)" " WHERE ? < stream_ordering AND stream_ordering <= ?" " AND instance_name = ?" " ORDER BY stream_ordering ASC" @@ -1152,12 +1154,14 @@ class EventsWorkerStore(SQLBaseStore): def get_ex_outlier_stream_rows_txn(txn): sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" + " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" + " LEFT JOIN room_memberships USING (event_id)" + " LEFT JOIN rejections USING (event_id)" " WHERE ? < event_stream_ordering" " AND event_stream_ordering <= ?" " AND out.instance_name = ?" -- cgit 1.5.1 From f21e24ffc22a5eb01f242f47fa30979321cf20fc Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 29 Oct 2020 15:58:44 +0000 Subject: Add ability for access tokens to belong to one user but grant access to another user. (#8616) We do it this way round so that only the "owner" can delete the access token (i.e. `/logout/all` by the "owner" also deletes that token, but `/logout/all` by the "target user" doesn't). A future PR will add an API for creating such a token. When the target user and authenticated entity are different the `Processed request` log line will be logged with a: `{@admin:server as @bob:server} ...`. I'm not convinced by that format (especially since it adds spaces in there, making it harder to use `cut -d ' '` to chop off the start of log lines). Suggestions welcome. --- changelog.d/8616.misc | 1 + synapse/api/auth.py | 113 +++++++++------------ synapse/appservice/__init__.py | 4 +- synapse/federation/transport/server.py | 2 +- synapse/handlers/auth.py | 8 +- synapse/handlers/register.py | 7 +- synapse/http/site.py | 30 ++++-- synapse/replication/http/membership.py | 6 +- synapse/replication/http/send_event.py | 3 +- synapse/storage/databases/main/registration.py | 48 +++++++-- .../main/schema/delta/58/22puppet_token.sql | 17 ++++ synapse/types.py | 33 ++++-- tests/api/test_auth.py | 29 +++--- tests/api/test_ratelimiting.py | 4 +- tests/appservice/test_appservice.py | 1 + tests/handlers/test_device.py | 2 +- tests/handlers/test_message.py | 2 +- tests/push/test_email.py | 2 +- tests/push/test_http.py | 10 +- tests/replication/test_pusher_shard.py | 2 +- tests/rest/client/v2_alpha/test_register.py | 1 + tests/storage/test_registration.py | 10 +- 22 files changed, 197 insertions(+), 138 deletions(-) create mode 100644 changelog.d/8616.misc create mode 100644 synapse/storage/databases/main/schema/delta/58/22puppet_token.sql (limited to 'synapse/replication') diff --git a/changelog.d/8616.misc b/changelog.d/8616.misc new file mode 100644 index 0000000000..385b14063e --- /dev/null +++ b/changelog.d/8616.misc @@ -0,0 +1 @@ +Change schema to support access tokens belonging to one user but granting access to another. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 526cb58c5f..bfcaf68b2a 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -33,6 +33,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.logging import opentracing as opentracing +from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import StateMap, UserID from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure @@ -190,10 +191,6 @@ class Auth: user_id, app_service = await self._get_appservice_user_id(request) if user_id: - request.authenticated_entity = user_id - opentracing.set_tag("authenticated_entity", user_id) - opentracing.set_tag("appservice_id", app_service.id) - if ip_addr and self._track_appservice_user_ips: await self.store.insert_client_ip( user_id=user_id, @@ -203,31 +200,38 @@ class Auth: device_id="dummy-device", # stubbed ) - return synapse.types.create_requester(user_id, app_service=app_service) + requester = synapse.types.create_requester( + user_id, app_service=app_service + ) + + request.requester = user_id + opentracing.set_tag("authenticated_entity", user_id) + opentracing.set_tag("user_id", user_id) + opentracing.set_tag("appservice_id", app_service.id) + + return requester user_info = await self.get_user_by_access_token( access_token, rights, allow_expired=allow_expired ) - user = user_info["user"] - token_id = user_info["token_id"] - is_guest = user_info["is_guest"] - shadow_banned = user_info["shadow_banned"] + token_id = user_info.token_id + is_guest = user_info.is_guest + shadow_banned = user_info.shadow_banned # Deny the request if the user account has expired. if self._account_validity.enabled and not allow_expired: - user_id = user.to_string() - if await self.store.is_account_expired(user_id, self.clock.time_msec()): + if await self.store.is_account_expired( + user_info.user_id, self.clock.time_msec() + ): raise AuthError( 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT ) - # device_id may not be present if get_user_by_access_token has been - # stubbed out. - device_id = user_info.get("device_id") + device_id = user_info.device_id - if user and access_token and ip_addr: + if access_token and ip_addr: await self.store.insert_client_ip( - user_id=user.to_string(), + user_id=user_info.token_owner, access_token=access_token, ip=ip_addr, user_agent=user_agent, @@ -241,19 +245,23 @@ class Auth: errcode=Codes.GUEST_ACCESS_FORBIDDEN, ) - request.authenticated_entity = user.to_string() - opentracing.set_tag("authenticated_entity", user.to_string()) - if device_id: - opentracing.set_tag("device_id", device_id) - - return synapse.types.create_requester( - user, + requester = synapse.types.create_requester( + user_info.user_id, token_id, is_guest, shadow_banned, device_id, app_service=app_service, + authenticated_entity=user_info.token_owner, ) + + request.requester = requester + opentracing.set_tag("authenticated_entity", user_info.token_owner) + opentracing.set_tag("user_id", user_info.user_id) + if device_id: + opentracing.set_tag("device_id", device_id) + + return requester except KeyError: raise MissingClientTokenError() @@ -284,7 +292,7 @@ class Auth: async def get_user_by_access_token( self, token: str, rights: str = "access", allow_expired: bool = False, - ) -> dict: + ) -> TokenLookupResult: """ Validate access token and get user_id from it Args: @@ -293,13 +301,7 @@ class Auth: allow this allow_expired: If False, raises an InvalidClientTokenError if the token is expired - Returns: - dict that includes: - `user` (UserID) - `is_guest` (bool) - `shadow_banned` (bool) - `token_id` (int|None): access token id. May be None if guest - `device_id` (str|None): device corresponding to access token + Raises: InvalidClientTokenError if a user by that token exists, but the token is expired @@ -309,9 +311,9 @@ class Auth: if rights == "access": # first look in the database - r = await self._look_up_user_by_access_token(token) + r = await self.store.get_user_by_access_token(token) if r: - valid_until_ms = r["valid_until_ms"] + valid_until_ms = r.valid_until_ms if ( not allow_expired and valid_until_ms is not None @@ -328,7 +330,6 @@ class Auth: # otherwise it needs to be a valid macaroon try: user_id, guest = self._parse_and_validate_macaroon(token, rights) - user = UserID.from_string(user_id) if rights == "access": if not guest: @@ -354,23 +355,17 @@ class Auth: raise InvalidClientTokenError( "Guest access token used for regular user" ) - ret = { - "user": user, - "is_guest": True, - "shadow_banned": False, - "token_id": None, + + ret = TokenLookupResult( + user_id=user_id, + is_guest=True, # all guests get the same device id - "device_id": GUEST_DEVICE_ID, - } + device_id=GUEST_DEVICE_ID, + ) elif rights == "delete_pusher": # We don't store these tokens in the database - ret = { - "user": user, - "is_guest": False, - "shadow_banned": False, - "token_id": None, - "device_id": None, - } + + ret = TokenLookupResult(user_id=user_id, is_guest=False) else: raise RuntimeError("Unknown rights setting %s", rights) return ret @@ -479,31 +474,15 @@ class Auth: now = self.hs.get_clock().time_msec() return now < expiry - async def _look_up_user_by_access_token(self, token): - ret = await self.store.get_user_by_access_token(token) - if not ret: - return None - - # we use ret.get() below because *lots* of unit tests stub out - # get_user_by_access_token in a way where it only returns a couple of - # the fields. - user_info = { - "user": UserID.from_string(ret.get("name")), - "token_id": ret.get("token_id", None), - "is_guest": False, - "shadow_banned": ret.get("shadow_banned"), - "device_id": ret.get("device_id"), - "valid_until_ms": ret.get("valid_until_ms"), - } - return user_info - def get_appservice_by_req(self, request): token = self.get_access_token_from_request(request) service = self.store.get_app_service_by_token(token) if not service: logger.warning("Unrecognised appservice access token.") raise InvalidClientTokenError() - request.authenticated_entity = service.sender + request.requester = synapse.types.create_requester( + service.sender, app_service=service + ) return service async def is_server_admin(self, user: UserID) -> bool: diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 3862d9c08f..66d008d2f4 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -52,11 +52,11 @@ class ApplicationService: self, token, hostname, + id, + sender, url=None, namespaces=None, hs_token=None, - sender=None, - id=None, protocols=None, rate_limited=True, ip_range_whitelist=None, diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 3a6b95631e..a0933fae88 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -154,7 +154,7 @@ class Authenticator: ) logger.debug("Request from %s", origin) - request.authenticated_entity = origin + request.requester = origin # If we get a valid signed request from the other side, its probably # alive diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 276594f3d9..ff103cbb92 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -991,17 +991,17 @@ class AuthHandler(BaseHandler): # This might return an awaitable, if it does block the log out # until it completes. result = provider.on_logged_out( - user_id=str(user_info["user"]), - device_id=user_info["device_id"], + user_id=user_info.user_id, + device_id=user_info.device_id, access_token=access_token, ) if inspect.isawaitable(result): await result # delete pushers associated with this access token - if user_info["token_id"] is not None: + if user_info.token_id is not None: await self.hs.get_pusherpool().remove_pushers_by_access_token( - str(user_info["user"]), (user_info["token_id"],) + user_info.user_id, (user_info.token_id,) ) async def delete_access_tokens_for_user( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a6f1d21674..ed1ff62599 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -115,7 +115,10 @@ class RegistrationHandler(BaseHandler): 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) user_data = await self.auth.get_user_by_access_token(guest_access_token) - if not user_data["is_guest"] or user_data["user"].localpart != localpart: + if ( + not user_data.is_guest + or UserID.from_string(user_data.user_id).localpart != localpart + ): raise AuthError( 403, "Cannot register taken user ID without valid guest " @@ -741,7 +744,7 @@ class RegistrationHandler(BaseHandler): # up when the access token is saved, but that's quite an # invasive change I'd rather do separately. user_tuple = await self.store.get_user_by_access_token(token) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id await self.pusher_pool.add_pusher( user_id=user_id, diff --git a/synapse/http/site.py b/synapse/http/site.py index ddb1770b09..5f0581dc3f 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -14,7 +14,7 @@ import contextlib import logging import time -from typing import Optional +from typing import Optional, Union from twisted.python.failure import Failure from twisted.web.server import Request, Site @@ -23,6 +23,7 @@ from synapse.config.server import ListenerConfig from synapse.http import redact_uri from synapse.http.request_metrics import RequestMetrics, requests_counter from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.types import Requester logger = logging.getLogger(__name__) @@ -54,9 +55,12 @@ class SynapseRequest(Request): Request.__init__(self, channel, *args, **kw) self.site = channel.site self._channel = channel # this is used by the tests - self.authenticated_entity = None self.start_time = 0.0 + # The requester, if authenticated. For federation requests this is the + # server name, for client requests this is the Requester object. + self.requester = None # type: Optional[Union[Requester, str]] + # we can't yet create the logcontext, as we don't know the method. self.logcontext = None # type: Optional[LoggingContext] @@ -271,11 +275,23 @@ class SynapseRequest(Request): # to the client (nb may be negative) response_send_time = self.finish_time - self._processing_finished_time - # need to decode as it could be raw utf-8 bytes - # from a IDN servname in an auth header - authenticated_entity = self.authenticated_entity - if authenticated_entity is not None and isinstance(authenticated_entity, bytes): - authenticated_entity = authenticated_entity.decode("utf-8", "replace") + # Convert the requester into a string that we can log + authenticated_entity = None + if isinstance(self.requester, str): + authenticated_entity = self.requester + elif isinstance(self.requester, Requester): + authenticated_entity = self.requester.authenticated_entity + + # If this is a request where the target user doesn't match the user who + # authenticated (e.g. and admin is puppetting a user) then we log both. + if self.requester.user.to_string() != authenticated_entity: + authenticated_entity = "{},{}".format( + authenticated_entity, self.requester.user.to_string(), + ) + elif self.requester is not None: + # This shouldn't happen, but we log it so we don't lose information + # and can see that we're doing something wrong. + authenticated_entity = repr(self.requester) # type: ignore[unreachable] # ...or could be raw utf-8 bytes in the User-Agent header. # N.B. if you don't do this, the logger explodes cryptically diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index e7cc74a5d2..f0c37eaf5e 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -77,8 +77,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): requester = Requester.deserialize(self.store, content["requester"]) - if requester.user: - request.authenticated_entity = requester.user.to_string() + request.requester = requester logger.info("remote_join: %s into room: %s", user_id, room_id) @@ -142,8 +141,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): requester = Requester.deserialize(self.store, content["requester"]) - if requester.user: - request.authenticated_entity = requester.user.to_string() + request.requester = requester # hopefully we're now on the master, so this won't recurse! event_id, stream_id = await self.member_handler.remote_reject_invite( diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index fc129dbaa7..8fa104c8d3 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -115,8 +115,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] - if requester.user: - request.authenticated_entity = requester.user.to_string() + request.requester = requester logger.info( "Got event to send with ID: %s into room: %s", event.event_id, event.room_id diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index e7b17a7385..e5d07ce72a 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -18,6 +18,8 @@ import logging import re from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +import attr + from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import wrap_as_background_process @@ -38,6 +40,35 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 logger = logging.getLogger(__name__) +@attr.s(frozen=True, slots=True) +class TokenLookupResult: + """Result of looking up an access token. + + Attributes: + user_id: The user that this token authenticates as + is_guest + shadow_banned + token_id: The ID of the access token looked up + device_id: The device associated with the token, if any. + valid_until_ms: The timestamp the token expires, if any. + token_owner: The "owner" of the token. This is either the same as the + user, or a server admin who is logged in as the user. + """ + + user_id = attr.ib(type=str) + is_guest = attr.ib(type=bool, default=False) + shadow_banned = attr.ib(type=bool, default=False) + token_id = attr.ib(type=Optional[int], default=None) + device_id = attr.ib(type=Optional[str], default=None) + valid_until_ms = attr.ib(type=Optional[int], default=None) + token_owner = attr.ib(type=str) + + # Make the token owner default to the user ID, which is the common case. + @token_owner.default + def _default_token_owner(self): + return self.user_id + + class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) @@ -102,15 +133,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return is_trial @cached() - async def get_user_by_access_token(self, token: str) -> Optional[dict]: + async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]: """Get a user from the given access token. Args: token: The access token of a user. Returns: - None, if the token did not match, otherwise dict - including the keys `name`, `is_guest`, `device_id`, `token_id`, - `valid_until_ms`. + None, if the token did not match, otherwise a `TokenLookupResult` """ return await self.db_pool.runInteraction( "get_user_by_access_token", self._query_for_auth, token @@ -331,23 +360,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) - def _query_for_auth(self, txn, token): + def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]: sql = """ - SELECT users.name, + SELECT users.name as user_id, users.is_guest, users.shadow_banned, access_tokens.id as token_id, access_tokens.device_id, - access_tokens.valid_until_ms + access_tokens.valid_until_ms, + access_tokens.user_id as token_owner FROM users - INNER JOIN access_tokens on users.name = access_tokens.user_id + INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id) WHERE token = ? """ txn.execute(sql, (token,)) rows = self.db_pool.cursor_to_dict(txn) if rows: - return rows[0] + return TokenLookupResult(**rows[0]) return None diff --git a/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql new file mode 100644 index 0000000000..00a9431a97 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql @@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Whether the access token is an admin token for controlling another user. +ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT; diff --git a/synapse/types.py b/synapse/types.py index 5bde67cc07..66bb5bac8d 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -29,6 +29,7 @@ from typing import ( Tuple, Type, TypeVar, + Union, ) import attr @@ -38,6 +39,7 @@ from unpaddedbase64 import decode_base64 from synapse.api.errors import Codes, SynapseError if TYPE_CHECKING: + from synapse.appservice.api import ApplicationService from synapse.storage.databases.main import DataStore # define a version of typing.Collection that works on python 3.5 @@ -74,6 +76,7 @@ class Requester( "shadow_banned", "device_id", "app_service", + "authenticated_entity", ], ) ): @@ -104,6 +107,7 @@ class Requester( "shadow_banned": self.shadow_banned, "device_id": self.device_id, "app_server_id": self.app_service.id if self.app_service else None, + "authenticated_entity": self.authenticated_entity, } @staticmethod @@ -129,16 +133,18 @@ class Requester( shadow_banned=input["shadow_banned"], device_id=input["device_id"], app_service=appservice, + authenticated_entity=input["authenticated_entity"], ) def create_requester( - user_id, - access_token_id=None, - is_guest=False, - shadow_banned=False, - device_id=None, - app_service=None, + user_id: Union[str, "UserID"], + access_token_id: Optional[int] = None, + is_guest: Optional[bool] = False, + shadow_banned: Optional[bool] = False, + device_id: Optional[str] = None, + app_service: Optional["ApplicationService"] = None, + authenticated_entity: Optional[str] = None, ): """ Create a new ``Requester`` object @@ -151,14 +157,27 @@ def create_requester( shadow_banned (bool): True if the user making this request is shadow-banned. device_id (str|None): device_id which was set at authentication time app_service (ApplicationService|None): the AS requesting on behalf of the user + authenticated_entity: The entity that authenticated when making the request. + This is different to the user_id when an admin user or the server is + "puppeting" the user. Returns: Requester """ if not isinstance(user_id, UserID): user_id = UserID.from_string(user_id) + + if authenticated_entity is None: + authenticated_entity = user_id.to_string() + return Requester( - user_id, access_token_id, is_guest, shadow_banned, device_id, app_service + user_id, + access_token_id, + is_guest, + shadow_banned, + device_id, + app_service, + authenticated_entity, ) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index cb6f29d670..0fd55f428a 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -29,6 +29,7 @@ from synapse.api.errors import ( MissingClientTokenError, ResourceLimitError, ) +from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import UserID from tests import unittest @@ -61,7 +62,9 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_user_valid_token(self): - user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"} + user_info = TokenLookupResult( + user_id=self.test_user, token_id=5, device_id="device" + ) self.store.get_user_by_access_token = Mock( return_value=defer.succeed(user_info) ) @@ -84,7 +87,7 @@ class AuthTestCase(unittest.TestCase): self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_user_missing_token(self): - user_info = {"name": self.test_user, "token_id": "ditto"} + user_info = TokenLookupResult(user_id=self.test_user, token_id=5) self.store.get_user_by_access_token = Mock( return_value=defer.succeed(user_info) ) @@ -221,7 +224,7 @@ class AuthTestCase(unittest.TestCase): def test_get_user_from_macaroon(self): self.store.get_user_by_access_token = Mock( return_value=defer.succeed( - {"name": "@baldrick:matrix.org", "device_id": "device"} + TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device") ) ) @@ -237,12 +240,11 @@ class AuthTestCase(unittest.TestCase): user_info = yield defer.ensureDeferred( self.auth.get_user_by_access_token(macaroon.serialize()) ) - user = user_info["user"] - self.assertEqual(UserID.from_string(user_id), user) + self.assertEqual(user_id, user_info.user_id) # TODO: device_id should come from the macaroon, but currently comes # from the db. - self.assertEqual(user_info["device_id"], "device") + self.assertEqual(user_info.device_id, "device") @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): @@ -264,10 +266,8 @@ class AuthTestCase(unittest.TestCase): user_info = yield defer.ensureDeferred( self.auth.get_user_by_access_token(serialized) ) - user = user_info["user"] - is_guest = user_info["is_guest"] - self.assertEqual(UserID.from_string(user_id), user) - self.assertTrue(is_guest) + self.assertEqual(user_id, user_info.user_id) + self.assertTrue(user_info.is_guest) self.store.get_user_by_id.assert_called_with(user_id) @defer.inlineCallbacks @@ -289,12 +289,9 @@ class AuthTestCase(unittest.TestCase): if token != tok: return defer.succeed(None) return defer.succeed( - { - "name": USER_ID, - "is_guest": False, - "token_id": 1234, - "device_id": "DEVICE", - } + TokenLookupResult( + user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE", + ) ) self.store.get_user_by_access_token = get_user diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 1e1f30d790..fe504d0869 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -43,7 +43,7 @@ class TestRatelimiter(unittest.TestCase): def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): appservice = ApplicationService( - None, "example.com", id="foo", rate_limited=True, + None, "example.com", id="foo", rate_limited=True, sender="@as:example.com", ) as_requester = create_requester("@user:example.com", app_service=appservice) @@ -68,7 +68,7 @@ class TestRatelimiter(unittest.TestCase): def test_allowed_appservice_via_can_requester_do_action(self): appservice = ApplicationService( - None, "example.com", id="foo", rate_limited=False, + None, "example.com", id="foo", rate_limited=False, sender="@as:example.com", ) as_requester = create_requester("@user:example.com", app_service=appservice) diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 236b608d58..0bffeb1150 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -31,6 +31,7 @@ class ApplicationServiceTestCase(unittest.TestCase): def setUp(self): self.service = ApplicationService( id="unique_identifier", + sender="@as:test", url="some_url", token="some_token", hostname="matrix.org", # only used by get_groups_for_user diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 4512c51311..875aaec2c6 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -289,7 +289,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase): # make sure that our device ID has changed user_info = self.get_success(self.auth.get_user_by_access_token(access_token)) - self.assertEqual(user_info["device_id"], retrieved_device_id) + self.assertEqual(user_info.device_id, retrieved_device_id) # make sure the device has the display name that was set from the login res = self.get_success(self.handler.get_device(user_id, retrieved_device_id)) diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 9f6f21a6e2..2e0fea04af 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -46,7 +46,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.info = self.get_success( self.hs.get_datastore().get_user_by_access_token(self.access_token,) ) - self.token_id = self.info["token_id"] + self.token_id = self.info.token_id self.requester = create_requester(self.user_id, access_token_id=self.token_id) diff --git a/tests/push/test_email.py b/tests/push/test_email.py index d9993e6245..bcdcafa5a9 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(self.access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.pusher = self.get_success( self.hs.get_pusherpool().add_pusher( diff --git a/tests/push/test_http.py b/tests/push/test_http.py index b567868b02..8571924b29 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -69,7 +69,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -181,7 +181,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -297,7 +297,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -379,7 +379,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -452,7 +452,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 2bdc6edbb1..67c27a089f 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): user_dict = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_dict["token_id"] + token_id = user_dict.token_id self.get_success( self.hs.get_pusherpool().add_pusher( diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 2fc3a60fc5..98c3887bbf 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -55,6 +55,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.hs.config.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", ) self.hs.get_datastore().services_cache.append(appservice) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 6b582771fe..c8c7a90e5d 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -69,11 +69,9 @@ class RegistrationStoreTestCase(unittest.TestCase): self.store.get_user_by_access_token(self.tokens[1]) ) - self.assertDictContainsSubset( - {"name": self.user_id, "device_id": self.device_id}, result - ) - - self.assertTrue("token_id" in result) + self.assertEqual(result.user_id, self.user_id) + self.assertEqual(result.device_id, self.device_id) + self.assertIsNotNone(result.token_id) @defer.inlineCallbacks def test_user_delete_access_tokens(self): @@ -105,7 +103,7 @@ class RegistrationStoreTestCase(unittest.TestCase): user = yield defer.ensureDeferred( self.store.get_user_by_access_token(self.tokens[0]) ) - self.assertEqual(self.user_id, user["name"]) + self.assertEqual(self.user_id, user.user_id) # now delete the rest yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id)) -- cgit 1.5.1 From e8d08537394a49f3e66e9cbea3627e3c25818a7d Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 13 Nov 2020 16:24:04 +0000 Subject: Generalise _maybe_store_room_on_invite (#8754) There's a handy function called maybe_store_room_on_invite which allows us to create an entry in the rooms table for a room and its version for which we aren't joined to yet, but we can reference when ingesting events about. This is currently used for invites where we receive some stripped state about the room and pass it down via /sync to the client, without us being in the room yet. There is a similar requirement for knocking, where we will eventually do the same thing, and need an entry in the rooms table as well. Thus, reusing this function works, however its name needs to be generalised a bit. Separated out from #6739. --- changelog.d/8754.misc | 1 + synapse/handlers/federation.py | 10 ++++++---- synapse/replication/http/federation.py | 10 +++++----- synapse/storage/databases/main/room.py | 10 ++++++---- 4 files changed, 18 insertions(+), 13 deletions(-) create mode 100644 changelog.d/8754.misc (limited to 'synapse/replication') diff --git a/changelog.d/8754.misc b/changelog.d/8754.misc new file mode 100644 index 0000000000..0436bb1be7 --- /dev/null +++ b/changelog.d/8754.misc @@ -0,0 +1 @@ +Generalise `RoomStore.maybe_store_room_on_invite` to handle other, non-invite membership events. \ No newline at end of file diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c386957706..69bc5ba44d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -67,7 +67,7 @@ from synapse.replication.http.devices import ReplicationUserDevicesResyncRestSer from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, ReplicationFederationSendEventsRestServlet, - ReplicationStoreRoomOnInviteRestServlet, + ReplicationStoreRoomOnOutlierMembershipRestServlet, ) from synapse.state import StateResolutionStore from synapse.storage.databases.main.events_worker import EventRedactBehaviour @@ -152,12 +152,14 @@ class FederationHandler(BaseHandler): self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( hs ) - self._maybe_store_room_on_invite = ReplicationStoreRoomOnInviteRestServlet.make_client( + self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client( hs ) else: self._device_list_updater = hs.get_device_handler().device_list_updater - self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite + self._maybe_store_room_on_outlier_membership = ( + self.store.maybe_store_room_on_outlier_membership + ) # When joining a room we need to queue any events for that room up. # For each room, a list of (pdu, origin) tuples. @@ -1617,7 +1619,7 @@ class FederationHandler(BaseHandler): # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a # join dance). - await self._maybe_store_room_on_invite( + await self._maybe_store_room_on_outlier_membership( room_id=event.room_id, room_version=room_version ) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index b4f4a68b5c..7a0dbb5b1a 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -254,20 +254,20 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): return 200, {} -class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint): +class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint): """Called to clean up any data in DB for a given room, ready for the server to join the room. Request format: - POST /_synapse/replication/store_room_on_invite/:room_id/:txn_id + POST /_synapse/replication/store_room_on_outlier_membership/:room_id/:txn_id { "room_version": "1", } """ - NAME = "store_room_on_invite" + NAME = "store_room_on_outlier_membership" PATH_ARGS = ("room_id",) def __init__(self, hs): @@ -282,7 +282,7 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint): async def _handle_request(self, request, room_id): content = parse_json_object_from_request(request) room_version = KNOWN_ROOM_VERSIONS[content["room_version"]] - await self.store.maybe_store_room_on_invite(room_id, room_version) + await self.store.maybe_store_room_on_outlier_membership(room_id, room_version) return 200, {} @@ -291,4 +291,4 @@ def register_servlets(hs, http_server): ReplicationFederationSendEduRestServlet(hs).register(http_server) ReplicationGetQueryRestServlet(hs).register(http_server) ReplicationCleanRoomRestServlet(hs).register(http_server) - ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server) + ReplicationStoreRoomOnOutlierMembershipRestServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index dc0c4b5499..6b89db15c9 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1240,13 +1240,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") - async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion): + async def maybe_store_room_on_outlier_membership( + self, room_id: str, room_version: RoomVersion + ): """ - When we receive an invite over federation, store the version of the room if we - don't already know the room version. + When we receive an invite or any other event over federation that may relate to a room + we are not in, store the version of the room if we don't already know the room version. """ await self.db_pool.simple_upsert( - desc="maybe_store_room_on_invite", + desc="maybe_store_room_on_outlier_membership", table="rooms", keyvalues={"room_id": room_id}, values={}, -- cgit 1.5.1 From 5cbe8d93fe08ff347730596475a8a50b977754d7 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 27 Nov 2020 10:49:38 +0000 Subject: Add typing to membership Replication class methods (#8809) This PR grew out of #6739, and adds typing to some method arguments You'll notice that there are a lot of `# type: ignores` in here. This is due to the base methods not matching the overloads here. This is necessary to stop mypy complaining, but a better solution is #8828. --- changelog.d/8809.misc | 1 + synapse/replication/http/membership.py | 66 ++++++++++++++++++++++------------ 2 files changed, 45 insertions(+), 22 deletions(-) create mode 100644 changelog.d/8809.misc (limited to 'synapse/replication') diff --git a/changelog.d/8809.misc b/changelog.d/8809.misc new file mode 100644 index 0000000000..bbf83cf18d --- /dev/null +++ b/changelog.d/8809.misc @@ -0,0 +1 @@ +Remove unnecessary function arguments and add typing to several membership replication classes. \ No newline at end of file diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index f0c37eaf5e..84e002f934 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -12,9 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple + +from twisted.web.http import Request from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint @@ -52,16 +53,23 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - async def _serialize_payload( - requester, room_id, user_id, remote_room_hosts, content - ): + async def _serialize_payload( # type: ignore + requester: Requester, + room_id: str, + user_id: str, + remote_room_hosts: List[str], + content: JsonDict, + ) -> JsonDict: """ Args: - requester(Requester) - room_id (str) - user_id (str) - remote_room_hosts (list[str]): Servers to try and join via - content(dict): The event content to use for the join event + requester: The user making the request according to the access token + room_id: The ID of the room. + user_id: The ID of the user. + remote_room_hosts: Servers to try and join via + content: The event content to use for the join event + + Returns: + A dict representing the payload of the request. """ return { "requester": requester.serialize(), @@ -69,7 +77,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): "content": content, } - async def _handle_request(self, request, room_id, user_id): + async def _handle_request( # type: ignore + self, request: Request, room_id: str, user_id: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] @@ -118,14 +128,17 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): txn_id: Optional[str], requester: Requester, content: JsonDict, - ): + ) -> JsonDict: """ Args: - invite_event_id: ID of the invite to be rejected - txn_id: optional transaction ID supplied by the client - requester: user making the rejection request, according to the access token - content: additional content to include in the rejection event. + invite_event_id: The ID of the invite to be rejected. + txn_id: Optional transaction ID supplied by the client + requester: User making the rejection request, according to the access token + content: Additional content to include in the rejection event. Normally an empty dict. + + Returns: + A dict representing the payload of the request. """ return { "txn_id": txn_id, @@ -133,7 +146,9 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): "content": content, } - async def _handle_request(self, request, invite_event_id): + async def _handle_request( # type: ignore + self, request: Request, invite_event_id: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) txn_id = content["txn_id"] @@ -174,18 +189,25 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): self.distributor = hs.get_distributor() @staticmethod - async def _serialize_payload(room_id, user_id, change): + async def _serialize_payload( # type: ignore + room_id: str, user_id: str, change: str + ) -> JsonDict: """ Args: - room_id (str) - user_id (str) - change (str): "left" + room_id: The ID of the room. + user_id: The ID of the user. + change: "left" + + Returns: + A dict representing the payload of the request. """ assert change == "left" return {} - def _handle_request(self, request, room_id, user_id, change): + def _handle_request( # type: ignore + self, request: Request, room_id: str, user_id: str, change: str + ) -> Tuple[int, JsonDict]: logger.info("user membership change: %s in %s", user_id, room_id) user = UserID.from_string(user_id) -- cgit 1.5.1