From 313987187ee04dce5e70db17c1ab9377f283be7e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 26 Feb 2019 15:04:34 +0000 Subject: Fix tightloop over connecting to replication server If the client failed to process incoming commands during the initial set up of the replication connection it would immediately disconnect and reconnect, resulting in a tightloop. This can happen, for example, when subscribing to a stream that has a row that is too long in the backlog. The fix here is to not consider the connection successfully set up until the client has succesfully subscribed and caught up with the streams. This ensures that the retry logic timers aren't reset until then, meaning that if an error does happen during start up the client will continue backing off before retrying again. --- synapse/replication/tcp/client.py | 38 ++++++++++++++++++++++++++++++++++--- synapse/replication/tcp/commands.py | 5 ++++- 2 files changed, 39 insertions(+), 4 deletions(-) (limited to 'synapse/replication') diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 586dddb40b..914cd24b55 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -54,7 +54,6 @@ class ReplicationClientFactory(ReconnectingClientFactory): def buildProtocol(self, addr): logger.info("Connected to replication: %r", addr) - self.resetDelay() return ClientReplicationStreamProtocol( self.client_name, self.server_name, self._clock, self.handler ) @@ -90,15 +89,23 @@ class ReplicationClientHandler(object): # Used for tests. self.awaiting_syncs = {} + # Set of stream names that have been subscribe to, but haven't yet + # caught up with. This is used to track when the client has been fully + # connected to the remote. + self.streams_connecting = None + + # The factory used to create connections. + self.factory = None + def start_replication(self, hs): """Helper method to start a replication connection to the remote server using TCP. """ client_name = hs.config.worker_name - factory = ReplicationClientFactory(hs, client_name, self) + self.factory = ReplicationClientFactory(hs, client_name, self) host = hs.config.worker_replication_host port = hs.config.worker_replication_port - hs.get_reactor().connectTCP(host, port, factory) + hs.get_reactor().connectTCP(host, port, self.factory) def on_rdata(self, stream_name, token, rows): """Called when we get new replication data. By default this just pokes @@ -115,6 +122,12 @@ class ReplicationClientHandler(object): Can be overriden in subclasses to handle more. """ + # When we get a `POSITION` command it means we've finished getting + # missing updates for the given stream, and are now up to date. + self.streams_connecting.discard(stream_name) + if not self.streams_connecting: + self.finished_connecting() + return self.store.process_replication_rows(stream_name, token, []) def on_sync(self, data): @@ -140,6 +153,10 @@ class ReplicationClientHandler(object): args["account_data"] = user_account_data elif room_account_data: args["account_data"] = room_account_data + + # Record which streams we're in the process of subscribing to + self.streams_connecting = set(args.keys()) + return args def get_currently_syncing_users(self): @@ -204,3 +221,18 @@ class ReplicationClientHandler(object): for cmd in self.pending_commands: connection.send_command(cmd) self.pending_commands = [] + + # This will happen if we don't actually subscribe to any streams + if not self.streams_connecting: + self.finished_connecting() + + def finished_connecting(self): + """Called when we have successfully subscribed and caught up to all + streams we're interested in. + """ + logger.info("Finished connecting to server") + + # We don't reset the delay any earlier as otherwise if there is a + # problem during start up we'll end up tight looping connecting to the + # server. + self.factory.resetDelay() diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 327556f6a1..2098c32a77 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -127,8 +127,11 @@ class RdataCommand(Command): class PositionCommand(Command): - """Sent by the client to tell the client the stream postition without + """Sent by the server to tell the client the stream postition without needing to send an RDATA. + + Sent to the client after all missing updates for a stream have been sent + to the client and they're now up to date. """ NAME = "POSITION" -- cgit 1.5.1 From 25814921f1900e92b50830d6616762b174e82773 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 26 Feb 2019 15:12:29 +0000 Subject: Increase the max delay between retry attempts Otherwise if you have many workers they can easily take out master with their connection attempts --- synapse/replication/tcp/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/replication') diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 914cd24b55..51f90655d0 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -39,7 +39,7 @@ class ReplicationClientFactory(ReconnectingClientFactory): Accepts a handler that will be called when new data is available or data is required. """ - maxDelay = 5 # Try at least once every N seconds + maxDelay = 30 # Try at least once every N seconds def __init__(self, hs, client_name, handler): self.client_name = client_name -- cgit 1.5.1 From 6870fc496ff3da5075fec74e40515c03c929915f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 27 Feb 2019 10:22:52 +0000 Subject: Move connecting logic into ClientReplicationStreamProtocol --- synapse/replication/tcp/client.py | 18 ------------------ synapse/replication/tcp/protocol.py | 17 +++++++++++++++++ 2 files changed, 17 insertions(+), 18 deletions(-) (limited to 'synapse/replication') diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 51f90655d0..e558f90e1a 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -89,11 +89,6 @@ class ReplicationClientHandler(object): # Used for tests. self.awaiting_syncs = {} - # Set of stream names that have been subscribe to, but haven't yet - # caught up with. This is used to track when the client has been fully - # connected to the remote. - self.streams_connecting = None - # The factory used to create connections. self.factory = None @@ -122,12 +117,6 @@ class ReplicationClientHandler(object): Can be overriden in subclasses to handle more. """ - # When we get a `POSITION` command it means we've finished getting - # missing updates for the given stream, and are now up to date. - self.streams_connecting.discard(stream_name) - if not self.streams_connecting: - self.finished_connecting() - return self.store.process_replication_rows(stream_name, token, []) def on_sync(self, data): @@ -154,9 +143,6 @@ class ReplicationClientHandler(object): elif room_account_data: args["account_data"] = room_account_data - # Record which streams we're in the process of subscribing to - self.streams_connecting = set(args.keys()) - return args def get_currently_syncing_users(self): @@ -222,10 +208,6 @@ class ReplicationClientHandler(object): connection.send_command(cmd) self.pending_commands = [] - # This will happen if we don't actually subscribe to any streams - if not self.streams_connecting: - self.finished_connecting() - def finished_connecting(self): """Called when we have successfully subscribed and caught up to all streams we're interested in. diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 0b3fe6cbf5..6123c995b9 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -511,6 +511,11 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.server_name = server_name self.handler = handler + # Set of stream names that have been subscribe to, but haven't yet + # caught up with. This is used to track when the client has been fully + # connected to the remote. + self.streams_connecting = set() + # Map of stream to batched updates. See RdataCommand for info on how # batching works. self.pending_batches = {} @@ -533,6 +538,10 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # We've now finished connecting to so inform the client handler self.handler.update_connection(self) + # This will happen if we don't actually subscribe to any streams + if not self.streams_connecting: + self.handler.finished_connecting() + def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) @@ -562,6 +571,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): return self.handler.on_rdata(stream_name, cmd.token, rows) def on_POSITION(self, cmd): + # When we get a `POSITION` command it means we've finished getting + # missing updates for the given stream, and are now up to date. + self.streams_connecting.discard(cmd.stream_name) + if not self.streams_connecting: + self.handler.finished_connecting() + return self.handler.on_position(cmd.stream_name, cmd.token) def on_SYNC(self, cmd): @@ -578,6 +593,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.id(), stream_name, token ) + self.streams_connecting.add(stream_name) + self.send_command(ReplicateCommand(stream_name, token)) def on_connection_closed(self): -- cgit 1.5.1 From 1e315017d3c0dcde20e781fb3c87bc0e54c53cd1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 27 Feb 2019 13:53:46 +0000 Subject: When presence is enabled don't send over replication --- synapse/federation/federation_server.py | 3 +++ synapse/replication/slave/storage/presence.py | 7 +++++-- 2 files changed, 8 insertions(+), 2 deletions(-) (limited to 'synapse/replication') diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 569eb277a9..81f3b4b1ff 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -886,6 +886,9 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): def on_edu(self, edu_type, origin, content): """Overrides FederationHandlerRegistry """ + if not self.config.use_presence and edu_type == "m.presence": + return + handler = self.edu_handlers.get(edu_type) if handler: return super(ReplicationFederationHandlerRegistry, self).on_edu( diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 92447b00d4..9e530defe0 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -54,8 +54,11 @@ class SlavedPresenceStore(BaseSlavedStore): def stream_positions(self): result = super(SlavedPresenceStore, self).stream_positions() - position = self._presence_id_gen.get_current_token() - result["presence"] = position + + if self.hs.config.use_presence: + position = self._presence_id_gen.get_current_token() + result["presence"] = position + return result def process_replication_rows(self, stream_name, token, rows): -- cgit 1.5.1 From 5f0c449dd50fa84ff741e09f34cad5330c6e4745 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Mon, 4 Mar 2019 13:56:49 +0000 Subject: Prevent replication wedging --- synapse/replication/tcp/protocol.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) (limited to 'synapse/replication') diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 49ae5b3355..a6df04d851 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -451,7 +451,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): @defer.inlineCallbacks def subscribe_to_stream(self, stream_name, token): - """Subscribe the remote to a streams. + """Subscribe the remote to a stream. This invloves checking if they've missed anything and sending those updates down if they have. During that time new updates for the stream @@ -478,10 +478,30 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): # Now we can send any updates that came in while we were subscribing pending_rdata = self.pending_rdata.pop(stream_name, []) + batch_updates = [] for token, update in pending_rdata: - # Only send updates newer than the current token - if token > current_token: - self.send_command(RdataCommand(stream_name, token, update)) + # If the token is null, it is part of a batch update. Batches + # are multiple updates that share a single token. To denote + # this, the token is set to None for all tokens in the batch + # except for the last. If we find a None token, we keep looking + # through tokens until we find one that is not None and then + # process all previous updates in the batch as if they had the + # final token. + if not token or len(batch_updates) > 0: + batch_updates.append(update) + if token and not token > current_token: + # This batch is older than current_token, dismiss + batch_updates = [] + continue + if token: + # Send all updates that are part of this batch with the + # found token + for update in batch_updates: + self.send_command(RdataCommand(stream_name, token, update)) + else: + # Only send updates newer than the current token + if token > current_token: + self.send_command(RdataCommand(stream_name, token, update)) # They're now fully subscribed self.replication_streams.add(stream_name) -- cgit 1.5.1 From 9f7cdf3da16e4e6c29229dcc80d9cf060cd64584 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Mon, 4 Mar 2019 14:36:52 +0000 Subject: Clearer branching, fix missing list clear --- synapse/replication/tcp/protocol.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) (limited to 'synapse/replication') diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index a6df04d851..53615b7ee3 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -488,16 +488,23 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): # process all previous updates in the batch as if they had the # final token. if not token or len(batch_updates) > 0: - batch_updates.append(update) - if token and not token > current_token: + if token is None: + # Store this update as part of the batch + batch_updates.append(update) + elif current_token <= current_token: # This batch is older than current_token, dismiss batch_updates = [] - continue - if token: + else: + # Append final update of this batch before sending + batch_updates.append(update) + # Send all updates that are part of this batch with the # found token for update in batch_updates: self.send_command(RdataCommand(stream_name, token, update)) + + # Clear saved batch updates + batch_updates = [] else: # Only send updates newer than the current token if token > current_token: -- cgit 1.5.1 From fe7bd23a85988c5251fe17e78589b69f92f21dd7 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Mon, 4 Mar 2019 15:08:15 +0000 Subject: Clean up logic and add comments --- synapse/replication/tcp/protocol.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) (limited to 'synapse/replication') diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 53615b7ee3..dac4fbeef7 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -487,15 +487,19 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): # through tokens until we find one that is not None and then # process all previous updates in the batch as if they had the # final token. - if not token or len(batch_updates) > 0: - if token is None: - # Store this update as part of the batch - batch_updates.append(update) - elif current_token <= current_token: - # This batch is older than current_token, dismiss + if token is None: + # Store this update as part of a batch + batch_updates.append(update) + continue + + if len(batch_updates) > 0: + # There is an ongoing batch and this is the end + if current_token <= current_token: + # This batch is older than current_token, dismiss it batch_updates = [] else: - # Append final update of this batch before sending + # This is the end of the batch. Append final update of + # this batch before sending batch_updates.append(update) # Send all updates that are part of this batch with the @@ -505,10 +509,13 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): # Clear saved batch updates batch_updates = [] - else: - # Only send updates newer than the current token - if token > current_token: - self.send_command(RdataCommand(stream_name, token, update)) + continue + + # This is an update that's not part of a batch. + # + # Only send updates newer than the current token + if token > current_token: + self.send_command(RdataCommand(stream_name, token, update)) # They're now fully subscribed self.replication_streams.add(stream_name) -- cgit 1.5.1 From a84b8d56c2cfec62c87a13771c817b3205b5ec4b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 4 Mar 2019 18:03:29 +0000 Subject: Fixup slave stores --- synapse/replication/slave/storage/deviceinbox.py | 15 +- synapse/replication/slave/storage/devices.py | 45 +- synapse/replication/slave/storage/push_rule.py | 2 +- synapse/storage/deviceinbox.py | 324 ++++---- synapse/storage/devices.py | 987 ++++++++++++----------- synapse/storage/end_to_end_keys.py | 88 +- 6 files changed, 728 insertions(+), 733 deletions(-) (limited to 'synapse/replication') diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 4f19fd35aa..4d59778863 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -13,15 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage import DataStore +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.storage.deviceinbox import DeviceInboxWorkerStore from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import BaseSlavedStore, __func__ -from ._slaved_id_tracker import SlavedIdTracker - -class SlavedDeviceInboxStore(BaseSlavedStore): +class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( @@ -43,12 +42,6 @@ class SlavedDeviceInboxStore(BaseSlavedStore): expiry_ms=30 * 60 * 1000, ) - get_to_device_stream_token = __func__(DataStore.get_to_device_stream_token) - get_new_messages_for_device = __func__(DataStore.get_new_messages_for_device) - get_new_device_msgs_for_remote = __func__(DataStore.get_new_device_msgs_for_remote) - delete_messages_for_device = __func__(DataStore.delete_messages_for_device) - delete_device_msgs_for_remote = __func__(DataStore.delete_device_msgs_for_remote) - def stream_positions(self): result = super(SlavedDeviceInboxStore, self).stream_positions() result["to_device"] = self._device_inbox_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index ec2fd561cc..16c9a162c5 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -13,15 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage import DataStore -from synapse.storage.end_to_end_keys import EndToEndKeyStore +from synapse.replication.slave.storage._base import BaseSlavedStore +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.storage.devices import DeviceWorkerStore +from synapse.storage.end_to_end_keys import EndToEndKeyWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import BaseSlavedStore, __func__ -from ._slaved_id_tracker import SlavedIdTracker - -class SlavedDeviceStore(BaseSlavedStore): +class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): def __init__(self, db_conn, hs): super(SlavedDeviceStore, self).__init__(db_conn, hs) @@ -38,17 +37,6 @@ class SlavedDeviceStore(BaseSlavedStore): "DeviceListFederationStreamChangeCache", device_list_max, ) - get_device_stream_token = __func__(DataStore.get_device_stream_token) - get_user_whose_devices_changed = __func__(DataStore.get_user_whose_devices_changed) - get_devices_by_remote = __func__(DataStore.get_devices_by_remote) - _get_devices_by_remote_txn = __func__(DataStore._get_devices_by_remote_txn) - _get_e2e_device_keys_txn = __func__(DataStore._get_e2e_device_keys_txn) - mark_as_sent_devices_by_remote = __func__(DataStore.mark_as_sent_devices_by_remote) - _mark_as_sent_devices_by_remote_txn = ( - __func__(DataStore._mark_as_sent_devices_by_remote_txn) - ) - count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"] - def stream_positions(self): result = super(SlavedDeviceStore, self).stream_positions() result["device_lists"] = self._device_list_id_gen.get_current_token() @@ -58,14 +46,23 @@ class SlavedDeviceStore(BaseSlavedStore): if stream_name == "device_lists": self._device_list_id_gen.advance(token) for row in rows: - self._device_list_stream_cache.entity_has_changed( - row.user_id, token + self._invalidate_caches_for_devices( + token, row.user_id, row.destination, ) - - if row.destination: - self._device_list_federation_stream_cache.entity_has_changed( - row.destination, token - ) return super(SlavedDeviceStore, self).process_replication_rows( stream_name, token, rows ) + + def _invalidate_caches_for_devices(self, token, user_id, destination): + self._device_list_stream_cache.entity_has_changed( + user_id, token + ) + + if destination: + self._device_list_federation_stream_cache.entity_has_changed( + destination, token + ) + + self._get_cached_devices_for_user.invalidate((user_id,)) + self._get_cached_user_device.invalidate_many((user_id,)) + self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index f0200c1e98..45fc913c52 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -20,7 +20,7 @@ from ._slaved_id_tracker import SlavedIdTracker from .events import SlavedEventStore -class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore): +class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): def __init__(self, db_conn, hs): self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id", diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index e06b0bc56d..e6a42a53bb 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -19,14 +19,174 @@ from canonicaljson import json from twisted.internet import defer +from synapse.storage._base import SQLBaseStore +from synapse.storage.background_updates import BackgroundUpdateStore from synapse.util.caches.expiringcache import ExpiringCache -from .background_updates import BackgroundUpdateStore - logger = logging.getLogger(__name__) -class DeviceInboxStore(BackgroundUpdateStore): +class DeviceInboxWorkerStore(SQLBaseStore): + def get_to_device_stream_token(self): + return self._device_inbox_id_gen.get_current_token() + + def get_new_messages_for_device( + self, user_id, device_id, last_stream_id, current_stream_id, limit=100 + ): + """ + Args: + user_id(str): The recipient user_id. + device_id(str): The recipient device_id. + current_stream_id(int): The current position of the to device + message stream. + Returns: + Deferred ([dict], int): List of messages for the device and where + in the stream the messages got to. + """ + has_changed = self._device_inbox_stream_cache.has_entity_changed( + user_id, last_stream_id + ) + if not has_changed: + return defer.succeed(([], current_stream_id)) + + def get_new_messages_for_device_txn(txn): + sql = ( + "SELECT stream_id, message_json FROM device_inbox" + " WHERE user_id = ? AND device_id = ?" + " AND ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, ( + user_id, device_id, last_stream_id, current_stream_id, limit + )) + messages = [] + for row in txn: + stream_pos = row[0] + messages.append(json.loads(row[1])) + if len(messages) < limit: + stream_pos = current_stream_id + return (messages, stream_pos) + + return self.runInteraction( + "get_new_messages_for_device", get_new_messages_for_device_txn, + ) + + @defer.inlineCallbacks + def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): + """ + Args: + user_id(str): The recipient user_id. + device_id(str): The recipient device_id. + up_to_stream_id(int): Where to delete messages up to. + Returns: + A deferred that resolves to the number of messages deleted. + """ + # If we have cached the last stream id we've deleted up to, we can + # check if there is likely to be anything that needs deleting + last_deleted_stream_id = self._last_device_delete_cache.get( + (user_id, device_id), None + ) + if last_deleted_stream_id: + has_changed = self._device_inbox_stream_cache.has_entity_changed( + user_id, last_deleted_stream_id + ) + if not has_changed: + defer.returnValue(0) + + def delete_messages_for_device_txn(txn): + sql = ( + "DELETE FROM device_inbox" + " WHERE user_id = ? AND device_id = ?" + " AND stream_id <= ?" + ) + txn.execute(sql, (user_id, device_id, up_to_stream_id)) + return txn.rowcount + + count = yield self.runInteraction( + "delete_messages_for_device", delete_messages_for_device_txn + ) + + # Update the cache, ensuring that we only ever increase the value + last_deleted_stream_id = self._last_device_delete_cache.get( + (user_id, device_id), 0 + ) + self._last_device_delete_cache[(user_id, device_id)] = max( + last_deleted_stream_id, up_to_stream_id + ) + + defer.returnValue(count) + + def get_new_device_msgs_for_remote( + self, destination, last_stream_id, current_stream_id, limit=100 + ): + """ + Args: + destination(str): The name of the remote server. + last_stream_id(int|long): The last position of the device message stream + that the server sent up to. + current_stream_id(int|long): The current position of the device + message stream. + Returns: + Deferred ([dict], int|long): List of messages for the device and where + in the stream the messages got to. + """ + + has_changed = self._device_federation_outbox_stream_cache.has_entity_changed( + destination, last_stream_id + ) + if not has_changed or last_stream_id == current_stream_id: + return defer.succeed(([], current_stream_id)) + + def get_new_messages_for_remote_destination_txn(txn): + sql = ( + "SELECT stream_id, messages_json FROM device_federation_outbox" + " WHERE destination = ?" + " AND ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + " LIMIT ?" + ) + txn.execute(sql, ( + destination, last_stream_id, current_stream_id, limit + )) + messages = [] + for row in txn: + stream_pos = row[0] + messages.append(json.loads(row[1])) + if len(messages) < limit: + stream_pos = current_stream_id + return (messages, stream_pos) + + return self.runInteraction( + "get_new_device_msgs_for_remote", + get_new_messages_for_remote_destination_txn, + ) + + def delete_device_msgs_for_remote(self, destination, up_to_stream_id): + """Used to delete messages when the remote destination acknowledges + their receipt. + + Args: + destination(str): The destination server_name + up_to_stream_id(int): Where to delete messages up to. + Returns: + A deferred that resolves when the messages have been deleted. + """ + def delete_messages_for_remote_destination_txn(txn): + sql = ( + "DELETE FROM device_federation_outbox" + " WHERE destination = ?" + " AND stream_id <= ?" + ) + txn.execute(sql, (destination, up_to_stream_id)) + + return self.runInteraction( + "delete_device_msgs_for_remote", + delete_messages_for_remote_destination_txn + ) + + +class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" def __init__(self, db_conn, hs): @@ -220,93 +380,6 @@ class DeviceInboxStore(BackgroundUpdateStore): txn.executemany(sql, rows) - def get_new_messages_for_device( - self, user_id, device_id, last_stream_id, current_stream_id, limit=100 - ): - """ - Args: - user_id(str): The recipient user_id. - device_id(str): The recipient device_id. - current_stream_id(int): The current position of the to device - message stream. - Returns: - Deferred ([dict], int): List of messages for the device and where - in the stream the messages got to. - """ - has_changed = self._device_inbox_stream_cache.has_entity_changed( - user_id, last_stream_id - ) - if not has_changed: - return defer.succeed(([], current_stream_id)) - - def get_new_messages_for_device_txn(txn): - sql = ( - "SELECT stream_id, message_json FROM device_inbox" - " WHERE user_id = ? AND device_id = ?" - " AND ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - " LIMIT ?" - ) - txn.execute(sql, ( - user_id, device_id, last_stream_id, current_stream_id, limit - )) - messages = [] - for row in txn: - stream_pos = row[0] - messages.append(json.loads(row[1])) - if len(messages) < limit: - stream_pos = current_stream_id - return (messages, stream_pos) - - return self.runInteraction( - "get_new_messages_for_device", get_new_messages_for_device_txn, - ) - - @defer.inlineCallbacks - def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): - """ - Args: - user_id(str): The recipient user_id. - device_id(str): The recipient device_id. - up_to_stream_id(int): Where to delete messages up to. - Returns: - A deferred that resolves to the number of messages deleted. - """ - # If we have cached the last stream id we've deleted up to, we can - # check if there is likely to be anything that needs deleting - last_deleted_stream_id = self._last_device_delete_cache.get( - (user_id, device_id), None - ) - if last_deleted_stream_id: - has_changed = self._device_inbox_stream_cache.has_entity_changed( - user_id, last_deleted_stream_id - ) - if not has_changed: - defer.returnValue(0) - - def delete_messages_for_device_txn(txn): - sql = ( - "DELETE FROM device_inbox" - " WHERE user_id = ? AND device_id = ?" - " AND stream_id <= ?" - ) - txn.execute(sql, (user_id, device_id, up_to_stream_id)) - return txn.rowcount - - count = yield self.runInteraction( - "delete_messages_for_device", delete_messages_for_device_txn - ) - - # Update the cache, ensuring that we only ever increase the value - last_deleted_stream_id = self._last_device_delete_cache.get( - (user_id, device_id), 0 - ) - self._last_device_delete_cache[(user_id, device_id)] = max( - last_deleted_stream_id, up_to_stream_id - ) - - defer.returnValue(count) - def get_all_new_device_messages(self, last_pos, current_pos, limit): """ Args: @@ -351,77 +424,6 @@ class DeviceInboxStore(BackgroundUpdateStore): "get_all_new_device_messages", get_all_new_device_messages_txn ) - def get_to_device_stream_token(self): - return self._device_inbox_id_gen.get_current_token() - - def get_new_device_msgs_for_remote( - self, destination, last_stream_id, current_stream_id, limit=100 - ): - """ - Args: - destination(str): The name of the remote server. - last_stream_id(int|long): The last position of the device message stream - that the server sent up to. - current_stream_id(int|long): The current position of the device - message stream. - Returns: - Deferred ([dict], int|long): List of messages for the device and where - in the stream the messages got to. - """ - - has_changed = self._device_federation_outbox_stream_cache.has_entity_changed( - destination, last_stream_id - ) - if not has_changed or last_stream_id == current_stream_id: - return defer.succeed(([], current_stream_id)) - - def get_new_messages_for_remote_destination_txn(txn): - sql = ( - "SELECT stream_id, messages_json FROM device_federation_outbox" - " WHERE destination = ?" - " AND ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - " LIMIT ?" - ) - txn.execute(sql, ( - destination, last_stream_id, current_stream_id, limit - )) - messages = [] - for row in txn: - stream_pos = row[0] - messages.append(json.loads(row[1])) - if len(messages) < limit: - stream_pos = current_stream_id - return (messages, stream_pos) - - return self.runInteraction( - "get_new_device_msgs_for_remote", - get_new_messages_for_remote_destination_txn, - ) - - def delete_device_msgs_for_remote(self, destination, up_to_stream_id): - """Used to delete messages when the remote destination acknowledges - their receipt. - - Args: - destination(str): The destination server_name - up_to_stream_id(int): Where to delete messages up to. - Returns: - A deferred that resolves when the messages have been deleted. - """ - def delete_messages_for_remote_destination_txn(txn): - sql = ( - "DELETE FROM device_federation_outbox" - " WHERE destination = ?" - " AND stream_id <= ?" - ) - txn.execute(sql, (destination, up_to_stream_id)) - - return self.runInteraction( - "delete_device_msgs_for_remote", - delete_messages_for_remote_destination_txn - ) - @defer.inlineCallbacks def _background_drop_index_device_inbox(self, progress, batch_size): def reindex_txn(conn): diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index ecdab34e7d..e716dc1437 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -22,11 +22,10 @@ from twisted.internet import defer from synapse.api.errors import StoreError from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage._base import Cache, SQLBaseStore, db_to_json from synapse.storage.background_updates import BackgroundUpdateStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList -from ._base import Cache, db_to_json - logger = logging.getLogger(__name__) DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( @@ -34,93 +33,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( ) -class DeviceStore(BackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(DeviceStore, self).__init__(db_conn, hs) - - # Map of (user_id, device_id) -> bool. If there is an entry that implies - # the device exists. - self.device_id_exists_cache = Cache( - name="device_id_exists", - keylen=2, - max_entries=10000, - ) - - self._clock.looping_call( - self._prune_old_outbound_device_pokes, 60 * 60 * 1000 - ) - - self.register_background_index_update( - "device_lists_stream_idx", - index_name="device_lists_stream_user_id", - table="device_lists_stream", - columns=["user_id", "device_id"], - ) - - # create a unique index on device_lists_remote_cache - self.register_background_index_update( - "device_lists_remote_cache_unique_idx", - index_name="device_lists_remote_cache_unique_id", - table="device_lists_remote_cache", - columns=["user_id", "device_id"], - unique=True, - ) - - # And one on device_lists_remote_extremeties - self.register_background_index_update( - "device_lists_remote_extremeties_unique_idx", - index_name="device_lists_remote_extremeties_unique_idx", - table="device_lists_remote_extremeties", - columns=["user_id"], - unique=True, - ) - - # once they complete, we can remove the old non-unique indexes. - self.register_background_update_handler( - DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, - self._drop_device_list_streams_non_unique_indexes, - ) - - @defer.inlineCallbacks - def store_device(self, user_id, device_id, - initial_device_display_name): - """Ensure the given device is known; add it to the store if not - - Args: - user_id (str): id of user associated with the device - device_id (str): id of device - initial_device_display_name (str): initial displayname of the - device. Ignored if device exists. - Returns: - defer.Deferred: boolean whether the device was inserted or an - existing device existed with that ID. - """ - key = (user_id, device_id) - if self.device_id_exists_cache.get(key, None): - defer.returnValue(False) - - try: - inserted = yield self._simple_insert( - "devices", - values={ - "user_id": user_id, - "device_id": device_id, - "display_name": initial_device_display_name - }, - desc="store_device", - or_ignore=True, - ) - self.device_id_exists_cache.prefill(key, True) - defer.returnValue(inserted) - except Exception as e: - logger.error("store_device with device_id=%s(%r) user_id=%s(%r)" - " display_name=%s(%r) failed: %s", - type(device_id).__name__, device_id, - type(user_id).__name__, user_id, - type(initial_device_display_name).__name__, - initial_device_display_name, e) - raise StoreError(500, "Problem storing device.") - +class DeviceWorkerStore(SQLBaseStore): def get_device(self, user_id, device_id): """Retrieve a device. @@ -139,69 +52,6 @@ class DeviceStore(BackgroundUpdateStore): desc="get_device", ) - @defer.inlineCallbacks - def delete_device(self, user_id, device_id): - """Delete a device. - - Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to delete - Returns: - defer.Deferred - """ - yield self._simple_delete_one( - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, - desc="delete_device", - ) - - self.device_id_exists_cache.invalidate((user_id, device_id)) - - @defer.inlineCallbacks - def delete_devices(self, user_id, device_ids): - """Deletes several devices. - - Args: - user_id (str): The ID of the user which owns the devices - device_ids (list): The IDs of the devices to delete - Returns: - defer.Deferred - """ - yield self._simple_delete_many( - table="devices", - column="device_id", - iterable=device_ids, - keyvalues={"user_id": user_id}, - desc="delete_devices", - ) - for device_id in device_ids: - self.device_id_exists_cache.invalidate((user_id, device_id)) - - def update_device(self, user_id, device_id, new_display_name=None): - """Update a device. - - Args: - user_id (str): The ID of the user which owns the device - device_id (str): The ID of the device to update - new_display_name (str|None): new displayname for device; None - to leave unchanged - Raises: - StoreError: if the device is not found - Returns: - defer.Deferred - """ - updates = {} - if new_display_name is not None: - updates["display_name"] = new_display_name - if not updates: - return defer.succeed(None) - return self._simple_update_one( - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, - updatevalues=updates, - desc="update_device", - ) - @defer.inlineCallbacks def get_devices_by_user(self, user_id): """Retrieve all of a user's registered devices. @@ -222,449 +72,603 @@ class DeviceStore(BackgroundUpdateStore): defer.returnValue({d["device_id"]: d for d in devices}) - @cached(max_entries=10000) - def get_device_list_last_stream_id_for_remote(self, user_id): - """Get the last stream_id we got for a user. May be None if we haven't - got any information for them. + def get_devices_by_remote(self, destination, from_stream_id): + """Get stream of updates to send to remote servers + + Returns: + (int, list[dict]): current stream id and list of updates """ - return self._simple_select_one_onecol( - table="device_lists_remote_extremeties", - keyvalues={"user_id": user_id}, - retcol="stream_id", - desc="get_device_list_remote_extremity", - allow_none=True, - ) + now_stream_id = self._device_list_id_gen.get_current_token() - @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote", - list_name="user_ids", inlineCallbacks=True) - def get_device_list_last_stream_id_for_remotes(self, user_ids): - rows = yield self._simple_select_many_batch( - table="device_lists_remote_extremeties", - column="user_id", - iterable=user_ids, - retcols=("user_id", "stream_id",), - desc="get_user_devices_from_cache", + has_changed = self._device_list_federation_stream_cache.has_entity_changed( + destination, int(from_stream_id) ) + if not has_changed: + return (now_stream_id, []) - results = {user_id: None for user_id in user_ids} - results.update({ - row["user_id"]: row["stream_id"] for row in rows - }) - - defer.returnValue(results) + return self.runInteraction( + "get_devices_by_remote", self._get_devices_by_remote_txn, + destination, from_stream_id, now_stream_id, + ) - @defer.inlineCallbacks - def mark_remote_user_device_list_as_unsubscribed(self, user_id): - """Mark that we no longer track device lists for remote user. + def _get_devices_by_remote_txn(self, txn, destination, from_stream_id, + now_stream_id): + sql = """ + SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes + WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? + GROUP BY user_id, device_id + LIMIT 20 """ - yield self._simple_delete( - table="device_lists_remote_extremeties", - keyvalues={ - "user_id": user_id, - }, - desc="mark_remote_user_device_list_as_unsubscribed", + txn.execute( + sql, (destination, from_stream_id, now_stream_id, False) ) - self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) - def update_remote_device_list_cache_entry(self, user_id, device_id, content, - stream_id): - """Updates a single device in the cache of a remote user's devicelist. + # maps (user_id, device_id) -> stream_id + query_map = {(r[0], r[1]): r[2] for r in txn} + if not query_map: + return (now_stream_id, []) - Note: assumes that we are the only thread that can be updating this user's - device list. + if len(query_map) >= 20: + now_stream_id = max(stream_id for stream_id in itervalues(query_map)) - Args: - user_id (str): User to update device list for - device_id (str): ID of decivice being updated - content (dict): new data on this device - stream_id (int): the version of the device list + devices = self._get_e2e_device_keys_txn( + txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True + ) - Returns: - Deferred[None] + prev_sent_id_sql = """ + SELECT coalesce(max(stream_id), 0) as stream_id + FROM device_lists_outbound_last_success + WHERE destination = ? AND user_id = ? AND stream_id <= ? """ - return self.runInteraction( - "update_remote_device_list_cache_entry", - self._update_remote_device_list_cache_entry_txn, - user_id, device_id, content, stream_id, - ) - def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id, - content, stream_id): - if content.get("deleted"): - self._simple_delete_txn( - txn, - table="device_lists_remote_cache", - keyvalues={ + results = [] + for user_id, user_devices in iteritems(devices): + # The prev_id for the first row is always the last row before + # `from_stream_id` + txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) + rows = txn.fetchall() + prev_id = rows[0][0] + for device_id, device in iteritems(user_devices): + stream_id = query_map[(user_id, device_id)] + result = { "user_id": user_id, "device_id": device_id, - }, - ) + "prev_id": [prev_id] if prev_id else [], + "stream_id": stream_id, + } - txn.call_after( - self.device_id_exists_cache.invalidate, (user_id, device_id,) - ) - else: - self._simple_upsert_txn( - txn, - table="device_lists_remote_cache", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - values={ - "content": json.dumps(content), - }, + prev_id = stream_id - # we don't need to lock, because we assume we are the only thread - # updating this user's devices. - lock=False, - ) + if device is not None: + key_json = device.get("key_json", None) + if key_json: + result["keys"] = db_to_json(key_json) + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + else: + result["deleted"] = True - txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,)) - txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) - txn.call_after( - self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) + results.append(result) + + return (now_stream_id, results) + + def mark_as_sent_devices_by_remote(self, destination, stream_id): + """Mark that updates have successfully been sent to the destination. + """ + return self.runInteraction( + "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, + destination, stream_id, ) - self._simple_upsert_txn( - txn, - table="device_lists_remote_extremeties", - keyvalues={ - "user_id": user_id, - }, - values={ - "stream_id": stream_id, - }, + def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): + # We update the device_lists_outbound_last_success with the successfully + # poked users. We do the join to see which users need to be inserted and + # which updated. + sql = """ + SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL) + FROM device_lists_outbound_pokes as o + LEFT JOIN device_lists_outbound_last_success as s + USING (destination, user_id) + WHERE destination = ? AND o.stream_id <= ? + GROUP BY user_id + """ + txn.execute(sql, (destination, stream_id,)) + rows = txn.fetchall() - # again, we can assume we are the only thread updating this user's - # extremity. - lock=False, + sql = """ + UPDATE device_lists_outbound_last_success + SET stream_id = ? + WHERE destination = ? AND user_id = ? + """ + txn.executemany( + sql, ((row[1], destination, row[0],) for row in rows if row[2]) ) - def update_remote_device_list_cache(self, user_id, devices, stream_id): - """Replace the entire cache of the remote user's devices. + sql = """ + INSERT INTO device_lists_outbound_last_success + (destination, user_id, stream_id) VALUES (?, ?, ?) + """ + txn.executemany( + sql, ((destination, row[0], row[1],) for row in rows if not row[2]) + ) - Note: assumes that we are the only thread that can be updating this user's - device list. + # Delete all sent outbound pokes + sql = """ + DELETE FROM device_lists_outbound_pokes + WHERE destination = ? AND stream_id <= ? + """ + txn.execute(sql, (destination, stream_id,)) + + def get_device_stream_token(self): + return self._device_list_id_gen.get_current_token() + + @defer.inlineCallbacks + def get_user_devices_from_cache(self, query_list): + """Get the devices (and keys if any) for remote users from the cache. Args: - user_id (str): User to update device list for - devices (list[dict]): list of device objects supplied over federation - stream_id (int): the version of the device list + query_list(list): List of (user_id, device_ids), if device_ids is + falsey then return all device ids for that user. Returns: - Deferred[None] + (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is + a set of user_ids and results_map is a mapping of + user_id -> device_id -> device_info """ - return self.runInteraction( - "update_remote_device_list_cache", - self._update_remote_device_list_cache_txn, - user_id, devices, stream_id, + user_ids = set(user_id for user_id, _ in query_list) + user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) + user_ids_in_cache = set( + user_id for user_id, stream_id in user_map.items() if stream_id ) + user_ids_not_in_cache = user_ids - user_ids_in_cache - def _update_remote_device_list_cache_txn(self, txn, user_id, devices, - stream_id): - self._simple_delete_txn( - txn, + results = {} + for user_id, device_id in query_list: + if user_id not in user_ids_in_cache: + continue + + if device_id: + device = yield self._get_cached_user_device(user_id, device_id) + results.setdefault(user_id, {})[device_id] = device + else: + results[user_id] = yield self._get_cached_devices_for_user(user_id) + + defer.returnValue((user_ids_not_in_cache, results)) + + @cachedInlineCallbacks(num_args=2, tree=True) + def _get_cached_user_device(self, user_id, device_id): + content = yield self._simple_select_one_onecol( table="device_lists_remote_cache", keyvalues={ "user_id": user_id, + "device_id": device_id, }, + retcol="content", + desc="_get_cached_user_device", ) + defer.returnValue(db_to_json(content)) - self._simple_insert_many_txn( - txn, + @cachedInlineCallbacks() + def _get_cached_devices_for_user(self, user_id): + devices = yield self._simple_select_list( table="device_lists_remote_cache", - values=[ - { - "user_id": user_id, - "device_id": content["device_id"], - "content": json.dumps(content), - } - for content in devices - ] - ) - - txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) - txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,)) - txn.call_after( - self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) - ) - - self._simple_upsert_txn( - txn, - table="device_lists_remote_extremeties", keyvalues={ "user_id": user_id, }, - values={ - "stream_id": stream_id, - }, - - # we don't need to lock, because we can assume we are the only thread - # updating this user's extremity. - lock=False, + retcols=("device_id", "content"), + desc="_get_cached_devices_for_user", ) + defer.returnValue({ + device["device_id"]: db_to_json(device["content"]) + for device in devices + }) - def get_devices_by_remote(self, destination, from_stream_id): - """Get stream of updates to send to remote servers + def get_devices_with_keys_by_user(self, user_id): + """Get all devices (with any device keys) for a user Returns: - (int, list[dict]): current stream id and list of updates + (stream_id, devices) """ + return self.runInteraction( + "get_devices_with_keys_by_user", + self._get_devices_with_keys_by_user_txn, user_id, + ) + + def _get_devices_with_keys_by_user_txn(self, txn, user_id): now_stream_id = self._device_list_id_gen.get_current_token() - has_changed = self._device_list_federation_stream_cache.has_entity_changed( - destination, int(from_stream_id) + devices = self._get_e2e_device_keys_txn( + txn, [(user_id, None)], include_all_devices=True ) - if not has_changed: - return (now_stream_id, []) - return self.runInteraction( - "get_devices_by_remote", self._get_devices_by_remote_txn, - destination, from_stream_id, now_stream_id, - ) + if devices: + user_devices = devices[user_id] + results = [] + for device_id, device in iteritems(user_devices): + result = { + "device_id": device_id, + } + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = db_to_json(key_json) + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return now_stream_id, results + + return now_stream_id, [] + + @defer.inlineCallbacks + def get_user_whose_devices_changed(self, from_key): + """Get set of users whose devices have changed since `from_key`. + """ + from_key = int(from_key) + changed = self._device_list_stream_cache.get_all_entities_changed(from_key) + if changed is not None: + defer.returnValue(set(changed)) - def _get_devices_by_remote_txn(self, txn, destination, from_stream_id, - now_stream_id): sql = """ - SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes - WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? - GROUP BY user_id, device_id - LIMIT 20 + SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ? """ - txn.execute( - sql, (destination, from_stream_id, now_stream_id, False) + rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) + defer.returnValue(set(row[0] for row in rows)) + + def get_all_device_list_changes_for_remotes(self, from_key, to_key): + """Return a list of `(stream_id, user_id, destination)` which is the + combined list of changes to devices, and which destinations need to be + poked. `destination` may be None if no destinations need to be poked. + """ + # We do a group by here as there can be a large number of duplicate + # entries, since we throw away device IDs. + sql = """ + SELECT MAX(stream_id) AS stream_id, user_id, destination + FROM device_lists_stream + LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) + WHERE ? < stream_id AND stream_id <= ? + GROUP BY user_id, destination + """ + return self._execute( + "get_all_device_list_changes_for_remotes", None, + sql, from_key, to_key ) - # maps (user_id, device_id) -> stream_id - query_map = {(r[0], r[1]): r[2] for r in txn} - if not query_map: - return (now_stream_id, []) + @cached(max_entries=10000) + def get_device_list_last_stream_id_for_remote(self, user_id): + """Get the last stream_id we got for a user. May be None if we haven't + got any information for them. + """ + return self._simple_select_one_onecol( + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + retcol="stream_id", + desc="get_device_list_last_stream_id_for_remote", + allow_none=True, + ) - if len(query_map) >= 20: - now_stream_id = max(stream_id for stream_id in itervalues(query_map)) + @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote", + list_name="user_ids", inlineCallbacks=True) + def get_device_list_last_stream_id_for_remotes(self, user_ids): + rows = yield self._simple_select_many_batch( + table="device_lists_remote_extremeties", + column="user_id", + iterable=user_ids, + retcols=("user_id", "stream_id",), + desc="get_device_list_last_stream_id_for_remotes", + ) - devices = self._get_e2e_device_keys_txn( - txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True + results = {user_id: None for user_id in user_ids} + results.update({ + row["user_id"]: row["stream_id"] for row in rows + }) + + defer.returnValue(results) + + +class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): + def __init__(self, db_conn, hs): + super(DeviceStore, self).__init__(db_conn, hs) + + # Map of (user_id, device_id) -> bool. If there is an entry that implies + # the device exists. + self.device_id_exists_cache = Cache( + name="device_id_exists", + keylen=2, + max_entries=10000, ) - prev_sent_id_sql = """ - SELECT coalesce(max(stream_id), 0) as stream_id - FROM device_lists_outbound_last_success - WHERE destination = ? AND user_id = ? AND stream_id <= ? + self._clock.looping_call( + self._prune_old_outbound_device_pokes, 60 * 60 * 1000 + ) + + self.register_background_index_update( + "device_lists_stream_idx", + index_name="device_lists_stream_user_id", + table="device_lists_stream", + columns=["user_id", "device_id"], + ) + + # create a unique index on device_lists_remote_cache + self.register_background_index_update( + "device_lists_remote_cache_unique_idx", + index_name="device_lists_remote_cache_unique_id", + table="device_lists_remote_cache", + columns=["user_id", "device_id"], + unique=True, + ) + + # And one on device_lists_remote_extremeties + self.register_background_index_update( + "device_lists_remote_extremeties_unique_idx", + index_name="device_lists_remote_extremeties_unique_idx", + table="device_lists_remote_extremeties", + columns=["user_id"], + unique=True, + ) + + # once they complete, we can remove the old non-unique indexes. + self.register_background_update_handler( + DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, + self._drop_device_list_streams_non_unique_indexes, + ) + + @defer.inlineCallbacks + def store_device(self, user_id, device_id, + initial_device_display_name): + """Ensure the given device is known; add it to the store if not + + Args: + user_id (str): id of user associated with the device + device_id (str): id of device + initial_device_display_name (str): initial displayname of the + device. Ignored if device exists. + Returns: + defer.Deferred: boolean whether the device was inserted or an + existing device existed with that ID. """ + key = (user_id, device_id) + if self.device_id_exists_cache.get(key, None): + defer.returnValue(False) - results = [] - for user_id, user_devices in iteritems(devices): - # The prev_id for the first row is always the last row before - # `from_stream_id` - txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) - rows = txn.fetchall() - prev_id = rows[0][0] - for device_id, device in iteritems(user_devices): - stream_id = query_map[(user_id, device_id)] - result = { + try: + inserted = yield self._simple_insert( + "devices", + values={ "user_id": user_id, "device_id": device_id, - "prev_id": [prev_id] if prev_id else [], - "stream_id": stream_id, - } - - prev_id = stream_id + "display_name": initial_device_display_name + }, + desc="store_device", + or_ignore=True, + ) + self.device_id_exists_cache.prefill(key, True) + defer.returnValue(inserted) + except Exception as e: + logger.error("store_device with device_id=%s(%r) user_id=%s(%r)" + " display_name=%s(%r) failed: %s", + type(device_id).__name__, device_id, + type(user_id).__name__, user_id, + type(initial_device_display_name).__name__, + initial_device_display_name, e) + raise StoreError(500, "Problem storing device.") - if device is not None: - key_json = device.get("key_json", None) - if key_json: - result["keys"] = db_to_json(key_json) - device_display_name = device.get("device_display_name", None) - if device_display_name: - result["device_display_name"] = device_display_name - else: - result["deleted"] = True + @defer.inlineCallbacks + def delete_device(self, user_id, device_id): + """Delete a device. - results.append(result) + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to delete + Returns: + defer.Deferred + """ + yield self._simple_delete_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + desc="delete_device", + ) - return (now_stream_id, results) + self.device_id_exists_cache.invalidate((user_id, device_id)) @defer.inlineCallbacks - def get_user_devices_from_cache(self, query_list): - """Get the devices (and keys if any) for remote users from the cache. + def delete_devices(self, user_id, device_ids): + """Deletes several devices. Args: - query_list(list): List of (user_id, device_ids), if device_ids is - falsey then return all device ids for that user. + user_id (str): The ID of the user which owns the devices + device_ids (list): The IDs of the devices to delete + Returns: + defer.Deferred + """ + yield self._simple_delete_many( + table="devices", + column="device_id", + iterable=device_ids, + keyvalues={"user_id": user_id}, + desc="delete_devices", + ) + for device_id in device_ids: + self.device_id_exists_cache.invalidate((user_id, device_id)) + def update_device(self, user_id, device_id, new_display_name=None): + """Update a device. + + Args: + user_id (str): The ID of the user which owns the device + device_id (str): The ID of the device to update + new_display_name (str|None): new displayname for device; None + to leave unchanged + Raises: + StoreError: if the device is not found Returns: - (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is - a set of user_ids and results_map is a mapping of - user_id -> device_id -> device_info + defer.Deferred """ - user_ids = set(user_id for user_id, _ in query_list) - user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) - user_ids_in_cache = set( - user_id for user_id, stream_id in user_map.items() if stream_id + updates = {} + if new_display_name is not None: + updates["display_name"] = new_display_name + if not updates: + return defer.succeed(None) + return self._simple_update_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + updatevalues=updates, + desc="update_device", ) - user_ids_not_in_cache = user_ids - user_ids_in_cache - - results = {} - for user_id, device_id in query_list: - if user_id not in user_ids_in_cache: - continue - - if device_id: - device = yield self._get_cached_user_device(user_id, device_id) - results.setdefault(user_id, {})[device_id] = device - else: - results[user_id] = yield self._get_cached_devices_for_user(user_id) - - defer.returnValue((user_ids_not_in_cache, results)) - @cachedInlineCallbacks(num_args=2, tree=True) - def _get_cached_user_device(self, user_id, device_id): - content = yield self._simple_select_one_onecol( - table="device_lists_remote_cache", + @defer.inlineCallbacks + def mark_remote_user_device_list_as_unsubscribed(self, user_id): + """Mark that we no longer track device lists for remote user. + """ + yield self._simple_delete( + table="device_lists_remote_extremeties", keyvalues={ "user_id": user_id, - "device_id": device_id, }, - retcol="content", - desc="_get_cached_user_device", + desc="mark_remote_user_device_list_as_unsubscribed", ) - defer.returnValue(db_to_json(content)) + self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) - @cachedInlineCallbacks() - def _get_cached_devices_for_user(self, user_id): - devices = yield self._simple_select_list( - table="device_lists_remote_cache", - keyvalues={ - "user_id": user_id, - }, - retcols=("device_id", "content"), - desc="_get_cached_devices_for_user", - ) - defer.returnValue({ - device["device_id"]: db_to_json(device["content"]) - for device in devices - }) + def update_remote_device_list_cache_entry(self, user_id, device_id, content, + stream_id): + """Updates a single device in the cache of a remote user's devicelist. - def get_devices_with_keys_by_user(self, user_id): - """Get all devices (with any device keys) for a user + Note: assumes that we are the only thread that can be updating this user's + device list. + + Args: + user_id (str): User to update device list for + device_id (str): ID of decivice being updated + content (dict): new data on this device + stream_id (int): the version of the device list Returns: - (stream_id, devices) + Deferred[None] """ return self.runInteraction( - "get_devices_with_keys_by_user", - self._get_devices_with_keys_by_user_txn, user_id, + "update_remote_device_list_cache_entry", + self._update_remote_device_list_cache_entry_txn, + user_id, device_id, content, stream_id, ) - def _get_devices_with_keys_by_user_txn(self, txn, user_id): - now_stream_id = self._device_list_id_gen.get_current_token() + def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id, + content, stream_id): + if content.get("deleted"): + self._simple_delete_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + ) - devices = self._get_e2e_device_keys_txn( - txn, [(user_id, None)], include_all_devices=True + txn.call_after( + self.device_id_exists_cache.invalidate, (user_id, device_id,) + ) + else: + self._simple_upsert_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "content": json.dumps(content), + }, + + # we don't need to lock, because we assume we are the only thread + # updating this user's devices. + lock=False, + ) + + txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,)) + txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) + txn.call_after( + self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - if devices: - user_devices = devices[user_id] - results = [] - for device_id, device in iteritems(user_devices): - result = { - "device_id": device_id, - } + self._simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + values={ + "stream_id": stream_id, + }, - key_json = device.get("key_json", None) - if key_json: - result["keys"] = db_to_json(key_json) - device_display_name = device.get("device_display_name", None) - if device_display_name: - result["device_display_name"] = device_display_name + # again, we can assume we are the only thread updating this user's + # extremity. + lock=False, + ) - results.append(result) + def update_remote_device_list_cache(self, user_id, devices, stream_id): + """Replace the entire cache of the remote user's devices. - return now_stream_id, results + Note: assumes that we are the only thread that can be updating this user's + device list. - return now_stream_id, [] + Args: + user_id (str): User to update device list for + devices (list[dict]): list of device objects supplied over federation + stream_id (int): the version of the device list - def mark_as_sent_devices_by_remote(self, destination, stream_id): - """Mark that updates have successfully been sent to the destination. + Returns: + Deferred[None] """ return self.runInteraction( - "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, - destination, stream_id, + "update_remote_device_list_cache", + self._update_remote_device_list_cache_txn, + user_id, devices, stream_id, ) - def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): - # We update the device_lists_outbound_last_success with the successfully - # poked users. We do the join to see which users need to be inserted and - # which updated. - sql = """ - SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL) - FROM device_lists_outbound_pokes as o - LEFT JOIN device_lists_outbound_last_success as s - USING (destination, user_id) - WHERE destination = ? AND o.stream_id <= ? - GROUP BY user_id - """ - txn.execute(sql, (destination, stream_id,)) - rows = txn.fetchall() - - sql = """ - UPDATE device_lists_outbound_last_success - SET stream_id = ? - WHERE destination = ? AND user_id = ? - """ - txn.executemany( - sql, ((row[1], destination, row[0],) for row in rows if row[2]) + def _update_remote_device_list_cache_txn(self, txn, user_id, devices, + stream_id): + self._simple_delete_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, ) - sql = """ - INSERT INTO device_lists_outbound_last_success - (destination, user_id, stream_id) VALUES (?, ?, ?) - """ - txn.executemany( - sql, ((destination, row[0], row[1],) for row in rows if not row[2]) + self._simple_insert_many_txn( + txn, + table="device_lists_remote_cache", + values=[ + { + "user_id": user_id, + "device_id": content["device_id"], + "content": json.dumps(content), + } + for content in devices + ] ) - # Delete all sent outbound pokes - sql = """ - DELETE FROM device_lists_outbound_pokes - WHERE destination = ? AND stream_id <= ? - """ - txn.execute(sql, (destination, stream_id,)) - - @defer.inlineCallbacks - def get_user_whose_devices_changed(self, from_key): - """Get set of users whose devices have changed since `from_key`. - """ - from_key = int(from_key) - changed = self._device_list_stream_cache.get_all_entities_changed(from_key) - if changed is not None: - defer.returnValue(set(changed)) + txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) + txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,)) + txn.call_after( + self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) + ) - sql = """ - SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ? - """ - rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) - defer.returnValue(set(row[0] for row in rows)) + self._simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + values={ + "stream_id": stream_id, + }, - def get_all_device_list_changes_for_remotes(self, from_key, to_key): - """Return a list of `(stream_id, user_id, destination)` which is the - combined list of changes to devices, and which destinations need to be - poked. `destination` may be None if no destinations need to be poked. - """ - # We do a group by here as there can be a large number of duplicate - # entries, since we throw away device IDs. - sql = """ - SELECT MAX(stream_id) AS stream_id, user_id, destination - FROM device_lists_stream - LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) - WHERE ? < stream_id AND stream_id <= ? - GROUP BY user_id, destination - """ - return self._execute( - "get_all_device_list_changes_for_remotes", None, - sql, from_key, to_key + # we don't need to lock, because we can assume we are the only thread + # updating this user's extremity. + lock=False, ) @defer.inlineCallbacks @@ -732,9 +736,6 @@ class DeviceStore(BackgroundUpdateStore): ] ) - def get_device_stream_token(self): - return self._device_list_id_gen.get_current_token() - def _prune_old_outbound_device_pokes(self): """Delete old entries out of the device_lists_outbound_pokes to ensure that we don't fill up due to dead servers. We keep one entry per diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 2a0f6cfca9..e381e472a2 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -23,49 +23,7 @@ from synapse.util.caches.descriptors import cached from ._base import SQLBaseStore, db_to_json -class EndToEndKeyStore(SQLBaseStore): - def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): - """Stores device keys for a device. Returns whether there was a change - or the keys were already in the database. - """ - def _set_e2e_device_keys_txn(txn): - old_key_json = self._simple_select_one_onecol_txn( - txn, - table="e2e_device_keys_json", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - retcol="key_json", - allow_none=True, - ) - - # In py3 we need old_key_json to match new_key_json type. The DB - # returns unicode while encode_canonical_json returns bytes. - new_key_json = encode_canonical_json(device_keys).decode("utf-8") - - if old_key_json == new_key_json: - return False - - self._simple_upsert_txn( - txn, - table="e2e_device_keys_json", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - values={ - "ts_added_ms": time_now, - "key_json": new_key_json, - } - ) - - return True - - return self.runInteraction( - "set_e2e_device_keys", _set_e2e_device_keys_txn - ) - +class EndToEndKeyWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_e2e_device_keys( self, query_list, include_all_devices=False, @@ -238,6 +196,50 @@ class EndToEndKeyStore(SQLBaseStore): "count_e2e_one_time_keys", _count_e2e_one_time_keys ) + +class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): + def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): + """Stores device keys for a device. Returns whether there was a change + or the keys were already in the database. + """ + def _set_e2e_device_keys_txn(txn): + old_key_json = self._simple_select_one_onecol_txn( + txn, + table="e2e_device_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + retcol="key_json", + allow_none=True, + ) + + # In py3 we need old_key_json to match new_key_json type. The DB + # returns unicode while encode_canonical_json returns bytes. + new_key_json = encode_canonical_json(device_keys).decode("utf-8") + + if old_key_json == new_key_json: + return False + + self._simple_upsert_txn( + txn, + table="e2e_device_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "ts_added_ms": time_now, + "key_json": new_key_json, + } + ) + + return True + + return self.runInteraction( + "set_e2e_device_keys", _set_e2e_device_keys_txn + ) + def claim_e2e_one_time_keys(self, query_list): """Take a list of one time keys out of the database""" def _claim_e2e_one_time_keys(txn): -- cgit 1.5.1 From b9f61630927752422fb80cf7ece083741aefd399 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 5 Mar 2019 13:58:30 +0000 Subject: Simplify token replication logic --- synapse/replication/tcp/protocol.py | 37 ++++++++++++++----------------------- 1 file changed, 14 insertions(+), 23 deletions(-) (limited to 'synapse/replication') diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index dac4fbeef7..55630ba9a7 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -478,7 +478,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): # Now we can send any updates that came in while we were subscribing pending_rdata = self.pending_rdata.pop(stream_name, []) - batch_updates = [] + updates = [] for token, update in pending_rdata: # If the token is null, it is part of a batch update. Batches # are multiple updates that share a single token. To denote @@ -489,34 +489,25 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): # final token. if token is None: # Store this update as part of a batch - batch_updates.append(update) + updates.append(update) continue - if len(batch_updates) > 0: - # There is an ongoing batch and this is the end - if current_token <= current_token: - # This batch is older than current_token, dismiss it - batch_updates = [] - else: - # This is the end of the batch. Append final update of - # this batch before sending - batch_updates.append(update) - - # Send all updates that are part of this batch with the - # found token - for update in batch_updates: - self.send_command(RdataCommand(stream_name, token, update)) - - # Clear saved batch updates - batch_updates = [] + if token <= current_token: + # This update or batch of updates is older than + # current_token, dismiss it + updates = [] continue - # This is an update that's not part of a batch. - # - # Only send updates newer than the current token - if token > current_token: + updates.append(update) + + # Send all updates that are part of this batch with the + # found token + for update in updates: self.send_command(RdataCommand(stream_name, token, update)) + # Clear stored updates + updates = [] + # They're now fully subscribed self.replication_streams.add(stream_name) except Exception as e: -- cgit 1.5.1 From a4c3a361b70bc02d65104240bef1b3cbb110bf22 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 5 Mar 2019 14:25:33 +0000 Subject: Add rate-limiting on registration (#4735) * Rate-limiting for registration * Add unit test for registration rate limiting * Add config parameters for rate limiting on auth endpoints * Doc * Fix doc of rate limiting function Co-Authored-By: babolivier * Incorporate review * Fix config parsing * Fix linting errors * Set default config for auth rate limiting * Fix tests * Add changelog * Advance reactor instead of mocked clock * Move parameters to registration specific config and give them more sensible default values * Remove unused config options * Don't mock the rate limiter un MAU tests * Rename _register_with_store into register_with_store * Make CI happy * Remove unused import * Update sample config * Fix ratelimiting test for py2 * Add non-guest test --- changelog.d/4735.feature | 1 + docs/sample_config.yaml | 11 +++++++ synapse/api/ratelimiting.py | 31 ++++++++++--------- synapse/config/registration.py | 18 +++++++++++ synapse/handlers/_base.py | 4 +-- synapse/handlers/register.py | 39 ++++++++++++++++++----- synapse/replication/http/register.py | 8 +++-- synapse/rest/client/v2_alpha/register.py | 33 +++++++++++++++++--- tests/api/test_ratelimiting.py | 20 ++++++------ tests/handlers/test_profile.py | 4 +-- tests/replication/slave/storage/_base.py | 4 +-- tests/rest/client/v1/test_events.py | 4 +-- tests/rest/client/v1/test_rooms.py | 6 ++-- tests/rest/client/v1/test_typing.py | 4 +-- tests/rest/client/v2_alpha/test_register.py | 48 +++++++++++++++++++++++++++++ tests/test_mau.py | 3 +- tests/utils.py | 2 ++ 17 files changed, 186 insertions(+), 54 deletions(-) create mode 100644 changelog.d/4735.feature (limited to 'synapse/replication') diff --git a/changelog.d/4735.feature b/changelog.d/4735.feature new file mode 100644 index 0000000000..a4c0b196f6 --- /dev/null +++ b/changelog.d/4735.feature @@ -0,0 +1 @@ +Add configurable rate limiting to the /register endpoint. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 7cf58d2182..e0140003fd 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -657,6 +657,17 @@ trusted_third_party_id_servers: # autocreate_auto_join_rooms: true +# Number of registration requests a client can send per second. +# Defaults to 1/minute (0.17). +# +#rc_registration_requests_per_second: 0.17 + +# Number of registration requests a client can send before being +# throttled. +# Defaults to 3. +# +#rc_registration_request_burst_count: 3.0 + ## Metrics ### diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 3bb5b3da37..ad68079eeb 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -23,12 +23,13 @@ class Ratelimiter(object): def __init__(self): self.message_counts = collections.OrderedDict() - def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True): - """Can the user send a message? + def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True): + """Can the entity (e.g. user or IP address) perform the action? Args: - user_id: The user sending a message. + key: The key we should use when rate limiting. Can be a user ID + (when sending events), an IP address, etc. time_now_s: The time now. - msg_rate_hz: The long term number of messages a user can send in a + rate_hz: The long term number of messages a user can send in a second. burst_count: How many messages the user can send before being limited. @@ -41,10 +42,10 @@ class Ratelimiter(object): """ self.prune_message_counts(time_now_s) message_count, time_start, _ignored = self.message_counts.get( - user_id, (0., time_now_s, None), + key, (0., time_now_s, None), ) time_delta = time_now_s - time_start - sent_count = message_count - time_delta * msg_rate_hz + sent_count = message_count - time_delta * rate_hz if sent_count < 0: allowed = True time_start = time_now_s @@ -56,13 +57,13 @@ class Ratelimiter(object): message_count += 1 if update: - self.message_counts[user_id] = ( - message_count, time_start, msg_rate_hz + self.message_counts[key] = ( + message_count, time_start, rate_hz ) - if msg_rate_hz > 0: + if rate_hz > 0: time_allowed = ( - time_start + (message_count - burst_count + 1) / msg_rate_hz + time_start + (message_count - burst_count + 1) / rate_hz ) if time_allowed < time_now_s: time_allowed = time_now_s @@ -72,12 +73,12 @@ class Ratelimiter(object): return allowed, time_allowed def prune_message_counts(self, time_now_s): - for user_id in list(self.message_counts.keys()): - message_count, time_start, msg_rate_hz = ( - self.message_counts[user_id] + for key in list(self.message_counts.keys()): + message_count, time_start, rate_hz = ( + self.message_counts[key] ) time_delta = time_now_s - time_start - if message_count - time_delta * msg_rate_hz > 0: + if message_count - time_delta * rate_hz > 0: break else: - del self.message_counts[user_id] + del self.message_counts[key] diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 2881482f96..d32f6fff73 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -54,6 +54,13 @@ class RegistrationConfig(Config): config.get("disable_msisdn_registration", False) ) + self.rc_registration_requests_per_second = config.get( + "rc_registration_requests_per_second", 0.17, + ) + self.rc_registration_request_burst_count = config.get( + "rc_registration_request_burst_count", 3, + ) + def default_config(self, generate_secrets=False, **kwargs): if generate_secrets: registration_shared_secret = 'registration_shared_secret: "%s"' % ( @@ -140,6 +147,17 @@ class RegistrationConfig(Config): # users cannot be auto-joined since they do not exist. # autocreate_auto_join_rooms: true + + # Number of registration requests a client can send per second. + # Defaults to 1/minute (0.17). + # + #rc_registration_requests_per_second: 0.17 + + # Number of registration requests a client can send before being + # throttled. + # Defaults to 3. + # + #rc_registration_request_burst_count: 3.0 """ % locals() def add_arguments(self, parser): diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 594754cfd8..d8d86d6ff3 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -93,9 +93,9 @@ class BaseHandler(object): messages_per_second = self.hs.config.rc_messages_per_second burst_count = self.hs.config.rc_message_burst_count - allowed, time_allowed = self.ratelimiter.send_message( + allowed, time_allowed = self.ratelimiter.can_do_action( user_id, time_now, - msg_rate_hz=messages_per_second, + rate_hz=messages_per_second, burst_count=burst_count, update=update, ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index c0e06929bd..47d5e276f8 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -24,6 +24,7 @@ from synapse.api.errors import ( AuthError, Codes, InvalidCaptchaError, + LimitExceededError, RegistrationError, SynapseError, ) @@ -60,6 +61,7 @@ class RegistrationHandler(BaseHandler): self.user_directory_handler = hs.get_user_directory_handler() self.captcha_client = CaptchaServerHttpClient(hs) self.identity_handler = self.hs.get_handlers().identity_handler + self.ratelimiter = hs.get_ratelimiter() self._next_generated_user_id = None @@ -149,6 +151,7 @@ class RegistrationHandler(BaseHandler): threepid=None, user_type=None, default_display_name=None, + address=None, ): """Registers a new client on the server. @@ -167,6 +170,7 @@ class RegistrationHandler(BaseHandler): api.constants.UserTypes, or None for a normal user. default_display_name (unicode|None): if set, the new user's displayname will be set to this. Defaults to 'localpart'. + address (str|None): the IP address used to perform the regitration. Returns: A tuple of (user_id, access_token). Raises: @@ -206,7 +210,7 @@ class RegistrationHandler(BaseHandler): token = None if generate_token: token = self.macaroon_gen.generate_access_token(user_id) - yield self._register_with_store( + yield self.register_with_store( user_id=user_id, token=token, password_hash=password_hash, @@ -215,6 +219,7 @@ class RegistrationHandler(BaseHandler): create_profile_with_displayname=default_display_name, admin=admin, user_type=user_type, + address=address, ) if self.hs.config.user_directory_search_all_users: @@ -238,12 +243,13 @@ class RegistrationHandler(BaseHandler): if default_display_name is None: default_display_name = localpart try: - yield self._register_with_store( + yield self.register_with_store( user_id=user_id, token=token, password_hash=password_hash, make_guest=make_guest, create_profile_with_displayname=default_display_name, + address=address, ) except SynapseError: # if user id is taken, just generate another @@ -337,7 +343,7 @@ class RegistrationHandler(BaseHandler): user_id, allowed_appservice=service ) - yield self._register_with_store( + yield self.register_with_store( user_id=user_id, password_hash="", appservice_id=service_id, @@ -513,7 +519,7 @@ class RegistrationHandler(BaseHandler): token = self.macaroon_gen.generate_access_token(user_id) if need_register: - yield self._register_with_store( + yield self.register_with_store( user_id=user_id, token=token, password_hash=password_hash, @@ -590,10 +596,10 @@ class RegistrationHandler(BaseHandler): ratelimit=False, ) - def _register_with_store(self, user_id, token=None, password_hash=None, - was_guest=False, make_guest=False, appservice_id=None, - create_profile_with_displayname=None, admin=False, - user_type=None): + def register_with_store(self, user_id, token=None, password_hash=None, + was_guest=False, make_guest=False, appservice_id=None, + create_profile_with_displayname=None, admin=False, + user_type=None, address=None): """Register user in the datastore. Args: @@ -612,10 +618,26 @@ class RegistrationHandler(BaseHandler): admin (boolean): is an admin user? user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. + address (str|None): the IP address used to perform the regitration. Returns: Deferred """ + # Don't rate limit for app services + if appservice_id is None and address is not None: + time_now = self.clock.time() + + allowed, time_allowed = self.ratelimiter.can_do_action( + address, time_now_s=time_now, + rate_hz=self.hs.config.rc_registration_requests_per_second, + burst_count=self.hs.config.rc_registration_request_burst_count, + ) + + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now)), + ) + if self.hs.config.worker_app: return self._register_client( user_id=user_id, @@ -627,6 +649,7 @@ class RegistrationHandler(BaseHandler): create_profile_with_displayname=create_profile_with_displayname, admin=admin, user_type=user_type, + address=address, ) else: return self.store.register( diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 1d27c9221f..912a5ac341 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -33,11 +33,12 @@ class ReplicationRegisterServlet(ReplicationEndpoint): def __init__(self, hs): super(ReplicationRegisterServlet, self).__init__(hs) self.store = hs.get_datastore() + self.registration_handler = hs.get_registration_handler() @staticmethod def _serialize_payload( user_id, token, password_hash, was_guest, make_guest, appservice_id, - create_profile_with_displayname, admin, user_type, + create_profile_with_displayname, admin, user_type, address, ): """ Args: @@ -56,6 +57,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): admin (boolean): is an admin user? user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. + address (str|None): the IP address used to perform the regitration. """ return { "token": token, @@ -66,13 +68,14 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "create_profile_with_displayname": create_profile_with_displayname, "admin": admin, "user_type": user_type, + "address": address, } @defer.inlineCallbacks def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) - yield self.store.register( + yield self.registration_handler.register_with_store( user_id=user_id, token=content["token"], password_hash=content["password_hash"], @@ -82,6 +85,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): create_profile_with_displayname=content["create_profile_with_displayname"], admin=content["admin"], user_type=content["user_type"], + address=content["address"] ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 94cbba4303..b7f354570c 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -25,7 +25,12 @@ from twisted.internet import defer import synapse import synapse.types from synapse.api.constants import LoginType -from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError +from synapse.api.errors import ( + Codes, + LimitExceededError, + SynapseError, + UnrecognizedRequestError, +) from synapse.config.server import is_threepid_reserved from synapse.http.servlet import ( RestServlet, @@ -191,18 +196,36 @@ class RegisterRestServlet(RestServlet): self.identity_handler = hs.get_handlers().identity_handler self.room_member_handler = hs.get_room_member_handler() self.macaroon_gen = hs.get_macaroon_generator() + self.ratelimiter = hs.get_ratelimiter() + self.clock = hs.get_clock() @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) + client_addr = request.getClientIP() + + time_now = self.clock.time() + + allowed, time_allowed = self.ratelimiter.can_do_action( + client_addr, time_now_s=time_now, + rate_hz=self.hs.config.rc_registration_requests_per_second, + burst_count=self.hs.config.rc_registration_request_burst_count, + update=False, + ) + + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now)), + ) + kind = b"user" if b"kind" in request.args: kind = request.args[b"kind"][0] if kind == b"guest": - ret = yield self._do_guest_registration(body) + ret = yield self._do_guest_registration(body, address=client_addr) defer.returnValue(ret) return elif kind != b"user": @@ -411,6 +434,7 @@ class RegisterRestServlet(RestServlet): guest_access_token=guest_access_token, generate_token=False, threepid=threepid, + address=client_addr, ) # Necessary due to auth checks prior to the threepid being # written to the db @@ -522,12 +546,13 @@ class RegisterRestServlet(RestServlet): defer.returnValue(result) @defer.inlineCallbacks - def _do_guest_registration(self, params): + def _do_guest_registration(self, params, address=None): if not self.hs.config.allow_guest_access: raise SynapseError(403, "Guest access is disabled") user_id, _ = yield self.registration_handler.register( generate_token=False, - make_guest=True + make_guest=True, + address=address, ) # we don't allow guests to specify their own device_id, because diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 8933fe3b72..30a255d441 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -6,34 +6,34 @@ from tests import unittest class TestRatelimiter(unittest.TestCase): def test_allowed(self): limiter = Ratelimiter() - allowed, time_allowed = limiter.send_message( - user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1 + allowed, time_allowed = limiter.can_do_action( + key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1 ) self.assertTrue(allowed) self.assertEquals(10., time_allowed) - allowed, time_allowed = limiter.send_message( - user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1 + allowed, time_allowed = limiter.can_do_action( + key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1 ) self.assertFalse(allowed) self.assertEquals(10., time_allowed) - allowed, time_allowed = limiter.send_message( - user_id="test_id", time_now_s=10, msg_rate_hz=0.1, burst_count=1 + allowed, time_allowed = limiter.can_do_action( + key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1 ) self.assertTrue(allowed) self.assertEquals(20., time_allowed) def test_pruning(self): limiter = Ratelimiter() - allowed, time_allowed = limiter.send_message( - user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1 + allowed, time_allowed = limiter.can_do_action( + key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1 ) self.assertIn("test_id_1", limiter.message_counts) - allowed, time_allowed = limiter.send_message( - user_id="test_id_2", time_now_s=10, msg_rate_hz=0.1, burst_count=1 + allowed, time_allowed = limiter.can_do_action( + key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1 ) self.assertNotIn("test_id_1", limiter.message_counts) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 80da1c8954..d60c124eec 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -55,11 +55,11 @@ class ProfileTestCase(unittest.TestCase): federation_client=self.mock_federation, federation_server=Mock(), federation_registry=self.mock_registry, - ratelimiter=NonCallableMock(spec_set=["send_message"]), + ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) + self.ratelimiter.can_do_action.return_value = (True, 0) self.store = hs.get_datastore() diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 9e9fbbfe93..524af4f8d1 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -31,10 +31,10 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver( "blue", federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), + ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) - hs.get_ratelimiter().send_message.return_value = (True, 0) + hs.get_ratelimiter().can_do_action.return_value = (True, 0) return hs diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 483bebc832..36d8547275 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -40,10 +40,10 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): config.auto_join_rooms = [] hs = self.setup_test_homeserver( - config=config, ratelimiter=NonCallableMock(spec_set=["send_message"]) + config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"]) ) self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) + self.ratelimiter.can_do_action.return_value = (True, 0) hs.get_handlers().federation_handler = Mock() diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index a824be9a62..015c144248 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -41,10 +41,10 @@ class RoomBase(unittest.HomeserverTestCase): "red", http_client=None, federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), + ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) self.ratelimiter = self.hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) + self.ratelimiter.can_do_action.return_value = (True, 0) self.hs.get_federation_handler = Mock(return_value=Mock()) @@ -96,7 +96,7 @@ class RoomPermissionsTestCase(RoomBase): # auth as user_id now self.helper.auth_user_id = self.user_id - def test_send_message(self): + def test_can_do_action(self): msg_content = b'{"msgtype":"m.text","body":"hello"}' seq = iter(range(100)) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 0ad814c5e5..30fb77bac8 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -42,13 +42,13 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): "red", http_client=None, federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), + ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) self.event_source = hs.get_event_sources().sources["typing"] self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.send_message.return_value = (True, 0) + self.ratelimiter.can_do_action.return_value = (True, 0) hs.get_handlers().federation_handler = Mock() diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 906b348d3e..3600434858 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -130,3 +130,51 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Guest access is disabled") + + def test_POST_ratelimiting_guest(self): + self.hs.config.rc_registration_request_burst_count = 5 + + for i in range(0, 6): + url = self.url + b"?kind=guest" + request, channel = self.make_request(b"POST", url, b"{}") + self.render(request) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(retry_after_ms / 1000.) + + request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + self.render(request) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_POST_ratelimiting(self): + self.hs.config.rc_registration_request_burst_count = 5 + + for i in range(0, 6): + params = { + "username": "kermit" + str(i), + "password": "monkey", + "device_id": "frogfone", + "auth": {"type": LoginType.DUMMY}, + } + request_data = json.dumps(params) + request, channel = self.make_request(b"POST", self.url, request_data) + self.render(request) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(retry_after_ms / 1000.) + + request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + self.render(request) + + self.assertEquals(channel.result["code"], b"200", channel.result) diff --git a/tests/test_mau.py b/tests/test_mau.py index 04f95c942f..00be1a8c21 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -17,7 +17,7 @@ import json -from mock import Mock, NonCallableMock +from mock import Mock from synapse.api.constants import LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError @@ -36,7 +36,6 @@ class TestMauLimit(unittest.HomeserverTestCase): "red", http_client=None, federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["send_message"]), ) self.store = self.hs.get_datastore() diff --git a/tests/utils.py b/tests/utils.py index ee272157aa..e4c42f9fa8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -150,6 +150,8 @@ def default_config(name): config.admin_contact = None config.rc_messages_per_second = 10000 config.rc_message_burst_count = 10000 + config.rc_registration_request_burst_count = 3.0 + config.rc_registration_requests_per_second = 0.17 config.saml2_enabled = False config.public_baseurl = None config.default_identity_server = None -- cgit 1.5.1 From face0c5b3c8ed6d0f29f7eaa3a2f9fd19eb99540 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 6 Mar 2019 17:39:32 +0000 Subject: Prefill client IPs cache on workers --- synapse/replication/slave/storage/client_ips.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'synapse/replication') diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index 60641f1a49..5b8521c770 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -43,6 +43,8 @@ class SlavedClientIpStore(BaseSlavedStore): if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: return + self.client_ip_last_seen.prefill(key, now) + self.hs.get_tcp_replication().send_user_ip( user_id, access_token, ip, user_agent, device_id, now ) -- cgit 1.5.1 From cdb803616195c76306b30328af9da7cb2b32c960 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 20 Mar 2019 16:04:35 +0000 Subject: Add a config option for torture-testing worker replication. (#4902) Setting this to 50 or so makes a bunch of sytests fail in worker mode. --- changelog.d/4902.misc | 1 + synapse/config/server.py | 5 +++++ synapse/replication/tcp/resource.py | 18 +++++++++++++++++- 3 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 changelog.d/4902.misc (limited to 'synapse/replication') diff --git a/changelog.d/4902.misc b/changelog.d/4902.misc new file mode 100644 index 0000000000..fecc06a6e8 --- /dev/null +++ b/changelog.d/4902.misc @@ -0,0 +1 @@ +Add a config option for torture-testing worker replication. diff --git a/synapse/config/server.py b/synapse/config/server.py index 499eb30bea..08e4e45482 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -126,6 +126,11 @@ class ServerConfig(Config): self.public_baseurl += '/' self.start_pushers = config.get("start_pushers", True) + # (undocumented) option for torturing the worker-mode replication a bit, + # for testing. The value defines the number of milliseconds to pause before + # sending out any replication updates. + self.replication_torture_level = config.get("replication_torture_level") + self.listeners = [] for listener in config.get("listeners", []): if not isinstance(listener.get("port", None), int): diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index fd59f1595f..47cdf30bd3 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -16,6 +16,7 @@ """ import logging +import random from six import itervalues @@ -74,6 +75,8 @@ class ReplicationStreamer(object): self.notifier = hs.get_notifier() self._server_notices_sender = hs.get_server_notices_sender() + self._replication_torture_level = hs.config.replication_torture_level + # Current connections. self.connections = [] @@ -157,10 +160,23 @@ class ReplicationStreamer(object): for stream in self.streams: stream.advance_current_token() - for stream in self.streams: + all_streams = self.streams + + if self._replication_torture_level is not None: + # there is no guarantee about ordering between the streams, + # so let's shuffle them around a bit when we are in torture mode. + all_streams = list(all_streams) + random.shuffle(all_streams) + + for stream in all_streams: if stream.last_token == stream.upto_token: continue + if self._replication_torture_level: + yield self.clock.sleep( + self._replication_torture_level / 1000.0 + ) + logger.debug( "Getting stream: %s: %s -> %s", stream.NAME, stream.last_token, stream.upto_token -- cgit 1.5.1