diff --git a/changelog.d/6988.doc b/changelog.d/6988.doc
new file mode 100644
index 0000000000..b6f71bb966
--- /dev/null
+++ b/changelog.d/6988.doc
@@ -0,0 +1 @@
+Improve the documentation for database configuration.
diff --git a/changelog.d/7009.feature b/changelog.d/7009.feature
new file mode 100644
index 0000000000..cd2705d5ba
--- /dev/null
+++ b/changelog.d/7009.feature
@@ -0,0 +1 @@
+Set `Referrer-Policy` header to `no-referrer` on media downloads.
diff --git a/changelog.d/7010.misc b/changelog.d/7010.misc
new file mode 100644
index 0000000000..4ba1f6cdf8
--- /dev/null
+++ b/changelog.d/7010.misc
@@ -0,0 +1 @@
+Change device list streams to have one row per ID.
diff --git a/changelog.d/7011.misc b/changelog.d/7011.misc
new file mode 100644
index 0000000000..41c3b37574
--- /dev/null
+++ b/changelog.d/7011.misc
@@ -0,0 +1 @@
+Remove concept of a non-limited stream.
diff --git a/changelog.d/7089.bugfix b/changelog.d/7089.bugfix
new file mode 100644
index 0000000000..f1f440f23a
--- /dev/null
+++ b/changelog.d/7089.bugfix
@@ -0,0 +1 @@
+Fix a bug in the federation API which could cause occasional "Failed to get PDU" errors.
diff --git a/changelog.d/7110.misc b/changelog.d/7110.misc
new file mode 100644
index 0000000000..fac5bc0403
--- /dev/null
+++ b/changelog.d/7110.misc
@@ -0,0 +1 @@
+Convert some of synapse.rest.media to async/await.
diff --git a/changelog.d/7115.misc b/changelog.d/7115.misc
new file mode 100644
index 0000000000..7d4a011e3e
--- /dev/null
+++ b/changelog.d/7115.misc
@@ -0,0 +1 @@
+De-duplicate / remove unused REST code for login and auth.
diff --git a/changelog.d/7117.bugfix b/changelog.d/7117.bugfix
new file mode 100644
index 0000000000..1896d7ad49
--- /dev/null
+++ b/changelog.d/7117.bugfix
@@ -0,0 +1 @@
+Fix a bug which meant that groups updates were not correctly replicated between workers.
diff --git a/docs/postgres.md b/docs/postgres.md
index ca7ef1cf3a..04aa746051 100644
--- a/docs/postgres.md
+++ b/docs/postgres.md
@@ -104,19 +104,41 @@ of free memory the database host has available.
When you are ready to start using PostgreSQL, edit the `database`
section in your config file to match the following lines:
- database:
- name: psycopg2
- args:
- user: <user>
- password: <pass>
- database: <db>
- host: <host>
- cp_min: 5
- cp_max: 10
+```yaml
+database:
+ name: psycopg2
+ args:
+ user: <user>
+ password: <pass>
+ database: <db>
+ host: <host>
+ cp_min: 5
+ cp_max: 10
+```
All key, values in `args` are passed to the `psycopg2.connect(..)`
function, except keys beginning with `cp_`, which are consumed by the
-twisted adbapi connection pool.
+twisted adbapi connection pool. See the [libpq
+documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS)
+for a list of options which can be passed.
+
+You should consider tuning the `args.keepalives_*` options if there is any danger of
+the connection between your homeserver and database dropping, otherwise Synapse
+may block for an extended period while it waits for a response from the
+database server. Example values might be:
+
+```yaml
+# seconds of inactivity after which TCP should send a keepalive message to the server
+keepalives_idle: 10
+
+# the number of seconds after which a TCP keepalive message that is not
+# acknowledged by the server should be retransmitted
+keepalives_interval: 10
+
+# the number of TCP keepalives that can be lost before the client's connection
+# to the server is considered dead
+keepalives_count: 3
+```
## Porting from SQLite
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 2ff0dd05a2..276e43b732 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -578,13 +578,46 @@ acme:
## Database ##
+# The 'database' setting defines the database that synapse uses to store all of
+# its data.
+#
+# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or
+# 'psycopg2' (for PostgreSQL).
+#
+# 'args' gives options which are passed through to the database engine,
+# except for options starting 'cp_', which are used to configure the Twisted
+# connection pool. For a reference to valid arguments, see:
+# * for sqlite: https://docs.python.org/3/library/sqlite3.html#sqlite3.connect
+# * for postgres: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
+# * for the connection pool: https://twistedmatrix.com/documents/current/api/twisted.enterprise.adbapi.ConnectionPool.html#__init__
+#
+#
+# Example SQLite configuration:
+#
+#database:
+# name: sqlite3
+# args:
+# database: /path/to/homeserver.db
+#
+#
+# Example Postgres configuration:
+#
+#database:
+# name: psycopg2
+# args:
+# user: synapse
+# password: secretpassword
+# database: synapse
+# host: localhost
+# cp_min: 5
+# cp_max: 10
+#
+# For more information on using Synapse with Postgres, see `docs/postgres.md`.
+#
database:
- # The database engine name
- name: "sqlite3"
- # Arguments to pass to the engine
+ name: sqlite3
args:
- # Path to the database
- database: "DATADIR/homeserver.db"
+ database: DATADIR/homeserver.db
# Number of events to cache in memory.
#
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index b2c764bfe8..136babe6ce 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -65,12 +65,23 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.replication.tcp.streams._base import (
+from synapse.replication.tcp.streams import (
+ AccountDataStream,
DeviceListsStream,
+ GroupServerStream,
+ PresenceStream,
+ PushersStream,
+ PushRulesStream,
ReceiptsStream,
+ TagAccountDataStream,
ToDeviceStream,
+ TypingStream,
+)
+from synapse.replication.tcp.streams.events import (
+ EventsStream,
+ EventsStreamEventRow,
+ EventsStreamRow,
)
-from synapse.replication.tcp.streams.events import EventsStreamEventRow, EventsStreamRow
from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.client.v1 import events
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
@@ -626,7 +637,7 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
if self.send_handler:
self.send_handler.process_replication_rows(stream_name, token, rows)
- if stream_name == "events":
+ if stream_name == EventsStream.NAME:
# We shouldn't get multiple rows per token for events stream, so
# we don't need to optimise this for multiple rows.
for row in rows:
@@ -649,43 +660,44 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
)
await self.pusher_pool.on_new_notifications(token, token)
- elif stream_name == "push_rules":
+ elif stream_name == PushRulesStream.NAME:
self.notifier.on_new_event(
"push_rules_key", token, users=[row.user_id for row in rows]
)
- elif stream_name in ("account_data", "tag_account_data"):
+ elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME):
self.notifier.on_new_event(
"account_data_key", token, users=[row.user_id for row in rows]
)
- elif stream_name == "receipts":
+ elif stream_name == ReceiptsStream.NAME:
self.notifier.on_new_event(
"receipt_key", token, rooms=[row.room_id for row in rows]
)
await self.pusher_pool.on_new_receipts(
token, token, {row.room_id for row in rows}
)
- elif stream_name == "typing":
+ elif stream_name == TypingStream.NAME:
self.typing_handler.process_replication_rows(token, rows)
self.notifier.on_new_event(
"typing_key", token, rooms=[row.room_id for row in rows]
)
- elif stream_name == "to_device":
+ elif stream_name == ToDeviceStream.NAME:
entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities:
self.notifier.on_new_event("to_device_key", token, users=entities)
- elif stream_name == "device_lists":
+ elif stream_name == DeviceListsStream.NAME:
all_room_ids = set()
for row in rows:
- room_ids = await self.store.get_rooms_for_user(row.user_id)
- all_room_ids.update(room_ids)
+ if row.entity.startswith("@"):
+ room_ids = await self.store.get_rooms_for_user(row.entity)
+ all_room_ids.update(room_ids)
self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
- elif stream_name == "presence":
+ elif stream_name == PresenceStream.NAME:
await self.presence_handler.process_replication_rows(token, rows)
- elif stream_name == "receipts":
+ elif stream_name == GroupServerStream.NAME:
self.notifier.on_new_event(
"groups_key", token, users=[row.user_id for row in rows]
)
- elif stream_name == "pushers":
+ elif stream_name == PushersStream.NAME:
for row in rows:
if row.deleted:
self.stop_pusher(row.user_id, row.app_id, row.pushkey)
@@ -774,7 +786,10 @@ class FederationSenderHandler(object):
# ... as well as device updates and messages
elif stream_name == DeviceListsStream.NAME:
- hosts = {row.destination for row in rows}
+ # The entities are either user IDs (starting with '@') whose devices
+ # have changed, or remote servers that we need to tell about
+ # changes.
+ hosts = {row.entity for row in rows if not row.entity.startswith("@")}
for host in hosts:
self.federation_sender.send_device_messages(host)
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index ba846042c4..efe2af5504 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -294,7 +294,6 @@ class RootConfig(object):
report_stats=None,
open_private_ports=False,
listeners=None,
- database_conf=None,
tls_certificate_path=None,
tls_private_key_path=None,
acme_domain=None,
@@ -367,7 +366,6 @@ class RootConfig(object):
report_stats=report_stats,
open_private_ports=open_private_ports,
listeners=listeners,
- database_conf=database_conf,
tls_certificate_path=tls_certificate_path,
tls_private_key_path=tls_private_key_path,
acme_domain=acme_domain,
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 219b32f670..b8ab2f86ac 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,14 +15,60 @@
# limitations under the License.
import logging
import os
-from textwrap import indent
-
-import yaml
from synapse.config._base import Config, ConfigError
logger = logging.getLogger(__name__)
+DEFAULT_CONFIG = """\
+## Database ##
+
+# The 'database' setting defines the database that synapse uses to store all of
+# its data.
+#
+# 'name' gives the database engine to use: either 'sqlite3' (for SQLite) or
+# 'psycopg2' (for PostgreSQL).
+#
+# 'args' gives options which are passed through to the database engine,
+# except for options starting 'cp_', which are used to configure the Twisted
+# connection pool. For a reference to valid arguments, see:
+# * for sqlite: https://docs.python.org/3/library/sqlite3.html#sqlite3.connect
+# * for postgres: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS
+# * for the connection pool: https://twistedmatrix.com/documents/current/api/twisted.enterprise.adbapi.ConnectionPool.html#__init__
+#
+#
+# Example SQLite configuration:
+#
+#database:
+# name: sqlite3
+# args:
+# database: /path/to/homeserver.db
+#
+#
+# Example Postgres configuration:
+#
+#database:
+# name: psycopg2
+# args:
+# user: synapse
+# password: secretpassword
+# database: synapse
+# host: localhost
+# cp_min: 5
+# cp_max: 10
+#
+# For more information on using Synapse with Postgres, see `docs/postgres.md`.
+#
+database:
+ name: sqlite3
+ args:
+ database: %(database_path)s
+
+# Number of events to cache in memory.
+#
+#event_cache_size: 10K
+"""
+
class DatabaseConnectionConfig:
"""Contains the connection config for a particular database.
@@ -36,10 +83,12 @@ class DatabaseConnectionConfig:
"""
def __init__(self, name: str, db_config: dict):
- if db_config["name"] not in ("sqlite3", "psycopg2"):
- raise ConfigError("Unsupported database type %r" % (db_config["name"],))
+ db_engine = db_config.get("name", "sqlite3")
- if db_config["name"] == "sqlite3":
+ if db_engine not in ("sqlite3", "psycopg2"):
+ raise ConfigError("Unsupported database type %r" % (db_engine,))
+
+ if db_engine == "sqlite3":
db_config.setdefault("args", {}).update(
{"cp_min": 1, "cp_max": 1, "check_same_thread": False}
)
@@ -97,34 +146,10 @@ class DatabaseConfig(Config):
self.set_databasepath(config.get("database_path"))
- def generate_config_section(self, data_dir_path, database_conf, **kwargs):
- if not database_conf:
- database_path = os.path.join(data_dir_path, "homeserver.db")
- database_conf = (
- """# The database engine name
- name: "sqlite3"
- # Arguments to pass to the engine
- args:
- # Path to the database
- database: "%(database_path)s"
- """
- % locals()
- )
- else:
- database_conf = indent(yaml.dump(database_conf), " " * 10).lstrip()
-
- return (
- """\
- ## Database ##
-
- database:
- %(database_conf)s
- # Number of events to cache in memory.
- #
- #event_cache_size: 10K
- """
- % locals()
- )
+ def generate_config_section(self, data_dir_path, **kwargs):
+ return DEFAULT_CONFIG % {
+ "database_path": os.path.join(data_dir_path, "homeserver.db")
+ }
def read_arguments(self, args):
self.set_databasepath(args.database_path)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 5c991e5412..b0b0eba41e 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -25,11 +25,7 @@ from twisted.python.failure import Failure
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
-from synapse.api.room_versions import (
- KNOWN_ROOM_VERSIONS,
- EventFormatVersions,
- RoomVersion,
-)
+from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.crypto.event_signing import check_event_content_hash
from synapse.crypto.keyring import Keyring
from synapse.events import EventBase, make_event_from_dict
@@ -55,13 +51,15 @@ class FederationBase(object):
self.store = hs.get_datastore()
self._clock = hs.get_clock()
- def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
+ def _check_sigs_and_hash(
+ self, room_version: RoomVersion, pdu: EventBase
+ ) -> Deferred:
return make_deferred_yieldable(
self._check_sigs_and_hashes(room_version, [pdu])[0]
)
def _check_sigs_and_hashes(
- self, room_version: str, pdus: List[EventBase]
+ self, room_version: RoomVersion, pdus: List[EventBase]
) -> List[Deferred]:
"""Checks that each of the received events is correctly signed by the
sending server.
@@ -146,7 +144,7 @@ class PduToCheckSig(
def _check_sigs_on_pdus(
- keyring: Keyring, room_version: str, pdus: Iterable[EventBase]
+ keyring: Keyring, room_version: RoomVersion, pdus: Iterable[EventBase]
) -> List[Deferred]:
"""Check that the given events are correctly signed
@@ -191,10 +189,6 @@ def _check_sigs_on_pdus(
for p in pdus
]
- v = KNOWN_ROOM_VERSIONS.get(room_version)
- if not v:
- raise RuntimeError("Unrecognized room version %s" % (room_version,))
-
# First we check that the sender event is signed by the sender's domain
# (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
@@ -204,7 +198,7 @@ def _check_sigs_on_pdus(
(
p.sender_domain,
p.redacted_pdu_json,
- p.pdu.origin_server_ts if v.enforce_key_validity else 0,
+ p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_sender
@@ -227,7 +221,7 @@ def _check_sigs_on_pdus(
# event id's domain (normally only the case for joins/leaves), and add additional
# checks. Only do this if the room version has a concept of event ID domain
# (ie, the room version uses old-style non-hash event IDs).
- if v.event_format == EventFormatVersions.V1:
+ if room_version.event_format == EventFormatVersions.V1:
pdus_to_check_event_id = [
p
for p in pdus_to_check
@@ -239,7 +233,7 @@ def _check_sigs_on_pdus(
(
get_domain_from_id(p.pdu.event_id),
p.redacted_pdu_json,
- p.pdu.origin_server_ts if v.enforce_key_validity else 0,
+ p.pdu.origin_server_ts if room_version.enforce_key_validity else 0,
p.pdu.event_id,
)
for p in pdus_to_check_event_id
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 8c6b839478..a0071fec94 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -220,8 +220,7 @@ class FederationClient(FederationBase):
# FIXME: We should handle signature failures more gracefully.
pdus[:] = await make_deferred_yieldable(
defer.gatherResults(
- self._check_sigs_and_hashes(room_version.identifier, pdus),
- consumeErrors=True,
+ self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True,
).addErrback(unwrapFirstError)
)
@@ -291,9 +290,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
- signed_pdu = await self._check_sigs_and_hash(
- room_version.identifier, pdu
- )
+ signed_pdu = await self._check_sigs_and_hash(room_version, pdu)
break
@@ -350,7 +347,7 @@ class FederationClient(FederationBase):
self,
origin: str,
pdus: List[EventBase],
- room_version: str,
+ room_version: RoomVersion,
outlier: bool = False,
include_none: bool = False,
) -> List[EventBase]:
@@ -396,7 +393,7 @@ class FederationClient(FederationBase):
self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
- room_version=room_version, # type: ignore
+ room_version=room_version,
outlier=outlier,
timeout=10000,
)
@@ -434,7 +431,7 @@ class FederationClient(FederationBase):
]
signed_auth = await self._check_sigs_and_hash_and_fetch(
- destination, auth_chain, outlier=True, room_version=room_version.identifier
+ destination, auth_chain, outlier=True, room_version=room_version
)
signed_auth.sort(key=lambda e: e.depth)
@@ -661,7 +658,7 @@ class FederationClient(FederationBase):
destination,
list(pdus.values()),
outlier=True,
- room_version=room_version.identifier,
+ room_version=room_version,
)
valid_pdus_map = {p.event_id: p for p in valid_pdus}
@@ -756,7 +753,7 @@ class FederationClient(FederationBase):
pdu = event_from_pdu_json(pdu_dict, room_version)
# Check signatures are correct.
- pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
# FIXME: We should handle signature failures more gracefully.
@@ -948,7 +945,7 @@ class FederationClient(FederationBase):
]
signed_events = await self._check_sigs_and_hash_and_fetch(
- destination, events, outlier=False, room_version=room_version.identifier
+ destination, events, outlier=False, room_version=room_version
)
except HttpResponseException as e:
if not e.code == 400:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 275b9c99d7..89d521bc31 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -409,7 +409,7 @@ class FederationServer(FederationBase):
pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, pdu.room_id)
- pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version)
time_now = self._clock.time_msec()
return {"event": ret_pdu.get_pdu_json(time_now)}
@@ -425,7 +425,7 @@ class FederationServer(FederationBase):
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
- pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
res_pdus = await self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
@@ -455,7 +455,7 @@ class FederationServer(FederationBase):
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
- pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
await self.handler.on_send_leave_request(origin, pdu)
return {}
@@ -611,7 +611,7 @@ class FederationServer(FederationBase):
logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
# We've already checked that we know the room version by this point
- room_version = await self.store.get_room_version_id(pdu.room_id)
+ room_version = await self.store.get_room_version(pdu.room_id)
# Check signature.
try:
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 5526015ddb..6912165622 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -747,7 +747,7 @@ class PresenceHandler(object):
return False
- async def get_all_presence_updates(self, last_id, current_id):
+ async def get_all_presence_updates(self, last_id, current_id, limit):
"""
Gets a list of presence update rows from between the given stream ids.
Each row has:
@@ -762,7 +762,7 @@ class PresenceHandler(object):
"""
# TODO(markjh): replicate the unpersisted changes.
# This could use the in-memory stores for recent changes.
- rows = await self.store.get_all_presence_updates(last_id, current_id)
+ rows = await self.store.get_all_presence_updates(last_id, current_id, limit)
return rows
def notify_new_event(self):
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 391bceb0c4..c7bc14c623 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,6 +15,7 @@
import logging
from collections import namedtuple
+from typing import List
from twisted.internet import defer
@@ -257,7 +258,13 @@ class TypingHandler(object):
"typing_key", self._latest_room_serial, rooms=[member.room_id]
)
- async def get_all_typing_updates(self, last_id, current_id):
+ async def get_all_typing_updates(
+ self, last_id: int, current_id: int, limit: int
+ ) -> List[dict]:
+ """Get up to `limit` typing updates between the given tokens, earliest
+ updates first.
+ """
+
if last_id == current_id:
return []
@@ -275,7 +282,7 @@ class TypingHandler(object):
typing = self._room_typing[room_id]
rows.append((serial, room_id, list(typing)))
rows.sort()
- return rows
+ return rows[:limit]
def get_current_token(self):
return self._latest_room_serial
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 1c77687eea..23b1650e41 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -29,7 +29,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
self.hs = hs
self._device_list_id_gen = SlavedIdTracker(
- db_conn, "device_lists_stream", "stream_id"
+ db_conn,
+ "device_lists_stream",
+ "stream_id",
+ extra_tables=[
+ ("user_signature_stream", "stream_id"),
+ ("device_lists_outbound_pokes", "stream_id"),
+ ],
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
@@ -55,23 +61,27 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
def process_replication_rows(self, stream_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
- for row in rows:
- self._invalidate_caches_for_devices(token, row.user_id, row.destination)
+ self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
+ self._device_list_id_gen.advance(token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
- def _invalidate_caches_for_devices(self, token, user_id, destination):
- self._device_list_stream_cache.entity_has_changed(user_id, token)
-
- if destination:
- self._device_list_federation_stream_cache.entity_has_changed(
- destination, token
- )
+ def _invalidate_caches_for_devices(self, token, rows):
+ for row in rows:
+ # The entities are either user IDs (starting with '@') whose devices
+ # have changed, or remote servers that we need to tell about
+ # changes.
+ if row.entity.startswith("@"):
+ self._device_list_stream_cache.entity_has_changed(row.entity, token)
+ self.get_cached_devices_for_user.invalidate((row.entity,))
+ self._get_cached_user_device.invalidate_many((row.entity,))
+ self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
- self.get_cached_devices_for_user.invalidate((user_id,))
- self._get_cached_user_device.invalidate_many((user_id,))
- self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
+ else:
+ self._device_list_federation_stream_cache.entity_has_changed(
+ row.entity, token
+ )
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index ce9d1fae12..6e2ebaf614 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -166,11 +166,6 @@ class ReplicationStreamer(object):
self.pending_updates = False
with Measure(self.clock, "repl.stream.get_updates"):
- # First we tell the streams that they should update their
- # current tokens.
- for stream in self.streams:
- stream.advance_current_token()
-
all_streams = self.streams
if self._replication_torture_level is not None:
@@ -180,7 +175,7 @@ class ReplicationStreamer(object):
random.shuffle(all_streams)
for stream in all_streams:
- if stream.last_token == stream.upto_token:
+ if stream.last_token == stream.current_token():
continue
if self._replication_torture_level:
@@ -192,7 +187,7 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s",
stream.NAME,
stream.last_token,
- stream.upto_token,
+ stream.current_token(),
)
try:
updates, current_token = await stream.get_updates()
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 5f52264e84..29199f5b46 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -24,27 +24,61 @@ Each stream is defined by the following information:
current_token: The function that returns the current token for the stream
update_function: The function that returns a list of updates between two tokens
"""
-
-from . import _base, events, federation
+from synapse.replication.tcp.streams._base import (
+ AccountDataStream,
+ BackfillStream,
+ CachesStream,
+ DeviceListsStream,
+ GroupServerStream,
+ PresenceStream,
+ PublicRoomsStream,
+ PushersStream,
+ PushRulesStream,
+ ReceiptsStream,
+ TagAccountDataStream,
+ ToDeviceStream,
+ TypingStream,
+ UserSignatureStream,
+)
+from synapse.replication.tcp.streams.events import EventsStream
+from synapse.replication.tcp.streams.federation import FederationStream
STREAMS_MAP = {
stream.NAME: stream
for stream in (
- events.EventsStream,
- _base.BackfillStream,
- _base.PresenceStream,
- _base.TypingStream,
- _base.ReceiptsStream,
- _base.PushRulesStream,
- _base.PushersStream,
- _base.CachesStream,
- _base.PublicRoomsStream,
- _base.DeviceListsStream,
- _base.ToDeviceStream,
- federation.FederationStream,
- _base.TagAccountDataStream,
- _base.AccountDataStream,
- _base.GroupServerStream,
- _base.UserSignatureStream,
+ EventsStream,
+ BackfillStream,
+ PresenceStream,
+ TypingStream,
+ ReceiptsStream,
+ PushRulesStream,
+ PushersStream,
+ CachesStream,
+ PublicRoomsStream,
+ DeviceListsStream,
+ ToDeviceStream,
+ FederationStream,
+ TagAccountDataStream,
+ AccountDataStream,
+ GroupServerStream,
+ UserSignatureStream,
)
}
+
+__all__ = [
+ "STREAMS_MAP",
+ "BackfillStream",
+ "PresenceStream",
+ "TypingStream",
+ "ReceiptsStream",
+ "PushRulesStream",
+ "PushersStream",
+ "CachesStream",
+ "PublicRoomsStream",
+ "DeviceListsStream",
+ "ToDeviceStream",
+ "TagAccountDataStream",
+ "AccountDataStream",
+ "GroupServerStream",
+ "UserSignatureStream",
+]
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 208e8a667b..abf5c6c6a8 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -17,10 +17,12 @@
import itertools
import logging
from collections import namedtuple
-from typing import Any, List, Optional
+from typing import Any, List, Optional, Tuple
import attr
+from synapse.types import JsonDict
+
logger = logging.getLogger(__name__)
@@ -94,9 +96,13 @@ PublicRoomsStreamRow = namedtuple(
"network_id", # str, optional
),
)
-DeviceListsStreamRow = namedtuple(
- "DeviceListsStreamRow", ("user_id", "destination") # str # str
-)
+
+
+@attr.s
+class DeviceListsStreamRow:
+ entity = attr.ib(type=str)
+
+
ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
TagAccountDataStreamRow = namedtuple(
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
@@ -115,13 +121,12 @@ class Stream(object):
"""Base class for the streams.
Provides a `get_updates()` function that returns new updates since the last
- time it was called up until the point `advance_current_token` was called.
+ time it was called.
"""
NAME = None # type: str # The name of the stream
# The type of the row. Used by the default impl of parse_row.
ROW_TYPE = None # type: Any
- _LIMITED = True # Whether the update function takes a limit
@classmethod
def parse_row(cls, row):
@@ -142,26 +147,15 @@ class Stream(object):
# The token from which we last asked for updates
self.last_token = self.current_token()
- # The token that we will get updates up to
- self.upto_token = self.current_token()
-
- def advance_current_token(self):
- """Updates `upto_token` to "now", which updates up until which point
- get_updates[_since] will fetch rows till.
- """
- self.upto_token = self.current_token()
-
def discard_updates_and_advance(self):
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
"""
- self.upto_token = self.current_token()
- self.last_token = self.upto_token
+ self.last_token = self.current_token()
async def get_updates(self):
"""Gets all updates since the last time this function was called (or
- since the stream was constructed if it hadn't been called before),
- until the `upto_token`
+ since the stream was constructed if it hadn't been called before).
Returns:
Deferred[Tuple[List[Tuple[int, Any]], int]:
@@ -174,44 +168,45 @@ class Stream(object):
return updates, current_token
- async def get_updates_since(self, from_token):
+ async def get_updates_since(
+ self, from_token: int
+ ) -> Tuple[List[Tuple[int, JsonDict]], int]:
"""Like get_updates except allows specifying from when we should
stream updates
Returns:
- Deferred[Tuple[List[Tuple[int, Any]], int]:
- Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
- list of ``(token, row)`` entries. ``row`` will be json-serialised and
- sent over the replication steam.
+ Resolves to a pair `(updates, new_last_token)`, where `updates` is
+ a list of `(token, row)` entries and `new_last_token` is the new
+ position in stream.
"""
+
if from_token in ("NOW", "now"):
- return [], self.upto_token
+ return [], self.current_token()
- current_token = self.upto_token
+ current_token = self.current_token()
from_token = int(from_token)
if from_token == current_token:
return [], current_token
- logger.info("get_updates_since: %s", self.__class__)
- if self._LIMITED:
- rows = await self.update_function(
- from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
- )
+ rows = await self.update_function(
+ from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
+ )
- # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
- rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
- else:
- rows = await self.update_function(from_token, current_token)
+ # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
+ rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
updates = [(row[0], row[1:]) for row in rows]
# check we didn't get more rows than the limit.
# doing it like this allows the update_function to be a generator.
- if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
+ if len(updates) >= MAX_EVENTS_BEHIND:
raise Exception("stream %s has fallen behind" % (self.NAME))
+ # The update function didn't hit the limit, so we must have got all
+ # the updates to `current_token`, and can return that as our new
+ # stream position.
return updates, current_token
def current_token(self):
@@ -223,9 +218,8 @@ class Stream(object):
"""
raise NotImplementedError()
- def update_function(self, from_token, current_token, limit=None):
- """Get updates between from_token and to_token. If Stream._LIMITED is
- True then limit is provided, otherwise it's not.
+ def update_function(self, from_token, current_token, limit):
+ """Get updates between from_token and to_token.
Returns:
Deferred(list(tuple)): the first entry in the tuple is the token for
@@ -253,7 +247,6 @@ class BackfillStream(Stream):
class PresenceStream(Stream):
NAME = "presence"
- _LIMITED = False
ROW_TYPE = PresenceStreamRow
def __init__(self, hs):
@@ -268,7 +261,6 @@ class PresenceStream(Stream):
class TypingStream(Stream):
NAME = "typing"
- _LIMITED = False
ROW_TYPE = TypingStreamRow
def __init__(self, hs):
@@ -363,11 +355,11 @@ class PublicRoomsStream(Stream):
class DeviceListsStream(Stream):
- """Someone added/changed/removed a device
+ """Either a user has updated their devices or a remote server needs to be
+ told about a device update.
"""
NAME = "device_lists"
- _LIMITED = False
ROW_TYPE = DeviceListsStreamRow
def __init__(self, hs):
@@ -457,7 +449,6 @@ class UserSignatureStream(Stream):
"""
NAME = "user_signature"
- _LIMITED = False
ROW_TYPE = UserSignatureStreamRow
def __init__(self, hs):
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d0d4999795..31551524f8 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -28,7 +28,6 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
-from synapse.push.mailer import load_jinja2_templates
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart
@@ -548,13 +547,6 @@ class SSOAuthHandler(object):
self._registration_handler = hs.get_registration_handler()
self._macaroon_gen = hs.get_macaroon_generator()
- # Load the redirect page HTML template
- self._template = load_jinja2_templates(
- hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
- )[0]
-
- self._server_name = hs.config.server_name
-
# cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 50e080673b..85cf5a14c6 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -142,14 +142,6 @@ class AuthRestServlet(RestServlet):
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
"sitekey": self.hs.config.recaptcha_public_key,
}
- html_bytes = html.encode("utf8")
- request.setResponseCode(200)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
-
- request.write(html_bytes)
- finish_request(request)
- return None
elif stagetype == LoginType.TERMS:
html = TERMS_TEMPLATE % {
"session": session,
@@ -158,17 +150,19 @@ class AuthRestServlet(RestServlet):
"myurl": "%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
}
- html_bytes = html.encode("utf8")
- request.setResponseCode(200)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
-
- request.write(html_bytes)
- finish_request(request)
- return None
else:
raise SynapseError(404, "Unknown auth stage type")
+ # Render the HTML and return.
+ html_bytes = html.encode("utf8")
+ request.setResponseCode(200)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
+
+ request.write(html_bytes)
+ finish_request(request)
+ return None
+
async def on_POST(self, request, stagetype):
session = parse_string(request, "session")
@@ -196,15 +190,6 @@ class AuthRestServlet(RestServlet):
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
"sitekey": self.hs.config.recaptcha_public_key,
}
- html_bytes = html.encode("utf8")
- request.setResponseCode(200)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
-
- request.write(html_bytes)
- finish_request(request)
-
- return None
elif stagetype == LoginType.TERMS:
authdict = {"session": session}
@@ -225,17 +210,19 @@ class AuthRestServlet(RestServlet):
"myurl": "%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
}
- html_bytes = html.encode("utf8")
- request.setResponseCode(200)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
-
- request.write(html_bytes)
- finish_request(request)
- return None
else:
raise SynapseError(404, "Unknown auth stage type")
+ # Render the HTML and return.
+ html_bytes = html.encode("utf8")
+ request.setResponseCode(200)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
+
+ request.write(html_bytes)
+ finish_request(request)
+ return None
+
def on_OPTIONS(self, _):
return 200, {}
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 66a01559e1..24d3ae5bbc 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -50,6 +50,9 @@ class DownloadResource(DirectServeResource):
b" media-src 'self';"
b" object-src 'self';",
)
+ request.setHeader(
+ b"Referrer-Policy", b"no-referrer",
+ )
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
await self.media_repo.get_local_media(request, media_id, name)
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 490b1b45a8..fd10d42f2f 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -24,7 +24,6 @@ from six import iteritems
import twisted.internet.error
import twisted.web.http
-from twisted.internet import defer
from twisted.web.resource import Resource
from synapse.api.errors import (
@@ -114,15 +113,14 @@ class MediaRepository(object):
"update_recently_accessed_media", self._update_recently_accessed
)
- @defer.inlineCallbacks
- def _update_recently_accessed(self):
+ async def _update_recently_accessed(self):
remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
local_media = self.recently_accessed_locals
self.recently_accessed_locals = set()
- yield self.store.update_cached_last_access_time(
+ await self.store.update_cached_last_access_time(
local_media, remote_media, self.clock.time_msec()
)
@@ -138,8 +136,7 @@ class MediaRepository(object):
else:
self.recently_accessed_locals.add(media_id)
- @defer.inlineCallbacks
- def create_content(
+ async def create_content(
self, media_type, upload_name, content, content_length, auth_user
):
"""Store uploaded content for a local user and return the mxc URL
@@ -158,11 +155,11 @@ class MediaRepository(object):
file_info = FileInfo(server_name=None, file_id=media_id)
- fname = yield self.media_storage.store_file(content, file_info)
+ fname = await self.media_storage.store_file(content, file_info)
logger.info("Stored local media in file %r", fname)
- yield self.store.store_local_media(
+ await self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
@@ -171,12 +168,11 @@ class MediaRepository(object):
user_id=auth_user,
)
- yield self._generate_thumbnails(None, media_id, media_id, media_type)
+ await self._generate_thumbnails(None, media_id, media_id, media_type)
return "mxc://%s/%s" % (self.server_name, media_id)
- @defer.inlineCallbacks
- def get_local_media(self, request, media_id, name):
+ async def get_local_media(self, request, media_id, name):
"""Responds to reqests for local media, if exists, or returns 404.
Args:
@@ -190,7 +186,7 @@ class MediaRepository(object):
Deferred: Resolves once a response has successfully been written
to request
"""
- media_info = yield self.store.get_local_media(media_id)
+ media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
respond_404(request)
return
@@ -204,13 +200,12 @@ class MediaRepository(object):
file_info = FileInfo(None, media_id, url_cache=url_cache)
- responder = yield self.media_storage.fetch_media(file_info)
- yield respond_with_responder(
+ responder = await self.media_storage.fetch_media(file_info)
+ await respond_with_responder(
request, responder, media_type, media_length, upload_name
)
- @defer.inlineCallbacks
- def get_remote_media(self, request, server_name, media_id, name):
+ async def get_remote_media(self, request, server_name, media_id, name):
"""Respond to requests for remote media.
Args:
@@ -236,8 +231,8 @@ class MediaRepository(object):
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
- with (yield self.remote_media_linearizer.queue(key)):
- responder, media_info = yield self._get_remote_media_impl(
+ with (await self.remote_media_linearizer.queue(key)):
+ responder, media_info = await self._get_remote_media_impl(
server_name, media_id
)
@@ -246,14 +241,13 @@ class MediaRepository(object):
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
- yield respond_with_responder(
+ await respond_with_responder(
request, responder, media_type, media_length, upload_name
)
else:
respond_404(request)
- @defer.inlineCallbacks
- def get_remote_media_info(self, server_name, media_id):
+ async def get_remote_media_info(self, server_name, media_id):
"""Gets the media info associated with the remote file, downloading
if necessary.
@@ -274,8 +268,8 @@ class MediaRepository(object):
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
- with (yield self.remote_media_linearizer.queue(key)):
- responder, media_info = yield self._get_remote_media_impl(
+ with (await self.remote_media_linearizer.queue(key)):
+ responder, media_info = await self._get_remote_media_impl(
server_name, media_id
)
@@ -286,8 +280,7 @@ class MediaRepository(object):
return media_info
- @defer.inlineCallbacks
- def _get_remote_media_impl(self, server_name, media_id):
+ async def _get_remote_media_impl(self, server_name, media_id):
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -299,7 +292,7 @@ class MediaRepository(object):
Returns:
Deferred[(Responder, media_info)]
"""
- media_info = yield self.store.get_cached_remote_media(server_name, media_id)
+ media_info = await self.store.get_cached_remote_media(server_name, media_id)
# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise genereate a new
@@ -317,19 +310,18 @@ class MediaRepository(object):
logger.info("Media is quarantined")
raise NotFoundError()
- responder = yield self.media_storage.fetch_media(file_info)
+ responder = await self.media_storage.fetch_media(file_info)
if responder:
return responder, media_info
# Failed to find the file anywhere, lets download it.
- media_info = yield self._download_remote_file(server_name, media_id, file_id)
+ media_info = await self._download_remote_file(server_name, media_id, file_id)
- responder = yield self.media_storage.fetch_media(file_info)
+ responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
- @defer.inlineCallbacks
- def _download_remote_file(self, server_name, media_id, file_id):
+ async def _download_remote_file(self, server_name, media_id, file_id):
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -351,7 +343,7 @@ class MediaRepository(object):
("/_matrix/media/v1/download", server_name, media_id)
)
try:
- length, headers = yield self.client.get_file(
+ length, headers = await self.client.get_file(
server_name,
request_path,
output_stream=f,
@@ -397,7 +389,7 @@ class MediaRepository(object):
)
raise SynapseError(502, "Failed to fetch remote media")
- yield finish()
+ await finish()
media_type = headers[b"Content-Type"][0].decode("ascii")
upload_name = get_filename_from_headers(headers)
@@ -405,7 +397,7 @@ class MediaRepository(object):
logger.info("Stored remote media in file %r", fname)
- yield self.store.store_cached_remote_media(
+ await self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
@@ -423,7 +415,7 @@ class MediaRepository(object):
"filesystem_id": file_id,
}
- yield self._generate_thumbnails(server_name, media_id, file_id, media_type)
+ await self._generate_thumbnails(server_name, media_id, file_id, media_type)
return media_info
@@ -458,16 +450,15 @@ class MediaRepository(object):
return t_byte_source
- @defer.inlineCallbacks
- def generate_local_exact_thumbnail(
+ async def generate_local_exact_thumbnail(
self, media_id, t_width, t_height, t_method, t_type, url_cache
):
- input_path = yield self.media_storage.ensure_media_is_in_local_cache(
+ input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
)
thumbnailer = Thumbnailer(input_path)
- t_byte_source = yield defer_to_thread(
+ t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
@@ -490,7 +481,7 @@ class MediaRepository(object):
thumbnail_type=t_type,
)
- output_path = yield self.media_storage.store_file(
+ output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@@ -500,22 +491,21 @@ class MediaRepository(object):
t_len = os.path.getsize(output_path)
- yield self.store.store_local_thumbnail(
+ await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
return output_path
- @defer.inlineCallbacks
- def generate_remote_exact_thumbnail(
+ async def generate_remote_exact_thumbnail(
self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
):
- input_path = yield self.media_storage.ensure_media_is_in_local_cache(
+ input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=False)
)
thumbnailer = Thumbnailer(input_path)
- t_byte_source = yield defer_to_thread(
+ t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
@@ -537,7 +527,7 @@ class MediaRepository(object):
thumbnail_type=t_type,
)
- output_path = yield self.media_storage.store_file(
+ output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@@ -547,7 +537,7 @@ class MediaRepository(object):
t_len = os.path.getsize(output_path)
- yield self.store.store_remote_media_thumbnail(
+ await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
@@ -560,8 +550,7 @@ class MediaRepository(object):
return output_path
- @defer.inlineCallbacks
- def _generate_thumbnails(
+ async def _generate_thumbnails(
self, server_name, media_id, file_id, media_type, url_cache=False
):
"""Generate and store thumbnails for an image.
@@ -582,7 +571,7 @@ class MediaRepository(object):
if not requirements:
return
- input_path = yield self.media_storage.ensure_media_is_in_local_cache(
+ input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
)
@@ -600,7 +589,7 @@ class MediaRepository(object):
return
if thumbnailer.transpose_method is not None:
- m_width, m_height = yield defer_to_thread(
+ m_width, m_height = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.transpose
)
@@ -620,11 +609,11 @@ class MediaRepository(object):
for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
# Generate the thumbnail
if t_method == "crop":
- t_byte_source = yield defer_to_thread(
+ t_byte_source = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type
)
elif t_method == "scale":
- t_byte_source = yield defer_to_thread(
+ t_byte_source = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type
)
else:
@@ -646,7 +635,7 @@ class MediaRepository(object):
url_cache=url_cache,
)
- output_path = yield self.media_storage.store_file(
+ output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@@ -656,7 +645,7 @@ class MediaRepository(object):
# Write to database
if server_name:
- yield self.store.store_remote_media_thumbnail(
+ await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
@@ -667,15 +656,14 @@ class MediaRepository(object):
t_len,
)
else:
- yield self.store.store_local_thumbnail(
+ await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
return {"width": m_width, "height": m_height}
- @defer.inlineCallbacks
- def delete_old_remote_media(self, before_ts):
- old_media = yield self.store.get_remote_media_before(before_ts)
+ async def delete_old_remote_media(self, before_ts):
+ old_media = await self.store.get_remote_media_before(before_ts)
deleted = 0
@@ -689,7 +677,7 @@ class MediaRepository(object):
# TODO: Should we delete from the backup store
- with (yield self.remote_media_linearizer.queue(key)):
+ with (await self.remote_media_linearizer.queue(key)):
full_path = self.filepaths.remote_media_filepath(origin, file_id)
try:
os.remove(full_path)
@@ -705,7 +693,7 @@ class MediaRepository(object):
)
shutil.rmtree(thumbnail_dir, ignore_errors=True)
- yield self.store.delete_remote_media(origin, media_id)
+ await self.store.delete_remote_media(origin, media_id)
deleted += 1
return {"deleted": deleted}
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 07e395cfd1..c46676f8fc 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -165,8 +165,7 @@ class PreviewUrlResource(DirectServeResource):
og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
respond_with_json_bytes(request, 200, og, send_cors=True)
- @defer.inlineCallbacks
- def _do_preview(self, url, user, ts):
+ async def _do_preview(self, url, user, ts):
"""Check the db, and download the URL and build a preview
Args:
@@ -179,7 +178,7 @@ class PreviewUrlResource(DirectServeResource):
"""
# check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)
- cache_result = yield self.store.get_url_cache(url, ts)
+ cache_result = await self.store.get_url_cache(url, ts)
if (
cache_result
and cache_result["expires_ts"] > ts
@@ -192,13 +191,13 @@ class PreviewUrlResource(DirectServeResource):
og = og.encode("utf8")
return og
- media_info = yield self._download_url(url, user)
+ media_info = await self._download_url(url, user)
logger.debug("got media_info of '%s'", media_info)
if _is_media(media_info["media_type"]):
file_id = media_info["filesystem_id"]
- dims = yield self.media_repo._generate_thumbnails(
+ dims = await self.media_repo._generate_thumbnails(
None, file_id, file_id, media_info["media_type"], url_cache=True
)
@@ -248,14 +247,14 @@ class PreviewUrlResource(DirectServeResource):
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
if "og:image" in og and og["og:image"]:
- image_info = yield self._download_url(
+ image_info = await self._download_url(
_rebase_url(og["og:image"], media_info["uri"]), user
)
if _is_media(image_info["media_type"]):
# TODO: make sure we don't choke on white-on-transparent images
file_id = image_info["filesystem_id"]
- dims = yield self.media_repo._generate_thumbnails(
+ dims = await self.media_repo._generate_thumbnails(
None, file_id, file_id, image_info["media_type"], url_cache=True
)
if dims:
@@ -293,7 +292,7 @@ class PreviewUrlResource(DirectServeResource):
jsonog = json.dumps(og)
# store OG in history-aware DB cache
- yield self.store.store_url_cache(
+ await self.store.store_url_cache(
url,
media_info["response_code"],
media_info["etag"],
@@ -305,8 +304,7 @@ class PreviewUrlResource(DirectServeResource):
return jsonog.encode("utf8")
- @defer.inlineCallbacks
- def _download_url(self, url, user):
+ async def _download_url(self, url, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@@ -318,7 +316,7 @@ class PreviewUrlResource(DirectServeResource):
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
logger.debug("Trying to get url '%s'", url)
- length, headers, uri, code = yield self.client.get_file(
+ length, headers, uri, code = await self.client.get_file(
url, output_stream=f, max_size=self.max_spider_size
)
except SynapseError:
@@ -345,7 +343,7 @@ class PreviewUrlResource(DirectServeResource):
% (traceback.format_exception_only(sys.exc_info()[0], e),),
Codes.UNKNOWN,
)
- yield finish()
+ await finish()
try:
if b"Content-Type" in headers:
@@ -356,7 +354,7 @@ class PreviewUrlResource(DirectServeResource):
download_name = get_filename_from_headers(headers)
- yield self.store.store_local_media(
+ await self.store.store_local_media(
media_id=file_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
@@ -393,8 +391,7 @@ class PreviewUrlResource(DirectServeResource):
"expire_url_cache_data", self._expire_url_cache_data
)
- @defer.inlineCallbacks
- def _expire_url_cache_data(self):
+ async def _expire_url_cache_data(self):
"""Clean up expired url cache content, media and thumbnails.
"""
# TODO: Delete from backup media store
@@ -403,12 +400,12 @@ class PreviewUrlResource(DirectServeResource):
logger.info("Running url preview cache expiry")
- if not (yield self.store.db.updates.has_completed_background_updates()):
+ if not (await self.store.db.updates.has_completed_background_updates()):
logger.info("Still running DB updates; skipping expiry")
return
# First we delete expired url cache entries
- media_ids = yield self.store.get_expired_url_cache(now)
+ media_ids = await self.store.get_expired_url_cache(now)
removed_media = []
for media_id in media_ids:
@@ -430,7 +427,7 @@ class PreviewUrlResource(DirectServeResource):
except Exception:
pass
- yield self.store.delete_url_cache(removed_media)
+ await self.store.delete_url_cache(removed_media)
if removed_media:
logger.info("Deleted %d entries from url cache", len(removed_media))
@@ -440,7 +437,7 @@ class PreviewUrlResource(DirectServeResource):
# may have a room open with a preview url thing open).
# So we wait a couple of days before deleting, just in case.
expire_before = now - 2 * 24 * 60 * 60 * 1000
- media_ids = yield self.store.get_url_cache_media_before(expire_before)
+ media_ids = await self.store.get_url_cache_media_before(expire_before)
removed_media = []
for media_id in media_ids:
@@ -478,7 +475,7 @@ class PreviewUrlResource(DirectServeResource):
except Exception:
pass
- yield self.store.delete_url_cache_media(removed_media)
+ await self.store.delete_url_cache_media(removed_media)
logger.info("Deleted %d media from url cache", len(removed_media))
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index d57480f761..0b87220234 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.server import (
DirectServeResource,
set_cors_headers,
@@ -79,11 +77,10 @@ class ThumbnailResource(DirectServeResource):
)
self.media_repo.mark_recently_accessed(server_name, media_id)
- @defer.inlineCallbacks
- def _respond_local_thumbnail(
+ async def _respond_local_thumbnail(
self, request, media_id, width, height, method, m_type
):
- media_info = yield self.store.get_local_media(media_id)
+ media_info = await self.store.get_local_media(media_id)
if not media_info:
respond_404(request)
@@ -93,7 +90,7 @@ class ThumbnailResource(DirectServeResource):
respond_404(request)
return
- thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
+ thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
if thumbnail_infos:
thumbnail_info = self._select_thumbnail(
@@ -114,14 +111,13 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
- responder = yield self.media_storage.fetch_media(file_info)
- yield respond_with_responder(request, responder, t_type, t_length)
+ responder = await self.media_storage.fetch_media(file_info)
+ await respond_with_responder(request, responder, t_type, t_length)
else:
logger.info("Couldn't find any generated thumbnails")
respond_404(request)
- @defer.inlineCallbacks
- def _select_or_generate_local_thumbnail(
+ async def _select_or_generate_local_thumbnail(
self,
request,
media_id,
@@ -130,7 +126,7 @@ class ThumbnailResource(DirectServeResource):
desired_method,
desired_type,
):
- media_info = yield self.store.get_local_media(media_id)
+ media_info = await self.store.get_local_media(media_id)
if not media_info:
respond_404(request)
@@ -140,7 +136,7 @@ class ThumbnailResource(DirectServeResource):
respond_404(request)
return
- thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
+ thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos:
t_w = info["thumbnail_width"] == desired_width
t_h = info["thumbnail_height"] == desired_height
@@ -162,15 +158,15 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = info["thumbnail_length"]
- responder = yield self.media_storage.fetch_media(file_info)
+ responder = await self.media_storage.fetch_media(file_info)
if responder:
- yield respond_with_responder(request, responder, t_type, t_length)
+ await respond_with_responder(request, responder, t_type, t_length)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
- file_path = yield self.media_repo.generate_local_exact_thumbnail(
+ file_path = await self.media_repo.generate_local_exact_thumbnail(
media_id,
desired_width,
desired_height,
@@ -180,13 +176,12 @@ class ThumbnailResource(DirectServeResource):
)
if file_path:
- yield respond_with_file(request, desired_type, file_path)
+ await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
respond_404(request)
- @defer.inlineCallbacks
- def _select_or_generate_remote_thumbnail(
+ async def _select_or_generate_remote_thumbnail(
self,
request,
server_name,
@@ -196,9 +191,9 @@ class ThumbnailResource(DirectServeResource):
desired_method,
desired_type,
):
- media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
+ media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
- thumbnail_infos = yield self.store.get_remote_media_thumbnails(
+ thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
@@ -224,15 +219,15 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = info["thumbnail_length"]
- responder = yield self.media_storage.fetch_media(file_info)
+ responder = await self.media_storage.fetch_media(file_info)
if responder:
- yield respond_with_responder(request, responder, t_type, t_length)
+ await respond_with_responder(request, responder, t_type, t_length)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
- file_path = yield self.media_repo.generate_remote_exact_thumbnail(
+ file_path = await self.media_repo.generate_remote_exact_thumbnail(
server_name,
file_id,
media_id,
@@ -243,21 +238,20 @@ class ThumbnailResource(DirectServeResource):
)
if file_path:
- yield respond_with_file(request, desired_type, file_path)
+ await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
respond_404(request)
- @defer.inlineCallbacks
- def _respond_remote_thumbnail(
+ async def _respond_remote_thumbnail(
self, request, server_name, media_id, width, height, method, m_type
):
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
- media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
+ media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
- thumbnail_infos = yield self.store.get_remote_media_thumbnails(
+ thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
@@ -278,8 +272,8 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
- responder = yield self.media_storage.fetch_media(file_info)
- yield respond_with_responder(request, responder, t_type, t_length)
+ responder = await self.media_storage.fetch_media(file_info)
+ await respond_with_responder(request, responder, t_type, t_length)
else:
logger.info("Failed to find any generated thumbnails")
respond_404(request)
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index acca079f23..649e835303 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -144,7 +144,10 @@ class DataStore(
db_conn,
"device_lists_stream",
"stream_id",
- extra_tables=[("user_signature_stream", "stream_id")],
+ extra_tables=[
+ ("user_signature_stream", "stream_id"),
+ ("device_lists_outbound_pokes", "stream_id"),
+ ],
)
self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id"
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index d55733a4cd..2d47cfd131 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import List, Tuple
from six import iteritems
@@ -31,7 +32,7 @@ from synapse.logging.opentracing import (
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import Database
+from synapse.storage.database import Database, LoggingTransaction
from synapse.types import Collection, get_verify_key_from_cross_signing_key
from synapse.util.caches.descriptors import (
Cache,
@@ -112,23 +113,13 @@ class DeviceWorkerStore(SQLBaseStore):
if not has_changed:
return now_stream_id, []
- # We retrieve n+1 devices from the list of outbound pokes where n is
- # our outbound device update limit. We then check if the very last
- # device has the same stream_id as the second-to-last device. If so,
- # then we ignore all devices with that stream_id and only send the
- # devices with a lower stream_id.
- #
- # If when culling the list we end up with no devices afterwards, we
- # consider the device update to be too large, and simply skip the
- # stream_id; the rationale being that such a large device list update
- # is likely an error.
updates = yield self.db.runInteraction(
"get_device_updates_by_remote",
self._get_device_updates_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
- limit + 1,
+ limit,
)
# Return an empty list if there are no updates
@@ -166,14 +157,6 @@ class DeviceWorkerStore(SQLBaseStore):
"device_id": verify_key.version,
}
- # if we have exceeded the limit, we need to exclude any results with the
- # same stream_id as the last row.
- if len(updates) > limit:
- stream_id_cutoff = updates[-1][2]
- now_stream_id = stream_id_cutoff - 1
- else:
- stream_id_cutoff = None
-
# Perform the equivalent of a GROUP BY
#
# Iterate through the updates list and copy non-duplicate
@@ -192,10 +175,6 @@ class DeviceWorkerStore(SQLBaseStore):
query_map = {}
cross_signing_keys_by_user = {}
for user_id, device_id, update_stream_id, update_context in updates:
- if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
- # Stop processing updates
- break
-
if (
user_id in master_key_by_user
and device_id == master_key_by_user[user_id]["device_id"]
@@ -218,17 +197,6 @@ class DeviceWorkerStore(SQLBaseStore):
if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context)
- # If we didn't find any updates with a stream_id lower than the cutoff, it
- # means that there are more than limit updates all of which have the same
- # steam_id.
-
- # That should only happen if a client is spamming the server with new
- # devices, in which case E2E isn't going to work well anyway. We'll just
- # skip that stream_id and return an empty list, and continue with the next
- # stream_id next time.
- if not query_map and not cross_signing_keys_by_user:
- return stream_id_cutoff, []
-
results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
@@ -607,22 +575,33 @@ class DeviceWorkerStore(SQLBaseStore):
else:
return set()
- def get_all_device_list_changes_for_remotes(self, from_key, to_key):
- """Return a list of `(stream_id, user_id, destination)` which is the
- combined list of changes to devices, and which destinations need to be
- poked. `destination` may be None if no destinations need to be poked.
+ async def get_all_device_list_changes_for_remotes(
+ self, from_key: int, to_key: int, limit: int,
+ ) -> List[Tuple[int, str]]:
+ """Return a list of `(stream_id, entity)` which is the combined list of
+ changes to devices and which destinations need to be poked. Entity is
+ either a user ID (starting with '@') or a remote destination.
"""
- # We do a group by here as there can be a large number of duplicate
- # entries, since we throw away device IDs.
+
+ # This query Does The Right Thing where it'll correctly apply the
+ # bounds to the inner queries.
sql = """
- SELECT MAX(stream_id) AS stream_id, user_id, destination
- FROM device_lists_stream
- LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+ SELECT stream_id, entity FROM (
+ SELECT stream_id, user_id AS entity FROM device_lists_stream
+ UNION ALL
+ SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+ ) AS e
WHERE ? < stream_id AND stream_id <= ?
- GROUP BY user_id, destination
+ LIMIT ?
"""
- return self.db.execute(
- "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
+
+ return await self.db.execute(
+ "get_all_device_list_changes_for_remotes",
+ None,
+ sql,
+ from_key,
+ to_key,
+ limit,
)
@cached(max_entries=10000)
@@ -1017,29 +996,49 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
- with self._device_list_id_gen.get_next() as stream_id:
+ if not device_ids:
+ return
+
+ with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
+ yield self.db.runInteraction(
+ "add_device_change_to_stream",
+ self._add_device_change_to_stream_txn,
+ user_id,
+ device_ids,
+ stream_ids,
+ )
+
+ if not hosts:
+ return stream_ids[-1]
+
+ context = get_active_span_text_map()
+ with self._device_list_id_gen.get_next_mult(
+ len(hosts) * len(device_ids)
+ ) as stream_ids:
yield self.db.runInteraction(
- "add_device_change_to_streams",
- self._add_device_change_txn,
+ "add_device_outbound_poke_to_stream",
+ self._add_device_outbound_poke_to_stream_txn,
user_id,
device_ids,
hosts,
- stream_id,
+ stream_ids,
+ context,
)
- return stream_id
- def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
- now = self._clock.time_msec()
+ return stream_ids[-1]
+ def _add_device_change_to_stream_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_ids: Collection[str],
+ stream_ids: List[str],
+ ):
txn.call_after(
- self._device_list_stream_cache.entity_has_changed, user_id, stream_id
+ self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
)
- for host in hosts:
- txn.call_after(
- self._device_list_federation_stream_cache.entity_has_changed,
- host,
- stream_id,
- )
+
+ min_stream_id = stream_ids[0]
# Delete older entries in the table, as we really only care about
# when the latest change happened.
@@ -1048,7 +1047,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?
""",
- [(user_id, device_id, stream_id) for device_id in device_ids],
+ [(user_id, device_id, min_stream_id) for device_id in device_ids],
)
self.db.simple_insert_many_txn(
@@ -1056,11 +1055,22 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_stream",
values=[
{"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
- for device_id in device_ids
+ for stream_id, device_id in zip(stream_ids, device_ids)
],
)
- context = get_active_span_text_map()
+ def _add_device_outbound_poke_to_stream_txn(
+ self, txn, user_id, device_ids, hosts, stream_ids, context,
+ ):
+ for host in hosts:
+ txn.call_after(
+ self._device_list_federation_stream_cache.entity_has_changed,
+ host,
+ stream_ids[-1],
+ )
+
+ now = self._clock.time_msec()
+ next_stream_id = iter(stream_ids)
self.db.simple_insert_many_txn(
txn,
@@ -1068,7 +1078,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values=[
{
"destination": destination,
- "stream_id": stream_id,
+ "stream_id": next(next_stream_id),
"user_id": user_id,
"device_id": device_id,
"sent": False,
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 001a53f9b4..bcf746b7ef 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -537,7 +537,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return result
- def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
+ def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
"""Return a list of changes from the user signature stream to notify remotes.
Note that the user signature stream represents when a user signs their
device with their user-signing key, which is not published to other
@@ -552,13 +552,19 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
"""
sql = """
- SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
+ SELECT stream_id, from_user_id AS user_id
FROM user_signature_stream
WHERE ? < stream_id AND stream_id <= ?
- GROUP BY user_id
+ ORDER BY stream_id ASC
+ LIMIT ?
"""
return self.db.execute(
- "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
+ "get_all_user_signature_changes_for_remotes",
+ None,
+ sql,
+ from_key,
+ to_key,
+ limit,
)
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py
index 604c8b7ddd..dab31e0c2d 100644
--- a/synapse/storage/data_stores/main/presence.py
+++ b/synapse/storage/data_stores/main/presence.py
@@ -60,7 +60,7 @@ class PresenceStore(SQLBaseStore):
"status_msg": state.status_msg,
"currently_active": state.currently_active,
}
- for state in presence_states
+ for stream_id, state in zip(stream_orderings, presence_states)
],
)
@@ -73,19 +73,22 @@ class PresenceStore(SQLBaseStore):
)
txn.execute(sql + clause, [stream_id] + list(args))
- def get_all_presence_updates(self, last_id, current_id):
+ def get_all_presence_updates(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_presence_updates_txn(txn):
- sql = (
- "SELECT stream_id, user_id, state, last_active_ts,"
- " last_federation_update_ts, last_user_sync_ts, status_msg,"
- " currently_active"
- " FROM presence_stream"
- " WHERE ? < stream_id AND stream_id <= ?"
- )
- txn.execute(sql, (last_id, current_id))
+ sql = """
+ SELECT stream_id, user_id, state, last_active_ts,
+ last_federation_update_ts, last_user_sync_ts,
+ status_msg,
+ currently_active
+ FROM presence_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return self.db.runInteraction(
diff --git a/tests/config/test_database.py b/tests/config/test_database.py
index 151d3006ac..f675bde68e 100644
--- a/tests/config/test_database.py
+++ b/tests/config/test_database.py
@@ -21,9 +21,9 @@ from tests import unittest
class DatabaseConfigTestCase(unittest.TestCase):
- def test_database_configured_correctly_no_database_conf_param(self):
+ def test_database_configured_correctly(self):
conf = yaml.safe_load(
- DatabaseConfig().generate_config_section("/data_dir_path", None)
+ DatabaseConfig().generate_config_section(data_dir_path="/data_dir_path")
)
expected_database_conf = {
@@ -32,21 +32,3 @@ class DatabaseConfigTestCase(unittest.TestCase):
}
self.assertEqual(conf["database"], expected_database_conf)
-
- def test_database_configured_correctly_database_conf_param(self):
-
- database_conf = {
- "name": "my super fast datastore",
- "args": {
- "user": "matrix",
- "password": "synapse_database_password",
- "host": "synapse_database_host",
- "database": "matrix",
- },
- }
-
- conf = yaml.safe_load(
- DatabaseConfig().generate_config_section("/data_dir_path", database_conf)
- )
-
- self.assertEqual(conf["database"], database_conf)
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 6f8d990959..c2539b353a 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -88,51 +88,6 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
# Check original device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)
- @defer.inlineCallbacks
- def test_get_device_updates_by_remote_limited(self):
- # Test breaking the update limit in 1, 101, and 1 device_id segments
-
- # first add one device
- device_ids1 = ["device_id0"]
- yield self.store.add_device_change_to_streams(
- "user_id", device_ids1, ["someotherhost"]
- )
-
- # then add 101
- device_ids2 = ["device_id" + str(i + 1) for i in range(101)]
- yield self.store.add_device_change_to_streams(
- "user_id", device_ids2, ["someotherhost"]
- )
-
- # then one more
- device_ids3 = ["newdevice"]
- yield self.store.add_device_change_to_streams(
- "user_id", device_ids3, ["someotherhost"]
- )
-
- #
- # now read them back.
- #
-
- # first we should get a single update
- now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
- "someotherhost", -1, limit=100
- )
- self._check_devices_in_updates(device_ids1, device_updates)
-
- # Then we should get an empty list back as the 101 devices broke the limit
- now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
- "someotherhost", now_stream_id, limit=100
- )
- self.assertEqual(len(device_updates), 0)
-
- # The 101 devices should've been cleared, so we should now just get one device
- # update
- now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
- "someotherhost", now_stream_id, limit=100
- )
- self._check_devices_in_updates(device_ids3, device_updates)
-
def _check_devices_in_updates(self, expected_device_ids, device_updates):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
|