diff options
Diffstat (limited to 'synapse')
22 files changed, 448 insertions, 399 deletions
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index b2c764bfe8..136babe6ce 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -65,12 +65,23 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler -from synapse.replication.tcp.streams._base import ( +from synapse.replication.tcp.streams import ( + AccountDataStream, DeviceListsStream, + GroupServerStream, + PresenceStream, + PushersStream, + PushRulesStream, ReceiptsStream, + TagAccountDataStream, ToDeviceStream, + TypingStream, +) +from synapse.replication.tcp.streams.events import ( + EventsStream, + EventsStreamEventRow, + EventsStreamRow, ) -from synapse.replication.tcp.streams.events import EventsStreamEventRow, EventsStreamRow from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.client.v1 import events from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet @@ -626,7 +637,7 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler): if self.send_handler: self.send_handler.process_replication_rows(stream_name, token, rows) - if stream_name == "events": + if stream_name == EventsStream.NAME: # We shouldn't get multiple rows per token for events stream, so # we don't need to optimise this for multiple rows. for row in rows: @@ -649,43 +660,44 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler): ) await self.pusher_pool.on_new_notifications(token, token) - elif stream_name == "push_rules": + elif stream_name == PushRulesStream.NAME: self.notifier.on_new_event( "push_rules_key", token, users=[row.user_id for row in rows] ) - elif stream_name in ("account_data", "tag_account_data"): + elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME): self.notifier.on_new_event( "account_data_key", token, users=[row.user_id for row in rows] ) - elif stream_name == "receipts": + elif stream_name == ReceiptsStream.NAME: self.notifier.on_new_event( "receipt_key", token, rooms=[row.room_id for row in rows] ) await self.pusher_pool.on_new_receipts( token, token, {row.room_id for row in rows} ) - elif stream_name == "typing": + elif stream_name == TypingStream.NAME: self.typing_handler.process_replication_rows(token, rows) self.notifier.on_new_event( "typing_key", token, rooms=[row.room_id for row in rows] ) - elif stream_name == "to_device": + elif stream_name == ToDeviceStream.NAME: entities = [row.entity for row in rows if row.entity.startswith("@")] if entities: self.notifier.on_new_event("to_device_key", token, users=entities) - elif stream_name == "device_lists": + elif stream_name == DeviceListsStream.NAME: all_room_ids = set() for row in rows: - room_ids = await self.store.get_rooms_for_user(row.user_id) - all_room_ids.update(room_ids) + if row.entity.startswith("@"): + room_ids = await self.store.get_rooms_for_user(row.entity) + all_room_ids.update(room_ids) self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids) - elif stream_name == "presence": + elif stream_name == PresenceStream.NAME: await self.presence_handler.process_replication_rows(token, rows) - elif stream_name == "receipts": + elif stream_name == GroupServerStream.NAME: self.notifier.on_new_event( "groups_key", token, users=[row.user_id for row in rows] ) - elif stream_name == "pushers": + elif stream_name == PushersStream.NAME: for row in rows: if row.deleted: self.stop_pusher(row.user_id, row.app_id, row.pushkey) @@ -774,7 +786,10 @@ class FederationSenderHandler(object): # ... as well as device updates and messages elif stream_name == DeviceListsStream.NAME: - hosts = {row.destination for row in rows} + # The entities are either user IDs (starting with '@') whose devices + # have changed, or remote servers that we need to tell about + # changes. + hosts = {row.entity for row in rows if not row.entity.startswith("@")} for host in hosts: self.federation_sender.send_device_messages(host) diff --git a/synapse/config/_base.py b/synapse/config/_base.py index ba846042c4..efe2af5504 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -294,7 +294,6 @@ class RootConfig(object): report_stats=None, open_private_ports=False, listeners=None, - database_conf=None, tls_certificate_path=None, tls_private_key_path=None, acme_domain=None, @@ -367,7 +366,6 @@ class RootConfig(object): report_stats=report_stats, open_private_ports=open_private_ports, listeners=listeners, - database_conf=database_conf, tls_certificate_path=tls_certificate_path, tls_private_key_path=tls_private_key_path, acme_domain=acme_domain, diff --git a/synapse/config/database.py b/synapse/config/database.py index 219b32f670..b8ab2f86ac 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,14 +15,60 @@ # limitations under the License. import logging import os -from textwrap import indent - -import yaml from synapse.config._base import Config, ConfigError logger = logging.getLogger(__name__) +DEFAULT_CONFIG = """\ +## Database ## + +# The 'database' setting defines the database that synapse uses to store all of +# its data. +# +# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or +# 'psycopg2' (for PostgreSQL). +# +# 'args' gives options which are passed through to the database engine, +# except for options starting 'cp_', which are used to configure the Twisted +# connection pool. For a reference to valid arguments, see: +# * for sqlite: https://docs.python.org/3/library/sqlite3.html#sqlite3.connect +# * for postgres: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS +# * for the connection pool: https://twistedmatrix.com/documents/current/api/twisted.enterprise.adbapi.ConnectionPool.html#__init__ +# +# +# Example SQLite configuration: +# +#database: +# name: sqlite3 +# args: +# database: /path/to/homeserver.db +# +# +# Example Postgres configuration: +# +#database: +# name: psycopg2 +# args: +# user: synapse +# password: secretpassword +# database: synapse +# host: localhost +# cp_min: 5 +# cp_max: 10 +# +# For more information on using Synapse with Postgres, see `docs/postgres.md`. +# +database: + name: sqlite3 + args: + database: %(database_path)s + +# Number of events to cache in memory. +# +#event_cache_size: 10K +""" + class DatabaseConnectionConfig: """Contains the connection config for a particular database. @@ -36,10 +83,12 @@ class DatabaseConnectionConfig: """ def __init__(self, name: str, db_config: dict): - if db_config["name"] not in ("sqlite3", "psycopg2"): - raise ConfigError("Unsupported database type %r" % (db_config["name"],)) + db_engine = db_config.get("name", "sqlite3") - if db_config["name"] == "sqlite3": + if db_engine not in ("sqlite3", "psycopg2"): + raise ConfigError("Unsupported database type %r" % (db_engine,)) + + if db_engine == "sqlite3": db_config.setdefault("args", {}).update( {"cp_min": 1, "cp_max": 1, "check_same_thread": False} ) @@ -97,34 +146,10 @@ class DatabaseConfig(Config): self.set_databasepath(config.get("database_path")) - def generate_config_section(self, data_dir_path, database_conf, **kwargs): - if not database_conf: - database_path = os.path.join(data_dir_path, "homeserver.db") - database_conf = ( - """# The database engine name - name: "sqlite3" - # Arguments to pass to the engine - args: - # Path to the database - database: "%(database_path)s" - """ - % locals() - ) - else: - database_conf = indent(yaml.dump(database_conf), " " * 10).lstrip() - - return ( - """\ - ## Database ## - - database: - %(database_conf)s - # Number of events to cache in memory. - # - #event_cache_size: 10K - """ - % locals() - ) + def generate_config_section(self, data_dir_path, **kwargs): + return DEFAULT_CONFIG % { + "database_path": os.path.join(data_dir_path, "homeserver.db") + } def read_arguments(self, args): self.set_databasepath(args.database_path) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 5c991e5412..b0b0eba41e 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -25,11 +25,7 @@ from twisted.python.failure import Failure from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError -from synapse.api.room_versions import ( - KNOWN_ROOM_VERSIONS, - EventFormatVersions, - RoomVersion, -) +from synapse.api.room_versions import EventFormatVersions, RoomVersion from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.keyring import Keyring from synapse.events import EventBase, make_event_from_dict @@ -55,13 +51,15 @@ class FederationBase(object): self.store = hs.get_datastore() self._clock = hs.get_clock() - def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred: + def _check_sigs_and_hash( + self, room_version: RoomVersion, pdu: EventBase + ) -> Deferred: return make_deferred_yieldable( self._check_sigs_and_hashes(room_version, [pdu])[0] ) def _check_sigs_and_hashes( - self, room_version: str, pdus: List[EventBase] + self, room_version: RoomVersion, pdus: List[EventBase] ) -> List[Deferred]: """Checks that each of the received events is correctly signed by the sending server. @@ -146,7 +144,7 @@ class PduToCheckSig( def _check_sigs_on_pdus( - keyring: Keyring, room_version: str, pdus: Iterable[EventBase] + keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase] ) -> List[Deferred]: """Check that the given events are correctly signed @@ -191,10 +189,6 @@ def _check_sigs_on_pdus( for p in pdus ] - v = KNOWN_ROOM_VERSIONS.get(room_version) - if not v: - raise RuntimeError("Unrecognized room version %s" % (room_version,)) - # First we check that the sender event is signed by the sender's domain # (except if its a 3pid invite, in which case it may be sent by any server) pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)] @@ -204,7 +198,7 @@ def _check_sigs_on_pdus( ( p.sender_domain, p.redacted_pdu_json, - p.pdu.origin_server_ts if v.enforce_key_validity else 0, + p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, p.pdu.event_id, ) for p in pdus_to_check_sender @@ -227,7 +221,7 @@ def _check_sigs_on_pdus( # event id's domain (normally only the case for joins/leaves), and add additional # checks. Only do this if the room version has a concept of event ID domain # (ie, the room version uses old-style non-hash event IDs). - if v.event_format == EventFormatVersions.V1: + if room_version.event_format == EventFormatVersions.V1: pdus_to_check_event_id = [ p for p in pdus_to_check @@ -239,7 +233,7 @@ def _check_sigs_on_pdus( ( get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json, - p.pdu.origin_server_ts if v.enforce_key_validity else 0, + p.pdu.origin_server_ts if room_version.enforce_key_validity else 0, p.pdu.event_id, ) for p in pdus_to_check_event_id diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 8c6b839478..a0071fec94 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -220,8 +220,7 @@ class FederationClient(FederationBase): # FIXME: We should handle signature failures more gracefully. pdus[:] = await make_deferred_yieldable( defer.gatherResults( - self._check_sigs_and_hashes(room_version.identifier, pdus), - consumeErrors=True, + self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True, ).addErrback(unwrapFirstError) ) @@ -291,9 +290,7 @@ class FederationClient(FederationBase): pdu = pdu_list[0] # Check signatures are correct. - signed_pdu = await self._check_sigs_and_hash( - room_version.identifier, pdu - ) + signed_pdu = await self._check_sigs_and_hash(room_version, pdu) break @@ -350,7 +347,7 @@ class FederationClient(FederationBase): self, origin: str, pdus: List[EventBase], - room_version: str, + room_version: RoomVersion, outlier: bool = False, include_none: bool = False, ) -> List[EventBase]: @@ -396,7 +393,7 @@ class FederationClient(FederationBase): self.get_pdu( destinations=[pdu.origin], event_id=pdu.event_id, - room_version=room_version, # type: ignore + room_version=room_version, outlier=outlier, timeout=10000, ) @@ -434,7 +431,7 @@ class FederationClient(FederationBase): ] signed_auth = await self._check_sigs_and_hash_and_fetch( - destination, auth_chain, outlier=True, room_version=room_version.identifier + destination, auth_chain, outlier=True, room_version=room_version ) signed_auth.sort(key=lambda e: e.depth) @@ -661,7 +658,7 @@ class FederationClient(FederationBase): destination, list(pdus.values()), outlier=True, - room_version=room_version.identifier, + room_version=room_version, ) valid_pdus_map = {p.event_id: p for p in valid_pdus} @@ -756,7 +753,7 @@ class FederationClient(FederationBase): pdu = event_from_pdu_json(pdu_dict, room_version) # Check signatures are correct. - pdu = await self._check_sigs_and_hash(room_version.identifier, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) # FIXME: We should handle signature failures more gracefully. @@ -948,7 +945,7 @@ class FederationClient(FederationBase): ] signed_events = await self._check_sigs_and_hash_and_fetch( - destination, events, outlier=False, room_version=room_version.identifier + destination, events, outlier=False, room_version=room_version ) except HttpResponseException as e: if not e.code == 400: diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 275b9c99d7..89d521bc31 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -409,7 +409,7 @@ class FederationServer(FederationBase): pdu = event_from_pdu_json(content, room_version) origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, pdu.room_id) - pdu = await self._check_sigs_and_hash(room_version.identifier, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version) time_now = self._clock.time_msec() return {"event": ret_pdu.get_pdu_json(time_now)} @@ -425,7 +425,7 @@ class FederationServer(FederationBase): logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) - pdu = await self._check_sigs_and_hash(room_version.identifier, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) res_pdus = await self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() @@ -455,7 +455,7 @@ class FederationServer(FederationBase): logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) - pdu = await self._check_sigs_and_hash(room_version.identifier, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) await self.handler.on_send_leave_request(origin, pdu) return {} @@ -611,7 +611,7 @@ class FederationServer(FederationBase): logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point - room_version = await self.store.get_room_version_id(pdu.room_id) + room_version = await self.store.get_room_version(pdu.room_id) # Check signature. try: diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 5526015ddb..6912165622 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -747,7 +747,7 @@ class PresenceHandler(object): return False - async def get_all_presence_updates(self, last_id, current_id): + async def get_all_presence_updates(self, last_id, current_id, limit): """ Gets a list of presence update rows from between the given stream ids. Each row has: @@ -762,7 +762,7 @@ class PresenceHandler(object): """ # TODO(markjh): replicate the unpersisted changes. # This could use the in-memory stores for recent changes. - rows = await self.store.get_all_presence_updates(last_id, current_id) + rows = await self.store.get_all_presence_updates(last_id, current_id, limit) return rows def notify_new_event(self): diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 391bceb0c4..c7bc14c623 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -15,6 +15,7 @@ import logging from collections import namedtuple +from typing import List from twisted.internet import defer @@ -257,7 +258,13 @@ class TypingHandler(object): "typing_key", self._latest_room_serial, rooms=[member.room_id] ) - async def get_all_typing_updates(self, last_id, current_id): + async def get_all_typing_updates( + self, last_id: int, current_id: int, limit: int + ) -> List[dict]: + """Get up to `limit` typing updates between the given tokens, earliest + updates first. + """ + if last_id == current_id: return [] @@ -275,7 +282,7 @@ class TypingHandler(object): typing = self._room_typing[room_id] rows.append((serial, room_id, list(typing))) rows.sort() - return rows + return rows[:limit] def get_current_token(self): return self._latest_room_serial diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 1c77687eea..23b1650e41 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -29,7 +29,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto self.hs = hs self._device_list_id_gen = SlavedIdTracker( - db_conn, "device_lists_stream", "stream_id" + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ], ) device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( @@ -55,23 +61,27 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto def process_replication_rows(self, stream_name, token, rows): if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(token) - for row in rows: - self._invalidate_caches_for_devices(token, row.user_id, row.destination) + self._invalidate_caches_for_devices(token, rows) elif stream_name == UserSignatureStream.NAME: + self._device_list_id_gen.advance(token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super(SlavedDeviceStore, self).process_replication_rows( stream_name, token, rows ) - def _invalidate_caches_for_devices(self, token, user_id, destination): - self._device_list_stream_cache.entity_has_changed(user_id, token) - - if destination: - self._device_list_federation_stream_cache.entity_has_changed( - destination, token - ) + def _invalidate_caches_for_devices(self, token, rows): + for row in rows: + # The entities are either user IDs (starting with '@') whose devices + # have changed, or remote servers that we need to tell about + # changes. + if row.entity.startswith("@"): + self._device_list_stream_cache.entity_has_changed(row.entity, token) + self.get_cached_devices_for_user.invalidate((row.entity,)) + self._get_cached_user_device.invalidate_many((row.entity,)) + self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,)) - self.get_cached_devices_for_user.invalidate((user_id,)) - self._get_cached_user_device.invalidate_many((user_id,)) - self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) + else: + self._device_list_federation_stream_cache.entity_has_changed( + row.entity, token + ) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index ce9d1fae12..6e2ebaf614 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -166,11 +166,6 @@ class ReplicationStreamer(object): self.pending_updates = False with Measure(self.clock, "repl.stream.get_updates"): - # First we tell the streams that they should update their - # current tokens. - for stream in self.streams: - stream.advance_current_token() - all_streams = self.streams if self._replication_torture_level is not None: @@ -180,7 +175,7 @@ class ReplicationStreamer(object): random.shuffle(all_streams) for stream in all_streams: - if stream.last_token == stream.upto_token: + if stream.last_token == stream.current_token(): continue if self._replication_torture_level: @@ -192,7 +187,7 @@ class ReplicationStreamer(object): "Getting stream: %s: %s -> %s", stream.NAME, stream.last_token, - stream.upto_token, + stream.current_token(), ) try: updates, current_token = await stream.get_updates() diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 5f52264e84..29199f5b46 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -24,27 +24,61 @@ Each stream is defined by the following information: current_token: The function that returns the current token for the stream update_function: The function that returns a list of updates between two tokens """ - -from . import _base, events, federation +from synapse.replication.tcp.streams._base import ( + AccountDataStream, + BackfillStream, + CachesStream, + DeviceListsStream, + GroupServerStream, + PresenceStream, + PublicRoomsStream, + PushersStream, + PushRulesStream, + ReceiptsStream, + TagAccountDataStream, + ToDeviceStream, + TypingStream, + UserSignatureStream, +) +from synapse.replication.tcp.streams.events import EventsStream +from synapse.replication.tcp.streams.federation import FederationStream STREAMS_MAP = { stream.NAME: stream for stream in ( - events.EventsStream, - _base.BackfillStream, - _base.PresenceStream, - _base.TypingStream, - _base.ReceiptsStream, - _base.PushRulesStream, - _base.PushersStream, - _base.CachesStream, - _base.PublicRoomsStream, - _base.DeviceListsStream, - _base.ToDeviceStream, - federation.FederationStream, - _base.TagAccountDataStream, - _base.AccountDataStream, - _base.GroupServerStream, - _base.UserSignatureStream, + EventsStream, + BackfillStream, + PresenceStream, + TypingStream, + ReceiptsStream, + PushRulesStream, + PushersStream, + CachesStream, + PublicRoomsStream, + DeviceListsStream, + ToDeviceStream, + FederationStream, + TagAccountDataStream, + AccountDataStream, + GroupServerStream, + UserSignatureStream, ) } + +__all__ = [ + "STREAMS_MAP", + "BackfillStream", + "PresenceStream", + "TypingStream", + "ReceiptsStream", + "PushRulesStream", + "PushersStream", + "CachesStream", + "PublicRoomsStream", + "DeviceListsStream", + "ToDeviceStream", + "TagAccountDataStream", + "AccountDataStream", + "GroupServerStream", + "UserSignatureStream", +] diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 208e8a667b..abf5c6c6a8 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -17,10 +17,12 @@ import itertools import logging from collections import namedtuple -from typing import Any, List, Optional +from typing import Any, List, Optional, Tuple import attr +from synapse.types import JsonDict + logger = logging.getLogger(__name__) @@ -94,9 +96,13 @@ PublicRoomsStreamRow = namedtuple( "network_id", # str, optional ), ) -DeviceListsStreamRow = namedtuple( - "DeviceListsStreamRow", ("user_id", "destination") # str # str -) + + +@attr.s +class DeviceListsStreamRow: + entity = attr.ib(type=str) + + ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str TagAccountDataStreamRow = namedtuple( "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict @@ -115,13 +121,12 @@ class Stream(object): """Base class for the streams. Provides a `get_updates()` function that returns new updates since the last - time it was called up until the point `advance_current_token` was called. + time it was called. """ NAME = None # type: str # The name of the stream # The type of the row. Used by the default impl of parse_row. ROW_TYPE = None # type: Any - _LIMITED = True # Whether the update function takes a limit @classmethod def parse_row(cls, row): @@ -142,26 +147,15 @@ class Stream(object): # The token from which we last asked for updates self.last_token = self.current_token() - # The token that we will get updates up to - self.upto_token = self.current_token() - - def advance_current_token(self): - """Updates `upto_token` to "now", which updates up until which point - get_updates[_since] will fetch rows till. - """ - self.upto_token = self.current_token() - def discard_updates_and_advance(self): """Called when the stream should advance but the updates would be discarded, e.g. when there are no currently connected workers. """ - self.upto_token = self.current_token() - self.last_token = self.upto_token + self.last_token = self.current_token() async def get_updates(self): """Gets all updates since the last time this function was called (or - since the stream was constructed if it hadn't been called before), - until the `upto_token` + since the stream was constructed if it hadn't been called before). Returns: Deferred[Tuple[List[Tuple[int, Any]], int]: @@ -174,44 +168,45 @@ class Stream(object): return updates, current_token - async def get_updates_since(self, from_token): + async def get_updates_since( + self, from_token: int + ) -> Tuple[List[Tuple[int, JsonDict]], int]: """Like get_updates except allows specifying from when we should stream updates Returns: - Deferred[Tuple[List[Tuple[int, Any]], int]: - Resolves to a pair ``(updates, current_token)``, where ``updates`` is a - list of ``(token, row)`` entries. ``row`` will be json-serialised and - sent over the replication steam. + Resolves to a pair `(updates, new_last_token)`, where `updates` is + a list of `(token, row)` entries and `new_last_token` is the new + position in stream. """ + if from_token in ("NOW", "now"): - return [], self.upto_token + return [], self.current_token() - current_token = self.upto_token + current_token = self.current_token() from_token = int(from_token) if from_token == current_token: return [], current_token - logger.info("get_updates_since: %s", self.__class__) - if self._LIMITED: - rows = await self.update_function( - from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 - ) + rows = await self.update_function( + from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 + ) - # never turn more than MAX_EVENTS_BEHIND + 1 into updates. - rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) - else: - rows = await self.update_function(from_token, current_token) + # never turn more than MAX_EVENTS_BEHIND + 1 into updates. + rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) updates = [(row[0], row[1:]) for row in rows] # check we didn't get more rows than the limit. # doing it like this allows the update_function to be a generator. - if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND: + if len(updates) >= MAX_EVENTS_BEHIND: raise Exception("stream %s has fallen behind" % (self.NAME)) + # The update function didn't hit the limit, so we must have got all + # the updates to `current_token`, and can return that as our new + # stream position. return updates, current_token def current_token(self): @@ -223,9 +218,8 @@ class Stream(object): """ raise NotImplementedError() - def update_function(self, from_token, current_token, limit=None): - """Get updates between from_token and to_token. If Stream._LIMITED is - True then limit is provided, otherwise it's not. + def update_function(self, from_token, current_token, limit): + """Get updates between from_token and to_token. Returns: Deferred(list(tuple)): the first entry in the tuple is the token for @@ -253,7 +247,6 @@ class BackfillStream(Stream): class PresenceStream(Stream): NAME = "presence" - _LIMITED = False ROW_TYPE = PresenceStreamRow def __init__(self, hs): @@ -268,7 +261,6 @@ class PresenceStream(Stream): class TypingStream(Stream): NAME = "typing" - _LIMITED = False ROW_TYPE = TypingStreamRow def __init__(self, hs): @@ -363,11 +355,11 @@ class PublicRoomsStream(Stream): class DeviceListsStream(Stream): - """Someone added/changed/removed a device + """Either a user has updated their devices or a remote server needs to be + told about a device update. """ NAME = "device_lists" - _LIMITED = False ROW_TYPE = DeviceListsStreamRow def __init__(self, hs): @@ -457,7 +449,6 @@ class UserSignatureStream(Stream): """ NAME = "user_signature" - _LIMITED = False ROW_TYPE = UserSignatureStreamRow def __init__(self, hs): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index d0d4999795..31551524f8 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -28,7 +28,6 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) -from synapse.push.mailer import load_jinja2_templates from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import UserID, map_username_to_mxid_localpart @@ -548,13 +547,6 @@ class SSOAuthHandler(object): self._registration_handler = hs.get_registration_handler() self._macaroon_gen = hs.get_macaroon_generator() - # Load the redirect page HTML template - self._template = load_jinja2_templates( - hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"], - )[0] - - self._server_name = hs.config.server_name - # cast to tuple for use with str.startswith self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 50e080673b..85cf5a14c6 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -142,14 +142,6 @@ class AuthRestServlet(RestServlet): % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), "sitekey": self.hs.config.recaptcha_public_key, } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - return None elif stagetype == LoginType.TERMS: html = TERMS_TEMPLATE % { "session": session, @@ -158,17 +150,19 @@ class AuthRestServlet(RestServlet): "myurl": "%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - return None else: raise SynapseError(404, "Unknown auth stage type") + # Render the HTML and return. + html_bytes = html.encode("utf8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + finish_request(request) + return None + async def on_POST(self, request, stagetype): session = parse_string(request, "session") @@ -196,15 +190,6 @@ class AuthRestServlet(RestServlet): % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), "sitekey": self.hs.config.recaptcha_public_key, } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - - return None elif stagetype == LoginType.TERMS: authdict = {"session": session} @@ -225,17 +210,19 @@ class AuthRestServlet(RestServlet): "myurl": "%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), } - html_bytes = html.encode("utf8") - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) - - request.write(html_bytes) - finish_request(request) - return None else: raise SynapseError(404, "Unknown auth stage type") + # Render the HTML and return. + html_bytes = html.encode("utf8") + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),)) + + request.write(html_bytes) + finish_request(request) + return None + def on_OPTIONS(self, _): return 200, {} diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index 66a01559e1..24d3ae5bbc 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -50,6 +50,9 @@ class DownloadResource(DirectServeResource): b" media-src 'self';" b" object-src 'self';", ) + request.setHeader( + b"Referrer-Policy", b"no-referrer", + ) server_name, media_id, name = parse_media_id(request) if server_name == self.server_name: await self.media_repo.get_local_media(request, media_id, name) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 490b1b45a8..fd10d42f2f 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -24,7 +24,6 @@ from six import iteritems import twisted.internet.error import twisted.web.http -from twisted.internet import defer from twisted.web.resource import Resource from synapse.api.errors import ( @@ -114,15 +113,14 @@ class MediaRepository(object): "update_recently_accessed_media", self._update_recently_accessed ) - @defer.inlineCallbacks - def _update_recently_accessed(self): + async def _update_recently_accessed(self): remote_media = self.recently_accessed_remotes self.recently_accessed_remotes = set() local_media = self.recently_accessed_locals self.recently_accessed_locals = set() - yield self.store.update_cached_last_access_time( + await self.store.update_cached_last_access_time( local_media, remote_media, self.clock.time_msec() ) @@ -138,8 +136,7 @@ class MediaRepository(object): else: self.recently_accessed_locals.add(media_id) - @defer.inlineCallbacks - def create_content( + async def create_content( self, media_type, upload_name, content, content_length, auth_user ): """Store uploaded content for a local user and return the mxc URL @@ -158,11 +155,11 @@ class MediaRepository(object): file_info = FileInfo(server_name=None, file_id=media_id) - fname = yield self.media_storage.store_file(content, file_info) + fname = await self.media_storage.store_file(content, file_info) logger.info("Stored local media in file %r", fname) - yield self.store.store_local_media( + await self.store.store_local_media( media_id=media_id, media_type=media_type, time_now_ms=self.clock.time_msec(), @@ -171,12 +168,11 @@ class MediaRepository(object): user_id=auth_user, ) - yield self._generate_thumbnails(None, media_id, media_id, media_type) + await self._generate_thumbnails(None, media_id, media_id, media_type) return "mxc://%s/%s" % (self.server_name, media_id) - @defer.inlineCallbacks - def get_local_media(self, request, media_id, name): + async def get_local_media(self, request, media_id, name): """Responds to reqests for local media, if exists, or returns 404. Args: @@ -190,7 +186,7 @@ class MediaRepository(object): Deferred: Resolves once a response has successfully been written to request """ - media_info = yield self.store.get_local_media(media_id) + media_info = await self.store.get_local_media(media_id) if not media_info or media_info["quarantined_by"]: respond_404(request) return @@ -204,13 +200,12 @@ class MediaRepository(object): file_info = FileInfo(None, media_id, url_cache=url_cache) - responder = yield self.media_storage.fetch_media(file_info) - yield respond_with_responder( + responder = await self.media_storage.fetch_media(file_info) + await respond_with_responder( request, responder, media_type, media_length, upload_name ) - @defer.inlineCallbacks - def get_remote_media(self, request, server_name, media_id, name): + async def get_remote_media(self, request, server_name, media_id, name): """Respond to requests for remote media. Args: @@ -236,8 +231,8 @@ class MediaRepository(object): # We linearize here to ensure that we don't try and download remote # media multiple times concurrently key = (server_name, media_id) - with (yield self.remote_media_linearizer.queue(key)): - responder, media_info = yield self._get_remote_media_impl( + with (await self.remote_media_linearizer.queue(key)): + responder, media_info = await self._get_remote_media_impl( server_name, media_id ) @@ -246,14 +241,13 @@ class MediaRepository(object): media_type = media_info["media_type"] media_length = media_info["media_length"] upload_name = name if name else media_info["upload_name"] - yield respond_with_responder( + await respond_with_responder( request, responder, media_type, media_length, upload_name ) else: respond_404(request) - @defer.inlineCallbacks - def get_remote_media_info(self, server_name, media_id): + async def get_remote_media_info(self, server_name, media_id): """Gets the media info associated with the remote file, downloading if necessary. @@ -274,8 +268,8 @@ class MediaRepository(object): # We linearize here to ensure that we don't try and download remote # media multiple times concurrently key = (server_name, media_id) - with (yield self.remote_media_linearizer.queue(key)): - responder, media_info = yield self._get_remote_media_impl( + with (await self.remote_media_linearizer.queue(key)): + responder, media_info = await self._get_remote_media_impl( server_name, media_id ) @@ -286,8 +280,7 @@ class MediaRepository(object): return media_info - @defer.inlineCallbacks - def _get_remote_media_impl(self, server_name, media_id): + async def _get_remote_media_impl(self, server_name, media_id): """Looks for media in local cache, if not there then attempt to download from remote server. @@ -299,7 +292,7 @@ class MediaRepository(object): Returns: Deferred[(Responder, media_info)] """ - media_info = yield self.store.get_cached_remote_media(server_name, media_id) + media_info = await self.store.get_cached_remote_media(server_name, media_id) # file_id is the ID we use to track the file locally. If we've already # seen the file then reuse the existing ID, otherwise genereate a new @@ -317,19 +310,18 @@ class MediaRepository(object): logger.info("Media is quarantined") raise NotFoundError() - responder = yield self.media_storage.fetch_media(file_info) + responder = await self.media_storage.fetch_media(file_info) if responder: return responder, media_info # Failed to find the file anywhere, lets download it. - media_info = yield self._download_remote_file(server_name, media_id, file_id) + media_info = await self._download_remote_file(server_name, media_id, file_id) - responder = yield self.media_storage.fetch_media(file_info) + responder = await self.media_storage.fetch_media(file_info) return responder, media_info - @defer.inlineCallbacks - def _download_remote_file(self, server_name, media_id, file_id): + async def _download_remote_file(self, server_name, media_id, file_id): """Attempt to download the remote file from the given server name, using the given file_id as the local id. @@ -351,7 +343,7 @@ class MediaRepository(object): ("/_matrix/media/v1/download", server_name, media_id) ) try: - length, headers = yield self.client.get_file( + length, headers = await self.client.get_file( server_name, request_path, output_stream=f, @@ -397,7 +389,7 @@ class MediaRepository(object): ) raise SynapseError(502, "Failed to fetch remote media") - yield finish() + await finish() media_type = headers[b"Content-Type"][0].decode("ascii") upload_name = get_filename_from_headers(headers) @@ -405,7 +397,7 @@ class MediaRepository(object): logger.info("Stored remote media in file %r", fname) - yield self.store.store_cached_remote_media( + await self.store.store_cached_remote_media( origin=server_name, media_id=media_id, media_type=media_type, @@ -423,7 +415,7 @@ class MediaRepository(object): "filesystem_id": file_id, } - yield self._generate_thumbnails(server_name, media_id, file_id, media_type) + await self._generate_thumbnails(server_name, media_id, file_id, media_type) return media_info @@ -458,16 +450,15 @@ class MediaRepository(object): return t_byte_source - @defer.inlineCallbacks - def generate_local_exact_thumbnail( + async def generate_local_exact_thumbnail( self, media_id, t_width, t_height, t_method, t_type, url_cache ): - input_path = yield self.media_storage.ensure_media_is_in_local_cache( + input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(None, media_id, url_cache=url_cache) ) thumbnailer = Thumbnailer(input_path) - t_byte_source = yield defer_to_thread( + t_byte_source = await defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, thumbnailer, @@ -490,7 +481,7 @@ class MediaRepository(object): thumbnail_type=t_type, ) - output_path = yield self.media_storage.store_file( + output_path = await self.media_storage.store_file( t_byte_source, file_info ) finally: @@ -500,22 +491,21 @@ class MediaRepository(object): t_len = os.path.getsize(output_path) - yield self.store.store_local_thumbnail( + await self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) return output_path - @defer.inlineCallbacks - def generate_remote_exact_thumbnail( + async def generate_remote_exact_thumbnail( self, server_name, file_id, media_id, t_width, t_height, t_method, t_type ): - input_path = yield self.media_storage.ensure_media_is_in_local_cache( + input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=False) ) thumbnailer = Thumbnailer(input_path) - t_byte_source = yield defer_to_thread( + t_byte_source = await defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, thumbnailer, @@ -537,7 +527,7 @@ class MediaRepository(object): thumbnail_type=t_type, ) - output_path = yield self.media_storage.store_file( + output_path = await self.media_storage.store_file( t_byte_source, file_info ) finally: @@ -547,7 +537,7 @@ class MediaRepository(object): t_len = os.path.getsize(output_path) - yield self.store.store_remote_media_thumbnail( + await self.store.store_remote_media_thumbnail( server_name, media_id, file_id, @@ -560,8 +550,7 @@ class MediaRepository(object): return output_path - @defer.inlineCallbacks - def _generate_thumbnails( + async def _generate_thumbnails( self, server_name, media_id, file_id, media_type, url_cache=False ): """Generate and store thumbnails for an image. @@ -582,7 +571,7 @@ class MediaRepository(object): if not requirements: return - input_path = yield self.media_storage.ensure_media_is_in_local_cache( + input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=url_cache) ) @@ -600,7 +589,7 @@ class MediaRepository(object): return if thumbnailer.transpose_method is not None: - m_width, m_height = yield defer_to_thread( + m_width, m_height = await defer_to_thread( self.hs.get_reactor(), thumbnailer.transpose ) @@ -620,11 +609,11 @@ class MediaRepository(object): for (t_width, t_height, t_type), t_method in iteritems(thumbnails): # Generate the thumbnail if t_method == "crop": - t_byte_source = yield defer_to_thread( + t_byte_source = await defer_to_thread( self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type ) elif t_method == "scale": - t_byte_source = yield defer_to_thread( + t_byte_source = await defer_to_thread( self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type ) else: @@ -646,7 +635,7 @@ class MediaRepository(object): url_cache=url_cache, ) - output_path = yield self.media_storage.store_file( + output_path = await self.media_storage.store_file( t_byte_source, file_info ) finally: @@ -656,7 +645,7 @@ class MediaRepository(object): # Write to database if server_name: - yield self.store.store_remote_media_thumbnail( + await self.store.store_remote_media_thumbnail( server_name, media_id, file_id, @@ -667,15 +656,14 @@ class MediaRepository(object): t_len, ) else: - yield self.store.store_local_thumbnail( + await self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) return {"width": m_width, "height": m_height} - @defer.inlineCallbacks - def delete_old_remote_media(self, before_ts): - old_media = yield self.store.get_remote_media_before(before_ts) + async def delete_old_remote_media(self, before_ts): + old_media = await self.store.get_remote_media_before(before_ts) deleted = 0 @@ -689,7 +677,7 @@ class MediaRepository(object): # TODO: Should we delete from the backup store - with (yield self.remote_media_linearizer.queue(key)): + with (await self.remote_media_linearizer.queue(key)): full_path = self.filepaths.remote_media_filepath(origin, file_id) try: os.remove(full_path) @@ -705,7 +693,7 @@ class MediaRepository(object): ) shutil.rmtree(thumbnail_dir, ignore_errors=True) - yield self.store.delete_remote_media(origin, media_id) + await self.store.delete_remote_media(origin, media_id) deleted += 1 return {"deleted": deleted} diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 07e395cfd1..c46676f8fc 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -165,8 +165,7 @@ class PreviewUrlResource(DirectServeResource): og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe)) respond_with_json_bytes(request, 200, og, send_cors=True) - @defer.inlineCallbacks - def _do_preview(self, url, user, ts): + async def _do_preview(self, url, user, ts): """Check the db, and download the URL and build a preview Args: @@ -179,7 +178,7 @@ class PreviewUrlResource(DirectServeResource): """ # check the URL cache in the DB (which will also provide us with # historical previews, if we have any) - cache_result = yield self.store.get_url_cache(url, ts) + cache_result = await self.store.get_url_cache(url, ts) if ( cache_result and cache_result["expires_ts"] > ts @@ -192,13 +191,13 @@ class PreviewUrlResource(DirectServeResource): og = og.encode("utf8") return og - media_info = yield self._download_url(url, user) + media_info = await self._download_url(url, user) logger.debug("got media_info of '%s'", media_info) if _is_media(media_info["media_type"]): file_id = media_info["filesystem_id"] - dims = yield self.media_repo._generate_thumbnails( + dims = await self.media_repo._generate_thumbnails( None, file_id, file_id, media_info["media_type"], url_cache=True ) @@ -248,14 +247,14 @@ class PreviewUrlResource(DirectServeResource): # request itself and benefit from the same caching etc. But for now we # just rely on the caching on the master request to speed things up. if "og:image" in og and og["og:image"]: - image_info = yield self._download_url( + image_info = await self._download_url( _rebase_url(og["og:image"], media_info["uri"]), user ) if _is_media(image_info["media_type"]): # TODO: make sure we don't choke on white-on-transparent images file_id = image_info["filesystem_id"] - dims = yield self.media_repo._generate_thumbnails( + dims = await self.media_repo._generate_thumbnails( None, file_id, file_id, image_info["media_type"], url_cache=True ) if dims: @@ -293,7 +292,7 @@ class PreviewUrlResource(DirectServeResource): jsonog = json.dumps(og) # store OG in history-aware DB cache - yield self.store.store_url_cache( + await self.store.store_url_cache( url, media_info["response_code"], media_info["etag"], @@ -305,8 +304,7 @@ class PreviewUrlResource(DirectServeResource): return jsonog.encode("utf8") - @defer.inlineCallbacks - def _download_url(self, url, user): + async def _download_url(self, url, user): # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -318,7 +316,7 @@ class PreviewUrlResource(DirectServeResource): with self.media_storage.store_into_file(file_info) as (f, fname, finish): try: logger.debug("Trying to get url '%s'", url) - length, headers, uri, code = yield self.client.get_file( + length, headers, uri, code = await self.client.get_file( url, output_stream=f, max_size=self.max_spider_size ) except SynapseError: @@ -345,7 +343,7 @@ class PreviewUrlResource(DirectServeResource): % (traceback.format_exception_only(sys.exc_info()[0], e),), Codes.UNKNOWN, ) - yield finish() + await finish() try: if b"Content-Type" in headers: @@ -356,7 +354,7 @@ class PreviewUrlResource(DirectServeResource): download_name = get_filename_from_headers(headers) - yield self.store.store_local_media( + await self.store.store_local_media( media_id=file_id, media_type=media_type, time_now_ms=self.clock.time_msec(), @@ -393,8 +391,7 @@ class PreviewUrlResource(DirectServeResource): "expire_url_cache_data", self._expire_url_cache_data ) - @defer.inlineCallbacks - def _expire_url_cache_data(self): + async def _expire_url_cache_data(self): """Clean up expired url cache content, media and thumbnails. """ # TODO: Delete from backup media store @@ -403,12 +400,12 @@ class PreviewUrlResource(DirectServeResource): logger.info("Running url preview cache expiry") - if not (yield self.store.db.updates.has_completed_background_updates()): + if not (await self.store.db.updates.has_completed_background_updates()): logger.info("Still running DB updates; skipping expiry") return # First we delete expired url cache entries - media_ids = yield self.store.get_expired_url_cache(now) + media_ids = await self.store.get_expired_url_cache(now) removed_media = [] for media_id in media_ids: @@ -430,7 +427,7 @@ class PreviewUrlResource(DirectServeResource): except Exception: pass - yield self.store.delete_url_cache(removed_media) + await self.store.delete_url_cache(removed_media) if removed_media: logger.info("Deleted %d entries from url cache", len(removed_media)) @@ -440,7 +437,7 @@ class PreviewUrlResource(DirectServeResource): # may have a room open with a preview url thing open). # So we wait a couple of days before deleting, just in case. expire_before = now - 2 * 24 * 60 * 60 * 1000 - media_ids = yield self.store.get_url_cache_media_before(expire_before) + media_ids = await self.store.get_url_cache_media_before(expire_before) removed_media = [] for media_id in media_ids: @@ -478,7 +475,7 @@ class PreviewUrlResource(DirectServeResource): except Exception: pass - yield self.store.delete_url_cache_media(removed_media) + await self.store.delete_url_cache_media(removed_media) logger.info("Deleted %d media from url cache", len(removed_media)) diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index d57480f761..0b87220234 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.http.server import ( DirectServeResource, set_cors_headers, @@ -79,11 +77,10 @@ class ThumbnailResource(DirectServeResource): ) self.media_repo.mark_recently_accessed(server_name, media_id) - @defer.inlineCallbacks - def _respond_local_thumbnail( + async def _respond_local_thumbnail( self, request, media_id, width, height, method, m_type ): - media_info = yield self.store.get_local_media(media_id) + media_info = await self.store.get_local_media(media_id) if not media_info: respond_404(request) @@ -93,7 +90,7 @@ class ThumbnailResource(DirectServeResource): respond_404(request) return - thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) + thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) if thumbnail_infos: thumbnail_info = self._select_thumbnail( @@ -114,14 +111,13 @@ class ThumbnailResource(DirectServeResource): t_type = file_info.thumbnail_type t_length = thumbnail_info["thumbnail_length"] - responder = yield self.media_storage.fetch_media(file_info) - yield respond_with_responder(request, responder, t_type, t_length) + responder = await self.media_storage.fetch_media(file_info) + await respond_with_responder(request, responder, t_type, t_length) else: logger.info("Couldn't find any generated thumbnails") respond_404(request) - @defer.inlineCallbacks - def _select_or_generate_local_thumbnail( + async def _select_or_generate_local_thumbnail( self, request, media_id, @@ -130,7 +126,7 @@ class ThumbnailResource(DirectServeResource): desired_method, desired_type, ): - media_info = yield self.store.get_local_media(media_id) + media_info = await self.store.get_local_media(media_id) if not media_info: respond_404(request) @@ -140,7 +136,7 @@ class ThumbnailResource(DirectServeResource): respond_404(request) return - thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) + thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) for info in thumbnail_infos: t_w = info["thumbnail_width"] == desired_width t_h = info["thumbnail_height"] == desired_height @@ -162,15 +158,15 @@ class ThumbnailResource(DirectServeResource): t_type = file_info.thumbnail_type t_length = info["thumbnail_length"] - responder = yield self.media_storage.fetch_media(file_info) + responder = await self.media_storage.fetch_media(file_info) if responder: - yield respond_with_responder(request, responder, t_type, t_length) + await respond_with_responder(request, responder, t_type, t_length) return logger.debug("We don't have a thumbnail of that size. Generating") # Okay, so we generate one. - file_path = yield self.media_repo.generate_local_exact_thumbnail( + file_path = await self.media_repo.generate_local_exact_thumbnail( media_id, desired_width, desired_height, @@ -180,13 +176,12 @@ class ThumbnailResource(DirectServeResource): ) if file_path: - yield respond_with_file(request, desired_type, file_path) + await respond_with_file(request, desired_type, file_path) else: logger.warning("Failed to generate thumbnail") respond_404(request) - @defer.inlineCallbacks - def _select_or_generate_remote_thumbnail( + async def _select_or_generate_remote_thumbnail( self, request, server_name, @@ -196,9 +191,9 @@ class ThumbnailResource(DirectServeResource): desired_method, desired_type, ): - media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) + media_info = await self.media_repo.get_remote_media_info(server_name, media_id) - thumbnail_infos = yield self.store.get_remote_media_thumbnails( + thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) @@ -224,15 +219,15 @@ class ThumbnailResource(DirectServeResource): t_type = file_info.thumbnail_type t_length = info["thumbnail_length"] - responder = yield self.media_storage.fetch_media(file_info) + responder = await self.media_storage.fetch_media(file_info) if responder: - yield respond_with_responder(request, responder, t_type, t_length) + await respond_with_responder(request, responder, t_type, t_length) return logger.debug("We don't have a thumbnail of that size. Generating") # Okay, so we generate one. - file_path = yield self.media_repo.generate_remote_exact_thumbnail( + file_path = await self.media_repo.generate_remote_exact_thumbnail( server_name, file_id, media_id, @@ -243,21 +238,20 @@ class ThumbnailResource(DirectServeResource): ) if file_path: - yield respond_with_file(request, desired_type, file_path) + await respond_with_file(request, desired_type, file_path) else: logger.warning("Failed to generate thumbnail") respond_404(request) - @defer.inlineCallbacks - def _respond_remote_thumbnail( + async def _respond_remote_thumbnail( self, request, server_name, media_id, width, height, method, m_type ): # TODO: Don't download the whole remote file # We should proxy the thumbnail from the remote server instead of # downloading the remote file and generating our own thumbnails. - media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) + media_info = await self.media_repo.get_remote_media_info(server_name, media_id) - thumbnail_infos = yield self.store.get_remote_media_thumbnails( + thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) @@ -278,8 +272,8 @@ class ThumbnailResource(DirectServeResource): t_type = file_info.thumbnail_type t_length = thumbnail_info["thumbnail_length"] - responder = yield self.media_storage.fetch_media(file_info) - yield respond_with_responder(request, responder, t_type, t_length) + responder = await self.media_storage.fetch_media(file_info) + await respond_with_responder(request, responder, t_type, t_length) else: logger.info("Failed to find any generated thumbnails") respond_404(request) diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index acca079f23..649e835303 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -144,7 +144,10 @@ class DataStore( db_conn, "device_lists_stream", "stream_id", - extra_tables=[("user_signature_stream", "stream_id")], + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ], ) self._cross_signing_id_gen = StreamIdGenerator( db_conn, "e2e_cross_signing_keys", "stream_id" diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index d55733a4cd..2d47cfd131 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import List, Tuple from six import iteritems @@ -31,7 +32,7 @@ from synapse.logging.opentracing import ( ) from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import Database +from synapse.storage.database import Database, LoggingTransaction from synapse.types import Collection, get_verify_key_from_cross_signing_key from synapse.util.caches.descriptors import ( Cache, @@ -112,23 +113,13 @@ class DeviceWorkerStore(SQLBaseStore): if not has_changed: return now_stream_id, [] - # We retrieve n+1 devices from the list of outbound pokes where n is - # our outbound device update limit. We then check if the very last - # device has the same stream_id as the second-to-last device. If so, - # then we ignore all devices with that stream_id and only send the - # devices with a lower stream_id. - # - # If when culling the list we end up with no devices afterwards, we - # consider the device update to be too large, and simply skip the - # stream_id; the rationale being that such a large device list update - # is likely an error. updates = yield self.db.runInteraction( "get_device_updates_by_remote", self._get_device_updates_by_remote_txn, destination, from_stream_id, now_stream_id, - limit + 1, + limit, ) # Return an empty list if there are no updates @@ -166,14 +157,6 @@ class DeviceWorkerStore(SQLBaseStore): "device_id": verify_key.version, } - # if we have exceeded the limit, we need to exclude any results with the - # same stream_id as the last row. - if len(updates) > limit: - stream_id_cutoff = updates[-1][2] - now_stream_id = stream_id_cutoff - 1 - else: - stream_id_cutoff = None - # Perform the equivalent of a GROUP BY # # Iterate through the updates list and copy non-duplicate @@ -192,10 +175,6 @@ class DeviceWorkerStore(SQLBaseStore): query_map = {} cross_signing_keys_by_user = {} for user_id, device_id, update_stream_id, update_context in updates: - if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff: - # Stop processing updates - break - if ( user_id in master_key_by_user and device_id == master_key_by_user[user_id]["device_id"] @@ -218,17 +197,6 @@ class DeviceWorkerStore(SQLBaseStore): if update_stream_id > previous_update_stream_id: query_map[key] = (update_stream_id, update_context) - # If we didn't find any updates with a stream_id lower than the cutoff, it - # means that there are more than limit updates all of which have the same - # steam_id. - - # That should only happen if a client is spamming the server with new - # devices, in which case E2E isn't going to work well anyway. We'll just - # skip that stream_id and return an empty list, and continue with the next - # stream_id next time. - if not query_map and not cross_signing_keys_by_user: - return stream_id_cutoff, [] - results = yield self._get_device_update_edus_by_remote( destination, from_stream_id, query_map ) @@ -607,22 +575,33 @@ class DeviceWorkerStore(SQLBaseStore): else: return set() - def get_all_device_list_changes_for_remotes(self, from_key, to_key): - """Return a list of `(stream_id, user_id, destination)` which is the - combined list of changes to devices, and which destinations need to be - poked. `destination` may be None if no destinations need to be poked. + async def get_all_device_list_changes_for_remotes( + self, from_key: int, to_key: int, limit: int, + ) -> List[Tuple[int, str]]: + """Return a list of `(stream_id, entity)` which is the combined list of + changes to devices and which destinations need to be poked. Entity is + either a user ID (starting with '@') or a remote destination. """ - # We do a group by here as there can be a large number of duplicate - # entries, since we throw away device IDs. + + # This query Does The Right Thing where it'll correctly apply the + # bounds to the inner queries. sql = """ - SELECT MAX(stream_id) AS stream_id, user_id, destination - FROM device_lists_stream - LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) + SELECT stream_id, entity FROM ( + SELECT stream_id, user_id AS entity FROM device_lists_stream + UNION ALL + SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes + ) AS e WHERE ? < stream_id AND stream_id <= ? - GROUP BY user_id, destination + LIMIT ? """ - return self.db.execute( - "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key + + return await self.db.execute( + "get_all_device_list_changes_for_remotes", + None, + sql, + from_key, + to_key, + limit, ) @cached(max_entries=10000) @@ -1017,29 +996,49 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): """Persist that a user's devices have been updated, and which hosts (if any) should be poked. """ - with self._device_list_id_gen.get_next() as stream_id: + if not device_ids: + return + + with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: + yield self.db.runInteraction( + "add_device_change_to_stream", + self._add_device_change_to_stream_txn, + user_id, + device_ids, + stream_ids, + ) + + if not hosts: + return stream_ids[-1] + + context = get_active_span_text_map() + with self._device_list_id_gen.get_next_mult( + len(hosts) * len(device_ids) + ) as stream_ids: yield self.db.runInteraction( - "add_device_change_to_streams", - self._add_device_change_txn, + "add_device_outbound_poke_to_stream", + self._add_device_outbound_poke_to_stream_txn, user_id, device_ids, hosts, - stream_id, + stream_ids, + context, ) - return stream_id - def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id): - now = self._clock.time_msec() + return stream_ids[-1] + def _add_device_change_to_stream_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_ids: Collection[str], + stream_ids: List[str], + ): txn.call_after( - self._device_list_stream_cache.entity_has_changed, user_id, stream_id + self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1], ) - for host in hosts: - txn.call_after( - self._device_list_federation_stream_cache.entity_has_changed, - host, - stream_id, - ) + + min_stream_id = stream_ids[0] # Delete older entries in the table, as we really only care about # when the latest change happened. @@ -1048,7 +1047,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): DELETE FROM device_lists_stream WHERE user_id = ? AND device_id = ? AND stream_id < ? """, - [(user_id, device_id, stream_id) for device_id in device_ids], + [(user_id, device_id, min_stream_id) for device_id in device_ids], ) self.db.simple_insert_many_txn( @@ -1056,11 +1055,22 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): table="device_lists_stream", values=[ {"stream_id": stream_id, "user_id": user_id, "device_id": device_id} - for device_id in device_ids + for stream_id, device_id in zip(stream_ids, device_ids) ], ) - context = get_active_span_text_map() + def _add_device_outbound_poke_to_stream_txn( + self, txn, user_id, device_ids, hosts, stream_ids, context, + ): + for host in hosts: + txn.call_after( + self._device_list_federation_stream_cache.entity_has_changed, + host, + stream_ids[-1], + ) + + now = self._clock.time_msec() + next_stream_id = iter(stream_ids) self.db.simple_insert_many_txn( txn, @@ -1068,7 +1078,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): values=[ { "destination": destination, - "stream_id": stream_id, + "stream_id": next(next_stream_id), "user_id": user_id, "device_id": device_id, "sent": False, diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index 001a53f9b4..bcf746b7ef 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -537,7 +537,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): return result - def get_all_user_signature_changes_for_remotes(self, from_key, to_key): + def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit): """Return a list of changes from the user signature stream to notify remotes. Note that the user signature stream represents when a user signs their device with their user-signing key, which is not published to other @@ -552,13 +552,19 @@ class EndToEndKeyWorkerStore(SQLBaseStore): Deferred[list[(int,str)]] a list of `(stream_id, user_id)` """ sql = """ - SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id + SELECT stream_id, from_user_id AS user_id FROM user_signature_stream WHERE ? < stream_id AND stream_id <= ? - GROUP BY user_id + ORDER BY stream_id ASC + LIMIT ? """ return self.db.execute( - "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key + "get_all_user_signature_changes_for_remotes", + None, + sql, + from_key, + to_key, + limit, ) diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py index 604c8b7ddd..dab31e0c2d 100644 --- a/synapse/storage/data_stores/main/presence.py +++ b/synapse/storage/data_stores/main/presence.py @@ -60,7 +60,7 @@ class PresenceStore(SQLBaseStore): "status_msg": state.status_msg, "currently_active": state.currently_active, } - for state in presence_states + for stream_id, state in zip(stream_orderings, presence_states) ], ) @@ -73,19 +73,22 @@ class PresenceStore(SQLBaseStore): ) txn.execute(sql + clause, [stream_id] + list(args)) - def get_all_presence_updates(self, last_id, current_id): + def get_all_presence_updates(self, last_id, current_id, limit): if last_id == current_id: return defer.succeed([]) def get_all_presence_updates_txn(txn): - sql = ( - "SELECT stream_id, user_id, state, last_active_ts," - " last_federation_update_ts, last_user_sync_ts, status_msg," - " currently_active" - " FROM presence_stream" - " WHERE ? < stream_id AND stream_id <= ?" - ) - txn.execute(sql, (last_id, current_id)) + sql = """ + SELECT stream_id, user_id, state, last_active_ts, + last_federation_update_ts, last_user_sync_ts, + status_msg, + currently_active + FROM presence_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() return self.db.runInteraction( |