From 7d2532be36dc116e130ad226a7462bb0e899aca4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 15 Jun 2020 08:44:54 -0400 Subject: Discard RDATA from already seen positions. (#7648) --- synapse/replication/tcp/commands.py | 4 ++-- synapse/replication/tcp/handler.py | 30 ++++++++++++++++++++++++++---- 2 files changed, 28 insertions(+), 6 deletions(-) (limited to 'synapse/replication/tcp') diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index c04f622816..ea5937a20c 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -149,7 +149,7 @@ class RdataCommand(Command): class PositionCommand(Command): - """Sent by the server to tell the client the stream postition without + """Sent by the server to tell the client the stream position without needing to send an RDATA. Format:: @@ -188,7 +188,7 @@ class ErrorCommand(_SimpleCommand): class PingCommand(_SimpleCommand): - """Sent by either side as a keep alive. The data is arbitary (often timestamp) + """Sent by either side as a keep alive. The data is arbitrary (often timestamp) """ NAME = "PING" diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index cbcf46f3ae..e6a2e2598b 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -112,8 +112,8 @@ class ReplicationCommandHandler: "replication_position", clock=self._clock ) - # Map of stream to batched updates. See RdataCommand for info on how - # batching works. + # Map of stream name to batched updates. See RdataCommand for info on + # how batching works. self._pending_batches = {} # type: Dict[str, List[Any]] # The factory used to create connections. @@ -123,7 +123,8 @@ class ReplicationCommandHandler: # outgoing replication commands to.) self._connections = [] # type: List[AbstractConnection] - # For each connection, the incoming streams that are coming from that connection + # For each connection, the incoming stream names that are coming from + # that connection. self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] LaterGauge( @@ -310,7 +311,28 @@ class ReplicationCommandHandler: # Check if this is the last of a batch of updates rows = self._pending_batches.pop(stream_name, []) rows.append(row) - await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) + + stream = self._streams.get(stream_name) + if not stream: + logger.error("Got RDATA for unknown stream: %s", stream_name) + return + + # Find where we previously streamed up to. + current_token = stream.current_token(cmd.instance_name) + + # Discard this data if this token is earlier than the current + # position. Note that streams can be reset (in which case you + # expect an earlier token), but that must be preceded by a + # POSITION command. + if cmd.token <= current_token: + logger.debug( + "Discarding RDATA from stream %s at position %s before previous position %s", + stream_name, + cmd.token, + current_token, + ) + else: + await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list -- cgit 1.5.1 From f6f7511a4c0548b17bd1cdabebd0ffad9ea73bc7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 16 Jun 2020 17:10:28 +0100 Subject: Refactor getting replication updates from database. (#7636) The aim here is to make it easier to reason about when streams are limited and when they're not, by moving the logic into the database functions themselves. This should mean we can kill of `db_query_to_update_function` function. --- changelog.d/7636.misc | 1 + synapse/handlers/presence.py | 29 +++++++- synapse/handlers/typing.py | 40 ++++++++--- synapse/push/pusherpool.py | 4 +- synapse/replication/tcp/streams/_base.py | 29 +++----- synapse/storage/data_stores/main/events_worker.py | 41 ++++++++++-- synapse/storage/data_stores/main/presence.py | 41 ++++++++++-- synapse/storage/data_stores/main/push_rule.py | 56 ++++++++++++---- synapse/storage/data_stores/main/receipts.py | 82 +++++++++++++++++++---- 9 files changed, 251 insertions(+), 72 deletions(-) create mode 100644 changelog.d/7636.misc (limited to 'synapse/replication/tcp') diff --git a/changelog.d/7636.misc b/changelog.d/7636.misc new file mode 100644 index 0000000000..f93149502e --- /dev/null +++ b/changelog.d/7636.misc @@ -0,0 +1 @@ +Refactor getting replication updates from database. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 2e8914be14..d2f25ae12a 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -25,7 +25,7 @@ The methods that define policy are: import abc import logging from contextlib import contextmanager -from typing import Dict, Iterable, List, Set +from typing import Dict, Iterable, List, Set, Tuple from prometheus_client import Counter from typing_extensions import ContextManager @@ -773,7 +773,9 @@ class PresenceHandler(BasePresenceHandler): return False - async def get_all_presence_updates(self, last_id, current_id, limit): + async def get_all_presence_updates( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: """ Gets a list of presence update rows from between the given stream ids. Each row has: @@ -785,10 +787,31 @@ class PresenceHandler(BasePresenceHandler): - last_user_sync_ts(int) - status_msg(int) - currently_active(int) + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data """ + # TODO(markjh): replicate the unpersisted changes. # This could use the in-memory stores for recent changes. - rows = await self.store.get_all_presence_updates(last_id, current_id, limit) + rows = await self.store.get_all_presence_updates( + instance_name, last_id, current_id, limit + ) return rows def notify_new_event(self): diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index c7bc14c623..4330abb9f7 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -15,7 +15,7 @@ import logging from collections import namedtuple -from typing import List +from typing import List, Tuple from twisted.internet import defer @@ -259,14 +259,31 @@ class TypingHandler(object): ) async def get_all_typing_updates( - self, last_id: int, current_id: int, limit: int - ) -> List[dict]: - """Get up to `limit` typing updates between the given tokens, earliest - updates first. + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: + """Get updates for typing replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: - return [] + return [], current_id, False changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( last_id @@ -280,9 +297,16 @@ class TypingHandler(object): serial = self._room_serials[room_id] if last_id < serial <= current_id: typing = self._room_typing[room_id] - rows.append((serial, room_id, list(typing))) + rows.append((serial, [room_id, list(typing)])) rows.sort() - return rows[:limit] + + limited = False + if len(rows) > limit: + rows = rows[:limit] + current_id = rows[-1][0] + limited = True + + return rows, current_id, limited def get_current_token(self): return self._latest_room_serial diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 88d203aa44..f6a5458681 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -215,11 +215,9 @@ class PusherPool: try: # Need to subtract 1 from the minimum because the lower bound here # is not inclusive - updated_receipts = yield self.store.get_all_updated_receipts( + users_affected = yield self.store.get_users_sent_receipts_between( min_stream_id - 1, max_stream_id ) - # This returns a tuple, user_id is at index 3 - users_affected = {r[3] for r in updated_receipts} for u in users_affected: if u in self.pushers: diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 4acefc8a96..f196eff072 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -264,7 +264,7 @@ class BackfillStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_current_backfill_token), - db_query_to_update_function(store.get_all_new_backfill_event_rows), + store.get_all_new_backfill_event_rows, ) @@ -291,9 +291,7 @@ class PresenceStream(Stream): if hs.config.worker_app is None: # on the master, query the presence handler presence_handler = hs.get_presence_handler() - update_function = db_query_to_update_function( - presence_handler.get_all_presence_updates - ) + update_function = presence_handler.get_all_presence_updates else: # Query master process update_function = make_http_update_function(hs, self.NAME) @@ -318,9 +316,7 @@ class TypingStream(Stream): if hs.config.worker_app is None: # on the master, query the typing handler - update_function = db_query_to_update_function( - typing_handler.get_all_typing_updates - ) + update_function = typing_handler.get_all_typing_updates else: # Query master process update_function = make_http_update_function(hs, self.NAME) @@ -352,7 +348,7 @@ class ReceiptsStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_max_receipt_stream_id), - db_query_to_update_function(store.get_all_updated_receipts), + store.get_all_updated_receipts, ) @@ -367,26 +363,17 @@ class PushRulesStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() + super(PushRulesStream, self).__init__( - hs.get_instance_name(), self._current_token, self._update_function + hs.get_instance_name(), + self._current_token, + self.store.get_all_push_rule_updates, ) def _current_token(self, instance_name: str) -> int: push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token - async def _update_function( - self, instance_name: str, from_token: Token, to_token: Token, limit: int - ): - rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) - - limited = False - if len(rows) == limit: - to_token = rows[-1][0] - limited = True - - return [(row[0], (row[2],)) for row in rows], to_token, limited - class PushersStream(Stream): """A user has added/changed/removed a pusher diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 213d69100a..a48c7a96ca 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -1077,9 +1077,32 @@ class EventsWorkerStore(SQLBaseStore): "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn ) - def get_all_new_backfill_event_rows(self, last_id, current_id, limit): + async def get_all_new_backfill_event_rows( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: + """Get updates for backfill replication stream, including all new + backfilled events and events that have gone from being outliers to not. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ if last_id == current_id: - return defer.succeed([]) + return [], current_id, False def get_all_new_backfill_event_rows(txn): sql = ( @@ -1094,10 +1117,12 @@ class EventsWorkerStore(SQLBaseStore): " LIMIT ?" ) txn.execute(sql, (-last_id, -current_id, limit)) - new_event_updates = txn.fetchall() + new_event_updates = [(row[0], row[1:]) for row in txn] + limited = False if len(new_event_updates) == limit: upper_bound = new_event_updates[-1][0] + limited = True else: upper_bound = current_id @@ -1114,11 +1139,15 @@ class EventsWorkerStore(SQLBaseStore): " ORDER BY event_stream_ordering DESC" ) txn.execute(sql, (-last_id, -upper_bound)) - new_event_updates.extend(txn.fetchall()) + new_event_updates.extend((row[0], row[1:]) for row in txn) - return new_event_updates + if len(new_event_updates) >= limit: + upper_bound = new_event_updates[-1][0] + limited = True - return self.db.runInteraction( + return new_event_updates, upper_bound, limited + + return await self.db.runInteraction( "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows ) diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py index dab31e0c2d..7574612619 100644 --- a/synapse/storage/data_stores/main/presence.py +++ b/synapse/storage/data_stores/main/presence.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Tuple + from twisted.internet import defer from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause @@ -73,9 +75,32 @@ class PresenceStore(SQLBaseStore): ) txn.execute(sql + clause, [stream_id] + list(args)) - def get_all_presence_updates(self, last_id, current_id, limit): + async def get_all_presence_updates( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: + """Get updates for presence replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + if last_id == current_id: - return defer.succeed([]) + return [], current_id, False def get_all_presence_updates_txn(txn): sql = """ @@ -89,9 +114,17 @@ class PresenceStore(SQLBaseStore): LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() + updates = [(row[0], row[1:]) for row in txn] + + upper_bound = current_id + limited = False + if len(updates) >= limit: + upper_bound = updates[-1][0] + limited = True + + return updates, upper_bound, limited - return self.db.runInteraction( + return await self.db.runInteraction( "get_all_presence_updates", get_all_presence_updates_txn ) diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index ef8f40959f..f6e78ca590 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -16,7 +16,7 @@ import abc import logging -from typing import Union +from typing import List, Tuple, Union from canonicaljson import json @@ -348,23 +348,53 @@ class PushRulesWorkerStore( results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled return results - def get_all_push_rule_updates(self, last_id, current_id, limit): - """Get all the push rules changes that have happend on the server""" + async def get_all_push_rule_updates( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for push_rules replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + if last_id == current_id: - return defer.succeed([]) + return [], current_id, False def get_all_push_rule_updates_txn(txn): - sql = ( - "SELECT stream_id, event_stream_ordering, user_id, rule_id," - " op, priority_class, priority, conditions, actions" - " FROM push_rules_stream" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) + sql = """ + SELECT stream_id, user_id + FROM push_rules_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() + updates = [(stream_id, (user_id,)) for stream_id, user_id in txn] + + limited = False + upper_bound = current_id + if len(updates) == limit: + limited = True + upper_bound = updates[-1][0] + + return updates, upper_bound, limited - return self.db.runInteraction( + return await self.db.runInteraction( "get_all_push_rule_updates", get_all_push_rule_updates_txn ) diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index d4a7163049..8f5505bd67 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -16,6 +16,7 @@ import abc import logging +from typing import List, Tuple from canonicaljson import json @@ -267,26 +268,79 @@ class ReceiptsWorkerStore(SQLBaseStore): } return results - def get_all_updated_receipts(self, last_id, current_id, limit=None): + def get_users_sent_receipts_between(self, last_id: int, current_id: int): + """Get all users who sent receipts between `last_id` exclusive and + `current_id` inclusive. + + Returns: + Deferred[List[str]] + """ + if last_id == current_id: return defer.succeed([]) - def get_all_updated_receipts_txn(txn): - sql = ( - "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" - " FROM receipts_linearized" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - ) - args = [last_id, current_id] - if limit is not None: - sql += " LIMIT ?" - args.append(limit) - txn.execute(sql, args) + def _get_users_sent_receipts_between_txn(txn): + sql = """ + SELECT DISTINCT user_id FROM receipts_linearized + WHERE ? < stream_id AND stream_id <= ? + """ + txn.execute(sql, (last_id, current_id)) - return [r[0:5] + (json.loads(r[5]),) for r in txn] + return [r[0] for r in txn] return self.db.runInteraction( + "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn + ) + + async def get_all_updated_receipts( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, list]], int, bool]: + """Get updates for receipts replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + if last_id == current_id: + return [], current_id, False + + def get_all_updated_receipts_txn(txn): + sql = """ + SELECT stream_id, room_id, receipt_type, user_id, event_id, data + FROM receipts_linearized + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, current_id, limit)) + + updates = [(r[0], r[1:5] + (json.loads(r[5]),)) for r in txn] + + limited = False + upper_bound = current_id + + if len(updates) == limit: + limited = True + upper_bound = updates[-1][0] + + return updates, upper_bound, limited + + return await self.db.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn ) -- cgit 1.5.1 From 62b1ce85398f52e7d6137e77083294d0c90af459 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Sun, 5 Jul 2020 16:32:02 +0100 Subject: isort 5 compatibility (#7786) The CI appears to use the latest version of isort, which is a problem when isort gets a major version bump. Rather than try to pin the version, I've done the necessary to make isort5 happy with synapse. --- changelog.d/7786.misc | 1 + scripts-dev/check_signature.py | 2 +- scripts-dev/lint.sh | 2 +- setup.cfg | 1 - synapse/api/auth.py | 3 +-- synapse/config/__main__.py | 1 + synapse/config/emailconfig.py | 3 +-- synapse/handlers/auth.py | 3 +-- synapse/handlers/cas_handler.py | 3 +-- synapse/logging/opentracing.py | 4 ++-- synapse/replication/tcp/client.py | 2 +- synapse/replication/tcp/handler.py | 4 ++-- synapse/replication/tcp/streams/events.py | 2 -- synapse/rest/media/v1/thumbnailer.py | 3 +-- synapse/secrets.py | 3 +-- synapse/storage/data_stores/main/events.py | 3 +-- synapse/storage/data_stores/main/ui_auth.py | 2 +- synapse/storage/types.py | 2 -- synapse/types.py | 2 +- tests/handlers/test_e2e_keys.py | 4 +--- tests/rest/media/v1/test_media_storage.py | 4 +--- tests/test_utils/event_injection.py | 2 -- tox.ini | 4 ++-- 23 files changed, 22 insertions(+), 38 deletions(-) create mode 100644 changelog.d/7786.misc (limited to 'synapse/replication/tcp') diff --git a/changelog.d/7786.misc b/changelog.d/7786.misc new file mode 100644 index 0000000000..27af2681dc --- /dev/null +++ b/changelog.d/7786.misc @@ -0,0 +1 @@ +Update linting scripts and codebase to be compatible with `isort` v5. diff --git a/scripts-dev/check_signature.py b/scripts-dev/check_signature.py index ecda103cf7..6755bc5282 100644 --- a/scripts-dev/check_signature.py +++ b/scripts-dev/check_signature.py @@ -2,9 +2,9 @@ import argparse import json import logging import sys -import urllib2 import dns.resolver +import urllib2 from signedjson.key import decode_verify_key_bytes, write_signing_keys from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index 6f1ba22931..66b0568858 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -15,7 +15,7 @@ else fi echo "Linting these locations: $files" -isort -y -rc $files +isort $files python3 -m black $files ./scripts-dev/config-lint.sh flake8 $files diff --git a/setup.cfg b/setup.cfg index f2bca272e1..a32278ea8a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,7 +26,6 @@ ignore=W503,W504,E203,E731,E501 [isort] line_length = 88 -not_skip = __init__.py sections=FUTURE,STDLIB,COMPAT,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER default_section=THIRDPARTY known_first_party = synapse diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 06ba6604f3..cb22508f4d 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import Optional @@ -22,7 +21,6 @@ from netaddr import IPAddress from twisted.internet import defer from twisted.web.server import Request -import synapse.logging.opentracing as opentracing import synapse.types from synapse import event_auth from synapse.api.auth_blocking import AuthBlocking @@ -35,6 +33,7 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase +from synapse.logging import opentracing as opentracing from synapse.types import StateMap, UserID from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py index fca35b008c..65043d5b5b 100644 --- a/synapse/config/__main__.py +++ b/synapse/config/__main__.py @@ -16,6 +16,7 @@ from synapse.config._base import ConfigError if __name__ == "__main__": import sys + from synapse.config.homeserver import HomeServerConfig action = sys.argv[1] diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index ca61214454..df08bcd1bc 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import print_function # This file can't be called email.py because if it is, we cannot: @@ -145,8 +144,8 @@ class EmailConfig(Config): or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL ): # make sure we can import the required deps - import jinja2 import bleach + import jinja2 # prevent unused warnings jinja2 diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d713a06bf9..a162392e4c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import time import unicodedata @@ -24,7 +23,6 @@ import attr import bcrypt # type: ignore[import] import pymacaroons -import synapse.util.stringutils as stringutils from synapse.api.constants import LoginType from synapse.api.errors import ( AuthError, @@ -45,6 +43,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi from synapse.push.mailer import load_jinja2_templates from synapse.types import Requester, UserID +from synapse.util import stringutils as stringutils from synapse.util.threepids import canonicalise_email from ._base import BaseHandler diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 76f213723a..d79ffefdb5 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -12,11 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import urllib -import xml.etree.ElementTree as ET from typing import Dict, Optional, Tuple +from xml.etree import ElementTree as ET from twisted.web.client import PartialDownloadError diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 1676771ef0..c6c0e623c1 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -164,7 +164,6 @@ Gotchas than one caller? Will all of those calling functions have be in a context with an active span? """ - import contextlib import inspect import logging @@ -180,8 +179,8 @@ from twisted.internet import defer from synapse.config import ConfigError if TYPE_CHECKING: - from synapse.server import HomeServer from synapse.http.site import SynapseRequest + from synapse.server import HomeServer # Helper class @@ -227,6 +226,7 @@ except ImportError: tags = _DummyTagNames try: from jaeger_client import Config as JaegerConfig + from synapse.logging.scopecontextmanager import LogContextScopeManager except ImportError: JaegerConfig = None # type: ignore diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index df29732f51..4985e40b1f 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -33,8 +33,8 @@ from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure if TYPE_CHECKING: - from synapse.server import HomeServer from synapse.replication.tcp.handler import ReplicationCommandHandler + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index e6a2e2598b..55b3b79008 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar @@ -149,10 +148,11 @@ class ReplicationCommandHandler: using TCP. """ if hs.config.redis.redis_enabled: + import txredisapi + from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, ) - import txredisapi logger.info( "Connecting to redis (host=%r port=%r)", diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index f370390331..bdddb62ad6 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import heapq from collections import Iterable from typing import List, Tuple, Type @@ -22,7 +21,6 @@ import attr from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance - """Handling of the 'events' replication stream This stream contains rows of various types. Each row therefore contains a 'type' diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index c234ea7421..7126997134 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -12,11 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from io import BytesIO -import PIL.Image as Image +from PIL import Image as Image logger = logging.getLogger(__name__) diff --git a/synapse/secrets.py b/synapse/secrets.py index 0b327a0f82..5f43f81eb0 100644 --- a/synapse/secrets.py +++ b/synapse/secrets.py @@ -19,7 +19,6 @@ Injectable secrets module for Synapse. See https://docs.python.org/3/library/secrets.html#module-secrets for the API used in Python 3.6, and the API emulated in Python 2.7. """ - import sys # secrets is available since python 3.6 @@ -31,8 +30,8 @@ if sys.version_info[0:2] >= (3, 6): else: - import os import binascii + import os class Secrets(object): def token_bytes(self, nbytes=32): diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index cfd24d2f06..b7bf3fbd9d 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import itertools import logging from collections import OrderedDict, namedtuple @@ -48,8 +47,8 @@ from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.iterutils import batch_iter if TYPE_CHECKING: - from synapse.storage.data_stores.main import DataStore from synapse.server import HomeServer + from synapse.storage.data_stores.main import DataStore logger = logging.getLogger(__name__) diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py index ec2f38c373..4c044b1a15 100644 --- a/synapse/storage/data_stores/main/ui_auth.py +++ b/synapse/storage/data_stores/main/ui_auth.py @@ -17,10 +17,10 @@ from typing import Any, Dict, Optional, Union import attr -import synapse.util.stringutils as stringutils from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.types import JsonDict +from synapse.util import stringutils as stringutils @attr.s diff --git a/synapse/storage/types.py b/synapse/storage/types.py index daff81c5ee..2d2b560e74 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -12,12 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Any, Iterable, Iterator, List, Tuple from typing_extensions import Protocol - """ Some very basic protocol definitions for the DB-API2 classes specified in PEP-249 """ diff --git a/synapse/types.py b/synapse/types.py index acf60baddc..238b938064 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -29,7 +29,7 @@ from synapse.api.errors import Codes, SynapseError if sys.version_info[:3] >= (3, 6, 0): from typing import Collection else: - from typing import Sized, Iterable, Container + from typing import Container, Iterable, Sized T_co = TypeVar("T_co", covariant=True) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 6c1dc72bd1..1acf287ca4 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -14,11 +14,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import mock -import signedjson.key as key -import signedjson.sign as sign +from signedjson import key as key, sign as sign from twisted.internet import defer diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 2ed9312d56..66fa5978b2 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - import os import shutil import tempfile @@ -25,8 +23,8 @@ from urllib import parse from mock import Mock import attr -import PIL.Image as Image from parameterized import parameterized_class +from PIL import Image as Image from twisted.internet.defer import Deferred diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 431e9f8e5e..43297b530c 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Optional, Tuple import synapse.server @@ -25,7 +24,6 @@ from synapse.types import Collection from tests.test_utils import get_awaitable_result - """ Utility functions for poking events into the storage of the server under test. """ diff --git a/tox.ini b/tox.ini index ab6557f15e..1c042cb227 100644 --- a/tox.ini +++ b/tox.ini @@ -131,8 +131,8 @@ commands = [testenv:check_isort] skip_install = True -deps = isort -commands = /bin/sh -c "isort -c -df -sp setup.cfg -rc synapse tests scripts-dev scripts" +deps = isort==5.0.3 +commands = /bin/sh -c "isort -c --df --sp setup.cfg synapse tests scripts-dev scripts" [testenv:check-newsfragment] skip_install = True -- cgit 1.5.1 From 67d7756fcfb43c2b01a83da10b4f36635fa7b441 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 7 Jul 2020 12:11:35 +0100 Subject: Refactor getting replication updates from database v2. (#7740) --- changelog.d/7740.misc | 1 + synapse/handlers/typing.py | 3 + synapse/replication/tcp/streams/_base.py | 56 ++--------- synapse/storage/data_stores/main/cache.py | 36 +++++-- synapse/storage/data_stores/main/deviceinbox.py | 54 ++++++++--- synapse/storage/data_stores/main/devices.py | 70 ++++++++----- .../storage/data_stores/main/end_to_end_keys.py | 65 +++++++++---- synapse/storage/data_stores/main/group_server.py | 52 ++++++++-- synapse/storage/data_stores/main/pusher.py | 108 ++++++++++----------- synapse/storage/data_stores/main/room.py | 41 ++++++-- synapse/storage/data_stores/main/tags.py | 45 ++++++--- 11 files changed, 336 insertions(+), 195 deletions(-) create mode 100644 changelog.d/7740.misc (limited to 'synapse/replication/tcp') diff --git a/changelog.d/7740.misc b/changelog.d/7740.misc new file mode 100644 index 0000000000..f93149502e --- /dev/null +++ b/changelog.d/7740.misc @@ -0,0 +1 @@ +Refactor getting replication updates from database. diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 6c7abaa578..879c4c07c6 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -294,6 +294,9 @@ class TypingHandler(object): rows.sort() limited = False + # We, unusually, use a strict limit here as we have all the rows in + # memory rather than pulling them out of the database with a `LIMIT ?` + # clause. if len(rows) > limit: rows = rows[:limit] current_id = rows[-1][0] diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index f196eff072..9076bbe9f1 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -198,26 +198,6 @@ def current_token_without_instance( return lambda instance_name: current_token() -def db_query_to_update_function( - query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] -) -> UpdateFunction: - """Wraps a db query function which returns a list of rows to make it - suitable for use as an `update_function` for the Stream class - """ - - async def update_function(instance_name, from_token, upto_token, limit): - rows = await query_function(from_token, upto_token, limit) - updates = [(row[0], row[1:]) for row in rows] - limited = False - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited - - return update_function - - def make_http_update_function(hs, stream_name: str) -> UpdateFunction: """Makes a suitable function for use as an `update_function` that queries the master process for updates. @@ -393,7 +373,7 @@ class PushersStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_pushers_stream_token), - db_query_to_update_function(store.get_all_updated_pushers_rows), + store.get_all_updated_pushers_rows, ) @@ -421,26 +401,12 @@ class CachesStream(Stream): ROW_TYPE = CachesStreamRow def __init__(self, hs): - self.store = hs.get_datastore() + store = hs.get_datastore() super().__init__( hs.get_instance_name(), - self.store.get_cache_stream_token, - self._update_function, - ) - - async def _update_function( - self, instance_name: str, from_token: int, upto_token: int, limit: int - ): - rows = await self.store.get_all_updated_caches( - instance_name, from_token, upto_token, limit + store.get_cache_stream_token, + store.get_all_updated_caches, ) - updates = [(row[0], row[1:]) for row in rows] - limited = False - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited class PublicRoomsStream(Stream): @@ -465,7 +431,7 @@ class PublicRoomsStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_current_public_room_stream_id), - db_query_to_update_function(store.get_all_new_public_rooms), + store.get_all_new_public_rooms, ) @@ -486,7 +452,7 @@ class DeviceListsStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_device_stream_token), - db_query_to_update_function(store.get_all_device_list_changes_for_remotes), + store.get_all_device_list_changes_for_remotes, ) @@ -504,7 +470,7 @@ class ToDeviceStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_to_device_stream_token), - db_query_to_update_function(store.get_all_new_device_messages), + store.get_all_new_device_messages, ) @@ -524,7 +490,7 @@ class TagAccountDataStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_max_account_data_stream_id), - db_query_to_update_function(store.get_all_updated_tags), + store.get_all_updated_tags, ) @@ -612,7 +578,7 @@ class GroupServerStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_group_stream_token), - db_query_to_update_function(store.get_all_groups_changes), + store.get_all_groups_changes, ) @@ -630,7 +596,5 @@ class UserSignatureStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_device_stream_token), - db_query_to_update_function( - store.get_all_user_signature_changes_for_remotes - ), + store.get_all_user_signature_changes_for_remotes, ) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index d30766e543..f39f556c20 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -16,7 +16,7 @@ import itertools import logging -from typing import Any, Iterable, Optional, Tuple +from typing import Any, Iterable, List, Optional, Tuple from synapse.api.constants import EventTypes from synapse.replication.tcp.streams import BackfillStream, CachesStream @@ -46,13 +46,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore): async def get_all_updated_caches( self, instance_name: str, last_id: int, current_id: int, limit: int - ): - """Fetches cache invalidation rows between the two given IDs written - by the given instance. Returns at most `limit` rows. + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for caches replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: - return [] + return [], current_id, False def get_all_updated_caches_txn(txn): # We purposefully don't bound by the current token, as we want to @@ -66,7 +83,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore): LIMIT ? """ txn.execute(sql, (last_id, instance_name, limit)) - return txn.fetchall() + updates = [(row[0], row[1:]) for row in txn] + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited return await self.db.runInteraction( "get_all_updated_caches", get_all_updated_caches_txn diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 9a1178fb39..d313b9705f 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import List, Tuple from canonicaljson import json @@ -207,31 +208,46 @@ class DeviceInboxWorkerStore(SQLBaseStore): "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) - def get_all_new_device_messages(self, last_pos, current_pos, limit): - """ + async def get_all_new_device_messages( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for to device replication stream. + Args: - last_pos(int): - current_pos(int): - limit(int): + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + Returns: - A deferred list of rows from the device inbox + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data """ - if last_pos == current_pos: - return defer.succeed([]) + + if last_id == current_id: + return [], current_id, False def get_all_new_device_messages_txn(txn): # We limit like this as we might have multiple rows per stream_id, and # we want to make sure we always get all entries for any stream_id # we return. - upper_pos = min(current_pos, last_pos + limit) + upper_pos = min(current_id, last_id + limit) sql = ( "SELECT max(stream_id), user_id" " FROM device_inbox" " WHERE ? < stream_id AND stream_id <= ?" " GROUP BY user_id" ) - txn.execute(sql, (last_pos, upper_pos)) - rows = txn.fetchall() + txn.execute(sql, (last_id, upper_pos)) + updates = [(row[0], row[1:]) for row in txn] sql = ( "SELECT max(stream_id), destination" @@ -239,15 +255,21 @@ class DeviceInboxWorkerStore(SQLBaseStore): " WHERE ? < stream_id AND stream_id <= ?" " GROUP BY destination" ) - txn.execute(sql, (last_pos, upper_pos)) - rows.extend(txn) + txn.execute(sql, (last_id, upper_pos)) + updates.extend((row[0], row[1:]) for row in txn) # Order by ascending stream ordering - rows.sort() + updates.sort() - return rows + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True - return self.db.runInteraction( + return updates, upto_token, limited + + return await self.db.runInteraction( "get_all_new_device_messages", get_all_new_device_messages_txn ) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 0ff0542453..343cf9a2d5 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -582,32 +582,58 @@ class DeviceWorkerStore(SQLBaseStore): return set() async def get_all_device_list_changes_for_remotes( - self, from_key: int, to_key: int, limit: int, - ) -> List[Tuple[int, str]]: - """Return a list of `(stream_id, entity)` which is the combined list of - changes to devices and which destinations need to be poked. Entity is - either a user ID (starting with '@') or a remote destination. - """ + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for device lists replication stream. - # This query Does The Right Thing where it'll correctly apply the - # bounds to the inner queries. - sql = """ - SELECT stream_id, entity FROM ( - SELECT stream_id, user_id AS entity FROM device_lists_stream - UNION ALL - SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes - ) AS e - WHERE ? < stream_id AND stream_id <= ? - LIMIT ? + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data """ - return await self.db.execute( + if last_id == current_id: + return [], current_id, False + + def _get_all_device_list_changes_for_remotes(txn): + # This query Does The Right Thing where it'll correctly apply the + # bounds to the inner queries. + sql = """ + SELECT stream_id, entity FROM ( + SELECT stream_id, user_id AS entity FROM device_lists_stream + UNION ALL + SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes + ) AS e + WHERE ? < stream_id AND stream_id <= ? + LIMIT ? + """ + + txn.execute(sql, (last_id, current_id, limit)) + updates = [(row[0], row[1:]) for row in txn] + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited + + return await self.db.runInteraction( "get_all_device_list_changes_for_remotes", - None, - sql, - from_key, - to_key, - limit, + _get_all_device_list_changes_for_remotes, ) @cached(max_entries=10000) diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index 1a0842d4b0..6c3cff82e1 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Dict, List, Tuple from canonicaljson import encode_canonical_json, json @@ -479,34 +479,61 @@ class EndToEndKeyWorkerStore(SQLBaseStore): return result - def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit): - """Return a list of changes from the user signature stream to notify remotes. + async def get_all_user_signature_changes_for_remotes( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for groups replication stream. + Note that the user signature stream represents when a user signs their device with their user-signing key, which is not published to other users or servers, so no `destination` is needed in the returned list. However, this is needed to poke workers. Args: - from_key (int): the stream ID to start at (exclusive) - to_key (int): the stream ID to end at (inclusive) + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. Returns: - Deferred[list[(int,str)]] a list of `(stream_id, user_id)` - """ - sql = """ - SELECT stream_id, from_user_id AS user_id - FROM user_signature_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data """ - return self.db.execute( + + if last_id == current_id: + return [], current_id, False + + def _get_all_user_signature_changes_for_remotes_txn(txn): + sql = """ + SELECT stream_id, from_user_id AS user_id + FROM user_signature_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (last_id, current_id, limit)) + + updates = [(row[0], (row[1:])) for row in txn] + + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited + + return await self.db.runInteraction( "get_all_user_signature_changes_for_remotes", - None, - sql, - from_key, - to_key, - limit, + _get_all_user_signature_changes_for_remotes_txn, ) diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py index fb1361f1c1..4fb9f9850c 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Tuple + from canonicaljson import json from twisted.internet import defer @@ -526,13 +528,35 @@ class GroupServerWorkerStore(SQLBaseStore): "get_groups_changes_for_user", _get_groups_changes_for_user_txn ) - def get_all_groups_changes(self, from_token, to_token, limit): - from_token = int(from_token) - has_changed = self._group_updates_stream_cache.has_any_entity_changed( - from_token - ) + async def get_all_groups_changes( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for groups replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + + last_id = int(last_id) + has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id) + if not has_changed: - return defer.succeed([]) + return [], current_id, False def _get_all_groups_changes_txn(txn): sql = """ @@ -541,13 +565,21 @@ class GroupServerWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? LIMIT ? """ - txn.execute(sql, (from_token, to_token, limit)) - return [ - (stream_id, group_id, user_id, gtype, json.loads(content_json)) + txn.execute(sql, (last_id, current_id, limit)) + updates = [ + (stream_id, (group_id, user_id, gtype, json.loads(content_json))) for stream_id, group_id, user_id, gtype, content_json in txn ] - return self.db.runInteraction( + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited + + return await self.db.runInteraction( "get_all_groups_changes", _get_all_groups_changes_txn ) diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py index 547b9d69cb..5461016240 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import Iterable, Iterator +from typing import Iterable, Iterator, List, Tuple from canonicaljson import encode_canonical_json, json @@ -98,77 +98,69 @@ class PusherWorkerStore(SQLBaseStore): rows = yield self.db.runInteraction("get_all_pushers", get_pushers) return rows - def get_all_updated_pushers(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed(([], [])) - - def get_all_updated_pushers_txn(txn): - sql = ( - "SELECT id, user_name, access_token, profile_tag, kind," - " app_id, app_display_name, device_display_name, pushkey, ts," - " lang, data" - " FROM pushers" - " WHERE ? < id AND id <= ?" - " ORDER BY id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - updated = txn.fetchall() - - sql = ( - "SELECT stream_id, user_id, app_id, pushkey" - " FROM deleted_pushers" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, current_id, limit)) - deleted = txn.fetchall() + async def get_all_updated_pushers_rows( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for pushers replication stream. - return updated, deleted + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. - return self.db.runInteraction( - "get_all_updated_pushers", get_all_updated_pushers_txn - ) + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. - def get_all_updated_pushers_rows(self, last_id, current_id, limit): - """Get all the pushers that have changed between the given tokens. + The token returned can be used in a subsequent call to this + function to get further updatees. - Returns: - Deferred(list(tuple)): each tuple consists of: - stream_id (str) - user_id (str) - app_id (str) - pushkey (str) - was_deleted (bool): whether the pusher was added/updated (False) - or deleted (True) + The updates are a list of 2-tuples of stream ID and the row data """ if last_id == current_id: - return defer.succeed([]) + return [], current_id, False def get_all_updated_pushers_rows_txn(txn): - sql = ( - "SELECT id, user_name, app_id, pushkey" - " FROM pushers" - " WHERE ? < id AND id <= ?" - " ORDER BY id ASC LIMIT ?" - ) + sql = """ + SELECT id, user_name, app_id, pushkey + FROM pushers + WHERE ? < id AND id <= ? + ORDER BY id ASC LIMIT ? + """ txn.execute(sql, (last_id, current_id, limit)) - results = [list(row) + [False] for row in txn] - - sql = ( - "SELECT stream_id, user_id, app_id, pushkey" - " FROM deleted_pushers" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC LIMIT ?" - ) + updates = [ + (stream_id, (user_name, app_id, pushkey, False)) + for stream_id, user_name, app_id, pushkey in txn + ] + + sql = """ + SELECT stream_id, user_id, app_id, pushkey + FROM deleted_pushers + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? + """ txn.execute(sql, (last_id, current_id, limit)) + updates.extend( + (stream_id, (user_name, app_id, pushkey, True)) + for stream_id, user_name, app_id, pushkey in txn + ) + + updates.sort() # Sort so that they're ordered by stream id - results.extend(list(row) + [True] for row in txn) - results.sort() # Sort so that they're ordered by stream id + limited = False + upper_bound = current_id + if len(updates) >= limit: + limited = True + upper_bound = updates[-1][0] - return results + return updates, upper_bound, limited - return self.db.runInteraction( + return await self.db.runInteraction( "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn ) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 13e366536a..c473cf158f 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -803,7 +803,32 @@ class RoomWorkerStore(SQLBaseStore): return total_media_quarantined - def get_all_new_public_rooms(self, prev_id, current_id, limit): + async def get_all_new_public_rooms( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for public rooms replication stream. + + Args: + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + + Returns: + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data + """ + if last_id == current_id: + return [], current_id, False + def get_all_new_public_rooms(txn): sql = """ SELECT stream_id, room_id, visibility, appservice_id, network_id @@ -813,13 +838,17 @@ class RoomWorkerStore(SQLBaseStore): LIMIT ? """ - txn.execute(sql, (prev_id, current_id, limit)) - return txn.fetchall() + txn.execute(sql, (last_id, current_id, limit)) + updates = [(row[0], row[1:]) for row in txn] + limited = False + upto_token = current_id + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True - if prev_id == current_id: - return defer.succeed([]) + return updates, upto_token, limited - return self.db.runInteraction( + return await self.db.runInteraction( "get_all_new_public_rooms", get_all_new_public_rooms ) diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py index f8c776be3f..290317fd94 100644 --- a/synapse/storage/data_stores/main/tags.py +++ b/synapse/storage/data_stores/main/tags.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +from typing import List, Tuple from canonicaljson import json @@ -53,18 +54,32 @@ class TagsWorkerStore(AccountDataWorkerStore): return deferred - @defer.inlineCallbacks - def get_all_updated_tags(self, last_id, current_id, limit): - """Get all the client tags that have changed on the server + async def get_all_updated_tags( + self, instance_name: str, last_id: int, current_id: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + """Get updates for tags replication stream. + Args: - last_id(int): The position to fetch from. - current_id(int): The position to fetch up to. + instance_name: The writer we want to fetch updates from. Unused + here since there is only ever one writer. + last_id: The token to fetch updates from. Exclusive. + current_id: The token to fetch updates up to. Inclusive. + limit: The requested limit for the number of rows to return. The + function may return more or fewer rows. + Returns: - A deferred list of tuples of stream_id int, user_id string, - room_id string, tag string and content string. + A tuple consisting of: the updates, a token to use to fetch + subsequent updates, and whether we returned fewer rows than exists + between the requested tokens due to the limit. + + The token returned can be used in a subsequent call to this + function to get further updatees. + + The updates are a list of 2-tuples of stream ID and the row data """ + if last_id == current_id: - return [] + return [], current_id, False def get_all_updated_tags_txn(txn): sql = ( @@ -76,7 +91,7 @@ class TagsWorkerStore(AccountDataWorkerStore): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - tag_ids = yield self.db.runInteraction( + tag_ids = await self.db.runInteraction( "get_all_updated_tags", get_all_updated_tags_txn ) @@ -89,21 +104,27 @@ class TagsWorkerStore(AccountDataWorkerStore): for tag, content in txn: tags.append(json.dumps(tag) + ":" + content) tag_json = "{" + ",".join(tags) + "}" - results.append((stream_id, user_id, room_id, tag_json)) + results.append((stream_id, (user_id, room_id, tag_json))) return results batch_size = 50 results = [] for i in range(0, len(tag_ids), batch_size): - tags = yield self.db.runInteraction( + tags = await self.db.runInteraction( "get_all_updated_tag_content", get_tag_content, tag_ids[i : i + batch_size], ) results.extend(tags) - return results + limited = False + upto_token = current_id + if len(results) >= limit: + upto_token = results[-1][0] + limited = True + + return results, upto_token, limited @defer.inlineCallbacks def get_updated_tags(self, user_id, stream_id): -- cgit 1.5.1 From e7efd8f827be141cdda4355f19ec1b9170c11323 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Jul 2020 07:15:08 -0400 Subject: Do not use simplejson in Synapse. (#7800) --- changelog.d/7800.misc | 1 + synapse/replication/tcp/commands.py | 11 ++--------- synapse/storage/data_stores/main/schema/delta/25/fts.py | 6 ++---- synapse/storage/data_stores/main/schema/delta/27/ts.py | 6 ++---- .../storage/data_stores/main/schema/delta/31/search_update.py | 6 ++---- .../storage/data_stores/main/schema/delta/33/event_fields.py | 6 ++---- 6 files changed, 11 insertions(+), 25 deletions(-) create mode 100644 changelog.d/7800.misc (limited to 'synapse/replication/tcp') diff --git a/changelog.d/7800.misc b/changelog.d/7800.misc new file mode 100644 index 0000000000..ce2346b3d4 --- /dev/null +++ b/changelog.d/7800.misc @@ -0,0 +1 @@ +Switch from simplejson to the standard library json. diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index ea5937a20c..0f453ff0a8 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -18,18 +18,11 @@ The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are allowed to be sent by which side. """ import abc +import json import logging -import platform from typing import Tuple, Type -if platform.python_implementation() == "PyPy": - import json - - _json_encoder = json.JSONEncoder() -else: - import simplejson as json # type: ignore[no-redef] # noqa: F821 - - _json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821 +_json_encoder = json.JSONEncoder() logger = logging.getLogger(__name__) diff --git a/synapse/storage/data_stores/main/schema/delta/25/fts.py b/synapse/storage/data_stores/main/schema/delta/25/fts.py index 4b2ffd35fd..ee675e71ff 100644 --- a/synapse/storage/data_stores/main/schema/delta/25/fts.py +++ b/synapse/storage/data_stores/main/schema/delta/25/fts.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json import logging -import simplejson - from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.prepare_database import get_statements @@ -66,7 +64,7 @@ def run_create(cur, database_engine, *args, **kwargs): "max_stream_id_exclusive": max_stream_id + 1, "rows_inserted": 0, } - progress_json = simplejson.dumps(progress) + progress_json = json.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/data_stores/main/schema/delta/27/ts.py b/synapse/storage/data_stores/main/schema/delta/27/ts.py index 414f9f5aa0..b7972cfa8e 100644 --- a/synapse/storage/data_stores/main/schema/delta/27/ts.py +++ b/synapse/storage/data_stores/main/schema/delta/27/ts.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json import logging -import simplejson - from synapse.storage.prepare_database import get_statements logger = logging.getLogger(__name__) @@ -45,7 +43,7 @@ def run_create(cur, database_engine, *args, **kwargs): "max_stream_id_exclusive": max_stream_id + 1, "rows_inserted": 0, } - progress_json = simplejson.dumps(progress) + progress_json = json.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/data_stores/main/schema/delta/31/search_update.py b/synapse/storage/data_stores/main/schema/delta/31/search_update.py index 7d8ca5f93f..63b757ade6 100644 --- a/synapse/storage/data_stores/main/schema/delta/31/search_update.py +++ b/synapse/storage/data_stores/main/schema/delta/31/search_update.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json import logging -import simplejson - from synapse.storage.engines import PostgresEngine from synapse.storage.prepare_database import get_statements @@ -50,7 +48,7 @@ def run_create(cur, database_engine, *args, **kwargs): "rows_inserted": 0, "have_added_indexes": False, } - progress_json = simplejson.dumps(progress) + progress_json = json.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/data_stores/main/schema/delta/33/event_fields.py b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py index bff1256a7b..a3e81eeac7 100644 --- a/synapse/storage/data_stores/main/schema/delta/33/event_fields.py +++ b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json import logging -import simplejson - from synapse.storage.prepare_database import get_statements logger = logging.getLogger(__name__) @@ -45,7 +43,7 @@ def run_create(cur, database_engine, *args, **kwargs): "max_stream_id_exclusive": max_stream_id + 1, "rows_inserted": 0, } - progress_json = simplejson.dumps(progress) + progress_json = json.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" -- cgit 1.5.1 From 38e1fac8861f12b707609da06008695a05aaf21c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 9 Jul 2020 09:52:58 -0400 Subject: Fix some spelling mistakes / typos. (#7811) --- changelog.d/7811.misc | 1 + synapse/api/auth.py | 2 +- synapse/config/emailconfig.py | 2 +- synapse/federation/federation_client.py | 4 ++-- synapse/federation/federation_server.py | 6 +++--- synapse/federation/send_queue.py | 2 +- synapse/federation/sender/per_destination_queue.py | 4 ++-- synapse/federation/transport/client.py | 2 +- synapse/federation/transport/server.py | 4 ++-- synapse/notifier.py | 2 +- synapse/replication/http/_base.py | 4 ++-- synapse/replication/tcp/__init__.py | 2 +- synapse/replication/tcp/commands.py | 2 +- synapse/replication/tcp/protocol.py | 2 +- synapse/replication/tcp/redis.py | 2 +- synapse/replication/tcp/streams/events.py | 2 +- synapse/streams/config.py | 4 ++-- synapse/streams/events.py | 2 +- synapse/util/__init__.py | 2 +- synapse/util/async_helpers.py | 2 +- synapse/util/caches/descriptors.py | 2 +- synapse/util/distributor.py | 2 +- synapse/util/patch_inline_callbacks.py | 2 +- synapse/util/retryutils.py | 4 ++-- synapse/visibility.py | 4 ++-- tests/crypto/test_keyring.py | 2 +- tests/rest/client/test_retention.py | 2 +- tests/rest/client/v1/test_presence.py | 2 +- tests/rest/client/v2_alpha/test_relations.py | 2 +- tests/test_mau.py | 2 +- tests/util/test_logcontext.py | 4 ++-- 31 files changed, 41 insertions(+), 40 deletions(-) create mode 100644 changelog.d/7811.misc (limited to 'synapse/replication/tcp') diff --git a/changelog.d/7811.misc b/changelog.d/7811.misc new file mode 100644 index 0000000000..d907bba4df --- /dev/null +++ b/changelog.d/7811.misc @@ -0,0 +1 @@ +Fix various spelling errors in comments and log lines. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index cb22508f4d..40dc62ef6c 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -537,7 +537,7 @@ class Auth(object): # Currently we ignore the `for_verification` flag even though there are # some situations where we can drop particular auth events when adding # to the event's `auth_events` (e.g. joins pointing to previous joins - # when room is publically joinable). Dropping event IDs has the + # when room is publicly joinable). Dropping event IDs has the # advantage that the auth chain for the room grows slower, but we use # the auth chain in state resolution v2 to order events, which means # care must be taken if dropping events to ensure that it doesn't diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index df08bcd1bc..b1dc7ad502 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -72,7 +72,7 @@ class EmailConfig(Config): template_dir = email_config.get("template_dir") # we need an absolute path, because we change directory after starting (and - # we don't yet know what auxilliary templates like mail.css we will need). + # we don't yet know what auxiliary templates like mail.css we will need). # (Note that loading as package_resources with jinja.PackageLoader doesn't # work for the same reason.) if not template_dir: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 07d41ec03f..a37cc9cb4a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -245,7 +245,7 @@ class FederationClient(FederationBase): event_id: event to fetch room_version: version of the room outlier: Indicates whether the PDU is an `outlier`, i.e. if - it's from an arbitary point in the context as opposed to part + it's from an arbitrary point in the context as opposed to part of the current block of PDUs. Defaults to `False` timeout: How long to try (in ms) each destination for before moving to the next destination. None indicates no timeout. @@ -351,7 +351,7 @@ class FederationClient(FederationBase): outlier: bool = False, include_none: bool = False, ) -> List[EventBase]: - """Takes a list of PDUs and checks the signatures and hashs of each + """Takes a list of PDUs and checks the signatures and hashes of each one. If a PDU fails its signature check then we check if we have it in the database and if not then request if from the originating server of that PDU. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e704cf2f44..86051decd4 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -717,7 +717,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: # server name is a literal IP allow_ip_literals = acl_event.content.get("allow_ip_literals", True) if not isinstance(allow_ip_literals, bool): - logger.warning("Ignorning non-bool allow_ip_literals flag") + logger.warning("Ignoring non-bool allow_ip_literals flag") allow_ip_literals = True if not allow_ip_literals: # check for ipv6 literals. These start with '['. @@ -731,7 +731,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: # next, check the deny list deny = acl_event.content.get("deny", []) if not isinstance(deny, (list, tuple)): - logger.warning("Ignorning non-list deny ACL %s", deny) + logger.warning("Ignoring non-list deny ACL %s", deny) deny = [] for e in deny: if _acl_entry_matches(server_name, e): @@ -741,7 +741,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: # then the allow list. allow = acl_event.content.get("allow", []) if not isinstance(allow, (list, tuple)): - logger.warning("Ignorning non-list allow ACL %s", allow) + logger.warning("Ignoring non-list allow ACL %s", allow) allow = [] for e in allow: if _acl_entry_matches(server_name, e): diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 6bbd762681..860b03f7b9 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -359,7 +359,7 @@ class BaseFederationRow(object): Specifies how to identify, serialize and deserialize the different types. """ - TypeId = "" # Unique string that ids the type. Must be overriden in sub classes. + TypeId = "" # Unique string that ids the type. Must be overridden in sub classes. @staticmethod def from_data(data): diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 4e698981a4..12966e239b 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -119,7 +119,7 @@ class PerDestinationQueue(object): ) def send_pdu(self, pdu: EventBase, order: int) -> None: - """Add a PDU to the queue, and start the transmission loop if neccessary + """Add a PDU to the queue, and start the transmission loop if necessary Args: pdu: pdu to send @@ -129,7 +129,7 @@ class PerDestinationQueue(object): self.attempt_new_transaction() def send_presence(self, states: Iterable[UserPresenceState]) -> None: - """Add presence updates to the queue. Start the transmission loop if neccessary. + """Add presence updates to the queue. Start the transmission loop if necessary. Args: states: presence to send diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 9f99311419..cfdf23d366 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -746,7 +746,7 @@ class TransportLayerClient(object): def remove_user_from_group( self, destination, group_id, requester_user_id, user_id, content ): - """Remove a user fron a group + """Remove a user from a group """ path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index bfb7831a02..d1bac318e7 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -109,7 +109,7 @@ class Authenticator(object): self.server_name = hs.hostname self.store = hs.get_datastore() self.federation_domain_whitelist = hs.config.federation_domain_whitelist - self.notifer = hs.get_notifier() + self.notifier = hs.get_notifier() self.replication_client = None if hs.config.worker.worker_app: @@ -175,7 +175,7 @@ class Authenticator(object): await self.store.set_destination_retry_timings(origin, None, 0, 0) # Inform the relevant places that the remote server is back up. - self.notifer.notify_remote_server_up(origin) + self.notifier.notify_remote_server_up(origin) if self.replication_client: # If we're on a worker we try and inform master about this. The # replication client doesn't hook into the notifier to avoid diff --git a/synapse/notifier.py b/synapse/notifier.py index 87c120a59c..bd41f77852 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -83,7 +83,7 @@ class _NotifierUserStream(object): self.current_token = current_token # The last token for which we should wake up any streams that have a - # token that comes before it. This gets updated everytime we get poked. + # token that comes before it. This gets updated every time we get poked. # We start it at the current token since if we get any streams # that have a token from before we have no idea whether they should be # woken up or not, so lets just wake them up. diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 0843d28d4b..fb0dd04f88 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -92,11 +92,11 @@ class ReplicationEndpoint(object): # assert here that sub classes don't try and use the name. assert ( "instance_name" not in self.PATH_ARGS - ), "`instance_name` is a reserved paramater name" + ), "`instance_name` is a reserved parameter name" assert ( "instance_name" not in signature(self.__class__._serialize_payload).parameters - ), "`instance_name` is a reserved paramater name" + ), "`instance_name` is a reserved parameter name" assert self.METHOD in ("PUT", "POST", "GET") diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py index 523a1358d4..1b8718b11d 100644 --- a/synapse/replication/tcp/__init__.py +++ b/synapse/replication/tcp/__init__.py @@ -25,7 +25,7 @@ Structure of the module: * command.py - the definitions of all the valid commands * protocol.py - the TCP protocol classes * resource.py - handles streaming stream updates to replications - * streams/ - the definitons of all the valid streams + * streams/ - the definitions of all the valid streams The general interaction of the classes are: diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 0f453ff0a8..ccc7f1f0d1 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -47,7 +47,7 @@ class Command(metaclass=abc.ABCMeta): @abc.abstractmethod def to_line(self) -> str: - """Serialises the comamnd for the wire. Does not include the command + """Serialises the command for the wire. Does not include the command prefix. """ diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 4198eece71..ca47f5cc88 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -317,7 +317,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def _queue_command(self, cmd): """Queue the command until the connection is ready to write to again. """ - logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd) + logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd) self.pending_commands.append(cmd) if len(self.pending_commands) > self.max_line_buffer: diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index e776b63183..0a7e7f67be 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -177,7 +177,7 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): Args: hs outbound_redis_connection: A connection to redis that will be used to - send outbound commands (this is seperate to the redis connection + send outbound commands (this is separate to the redis connection used to subscribe). """ diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index bdddb62ad6..1c2a4cce7f 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -62,7 +62,7 @@ class BaseEventsStreamRow(object): Specifies how to identify, serialize and deserialize the different types. """ - # Unique string that ids the type. Must be overriden in sub classes. + # Unique string that ids the type. Must be overridden in sub classes. TypeId = None # type: str @classmethod diff --git a/synapse/streams/config.py b/synapse/streams/config.py index cd56cd91ed..ca7c16ff65 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -68,13 +68,13 @@ class PaginationConfig(object): elif from_tok: from_tok = StreamToken.from_string(from_tok) except Exception: - raise SynapseError(400, "'from' paramater is invalid") + raise SynapseError(400, "'from' parameter is invalid") try: if to_tok: to_tok = StreamToken.from_string(to_tok) except Exception: - raise SynapseError(400, "'to' paramater is invalid") + raise SynapseError(400, "'to' parameter is invalid") limit = parse_integer(request, "limit", default=default_limit) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index fcd2aaa9c9..5d3eddcfdc 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -68,7 +68,7 @@ class EventSources(object): The returned token does not have the current values for fields other than `room`, since they are not used during pagination. - Retuns: + Returns: Deferred[StreamToken] """ token = StreamToken( diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 60f0de70f7..c63256d3bd 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -55,7 +55,7 @@ class Clock(object): return self._reactor.seconds() def time_msec(self): - """Returns the current system time in miliseconds since epoch.""" + """Returns the current system time in milliseconds since epoch.""" return int(self.time() * 1000) def looping_call(self, f, msec, *args, **kwargs): diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 65abf0846e..f562770922 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -352,7 +352,7 @@ class ReadWriteLock(object): # resolved when they release the lock). # # Read: We know its safe to acquire a read lock when the latest writer has - # been resolved. The new reader is appeneded to the list of latest readers. + # been resolved. The new reader is appended to the list of latest readers. # # Write: We know its safe to acquire the write lock when both the latest # writers and readers have been resolved. The new writer replaces the latest diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 64f35fc288..9b09c08b89 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -516,7 +516,7 @@ class CacheListDescriptor(_CacheDescriptorBase): """ Args: orig (function) - cached_method_name (str): The name of the chached method. + cached_method_name (str): The name of the cached method. list_name (str): Name of the argument which is the bulk lookup list num_args (int): number of positional arguments (excluding ``self``, but including list_name) to use as cache keys. Defaults to all diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 45af8d3eeb..da20523b70 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -39,7 +39,7 @@ class Distributor(object): Signals are named simply by strings. TODO(paul): It would be nice to give signals stronger object identities, - so we can attach metadata, docstrings, detect typoes, etc... But this + so we can attach metadata, docstrings, detect typos, etc... But this model will do for today. """ diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 2605f3c65b..54c046b6e1 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -192,7 +192,7 @@ def _check_yield_points(f: Callable, changes: List[str]): result = yield d except Exception: # this will fish an earlier Failure out of the stack where possible, and - # thus is preferable to passing in an exeception to the Failure + # thus is preferable to passing in an exception to the Failure # constructor, since it results in less stack-mangling. result = Failure() diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index af69587196..8794317caa 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -22,7 +22,7 @@ from synapse.api.errors import CodeMessageException logger = logging.getLogger(__name__) -# the intial backoff, after the first transaction fails +# the initial backoff, after the first transaction fails MIN_RETRY_INTERVAL = 10 * 60 * 1000 # how much we multiply the backoff by after each subsequent fail @@ -174,7 +174,7 @@ class RetryDestinationLimiter(object): # has been decommissioned. # If we get a 401, then we should probably back off since they # won't accept our requests for at least a while. - # 429 is us being aggresively rate limited, so lets rate limit + # 429 is us being aggressively rate limited, so lets rate limit # ourselves. if exc_val.code == 404 and self.backoff_on_404: valid_err_code = False diff --git a/synapse/visibility.py b/synapse/visibility.py index 3dfd4af26c..0f042c5696 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -319,7 +319,7 @@ def filter_events_for_server( return True # Lets check to see if all the events have a history visibility - # of "shared" or "world_readable". If thats the case then we don't + # of "shared" or "world_readable". If that's the case then we don't # need to check membership (as we know the server is in the room). event_to_state_ids = yield storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), @@ -335,7 +335,7 @@ def filter_events_for_server( visibility_ids.add(hist) # If we failed to find any history visibility events then the default - # is "shared" visiblity. + # is "shared" visibility. if not visibility_ids: all_open = True else: diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 70c8e72303..f9ce609923 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -192,7 +192,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") self.failureResultOf(d, SynapseError) - # should suceed on a signed object + # should succeed on a signed object d = _verify_json_for_server(kr, "server9", json1, 500, "test signed") # self.assertFalse(d.called) self.get_success(d) diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 95475bb651..e54ffea150 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -126,7 +126,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): events.append(self.get_success(store.get_event(valid_event_id))) - # Advance the time by anothe 2 days. After this, the first event should be + # Advance the time by another 2 days. After this, the first event should be # outdated but not the second one. self.reactor.advance(one_day_ms * 2 / 1000) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 0fdff79aa7..3c66255dac 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -60,7 +60,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): def test_put_presence_disabled(self): """ - PUT to the status endpoint with use_presence disbled will NOT call + PUT to the status endpoint with use_presence disabled will NOT call set_state on the presence handler. """ self.hs.config.use_presence = False diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index fd641a7c2f..99c9f4e928 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -99,7 +99,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(400, channel.code, channel.json_body) def test_basic_paginate_relations(self): - """Tests that calling pagination API corectly the latest relations. + """Tests that calling pagination API correctly the latest relations. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") self.assertEquals(200, channel.code, channel.json_body) diff --git a/tests/test_mau.py b/tests/test_mau.py index 49667ed7f4..654a6fa42d 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -166,7 +166,7 @@ class TestMauLimit(unittest.HomeserverTestCase): self.do_sync_for_user(token5) self.do_sync_for_user(token6) - # But old user cant + # But old user can't with self.assertRaises(SynapseError) as cm: self.do_sync_for_user(token1) diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 95301c013c..58ee918f65 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def test_make_deferred_yieldable(self): - # a function which retuns an incomplete deferred, but doesn't follow + # a function which returns an incomplete deferred, but doesn't follow # the synapse rules. def blocking_function(): d = defer.Deferred() @@ -183,7 +183,7 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def test_make_deferred_yieldable_with_await(self): - # an async function which retuns an incomplete coroutine, but doesn't + # an async function which returns an incomplete coroutine, but doesn't # follow the synapse rules. async def blocking_function(): -- cgit 1.5.1 From f299441cc67f31dcd47b8fdfda4a218bee9df9ba Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 10 Jul 2020 18:26:36 +0100 Subject: Add ability to shard the federation sender (#7798) --- changelog.d/7798.feature | 1 + docs/sample_config.yaml | 65 ++--- synapse/app/generic_worker.py | 59 ++--- synapse/config/federation.py | 129 ++++++++++ synapse/config/homeserver.py | 3 + synapse/config/server.py | 66 ----- synapse/federation/send_queue.py | 14 +- synapse/federation/sender/__init__.py | 48 +++- synapse/federation/sender/per_destination_queue.py | 22 ++ synapse/replication/tcp/commands.py | 10 +- synapse/replication/tcp/handler.py | 4 +- .../delta/58/10federation_pos_instance_name.sql | 22 ++ synapse/storage/data_stores/main/stream.py | 97 ++++++- tests/replication/test_federation_ack.py | 1 + tests/replication/test_federation_sender_shard.py | 286 +++++++++++++++++++++ 15 files changed, 670 insertions(+), 157 deletions(-) create mode 100644 changelog.d/7798.feature create mode 100644 synapse/config/federation.py create mode 100644 synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql create mode 100644 tests/replication/test_federation_sender_shard.py (limited to 'synapse/replication/tcp') diff --git a/changelog.d/7798.feature b/changelog.d/7798.feature new file mode 100644 index 0000000000..56ffaf0d4a --- /dev/null +++ b/changelog.d/7798.feature @@ -0,0 +1 @@ +Add experimental support for running multiple federation sender processes. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 164a104045..1a2d9fb153 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -118,38 +118,6 @@ pid_file: DATADIR/homeserver.pid # #enable_search: false -# Restrict federation to the following whitelist of domains. -# N.B. we recommend also firewalling your federation listener to limit -# inbound federation traffic as early as possible, rather than relying -# purely on this application-layer restriction. If not specified, the -# default is to whitelist everything. -# -#federation_domain_whitelist: -# - lon.example.com -# - nyc.example.com -# - syd.example.com - -# Prevent federation requests from being sent to the following -# blacklist IP address CIDR ranges. If this option is not specified, or -# specified with an empty list, no ip range blacklist will be enforced. -# -# As of Synapse v1.4.0 this option also affects any outbound requests to identity -# servers provided by user input. -# -# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly -# listed here, since they correspond to unroutable addresses.) -# -federation_ip_range_blacklist: - - '127.0.0.0/8' - - '10.0.0.0/8' - - '172.16.0.0/12' - - '192.168.0.0/16' - - '100.64.0.0/10' - - '169.254.0.0/16' - - '::1/128' - - 'fe80::/64' - - 'fc00::/7' - # List of ports that Synapse should listen on, their purpose and their # configuration. # @@ -608,6 +576,39 @@ acme: +# Restrict federation to the following whitelist of domains. +# N.B. we recommend also firewalling your federation listener to limit +# inbound federation traffic as early as possible, rather than relying +# purely on this application-layer restriction. If not specified, the +# default is to whitelist everything. +# +#federation_domain_whitelist: +# - lon.example.com +# - nyc.example.com +# - syd.example.com + +# Prevent federation requests from being sent to the following +# blacklist IP address CIDR ranges. If this option is not specified, or +# specified with an empty list, no ip range blacklist will be enforced. +# +# As of Synapse v1.4.0 this option also affects any outbound requests to identity +# servers provided by user input. +# +# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly +# listed here, since they correspond to unroutable addresses.) +# +federation_ip_range_blacklist: + - '127.0.0.0/8' + - '10.0.0.0/8' + - '172.16.0.0/12' + - '192.168.0.0/16' + - '100.64.0.0/10' + - '169.254.0.0/16' + - '::1/128' + - 'fe80::/64' + - 'fc00::/7' + + ## Caching ## # Caching can be configured through the following options. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index f6792d9fc8..e90695f026 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -511,25 +511,7 @@ class GenericWorkerSlavedStore( SearchWorkerStore, BaseSlavedStore, ): - def __init__(self, database, db_conn, hs): - super(GenericWorkerSlavedStore, self).__init__(database, db_conn, hs) - - # We pull out the current federation stream position now so that we - # always have a known value for the federation position in memory so - # that we don't have to bounce via a deferred once when we start the - # replication streams. - self.federation_out_pos_startup = self._get_federation_out_pos(db_conn) - - def _get_federation_out_pos(self, db_conn): - sql = "SELECT stream_id FROM federation_stream_position WHERE type = ?" - sql = self.database_engine.convert_param_style(sql) - - txn = db_conn.cursor() - txn.execute(sql, ("federation",)) - rows = txn.fetchall() - txn.close() - - return rows[0][0] if rows else -1 + pass class GenericWorkerServer(HomeServer): @@ -812,19 +794,11 @@ class FederationSenderHandler(object): self.federation_sender = hs.get_federation_sender() self._hs = hs - # if the worker is restarted, we want to pick up where we left off in - # the replication stream, so load the position from the database. - # - # XXX is this actually worthwhile? Whenever the master is restarted, we'll - # drop some rows anyway (which is mostly fine because we're only dropping - # typing and presence notifications). If the replication stream is - # unreliable, why do we do all this hoop-jumping to store the position in the - # database? See also https://github.com/matrix-org/synapse/issues/7535. - # - self.federation_position = self.store.federation_out_pos_startup + # Stores the latest position in the federation stream we've gotten up + # to. This is always set before we use it. + self.federation_position = None self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") - self._last_ack = self.federation_position def on_start(self): # There may be some events that are persisted but haven't been sent, @@ -932,7 +906,6 @@ class FederationSenderHandler(object): # We ACK this token over replication so that the master can drop # its in memory queues self._hs.get_tcp_replication().send_federation_ack(current_position) - self._last_ack = current_position except Exception: logger.exception("Error updating federation stream position") @@ -960,7 +933,7 @@ def start(config_options): ) if config.worker_app == "synapse.app.appservice": - if config.notify_appservices: + if config.appservice.notify_appservices: sys.stderr.write( "\nThe appservices must be disabled in the main synapse process" "\nbefore they can be run in a separate worker." @@ -970,13 +943,13 @@ def start(config_options): sys.exit(1) # Force the appservice to start since they will be disabled in the main config - config.notify_appservices = True + config.appservice.notify_appservices = True else: # For other worker types we force this to off. - config.notify_appservices = False + config.appservice.notify_appservices = False if config.worker_app == "synapse.app.pusher": - if config.start_pushers: + if config.server.start_pushers: sys.stderr.write( "\nThe pushers must be disabled in the main synapse process" "\nbefore they can be run in a separate worker." @@ -986,13 +959,13 @@ def start(config_options): sys.exit(1) # Force the pushers to start since they will be disabled in the main config - config.start_pushers = True + config.server.start_pushers = True else: # For other worker types we force this to off. - config.start_pushers = False + config.server.start_pushers = False if config.worker_app == "synapse.app.user_dir": - if config.update_user_directory: + if config.server.update_user_directory: sys.stderr.write( "\nThe update_user_directory must be disabled in the main synapse process" "\nbefore they can be run in a separate worker." @@ -1002,13 +975,13 @@ def start(config_options): sys.exit(1) # Force the pushers to start since they will be disabled in the main config - config.update_user_directory = True + config.server.update_user_directory = True else: # For other worker types we force this to off. - config.update_user_directory = False + config.server.update_user_directory = False if config.worker_app == "synapse.app.federation_sender": - if config.send_federation: + if config.federation.send_federation: sys.stderr.write( "\nThe send_federation must be disabled in the main synapse process" "\nbefore they can be run in a separate worker." @@ -1018,10 +991,10 @@ def start(config_options): sys.exit(1) # Force the pushers to start since they will be disabled in the main config - config.send_federation = True + config.federation.send_federation = True else: # For other worker types we force this to off. - config.send_federation = False + config.federation.send_federation = False synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts diff --git a/synapse/config/federation.py b/synapse/config/federation.py new file mode 100644 index 0000000000..7782ab4c9d --- /dev/null +++ b/synapse/config/federation.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hashlib import sha256 +from typing import List, Optional + +import attr +from netaddr import IPSet + +from ._base import Config, ConfigError + + +@attr.s +class ShardedFederationSendingConfig: + """Algorithm for choosing which federation sender instance is responsible + for which destionation host. + """ + + instances = attr.ib(type=List[str]) + + def should_send_to(self, instance_name: str, destination: str) -> bool: + """Whether this instance is responsible for sending transcations for + the given host. + """ + + # If multiple federation senders are not defined we always return true. + if not self.instances or len(self.instances) == 1: + return True + + # We shard by taking the hash, modulo it by the number of federation + # senders and then checking whether this instance matches the instance + # at that index. + # + # (Technically this introduces some bias and is not entirely uniform, but + # since the hash is so large the bias is ridiculously small). + dest_hash = sha256(destination.encode("utf8")).digest() + dest_int = int.from_bytes(dest_hash, byteorder="little") + remainder = dest_int % (len(self.instances)) + return self.instances[remainder] == instance_name + + +class FederationConfig(Config): + section = "federation" + + def read_config(self, config, **kwargs): + # Whether to send federation traffic out in this process. This only + # applies to some federation traffic, and so shouldn't be used to + # "disable" federation + self.send_federation = config.get("send_federation", True) + + federation_sender_instances = config.get("federation_sender_instances") or [] + self.federation_shard_config = ShardedFederationSendingConfig( + federation_sender_instances + ) + + # FIXME: federation_domain_whitelist needs sytests + self.federation_domain_whitelist = None # type: Optional[dict] + federation_domain_whitelist = config.get("federation_domain_whitelist", None) + + if federation_domain_whitelist is not None: + # turn the whitelist into a hash for speed of lookup + self.federation_domain_whitelist = {} + + for domain in federation_domain_whitelist: + self.federation_domain_whitelist[domain] = True + + self.federation_ip_range_blacklist = config.get( + "federation_ip_range_blacklist", [] + ) + + # Attempt to create an IPSet from the given ranges + try: + self.federation_ip_range_blacklist = IPSet( + self.federation_ip_range_blacklist + ) + + # Always blacklist 0.0.0.0, :: + self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) + except Exception as e: + raise ConfigError( + "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e + ) + + def generate_config_section(self, config_dir_path, server_name, **kwargs): + return """\ + # Restrict federation to the following whitelist of domains. + # N.B. we recommend also firewalling your federation listener to limit + # inbound federation traffic as early as possible, rather than relying + # purely on this application-layer restriction. If not specified, the + # default is to whitelist everything. + # + #federation_domain_whitelist: + # - lon.example.com + # - nyc.example.com + # - syd.example.com + + # Prevent federation requests from being sent to the following + # blacklist IP address CIDR ranges. If this option is not specified, or + # specified with an empty list, no ip range blacklist will be enforced. + # + # As of Synapse v1.4.0 this option also affects any outbound requests to identity + # servers provided by user input. + # + # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly + # listed here, since they correspond to unroutable addresses.) + # + federation_ip_range_blacklist: + - '127.0.0.0/8' + - '10.0.0.0/8' + - '172.16.0.0/12' + - '192.168.0.0/16' + - '100.64.0.0/10' + - '169.254.0.0/16' + - '::1/128' + - 'fe80::/64' + - 'fc00::/7' + """ diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 264c274c52..8e93d31394 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -23,6 +23,7 @@ from .cas import CasConfig from .consent_config import ConsentConfig from .database import DatabaseConfig from .emailconfig import EmailConfig +from .federation import FederationConfig from .groups import GroupsConfig from .jwt_config import JWTConfig from .key import KeyConfig @@ -57,6 +58,7 @@ class HomeServerConfig(RootConfig): config_classes = [ ServerConfig, TlsConfig, + FederationConfig, CacheConfig, DatabaseConfig, LoggingConfig, @@ -90,4 +92,5 @@ class HomeServerConfig(RootConfig): ThirdPartyRulesConfig, TracerConfig, RedisConfig, + FederationConfig, ] diff --git a/synapse/config/server.py b/synapse/config/server.py index 8204664883..b6afa642ca 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -23,7 +23,6 @@ from typing import Any, Dict, Iterable, List, Optional import attr import yaml -from netaddr import IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.endpoint import parse_and_validate_server_name @@ -136,11 +135,6 @@ class ServerConfig(Config): self.use_frozen_dicts = config.get("use_frozen_dicts", False) self.public_baseurl = config.get("public_baseurl") - # Whether to send federation traffic out in this process. This only - # applies to some federation traffic, and so shouldn't be used to - # "disable" federation - self.send_federation = config.get("send_federation", True) - # Whether to enable user presence. self.use_presence = config.get("use_presence", True) @@ -263,34 +257,6 @@ class ServerConfig(Config): # due to resource constraints self.admin_contact = config.get("admin_contact", None) - # FIXME: federation_domain_whitelist needs sytests - self.federation_domain_whitelist = None # type: Optional[dict] - federation_domain_whitelist = config.get("federation_domain_whitelist", None) - - if federation_domain_whitelist is not None: - # turn the whitelist into a hash for speed of lookup - self.federation_domain_whitelist = {} - - for domain in federation_domain_whitelist: - self.federation_domain_whitelist[domain] = True - - self.federation_ip_range_blacklist = config.get( - "federation_ip_range_blacklist", [] - ) - - # Attempt to create an IPSet from the given ranges - try: - self.federation_ip_range_blacklist = IPSet( - self.federation_ip_range_blacklist - ) - - # Always blacklist 0.0.0.0, :: - self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) - except Exception as e: - raise ConfigError( - "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e - ) - if self.public_baseurl is not None: if self.public_baseurl[-1] != "/": self.public_baseurl += "/" @@ -743,38 +709,6 @@ class ServerConfig(Config): # #enable_search: false - # Restrict federation to the following whitelist of domains. - # N.B. we recommend also firewalling your federation listener to limit - # inbound federation traffic as early as possible, rather than relying - # purely on this application-layer restriction. If not specified, the - # default is to whitelist everything. - # - #federation_domain_whitelist: - # - lon.example.com - # - nyc.example.com - # - syd.example.com - - # Prevent federation requests from being sent to the following - # blacklist IP address CIDR ranges. If this option is not specified, or - # specified with an empty list, no ip range blacklist will be enforced. - # - # As of Synapse v1.4.0 this option also affects any outbound requests to identity - # servers provided by user input. - # - # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly - # listed here, since they correspond to unroutable addresses.) - # - federation_ip_range_blacklist: - - '127.0.0.0/8' - - '10.0.0.0/8' - - '172.16.0.0/12' - - '192.168.0.0/16' - - '100.64.0.0/10' - - '169.254.0.0/16' - - '::1/128' - - 'fe80::/64' - - 'fc00::/7' - # List of ports that Synapse should listen on, their purpose and their # configuration. # diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 860b03f7b9..4fc9ff92e5 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -55,6 +55,11 @@ class FederationRemoteSendQueue(object): self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id + # We may have multiple federation sender instances, so we need to track + # their positions separately. + self._sender_instances = hs.config.federation.federation_shard_config.instances + self._sender_positions = {} + # Pending presence map user_id -> UserPresenceState self.presence_map = {} # type: Dict[str, UserPresenceState] @@ -261,7 +266,14 @@ class FederationRemoteSendQueue(object): def get_current_token(self): return self.pos - 1 - def federation_ack(self, token): + def federation_ack(self, instance_name, token): + if self._sender_instances: + # If we have configured multiple federation sender instances we need + # to track their positions separately, and only clear the queue up + # to the token all instances have acked. + self._sender_positions[instance_name] = token + token = min(self._sender_positions.values()) + self._clear_queue_before_pos(token) async def get_replication_rows( diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 464d7a41de..4b63a0755f 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -69,6 +69,9 @@ class FederationSender(object): self._transaction_manager = TransactionManager(hs) + self._instance_name = hs.get_instance_name() + self._federation_shard_config = hs.config.federation.federation_shard_config + # map from destination to PerDestinationQueue self._per_destination_queues = {} # type: Dict[str, PerDestinationQueue] @@ -191,7 +194,13 @@ class FederationSender(object): ) return - destinations = set(destinations) + destinations = { + d + for d in destinations + if self._federation_shard_config.should_send_to( + self._instance_name, d + ) + } if send_on_behalf_of is not None: # If we are sending the event on behalf of another server @@ -322,7 +331,12 @@ class FederationSender(object): # Work out which remote servers should be poked and poke them. domains = yield self.state.get_current_hosts_in_room(room_id) - domains = [d for d in domains if d != self.server_name] + domains = [ + d + for d in domains + if d != self.server_name + and self._federation_shard_config.should_send_to(self._instance_name, d) + ] if not domains: return @@ -427,6 +441,10 @@ class FederationSender(object): for destination in destinations: if destination == self.server_name: continue + if not self._federation_shard_config.should_send_to( + self._instance_name, destination + ): + continue self._get_per_destination_queue(destination).send_presence(states) @measure_func("txnqueue._process_presence") @@ -441,6 +459,12 @@ class FederationSender(object): for destination in destinations: if destination == self.server_name: continue + + if not self._federation_shard_config.should_send_to( + self._instance_name, destination + ): + continue + self._get_per_destination_queue(destination).send_presence(states) def build_and_send_edu( @@ -462,6 +486,11 @@ class FederationSender(object): logger.info("Not sending EDU to ourselves") return + if not self._federation_shard_config.should_send_to( + self._instance_name, destination + ): + return + edu = Edu( origin=self.server_name, destination=destination, @@ -478,6 +507,11 @@ class FederationSender(object): edu: edu to send key: clobbering key for this edu """ + if not self._federation_shard_config.should_send_to( + self._instance_name, edu.destination + ): + return + queue = self._get_per_destination_queue(edu.destination) if key: queue.send_keyed_edu(edu, key) @@ -489,6 +523,11 @@ class FederationSender(object): logger.warning("Not sending device update to ourselves") return + if not self._federation_shard_config.should_send_to( + self._instance_name, destination + ): + return + self._get_per_destination_queue(destination).attempt_new_transaction() def wake_destination(self, destination: str): @@ -502,6 +541,11 @@ class FederationSender(object): logger.warning("Not waking up ourselves") return + if not self._federation_shard_config.should_send_to( + self._instance_name, destination + ): + return + self._get_per_destination_queue(destination).attempt_new_transaction() @staticmethod diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 12966e239b..6402136e8a 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -74,6 +74,20 @@ class PerDestinationQueue(object): self._clock = hs.get_clock() self._store = hs.get_datastore() self._transaction_manager = transaction_manager + self._instance_name = hs.get_instance_name() + self._federation_shard_config = hs.config.federation.federation_shard_config + + self._should_send_on_this_instance = True + if not self._federation_shard_config.should_send_to( + self._instance_name, destination + ): + # We don't raise an exception here to avoid taking out any other + # processing. We have a guard in `attempt_new_transaction` that + # ensure we don't start sending stuff. + logger.error( + "Create a per destination queue for %s on wrong worker", destination, + ) + self._should_send_on_this_instance = False self._destination = destination self.transmission_loop_running = False @@ -180,6 +194,14 @@ class PerDestinationQueue(object): logger.debug("TX [%s] Transaction already in progress", self._destination) return + if not self._should_send_on_this_instance: + # We don't raise an exception here to avoid taking out any other + # processing. + logger.error( + "Trying to start a transaction to %s on wrong worker", self._destination + ) + return + logger.debug("TX [%s] Starting transaction loop", self._destination) run_as_background_process( diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index ccc7f1f0d1..f33801f883 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -293,20 +293,22 @@ class FederationAckCommand(Command): Format:: - FEDERATION_ACK + FEDERATION_ACK """ NAME = "FEDERATION_ACK" - def __init__(self, token): + def __init__(self, instance_name, token): + self.instance_name = instance_name self.token = token @classmethod def from_line(cls, line): - return cls(int(line)) + instance_name, token = line.split(" ") + return cls(instance_name, int(token)) def to_line(self): - return str(self.token) + return "%s %s" % (self.instance_name, self.token) class RemovePusherCommand(Command): diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 55b3b79008..80f5df60f9 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -238,7 +238,7 @@ class ReplicationCommandHandler: federation_ack_counter.inc() if self._federation_sender: - self._federation_sender.federation_ack(cmd.token) + self._federation_sender.federation_ack(cmd.instance_name, cmd.token) async def on_REMOVE_PUSHER( self, conn: AbstractConnection, cmd: RemovePusherCommand @@ -527,7 +527,7 @@ class ReplicationCommandHandler: """Ack data for the federation stream. This allows the master to drop data stored purely in memory. """ - self.send_command(FederationAckCommand(token)) + self.send_command(FederationAckCommand(self._instance_name, token)) def send_user_sync( self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int diff --git a/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql b/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql new file mode 100644 index 0000000000..1cc2633aad --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql @@ -0,0 +1,22 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- We need to store the stream positions by instance in a sharded config world. +-- +-- We default to master as we want the column to be NOT NULL and we correctly +-- reset the instance name to match the config each time we start up. +ALTER TABLE federation_stream_position ADD COLUMN instance_name TEXT NOT NULL DEFAULT 'master'; + +CREATE UNIQUE INDEX federation_stream_position_instance ON federation_stream_position(type, instance_name); diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 379d758b5d..5e32c7aa1e 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -45,7 +45,7 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import Database, make_in_list_sql_clause from synapse.storage.engines import PostgresEngine from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -253,6 +253,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): def __init__(self, database: Database, db_conn, hs): super(StreamWorkerStore, self).__init__(database, db_conn, hs) + self._instance_name = hs.get_instance_name() + self._send_federation = hs.should_send_federation() + self._federation_shard_config = hs.config.federation.federation_shard_config + + # If we're a process that sends federation we may need to reset the + # `federation_stream_position` table to match the current sharding + # config. We don't do this now as otherwise two processes could conflict + # during startup which would cause one to die. + self._need_to_reset_federation_stream_positions = self._send_federation + events_max = self.get_room_max_stream_ordering() event_cache_prefill, min_event_val = self.db.get_cache_dict( db_conn, @@ -793,22 +803,95 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, events - def get_federation_out_pos(self, typ): - return self.db.simple_select_one_onecol( + async def get_federation_out_pos(self, typ: str) -> int: + if self._need_to_reset_federation_stream_positions: + await self.db.runInteraction( + "_reset_federation_positions_txn", self._reset_federation_positions_txn + ) + self._need_to_reset_federation_stream_positions = False + + return await self.db.simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", - keyvalues={"type": typ}, + keyvalues={"type": typ, "instance_name": self._instance_name}, desc="get_federation_out_pos", ) - def update_federation_out_pos(self, typ, stream_id): - return self.db.simple_update_one( + async def update_federation_out_pos(self, typ, stream_id): + if self._need_to_reset_federation_stream_positions: + await self.db.runInteraction( + "_reset_federation_positions_txn", self._reset_federation_positions_txn + ) + self._need_to_reset_federation_stream_positions = False + + return await self.db.simple_update_one( table="federation_stream_position", - keyvalues={"type": typ}, + keyvalues={"type": typ, "instance_name": self._instance_name}, updatevalues={"stream_id": stream_id}, desc="update_federation_out_pos", ) + def _reset_federation_positions_txn(self, txn): + """Fiddles with the `federation_stream_position` table to make it match + the configured federation sender instances during start up. + """ + + # The federation sender instances may have changed, so we need to + # massage the `federation_stream_position` table to have a row per type + # per instance sending federation. If there is a mismatch we update the + # table with the correct rows using the *minimum* stream ID seen. This + # may result in resending of events/EDUs to remote servers, but that is + # preferable to dropping them. + + if not self._send_federation: + return + + # Pull out the configured instances. If we don't have a shard config then + # we assume that we're the only instance sending. + configured_instances = self._federation_shard_config.instances + if not configured_instances: + configured_instances = [self._instance_name] + elif self._instance_name not in configured_instances: + return + + instances_in_table = self.db.simple_select_onecol_txn( + txn, + table="federation_stream_position", + keyvalues={}, + retcol="instance_name", + ) + + if set(instances_in_table) == set(configured_instances): + # Nothing to do + return + + sql = """ + SELECT type, MIN(stream_id) FROM federation_stream_position + GROUP BY type + """ + txn.execute(sql) + min_positions = dict(txn) # Map from type -> min position + + # Ensure we do actually have some values here + assert set(min_positions) == {"federation", "events"} + + sql = """ + DELETE FROM federation_stream_position + WHERE NOT (%s) + """ + clause, args = make_in_list_sql_clause( + txn.database_engine, "instance_name", configured_instances + ) + txn.execute(sql % (clause,), args) + + for typ, stream_id in min_positions.items(): + self.db.simple_upsert_txn( + txn, + table="federation_stream_position", + keyvalues={"type": typ, "instance_name": self._instance_name}, + values={"stream_id": stream_id}, + ) + def has_room_changed_since(self, room_id, stream_id): return self._events_stream_cache.has_entity_changed(room_id, stream_id) diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 5448d9f0dc..23be1167a3 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -32,6 +32,7 @@ class FederationAckTestCase(HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer) + return hs def test_federation_ack_sent(self): diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py new file mode 100644 index 0000000000..519a2dc510 --- /dev/null +++ b/tests/replication/test_federation_sender_shard.py @@ -0,0 +1,286 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from mock import Mock + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.app.generic_worker import GenericWorkerServer +from synapse.events.builder import EventBuilderFactory +from synapse.replication.http import streams +from synapse.replication.tcp.handler import ReplicationCommandHandler +from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.rest.admin import register_servlets_for_client_rest_resource +from synapse.rest.client.v1 import login, room +from synapse.types import UserID + +from tests import unittest +from tests.server import FakeTransport + +logger = logging.getLogger(__name__) + + +class BaseStreamTestCase(unittest.HomeserverTestCase): + """Base class for tests of the replication streams""" + + servlets = [ + streams.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + # build a replication server + self.server_factory = ReplicationStreamProtocolFactory(hs) + self.streamer = hs.get_replication_streamer() + + store = hs.get_datastore() + self.database = store.db + + self.reactor.lookups["testserv"] = "1.2.3.4" + + def default_config(self): + conf = super().default_config() + conf["send_federation"] = False + return conf + + def make_worker_hs(self, extra_config={}): + config = self._get_worker_hs_config() + config.update(extra_config) + + mock_federation_client = Mock(spec=["put_json"]) + mock_federation_client.put_json.side_effect = lambda *_, **__: defer.succeed({}) + + worker_hs = self.setup_test_homeserver( + http_client=mock_federation_client, + homeserverToUse=GenericWorkerServer, + config=config, + reactor=self.reactor, + ) + + store = worker_hs.get_datastore() + store.db._db_pool = self.database._db_pool + + repl_handler = ReplicationCommandHandler(worker_hs) + client = ClientReplicationStreamProtocol( + worker_hs, "client", "test", self.clock, repl_handler, + ) + server = self.server_factory.buildProtocol(None) + + client_transport = FakeTransport(server, self.reactor) + client.makeConnection(client_transport) + + server_transport = FakeTransport(client, self.reactor) + server.makeConnection(server_transport) + + return worker_hs + + def _get_worker_hs_config(self) -> dict: + config = self.default_config() + config["worker_app"] = "synapse.app.federation_sender" + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + return config + + def replicate(self): + """Tell the master side of replication that something has happened, and then + wait for the replication to occur. + """ + self.streamer.on_notifier_poke() + self.pump() + + def create_room_with_remote_server(self, user, token, remote_server="other_server"): + room = self.helper.create_room_as(user, tok=token) + store = self.hs.get_datastore() + federation = self.hs.get_handlers().federation_handler + + prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) + room_version = self.get_success(store.get_room_version(room)) + + factory = EventBuilderFactory(self.hs) + factory.hostname = remote_server + + user_id = UserID("user", remote_server).to_string() + + event_dict = { + "type": EventTypes.Member, + "state_key": user_id, + "content": {"membership": Membership.JOIN}, + "sender": user_id, + "room_id": room, + } + + builder = factory.for_room_version(room_version, event_dict) + join_event = self.get_success(builder.build(prev_event_ids)) + + self.get_success(federation.on_send_join_request(remote_server, join_event)) + self.replicate() + + return room + + +class FederationSenderTestCase(BaseStreamTestCase): + servlets = [ + login.register_servlets, + register_servlets_for_client_rest_resource, + room.register_servlets, + ] + + def test_send_event_single_sender(self): + """Test that using a single federation sender worker correctly sends a + new event. + """ + worker_hs = self.make_worker_hs({"send_federation": True}) + mock_client = worker_hs.get_http_client() + + user = self.register_user("user", "pass") + token = self.login("user", "pass") + + room = self.create_room_with_remote_server(user, token) + + mock_client.put_json.reset_mock() + + self.create_and_send_event(room, UserID.from_string(user)) + self.replicate() + + # Assert that the event was sent out over federation. + mock_client.put_json.assert_called() + self.assertEqual(mock_client.put_json.call_args[0][0], "other_server") + self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus")) + + def test_send_event_sharded(self): + """Test that using two federation sender workers correctly sends + new events. + """ + worker1 = self.make_worker_hs( + { + "send_federation": True, + "worker_name": "sender1", + "federation_sender_instances": ["sender1", "sender2"], + } + ) + mock_client1 = worker1.get_http_client() + + worker2 = self.make_worker_hs( + { + "send_federation": True, + "worker_name": "sender2", + "federation_sender_instances": ["sender1", "sender2"], + } + ) + mock_client2 = worker2.get_http_client() + + user = self.register_user("user2", "pass") + token = self.login("user2", "pass") + + sent_on_1 = False + sent_on_2 = False + for i in range(20): + server_name = "other_server_%d" % (i,) + room = self.create_room_with_remote_server(user, token, server_name) + mock_client1.reset_mock() + mock_client2.reset_mock() + + self.create_and_send_event(room, UserID.from_string(user)) + self.replicate() + + if mock_client1.put_json.called: + sent_on_1 = True + mock_client2.put_json.assert_not_called() + self.assertEqual(mock_client1.put_json.call_args[0][0], server_name) + self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("pdus")) + elif mock_client2.put_json.called: + sent_on_2 = True + mock_client1.put_json.assert_not_called() + self.assertEqual(mock_client2.put_json.call_args[0][0], server_name) + self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("pdus")) + else: + raise AssertionError( + "Expected send transaction from one or the other sender" + ) + + if sent_on_1 and sent_on_2: + break + + self.assertTrue(sent_on_1) + self.assertTrue(sent_on_2) + + def test_send_typing_sharded(self): + """Test that using two federation sender workers correctly sends + new typing EDUs. + """ + worker1 = self.make_worker_hs( + { + "send_federation": True, + "worker_name": "sender1", + "federation_sender_instances": ["sender1", "sender2"], + } + ) + mock_client1 = worker1.get_http_client() + + worker2 = self.make_worker_hs( + { + "send_federation": True, + "worker_name": "sender2", + "federation_sender_instances": ["sender1", "sender2"], + } + ) + mock_client2 = worker2.get_http_client() + + user = self.register_user("user3", "pass") + token = self.login("user3", "pass") + + typing_handler = self.hs.get_typing_handler() + + sent_on_1 = False + sent_on_2 = False + for i in range(20): + server_name = "other_server_%d" % (i,) + room = self.create_room_with_remote_server(user, token, server_name) + mock_client1.reset_mock() + mock_client2.reset_mock() + + self.get_success( + typing_handler.started_typing( + target_user=UserID.from_string(user), + auth_user=UserID.from_string(user), + room_id=room, + timeout=20000, + ) + ) + + self.replicate() + + if mock_client1.put_json.called: + sent_on_1 = True + mock_client2.put_json.assert_not_called() + self.assertEqual(mock_client1.put_json.call_args[0][0], server_name) + self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("edus")) + elif mock_client2.put_json.called: + sent_on_2 = True + mock_client1.put_json.assert_not_called() + self.assertEqual(mock_client2.put_json.call_args[0][0], server_name) + self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("edus")) + else: + raise AssertionError( + "Expected send transaction from one or the other sender" + ) + + if sent_on_1 and sent_on_2: + break + + self.assertTrue(sent_on_1) + self.assertTrue(sent_on_2) -- cgit 1.5.1 From f2e38ca86711a8f80cf45d3182e426ed8967fc81 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 16 Jul 2020 15:12:54 +0100 Subject: Allow moving typing off master (#7869) --- changelog.d/7869.feature | 1 + synapse/app/generic_worker.py | 36 +---- synapse/config/workers.py | 19 +-- synapse/federation/federation_server.py | 125 +++++++++------- synapse/handlers/typing.py | 241 +++++++++++++++++++++---------- synapse/replication/tcp/handler.py | 9 ++ synapse/replication/tcp/streams/_base.py | 7 +- synapse/rest/client/v1/room.py | 9 ++ synapse/server.py | 13 +- synapse/server.pyi | 2 + 10 files changed, 284 insertions(+), 178 deletions(-) create mode 100644 changelog.d/7869.feature (limited to 'synapse/replication/tcp') diff --git a/changelog.d/7869.feature b/changelog.d/7869.feature new file mode 100644 index 0000000000..1982049a52 --- /dev/null +++ b/changelog.d/7869.feature @@ -0,0 +1 @@ +Add experimental support for moving typing off master. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index e90695f026..c0853eef22 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -111,6 +111,7 @@ from synapse.rest.client.v1.room import ( RoomSendEventRestServlet, RoomStateEventRestServlet, RoomStateRestServlet, + RoomTypingRestServlet, ) from synapse.rest.client.v1.voip import VoipRestServlet from synapse.rest.client.v2_alpha import groups, sync, user_directory @@ -451,37 +452,6 @@ class GenericWorkerPresence(BasePresenceHandler): await self._bump_active_client(user_id=user_id) -class GenericWorkerTyping(object): - def __init__(self, hs): - self._latest_room_serial = 0 - self._reset() - - def _reset(self): - """ - Reset the typing handler's data caches. - """ - # map room IDs to serial numbers - self._room_serials = {} - # map room IDs to sets of users currently typing - self._room_typing = {} - - def process_replication_rows(self, token, rows): - if self._latest_room_serial > token: - # The master has gone backwards. To prevent inconsistent data, just - # clear everything. - self._reset() - - # Set the latest serial token to whatever the server gave us. - self._latest_room_serial = token - - for row in rows: - self._room_serials[row.room_id] = token - self._room_typing[row.room_id] = row.user_ids - - def get_current_token(self) -> int: - return self._latest_room_serial - - class GenericWorkerSlavedStore( # FIXME(#3714): We need to add UserDirectoryStore as we write directly # rather than going via the correct worker. @@ -558,6 +528,7 @@ class GenericWorkerServer(HomeServer): KeyUploadServlet(self).register(resource) AccountDataServlet(self).register(resource) RoomAccountDataServlet(self).register(resource) + RoomTypingRestServlet(self).register(resource) sync.register_servlets(self, resource) events.register_servlets(self, resource) @@ -669,9 +640,6 @@ class GenericWorkerServer(HomeServer): def build_presence_handler(self): return GenericWorkerPresence(self) - def build_typing_handler(self): - return GenericWorkerTyping(self) - class GenericWorkerReplicationHandler(ReplicationDataHandler): def __init__(self, hs): diff --git a/synapse/config/workers.py b/synapse/config/workers.py index dbc661630c..2574cd3aa1 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -34,9 +34,11 @@ class WriterLocations: Attributes: events: The instance that writes to the event and backfill streams. + events: The instance that writes to the typing stream. """ events = attr.ib(default="master", type=str) + typing = attr.ib(default="master", type=str) class WorkerConfig(Config): @@ -93,16 +95,15 @@ class WorkerConfig(Config): writers = config.get("stream_writers") or {} self.writers = WriterLocations(**writers) - # Check that the configured writer for events also appears in + # Check that the configured writer for events and typing also appears in # `instance_map`. - if ( - self.writers.events != "master" - and self.writers.events not in self.instance_map - ): - raise ConfigError( - "Instance %r is configured to write events but does not appear in `instance_map` config." - % (self.writers.events,) - ) + for stream in ("events", "typing"): + instance = getattr(self.writers, stream) + if instance != "master" and instance not in self.instance_map: + raise ConfigError( + "Instance %r is configured to write %s but does not appear in `instance_map` config." + % (instance, stream) + ) def read_arguments(self, args): # We support a bunch of command line arguments that override options in diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8c53330c49..23625ba995 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -15,7 +15,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + Match, + Optional, + Tuple, + Union, +) from canonicaljson import json from prometheus_client import Counter, Histogram @@ -56,6 +67,9 @@ from synapse.util import glob_to_regex, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache +if TYPE_CHECKING: + from synapse.server import HomeServer + # when processing incoming transactions, we try to handle multiple rooms in # parallel, up to this limit. TRANSACTION_CONCURRENCY_LIMIT = 10 @@ -768,11 +782,30 @@ class FederationHandlerRegistry(object): query type for incoming federation traffic. """ - def __init__(self): - self.edu_handlers = {} - self.query_handlers = {} + def __init__(self, hs: "HomeServer"): + self.config = hs.config + self.http_client = hs.get_simple_http_client() + self.clock = hs.get_clock() + self._instance_name = hs.get_instance_name() - def register_edu_handler(self, edu_type: str, handler: Callable[[str, dict], None]): + # These are safe to load in monolith mode, but will explode if we try + # and use them. However we have guards before we use them to ensure that + # we don't route to ourselves, and in monolith mode that will always be + # the case. + self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs) + self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs) + + self.edu_handlers = ( + {} + ) # type: Dict[str, Callable[[str, dict], Awaitable[None]]] + self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]] + + # Map from type to instance name that we should route EDU handling to. + self._edu_type_to_instance = {} # type: Dict[str, str] + + def register_edu_handler( + self, edu_type: str, handler: Callable[[str, dict], Awaitable[None]] + ): """Sets the handler callable that will be used to handle an incoming federation EDU of the given type. @@ -809,66 +842,56 @@ class FederationHandlerRegistry(object): self.query_handlers[query_type] = handler + def register_instance_for_edu(self, edu_type: str, instance_name: str): + """Register that the EDU handler is on a different instance than master. + """ + self._edu_type_to_instance[edu_type] = instance_name + async def on_edu(self, edu_type: str, origin: str, content: dict): + if not self.config.use_presence and edu_type == "m.presence": + return + + # Check if we have a handler on this instance handler = self.edu_handlers.get(edu_type) - if not handler: - logger.warning("No handler registered for EDU type %s", edu_type) + if handler: + with start_active_span_from_edu(content, "handle_edu"): + try: + await handler(origin, content) + except SynapseError as e: + logger.info("Failed to handle edu %r: %r", edu_type, e) + except Exception: + logger.exception("Failed to handle edu %r", edu_type) return - with start_active_span_from_edu(content, "handle_edu"): + # Check if we can route it somewhere else that isn't us + route_to = self._edu_type_to_instance.get(edu_type, "master") + if route_to != self._instance_name: try: - await handler(origin, content) + await self._send_edu( + instance_name=route_to, + edu_type=edu_type, + origin=origin, + content=content, + ) except SynapseError as e: logger.info("Failed to handle edu %r: %r", edu_type, e) except Exception: logger.exception("Failed to handle edu %r", edu_type) - - def on_query(self, query_type: str, args: dict) -> defer.Deferred: - handler = self.query_handlers.get(query_type) - if not handler: - logger.warning("No handler registered for query type %s", query_type) - raise NotFoundError("No handler for Query type '%s'" % (query_type,)) - - return handler(args) - - -class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): - """A FederationHandlerRegistry for worker processes. - - When receiving EDU or queries it will check if an appropriate handler has - been registered on the worker, if there isn't one then it calls off to the - master process. - """ - - def __init__(self, hs): - self.config = hs.config - self.http_client = hs.get_simple_http_client() - self.clock = hs.get_clock() - - self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs) - self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs) - - super(ReplicationFederationHandlerRegistry, self).__init__() - - async def on_edu(self, edu_type: str, origin: str, content: dict): - """Overrides FederationHandlerRegistry - """ - if not self.config.use_presence and edu_type == "m.presence": return - handler = self.edu_handlers.get(edu_type) - if handler: - return await super(ReplicationFederationHandlerRegistry, self).on_edu( - edu_type, origin, content - ) - - return await self._send_edu(edu_type=edu_type, origin=origin, content=content) + # Oh well, let's just log and move on. + logger.warning("No handler registered for EDU type %s", edu_type) async def on_query(self, query_type: str, args: dict): - """Overrides FederationHandlerRegistry - """ handler = self.query_handlers.get(query_type) if handler: return await handler(args) - return await self._get_query_client(query_type=query_type, args=args) + # Check if we can route it somewhere else that isn't us + if self._instance_name == "master": + return await self._get_query_client(query_type=query_type, args=args) + + # Uh oh, no handler! Let's raise an exception so the request returns an + # error. + logger.warning("No handler registered for query type %s", query_type) + raise NotFoundError("No handler for Query type '%s'" % (query_type,)) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 846ddbdc6c..a86ac0150e 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -15,15 +15,19 @@ import logging from collections import namedtuple -from typing import List, Tuple +from typing import TYPE_CHECKING, List, Set, Tuple from synapse.api.errors import AuthError, SynapseError -from synapse.logging.context import run_in_background +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.tcp.streams import TypingStream from synapse.types import UserID, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -39,48 +43,48 @@ FEDERATION_TIMEOUT = 60 * 1000 FEDERATION_PING_INTERVAL = 40 * 1000 -class TypingHandler(object): - def __init__(self, hs): +class FollowerTypingHandler: + """A typing handler on a different process than the writer that is updated + via replication. + """ + + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.server_name = hs.config.server_name - self.auth = hs.get_auth() - self.is_mine_id = hs.is_mine_id - self.notifier = hs.get_notifier() - self.state = hs.get_state_handler() - - self.hs = hs - self.clock = hs.get_clock() - self.wheel_timer = WheelTimer(bucket_size=5000) + self.is_mine_id = hs.is_mine_id - self.federation = hs.get_federation_sender() + self.federation = None + if hs.should_send_federation(): + self.federation = hs.get_federation_sender() - hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu) + if hs.config.worker.writers.typing != hs.get_instance_name(): + hs.get_federation_registry().register_instance_for_edu( + "m.typing", hs.config.worker.writers.typing, + ) - hs.get_distributor().observe("user_left_room", self.user_left_room) + # map room IDs to serial numbers + self._room_serials = {} + # map room IDs to sets of users currently typing + self._room_typing = {} - self._member_typing_until = {} # clock time we expect to stop self._member_last_federation_poke = {} - + self.wheel_timer = WheelTimer(bucket_size=5000) self._latest_room_serial = 0 - self._reset() - - # caches which room_ids changed at which serials - self._typing_stream_change_cache = StreamChangeCache( - "TypingStreamChangeCache", self._latest_room_serial - ) self.clock.looping_call(self._handle_timeouts, 5000) def _reset(self): - """ - Reset the typing handler's data caches. + """Reset the typing handler's data caches. """ # map room IDs to serial numbers self._room_serials = {} # map room IDs to sets of users currently typing self._room_typing = {} + self._member_last_federation_poke = {} + self.wheel_timer = WheelTimer(bucket_size=5000) + def _handle_timeouts(self): logger.debug("Checking for typing timeouts") @@ -89,30 +93,140 @@ class TypingHandler(object): members = set(self.wheel_timer.fetch(now)) for member in members: - if not self.is_typing(member): - # Nothing to do if they're no longer typing - continue - - until = self._member_typing_until.get(member, None) - if not until or until <= now: - logger.info("Timing out typing for: %s", member.user_id) - self._stopped_typing(member) - continue - - # Check if we need to resend a keep alive over federation for this - # user. - if self.hs.is_mine_id(member.user_id): - last_fed_poke = self._member_last_federation_poke.get(member, None) - if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: - run_in_background(self._push_remote, member=member, typing=True) - - # Add a paranoia timer to ensure that we always have a timer for - # each person typing. - self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) + self._handle_timeout_for_member(now, member) + + def _handle_timeout_for_member(self, now: int, member: RoomMember): + if not self.is_typing(member): + # Nothing to do if they're no longer typing + return + + # Check if we need to resend a keep alive over federation for this + # user. + if self.federation and self.is_mine_id(member.user_id): + last_fed_poke = self._member_last_federation_poke.get(member, None) + if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: + run_as_background_process( + "typing._push_remote", self._push_remote, member=member, typing=True + ) + + # Add a paranoia timer to ensure that we always have a timer for + # each person typing. + self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) def is_typing(self, member): return member.user_id in self._room_typing.get(member.room_id, []) + async def _push_remote(self, member, typing): + if not self.federation: + return + + try: + users = await self.store.get_users_in_room(member.room_id) + self._member_last_federation_poke[member] = self.clock.time_msec() + + now = self.clock.time_msec() + self.wheel_timer.insert( + now=now, obj=member, then=now + FEDERATION_PING_INTERVAL + ) + + for domain in {get_domain_from_id(u) for u in users}: + if domain != self.server_name: + logger.debug("sending typing update to %s", domain) + self.federation.build_and_send_edu( + destination=domain, + edu_type="m.typing", + content={ + "room_id": member.room_id, + "user_id": member.user_id, + "typing": typing, + }, + key=member, + ) + except Exception: + logger.exception("Error pushing typing notif to remotes") + + def process_replication_rows( + self, token: int, rows: List[TypingStream.TypingStreamRow] + ): + """Should be called whenever we receive updates for typing stream. + """ + + if self._latest_room_serial > token: + # The master has gone backwards. To prevent inconsistent data, just + # clear everything. + self._reset() + + # Set the latest serial token to whatever the server gave us. + self._latest_room_serial = token + + for row in rows: + self._room_serials[row.room_id] = token + + prev_typing = set(self._room_typing.get(row.room_id, [])) + now_typing = set(row.user_ids) + self._room_typing[row.room_id] = row.user_ids + + run_as_background_process( + "_handle_change_in_typing", + self._handle_change_in_typing, + row.room_id, + prev_typing, + now_typing, + ) + + async def _handle_change_in_typing( + self, room_id: str, prev_typing: Set[str], now_typing: Set[str] + ): + """Process a change in typing of a room from replication, sending EDUs + for any local users. + """ + for user_id in now_typing - prev_typing: + if self.is_mine_id(user_id): + await self._push_remote(RoomMember(room_id, user_id), True) + + for user_id in prev_typing - now_typing: + if self.is_mine_id(user_id): + await self._push_remote(RoomMember(room_id, user_id), False) + + def get_current_token(self): + return self._latest_room_serial + + +class TypingWriterHandler(FollowerTypingHandler): + def __init__(self, hs): + super().__init__(hs) + + assert hs.config.worker.writers.typing == hs.get_instance_name() + + self.auth = hs.get_auth() + self.notifier = hs.get_notifier() + + self.hs = hs + + hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu) + + hs.get_distributor().observe("user_left_room", self.user_left_room) + + self._member_typing_until = {} # clock time we expect to stop + + # caches which room_ids changed at which serials + self._typing_stream_change_cache = StreamChangeCache( + "TypingStreamChangeCache", self._latest_room_serial + ) + + def _handle_timeout_for_member(self, now: int, member: RoomMember): + super()._handle_timeout_for_member(now, member) + + if not self.is_typing(member): + # Nothing to do if they're no longer typing + return + + until = self._member_typing_until.get(member, None) + if not until or until <= now: + logger.info("Timing out typing for: %s", member.user_id) + self._stopped_typing(member) + return + async def started_typing(self, target_user, auth_user, room_id, timeout): target_user_id = target_user.to_string() auth_user_id = auth_user.to_string() @@ -179,35 +293,11 @@ class TypingHandler(object): def _push_update(self, member, typing): if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. - run_in_background(self._push_remote, member, typing) - - self._push_update_local(member=member, typing=typing) - - async def _push_remote(self, member, typing): - try: - users = await self.store.get_users_in_room(member.room_id) - self._member_last_federation_poke[member] = self.clock.time_msec() - - now = self.clock.time_msec() - self.wheel_timer.insert( - now=now, obj=member, then=now + FEDERATION_PING_INTERVAL + run_as_background_process( + "typing._push_remote", self._push_remote, member, typing ) - for domain in {get_domain_from_id(u) for u in users}: - if domain != self.server_name: - logger.debug("sending typing update to %s", domain) - self.federation.build_and_send_edu( - destination=domain, - edu_type="m.typing", - content={ - "room_id": member.room_id, - "user_id": member.user_id, - "typing": typing, - }, - key=member, - ) - except Exception: - logger.exception("Error pushing typing notif to remotes") + self._push_update_local(member=member, typing=typing) async def _recv_edu(self, origin, content): room_id = content["room_id"] @@ -304,8 +394,11 @@ class TypingHandler(object): return rows, current_id, limited - def get_current_token(self): - return self._latest_room_serial + def process_replication_rows( + self, token: int, rows: List[TypingStream.TypingStreamRow] + ): + # The writing process should never get updates from replication. + raise Exception("Typing writer instance got typing info over replication") class TypingNotificationEventSource(object): diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 80f5df60f9..30d8de48fa 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -42,6 +42,7 @@ from synapse.replication.tcp.streams import ( EventsStream, FederationStream, Stream, + TypingStream, ) from synapse.util.async_helpers import Linearizer @@ -96,6 +97,14 @@ class ReplicationCommandHandler: continue + if isinstance(stream, TypingStream): + # Only add TypingStream as a source on the instance in charge of + # typing. + if hs.config.worker.writers.typing == hs.get_instance_name(): + self._streams_to_replicate.append(stream) + + continue + # Only add any other streams if we're on master. if hs.config.worker_app is not None: continue diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 9076bbe9f1..7a42de3f7d 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -294,11 +294,12 @@ class TypingStream(Stream): def __init__(self, hs): typing_handler = hs.get_typing_handler() - if hs.config.worker_app is None: - # on the master, query the typing handler + writer_instance = hs.config.worker.writers.typing + if writer_instance == hs.get_instance_name(): + # On the writer, query the typing handler update_function = typing_handler.get_all_typing_updates else: - # Query master process + # Query the typing writer process update_function = make_http_update_function(hs, self.NAME) super().__init__( diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index ea5912d4e4..26d5a51cb2 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -817,9 +817,18 @@ class RoomTypingRestServlet(RestServlet): self.typing_handler = hs.get_typing_handler() self.auth = hs.get_auth() + # If we're not on the typing writer instance we should scream if we get + # requests. + self._is_typing_writer = ( + hs.config.worker.writers.typing == hs.get_instance_name() + ) + async def on_PUT(self, request, room_id, user_id): requester = await self.auth.get_user_by_req(request) + if not self._is_typing_writer: + raise Exception("Got /typing request on instance that is not typing writer") + room_id = urlparse.unquote(room_id) target_user = UserID.from_string(urlparse.unquote(user_id)) diff --git a/synapse/server.py b/synapse/server.py index 0e6ea96b33..8e41112530 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -44,7 +44,6 @@ from synapse.federation.federation_client import FederationClient from synapse.federation.federation_server import ( FederationHandlerRegistry, FederationServer, - ReplicationFederationHandlerRegistry, ) from synapse.federation.send_queue import FederationRemoteSendQueue from synapse.federation.sender import FederationSender @@ -84,7 +83,7 @@ from synapse.handlers.room_member_worker import RoomMemberWorkerHandler from synapse.handlers.set_password import SetPasswordHandler from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler -from synapse.handlers.typing import TypingHandler +from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler from synapse.handlers.user_directory import UserDirectoryHandler from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient @@ -378,7 +377,10 @@ class HomeServer(object): return PresenceHandler(self) def build_typing_handler(self): - return TypingHandler(self) + if self.config.worker.writers.typing == self.get_instance_name(): + return TypingWriterHandler(self) + else: + return FollowerTypingHandler(self) def build_sync_handler(self): return SyncHandler(self) @@ -534,10 +536,7 @@ class HomeServer(object): return RoomMemberMasterHandler(self) def build_federation_registry(self): - if self.config.worker_app: - return ReplicationFederationHandlerRegistry(self) - else: - return FederationHandlerRegistry() + return FederationHandlerRegistry(self) def build_server_notices_manager(self): if self.config.worker_app: diff --git a/synapse/server.pyi b/synapse/server.pyi index cd50c721b8..90a673778f 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -148,3 +148,5 @@ class HomeServer(object): self, ) -> synapse.http.matrixfederationclient.MatrixFederationHttpClient: pass + def should_send_federation(self) -> bool: + pass -- cgit 1.5.1 From e5300063ede787414e23295767e3279097d7befa Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 16 Jul 2020 15:49:37 +0100 Subject: Optimise queueing of inbound replication commands (#7861) When we get behind on replication, we tend to stack up background processes behind a linearizer. Bg processes are heavy (particularly with respect to prometheus metrics) and linearizers aren't terribly efficient once the queue gets long either. A better approach is to maintain a queue of requests to be processed, and nominate a single process to work its way through the queue. Fixes: #7444 --- changelog.d/7861.misc | 1 + synapse/replication/tcp/handler.py | 331 ++++++++++++++++++++++++------------- 2 files changed, 216 insertions(+), 116 deletions(-) create mode 100644 changelog.d/7861.misc (limited to 'synapse/replication/tcp') diff --git a/changelog.d/7861.misc b/changelog.d/7861.misc new file mode 100644 index 0000000000..ada616c62f --- /dev/null +++ b/changelog.d/7861.misc @@ -0,0 +1 @@ +Optimise queueing of inbound replication commands. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 30d8de48fa..f88e0a2e40 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -14,9 +14,21 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) from prometheus_client import Counter +from typing_extensions import Deque from twisted.internet.protocol import ReconnectingClientFactory @@ -44,7 +56,6 @@ from synapse.replication.tcp.streams import ( Stream, TypingStream, ) -from synapse.util.async_helpers import Linearizer logger = logging.getLogger(__name__) @@ -62,6 +73,12 @@ invalidate_cache_counter = Counter( user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") +# the type of the entries in _command_queues_by_stream +_StreamCommandQueue = Deque[ + Tuple[Union[RdataCommand, PositionCommand], AbstractConnection] +] + + class ReplicationCommandHandler: """Handles incoming commands from replication as well as sending commands back out to connections. @@ -116,10 +133,6 @@ class ReplicationCommandHandler: self._streams_to_replicate.append(stream) - self._position_linearizer = Linearizer( - "replication_position", clock=self._clock - ) - # Map of stream name to batched updates. See RdataCommand for info on # how batching works. self._pending_batches = {} # type: Dict[str, List[Any]] @@ -131,10 +144,6 @@ class ReplicationCommandHandler: # outgoing replication commands to.) self._connections = [] # type: List[AbstractConnection] - # For each connection, the incoming stream names that are coming from - # that connection. - self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] - LaterGauge( "synapse_replication_tcp_resource_total_connections", "", @@ -142,6 +151,32 @@ class ReplicationCommandHandler: lambda: len(self._connections), ) + # When POSITION or RDATA commands arrive, we stick them in a queue and process + # them in order in a separate background process. + + # the streams which are currently being processed by _unsafe_process_stream + self._processing_streams = set() # type: Set[str] + + # for each stream, a queue of commands that are awaiting processing, and the + # connection that they arrived on. + self._command_queues_by_stream = { + stream_name: _StreamCommandQueue() for stream_name in self._streams + } + + # For each connection, the incoming stream names that have received a POSITION + # from that connection. + self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] + + LaterGauge( + "synapse_replication_tcp_command_queue", + "Number of inbound RDATA/POSITION commands queued for processing", + ["stream_name"], + lambda: { + (stream_name,): len(queue) + for stream_name, queue in self._command_queues_by_stream.items() + }, + ) + self._is_master = hs.config.worker_app is None self._federation_sender = None @@ -152,6 +187,64 @@ class ReplicationCommandHandler: if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() + async def _add_command_to_stream_queue( + self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand] + ) -> None: + """Queue the given received command for processing + + Adds the given command to the per-stream queue, and processes the queue if + necessary + """ + stream_name = cmd.stream_name + queue = self._command_queues_by_stream.get(stream_name) + if queue is None: + logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name) + return + + # if we're already processing this stream, stick the new command in the + # queue, and we're done. + if stream_name in self._processing_streams: + queue.append((cmd, conn)) + return + + # otherwise, process the new command. + + # arguably we should start off a new background process here, but nothing + # will be too upset if we don't return for ages, so let's save the overhead + # and use the existing logcontext. + + self._processing_streams.add(stream_name) + try: + # might as well skip the queue for this one, since it must be empty + assert not queue + await self._process_command(cmd, conn, stream_name) + + # now process any other commands that have built up while we were + # dealing with that one. + while queue: + cmd, conn = queue.popleft() + try: + await self._process_command(cmd, conn, stream_name) + except Exception: + logger.exception("Failed to handle command %s", cmd) + + finally: + self._processing_streams.discard(stream_name) + + async def _process_command( + self, + cmd: Union[PositionCommand, RdataCommand], + conn: AbstractConnection, + stream_name: str, + ) -> None: + if isinstance(cmd, PositionCommand): + await self._process_position(stream_name, conn, cmd) + elif isinstance(cmd, RdataCommand): + await self._process_rdata(stream_name, conn, cmd) + else: + # This shouldn't be possible + raise Exception("Unrecognised command %s in stream queue", cmd.NAME) + def start_replication(self, hs): """Helper method to start a replication connection to the remote server using TCP. @@ -285,63 +378,71 @@ class ReplicationCommandHandler: stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() - try: - row = STREAMS_MAP[stream_name].parse_row(cmd.row) - except Exception: - logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row) - raise - - # We linearize here for two reasons: + # We put the received command into a queue here for two reasons: # 1. so we don't try and concurrently handle multiple rows for the # same stream, and # 2. so we don't race with getting a POSITION command and fetching # missing RDATA. - with await self._position_linearizer.queue(cmd.stream_name): - # make sure that we've processed a POSITION for this stream *on this - # connection*. (A POSITION on another connection is no good, as there - # is no guarantee that we have seen all the intermediate updates.) - sbc = self._streams_by_connection.get(conn) - if not sbc or stream_name not in sbc: - # Let's drop the row for now, on the assumption we'll receive a - # `POSITION` soon and we'll catch up correctly then. - logger.debug( - "Discarding RDATA for unconnected stream %s -> %s", - stream_name, - cmd.token, - ) - return - - if cmd.token is None: - # I.e. this is part of a batch of updates for this stream (in - # which case batch until we get an update for the stream with a non - # None token). - self._pending_batches.setdefault(stream_name, []).append(row) - else: - # Check if this is the last of a batch of updates - rows = self._pending_batches.pop(stream_name, []) - rows.append(row) - - stream = self._streams.get(stream_name) - if not stream: - logger.error("Got RDATA for unknown stream: %s", stream_name) - return - - # Find where we previously streamed up to. - current_token = stream.current_token(cmd.instance_name) - - # Discard this data if this token is earlier than the current - # position. Note that streams can be reset (in which case you - # expect an earlier token), but that must be preceded by a - # POSITION command. - if cmd.token <= current_token: - logger.debug( - "Discarding RDATA from stream %s at position %s before previous position %s", - stream_name, - cmd.token, - current_token, - ) - else: - await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) + + await self._add_command_to_stream_queue(conn, cmd) + + async def _process_rdata( + self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand + ) -> None: + """Process an RDATA command + + Called after the command has been popped off the queue of inbound commands + """ + try: + row = STREAMS_MAP[stream_name].parse_row(cmd.row) + except Exception as e: + raise Exception( + "Failed to parse RDATA: %r %r" % (stream_name, cmd.row) + ) from e + + # make sure that we've processed a POSITION for this stream *on this + # connection*. (A POSITION on another connection is no good, as there + # is no guarantee that we have seen all the intermediate updates.) + sbc = self._streams_by_connection.get(conn) + if not sbc or stream_name not in sbc: + # Let's drop the row for now, on the assumption we'll receive a + # `POSITION` soon and we'll catch up correctly then. + logger.debug( + "Discarding RDATA for unconnected stream %s -> %s", + stream_name, + cmd.token, + ) + return + + if cmd.token is None: + # I.e. this is part of a batch of updates for this stream (in + # which case batch until we get an update for the stream with a non + # None token). + self._pending_batches.setdefault(stream_name, []).append(row) + return + + # Check if this is the last of a batch of updates + rows = self._pending_batches.pop(stream_name, []) + rows.append(row) + + stream = self._streams[stream_name] + + # Find where we previously streamed up to. + current_token = stream.current_token(cmd.instance_name) + + # Discard this data if this token is earlier than the current + # position. Note that streams can be reset (in which case you + # expect an earlier token), but that must be preceded by a + # POSITION command. + if cmd.token <= current_token: + logger.debug( + "Discarding RDATA from stream %s at position %s before previous position %s", + stream_name, + cmd.token, + current_token, + ) + else: + await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -367,67 +468,65 @@ class ReplicationCommandHandler: logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line()) - stream_name = cmd.stream_name - stream = self._streams.get(stream_name) - if not stream: - logger.error("Got POSITION for unknown stream: %s", stream_name) - return + await self._add_command_to_stream_queue(conn, cmd) - # We protect catching up with a linearizer in case the replication - # connection reconnects under us. - with await self._position_linearizer.queue(stream_name): - # We're about to go and catch up with the stream, so remove from set - # of connected streams. - for streams in self._streams_by_connection.values(): - streams.discard(stream_name) - - # We clear the pending batches for the stream as the fetching of the - # missing updates below will fetch all rows in the batch. - self._pending_batches.pop(stream_name, []) - - # Find where we previously streamed up to. - current_token = stream.current_token(cmd.instance_name) - - # If the position token matches our current token then we're up to - # date and there's nothing to do. Otherwise, fetch all updates - # between then and now. - missing_updates = cmd.token != current_token - while missing_updates: - logger.info( - "Fetching replication rows for '%s' between %i and %i", - stream_name, - current_token, - cmd.token, - ) - ( - updates, - current_token, - missing_updates, - ) = await stream.get_updates_since( - cmd.instance_name, current_token, cmd.token - ) + async def _process_position( + self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand + ) -> None: + """Process a POSITION command - # TODO: add some tests for this + Called after the command has been popped off the queue of inbound commands + """ + stream = self._streams[stream_name] - # Some streams return multiple rows with the same stream IDs, - # which need to be processed in batches. + # We're about to go and catch up with the stream, so remove from set + # of connected streams. + for streams in self._streams_by_connection.values(): + streams.discard(stream_name) - for token, rows in _batch_updates(updates): - await self.on_rdata( - stream_name, - cmd.instance_name, - token, - [stream.parse_row(row) for row in rows], - ) + # We clear the pending batches for the stream as the fetching of the + # missing updates below will fetch all rows in the batch. + self._pending_batches.pop(stream_name, []) - logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + # Find where we previously streamed up to. + current_token = stream.current_token(cmd.instance_name) - # We've now caught up to position sent to us, notify handler. - await self._replication_data_handler.on_position( - cmd.stream_name, cmd.instance_name, cmd.token + # If the position token matches our current token then we're up to + # date and there's nothing to do. Otherwise, fetch all updates + # between then and now. + missing_updates = cmd.token != current_token + while missing_updates: + logger.info( + "Fetching replication rows for '%s' between %i and %i", + stream_name, + current_token, + cmd.token, + ) + (updates, current_token, missing_updates) = await stream.get_updates_since( + cmd.instance_name, current_token, cmd.token ) - self._streams_by_connection.setdefault(conn, set()).add(stream_name) + # TODO: add some tests for this + + # Some streams return multiple rows with the same stream IDs, + # which need to be processed in batches. + + for token, rows in _batch_updates(updates): + await self.on_rdata( + stream_name, + cmd.instance_name, + token, + [stream.parse_row(row) for row in rows], + ) + + logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + + # We've now caught up to position sent to us, notify handler. + await self._replication_data_handler.on_position( + cmd.stream_name, cmd.instance_name, cmd.token + ) + + self._streams_by_connection.setdefault(conn, set()).add(stream_name) async def on_REMOTE_SERVER_UP( self, conn: AbstractConnection, cmd: RemoteServerUpCommand -- cgit 1.5.1 From a7b06a81f02ed97975f45e0abd70b731c686fc86 Mon Sep 17 00:00:00 2001 From: Karthikeyan Singaravelan Date: Mon, 20 Jul 2020 23:03:04 +0530 Subject: Fix deprecation warning: import ABC from collections.abc (#7892) --- changelog.d/7892.misc | 1 + synapse/events/utils.py | 6 +++--- synapse/handlers/federation.py | 2 +- synapse/replication/tcp/streams/events.py | 2 +- synapse/util/stringutils.py | 2 +- 5 files changed, 7 insertions(+), 6 deletions(-) create mode 100644 changelog.d/7892.misc (limited to 'synapse/replication/tcp') diff --git a/changelog.d/7892.misc b/changelog.d/7892.misc new file mode 100644 index 0000000000..ef4cfa04fd --- /dev/null +++ b/changelog.d/7892.misc @@ -0,0 +1 @@ +Import ABC from `collections.abc` for Python 3.10 compatibility. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index f6b507977f..11f0d34ec8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections +import collections.abc import re from typing import Any, Mapping, Union @@ -424,7 +424,7 @@ def copy_power_levels_contents( Raises: TypeError if the input does not look like a valid power levels event content """ - if not isinstance(old_power_levels, collections.Mapping): + if not isinstance(old_power_levels, collections.abc.Mapping): raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,)) power_levels = {} @@ -434,7 +434,7 @@ def copy_power_levels_contents( power_levels[k] = v continue - if isinstance(v, collections.Mapping): + if isinstance(v, collections.abc.Mapping): power_levels[k] = h = {} for k1, v1 in v.items(): # we should only have one level of nesting diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index df885e45e8..71ac5dca99 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -19,7 +19,7 @@ import itertools import logging -from collections import Container +from collections.abc import Container from http import HTTPStatus from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 1c2a4cce7f..16c63ff4ec 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import heapq -from collections import Iterable +from collections.abc import Iterable from typing import List, Tuple, Type import attr diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 08c86e92b8..2e2b40a426 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -17,7 +17,7 @@ import itertools import random import re import string -from collections import Iterable +from collections.abc import Iterable from synapse.api.errors import Codes, SynapseError -- cgit 1.5.1 From 05060e02234eec533c357cbc9eb4347976c1617d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 22 Jul 2020 00:40:42 +0100 Subject: Track command processing as a background process (#7879) I'm going to be doing more stuff synchronously, and I don't want to lose the CPU metrics down the sofa. --- changelog.d/7879.feature | 1 + stubs/txredisapi.pyi | 1 + synapse/replication/tcp/protocol.py | 19 ++++++++++++++++++- synapse/replication/tcp/redis.py | 22 ++++++++++++++++++++-- 4 files changed, 40 insertions(+), 3 deletions(-) create mode 100644 changelog.d/7879.feature (limited to 'synapse/replication/tcp') diff --git a/changelog.d/7879.feature b/changelog.d/7879.feature new file mode 100644 index 0000000000..c89655f000 --- /dev/null +++ b/changelog.d/7879.feature @@ -0,0 +1 @@ +Report CPU metrics to prometheus for time spent processing replication commands. diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index cac689d4f3..c66413f003 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -22,6 +22,7 @@ class RedisProtocol: def publish(self, channel: str, message: bytes): ... class SubscriberProtocol: + def __init__(self, *args, **kwargs): ... password: Optional[str] def subscribe(self, channels: Union[str, List[str]]): ... def connectionMade(self): ... diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index ca47f5cc88..23191e3218 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -57,8 +57,12 @@ from prometheus_client import Counter from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure +from synapse.logging.context import PreserveLoggingContext from synapse.metrics import LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + BackgroundProcessLoggingContext, + run_as_background_process, +) from synapse.replication.tcp.commands import ( VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, @@ -160,6 +164,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # The LoopingCall for sending pings. self._send_ping_loop = None + # a logcontext which we use for processing incoming commands. We declare it as a + # background process so that the CPU stats get reported to prometheus. + self._logging_context = BackgroundProcessLoggingContext( + "replication_command_handler-%s" % self.conn_id + ) + def connectionMade(self): logger.info("[%s] Connection established", self.id()) @@ -210,6 +220,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def lineReceived(self, line: bytes): """Called when we've received a line """ + with PreserveLoggingContext(self._logging_context): + self._parse_and_dispatch_line(line) + + def _parse_and_dispatch_line(self, line: bytes): if line.strip() == "": # Ignore blank lines return @@ -397,6 +411,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if self.transport: self.transport.unregisterProducer() + # mark the logging context as finished + self._logging_context.__exit__(None, None, None) + def __str__(self): addr = None if self.transport: diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 0a7e7f67be..b5c533a607 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -18,8 +18,11 @@ from typing import TYPE_CHECKING import txredisapi -from synapse.logging.context import make_deferred_yieldable -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.metrics.background_process_metrics import ( + BackgroundProcessLoggingContext, + run_as_background_process, +) from synapse.replication.tcp.commands import ( Command, ReplicateCommand, @@ -66,6 +69,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): stream_name = None # type: str outbound_redis_connection = None # type: txredisapi.RedisProtocol + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # a logcontext which we use for processing incoming commands. We declare it as a + # background process so that the CPU stats get reported to prometheus. + self._logging_context = BackgroundProcessLoggingContext( + "replication_command_handler" + ) + def connectionMade(self): logger.info("Connected to redis") super().connectionMade() @@ -92,7 +104,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): def messageReceived(self, pattern: str, channel: str, message: str): """Received a message from redis. """ + with PreserveLoggingContext(self._logging_context): + self._parse_and_dispatch_message(message) + def _parse_and_dispatch_message(self, message: str): if message.strip() == "": # Ignore blank lines return @@ -145,6 +160,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): super().connectionLost(reason) self.handler.lost_connection(self) + # mark the logging context as finished + self._logging_context.__exit__(None, None, None) + def send_command(self, cmd: Command): """Send a command if connection has been established. -- cgit 1.5.1 From 931b02684481fb6b5daefd9218baf6a4b0b941f6 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 22 Jul 2020 00:40:55 +0100 Subject: Remove an unused prometheus metric (#7878) --- changelog.d/7878.removal | 1 + synapse/replication/tcp/handler.py | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) create mode 100644 changelog.d/7878.removal (limited to 'synapse/replication/tcp') diff --git a/changelog.d/7878.removal b/changelog.d/7878.removal new file mode 100644 index 0000000000..d5a4066624 --- /dev/null +++ b/changelog.d/7878.removal @@ -0,0 +1 @@ +Remove unused `synapse_replication_tcp_resource_invalidate_cache` prometheus metric. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index f88e0a2e40..1de590bba2 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -67,9 +67,7 @@ inbound_rdata_count = Counter( user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") -invalidate_cache_counter = Counter( - "synapse_replication_tcp_resource_invalidate_cache", "" -) + user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") -- cgit 1.5.1 From 84d099ae1192af0f38d26f9a32e38bd4c0ad304e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 27 Jul 2020 14:10:53 +0100 Subject: Fix typing replication not being handled on master (#7959) Handling of incoming typing stream updates from replication was not hooked up on master, effecting set ups where typing was handled on a different worker. This is really only a problem if the master process is also handling sync requests, which is unlikely for those that are at the stage of moving typing off. The other observable effect is that if a worker restarts or a replication connect drops then the typing worker will issue a `POSITION typing`, triggering master process to try and stream *all* typing updates from position 0. Fixes #7907 --- changelog.d/7959.bugfix | 1 + synapse/app/generic_worker.py | 7 ------- synapse/replication/tcp/client.py | 8 ++++++++ synapse/server.pyi | 3 +++ 4 files changed, 12 insertions(+), 7 deletions(-) create mode 100644 changelog.d/7959.bugfix (limited to 'synapse/replication/tcp') diff --git a/changelog.d/7959.bugfix b/changelog.d/7959.bugfix new file mode 100644 index 0000000000..1982049a52 --- /dev/null +++ b/changelog.d/7959.bugfix @@ -0,0 +1 @@ +Add experimental support for moving typing off master. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index c1b76d827b..ec0dbddb8c 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -87,7 +87,6 @@ from synapse.replication.tcp.streams import ( ReceiptsStream, TagAccountDataStream, ToDeviceStream, - TypingStream, ) from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.client.v1 import events @@ -644,7 +643,6 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): super(GenericWorkerReplicationHandler, self).__init__(hs) self.store = hs.get_datastore() - self.typing_handler = hs.get_typing_handler() self.presence_handler = hs.get_presence_handler() # type: GenericWorkerPresence self.notifier = hs.get_notifier() @@ -681,11 +679,6 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler): await self.pusher_pool.on_new_receipts( token, token, {row.room_id for row in rows} ) - elif stream_name == TypingStream.NAME: - self.typing_handler.process_replication_rows(token, rows) - self.notifier.on_new_event( - "typing_key", token, rooms=[row.room_id for row in rows] - ) elif stream_name == ToDeviceStream.NAME: entities = [row.entity for row in rows if row.entity.startswith("@")] if entities: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 4985e40b1f..fcf8ebf1e7 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -24,6 +24,7 @@ from twisted.internet.protocol import ReconnectingClientFactory from synapse.api.constants import EventTypes from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol +from synapse.replication.tcp.streams import TypingStream from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamEventRow, @@ -104,6 +105,7 @@ class ReplicationDataHandler: self._clock = hs.get_clock() self._streams = hs.get_replication_streams() self._instance_name = hs.get_instance_name() + self._typing_handler = hs.get_typing_handler() # Map from stream to list of deferreds waiting for the stream to # arrive at a particular position. The lists are sorted by stream position. @@ -127,6 +129,12 @@ class ReplicationDataHandler: """ self.store.process_replication_rows(stream_name, instance_name, token, rows) + if stream_name == TypingStream.NAME: + self._typing_handler.process_replication_rows(token, rows) + self.notifier.on_new_event( + "typing_key", token, rooms=[row.room_id for row in rows] + ) + if stream_name == EventsStream.NAME: # We shouldn't get multiple rows per token for events stream, so # we don't need to optimise this for multiple rows. diff --git a/synapse/server.pyi b/synapse/server.pyi index 90a673778f..1aba408c21 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -31,6 +31,7 @@ import synapse.server_notices.server_notices_sender import synapse.state import synapse.storage from synapse.events.builder import EventBuilderFactory +from synapse.handlers.typing import FollowerTypingHandler from synapse.replication.tcp.streams import Stream class HomeServer(object): @@ -150,3 +151,5 @@ class HomeServer(object): pass def should_send_federation(self) -> bool: pass + def get_typing_handler(self) -> FollowerTypingHandler: + pass -- cgit 1.5.1 From f57b99af22de874b11f44ef32c1f1425ec1344b9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 27 Jul 2020 18:54:43 +0100 Subject: Handle replication commands synchronously where possible (#7876) Most of the stuff we do for replication commands can be done synchronously. There's no point spinning up background processes if we're not going to need them. --- changelog.d/7876.bugfix | 1 + changelog.d/7876.misc | 1 + synapse/replication/tcp/handler.py | 115 +++++++++++++++++++++--------------- synapse/replication/tcp/protocol.py | 45 ++++++++------ synapse/replication/tcp/redis.py | 37 ++++++------ 5 files changed, 113 insertions(+), 86 deletions(-) create mode 100644 changelog.d/7876.bugfix create mode 100644 changelog.d/7876.misc (limited to 'synapse/replication/tcp') diff --git a/changelog.d/7876.bugfix b/changelog.d/7876.bugfix new file mode 100644 index 0000000000..4ba2fadd58 --- /dev/null +++ b/changelog.d/7876.bugfix @@ -0,0 +1 @@ +Fix an `AssertionError` exception introduced in v1.18.0rc1. diff --git a/changelog.d/7876.misc b/changelog.d/7876.misc new file mode 100644 index 0000000000..5c78a158cd --- /dev/null +++ b/changelog.d/7876.misc @@ -0,0 +1 @@ +Further optimise queueing of inbound replication commands. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 1de590bba2..1c303f3a46 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -16,6 +16,7 @@ import logging from typing import ( Any, + Awaitable, Dict, Iterable, Iterator, @@ -33,6 +34,7 @@ from typing_extensions import Deque from twisted.internet.protocol import ReconnectingClientFactory from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.client import DirectTcpReplicationClientFactory from synapse.replication.tcp.commands import ( ClearUserSyncsCommand, @@ -152,7 +154,7 @@ class ReplicationCommandHandler: # When POSITION or RDATA commands arrive, we stick them in a queue and process # them in order in a separate background process. - # the streams which are currently being processed by _unsafe_process_stream + # the streams which are currently being processed by _unsafe_process_queue self._processing_streams = set() # type: Set[str] # for each stream, a queue of commands that are awaiting processing, and the @@ -185,7 +187,7 @@ class ReplicationCommandHandler: if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() - async def _add_command_to_stream_queue( + def _add_command_to_stream_queue( self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand] ) -> None: """Queue the given received command for processing @@ -199,33 +201,34 @@ class ReplicationCommandHandler: logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name) return - # if we're already processing this stream, stick the new command in the - # queue, and we're done. + queue.append((cmd, conn)) + + # if we're already processing this stream, there's nothing more to do: + # the new entry on the queue will get picked up in due course if stream_name in self._processing_streams: - queue.append((cmd, conn)) return - # otherwise, process the new command. + # fire off a background process to start processing the queue. + run_as_background_process( + "process-replication-data", self._unsafe_process_queue, stream_name + ) - # arguably we should start off a new background process here, but nothing - # will be too upset if we don't return for ages, so let's save the overhead - # and use the existing logcontext. + async def _unsafe_process_queue(self, stream_name: str): + """Processes the command queue for the given stream, until it is empty + + Does not check if there is already a thread processing the queue, hence "unsafe" + """ + assert stream_name not in self._processing_streams self._processing_streams.add(stream_name) try: - # might as well skip the queue for this one, since it must be empty - assert not queue - await self._process_command(cmd, conn, stream_name) - - # now process any other commands that have built up while we were - # dealing with that one. + queue = self._command_queues_by_stream.get(stream_name) while queue: cmd, conn = queue.popleft() try: await self._process_command(cmd, conn, stream_name) except Exception: logger.exception("Failed to handle command %s", cmd) - finally: self._processing_streams.discard(stream_name) @@ -299,7 +302,7 @@ class ReplicationCommandHandler: """ return self._streams_to_replicate - async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): + def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): self.send_positions_to_connection(conn) def send_positions_to_connection(self, conn: AbstractConnection): @@ -318,57 +321,73 @@ class ReplicationCommandHandler: ) ) - async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand): + def on_USER_SYNC( + self, conn: AbstractConnection, cmd: UserSyncCommand + ) -> Optional[Awaitable[None]]: user_sync_counter.inc() if self._is_master: - await self._presence_handler.update_external_syncs_row( + return self._presence_handler.update_external_syncs_row( cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) + else: + return None - async def on_CLEAR_USER_SYNC( + def on_CLEAR_USER_SYNC( self, conn: AbstractConnection, cmd: ClearUserSyncsCommand - ): + ) -> Optional[Awaitable[None]]: if self._is_master: - await self._presence_handler.update_external_syncs_clear(cmd.instance_id) + return self._presence_handler.update_external_syncs_clear(cmd.instance_id) + else: + return None - async def on_FEDERATION_ACK( - self, conn: AbstractConnection, cmd: FederationAckCommand - ): + def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand): federation_ack_counter.inc() if self._federation_sender: self._federation_sender.federation_ack(cmd.instance_name, cmd.token) - async def on_REMOVE_PUSHER( + def on_REMOVE_PUSHER( self, conn: AbstractConnection, cmd: RemovePusherCommand - ): + ) -> Optional[Awaitable[None]]: remove_pusher_counter.inc() if self._is_master: - await self._store.delete_pusher_by_app_id_pushkey_user_id( - app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id - ) + return self._handle_remove_pusher(cmd) + else: + return None + + async def _handle_remove_pusher(self, cmd: RemovePusherCommand): + await self._store.delete_pusher_by_app_id_pushkey_user_id( + app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id + ) - self._notifier.on_new_replication_data() + self._notifier.on_new_replication_data() - async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand): + def on_USER_IP( + self, conn: AbstractConnection, cmd: UserIpCommand + ) -> Optional[Awaitable[None]]: user_ip_cache_counter.inc() if self._is_master: - await self._store.insert_client_ip( - cmd.user_id, - cmd.access_token, - cmd.ip, - cmd.user_agent, - cmd.device_id, - cmd.last_seen, - ) + return self._handle_user_ip(cmd) + else: + return None + + async def _handle_user_ip(self, cmd: UserIpCommand): + await self._store.insert_client_ip( + cmd.user_id, + cmd.access_token, + cmd.ip, + cmd.user_agent, + cmd.device_id, + cmd.last_seen, + ) - if self._server_notices_sender: - await self._server_notices_sender.on_user_ip(cmd.user_id) + assert self._server_notices_sender is not None + await self._server_notices_sender.on_user_ip(cmd.user_id) - async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): + def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): if cmd.instance_name == self._instance_name: # Ignore RDATA that are just our own echoes return @@ -382,7 +401,7 @@ class ReplicationCommandHandler: # 2. so we don't race with getting a POSITION command and fetching # missing RDATA. - await self._add_command_to_stream_queue(conn, cmd) + self._add_command_to_stream_queue(conn, cmd) async def _process_rdata( self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand @@ -459,14 +478,14 @@ class ReplicationCommandHandler: stream_name, instance_name, token, rows ) - async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): + def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): if cmd.instance_name == self._instance_name: # Ignore POSITION that are just our own echoes return logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line()) - await self._add_command_to_stream_queue(conn, cmd) + self._add_command_to_stream_queue(conn, cmd) async def _process_position( self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand @@ -526,9 +545,7 @@ class ReplicationCommandHandler: self._streams_by_connection.setdefault(conn, set()).add(stream_name) - async def on_REMOTE_SERVER_UP( - self, conn: AbstractConnection, cmd: RemoteServerUpCommand - ): + def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand): """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 23191e3218..0350923898 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -50,6 +50,7 @@ import abc import fcntl import logging import struct +from inspect import isawaitable from typing import TYPE_CHECKING, List from prometheus_client import Counter @@ -128,6 +129,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): On receiving a new command it calls `on_` with the parsed command before delegating to `ReplicationCommandHandler.on_`. + `ReplicationCommandHandler.on_` can optionally return a coroutine; + if so, that will get run as a background process. It also sends `PING` periodically, and correctly times out remote connections (if they send a `PING` command) @@ -166,9 +169,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. - self._logging_context = BackgroundProcessLoggingContext( - "replication_command_handler-%s" % self.conn_id - ) + ctx_name = "replication-conn-%s" % self.conn_id + self._logging_context = BackgroundProcessLoggingContext(ctx_name) + self._logging_context.request = ctx_name def connectionMade(self): logger.info("[%s] Connection established", self.id()) @@ -246,18 +249,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc() - # Now lets try and call on_ function - run_as_background_process( - "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd - ) + self.handle_command(cmd) - async def handle_command(self, cmd: Command): + def handle_command(self, cmd: Command) -> None: """Handle a command we have received over the replication stream. First calls `self.on_` if it exists, then calls - `self.command_handler.on_` if it exists. This allows for - protocol level handling of commands (e.g. PINGs), before delegating to - the handler. + `self.command_handler.on_` if it exists (which can optionally + return an Awaitable). + + This allows for protocol level handling of commands (e.g. PINGs), before + delegating to the handler. Args: cmd: received command @@ -268,13 +270,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # specific handling. cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) if cmd_func: - await cmd_func(cmd) + cmd_func(cmd) handled = True # Then call out to the handler. cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None) if cmd_func: - await cmd_func(self, cmd) + res = cmd_func(self, cmd) + + # the handler might be a coroutine: fire it off as a background process + # if so. + + if isawaitable(res): + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), lambda: res + ) + handled = True if not handled: @@ -350,10 +361,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): for cmd in pending: self.send_command(cmd) - async def on_PING(self, line): + def on_PING(self, line): self.received_ping = True - async def on_ERROR(self, cmd): + def on_ERROR(self, cmd): logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) def pauseProducing(self): @@ -448,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.send_command(ServerCommand(self.server_name)) super().connectionMade() - async def on_NAME(self, cmd): + def on_NAME(self, cmd): logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data @@ -477,7 +488,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # Once we've connected subscribe to the necessary streams self.replicate() - async def on_SERVER(self, cmd): + def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) self.send_error("Wrong remote") diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index b5c533a607..f225e533de 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from inspect import isawaitable from typing import TYPE_CHECKING import txredisapi @@ -124,36 +125,32 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): # remote instances. tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc() - # Now lets try and call on_ function - run_as_background_process( - "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd - ) + self.handle_command(cmd) - async def handle_command(self, cmd: Command): + def handle_command(self, cmd: Command) -> None: """Handle a command we have received over the replication stream. - By default delegates to on_, which should return an awaitable. + Delegates to `self.handler.on_` (which can optionally return an + Awaitable). Args: cmd: received command """ - handled = False - - # First call any command handlers on this instance. These are for redis - # specific handling. - cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) - if cmd_func: - await cmd_func(cmd) - handled = True - # Then call out to the handler. cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) - if cmd_func: - await cmd_func(self, cmd) - handled = True - - if not handled: + if not cmd_func: logger.warning("Unhandled command: %r", cmd) + return + + res = cmd_func(self, cmd) + + # the handler might be a coroutine: fire it off as a background process + # if so. + + if isawaitable(res): + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), lambda: res + ) def connectionLost(self, reason): logger.info("Lost connection to redis") -- cgit 1.5.1